mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-24 22:52:13 -03:00
Efficiency Nodes V2.0
This commit is contained in:
321
py/smZ_cfg_denoiser.py
Normal file
321
py/smZ_cfg_denoiser.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
import comfy
|
||||
import torch
|
||||
from typing import List
|
||||
import comfy.sample
|
||||
from comfy import model_base, model_management
|
||||
from comfy.samplers import KSampler, CompVisVDenoiser, KSamplerX0Inpaint
|
||||
from comfy.k_diffusion.external import CompVisDenoiser
|
||||
import nodes
|
||||
import inspect
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from comfy import model_management
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1)).to(device=tensor.device)], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.model_wrap = None
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.steps = None
|
||||
"""number of steps as specified by user in UI"""
|
||||
|
||||
self.total_steps = None
|
||||
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.sampler = None
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
self.mask_before_denoising = False
|
||||
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
return x_out
|
||||
|
||||
def update_inner_model(self):
|
||||
self.model_wrap = None
|
||||
|
||||
c, uc = self.p.get_conds()
|
||||
self.sampler.sampler_extra_args['cond'] = c
|
||||
self.sampler.sampler_extra_args['uncond'] = uc
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
is_edit_model = False
|
||||
|
||||
conds_list, tensor = cond
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
if self.mask_before_denoising and self.mask is not None:
|
||||
x = self.init_latent * self.mask + self.nmask * x
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if False:
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm, 'transformer_options': {'from_smZ': True}} # pylint: disable=C3001
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
skip_uncond = False
|
||||
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], **make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if not self.mask_before_denoising and self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
del x_out
|
||||
return denoised
|
||||
|
||||
# ========================================================================
|
||||
|
||||
def expand(tensor1, tensor2):
|
||||
def adjust_tensor_shape(tensor_small, tensor_big):
|
||||
# Calculate replication factor
|
||||
# -(-a // b) is ceiling of division without importing math.ceil
|
||||
replication_factor = -(-tensor_big.size(1) // tensor_small.size(1))
|
||||
|
||||
# Use repeat to extend tensor_small
|
||||
tensor_small_extended = tensor_small.repeat(1, replication_factor, 1)
|
||||
|
||||
# Take the rows of the extended tensor_small to match tensor_big
|
||||
tensor_small_matched = tensor_small_extended[:, :tensor_big.size(1), :]
|
||||
|
||||
return tensor_small_matched
|
||||
|
||||
# Check if their second dimensions are different
|
||||
if tensor1.size(1) != tensor2.size(1):
|
||||
# Check which tensor has the smaller second dimension and adjust its shape
|
||||
if tensor1.size(1) < tensor2.size(1):
|
||||
tensor1 = adjust_tensor_shape(tensor1, tensor2)
|
||||
else:
|
||||
tensor2 = adjust_tensor_shape(tensor2, tensor1)
|
||||
return (tensor1, tensor2)
|
||||
|
||||
def _find_outer_instance(target, target_type):
|
||||
import inspect
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if target in frame.f_locals:
|
||||
found = frame.f_locals[target]
|
||||
if isinstance(found, target_type) and found != 1: # steps == 1
|
||||
return found
|
||||
frame = frame.f_back
|
||||
return None
|
||||
|
||||
# ========================================================================
|
||||
def bounded_modulo(number, modulo_value):
|
||||
return number if number < modulo_value else modulo_value
|
||||
|
||||
def calc_cond(c, current_step):
|
||||
"""Group by smZ conds that may do prompt-editing / regular conds / comfy conds."""
|
||||
_cond = []
|
||||
# Group by conds from smZ
|
||||
fn=lambda x : x[1].get("from_smZ", None) is not None
|
||||
an_iterator = itertools.groupby(c, fn )
|
||||
for key, group in an_iterator:
|
||||
ls=list(group)
|
||||
# Group by prompt-editing conds
|
||||
fn2=lambda x : x[1].get("smZid", None)
|
||||
an_iterator2 = itertools.groupby(ls, fn2)
|
||||
for key2, group2 in an_iterator2:
|
||||
ls2=list(group2)
|
||||
if key2 is not None:
|
||||
orig_len = ls2[0][1].get('orig_len', 1)
|
||||
i = bounded_modulo(current_step, orig_len - 1)
|
||||
_cond = _cond + [ls2[i]]
|
||||
else:
|
||||
_cond = _cond + ls2
|
||||
return _cond
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.ksampler = _find_outer_instance('self', comfy.samplers.KSampler)
|
||||
self.step = 0
|
||||
self.orig = comfy.samplers.CFGNoisePredictor(model) #CFGNoisePredictorOrig(model)
|
||||
self.inner_model = model
|
||||
self.inner_model2 = CFGDenoiser(model.apply_model)
|
||||
self.inner_model2.num_timesteps = model.num_timesteps
|
||||
self.inner_model2.device = self.ksampler.device if hasattr(self.ksampler, "device") else None
|
||||
self.s_min_uncond = 0.0
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
self.c_adm = None
|
||||
self.init_cond = None
|
||||
self.init_uncond = None
|
||||
self.is_prompt_editing_u = False
|
||||
self.is_prompt_editing_c = False
|
||||
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||
|
||||
cc=calc_cond(cond, self.step)
|
||||
uu=calc_cond(uncond, self.step)
|
||||
self.step += 1
|
||||
|
||||
if (any([p[1].get('from_smZ', False) for p in cc]) or
|
||||
any([p[1].get('from_smZ', False) for p in uu])):
|
||||
if model_options.get('transformer_options',None) is None:
|
||||
model_options['transformer_options'] = {}
|
||||
model_options['transformer_options']['from_smZ'] = True
|
||||
|
||||
# Only supports one cond
|
||||
for ix in range(len(cc)):
|
||||
if cc[ix][1].get('from_smZ', False):
|
||||
cc = [cc[ix]]
|
||||
break
|
||||
for ix in range(len(uu)):
|
||||
if uu[ix][1].get('from_smZ', False):
|
||||
uu = [uu[ix]]
|
||||
break
|
||||
c=cc[0][1]
|
||||
u=uu[0][1]
|
||||
_cc = cc[0][0]
|
||||
_uu = uu[0][0]
|
||||
if c.get("adm_encoded", None) is not None:
|
||||
self.c_adm = torch.cat([c['adm_encoded'], u['adm_encoded']])
|
||||
# SDXL. Need to pad with repeats
|
||||
_cc, _uu = expand(_cc, _uu)
|
||||
_uu, _cc = expand(_uu, _cc)
|
||||
x.c_adm = self.c_adm
|
||||
conds_list = c.get('conds_list', [[(0, 1.0)]])
|
||||
image_cond = txt2img_image_conditioning(None, x)
|
||||
out = self.inner_model2(x, timestep, cond=(conds_list, _cc), uncond=_uu, cond_scale=cond_scale, s_min_uncond=self.s_min_uncond, image_cond=image_cond)
|
||||
return out
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width=None, height=None):
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
# =======================================================================================
|
||||
|
||||
def set_model_k(self: KSampler):
|
||||
self.model_denoise = CFGNoisePredictor(self.model) # main change
|
||||
if ((getattr(self.model, "parameterization", "") == "v") or
|
||||
(getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)):
|
||||
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||
self.model_wrap.parameterization = getattr(self.model, "parameterization", "v")
|
||||
else:
|
||||
self.model_wrap = CompVisDenoiser(self.model_denoise, quantize=True)
|
||||
self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps")
|
||||
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||
|
||||
class SDKSampler(comfy.samplers.KSampler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SDKSampler, self).__init__(*args, **kwargs)
|
||||
set_model_k(self)
|
||||
Reference in New Issue
Block a user