From 7a4b5a466719ab7caff62d54d9f3501b0a3c83d4 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Jul 2025 23:48:35 +0800 Subject: [PATCH] feat: Implement download progress WebSocket and enhance download manager with unique IDs --- py/routes/api_routes.py | 1 + py/services/download_manager.py | 53 ++++++++++----- py/services/websocket_manager.py | 67 +++++++++++++++---- py/utils/routes_common.py | 24 ++++--- .../js/managers/CheckpointDownloadManager.js | 22 ++++-- static/js/managers/DownloadManager.js | 25 +++++-- static/js/managers/import/DownloadManager.js | 25 +++++-- 7 files changed, 160 insertions(+), 57 deletions(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 29747105..7dfb3e8b 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -50,6 +50,7 @@ class ApiRoutes: app.router.add_get('/api/loras', routes.get_loras) app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai) app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) + app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route app.router.add_get('/api/lora-roots', routes.get_lora_roots) app.router.add_get('/api/folders', routes.get_folders) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index a8048a26..fc3b6448 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -7,6 +7,7 @@ from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry +from .settings_manager import settings # Download to temporary file first import tempfile @@ -49,8 +50,7 @@ class DownloadManager: async def download_from_civitai(self, model_id: str = None, model_version_id: str = None, save_dir: str = None, - relative_path: str = '', progress_callback=None, - model_type: str = None) -> Dict: + relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict: """Download model from Civitai Args: @@ -59,18 +59,12 @@ class DownloadManager: save_dir: Directory to save the model to relative_path: Relative path within save_dir progress_callback: Callback function for progress updates - model_type: Type of model ('lora' or 'checkpoint') + use_default_paths: Flag to indicate whether to use default paths Returns: Dict with download result """ try: - # Update save directory with relative path if provided - if relative_path: - save_dir = os.path.join(save_dir, relative_path) - # Create directory if it doesn't exist - os.makedirs(save_dir, exist_ok=True) - # Get civitai client civitai_client = await self._get_civitai_client() @@ -80,15 +74,38 @@ class DownloadManager: if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} - # Infer model_type if not provided - if model_type is None: - model_type_from_info = version_info.get('model', {}).get('type', '').lower() - if model_type_from_info == 'checkpoint': - model_type = 'checkpoint' - elif model_type_from_info in VALID_LORA_TYPES: - model_type = 'lora' - else: - return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} + model_type_from_info = version_info.get('model', {}).get('type', '').lower() + if model_type_from_info == 'checkpoint': + model_type = 'checkpoint' + elif model_type_from_info in VALID_LORA_TYPES: + model_type = 'lora' + else: + return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} + + # Handle use_default_paths + if use_default_paths: + # Set save_dir based on model type + if model_type == 'checkpoint': + default_path = settings.get('default_checkpoint_root') + if not default_path: + return {'success': False, 'error': 'Default checkpoint root path not set in settings'} + save_dir = default_path + else: # model_type == 'lora' + default_path = settings.get('default_lora_root') + if not default_path: + return {'success': False, 'error': 'Default lora root path not set in settings'} + save_dir = default_path + + # Set relative_path to the first tag if available + model_tags = version_info.get('model', {}).get('tags', []) + if model_tags: + relative_path = model_tags[0] + + # Update save directory with relative path if provided + if relative_path: + save_dir = os.path.join(save_dir, relative_path) + # Create directory if it doesn't exist + os.makedirs(save_dir, exist_ok=True) # Check if this is an early access model if version_info.get('earlyAccessEndsAt'): diff --git a/py/services/websocket_manager.py b/py/services/websocket_manager.py index 1887ee1f..958f0e38 100644 --- a/py/services/websocket_manager.py +++ b/py/services/websocket_manager.py @@ -1,6 +1,7 @@ import logging from aiohttp import web from typing import Set, Dict, Optional +from uuid import uuid4 logger = logging.getLogger(__name__) @@ -10,7 +11,7 @@ class WebSocketManager: def __init__(self): self._websockets: Set[web.WebSocketResponse] = set() self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients - self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress + self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: """Handle new WebSocket connection""" @@ -39,6 +40,39 @@ class WebSocketManager: finally: self._init_websockets.discard(ws) return ws + + async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse: + """Handle new WebSocket connection for download progress""" + ws = web.WebSocketResponse() + await ws.prepare(request) + + # Get download_id from query parameters + download_id = request.query.get('id') + + if not download_id: + # Generate a new download ID if not provided + download_id = str(uuid4()) + logger.info(f"Created new download ID: {download_id}") + else: + logger.info(f"Using provided download ID: {download_id}") + + # Store the websocket with its download ID + self._download_websockets[download_id] = ws + + try: + # Send the download ID back to the client + await ws.send_json({ + 'type': 'download_id', + 'download_id': download_id + }) + + async for msg in ws: + if msg.type == web.WSMsgType.ERROR: + logger.error(f'Download WebSocket error: {ws.exception()}') + finally: + if download_id in self._download_websockets: + del self._download_websockets[download_id] + return ws async def broadcast(self, data: Dict): """Broadcast message to all connected clients""" @@ -70,17 +104,18 @@ class WebSocketManager: except Exception as e: logger.error(f"Error sending initialization progress: {e}") - async def broadcast_checkpoint_progress(self, data: Dict): - """Broadcast checkpoint download progress to connected clients""" - if not self._checkpoint_websockets: + async def broadcast_download_progress(self, download_id: str, data: Dict): + """Send progress update to specific download client""" + if download_id not in self._download_websockets: + logger.debug(f"No WebSocket found for download ID: {download_id}") return - for ws in self._checkpoint_websockets: - try: - await ws.send_json(data) - except Exception as e: - logger.error(f"Error sending checkpoint progress: {e}") - + ws = self._download_websockets[download_id] + try: + await ws.send_json(data) + except Exception as e: + logger.error(f"Error sending download progress: {e}") + def get_connected_clients_count(self) -> int: """Get number of connected clients""" return len(self._websockets) @@ -88,10 +123,14 @@ class WebSocketManager: def get_init_clients_count(self) -> int: """Get number of initialization progress clients""" return len(self._init_websockets) - - def get_checkpoint_clients_count(self) -> int: - """Get number of checkpoint progress clients""" - return len(self._checkpoint_websockets) + + def get_download_clients_count(self) -> int: + """Get number of download progress clients""" + return len(self._download_websockets) + + def generate_download_id(self) -> str: + """Generate a unique download ID""" + return str(uuid4()) # Global instance ws_manager = WebSocketManager() \ No newline at end of file diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index 34f5b233..e13b5cff 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -12,6 +12,7 @@ from ..services.service_registry import ServiceRegistry from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from ..services.download_manager import DownloadManager +from ..services.websocket_manager import ws_manager logger = logging.getLogger(__name__) @@ -565,13 +566,12 @@ class ModelRouteUtils: return web.Response(text=str(e), status=500) @staticmethod - async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type=None) -> web.Response: + async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response: """Handle model download request Args: request: The aiohttp request download_manager: Instance of DownloadManager - model_type: Type of model ('lora' or 'checkpoint') Returns: web.Response: The HTTP response @@ -579,12 +579,15 @@ class ModelRouteUtils: try: data = await request.json() - # Create progress callback + # Get or generate a download ID + download_id = data.get('download_id', ws_manager.generate_download_id()) + + # Create progress callback with download ID async def progress_callback(progress): - from ..services.websocket_manager import ws_manager - await ws_manager.broadcast({ + await ws_manager.broadcast_download_progress(download_id, { 'status': 'progress', - 'progress': progress + 'progress': progress, + 'download_id': download_id }) # Check which identifier is provided @@ -598,15 +601,20 @@ class ModelRouteUtils: text="Missing required parameter: Please provide 'model_id'" ) + use_default_paths = data.get('use_default_paths', False) + result = await download_manager.download_from_civitai( model_id=model_id, model_version_id=model_version_id, save_dir=data.get('model_root'), relative_path=data.get('relative_path', ''), - progress_callback=progress_callback, - model_type=model_type + use_default_paths=use_default_paths, + progress_callback=progress_callback ) + # Include download_id in the response + result['download_id'] = download_id + if not result.get('success', False): error_message = result.get('error', 'Unknown error') diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js index c9167cfa..7545e617 100644 --- a/static/js/managers/CheckpointDownloadManager.js +++ b/static/js/managers/CheckpointDownloadManager.js @@ -301,13 +301,24 @@ export class CheckpointDownloadManager { const updateProgress = this.loadingManager.showDownloadProgress(1); updateProgress(0, 0, this.currentVersion.name); - // Setup WebSocket for progress updates using checkpoint-specific endpoint + // Generate a unique ID for this download + const downloadId = Date.now().toString(); + + // Setup WebSocket for progress updates using download-specific endpoint const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`); ws.onmessage = (event) => { const data = JSON.parse(event.data); - if (data.status === 'progress') { + + // Handle download ID confirmation + if (data.type === 'download_id') { + console.log(`Connected to checkpoint download progress with ID: ${data.download_id}`); + return; + } + + // Only process progress updates for our download + if (data.status === 'progress' && data.download_id === downloadId) { // Update progress display with current progress updateProgress(data.progress, 0, this.currentVersion.name); @@ -329,7 +340,7 @@ export class CheckpointDownloadManager { // Continue with download even if WebSocket fails }; - // Start download using checkpoint download endpoint + // Start download using checkpoint download endpoint with download ID const response = await fetch('/api/download-model', { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -337,7 +348,8 @@ export class CheckpointDownloadManager { model_id: this.modelId, model_version_id: this.currentVersion.id, model_root: checkpointRoot, - relative_path: targetFolder + relative_path: targetFolder, + download_id: downloadId }) }); diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 38809e40..b87a0e6b 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -311,13 +311,24 @@ export class DownloadManager { const updateProgress = this.loadingManager.showDownloadProgress(1); updateProgress(0, 0, this.currentVersion.name); - // Setup WebSocket for progress updates + // Generate a unique ID for this download + const downloadId = Date.now().toString(); + + // Setup WebSocket for progress updates - use download-specific endpoint const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`); ws.onmessage = (event) => { const data = JSON.parse(event.data); - if (data.status === 'progress') { + + // Handle download ID confirmation + if (data.type === 'download_id') { + console.log(`Connected to download progress with ID: ${data.download_id}`); + return; + } + + // Only process progress updates for our download + if (data.status === 'progress' && data.download_id === downloadId) { // Update progress display with current progress updateProgress(data.progress, 0, this.currentVersion.name); @@ -339,7 +350,7 @@ export class DownloadManager { // Continue with download even if WebSocket fails }; - // Start download + // Start download with our download ID const response = await fetch('/api/download-model', { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -347,7 +358,8 @@ export class DownloadManager { model_id: this.modelId, model_version_id: this.currentVersion.id, model_root: loraRoot, - relative_path: targetFolder + relative_path: targetFolder, + download_id: downloadId }) }); @@ -358,6 +370,9 @@ export class DownloadManager { showToast('Download completed successfully', 'success'); modalManager.closeModal('downloadModal'); + // Close WebSocket after download completes + ws.close(); + // Update state and trigger reload with folder update state.activeFolder = targetFolder; await resetAndReload(true); // Pass true to update folders diff --git a/static/js/managers/import/DownloadManager.js b/static/js/managers/import/DownloadManager.js index 0a21a805..00e5d497 100644 --- a/static/js/managers/import/DownloadManager.js +++ b/static/js/managers/import/DownloadManager.js @@ -128,9 +128,12 @@ export class DownloadManager { targetPath += '/' + newFolder; } + // Generate a unique ID for this batch download + const batchDownloadId = Date.now().toString(); + // Set up WebSocket for progress updates const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${batchDownloadId}`); // Show enhanced loading with progress details for multiple items const updateProgress = this.importManager.loadingManager.showDownloadProgress( @@ -145,7 +148,15 @@ export class DownloadManager { // Set up progress tracking for current download ws.onmessage = (event) => { const data = JSON.parse(event.data); - if (data.status === 'progress') { + + // Handle download ID confirmation + if (data.type === 'download_id') { + console.log(`Connected to batch download progress with ID: ${data.download_id}`); + return; + } + + // Process progress updates for our current active download + if (data.status === 'progress' && data.download_id && data.download_id.startsWith(batchDownloadId)) { // Update current LoRA progress currentLoraProgress = data.progress; @@ -188,16 +199,16 @@ export class DownloadManager { updateProgress(0, completedDownloads, lora.name); try { - // Download the LoRA + // Download the LoRA with download ID const response = await fetch('/api/download-model', { method: 'POST', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify({ - download_url: lora.downloadUrl, - model_version_id: lora.modelVersionId, - model_hash: lora.hash, + model_id: lora.modelId, + model_version_id: lora.id, model_root: loraRoot, - relative_path: targetPath.replace(loraRoot + '/', '') + relative_path: targetPath.replace(loraRoot + '/', ''), + download_id: batchDownloadId }) });