From 76fc9e5a3d74bda6b4ae7f6f7da322d685e5a1ea Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 13 Apr 2025 21:31:01 +0800 Subject: [PATCH] feat: Add WebSocket support for checkpoint download progress and update related components --- py/routes/checkpoints_routes.py | 69 +++++++++++++++++-- py/services/websocket_manager.py | 30 ++++++++ .../js/managers/CheckpointDownloadManager.js | 4 +- 3 files changed, 95 insertions(+), 8 deletions(-) diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 8cc35555..fa4e20ee 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -53,6 +53,9 @@ class CheckpointsRoutes: app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) app.router.add_post('/api/checkpoints/download', self.download_checkpoint) app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route + + # Add new WebSocket endpoint for checkpoint progress + app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection) async def get_checkpoints(self, request): """Get paginated checkpoint data""" @@ -501,12 +504,66 @@ class CheckpointsRoutes: if self.download_manager is None: self.download_manager = await ServiceRegistry.get_download_manager() - # Use the common download handler with model_type="checkpoint" - return await ModelRouteUtils.handle_download_model( - request=request, - download_manager=self.download_manager, - model_type="checkpoint" - ) + try: + data = await request.json() + + # Create progress callback that uses checkpoint-specific WebSocket + async def progress_callback(progress): + await ws_manager.broadcast_checkpoint_progress({ + 'status': 'progress', + 'progress': progress + }) + + # Check which identifier is provided + download_url = data.get('download_url') + model_hash = data.get('model_hash') + model_version_id = data.get('model_version_id') + + # Validate that at least one identifier is provided + if not any([download_url, model_hash, model_version_id]): + return web.Response( + status=400, + text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'" + ) + + result = await self.download_manager.download_from_civitai( + download_url=download_url, + model_hash=model_hash, + model_version_id=model_version_id, + save_dir=data.get('checkpoint_root'), + relative_path=data.get('relative_path', ''), + progress_callback=progress_callback, + model_type="checkpoint" + ) + + if not result.get('success', False): + error_message = result.get('error', 'Unknown error') + + # Return 401 for early access errors + if 'early access' in error_message.lower(): + logger.warning(f"Early access download failed: {error_message}") + return web.Response( + status=401, + text=f"Early Access Restriction: {error_message}" + ) + + return web.Response(status=500, text=error_message) + + return web.json_response(result) + + except Exception as e: + error_message = str(e) + + # Check if this might be an early access error + if '401' in error_message: + logger.warning(f"Early access error (401): {error_message}") + return web.Response( + status=401, + text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai." + ) + + logger.error(f"Error downloading checkpoint: {error_message}") + return web.Response(status=500, text=error_message) async def get_checkpoint_roots(self, request): """Return the checkpoint root directories""" diff --git a/py/services/websocket_manager.py b/py/services/websocket_manager.py index 8e35f601..c85aa3a2 100644 --- a/py/services/websocket_manager.py +++ b/py/services/websocket_manager.py @@ -10,6 +10,7 @@ class WebSocketManager: def __init__(self): self._websockets: Set[web.WebSocketResponse] = set() self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients + self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: """Handle new WebSocket connection""" @@ -39,6 +40,20 @@ class WebSocketManager: self._init_websockets.discard(ws) return ws + async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse: + """Handle new WebSocket connection for checkpoint download progress""" + ws = web.WebSocketResponse() + await ws.prepare(request) + self._checkpoint_websockets.add(ws) + + try: + async for msg in ws: + if msg.type == web.WSMsgType.ERROR: + logger.error(f'Checkpoint WebSocket error: {ws.exception()}') + finally: + self._checkpoint_websockets.discard(ws) + return ws + async def broadcast(self, data: Dict): """Broadcast message to all connected clients""" if not self._websockets: @@ -68,6 +83,17 @@ class WebSocketManager: await ws.send_json(data) except Exception as e: logger.error(f"Error sending initialization progress: {e}") + + async def broadcast_checkpoint_progress(self, data: Dict): + """Broadcast checkpoint download progress to connected clients""" + if not self._checkpoint_websockets: + return + + for ws in self._checkpoint_websockets: + try: + await ws.send_json(data) + except Exception as e: + logger.error(f"Error sending checkpoint progress: {e}") def get_connected_clients_count(self) -> int: """Get number of connected clients""" @@ -77,5 +103,9 @@ class WebSocketManager: """Get number of initialization progress clients""" return len(self._init_websockets) + def get_checkpoint_clients_count(self) -> int: + """Get number of checkpoint progress clients""" + return len(self._checkpoint_websockets) + # Global instance ws_manager = WebSocketManager() \ No newline at end of file diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js index 5a13f116..e0d132b0 100644 --- a/static/js/managers/CheckpointDownloadManager.js +++ b/static/js/managers/CheckpointDownloadManager.js @@ -292,9 +292,9 @@ export class CheckpointDownloadManager { const updateProgress = this.loadingManager.showDownloadProgress(1); updateProgress(0, 0, this.currentVersion.name); - // Setup WebSocket for progress updates + // Setup WebSocket for progress updates using checkpoint-specific endpoint const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`); + const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/checkpoint-progress`); ws.onmessage = (event) => { const data = JSON.parse(event.data);