Files
Bjornulf_custom_nodes/split_image.py
justumen 39dfb0220a 0.77
2025-03-19 17:36:25 +01:00

206 lines
9.0 KiB
Python

import torch
class SplitImageGrid:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"rows": ("INT", {"default": 1, "min": 1, "max": 9}),
"columns": ("INT", {"default": 1, "min": 1, "max": 9}),
"MODIFIED_part_index": ("INT", {"default": 1, "min": 1, "max": 9}),
}
}
RETURN_TYPES = ["IMAGE"] * 9 + ["INT", "INT", "IMAGE", "INT"]
RETURN_NAMES = [f"part_{i}" for i in range(1, 10)] + ["rows", "columns", "MODIFIED_part", "MODIFIED_part_index"]
FUNCTION = "split"
CATEGORY = "image"
def split(self, image, rows, columns, MODIFIED_part_index):
# Get image dimensions
B, H, W, C = image.shape
# Removed check: if H % rows != 0 or W % columns != 0:
# raise ValueError("Image dimensions must be divisible by rows and columns")
# Calculate base part dimensions
part_height = H // rows
part_width = W // columns
O = 2 # Overlap of 2 pixels
parts = []
# Split image with overlap
for r in range(rows):
for c in range(columns):
# Define slicing indices with overlap
h_start = max(0, r * part_height - O) # Extend O pixels up, but not beyond top
h_end = min(H, (r + 1) * part_height + O) # Extend O pixels down, not beyond bottom
w_start = max(0, c * part_width - O) # Extend O pixels left, not beyond left edge
w_end = min(W, (c + 1) * part_width + O) # Extend O pixels right, not beyond right edge
part = image[:, h_start:h_end, w_start:w_end, :]
parts.append(part)
# Pad unused parts with None
while len(parts) < 9:
parts.append(None)
# Adjust MODIFIED_part_index to 0-based and validate
MODIFIED_index = MODIFIED_part_index - 1
if MODIFIED_index < 0 or MODIFIED_index >= rows * columns:
raise ValueError(f"MODIFIED_part_index {MODIFIED_part_index} is out of range for {rows}x{columns} grid")
MODIFIED_part = parts[MODIFIED_index]
return tuple(parts + [rows, columns, MODIFIED_part, MODIFIED_part_index])
class ReassembleImageGrid:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original": ("IMAGE",),
"rows": ("INT", {"default": 1, "min": 1, "max": 10}),
"columns": ("INT", {"default": 1, "min": 1, "max": 10}),
},
"optional": {
"part_1": ("IMAGE",),
"part_2": ("IMAGE",),
"part_3": ("IMAGE",),
"part_4": ("IMAGE",),
"part_5": ("IMAGE",),
"part_6": ("IMAGE",),
"part_7": ("IMAGE",),
"part_8": ("IMAGE",),
"part_9": ("IMAGE",),
"MODIFIED_part": ("IMAGE",),
"MODIFIED_part_index": ("INT", {"default": 0, "min": 0, "max": 9}),
"reference_video_part_index": ("INT", {"default": 0, "min": 0, "max": 9}),
"auto_resize": ("BOOLEAN", {"default": True}), # Add option to enable/disable auto-resizing
}
}
RETURN_TYPES = ["IMAGE"]
RETURN_NAMES = ["image"]
FUNCTION = "reassemble"
CATEGORY = "image"
def repeat_frames(self, tensor, k):
"""Repeat the tensor k times along the batch dimension."""
return tensor.repeat(k, 1, 1, 1) if k > 1 else tensor
def adjust_frame_count_with_repeat(self, tensor, B_ref, B_original):
"""Adjust frame count, considering repetition if B_ref ≈ k * B_original."""
if B_original == 0:
raise ValueError("Original frame count is zero")
k = round(B_ref / B_original)
if k > 0 and abs(B_ref - k * B_original) <= 1:
repeated = self.repeat_frames(tensor, k)
if repeated.shape[0] > B_ref:
repeated = repeated[:B_ref]
elif repeated.shape[0] < B_ref:
pad_size = B_ref - repeated.shape[0]
last_frame = repeated[-1:].repeat(pad_size, 1, 1, 1)
repeated = torch.cat([repeated, last_frame], dim=0)
return repeated
else:
return self.adjust_frame_count(tensor, B_ref)
def adjust_frame_count(self, tensor, target_B):
"""Adjust the frame count of a tensor to match target_B by repeating or skipping frames."""
B = tensor.shape[0]
if B == target_B:
return tensor
indices = torch.linspace(0, B - 1, steps=target_B).round().long()
indices = indices.clamp(0, B - 1)
return tensor[indices]
def resize_tensor(self, tensor, target_height, target_width):
"""Resize tensor to target dimensions using interpolation."""
import torch.nn.functional as F
B, H, W, C = tensor.shape
# PyTorch's F.interpolate expects [B, C, H, W] format
# So we need to permute, resize, then permute back
tensor_BCHW = tensor.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
# Resize using bilinear interpolation
resized = F.interpolate(
tensor_BCHW,
size=(target_height, target_width),
mode='bilinear',
align_corners=False
)
# Convert back to [B, H, W, C] format
return resized.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
def reassemble(self, original, rows, columns, part_1=None, part_2=None, part_3=None,
part_4=None, part_5=None, part_6=None, part_7=None, part_8=None,
part_9=None, MODIFIED_part=None, MODIFIED_part_index=0,
reference_video_part_index=0, auto_resize=True):
# Get original dimensions
B, H, W, C = original.shape
# Calculate part dimensions
part_height = H // rows
part_width = W // columns
O = 2 # Overlap pixels, matching SplitImageGrid
parts = [part_1, part_2, part_3, part_4, part_5, part_6, part_7, part_8, part_9]
# Override with MODIFIED_part if provided
if MODIFIED_part is not None and MODIFIED_part_index > 0:
index = MODIFIED_part_index - 1
if index < 0 or index >= 9:
raise ValueError(f"Invalid MODIFIED_part_index: {MODIFIED_part_index}")
parts[index] = MODIFIED_part
# Handle reference part logic
if reference_video_part_index > 0:
ref_index = reference_video_part_index - 1
if parts[ref_index] is None:
raise ValueError(f"Reference part {reference_video_part_index} is not provided")
B_ref = parts[ref_index].shape[0]
original = self.adjust_frame_count_with_repeat(original, B_ref, B)
for i in range(9):
if parts[i] is not None and i != ref_index:
parts[i] = self.adjust_frame_count_with_repeat(parts[i], B_ref, B)
elif i == ref_index:
parts[i] = parts[i]
else:
B_ref = B
# Clone original to avoid modifying it
reassembled = original.clone()
# Reassemble the parts into the grid
for i, part in enumerate(parts, start=1):
if part is not None:
# Determine part position
row = (i - 1) // columns
col = (i - 1) % columns
# Calculate cropping offsets based on position
crop_top = O if row > 0 else 0
crop_left = O if col > 0 else 0
# Get the cropped part
cropped_part = part[:, crop_top:, crop_left:, :]
# Check if resize is needed and enabled
if auto_resize and (cropped_part.shape[1] != part_height or cropped_part.shape[2] != part_width):
print(f"Resizing part {i} from {cropped_part.shape[1:3]} to ({part_height}, {part_width})")
cropped_part = self.resize_tensor(cropped_part, part_height, part_width)
elif not auto_resize and (cropped_part.shape[1] != part_height or cropped_part.shape[2] != part_width):
# If auto-resize is disabled, still throw the error
raise ValueError(f"Cropped part {i} has incorrect shape. Expected ({part_height}, {part_width}, {C}), got {cropped_part.shape[1:]}")
# Validate frame count
if cropped_part.shape[0] != B_ref:
raise ValueError(f"Cropped part {i} has incorrect frame count. Expected {B_ref}, got {cropped_part.shape[0]}")
# Place cropped part into reassembled image
h_start = row * part_height
h_end = h_start + part_height
w_start = col * part_width
w_end = w_start + part_width
reassembled[:, h_start:h_end, w_start:w_end, :] = cropped_part
return (reassembled,)