mirror of
https://github.com/justUmen/Bjornulf_custom_nodes.git
synced 2026-03-25 06:22:13 -03:00
v1.1.7
This commit is contained in:
83
upscaler_transparency.py
Normal file
83
upscaler_transparency.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import comfy
|
||||
from comfy import model_management
|
||||
import torch
|
||||
import comfy.utils
|
||||
|
||||
class ImageUpscaleWithModelTransparency:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "upscale_model": ("UPSCALE_MODEL",),
|
||||
"image": ("IMAGE",),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "image/upscaling"
|
||||
|
||||
def upscale(self, upscale_model, image):
|
||||
device = model_management.get_torch_device()
|
||||
|
||||
# Check if image has alpha channel (4 channels)
|
||||
has_alpha = image.shape[-1] == 4
|
||||
|
||||
if has_alpha:
|
||||
# Split RGB and alpha channels
|
||||
rgb_image = image[..., :3]
|
||||
alpha_channel = image[..., 3:4]
|
||||
else:
|
||||
rgb_image = image
|
||||
alpha_channel = None
|
||||
|
||||
# Calculate memory requirements (based on RGB channels)
|
||||
memory_required = model_management.module_size(upscale_model.model)
|
||||
memory_required += (512 * 512 * 3) * rgb_image.element_size() * max(upscale_model.scale, 1.0) * 384.0
|
||||
memory_required += rgb_image.nelement() * rgb_image.element_size()
|
||||
|
||||
# Add memory for alpha channel processing if present
|
||||
if has_alpha:
|
||||
memory_required += alpha_channel.nelement() * alpha_channel.element_size() * max(upscale_model.scale, 1.0) ** 2
|
||||
|
||||
model_management.free_memory(memory_required, device)
|
||||
|
||||
upscale_model.to(device)
|
||||
|
||||
# Upscale RGB channels
|
||||
in_img = rgb_image.movedim(-1,-3).to(device)
|
||||
|
||||
tile = 512
|
||||
overlap = 32
|
||||
|
||||
oom = True
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
upscaled_rgb = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
|
||||
upscale_model.to("cpu")
|
||||
upscaled_rgb = torch.clamp(upscaled_rgb.movedim(-3,-1), min=0, max=1.0)
|
||||
|
||||
# Handle alpha channel if present
|
||||
if has_alpha:
|
||||
# Upscale alpha channel using simple interpolation
|
||||
alpha_upscaled = torch.nn.functional.interpolate(
|
||||
alpha_channel.movedim(-1,-3).to(device),
|
||||
size=(upscaled_rgb.shape[1], upscaled_rgb.shape[2]),
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
).movedim(-3,-1).to(upscaled_rgb.device)
|
||||
|
||||
# Clamp alpha values to [0, 1]
|
||||
alpha_upscaled = torch.clamp(alpha_upscaled, min=0, max=1.0)
|
||||
|
||||
# Combine RGB and alpha channels
|
||||
result = torch.cat([upscaled_rgb, alpha_upscaled], dim=-1)
|
||||
else:
|
||||
result = upscaled_rgb
|
||||
|
||||
return (result,)
|
||||
Reference in New Issue
Block a user