Files
Bjornulf_custom_nodes/save_tmp_audio.py
justumen 39dfb0220a 0.77
2025-03-19 17:36:25 +01:00

80 lines
3.1 KiB
Python

import os
import shutil
import numpy as np
import soundfile as sf
import torch
class SaveTmpAudio:
@classmethod
def INPUT_TYPES(cls):
return {
"optional": {
"audio": ("AUDIO",),
"audio_path": ("STRING",),
},
}
FUNCTION = "save_audio"
RETURN_TYPES = ()
OUTPUT_NODE = True
CATEGORY = "Bjornulf"
def save_audio(self, audio=None, audio_path=None):
# Ensure the output directory exists
os.makedirs("./output", exist_ok=True)
# Check if neither input is provided
if audio is None and audio_path is None:
raise ValueError("Either 'audio' or 'audio_path' must be provided")
# Case 1: Handle AUDIO input if provided
if audio is not None:
# Validate that audio is a dictionary with required keys
if not isinstance(audio, dict):
raise ValueError("AUDIO input must be a dictionary with 'waveform' and 'sample_rate'")
if 'waveform' not in audio or 'sample_rate' not in audio:
raise ValueError("AUDIO dictionary must contain 'waveform' and 'sample_rate' keys")
# Extract waveform and sample rate
waveform = audio['waveform']
sample_rate = audio['sample_rate']
# Ensure waveform is a PyTorch tensor
if not isinstance(waveform, torch.Tensor):
raise TypeError(f"Waveform must be a PyTorch tensor, got {type(waveform)}")
# Squeeze waveform to remove extra dimensions (e.g., from (1, 1, N) to (N,))
waveform = waveform.squeeze()
# Convert to NumPy array
audio_np = waveform.cpu().numpy()
# Validate the shape after squeezing
if audio_np.ndim > 2:
raise ValueError(f"Audio data has too many dimensions after squeezing: {audio_np.shape}")
# Scale floating-point data to 16-bit integers (assuming range [-1, 1])
if audio_np.dtype in (np.float32, np.float64):
audio_np = (audio_np * 32767).astype(np.int16)
# Save as WAV file
filename = "./output/tmp_api.wav"
sf.write(filename, audio_np, sample_rate)
# Case 2: Handle audio_path input if audio is not provided
elif audio_path is not None:
# Verify the file exists
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file not found: {audio_path}")
# Check for supported file extensions
ext = os.path.splitext(audio_path)[1].lower()
if ext not in ('.wav', '.mp3'):
raise ValueError("audio_path must be a .wav or .mp3 file")
# Copy the file to the output directory
filename = f"./output/tmp_api{ext}"
shutil.copy(audio_path, filename)
# Return UI information for ComfyUI
return {"ui": {"audio": [{"filename": filename, "type": "output"}]}}