Files
ComfyUI-Lora-Manager/py/routes/checkpoint_routes.py
Will Miao a2b81ea099 refactor: Implement base model routes and services for LoRA and Checkpoint
- Added BaseModelRoutes class to handle common routes and logic for model types.
- Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes.
- Implemented CheckpointService class for handling checkpoint-related data and operations.
- Developed LoraService class for managing LoRA-specific functionalities.
- Introduced ModelServiceFactory to manage service and route registrations for different model types.
- Established methods for fetching, filtering, and formatting model data across services.
- Integrated CivitAI metadata handling within model routes and services.
- Added pagination and filtering capabilities for model data retrieval.
2025-07-23 14:39:02 +08:00

170 lines
7.9 KiB
Python

import jinja2
import logging
from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
logger = logging.getLogger(__name__)
class CheckpointRoutes(BaseModelRoutes):
"""Checkpoint-specific route controller"""
def __init__(self):
"""Initialize Checkpoint routes with Checkpoint service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def setup_routes(self, app: web.Application):
"""Setup Checkpoint routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'checkpoints' prefix
super().setup_routes(app, 'checkpoints')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Checkpoint-specific routes"""
# Checkpoint page route
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
# Checkpoint-specific CivitAI integration
app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
# Checkpoint info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
try:
# Check if the CheckpointScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.service.scanner._cache is None or
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # Empty folder list
is_initializing=True, # New flag
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# Normal flow - get initialized cache data
try:
cache = await self.service.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# If getting cache fails, also show initialization page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
text="Error loading checkpoints page",
status=500
)
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.service.get_model_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if model_type.lower() != 'checkpoint':
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))