From ab22d16bad1c1f4097e412ca014852b6e5ac602f Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Jul 2025 19:21:25 +0800 Subject: [PATCH] feat: Rename download endpoint from /api/download-lora to /api/download-model and update related logic --- py/routes/api_routes.py | 6 +++--- py/routes/checkpoints_routes.py | 5 ----- py/services/download_manager.py | 15 ++++++++++++--- py/utils/routes_common.py | 10 +++------- static/js/managers/CheckpointDownloadManager.js | 4 ++-- static/js/managers/DownloadManager.js | 4 ++-- static/js/managers/import/DownloadManager.js | 4 ++-- static/js/utils/routes.js | 2 +- 8 files changed, 25 insertions(+), 25 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 37562ad2..29747105 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -56,7 +56,7 @@ class ApiRoutes: app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions) app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version) app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) - app.router.add_post('/api/download-lora', routes.download_lora) + app.router.add_post('/api/download-model', routes.download_model) app.router.add_post('/api/move_model', routes.move_model) app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route app.router.add_post('/api/loras/save-metadata', routes.save_metadata) @@ -436,8 +436,8 @@ class ApiRoutes: "error": str(e) }, status=500) - async def download_lora(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="lora") + async def download_model(self, request: web.Request) -> web.Response: + return await ModelRouteUtils.handle_download_model(request, self.download_manager) async def move_model(self, request: web.Request) -> web.Response: diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 873139f2..4edb4456 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -54,7 +54,6 @@ class CheckpointsRoutes: app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) - app.router.add_post('/api/checkpoints/download', self.download_checkpoint) app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint @@ -539,10 +538,6 @@ class CheckpointsRoutes: """Handle preview image replacement for checkpoints""" return await ModelRouteUtils.handle_replace_preview(request, self.scanner) - async def download_checkpoint(self, request: web.Request) -> web.Response: - """Handle checkpoint download request""" - return await ModelRouteUtils.handle_download_model(request, self.download_manager, model_type="checkpoint") - async def get_checkpoint_roots(self, request): """Return the checkpoint root directories""" try: diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 85b8bf52..a8048a26 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1,10 +1,9 @@ import logging import os -import json import asyncio from typing import Dict from ..utils.models import LoraMetadata, CheckpointMetadata -from ..utils.constants import CARD_PREVIEW_WIDTH +from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry @@ -51,7 +50,7 @@ class DownloadManager: async def download_from_civitai(self, model_id: str = None, model_version_id: str = None, save_dir: str = None, relative_path: str = '', progress_callback=None, - model_type: str = "lora") -> Dict: + model_type: str = None) -> Dict: """Download model from Civitai Args: @@ -81,6 +80,16 @@ class DownloadManager: if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} + # Infer model_type if not provided + if model_type is None: + model_type_from_info = version_info.get('model', {}).get('type', '').lower() + if model_type_from_info == 'checkpoint': + model_type = 'checkpoint' + elif model_type_from_info in VALID_LORA_TYPES: + model_type = 'lora' + else: + return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} + # Check if this is an early access model if version_info.get('earlyAccessEndsAt'): early_access_date = version_info.get('earlyAccessEndsAt', '') diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 992ad812..6d8af69f 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -564,7 +564,7 @@ class ModelRouteUtils: return web.Response(text=str(e), status=500) @staticmethod - async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: + async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type=None) -> web.Response: """Handle model download request Args: @@ -597,14 +597,10 @@ class ModelRouteUtils: text="Missing required parameter: Please provide 'model_id'" ) - # Use the correct root directory based on model type - root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root' - save_dir = data.get(root_key) - result = await download_manager.download_from_civitai( model_id=model_id, model_version_id=model_version_id, - save_dir=save_dir, + save_dir=data.get('model_root'), relative_path=data.get('relative_path', ''), progress_callback=progress_callback, model_type=model_type @@ -636,7 +632,7 @@ class ModelRouteUtils: text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." ) - logger.error(f"Error downloading {model_type}: {error_message}") + logger.error(f"Error downloading model: {error_message}") return web.Response(status=500, text=error_message) @staticmethod diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js index 0f69a484..49da7226 100644 --- a/static/js/managers/CheckpointDownloadManager.js +++ b/static/js/managers/CheckpointDownloadManager.js @@ -330,13 +330,13 @@ export class CheckpointDownloadManager { }; // Start download using checkpoint download endpoint - const response = await fetch('/api/checkpoints/download', { + const response = await fetch('/api/download-model', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model_id: this.modelId, model_version_id: this.currentVersion.id, - checkpoint_root: checkpointRoot, + model_root: checkpointRoot, relative_path: targetFolder }) }); diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 17f58f3d..38809e40 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -340,13 +340,13 @@ export class DownloadManager { }; // Start download - const response = await fetch('/api/download-lora', { + const response = await fetch('/api/download-model', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ model_id: this.modelId, model_version_id: this.currentVersion.id, - lora_root: loraRoot, + model_root: loraRoot, relative_path: targetFolder }) }); diff --git a/static/js/managers/import/DownloadManager.js b/static/js/managers/import/DownloadManager.js index 536ae3e0..0a21a805 100644 --- a/static/js/managers/import/DownloadManager.js +++ b/static/js/managers/import/DownloadManager.js @@ -189,14 +189,14 @@ export class DownloadManager { try { // Download the LoRA - const response = await fetch('/api/download-lora', { + const response = await fetch('/api/download-model', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ download_url: lora.downloadUrl, model_version_id: lora.modelVersionId, model_hash: lora.hash, - lora_root: loraRoot, + model_root: loraRoot, relative_path: targetPath.replace(loraRoot + '/', '') }) }); diff --git a/static/js/utils/routes.js b/static/js/utils/routes.js index b9d94f7b..9dba2b2c 100644 --- a/static/js/utils/routes.js +++ b/static/js/utils/routes.js @@ -7,7 +7,7 @@ export const apiRoutes = { delete: (id) => `/api/loras/${id}`, update: (id) => `/api/loras/${id}`, civitai: (id) => `/api/loras/${id}/civitai`, - download: '/api/download-lora', + download: '/api/download-model', move: '/api/move-lora', scan: '/api/scan-loras' },