mirror of
https://github.com/justUmen/Bjornulf_custom_nodes.git
synced 2026-03-21 20:52:11 -03:00
76 lines
2.8 KiB
Python
76 lines
2.8 KiB
Python
import numpy as np
|
|
from PIL import Image
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
|
|
class GreenScreenToTransparency:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE", {}),
|
|
"threshold": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
},
|
|
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
|
}
|
|
|
|
FUNCTION = "remove_green_screen"
|
|
RETURN_TYPES = ("IMAGE",)
|
|
OUTPUT_NODE = True
|
|
CATEGORY = "Bjornulf"
|
|
|
|
def remove_green_screen(self, image, threshold=0.1, prompt=None, extra_pnginfo=None):
|
|
# Ensure the input image is on CPU and convert to numpy array
|
|
image_np = image.cpu().numpy()
|
|
|
|
# Check if the image is in the format [batch, height, width, channel]
|
|
if image_np.ndim == 4:
|
|
# If so, we'll process each image in the batch
|
|
processed_images = []
|
|
for img in image_np:
|
|
processed_img = self._process_single_image(img, threshold)
|
|
processed_images.append(processed_img)
|
|
|
|
# Stack the processed images back into a batch
|
|
processed_batch = np.stack(processed_images)
|
|
# Convert to torch tensor
|
|
processed_tensor = torch.from_numpy(processed_batch)
|
|
else:
|
|
# If it's a single image, process it directly
|
|
processed_np = self._process_single_image(image_np, threshold)
|
|
# Add batch dimension if it was originally present
|
|
if image.dim() == 4:
|
|
processed_np = np.expand_dims(processed_np, axis=0)
|
|
# Convert to torch tensor
|
|
processed_tensor = torch.from_numpy(processed_np)
|
|
|
|
# Update metadata if needed
|
|
if extra_pnginfo is not None:
|
|
extra_pnginfo["green_screen_removed"] = True
|
|
|
|
return (processed_tensor, prompt, extra_pnginfo)
|
|
|
|
def _process_single_image(self, img, threshold):
|
|
# Convert to PIL Image
|
|
pil_img = Image.fromarray((img * 255).astype(np.uint8))
|
|
|
|
# Convert the image to RGBA mode
|
|
pil_img = pil_img.convert("RGBA")
|
|
|
|
# Get image data as numpy array
|
|
data = np.array(pil_img)
|
|
|
|
# Create a mask for green pixels
|
|
r, g, b, a = data[:,:,0], data[:,:,1], data[:,:,2], data[:,:,3]
|
|
mask = (g > r + threshold * 255) & (g > b + threshold * 255)
|
|
|
|
# Set alpha channel to 0 for green pixels
|
|
data[:,:,3] = np.where(mask, 0, a)
|
|
|
|
# Create a new image with the updated data
|
|
result = Image.fromarray(data)
|
|
|
|
# Convert back to numpy and normalize
|
|
processed_np = np.array(result).astype(np.float32) / 255.0
|
|
|
|
return processed_np |