mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 06:02:11 -03:00
406 lines
16 KiB
Python
406 lines
16 KiB
Python
import logging
|
|
import os
|
|
from server import PromptServer # type: ignore
|
|
from aiohttp import web
|
|
from ..services.settings_manager import settings
|
|
from ..utils.usage_stats import UsageStats
|
|
from ..utils.lora_metadata import extract_trained_words
|
|
from ..config import config
|
|
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
|
import re
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Download status tracking
|
|
download_task = None
|
|
is_downloading = False
|
|
download_progress = {
|
|
'total': 0,
|
|
'completed': 0,
|
|
'current_model': '',
|
|
'status': 'idle', # idle, running, paused, completed, error
|
|
'errors': [],
|
|
'last_error': None,
|
|
'start_time': None,
|
|
'end_time': None,
|
|
'processed_models': set(), # Track models that have been processed
|
|
'refreshed_models': set() # Track models that had metadata refreshed
|
|
}
|
|
|
|
class MiscRoutes:
|
|
"""Miscellaneous routes for various utility functions"""
|
|
|
|
@staticmethod
|
|
def setup_routes(app):
|
|
"""Register miscellaneous routes"""
|
|
app.router.add_post('/api/settings', MiscRoutes.update_settings)
|
|
|
|
# Add new route for clearing cache
|
|
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
|
|
|
|
# Usage stats routes
|
|
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
|
|
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
|
|
|
|
# Lora code update endpoint
|
|
app.router.add_post('/api/update-lora-code', MiscRoutes.update_lora_code)
|
|
|
|
# Add new route for getting trained words
|
|
app.router.add_get('/api/trained-words', MiscRoutes.get_trained_words)
|
|
|
|
# Add new route for getting model example files
|
|
app.router.add_get('/api/model-example-files', MiscRoutes.get_model_example_files)
|
|
|
|
@staticmethod
|
|
async def clear_cache(request):
|
|
"""Clear all cache files from the cache folder"""
|
|
try:
|
|
# Get the cache folder path (relative to project directory)
|
|
project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
cache_folder = os.path.join(project_dir, 'cache')
|
|
|
|
# Check if cache folder exists
|
|
if not os.path.exists(cache_folder):
|
|
logger.info("Cache folder does not exist, nothing to clear")
|
|
return web.json_response({'success': True, 'message': 'No cache folder found'})
|
|
|
|
# Get list of cache files before deleting for reporting
|
|
cache_files = [f for f in os.listdir(cache_folder) if os.path.isfile(os.path.join(cache_folder, f))]
|
|
deleted_files = []
|
|
|
|
# Delete each .msgpack file in the cache folder
|
|
for filename in cache_files:
|
|
if filename.endswith('.msgpack'):
|
|
file_path = os.path.join(cache_folder, filename)
|
|
try:
|
|
os.remove(file_path)
|
|
deleted_files.append(filename)
|
|
logger.info(f"Deleted cache file: {filename}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete {filename}: {e}")
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f"Failed to delete {filename}: {str(e)}"
|
|
}, status=500)
|
|
|
|
# If we want to completely remove the cache folder too (optional,
|
|
# but we'll keep the folder structure in place here)
|
|
# shutil.rmtree(cache_folder)
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'message': f"Successfully cleared {len(deleted_files)} cache files",
|
|
'deleted_files': deleted_files
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error clearing cache files: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def update_settings(request):
|
|
"""Update application settings"""
|
|
try:
|
|
data = await request.json()
|
|
|
|
# Validate and update settings
|
|
for key, value in data.items():
|
|
# Special handling for example_images_path - verify path exists
|
|
if key == 'example_images_path' and value:
|
|
if not os.path.exists(value):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f"Path does not exist: {value}"
|
|
})
|
|
|
|
# Path changed - server restart required for new path to take effect
|
|
old_path = settings.get('example_images_path')
|
|
if old_path != value:
|
|
logger.info(f"Example images path changed to {value} - server restart required")
|
|
|
|
# Save to settings
|
|
settings.set(key, value)
|
|
|
|
return web.json_response({'success': True})
|
|
except Exception as e:
|
|
logger.error(f"Error updating settings: {e}", exc_info=True)
|
|
return web.Response(status=500, text=str(e))
|
|
|
|
@staticmethod
|
|
async def update_usage_stats(request):
|
|
"""
|
|
Update usage statistics based on a prompt_id
|
|
|
|
Expects a JSON body with:
|
|
{
|
|
"prompt_id": "string"
|
|
}
|
|
"""
|
|
try:
|
|
# Parse the request body
|
|
data = await request.json()
|
|
prompt_id = data.get('prompt_id')
|
|
|
|
if not prompt_id:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Missing prompt_id'
|
|
}, status=400)
|
|
|
|
# Call the UsageStats to process this prompt_id synchronously
|
|
usage_stats = UsageStats()
|
|
await usage_stats.process_execution(prompt_id)
|
|
|
|
return web.json_response({
|
|
'success': True
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def get_usage_stats(request):
|
|
"""Get current usage statistics"""
|
|
try:
|
|
usage_stats = UsageStats()
|
|
stats = await usage_stats.get_stats()
|
|
|
|
# Add version information to help clients handle format changes
|
|
stats_response = {
|
|
'success': True,
|
|
'data': stats,
|
|
'format_version': 2 # Indicate this is the new format with history
|
|
}
|
|
|
|
return web.json_response(stats_response)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def update_lora_code(request):
|
|
"""
|
|
Update Lora code in ComfyUI nodes
|
|
|
|
Expects a JSON body with:
|
|
{
|
|
"node_ids": [123, 456], # Optional - List of node IDs to update (for browser mode)
|
|
"lora_code": "<lora:modelname:1.0>", # The Lora code to send
|
|
"mode": "append" # or "replace" - whether to append or replace existing code
|
|
}
|
|
"""
|
|
try:
|
|
# Parse the request body
|
|
data = await request.json()
|
|
node_ids = data.get('node_ids')
|
|
lora_code = data.get('lora_code', '')
|
|
mode = data.get('mode', 'append')
|
|
|
|
if not lora_code:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Missing lora_code parameter'
|
|
}, status=400)
|
|
|
|
results = []
|
|
|
|
# Desktop mode: no specific node_ids provided
|
|
if node_ids is None:
|
|
try:
|
|
# Send broadcast message with id=-1 to all Lora Loader nodes
|
|
PromptServer.instance.send_sync("lora_code_update", {
|
|
"id": -1,
|
|
"lora_code": lora_code,
|
|
"mode": mode
|
|
})
|
|
results.append({
|
|
'node_id': 'broadcast',
|
|
'success': True
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error broadcasting lora code: {e}")
|
|
results.append({
|
|
'node_id': 'broadcast',
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
else:
|
|
# Browser mode: send to specific nodes
|
|
for node_id in node_ids:
|
|
try:
|
|
# Send the message to the frontend
|
|
PromptServer.instance.send_sync("lora_code_update", {
|
|
"id": node_id,
|
|
"lora_code": lora_code,
|
|
"mode": mode
|
|
})
|
|
results.append({
|
|
'node_id': node_id,
|
|
'success': True
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error sending lora code to node {node_id}: {e}")
|
|
results.append({
|
|
'node_id': node_id,
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'results': results
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to update lora code: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def get_trained_words(request):
|
|
"""
|
|
Get trained words from a safetensors file, sorted by frequency
|
|
|
|
Expects a query parameter:
|
|
file_path: Path to the safetensors file
|
|
"""
|
|
try:
|
|
# Get file path from query parameters
|
|
file_path = request.query.get('file_path')
|
|
|
|
if not file_path:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Missing file_path parameter'
|
|
}, status=400)
|
|
|
|
# Check if file exists and is a safetensors file
|
|
if not os.path.exists(file_path):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': f"File not found: {file_path}"
|
|
}, status=404)
|
|
|
|
if not file_path.lower().endswith('.safetensors'):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'File is not a safetensors file'
|
|
}, status=400)
|
|
|
|
# Extract trained words and class_tokens
|
|
trained_words, class_tokens = await extract_trained_words(file_path)
|
|
|
|
# Return result with both trained words and class tokens
|
|
return web.json_response({
|
|
'success': True,
|
|
'trained_words': trained_words,
|
|
'class_tokens': class_tokens
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get trained words: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|
|
|
|
@staticmethod
|
|
async def get_model_example_files(request):
|
|
"""
|
|
Get list of example image files for a specific model based on file path
|
|
|
|
Expects:
|
|
- file_path in query parameters
|
|
|
|
Returns:
|
|
- List of image files with their paths as static URLs
|
|
"""
|
|
try:
|
|
# Get the model file path from query parameters
|
|
file_path = request.query.get('file_path')
|
|
|
|
if not file_path:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Missing file_path parameter'
|
|
}, status=400)
|
|
|
|
# Extract directory and base filename
|
|
model_dir = os.path.dirname(file_path)
|
|
model_filename = os.path.basename(file_path)
|
|
model_name = os.path.splitext(model_filename)[0]
|
|
|
|
# Check if the directory exists
|
|
if not os.path.exists(model_dir):
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Model directory not found',
|
|
'files': []
|
|
}, status=404)
|
|
|
|
# Look for files matching the pattern modelname.example.<index>.<ext>
|
|
files = []
|
|
pattern = f"{model_name}.example."
|
|
|
|
for file in os.listdir(model_dir):
|
|
file_lower = file.lower()
|
|
if file_lower.startswith(pattern.lower()):
|
|
file_full_path = os.path.join(model_dir, file)
|
|
if os.path.isfile(file_full_path):
|
|
# Check if the file is a supported media file
|
|
file_ext = os.path.splitext(file)[1].lower()
|
|
if (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
|
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
|
|
|
|
# Extract the index from the filename
|
|
try:
|
|
# Extract the part after '.example.' and before file extension
|
|
index_part = file[len(pattern):].split('.')[0]
|
|
# Try to parse it as an integer
|
|
index = int(index_part)
|
|
except (ValueError, IndexError):
|
|
# If we can't parse the index, use infinity to sort at the end
|
|
index = float('inf')
|
|
|
|
# Convert file path to static URL
|
|
static_url = config.get_preview_static_url(file_full_path)
|
|
|
|
files.append({
|
|
'name': file,
|
|
'path': static_url,
|
|
'extension': file_ext,
|
|
'is_video': file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'],
|
|
'index': index
|
|
})
|
|
|
|
# Sort files by their index for consistent ordering
|
|
files.sort(key=lambda x: x['index'])
|
|
# Remove the index field as it's only used for sorting
|
|
for file in files:
|
|
file.pop('index', None)
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'files': files
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get model example files: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
}, status=500)
|