diff --git a/py/smZ_cfg_denoiser.py b/py/smZ_cfg_denoiser.py index c970138..3fb8628 100644 --- a/py/smZ_cfg_denoiser.py +++ b/py/smZ_cfg_denoiser.py @@ -4,7 +4,7 @@ import torch from typing import List import comfy.sample from comfy import model_base, model_management -from comfy.samplers import KSampler, KSamplerX0Inpaint +from comfy.samplers import KSampler, KSamplerX0Inpaint, wrap_model #from comfy.k_diffusion.external import CompVisDenoiser import nodes import inspect @@ -308,10 +308,10 @@ def set_model_k(self: KSampler): self.model_denoise = CFGNoisePredictor(self.model) # main change if ((getattr(self.model, "parameterization", "") == "v") or (getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)): - self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) + self.model_wrap = wrap_model(self.model_denoise, quantize=True) self.model_wrap.parameterization = getattr(self.model, "parameterization", "v") else: - self.model_wrap = CompVisDenoiser(self.model_denoise, quantize=True) + self.model_wrap = wrap_model(self.model_denoise, quantize=True) self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps") self.model_k = KSamplerX0Inpaint(self.model_wrap)