diff --git a/py/nodes/trigger_word_toggle.py b/py/nodes/trigger_word_toggle.py index 16b72f55..61bcd123 100644 --- a/py/nodes/trigger_word_toggle.py +++ b/py/nodes/trigger_word_toggle.py @@ -47,10 +47,10 @@ class TriggerWordToggle: trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else "" # Send trigger words to frontend - PromptServer.instance.send_sync("trigger_word_update", { - "id": id, - "message": trigger_words - }) + # PromptServer.instance.send_sync("trigger_word_update", { + # "id": id, + # "message": trigger_words + # }) filtered_triggers = trigger_words diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 39d0d429..06aff4a5 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -3,8 +3,10 @@ import json import logging from aiohttp import web from typing import Dict +from server import PromptServer # type: ignore from ..utils.routes_common import ModelRouteUtils +from ..nodes.utils import get_lora_info from ..config import config from ..services.websocket_manager import ws_manager @@ -64,6 +66,9 @@ class ApiRoutes: app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files + + # Add the new trigger words route + app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words) # Add update check routes UpdateRoutes.setup_routes(app) @@ -1021,4 +1026,35 @@ class ApiRoutes: return web.json_response({ 'success': False, 'error': str(e) + }, status=500) + + async def get_trigger_words(self, request: web.Request) -> web.Response: + """Get trigger words for specified LoRA models""" + try: + json_data = await request.json() + lora_names = json_data.get("lora_names", []) + node_ids = json_data.get("node_ids", []) + + all_trigger_words = [] + for lora_name in lora_names: + _, trigger_words = await get_lora_info(lora_name) + all_trigger_words.extend(trigger_words) + + # Format the trigger words + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + # Send update to all connected trigger word toggle nodes + for node_id in node_ids: + PromptServer.instance.send_sync("trigger_word_update", { + "id": node_id, + "message": trigger_words_text + }) + + return web.json_response({"success": True}) + + except Exception as e: + logger.error(f"Error getting trigger words: {e}") + return web.json_response({ + "success": False, + "error": str(e) }, status=500) \ No newline at end of file diff --git a/py/server_routes.py b/py/server_routes.py new file mode 100644 index 00000000..68ee9749 --- /dev/null +++ b/py/server_routes.py @@ -0,0 +1,26 @@ +from aiohttp import web +from server import PromptServer +from .nodes.utils import get_lora_info + +@PromptServer.instance.routes.post("/loramanager/get_trigger_words") +async def get_trigger_words(request): + json_data = await request.json() + lora_names = json_data.get("lora_names", []) + node_ids = json_data.get("node_ids", []) + + all_trigger_words = [] + for lora_name in lora_names: + _, trigger_words = await get_lora_info(lora_name) + all_trigger_words.extend(trigger_words) + + # Format the trigger words + trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" + + # Send update to all connected trigger word toggle nodes + for node_id in node_ids: + PromptServer.instance.send_sync("trigger_word_update", { + "id": node_id, + "message": trigger_words_text + }) + + return web.json_response({"success": True}) diff --git a/web/comfyui/lora_loader.js b/web/comfyui/lora_loader.js index 62f01f1c..f8784109 100644 --- a/web/comfyui/lora_loader.js +++ b/web/comfyui/lora_loader.js @@ -9,6 +9,57 @@ async function getLorasWidgetModule() { return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js"); } +// Function to get connected trigger toggle nodes +function getConnectedTriggerToggleNodes(node) { + const connectedNodes = []; + + // Check if node has outputs + if (node.outputs && node.outputs.length > 0) { + // For each output slot + for (const output of node.outputs) { + // Check if this output has any links + if (output.links && output.links.length > 0) { + // For each link, get the target node + for (const linkId of output.links) { + const link = app.graph.links[linkId]; + if (link) { + const targetNode = app.graph.getNodeById(link.target_id); + if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") { + connectedNodes.push(targetNode.id); + } + } + } + } + } + } + return connectedNodes; +} + +// Function to update trigger words for connected toggle nodes +function updateConnectedTriggerWords(node, text) { + const connectedNodeIds = getConnectedTriggerToggleNodes(node); + if (connectedNodeIds.length > 0) { + // Extract lora names from the text + const loraNames = []; + let match; + // Reset the RegExp object's lastIndex to start from the beginning + LORA_PATTERN.lastIndex = 0; + while ((match = LORA_PATTERN.exec(text)) !== null) { + loraNames.push(match[1]); // match[1] contains the lora name + } + + // Call API to get trigger words + fetch("/loramanager/get_trigger_words", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + lora_names: loraNames, + node_ids: connectedNodeIds + }) + }).catch(err => console.error("Error fetching trigger words:", err)); + } +} + function mergeLoras(lorasText, lorasArr) { const result = []; let match; @@ -99,6 +150,9 @@ app.registerExtension({ newText = newText.replace(/\s+/g, ' ').trim(); inputWidget.value = newText; + + // Add this line to update trigger words when lorasWidget changes cause inputWidget value to change + updateConnectedTriggerWords(node, newText); } finally { isUpdating = false; } @@ -117,6 +171,9 @@ app.registerExtension({ const mergedLoras = mergeLoras(value, currentLoras); node.lorasWidget.value = mergedLoras; + + // Replace the existing trigger word update code with the new function + updateConnectedTriggerWords(node, value); } finally { isUpdating = false; }