feat: implement thread-safe node registry and registration endpoints for Lora nodes

This commit is contained in:
Will Miao
2025-06-26 18:31:14 +08:00
parent ae905c8630
commit eb57e04e95
3 changed files with 191 additions and 0 deletions

View File

@@ -1,5 +1,6 @@
import logging
import os
import threading
from server import PromptServer # type: ignore
from aiohttp import web
from ..services.settings_manager import settings
@@ -27,6 +28,53 @@ download_progress = {
'refreshed_models': set() # Track models that had metadata refreshed
}
# Node registry for tracking active workflow nodes
class NodeRegistry:
"""Thread-safe registry for tracking Lora nodes in active workflows"""
def __init__(self):
self._lock = threading.RLock()
self._current_graph_id = None
self._nodes = {} # node_id -> node_info
def register_node(self, node_id, bgcolor, title, graph_id):
"""Register a node for the current workflow"""
with self._lock:
# If graph_id changed, clear existing registry for new workflow
if self._current_graph_id != graph_id:
self._current_graph_id = graph_id
self._nodes.clear()
logger.info(f"Workflow changed to {graph_id}, cleared node registry")
# Register the node
self._nodes[node_id] = {
'id': node_id,
'bgcolor': bgcolor,
'title': title,
'graph_id': graph_id
}
logger.info(f"Registered node {node_id} ({title}) for workflow {graph_id} with bgcolor {bgcolor}")
def get_registry(self):
"""Get current registry information"""
with self._lock:
return {
'current_graph_id': self._current_graph_id,
'nodes': dict(self._nodes), # Return a copy
'node_count': len(self._nodes)
}
def clear_registry(self):
"""Clear the entire registry"""
with self._lock:
self._current_graph_id = None
self._nodes.clear()
logger.info("Node registry cleared")
# Global registry instance
node_registry = NodeRegistry()
class MiscRoutes:
"""Miscellaneous routes for various utility functions"""
@@ -50,6 +98,10 @@ class MiscRoutes:
# Add new route for getting model example files
app.router.add_get('/api/model-example-files', MiscRoutes.get_model_example_files)
# Node registry endpoints
app.router.add_post('/api/register-node', MiscRoutes.register_node)
app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
@staticmethod
async def clear_cache(request):
@@ -403,3 +455,91 @@ class MiscRoutes:
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def register_node(request):
"""
Register a Lora node for the current workflow
Expects a JSON body with:
{
"node_id": 123,
"bgcolor": "#535",
"title": "Lora Loader (LoraManager)",
"graph_id": "151410b3-7845-4561-aac4-8968574e9ba2"
}
"""
try:
data = await request.json()
# Validate required fields
node_id = data.get('node_id')
bgcolor = data.get('bgcolor')
title = data.get('title')
graph_id = data.get('graph_id')
if node_id is None:
return web.json_response({
'success': False,
'error': 'Missing node_id parameter'
}, status=400)
if not bgcolor:
return web.json_response({
'success': False,
'error': 'Missing bgcolor parameter'
}, status=400)
if not title:
return web.json_response({
'success': False,
'error': 'Missing title parameter'
}, status=400)
if not graph_id:
return web.json_response({
'success': False,
'error': 'Missing graph_id parameter'
}, status=400)
# Validate node_id is an integer
try:
node_id = int(node_id)
except (ValueError, TypeError):
return web.json_response({
'success': False,
'error': 'node_id must be an integer'
}, status=400)
# Register the node
node_registry.register_node(node_id, bgcolor, title, graph_id)
return web.json_response({
'success': True,
'message': f'Node {node_id} registered successfully'
})
except Exception as e:
logger.error(f"Failed to register node: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_registry(request):
"""Get current node registry information"""
try:
registry_info = node_registry.get_registry()
return web.json_response({
'success': True,
'data': registry_info
})
except Exception as e:
logger.error(f"Failed to get registry: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

View File

@@ -195,6 +195,32 @@ app.registerExtension({
isUpdating = false;
}
};
// Register this node with the backend
this.registerNode = async () => {
try {
await fetch('/api/register-node', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
node_id: this.id,
bgcolor: this.bgcolor,
title: this.title,
graph_id: this.graph.id
})
});
} catch (error) {
console.warn('Failed to register node:', error);
}
};
// Ensure the node is registered after creation
// Call registration
setTimeout(() => {
this.registerNode();
}, 0);
});
}
},

View File

@@ -125,6 +125,31 @@ app.registerExtension({
isUpdating = false;
}
};
// Register this node with the backend
this.registerNode = async () => {
try {
await fetch('/api/register-node', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
node_id: this.id,
bgcolor: this.bgcolor,
title: this.title,
graph_id: this.graph.id
})
});
} catch (error) {
console.warn('Failed to register node:', error);
}
};
// Call registration
setTimeout(() => {
this.registerNode();
}, 0);
});
}
},