mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-21 13:12:13 -03:00
360 lines
18 KiB
Python
360 lines
18 KiB
Python
# https://github.com/BlenderNeko/ComfyUI_TiledKSampler
|
|
import sys
|
|
import os
|
|
import itertools
|
|
import numpy as np
|
|
from tqdm.auto import tqdm
|
|
import torch
|
|
|
|
#sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
|
import comfy.sd
|
|
import comfy.controlnet
|
|
import comfy.model_management
|
|
import comfy.sample
|
|
from . import bnk_tiling as tiling
|
|
import latent_preview
|
|
#import torch
|
|
#import itertools
|
|
#import numpy as np
|
|
MAX_RESOLUTION=8192
|
|
|
|
#######################
|
|
|
|
def get_device():
|
|
return comfy.model_management.get_torch_device()
|
|
|
|
def recursion_to_list(obj, attr):
|
|
current = obj
|
|
yield current
|
|
while True:
|
|
current = getattr(current, attr, None)
|
|
if current is not None:
|
|
yield current
|
|
else:
|
|
return
|
|
|
|
def copy_cond(cond):
|
|
return [[c1,c2.copy()] for c1,c2 in cond]
|
|
|
|
def slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, cond, area, device):
|
|
tile_h_end = tile_h + tile_h_len
|
|
tile_w_end = tile_w + tile_w_len
|
|
coords = area[0] #h_len, w_len, h, w,
|
|
mask = area[1]
|
|
if coords is not None:
|
|
h_len, w_len, h, w = coords
|
|
h_end = h + h_len
|
|
w_end = w + w_len
|
|
if h < tile_h_end and h_end > tile_h and w < tile_w_end and w_end > tile_w:
|
|
new_h = max(0, h - tile_h)
|
|
new_w = max(0, w - tile_w)
|
|
new_h_end = min(tile_h_end, h_end - tile_h)
|
|
new_w_end = min(tile_w_end, w_end - tile_w)
|
|
cond[1]['area'] = (new_h_end - new_h, new_w_end - new_w, new_h, new_w)
|
|
else:
|
|
return (cond, True)
|
|
if mask is not None:
|
|
new_mask = tiling.get_slice(mask, tile_h,tile_h_len,tile_w,tile_w_len)
|
|
if new_mask.sum().to(device) == 0.0 and 'mask' in cond[1]:
|
|
return (cond, True)
|
|
else:
|
|
cond[1]['mask'] = new_mask
|
|
return (cond, False)
|
|
|
|
def slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen):
|
|
tile_h_end = tile_h + tile_h_len
|
|
tile_w_end = tile_w + tile_w_len
|
|
if gligen is None:
|
|
return
|
|
gligen_type = gligen[0]
|
|
gligen_model = gligen[1]
|
|
gligen_areas = gligen[2]
|
|
|
|
gligen_areas_new = []
|
|
for emb, h_len, w_len, h, w in gligen_areas:
|
|
h_end = h + h_len
|
|
w_end = w + w_len
|
|
if h < tile_h_end and h_end > tile_h and w < tile_w_end and w_end > tile_w:
|
|
new_h = max(0, h - tile_h)
|
|
new_w = max(0, w - tile_w)
|
|
new_h_end = min(tile_h_end, h_end - tile_h)
|
|
new_w_end = min(tile_w_end, w_end - tile_w)
|
|
gligen_areas_new.append((emb, new_h_end - new_h, new_w_end - new_w, new_h, new_w))
|
|
|
|
if len(gligen_areas_new) == 0:
|
|
del cond['gligen']
|
|
else:
|
|
cond['gligen'] = (gligen_type, gligen_model, gligen_areas_new)
|
|
|
|
def slice_cnet(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
|
if img is None:
|
|
img = model.cond_hint_original
|
|
model.cond_hint = tiling.get_slice(img, h*8, h_len*8, w*8, w_len*8).to(model.control_model.dtype).to(get_device())
|
|
|
|
def slices_T2I(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
|
model.control_input = None
|
|
if img is None:
|
|
img = model.cond_hint_original
|
|
model.cond_hint = tiling.get_slice(img, h*8, h_len*8, w*8, w_len*8).float().to(get_device())
|
|
|
|
# TODO: refactor some of the mess
|
|
|
|
from PIL import Image
|
|
|
|
def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, preview=False):
|
|
end_at_step = min(end_at_step, steps)
|
|
device = get_device()
|
|
samples = latent_image["samples"]
|
|
noise_mask = latent_image["noise_mask"] if "noise_mask" in latent_image else None
|
|
force_full_denoise = return_with_leftover_noise == "enable"
|
|
if add_noise == "disable":
|
|
noise = torch.zeros(samples.size(), dtype=samples.dtype, layout=samples.layout, device="cpu")
|
|
else:
|
|
skip = latent_image["batch_index"] if "batch_index" in latent_image else None
|
|
noise = comfy.sample.prepare_noise(samples, noise_seed, skip)
|
|
|
|
if noise_mask is not None:
|
|
noise_mask = comfy.sample.prepare_mask(noise_mask, noise.shape, device='cpu')
|
|
|
|
shape = samples.shape
|
|
samples = samples.clone()
|
|
|
|
tile_width = min(shape[-1] * 8, tile_width)
|
|
tile_height = min(shape[2] * 8, tile_height)
|
|
|
|
# real_model = None
|
|
# positive_copy = comfy.sample.convert_cond(positive)
|
|
# negative_copy = comfy.sample.convert_cond(negative)
|
|
#change made due to changes in omfy.sampler's 'convert cond' has been removed.. check if any error persists.
|
|
conds0 = \
|
|
{"positive": comfy.sampler_helpers.convert_cond(positive),
|
|
"negative": comfy.sampler_helpers.convert_cond(negative)}
|
|
|
|
conds = {}
|
|
for k in conds0:
|
|
conds[k] = list(map(lambda a: a.copy(), conds0[k]))
|
|
|
|
modelPatches, inference_memory = comfy.sampler_helpers.get_additional_models(conds, model.model_dtype())
|
|
comfy.model_management.load_models_gpu([model] + modelPatches, model.memory_required(noise.shape) + inference_memory)
|
|
real_model = model.model
|
|
#smaller changes
|
|
model_patcher = comfy.model_patcher.ModelPatcher(model.model, load_device=device, offload_device=comfy.model_management.unet_offload_device())
|
|
sampler = comfy.samplers.KSampler(model_patcher, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
|
#sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
|
|
|
if tiling_strategy != 'padded':
|
|
if noise_mask is not None:
|
|
samples += sampler.sigmas[start_at_step].cpu() * noise_mask * model.model.process_latent_out(noise)
|
|
else:
|
|
samples += sampler.sigmas[start_at_step].cpu() * model.model.process_latent_out(noise)
|
|
|
|
#cnets
|
|
cnets = [c['control'] for (_, c) in positive + negative if 'control' in c and isinstance(c['control'], comfy.controlnet.ControlNet)]
|
|
cnets = list(set([x for m in cnets for x in recursion_to_list(m, "previous_controlnet")]))
|
|
cnet_imgs = [
|
|
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
|
if m.cond_hint_original.shape[-2] != shape[-2] * 8 or m.cond_hint_original.shape[-1] != shape[-1] * 8 else None
|
|
for m in cnets]
|
|
|
|
#T2I
|
|
T2Is = [c['control'] for (_, c) in positive + negative if 'control' in c and isinstance(c['control'], comfy.controlnet.T2IAdapter)]
|
|
T2Is = [x for m in T2Is for x in recursion_to_list(m, "previous_controlnet")]
|
|
T2I_imgs = [
|
|
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
|
if m.cond_hint_original.shape[-2] != shape[-2] * 8 or m.cond_hint_original.shape[-1] != shape[-1] * 8 or (m.channels_in == 1 and m.cond_hint_original.shape[1] != 1) else None
|
|
for m in T2Is
|
|
]
|
|
T2I_imgs = [
|
|
torch.mean(img, 1, keepdim=True) if img is not None and m.channels_in == 1 and m.cond_hint_original.shape[1] else img
|
|
for m, img in zip(T2Is, T2I_imgs)
|
|
]
|
|
|
|
#cond area and mask
|
|
spatial_conds_pos = [
|
|
(c[1]['area'] if 'area' in c[1] else None,
|
|
comfy.sample.prepare_mask(c[1]['mask'], shape, device) if 'mask' in c[1] else None)
|
|
for c in positive
|
|
]
|
|
spatial_conds_neg = [
|
|
(c[1]['area'] if 'area' in c[1] else None,
|
|
comfy.sample.prepare_mask(c[1]['mask'], shape, device) if 'mask' in c[1] else None)
|
|
for c in negative
|
|
]
|
|
|
|
#gligen
|
|
gligen_pos = [
|
|
c[1]['gligen'] if 'gligen' in c[1] else None
|
|
for c in positive
|
|
]
|
|
gligen_neg = [
|
|
c[1]['gligen'] if 'gligen' in c[1] else None
|
|
for c in negative
|
|
]
|
|
|
|
|
|
gen = torch.manual_seed(noise_seed)
|
|
if tiling_strategy == 'random' or tiling_strategy == 'random strict':
|
|
tiles = tiling.get_tiles_and_masks_rgrid(end_at_step - start_at_step, samples.shape, tile_height, tile_width, gen)
|
|
elif tiling_strategy == 'padded':
|
|
tiles = tiling.get_tiles_and_masks_padded(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
|
else:
|
|
tiles = tiling.get_tiles_and_masks_simple(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
|
|
|
total_steps = sum([num_steps for img_pass in tiles for steps_list in img_pass for _,_,_,_,num_steps,_ in steps_list])
|
|
current_step = [0]
|
|
|
|
preview_format = "JPEG"
|
|
if preview_format not in ["JPEG", "PNG"]:
|
|
preview_format = "JPEG"
|
|
previewer = None
|
|
if preview:
|
|
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
|
|
|
|
|
with tqdm(total=total_steps) as pbar_tqdm:
|
|
pbar = comfy.utils.ProgressBar(total_steps)
|
|
|
|
def callback(step, x0, x, total_steps):
|
|
current_step[0] += 1
|
|
preview_bytes = None
|
|
if previewer:
|
|
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
|
pbar.update_absolute(current_step[0], preview=preview_bytes)
|
|
pbar_tqdm.update(1)
|
|
|
|
if tiling_strategy == "random strict":
|
|
samples_next = samples.clone()
|
|
for img_pass in tiles:
|
|
for i in range(len(img_pass)):
|
|
for tile_h, tile_h_len, tile_w, tile_w_len, tile_steps, tile_mask in img_pass[i]:
|
|
tiled_mask = None
|
|
if noise_mask is not None:
|
|
tiled_mask = tiling.get_slice(noise_mask, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
|
if tile_mask is not None:
|
|
if tiled_mask is not None:
|
|
tiled_mask *= tile_mask.to(device)
|
|
else:
|
|
tiled_mask = tile_mask.to(device)
|
|
|
|
if tiling_strategy == 'padded' or tiling_strategy == 'random strict':
|
|
tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask = tiling.mask_at_boundary( tile_h, tile_h_len, tile_w, tile_w_len,
|
|
tile_height, tile_width, samples.shape[-2], samples.shape[-1],
|
|
tiled_mask, device)
|
|
|
|
|
|
if tiled_mask is not None and tiled_mask.sum().cpu() == 0.0:
|
|
continue
|
|
|
|
tiled_latent = tiling.get_slice(samples, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
|
|
|
if tiling_strategy == 'padded':
|
|
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
|
else:
|
|
if tiled_mask is None or noise_mask is None:
|
|
tiled_noise = torch.zeros_like(tiled_latent)
|
|
else:
|
|
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) * (1 - tiled_mask)
|
|
|
|
#TODO: all other condition based stuff like area sets and GLIGEN should also happen here
|
|
|
|
#cnets
|
|
for m, img in zip(cnets, cnet_imgs):
|
|
slice_cnet(tile_h, tile_h_len, tile_w, tile_w_len, m, img)
|
|
|
|
#T2I
|
|
for m, img in zip(T2Is, T2I_imgs):
|
|
slices_T2I(tile_h, tile_h_len, tile_w, tile_w_len, m, img)
|
|
|
|
pos = [c.copy() for c in positive]#copy_cond(positive_copy)
|
|
neg = [c.copy() for c in negative]#copy_cond(negative_copy)
|
|
|
|
#cond areas
|
|
pos = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area, device) for c, area in zip(pos, spatial_conds_pos)]
|
|
pos = [c for c, ignore in pos if not ignore]
|
|
neg = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area, device) for c, area in zip(neg, spatial_conds_neg)]
|
|
neg = [c for c, ignore in neg if not ignore]
|
|
|
|
#gligen
|
|
for cond, gligen in zip(pos, gligen_pos):
|
|
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
|
for cond, gligen in zip(neg, gligen_neg):
|
|
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
|
|
|
tile_result = sampler.sample(tiled_noise, pos, neg, cfg=cfg, latent_image=tiled_latent, start_step=start_at_step + i * tile_steps, last_step=start_at_step + i*tile_steps + tile_steps, force_full_denoise=force_full_denoise and i+1 == end_at_step - start_at_step, denoise_mask=tiled_mask, callback=callback, disable_pbar=True, seed=noise_seed)
|
|
tile_result = tile_result.cpu()
|
|
if tiled_mask is not None:
|
|
tiled_mask = tiled_mask.cpu()
|
|
if tiling_strategy == "random strict":
|
|
tiling.set_slice(samples_next, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
|
else:
|
|
tiling.set_slice(samples, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
|
if tiling_strategy == "random strict":
|
|
samples = samples_next.clone()
|
|
|
|
|
|
comfy.sampler_helpers.cleanup_additional_models(modelPatches)
|
|
|
|
out = latent_image.copy()
|
|
out["samples"] = samples.cpu()
|
|
return (out, )
|
|
|
|
|
|
class TiledKSampler:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"model": ("MODEL",),
|
|
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
|
"tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
|
"tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
|
"tiling_strategy": (["random", "random strict", "padded", 'simple'], ),
|
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
|
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
|
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
|
"positive": ("CONDITIONING", ),
|
|
"negative": ("CONDITIONING", ),
|
|
"latent_image": ("LATENT", ),
|
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
|
}}
|
|
|
|
RETURN_TYPES = ("LATENT",)
|
|
FUNCTION = "sample"
|
|
|
|
CATEGORY = "sampling"
|
|
|
|
def sample(self, model, seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise):
|
|
steps_total = int(steps / denoise)
|
|
return sample_common(model, 'enable', seed, tile_width, tile_height, tiling_strategy, steps_total, cfg, sampler_name, scheduler, positive, negative, latent_image, steps_total-steps, steps_total, 'disable', denoise=1.0, preview=True)
|
|
|
|
class TiledKSamplerAdvanced:
|
|
@classmethod
|
|
def INPUT_TYPES(s):
|
|
return {"required":
|
|
{"model": ("MODEL",),
|
|
"add_noise": (["enable", "disable"], ),
|
|
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
|
"tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
|
"tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
|
"tiling_strategy": (["random", "random strict", "padded", 'simple'], ),
|
|
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
|
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
|
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
|
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
|
"positive": ("CONDITIONING", ),
|
|
"negative": ("CONDITIONING", ),
|
|
"latent_image": ("LATENT", ),
|
|
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
|
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
|
"return_with_leftover_noise": (["disable", "enable"], ),
|
|
"preview": (["disable", "enable"], ),
|
|
}}
|
|
|
|
RETURN_TYPES = ("LATENT",)
|
|
FUNCTION = "sample"
|
|
|
|
CATEGORY = "sampling"
|
|
|
|
def sample(self, model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, preview, denoise=1.0):
|
|
return sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, preview= preview == 'enable')
|