diff --git a/py/nodes/lora_loader.py b/py/nodes/lora_loader.py index aae8d882..d0bde635 100644 --- a/py/nodes/lora_loader.py +++ b/py/nodes/lora_loader.py @@ -1,3 +1,4 @@ +import logging from nodes import LoraLoader from comfy.comfy_types import IO # type: ignore from ..services.lora_scanner import LoraScanner @@ -6,6 +7,8 @@ import asyncio import os from .utils import FlexibleOptionalInputType, any_type +logger = logging.getLogger(__name__) + class LoraManagerLoader: NAME = "Lora Loader (LoraManager)" CATEGORY = "Lora Manager/loaders" @@ -55,6 +58,23 @@ class LoraManagerLoader: basename = os.path.basename(lora_path) return os.path.splitext(basename)[0] + def _get_loras_list(self, kwargs): + """Helper to extract loras list from either old or new kwargs format""" + if 'loras' not in kwargs: + return [] + + loras_data = kwargs['loras'] + # Handle new format: {'loras': {'__value__': [...]}} + if isinstance(loras_data, dict) and '__value__' in loras_data: + return loras_data['__value__'] + # Handle old format: {'loras': [...]} + elif isinstance(loras_data, list): + return loras_data + # Unexpected format + else: + logger.warning(f"Unexpected loras format: {type(loras_data)}") + return [] + def load_loras(self, model, clip, text, **kwargs): """Loads multiple LoRAs based on the kwargs input and lora_stack.""" loaded_loras = [] @@ -74,24 +94,24 @@ class LoraManagerLoader: all_trigger_words.extend(trigger_words) loaded_loras.append(f"{lora_name}: {model_strength}") - # Then process loras from kwargs - if 'loras' in kwargs: - for lora in kwargs['loras']: - if not lora.get('active', False): - continue - - lora_name = lora['name'] - strength = float(lora['strength']) + # Then process loras from kwargs with support for both old and new formats + loras_list = self._get_loras_list(kwargs) + for lora in loras_list: + if not lora.get('active', False): + continue - # Get lora path and trigger words - lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) - - # Apply the LoRA using the resolved path - model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength) - loaded_loras.append(f"{lora_name}: {strength}") - - # Add trigger words to collection - all_trigger_words.extend(trigger_words) + lora_name = lora['name'] + strength = float(lora['strength']) + + # Get lora path and trigger words + lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + + # Apply the LoRA using the resolved path + model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength) + loaded_loras.append(f"{lora_name}: {strength}") + + # Add trigger words to collection + all_trigger_words.extend(trigger_words) # use ',, ' to separate trigger words for group mode trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" diff --git a/py/nodes/lora_stacker.py b/py/nodes/lora_stacker.py index 7909cfbf..b264bd27 100644 --- a/py/nodes/lora_stacker.py +++ b/py/nodes/lora_stacker.py @@ -4,6 +4,9 @@ from ..config import config import asyncio import os from .utils import FlexibleOptionalInputType, any_type +import logging + +logger = logging.getLogger(__name__) class LoraStacker: NAME = "Lora Stacker (LoraManager)" @@ -52,6 +55,23 @@ class LoraStacker: basename = os.path.basename(lora_path) return os.path.splitext(basename)[0] + def _get_loras_list(self, kwargs): + """Helper to extract loras list from either old or new kwargs format""" + if 'loras' not in kwargs: + return [] + + loras_data = kwargs['loras'] + # Handle new format: {'loras': {'__value__': [...]}} + if isinstance(loras_data, dict) and '__value__' in loras_data: + return loras_data['__value__'] + # Handle old format: {'loras': [...]} + elif isinstance(loras_data, list): + return loras_data + # Unexpected format + else: + logger.warning(f"Unexpected loras format: {type(loras_data)}") + return [] + def stack_loras(self, text, **kwargs): """Stacks multiple LoRAs based on the kwargs input without loading them.""" stack = [] @@ -67,24 +87,25 @@ class LoraStacker: _, trigger_words = asyncio.run(self.get_lora_info(lora_name)) all_trigger_words.extend(trigger_words) - if 'loras' in kwargs: - for lora in kwargs['loras']: - if not lora.get('active', False): - continue - - lora_name = lora['name'] - model_strength = float(lora['strength']) - clip_strength = model_strength # Using same strength for both as in the original loader + # Process loras from kwargs with support for both old and new formats + loras_list = self._get_loras_list(kwargs) + for lora in loras_list: + if not lora.get('active', False): + continue - # Get lora path and trigger words - lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) - - # Add to stack without loading - # replace '/' with os.sep to avoid different OS path format - stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength)) - - # Add trigger words to collection - all_trigger_words.extend(trigger_words) + lora_name = lora['name'] + model_strength = float(lora['strength']) + clip_strength = model_strength # Using same strength for both as in the original loader + + # Get lora path and trigger words + lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + + # Add to stack without loading + # replace '/' with os.sep to avoid different OS path format + stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength)) + + # Add trigger words to collection + all_trigger_words.extend(trigger_words) # use ',, ' to separate trigger words for group mode trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" diff --git a/py/nodes/trigger_word_toggle.py b/py/nodes/trigger_word_toggle.py index bdd771dd..16b72f55 100644 --- a/py/nodes/trigger_word_toggle.py +++ b/py/nodes/trigger_word_toggle.py @@ -2,6 +2,10 @@ import json import re from server import PromptServer # type: ignore from .utils import FlexibleOptionalInputType, any_type +import logging + +logger = logging.getLogger(__name__) + class TriggerWordToggle: NAME = "TriggerWord Toggle (LoraManager)" @@ -24,8 +28,24 @@ class TriggerWordToggle: RETURN_NAMES = ("filtered_trigger_words",) FUNCTION = "process_trigger_words" + def _get_toggle_data(self, kwargs, key='toggle_trigger_words'): + """Helper to extract data from either old or new kwargs format""" + if key not in kwargs: + return None + + data = kwargs[key] + # Handle new format: {'key': {'__value__': ...}} + if isinstance(data, dict) and '__value__' in data: + return data['__value__'] + # Handle old format: {'key': ...} + else: + return data + def process_trigger_words(self, id, group_mode, **kwargs): - trigger_words = kwargs.get("trigger_words", "") + # Handle both old and new formats for trigger_words + trigger_words_data = self._get_toggle_data(kwargs, 'trigger_words') + trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else "" + # Send trigger words to frontend PromptServer.instance.send_sync("trigger_word_update", { "id": id, @@ -34,11 +54,10 @@ class TriggerWordToggle: filtered_triggers = trigger_words - if 'toggle_trigger_words' in kwargs: + # Get toggle data with support for both formats + trigger_data = self._get_toggle_data(kwargs, 'toggle_trigger_words') + if trigger_data: try: - # Get trigger word toggle data - trigger_data = kwargs['toggle_trigger_words'] - # Convert to list if it's a JSON string if isinstance(trigger_data, str): trigger_data = json.loads(trigger_data) @@ -72,6 +91,6 @@ class TriggerWordToggle: filtered_triggers = "" except Exception as e: - print(f"Error processing trigger words: {e}") + logger.error(f"Error processing trigger words: {e}") return (filtered_triggers,) \ No newline at end of file