diff --git a/py/routes/misc_routes.py b/py/routes/misc_routes.py index 73c5dd7e..4da2e188 100644 --- a/py/routes/misc_routes.py +++ b/py/routes/misc_routes.py @@ -1,5 +1,6 @@ import logging import os +import threading from server import PromptServer # type: ignore from aiohttp import web from ..services.settings_manager import settings @@ -27,6 +28,53 @@ download_progress = { 'refreshed_models': set() # Track models that had metadata refreshed } +# Node registry for tracking active workflow nodes +class NodeRegistry: + """Thread-safe registry for tracking Lora nodes in active workflows""" + + def __init__(self): + self._lock = threading.RLock() + self._current_graph_id = None + self._nodes = {} # node_id -> node_info + + def register_node(self, node_id, bgcolor, title, graph_id): + """Register a node for the current workflow""" + with self._lock: + # If graph_id changed, clear existing registry for new workflow + if self._current_graph_id != graph_id: + self._current_graph_id = graph_id + self._nodes.clear() + logger.info(f"Workflow changed to {graph_id}, cleared node registry") + + # Register the node + self._nodes[node_id] = { + 'id': node_id, + 'bgcolor': bgcolor, + 'title': title, + 'graph_id': graph_id + } + + logger.info(f"Registered node {node_id} ({title}) for workflow {graph_id} with bgcolor {bgcolor}") + + def get_registry(self): + """Get current registry information""" + with self._lock: + return { + 'current_graph_id': self._current_graph_id, + 'nodes': dict(self._nodes), # Return a copy + 'node_count': len(self._nodes) + } + + def clear_registry(self): + """Clear the entire registry""" + with self._lock: + self._current_graph_id = None + self._nodes.clear() + logger.info("Node registry cleared") + +# Global registry instance +node_registry = NodeRegistry() + class MiscRoutes: """Miscellaneous routes for various utility functions""" @@ -50,6 +98,10 @@ class MiscRoutes: # Add new route for getting model example files app.router.add_get('/api/model-example-files', MiscRoutes.get_model_example_files) + + # Node registry endpoints + app.router.add_post('/api/register-node', MiscRoutes.register_node) + app.router.add_get('/api/get-registry', MiscRoutes.get_registry) @staticmethod async def clear_cache(request): @@ -403,3 +455,91 @@ class MiscRoutes: 'success': False, 'error': str(e) }, status=500) + + @staticmethod + async def register_node(request): + """ + Register a Lora node for the current workflow + + Expects a JSON body with: + { + "node_id": 123, + "bgcolor": "#535", + "title": "Lora Loader (LoraManager)", + "graph_id": "151410b3-7845-4561-aac4-8968574e9ba2" + } + """ + try: + data = await request.json() + + # Validate required fields + node_id = data.get('node_id') + bgcolor = data.get('bgcolor') + title = data.get('title') + graph_id = data.get('graph_id') + + if node_id is None: + return web.json_response({ + 'success': False, + 'error': 'Missing node_id parameter' + }, status=400) + + if not bgcolor: + return web.json_response({ + 'success': False, + 'error': 'Missing bgcolor parameter' + }, status=400) + + if not title: + return web.json_response({ + 'success': False, + 'error': 'Missing title parameter' + }, status=400) + + if not graph_id: + return web.json_response({ + 'success': False, + 'error': 'Missing graph_id parameter' + }, status=400) + + # Validate node_id is an integer + try: + node_id = int(node_id) + except (ValueError, TypeError): + return web.json_response({ + 'success': False, + 'error': 'node_id must be an integer' + }, status=400) + + # Register the node + node_registry.register_node(node_id, bgcolor, title, graph_id) + + return web.json_response({ + 'success': True, + 'message': f'Node {node_id} registered successfully' + }) + + except Exception as e: + logger.error(f"Failed to register node: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + @staticmethod + async def get_registry(request): + """Get current node registry information""" + try: + registry_info = node_registry.get_registry() + + return web.json_response({ + 'success': True, + 'data': registry_info + }) + + except Exception as e: + logger.error(f"Failed to get registry: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) diff --git a/web/comfyui/lora_loader.js b/web/comfyui/lora_loader.js index 6c285c67..fb5ab52c 100644 --- a/web/comfyui/lora_loader.js +++ b/web/comfyui/lora_loader.js @@ -195,6 +195,32 @@ app.registerExtension({ isUpdating = false; } }; + + // Register this node with the backend + this.registerNode = async () => { + try { + await fetch('/api/register-node', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + node_id: this.id, + bgcolor: this.bgcolor, + title: this.title, + graph_id: this.graph.id + }) + }); + } catch (error) { + console.warn('Failed to register node:', error); + } + }; + + // Ensure the node is registered after creation + // Call registration + setTimeout(() => { + this.registerNode(); + }, 0); }); } }, diff --git a/web/comfyui/lora_stacker.js b/web/comfyui/lora_stacker.js index bd73b53e..f6fb3d95 100644 --- a/web/comfyui/lora_stacker.js +++ b/web/comfyui/lora_stacker.js @@ -125,6 +125,31 @@ app.registerExtension({ isUpdating = false; } }; + + // Register this node with the backend + this.registerNode = async () => { + try { + await fetch('/api/register-node', { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + node_id: this.id, + bgcolor: this.bgcolor, + title: this.title, + graph_id: this.graph.id + }) + }); + } catch (error) { + console.warn('Failed to register node:', error); + } + }; + + // Call registration + setTimeout(() => { + this.registerNode(); + }, 0); }); } },