mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Fix config.py: save and restore main paths when processing extra folder paths to prevent _prepare_checkpoint_paths from overwriting checkpoints_roots and unet_roots - Fix lora_manager.py: apply library settings during initialization to load extra folder paths in ComfyUI plugin mode - Fix checkpoint_routes.py: merge checkpoints/unet roots with extra paths in API endpoints - Add logging for extra folder paths Fixes issue where extra folder paths were not recognized for checkpoints and unet models.
133 lines
5.6 KiB
Python
133 lines
5.6 KiB
Python
import logging
|
|
from typing import Dict, List, Set
|
|
from aiohttp import web
|
|
|
|
from .base_model_routes import BaseModelRoutes
|
|
from .model_route_registrar import ModelRouteRegistrar
|
|
from ..services.checkpoint_service import CheckpointService
|
|
from ..services.service_registry import ServiceRegistry
|
|
from ..config import config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class CheckpointRoutes(BaseModelRoutes):
|
|
"""Checkpoint-specific route controller"""
|
|
|
|
def __init__(self):
|
|
"""Initialize Checkpoint routes with Checkpoint service"""
|
|
super().__init__()
|
|
self.template_name = "checkpoints.html"
|
|
|
|
async def initialize_services(self):
|
|
"""Initialize services from ServiceRegistry"""
|
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
update_service = await ServiceRegistry.get_model_update_service()
|
|
self.service = CheckpointService(checkpoint_scanner, update_service=update_service)
|
|
self.set_model_update_service(update_service)
|
|
|
|
# Attach service dependencies
|
|
self.attach_service(self.service)
|
|
|
|
def setup_routes(self, app: web.Application):
|
|
"""Setup Checkpoint routes"""
|
|
# Schedule service initialization on app startup
|
|
app.on_startup.append(lambda _: self.initialize_services())
|
|
|
|
# Setup common routes with 'checkpoints' prefix (includes page route)
|
|
super().setup_routes(app, 'checkpoints')
|
|
|
|
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
|
"""Setup Checkpoint-specific routes"""
|
|
# Checkpoint info by name
|
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info)
|
|
|
|
# Checkpoint roots and Unet roots
|
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots)
|
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots)
|
|
|
|
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
|
"""Validate CivitAI model type for Checkpoint"""
|
|
return model_type.lower() == 'checkpoint'
|
|
|
|
def _get_expected_model_types(self) -> str:
|
|
"""Get expected model types string for error messages"""
|
|
return "Checkpoint"
|
|
|
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
|
"""Parse Checkpoint-specific parameters"""
|
|
params: Dict = {}
|
|
|
|
if 'checkpoint_hash' in request.query:
|
|
params['hash_filters'] = {'single_hash': request.query['checkpoint_hash'].lower()}
|
|
elif 'checkpoint_hashes' in request.query:
|
|
params['hash_filters'] = {
|
|
'multiple_hashes': [h.lower() for h in request.query['checkpoint_hashes'].split(',')]
|
|
}
|
|
|
|
return params
|
|
|
|
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
|
"""Get detailed information for a specific checkpoint by name"""
|
|
try:
|
|
name = request.match_info.get('name', '')
|
|
checkpoint_info = await self.service.get_model_info_by_name(name)
|
|
|
|
if checkpoint_info:
|
|
return web.json_response(checkpoint_info)
|
|
else:
|
|
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
|
return web.json_response({"error": str(e)}, status=500)
|
|
|
|
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
|
|
"""Return the list of checkpoint roots from config (including extra paths)"""
|
|
try:
|
|
# Merge checkpoints_roots with extra_checkpoints_roots, preserving order and removing duplicates
|
|
roots: List[str] = []
|
|
roots.extend(config.checkpoints_roots or [])
|
|
roots.extend(config.extra_checkpoints_roots or [])
|
|
# Remove duplicates while preserving order
|
|
seen: set = set()
|
|
unique_roots: List[str] = []
|
|
for root in roots:
|
|
if root and root not in seen:
|
|
seen.add(root)
|
|
unique_roots.append(root)
|
|
return web.json_response({
|
|
"success": True,
|
|
"roots": unique_roots
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
|
return web.json_response({
|
|
"success": False,
|
|
"error": str(e)
|
|
}, status=500)
|
|
|
|
async def get_unet_roots(self, request: web.Request) -> web.Response:
|
|
"""Return the list of unet roots from config (including extra paths)"""
|
|
try:
|
|
# Merge unet_roots with extra_unet_roots, preserving order and removing duplicates
|
|
roots: List[str] = []
|
|
roots.extend(config.unet_roots or [])
|
|
roots.extend(config.extra_unet_roots or [])
|
|
# Remove duplicates while preserving order
|
|
seen: set = set()
|
|
unique_roots: List[str] = []
|
|
for root in roots:
|
|
if root and root not in seen:
|
|
seen.add(root)
|
|
unique_roots.append(root)
|
|
return web.json_response({
|
|
"success": True,
|
|
"roots": unique_roots
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Error getting unet roots: {e}", exc_info=True)
|
|
return web.json_response({
|
|
"success": False,
|
|
"error": str(e)
|
|
}, status=500)
|