From 89368ad0e4b407dfdab49c2fc505477fb30c67c7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 14 Feb 2025 14:37:23 +0800 Subject: [PATCH] Add progress tracking to downloads --- routes/api_routes.py | 11 +++++- services/civitai_client.py | 36 ++++++++++++++---- services/download_manager.py | 53 +++++++++++++++++++++------ static/css/style.css | 9 ++--- static/js/managers/DownloadManager.js | 31 ++++++++++++++-- utils/model_utils.py | 1 + 6 files changed, 112 insertions(+), 29 deletions(-) diff --git a/routes/api_routes.py b/routes/api_routes.py index e369644a..1fc291be 100644 --- a/routes/api_routes.py +++ b/routes/api_routes.py @@ -480,10 +480,19 @@ class ApiRoutes: async with self._download_lock: try: data = await request.json() + + # Create progress callback + async def progress_callback(progress): + await ws_manager.broadcast({ + 'status': 'progress', + 'progress': progress + }) + result = await self.download_manager.download_from_civitai( download_url=data.get('download_url'), save_dir=data.get('lora_root'), - relative_path=data.get('relative_path') + relative_path=data.get('relative_path'), + progress_callback=progress_callback # Add progress callback ) return web.json_response(result) except Exception as e: diff --git a/services/civitai_client.py b/services/civitai_client.py index 92702863..fe002ed9 100644 --- a/services/civitai_client.py +++ b/services/civitai_client.py @@ -59,8 +59,18 @@ class CivitaiClient: return headers - async def _download_file(self, url: str, save_dir: str, default_filename: str) -> Tuple[bool, str]: - """Download file with content-disposition support""" + async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: + """Download file with content-disposition support and progress tracking + + Args: + url: Download URL + save_dir: Directory to save the file + default_filename: Fallback filename if none provided in headers + progress_callback: Optional async callback function for progress updates (0-100) + + Returns: + Tuple[bool, str]: (success, save_path or error message) + """ session = await self.session try: headers = self._get_request_headers() @@ -76,13 +86,23 @@ class CivitaiClient: save_path = os.path.join(save_dir, filename) - # Stream download to file + # Get total file size for progress calculation + total_size = int(response.headers.get('content-length', 0)) + current_size = 0 + + # Stream download to file with progress updates with open(save_path, 'wb') as f: - while True: - chunk = await response.content.read(8192) - if not chunk: - break - f.write(chunk) + async for chunk in response.content.iter_chunked(8192): + if chunk: + f.write(chunk) + current_size += len(chunk) + if progress_callback and total_size: + progress = (current_size / total_size) * 100 + await progress_callback(progress) + + # Ensure 100% progress is reported + if progress_callback: + await progress_callback(100) return True, save_path diff --git a/services/download_manager.py b/services/download_manager.py index b7cfd6e7..4dab4530 100644 --- a/services/download_manager.py +++ b/services/download_manager.py @@ -13,7 +13,8 @@ class DownloadManager: self.civitai_client = CivitaiClient() self.file_monitor = file_monitor - async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '') -> Dict: + async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '', + progress_callback=None) -> Dict: try: # Update save directory with relative path if provided if relative_path: @@ -27,6 +28,10 @@ class DownloadManager: if not version_info: return {'success': False, 'error': 'Failed to fetch model metadata'} + # Report initial progress + if progress_callback: + await progress_callback(0) + # 2. 获取文件信息 file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) if not file_info: @@ -52,7 +57,8 @@ class DownloadManager: save_dir=save_dir, metadata=metadata, version_info=version_info, - relative_path=relative_path + relative_path=relative_path, + progress_callback=progress_callback ) return result @@ -62,32 +68,41 @@ class DownloadManager: return {'success': False, 'error': str(e)} async def _execute_download(self, download_url: str, save_dir: str, - metadata: LoraMetadata, version_info: Dict, relative_path: str) -> Dict: - """执行实际的下载流程,包括预览图和模型文件""" + metadata: LoraMetadata, version_info: Dict, + relative_path: str, progress_callback=None) -> Dict: + """Execute the actual download process including preview images and model files""" try: save_path = metadata.file_path metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' - # 2. 下载预览图(如果有) + # Download preview image if available images = version_info.get('images', []) if images: + # Report preview download progress + if progress_callback: + await progress_callback(5) # 5% progress for starting preview download + preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png' preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext if await self.civitai_client.download_preview_image(images[0]['url'], preview_path): metadata.preview_url = preview_path.replace(os.sep, '/') - # 更新元数据中的预览图URL with open(metadata_path, 'w', encoding='utf-8') as f: json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) - # 3. 下载模型文件 + # Report preview download completion + if progress_callback: + await progress_callback(10) # 10% progress after preview download + + # Download model file with progress tracking success, result = await self.civitai_client._download_file( download_url, save_dir, - os.path.basename(save_path) + os.path.basename(save_path), + progress_callback=lambda p: self._handle_download_progress(p, progress_callback) ) if not success: - # 下载失败时清理文件 + # Clean up files on failure for path in [save_path, metadata_path, metadata.preview_url]: if path and os.path.exists(path): os.remove(path) @@ -110,14 +125,30 @@ class DownloadManager: all_folders.add(relative_path) cache.folders = sorted(list(all_folders)) + # Report 100% completion + if progress_callback: + await progress_callback(100) + return { 'success': True } except Exception as e: logger.error(f"Error in _execute_download: {e}", exc_info=True) - # 确保清理任何部分下载的文件 + # Clean up partial downloads for path in [save_path, metadata_path]: if path and os.path.exists(path): os.remove(path) - return {'success': False, 'error': str(e)} \ No newline at end of file + return {'success': False, 'error': str(e)} + + async def _handle_download_progress(self, file_progress: float, progress_callback): + """Convert file download progress to overall progress + + Args: + file_progress: Progress of file download (0-100) + progress_callback: Callback function for progress updates + """ + if progress_callback: + # Scale file progress to 10-100 range (after preview download) + overall_progress = 10 + (file_progress * 0.9) # 90% of progress for file download + await progress_callback(round(overall_progress)) \ No newline at end of file diff --git a/static/css/style.css b/static/css/style.css index fd28b21c..a35b58cb 100644 --- a/static/css/style.css +++ b/static/css/style.css @@ -18,8 +18,8 @@ /* Z-index Scale */ --z-base: 10; - --z-modal: 30; - --z-overlay: 50; + --z-modal: 1000; /* 更新modal的z-index */ + --z-overlay: 2000; /* 更新overlay的z-index,确保比modal高 */ /* Border Radius */ --border-radius-base: 12px; @@ -274,7 +274,7 @@ body { width: 100%; height: 100%; background: rgba(0, 0, 0, 0.8); - z-index: 1000; + z-index: var(--z-modal); overflow-y: auto; /* 允许模态窗口内容滚动 */ } @@ -665,8 +665,7 @@ body.modal-open { padding: 12px 16px; border-radius: var(--border-radius-sm); box-shadow: 0 4px 16px rgba(0, 0, 0, 0.2); - /* z-index: calc(var(--z-overlay) + 10); */ - z-index: 1000; /* 保证在其他元素之上 */ + z-index: calc(var(--z-overlay) + 10); /* 让toast显示在最上层 */ opacity: 0; transition: transform 0.3s cubic-bezier(0.4, 0, 0.2, 1), opacity 0.3s cubic-bezier(0.4, 0, 0.2, 1); diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index 96726f3d..d5ca2b40 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -1,5 +1,6 @@ import { modalManager } from './ModalManager.js'; import { showToast } from '../utils/uiHelpers.js'; +import { LoadingManager } from './LoadingManager.js'; export class DownloadManager { constructor() { @@ -11,6 +12,9 @@ export class DownloadManager { // Add initialization check this.initialized = false; this.selectedFolder = ''; + + // Add LoadingManager instance + this.loadingManager = new LoadingManager(); } showDownloadModal() { @@ -45,6 +49,9 @@ export class DownloadManager { const errorElement = document.getElementById('urlError'); try { + // Show loading while fetching versions + this.loadingManager.showSimpleLoading('Fetching model versions...'); + const modelId = this.extractModelId(url); if (!modelId) { throw new Error('Invalid Civitai URL format'); @@ -68,6 +75,9 @@ export class DownloadManager { this.showVersionStep(); } catch (error) { errorElement.textContent = error.message; + } finally { + // Hide loading when done + this.loadingManager.hide(); } } @@ -155,9 +165,6 @@ export class DownloadManager { return; } - console.log('Selected folder:', this.selectedFolder); // Log selected folder - console.log('New folder:', newFolder); // Log new folder - // Construct relative path let relativePath = ''; if (this.selectedFolder) { @@ -174,7 +181,20 @@ export class DownloadManager { throw new Error('No download URL available'); } - // 只传递必要参数 + // Show loading with progress bar for download + this.loadingManager.show('Downloading LoRA...', 0); + + // Setup WebSocket for progress updates + const ws = new WebSocket(`ws://${window.location.host}/ws/fetch-progress`); + ws.onmessage = (event) => { + const data = JSON.parse(event.data); + if (data.status === 'progress') { + this.loadingManager.setProgress(data.progress); + this.loadingManager.setStatus(`Downloading: ${data.progress}%`); + } + }; + + // Start download const response = await fetch('/api/download-lora', { method: 'POST', headers: { 'Content-Type': 'application/json' }, @@ -195,8 +215,11 @@ export class DownloadManager { // Refresh the grid to show new model window.refreshLoras(false); + } catch (error) { showToast(error.message, 'error'); + } finally { + this.loadingManager.hide(); } } diff --git a/utils/model_utils.py b/utils/model_utils.py index 5a731dcf..593c1339 100644 --- a/utils/model_utils.py +++ b/utils/model_utils.py @@ -7,6 +7,7 @@ BASE_MODEL_MAPPING = { "sdxl": "SDXL", "sd-v2": "SD2.0", "flux1": "Flux.1 D", + "flux.1 d": "Flux.1 D", "illustrious": "IL", "pony": "Pony" }