diff --git a/py/nodes/lora_loader.py b/py/nodes/lora_loader.py index 8d45e19c..f49eb34b 100644 --- a/py/nodes/lora_loader.py +++ b/py/nodes/lora_loader.py @@ -2,7 +2,8 @@ import logging from nodes import LoraLoader from comfy.comfy_types import IO # type: ignore import asyncio -from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list, nunchaku_load_lora +from ..utils.utils import get_lora_info +from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora logger = logging.getLogger(__name__) diff --git a/py/nodes/lora_stacker.py b/py/nodes/lora_stacker.py index 50a8c510..34d8adc0 100644 --- a/py/nodes/lora_stacker.py +++ b/py/nodes/lora_stacker.py @@ -1,7 +1,9 @@ from comfy.comfy_types import IO # type: ignore import asyncio import os -from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list +from ..utils.utils import get_lora_info +from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list + import logging logger = logging.getLogger(__name__) diff --git a/py/nodes/utils.py b/py/nodes/utils.py index 33a3c972..604b8fe7 100644 --- a/py/nodes/utils.py +++ b/py/nodes/utils.py @@ -41,30 +41,9 @@ import torch import safetensors.torch from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft from diffusers.loaders import FluxLoraLoaderMixin -from ..services.lora_scanner import LoraScanner -from ..config import config logger = logging.getLogger(__name__) -async def get_lora_info(lora_name): - """Get the lora path and trigger words from cache""" - scanner = await LoraScanner.get_instance() - cache = await scanner.get_cached_data() - - for item in cache.raw_data: - if item.get('file_name') == lora_name: - file_path = item.get('file_path') - if file_path: - for root in config.loras_roots: - root = root.replace(os.sep, '/') - if file_path.startswith(root): - relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') - # Get trigger words from civitai metadata - civitai = item.get('civitai', {}) - trigger_words = civitai.get('trainedWords', []) if civitai else [] - return relative_path, trigger_words - return lora_name, [] # Fallback if not found - def extract_lora_name(lora_path): """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" # Get the basename without extension diff --git a/py/nodes/wanvideo_lora_select.py b/py/nodes/wanvideo_lora_select.py index 056ed2da..44d2d25c 100644 --- a/py/nodes/wanvideo_lora_select.py +++ b/py/nodes/wanvideo_lora_select.py @@ -1,7 +1,8 @@ from comfy.comfy_types import IO # type: ignore import asyncio import folder_paths # type: ignore -from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list +from ..utils.utils import get_lora_info +from .utils import FlexibleOptionalInputType, any_type, get_loras_list import logging logger = logging.getLogger(__name__) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 38734511..26a15ea1 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -6,7 +6,7 @@ from typing import Dict from server import PromptServer # type: ignore from ..utils.routes_common import ModelRouteUtils -from ..nodes.utils import get_lora_info +from ..utils.utils import get_lora_info from ..config import config from ..services.websocket_manager import ws_manager diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index 2e18674d..59ff58b5 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -1,11 +1,7 @@ -import json import os import logging import asyncio -import shutil -import time -import re -from typing import List, Dict, Optional, Set +from typing import List, Dict, Optional from ..utils.models import LoraMetadata from ..config import config @@ -14,7 +10,6 @@ from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to Mo from .settings_manager import settings from ..utils.constants import NSFW_LEVELS from ..utils.utils import fuzzy_match -from .service_registry import ServiceRegistry import sys logger = logging.getLogger(__name__) diff --git a/py/utils/utils.py b/py/utils/utils.py index 994d8c54..710065f2 100644 --- a/py/utils/utils.py +++ b/py/utils/utils.py @@ -1,8 +1,29 @@ from difflib import SequenceMatcher import requests import tempfile -import re +import os from bs4 import BeautifulSoup +from ..services.service_registry import ServiceRegistry +from ..config import config + +async def get_lora_info(lora_name): + """Get the lora path and trigger words from cache""" + scanner = await ServiceRegistry.get_lora_scanner() + cache = await scanner.get_cached_data() + + for item in cache.raw_data: + if item.get('file_name') == lora_name: + file_path = item.get('file_path') + if file_path: + for root in config.loras_roots: + root = root.replace(os.sep, '/') + if file_path.startswith(root): + relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') + # Get trigger words from civitai metadata + civitai = item.get('civitai', {}) + trigger_words = civitai.get('trainedWords', []) if civitai else [] + return relative_path, trigger_words + return lora_name, [] def download_twitter_image(url): """Download image from a URL containing twitter:image meta tag diff --git a/standalone.py b/standalone.py index 73d98db2..b27128ba 100644 --- a/standalone.py +++ b/standalone.py @@ -3,6 +3,26 @@ import os import sys import json +# Create mock modules for py/nodes directory - add this before any other imports +def mock_nodes_directory(): + """Create mock modules for all Python files in the py/nodes directory""" + nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes') + if os.path.exists(nodes_dir): + # Create a mock module for the nodes package itself + sys.modules['py.nodes'] = type('MockNodesModule', (), {}) + + # Create mock modules for all Python files in the nodes directory + for file in os.listdir(nodes_dir): + if file.endswith('.py') and file != '__init__.py': + module_name = file[:-3] # Remove .py extension + full_module_name = f'py.nodes.{module_name}' + # Create empty module object + sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {}) + print(f"Created mock module for: {full_module_name}") + +# Run the mocking function before any other imports +mock_nodes_directory() + # Create mock folder_paths module BEFORE any other imports class MockFolderPaths: @staticmethod