Merge pull request #127 from slouffka/main

Fix CFG denoiser.
This commit is contained in:
VALADI K JAGANATHAN
2024-03-25 19:11:14 +05:30
committed by GitHub

View File

@@ -253,10 +253,8 @@ class CFGNoisePredictor(torch.nn.Module):
self.orig = comfy.samplers.CFGNoisePredictor(model) #CFGNoisePredictorOrig(model) self.orig = comfy.samplers.CFGNoisePredictor(model) #CFGNoisePredictorOrig(model)
self.inner_model = model self.inner_model = model
self.inner_model2 = CFGDenoiser(model.apply_model) self.inner_model2 = CFGDenoiser(model.apply_model)
self.inner_model2.num_timesteps = model.num_timesteps
self.inner_model2.device = self.ksampler.device if hasattr(self.ksampler, "device") else None self.inner_model2.device = self.ksampler.device if hasattr(self.ksampler, "device") else None
self.s_min_uncond = 0.0 self.s_min_uncond = 0.0
self.alphas_cumprod = model.alphas_cumprod
self.c_adm = None self.c_adm = None
self.init_cond = None self.init_cond = None
self.init_uncond = None self.init_uncond = None
@@ -308,12 +306,13 @@ def set_model_k(self: KSampler):
self.model_denoise = CFGNoisePredictor(self.model) # main change self.model_denoise = CFGNoisePredictor(self.model) # main change
if ((getattr(self.model, "parameterization", "") == "v") or if ((getattr(self.model, "parameterization", "") == "v") or
(getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)): (getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)):
self.model_wrap = wrap_model(self.model_denoise, quantize=True) self.model_wrap = wrap_model(self.model_denoise)
self.model_wrap.parameterization = getattr(self.model, "parameterization", "v") self.model_wrap.parameterization = getattr(self.model, "parameterization", "v")
else: else:
self.model_wrap = wrap_model(self.model_denoise, quantize=True) self.model_wrap = wrap_model(self.model_denoise)
self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps") self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps")
self.model_k = KSamplerX0Inpaint(self.model_wrap) sigmas = self.calculate_sigmas(self.steps)
self.model_k = KSamplerX0Inpaint(self.model_wrap, sigmas)
class SDKSampler(comfy.samplers.KSampler): class SDKSampler(comfy.samplers.KSampler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):