mirror of
https://github.com/justUmen/Bjornulf_custom_nodes.git
synced 2026-03-21 12:42:11 -03:00
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
import torch
|
|
|
|
class ImageMaskCutter:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
"mask": ("MASK",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "cut_image"
|
|
|
|
CATEGORY = "Bjornulf"
|
|
|
|
def cut_image(self, image, mask):
|
|
print(f"Image shape: {image.shape}")
|
|
print(f"Mask shape: {mask.shape}")
|
|
|
|
# Check if image channels are in the last dimension
|
|
if image.shape[-1] == 3 or image.shape[-1] == 4:
|
|
# Move channels to second dimension
|
|
image = image.permute(0, 3, 1, 2)
|
|
|
|
# Ensure image and mask have compatible dimensions
|
|
if image.shape[2:] != mask.shape[1:]:
|
|
raise ValueError(f"Image and mask must have compatible dimensions. Got image shape {image.shape} and mask shape {mask.shape}")
|
|
|
|
# Convert mask to float and ensure it's in the range [0, 1]
|
|
mask = mask.float()
|
|
mask = torch.clamp(mask, 0, 1)
|
|
|
|
# If image is RGB, convert to RGBA
|
|
if image.shape[1] == 3:
|
|
alpha = torch.ones((image.shape[0], 1, image.shape[2], image.shape[3]), device=image.device)
|
|
image = torch.cat([image, alpha], dim=1)
|
|
|
|
# Use the mask as the alpha channel
|
|
image[:, 3:4, :, :] = mask.unsqueeze(1)
|
|
|
|
# Move channels back to the last dimension
|
|
cut_image = image.permute(0, 2, 3, 1)
|
|
|
|
return (cut_image,) |