diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index ab642e01..6f7f7747 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -766,7 +766,15 @@ class ModelDownloadHandler: progress_data = self._ws_manager.get_download_progress(download_id) if progress_data is None: return web.json_response({"success": False, "error": "Download ID not found"}, status=404) - return web.json_response({"success": True, "progress": progress_data.get("progress", 0)}) + return web.json_response( + { + "success": True, + "progress": progress_data.get("progress", 0), + "bytes_downloaded": progress_data.get("bytes_downloaded"), + "total_bytes": progress_data.get("total_bytes"), + "bytes_per_second": progress_data.get("bytes_per_second", 0.0), + } + ) except Exception as exc: self._logger.error("Error getting download progress: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) diff --git a/py/services/download_coordinator.py b/py/services/download_coordinator.py index 4cf866e5..51700793 100644 --- a/py/services/download_coordinator.py +++ b/py/services/download_coordinator.py @@ -5,6 +5,8 @@ from __future__ import annotations import logging from typing import Any, Awaitable, Callable, Dict, Optional +from .downloader import DownloadProgress + logger = logging.getLogger(__name__) @@ -29,14 +31,40 @@ class DownloadCoordinator: download_id = payload.get("download_id") or self._ws_manager.generate_download_id() payload.setdefault("download_id", download_id) - async def progress_callback(progress: Any) -> None: + async def progress_callback(progress: Any, snapshot: Optional[DownloadProgress] = None) -> None: + percent = 0.0 + metrics: Optional[DownloadProgress] = None + + if isinstance(progress, DownloadProgress): + metrics = progress + percent = progress.percent_complete + elif isinstance(snapshot, DownloadProgress): + metrics = snapshot + percent = snapshot.percent_complete + else: + try: + percent = float(progress) + except (TypeError, ValueError): + percent = 0.0 + + payload: Dict[str, Any] = { + "status": "progress", + "progress": round(percent), + "download_id": download_id, + } + + if metrics is not None: + payload.update( + { + "bytes_downloaded": metrics.bytes_downloaded, + "total_bytes": metrics.total_bytes, + "bytes_per_second": metrics.bytes_per_second, + } + ) + await self._ws_manager.broadcast_download_progress( download_id, - { - "status": "progress", - "progress": progress, - "download_id": download_id, - }, + payload, ) model_id = self._parse_optional_int(payload.get("model_id"), "model_id") diff --git a/py/services/download_manager.py b/py/services/download_manager.py index bc20bd29..bcff61ad 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1,9 +1,10 @@ import logging import os import asyncio +import inspect from collections import OrderedDict import uuid -from typing import Dict, List +from typing import Dict, List, Optional, Tuple from urllib.parse import urlparse from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES @@ -13,7 +14,7 @@ from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry from .settings_manager import get_settings_manager from .metadata_service import get_default_metadata_provider -from .downloader import get_downloader +from .downloader import get_downloader, DownloadProgress # Download to temporary file first import tempfile @@ -82,7 +83,10 @@ class DownloadManager: 'model_id': model_id, 'model_version_id': model_version_id, 'progress': 0, - 'status': 'queued' + 'status': 'queued', + 'bytes_downloaded': 0, + 'total_bytes': None, + 'bytes_per_second': 0.0, } # Create tracking task @@ -119,11 +123,19 @@ class DownloadManager: # Wrap progress callback to track progress in active_downloads original_callback = progress_callback - async def tracking_callback(progress): + async def tracking_callback(progress, metrics=None): + progress_value, snapshot = self._normalize_progress(progress, metrics) + if task_id in self._active_downloads: - self._active_downloads[task_id]['progress'] = progress + info = self._active_downloads[task_id] + info['progress'] = round(progress_value) + if snapshot is not None: + info['bytes_downloaded'] = snapshot.bytes_downloaded + info['total_bytes'] = snapshot.total_bytes + info['bytes_per_second'] = snapshot.bytes_per_second + if original_callback: - await original_callback(progress) + await self._dispatch_progress(original_callback, snapshot, progress_value) # Acquire semaphore to limit concurrent downloads try: @@ -149,12 +161,14 @@ class DownloadManager: self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed' if not result['success']: self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error') + self._active_downloads[task_id]['bytes_per_second'] = 0.0 return result except asyncio.CancelledError: # Handle cancellation if task_id in self._active_downloads: self._active_downloads[task_id]['status'] = 'cancelled' + self._active_downloads[task_id]['bytes_per_second'] = 0.0 logger.info(f"Download cancelled for task {task_id}") raise except Exception as e: @@ -163,6 +177,7 @@ class DownloadManager: if task_id in self._active_downloads: self._active_downloads[task_id]['status'] = 'failed' self._active_downloads[task_id]['error'] = str(e) + self._active_downloads[task_id]['bytes_per_second'] = 0.0 return {'success': False, 'error': str(e)} finally: # Schedule cleanup of download record after delay @@ -551,7 +566,11 @@ class DownloadManager: success, result = await downloader.download_file( download_url, save_path, # Use full path instead of separate dir and filename - progress_callback=lambda p: self._handle_download_progress(p, progress_callback), + progress_callback=lambda progress, snapshot=None: self._handle_download_progress( + progress, + progress_callback, + snapshot, + ), use_auth=use_auth # Only use authentication for Civitai downloads ) @@ -631,17 +650,33 @@ class DownloadManager: return {'success': False, 'error': str(e)} - async def _handle_download_progress(self, file_progress: float, progress_callback): - """Convert file download progress to overall progress - - Args: - file_progress: Progress of file download (0-100) - progress_callback: Callback function for progress updates - """ - if progress_callback: - # Scale file progress to 3-100 range (after preview download) - overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download - await progress_callback(round(overall_progress)) + async def _handle_download_progress( + self, + progress_update, + progress_callback, + snapshot=None, + ): + """Convert file download progress to overall progress.""" + + if not progress_callback: + return + + file_progress, original_snapshot = self._normalize_progress(progress_update, snapshot) + overall_progress = 3 + (file_progress * 0.97) + overall_progress = max(0.0, min(overall_progress, 100.0)) + rounded_progress = round(overall_progress) + + normalized_snapshot: Optional[DownloadProgress] = None + if original_snapshot is not None: + normalized_snapshot = DownloadProgress( + percent_complete=overall_progress, + bytes_downloaded=original_snapshot.bytes_downloaded, + total_bytes=original_snapshot.total_bytes, + bytes_per_second=original_snapshot.bytes_per_second, + timestamp=original_snapshot.timestamp, + ) + + await self._dispatch_progress(progress_callback, normalized_snapshot, rounded_progress) async def cancel_download(self, download_id: str) -> Dict: """Cancel an active download by download_id @@ -663,6 +698,7 @@ class DownloadManager: # Update status in active downloads if download_id in self._active_downloads: self._active_downloads[download_id]['status'] = 'cancelling' + self._active_downloads[download_id]['bytes_per_second'] = 0.0 # Wait briefly for the task to acknowledge cancellation try: @@ -725,7 +761,53 @@ class DownloadManager: except Exception as e: logger.error(f"Error cancelling download: {e}", exc_info=True) return {'success': False, 'error': str(e)} - + + @staticmethod + def _coerce_progress_value(progress) -> float: + try: + return float(progress) + except (TypeError, ValueError): + return 0.0 + + @classmethod + def _normalize_progress( + cls, + progress, + snapshot: Optional[DownloadProgress] = None, + ) -> Tuple[float, Optional[DownloadProgress]]: + if isinstance(progress, DownloadProgress): + return progress.percent_complete, progress + + if isinstance(snapshot, DownloadProgress): + return snapshot.percent_complete, snapshot + + if isinstance(progress, dict): + if 'percent_complete' in progress: + return cls._coerce_progress_value(progress['percent_complete']), snapshot + if 'progress' in progress: + return cls._coerce_progress_value(progress['progress']), snapshot + + return cls._coerce_progress_value(progress), None + + async def _dispatch_progress( + self, + callback, + snapshot: Optional[DownloadProgress], + progress_value: float, + ) -> None: + try: + if snapshot is not None: + result = callback(snapshot, snapshot) + else: + result = callback(progress_value) + except TypeError: + result = callback(progress_value) + + if inspect.isawaitable(result): + await result + elif asyncio.iscoroutine(result): + await result + async def get_active_downloads(self) -> Dict: """Get information about all active downloads @@ -740,7 +822,10 @@ class DownloadManager: 'model_version_id': info.get('model_version_id'), 'progress': info.get('progress', 0), 'status': info.get('status', 'unknown'), - 'error': info.get('error', None) + 'error': info.get('error', None), + 'bytes_downloaded': info.get('bytes_downloaded', 0), + 'total_bytes': info.get('total_bytes'), + 'bytes_per_second': info.get('bytes_per_second', 0.0), } for task_id, info in self._active_downloads.items() ] diff --git a/py/services/downloader.py b/py/services/downloader.py index 5e1d2f4f..5c2a1f0c 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -14,13 +14,26 @@ import os import logging import asyncio import aiohttp -from datetime import datetime -from typing import Optional, Dict, Tuple, Callable, Union +from collections import deque +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Optional, Dict, Tuple, Callable, Union, Awaitable from ..services.settings_manager import get_settings_manager logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class DownloadProgress: + """Snapshot of a download transfer at a moment in time.""" + + percent_complete: float + bytes_downloaded: int + total_bytes: Optional[int] + bytes_per_second: float + timestamp: float + + class Downloader: """Unified downloader for all HTTP/HTTPS downloads in the application.""" @@ -159,7 +172,7 @@ class Downloader: self, url: str, save_path: str, - progress_callback: Optional[Callable[[float], None]] = None, + progress_callback: Optional[Callable[..., Awaitable[None]]] = None, use_auth: bool = False, custom_headers: Optional[Dict[str, str]] = None, allow_resume: bool = True @@ -248,7 +261,16 @@ class Downloader: if allow_resume: os.rename(part_path, save_path) if progress_callback: - await progress_callback(100) + await self._dispatch_progress_callback( + progress_callback, + DownloadProgress( + percent_complete=100.0, + bytes_downloaded=part_size, + total_bytes=actual_size, + bytes_per_second=0.0, + timestamp=datetime.now().timestamp(), + ), + ) return True, save_path # Remove corrupted part file and restart os.remove(part_path) @@ -276,6 +298,8 @@ class Downloader: current_size = resume_offset last_progress_report_time = datetime.now() + progress_samples: deque[tuple[datetime, int]] = deque() + progress_samples.append((last_progress_report_time, current_size)) # Ensure directory exists os.makedirs(os.path.dirname(save_path), exist_ok=True) @@ -289,14 +313,35 @@ class Downloader: # Run blocking file write in executor await loop.run_in_executor(None, f.write, chunk) current_size += len(chunk) - + # Limit progress update frequency to reduce overhead now = datetime.now() time_diff = (now - last_progress_report_time).total_seconds() - - if progress_callback and total_size and time_diff >= 1.0: - progress = (current_size / total_size) * 100 - await progress_callback(progress) + + if progress_callback and time_diff >= 1.0: + progress_samples.append((now, current_size)) + cutoff = now - timedelta(seconds=5) + while progress_samples and progress_samples[0][0] < cutoff: + progress_samples.popleft() + + percent = (current_size / total_size) * 100 if total_size else 0.0 + bytes_per_second = 0.0 + if len(progress_samples) >= 2: + first_time, first_bytes = progress_samples[0] + last_time, last_bytes = progress_samples[-1] + elapsed = (last_time - first_time).total_seconds() + if elapsed > 0: + bytes_per_second = (last_bytes - first_bytes) / elapsed + + progress_snapshot = DownloadProgress( + percent_complete=percent, + bytes_downloaded=current_size, + total_bytes=total_size or None, + bytes_per_second=bytes_per_second, + timestamp=now.timestamp(), + ) + + await self._dispatch_progress_callback(progress_callback, progress_snapshot) last_progress_report_time = now # Download completed successfully @@ -331,7 +376,15 @@ class Downloader: # Ensure 100% progress is reported if progress_callback: - await progress_callback(100) + final_snapshot = DownloadProgress( + percent_complete=100.0, + bytes_downloaded=final_size, + total_bytes=total_size or final_size, + bytes_per_second=0.0, + timestamp=datetime.now().timestamp(), + ) + await self._dispatch_progress_callback(progress_callback, final_snapshot) + return True, save_path @@ -363,7 +416,24 @@ class Downloader: return False, str(e) return False, f"Download failed after {self.max_retries + 1} attempts" - + + async def _dispatch_progress_callback( + self, + progress_callback: Callable[..., Awaitable[None]], + snapshot: DownloadProgress, + ) -> None: + """Invoke a progress callback while preserving backward compatibility.""" + + try: + result = progress_callback(snapshot, snapshot) + except TypeError: + result = progress_callback(snapshot.percent_complete) + + if asyncio.iscoroutine(result): + await result + elif hasattr(result, "__await__"): + await result + async def download_to_memory( self, url: str, diff --git a/py/services/metadata_archive_manager.py b/py/services/metadata_archive_manager.py index 49b22c01..c9e24697 100644 --- a/py/services/metadata_archive_manager.py +++ b/py/services/metadata_archive_manager.py @@ -3,7 +3,7 @@ import logging import asyncio from pathlib import Path from typing import Optional -from .downloader import get_downloader +from .downloader import get_downloader, DownloadProgress logger = logging.getLogger(__name__) @@ -77,9 +77,15 @@ class MetadataArchiveManager: progress_callback("download", f"Downloading from {url}") # Custom progress callback to report download progress - async def download_progress(progress): + async def download_progress(progress, snapshot=None): if progress_callback: - progress_callback("download", f"Downloading archive... {progress:.1f}%") + if isinstance(progress, DownloadProgress): + percent = progress.percent_complete + elif isinstance(snapshot, DownloadProgress): + percent = snapshot.percent_complete + else: + percent = float(progress or 0) + progress_callback("download", f"Downloading archive... {percent:.1f}%") success, result = await downloader.download_file( url=url, diff --git a/py/services/websocket_manager.py b/py/services/websocket_manager.py index 223d8b76..b98dc5ba 100644 --- a/py/services/websocket_manager.py +++ b/py/services/websocket_manager.py @@ -155,11 +155,16 @@ class WebSocketManager: async def broadcast_download_progress(self, download_id: str, data: Dict): """Send progress update to specific download client""" - # Store simplified progress data in memory (only progress percentage) - self._download_progress[download_id] = { + progress_entry = { 'progress': data.get('progress', 0), - 'timestamp': datetime.now() + 'timestamp': datetime.now(), } + + for field in ('bytes_downloaded', 'total_bytes', 'bytes_per_second'): + if field in data: + progress_entry[field] = data[field] + + self._download_progress[download_id] = progress_entry if download_id not in self._download_websockets: logger.debug(f"No WebSocket found for download ID: {download_id}") diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index bf7b6560..88bff245 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -5,6 +5,7 @@ import sys from pathlib import Path import types +from typing import Optional folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: []) sys.modules.setdefault("folder_paths", folder_paths_stub) @@ -16,6 +17,7 @@ from aiohttp.test_utils import TestClient, TestServer from py.config import config from py.routes.base_model_routes import BaseModelRoutes from py.services import model_file_service +from py.services.downloader import DownloadProgress from py.services.metadata_sync_service import MetadataSyncService from py.services.model_file_service import AutoOrganizeResult from py.services.service_registry import ServiceRegistry @@ -59,12 +61,21 @@ def download_manager_stub(): self.error = None self.cancelled = [] self.active_downloads = {} + self.last_progress_snapshot: Optional[DownloadProgress] = None async def download_from_civitai(self, **kwargs): self.calls.append(kwargs) if self.error is not None: raise self.error - await kwargs["progress_callback"](42) + snapshot = DownloadProgress( + percent_complete=50.0, + bytes_downloaded=5120, + total_bytes=10240, + bytes_per_second=2048.0, + timestamp=0.0, + ) + self.last_progress_snapshot = snapshot + await kwargs["progress_callback"](snapshot) return {"success": True, "path": "/tmp/model.safetensors"} async def cancel_download(self, download_id): @@ -332,7 +343,11 @@ def test_download_model_invokes_download_manager( assert call_args["download_id"] == payload["download_id"] progress = ws_manager.get_download_progress(payload["download_id"]) assert progress is not None - assert progress["progress"] == 42 + expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete) + assert progress["progress"] == expected_progress + assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded + assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes + assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second assert "timestamp" in progress progress_response = await client.get( @@ -341,7 +356,13 @@ def test_download_model_invokes_download_manager( progress_payload = await progress_response.json() assert progress_response.status == 200 - assert progress_payload == {"success": True, "progress": 42} + assert progress_payload == { + "success": True, + "progress": expected_progress, + "bytes_downloaded": download_manager_stub.last_progress_snapshot.bytes_downloaded, + "total_bytes": download_manager_stub.last_progress_snapshot.total_bytes, + "bytes_per_second": download_manager_stub.last_progress_snapshot.bytes_per_second, + } ws_manager.cleanup_download_progress(payload["download_id"]) finally: await client.close() diff --git a/tests/services/test_route_support_services.py b/tests/services/test_route_support_services.py index 544333ec..aa10e483 100644 --- a/tests/services/test_route_support_services.py +++ b/tests/services/test_route_support_services.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List import pytest from py.services.download_coordinator import DownloadCoordinator +from py.services.downloader import DownloadProgress from py.services.metadata_sync_service import MetadataSyncService from py.services.preview_asset_service import PreviewAssetService from py.services.tag_update_service import TagUpdateService @@ -191,10 +192,17 @@ def test_download_coordinator_emits_progress() -> None: class DownloadManagerStub: def __init__(self) -> None: self.calls: List[Dict[str, Any]] = [] + self.snapshot = DownloadProgress( + percent_complete=25.0, + bytes_downloaded=256, + total_bytes=1024, + bytes_per_second=128.0, + timestamp=0.0, + ) async def download_from_civitai(self, **kwargs) -> Dict[str, Any]: self.calls.append(kwargs) - await kwargs["progress_callback"](10) + await kwargs["progress_callback"](self.snapshot) return {"success": True} async def cancel_download(self, download_id: str) -> Dict[str, Any]: @@ -216,6 +224,12 @@ def test_download_coordinator_emits_progress() -> None: assert result["success"] is True assert manager_stub.calls assert ws_stub.progress_events + expected_progress = round(manager_stub.snapshot.percent_complete) + first_event = ws_stub.progress_events[0] + assert first_event["progress"] == expected_progress + assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded + assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes + assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"])) assert cancel_result["success"] is True diff --git a/tests/services/test_websocket_manager.py b/tests/services/test_websocket_manager.py index 9f512aad..9d8052ea 100644 --- a/tests/services/test_websocket_manager.py +++ b/tests/services/test_websocket_manager.py @@ -86,10 +86,29 @@ async def test_broadcast_download_progress_tracks_state(manager): download_id = "abc" manager._download_websockets[download_id] = ws - await manager.broadcast_download_progress(download_id, {"progress": 55}) + await manager.broadcast_download_progress( + download_id, + { + "progress": 55, + "bytes_downloaded": 512, + "total_bytes": 1024, + "bytes_per_second": 128.0, + }, + ) - assert ws.messages == [{"progress": 55}] - assert manager.get_download_progress(download_id)["progress"] == 55 + assert ws.messages == [ + { + "progress": 55, + "bytes_downloaded": 512, + "total_bytes": 1024, + "bytes_per_second": 128.0, + } + ] + stored = manager.get_download_progress(download_id) + assert stored["progress"] == 55 + assert stored["bytes_downloaded"] == 512 + assert stored["total_bytes"] == 1024 + assert stored["bytes_per_second"] == 128.0 async def test_broadcast_download_progress_to_multiple_updates(manager): @@ -107,7 +126,9 @@ async def test_broadcast_download_progress_to_multiple_updates(manager): async def test_broadcast_download_progress_missing_socket(manager): await manager.broadcast_download_progress("missing", {"progress": 30}) # Progress should be stored even without a live websocket - assert manager.get_download_progress("missing")["progress"] == 30 + missing = manager.get_download_progress("missing") + assert missing["progress"] == 30 + assert "bytes_downloaded" not in missing async def test_auto_organize_progress_helpers(manager):