mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat: Implement download progress WebSocket and enhance download manager with unique IDs
This commit is contained in:
@@ -50,6 +50,7 @@ class ApiRoutes:
|
|||||||
app.router.add_get('/api/loras', routes.get_loras)
|
app.router.add_get('/api/loras', routes.get_loras)
|
||||||
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
|
||||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection) # Add new WebSocket route for download progress
|
||||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
|
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection) # Add new WebSocket route
|
||||||
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
|
||||||
app.router.add_get('/api/folders', routes.get_folders)
|
app.router.add_get('/api/folders', routes.get_folders)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
|||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
|
from .settings_manager import settings
|
||||||
|
|
||||||
# Download to temporary file first
|
# Download to temporary file first
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -49,8 +50,7 @@ class DownloadManager:
|
|||||||
|
|
||||||
async def download_from_civitai(self, model_id: str = None,
|
async def download_from_civitai(self, model_id: str = None,
|
||||||
model_version_id: str = None, save_dir: str = None,
|
model_version_id: str = None, save_dir: str = None,
|
||||||
relative_path: str = '', progress_callback=None,
|
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
|
||||||
model_type: str = None) -> Dict:
|
|
||||||
"""Download model from Civitai
|
"""Download model from Civitai
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -59,18 +59,12 @@ class DownloadManager:
|
|||||||
save_dir: Directory to save the model to
|
save_dir: Directory to save the model to
|
||||||
relative_path: Relative path within save_dir
|
relative_path: Relative path within save_dir
|
||||||
progress_callback: Callback function for progress updates
|
progress_callback: Callback function for progress updates
|
||||||
model_type: Type of model ('lora' or 'checkpoint')
|
use_default_paths: Flag to indicate whether to use default paths
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict with download result
|
Dict with download result
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Update save directory with relative path if provided
|
|
||||||
if relative_path:
|
|
||||||
save_dir = os.path.join(save_dir, relative_path)
|
|
||||||
# Create directory if it doesn't exist
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# Get civitai client
|
# Get civitai client
|
||||||
civitai_client = await self._get_civitai_client()
|
civitai_client = await self._get_civitai_client()
|
||||||
|
|
||||||
@@ -80,15 +74,38 @@ class DownloadManager:
|
|||||||
if not version_info:
|
if not version_info:
|
||||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||||
|
|
||||||
# Infer model_type if not provided
|
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||||
if model_type is None:
|
if model_type_from_info == 'checkpoint':
|
||||||
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
model_type = 'checkpoint'
|
||||||
if model_type_from_info == 'checkpoint':
|
elif model_type_from_info in VALID_LORA_TYPES:
|
||||||
model_type = 'checkpoint'
|
model_type = 'lora'
|
||||||
elif model_type_from_info in VALID_LORA_TYPES:
|
else:
|
||||||
model_type = 'lora'
|
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||||
else:
|
|
||||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
# Handle use_default_paths
|
||||||
|
if use_default_paths:
|
||||||
|
# Set save_dir based on model type
|
||||||
|
if model_type == 'checkpoint':
|
||||||
|
default_path = settings.get('default_checkpoint_root')
|
||||||
|
if not default_path:
|
||||||
|
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||||
|
save_dir = default_path
|
||||||
|
else: # model_type == 'lora'
|
||||||
|
default_path = settings.get('default_lora_root')
|
||||||
|
if not default_path:
|
||||||
|
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||||
|
save_dir = default_path
|
||||||
|
|
||||||
|
# Set relative_path to the first tag if available
|
||||||
|
model_tags = version_info.get('model', {}).get('tags', [])
|
||||||
|
if model_tags:
|
||||||
|
relative_path = model_tags[0]
|
||||||
|
|
||||||
|
# Update save directory with relative path if provided
|
||||||
|
if relative_path:
|
||||||
|
save_dir = os.path.join(save_dir, relative_path)
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
# Check if this is an early access model
|
# Check if this is an early access model
|
||||||
if version_info.get('earlyAccessEndsAt'):
|
if version_info.get('earlyAccessEndsAt'):
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Set, Dict, Optional
|
from typing import Set, Dict, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -10,7 +11,7 @@ class WebSocketManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._websockets: Set[web.WebSocketResponse] = set()
|
self._websockets: Set[web.WebSocketResponse] = set()
|
||||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||||
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
|
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||||
|
|
||||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
"""Handle new WebSocket connection"""
|
"""Handle new WebSocket connection"""
|
||||||
@@ -39,6 +40,39 @@ class WebSocketManager:
|
|||||||
finally:
|
finally:
|
||||||
self._init_websockets.discard(ws)
|
self._init_websockets.discard(ws)
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
|
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||||
|
"""Handle new WebSocket connection for download progress"""
|
||||||
|
ws = web.WebSocketResponse()
|
||||||
|
await ws.prepare(request)
|
||||||
|
|
||||||
|
# Get download_id from query parameters
|
||||||
|
download_id = request.query.get('id')
|
||||||
|
|
||||||
|
if not download_id:
|
||||||
|
# Generate a new download ID if not provided
|
||||||
|
download_id = str(uuid4())
|
||||||
|
logger.info(f"Created new download ID: {download_id}")
|
||||||
|
else:
|
||||||
|
logger.info(f"Using provided download ID: {download_id}")
|
||||||
|
|
||||||
|
# Store the websocket with its download ID
|
||||||
|
self._download_websockets[download_id] = ws
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Send the download ID back to the client
|
||||||
|
await ws.send_json({
|
||||||
|
'type': 'download_id',
|
||||||
|
'download_id': download_id
|
||||||
|
})
|
||||||
|
|
||||||
|
async for msg in ws:
|
||||||
|
if msg.type == web.WSMsgType.ERROR:
|
||||||
|
logger.error(f'Download WebSocket error: {ws.exception()}')
|
||||||
|
finally:
|
||||||
|
if download_id in self._download_websockets:
|
||||||
|
del self._download_websockets[download_id]
|
||||||
|
return ws
|
||||||
|
|
||||||
async def broadcast(self, data: Dict):
|
async def broadcast(self, data: Dict):
|
||||||
"""Broadcast message to all connected clients"""
|
"""Broadcast message to all connected clients"""
|
||||||
@@ -70,17 +104,18 @@ class WebSocketManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending initialization progress: {e}")
|
logger.error(f"Error sending initialization progress: {e}")
|
||||||
|
|
||||||
async def broadcast_checkpoint_progress(self, data: Dict):
|
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||||
"""Broadcast checkpoint download progress to connected clients"""
|
"""Send progress update to specific download client"""
|
||||||
if not self._checkpoint_websockets:
|
if download_id not in self._download_websockets:
|
||||||
|
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
for ws in self._checkpoint_websockets:
|
ws = self._download_websockets[download_id]
|
||||||
try:
|
try:
|
||||||
await ws.send_json(data)
|
await ws.send_json(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error sending checkpoint progress: {e}")
|
logger.error(f"Error sending download progress: {e}")
|
||||||
|
|
||||||
def get_connected_clients_count(self) -> int:
|
def get_connected_clients_count(self) -> int:
|
||||||
"""Get number of connected clients"""
|
"""Get number of connected clients"""
|
||||||
return len(self._websockets)
|
return len(self._websockets)
|
||||||
@@ -88,10 +123,14 @@ class WebSocketManager:
|
|||||||
def get_init_clients_count(self) -> int:
|
def get_init_clients_count(self) -> int:
|
||||||
"""Get number of initialization progress clients"""
|
"""Get number of initialization progress clients"""
|
||||||
return len(self._init_websockets)
|
return len(self._init_websockets)
|
||||||
|
|
||||||
def get_checkpoint_clients_count(self) -> int:
|
def get_download_clients_count(self) -> int:
|
||||||
"""Get number of checkpoint progress clients"""
|
"""Get number of download progress clients"""
|
||||||
return len(self._checkpoint_websockets)
|
return len(self._download_websockets)
|
||||||
|
|
||||||
|
def generate_download_id(self) -> str:
|
||||||
|
"""Generate a unique download ID"""
|
||||||
|
return str(uuid4())
|
||||||
|
|
||||||
# Global instance
|
# Global instance
|
||||||
ws_manager = WebSocketManager()
|
ws_manager = WebSocketManager()
|
||||||
@@ -12,6 +12,7 @@ from ..services.service_registry import ServiceRegistry
|
|||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from ..services.download_manager import DownloadManager
|
from ..services.download_manager import DownloadManager
|
||||||
|
from ..services.websocket_manager import ws_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -565,13 +566,12 @@ class ModelRouteUtils:
|
|||||||
return web.Response(text=str(e), status=500)
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type=None) -> web.Response:
|
async def handle_download_model(request: web.Request, download_manager: DownloadManager) -> web.Response:
|
||||||
"""Handle model download request
|
"""Handle model download request
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The aiohttp request
|
request: The aiohttp request
|
||||||
download_manager: Instance of DownloadManager
|
download_manager: Instance of DownloadManager
|
||||||
model_type: Type of model ('lora' or 'checkpoint')
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
web.Response: The HTTP response
|
web.Response: The HTTP response
|
||||||
@@ -579,12 +579,15 @@ class ModelRouteUtils:
|
|||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
# Create progress callback
|
# Get or generate a download ID
|
||||||
|
download_id = data.get('download_id', ws_manager.generate_download_id())
|
||||||
|
|
||||||
|
# Create progress callback with download ID
|
||||||
async def progress_callback(progress):
|
async def progress_callback(progress):
|
||||||
from ..services.websocket_manager import ws_manager
|
await ws_manager.broadcast_download_progress(download_id, {
|
||||||
await ws_manager.broadcast({
|
|
||||||
'status': 'progress',
|
'status': 'progress',
|
||||||
'progress': progress
|
'progress': progress,
|
||||||
|
'download_id': download_id
|
||||||
})
|
})
|
||||||
|
|
||||||
# Check which identifier is provided
|
# Check which identifier is provided
|
||||||
@@ -598,15 +601,20 @@ class ModelRouteUtils:
|
|||||||
text="Missing required parameter: Please provide 'model_id'"
|
text="Missing required parameter: Please provide 'model_id'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
use_default_paths = data.get('use_default_paths', False)
|
||||||
|
|
||||||
result = await download_manager.download_from_civitai(
|
result = await download_manager.download_from_civitai(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
model_version_id=model_version_id,
|
model_version_id=model_version_id,
|
||||||
save_dir=data.get('model_root'),
|
save_dir=data.get('model_root'),
|
||||||
relative_path=data.get('relative_path', ''),
|
relative_path=data.get('relative_path', ''),
|
||||||
progress_callback=progress_callback,
|
use_default_paths=use_default_paths,
|
||||||
model_type=model_type
|
progress_callback=progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Include download_id in the response
|
||||||
|
result['download_id'] = download_id
|
||||||
|
|
||||||
if not result.get('success', False):
|
if not result.get('success', False):
|
||||||
error_message = result.get('error', 'Unknown error')
|
error_message = result.get('error', 'Unknown error')
|
||||||
|
|
||||||
|
|||||||
@@ -301,13 +301,24 @@ export class CheckpointDownloadManager {
|
|||||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||||
updateProgress(0, 0, this.currentVersion.name);
|
updateProgress(0, 0, this.currentVersion.name);
|
||||||
|
|
||||||
// Setup WebSocket for progress updates using checkpoint-specific endpoint
|
// Generate a unique ID for this download
|
||||||
|
const downloadId = Date.now().toString();
|
||||||
|
|
||||||
|
// Setup WebSocket for progress updates using download-specific endpoint
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
if (data.status === 'progress') {
|
|
||||||
|
// Handle download ID confirmation
|
||||||
|
if (data.type === 'download_id') {
|
||||||
|
console.log(`Connected to checkpoint download progress with ID: ${data.download_id}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process progress updates for our download
|
||||||
|
if (data.status === 'progress' && data.download_id === downloadId) {
|
||||||
// Update progress display with current progress
|
// Update progress display with current progress
|
||||||
updateProgress(data.progress, 0, this.currentVersion.name);
|
updateProgress(data.progress, 0, this.currentVersion.name);
|
||||||
|
|
||||||
@@ -329,7 +340,7 @@ export class CheckpointDownloadManager {
|
|||||||
// Continue with download even if WebSocket fails
|
// Continue with download even if WebSocket fails
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start download using checkpoint download endpoint
|
// Start download using checkpoint download endpoint with download ID
|
||||||
const response = await fetch('/api/download-model', {
|
const response = await fetch('/api/download-model', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
@@ -337,7 +348,8 @@ export class CheckpointDownloadManager {
|
|||||||
model_id: this.modelId,
|
model_id: this.modelId,
|
||||||
model_version_id: this.currentVersion.id,
|
model_version_id: this.currentVersion.id,
|
||||||
model_root: checkpointRoot,
|
model_root: checkpointRoot,
|
||||||
relative_path: targetFolder
|
relative_path: targetFolder,
|
||||||
|
download_id: downloadId
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -311,13 +311,24 @@ export class DownloadManager {
|
|||||||
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
const updateProgress = this.loadingManager.showDownloadProgress(1);
|
||||||
updateProgress(0, 0, this.currentVersion.name);
|
updateProgress(0, 0, this.currentVersion.name);
|
||||||
|
|
||||||
// Setup WebSocket for progress updates
|
// Generate a unique ID for this download
|
||||||
|
const downloadId = Date.now().toString();
|
||||||
|
|
||||||
|
// Setup WebSocket for progress updates - use download-specific endpoint
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||||
|
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
if (data.status === 'progress') {
|
|
||||||
|
// Handle download ID confirmation
|
||||||
|
if (data.type === 'download_id') {
|
||||||
|
console.log(`Connected to download progress with ID: ${data.download_id}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only process progress updates for our download
|
||||||
|
if (data.status === 'progress' && data.download_id === downloadId) {
|
||||||
// Update progress display with current progress
|
// Update progress display with current progress
|
||||||
updateProgress(data.progress, 0, this.currentVersion.name);
|
updateProgress(data.progress, 0, this.currentVersion.name);
|
||||||
|
|
||||||
@@ -339,7 +350,7 @@ export class DownloadManager {
|
|||||||
// Continue with download even if WebSocket fails
|
// Continue with download even if WebSocket fails
|
||||||
};
|
};
|
||||||
|
|
||||||
// Start download
|
// Start download with our download ID
|
||||||
const response = await fetch('/api/download-model', {
|
const response = await fetch('/api/download-model', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
@@ -347,7 +358,8 @@ export class DownloadManager {
|
|||||||
model_id: this.modelId,
|
model_id: this.modelId,
|
||||||
model_version_id: this.currentVersion.id,
|
model_version_id: this.currentVersion.id,
|
||||||
model_root: loraRoot,
|
model_root: loraRoot,
|
||||||
relative_path: targetFolder
|
relative_path: targetFolder,
|
||||||
|
download_id: downloadId
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -358,6 +370,9 @@ export class DownloadManager {
|
|||||||
showToast('Download completed successfully', 'success');
|
showToast('Download completed successfully', 'success');
|
||||||
modalManager.closeModal('downloadModal');
|
modalManager.closeModal('downloadModal');
|
||||||
|
|
||||||
|
// Close WebSocket after download completes
|
||||||
|
ws.close();
|
||||||
|
|
||||||
// Update state and trigger reload with folder update
|
// Update state and trigger reload with folder update
|
||||||
state.activeFolder = targetFolder;
|
state.activeFolder = targetFolder;
|
||||||
await resetAndReload(true); // Pass true to update folders
|
await resetAndReload(true); // Pass true to update folders
|
||||||
|
|||||||
@@ -128,9 +128,12 @@ export class DownloadManager {
|
|||||||
targetPath += '/' + newFolder;
|
targetPath += '/' + newFolder;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generate a unique ID for this batch download
|
||||||
|
const batchDownloadId = Date.now().toString();
|
||||||
|
|
||||||
// Set up WebSocket for progress updates
|
// Set up WebSocket for progress updates
|
||||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
|
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${batchDownloadId}`);
|
||||||
|
|
||||||
// Show enhanced loading with progress details for multiple items
|
// Show enhanced loading with progress details for multiple items
|
||||||
const updateProgress = this.importManager.loadingManager.showDownloadProgress(
|
const updateProgress = this.importManager.loadingManager.showDownloadProgress(
|
||||||
@@ -145,7 +148,15 @@ export class DownloadManager {
|
|||||||
// Set up progress tracking for current download
|
// Set up progress tracking for current download
|
||||||
ws.onmessage = (event) => {
|
ws.onmessage = (event) => {
|
||||||
const data = JSON.parse(event.data);
|
const data = JSON.parse(event.data);
|
||||||
if (data.status === 'progress') {
|
|
||||||
|
// Handle download ID confirmation
|
||||||
|
if (data.type === 'download_id') {
|
||||||
|
console.log(`Connected to batch download progress with ID: ${data.download_id}`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process progress updates for our current active download
|
||||||
|
if (data.status === 'progress' && data.download_id && data.download_id.startsWith(batchDownloadId)) {
|
||||||
// Update current LoRA progress
|
// Update current LoRA progress
|
||||||
currentLoraProgress = data.progress;
|
currentLoraProgress = data.progress;
|
||||||
|
|
||||||
@@ -188,16 +199,16 @@ export class DownloadManager {
|
|||||||
updateProgress(0, completedDownloads, lora.name);
|
updateProgress(0, completedDownloads, lora.name);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Download the LoRA
|
// Download the LoRA with download ID
|
||||||
const response = await fetch('/api/download-model', {
|
const response = await fetch('/api/download-model', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
download_url: lora.downloadUrl,
|
model_id: lora.modelId,
|
||||||
model_version_id: lora.modelVersionId,
|
model_version_id: lora.id,
|
||||||
model_hash: lora.hash,
|
|
||||||
model_root: loraRoot,
|
model_root: loraRoot,
|
||||||
relative_path: targetPath.replace(loraRoot + '/', '')
|
relative_path: targetPath.replace(loraRoot + '/', ''),
|
||||||
|
download_id: batchDownloadId
|
||||||
})
|
})
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user