From bbf7295c3240057f0dd8c4b432a217b77ef7369c Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 19 Apr 2025 21:42:01 +0300 Subject: [PATCH 1/3] Prevent duplicates of root folders when using symlinks --- py/config.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/py/config.py b/py/config.py index a081f472..83926aa0 100644 --- a/py/config.py +++ b/py/config.py @@ -99,21 +99,29 @@ class Config: def _init_lora_paths(self) -> List[str]: """Initialize and validate LoRA paths from ComfyUI settings""" - paths = sorted(set(path.replace(os.sep, "/") - for path in folder_paths.get_folder_paths("loras") - if os.path.exists(path)), key=lambda p: p.lower()) - print("Found LoRA roots:", "\n - " + "\n - ".join(paths)) + raw_paths = folder_paths.get_folder_paths("loras") - if not paths: + # Normalize and resolve symlinks, store mapping from resolved -> original + path_map = {} + for path in raw_paths: + if os.path.exists(path): + real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') + path_map[real_path] = path_map.get(real_path, path) # preserve first seen + + # Now sort and use only the deduplicated real paths + unique_paths = sorted(path_map.values(), key=lambda p: p.lower()) + print("Found LoRA roots:", "\n - " + "\n - ".join(unique_paths)) + + if not unique_paths: raise ValueError("No valid loras folders found in ComfyUI configuration") - # 初始化路径映射 - for path in paths: - real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') - if real_path != path: - self.add_path_mapping(path, real_path) + for original_path in unique_paths: + real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/') + if real_path != original_path: + self.add_path_mapping(original_path, real_path) - return paths + return unique_paths + def get_preview_static_url(self, preview_path: str) -> str: """Convert local preview path to static URL""" From 9bb9e7b64dd7b9c9763abcfa8db08332e33a48e2 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 20 Apr 2025 21:35:36 +0800 Subject: [PATCH 2/3] refactor: Extract common methods for Lora handling into utils.py and update references in lora_loader.py and lora_stacker.py --- py/nodes/lora_loader.py | 52 ++++----------------------------------- py/nodes/lora_stacker.py | 52 ++++----------------------------------- py/nodes/utils.py | 53 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 62 insertions(+), 95 deletions(-) diff --git a/py/nodes/lora_loader.py b/py/nodes/lora_loader.py index 766a17ec..ab5395c8 100644 --- a/py/nodes/lora_loader.py +++ b/py/nodes/lora_loader.py @@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner from ..config import config import asyncio import os -from .utils import FlexibleOptionalInputType, any_type +from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list logger = logging.getLogger(__name__) @@ -32,48 +32,6 @@ class LoraManagerLoader: RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING) RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras") FUNCTION = "load_loras" - - async def get_lora_info(self, lora_name): - """Get the lora path and trigger words from cache""" - scanner = await LoraScanner.get_instance() - cache = await scanner.get_cached_data() - - for item in cache.raw_data: - if item.get('file_name') == lora_name: - file_path = item.get('file_path') - if file_path: - for root in config.loras_roots: - root = root.replace(os.sep, '/') - if file_path.startswith(root): - relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') - # Get trigger words from civitai metadata - civitai = item.get('civitai', {}) - trigger_words = civitai.get('trainedWords', []) if civitai else [] - return relative_path, trigger_words - return lora_name, [] # Fallback if not found - - def extract_lora_name(self, lora_path): - """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" - # Get the basename without extension - basename = os.path.basename(lora_path) - return os.path.splitext(basename)[0] - - def _get_loras_list(self, kwargs): - """Helper to extract loras list from either old or new kwargs format""" - if 'loras' not in kwargs: - return [] - - loras_data = kwargs['loras'] - # Handle new format: {'loras': {'__value__': [...]}} - if isinstance(loras_data, dict) and '__value__' in loras_data: - return loras_data['__value__'] - # Handle old format: {'loras': [...]} - elif isinstance(loras_data, list): - return loras_data - # Unexpected format - else: - logger.warning(f"Unexpected loras format: {type(loras_data)}") - return [] def load_loras(self, model, text, **kwargs): """Loads multiple LoRAs based on the kwargs input and lora_stack.""" @@ -89,14 +47,14 @@ class LoraManagerLoader: model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) # Extract lora name for trigger words lookup - lora_name = self.extract_lora_name(lora_path) - _, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_name = extract_lora_name(lora_path) + _, trigger_words = asyncio.run(get_lora_info(lora_name)) all_trigger_words.extend(trigger_words) loaded_loras.append(f"{lora_name}: {model_strength}") # Then process loras from kwargs with support for both old and new formats - loras_list = self._get_loras_list(kwargs) + loras_list = get_loras_list(kwargs) for lora in loras_list: if not lora.get('active', False): continue @@ -105,7 +63,7 @@ class LoraManagerLoader: strength = float(lora['strength']) # Get lora path and trigger words - lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_path, trigger_words = asyncio.run(get_lora_info(lora_name)) # Apply the LoRA using the resolved path model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength) diff --git a/py/nodes/lora_stacker.py b/py/nodes/lora_stacker.py index ed6662cb..7f0a015b 100644 --- a/py/nodes/lora_stacker.py +++ b/py/nodes/lora_stacker.py @@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner from ..config import config import asyncio import os -from .utils import FlexibleOptionalInputType, any_type +from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list import logging logger = logging.getLogger(__name__) @@ -29,48 +29,6 @@ class LoraStacker: RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING) RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras") FUNCTION = "stack_loras" - - async def get_lora_info(self, lora_name): - """Get the lora path and trigger words from cache""" - scanner = await LoraScanner.get_instance() - cache = await scanner.get_cached_data() - - for item in cache.raw_data: - if item.get('file_name') == lora_name: - file_path = item.get('file_path') - if file_path: - for root in config.loras_roots: - root = root.replace(os.sep, '/') - if file_path.startswith(root): - relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') - # Get trigger words from civitai metadata - civitai = item.get('civitai', {}) - trigger_words = civitai.get('trainedWords', []) if civitai else [] - return relative_path, trigger_words - return lora_name, [] # Fallback if not found - - def extract_lora_name(self, lora_path): - """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" - # Get the basename without extension - basename = os.path.basename(lora_path) - return os.path.splitext(basename)[0] - - def _get_loras_list(self, kwargs): - """Helper to extract loras list from either old or new kwargs format""" - if 'loras' not in kwargs: - return [] - - loras_data = kwargs['loras'] - # Handle new format: {'loras': {'__value__': [...]}} - if isinstance(loras_data, dict) and '__value__' in loras_data: - return loras_data['__value__'] - # Handle old format: {'loras': [...]} - elif isinstance(loras_data, list): - return loras_data - # Unexpected format - else: - logger.warning(f"Unexpected loras format: {type(loras_data)}") - return [] def stack_loras(self, text, **kwargs): """Stacks multiple LoRAs based on the kwargs input without loading them.""" @@ -84,12 +42,12 @@ class LoraStacker: stack.extend(lora_stack) # Get trigger words from existing stack entries for lora_path, _, _ in lora_stack: - lora_name = self.extract_lora_name(lora_path) - _, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_name = extract_lora_name(lora_path) + _, trigger_words = asyncio.run(get_lora_info(lora_name)) all_trigger_words.extend(trigger_words) # Process loras from kwargs with support for both old and new formats - loras_list = self._get_loras_list(kwargs) + loras_list = get_loras_list(kwargs) for lora in loras_list: if not lora.get('active', False): continue @@ -99,7 +57,7 @@ class LoraStacker: clip_strength = model_strength # Using same strength for both as in the original loader # Get lora path and trigger words - lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) + lora_path, trigger_words = asyncio.run(get_lora_info(lora_name)) # Add to stack without loading # replace '/' with os.sep to avoid different OS path format diff --git a/py/nodes/utils.py b/py/nodes/utils.py index 89b96c97..1feb1a77 100644 --- a/py/nodes/utils.py +++ b/py/nodes/utils.py @@ -30,4 +30,55 @@ class FlexibleOptionalInputType(dict): return True -any_type = AnyType("*") \ No newline at end of file +any_type = AnyType("*") + +# Common methods extracted from lora_loader.py and lora_stacker.py +import os +import logging +import asyncio +from ..services.lora_scanner import LoraScanner +from ..config import config + +logger = logging.getLogger(__name__) + +async def get_lora_info(lora_name): + """Get the lora path and trigger words from cache""" + scanner = await LoraScanner.get_instance() + cache = await scanner.get_cached_data() + + for item in cache.raw_data: + if item.get('file_name') == lora_name: + file_path = item.get('file_path') + if file_path: + for root in config.loras_roots: + root = root.replace(os.sep, '/') + if file_path.startswith(root): + relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') + # Get trigger words from civitai metadata + civitai = item.get('civitai', {}) + trigger_words = civitai.get('trainedWords', []) if civitai else [] + return relative_path, trigger_words + return lora_name, [] # Fallback if not found + +def extract_lora_name(lora_path): + """Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')""" + # Get the basename without extension + basename = os.path.basename(lora_path) + return os.path.splitext(basename)[0] + +def get_loras_list(kwargs): + """Helper to extract loras list from either old or new kwargs format""" + if 'loras' not in kwargs: + return [] + + loras_data = kwargs['loras'] + # Handle new format: {'loras': {'__value__': [...]}} + if isinstance(loras_data, dict) and '__value__' in loras_data: + return loras_data['__value__'] + # Handle old format: {'loras': [...]} + elif isinstance(loras_data, list): + return loras_data + # Unexpected format + else: + logger.warning(f"Unexpected loras format: {type(loras_data)}") + return [] \ No newline at end of file From e70fd73bdde07dda485a3821741592eb55f95fbf Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 20 Apr 2025 22:27:53 +0800 Subject: [PATCH 3/3] feat: Implement trigger words API and update frontend integration for LoraManager. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/43 --- py/nodes/trigger_word_toggle.py | 8 ++--- py/routes/api_routes.py | 36 +++++++++++++++++++++ py/server_routes.py | 26 +++++++++++++++ web/comfyui/lora_loader.js | 57 +++++++++++++++++++++++++++++++++ 4 files changed, 123 insertions(+), 4 deletions(-) create mode 100644 py/server_routes.py diff --git a/py/nodes/trigger_word_toggle.py b/py/nodes/trigger_word_toggle.py index 16b72f55..61bcd123 100644 --- a/py/nodes/trigger_word_toggle.py +++ b/py/nodes/trigger_word_toggle.py @@ -47,10 +47,10 @@ class TriggerWordToggle: trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else "" # Send trigger words to frontend - PromptServer.instance.send_sync("trigger_word_update", { - "id": id, - "message": trigger_words - }) + # PromptServer.instance.send_sync("trigger_word_update", { + # "id": id, + # "message": trigger_words + # }) filtered_triggers = trigger_words diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 39d0d429..06aff4a5 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -3,8 +3,10 @@ import json import logging from aiohttp import web from typing import Dict +from server import PromptServer # type: ignore from ..utils.routes_common import ModelRouteUtils +from ..nodes.utils import get_lora_info from ..config import config from ..services.websocket_manager import ws_manager @@ -64,6 +66,9 @@ class ApiRoutes: app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files + + # Add the new trigger words route + app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words) # Add update check routes UpdateRoutes.setup_routes(app) @@ -1021,4 +1026,35 @@ class ApiRoutes: return web.json_response({ 'success': False, 'error': str(e) + }, status=500) + + async def get_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for specified LoRA models""" + try: + json_data = await request.json() + lora_names = json_data.get("lora_names", []) + node_ids = json_data.get("node_ids", []) + + all_trigger_words = [] + for lora_name in lora_names: + _, trigger_words = await get_lora_info(lora_name) + all_trigger_words.extend(trigger_words) + + # Format the trigger words + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + # Send update to all connected trigger word toggle nodes + for node_id in node_ids: + PromptServer.instance.send_sync("trigger_word_update", { + "id": node_id, + "message": trigger_words_text + }) + + return web.json_response({"success": True}) + + except Exception as e: + logger.error(f"Error getting trigger words: {e}") + return web.json_response({ + "success": False, + "error": str(e) }, status=500) \ No newline at end of file diff --git a/py/server_routes.py b/py/server_routes.py new file mode 100644 index 00000000..68ee9749 --- /dev/null +++ b/py/server_routes.py @@ -0,0 +1,26 @@ +from aiohttp import web +from server import PromptServer +from .nodes.utils import get_lora_info + +@PromptServer.instance.routes.post("/loramanager/get_trigger_words") +async def get_trigger_words(request): + json_data = await request.json() + lora_names = json_data.get("lora_names", []) + node_ids = json_data.get("node_ids", []) + + all_trigger_words = [] + for lora_name in lora_names: + _, trigger_words = await get_lora_info(lora_name) + all_trigger_words.extend(trigger_words) + + # Format the trigger words + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + # Send update to all connected trigger word toggle nodes + for node_id in node_ids: + PromptServer.instance.send_sync("trigger_word_update", { + "id": node_id, + "message": trigger_words_text + }) + + return web.json_response({"success": True}) diff --git a/web/comfyui/lora_loader.js b/web/comfyui/lora_loader.js index 62f01f1c..f8784109 100644 --- a/web/comfyui/lora_loader.js +++ b/web/comfyui/lora_loader.js @@ -9,6 +9,57 @@ async function getLorasWidgetModule() { return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js"); } +// Function to get connected trigger toggle nodes +function getConnectedTriggerToggleNodes(node) { + const connectedNodes = []; + + // Check if node has outputs + if (node.outputs && node.outputs.length > 0) { + // For each output slot + for (const output of node.outputs) { + // Check if this output has any links + if (output.links && output.links.length > 0) { + // For each link, get the target node + for (const linkId of output.links) { + const link = app.graph.links[linkId]; + if (link) { + const targetNode = app.graph.getNodeById(link.target_id); + if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") { + connectedNodes.push(targetNode.id); + } + } + } + } + } + } + return connectedNodes; +} + +// Function to update trigger words for connected toggle nodes +function updateConnectedTriggerWords(node, text) { + const connectedNodeIds = getConnectedTriggerToggleNodes(node); + if (connectedNodeIds.length > 0) { + // Extract lora names from the text + const loraNames = []; + let match; + // Reset the RegExp object's lastIndex to start from the beginning + LORA_PATTERN.lastIndex = 0; + while ((match = LORA_PATTERN.exec(text)) !== null) { + loraNames.push(match[1]); // match[1] contains the lora name + } + + // Call API to get trigger words + fetch("/loramanager/get_trigger_words", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + lora_names: loraNames, + node_ids: connectedNodeIds + }) + }).catch(err => console.error("Error fetching trigger words:", err)); + } +} + function mergeLoras(lorasText, lorasArr) { const result = []; let match; @@ -99,6 +150,9 @@ app.registerExtension({ newText = newText.replace(/\s+/g, ' ').trim(); inputWidget.value = newText; + + // Add this line to update trigger words when lorasWidget changes cause inputWidget value to change + updateConnectedTriggerWords(node, newText); } finally { isUpdating = false; } @@ -117,6 +171,9 @@ app.registerExtension({ const mergedLoras = mergeLoras(value, currentLoras); node.lorasWidget.value = mergedLoras; + + // Replace the existing trigger word update code with the new function + updateConnectedTriggerWords(node, value); } finally { isUpdating = false; }