mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat(downloads): expose throughput metrics in progress APIs
This commit is contained in:
@@ -766,7 +766,15 @@ class ModelDownloadHandler:
|
|||||||
progress_data = self._ws_manager.get_download_progress(download_id)
|
progress_data = self._ws_manager.get_download_progress(download_id)
|
||||||
if progress_data is None:
|
if progress_data is None:
|
||||||
return web.json_response({"success": False, "error": "Download ID not found"}, status=404)
|
return web.json_response({"success": False, "error": "Download ID not found"}, status=404)
|
||||||
return web.json_response({"success": True, "progress": progress_data.get("progress", 0)})
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"progress": progress_data.get("progress", 0),
|
||||||
|
"bytes_downloaded": progress_data.get("bytes_downloaded"),
|
||||||
|
"total_bytes": progress_data.get("total_bytes"),
|
||||||
|
"bytes_per_second": progress_data.get("bytes_per_second", 0.0),
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error getting download progress: %s", exc, exc_info=True)
|
self._logger.error("Error getting download progress: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
|
from .downloader import DownloadProgress
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -29,14 +31,40 @@ class DownloadCoordinator:
|
|||||||
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||||
payload.setdefault("download_id", download_id)
|
payload.setdefault("download_id", download_id)
|
||||||
|
|
||||||
async def progress_callback(progress: Any) -> None:
|
async def progress_callback(progress: Any, snapshot: Optional[DownloadProgress] = None) -> None:
|
||||||
|
percent = 0.0
|
||||||
|
metrics: Optional[DownloadProgress] = None
|
||||||
|
|
||||||
|
if isinstance(progress, DownloadProgress):
|
||||||
|
metrics = progress
|
||||||
|
percent = progress.percent_complete
|
||||||
|
elif isinstance(snapshot, DownloadProgress):
|
||||||
|
metrics = snapshot
|
||||||
|
percent = snapshot.percent_complete
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
percent = float(progress)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
percent = 0.0
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"status": "progress",
|
||||||
|
"progress": round(percent),
|
||||||
|
"download_id": download_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if metrics is not None:
|
||||||
|
payload.update(
|
||||||
|
{
|
||||||
|
"bytes_downloaded": metrics.bytes_downloaded,
|
||||||
|
"total_bytes": metrics.total_bytes,
|
||||||
|
"bytes_per_second": metrics.bytes_per_second,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
await self._ws_manager.broadcast_download_progress(
|
await self._ws_manager.broadcast_download_progress(
|
||||||
download_id,
|
download_id,
|
||||||
{
|
payload,
|
||||||
"status": "progress",
|
|
||||||
"progress": progress,
|
|
||||||
"download_id": download_id,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
||||||
@@ -13,7 +14,7 @@ from ..utils.metadata_manager import MetadataManager
|
|||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
from .metadata_service import get_default_metadata_provider
|
from .metadata_service import get_default_metadata_provider
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader, DownloadProgress
|
||||||
|
|
||||||
# Download to temporary file first
|
# Download to temporary file first
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -82,7 +83,10 @@ class DownloadManager:
|
|||||||
'model_id': model_id,
|
'model_id': model_id,
|
||||||
'model_version_id': model_version_id,
|
'model_version_id': model_version_id,
|
||||||
'progress': 0,
|
'progress': 0,
|
||||||
'status': 'queued'
|
'status': 'queued',
|
||||||
|
'bytes_downloaded': 0,
|
||||||
|
'total_bytes': None,
|
||||||
|
'bytes_per_second': 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create tracking task
|
# Create tracking task
|
||||||
@@ -119,11 +123,19 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Wrap progress callback to track progress in active_downloads
|
# Wrap progress callback to track progress in active_downloads
|
||||||
original_callback = progress_callback
|
original_callback = progress_callback
|
||||||
async def tracking_callback(progress):
|
async def tracking_callback(progress, metrics=None):
|
||||||
|
progress_value, snapshot = self._normalize_progress(progress, metrics)
|
||||||
|
|
||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
self._active_downloads[task_id]['progress'] = progress
|
info = self._active_downloads[task_id]
|
||||||
|
info['progress'] = round(progress_value)
|
||||||
|
if snapshot is not None:
|
||||||
|
info['bytes_downloaded'] = snapshot.bytes_downloaded
|
||||||
|
info['total_bytes'] = snapshot.total_bytes
|
||||||
|
info['bytes_per_second'] = snapshot.bytes_per_second
|
||||||
|
|
||||||
if original_callback:
|
if original_callback:
|
||||||
await original_callback(progress)
|
await self._dispatch_progress(original_callback, snapshot, progress_value)
|
||||||
|
|
||||||
# Acquire semaphore to limit concurrent downloads
|
# Acquire semaphore to limit concurrent downloads
|
||||||
try:
|
try:
|
||||||
@@ -149,12 +161,14 @@ class DownloadManager:
|
|||||||
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
||||||
if not result['success']:
|
if not result['success']:
|
||||||
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
||||||
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Handle cancellation
|
# Handle cancellation
|
||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
self._active_downloads[task_id]['status'] = 'cancelled'
|
self._active_downloads[task_id]['status'] = 'cancelled'
|
||||||
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||||
logger.info(f"Download cancelled for task {task_id}")
|
logger.info(f"Download cancelled for task {task_id}")
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -163,6 +177,7 @@ class DownloadManager:
|
|||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
self._active_downloads[task_id]['status'] = 'failed'
|
self._active_downloads[task_id]['status'] = 'failed'
|
||||||
self._active_downloads[task_id]['error'] = str(e)
|
self._active_downloads[task_id]['error'] = str(e)
|
||||||
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
finally:
|
finally:
|
||||||
# Schedule cleanup of download record after delay
|
# Schedule cleanup of download record after delay
|
||||||
@@ -551,7 +566,11 @@ class DownloadManager:
|
|||||||
success, result = await downloader.download_file(
|
success, result = await downloader.download_file(
|
||||||
download_url,
|
download_url,
|
||||||
save_path, # Use full path instead of separate dir and filename
|
save_path, # Use full path instead of separate dir and filename
|
||||||
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
progress_callback=lambda progress, snapshot=None: self._handle_download_progress(
|
||||||
|
progress,
|
||||||
|
progress_callback,
|
||||||
|
snapshot,
|
||||||
|
),
|
||||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -631,17 +650,33 @@ class DownloadManager:
|
|||||||
|
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
async def _handle_download_progress(
|
||||||
"""Convert file download progress to overall progress
|
self,
|
||||||
|
progress_update,
|
||||||
|
progress_callback,
|
||||||
|
snapshot=None,
|
||||||
|
):
|
||||||
|
"""Convert file download progress to overall progress."""
|
||||||
|
|
||||||
Args:
|
if not progress_callback:
|
||||||
file_progress: Progress of file download (0-100)
|
return
|
||||||
progress_callback: Callback function for progress updates
|
|
||||||
"""
|
file_progress, original_snapshot = self._normalize_progress(progress_update, snapshot)
|
||||||
if progress_callback:
|
overall_progress = 3 + (file_progress * 0.97)
|
||||||
# Scale file progress to 3-100 range (after preview download)
|
overall_progress = max(0.0, min(overall_progress, 100.0))
|
||||||
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
|
rounded_progress = round(overall_progress)
|
||||||
await progress_callback(round(overall_progress))
|
|
||||||
|
normalized_snapshot: Optional[DownloadProgress] = None
|
||||||
|
if original_snapshot is not None:
|
||||||
|
normalized_snapshot = DownloadProgress(
|
||||||
|
percent_complete=overall_progress,
|
||||||
|
bytes_downloaded=original_snapshot.bytes_downloaded,
|
||||||
|
total_bytes=original_snapshot.total_bytes,
|
||||||
|
bytes_per_second=original_snapshot.bytes_per_second,
|
||||||
|
timestamp=original_snapshot.timestamp,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._dispatch_progress(progress_callback, normalized_snapshot, rounded_progress)
|
||||||
|
|
||||||
async def cancel_download(self, download_id: str) -> Dict:
|
async def cancel_download(self, download_id: str) -> Dict:
|
||||||
"""Cancel an active download by download_id
|
"""Cancel an active download by download_id
|
||||||
@@ -663,6 +698,7 @@ class DownloadManager:
|
|||||||
# Update status in active downloads
|
# Update status in active downloads
|
||||||
if download_id in self._active_downloads:
|
if download_id in self._active_downloads:
|
||||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||||
|
self._active_downloads[download_id]['bytes_per_second'] = 0.0
|
||||||
|
|
||||||
# Wait briefly for the task to acknowledge cancellation
|
# Wait briefly for the task to acknowledge cancellation
|
||||||
try:
|
try:
|
||||||
@@ -726,6 +762,52 @@ class DownloadManager:
|
|||||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_progress_value(progress) -> float:
|
||||||
|
try:
|
||||||
|
return float(progress)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _normalize_progress(
|
||||||
|
cls,
|
||||||
|
progress,
|
||||||
|
snapshot: Optional[DownloadProgress] = None,
|
||||||
|
) -> Tuple[float, Optional[DownloadProgress]]:
|
||||||
|
if isinstance(progress, DownloadProgress):
|
||||||
|
return progress.percent_complete, progress
|
||||||
|
|
||||||
|
if isinstance(snapshot, DownloadProgress):
|
||||||
|
return snapshot.percent_complete, snapshot
|
||||||
|
|
||||||
|
if isinstance(progress, dict):
|
||||||
|
if 'percent_complete' in progress:
|
||||||
|
return cls._coerce_progress_value(progress['percent_complete']), snapshot
|
||||||
|
if 'progress' in progress:
|
||||||
|
return cls._coerce_progress_value(progress['progress']), snapshot
|
||||||
|
|
||||||
|
return cls._coerce_progress_value(progress), None
|
||||||
|
|
||||||
|
async def _dispatch_progress(
|
||||||
|
self,
|
||||||
|
callback,
|
||||||
|
snapshot: Optional[DownloadProgress],
|
||||||
|
progress_value: float,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
if snapshot is not None:
|
||||||
|
result = callback(snapshot, snapshot)
|
||||||
|
else:
|
||||||
|
result = callback(progress_value)
|
||||||
|
except TypeError:
|
||||||
|
result = callback(progress_value)
|
||||||
|
|
||||||
|
if inspect.isawaitable(result):
|
||||||
|
await result
|
||||||
|
elif asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
|
||||||
async def get_active_downloads(self) -> Dict:
|
async def get_active_downloads(self) -> Dict:
|
||||||
"""Get information about all active downloads
|
"""Get information about all active downloads
|
||||||
|
|
||||||
@@ -740,7 +822,10 @@ class DownloadManager:
|
|||||||
'model_version_id': info.get('model_version_id'),
|
'model_version_id': info.get('model_version_id'),
|
||||||
'progress': info.get('progress', 0),
|
'progress': info.get('progress', 0),
|
||||||
'status': info.get('status', 'unknown'),
|
'status': info.get('status', 'unknown'),
|
||||||
'error': info.get('error', None)
|
'error': info.get('error', None),
|
||||||
|
'bytes_downloaded': info.get('bytes_downloaded', 0),
|
||||||
|
'total_bytes': info.get('total_bytes'),
|
||||||
|
'bytes_per_second': info.get('bytes_per_second', 0.0),
|
||||||
}
|
}
|
||||||
for task_id, info in self._active_downloads.items()
|
for task_id, info in self._active_downloads.items()
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -14,13 +14,26 @@ import os
|
|||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from datetime import datetime
|
from collections import deque
|
||||||
from typing import Optional, Dict, Tuple, Callable, Union
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
|
||||||
from ..services.settings_manager import get_settings_manager
|
from ..services.settings_manager import get_settings_manager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DownloadProgress:
|
||||||
|
"""Snapshot of a download transfer at a moment in time."""
|
||||||
|
|
||||||
|
percent_complete: float
|
||||||
|
bytes_downloaded: int
|
||||||
|
total_bytes: Optional[int]
|
||||||
|
bytes_per_second: float
|
||||||
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
class Downloader:
|
class Downloader:
|
||||||
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||||
|
|
||||||
@@ -159,7 +172,7 @@ class Downloader:
|
|||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
save_path: str,
|
save_path: str,
|
||||||
progress_callback: Optional[Callable[[float], None]] = None,
|
progress_callback: Optional[Callable[..., Awaitable[None]]] = None,
|
||||||
use_auth: bool = False,
|
use_auth: bool = False,
|
||||||
custom_headers: Optional[Dict[str, str]] = None,
|
custom_headers: Optional[Dict[str, str]] = None,
|
||||||
allow_resume: bool = True
|
allow_resume: bool = True
|
||||||
@@ -248,7 +261,16 @@ class Downloader:
|
|||||||
if allow_resume:
|
if allow_resume:
|
||||||
os.rename(part_path, save_path)
|
os.rename(part_path, save_path)
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(100)
|
await self._dispatch_progress_callback(
|
||||||
|
progress_callback,
|
||||||
|
DownloadProgress(
|
||||||
|
percent_complete=100.0,
|
||||||
|
bytes_downloaded=part_size,
|
||||||
|
total_bytes=actual_size,
|
||||||
|
bytes_per_second=0.0,
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
),
|
||||||
|
)
|
||||||
return True, save_path
|
return True, save_path
|
||||||
# Remove corrupted part file and restart
|
# Remove corrupted part file and restart
|
||||||
os.remove(part_path)
|
os.remove(part_path)
|
||||||
@@ -276,6 +298,8 @@ class Downloader:
|
|||||||
|
|
||||||
current_size = resume_offset
|
current_size = resume_offset
|
||||||
last_progress_report_time = datetime.now()
|
last_progress_report_time = datetime.now()
|
||||||
|
progress_samples: deque[tuple[datetime, int]] = deque()
|
||||||
|
progress_samples.append((last_progress_report_time, current_size))
|
||||||
|
|
||||||
# Ensure directory exists
|
# Ensure directory exists
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
@@ -294,9 +318,30 @@ class Downloader:
|
|||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
time_diff = (now - last_progress_report_time).total_seconds()
|
time_diff = (now - last_progress_report_time).total_seconds()
|
||||||
|
|
||||||
if progress_callback and total_size and time_diff >= 1.0:
|
if progress_callback and time_diff >= 1.0:
|
||||||
progress = (current_size / total_size) * 100
|
progress_samples.append((now, current_size))
|
||||||
await progress_callback(progress)
|
cutoff = now - timedelta(seconds=5)
|
||||||
|
while progress_samples and progress_samples[0][0] < cutoff:
|
||||||
|
progress_samples.popleft()
|
||||||
|
|
||||||
|
percent = (current_size / total_size) * 100 if total_size else 0.0
|
||||||
|
bytes_per_second = 0.0
|
||||||
|
if len(progress_samples) >= 2:
|
||||||
|
first_time, first_bytes = progress_samples[0]
|
||||||
|
last_time, last_bytes = progress_samples[-1]
|
||||||
|
elapsed = (last_time - first_time).total_seconds()
|
||||||
|
if elapsed > 0:
|
||||||
|
bytes_per_second = (last_bytes - first_bytes) / elapsed
|
||||||
|
|
||||||
|
progress_snapshot = DownloadProgress(
|
||||||
|
percent_complete=percent,
|
||||||
|
bytes_downloaded=current_size,
|
||||||
|
total_bytes=total_size or None,
|
||||||
|
bytes_per_second=bytes_per_second,
|
||||||
|
timestamp=now.timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
|
||||||
last_progress_report_time = now
|
last_progress_report_time = now
|
||||||
|
|
||||||
# Download completed successfully
|
# Download completed successfully
|
||||||
@@ -331,7 +376,15 @@ class Downloader:
|
|||||||
|
|
||||||
# Ensure 100% progress is reported
|
# Ensure 100% progress is reported
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(100)
|
final_snapshot = DownloadProgress(
|
||||||
|
percent_complete=100.0,
|
||||||
|
bytes_downloaded=final_size,
|
||||||
|
total_bytes=total_size or final_size,
|
||||||
|
bytes_per_second=0.0,
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
)
|
||||||
|
await self._dispatch_progress_callback(progress_callback, final_snapshot)
|
||||||
|
|
||||||
|
|
||||||
return True, save_path
|
return True, save_path
|
||||||
|
|
||||||
@@ -364,6 +417,23 @@ class Downloader:
|
|||||||
|
|
||||||
return False, f"Download failed after {self.max_retries + 1} attempts"
|
return False, f"Download failed after {self.max_retries + 1} attempts"
|
||||||
|
|
||||||
|
async def _dispatch_progress_callback(
|
||||||
|
self,
|
||||||
|
progress_callback: Callable[..., Awaitable[None]],
|
||||||
|
snapshot: DownloadProgress,
|
||||||
|
) -> None:
|
||||||
|
"""Invoke a progress callback while preserving backward compatibility."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = progress_callback(snapshot, snapshot)
|
||||||
|
except TypeError:
|
||||||
|
result = progress_callback(snapshot.percent_complete)
|
||||||
|
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
await result
|
||||||
|
elif hasattr(result, "__await__"):
|
||||||
|
await result
|
||||||
|
|
||||||
async def download_to_memory(
|
async def download_to_memory(
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader, DownloadProgress
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -77,9 +77,15 @@ class MetadataArchiveManager:
|
|||||||
progress_callback("download", f"Downloading from {url}")
|
progress_callback("download", f"Downloading from {url}")
|
||||||
|
|
||||||
# Custom progress callback to report download progress
|
# Custom progress callback to report download progress
|
||||||
async def download_progress(progress):
|
async def download_progress(progress, snapshot=None):
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback("download", f"Downloading archive... {progress:.1f}%")
|
if isinstance(progress, DownloadProgress):
|
||||||
|
percent = progress.percent_complete
|
||||||
|
elif isinstance(snapshot, DownloadProgress):
|
||||||
|
percent = snapshot.percent_complete
|
||||||
|
else:
|
||||||
|
percent = float(progress or 0)
|
||||||
|
progress_callback("download", f"Downloading archive... {percent:.1f}%")
|
||||||
|
|
||||||
success, result = await downloader.download_file(
|
success, result = await downloader.download_file(
|
||||||
url=url,
|
url=url,
|
||||||
|
|||||||
@@ -155,12 +155,17 @@ class WebSocketManager:
|
|||||||
|
|
||||||
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||||
"""Send progress update to specific download client"""
|
"""Send progress update to specific download client"""
|
||||||
# Store simplified progress data in memory (only progress percentage)
|
progress_entry = {
|
||||||
self._download_progress[download_id] = {
|
|
||||||
'progress': data.get('progress', 0),
|
'progress': data.get('progress', 0),
|
||||||
'timestamp': datetime.now()
|
'timestamp': datetime.now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for field in ('bytes_downloaded', 'total_bytes', 'bytes_per_second'):
|
||||||
|
if field in data:
|
||||||
|
progress_entry[field] = data[field]
|
||||||
|
|
||||||
|
self._download_progress[download_id] = progress_entry
|
||||||
|
|
||||||
if download_id not in self._download_websockets:
|
if download_id not in self._download_websockets:
|
||||||
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import sys
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import types
|
import types
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
||||||
sys.modules.setdefault("folder_paths", folder_paths_stub)
|
sys.modules.setdefault("folder_paths", folder_paths_stub)
|
||||||
@@ -16,6 +17,7 @@ from aiohttp.test_utils import TestClient, TestServer
|
|||||||
from py.config import config
|
from py.config import config
|
||||||
from py.routes.base_model_routes import BaseModelRoutes
|
from py.routes.base_model_routes import BaseModelRoutes
|
||||||
from py.services import model_file_service
|
from py.services import model_file_service
|
||||||
|
from py.services.downloader import DownloadProgress
|
||||||
from py.services.metadata_sync_service import MetadataSyncService
|
from py.services.metadata_sync_service import MetadataSyncService
|
||||||
from py.services.model_file_service import AutoOrganizeResult
|
from py.services.model_file_service import AutoOrganizeResult
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
@@ -59,12 +61,21 @@ def download_manager_stub():
|
|||||||
self.error = None
|
self.error = None
|
||||||
self.cancelled = []
|
self.cancelled = []
|
||||||
self.active_downloads = {}
|
self.active_downloads = {}
|
||||||
|
self.last_progress_snapshot: Optional[DownloadProgress] = None
|
||||||
|
|
||||||
async def download_from_civitai(self, **kwargs):
|
async def download_from_civitai(self, **kwargs):
|
||||||
self.calls.append(kwargs)
|
self.calls.append(kwargs)
|
||||||
if self.error is not None:
|
if self.error is not None:
|
||||||
raise self.error
|
raise self.error
|
||||||
await kwargs["progress_callback"](42)
|
snapshot = DownloadProgress(
|
||||||
|
percent_complete=50.0,
|
||||||
|
bytes_downloaded=5120,
|
||||||
|
total_bytes=10240,
|
||||||
|
bytes_per_second=2048.0,
|
||||||
|
timestamp=0.0,
|
||||||
|
)
|
||||||
|
self.last_progress_snapshot = snapshot
|
||||||
|
await kwargs["progress_callback"](snapshot)
|
||||||
return {"success": True, "path": "/tmp/model.safetensors"}
|
return {"success": True, "path": "/tmp/model.safetensors"}
|
||||||
|
|
||||||
async def cancel_download(self, download_id):
|
async def cancel_download(self, download_id):
|
||||||
@@ -332,7 +343,11 @@ def test_download_model_invokes_download_manager(
|
|||||||
assert call_args["download_id"] == payload["download_id"]
|
assert call_args["download_id"] == payload["download_id"]
|
||||||
progress = ws_manager.get_download_progress(payload["download_id"])
|
progress = ws_manager.get_download_progress(payload["download_id"])
|
||||||
assert progress is not None
|
assert progress is not None
|
||||||
assert progress["progress"] == 42
|
expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete)
|
||||||
|
assert progress["progress"] == expected_progress
|
||||||
|
assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded
|
||||||
|
assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes
|
||||||
|
assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second
|
||||||
assert "timestamp" in progress
|
assert "timestamp" in progress
|
||||||
|
|
||||||
progress_response = await client.get(
|
progress_response = await client.get(
|
||||||
@@ -341,7 +356,13 @@ def test_download_model_invokes_download_manager(
|
|||||||
progress_payload = await progress_response.json()
|
progress_payload = await progress_response.json()
|
||||||
|
|
||||||
assert progress_response.status == 200
|
assert progress_response.status == 200
|
||||||
assert progress_payload == {"success": True, "progress": 42}
|
assert progress_payload == {
|
||||||
|
"success": True,
|
||||||
|
"progress": expected_progress,
|
||||||
|
"bytes_downloaded": download_manager_stub.last_progress_snapshot.bytes_downloaded,
|
||||||
|
"total_bytes": download_manager_stub.last_progress_snapshot.total_bytes,
|
||||||
|
"bytes_per_second": download_manager_stub.last_progress_snapshot.bytes_per_second,
|
||||||
|
}
|
||||||
ws_manager.cleanup_download_progress(payload["download_id"])
|
ws_manager.cleanup_download_progress(payload["download_id"])
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Any, Dict, List
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.download_coordinator import DownloadCoordinator
|
from py.services.download_coordinator import DownloadCoordinator
|
||||||
|
from py.services.downloader import DownloadProgress
|
||||||
from py.services.metadata_sync_service import MetadataSyncService
|
from py.services.metadata_sync_service import MetadataSyncService
|
||||||
from py.services.preview_asset_service import PreviewAssetService
|
from py.services.preview_asset_service import PreviewAssetService
|
||||||
from py.services.tag_update_service import TagUpdateService
|
from py.services.tag_update_service import TagUpdateService
|
||||||
@@ -191,10 +192,17 @@ def test_download_coordinator_emits_progress() -> None:
|
|||||||
class DownloadManagerStub:
|
class DownloadManagerStub:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.calls: List[Dict[str, Any]] = []
|
self.calls: List[Dict[str, Any]] = []
|
||||||
|
self.snapshot = DownloadProgress(
|
||||||
|
percent_complete=25.0,
|
||||||
|
bytes_downloaded=256,
|
||||||
|
total_bytes=1024,
|
||||||
|
bytes_per_second=128.0,
|
||||||
|
timestamp=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
||||||
self.calls.append(kwargs)
|
self.calls.append(kwargs)
|
||||||
await kwargs["progress_callback"](10)
|
await kwargs["progress_callback"](self.snapshot)
|
||||||
return {"success": True}
|
return {"success": True}
|
||||||
|
|
||||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
@@ -216,6 +224,12 @@ def test_download_coordinator_emits_progress() -> None:
|
|||||||
assert result["success"] is True
|
assert result["success"] is True
|
||||||
assert manager_stub.calls
|
assert manager_stub.calls
|
||||||
assert ws_stub.progress_events
|
assert ws_stub.progress_events
|
||||||
|
expected_progress = round(manager_stub.snapshot.percent_complete)
|
||||||
|
first_event = ws_stub.progress_events[0]
|
||||||
|
assert first_event["progress"] == expected_progress
|
||||||
|
assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded
|
||||||
|
assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes
|
||||||
|
assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second
|
||||||
|
|
||||||
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
||||||
assert cancel_result["success"] is True
|
assert cancel_result["success"] is True
|
||||||
|
|||||||
@@ -86,10 +86,29 @@ async def test_broadcast_download_progress_tracks_state(manager):
|
|||||||
download_id = "abc"
|
download_id = "abc"
|
||||||
manager._download_websockets[download_id] = ws
|
manager._download_websockets[download_id] = ws
|
||||||
|
|
||||||
await manager.broadcast_download_progress(download_id, {"progress": 55})
|
await manager.broadcast_download_progress(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"progress": 55,
|
||||||
|
"bytes_downloaded": 512,
|
||||||
|
"total_bytes": 1024,
|
||||||
|
"bytes_per_second": 128.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
assert ws.messages == [{"progress": 55}]
|
assert ws.messages == [
|
||||||
assert manager.get_download_progress(download_id)["progress"] == 55
|
{
|
||||||
|
"progress": 55,
|
||||||
|
"bytes_downloaded": 512,
|
||||||
|
"total_bytes": 1024,
|
||||||
|
"bytes_per_second": 128.0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
stored = manager.get_download_progress(download_id)
|
||||||
|
assert stored["progress"] == 55
|
||||||
|
assert stored["bytes_downloaded"] == 512
|
||||||
|
assert stored["total_bytes"] == 1024
|
||||||
|
assert stored["bytes_per_second"] == 128.0
|
||||||
|
|
||||||
|
|
||||||
async def test_broadcast_download_progress_to_multiple_updates(manager):
|
async def test_broadcast_download_progress_to_multiple_updates(manager):
|
||||||
@@ -107,7 +126,9 @@ async def test_broadcast_download_progress_to_multiple_updates(manager):
|
|||||||
async def test_broadcast_download_progress_missing_socket(manager):
|
async def test_broadcast_download_progress_missing_socket(manager):
|
||||||
await manager.broadcast_download_progress("missing", {"progress": 30})
|
await manager.broadcast_download_progress("missing", {"progress": 30})
|
||||||
# Progress should be stored even without a live websocket
|
# Progress should be stored even without a live websocket
|
||||||
assert manager.get_download_progress("missing")["progress"] == 30
|
missing = manager.get_download_progress("missing")
|
||||||
|
assert missing["progress"] == 30
|
||||||
|
assert "bytes_downloaded" not in missing
|
||||||
|
|
||||||
|
|
||||||
async def test_auto_organize_progress_helpers(manager):
|
async def test_auto_organize_progress_helpers(manager):
|
||||||
|
|||||||
Reference in New Issue
Block a user