mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 06:02:11 -03:00
Add progress tracking to downloads
This commit is contained in:
@@ -59,8 +59,18 @@ class CivitaiClient:
|
||||
|
||||
return headers
|
||||
|
||||
async def _download_file(self, url: str, save_dir: str, default_filename: str) -> Tuple[bool, str]:
|
||||
"""Download file with content-disposition support"""
|
||||
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with content-disposition support and progress tracking
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
save_dir: Directory to save the file
|
||||
default_filename: Fallback filename if none provided in headers
|
||||
progress_callback: Optional async callback function for progress updates (0-100)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
session = await self.session
|
||||
try:
|
||||
headers = self._get_request_headers()
|
||||
@@ -76,13 +86,23 @@ class CivitaiClient:
|
||||
|
||||
save_path = os.path.join(save_dir, filename)
|
||||
|
||||
# Stream download to file
|
||||
# Get total file size for progress calculation
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
current_size = 0
|
||||
|
||||
# Stream download to file with progress updates
|
||||
with open(save_path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await response.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
current_size += len(chunk)
|
||||
if progress_callback and total_size:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
|
||||
return True, save_path
|
||||
|
||||
|
||||
@@ -13,7 +13,8 @@ class DownloadManager:
|
||||
self.civitai_client = CivitaiClient()
|
||||
self.file_monitor = file_monitor
|
||||
|
||||
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '') -> Dict:
|
||||
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '',
|
||||
progress_callback=None) -> Dict:
|
||||
try:
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
@@ -27,6 +28,10 @@ class DownloadManager:
|
||||
if not version_info:
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
# Report initial progress
|
||||
if progress_callback:
|
||||
await progress_callback(0)
|
||||
|
||||
# 2. 获取文件信息
|
||||
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
||||
if not file_info:
|
||||
@@ -52,7 +57,8 @@ class DownloadManager:
|
||||
save_dir=save_dir,
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path=relative_path
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -62,32 +68,41 @@ class DownloadManager:
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _execute_download(self, download_url: str, save_dir: str,
|
||||
metadata: LoraMetadata, version_info: Dict, relative_path: str) -> Dict:
|
||||
"""执行实际的下载流程,包括预览图和模型文件"""
|
||||
metadata: LoraMetadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
save_path = metadata.file_path
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
# 2. 下载预览图(如果有)
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
if images:
|
||||
# Report preview download progress
|
||||
if progress_callback:
|
||||
await progress_callback(5) # 5% progress for starting preview download
|
||||
|
||||
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
|
||||
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
|
||||
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
# 更新元数据中的预览图URL
|
||||
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
|
||||
# 3. 下载模型文件
|
||||
# Report preview download completion
|
||||
if progress_callback:
|
||||
await progress_callback(10) # 10% progress after preview download
|
||||
|
||||
# Download model file with progress tracking
|
||||
success, result = await self.civitai_client._download_file(
|
||||
download_url,
|
||||
save_dir,
|
||||
os.path.basename(save_path)
|
||||
os.path.basename(save_path),
|
||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback)
|
||||
)
|
||||
|
||||
if not success:
|
||||
# 下载失败时清理文件
|
||||
# Clean up files on failure
|
||||
for path in [save_path, metadata_path, metadata.preview_url]:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
@@ -110,14 +125,30 @@ class DownloadManager:
|
||||
all_folders.add(relative_path)
|
||||
cache.folders = sorted(list(all_folders))
|
||||
|
||||
# Report 100% completion
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
|
||||
return {
|
||||
'success': True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
||||
# 确保清理任何部分下载的文件
|
||||
# Clean up partial downloads
|
||||
for path in [save_path, metadata_path]:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
return {'success': False, 'error': str(e)}
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
||||
"""Convert file download progress to overall progress
|
||||
|
||||
Args:
|
||||
file_progress: Progress of file download (0-100)
|
||||
progress_callback: Callback function for progress updates
|
||||
"""
|
||||
if progress_callback:
|
||||
# Scale file progress to 10-100 range (after preview download)
|
||||
overall_progress = 10 + (file_progress * 0.9) # 90% of progress for file download
|
||||
await progress_callback(round(overall_progress))
|
||||
Reference in New Issue
Block a user