mirror of
https://github.com/justUmen/Bjornulf_custom_nodes.git
synced 2026-03-21 12:42:11 -03:00
206 lines
9.0 KiB
Python
206 lines
9.0 KiB
Python
import torch
|
|
|
|
class SplitImageGrid:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"rows": ("INT", {"default": 1, "min": 1, "max": 9}),
|
|
"columns": ("INT", {"default": 1, "min": 1, "max": 9}),
|
|
"MODIFIED_part_index": ("INT", {"default": 1, "min": 1, "max": 9}),
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ["IMAGE"] * 9 + ["INT", "INT", "IMAGE", "INT"]
|
|
RETURN_NAMES = [f"part_{i}" for i in range(1, 10)] + ["rows", "columns", "MODIFIED_part", "MODIFIED_part_index"]
|
|
FUNCTION = "split"
|
|
CATEGORY = "image"
|
|
|
|
def split(self, image, rows, columns, MODIFIED_part_index):
|
|
# Get image dimensions
|
|
B, H, W, C = image.shape
|
|
# Removed check: if H % rows != 0 or W % columns != 0:
|
|
# raise ValueError("Image dimensions must be divisible by rows and columns")
|
|
|
|
# Calculate base part dimensions
|
|
part_height = H // rows
|
|
part_width = W // columns
|
|
O = 2 # Overlap of 2 pixels
|
|
parts = []
|
|
|
|
# Split image with overlap
|
|
for r in range(rows):
|
|
for c in range(columns):
|
|
# Define slicing indices with overlap
|
|
h_start = max(0, r * part_height - O) # Extend O pixels up, but not beyond top
|
|
h_end = min(H, (r + 1) * part_height + O) # Extend O pixels down, not beyond bottom
|
|
w_start = max(0, c * part_width - O) # Extend O pixels left, not beyond left edge
|
|
w_end = min(W, (c + 1) * part_width + O) # Extend O pixels right, not beyond right edge
|
|
part = image[:, h_start:h_end, w_start:w_end, :]
|
|
parts.append(part)
|
|
|
|
# Pad unused parts with None
|
|
while len(parts) < 9:
|
|
parts.append(None)
|
|
|
|
# Adjust MODIFIED_part_index to 0-based and validate
|
|
MODIFIED_index = MODIFIED_part_index - 1
|
|
if MODIFIED_index < 0 or MODIFIED_index >= rows * columns:
|
|
raise ValueError(f"MODIFIED_part_index {MODIFIED_part_index} is out of range for {rows}x{columns} grid")
|
|
MODIFIED_part = parts[MODIFIED_index]
|
|
|
|
return tuple(parts + [rows, columns, MODIFIED_part, MODIFIED_part_index])
|
|
|
|
class ReassembleImageGrid:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"original": ("IMAGE",),
|
|
"rows": ("INT", {"default": 1, "min": 1, "max": 10}),
|
|
"columns": ("INT", {"default": 1, "min": 1, "max": 10}),
|
|
},
|
|
"optional": {
|
|
"part_1": ("IMAGE",),
|
|
"part_2": ("IMAGE",),
|
|
"part_3": ("IMAGE",),
|
|
"part_4": ("IMAGE",),
|
|
"part_5": ("IMAGE",),
|
|
"part_6": ("IMAGE",),
|
|
"part_7": ("IMAGE",),
|
|
"part_8": ("IMAGE",),
|
|
"part_9": ("IMAGE",),
|
|
"MODIFIED_part": ("IMAGE",),
|
|
"MODIFIED_part_index": ("INT", {"default": 0, "min": 0, "max": 9}),
|
|
"reference_video_part_index": ("INT", {"default": 0, "min": 0, "max": 9}),
|
|
"auto_resize": ("BOOLEAN", {"default": True}), # Add option to enable/disable auto-resizing
|
|
}
|
|
}
|
|
|
|
RETURN_TYPES = ["IMAGE"]
|
|
RETURN_NAMES = ["image"]
|
|
FUNCTION = "reassemble"
|
|
CATEGORY = "image"
|
|
|
|
def repeat_frames(self, tensor, k):
|
|
"""Repeat the tensor k times along the batch dimension."""
|
|
return tensor.repeat(k, 1, 1, 1) if k > 1 else tensor
|
|
|
|
def adjust_frame_count_with_repeat(self, tensor, B_ref, B_original):
|
|
"""Adjust frame count, considering repetition if B_ref ≈ k * B_original."""
|
|
if B_original == 0:
|
|
raise ValueError("Original frame count is zero")
|
|
k = round(B_ref / B_original)
|
|
if k > 0 and abs(B_ref - k * B_original) <= 1:
|
|
repeated = self.repeat_frames(tensor, k)
|
|
if repeated.shape[0] > B_ref:
|
|
repeated = repeated[:B_ref]
|
|
elif repeated.shape[0] < B_ref:
|
|
pad_size = B_ref - repeated.shape[0]
|
|
last_frame = repeated[-1:].repeat(pad_size, 1, 1, 1)
|
|
repeated = torch.cat([repeated, last_frame], dim=0)
|
|
return repeated
|
|
else:
|
|
return self.adjust_frame_count(tensor, B_ref)
|
|
|
|
def adjust_frame_count(self, tensor, target_B):
|
|
"""Adjust the frame count of a tensor to match target_B by repeating or skipping frames."""
|
|
B = tensor.shape[0]
|
|
if B == target_B:
|
|
return tensor
|
|
indices = torch.linspace(0, B - 1, steps=target_B).round().long()
|
|
indices = indices.clamp(0, B - 1)
|
|
return tensor[indices]
|
|
|
|
def resize_tensor(self, tensor, target_height, target_width):
|
|
"""Resize tensor to target dimensions using interpolation."""
|
|
import torch.nn.functional as F
|
|
B, H, W, C = tensor.shape
|
|
|
|
# PyTorch's F.interpolate expects [B, C, H, W] format
|
|
# So we need to permute, resize, then permute back
|
|
tensor_BCHW = tensor.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
|
|
|
|
# Resize using bilinear interpolation
|
|
resized = F.interpolate(
|
|
tensor_BCHW,
|
|
size=(target_height, target_width),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
)
|
|
|
|
# Convert back to [B, H, W, C] format
|
|
return resized.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
|
|
|
|
def reassemble(self, original, rows, columns, part_1=None, part_2=None, part_3=None,
|
|
part_4=None, part_5=None, part_6=None, part_7=None, part_8=None,
|
|
part_9=None, MODIFIED_part=None, MODIFIED_part_index=0,
|
|
reference_video_part_index=0, auto_resize=True):
|
|
# Get original dimensions
|
|
B, H, W, C = original.shape
|
|
|
|
# Calculate part dimensions
|
|
part_height = H // rows
|
|
part_width = W // columns
|
|
O = 2 # Overlap pixels, matching SplitImageGrid
|
|
parts = [part_1, part_2, part_3, part_4, part_5, part_6, part_7, part_8, part_9]
|
|
|
|
# Override with MODIFIED_part if provided
|
|
if MODIFIED_part is not None and MODIFIED_part_index > 0:
|
|
index = MODIFIED_part_index - 1
|
|
if index < 0 or index >= 9:
|
|
raise ValueError(f"Invalid MODIFIED_part_index: {MODIFIED_part_index}")
|
|
parts[index] = MODIFIED_part
|
|
|
|
# Handle reference part logic
|
|
if reference_video_part_index > 0:
|
|
ref_index = reference_video_part_index - 1
|
|
if parts[ref_index] is None:
|
|
raise ValueError(f"Reference part {reference_video_part_index} is not provided")
|
|
B_ref = parts[ref_index].shape[0]
|
|
original = self.adjust_frame_count_with_repeat(original, B_ref, B)
|
|
for i in range(9):
|
|
if parts[i] is not None and i != ref_index:
|
|
parts[i] = self.adjust_frame_count_with_repeat(parts[i], B_ref, B)
|
|
elif i == ref_index:
|
|
parts[i] = parts[i]
|
|
else:
|
|
B_ref = B
|
|
|
|
# Clone original to avoid modifying it
|
|
reassembled = original.clone()
|
|
|
|
# Reassemble the parts into the grid
|
|
for i, part in enumerate(parts, start=1):
|
|
if part is not None:
|
|
# Determine part position
|
|
row = (i - 1) // columns
|
|
col = (i - 1) % columns
|
|
# Calculate cropping offsets based on position
|
|
crop_top = O if row > 0 else 0
|
|
crop_left = O if col > 0 else 0
|
|
|
|
# Get the cropped part
|
|
cropped_part = part[:, crop_top:, crop_left:, :]
|
|
|
|
# Check if resize is needed and enabled
|
|
if auto_resize and (cropped_part.shape[1] != part_height or cropped_part.shape[2] != part_width):
|
|
print(f"Resizing part {i} from {cropped_part.shape[1:3]} to ({part_height}, {part_width})")
|
|
cropped_part = self.resize_tensor(cropped_part, part_height, part_width)
|
|
elif not auto_resize and (cropped_part.shape[1] != part_height or cropped_part.shape[2] != part_width):
|
|
# If auto-resize is disabled, still throw the error
|
|
raise ValueError(f"Cropped part {i} has incorrect shape. Expected ({part_height}, {part_width}, {C}), got {cropped_part.shape[1:]}")
|
|
|
|
# Validate frame count
|
|
if cropped_part.shape[0] != B_ref:
|
|
raise ValueError(f"Cropped part {i} has incorrect frame count. Expected {B_ref}, got {cropped_part.shape[0]}")
|
|
|
|
# Place cropped part into reassembled image
|
|
h_start = row * part_height
|
|
h_end = h_start + part_height
|
|
w_start = col * part_width
|
|
w_end = w_start + part_width
|
|
reassembled[:, h_start:h_end, w_start:w_end, :] = cropped_part
|
|
|
|
return (reassembled,) |