mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Refactor LoRA handling in LoraLoader, LoraStacker, and TriggerWordToggle
- Introduced logging to track unexpected formats in LoRA and trigger word data. - Refactored LoRA processing to support both old and new kwargs formats in LoraLoader and LoraStacker. - Enhanced trigger word processing to handle different data formats in TriggerWordToggle. - Improved code readability and maintainability by extracting common logic into helper methods.
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from nodes import LoraLoader
|
from nodes import LoraLoader
|
||||||
from comfy.comfy_types import IO # type: ignore
|
from comfy.comfy_types import IO # type: ignore
|
||||||
from ..services.lora_scanner import LoraScanner
|
from ..services.lora_scanner import LoraScanner
|
||||||
@@ -6,6 +7,8 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type
|
from .utils import FlexibleOptionalInputType, any_type
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoraManagerLoader:
|
class LoraManagerLoader:
|
||||||
NAME = "Lora Loader (LoraManager)"
|
NAME = "Lora Loader (LoraManager)"
|
||||||
CATEGORY = "Lora Manager/loaders"
|
CATEGORY = "Lora Manager/loaders"
|
||||||
@@ -55,6 +58,23 @@ class LoraManagerLoader:
|
|||||||
basename = os.path.basename(lora_path)
|
basename = os.path.basename(lora_path)
|
||||||
return os.path.splitext(basename)[0]
|
return os.path.splitext(basename)[0]
|
||||||
|
|
||||||
|
def _get_loras_list(self, kwargs):
|
||||||
|
"""Helper to extract loras list from either old or new kwargs format"""
|
||||||
|
if 'loras' not in kwargs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
loras_data = kwargs['loras']
|
||||||
|
# Handle new format: {'loras': {'__value__': [...]}}
|
||||||
|
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||||
|
return loras_data['__value__']
|
||||||
|
# Handle old format: {'loras': [...]}
|
||||||
|
elif isinstance(loras_data, list):
|
||||||
|
return loras_data
|
||||||
|
# Unexpected format
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||||
|
return []
|
||||||
|
|
||||||
def load_loras(self, model, clip, text, **kwargs):
|
def load_loras(self, model, clip, text, **kwargs):
|
||||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||||
loaded_loras = []
|
loaded_loras = []
|
||||||
@@ -74,24 +94,24 @@ class LoraManagerLoader:
|
|||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||||
|
|
||||||
# Then process loras from kwargs
|
# Then process loras from kwargs with support for both old and new formats
|
||||||
if 'loras' in kwargs:
|
loras_list = self._get_loras_list(kwargs)
|
||||||
for lora in kwargs['loras']:
|
for lora in loras_list:
|
||||||
if not lora.get('active', False):
|
if not lora.get('active', False):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
lora_name = lora['name']
|
lora_name = lora['name']
|
||||||
strength = float(lora['strength'])
|
strength = float(lora['strength'])
|
||||||
|
|
||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||||
|
|
||||||
# Apply the LoRA using the resolved path
|
# Apply the LoRA using the resolved path
|
||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
||||||
loaded_loras.append(f"{lora_name}: {strength}")
|
loaded_loras.append(f"{lora_name}: {strength}")
|
||||||
|
|
||||||
# Add trigger words to collection
|
# Add trigger words to collection
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
# use ',, ' to separate trigger words for group mode
|
# use ',, ' to separate trigger words for group mode
|
||||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ from ..config import config
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type
|
from .utils import FlexibleOptionalInputType, any_type
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoraStacker:
|
class LoraStacker:
|
||||||
NAME = "Lora Stacker (LoraManager)"
|
NAME = "Lora Stacker (LoraManager)"
|
||||||
@@ -52,6 +55,23 @@ class LoraStacker:
|
|||||||
basename = os.path.basename(lora_path)
|
basename = os.path.basename(lora_path)
|
||||||
return os.path.splitext(basename)[0]
|
return os.path.splitext(basename)[0]
|
||||||
|
|
||||||
|
def _get_loras_list(self, kwargs):
|
||||||
|
"""Helper to extract loras list from either old or new kwargs format"""
|
||||||
|
if 'loras' not in kwargs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
loras_data = kwargs['loras']
|
||||||
|
# Handle new format: {'loras': {'__value__': [...]}}
|
||||||
|
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||||
|
return loras_data['__value__']
|
||||||
|
# Handle old format: {'loras': [...]}
|
||||||
|
elif isinstance(loras_data, list):
|
||||||
|
return loras_data
|
||||||
|
# Unexpected format
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||||
|
return []
|
||||||
|
|
||||||
def stack_loras(self, text, **kwargs):
|
def stack_loras(self, text, **kwargs):
|
||||||
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
||||||
stack = []
|
stack = []
|
||||||
@@ -67,24 +87,25 @@ class LoraStacker:
|
|||||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
if 'loras' in kwargs:
|
# Process loras from kwargs with support for both old and new formats
|
||||||
for lora in kwargs['loras']:
|
loras_list = self._get_loras_list(kwargs)
|
||||||
if not lora.get('active', False):
|
for lora in loras_list:
|
||||||
continue
|
if not lora.get('active', False):
|
||||||
|
continue
|
||||||
|
|
||||||
lora_name = lora['name']
|
lora_name = lora['name']
|
||||||
model_strength = float(lora['strength'])
|
model_strength = float(lora['strength'])
|
||||||
clip_strength = model_strength # Using same strength for both as in the original loader
|
clip_strength = model_strength # Using same strength for both as in the original loader
|
||||||
|
|
||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||||
|
|
||||||
# Add to stack without loading
|
# Add to stack without loading
|
||||||
# replace '/' with os.sep to avoid different OS path format
|
# replace '/' with os.sep to avoid different OS path format
|
||||||
stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength))
|
stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength))
|
||||||
|
|
||||||
# Add trigger words to collection
|
# Add trigger words to collection
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
# use ',, ' to separate trigger words for group mode
|
# use ',, ' to separate trigger words for group mode
|
||||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ import json
|
|||||||
import re
|
import re
|
||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
from .utils import FlexibleOptionalInputType, any_type
|
from .utils import FlexibleOptionalInputType, any_type
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TriggerWordToggle:
|
class TriggerWordToggle:
|
||||||
NAME = "TriggerWord Toggle (LoraManager)"
|
NAME = "TriggerWord Toggle (LoraManager)"
|
||||||
@@ -24,8 +28,24 @@ class TriggerWordToggle:
|
|||||||
RETURN_NAMES = ("filtered_trigger_words",)
|
RETURN_NAMES = ("filtered_trigger_words",)
|
||||||
FUNCTION = "process_trigger_words"
|
FUNCTION = "process_trigger_words"
|
||||||
|
|
||||||
|
def _get_toggle_data(self, kwargs, key='toggle_trigger_words'):
|
||||||
|
"""Helper to extract data from either old or new kwargs format"""
|
||||||
|
if key not in kwargs:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = kwargs[key]
|
||||||
|
# Handle new format: {'key': {'__value__': ...}}
|
||||||
|
if isinstance(data, dict) and '__value__' in data:
|
||||||
|
return data['__value__']
|
||||||
|
# Handle old format: {'key': ...}
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
def process_trigger_words(self, id, group_mode, **kwargs):
|
def process_trigger_words(self, id, group_mode, **kwargs):
|
||||||
trigger_words = kwargs.get("trigger_words", "")
|
# Handle both old and new formats for trigger_words
|
||||||
|
trigger_words_data = self._get_toggle_data(kwargs, 'trigger_words')
|
||||||
|
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
||||||
|
|
||||||
# Send trigger words to frontend
|
# Send trigger words to frontend
|
||||||
PromptServer.instance.send_sync("trigger_word_update", {
|
PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
"id": id,
|
"id": id,
|
||||||
@@ -34,11 +54,10 @@ class TriggerWordToggle:
|
|||||||
|
|
||||||
filtered_triggers = trigger_words
|
filtered_triggers = trigger_words
|
||||||
|
|
||||||
if 'toggle_trigger_words' in kwargs:
|
# Get toggle data with support for both formats
|
||||||
|
trigger_data = self._get_toggle_data(kwargs, 'toggle_trigger_words')
|
||||||
|
if trigger_data:
|
||||||
try:
|
try:
|
||||||
# Get trigger word toggle data
|
|
||||||
trigger_data = kwargs['toggle_trigger_words']
|
|
||||||
|
|
||||||
# Convert to list if it's a JSON string
|
# Convert to list if it's a JSON string
|
||||||
if isinstance(trigger_data, str):
|
if isinstance(trigger_data, str):
|
||||||
trigger_data = json.loads(trigger_data)
|
trigger_data = json.loads(trigger_data)
|
||||||
@@ -72,6 +91,6 @@ class TriggerWordToggle:
|
|||||||
filtered_triggers = ""
|
filtered_triggers = ""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing trigger words: {e}")
|
logger.error(f"Error processing trigger words: {e}")
|
||||||
|
|
||||||
return (filtered_triggers,)
|
return (filtered_triggers,)
|
||||||
Reference in New Issue
Block a user