mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
147 lines
5.8 KiB
Python
147 lines
5.8 KiB
Python
import os
|
|
import json
|
|
import asyncio
|
|
import aiohttp
|
|
from aiohttp import web
|
|
import logging
|
|
from datetime import datetime
|
|
|
|
from ..services.checkpoint_scanner import CheckpointScanner
|
|
from ..config import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CheckpointsRoutes:
|
|
"""API routes for checkpoint management"""
|
|
|
|
def __init__(self):
|
|
self.scanner = CheckpointScanner()
|
|
|
|
def setup_routes(self, app):
|
|
"""Register routes with the aiohttp app"""
|
|
app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints)
|
|
app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints)
|
|
app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
|
|
|
async def get_checkpoints(self, request):
|
|
"""Get paginated checkpoint data"""
|
|
try:
|
|
# Parse query parameters
|
|
page = int(request.query.get('page', '1'))
|
|
page_size = min(int(request.query.get('page_size', '20')), 100)
|
|
sort_by = request.query.get('sort', 'name')
|
|
folder = request.query.get('folder', None)
|
|
search = request.query.get('search', None)
|
|
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
|
base_models = request.query.getall('base_model', [])
|
|
tags = request.query.getall('tag', [])
|
|
|
|
# Process search options
|
|
search_options = {
|
|
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
|
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
|
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
|
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
|
}
|
|
|
|
# Process hash filters if provided
|
|
hash_filters = {}
|
|
if 'hash' in request.query:
|
|
hash_filters['single_hash'] = request.query['hash']
|
|
elif 'hashes' in request.query:
|
|
try:
|
|
hash_list = json.loads(request.query['hashes'])
|
|
if isinstance(hash_list, list):
|
|
hash_filters['multiple_hashes'] = hash_list
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
# Get data from scanner
|
|
result = await self.get_paginated_data(
|
|
page=page,
|
|
page_size=page_size,
|
|
sort_by=sort_by,
|
|
folder=folder,
|
|
search=search,
|
|
fuzzy_search=fuzzy_search,
|
|
base_models=base_models,
|
|
tags=tags,
|
|
search_options=search_options,
|
|
hash_filters=hash_filters
|
|
)
|
|
|
|
# Return as JSON
|
|
return web.json_response(result)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
async def get_paginated_data(self, page, page_size, sort_by='name',
|
|
folder=None, search=None, fuzzy_search=False,
|
|
base_models=None, tags=None,
|
|
search_options=None, hash_filters=None):
|
|
"""Get paginated and filtered checkpoint data"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
# Implement similar filtering logic as in LoraScanner
|
|
# (Adapt code from LoraScanner.get_paginated_data)
|
|
# ...
|
|
|
|
# For now, a simplified implementation:
|
|
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
|
|
|
# Apply basic folder filtering if needed
|
|
if folder is not None:
|
|
filtered_data = [
|
|
cp for cp in filtered_data
|
|
if cp['folder'] == folder
|
|
]
|
|
|
|
# Apply basic search if needed
|
|
if search:
|
|
filtered_data = [
|
|
cp for cp in filtered_data
|
|
if search.lower() in cp['file_name'].lower() or
|
|
search.lower() in cp['model_name'].lower()
|
|
]
|
|
|
|
# Calculate pagination
|
|
total_items = len(filtered_data)
|
|
start_idx = (page - 1) * page_size
|
|
end_idx = min(start_idx + page_size, total_items)
|
|
|
|
result = {
|
|
'items': filtered_data[start_idx:end_idx],
|
|
'total': total_items,
|
|
'page': page,
|
|
'page_size': page_size,
|
|
'total_pages': (total_items + page_size - 1) // page_size
|
|
}
|
|
|
|
return result
|
|
|
|
async def scan_checkpoints(self, request):
|
|
"""Force a rescan of checkpoint files"""
|
|
try:
|
|
await self.scanner.get_cached_data(force_refresh=True)
|
|
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
|
except Exception as e:
|
|
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
async def get_checkpoint_info(self, request):
|
|
"""Get detailed information for a specific checkpoint by name"""
|
|
try:
|
|
name = request.match_info.get('name', '')
|
|
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
|
|
|
|
if checkpoint_info:
|
|
return web.json_response(checkpoint_info)
|
|
else:
|
|
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
|
return web.json_response({"error": str(e)}, status=500)
|