feat: Implement download progress WebSocket and enhance download manager with unique IDs

This commit is contained in:
Will Miao
2025-07-02 23:48:35 +08:00
parent 49c4a4068b
commit 7a4b5a4667
7 changed files with 160 additions and 57 deletions

View File

@@ -50,6 +50,7 @@ class ApiRoutes:
app.router.add_get('/api/loras', routes.get_loras)
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/folders', routes.get_folders)

View File

@@ -7,6 +7,7 @@ from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry
from .settings_manager import settings
# Download to temporary file first
import tempfile
@@ -49,8 +50,7 @@ class DownloadManager:
async def download_from_civitai(self, model_id: str = None,
model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None,
model_type: str = None) -> Dict:
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
"""Download model from Civitai
Args:
@@ -59,18 +59,12 @@ class DownloadManager:
save_dir: Directory to save the model to
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
model_type: Type of model ('lora' or 'checkpoint')
use_default_paths: Flag to indicate whether to use default paths
Returns:
Dict with download result
"""
try:
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Get civitai client
civitai_client = await self._get_civitai_client()
@@ -80,15 +74,38 @@ class DownloadManager:
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Infer model_type if not provided
if model_type is None:
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
if model_type_from_info == 'checkpoint':
model_type = 'checkpoint'
elif model_type_from_info in VALID_LORA_TYPES:
model_type = 'lora'
else:
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
if model_type_from_info == 'checkpoint':
model_type = 'checkpoint'
elif model_type_from_info in VALID_LORA_TYPES:
model_type = 'lora'
else:
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
# Handle use_default_paths
if use_default_paths:
# Set save_dir based on model type
if model_type == 'checkpoint':
default_path = settings.get('default_checkpoint_root')
if not default_path:
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
save_dir = default_path
else: # model_type == 'lora'
default_path = settings.get('default_lora_root')
if not default_path:
return {'success': False, 'error': 'Default lora root path not set in settings'}
save_dir = default_path
# Set relative_path to the first tag if available
model_tags = version_info.get('model', {}).get('tags', [])
if model_tags:
relative_path = model_tags[0]
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):

View File

@@ -1,6 +1,7 @@
import logging
from aiohttp import web
from typing import Set, Dict, Optional
from uuid import uuid4
logger = logging.getLogger(__name__)
@@ -10,7 +11,7 @@ class WebSocketManager:
def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
@@ -39,6 +40,39 @@ class WebSocketManager:
finally:
self._init_websockets.discard(ws)
return ws
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for download progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
# Get download_id from query parameters
download_id = request.query.get('id')
if not download_id:
# Generate a new download ID if not provided
download_id = str(uuid4())
logger.info(f"Created new download ID: {download_id}")
else:
logger.info(f"Using provided download ID: {download_id}")
# Store the websocket with its download ID
self._download_websockets[download_id] = ws
try:
# Send the download ID back to the client
await ws.send_json({
'type': 'download_id',
'download_id': download_id
})
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Download WebSocket error: {ws.exception()}')
finally:
if download_id in self._download_websockets:
del self._download_websockets[download_id]
return ws
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
@@ -70,17 +104,18 @@ class WebSocketManager:
except Exception as e:
logger.error(f"Error sending initialization progress: {e}")
async def broadcast_checkpoint_progress(self, data: Dict):
"""Broadcast checkpoint download progress to connected clients"""
if not self._checkpoint_websockets:
async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Send progress update to specific download client"""
if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}")
return
for ws in self._checkpoint_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending checkpoint progress: {e}")
ws = self._download_websockets[download_id]
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending download progress: {e}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
return len(self._websockets)
@@ -88,10 +123,14 @@ class WebSocketManager:
def get_init_clients_count(self) -> int:
"""Get number of initialization progress clients"""
return len(self._init_websockets)
def get_checkpoint_clients_count(self) -> int:
"""Get number of checkpoint progress clients"""
return len(self._checkpoint_websockets)
def get_download_clients_count(self) -> int:
"""Get number of download progress clients"""
return len(self._download_websockets)
def generate_download_id(self) -> str:
"""Generate a unique download ID"""
return str(uuid4())
# Global instance
ws_manager = WebSocketManager()

View File

@@ -12,6 +12,7 @@ from ..services.service_registry import ServiceRegistry
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from ..services.download_manager import DownloadManager
from ..services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
@@ -565,13 +566,12 @@ class ModelRouteUtils:
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type=None) -> web.Response:
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
"""Handle model download request
Args:
request: The aiohttp request
download_manager: Instance of DownloadManager
model_type: Type of model ('lora' or 'checkpoint')
Returns:
web.Response: The HTTP response
@@ -579,12 +579,15 @@ class ModelRouteUtils:
try:
data = await request.json()
# Create progress callback
# Get or generate a download ID
download_id = data.get('download_id', ws_manager.generate_download_id())
# Create progress callback with download ID
async def progress_callback(progress):
from ..services.websocket_manager import ws_manager
await ws_manager.broadcast({
await ws_manager.broadcast_download_progress(download_id, {
'status': 'progress',
'progress': progress
'progress': progress,
'download_id': download_id
})
# Check which identifier is provided
@@ -598,15 +601,20 @@ class ModelRouteUtils:
text="Missing required parameter: Please provide 'model_id'"
)
use_default_paths = data.get('use_default_paths', False)
result = await download_manager.download_from_civitai(
model_id=model_id,
model_version_id=model_version_id,
save_dir=data.get('model_root'),
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type=model_type
use_default_paths=use_default_paths,
progress_callback=progress_callback
)
# Include download_id in the response
result['download_id'] = download_id
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')

View File

@@ -301,13 +301,24 @@ export class CheckpointDownloadManager {
const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name);
// Setup WebSocket for progress updates using checkpoint-specific endpoint
// Generate a unique ID for this download
const downloadId = Date.now().toString();
// Setup WebSocket for progress updates using download-specific endpoint
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.status === 'progress') {
// Handle download ID confirmation
if (data.type === 'download_id') {
console.log(`Connected to checkpoint download progress with ID: ${data.download_id}`);
return;
}
// Only process progress updates for our download
if (data.status === 'progress' && data.download_id === downloadId) {
// Update progress display with current progress
updateProgress(data.progress, 0, this.currentVersion.name);
@@ -329,7 +340,7 @@ export class CheckpointDownloadManager {
// Continue with download even if WebSocket fails
};
// Start download using checkpoint download endpoint
// Start download using checkpoint download endpoint with download ID
const response = await fetch('/api/download-model', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
@@ -337,7 +348,8 @@ export class CheckpointDownloadManager {
model_id: this.modelId,
model_version_id: this.currentVersion.id,
model_root: checkpointRoot,
relative_path: targetFolder
relative_path: targetFolder,
download_id: downloadId
})
});

View File

@@ -311,13 +311,24 @@ export class DownloadManager {
const updateProgress = this.loadingManager.showDownloadProgress(1);
updateProgress(0, 0, this.currentVersion.name);
// Setup WebSocket for progress updates
// Generate a unique ID for this download
const downloadId = Date.now().toString();
// Setup WebSocket for progress updates - use download-specific endpoint
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.status === 'progress') {
// Handle download ID confirmation
if (data.type === 'download_id') {
console.log(`Connected to download progress with ID: ${data.download_id}`);
return;
}
// Only process progress updates for our download
if (data.status === 'progress' && data.download_id === downloadId) {
// Update progress display with current progress
updateProgress(data.progress, 0, this.currentVersion.name);
@@ -339,7 +350,7 @@ export class DownloadManager {
// Continue with download even if WebSocket fails
};
// Start download
// Start download with our download ID
const response = await fetch('/api/download-model', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
@@ -347,7 +358,8 @@ export class DownloadManager {
model_id: this.modelId,
model_version_id: this.currentVersion.id,
model_root: loraRoot,
relative_path: targetFolder
relative_path: targetFolder,
download_id: downloadId
})
});
@@ -358,6 +370,9 @@ export class DownloadManager {
showToast('Download completed successfully', 'success');
modalManager.closeModal('downloadModal');
// Close WebSocket after download completes
ws.close();
// Update state and trigger reload with folder update
state.activeFolder = targetFolder;
await resetAndReload(true); // Pass true to update folders

View File

@@ -128,9 +128,12 @@ export class DownloadManager {
targetPath += '/' + newFolder;
}
// Generate a unique ID for this batch download
const batchDownloadId = Date.now().toString();
// Set up WebSocket for progress updates
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${batchDownloadId}`);
// Show enhanced loading with progress details for multiple items
const updateProgress = this.importManager.loadingManager.showDownloadProgress(
@@ -145,7 +148,15 @@ export class DownloadManager {
// Set up progress tracking for current download
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
if (data.status === 'progress') {
// Handle download ID confirmation
if (data.type === 'download_id') {
console.log(`Connected to batch download progress with ID: ${data.download_id}`);
return;
}
// Process progress updates for our current active download
if (data.status === 'progress' && data.download_id && data.download_id.startsWith(batchDownloadId)) {
// Update current LoRA progress
currentLoraProgress = data.progress;
@@ -188,16 +199,16 @@ export class DownloadManager {
updateProgress(0, completedDownloads, lora.name);
try {
// Download the LoRA
// Download the LoRA with download ID
const response = await fetch('/api/download-model', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
download_url: lora.downloadUrl,
model_version_id: lora.modelVersionId,
model_hash: lora.hash,
model_id: lora.modelId,
model_version_id: lora.id,
model_root: loraRoot,
relative_path: targetPath.replace(loraRoot + '/', '')
relative_path: targetPath.replace(loraRoot + '/', ''),
download_id: batchDownloadId
})
});