From ea9370443d2de82138258bc8b5bed9fbef84bf40 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 23 Jul 2025 16:11:02 +0800 Subject: [PATCH] refactor: Implement download management routes and update API endpoints for LoRA --- py/routes/base_model_routes.py | 113 ++++++++++++++++ py/routes/lora_routes.py | 121 ------------------ py/utils/routes_common.py | 11 +- static/js/api/baseModelApi.js | 6 +- static/js/api/loraApi.js | 2 +- .../ContextMenu/ModelContextMenuMixin.js | 2 +- 6 files changed, 124 insertions(+), 131 deletions(-) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 8f5bc3f6..fc512d0c 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import asyncio import json import logging from aiohttp import web @@ -7,6 +8,7 @@ from typing import Dict import jinja2 from ..utils.routes_common import ModelRouteUtils +from ..services.service_registry import ServiceRegistry from ..services.websocket_manager import ws_manager from ..services.settings_manager import settings from ..config import config @@ -55,6 +57,12 @@ class BaseModelRoutes(ABC): app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots) app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models) app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts) + + # Common Download management + app.router.add_post(f'/api/download-model', self.download_model) + app.router.add_get(f'/api/download-model-get', self.download_model_get) + app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) + app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) # CivitAI integration routes app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai) @@ -408,6 +416,111 @@ class BaseModelRoutes(ABC): "success": False, "error": str(e) }, status=500) + + # Download management methods + async def download_model(self, request: web.Request) -> web.Response: + """Handle model download request""" + return await ModelRouteUtils.handle_download_model(request) + + async def download_model_get(self, request: web.Request) -> web.Response: + """Handle model download request via GET method""" + try: + # Extract query parameters + model_id = request.query.get('model_id') + if not model_id: + return web.Response( + status=400, + text="Missing required parameter: Please provide 'model_id'" + ) + + # Get optional parameters + model_version_id = request.query.get('model_version_id') + download_id = request.query.get('download_id') + use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' + + # Create a data dictionary that mimics what would be received from a POST request + data = { + 'model_id': model_id + } + + # Add optional parameters only if they are provided + if model_version_id: + data['model_version_id'] = model_version_id + + if download_id: + data['download_id'] = download_id + + data['use_default_paths'] = use_default_paths + + # Create a mock request object with the data + future = asyncio.get_event_loop().create_future() + future.set_result(data) + + mock_request = type('MockRequest', (), { + 'json': lambda self=None: future + })() + + # Call the existing download handler + return await ModelRouteUtils.handle_download_model(mock_request) + + except Exception as e: + error_message = str(e) + logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) + return web.Response(status=500, text=error_message) + + async def cancel_download_get(self, request: web.Request) -> web.Response: + """Handle GET request for cancelling a download by download_id""" + try: + download_id = request.query.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Create a mock request with match_info for compatibility + mock_request = type('MockRequest', (), { + 'match_info': {'download_id': download_id} + })() + return await ModelRouteUtils.handle_cancel_download(mock_request) + except Exception as e: + logger.error(f"Error cancelling download via GET: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + async def get_download_progress(self, request: web.Request) -> web.Response: + """Handle request for download progress by download_id""" + try: + # Get download_id from URL path + download_id = request.match_info.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Get progress information from websocket manager + from ..services.websocket_manager import ws_manager + progress_data = ws_manager.get_download_progress(download_id) + + if progress_data is None: + return web.json_response({ + 'success': False, + 'error': 'Download ID not found' + }, status=404) + + return web.json_response({ + 'success': True, + 'progress': progress_data.get('progress', 0) + }) + except Exception as e: + logger.error(f"Error getting download progress: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def fetch_all_civitai(self, request: web.Request) -> web.Response: """Fetch CivitAI metadata for all models in the background""" diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index 0a9d7dff..fde5586b 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -20,8 +20,6 @@ class LoraRoutes(BaseModelRoutes): # Service will be initialized later via setup_routes self.service = None self.civitai_client = None - self.download_manager = None - self._download_lock = asyncio.Lock() self.template_name = "loras.html" async def initialize_services(self): @@ -29,7 +27,6 @@ class LoraRoutes(BaseModelRoutes): lora_scanner = await ServiceRegistry.get_lora_scanner() self.service = LoraService(lora_scanner) self.civitai_client = await ServiceRegistry.get_civitai_client() - self.download_manager = await ServiceRegistry.get_download_manager() # Initialize parent with the service super().__init__(self.service) @@ -63,21 +60,8 @@ class LoraRoutes(BaseModelRoutes): app.router.add_get(f'/api/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version) app.router.add_get(f'/api/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash) - # Download management - app.router.add_post(f'/api/download-model', self.download_model) - app.router.add_get(f'/api/download-model-get', self.download_model_get) - app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get) - app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress) - # ComfyUI integration app.router.add_post(f'/loramanager/get_trigger_words', self.get_trigger_words) - - # Legacy API compatibility - app.router.add_post(f'/api/delete_model', self.delete_model) - app.router.add_post(f'/api/fetch-civitai', self.fetch_civitai) - app.router.add_post(f'/api/relink-civitai', self.relink_civitai) - app.router.add_post(f'/api/replace_preview', self.replace_preview) - app.router.add_post(f'/api/fetch-all-civitai', self.fetch_all_civitai) def _parse_specific_params(self, request: web.Request) -> Dict: """Parse LoRA-specific parameters""" @@ -358,111 +342,6 @@ class LoraRoutes(BaseModelRoutes): "error": str(e) }, status=500) - # Download management methods - async def download_model(self, request: web.Request) -> web.Response: - """Handle model download request""" - return await ModelRouteUtils.handle_download_model(request, self.download_manager) - - async def download_model_get(self, request: web.Request) -> web.Response: - """Handle model download request via GET method""" - try: - # Extract query parameters - model_id = request.query.get('model_id') - if not model_id: - return web.Response( - status=400, - text="Missing required parameter: Please provide 'model_id'" - ) - - # Get optional parameters - model_version_id = request.query.get('model_version_id') - download_id = request.query.get('download_id') - use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true' - - # Create a data dictionary that mimics what would be received from a POST request - data = { - 'model_id': model_id - } - - # Add optional parameters only if they are provided - if model_version_id: - data['model_version_id'] = model_version_id - - if download_id: - data['download_id'] = download_id - - data['use_default_paths'] = use_default_paths - - # Create a mock request object with the data - future = asyncio.get_event_loop().create_future() - future.set_result(data) - - mock_request = type('MockRequest', (), { - 'json': lambda self=None: future - })() - - # Call the existing download handler - return await ModelRouteUtils.handle_download_model(mock_request, self.download_manager) - - except Exception as e: - error_message = str(e) - logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) - return web.Response(status=500, text=error_message) - - async def cancel_download_get(self, request: web.Request) -> web.Response: - """Handle GET request for cancelling a download by download_id""" - try: - download_id = request.query.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Create a mock request with match_info for compatibility - mock_request = type('MockRequest', (), { - 'match_info': {'download_id': download_id} - })() - return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) - except Exception as e: - logger.error(f"Error cancelling download via GET: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_download_progress(self, request: web.Request) -> web.Response: - """Handle request for download progress by download_id""" - try: - # Get download_id from URL path - download_id = request.match_info.get('download_id') - if not download_id: - return web.json_response({ - 'success': False, - 'error': 'Download ID is required' - }, status=400) - - # Get progress information from websocket manager - from ..services.websocket_manager import ws_manager - progress_data = ws_manager.get_download_progress(download_id) - - if progress_data is None: - return web.json_response({ - 'success': False, - 'error': 'Download ID not found' - }, status=404) - - return web.json_response({ - 'success': True, - 'progress': progress_data.get('progress', 0) - }) - except Exception as e: - logger.error(f"Error getting download progress: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - # Model management methods async def move_model(self, request: web.Request) -> web.Response: """Handle model move request""" diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 8e5df544..2704b20f 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -566,9 +566,10 @@ class ModelRouteUtils: return web.Response(text=str(e), status=500) @staticmethod - async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_download_model(request: web.Request) -> web.Response: """Handle model download request""" try: + download_manager = await ServiceRegistry.get_download_manager() data = await request.json() # Get or generate a download ID @@ -663,17 +664,17 @@ class ModelRouteUtils: }, status=500) @staticmethod - async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_cancel_download(request: web.Request) -> web.Response: """Handle cancellation of a download task Args: request: The aiohttp request - download_manager: The download manager instance Returns: web.Response: The HTTP response """ try: + download_manager = await ServiceRegistry.get_download_manager() download_id = request.match_info.get('download_id') if not download_id: return web.json_response({ @@ -701,17 +702,17 @@ class ModelRouteUtils: }, status=500) @staticmethod - async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response: + async def handle_list_downloads(request: web.Request) -> web.Response: """Get list of active downloads Args: request: The aiohttp request - download_manager: The download manager instance Returns: web.Response: The HTTP response with list of downloads """ try: + download_manager = await ServiceRegistry.get_download_manager() result = await download_manager.get_active_downloads() return web.json_response(result) except Exception as e: diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 43abbedf..e15898a5 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -277,7 +277,7 @@ export async function deleteModel(filePath, modelType = 'lora') { const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/delete' - : '/api/delete_model'; + : '/api/loras/delete'; const response = await fetch(endpoint, { method: 'POST', @@ -454,7 +454,7 @@ export async function refreshSingleModelMetadata(filePath, modelType = 'lora') { const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/fetch-civitai' - : '/api/fetch-civitai'; + : '/api/loras/fetch-civitai'; const response = await fetch(endpoint, { method: 'POST', @@ -557,7 +557,7 @@ export async function uploadPreview(filePath, file, modelType = 'lora', nsfwLeve // Set endpoint based on model type const endpoint = modelType === 'checkpoint' ? '/api/checkpoints/replace-preview' - : '/api/replace_preview'; + : '/api/loras/replace_preview'; const response = await fetch(endpoint, { method: 'POST', diff --git a/static/js/api/loraApi.js b/static/js/api/loraApi.js index 3c6f0c86..9d5dd1bc 100644 --- a/static/js/api/loraApi.js +++ b/static/js/api/loraApi.js @@ -89,7 +89,7 @@ export async function fetchLorasPage(page = 1, pageSize = 100) { export async function fetchCivitai() { return fetchCivitaiMetadata({ modelType: 'lora', - fetchEndpoint: '/api/fetch-all-civitai', + fetchEndpoint: '/api/loras/fetch-all-civitai', resetAndReloadFunction: resetAndReload }); } diff --git a/static/js/components/ContextMenu/ModelContextMenuMixin.js b/static/js/components/ContextMenu/ModelContextMenuMixin.js index 7c0bc0c3..cd58bc0f 100644 --- a/static/js/components/ContextMenu/ModelContextMenuMixin.js +++ b/static/js/components/ContextMenu/ModelContextMenuMixin.js @@ -125,7 +125,7 @@ export const ModelContextMenuMixin = { const endpoint = this.modelType === 'checkpoint' ? '/api/checkpoints/relink-civitai' : - '/api/relink-civitai'; + '/api/loras/relink-civitai'; const response = await fetch(endpoint, { method: 'POST',