mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Add checkpoint hash parameter parsing to backend routes - Implement checkpoint hash filtering in frontend API client - Add click navigation from recipe modal to checkpoints page - Update checkpoint items to use pointer cursor for better UX Checkpoint items in recipe modal are now clickable and will navigate to the checkpoints page with appropriate hash filtering applied. This improves user workflow when wanting to view checkpoint details from recipes.
113 lines
4.5 KiB
Python
113 lines
4.5 KiB
Python
import logging
|
|
from typing import Dict
|
|
from aiohttp import web
|
|
|
|
from .base_model_routes import BaseModelRoutes
|
|
from .model_route_registrar import ModelRouteRegistrar
|
|
from ..services.checkpoint_service import CheckpointService
|
|
from ..services.service_registry import ServiceRegistry
|
|
from ..config import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CheckpointRoutes(BaseModelRoutes):
|
|
"""Checkpoint-specific route controller"""
|
|
|
|
def __init__(self):
|
|
"""Initialize Checkpoint routes with Checkpoint service"""
|
|
super().__init__()
|
|
self.template_name = "checkpoints.html"
|
|
|
|
async def initialize_services(self):
|
|
"""Initialize services from ServiceRegistry"""
|
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
update_service = await ServiceRegistry.get_model_update_service()
|
|
self.service = CheckpointService(checkpoint_scanner, update_service=update_service)
|
|
self.set_model_update_service(update_service)
|
|
|
|
# Attach service dependencies
|
|
self.attach_service(self.service)
|
|
|
|
def setup_routes(self, app: web.Application):
|
|
"""Setup Checkpoint routes"""
|
|
# Schedule service initialization on app startup
|
|
app.on_startup.append(lambda _: self.initialize_services())
|
|
|
|
# Setup common routes with 'checkpoints' prefix (includes page route)
|
|
super().setup_routes(app, 'checkpoints')
|
|
|
|
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
|
"""Setup Checkpoint-specific routes"""
|
|
# Checkpoint info by name
|
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info)
|
|
|
|
# Checkpoint roots and Unet roots
|
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots)
|
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots)
|
|
|
|
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
|
"""Validate CivitAI model type for Checkpoint"""
|
|
return model_type.lower() == 'checkpoint'
|
|
|
|
def _get_expected_model_types(self) -> str:
|
|
"""Get expected model types string for error messages"""
|
|
return "Checkpoint"
|
|
|
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
|
"""Parse Checkpoint-specific parameters"""
|
|
params: Dict = {}
|
|
|
|
if 'checkpoint_hash' in request.query:
|
|
params['hash_filters'] = {'single_hash': request.query['checkpoint_hash'].lower()}
|
|
elif 'checkpoint_hashes' in request.query:
|
|
params['hash_filters'] = {
|
|
'multiple_hashes': [h.lower() for h in request.query['checkpoint_hashes'].split(',')]
|
|
}
|
|
|
|
return params
|
|
|
|
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
|
"""Get detailed information for a specific checkpoint by name"""
|
|
try:
|
|
name = request.match_info.get('name', '')
|
|
checkpoint_info = await self.service.get_model_info_by_name(name)
|
|
|
|
if checkpoint_info:
|
|
return web.json_response(checkpoint_info)
|
|
else:
|
|
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
|
|
"""Return the list of checkpoint roots from config"""
|
|
try:
|
|
roots = config.checkpoints_roots
|
|
return web.json_response({
|
|
"success": True,
|
|
"roots": roots
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
|
return web.json_response({
|
|
"success": False,
|
|
"error": str(e)
|
|
}, status=500)
|
|
|
|
async def get_unet_roots(self, request: web.Request) -> web.Response:
|
|
"""Return the list of unet roots from config"""
|
|
try:
|
|
roots = config.unet_roots
|
|
return web.json_response({
|
|
"success": True,
|
|
"roots": roots
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error getting unet roots: {e}", exc_info=True)
|
|
return web.json_response({
|
|
"success": False,
|
|
"error": str(e)
|
|
}, status=500)
|