From 0dbb76e8c8f7634c5256dd58dcee6d15c4d83efd Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sat, 12 Jul 2025 10:11:16 +0800 Subject: [PATCH] feat: Add download progress endpoint and implement progress tracking in WebSocketManager --- py/routes/api_routes.py | 32 +++++++++++++++++++++++ py/services/websocket_manager.py | 44 +++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index ee4bdb13..3a04815a 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -59,6 +59,7 @@ class ApiRoutes: app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) app.router.add_post('/api/download-model', routes.download_model) app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint + app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress app.router.add_post('/api/move_model', routes.move_model) app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route app.router.add_post('/api/loras/save-metadata', routes.save_metadata) @@ -500,6 +501,37 @@ class ApiRoutes: logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) return web.Response(status=500, text=error_message) + async def get_download_progress(self, request: web.Request) -> web.Response: + """Handle request for download progress by download_id""" + try: + # Get download_id from URL path + download_id = request.match_info.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + # Get progress information from websocket manager + progress_data = ws_manager.get_download_progress(download_id) + + if progress_data is None: + return web.json_response({ + 'success': False, + 'error': 'Download ID not found' + }, status=404) + + return web.json_response({ + 'success': True, + 'progress': progress_data.get('progress', 0) + }) + except Exception as e: + logger.error(f"Error getting download progress: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + async def move_model(self, request: web.Request) -> web.Response: """Handle model move request""" try: diff --git a/py/services/websocket_manager.py b/py/services/websocket_manager.py index 1692fe54..b3d70811 100644 --- a/py/services/websocket_manager.py +++ b/py/services/websocket_manager.py @@ -2,6 +2,8 @@ import logging from aiohttp import web from typing import Set, Dict, Optional from uuid import uuid4 +import asyncio +from datetime import datetime, timedelta logger = logging.getLogger(__name__) @@ -12,6 +14,8 @@ class WebSocketManager: self._websockets: Set[web.WebSocketResponse] = set() self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients + # Add progress tracking dictionary + self._download_progress: Dict[str, Dict] = {} async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: """Handle new WebSocket connection""" @@ -69,8 +73,19 @@ class WebSocketManager: finally: if download_id in self._download_websockets: del self._download_websockets[download_id] + + # Schedule cleanup of completed downloads after WebSocket disconnection + asyncio.create_task(self._delayed_cleanup(download_id)) return ws - + + async def _delayed_cleanup(self, download_id: str, delay_seconds: int = 300): + """Clean up download progress after a delay (5 minutes by default)""" + await asyncio.sleep(delay_seconds) + progress_data = self._download_progress.get(download_id) + if progress_data and progress_data.get('progress', 0) >= 100: + self.cleanup_download_progress(download_id) + logger.debug(f"Delayed cleanup completed for download {download_id}") + async def broadcast(self, data: Dict): """Broadcast message to all connected clients""" if not self._websockets: @@ -103,6 +118,12 @@ class WebSocketManager: async def broadcast_download_progress(self, download_id: str, data: Dict): """Send progress update to specific download client""" + # Store simplified progress data in memory (only progress percentage) + self._download_progress[download_id] = { + 'progress': data.get('progress', 0), + 'timestamp': datetime.now() + } + if download_id not in self._download_websockets: logger.debug(f"No WebSocket found for download ID: {download_id}") return @@ -113,6 +134,27 @@ class WebSocketManager: except Exception as e: logger.error(f"Error sending download progress: {e}") + def get_download_progress(self, download_id: str) -> Optional[Dict]: + """Get progress information for a specific download""" + return self._download_progress.get(download_id) + + def cleanup_download_progress(self, download_id: str): + """Remove progress info for a specific download""" + self._download_progress.pop(download_id, None) + + def cleanup_old_downloads(self, max_age_hours: int = 24): + """Clean up old download progress entries""" + cutoff_time = datetime.now() - timedelta(hours=max_age_hours) + to_remove = [] + + for download_id, progress_data in self._download_progress.items(): + if progress_data.get('timestamp', datetime.now()) < cutoff_time: + to_remove.append(download_id) + + for download_id in to_remove: + self._download_progress.pop(download_id, None) + logger.debug(f"Cleaned up old download progress for {download_id}") + def get_connected_clients_count(self) -> int: """Get number of connected clients""" return len(self._websockets)