diff --git a/py/nodes/lora_loader.py b/py/nodes/lora_loader.py index 766a17ec..ab5395c8 100644 --- a/py/nodes/lora_loader.py +++ b/py/nodes/lora_loader.py @@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner from ..config import config import asyncio import os -from .utils import FlexibleOptionalInputType, any_type +from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list logger = logging.getLogger(__name__) @@ -32,48 +32,6 @@ class LoraManagerLoader: RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING) RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras") FUNCTION = "load_loras" - - async def get_lora_info(self, lora_name): - """Get the lora path and trigger words from cache""" - scanner = await LoraScanner.get_instance() - cache = await scanner.get_cached_data() - - for item in cache.raw_data: - if item.get('file_name') == lora_name: - file_path = item.get('file_path') - if file_path: - for root in config.loras_roots: - root = root.replace(os.sep, '/') - if file_path.startswith(root): - relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') - # Get trigger words from civitai metadata - civitai = item.get('civitai', {}) - trigger_words = civitai.get('trainedWords', []) if civitai else [] - return relative_path, trigger_words - return lora_name, [] # Fallback if not found - - def extract_lora_name(self, lora_path): - """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" - # Get the basename without extension - basename = os.path.basename(lora_path) - return os.path.splitext(basename)[0] - - def _get_loras_list(self, kwargs): - """Helper to extract loras list from either old or new kwargs format""" - if 'loras' not in kwargs: - return [] - - loras_data = kwargs['loras'] - # Handle new format: {'loras': {'__value__': [...]}} - if isinstance(loras_data, dict) and '__value__' in loras_data: - return loras_data['__value__'] - # Handle old format: {'loras': [...]} - elif isinstance(loras_data, list): - return loras_data - # Unexpected format - else: - logger.warning(f"Unexpected loras format: {type(loras_data)}") - return [] def load_loras(self, model, text, **kwargs): """Loads multiple LoRAs based on the kwargs input and lora_stack.""" @@ -89,14 +47,14 @@ class LoraManagerLoader: model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) # Extract lora name for trigger words lookup - lora_name = self.extract_lora_name(lora_path) - _, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_name = extract_lora_name(lora_path) + _, trigger_words = asyncio.run(get_lora_info(lora_name)) all_trigger_words.extend(trigger_words) loaded_loras.append(f"{lora_name}: {model_strength}") # Then process loras from kwargs with support for both old and new formats - loras_list = self._get_loras_list(kwargs) + loras_list = get_loras_list(kwargs) for lora in loras_list: if not lora.get('active', False): continue @@ -105,7 +63,7 @@ class LoraManagerLoader: strength = float(lora['strength']) # Get lora path and trigger words - lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_path, trigger_words = asyncio.run(get_lora_info(lora_name)) # Apply the LoRA using the resolved path model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength) diff --git a/py/nodes/lora_stacker.py b/py/nodes/lora_stacker.py index ed6662cb..7f0a015b 100644 --- a/py/nodes/lora_stacker.py +++ b/py/nodes/lora_stacker.py @@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner from ..config import config import asyncio import os -from .utils import FlexibleOptionalInputType, any_type +from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list import logging logger = logging.getLogger(__name__) @@ -29,48 +29,6 @@ class LoraStacker: RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING) RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras") FUNCTION = "stack_loras" - - async def get_lora_info(self, lora_name): - """Get the lora path and trigger words from cache""" - scanner = await LoraScanner.get_instance() - cache = await scanner.get_cached_data() - - for item in cache.raw_data: - if item.get('file_name') == lora_name: - file_path = item.get('file_path') - if file_path: - for root in config.loras_roots: - root = root.replace(os.sep, '/') - if file_path.startswith(root): - relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') - # Get trigger words from civitai metadata - civitai = item.get('civitai', {}) - trigger_words = civitai.get('trainedWords', []) if civitai else [] - return relative_path, trigger_words - return lora_name, [] # Fallback if not found - - def extract_lora_name(self, lora_path): - """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" - # Get the basename without extension - basename = os.path.basename(lora_path) - return os.path.splitext(basename)[0] - - def _get_loras_list(self, kwargs): - """Helper to extract loras list from either old or new kwargs format""" - if 'loras' not in kwargs: - return [] - - loras_data = kwargs['loras'] - # Handle new format: {'loras': {'__value__': [...]}} - if isinstance(loras_data, dict) and '__value__' in loras_data: - return loras_data['__value__'] - # Handle old format: {'loras': [...]} - elif isinstance(loras_data, list): - return loras_data - # Unexpected format - else: - logger.warning(f"Unexpected loras format: {type(loras_data)}") - return [] def stack_loras(self, text, **kwargs): """Stacks multiple LoRAs based on the kwargs input without loading them.""" @@ -84,12 +42,12 @@ class LoraStacker: stack.extend(lora_stack) # Get trigger words from existing stack entries for lora_path, _, _ in lora_stack: - lora_name = self.extract_lora_name(lora_path) - _, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_name = extract_lora_name(lora_path) + _, trigger_words = asyncio.run(get_lora_info(lora_name)) all_trigger_words.extend(trigger_words) # Process loras from kwargs with support for both old and new formats - loras_list = self._get_loras_list(kwargs) + loras_list = get_loras_list(kwargs) for lora in loras_list: if not lora.get('active', False): continue @@ -99,7 +57,7 @@ class LoraStacker: clip_strength = model_strength # Using same strength for both as in the original loader # Get lora path and trigger words - lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_path, trigger_words = asyncio.run(get_lora_info(lora_name)) # Add to stack without loading # replace '/' with os.sep to avoid different OS path format diff --git a/py/nodes/utils.py b/py/nodes/utils.py index 89b96c97..1feb1a77 100644 --- a/py/nodes/utils.py +++ b/py/nodes/utils.py @@ -30,4 +30,55 @@ class FlexibleOptionalInputType(dict): return True -any_type = AnyType("*") \ No newline at end of file +any_type = AnyType("*") + +# Common methods extracted from lora_loader.py and lora_stacker.py +import os +import logging +import asyncio +from ..services.lora_scanner import LoraScanner +from ..config import config + +logger = logging.getLogger(__name__) + +async def get_lora_info(lora_name): + """Get the lora path and trigger words from cache""" + scanner = await LoraScanner.get_instance() + cache = await scanner.get_cached_data() + + for item in cache.raw_data: + if item.get('file_name') == lora_name: + file_path = item.get('file_path') + if file_path: + for root in config.loras_roots: + root = root.replace(os.sep, '/') + if file_path.startswith(root): + relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') + # Get trigger words from civitai metadata + civitai = item.get('civitai', {}) + trigger_words = civitai.get('trainedWords', []) if civitai else [] + return relative_path, trigger_words + return lora_name, [] # Fallback if not found + +def extract_lora_name(lora_path): + """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" + # Get the basename without extension + basename = os.path.basename(lora_path) + return os.path.splitext(basename)[0] + +def get_loras_list(kwargs): + """Helper to extract loras list from either old or new kwargs format""" + if 'loras' not in kwargs: + return [] + + loras_data = kwargs['loras'] + # Handle new format: {'loras': {'__value__': [...]}} + if isinstance(loras_data, dict) and '__value__' in loras_data: + return loras_data['__value__'] + # Handle old format: {'loras': [...]} + elif isinstance(loras_data, list): + return loras_data + # Unexpected format + else: + logger.warning(f"Unexpected loras format: {type(loras_data)}") + return [] \ No newline at end of file