mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
313 lines
12 KiB
Python
313 lines
12 KiB
Python
import logging
|
|
import os
|
|
from server import PromptServer # type: ignore
|
|
from aiohttp import web
|
|
from ..services.settings_manager import settings
|
|
from ..utils.usage_stats import UsageStats
|
|
from ..utils.lora_metadata import extract_trained_words
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Download status tracking
|
|
download_task = None
|
|
is_downloading = False
|
|
download_progress = {
|
|
'total': 0,
|
|
'completed': 0,
|
|
'current_model': '',
|
|
'status': 'idle', # idle, running, paused, completed, error
|
|
'errors': [],
|
|
'last_error': None,
|
|
'start_time': None,
|
|
'end_time': None,
|
|
'processed_models': set(), # Track models that have been processed
|
|
'refreshed_models': set() # Track models that had metadata refreshed
|
|
}
|
|
|
|
class MiscRoutes:
|
|
"""Miscellaneous routes for various utility functions"""
|
|
|
|
@staticmethod
|
|
def setup_routes(app):
|
|
"""Register miscellaneous routes"""
|
|
app.router.add_post('/api/settings', MiscRoutes.update_settings)
|
|
|
|
# Add new route for clearing cache
|
|
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
|
|
|
|
# Usage stats routes
|
|
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
|
|
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
|
|
|
|
# Lora code update endpoint
|
|
app.router.add_post('/api/update-lora-code', MiscRoutes.update_lora_code)
|
|
|
|
# Add new route for getting trained words
|
|
app.router.add_get('/api/trained-words', MiscRoutes.get_trained_words)
|
|
|
|
@staticmethod
|
|
async def clear_cache(request):
|
|
"""Clear all cache files from the cache folder"""
|
|
try:
|
|
# Get the cache folder path (relative to project directory)
|
|
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
cache_folder = os.path.join(project_dir, 'cache')
|
|
|
|
# Check if cache folder exists
|
|
if not os.path.exists(cache_folder):
|
|
logger.info("Cache folder does not exist, nothing to clear")
|
|
return web.json_response({'success': True, 'message': 'No cache folder found'})
|
|
|
|
# Get list of cache files before deleting for reporting
|
|
cache_files = [f for f in os.listdir(cache_folder) if os.path.isfile(os.path.join(cache_folder, f))]
|
|
deleted_files = []
|
|
|
|
# Delete each .msgpack file in the cache folder
|
|
for filename in cache_files:
|
|
if filename.endswith('.msgpack'):
|
|
file_path = os.path.join(cache_folder, filename)
|
|
try:
|
|
os.remove(file_path)
|
|
deleted_files.append(filename)
|
|
logger.info(f"Deleted cache file: {filename}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete {filename}: {e}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f"Failed to delete {filename}: {str(e)}"
|
|
}, status=500)
|
|
|
|
# If we want to completely remove the cache folder too (optional,
|
|
# but we'll keep the folder structure in place here)
|
|
# shutil.rmtree(cache_folder)
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'message': f"Successfully cleared {len(deleted_files)} cache files",
|
|
'deleted_files': deleted_files
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error clearing cache files: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def update_settings(request):
|
|
"""Update application settings"""
|
|
try:
|
|
data = await request.json()
|
|
|
|
# Validate and update settings
|
|
for key, value in data.items():
|
|
# Special handling for example_images_path - verify path exists
|
|
if key == 'example_images_path' and value:
|
|
if not os.path.exists(value):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f"Path does not exist: {value}"
|
|
})
|
|
|
|
# Path changed - server restart required for new path to take effect
|
|
old_path = settings.get('example_images_path')
|
|
if old_path != value:
|
|
logger.info(f"Example images path changed to {value} - server restart required")
|
|
|
|
# Save to settings
|
|
settings.set(key, value)
|
|
|
|
return web.json_response({'success': True})
|
|
except Exception as e:
|
|
logger.error(f"Error updating settings: {e}", exc_info=True)
|
|
return web.Response(status=500, text=str(e))
|
|
|
|
@staticmethod
|
|
async def update_usage_stats(request):
|
|
"""
|
|
Update usage statistics based on a prompt_id
|
|
|
|
Expects a JSON body with:
|
|
{
|
|
"prompt_id": "string"
|
|
}
|
|
"""
|
|
try:
|
|
# Parse the request body
|
|
data = await request.json()
|
|
prompt_id = data.get('prompt_id')
|
|
|
|
if not prompt_id:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Missing prompt_id'
|
|
}, status=400)
|
|
|
|
# Call the UsageStats to process this prompt_id synchronously
|
|
usage_stats = UsageStats()
|
|
await usage_stats.process_execution(prompt_id)
|
|
|
|
return web.json_response({
|
|
'success': True
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def get_usage_stats(request):
|
|
"""Get current usage statistics"""
|
|
try:
|
|
usage_stats = UsageStats()
|
|
stats = await usage_stats.get_stats()
|
|
|
|
# Add version information to help clients handle format changes
|
|
stats_response = {
|
|
'success': True,
|
|
'data': stats,
|
|
'format_version': 2 # Indicate this is the new format with history
|
|
}
|
|
|
|
return web.json_response(stats_response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def update_lora_code(request):
|
|
"""
|
|
Update Lora code in ComfyUI nodes
|
|
|
|
Expects a JSON body with:
|
|
{
|
|
"node_ids": [123, 456], # Optional - List of node IDs to update (for browser mode)
|
|
"lora_code": "<lora:modelname:1.0>", # The Lora code to send
|
|
"mode": "append" # or "replace" - whether to append or replace existing code
|
|
}
|
|
"""
|
|
try:
|
|
# Parse the request body
|
|
data = await request.json()
|
|
node_ids = data.get('node_ids')
|
|
lora_code = data.get('lora_code', '')
|
|
mode = data.get('mode', 'append')
|
|
|
|
if not lora_code:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Missing lora_code parameter'
|
|
}, status=400)
|
|
|
|
results = []
|
|
|
|
# Desktop mode: no specific node_ids provided
|
|
if node_ids is None:
|
|
try:
|
|
# Send broadcast message with id=-1 to all Lora Loader nodes
|
|
PromptServer.instance.send_sync("lora_code_update", {
|
|
"id": -1,
|
|
"lora_code": lora_code,
|
|
"mode": mode
|
|
})
|
|
results.append({
|
|
'node_id': 'broadcast',
|
|
'success': True
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error broadcasting lora code: {e}")
|
|
results.append({
|
|
'node_id': 'broadcast',
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
else:
|
|
# Browser mode: send to specific nodes
|
|
for node_id in node_ids:
|
|
try:
|
|
# Send the message to the frontend
|
|
PromptServer.instance.send_sync("lora_code_update", {
|
|
"id": node_id,
|
|
"lora_code": lora_code,
|
|
"mode": mode
|
|
})
|
|
results.append({
|
|
'node_id': node_id,
|
|
'success': True
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error sending lora code to node {node_id}: {e}")
|
|
results.append({
|
|
'node_id': node_id,
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'results': results
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update lora code: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def get_trained_words(request):
|
|
"""
|
|
Get trained words from a safetensors file, sorted by frequency
|
|
|
|
Expects a query parameter:
|
|
file_path: Path to the safetensors file
|
|
"""
|
|
try:
|
|
# Get file path from query parameters
|
|
file_path = request.query.get('file_path')
|
|
|
|
if not file_path:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Missing file_path parameter'
|
|
}, status=400)
|
|
|
|
# Check if file exists and is a safetensors file
|
|
if not os.path.exists(file_path):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f"File not found: {file_path}"
|
|
}, status=404)
|
|
|
|
if not file_path.lower().endswith('.safetensors'):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'File is not a safetensors file'
|
|
}, status=400)
|
|
|
|
# Extract trained words and class_tokens
|
|
trained_words, class_tokens = await extract_trained_words(file_path)
|
|
|
|
# Return result with both trained words and class tokens
|
|
return web.json_response({
|
|
'success': True,
|
|
'trained_words': trained_words,
|
|
'class_tokens': class_tokens
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get trained words: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|