mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-24 22:52:13 -03:00
Update bnk_tiled_samplers.py
many changes in params and conditions
This commit is contained in:
committed by
GitHub
parent
05db003a26
commit
0e56eb34a0
@@ -11,190 +11,13 @@ import comfy.sd
|
|||||||
import comfy.controlnet
|
import comfy.controlnet
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.sample
|
import comfy.sample
|
||||||
#from . import tiling
|
from . import bnk_tiling as tiling
|
||||||
import latent_preview
|
import latent_preview
|
||||||
#import torch
|
#import torch
|
||||||
#import itertools
|
#import itertools
|
||||||
#import numpy as np
|
#import numpy as np
|
||||||
MAX_RESOLUTION=8192
|
MAX_RESOLUTION=8192
|
||||||
|
|
||||||
def grouper(n, iterable):
|
|
||||||
it = iter(iterable)
|
|
||||||
while True:
|
|
||||||
chunk = list(itertools.islice(it, n))
|
|
||||||
if not chunk:
|
|
||||||
return
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
def create_batches(n, iterable):
|
|
||||||
groups = itertools.groupby(iterable, key=lambda x: (x[1], x[3]))
|
|
||||||
for _, x in groups:
|
|
||||||
for y in grouper(n, x):
|
|
||||||
yield y
|
|
||||||
|
|
||||||
|
|
||||||
def get_slice(tensor, h, h_len, w, w_len):
|
|
||||||
t = tensor.narrow(-2, h, h_len)
|
|
||||||
t = t.narrow(-1, w, w_len)
|
|
||||||
return t
|
|
||||||
|
|
||||||
|
|
||||||
def set_slice(tensor1, tensor2, h, h_len, w, w_len, mask=None):
|
|
||||||
if mask is not None:
|
|
||||||
tensor1[:, :, h:h + h_len, w:w + w_len] = tensor1[:, :, h:h + h_len, w:w + w_len] * (1 - mask) + tensor2 * mask
|
|
||||||
else:
|
|
||||||
tensor1[:, :, h:h + h_len, w:w + w_len] = tensor2
|
|
||||||
|
|
||||||
|
|
||||||
def get_tiles_and_masks_simple(steps, latent_shape, tile_height, tile_width):
|
|
||||||
latent_size_h = latent_shape[-2]
|
|
||||||
latent_size_w = latent_shape[-1]
|
|
||||||
tile_size_h = int(tile_height // 8)
|
|
||||||
tile_size_w = int(tile_width // 8)
|
|
||||||
|
|
||||||
h = np.arange(0, latent_size_h, tile_size_h)
|
|
||||||
w = np.arange(0, latent_size_w, tile_size_w)
|
|
||||||
|
|
||||||
def create_tile(hs, ws, i, j):
|
|
||||||
h = int(hs[i])
|
|
||||||
w = int(ws[j])
|
|
||||||
h_len = min(tile_size_h, latent_size_h - h)
|
|
||||||
w_len = min(tile_size_w, latent_size_w - w)
|
|
||||||
return (h, h_len, w, w_len, steps, None)
|
|
||||||
|
|
||||||
passes = [
|
|
||||||
[[create_tile(h, w, i, j) for i in range(len(h)) for j in range(len(w))]],
|
|
||||||
]
|
|
||||||
return passes
|
|
||||||
|
|
||||||
|
|
||||||
def get_tiles_and_masks_padded(steps, latent_shape, tile_height, tile_width):
|
|
||||||
batch_size = latent_shape[0]
|
|
||||||
latent_size_h = latent_shape[-2]
|
|
||||||
latent_size_w = latent_shape[-1]
|
|
||||||
|
|
||||||
tile_size_h = int(tile_height // 8)
|
|
||||||
tile_size_h = int((tile_size_h // 4) * 4)
|
|
||||||
tile_size_w = int(tile_width // 8)
|
|
||||||
tile_size_w = int((tile_size_w // 4) * 4)
|
|
||||||
|
|
||||||
# masks
|
|
||||||
mask_h = [0, tile_size_h // 4, tile_size_h - tile_size_h // 4, tile_size_h]
|
|
||||||
mask_w = [0, tile_size_w // 4, tile_size_w - tile_size_w // 4, tile_size_w]
|
|
||||||
masks = [[] for _ in range(3)]
|
|
||||||
for i in range(3):
|
|
||||||
for j in range(3):
|
|
||||||
mask = torch.zeros((batch_size, 1, tile_size_h, tile_size_w), dtype=torch.float32, device='cpu')
|
|
||||||
mask[:, :, mask_h[i]:mask_h[i + 1], mask_w[j]:mask_w[j + 1]] = 1.0
|
|
||||||
masks[i].append(mask)
|
|
||||||
|
|
||||||
def create_mask(h_ind, w_ind, h_ind_max, w_ind_max, mask_h, mask_w, h_len, w_len):
|
|
||||||
mask = masks[1][1]
|
|
||||||
if not (h_ind == 0 or h_ind == h_ind_max or w_ind == 0 or w_ind == w_ind_max):
|
|
||||||
return get_slice(mask, 0, h_len, 0, w_len)
|
|
||||||
mask = mask.clone()
|
|
||||||
if h_ind == 0 and mask_h:
|
|
||||||
mask += masks[0][1]
|
|
||||||
if h_ind == h_ind_max and mask_h:
|
|
||||||
mask += masks[2][1]
|
|
||||||
if w_ind == 0 and mask_w:
|
|
||||||
mask += masks[1][0]
|
|
||||||
if w_ind == w_ind_max and mask_w:
|
|
||||||
mask += masks[1][2]
|
|
||||||
if h_ind == 0 and w_ind == 0 and mask_h and mask_w:
|
|
||||||
mask += masks[0][0]
|
|
||||||
if h_ind == 0 and w_ind == w_ind_max and mask_h and mask_w:
|
|
||||||
mask += masks[0][2]
|
|
||||||
if h_ind == h_ind_max and w_ind == 0 and mask_h and mask_w:
|
|
||||||
mask += masks[2][0]
|
|
||||||
if h_ind == h_ind_max and w_ind == w_ind_max and mask_h and mask_w:
|
|
||||||
mask += masks[2][2]
|
|
||||||
return get_slice(mask, 0, h_len, 0, w_len)
|
|
||||||
|
|
||||||
h = np.arange(0, latent_size_h, tile_size_h)
|
|
||||||
h_shift = np.arange(tile_size_h // 2, latent_size_h - tile_size_h // 2, tile_size_h)
|
|
||||||
w = np.arange(0, latent_size_w, tile_size_w)
|
|
||||||
w_shift = np.arange(tile_size_w // 2, latent_size_w - tile_size_h // 2, tile_size_w)
|
|
||||||
|
|
||||||
def create_tile(hs, ws, mask_h, mask_w, i, j):
|
|
||||||
h = int(hs[i])
|
|
||||||
w = int(ws[j])
|
|
||||||
h_len = min(tile_size_h, latent_size_h - h)
|
|
||||||
w_len = min(tile_size_w, latent_size_w - w)
|
|
||||||
mask = create_mask(i, j, len(hs) - 1, len(ws) - 1, mask_h, mask_w, h_len, w_len)
|
|
||||||
return (h, h_len, w, w_len, steps, mask)
|
|
||||||
|
|
||||||
passes = [
|
|
||||||
[[create_tile(h, w, True, True, i, j) for i in range(len(h)) for j in range(len(w))]],
|
|
||||||
[[create_tile(h_shift, w, False, True, i, j) for i in range(len(h_shift)) for j in range(len(w))]],
|
|
||||||
[[create_tile(h, w_shift, True, False, i, j) for i in range(len(h)) for j in range(len(w_shift))]],
|
|
||||||
[[create_tile(h_shift, w_shift, False, False, i, j) for i in range(len(h_shift)) for j in range(len(w_shift))]],
|
|
||||||
]
|
|
||||||
|
|
||||||
return passes
|
|
||||||
|
|
||||||
|
|
||||||
def mask_at_boundary(h, h_len, w, w_len, tile_size_h, tile_size_w, latent_size_h, latent_size_w, mask, device='cpu'):
|
|
||||||
tile_size_h = int(tile_size_h // 8)
|
|
||||||
tile_size_w = int(tile_size_w // 8)
|
|
||||||
|
|
||||||
if (h_len == tile_size_h or h_len == latent_size_h) and (w_len == tile_size_w or w_len == latent_size_w):
|
|
||||||
return h, h_len, w, w_len, mask
|
|
||||||
h_offset = min(0, latent_size_h - (h + tile_size_h))
|
|
||||||
w_offset = min(0, latent_size_w - (w + tile_size_w))
|
|
||||||
new_mask = torch.zeros((1, 1, tile_size_h, tile_size_w), dtype=torch.float32, device=device)
|
|
||||||
new_mask[:, :, -h_offset:h_len if h_offset == 0 else tile_size_h,
|
|
||||||
-w_offset:w_len if w_offset == 0 else tile_size_w] = 1.0 if mask is None else mask
|
|
||||||
return h + h_offset, tile_size_h, w + w_offset, tile_size_w, new_mask
|
|
||||||
|
|
||||||
|
|
||||||
def get_tiles_and_masks_rgrid(steps, latent_shape, tile_height, tile_width, generator):
|
|
||||||
def calc_coords(latent_size, tile_size, jitter):
|
|
||||||
tile_coords = int((latent_size + jitter - 1) // tile_size + 1)
|
|
||||||
tile_coords = [np.clip(tile_size * c - jitter, 0, latent_size) for c in range(tile_coords + 1)]
|
|
||||||
tile_coords = [(c1, c2 - c1) for c1, c2 in zip(tile_coords, tile_coords[1:])]
|
|
||||||
return tile_coords
|
|
||||||
|
|
||||||
# calc stuff
|
|
||||||
batch_size = latent_shape[0]
|
|
||||||
latent_size_h = latent_shape[-2]
|
|
||||||
latent_size_w = latent_shape[-1]
|
|
||||||
tile_size_h = int(tile_height // 8)
|
|
||||||
tile_size_w = int(tile_width // 8)
|
|
||||||
|
|
||||||
tiles_all = []
|
|
||||||
|
|
||||||
for s in range(steps):
|
|
||||||
rands = torch.rand((2,), dtype=torch.float32, generator=generator, device='cpu').numpy()
|
|
||||||
|
|
||||||
jitter_w1 = int(rands[0] * tile_size_w)
|
|
||||||
jitter_w2 = int(((rands[0] + .5) % 1.0) * tile_size_w)
|
|
||||||
jitter_h1 = int(rands[1] * tile_size_h)
|
|
||||||
jitter_h2 = int(((rands[1] + .5) % 1.0) * tile_size_h)
|
|
||||||
|
|
||||||
# calc number of tiles
|
|
||||||
tiles_h = [
|
|
||||||
calc_coords(latent_size_h, tile_size_h, jitter_h1),
|
|
||||||
calc_coords(latent_size_h, tile_size_h, jitter_h2)
|
|
||||||
]
|
|
||||||
tiles_w = [
|
|
||||||
calc_coords(latent_size_w, tile_size_w, jitter_w1),
|
|
||||||
calc_coords(latent_size_w, tile_size_w, jitter_w2)
|
|
||||||
]
|
|
||||||
|
|
||||||
tiles = []
|
|
||||||
if s % 2 == 0:
|
|
||||||
for i, h in enumerate(tiles_h[0]):
|
|
||||||
for w in tiles_w[i % 2]:
|
|
||||||
tiles.append((int(h[0]), int(h[1]), int(w[0]), int(w[1]), 1, None))
|
|
||||||
else:
|
|
||||||
for i, w in enumerate(tiles_w[0]):
|
|
||||||
for h in tiles_h[i % 2]:
|
|
||||||
tiles.append((int(h[0]), int(h[1]), int(w[0]), int(w[1]), 1, None))
|
|
||||||
tiles_all.append(tiles)
|
|
||||||
return [tiles_all]
|
|
||||||
|
|
||||||
#######################
|
#######################
|
||||||
|
|
||||||
def recursion_to_list(obj, attr):
|
def recursion_to_list(obj, attr):
|
||||||
@@ -210,7 +33,7 @@ def recursion_to_list(obj, attr):
|
|||||||
def copy_cond(cond):
|
def copy_cond(cond):
|
||||||
return [[c1,c2.copy()] for c1,c2 in cond]
|
return [[c1,c2.copy()] for c1,c2 in cond]
|
||||||
|
|
||||||
def slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, cond, area):
|
def slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, cond, area, device):
|
||||||
tile_h_end = tile_h + tile_h_len
|
tile_h_end = tile_h + tile_h_len
|
||||||
tile_w_end = tile_w + tile_w_len
|
tile_w_end = tile_w + tile_w_len
|
||||||
coords = area[0] #h_len, w_len, h, w,
|
coords = area[0] #h_len, w_len, h, w,
|
||||||
@@ -228,8 +51,8 @@ def slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, cond, area):
|
|||||||
else:
|
else:
|
||||||
return (cond, True)
|
return (cond, True)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
new_mask = get_slice(mask, tile_h,tile_h_len,tile_w,tile_w_len)
|
new_mask = tiling.get_slice(mask, tile_h,tile_h_len,tile_w,tile_w_len)
|
||||||
if new_mask.sum().cpu() == 0.0 and 'mask' in cond[1]:
|
if new_mask.sum().to(device) == 0.0 and 'mask' in cond[1]:
|
||||||
return (cond, True)
|
return (cond, True)
|
||||||
else:
|
else:
|
||||||
cond[1]['mask'] = new_mask
|
cond[1]['mask'] = new_mask
|
||||||
@@ -263,13 +86,13 @@ def slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen):
|
|||||||
def slice_cnet(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
def slice_cnet(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
||||||
if img is None:
|
if img is None:
|
||||||
img = model.cond_hint_original
|
img = model.cond_hint_original
|
||||||
model.cond_hint = get_slice(img, h*8, h_len*8, w*8, w_len*8).to(model.control_model.dtype).to(model.device)
|
model.cond_hint = tiling.get_slice(img, h*8, h_len*8, w*8, w_len*8).to(model.control_model.dtype).to(model.device)
|
||||||
|
|
||||||
def slices_T2I(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
def slices_T2I(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
||||||
model.control_input = None
|
model.control_input = None
|
||||||
if img is None:
|
if img is None:
|
||||||
img = model.cond_hint_original
|
img = model.cond_hint_original
|
||||||
model.cond_hint = get_slice(img, h*8, h_len*8, w*8, w_len*8).float().to(model.device)
|
model.cond_hint = tiling.get_slice(img, h*8, h_len*8, w*8, w_len*8).float().to(model.device)
|
||||||
|
|
||||||
# TODO: refactor some of the mess
|
# TODO: refactor some of the mess
|
||||||
|
|
||||||
@@ -297,7 +120,9 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
tile_height = min(shape[2] * 8, tile_height)
|
tile_height = min(shape[2] * 8, tile_height)
|
||||||
|
|
||||||
real_model = None
|
real_model = None
|
||||||
modelPatches, inference_memory = comfy.sample.get_additional_models(positive, negative, model.model_dtype())
|
positive_copy = comfy.sample.convert_cond(positive)
|
||||||
|
negative_copy = comfy.sample.convert_cond(negative)
|
||||||
|
modelPatches, inference_memory = comfy.sample.get_additional_models(positive_copy, negative_copy, model.model_dtype())
|
||||||
comfy.model_management.load_models_gpu([model] + modelPatches, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
comfy.model_management.load_models_gpu([model] + modelPatches, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
||||||
real_model = model.model
|
real_model = model.model
|
||||||
|
|
||||||
@@ -305,13 +130,12 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
|
|
||||||
if tiling_strategy != 'padded':
|
if tiling_strategy != 'padded':
|
||||||
if noise_mask is not None:
|
if noise_mask is not None:
|
||||||
samples += sampler.sigmas[start_at_step].cpu() * noise_mask * model.model.process_latent_out(noise).cpu()
|
samples += sampler.sigmas[start_at_step].cpu() * noise_mask * model.model.process_latent_out(noise)
|
||||||
else:
|
else:
|
||||||
samples += sampler.sigmas[start_at_step].cpu() * model.model.process_latent_out(noise).cpu()
|
samples += sampler.sigmas[start_at_step].cpu() * model.model.process_latent_out(noise)
|
||||||
|
|
||||||
#cnets
|
#cnets
|
||||||
cnets = comfy.sample.get_models_from_cond(positive, 'control') + comfy.sample.get_models_from_cond(negative, 'control')
|
cnets = [c['control'] for (_, c) in positive + negative if 'control' in c and isinstance(c['control'], comfy.controlnet.ControlNet)]
|
||||||
cnets = [m for m in cnets if isinstance(m, comfy.controlnet.ControlNet)]
|
|
||||||
cnets = list(set([x for m in cnets for x in recursion_to_list(m, "previous_controlnet")]))
|
cnets = list(set([x for m in cnets for x in recursion_to_list(m, "previous_controlnet")]))
|
||||||
cnet_imgs = [
|
cnet_imgs = [
|
||||||
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
||||||
@@ -319,8 +143,7 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
for m in cnets]
|
for m in cnets]
|
||||||
|
|
||||||
#T2I
|
#T2I
|
||||||
T2Is = comfy.sample.get_models_from_cond(positive, 'control') + comfy.sample.get_models_from_cond(negative, 'control')
|
T2Is = [c['control'] for (_, c) in positive + negative if 'control' in c and isinstance(c['control'], comfy.controlnet.T2IAdapter)]
|
||||||
T2Is = [m for m in T2Is if isinstance(m, comfy.controlnet.T2IAdapter)]
|
|
||||||
T2Is = [x for m in T2Is for x in recursion_to_list(m, "previous_controlnet")]
|
T2Is = [x for m in T2Is for x in recursion_to_list(m, "previous_controlnet")]
|
||||||
T2I_imgs = [
|
T2I_imgs = [
|
||||||
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
||||||
@@ -354,16 +177,14 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
for c in negative
|
for c in negative
|
||||||
]
|
]
|
||||||
|
|
||||||
positive_copy = comfy.sample.broadcast_cond(positive, shape[0], device)
|
|
||||||
negative_copy = comfy.sample.broadcast_cond(negative, shape[0], device)
|
|
||||||
|
|
||||||
gen = torch.manual_seed(noise_seed)
|
gen = torch.manual_seed(noise_seed)
|
||||||
if tiling_strategy == 'random' or tiling_strategy == 'random strict':
|
if tiling_strategy == 'random' or tiling_strategy == 'random strict':
|
||||||
tiles = get_tiles_and_masks_rgrid(end_at_step - start_at_step, samples.shape, tile_height, tile_width, gen)
|
tiles = tiling.get_tiles_and_masks_rgrid(end_at_step - start_at_step, samples.shape, tile_height, tile_width, gen)
|
||||||
elif tiling_strategy == 'padded':
|
elif tiling_strategy == 'padded':
|
||||||
tiles = get_tiles_and_masks_padded(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
tiles = tiling.get_tiles_and_masks_padded(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
||||||
else:
|
else:
|
||||||
tiles = get_tiles_and_masks_simple(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
tiles = tiling.get_tiles_and_masks_simple(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
||||||
|
|
||||||
total_steps = sum([num_steps for img_pass in tiles for steps_list in img_pass for _,_,_,_,num_steps,_ in steps_list])
|
total_steps = sum([num_steps for img_pass in tiles for steps_list in img_pass for _,_,_,_,num_steps,_ in steps_list])
|
||||||
current_step = [0]
|
current_step = [0]
|
||||||
@@ -394,7 +215,7 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
for tile_h, tile_h_len, tile_w, tile_w_len, tile_steps, tile_mask in img_pass[i]:
|
for tile_h, tile_h_len, tile_w, tile_w_len, tile_steps, tile_mask in img_pass[i]:
|
||||||
tiled_mask = None
|
tiled_mask = None
|
||||||
if noise_mask is not None:
|
if noise_mask is not None:
|
||||||
tiled_mask = get_slice(noise_mask, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
tiled_mask = tiling.get_slice(noise_mask, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
||||||
if tile_mask is not None:
|
if tile_mask is not None:
|
||||||
if tiled_mask is not None:
|
if tiled_mask is not None:
|
||||||
tiled_mask *= tile_mask.to(device)
|
tiled_mask *= tile_mask.to(device)
|
||||||
@@ -402,7 +223,7 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
tiled_mask = tile_mask.to(device)
|
tiled_mask = tile_mask.to(device)
|
||||||
|
|
||||||
if tiling_strategy == 'padded' or tiling_strategy == 'random strict':
|
if tiling_strategy == 'padded' or tiling_strategy == 'random strict':
|
||||||
tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask = mask_at_boundary( tile_h, tile_h_len, tile_w, tile_w_len,
|
tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask = tiling.mask_at_boundary( tile_h, tile_h_len, tile_w, tile_w_len,
|
||||||
tile_height, tile_width, samples.shape[-2], samples.shape[-1],
|
tile_height, tile_width, samples.shape[-2], samples.shape[-1],
|
||||||
tiled_mask, device)
|
tiled_mask, device)
|
||||||
|
|
||||||
@@ -410,15 +231,15 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
if tiled_mask is not None and tiled_mask.sum().cpu() == 0.0:
|
if tiled_mask is not None and tiled_mask.sum().cpu() == 0.0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tiled_latent = get_slice(samples, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
tiled_latent = tiling.get_slice(samples, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
||||||
|
|
||||||
if tiling_strategy == 'padded':
|
if tiling_strategy == 'padded':
|
||||||
tiled_noise = get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
||||||
else:
|
else:
|
||||||
if tiled_mask is None or noise_mask is None:
|
if tiled_mask is None or noise_mask is None:
|
||||||
tiled_noise = torch.zeros_like(tiled_latent)
|
tiled_noise = torch.zeros_like(tiled_latent)
|
||||||
else:
|
else:
|
||||||
tiled_noise = get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) * (1 - tiled_mask)
|
tiled_noise = tiling.get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) * (1 - tiled_mask)
|
||||||
|
|
||||||
#TODO: all other condition based stuff like area sets and GLIGEN should also happen here
|
#TODO: all other condition based stuff like area sets and GLIGEN should also happen here
|
||||||
|
|
||||||
@@ -430,19 +251,19 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
for m, img in zip(T2Is, T2I_imgs):
|
for m, img in zip(T2Is, T2I_imgs):
|
||||||
slices_T2I(tile_h, tile_h_len, tile_w, tile_w_len, m, img)
|
slices_T2I(tile_h, tile_h_len, tile_w, tile_w_len, m, img)
|
||||||
|
|
||||||
pos = copy_cond(positive_copy)
|
pos = [c.copy() for c in positive_copy]#copy_cond(positive_copy)
|
||||||
neg = copy_cond(negative_copy)
|
neg = [c.copy() for c in negative_copy]#copy_cond(negative_copy)
|
||||||
|
|
||||||
#cond areas
|
#cond areas
|
||||||
pos = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(pos, spatial_conds_pos)]
|
pos = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area, device) for c, area in zip(pos, spatial_conds_pos)]
|
||||||
pos = [c for c, ignore in pos if not ignore]
|
pos = [c for c, ignore in pos if not ignore]
|
||||||
neg = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(neg, spatial_conds_neg)]
|
neg = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area, device) for c, area in zip(neg, spatial_conds_neg)]
|
||||||
neg = [c for c, ignore in neg if not ignore]
|
neg = [c for c, ignore in neg if not ignore]
|
||||||
|
|
||||||
#gligen
|
#gligen
|
||||||
for (_, cond), gligen in zip(pos, gligen_pos):
|
for cond, gligen in zip(pos, gligen_pos):
|
||||||
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
||||||
for (_, cond), gligen in zip(neg, gligen_neg):
|
for cond, gligen in zip(neg, gligen_neg):
|
||||||
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
||||||
|
|
||||||
tile_result = sampler.sample(tiled_noise, pos, neg, cfg=cfg, latent_image=tiled_latent, start_step=start_at_step + i * tile_steps, last_step=start_at_step + i*tile_steps + tile_steps, force_full_denoise=force_full_denoise and i+1 == end_at_step - start_at_step, denoise_mask=tiled_mask, callback=callback, disable_pbar=True, seed=noise_seed)
|
tile_result = sampler.sample(tiled_noise, pos, neg, cfg=cfg, latent_image=tiled_latent, start_step=start_at_step + i * tile_steps, last_step=start_at_step + i*tile_steps + tile_steps, force_full_denoise=force_full_denoise and i+1 == end_at_step - start_at_step, denoise_mask=tiled_mask, callback=callback, disable_pbar=True, seed=noise_seed)
|
||||||
@@ -450,9 +271,9 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
if tiled_mask is not None:
|
if tiled_mask is not None:
|
||||||
tiled_mask = tiled_mask.cpu()
|
tiled_mask = tiled_mask.cpu()
|
||||||
if tiling_strategy == "random strict":
|
if tiling_strategy == "random strict":
|
||||||
set_slice(samples_next, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
tiling.set_slice(samples_next, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
||||||
else:
|
else:
|
||||||
set_slice(samples, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
tiling.set_slice(samples, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
||||||
if tiling_strategy == "random strict":
|
if tiling_strategy == "random strict":
|
||||||
samples = samples_next.clone()
|
samples = samples_next.clone()
|
||||||
|
|
||||||
@@ -463,6 +284,7 @@ def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_
|
|||||||
out["samples"] = samples.cpu()
|
out["samples"] = samples.cpu()
|
||||||
return (out, )
|
return (out, )
|
||||||
|
|
||||||
|
+
|
||||||
class TiledKSampler:
|
class TiledKSampler:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
|
|||||||
Reference in New Issue
Block a user