This commit is contained in:
justumen
2025-02-16 20:58:39 +01:00
parent dfcf429e5c
commit 3ebd5cbb92
22 changed files with 967 additions and 132 deletions

View File

@@ -4,6 +4,7 @@ import os
import numpy as np
import tempfile
import wave
import subprocess # Added for ffmpeg
try:
import faster_whisper
@@ -25,6 +26,7 @@ class SpeechToText:
"optional": {
"AUDIO": ("AUDIO",),
"audio_path": ("STRING", {"default": None, "forceInput": True}),
"video_path": ("STRING", {"default": None, "forceInput": True}),
}
}
@@ -35,26 +37,25 @@ class SpeechToText:
def tensor_to_wav(self, audio_tensor, sample_rate):
"""Convert audio tensor to temporary WAV file"""
# Convert tensor to numpy array
audio_data = audio_tensor.squeeze().numpy()
# Create temporary file
if audio_data.ndim == 2:
audio_data = np.mean(audio_data, axis=0)
elif audio_data.ndim > 2:
raise ValueError(f"Unsupported audio tensor shape: {audio_data.shape}")
temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
# Write WAV file
with wave.open(temp_file.name, 'wb') as wav_file:
wav_file.setnchannels(1) # Mono audio
wav_file.setsampwidth(2) # 2 bytes per sample
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
# Convert float32 to int16
audio_data = (audio_data * 32767).astype(np.int16)
wav_file.writeframes(audio_data.tobytes())
return temp_file.name
def load_local_model(self, model_size):
"""Load the local Whisper model if not already loaded"""
if not WHISPER_AVAILABLE:
return False, "faster-whisper not installed. Install with: pip install faster-whisper"
@@ -68,7 +69,6 @@ class SpeechToText:
return False, f"Error loading model: {str(e)}"
def transcribe_local(self, audio_path, model_size):
"""Transcribe audio using local Whisper model"""
success, message = self.load_local_model(model_size)
if not success:
return False, message, None
@@ -83,23 +83,47 @@ class SpeechToText:
except Exception as e:
return False, f"Error during local transcription: {str(e)}", None
def transcribe_audio(self, model_size, AUDIO=None, audio_path=None):
def transcribe_audio(self, model_size, AUDIO=None, audio_path=None, video_path=None):
transcript = "No valid audio input provided"
detected_language = ""
temp_wav_path = None
temp_audio_path = None
try:
# Determine which audio source to use
if AUDIO is not None:
# Convert tensor audio data to WAV file
# Check video input first
if video_path and os.path.exists(video_path):
try:
# Create temp file for extracted audio
temp_audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
temp_audio.close()
temp_audio_path = temp_audio.name
# FFmpeg command to extract audio
command = [
'ffmpeg',
'-i', video_path,
'-vn',
'-acodec', 'pcm_s16le',
'-ar', '16000',
'-ac', '1',
'-y',
temp_audio_path
]
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
audio_to_process = temp_audio_path
except subprocess.CalledProcessError as e:
return (f"FFmpeg error: {e.stderr.decode()}", "", "")
except Exception as e:
return (f"Error extracting audio: {str(e)}", "", "")
elif AUDIO is not None:
waveform = AUDIO['waveform']
sample_rate = AUDIO['sample_rate']
temp_wav_path = self.tensor_to_wav(waveform, sample_rate)
audio_to_process = temp_wav_path
elif audio_path is not None and os.path.exists(audio_path):
elif audio_path and os.path.exists(audio_path):
audio_to_process = audio_path
else:
return ("No valid audio input provided", "")
return ("No valid audio input provided", "", "")
if audio_to_process:
success, result, lang = self.transcribe_local(audio_to_process, model_size)
@@ -107,11 +131,12 @@ class SpeechToText:
detected_language = lang if success else ""
finally:
# Clean up temporary file if it was created
# Cleanup temporary files
if temp_wav_path and os.path.exists(temp_wav_path):
os.unlink(temp_wav_path)
if temp_audio_path and os.path.exists(temp_audio_path):
os.unlink(temp_audio_path)
#Create detected_language_name based on detected_language, en = English, es = Spanish, fr = French, de = German, etc...
language_map = {
"ar": "Arabic", "cs": "Czech", "de": "German", "en": "English",
"es": "Spanish", "fr": "French", "hi": "Hindi", "hu": "Hungarian",
@@ -121,4 +146,4 @@ class SpeechToText:
}
detected_language_name = language_map.get(detected_language, "Unknown")
return (transcript, detected_language,detected_language_name)
return (transcript, detected_language, detected_language_name)