mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 14:12:11 -03:00
feat: enhance node registration and management with support for multiple nodes and improved UI elements. Fixes #220
This commit is contained in:
@@ -1,32 +1,20 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import asyncio
|
||||
from server import PromptServer # type: ignore
|
||||
from aiohttp import web
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.usage_stats import UsageStats
|
||||
from ..utils.lora_metadata import extract_trained_words
|
||||
from ..config import config
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Download status tracking
|
||||
download_task = None
|
||||
is_downloading = False
|
||||
download_progress = {
|
||||
'total': 0,
|
||||
'completed': 0,
|
||||
'current_model': '',
|
||||
'status': 'idle', # idle, running, paused, completed, error
|
||||
'errors': [],
|
||||
'last_error': None,
|
||||
'start_time': None,
|
||||
'end_time': None,
|
||||
'processed_models': set(), # Track models that have been processed
|
||||
'refreshed_models': set() # Track models that had metadata refreshed
|
||||
}
|
||||
standalone_mode = 'nodes' not in sys.modules
|
||||
|
||||
# Node registry for tracking active workflow nodes
|
||||
class NodeRegistry:
|
||||
@@ -34,33 +22,45 @@ class NodeRegistry:
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.RLock()
|
||||
self._current_graph_id = None
|
||||
self._nodes = {} # node_id -> node_info
|
||||
self._registry_updated = threading.Event()
|
||||
|
||||
def register_node(self, node_id, bgcolor, title, graph_id):
|
||||
"""Register a node for the current workflow"""
|
||||
def register_nodes(self, nodes):
|
||||
"""Register multiple nodes at once, replacing existing registry"""
|
||||
with self._lock:
|
||||
# If graph_id changed, clear existing registry for new workflow
|
||||
if self._current_graph_id != graph_id:
|
||||
self._current_graph_id = graph_id
|
||||
self._nodes.clear()
|
||||
logger.info(f"Workflow changed to {graph_id}, cleared node registry")
|
||||
# Clear existing registry
|
||||
self._nodes.clear()
|
||||
|
||||
# Register the node
|
||||
self._nodes[node_id] = {
|
||||
'id': node_id,
|
||||
'bgcolor': bgcolor,
|
||||
'title': title,
|
||||
'graph_id': graph_id
|
||||
}
|
||||
# Register all new nodes
|
||||
for node in nodes:
|
||||
node_id = node['node_id']
|
||||
node_type = node.get('type', '')
|
||||
|
||||
# Convert node type name to integer
|
||||
type_id = NODE_TYPES.get(node_type, 0) # 0 for unknown types
|
||||
|
||||
# Handle null bgcolor with default color
|
||||
bgcolor = node.get('bgcolor')
|
||||
if bgcolor is None:
|
||||
bgcolor = DEFAULT_NODE_COLOR
|
||||
|
||||
self._nodes[node_id] = {
|
||||
'id': node_id,
|
||||
'bgcolor': bgcolor,
|
||||
'title': node.get('title'),
|
||||
'type': type_id,
|
||||
'type_name': node_type
|
||||
}
|
||||
|
||||
logger.info(f"Registered node {node_id} ({title}) for workflow {graph_id} with bgcolor {bgcolor}")
|
||||
logger.debug(f"Registered {len(nodes)} nodes in registry")
|
||||
|
||||
# Signal that registry has been updated
|
||||
self._registry_updated.set()
|
||||
|
||||
def get_registry(self):
|
||||
"""Get current registry information"""
|
||||
with self._lock:
|
||||
return {
|
||||
'current_graph_id': self._current_graph_id,
|
||||
'nodes': dict(self._nodes), # Return a copy
|
||||
'node_count': len(self._nodes)
|
||||
}
|
||||
@@ -68,9 +68,13 @@ class NodeRegistry:
|
||||
def clear_registry(self):
|
||||
"""Clear the entire registry"""
|
||||
with self._lock:
|
||||
self._current_graph_id = None
|
||||
self._nodes.clear()
|
||||
logger.info("Node registry cleared")
|
||||
|
||||
def wait_for_update(self, timeout=1.0):
|
||||
"""Wait for registry update with timeout"""
|
||||
self._registry_updated.clear()
|
||||
return self._registry_updated.wait(timeout)
|
||||
|
||||
# Global registry instance
|
||||
node_registry = NodeRegistry()
|
||||
@@ -100,7 +104,7 @@ class MiscRoutes:
|
||||
app.router.add_get('/api/model-example-files', MiscRoutes.get_model_example_files)
|
||||
|
||||
# Node registry endpoints
|
||||
app.router.add_post('/api/register-node', MiscRoutes.register_node)
|
||||
app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes)
|
||||
app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
|
||||
|
||||
@staticmethod
|
||||
@@ -135,10 +139,6 @@ class MiscRoutes:
|
||||
'error': f"Failed to delete {filename}: {str(e)}"
|
||||
}, status=500)
|
||||
|
||||
# If we want to completely remove the cache folder too (optional,
|
||||
# but we'll keep the folder structure in place here)
|
||||
# shutil.rmtree(cache_folder)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f"Successfully cleared {len(deleted_files)} cache files",
|
||||
@@ -457,70 +457,68 @@ class MiscRoutes:
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def register_node(request):
|
||||
async def register_nodes(request):
|
||||
"""
|
||||
Register a Lora node for the current workflow
|
||||
Register multiple Lora nodes at once
|
||||
|
||||
Expects a JSON body with:
|
||||
{
|
||||
"node_id": 123,
|
||||
"bgcolor": "#535",
|
||||
"title": "Lora Loader (LoraManager)",
|
||||
"graph_id": "151410b3-7845-4561-aac4-8968574e9ba2"
|
||||
"nodes": [
|
||||
{
|
||||
"node_id": 123,
|
||||
"bgcolor": "#535",
|
||||
"title": "Lora Loader (LoraManager)"
|
||||
},
|
||||
...
|
||||
]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Validate required fields
|
||||
node_id = data.get('node_id')
|
||||
bgcolor = data.get('bgcolor')
|
||||
title = data.get('title')
|
||||
graph_id = data.get('graph_id')
|
||||
nodes = data.get('nodes', [])
|
||||
|
||||
if node_id is None:
|
||||
if not isinstance(nodes, list):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing node_id parameter'
|
||||
'error': 'nodes must be a list'
|
||||
}, status=400)
|
||||
|
||||
if not bgcolor:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing bgcolor parameter'
|
||||
}, status=400)
|
||||
# Validate each node
|
||||
for i, node in enumerate(nodes):
|
||||
if not isinstance(node, dict):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Node {i} must be an object'
|
||||
}, status=400)
|
||||
|
||||
node_id = node.get('node_id')
|
||||
if node_id is None:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Node {i} missing node_id parameter'
|
||||
}, status=400)
|
||||
|
||||
# Validate node_id is an integer
|
||||
try:
|
||||
node['node_id'] = int(node_id)
|
||||
except (ValueError, TypeError):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f'Node {i} node_id must be an integer'
|
||||
}, status=400)
|
||||
|
||||
if not title:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing title parameter'
|
||||
}, status=400)
|
||||
|
||||
if not graph_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing graph_id parameter'
|
||||
}, status=400)
|
||||
|
||||
# Validate node_id is an integer
|
||||
try:
|
||||
node_id = int(node_id)
|
||||
except (ValueError, TypeError):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'node_id must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Register the node
|
||||
node_registry.register_node(node_id, bgcolor, title, graph_id)
|
||||
# Register all nodes
|
||||
node_registry.register_nodes(nodes)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Node {node_id} registered successfully'
|
||||
'message': f'{len(nodes)} nodes registered successfully'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register node: {e}", exc_info=True)
|
||||
logger.error(f"Failed to register nodes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
@@ -528,8 +526,46 @@ class MiscRoutes:
|
||||
|
||||
@staticmethod
|
||||
async def get_registry(request):
|
||||
"""Get current node registry information"""
|
||||
"""Get current node registry information by refreshing from frontend"""
|
||||
try:
|
||||
# Check if running in standalone mode
|
||||
if standalone_mode:
|
||||
logger.warning("Registry refresh not available in standalone mode")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Standalone Mode Active',
|
||||
'message': 'Cannot interact with ComfyUI in standalone mode.'
|
||||
}, status=503)
|
||||
|
||||
# Send message to frontend to refresh registry
|
||||
try:
|
||||
PromptServer.instance.send_sync("lora_registry_refresh", {})
|
||||
logger.debug("Sent registry refresh request to frontend")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send registry refresh message: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Communication Error',
|
||||
'message': f'Failed to communicate with ComfyUI frontend: {str(e)}'
|
||||
}, status=500)
|
||||
|
||||
# Wait for registry update with timeout
|
||||
def wait_for_registry():
|
||||
return node_registry.wait_for_update(timeout=1.0)
|
||||
|
||||
# Run the wait in a thread to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
registry_updated = await loop.run_in_executor(None, wait_for_registry)
|
||||
|
||||
if not registry_updated:
|
||||
logger.warning("Registry refresh timeout after 1 second")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Timeout Error',
|
||||
'message': 'Registry refresh timeout - ComfyUI frontend may not be responsive'
|
||||
}, status=408)
|
||||
|
||||
# Get updated registry
|
||||
registry_info = node_registry.get_registry()
|
||||
|
||||
return web.json_response({
|
||||
@@ -541,5 +577,6 @@ class MiscRoutes:
|
||||
logger.error(f"Failed to get registry: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
'error': 'Internal Error',
|
||||
'message': str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -7,6 +7,15 @@ NSFW_LEVELS = {
|
||||
"Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account?
|
||||
}
|
||||
|
||||
# Node type constants
|
||||
NODE_TYPES = {
|
||||
"Lora Loader (LoraManager)": 1,
|
||||
"Lora Stacker (LoraManager)": 2
|
||||
}
|
||||
|
||||
# Default ComfyUI node color when bgcolor is null
|
||||
DEFAULT_NODE_COLOR = "#353535"
|
||||
|
||||
# preview extensions
|
||||
PREVIEW_EXTENSIONS = [
|
||||
'.webp',
|
||||
|
||||
Reference in New Issue
Block a user