feat: Add mock module creation for py/nodes directory to prevent loading modules from the nodes directory

This commit is contained in:
Will Miao
2025-06-30 20:19:37 +08:00
parent afe23ad6b7
commit cad5fb3fba
8 changed files with 51 additions and 32 deletions

View File

@@ -2,7 +2,8 @@ import logging
from nodes import LoraLoader
from comfy.comfy_types import IO # type: ignore
import asyncio
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list, nunchaku_load_lora
from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
logger = logging.getLogger(__name__)

View File

@@ -1,7 +1,9 @@
from comfy.comfy_types import IO # type: ignore
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
import logging
logger = logging.getLogger(__name__)

View File

@@ -41,30 +41,9 @@ import torch
import safetensors.torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin
from ..services.lora_scanner import LoraScanner
from ..config import config
logger = logging.getLogger(__name__)
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension

View File

@@ -1,7 +1,8 @@
from comfy.comfy_types import IO # type: ignore
import asyncio
import folder_paths # type: ignore
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
import logging
logger = logging.getLogger(__name__)

View File

@@ -6,7 +6,7 @@ from typing import Dict
from server import PromptServer # type: ignore
from ..utils.routes_common import ModelRouteUtils
from ..nodes.utils import get_lora_info
from ..utils.utils import get_lora_info
from ..config import config
from ..services.websocket_manager import ws_manager

View File

@@ -1,11 +1,7 @@
import json
import os
import logging
import asyncio
import shutil
import time
import re
from typing import List, Dict, Optional, Set
from typing import List, Dict, Optional
from ..utils.models import LoraMetadata
from ..config import config
@@ -14,7 +10,6 @@ from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to Mo
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
from .service_registry import ServiceRegistry
import sys
logger = logging.getLogger(__name__)

View File

@@ -1,8 +1,29 @@
from difflib import SequenceMatcher
import requests
import tempfile
import re
import os
from bs4 import BeautifulSoup
from ..services.service_registry import ServiceRegistry
from ..config import config
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, []
def download_twitter_image(url):
"""Download image from a URL containing twitter:image meta tag

View File

@@ -3,6 +3,26 @@ import os
import sys
import json
# Create mock modules for py/nodes directory - add this before any other imports
def mock_nodes_directory():
"""Create mock modules for all Python files in the py/nodes directory"""
nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes')
if os.path.exists(nodes_dir):
# Create a mock module for the nodes package itself
sys.modules['py.nodes'] = type('MockNodesModule', (), {})
# Create mock modules for all Python files in the nodes directory
for file in os.listdir(nodes_dir):
if file.endswith('.py') and file != '__init__.py':
module_name = file[:-3] # Remove .py extension
full_module_name = f'py.nodes.{module_name}'
# Create empty module object
sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {})
print(f"Created mock module for: {full_module_name}")
# Run the mocking function before any other imports
mock_nodes_directory()
# Create mock folder_paths module BEFORE any other imports
class MockFolderPaths:
@staticmethod