mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: Implement download progress WebSocket and enhance download manager with unique IDs
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Set, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,7 +11,7 @@ class WebSocketManager:
|
||||
def __init__(self):
|
||||
self._websockets: Set[web.WebSocketResponse] = set()
|
||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
|
||||
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||
|
||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection"""
|
||||
@@ -39,6 +40,39 @@ class WebSocketManager:
|
||||
finally:
|
||||
self._init_websockets.discard(ws)
|
||||
return ws
|
||||
|
||||
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection for download progress"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
# Get download_id from query parameters
|
||||
download_id = request.query.get('id')
|
||||
|
||||
if not download_id:
|
||||
# Generate a new download ID if not provided
|
||||
download_id = str(uuid4())
|
||||
logger.info(f"Created new download ID: {download_id}")
|
||||
else:
|
||||
logger.info(f"Using provided download ID: {download_id}")
|
||||
|
||||
# Store the websocket with its download ID
|
||||
self._download_websockets[download_id] = ws
|
||||
|
||||
try:
|
||||
# Send the download ID back to the client
|
||||
await ws.send_json({
|
||||
'type': 'download_id',
|
||||
'download_id': download_id
|
||||
})
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.ERROR:
|
||||
logger.error(f'Download WebSocket error: {ws.exception()}')
|
||||
finally:
|
||||
if download_id in self._download_websockets:
|
||||
del self._download_websockets[download_id]
|
||||
return ws
|
||||
|
||||
async def broadcast(self, data: Dict):
|
||||
"""Broadcast message to all connected clients"""
|
||||
@@ -70,17 +104,18 @@ class WebSocketManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending initialization progress: {e}")
|
||||
|
||||
async def broadcast_checkpoint_progress(self, data: Dict):
|
||||
"""Broadcast checkpoint download progress to connected clients"""
|
||||
if not self._checkpoint_websockets:
|
||||
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||
"""Send progress update to specific download client"""
|
||||
if download_id not in self._download_websockets:
|
||||
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||
return
|
||||
|
||||
for ws in self._checkpoint_websockets:
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending checkpoint progress: {e}")
|
||||
|
||||
ws = self._download_websockets[download_id]
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending download progress: {e}")
|
||||
|
||||
def get_connected_clients_count(self) -> int:
|
||||
"""Get number of connected clients"""
|
||||
return len(self._websockets)
|
||||
@@ -88,10 +123,14 @@ class WebSocketManager:
|
||||
def get_init_clients_count(self) -> int:
|
||||
"""Get number of initialization progress clients"""
|
||||
return len(self._init_websockets)
|
||||
|
||||
def get_checkpoint_clients_count(self) -> int:
|
||||
"""Get number of checkpoint progress clients"""
|
||||
return len(self._checkpoint_websockets)
|
||||
|
||||
def get_download_clients_count(self) -> int:
|
||||
"""Get number of download progress clients"""
|
||||
return len(self._download_websockets)
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
"""Generate a unique download ID"""
|
||||
return str(uuid4())
|
||||
|
||||
# Global instance
|
||||
ws_manager = WebSocketManager()
|
||||
Reference in New Issue
Block a user