refactor: Enhance checkpoint download functionality with new modal and manager integration

This commit is contained in:
Will Miao
2025-04-11 18:25:37 +08:00
parent 3df96034a1
commit 1db49a4dd4
10 changed files with 699 additions and 43 deletions

View File

@@ -3,12 +3,14 @@ import json
import jinja2
from aiohttp import web
import logging
import asyncio
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS
from ..services.civitai_client import CivitaiClient
from ..services.websocket_manager import ws_manager
from ..services.checkpoint_scanner import CheckpointScanner
from ..services.download_manager import DownloadManager
from ..config import config
from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
@@ -24,6 +26,8 @@ class CheckpointsRoutes:
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.download_manager = DownloadManager()
self._download_lock = asyncio.Lock()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
@@ -34,11 +38,13 @@ class CheckpointsRoutes:
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
# Add new routes for model management similar to LoRA routes
app.router.add_post('/api/checkpoints/delete', self.delete_model)
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
@@ -478,3 +484,33 @@ class CheckpointsRoutes:
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement for checkpoints"""
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Initialize DownloadManager with the file monitor if the scanner has one
if not hasattr(self, 'download_manager') or self.download_manager is None:
file_monitor = getattr(self.scanner, 'file_monitor', None)
self.download_manager = DownloadManager(file_monitor)
# Use the common download handler with model_type="checkpoint"
return await ModelRouteUtils.handle_download_model(
request=request,
download_manager=self.download_manager,
model_type="checkpoint"
)
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)