mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
209 lines
8.7 KiB
Python
209 lines
8.7 KiB
Python
import os
|
||
import json
|
||
import asyncio
|
||
import aiohttp
|
||
import jinja2
|
||
from aiohttp import web
|
||
import logging
|
||
from datetime import datetime
|
||
|
||
from ..services.checkpoint_scanner import CheckpointScanner
|
||
from ..config import config
|
||
from ..services.settings_manager import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class CheckpointsRoutes:
|
||
"""API routes for checkpoint management"""
|
||
|
||
def __init__(self):
|
||
self.scanner = CheckpointScanner()
|
||
self.template_env = jinja2.Environment(
|
||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||
autoescape=True
|
||
)
|
||
|
||
def setup_routes(self, app):
|
||
"""Register routes with the aiohttp app"""
|
||
app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints)
|
||
app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints)
|
||
app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||
|
||
async def get_checkpoints(self, request):
|
||
"""Get paginated checkpoint data"""
|
||
try:
|
||
# Parse query parameters
|
||
page = int(request.query.get('page', '1'))
|
||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||
sort_by = request.query.get('sort', 'name')
|
||
folder = request.query.get('folder', None)
|
||
search = request.query.get('search', None)
|
||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||
base_models = request.query.getall('base_model', [])
|
||
tags = request.query.getall('tag', [])
|
||
|
||
# Process search options
|
||
search_options = {
|
||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||
}
|
||
|
||
# Process hash filters if provided
|
||
hash_filters = {}
|
||
if 'hash' in request.query:
|
||
hash_filters['single_hash'] = request.query['hash']
|
||
elif 'hashes' in request.query:
|
||
try:
|
||
hash_list = json.loads(request.query['hashes'])
|
||
if isinstance(hash_list, list):
|
||
hash_filters['multiple_hashes'] = hash_list
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
# Get data from scanner
|
||
result = await self.get_paginated_data(
|
||
page=page,
|
||
page_size=page_size,
|
||
sort_by=sort_by,
|
||
folder=folder,
|
||
search=search,
|
||
fuzzy_search=fuzzy_search,
|
||
base_models=base_models,
|
||
tags=tags,
|
||
search_options=search_options,
|
||
hash_filters=hash_filters
|
||
)
|
||
|
||
# Return as JSON
|
||
return web.json_response(result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
||
return web.json_response({"error": str(e)}, status=500)
|
||
|
||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||
folder=None, search=None, fuzzy_search=False,
|
||
base_models=None, tags=None,
|
||
search_options=None, hash_filters=None):
|
||
"""Get paginated and filtered checkpoint data"""
|
||
cache = await self.scanner.get_cached_data()
|
||
|
||
# Implement similar filtering logic as in LoraScanner
|
||
# (Adapt code from LoraScanner.get_paginated_data)
|
||
# ...
|
||
|
||
# For now, a simplified implementation:
|
||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||
|
||
# Apply basic folder filtering if needed
|
||
if folder is not None:
|
||
filtered_data = [
|
||
cp for cp in filtered_data
|
||
if cp['folder'] == folder
|
||
]
|
||
|
||
# Apply basic search if needed
|
||
if search:
|
||
filtered_data = [
|
||
cp for cp in filtered_data
|
||
if search.lower() in cp['file_name'].lower() or
|
||
search.lower() in cp['model_name'].lower()
|
||
]
|
||
|
||
# Calculate pagination
|
||
total_items = len(filtered_data)
|
||
start_idx = (page - 1) * page_size
|
||
end_idx = min(start_idx + page_size, total_items)
|
||
|
||
result = {
|
||
'items': filtered_data[start_idx:end_idx],
|
||
'total': total_items,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total_pages': (total_items + page_size - 1) // page_size
|
||
}
|
||
|
||
return result
|
||
|
||
async def scan_checkpoints(self, request):
|
||
"""Force a rescan of checkpoint files"""
|
||
try:
|
||
await self.scanner.get_cached_data(force_refresh=True)
|
||
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
||
except Exception as e:
|
||
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
||
return web.json_response({"error": str(e)}, status=500)
|
||
|
||
async def get_checkpoint_info(self, request):
|
||
"""Get detailed information for a specific checkpoint by name"""
|
||
try:
|
||
name = request.match_info.get('name', '')
|
||
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
|
||
|
||
if checkpoint_info:
|
||
return web.json_response(checkpoint_info)
|
||
else:
|
||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||
return web.json_response({"error": str(e)}, status=500)
|
||
|
||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||
"""Handle GET /checkpoints request"""
|
||
try:
|
||
# 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑
|
||
is_initializing = (
|
||
self.scanner._cache is None or
|
||
len(self.scanner._cache.raw_data) == 0 or
|
||
hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing
|
||
)
|
||
|
||
if is_initializing:
|
||
# 如果正在初始化,返回一个只包含加载提示的页面
|
||
template = self.template_env.get_template('checkpoints.html')
|
||
rendered = template.render(
|
||
folders=[], # 空文件夹列表
|
||
is_initializing=True, # 新增标志
|
||
settings=settings, # Pass settings to template
|
||
request=request # Pass the request object to the template
|
||
)
|
||
|
||
logger.info("Checkpoints page is initializing, returning loading page")
|
||
else:
|
||
# 正常流程 - 获取已经初始化好的缓存数据
|
||
try:
|
||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||
template = self.template_env.get_template('checkpoints.html')
|
||
rendered = template.render(
|
||
folders=cache.folders,
|
||
is_initializing=False,
|
||
settings=settings, # Pass settings to template
|
||
request=request # Pass the request object to the template
|
||
)
|
||
logger.debug(f"Checkpoints page loaded successfully with {len(cache.raw_data)} items")
|
||
except Exception as cache_error:
|
||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||
# 如果获取缓存失败,也显示初始化页面
|
||
template = self.template_env.get_template('checkpoints.html')
|
||
rendered = template.render(
|
||
folders=[],
|
||
is_initializing=True,
|
||
settings=settings,
|
||
request=request
|
||
)
|
||
logger.info("Checkpoints cache error, returning initialization page")
|
||
|
||
return web.Response(
|
||
text=rendered,
|
||
content_type='text/html'
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||
return web.Response(
|
||
text="Error loading checkpoints page",
|
||
status=500
|
||
)
|