feat(downloads): expose throughput metrics in progress APIs

This commit is contained in:
pixelpaws
2025-10-13 14:39:31 +08:00
parent 01bbaa31a8
commit eb76468280
9 changed files with 310 additions and 52 deletions

View File

@@ -766,7 +766,15 @@ class ModelDownloadHandler:
progress_data = self._ws_manager.get_download_progress(download_id) progress_data = self._ws_manager.get_download_progress(download_id)
if progress_data is None: if progress_data is None:
return web.json_response({"success": False, "error": "Download ID not found"}, status=404) return web.json_response({"success": False, "error": "Download ID not found"}, status=404)
return web.json_response({"success": True, "progress": progress_data.get("progress", 0)}) return web.json_response(
{
"success": True,
"progress": progress_data.get("progress", 0),
"bytes_downloaded": progress_data.get("bytes_downloaded"),
"total_bytes": progress_data.get("total_bytes"),
"bytes_per_second": progress_data.get("bytes_per_second", 0.0),
}
)
except Exception as exc: except Exception as exc:
self._logger.error("Error getting download progress: %s", exc, exc_info=True) self._logger.error("Error getting download progress: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)

View File

@@ -5,6 +5,8 @@ from __future__ import annotations
import logging import logging
from typing import Any, Awaitable, Callable, Dict, Optional from typing import Any, Awaitable, Callable, Dict, Optional
from .downloader import DownloadProgress
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -29,14 +31,40 @@ class DownloadCoordinator:
download_id = payload.get("download_id") or self._ws_manager.generate_download_id() download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
payload.setdefault("download_id", download_id) payload.setdefault("download_id", download_id)
async def progress_callback(progress: Any) -> None: async def progress_callback(progress: Any, snapshot: Optional[DownloadProgress] = None) -> None:
percent = 0.0
metrics: Optional[DownloadProgress] = None
if isinstance(progress, DownloadProgress):
metrics = progress
percent = progress.percent_complete
elif isinstance(snapshot, DownloadProgress):
metrics = snapshot
percent = snapshot.percent_complete
else:
try:
percent = float(progress)
except (TypeError, ValueError):
percent = 0.0
payload: Dict[str, Any] = {
"status": "progress",
"progress": round(percent),
"download_id": download_id,
}
if metrics is not None:
payload.update(
{
"bytes_downloaded": metrics.bytes_downloaded,
"total_bytes": metrics.total_bytes,
"bytes_per_second": metrics.bytes_per_second,
}
)
await self._ws_manager.broadcast_download_progress( await self._ws_manager.broadcast_download_progress(
download_id, download_id,
{ payload,
"status": "progress",
"progress": progress,
"download_id": download_id,
},
) )
model_id = self._parse_optional_int(payload.get("model_id"), "model_id") model_id = self._parse_optional_int(payload.get("model_id"), "model_id")

View File

@@ -1,9 +1,10 @@
import logging import logging
import os import os
import asyncio import asyncio
import inspect
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict, List from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
@@ -13,7 +14,7 @@ from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader from .downloader import get_downloader, DownloadProgress
# Download to temporary file first # Download to temporary file first
import tempfile import tempfile
@@ -82,7 +83,10 @@ class DownloadManager:
'model_id': model_id, 'model_id': model_id,
'model_version_id': model_version_id, 'model_version_id': model_version_id,
'progress': 0, 'progress': 0,
'status': 'queued' 'status': 'queued',
'bytes_downloaded': 0,
'total_bytes': None,
'bytes_per_second': 0.0,
} }
# Create tracking task # Create tracking task
@@ -119,11 +123,19 @@ class DownloadManager:
# Wrap progress callback to track progress in active_downloads # Wrap progress callback to track progress in active_downloads
original_callback = progress_callback original_callback = progress_callback
async def tracking_callback(progress): async def tracking_callback(progress, metrics=None):
progress_value, snapshot = self._normalize_progress(progress, metrics)
if task_id in self._active_downloads: if task_id in self._active_downloads:
self._active_downloads[task_id]['progress'] = progress info = self._active_downloads[task_id]
info['progress'] = round(progress_value)
if snapshot is not None:
info['bytes_downloaded'] = snapshot.bytes_downloaded
info['total_bytes'] = snapshot.total_bytes
info['bytes_per_second'] = snapshot.bytes_per_second
if original_callback: if original_callback:
await original_callback(progress) await self._dispatch_progress(original_callback, snapshot, progress_value)
# Acquire semaphore to limit concurrent downloads # Acquire semaphore to limit concurrent downloads
try: try:
@@ -149,12 +161,14 @@ class DownloadManager:
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed' self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
if not result['success']: if not result['success']:
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error') self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
self._active_downloads[task_id]['bytes_per_second'] = 0.0
return result return result
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle cancellation # Handle cancellation
if task_id in self._active_downloads: if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'cancelled' self._active_downloads[task_id]['status'] = 'cancelled'
self._active_downloads[task_id]['bytes_per_second'] = 0.0
logger.info(f"Download cancelled for task {task_id}") logger.info(f"Download cancelled for task {task_id}")
raise raise
except Exception as e: except Exception as e:
@@ -163,6 +177,7 @@ class DownloadManager:
if task_id in self._active_downloads: if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'failed' self._active_downloads[task_id]['status'] = 'failed'
self._active_downloads[task_id]['error'] = str(e) self._active_downloads[task_id]['error'] = str(e)
self._active_downloads[task_id]['bytes_per_second'] = 0.0
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
finally: finally:
# Schedule cleanup of download record after delay # Schedule cleanup of download record after delay
@@ -551,7 +566,11 @@ class DownloadManager:
success, result = await downloader.download_file( success, result = await downloader.download_file(
download_url, download_url,
save_path, # Use full path instead of separate dir and filename save_path, # Use full path instead of separate dir and filename
progress_callback=lambda p: self._handle_download_progress(p, progress_callback), progress_callback=lambda progress, snapshot=None: self._handle_download_progress(
progress,
progress_callback,
snapshot,
),
use_auth=use_auth # Only use authentication for Civitai downloads use_auth=use_auth # Only use authentication for Civitai downloads
) )
@@ -631,17 +650,33 @@ class DownloadManager:
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
async def _handle_download_progress(self, file_progress: float, progress_callback): async def _handle_download_progress(
"""Convert file download progress to overall progress self,
progress_update,
progress_callback,
snapshot=None,
):
"""Convert file download progress to overall progress."""
Args: if not progress_callback:
file_progress: Progress of file download (0-100) return
progress_callback: Callback function for progress updates
""" file_progress, original_snapshot = self._normalize_progress(progress_update, snapshot)
if progress_callback: overall_progress = 3 + (file_progress * 0.97)
# Scale file progress to 3-100 range (after preview download) overall_progress = max(0.0, min(overall_progress, 100.0))
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download rounded_progress = round(overall_progress)
await progress_callback(round(overall_progress))
normalized_snapshot: Optional[DownloadProgress] = None
if original_snapshot is not None:
normalized_snapshot = DownloadProgress(
percent_complete=overall_progress,
bytes_downloaded=original_snapshot.bytes_downloaded,
total_bytes=original_snapshot.total_bytes,
bytes_per_second=original_snapshot.bytes_per_second,
timestamp=original_snapshot.timestamp,
)
await self._dispatch_progress(progress_callback, normalized_snapshot, rounded_progress)
async def cancel_download(self, download_id: str) -> Dict: async def cancel_download(self, download_id: str) -> Dict:
"""Cancel an active download by download_id """Cancel an active download by download_id
@@ -663,6 +698,7 @@ class DownloadManager:
# Update status in active downloads # Update status in active downloads
if download_id in self._active_downloads: if download_id in self._active_downloads:
self._active_downloads[download_id]['status'] = 'cancelling' self._active_downloads[download_id]['status'] = 'cancelling'
self._active_downloads[download_id]['bytes_per_second'] = 0.0
# Wait briefly for the task to acknowledge cancellation # Wait briefly for the task to acknowledge cancellation
try: try:
@@ -726,6 +762,52 @@ class DownloadManager:
logger.error(f"Error cancelling download: {e}", exc_info=True) logger.error(f"Error cancelling download: {e}", exc_info=True)
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
@staticmethod
def _coerce_progress_value(progress) -> float:
try:
return float(progress)
except (TypeError, ValueError):
return 0.0
@classmethod
def _normalize_progress(
cls,
progress,
snapshot: Optional[DownloadProgress] = None,
) -> Tuple[float, Optional[DownloadProgress]]:
if isinstance(progress, DownloadProgress):
return progress.percent_complete, progress
if isinstance(snapshot, DownloadProgress):
return snapshot.percent_complete, snapshot
if isinstance(progress, dict):
if 'percent_complete' in progress:
return cls._coerce_progress_value(progress['percent_complete']), snapshot
if 'progress' in progress:
return cls._coerce_progress_value(progress['progress']), snapshot
return cls._coerce_progress_value(progress), None
async def _dispatch_progress(
self,
callback,
snapshot: Optional[DownloadProgress],
progress_value: float,
) -> None:
try:
if snapshot is not None:
result = callback(snapshot, snapshot)
else:
result = callback(progress_value)
except TypeError:
result = callback(progress_value)
if inspect.isawaitable(result):
await result
elif asyncio.iscoroutine(result):
await result
async def get_active_downloads(self) -> Dict: async def get_active_downloads(self) -> Dict:
"""Get information about all active downloads """Get information about all active downloads
@@ -740,7 +822,10 @@ class DownloadManager:
'model_version_id': info.get('model_version_id'), 'model_version_id': info.get('model_version_id'),
'progress': info.get('progress', 0), 'progress': info.get('progress', 0),
'status': info.get('status', 'unknown'), 'status': info.get('status', 'unknown'),
'error': info.get('error', None) 'error': info.get('error', None),
'bytes_downloaded': info.get('bytes_downloaded', 0),
'total_bytes': info.get('total_bytes'),
'bytes_per_second': info.get('bytes_per_second', 0.0),
} }
for task_id, info in self._active_downloads.items() for task_id, info in self._active_downloads.items()
] ]

View File

@@ -14,13 +14,26 @@ import os
import logging import logging
import asyncio import asyncio
import aiohttp import aiohttp
from datetime import datetime from collections import deque
from typing import Optional, Dict, Tuple, Callable, Union from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
from ..services.settings_manager import get_settings_manager from ..services.settings_manager import get_settings_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class DownloadProgress:
"""Snapshot of a download transfer at a moment in time."""
percent_complete: float
bytes_downloaded: int
total_bytes: Optional[int]
bytes_per_second: float
timestamp: float
class Downloader: class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application.""" """Unified downloader for all HTTP/HTTPS downloads in the application."""
@@ -159,7 +172,7 @@ class Downloader:
self, self,
url: str, url: str,
save_path: str, save_path: str,
progress_callback: Optional[Callable[[float], None]] = None, progress_callback: Optional[Callable[..., Awaitable[None]]] = None,
use_auth: bool = False, use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None, custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True allow_resume: bool = True
@@ -248,7 +261,16 @@ class Downloader:
if allow_resume: if allow_resume:
os.rename(part_path, save_path) os.rename(part_path, save_path)
if progress_callback: if progress_callback:
await progress_callback(100) await self._dispatch_progress_callback(
progress_callback,
DownloadProgress(
percent_complete=100.0,
bytes_downloaded=part_size,
total_bytes=actual_size,
bytes_per_second=0.0,
timestamp=datetime.now().timestamp(),
),
)
return True, save_path return True, save_path
# Remove corrupted part file and restart # Remove corrupted part file and restart
os.remove(part_path) os.remove(part_path)
@@ -276,6 +298,8 @@ class Downloader:
current_size = resume_offset current_size = resume_offset
last_progress_report_time = datetime.now() last_progress_report_time = datetime.now()
progress_samples: deque[tuple[datetime, int]] = deque()
progress_samples.append((last_progress_report_time, current_size))
# Ensure directory exists # Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True) os.makedirs(os.path.dirname(save_path), exist_ok=True)
@@ -294,9 +318,30 @@ class Downloader:
now = datetime.now() now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds() time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0: if progress_callback and time_diff >= 1.0:
progress = (current_size / total_size) * 100 progress_samples.append((now, current_size))
await progress_callback(progress) cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff:
progress_samples.popleft()
percent = (current_size / total_size) * 100 if total_size else 0.0
bytes_per_second = 0.0
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
progress_snapshot = DownloadProgress(
percent_complete=percent,
bytes_downloaded=current_size,
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
last_progress_report_time = now last_progress_report_time = now
# Download completed successfully # Download completed successfully
@@ -331,7 +376,15 @@ class Downloader:
# Ensure 100% progress is reported # Ensure 100% progress is reported
if progress_callback: if progress_callback:
await progress_callback(100) final_snapshot = DownloadProgress(
percent_complete=100.0,
bytes_downloaded=final_size,
total_bytes=total_size or final_size,
bytes_per_second=0.0,
timestamp=datetime.now().timestamp(),
)
await self._dispatch_progress_callback(progress_callback, final_snapshot)
return True, save_path return True, save_path
@@ -364,6 +417,23 @@ class Downloader:
return False, f"Download failed after {self.max_retries + 1} attempts" return False, f"Download failed after {self.max_retries + 1} attempts"
async def _dispatch_progress_callback(
self,
progress_callback: Callable[..., Awaitable[None]],
snapshot: DownloadProgress,
) -> None:
"""Invoke a progress callback while preserving backward compatibility."""
try:
result = progress_callback(snapshot, snapshot)
except TypeError:
result = progress_callback(snapshot.percent_complete)
if asyncio.iscoroutine(result):
await result
elif hasattr(result, "__await__"):
await result
async def download_to_memory( async def download_to_memory(
self, self,
url: str, url: str,

View File

@@ -3,7 +3,7 @@ import logging
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from .downloader import get_downloader from .downloader import get_downloader, DownloadProgress
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -77,9 +77,15 @@ class MetadataArchiveManager:
progress_callback("download", f"Downloading from {url}") progress_callback("download", f"Downloading from {url}")
# Custom progress callback to report download progress # Custom progress callback to report download progress
async def download_progress(progress): async def download_progress(progress, snapshot=None):
if progress_callback: if progress_callback:
progress_callback("download", f"Downloading archive... {progress:.1f}%") if isinstance(progress, DownloadProgress):
percent = progress.percent_complete
elif isinstance(snapshot, DownloadProgress):
percent = snapshot.percent_complete
else:
percent = float(progress or 0)
progress_callback("download", f"Downloading archive... {percent:.1f}%")
success, result = await downloader.download_file( success, result = await downloader.download_file(
url=url, url=url,

View File

@@ -155,12 +155,17 @@ class WebSocketManager:
async def broadcast_download_progress(self, download_id: str, data: Dict): async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Send progress update to specific download client""" """Send progress update to specific download client"""
# Store simplified progress data in memory (only progress percentage) progress_entry = {
self._download_progress[download_id] = {
'progress': data.get('progress', 0), 'progress': data.get('progress', 0),
'timestamp': datetime.now() 'timestamp': datetime.now(),
} }
for field in ('bytes_downloaded', 'total_bytes', 'bytes_per_second'):
if field in data:
progress_entry[field] = data[field]
self._download_progress[download_id] = progress_entry
if download_id not in self._download_websockets: if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}") logger.debug(f"No WebSocket found for download ID: {download_id}")
return return

View File

@@ -5,6 +5,7 @@ import sys
from pathlib import Path from pathlib import Path
import types import types
from typing import Optional
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: []) folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
sys.modules.setdefault("folder_paths", folder_paths_stub) sys.modules.setdefault("folder_paths", folder_paths_stub)
@@ -16,6 +17,7 @@ from aiohttp.test_utils import TestClient, TestServer
from py.config import config from py.config import config
from py.routes.base_model_routes import BaseModelRoutes from py.routes.base_model_routes import BaseModelRoutes
from py.services import model_file_service from py.services import model_file_service
from py.services.downloader import DownloadProgress
from py.services.metadata_sync_service import MetadataSyncService from py.services.metadata_sync_service import MetadataSyncService
from py.services.model_file_service import AutoOrganizeResult from py.services.model_file_service import AutoOrganizeResult
from py.services.service_registry import ServiceRegistry from py.services.service_registry import ServiceRegistry
@@ -59,12 +61,21 @@ def download_manager_stub():
self.error = None self.error = None
self.cancelled = [] self.cancelled = []
self.active_downloads = {} self.active_downloads = {}
self.last_progress_snapshot: Optional[DownloadProgress] = None
async def download_from_civitai(self, **kwargs): async def download_from_civitai(self, **kwargs):
self.calls.append(kwargs) self.calls.append(kwargs)
if self.error is not None: if self.error is not None:
raise self.error raise self.error
await kwargs["progress_callback"](42) snapshot = DownloadProgress(
percent_complete=50.0,
bytes_downloaded=5120,
total_bytes=10240,
bytes_per_second=2048.0,
timestamp=0.0,
)
self.last_progress_snapshot = snapshot
await kwargs["progress_callback"](snapshot)
return {"success": True, "path": "/tmp/model.safetensors"} return {"success": True, "path": "/tmp/model.safetensors"}
async def cancel_download(self, download_id): async def cancel_download(self, download_id):
@@ -332,7 +343,11 @@ def test_download_model_invokes_download_manager(
assert call_args["download_id"] == payload["download_id"] assert call_args["download_id"] == payload["download_id"]
progress = ws_manager.get_download_progress(payload["download_id"]) progress = ws_manager.get_download_progress(payload["download_id"])
assert progress is not None assert progress is not None
assert progress["progress"] == 42 expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete)
assert progress["progress"] == expected_progress
assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded
assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes
assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second
assert "timestamp" in progress assert "timestamp" in progress
progress_response = await client.get( progress_response = await client.get(
@@ -341,7 +356,13 @@ def test_download_model_invokes_download_manager(
progress_payload = await progress_response.json() progress_payload = await progress_response.json()
assert progress_response.status == 200 assert progress_response.status == 200
assert progress_payload == {"success": True, "progress": 42} assert progress_payload == {
"success": True,
"progress": expected_progress,
"bytes_downloaded": download_manager_stub.last_progress_snapshot.bytes_downloaded,
"total_bytes": download_manager_stub.last_progress_snapshot.total_bytes,
"bytes_per_second": download_manager_stub.last_progress_snapshot.bytes_per_second,
}
ws_manager.cleanup_download_progress(payload["download_id"]) ws_manager.cleanup_download_progress(payload["download_id"])
finally: finally:
await client.close() await client.close()

View File

@@ -7,6 +7,7 @@ from typing import Any, Dict, List
import pytest import pytest
from py.services.download_coordinator import DownloadCoordinator from py.services.download_coordinator import DownloadCoordinator
from py.services.downloader import DownloadProgress
from py.services.metadata_sync_service import MetadataSyncService from py.services.metadata_sync_service import MetadataSyncService
from py.services.preview_asset_service import PreviewAssetService from py.services.preview_asset_service import PreviewAssetService
from py.services.tag_update_service import TagUpdateService from py.services.tag_update_service import TagUpdateService
@@ -191,10 +192,17 @@ def test_download_coordinator_emits_progress() -> None:
class DownloadManagerStub: class DownloadManagerStub:
def __init__(self) -> None: def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = [] self.calls: List[Dict[str, Any]] = []
self.snapshot = DownloadProgress(
percent_complete=25.0,
bytes_downloaded=256,
total_bytes=1024,
bytes_per_second=128.0,
timestamp=0.0,
)
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]: async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
self.calls.append(kwargs) self.calls.append(kwargs)
await kwargs["progress_callback"](10) await kwargs["progress_callback"](self.snapshot)
return {"success": True} return {"success": True}
async def cancel_download(self, download_id: str) -> Dict[str, Any]: async def cancel_download(self, download_id: str) -> Dict[str, Any]:
@@ -216,6 +224,12 @@ def test_download_coordinator_emits_progress() -> None:
assert result["success"] is True assert result["success"] is True
assert manager_stub.calls assert manager_stub.calls
assert ws_stub.progress_events assert ws_stub.progress_events
expected_progress = round(manager_stub.snapshot.percent_complete)
first_event = ws_stub.progress_events[0]
assert first_event["progress"] == expected_progress
assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded
assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes
assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"])) cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
assert cancel_result["success"] is True assert cancel_result["success"] is True

View File

@@ -86,10 +86,29 @@ async def test_broadcast_download_progress_tracks_state(manager):
download_id = "abc" download_id = "abc"
manager._download_websockets[download_id] = ws manager._download_websockets[download_id] = ws
await manager.broadcast_download_progress(download_id, {"progress": 55}) await manager.broadcast_download_progress(
download_id,
{
"progress": 55,
"bytes_downloaded": 512,
"total_bytes": 1024,
"bytes_per_second": 128.0,
},
)
assert ws.messages == [{"progress": 55}] assert ws.messages == [
assert manager.get_download_progress(download_id)["progress"] == 55 {
"progress": 55,
"bytes_downloaded": 512,
"total_bytes": 1024,
"bytes_per_second": 128.0,
}
]
stored = manager.get_download_progress(download_id)
assert stored["progress"] == 55
assert stored["bytes_downloaded"] == 512
assert stored["total_bytes"] == 1024
assert stored["bytes_per_second"] == 128.0
async def test_broadcast_download_progress_to_multiple_updates(manager): async def test_broadcast_download_progress_to_multiple_updates(manager):
@@ -107,7 +126,9 @@ async def test_broadcast_download_progress_to_multiple_updates(manager):
async def test_broadcast_download_progress_missing_socket(manager): async def test_broadcast_download_progress_missing_socket(manager):
await manager.broadcast_download_progress("missing", {"progress": 30}) await manager.broadcast_download_progress("missing", {"progress": 30})
# Progress should be stored even without a live websocket # Progress should be stored even without a live websocket
assert manager.get_download_progress("missing")["progress"] == 30 missing = manager.get_download_progress("missing")
assert missing["progress"] == 30
assert "bytes_downloaded" not in missing
async def test_auto_organize_progress_helpers(manager): async def test_auto_organize_progress_helpers(manager):