feat: Add download progress endpoint and implement progress tracking in WebSocketManager

This commit is contained in:
Will Miao
2025-07-12 10:11:16 +08:00
parent f73b3422a6
commit 0dbb76e8c8
2 changed files with 75 additions and 1 deletions

View File

@@ -59,6 +59,7 @@ class ApiRoutes:
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
app.router.add_post('/api/download-model', routes.download_model) app.router.add_post('/api/download-model', routes.download_model)
app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint
app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress
app.router.add_post('/api/move_model', routes.move_model) app.router.add_post('/api/move_model', routes.move_model)
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
app.router.add_post('/api/loras/save-metadata', routes.save_metadata) app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
@@ -500,6 +501,37 @@ class ApiRoutes:
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
return web.Response(status=500, text=error_message) return web.Response(status=500, text=error_message)
async def get_download_progress(self, request: web.Request) -> web.Response:
"""Handle request for download progress by download_id"""
try:
# Get download_id from URL path
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
# Get progress information from websocket manager
progress_data = ws_manager.get_download_progress(download_id)
if progress_data is None:
return web.json_response({
'success': False,
'error': 'Download ID not found'
}, status=404)
return web.json_response({
'success': True,
'progress': progress_data.get('progress', 0)
})
except Exception as e:
logger.error(f"Error getting download progress: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def move_model(self, request: web.Request) -> web.Response: async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request""" """Handle model move request"""
try: try:

View File

@@ -2,6 +2,8 @@ import logging
from aiohttp import web from aiohttp import web
from typing import Set, Dict, Optional from typing import Set, Dict, Optional
from uuid import uuid4 from uuid import uuid4
import asyncio
from datetime import datetime, timedelta
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -12,6 +14,8 @@ class WebSocketManager:
self._websockets: Set[web.WebSocketResponse] = set() self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
# Add progress tracking dictionary
self._download_progress: Dict[str, Dict] = {}
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection""" """Handle new WebSocket connection"""
@@ -69,8 +73,19 @@ class WebSocketManager:
finally: finally:
if download_id in self._download_websockets: if download_id in self._download_websockets:
del self._download_websockets[download_id] del self._download_websockets[download_id]
# Schedule cleanup of completed downloads after WebSocket disconnection
asyncio.create_task(self._delayed_cleanup(download_id))
return ws return ws
async def _delayed_cleanup(self, download_id: str, delay_seconds: int = 300):
"""Clean up download progress after a delay (5 minutes by default)"""
await asyncio.sleep(delay_seconds)
progress_data = self._download_progress.get(download_id)
if progress_data and progress_data.get('progress', 0) >= 100:
self.cleanup_download_progress(download_id)
logger.debug(f"Delayed cleanup completed for download {download_id}")
async def broadcast(self, data: Dict): async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients""" """Broadcast message to all connected clients"""
if not self._websockets: if not self._websockets:
@@ -103,6 +118,12 @@ class WebSocketManager:
async def broadcast_download_progress(self, download_id: str, data: Dict): async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Send progress update to specific download client""" """Send progress update to specific download client"""
# Store simplified progress data in memory (only progress percentage)
self._download_progress[download_id] = {
'progress': data.get('progress', 0),
'timestamp': datetime.now()
}
if download_id not in self._download_websockets: if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}") logger.debug(f"No WebSocket found for download ID: {download_id}")
return return
@@ -113,6 +134,27 @@ class WebSocketManager:
except Exception as e: except Exception as e:
logger.error(f"Error sending download progress: {e}") logger.error(f"Error sending download progress: {e}")
def get_download_progress(self, download_id: str) -> Optional[Dict]:
"""Get progress information for a specific download"""
return self._download_progress.get(download_id)
def cleanup_download_progress(self, download_id: str):
"""Remove progress info for a specific download"""
self._download_progress.pop(download_id, None)
def cleanup_old_downloads(self, max_age_hours: int = 24):
"""Clean up old download progress entries"""
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
to_remove = []
for download_id, progress_data in self._download_progress.items():
if progress_data.get('timestamp', datetime.now()) < cutoff_time:
to_remove.append(download_id)
for download_id in to_remove:
self._download_progress.pop(download_id, None)
logger.debug(f"Cleaned up old download progress for {download_id}")
def get_connected_clients_count(self) -> int: def get_connected_clients_count(self) -> int:
"""Get number of connected clients""" """Get number of connected clients"""
return len(self._websockets) return len(self._websockets)