mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(downloads): expose throughput metrics in progress APIs
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
||||
@@ -13,7 +14,7 @@ from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
from .downloader import get_downloader
|
||||
from .downloader import get_downloader, DownloadProgress
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -82,7 +83,10 @@ class DownloadManager:
|
||||
'model_id': model_id,
|
||||
'model_version_id': model_version_id,
|
||||
'progress': 0,
|
||||
'status': 'queued'
|
||||
'status': 'queued',
|
||||
'bytes_downloaded': 0,
|
||||
'total_bytes': None,
|
||||
'bytes_per_second': 0.0,
|
||||
}
|
||||
|
||||
# Create tracking task
|
||||
@@ -119,11 +123,19 @@ class DownloadManager:
|
||||
|
||||
# Wrap progress callback to track progress in active_downloads
|
||||
original_callback = progress_callback
|
||||
async def tracking_callback(progress):
|
||||
async def tracking_callback(progress, metrics=None):
|
||||
progress_value, snapshot = self._normalize_progress(progress, metrics)
|
||||
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['progress'] = progress
|
||||
info = self._active_downloads[task_id]
|
||||
info['progress'] = round(progress_value)
|
||||
if snapshot is not None:
|
||||
info['bytes_downloaded'] = snapshot.bytes_downloaded
|
||||
info['total_bytes'] = snapshot.total_bytes
|
||||
info['bytes_per_second'] = snapshot.bytes_per_second
|
||||
|
||||
if original_callback:
|
||||
await original_callback(progress)
|
||||
await self._dispatch_progress(original_callback, snapshot, progress_value)
|
||||
|
||||
# Acquire semaphore to limit concurrent downloads
|
||||
try:
|
||||
@@ -149,12 +161,14 @@ class DownloadManager:
|
||||
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
||||
if not result['success']:
|
||||
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
||||
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'cancelled'
|
||||
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||
logger.info(f"Download cancelled for task {task_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -163,6 +177,7 @@ class DownloadManager:
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'failed'
|
||||
self._active_downloads[task_id]['error'] = str(e)
|
||||
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||
return {'success': False, 'error': str(e)}
|
||||
finally:
|
||||
# Schedule cleanup of download record after delay
|
||||
@@ -551,7 +566,11 @@ class DownloadManager:
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
||||
progress_callback=lambda progress, snapshot=None: self._handle_download_progress(
|
||||
progress,
|
||||
progress_callback,
|
||||
snapshot,
|
||||
),
|
||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||
)
|
||||
|
||||
@@ -631,17 +650,33 @@ class DownloadManager:
|
||||
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
||||
"""Convert file download progress to overall progress
|
||||
|
||||
Args:
|
||||
file_progress: Progress of file download (0-100)
|
||||
progress_callback: Callback function for progress updates
|
||||
"""
|
||||
if progress_callback:
|
||||
# Scale file progress to 3-100 range (after preview download)
|
||||
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
|
||||
await progress_callback(round(overall_progress))
|
||||
async def _handle_download_progress(
|
||||
self,
|
||||
progress_update,
|
||||
progress_callback,
|
||||
snapshot=None,
|
||||
):
|
||||
"""Convert file download progress to overall progress."""
|
||||
|
||||
if not progress_callback:
|
||||
return
|
||||
|
||||
file_progress, original_snapshot = self._normalize_progress(progress_update, snapshot)
|
||||
overall_progress = 3 + (file_progress * 0.97)
|
||||
overall_progress = max(0.0, min(overall_progress, 100.0))
|
||||
rounded_progress = round(overall_progress)
|
||||
|
||||
normalized_snapshot: Optional[DownloadProgress] = None
|
||||
if original_snapshot is not None:
|
||||
normalized_snapshot = DownloadProgress(
|
||||
percent_complete=overall_progress,
|
||||
bytes_downloaded=original_snapshot.bytes_downloaded,
|
||||
total_bytes=original_snapshot.total_bytes,
|
||||
bytes_per_second=original_snapshot.bytes_per_second,
|
||||
timestamp=original_snapshot.timestamp,
|
||||
)
|
||||
|
||||
await self._dispatch_progress(progress_callback, normalized_snapshot, rounded_progress)
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict:
|
||||
"""Cancel an active download by download_id
|
||||
@@ -663,6 +698,7 @@ class DownloadManager:
|
||||
# Update status in active downloads
|
||||
if download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||
self._active_downloads[download_id]['bytes_per_second'] = 0.0
|
||||
|
||||
# Wait briefly for the task to acknowledge cancellation
|
||||
try:
|
||||
@@ -725,7 +761,53 @@ class DownloadManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _coerce_progress_value(progress) -> float:
|
||||
try:
|
||||
return float(progress)
|
||||
except (TypeError, ValueError):
|
||||
return 0.0
|
||||
|
||||
@classmethod
|
||||
def _normalize_progress(
|
||||
cls,
|
||||
progress,
|
||||
snapshot: Optional[DownloadProgress] = None,
|
||||
) -> Tuple[float, Optional[DownloadProgress]]:
|
||||
if isinstance(progress, DownloadProgress):
|
||||
return progress.percent_complete, progress
|
||||
|
||||
if isinstance(snapshot, DownloadProgress):
|
||||
return snapshot.percent_complete, snapshot
|
||||
|
||||
if isinstance(progress, dict):
|
||||
if 'percent_complete' in progress:
|
||||
return cls._coerce_progress_value(progress['percent_complete']), snapshot
|
||||
if 'progress' in progress:
|
||||
return cls._coerce_progress_value(progress['progress']), snapshot
|
||||
|
||||
return cls._coerce_progress_value(progress), None
|
||||
|
||||
async def _dispatch_progress(
|
||||
self,
|
||||
callback,
|
||||
snapshot: Optional[DownloadProgress],
|
||||
progress_value: float,
|
||||
) -> None:
|
||||
try:
|
||||
if snapshot is not None:
|
||||
result = callback(snapshot, snapshot)
|
||||
else:
|
||||
result = callback(progress_value)
|
||||
except TypeError:
|
||||
result = callback(progress_value)
|
||||
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
elif asyncio.iscoroutine(result):
|
||||
await result
|
||||
|
||||
async def get_active_downloads(self) -> Dict:
|
||||
"""Get information about all active downloads
|
||||
|
||||
@@ -740,7 +822,10 @@ class DownloadManager:
|
||||
'model_version_id': info.get('model_version_id'),
|
||||
'progress': info.get('progress', 0),
|
||||
'status': info.get('status', 'unknown'),
|
||||
'error': info.get('error', None)
|
||||
'error': info.get('error', None),
|
||||
'bytes_downloaded': info.get('bytes_downloaded', 0),
|
||||
'total_bytes': info.get('total_bytes'),
|
||||
'bytes_per_second': info.get('bytes_per_second', 0.0),
|
||||
}
|
||||
for task_id, info in self._active_downloads.items()
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user