Files
efficiency-nodes-comfyui/py/smZ_cfg_denoiser.py
2023-11-01 13:57:57 +05:30

322 lines
14 KiB
Python

# https://github.com/shiimizu/ComfyUI_smZNodes
import comfy
import torch
from typing import List
import comfy.sample
from comfy import model_base, model_management
from comfy.samplers import KSampler, KSamplerX0Inpaint
#from comfy.k_diffusion.external import CompVisDenoiser
import nodes
import inspect
import functools
import importlib
import os
import re
import itertools
import torch
from comfy import model_management
def catenate_conds(conds):
if not isinstance(conds[0], dict):
return torch.cat(conds)
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
def subscript_cond(cond, a, b):
if not isinstance(cond, dict):
return cond[a:b]
return {key: vec[a:b] for key, vec in cond.items()}
def pad_cond(tensor, repeats, empty):
if not isinstance(tensor, dict):
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1)).to(device=tensor.device)], axis=1)
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
return tensor
class CFGDenoiser(torch.nn.Module):
"""
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
negative prompt.
"""
def __init__(self, model):
super().__init__()
self.inner_model = model
self.model_wrap = None
self.mask = None
self.nmask = None
self.init_latent = None
self.steps = None
"""number of steps as specified by user in UI"""
self.total_steps = None
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
self.step = 0
self.image_cfg_scale = None
self.padded_cond_uncond = False
self.sampler = None
self.model_wrap = None
self.p = None
self.mask_before_denoising = False
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
for cond_index, weight in conds:
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
return denoised
def combine_denoised_for_edit_model(self, x_out, cond_scale):
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
return denoised
def get_pred_x0(self, x_in, x_out, sigma):
return x_out
def update_inner_model(self):
self.model_wrap = None
c, uc = self.p.get_conds()
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
model_management.throw_exception_if_processing_interrupted()
is_edit_model = False
conds_list, tensor = cond
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
if self.mask_before_denoising and self.mask is not None:
x = self.init_latent * self.mask + self.nmask * x
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
if False:
image_uncond = torch.zeros_like(image_cond)
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm, 'transformer_options': {'from_smZ': True}} # pylint: disable=C3001
else:
image_uncond = image_cond
if isinstance(uncond, dict):
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}}
else:
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}}
if not is_edit_model:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
else:
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
skip_uncond = False
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
skip_uncond = True
x_in = x_in[:-batch_size]
sigma_in = sigma_in[:-batch_size]
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
if is_edit_model:
cond_in = catenate_conds([tensor, uncond, uncond])
elif skip_uncond:
cond_in = tensor
else:
cond_in = catenate_conds([tensor, uncond])
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
if not is_edit_model:
c_crossattn = subscript_cond(tensor, a, b)
else:
c_crossattn = torch.cat([tensor[a:b]], uncond)
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(c_crossattn, image_cond_in[a:b]))
if not skip_uncond:
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], **make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
denoised_image_indexes = [x[0][0] for x in conds_list]
if skip_uncond:
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
if not self.mask_before_denoising and self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
self.step += 1
del x_out
return denoised
# ========================================================================
def expand(tensor1, tensor2):
def adjust_tensor_shape(tensor_small, tensor_big):
# Calculate replication factor
# -(-a // b) is ceiling of division without importing math.ceil
replication_factor = -(-tensor_big.size(1) // tensor_small.size(1))
# Use repeat to extend tensor_small
tensor_small_extended = tensor_small.repeat(1, replication_factor, 1)
# Take the rows of the extended tensor_small to match tensor_big
tensor_small_matched = tensor_small_extended[:, :tensor_big.size(1), :]
return tensor_small_matched
# Check if their second dimensions are different
if tensor1.size(1) != tensor2.size(1):
# Check which tensor has the smaller second dimension and adjust its shape
if tensor1.size(1) < tensor2.size(1):
tensor1 = adjust_tensor_shape(tensor1, tensor2)
else:
tensor2 = adjust_tensor_shape(tensor2, tensor1)
return (tensor1, tensor2)
def _find_outer_instance(target, target_type):
import inspect
frame = inspect.currentframe()
while frame:
if target in frame.f_locals:
found = frame.f_locals[target]
if isinstance(found, target_type) and found != 1: # steps == 1
return found
frame = frame.f_back
return None
# ========================================================================
def bounded_modulo(number, modulo_value):
return number if number < modulo_value else modulo_value
def calc_cond(c, current_step):
"""Group by smZ conds that may do prompt-editing / regular conds / comfy conds."""
_cond = []
# Group by conds from smZ
fn=lambda x : x[1].get("from_smZ", None) is not None
an_iterator = itertools.groupby(c, fn )
for key, group in an_iterator:
ls=list(group)
# Group by prompt-editing conds
fn2=lambda x : x[1].get("smZid", None)
an_iterator2 = itertools.groupby(ls, fn2)
for key2, group2 in an_iterator2:
ls2=list(group2)
if key2 is not None:
orig_len = ls2[0][1].get('orig_len', 1)
i = bounded_modulo(current_step, orig_len - 1)
_cond = _cond + [ls2[i]]
else:
_cond = _cond + ls2
return _cond
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.ksampler = _find_outer_instance('self', comfy.samplers.KSampler)
self.step = 0
self.orig = comfy.samplers.CFGNoisePredictor(model) #CFGNoisePredictorOrig(model)
self.inner_model = model
self.inner_model2 = CFGDenoiser(model.apply_model)
self.inner_model2.num_timesteps = model.num_timesteps
self.inner_model2.device = self.ksampler.device if hasattr(self.ksampler, "device") else None
self.s_min_uncond = 0.0
self.alphas_cumprod = model.alphas_cumprod
self.c_adm = None
self.init_cond = None
self.init_uncond = None
self.is_prompt_editing_u = False
self.is_prompt_editing_c = False
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
cc=calc_cond(cond, self.step)
uu=calc_cond(uncond, self.step)
self.step += 1
if (any([p[1].get('from_smZ', False) for p in cc]) or
any([p[1].get('from_smZ', False) for p in uu])):
if model_options.get('transformer_options',None) is None:
model_options['transformer_options'] = {}
model_options['transformer_options']['from_smZ'] = True
# Only supports one cond
for ix in range(len(cc)):
if cc[ix][1].get('from_smZ', False):
cc = [cc[ix]]
break
for ix in range(len(uu)):
if uu[ix][1].get('from_smZ', False):
uu = [uu[ix]]
break
c=cc[0][1]
u=uu[0][1]
_cc = cc[0][0]
_uu = uu[0][0]
if c.get("adm_encoded", None) is not None:
self.c_adm = torch.cat([c['adm_encoded'], u['adm_encoded']])
# SDXL. Need to pad with repeats
_cc, _uu = expand(_cc, _uu)
_uu, _cc = expand(_uu, _cc)
x.c_adm = self.c_adm
conds_list = c.get('conds_list', [[(0, 1.0)]])
image_cond = txt2img_image_conditioning(None, x)
out = self.inner_model2(x, timestep, cond=(conds_list, _cc), uncond=_uu, cond_scale=cond_scale, s_min_uncond=self.s_min_uncond, image_cond=image_cond)
return out
def txt2img_image_conditioning(sd_model, x, width=None, height=None):
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
# =======================================================================================
def set_model_k(self: KSampler):
self.model_denoise = CFGNoisePredictor(self.model) # main change
if ((getattr(self.model, "parameterization", "") == "v") or
(getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)):
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = getattr(self.model, "parameterization", "v")
else:
self.model_wrap = CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps")
self.model_k = KSamplerX0Inpaint(self.model_wrap)
class SDKSampler(comfy.samplers.KSampler):
def __init__(self, *args, **kwargs):
super(SDKSampler, self).__init__(*args, **kwargs)
set_model_k(self)