mirror of
https://github.com/justUmen/Bjornulf_custom_nodes.git
synced 2026-03-21 20:52:11 -03:00
0.51
This commit is contained in:
@@ -231,12 +231,13 @@ class AudioVideoSync:
|
||||
def process_audio(self, audio_tensor, sample_rate, target_duration, original_duration,
|
||||
max_speedup, max_slowdown):
|
||||
"""Process audio to match video duration."""
|
||||
if audio_tensor.dim() == 3:
|
||||
audio_tensor = audio_tensor.squeeze(0)
|
||||
elif audio_tensor.dim() == 1:
|
||||
# Ensure audio tensor has correct dimensions
|
||||
if audio_tensor.dim() == 2:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
elif audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
current_duration = audio_tensor.shape[1] / sample_rate
|
||||
current_duration = audio_tensor.shape[-1] / sample_rate
|
||||
|
||||
# Calculate synchronized video duration
|
||||
if target_duration > original_duration:
|
||||
@@ -256,17 +257,17 @@ class AudioVideoSync:
|
||||
# Adjust audio length
|
||||
if current_duration < sync_duration:
|
||||
silence_samples = int((sync_duration - current_duration) * sample_rate)
|
||||
silence = torch.zeros(audio_tensor.shape[0], silence_samples)
|
||||
processed_audio = torch.cat([audio_tensor, silence], dim=1)
|
||||
silence = torch.zeros(audio_tensor.shape[0], audio_tensor.shape[1], silence_samples)
|
||||
processed_audio = torch.cat([audio_tensor, silence], dim=-1)
|
||||
else:
|
||||
required_samples = int(sync_duration * sample_rate)
|
||||
processed_audio = audio_tensor[:, :required_samples]
|
||||
processed_audio = audio_tensor[..., :required_samples]
|
||||
|
||||
return processed_audio, sync_duration
|
||||
|
||||
def save_audio(self, audio_tensor, sample_rate, target_duration, original_duration,
|
||||
max_speedup, max_slowdown):
|
||||
"""Save processed audio to file."""
|
||||
"""Save processed audio to file and return consistent AUDIO format."""
|
||||
timestamp = self.generate_timestamp()
|
||||
output_path = os.path.join(self.sync_audio_dir, f"sync_audio_{timestamp}.wav")
|
||||
|
||||
@@ -275,12 +276,29 @@ class AudioVideoSync:
|
||||
max_speedup, max_slowdown
|
||||
)
|
||||
|
||||
torchaudio.save(output_path, processed_audio, sample_rate)
|
||||
return os.path.abspath(output_path)
|
||||
# Save with proper format
|
||||
torchaudio.save(output_path, processed_audio.squeeze(0), sample_rate)
|
||||
|
||||
# Return consistent AUDIO format
|
||||
return {
|
||||
'waveform': processed_audio,
|
||||
'sample_rate': sample_rate
|
||||
}
|
||||
|
||||
def load_audio_from_path(self, audio_path):
|
||||
"""Load audio from file path."""
|
||||
"""Load audio from file path and format it consistently with AUDIO input."""
|
||||
waveform, sample_rate = torchaudio.load(audio_path)
|
||||
|
||||
# Ensure waveform has 3 dimensions (batch, channels, samples) like AUDIO input
|
||||
if waveform.dim() == 2:
|
||||
waveform = waveform.unsqueeze(0) # Add batch dimension
|
||||
|
||||
# Convert to float32 and normalize to range [0, 1] if needed
|
||||
if waveform.dtype != torch.float32:
|
||||
waveform = waveform.float()
|
||||
if waveform.max() > 1.0:
|
||||
waveform = waveform / 32768.0 # Normalize 16-bit audio
|
||||
|
||||
return {'waveform': waveform, 'sample_rate': sample_rate}
|
||||
|
||||
def extract_frames(self, video_path):
|
||||
@@ -297,7 +315,10 @@ class AudioVideoSync:
|
||||
# Load frames and convert to tensor
|
||||
frames = []
|
||||
frame_files = sorted(os.listdir(temp_dir))
|
||||
transform = transforms.Compose([transforms.ToTensor()])
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x * 255) # Scale to 0-255 range
|
||||
])
|
||||
|
||||
for frame_file in frame_files:
|
||||
image = Image.open(os.path.join(temp_dir, frame_file))
|
||||
@@ -307,6 +328,13 @@ class AudioVideoSync:
|
||||
# Stack frames into a single tensor
|
||||
frames_tensor = torch.stack(frames)
|
||||
|
||||
# Ensure the tensor is in the correct format (B, C, H, W)
|
||||
if frames_tensor.dim() == 3:
|
||||
frames_tensor = frames_tensor.unsqueeze(0)
|
||||
|
||||
# Convert to uint8
|
||||
frames_tensor = frames_tensor.byte()
|
||||
|
||||
# Clean up temporary directory
|
||||
for frame_file in frame_files:
|
||||
os.remove(os.path.join(temp_dir, frame_file))
|
||||
@@ -350,25 +378,35 @@ class AudioVideoSync:
|
||||
sync_video_path = self.create_sync_video(
|
||||
video_path, original_duration, audio_duration, max_speedup, max_slowdown
|
||||
)
|
||||
sync_audio_path = self.save_audio(
|
||||
|
||||
# Process and save audio, getting consistent AUDIO format back
|
||||
sync_audio = self.save_audio(
|
||||
AUDIO['waveform'], AUDIO['sample_rate'], audio_duration,
|
||||
original_duration, max_speedup, max_slowdown
|
||||
)
|
||||
|
||||
# Get sync_audio_path separately
|
||||
sync_audio_path = os.path.join(self.sync_audio_dir, f"sync_audio_{self.generate_timestamp()}.wav")
|
||||
torchaudio.save(sync_audio_path, sync_audio['waveform'].squeeze(0), sync_audio['sample_rate'])
|
||||
|
||||
# Get final properties
|
||||
sync_video_duration, _, sync_frame_count = self.get_video_info(sync_video_path)
|
||||
sync_audio_duration = torchaudio.info(sync_audio_path).num_frames / AUDIO['sample_rate']
|
||||
sync_audio_duration = sync_audio['waveform'].shape[-1] / sync_audio['sample_rate']
|
||||
|
||||
video_frames = self.extract_frames(sync_video_path)
|
||||
|
||||
# Convert video_frames to the format expected by ComfyUI
|
||||
video_frames = video_frames.float() / 255.0
|
||||
video_frames = video_frames.permute(0, 2, 3, 1)
|
||||
|
||||
return (
|
||||
video_frames,
|
||||
AUDIO,
|
||||
sync_audio, # Now returns consistent AUDIO format
|
||||
sync_audio_path,
|
||||
sync_video_path,
|
||||
original_duration, # input_video_duration
|
||||
original_duration,
|
||||
sync_video_duration,
|
||||
audio_duration, # input_audio_duration
|
||||
audio_duration,
|
||||
sync_audio_duration,
|
||||
sync_frame_count
|
||||
)
|
||||
Reference in New Issue
Block a user