mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: Add mock module creation for py/nodes directory to prevent loading modules from the nodes directory
This commit is contained in:
@@ -2,7 +2,8 @@ import logging
|
|||||||
from nodes import LoraLoader
|
from nodes import LoraLoader
|
||||||
from comfy.comfy_types import IO # type: ignore
|
from comfy.comfy_types import IO # type: ignore
|
||||||
import asyncio
|
import asyncio
|
||||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list, nunchaku_load_lora
|
from ..utils.utils import get_lora_info
|
||||||
|
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from comfy.comfy_types import IO # type: ignore
|
from comfy.comfy_types import IO # type: ignore
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
from ..utils.utils import get_lora_info
|
||||||
|
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -41,30 +41,9 @@ import torch
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||||
from diffusers.loaders import FluxLoraLoaderMixin
|
from diffusers.loaders import FluxLoraLoaderMixin
|
||||||
from ..services.lora_scanner import LoraScanner
|
|
||||||
from ..config import config
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
async def get_lora_info(lora_name):
|
|
||||||
"""Get the lora path and trigger words from cache"""
|
|
||||||
scanner = await LoraScanner.get_instance()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == lora_name:
|
|
||||||
file_path = item.get('file_path')
|
|
||||||
if file_path:
|
|
||||||
for root in config.loras_roots:
|
|
||||||
root = root.replace(os.sep, '/')
|
|
||||||
if file_path.startswith(root):
|
|
||||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
|
||||||
# Get trigger words from civitai metadata
|
|
||||||
civitai = item.get('civitai', {})
|
|
||||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
|
||||||
return relative_path, trigger_words
|
|
||||||
return lora_name, [] # Fallback if not found
|
|
||||||
|
|
||||||
def extract_lora_name(lora_path):
|
def extract_lora_name(lora_path):
|
||||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||||
# Get the basename without extension
|
# Get the basename without extension
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
from comfy.comfy_types import IO # type: ignore
|
from comfy.comfy_types import IO # type: ignore
|
||||||
import asyncio
|
import asyncio
|
||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
from ..utils.utils import get_lora_info
|
||||||
|
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import Dict
|
|||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
from ..nodes.utils import get_lora_info
|
from ..utils.utils import get_lora_info
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.websocket_manager import ws_manager
|
from ..services.websocket_manager import ws_manager
|
||||||
|
|||||||
@@ -1,11 +1,7 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import shutil
|
from typing import List, Dict, Optional
|
||||||
import time
|
|
||||||
import re
|
|
||||||
from typing import List, Dict, Optional, Set
|
|
||||||
|
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
@@ -14,7 +10,6 @@ from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to Mo
|
|||||||
from .settings_manager import settings
|
from .settings_manager import settings
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from ..utils.constants import NSFW_LEVELS
|
||||||
from ..utils.utils import fuzzy_match
|
from ..utils.utils import fuzzy_match
|
||||||
from .service_registry import ServiceRegistry
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -1,8 +1,29 @@
|
|||||||
from difflib import SequenceMatcher
|
from difflib import SequenceMatcher
|
||||||
import requests
|
import requests
|
||||||
import tempfile
|
import tempfile
|
||||||
import re
|
import os
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
|
async def get_lora_info(lora_name):
|
||||||
|
"""Get the lora path and trigger words from cache"""
|
||||||
|
scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get('file_name') == lora_name:
|
||||||
|
file_path = item.get('file_path')
|
||||||
|
if file_path:
|
||||||
|
for root in config.loras_roots:
|
||||||
|
root = root.replace(os.sep, '/')
|
||||||
|
if file_path.startswith(root):
|
||||||
|
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||||
|
# Get trigger words from civitai metadata
|
||||||
|
civitai = item.get('civitai', {})
|
||||||
|
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||||
|
return relative_path, trigger_words
|
||||||
|
return lora_name, []
|
||||||
|
|
||||||
def download_twitter_image(url):
|
def download_twitter_image(url):
|
||||||
"""Download image from a URL containing twitter:image meta tag
|
"""Download image from a URL containing twitter:image meta tag
|
||||||
|
|||||||
@@ -3,6 +3,26 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
# Create mock modules for py/nodes directory - add this before any other imports
|
||||||
|
def mock_nodes_directory():
|
||||||
|
"""Create mock modules for all Python files in the py/nodes directory"""
|
||||||
|
nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes')
|
||||||
|
if os.path.exists(nodes_dir):
|
||||||
|
# Create a mock module for the nodes package itself
|
||||||
|
sys.modules['py.nodes'] = type('MockNodesModule', (), {})
|
||||||
|
|
||||||
|
# Create mock modules for all Python files in the nodes directory
|
||||||
|
for file in os.listdir(nodes_dir):
|
||||||
|
if file.endswith('.py') and file != '__init__.py':
|
||||||
|
module_name = file[:-3] # Remove .py extension
|
||||||
|
full_module_name = f'py.nodes.{module_name}'
|
||||||
|
# Create empty module object
|
||||||
|
sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {})
|
||||||
|
print(f"Created mock module for: {full_module_name}")
|
||||||
|
|
||||||
|
# Run the mocking function before any other imports
|
||||||
|
mock_nodes_directory()
|
||||||
|
|
||||||
# Create mock folder_paths module BEFORE any other imports
|
# Create mock folder_paths module BEFORE any other imports
|
||||||
class MockFolderPaths:
|
class MockFolderPaths:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user