mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
feat: Improve Nunchaku LoRA loading with copy_with_ctx support and add unit tests. see #733
This commit is contained in:
@@ -36,6 +36,7 @@ any_type = AnyType("*")
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import sys
|
||||
import folder_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -98,25 +99,37 @@ def to_diffusers(input_lora):
|
||||
|
||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
"""Load a Flux LoRA for Nunchaku model"""
|
||||
model_wrapper = model.model.diffusion_model
|
||||
transformer = model_wrapper.model
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = model.clone()
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
|
||||
# Get full path to the LoRA file. Allow both direct paths and registered LoRA names.
|
||||
lora_path = lora_name if os.path.isfile(lora_name) else folder_paths.get_full_path("loras", lora_name)
|
||||
if not lora_path or not os.path.isfile(lora_path):
|
||||
logger.warning("Skipping LoRA '%s' because it could not be found", lora_name)
|
||||
return model
|
||||
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
model_wrapper = model.model.diffusion_model
|
||||
|
||||
# Try to find copy_with_ctx in the same module as ComfyFluxWrapper
|
||||
module_name = model_wrapper.__class__.__module__
|
||||
module = sys.modules.get(module_name)
|
||||
copy_with_ctx = getattr(module, "copy_with_ctx", None)
|
||||
|
||||
if copy_with_ctx is not None:
|
||||
# New logic using copy_with_ctx from ComfyUI-nunchaku 1.1.0+
|
||||
ret_model_wrapper, ret_model = copy_with_ctx(model_wrapper)
|
||||
ret_model_wrapper.loras = [*model_wrapper.loras, (lora_path, lora_strength)]
|
||||
else:
|
||||
# Fallback to legacy logic
|
||||
logger.warning("Please upgrade ComfyUI-nunchaku to 1.1.0 or above for better LoRA support. Falling back to legacy loading logic.")
|
||||
transformer = model_wrapper.model
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
# Convert the LoRA to diffusers format
|
||||
sd = to_diffusers(lora_path)
|
||||
|
||||
Reference in New Issue
Block a user