checkpoint

This commit is contained in:
Will Miao
2025-04-11 20:22:12 +08:00
parent 1db49a4dd4
commit 0618541527
13 changed files with 793 additions and 276 deletions

View File

@@ -11,6 +11,7 @@ from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager
from ..services.checkpoint_scanner import CheckpointScanner
from ..services.download_manager import DownloadManager
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
@@ -21,16 +22,24 @@ class CheckpointsRoutes:
"""API routes for checkpoint management"""
def __init__(self):
self.scanner = CheckpointScanner()
self.scanner = None # Will be initialized in setup_routes
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.download_manager = DownloadManager()
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
self.download_manager = await ServiceRegistry.get_download_manager()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
@@ -488,10 +497,9 @@ class CheckpointsRoutes:
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Initialize DownloadManager with the file monitor if the scanner has one
if not hasattr(self, 'download_manager') or self.download_manager is None:
file_monitor = getattr(self.scanner, 'file_monitor', None)
self.download_manager = DownloadManager(file_monitor)
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
# Use the common download handler with model_type="checkpoint"
return await ModelRouteUtils.handle_download_model(
@@ -503,6 +511,9 @@ class CheckpointsRoutes:
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,