mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
Add progress tracking to downloads
This commit is contained in:
@@ -480,10 +480,19 @@ class ApiRoutes:
|
|||||||
async with self._download_lock:
|
async with self._download_lock:
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
|
|
||||||
|
# Create progress callback
|
||||||
|
async def progress_callback(progress):
|
||||||
|
await ws_manager.broadcast({
|
||||||
|
'status': 'progress',
|
||||||
|
'progress': progress
|
||||||
|
})
|
||||||
|
|
||||||
result = await self.download_manager.download_from_civitai(
|
result = await self.download_manager.download_from_civitai(
|
||||||
download_url=data.get('download_url'),
|
download_url=data.get('download_url'),
|
||||||
save_dir=data.get('lora_root'),
|
save_dir=data.get('lora_root'),
|
||||||
relative_path=data.get('relative_path')
|
relative_path=data.get('relative_path'),
|
||||||
|
progress_callback=progress_callback # Add progress callback
|
||||||
)
|
)
|
||||||
return web.json_response(result)
|
return web.json_response(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -59,8 +59,18 @@ class CivitaiClient:
|
|||||||
|
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
async def _download_file(self, url: str, save_dir: str, default_filename: str) -> Tuple[bool, str]:
|
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||||
"""Download file with content-disposition support"""
|
"""Download file with content-disposition support and progress tracking
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Download URL
|
||||||
|
save_dir: Directory to save the file
|
||||||
|
default_filename: Fallback filename if none provided in headers
|
||||||
|
progress_callback: Optional async callback function for progress updates (0-100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (success, save_path or error message)
|
||||||
|
"""
|
||||||
session = await self.session
|
session = await self.session
|
||||||
try:
|
try:
|
||||||
headers = self._get_request_headers()
|
headers = self._get_request_headers()
|
||||||
@@ -76,13 +86,23 @@ class CivitaiClient:
|
|||||||
|
|
||||||
save_path = os.path.join(save_dir, filename)
|
save_path = os.path.join(save_dir, filename)
|
||||||
|
|
||||||
# Stream download to file
|
# Get total file size for progress calculation
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
current_size = 0
|
||||||
|
|
||||||
|
# Stream download to file with progress updates
|
||||||
with open(save_path, 'wb') as f:
|
with open(save_path, 'wb') as f:
|
||||||
while True:
|
async for chunk in response.content.iter_chunked(8192):
|
||||||
chunk = await response.content.read(8192)
|
if chunk:
|
||||||
if not chunk:
|
f.write(chunk)
|
||||||
break
|
current_size += len(chunk)
|
||||||
f.write(chunk)
|
if progress_callback and total_size:
|
||||||
|
progress = (current_size / total_size) * 100
|
||||||
|
await progress_callback(progress)
|
||||||
|
|
||||||
|
# Ensure 100% progress is reported
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(100)
|
||||||
|
|
||||||
return True, save_path
|
return True, save_path
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@ class DownloadManager:
|
|||||||
self.civitai_client = CivitaiClient()
|
self.civitai_client = CivitaiClient()
|
||||||
self.file_monitor = file_monitor
|
self.file_monitor = file_monitor
|
||||||
|
|
||||||
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '') -> Dict:
|
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '',
|
||||||
|
progress_callback=None) -> Dict:
|
||||||
try:
|
try:
|
||||||
# Update save directory with relative path if provided
|
# Update save directory with relative path if provided
|
||||||
if relative_path:
|
if relative_path:
|
||||||
@@ -27,6 +28,10 @@ class DownloadManager:
|
|||||||
if not version_info:
|
if not version_info:
|
||||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||||
|
|
||||||
|
# Report initial progress
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(0)
|
||||||
|
|
||||||
# 2. 获取文件信息
|
# 2. 获取文件信息
|
||||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||||
if not file_info:
|
if not file_info:
|
||||||
@@ -52,7 +57,8 @@ class DownloadManager:
|
|||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
version_info=version_info,
|
version_info=version_info,
|
||||||
relative_path=relative_path
|
relative_path=relative_path,
|
||||||
|
progress_callback=progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -62,32 +68,41 @@ class DownloadManager:
|
|||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
async def _execute_download(self, download_url: str, save_dir: str,
|
async def _execute_download(self, download_url: str, save_dir: str,
|
||||||
metadata: LoraMetadata, version_info: Dict, relative_path: str) -> Dict:
|
metadata: LoraMetadata, version_info: Dict,
|
||||||
"""执行实际的下载流程,包括预览图和模型文件"""
|
relative_path: str, progress_callback=None) -> Dict:
|
||||||
|
"""Execute the actual download process including preview images and model files"""
|
||||||
try:
|
try:
|
||||||
save_path = metadata.file_path
|
save_path = metadata.file_path
|
||||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||||
|
|
||||||
# 2. 下载预览图(如果有)
|
# Download preview image if available
|
||||||
images = version_info.get('images', [])
|
images = version_info.get('images', [])
|
||||||
if images:
|
if images:
|
||||||
|
# Report preview download progress
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(5) # 5% progress for starting preview download
|
||||||
|
|
||||||
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
|
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
|
||||||
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
|
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
|
||||||
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||||
# 更新元数据中的预览图URL
|
|
||||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
# 3. 下载模型文件
|
# Report preview download completion
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(10) # 10% progress after preview download
|
||||||
|
|
||||||
|
# Download model file with progress tracking
|
||||||
success, result = await self.civitai_client._download_file(
|
success, result = await self.civitai_client._download_file(
|
||||||
download_url,
|
download_url,
|
||||||
save_dir,
|
save_dir,
|
||||||
os.path.basename(save_path)
|
os.path.basename(save_path),
|
||||||
|
progress_callback=lambda p: self._handle_download_progress(p, progress_callback)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
# 下载失败时清理文件
|
# Clean up files on failure
|
||||||
for path in [save_path, metadata_path, metadata.preview_url]:
|
for path in [save_path, metadata_path, metadata.preview_url]:
|
||||||
if path and os.path.exists(path):
|
if path and os.path.exists(path):
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
@@ -110,14 +125,30 @@ class DownloadManager:
|
|||||||
all_folders.add(relative_path)
|
all_folders.add(relative_path)
|
||||||
cache.folders = sorted(list(all_folders))
|
cache.folders = sorted(list(all_folders))
|
||||||
|
|
||||||
|
# Report 100% completion
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(100)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'success': True
|
'success': True
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
||||||
# 确保清理任何部分下载的文件
|
# Clean up partial downloads
|
||||||
for path in [save_path, metadata_path]:
|
for path in [save_path, metadata_path]:
|
||||||
if path and os.path.exists(path):
|
if path and os.path.exists(path):
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
|
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
||||||
|
"""Convert file download progress to overall progress
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_progress: Progress of file download (0-100)
|
||||||
|
progress_callback: Callback function for progress updates
|
||||||
|
"""
|
||||||
|
if progress_callback:
|
||||||
|
# Scale file progress to 10-100 range (after preview download)
|
||||||
|
overall_progress = 10 + (file_progress * 0.9) # 90% of progress for file download
|
||||||
|
await progress_callback(round(overall_progress))
|
||||||
@@ -18,8 +18,8 @@
|
|||||||
|
|
||||||
/* Z-index Scale */
|
/* Z-index Scale */
|
||||||
--z-base: 10;
|
--z-base: 10;
|
||||||
--z-modal: 30;
|
--z-modal: 1000; /* 更新modal的z-index */
|
||||||
--z-overlay: 50;
|
--z-overlay: 2000; /* 更新overlay的z-index,确保比modal高 */
|
||||||
|
|
||||||
/* Border Radius */
|
/* Border Radius */
|
||||||
--border-radius-base: 12px;
|
--border-radius-base: 12px;
|
||||||
@@ -274,7 +274,7 @@ body {
|
|||||||
width: 100%;
|
width: 100%;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
background: rgba(0, 0, 0, 0.8);
|
background: rgba(0, 0, 0, 0.8);
|
||||||
z-index: 1000;
|
z-index: var(--z-modal);
|
||||||
overflow-y: auto; /* 允许模态窗口内容滚动 */
|
overflow-y: auto; /* 允许模态窗口内容滚动 */
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -665,8 +665,7 @@ body.modal-open {
|
|||||||
padding: 12px 16px;
|
padding: 12px 16px;
|
||||||
border-radius: var(--border-radius-sm);
|
border-radius: var(--border-radius-sm);
|
||||||
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.2);
|
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.2);
|
||||||
/* z-index: calc(var(--z-overlay) + 10); */
|
z-index: calc(var(--z-overlay) + 10); /* 让toast显示在最上层 */
|
||||||
z-index: 1000; /* 保证在其他元素之上 */
|
|
||||||
opacity: 0;
|
opacity: 0;
|
||||||
transition: transform 0.3s cubic-bezier(0.4, 0, 0.2, 1),
|
transition: transform 0.3s cubic-bezier(0.4, 0, 0.2, 1),
|
||||||
opacity 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
opacity 0.3s cubic-bezier(0.4, 0, 0.2, 1);
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import { modalManager } from './ModalManager.js';
|
import { modalManager } from './ModalManager.js';
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast } from '../utils/uiHelpers.js';
|
||||||
|
import { LoadingManager } from './LoadingManager.js';
|
||||||
|
|
||||||
export class DownloadManager {
|
export class DownloadManager {
|
||||||
constructor() {
|
constructor() {
|
||||||
@@ -11,6 +12,9 @@ export class DownloadManager {
|
|||||||
// Add initialization check
|
// Add initialization check
|
||||||
this.initialized = false;
|
this.initialized = false;
|
||||||
this.selectedFolder = '';
|
this.selectedFolder = '';
|
||||||
|
|
||||||
|
// Add LoadingManager instance
|
||||||
|
this.loadingManager = new LoadingManager();
|
||||||
}
|
}
|
||||||
|
|
||||||
showDownloadModal() {
|
showDownloadModal() {
|
||||||
@@ -45,6 +49,9 @@ export class DownloadManager {
|
|||||||
const errorElement = document.getElementById('urlError');
|
const errorElement = document.getElementById('urlError');
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// Show loading while fetching versions
|
||||||
|
this.loadingManager.showSimpleLoading('Fetching model versions...');
|
||||||
|
|
||||||
const modelId = this.extractModelId(url);
|
const modelId = this.extractModelId(url);
|
||||||
if (!modelId) {
|
if (!modelId) {
|
||||||
throw new Error('Invalid Civitai URL format');
|
throw new Error('Invalid Civitai URL format');
|
||||||
@@ -68,6 +75,9 @@ export class DownloadManager {
|
|||||||
this.showVersionStep();
|
this.showVersionStep();
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
errorElement.textContent = error.message;
|
errorElement.textContent = error.message;
|
||||||
|
} finally {
|
||||||
|
// Hide loading when done
|
||||||
|
this.loadingManager.hide();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,9 +165,6 @@ export class DownloadManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
console.log('Selected folder:', this.selectedFolder); // Log selected folder
|
|
||||||
console.log('New folder:', newFolder); // Log new folder
|
|
||||||
|
|
||||||
// Construct relative path
|
// Construct relative path
|
||||||
let relativePath = '';
|
let relativePath = '';
|
||||||
if (this.selectedFolder) {
|
if (this.selectedFolder) {
|
||||||
@@ -174,7 +181,20 @@ export class DownloadManager {
|
|||||||
throw new Error('No download URL available');
|
throw new Error('No download URL available');
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只传递必要参数
|
// Show loading with progress bar for download
|
||||||
|
this.loadingManager.show('Downloading LoRA...', 0);
|
||||||
|
|
||||||
|
// Setup WebSocket for progress updates
|
||||||
|
const ws = new WebSocket(`ws://${window.location.host}/ws/fetch-progress`);
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
if (data.status === 'progress') {
|
||||||
|
this.loadingManager.setProgress(data.progress);
|
||||||
|
this.loadingManager.setStatus(`Downloading: ${data.progress}%`);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Start download
|
||||||
const response = await fetch('/api/download-lora', {
|
const response = await fetch('/api/download-lora', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: { 'Content-Type': 'application/json' },
|
headers: { 'Content-Type': 'application/json' },
|
||||||
@@ -195,8 +215,11 @@ export class DownloadManager {
|
|||||||
|
|
||||||
// Refresh the grid to show new model
|
// Refresh the grid to show new model
|
||||||
window.refreshLoras(false);
|
window.refreshLoras(false);
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showToast(error.message, 'error');
|
showToast(error.message, 'error');
|
||||||
|
} finally {
|
||||||
|
this.loadingManager.hide();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ BASE_MODEL_MAPPING = {
|
|||||||
"sdxl": "SDXL",
|
"sdxl": "SDXL",
|
||||||
"sd-v2": "SD2.0",
|
"sd-v2": "SD2.0",
|
||||||
"flux1": "Flux.1 D",
|
"flux1": "Flux.1 D",
|
||||||
|
"flux.1 d": "Flux.1 D",
|
||||||
"illustrious": "IL",
|
"illustrious": "IL",
|
||||||
"pony": "Pony"
|
"pony": "Pony"
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user