Refactor LoRA handling in LoraLoader, LoraStacker, and TriggerWordToggle

- Introduced logging to track unexpected formats in LoRA and trigger word data.
- Refactored LoRA processing to support both old and new kwargs formats in LoraLoader and LoraStacker.
- Enhanced trigger word processing to handle different data formats in TriggerWordToggle.
- Improved code readability and maintainability by extracting common logic into helper methods.
This commit is contained in:
Will Miao
2025-03-22 15:56:37 +08:00
parent a31712ad1f
commit e7dffbbb1e
3 changed files with 100 additions and 40 deletions

View File

@@ -1,3 +1,4 @@
import logging
from nodes import LoraLoader from nodes import LoraLoader
from comfy.comfy_types import IO # type: ignore from comfy.comfy_types import IO # type: ignore
from ..services.lora_scanner import LoraScanner from ..services.lora_scanner import LoraScanner
@@ -6,6 +7,8 @@ import asyncio
import os import os
from .utils import FlexibleOptionalInputType, any_type from .utils import FlexibleOptionalInputType, any_type
logger = logging.getLogger(__name__)
class LoraManagerLoader: class LoraManagerLoader:
NAME = "Lora Loader (LoraManager)" NAME = "Lora Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders" CATEGORY = "Lora Manager/loaders"
@@ -55,6 +58,23 @@ class LoraManagerLoader:
basename = os.path.basename(lora_path) basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0] return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def load_loras(self, model, clip, text, **kwargs): def load_loras(self, model, clip, text, **kwargs):
"""Loads multiple LoRAs based on the kwargs input and lora_stack.""" """Loads multiple LoRAs based on the kwargs input and lora_stack."""
loaded_loras = [] loaded_loras = []
@@ -74,24 +94,24 @@ class LoraManagerLoader:
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
loaded_loras.append(f"{lora_name}: {model_strength}") loaded_loras.append(f"{lora_name}: {model_strength}")
# Then process loras from kwargs # Then process loras from kwargs with support for both old and new formats
if 'loras' in kwargs: loras_list = self._get_loras_list(kwargs)
for lora in kwargs['loras']: for lora in loras_list:
if not lora.get('active', False): if not lora.get('active', False):
continue continue
lora_name = lora['name'] lora_name = lora['name']
strength = float(lora['strength']) strength = float(lora['strength'])
# Get lora path and trigger words # Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
# Apply the LoRA using the resolved path # Apply the LoRA using the resolved path
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength) model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
loaded_loras.append(f"{lora_name}: {strength}") loaded_loras.append(f"{lora_name}: {strength}")
# Add trigger words to collection # Add trigger words to collection
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
# use ',, ' to separate trigger words for group mode # use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""

View File

@@ -4,6 +4,9 @@ from ..config import config
import asyncio import asyncio
import os import os
from .utils import FlexibleOptionalInputType, any_type from .utils import FlexibleOptionalInputType, any_type
import logging
logger = logging.getLogger(__name__)
class LoraStacker: class LoraStacker:
NAME = "Lora Stacker (LoraManager)" NAME = "Lora Stacker (LoraManager)"
@@ -52,6 +55,23 @@ class LoraStacker:
basename = os.path.basename(lora_path) basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0] return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def stack_loras(self, text, **kwargs): def stack_loras(self, text, **kwargs):
"""Stacks multiple LoRAs based on the kwargs input without loading them.""" """Stacks multiple LoRAs based on the kwargs input without loading them."""
stack = [] stack = []
@@ -67,24 +87,25 @@ class LoraStacker:
_, trigger_words = asyncio.run(self.get_lora_info(lora_name)) _, trigger_words = asyncio.run(self.get_lora_info(lora_name))
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
if 'loras' in kwargs: # Process loras from kwargs with support for both old and new formats
for lora in kwargs['loras']: loras_list = self._get_loras_list(kwargs)
if not lora.get('active', False): for lora in loras_list:
continue if not lora.get('active', False):
continue
lora_name = lora['name'] lora_name = lora['name']
model_strength = float(lora['strength']) model_strength = float(lora['strength'])
clip_strength = model_strength # Using same strength for both as in the original loader clip_strength = model_strength # Using same strength for both as in the original loader
# Get lora path and trigger words # Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
# Add to stack without loading # Add to stack without loading
# replace '/' with os.sep to avoid different OS path format # replace '/' with os.sep to avoid different OS path format
stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength)) stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength))
# Add trigger words to collection # Add trigger words to collection
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
# use ',, ' to separate trigger words for group mode # use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""

View File

@@ -2,6 +2,10 @@ import json
import re import re
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from .utils import FlexibleOptionalInputType, any_type from .utils import FlexibleOptionalInputType, any_type
import logging
logger = logging.getLogger(__name__)
class TriggerWordToggle: class TriggerWordToggle:
NAME = "TriggerWord Toggle (LoraManager)" NAME = "TriggerWord Toggle (LoraManager)"
@@ -24,8 +28,24 @@ class TriggerWordToggle:
RETURN_NAMES = ("filtered_trigger_words",) RETURN_NAMES = ("filtered_trigger_words",)
FUNCTION = "process_trigger_words" FUNCTION = "process_trigger_words"
def _get_toggle_data(self, kwargs, key='toggle_trigger_words'):
"""Helper to extract data from either old or new kwargs format"""
if key not in kwargs:
return None
data = kwargs[key]
# Handle new format: {'key': {'__value__': ...}}
if isinstance(data, dict) and '__value__' in data:
return data['__value__']
# Handle old format: {'key': ...}
else:
return data
def process_trigger_words(self, id, group_mode, **kwargs): def process_trigger_words(self, id, group_mode, **kwargs):
trigger_words = kwargs.get("trigger_words", "") # Handle both old and new formats for trigger_words
trigger_words_data = self._get_toggle_data(kwargs, 'trigger_words')
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
# Send trigger words to frontend # Send trigger words to frontend
PromptServer.instance.send_sync("trigger_word_update", { PromptServer.instance.send_sync("trigger_word_update", {
"id": id, "id": id,
@@ -34,11 +54,10 @@ class TriggerWordToggle:
filtered_triggers = trigger_words filtered_triggers = trigger_words
if 'toggle_trigger_words' in kwargs: # Get toggle data with support for both formats
trigger_data = self._get_toggle_data(kwargs, 'toggle_trigger_words')
if trigger_data:
try: try:
# Get trigger word toggle data
trigger_data = kwargs['toggle_trigger_words']
# Convert to list if it's a JSON string # Convert to list if it's a JSON string
if isinstance(trigger_data, str): if isinstance(trigger_data, str):
trigger_data = json.loads(trigger_data) trigger_data = json.loads(trigger_data)
@@ -72,6 +91,6 @@ class TriggerWordToggle:
filtered_triggers = "" filtered_triggers = ""
except Exception as e: except Exception as e:
print(f"Error processing trigger words: {e}") logger.error(f"Error processing trigger words: {e}")
return (filtered_triggers,) return (filtered_triggers,)