import os import json import asyncio import aiohttp import jinja2 from aiohttp import web import logging from datetime import datetime from ..services.checkpoint_scanner import CheckpointScanner from ..config import config from ..services.settings_manager import settings logger = logging.getLogger(__name__) class CheckpointsRoutes: """API routes for checkpoint management""" def __init__(self): self.scanner = CheckpointScanner() self.template_env = jinja2.Environment( loader=jinja2.FileSystemLoader(config.templates_path), autoescape=True ) def setup_routes(self, app): """Register routes with the aiohttp app""" app.router.add_get('/lora_manager/api/checkpoints', self.get_checkpoints) app.router.add_get('/lora_manager/api/checkpoints/scan', self.scan_checkpoints) app.router.add_get('/lora_manager/api/checkpoints/info/{name}', self.get_checkpoint_info) async def get_checkpoints(self, request): """Get paginated checkpoint data""" try: # Parse query parameters page = int(request.query.get('page', '1')) page_size = min(int(request.query.get('page_size', '20')), 100) sort_by = request.query.get('sort', 'name') folder = request.query.get('folder', None) search = request.query.get('search', None) fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true' base_models = request.query.getall('base_model', []) tags = request.query.getall('tag', []) # Process search options search_options = { 'filename': request.query.get('search_filename', 'true').lower() == 'true', 'modelname': request.query.get('search_modelname', 'true').lower() == 'true', 'tags': request.query.get('search_tags', 'false').lower() == 'true', 'recursive': request.query.get('recursive', 'false').lower() == 'true', } # Process hash filters if provided hash_filters = {} if 'hash' in request.query: hash_filters['single_hash'] = request.query['hash'] elif 'hashes' in request.query: try: hash_list = json.loads(request.query['hashes']) if isinstance(hash_list, list): hash_filters['multiple_hashes'] = hash_list except (json.JSONDecodeError, TypeError): pass # Get data from scanner result = await self.get_paginated_data( page=page, page_size=page_size, sort_by=sort_by, folder=folder, search=search, fuzzy_search=fuzzy_search, base_models=base_models, tags=tags, search_options=search_options, hash_filters=hash_filters ) # Return as JSON return web.json_response(result) except Exception as e: logger.error(f"Error in get_checkpoints: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) async def get_paginated_data(self, page, page_size, sort_by='name', folder=None, search=None, fuzzy_search=False, base_models=None, tags=None, search_options=None, hash_filters=None): """Get paginated and filtered checkpoint data""" cache = await self.scanner.get_cached_data() # Implement similar filtering logic as in LoraScanner # (Adapt code from LoraScanner.get_paginated_data) # ... # For now, a simplified implementation: filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name # Apply basic folder filtering if needed if folder is not None: filtered_data = [ cp for cp in filtered_data if cp['folder'] == folder ] # Apply basic search if needed if search: filtered_data = [ cp for cp in filtered_data if search.lower() in cp['file_name'].lower() or search.lower() in cp['model_name'].lower() ] # Calculate pagination total_items = len(filtered_data) start_idx = (page - 1) * page_size end_idx = min(start_idx + page_size, total_items) result = { 'items': filtered_data[start_idx:end_idx], 'total': total_items, 'page': page, 'page_size': page_size, 'total_pages': (total_items + page_size - 1) // page_size } return result async def scan_checkpoints(self, request): """Force a rescan of checkpoint files""" try: await self.scanner.get_cached_data(force_refresh=True) return web.json_response({"status": "success", "message": "Checkpoint scan completed"}) except Exception as e: logger.error(f"Error in scan_checkpoints: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) async def get_checkpoint_info(self, request): """Get detailed information for a specific checkpoint by name""" try: name = request.match_info.get('name', '') checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name) if checkpoint_info: return web.json_response(checkpoint_info) else: return web.json_response({"error": "Checkpoint not found"}, status=404) except Exception as e: logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) return web.json_response({"error": str(e)}, status=500) async def handle_checkpoints_page(self, request: web.Request) -> web.Response: """Handle GET /checkpoints request""" try: # 检查缓存初始化状态,根据initialize_in_background的工作方式调整判断逻辑 is_initializing = ( self.scanner._cache is None or len(self.scanner._cache.raw_data) == 0 or hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing ) if is_initializing: # 如果正在初始化,返回一个只包含加载提示的页面 template = self.template_env.get_template('checkpoints.html') rendered = template.render( folders=[], # 空文件夹列表 is_initializing=True, # 新增标志 settings=settings, # Pass settings to template request=request # Pass the request object to the template ) logger.info("Checkpoints page is initializing, returning loading page") else: # 正常流程 - 获取已经初始化好的缓存数据 try: cache = await self.scanner.get_cached_data(force_refresh=False) template = self.template_env.get_template('checkpoints.html') rendered = template.render( folders=cache.folders, is_initializing=False, settings=settings, # Pass settings to template request=request # Pass the request object to the template ) logger.debug(f"Checkpoints page loaded successfully with {len(cache.raw_data)} items") except Exception as cache_error: logger.error(f"Error loading checkpoints cache data: {cache_error}") # 如果获取缓存失败,也显示初始化页面 template = self.template_env.get_template('checkpoints.html') rendered = template.render( folders=[], is_initializing=True, settings=settings, request=request ) logger.info("Checkpoints cache error, returning initialization page") return web.Response( text=rendered, content_type='text/html' ) except Exception as e: logger.error(f"Error handling checkpoints request: {e}", exc_info=True) return web.Response( text="Error loading checkpoints page", status=500 )