feat: Implement download progress WebSocket and enhance download manager with unique IDs

This commit is contained in:
Will Miao
2025-07-02 23:48:35 +08:00
parent 49c4a4068b
commit 7a4b5a4667
7 changed files with 160 additions and 57 deletions

View File

@@ -1,6 +1,7 @@
import logging
from aiohttp import web
from typing import Set, Dict, Optional
from uuid import uuid4
logger = logging.getLogger(__name__)
@@ -10,7 +11,7 @@ class WebSocketManager:
def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
@@ -39,6 +40,39 @@ class WebSocketManager:
finally:
self._init_websockets.discard(ws)
return ws
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for download progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get download_id from query parameters
download_id = request.query.get('id')
if not download_id:
# Generate a new download ID if not provided
download_id = str(uuid4())
logger.info(f"Created new download ID: {download_id}")
else:
logger.info(f"Using provided download ID: {download_id}")
# Store the websocket with its download ID
self._download_websockets[download_id] = ws
try:
# Send the download ID back to the client
await ws.send_json({
'type': 'download_id',
'download_id': download_id
})
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Download WebSocket error: {ws.exception()}')
finally:
if download_id in self._download_websockets:
del self._download_websockets[download_id]
return ws
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
@@ -70,17 +104,18 @@ class WebSocketManager:
except Exception as e:
logger.error(f"Error sending initialization progress: {e}")
async def broadcast_checkpoint_progress(self, data: Dict):
"""Broadcast checkpoint download progress to connected clients"""
if not self._checkpoint_websockets:
async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Send progress update to specific download client"""
if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}")
return
for ws in self._checkpoint_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending checkpoint progress: {e}")
ws = self._download_websockets[download_id]
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending download progress: {e}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
return len(self._websockets)
@@ -88,10 +123,14 @@ class WebSocketManager:
def get_init_clients_count(self) -> int:
"""Get number of initialization progress clients"""
return len(self._init_websockets)
def get_checkpoint_clients_count(self) -> int:
"""Get number of checkpoint progress clients"""
return len(self._checkpoint_websockets)
def get_download_clients_count(self) -> int:
"""Get number of download progress clients"""
return len(self._download_websockets)
def generate_download_id(self) -> str:
"""Generate a unique download ID"""
return str(uuid4())
# Global instance
ws_manager = WebSocketManager()