mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
feat: Add download progress endpoint and implement progress tracking in WebSocketManager
This commit is contained in:
@@ -59,6 +59,7 @@ class ApiRoutes:
|
|||||||
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
|
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
|
||||||
app.router.add_post('/api/download-model', routes.download_model)
|
app.router.add_post('/api/download-model', routes.download_model)
|
||||||
app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint
|
app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint
|
||||||
|
app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress
|
||||||
app.router.add_post('/api/move_model', routes.move_model)
|
app.router.add_post('/api/move_model', routes.move_model)
|
||||||
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
|
||||||
app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
|
app.router.add_post('/api/loras/save-metadata', routes.save_metadata)
|
||||||
@@ -500,6 +501,37 @@ class ApiRoutes:
|
|||||||
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
|
||||||
return web.Response(status=500, text=error_message)
|
return web.Response(status=500, text=error_message)
|
||||||
|
|
||||||
|
async def get_download_progress(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle request for download progress by download_id"""
|
||||||
|
try:
|
||||||
|
# Get download_id from URL path
|
||||||
|
download_id = request.match_info.get('download_id')
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID is required'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Get progress information from websocket manager
|
||||||
|
progress_data = ws_manager.get_download_progress(download_id)
|
||||||
|
|
||||||
|
if progress_data is None:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Download ID not found'
|
||||||
|
}, status=404)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'progress': progress_data.get('progress', 0)
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting download progress: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
async def move_model(self, request: web.Request) -> web.Response:
|
async def move_model(self, request: web.Request) -> web.Response:
|
||||||
"""Handle model move request"""
|
"""Handle model move request"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ import logging
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Set, Dict, Optional
|
from typing import Set, Dict, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -12,6 +14,8 @@ class WebSocketManager:
|
|||||||
self._websockets: Set[web.WebSocketResponse] = set()
|
self._websockets: Set[web.WebSocketResponse] = set()
|
||||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||||
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||||
|
# Add progress tracking dictionary
|
||||||
|
self._download_progress: Dict[str, Dict] = {}
|
||||||
|
|
||||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
"""Handle new WebSocket connection"""
|
"""Handle new WebSocket connection"""
|
||||||
@@ -69,8 +73,19 @@ class WebSocketManager:
|
|||||||
finally:
|
finally:
|
||||||
if download_id in self._download_websockets:
|
if download_id in self._download_websockets:
|
||||||
del self._download_websockets[download_id]
|
del self._download_websockets[download_id]
|
||||||
|
|
||||||
|
# Schedule cleanup of completed downloads after WebSocket disconnection
|
||||||
|
asyncio.create_task(self._delayed_cleanup(download_id))
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
async def _delayed_cleanup(self, download_id: str, delay_seconds: int = 300):
|
||||||
|
"""Clean up download progress after a delay (5 minutes by default)"""
|
||||||
|
await asyncio.sleep(delay_seconds)
|
||||||
|
progress_data = self._download_progress.get(download_id)
|
||||||
|
if progress_data and progress_data.get('progress', 0) >= 100:
|
||||||
|
self.cleanup_download_progress(download_id)
|
||||||
|
logger.debug(f"Delayed cleanup completed for download {download_id}")
|
||||||
|
|
||||||
async def broadcast(self, data: Dict):
|
async def broadcast(self, data: Dict):
|
||||||
"""Broadcast message to all connected clients"""
|
"""Broadcast message to all connected clients"""
|
||||||
if not self._websockets:
|
if not self._websockets:
|
||||||
@@ -103,6 +118,12 @@ class WebSocketManager:
|
|||||||
|
|
||||||
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||||
"""Send progress update to specific download client"""
|
"""Send progress update to specific download client"""
|
||||||
|
# Store simplified progress data in memory (only progress percentage)
|
||||||
|
self._download_progress[download_id] = {
|
||||||
|
'progress': data.get('progress', 0),
|
||||||
|
'timestamp': datetime.now()
|
||||||
|
}
|
||||||
|
|
||||||
if download_id not in self._download_websockets:
|
if download_id not in self._download_websockets:
|
||||||
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||||
return
|
return
|
||||||
@@ -113,6 +134,27 @@ class WebSocketManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending download progress: {e}")
|
logger.error(f"Error sending download progress: {e}")
|
||||||
|
|
||||||
|
def get_download_progress(self, download_id: str) -> Optional[Dict]:
|
||||||
|
"""Get progress information for a specific download"""
|
||||||
|
return self._download_progress.get(download_id)
|
||||||
|
|
||||||
|
def cleanup_download_progress(self, download_id: str):
|
||||||
|
"""Remove progress info for a specific download"""
|
||||||
|
self._download_progress.pop(download_id, None)
|
||||||
|
|
||||||
|
def cleanup_old_downloads(self, max_age_hours: int = 24):
|
||||||
|
"""Clean up old download progress entries"""
|
||||||
|
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for download_id, progress_data in self._download_progress.items():
|
||||||
|
if progress_data.get('timestamp', datetime.now()) < cutoff_time:
|
||||||
|
to_remove.append(download_id)
|
||||||
|
|
||||||
|
for download_id in to_remove:
|
||||||
|
self._download_progress.pop(download_id, None)
|
||||||
|
logger.debug(f"Cleaned up old download progress for {download_id}")
|
||||||
|
|
||||||
def get_connected_clients_count(self) -> int:
|
def get_connected_clients_count(self) -> int:
|
||||||
"""Get number of connected clients"""
|
"""Get number of connected clients"""
|
||||||
return len(self._websockets)
|
return len(self._websockets)
|
||||||
|
|||||||
Reference in New Issue
Block a user