mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 21:52:11 -03:00
feat: Implement download progress WebSocket and enhance download manager with unique IDs
This commit is contained in:
@@ -50,6 +50,7 @@ class ApiRoutes:
|
||||
app.router.add_get('/api/loras', routes.get_loras)
|
||||
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
|
||||
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
||||
app.router.add_get('/api/folders', routes.get_folders)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -49,8 +50,7 @@ class DownloadManager:
|
||||
|
||||
async def download_from_civitai(self, model_id: str = None,
|
||||
model_version_id: str = None, save_dir: str = None,
|
||||
relative_path: str = '', progress_callback=None,
|
||||
model_type: str = None) -> Dict:
|
||||
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
|
||||
"""Download model from Civitai
|
||||
|
||||
Args:
|
||||
@@ -59,18 +59,12 @@ class DownloadManager:
|
||||
save_dir: Directory to save the model to
|
||||
relative_path: Relative path within save_dir
|
||||
progress_callback: Callback function for progress updates
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
use_default_paths: Flag to indicate whether to use default paths
|
||||
|
||||
Returns:
|
||||
Dict with download result
|
||||
"""
|
||||
try:
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Get civitai client
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
@@ -80,15 +74,38 @@ class DownloadManager:
|
||||
if not version_info:
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
# Infer model_type if not provided
|
||||
if model_type is None:
|
||||
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||
if model_type_from_info == 'checkpoint':
|
||||
model_type = 'checkpoint'
|
||||
elif model_type_from_info in VALID_LORA_TYPES:
|
||||
model_type = 'lora'
|
||||
else:
|
||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||
if model_type_from_info == 'checkpoint':
|
||||
model_type = 'checkpoint'
|
||||
elif model_type_from_info in VALID_LORA_TYPES:
|
||||
model_type = 'lora'
|
||||
else:
|
||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||
|
||||
# Handle use_default_paths
|
||||
if use_default_paths:
|
||||
# Set save_dir based on model type
|
||||
if model_type == 'checkpoint':
|
||||
default_path = settings.get('default_checkpoint_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||
save_dir = default_path
|
||||
else: # model_type == 'lora'
|
||||
default_path = settings.get('default_lora_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||
save_dir = default_path
|
||||
|
||||
# Set relative_path to the first tag if available
|
||||
model_tags = version_info.get('model', {}).get('tags', [])
|
||||
if model_tags:
|
||||
relative_path = model_tags[0]
|
||||
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Check if this is an early access model
|
||||
if version_info.get('earlyAccessEndsAt'):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Set, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,7 +11,7 @@ class WebSocketManager:
|
||||
def __init__(self):
|
||||
self._websockets: Set[web.WebSocketResponse] = set()
|
||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
|
||||
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||
|
||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection"""
|
||||
@@ -39,6 +40,39 @@ class WebSocketManager:
|
||||
finally:
|
||||
self._init_websockets.discard(ws)
|
||||
return ws
|
||||
|
||||
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection for download progress"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
# Get download_id from query parameters
|
||||
download_id = request.query.get('id')
|
||||
|
||||
if not download_id:
|
||||
# Generate a new download ID if not provided
|
||||
download_id = str(uuid4())
|
||||
logger.info(f"Created new download ID: {download_id}")
|
||||
else:
|
||||
logger.info(f"Using provided download ID: {download_id}")
|
||||
|
||||
# Store the websocket with its download ID
|
||||
self._download_websockets[download_id] = ws
|
||||
|
||||
try:
|
||||
# Send the download ID back to the client
|
||||
await ws.send_json({
|
||||
'type': 'download_id',
|
||||
'download_id': download_id
|
||||
})
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.ERROR:
|
||||
logger.error(f'Download WebSocket error: {ws.exception()}')
|
||||
finally:
|
||||
if download_id in self._download_websockets:
|
||||
del self._download_websockets[download_id]
|
||||
return ws
|
||||
|
||||
async def broadcast(self, data: Dict):
|
||||
"""Broadcast message to all connected clients"""
|
||||
@@ -70,17 +104,18 @@ class WebSocketManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending initialization progress: {e}")
|
||||
|
||||
async def broadcast_checkpoint_progress(self, data: Dict):
|
||||
"""Broadcast checkpoint download progress to connected clients"""
|
||||
if not self._checkpoint_websockets:
|
||||
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||
"""Send progress update to specific download client"""
|
||||
if download_id not in self._download_websockets:
|
||||
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||
return
|
||||
|
||||
for ws in self._checkpoint_websockets:
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending checkpoint progress: {e}")
|
||||
|
||||
ws = self._download_websockets[download_id]
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending download progress: {e}")
|
||||
|
||||
def get_connected_clients_count(self) -> int:
|
||||
"""Get number of connected clients"""
|
||||
return len(self._websockets)
|
||||
@@ -88,10 +123,14 @@ class WebSocketManager:
|
||||
def get_init_clients_count(self) -> int:
|
||||
"""Get number of initialization progress clients"""
|
||||
return len(self._init_websockets)
|
||||
|
||||
def get_checkpoint_clients_count(self) -> int:
|
||||
"""Get number of checkpoint progress clients"""
|
||||
return len(self._checkpoint_websockets)
|
||||
|
||||
def get_download_clients_count(self) -> int:
|
||||
"""Get number of download progress clients"""
|
||||
return len(self._download_websockets)
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
"""Generate a unique download ID"""
|
||||
return str(uuid4())
|
||||
|
||||
# Global instance
|
||||
ws_manager = WebSocketManager()
|
||||
@@ -12,6 +12,7 @@ from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -565,13 +566,12 @@ class ModelRouteUtils:
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type=None) -> web.Response:
|
||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
||||
"""Handle model download request
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
download_manager: Instance of DownloadManager
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
@@ -579,12 +579,15 @@ class ModelRouteUtils:
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
# Get or generate a download ID
|
||||
download_id = data.get('download_id', ws_manager.generate_download_id())
|
||||
|
||||
# Create progress callback with download ID
|
||||
async def progress_callback(progress):
|
||||
from ..services.websocket_manager import ws_manager
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_download_progress(download_id, {
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
'progress': progress,
|
||||
'download_id': download_id
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
@@ -598,15 +601,20 @@ class ModelRouteUtils:
|
||||
text="Missing required parameter: Please provide 'model_id'"
|
||||
)
|
||||
|
||||
use_default_paths = data.get('use_default_paths', False)
|
||||
|
||||
result = await download_manager.download_from_civitai(
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=data.get('model_root'),
|
||||
relative_path=data.get('relative_path', ''),
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
use_default_paths=use_default_paths,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
# Include download_id in the response
|
||||
result['download_id'] = download_id
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
|
||||
@@ -301,13 +301,24 @@ export class CheckpointDownloadManager {
|
||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||
updateProgress(0, 0, this.currentVersion.name);
|
||||
|
||||
// Setup WebSocket for progress updates using checkpoint-specific endpoint
|
||||
// Generate a unique ID for this download
|
||||
const downloadId = Date.now().toString();
|
||||
|
||||
// Setup WebSocket for progress updates using download-specific endpoint
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.status === 'progress') {
|
||||
|
||||
// Handle download ID confirmation
|
||||
if (data.type === 'download_id') {
|
||||
console.log(`Connected to checkpoint download progress with ID: ${data.download_id}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Only process progress updates for our download
|
||||
if (data.status === 'progress' && data.download_id === downloadId) {
|
||||
// Update progress display with current progress
|
||||
updateProgress(data.progress, 0, this.currentVersion.name);
|
||||
|
||||
@@ -329,7 +340,7 @@ export class CheckpointDownloadManager {
|
||||
// Continue with download even if WebSocket fails
|
||||
};
|
||||
|
||||
// Start download using checkpoint download endpoint
|
||||
// Start download using checkpoint download endpoint with download ID
|
||||
const response = await fetch('/api/download-model', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
@@ -337,7 +348,8 @@ export class CheckpointDownloadManager {
|
||||
model_id: this.modelId,
|
||||
model_version_id: this.currentVersion.id,
|
||||
model_root: checkpointRoot,
|
||||
relative_path: targetFolder
|
||||
relative_path: targetFolder,
|
||||
download_id: downloadId
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
@@ -311,13 +311,24 @@ export class DownloadManager {
|
||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||
updateProgress(0, 0, this.currentVersion.name);
|
||||
|
||||
// Setup WebSocket for progress updates
|
||||
// Generate a unique ID for this download
|
||||
const downloadId = Date.now().toString();
|
||||
|
||||
// Setup WebSocket for progress updates - use download-specific endpoint
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.status === 'progress') {
|
||||
|
||||
// Handle download ID confirmation
|
||||
if (data.type === 'download_id') {
|
||||
console.log(`Connected to download progress with ID: ${data.download_id}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Only process progress updates for our download
|
||||
if (data.status === 'progress' && data.download_id === downloadId) {
|
||||
// Update progress display with current progress
|
||||
updateProgress(data.progress, 0, this.currentVersion.name);
|
||||
|
||||
@@ -339,7 +350,7 @@ export class DownloadManager {
|
||||
// Continue with download even if WebSocket fails
|
||||
};
|
||||
|
||||
// Start download
|
||||
// Start download with our download ID
|
||||
const response = await fetch('/api/download-model', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
@@ -347,7 +358,8 @@ export class DownloadManager {
|
||||
model_id: this.modelId,
|
||||
model_version_id: this.currentVersion.id,
|
||||
model_root: loraRoot,
|
||||
relative_path: targetFolder
|
||||
relative_path: targetFolder,
|
||||
download_id: downloadId
|
||||
})
|
||||
});
|
||||
|
||||
@@ -358,6 +370,9 @@ export class DownloadManager {
|
||||
showToast('Download completed successfully', 'success');
|
||||
modalManager.closeModal('downloadModal');
|
||||
|
||||
// Close WebSocket after download completes
|
||||
ws.close();
|
||||
|
||||
// Update state and trigger reload with folder update
|
||||
state.activeFolder = targetFolder;
|
||||
await resetAndReload(true); // Pass true to update folders
|
||||
|
||||
@@ -128,9 +128,12 @@ export class DownloadManager {
|
||||
targetPath += '/' + newFolder;
|
||||
}
|
||||
|
||||
// Generate a unique ID for this batch download
|
||||
const batchDownloadId = Date.now().toString();
|
||||
|
||||
// Set up WebSocket for progress updates
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${batchDownloadId}`);
|
||||
|
||||
// Show enhanced loading with progress details for multiple items
|
||||
const updateProgress = this.importManager.loadingManager.showDownloadProgress(
|
||||
@@ -145,7 +148,15 @@ export class DownloadManager {
|
||||
// Set up progress tracking for current download
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.status === 'progress') {
|
||||
|
||||
// Handle download ID confirmation
|
||||
if (data.type === 'download_id') {
|
||||
console.log(`Connected to batch download progress with ID: ${data.download_id}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Process progress updates for our current active download
|
||||
if (data.status === 'progress' && data.download_id && data.download_id.startsWith(batchDownloadId)) {
|
||||
// Update current LoRA progress
|
||||
currentLoraProgress = data.progress;
|
||||
|
||||
@@ -188,16 +199,16 @@ export class DownloadManager {
|
||||
updateProgress(0, completedDownloads, lora.name);
|
||||
|
||||
try {
|
||||
// Download the LoRA
|
||||
// Download the LoRA with download ID
|
||||
const response = await fetch('/api/download-model', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
download_url: lora.downloadUrl,
|
||||
model_version_id: lora.modelVersionId,
|
||||
model_hash: lora.hash,
|
||||
model_id: lora.modelId,
|
||||
model_version_id: lora.id,
|
||||
model_root: loraRoot,
|
||||
relative_path: targetPath.replace(loraRoot + '/', '')
|
||||
relative_path: targetPath.replace(loraRoot + '/', ''),
|
||||
download_id: batchDownloadId
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user