mirror of
https://github.com/justUmen/Bjornulf_custom_nodes.git
synced 2026-03-21 12:42:11 -03:00
83 lines
3.0 KiB
Python
83 lines
3.0 KiB
Python
import comfy
|
|
from comfy import model_management
|
|
import torch
|
|
import comfy.utils
|
|
|
|
class ImageUpscaleWithModelTransparency:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
|
|
"image": ("IMAGE",),
|
|
}}
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "upscale"
|
|
|
|
CATEGORY = "image/upscaling"
|
|
|
|
def upscale(self, upscale_model, image):
|
|
device = model_management.get_torch_device()
|
|
|
|
# Check if image has alpha channel (4 channels)
|
|
has_alpha = image.shape[-1] == 4
|
|
|
|
if has_alpha:
|
|
# Split RGB and alpha channels
|
|
rgb_image = image[..., :3]
|
|
alpha_channel = image[..., 3:4]
|
|
else:
|
|
rgb_image = image
|
|
alpha_channel = None
|
|
|
|
# Calculate memory requirements (based on RGB channels)
|
|
memory_required = model_management.module_size(upscale_model.model)
|
|
memory_required += (512 * 512 * 3) * rgb_image.element_size() * max(upscale_model.scale, 1.0) * 384.0
|
|
memory_required += rgb_image.nelement() * rgb_image.element_size()
|
|
|
|
# Add memory for alpha channel processing if present
|
|
if has_alpha:
|
|
memory_required += alpha_channel.nelement() * alpha_channel.element_size() * max(upscale_model.scale, 1.0) ** 2
|
|
|
|
model_management.free_memory(memory_required, device)
|
|
|
|
upscale_model.to(device)
|
|
|
|
# Upscale RGB channels
|
|
in_img = rgb_image.movedim(-1,-3).to(device)
|
|
|
|
tile = 512
|
|
overlap = 32
|
|
|
|
oom = True
|
|
while oom:
|
|
try:
|
|
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
|
pbar = comfy.utils.ProgressBar(steps)
|
|
upscaled_rgb = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
|
oom = False
|
|
except model_management.OOM_EXCEPTION as e:
|
|
tile //= 2
|
|
if tile < 128:
|
|
raise e
|
|
|
|
upscale_model.to("cpu")
|
|
upscaled_rgb = torch.clamp(upscaled_rgb.movedim(-3,-1), min=0, max=1.0)
|
|
|
|
# Handle alpha channel if present
|
|
if has_alpha:
|
|
# Upscale alpha channel using simple interpolation
|
|
alpha_upscaled = torch.nn.functional.interpolate(
|
|
alpha_channel.movedim(-1,-3).to(device),
|
|
size=(upscaled_rgb.shape[1], upscaled_rgb.shape[2]),
|
|
mode='bilinear',
|
|
align_corners=False
|
|
).movedim(-3,-1).to(upscaled_rgb.device)
|
|
|
|
# Clamp alpha values to [0, 1]
|
|
alpha_upscaled = torch.clamp(alpha_upscaled, min=0, max=1.0)
|
|
|
|
# Combine RGB and alpha channels
|
|
result = torch.cat([upscaled_rgb, alpha_upscaled], dim=-1)
|
|
else:
|
|
result = upscaled_rgb
|
|
|
|
return (result,) |