This commit is contained in:
justumen
2024-10-28 11:32:37 +01:00
parent 19cb4cbe49
commit 66eb5282c2
18 changed files with 950 additions and 213 deletions

View File

@@ -4,211 +4,371 @@ import os
import subprocess
from datetime import datetime
import math
from PIL import Image
import logging
import torchvision.transforms as transforms
class AudioVideoSync:
"""
ComfyUI custom node for synchronizing audio and video with configurable speed adjustments.
Supports both video files and image sequences as input, as well as audio files or AUDIO objects.
"""
def __init__(self):
pass
"""Initialize the AudioVideoSync node."""
self.base_dir = "Bjornulf"
self.temp_dir = os.path.join(self.base_dir, "temp_frames")
self.sync_video_dir = os.path.join(self.base_dir, "sync_video")
self.sync_audio_dir = os.path.join(self.base_dir, "sync_audio")
# Create necessary directories
for directory in [self.temp_dir, self.sync_video_dir, self.sync_audio_dir]:
os.makedirs(directory, exist_ok=True)
@classmethod
def INPUT_TYPES(cls):
"""Define input parameters for the node."""
return {
"required": {
"audio": ("AUDIO",),
"video_path": ("STRING", {"default": ""}),
"audio_duration": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 3600.0, "step": 0.001}),
"max_speedup": ("FLOAT", {
"default": 1.5,
"min": 1.0,
"max": 10.0,
"step": 0.1
}),
"max_slowdown": ("FLOAT", {
"default": 0.5,
"min": 0.1,
"max": 1.0,
"step": 0.1
}),
},
"optional": {
"IMAGES": ("IMAGE",),
"AUDIO": ("AUDIO",),
"audio_path": ("STRING", {"default": "", "forceInput": True}),
"audio_duration": ("FLOAT", {
"default": 0.0,
"min": 0.0,
"max": 3600.0,
"step": 0.001
}),
"video_path": ("STRING", {
"default": "",
"forceInput": True
}),
"output_fps": ("FLOAT", {
"default": 30.0,
"min": 1.0,
"max": 120.0,
"step": 0.1
}),
}
}
RETURN_TYPES = ("AUDIO", "STRING", "STRING", "FLOAT", "FLOAT", "INT", "FLOAT", "FLOAT")
RETURN_NAMES = ("sync_audio", "sync_audio_path", "sync_video_path", "video_fps", "video_duration", "sync_video_frame_count", "sync_audio_duration", "sync_video_duration")
RETURN_TYPES = ("IMAGE", "AUDIO", "STRING", "STRING", "FLOAT", "FLOAT", "FLOAT", "FLOAT", "INT")
RETURN_NAMES = ("sync_IMAGES", "sync_AUDIO", "sync_audio_path", "sync_video_path",
"input_video_duration", "sync_video_duration", "input_audio_duration", "sync_audio_duration",
"sync_video_frame_count")
FUNCTION = "sync_audio_video"
CATEGORY = "Bjornulf"
def sync_audio_video(self, audio, video_path, audio_duration):
def generate_timestamp(self):
"""Generate a unique timestamp for file naming."""
return datetime.now().strftime("%Y%m%d_%H%M%S")
def validate_audio_input(self, audio):
"""Validate the audio input format."""
if not isinstance(audio, dict) or 'waveform' not in audio or 'sample_rate' not in audio:
raise ValueError("Expected audio input to be a dictionary with 'waveform' and 'sample_rate' keys")
audio_data = audio['waveform']
sample_rate = audio['sample_rate']
# Get original video properties
original_duration = self.get_video_duration(video_path)
video_fps = self.get_video_fps(video_path)
original_frame_count = self.get_frame_count(video_path)
print(f"Original video duration: {original_duration}")
print(f"Target audio duration: {audio_duration}")
print(f"Video FPS: {video_fps}")
print(f"Original frame count: {original_frame_count}")
# Create synchronized versions of video and audio
sync_video_path = self.create_sync_video(video_path, original_duration, audio_duration)
sync_audio_path = self.save_audio(audio_data, sample_rate, audio_duration, original_duration)
# Get properties of synchronized files
sync_video_duration = self.get_video_duration(sync_video_path)
sync_frame_count = self.get_frame_count(sync_video_path)
sync_audio_duration = torchaudio.info(sync_audio_path).num_frames / sample_rate
print(f"Sync video duration: {sync_video_duration}")
print(f"Sync video frame count: {sync_frame_count}")
print(f"Sync audio duration: {sync_audio_duration}")
return (
audio, # Return original audio dictionary
sync_audio_path,
sync_video_path,
video_fps,
original_duration,
sync_frame_count,
sync_audio_duration,
sync_video_duration
)
def get_video_duration(self, video_path):
cmd = ['ffprobe', '-v', 'error', '-show_entries', 'format=duration', '-of', 'default=noprint_wrappers=1:nokey=1', video_path]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return float(result.stdout)
def get_video_fps(self, video_path):
cmd = ['ffprobe', '-v', 'error', '-select_streams', 'v:0', '-count_packets', '-show_entries', 'stream=r_frame_rate', '-of', 'csv=p=0', video_path]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
fps = result.stdout.strip()
if '/' in fps:
num, den = map(float, fps.split('/'))
return num / den
return float(fps)
def get_frame_count(self, video_path):
cmd = ['ffprobe', '-v', 'error', '-count_packets', '-select_streams', 'v:0', '-show_entries', 'stream=nb_read_packets', '-of', 'csv=p=0', video_path]
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return int(result.stdout.strip())
def create_sync_video(self, video_path, original_duration, target_duration):
os.makedirs("Bjornulf/sync_video", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
final_output_path = f"Bjornulf/sync_video/sync_video_{timestamp}.mp4"
# Calculate the relative difference between durations
duration_difference = abs(target_duration - original_duration) / original_duration
# If target duration is longer but within 50% difference, use speed adjustment instead of repeating
if target_duration > original_duration and duration_difference <= 0.5:
# Calculate slowdown ratio
speed_ratio = original_duration / target_duration
pts_speed = 1/speed_ratio
speed_adjust_cmd = [
'ffmpeg',
'-i', video_path,
'-filter:v', f'setpts={pts_speed}*PTS',
'-an',
'-c:v', 'libx264',
'-preset', 'medium',
'-crf', '23',
final_output_path
]
subprocess.run(speed_adjust_cmd, check=True)
print(f"Speed-adjusted video (slowdown ratio: {speed_ratio}) saved to: {final_output_path}")
elif target_duration > original_duration:
# Use the original repeating logic for larger differences
repeat_count = math.ceil(target_duration / original_duration)
concat_file = f"Bjornulf/sync_video/concat_{timestamp}.txt"
with open(concat_file, 'w') as f:
for _ in range(repeat_count):
f.write(f"file '{os.path.abspath(video_path)}'\n")
concat_cmd = [
'ffmpeg',
'-f', 'concat',
'-safe', '0',
'-i', concat_file,
'-c', 'copy',
final_output_path
]
subprocess.run(concat_cmd, check=True)
os.remove(concat_file)
print(f"Duplicated video {repeat_count} times, saved to: {final_output_path}")
def validate_speed_limits(self, max_speedup, max_slowdown):
"""Validate the speed limit parameters."""
if max_speedup < 1.0:
raise ValueError("max_speedup must be greater than or equal to 1.0")
if max_slowdown > 1.0:
raise ValueError("max_slowdown must be less than or equal to 1.0")
def get_audio_duration(self, audio):
"""Calculate audio duration from audio input."""
if isinstance(audio, dict) and 'waveform' in audio and 'sample_rate' in audio:
return audio['waveform'].shape[-1] / audio['sample_rate']
else:
# Original speed-up logic remains the same
speed_ratio = original_duration / target_duration
raise ValueError("Invalid audio input format")
def ffprobe_run(self, cmd):
"""Run ffprobe command and return the output."""
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.stdout.strip()
def get_video_info(self, video_path):
"""Get video duration, fps, and frame count."""
duration = float(self.ffprobe_run([
'ffprobe', '-v', 'error',
'-show_entries', 'format=duration',
'-of', 'default=noprint_wrappers=1:nokey=1',
video_path
]))
fps_str = self.ffprobe_run([
'ffprobe', '-v', 'error',
'-select_streams', 'v:0',
'-show_entries', 'stream=r_frame_rate',
'-of', 'csv=p=0',
video_path
])
fps = float(eval(fps_str)) if '/' in fps_str else float(fps_str)
frame_count = int(self.ffprobe_run([
'ffprobe', '-v', 'error',
'-count_packets',
'-select_streams', 'v:0',
'-show_entries', 'stream=nb_read_packets',
'-of', 'csv=p=0',
video_path
]))
return duration, fps, frame_count
def process_images_to_video(self, IMAGES, fps):
"""Convert image sequence to video."""
timestamp = self.generate_timestamp()
temp_dir = os.path.join(self.temp_dir, f"frames_{timestamp}")
os.makedirs(temp_dir, exist_ok=True)
# Save frames
frame_paths = []
for i, img in enumerate(IMAGES):
if isinstance(img, torch.Tensor):
if img.dim() == 4:
img = img.squeeze(0)
img = (img * 255).byte().cpu().numpy()
img = Image.fromarray(img)
if abs(speed_ratio - 1.0) <= 0.1: # If the difference is less than 10%
copy_cmd = [
'ffmpeg', '-i', video_path, '-c', 'copy', final_output_path
]
subprocess.run(copy_cmd, check=True)
print(f"Video copied without speed adjustment to: {final_output_path}")
frame_path = os.path.join(temp_dir, f"frame_{i:05d}.png")
img.save(frame_path)
frame_paths.append(frame_path)
# Create video
output_path = os.path.join(self.temp_dir, f"video_{timestamp}.mp4")
subprocess.run([
'ffmpeg', '-y',
'-framerate', str(fps),
'-i', os.path.join(temp_dir, 'frame_%05d.png'),
'-c:v', 'libx264',
'-pix_fmt', 'yuv420p',
'-preset', 'medium',
'-crf', '19',
output_path
], check=True)
# Cleanup
for path in frame_paths:
os.remove(path)
os.rmdir(temp_dir)
return output_path
def adjust_video_speed(self, video_path, speed_factor, output_path):
"""Adjust video speed using ffmpeg."""
pts_speed = 1 / speed_factor
subprocess.run([
'ffmpeg', '-y',
'-i', video_path,
'-filter:v', f'setpts={pts_speed}*PTS',
'-an',
'-c:v', 'libx264',
'-preset', 'medium',
'-crf', '19',
output_path
], check=True)
def create_sync_video(self, video_path, original_duration, target_duration, max_speedup, max_slowdown):
"""Create synchronized version of the video."""
timestamp = self.generate_timestamp()
output_path = os.path.join(self.sync_video_dir, f"sync_video_{timestamp}.mp4")
if target_duration > original_duration:
speed_ratio = original_duration / target_duration
if speed_ratio >= max_slowdown:
# Slow down video within limits
self.adjust_video_speed(video_path, speed_ratio, output_path)
else:
speed = min(speed_ratio, 1.5)
pts_speed = 1/speed
# Repeat video if slowdown would exceed limit
repeat_count = math.ceil(target_duration / original_duration)
concat_file = os.path.join(self.sync_video_dir, f"concat_{timestamp}.txt")
speed_adjust_cmd = [
'ffmpeg',
with open(concat_file, 'w') as f:
for _ in range(repeat_count):
f.write(f"file '{os.path.abspath(video_path)}'\n")
subprocess.run([
'ffmpeg', '-y',
'-f', 'concat',
'-safe', '0',
'-i', concat_file,
'-c', 'copy',
output_path
], check=True)
os.remove(concat_file)
else:
speed_ratio = original_duration / target_duration
if abs(speed_ratio - 1.0) <= 0.1:
# Copy video if speed change is minimal
subprocess.run([
'ffmpeg', '-y',
'-i', video_path,
'-filter:v', f'setpts={pts_speed}*PTS',
'-an',
'-c:v', 'libx264',
'-preset', 'medium',
'-crf', '23',
final_output_path
]
subprocess.run(speed_adjust_cmd, check=True)
print(f"Speed-adjusted video (ratio: {speed}) saved to: {final_output_path}")
'-c', 'copy',
output_path
], check=True)
else:
# Speed up video within limits
speed = min(speed_ratio, max_speedup)
self.adjust_video_speed(video_path, speed, output_path)
return os.path.abspath(final_output_path)
def save_audio(self, audio_tensor, sample_rate, target_duration, original_video_duration):
os.makedirs("Bjornulf/sync_audio", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"Bjornulf/sync_audio/sync_audio_{timestamp}.wav"
return os.path.abspath(output_path)
def process_audio(self, audio_tensor, sample_rate, target_duration, original_duration,
max_speedup, max_slowdown):
"""Process audio to match video duration."""
if audio_tensor.dim() == 3:
audio_tensor = audio_tensor.squeeze(0)
elif audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
current_duration = audio_tensor.shape[1] / sample_rate
# Calculate the relative difference between durations
duration_difference = abs(target_duration - original_video_duration) / original_video_duration
# Calculate the final duration based on the same logic as create_sync_video
if target_duration > original_video_duration:
if duration_difference <= 0.5:
# For small differences, we'll keep the original audio duration
sync_video_duration = target_duration
# Calculate synchronized video duration
if target_duration > original_duration:
speed_ratio = original_duration / target_duration
if speed_ratio >= max_slowdown:
sync_duration = target_duration
else:
# For larger differences, we'll repeat the video
sync_video_duration = math.ceil(target_duration / original_video_duration) * original_video_duration
sync_duration = math.ceil(target_duration / original_duration) * original_duration
else:
# Handle speed-up cases
speed_ratio = original_video_duration / target_duration
speed_ratio = original_duration / target_duration
if abs(speed_ratio - 1.0) <= 0.1:
sync_video_duration = original_video_duration
sync_duration = original_duration
else:
speed = min(speed_ratio, 1.5)
sync_video_duration = original_video_duration / speed
speed = min(speed_ratio, max_speedup)
sync_duration = original_duration / speed
# Adjust audio to match sync video duration
if current_duration < sync_video_duration:
# Pad with silence
silence_samples = int((sync_video_duration - current_duration) * sample_rate)
# Adjust audio length
if current_duration < sync_duration:
silence_samples = int((sync_duration - current_duration) * sample_rate)
silence = torch.zeros(audio_tensor.shape[0], silence_samples)
padded_audio = torch.cat([audio_tensor, silence], dim=1)
processed_audio = torch.cat([audio_tensor, silence], dim=1)
else:
# Trim audio to match sync video duration
required_samples = int(sync_video_duration * sample_rate)
padded_audio = audio_tensor[:, :required_samples]
required_samples = int(sync_duration * sample_rate)
processed_audio = audio_tensor[:, :required_samples]
torchaudio.save(filename, padded_audio, sample_rate)
print(f"target_duration: {target_duration}")
print(f"original_video_duration: {original_video_duration}")
print(f"sync_video_duration: {sync_video_duration}")
print(f"current_audio_duration: {current_duration}")
print(f"final_audio_duration: {padded_audio.shape[1] / sample_rate}")
return processed_audio, sync_duration
def save_audio(self, audio_tensor, sample_rate, target_duration, original_duration,
max_speedup, max_slowdown):
"""Save processed audio to file."""
timestamp = self.generate_timestamp()
output_path = os.path.join(self.sync_audio_dir, f"sync_audio_{timestamp}.wav")
processed_audio, sync_duration = self.process_audio(
audio_tensor, sample_rate, target_duration, original_duration,
max_speedup, max_slowdown
)
torchaudio.save(output_path, processed_audio, sample_rate)
return os.path.abspath(output_path)
def load_audio_from_path(self, audio_path):
"""Load audio from file path."""
waveform, sample_rate = torchaudio.load(audio_path)
return {'waveform': waveform, 'sample_rate': sample_rate}
def extract_frames(self, video_path):
"""Extract all frames of the video as a tensor."""
temp_dir = os.path.join(self.temp_dir, "temp_frames")
os.makedirs(temp_dir, exist_ok=True)
# Extract frames using ffmpeg
subprocess.run([
'ffmpeg', '-i', video_path,
os.path.join(temp_dir, 'frame_%05d.png')
], check=True)
# Load frames and convert to tensor
frames = []
frame_files = sorted(os.listdir(temp_dir))
transform = transforms.Compose([transforms.ToTensor()])
for frame_file in frame_files:
image = Image.open(os.path.join(temp_dir, frame_file))
frame_tensor = transform(image)
frames.append(frame_tensor)
# Stack frames into a single tensor
frames_tensor = torch.stack(frames)
# Clean up temporary directory
for frame_file in frame_files:
os.remove(os.path.join(temp_dir, frame_file))
os.rmdir(temp_dir)
return frames_tensor
def sync_audio_video(self, max_speedup=1.5, max_slowdown=0.5,
AUDIO=None, audio_path="", audio_duration=None,
video_path="", IMAGES=None, output_fps=30.0):
"""Main function to synchronize audio and video."""
self.validate_speed_limits(max_speedup, max_slowdown)
# Handle audio input
if AUDIO is None and not audio_path:
raise ValueError("Either AUDIO or audio_path must be provided")
print(f"sync audio saved to: {filename}")
return os.path.abspath(filename)
if audio_path:
AUDIO = self.load_audio_from_path(audio_path)
self.validate_audio_input(AUDIO)
# Calculate audio duration if not provided
if audio_duration is None or audio_duration == 0.0:
audio_duration = self.get_audio_duration(AUDIO)
logging.info(f"Audio duration: {audio_duration}")
# Process input source
if IMAGES is not None and len(IMAGES) > 0:
video_path = self.process_images_to_video(IMAGES, output_fps)
original_duration = len(IMAGES) / output_fps
video_fps = output_fps
original_frame_count = len(IMAGES)
elif video_path:
original_duration, video_fps, original_frame_count = self.get_video_info(video_path)
else:
raise ValueError("Either video_path or IMAGES must be provided")
# Create synchronized versions
sync_video_path = self.create_sync_video(
video_path, original_duration, audio_duration, max_speedup, max_slowdown
)
sync_audio_path = self.save_audio(
AUDIO['waveform'], AUDIO['sample_rate'], audio_duration,
original_duration, max_speedup, max_slowdown
)
# Get final properties
sync_video_duration, _, sync_frame_count = self.get_video_info(sync_video_path)
sync_audio_duration = torchaudio.info(sync_audio_path).num_frames / AUDIO['sample_rate']
video_frames = self.extract_frames(sync_video_path)
return (
video_frames,
AUDIO,
sync_audio_path,
sync_video_path,
original_duration, # input_video_duration
sync_video_duration,
audio_duration, # input_audio_duration
sync_audio_duration,
sync_frame_count
)