mirror of
https://github.com/justUmen/Bjornulf_custom_nodes.git
synced 2026-03-21 12:42:11 -03:00
80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
import os
|
|
import shutil
|
|
import numpy as np
|
|
import soundfile as sf
|
|
import torch
|
|
|
|
class SaveTmpAudio:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"optional": {
|
|
"audio": ("AUDIO",),
|
|
"audio_path": ("STRING",),
|
|
},
|
|
}
|
|
|
|
FUNCTION = "save_audio"
|
|
RETURN_TYPES = ()
|
|
OUTPUT_NODE = True
|
|
CATEGORY = "Bjornulf"
|
|
|
|
def save_audio(self, audio=None, audio_path=None):
|
|
# Ensure the output directory exists
|
|
os.makedirs("./output", exist_ok=True)
|
|
|
|
# Check if neither input is provided
|
|
if audio is None and audio_path is None:
|
|
raise ValueError("Either 'audio' or 'audio_path' must be provided")
|
|
|
|
# Case 1: Handle AUDIO input if provided
|
|
if audio is not None:
|
|
# Validate that audio is a dictionary with required keys
|
|
if not isinstance(audio, dict):
|
|
raise ValueError("AUDIO input must be a dictionary with 'waveform' and 'sample_rate'")
|
|
if 'waveform' not in audio or 'sample_rate' not in audio:
|
|
raise ValueError("AUDIO dictionary must contain 'waveform' and 'sample_rate' keys")
|
|
|
|
# Extract waveform and sample rate
|
|
waveform = audio['waveform']
|
|
sample_rate = audio['sample_rate']
|
|
|
|
# Ensure waveform is a PyTorch tensor
|
|
if not isinstance(waveform, torch.Tensor):
|
|
raise TypeError(f"Waveform must be a PyTorch tensor, got {type(waveform)}")
|
|
|
|
# Squeeze waveform to remove extra dimensions (e.g., from (1, 1, N) to (N,))
|
|
waveform = waveform.squeeze()
|
|
|
|
# Convert to NumPy array
|
|
audio_np = waveform.cpu().numpy()
|
|
|
|
# Validate the shape after squeezing
|
|
if audio_np.ndim > 2:
|
|
raise ValueError(f"Audio data has too many dimensions after squeezing: {audio_np.shape}")
|
|
|
|
# Scale floating-point data to 16-bit integers (assuming range [-1, 1])
|
|
if audio_np.dtype in (np.float32, np.float64):
|
|
audio_np = (audio_np * 32767).astype(np.int16)
|
|
|
|
# Save as WAV file
|
|
filename = "./output/tmp_api.wav"
|
|
sf.write(filename, audio_np, sample_rate)
|
|
|
|
# Case 2: Handle audio_path input if audio is not provided
|
|
elif audio_path is not None:
|
|
# Verify the file exists
|
|
if not os.path.exists(audio_path):
|
|
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
|
|
|
# Check for supported file extensions
|
|
ext = os.path.splitext(audio_path)[1].lower()
|
|
if ext not in ('.wav', '.mp3'):
|
|
raise ValueError("audio_path must be a .wav or .mp3 file")
|
|
|
|
# Copy the file to the output directory
|
|
filename = f"./output/tmp_api{ext}"
|
|
shutil.copy(audio_path, filename)
|
|
|
|
# Return UI information for ComfyUI
|
|
return {"ui": {"audio": [{"filename": filename, "type": "output"}]}} |