This commit is contained in:
justumen
2024-11-01 09:31:56 +01:00
parent 06da237179
commit 7df528d1d9
16 changed files with 195 additions and 145 deletions

View File

@@ -231,12 +231,13 @@ class AudioVideoSync:
def process_audio(self, audio_tensor, sample_rate, target_duration, original_duration,
max_speedup, max_slowdown):
"""Process audio to match video duration."""
if audio_tensor.dim() == 3:
audio_tensor = audio_tensor.squeeze(0)
elif audio_tensor.dim() == 1:
# Ensure audio tensor has correct dimensions
if audio_tensor.dim() == 2:
audio_tensor = audio_tensor.unsqueeze(0)
elif audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0).unsqueeze(0)
current_duration = audio_tensor.shape[1] / sample_rate
current_duration = audio_tensor.shape[-1] / sample_rate
# Calculate synchronized video duration
if target_duration > original_duration:
@@ -256,17 +257,17 @@ class AudioVideoSync:
# Adjust audio length
if current_duration < sync_duration:
silence_samples = int((sync_duration - current_duration) * sample_rate)
silence = torch.zeros(audio_tensor.shape[0], silence_samples)
processed_audio = torch.cat([audio_tensor, silence], dim=1)
silence = torch.zeros(audio_tensor.shape[0], audio_tensor.shape[1], silence_samples)
processed_audio = torch.cat([audio_tensor, silence], dim=-1)
else:
required_samples = int(sync_duration * sample_rate)
processed_audio = audio_tensor[:, :required_samples]
processed_audio = audio_tensor[..., :required_samples]
return processed_audio, sync_duration
def save_audio(self, audio_tensor, sample_rate, target_duration, original_duration,
max_speedup, max_slowdown):
"""Save processed audio to file."""
"""Save processed audio to file and return consistent AUDIO format."""
timestamp = self.generate_timestamp()
output_path = os.path.join(self.sync_audio_dir, f"sync_audio_{timestamp}.wav")
@@ -275,12 +276,29 @@ class AudioVideoSync:
max_speedup, max_slowdown
)
torchaudio.save(output_path, processed_audio, sample_rate)
return os.path.abspath(output_path)
# Save with proper format
torchaudio.save(output_path, processed_audio.squeeze(0), sample_rate)
# Return consistent AUDIO format
return {
'waveform': processed_audio,
'sample_rate': sample_rate
}
def load_audio_from_path(self, audio_path):
"""Load audio from file path."""
"""Load audio from file path and format it consistently with AUDIO input."""
waveform, sample_rate = torchaudio.load(audio_path)
# Ensure waveform has 3 dimensions (batch, channels, samples) like AUDIO input
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0) # Add batch dimension
# Convert to float32 and normalize to range [0, 1] if needed
if waveform.dtype != torch.float32:
waveform = waveform.float()
if waveform.max() > 1.0:
waveform = waveform / 32768.0 # Normalize 16-bit audio
return {'waveform': waveform, 'sample_rate': sample_rate}
def extract_frames(self, video_path):
@@ -297,7 +315,10 @@ class AudioVideoSync:
# Load frames and convert to tensor
frames = []
frame_files = sorted(os.listdir(temp_dir))
transform = transforms.Compose([transforms.ToTensor()])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 255) # Scale to 0-255 range
])
for frame_file in frame_files:
image = Image.open(os.path.join(temp_dir, frame_file))
@@ -307,6 +328,13 @@ class AudioVideoSync:
# Stack frames into a single tensor
frames_tensor = torch.stack(frames)
# Ensure the tensor is in the correct format (B, C, H, W)
if frames_tensor.dim() == 3:
frames_tensor = frames_tensor.unsqueeze(0)
# Convert to uint8
frames_tensor = frames_tensor.byte()
# Clean up temporary directory
for frame_file in frame_files:
os.remove(os.path.join(temp_dir, frame_file))
@@ -350,25 +378,35 @@ class AudioVideoSync:
sync_video_path = self.create_sync_video(
video_path, original_duration, audio_duration, max_speedup, max_slowdown
)
sync_audio_path = self.save_audio(
# Process and save audio, getting consistent AUDIO format back
sync_audio = self.save_audio(
AUDIO['waveform'], AUDIO['sample_rate'], audio_duration,
original_duration, max_speedup, max_slowdown
)
# Get sync_audio_path separately
sync_audio_path = os.path.join(self.sync_audio_dir, f"sync_audio_{self.generate_timestamp()}.wav")
torchaudio.save(sync_audio_path, sync_audio['waveform'].squeeze(0), sync_audio['sample_rate'])
# Get final properties
sync_video_duration, _, sync_frame_count = self.get_video_info(sync_video_path)
sync_audio_duration = torchaudio.info(sync_audio_path).num_frames / AUDIO['sample_rate']
sync_audio_duration = sync_audio['waveform'].shape[-1] / sync_audio['sample_rate']
video_frames = self.extract_frames(sync_video_path)
# Convert video_frames to the format expected by ComfyUI
video_frames = video_frames.float() / 255.0
video_frames = video_frames.permute(0, 2, 3, 1)
return (
video_frames,
AUDIO,
sync_audio, # Now returns consistent AUDIO format
sync_audio_path,
sync_video_path,
original_duration, # input_video_duration
original_duration,
sync_video_duration,
audio_duration, # input_audio_duration
audio_duration,
sync_audio_duration,
sync_frame_count
)