feat: Add download progress endpoint and implement progress tracking in WebSocketManager

This commit is contained in:
Will Miao
2025-07-12 10:11:16 +08:00
parent f73b3422a6
commit 0dbb76e8c8
2 changed files with 75 additions and 1 deletions

View File

@@ -2,6 +2,8 @@ import logging
from aiohttp import web
from typing import Set, Dict, Optional
from uuid import uuid4
import asyncio
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
@@ -12,6 +14,8 @@ class WebSocketManager:
self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
# Add progress tracking dictionary
self._download_progress: Dict[str, Dict] = {}
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
@@ -69,8 +73,19 @@ class WebSocketManager:
finally:
if download_id in self._download_websockets:
del self._download_websockets[download_id]
# Schedule cleanup of completed downloads after WebSocket disconnection
asyncio.create_task(self._delayed_cleanup(download_id))
return ws
async def _delayed_cleanup(self, download_id: str, delay_seconds: int = 300):
"""Clean up download progress after a delay (5 minutes by default)"""
await asyncio.sleep(delay_seconds)
progress_data = self._download_progress.get(download_id)
if progress_data and progress_data.get('progress', 0) >= 100:
self.cleanup_download_progress(download_id)
logger.debug(f"Delayed cleanup completed for download {download_id}")
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
if not self._websockets:
@@ -103,6 +118,12 @@ class WebSocketManager:
async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Send progress update to specific download client"""
# Store simplified progress data in memory (only progress percentage)
self._download_progress[download_id] = {
'progress': data.get('progress', 0),
'timestamp': datetime.now()
}
if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}")
return
@@ -113,6 +134,27 @@ class WebSocketManager:
except Exception as e:
logger.error(f"Error sending download progress: {e}")
def get_download_progress(self, download_id: str) -> Optional[Dict]:
"""Get progress information for a specific download"""
return self._download_progress.get(download_id)
def cleanup_download_progress(self, download_id: str):
"""Remove progress info for a specific download"""
self._download_progress.pop(download_id, None)
def cleanup_old_downloads(self, max_age_hours: int = 24):
"""Clean up old download progress entries"""
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
to_remove = []
for download_id, progress_data in self._download_progress.items():
if progress_data.get('timestamp', datetime.now()) < cutoff_time:
to_remove.append(download_id)
for download_id in to_remove:
self._download_progress.pop(download_id, None)
logger.debug(f"Cleaned up old download progress for {download_id}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
return len(self._websockets)