Fix CFG denoiser.

This commit is contained in:
Vyacheslav Moskalev
2024-03-09 20:28:19 +07:00
parent f6d01dc544
commit 685a3d4a92

View File

@@ -253,10 +253,8 @@ class CFGNoisePredictor(torch.nn.Module):
self.orig = comfy.samplers.CFGNoisePredictor(model) #CFGNoisePredictorOrig(model)
self.inner_model = model
self.inner_model2 = CFGDenoiser(model.apply_model)
self.inner_model2.num_timesteps = model.num_timesteps
self.inner_model2.device = self.ksampler.device if hasattr(self.ksampler, "device") else None
self.s_min_uncond = 0.0
self.alphas_cumprod = model.alphas_cumprod
self.c_adm = None
self.init_cond = None
self.init_uncond = None
@@ -308,12 +306,13 @@ def set_model_k(self: KSampler):
self.model_denoise = CFGNoisePredictor(self.model) # main change
if ((getattr(self.model, "parameterization", "") == "v") or
(getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)):
self.model_wrap = wrap_model(self.model_denoise, quantize=True)
self.model_wrap = wrap_model(self.model_denoise)
self.model_wrap.parameterization = getattr(self.model, "parameterization", "v")
else:
self.model_wrap = wrap_model(self.model_denoise, quantize=True)
self.model_wrap = wrap_model(self.model_denoise)
self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps")
self.model_k = KSamplerX0Inpaint(self.model_wrap)
sigmas = self.calculate_sigmas(self.steps)
self.model_k = KSamplerX0Inpaint(self.model_wrap, sigmas)
class SDKSampler(comfy.samplers.KSampler):
def __init__(self, *args, **kwargs):