feat(downloads): expose throughput metrics in progress APIs

This commit is contained in:
pixelpaws
2025-10-13 14:39:31 +08:00
parent 01bbaa31a8
commit eb76468280
9 changed files with 310 additions and 52 deletions

View File

@@ -766,7 +766,15 @@ class ModelDownloadHandler:
progress_data = self._ws_manager.get_download_progress(download_id)
if progress_data is None:
return web.json_response({"success": False, "error": "Download ID not found"}, status=404)
return web.json_response({"success": True, "progress": progress_data.get("progress", 0)})
return web.json_response(
{
"success": True,
"progress": progress_data.get("progress", 0),
"bytes_downloaded": progress_data.get("bytes_downloaded"),
"total_bytes": progress_data.get("total_bytes"),
"bytes_per_second": progress_data.get("bytes_per_second", 0.0),
}
)
except Exception as exc:
self._logger.error("Error getting download progress: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)

View File

@@ -5,6 +5,8 @@ from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable, Dict, Optional
from .downloader import DownloadProgress
logger = logging.getLogger(__name__)
@@ -29,14 +31,40 @@ class DownloadCoordinator:
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
payload.setdefault("download_id", download_id)
async def progress_callback(progress: Any) -> None:
async def progress_callback(progress: Any, snapshot: Optional[DownloadProgress] = None) -> None:
percent = 0.0
metrics: Optional[DownloadProgress] = None
if isinstance(progress, DownloadProgress):
metrics = progress
percent = progress.percent_complete
elif isinstance(snapshot, DownloadProgress):
metrics = snapshot
percent = snapshot.percent_complete
else:
try:
percent = float(progress)
except (TypeError, ValueError):
percent = 0.0
payload: Dict[str, Any] = {
"status": "progress",
"progress": round(percent),
"download_id": download_id,
}
if metrics is not None:
payload.update(
{
"bytes_downloaded": metrics.bytes_downloaded,
"total_bytes": metrics.total_bytes,
"bytes_per_second": metrics.bytes_per_second,
}
)
await self._ws_manager.broadcast_download_progress(
download_id,
{
"status": "progress",
"progress": progress,
"download_id": download_id,
},
payload,
)
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")

View File

@@ -1,9 +1,10 @@
import logging
import os
import asyncio
import inspect
from collections import OrderedDict
import uuid
from typing import Dict, List
from typing import Dict, List, Optional, Tuple
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
@@ -13,7 +14,7 @@ from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader
from .downloader import get_downloader, DownloadProgress
# Download to temporary file first
import tempfile
@@ -82,7 +83,10 @@ class DownloadManager:
'model_id': model_id,
'model_version_id': model_version_id,
'progress': 0,
'status': 'queued'
'status': 'queued',
'bytes_downloaded': 0,
'total_bytes': None,
'bytes_per_second': 0.0,
}
# Create tracking task
@@ -119,11 +123,19 @@ class DownloadManager:
# Wrap progress callback to track progress in active_downloads
original_callback = progress_callback
async def tracking_callback(progress):
async def tracking_callback(progress, metrics=None):
progress_value, snapshot = self._normalize_progress(progress, metrics)
if task_id in self._active_downloads:
self._active_downloads[task_id]['progress'] = progress
info = self._active_downloads[task_id]
info['progress'] = round(progress_value)
if snapshot is not None:
info['bytes_downloaded'] = snapshot.bytes_downloaded
info['total_bytes'] = snapshot.total_bytes
info['bytes_per_second'] = snapshot.bytes_per_second
if original_callback:
await original_callback(progress)
await self._dispatch_progress(original_callback, snapshot, progress_value)
# Acquire semaphore to limit concurrent downloads
try:
@@ -149,12 +161,14 @@ class DownloadManager:
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
if not result['success']:
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
self._active_downloads[task_id]['bytes_per_second'] = 0.0
return result
except asyncio.CancelledError:
# Handle cancellation
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'cancelled'
self._active_downloads[task_id]['bytes_per_second'] = 0.0
logger.info(f"Download cancelled for task {task_id}")
raise
except Exception as e:
@@ -163,6 +177,7 @@ class DownloadManager:
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'failed'
self._active_downloads[task_id]['error'] = str(e)
self._active_downloads[task_id]['bytes_per_second'] = 0.0
return {'success': False, 'error': str(e)}
finally:
# Schedule cleanup of download record after delay
@@ -551,7 +566,11 @@ class DownloadManager:
success, result = await downloader.download_file(
download_url,
save_path, # Use full path instead of separate dir and filename
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
progress_callback=lambda progress, snapshot=None: self._handle_download_progress(
progress,
progress_callback,
snapshot,
),
use_auth=use_auth # Only use authentication for Civitai downloads
)
@@ -631,17 +650,33 @@ class DownloadManager:
return {'success': False, 'error': str(e)}
async def _handle_download_progress(self, file_progress: float, progress_callback):
"""Convert file download progress to overall progress
Args:
file_progress: Progress of file download (0-100)
progress_callback: Callback function for progress updates
"""
if progress_callback:
# Scale file progress to 3-100 range (after preview download)
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
await progress_callback(round(overall_progress))
async def _handle_download_progress(
self,
progress_update,
progress_callback,
snapshot=None,
):
"""Convert file download progress to overall progress."""
if not progress_callback:
return
file_progress, original_snapshot = self._normalize_progress(progress_update, snapshot)
overall_progress = 3 + (file_progress * 0.97)
overall_progress = max(0.0, min(overall_progress, 100.0))
rounded_progress = round(overall_progress)
normalized_snapshot: Optional[DownloadProgress] = None
if original_snapshot is not None:
normalized_snapshot = DownloadProgress(
percent_complete=overall_progress,
bytes_downloaded=original_snapshot.bytes_downloaded,
total_bytes=original_snapshot.total_bytes,
bytes_per_second=original_snapshot.bytes_per_second,
timestamp=original_snapshot.timestamp,
)
await self._dispatch_progress(progress_callback, normalized_snapshot, rounded_progress)
async def cancel_download(self, download_id: str) -> Dict:
"""Cancel an active download by download_id
@@ -663,6 +698,7 @@ class DownloadManager:
# Update status in active downloads
if download_id in self._active_downloads:
self._active_downloads[download_id]['status'] = 'cancelling'
self._active_downloads[download_id]['bytes_per_second'] = 0.0
# Wait briefly for the task to acknowledge cancellation
try:
@@ -725,7 +761,53 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
@staticmethod
def _coerce_progress_value(progress) -> float:
try:
return float(progress)
except (TypeError, ValueError):
return 0.0
@classmethod
def _normalize_progress(
cls,
progress,
snapshot: Optional[DownloadProgress] = None,
) -> Tuple[float, Optional[DownloadProgress]]:
if isinstance(progress, DownloadProgress):
return progress.percent_complete, progress
if isinstance(snapshot, DownloadProgress):
return snapshot.percent_complete, snapshot
if isinstance(progress, dict):
if 'percent_complete' in progress:
return cls._coerce_progress_value(progress['percent_complete']), snapshot
if 'progress' in progress:
return cls._coerce_progress_value(progress['progress']), snapshot
return cls._coerce_progress_value(progress), None
async def _dispatch_progress(
self,
callback,
snapshot: Optional[DownloadProgress],
progress_value: float,
) -> None:
try:
if snapshot is not None:
result = callback(snapshot, snapshot)
else:
result = callback(progress_value)
except TypeError:
result = callback(progress_value)
if inspect.isawaitable(result):
await result
elif asyncio.iscoroutine(result):
await result
async def get_active_downloads(self) -> Dict:
"""Get information about all active downloads
@@ -740,7 +822,10 @@ class DownloadManager:
'model_version_id': info.get('model_version_id'),
'progress': info.get('progress', 0),
'status': info.get('status', 'unknown'),
'error': info.get('error', None)
'error': info.get('error', None),
'bytes_downloaded': info.get('bytes_downloaded', 0),
'total_bytes': info.get('total_bytes'),
'bytes_per_second': info.get('bytes_per_second', 0.0),
}
for task_id, info in self._active_downloads.items()
]

View File

@@ -14,13 +14,26 @@ import os
import logging
import asyncio
import aiohttp
from datetime import datetime
from typing import Optional, Dict, Tuple, Callable, Union
from collections import deque
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Optional, Dict, Tuple, Callable, Union, Awaitable
from ..services.settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class DownloadProgress:
"""Snapshot of a download transfer at a moment in time."""
percent_complete: float
bytes_downloaded: int
total_bytes: Optional[int]
bytes_per_second: float
timestamp: float
class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
@@ -159,7 +172,7 @@ class Downloader:
self,
url: str,
save_path: str,
progress_callback: Optional[Callable[[float], None]] = None,
progress_callback: Optional[Callable[..., Awaitable[None]]] = None,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True
@@ -248,7 +261,16 @@ class Downloader:
if allow_resume:
os.rename(part_path, save_path)
if progress_callback:
await progress_callback(100)
await self._dispatch_progress_callback(
progress_callback,
DownloadProgress(
percent_complete=100.0,
bytes_downloaded=part_size,
total_bytes=actual_size,
bytes_per_second=0.0,
timestamp=datetime.now().timestamp(),
),
)
return True, save_path
# Remove corrupted part file and restart
os.remove(part_path)
@@ -276,6 +298,8 @@ class Downloader:
current_size = resume_offset
last_progress_report_time = datetime.now()
progress_samples: deque[tuple[datetime, int]] = deque()
progress_samples.append((last_progress_report_time, current_size))
# Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
@@ -289,14 +313,35 @@ class Downloader:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
if progress_callback and time_diff >= 1.0:
progress_samples.append((now, current_size))
cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff:
progress_samples.popleft()
percent = (current_size / total_size) * 100 if total_size else 0.0
bytes_per_second = 0.0
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
progress_snapshot = DownloadProgress(
percent_complete=percent,
bytes_downloaded=current_size,
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
last_progress_report_time = now
# Download completed successfully
@@ -331,7 +376,15 @@ class Downloader:
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
final_snapshot = DownloadProgress(
percent_complete=100.0,
bytes_downloaded=final_size,
total_bytes=total_size or final_size,
bytes_per_second=0.0,
timestamp=datetime.now().timestamp(),
)
await self._dispatch_progress_callback(progress_callback, final_snapshot)
return True, save_path
@@ -363,7 +416,24 @@ class Downloader:
return False, str(e)
return False, f"Download failed after {self.max_retries + 1} attempts"
async def _dispatch_progress_callback(
self,
progress_callback: Callable[..., Awaitable[None]],
snapshot: DownloadProgress,
) -> None:
"""Invoke a progress callback while preserving backward compatibility."""
try:
result = progress_callback(snapshot, snapshot)
except TypeError:
result = progress_callback(snapshot.percent_complete)
if asyncio.iscoroutine(result):
await result
elif hasattr(result, "__await__"):
await result
async def download_to_memory(
self,
url: str,

View File

@@ -3,7 +3,7 @@ import logging
import asyncio
from pathlib import Path
from typing import Optional
from .downloader import get_downloader
from .downloader import get_downloader, DownloadProgress
logger = logging.getLogger(__name__)
@@ -77,9 +77,15 @@ class MetadataArchiveManager:
progress_callback("download", f"Downloading from {url}")
# Custom progress callback to report download progress
async def download_progress(progress):
async def download_progress(progress, snapshot=None):
if progress_callback:
progress_callback("download", f"Downloading archive... {progress:.1f}%")
if isinstance(progress, DownloadProgress):
percent = progress.percent_complete
elif isinstance(snapshot, DownloadProgress):
percent = snapshot.percent_complete
else:
percent = float(progress or 0)
progress_callback("download", f"Downloading archive... {percent:.1f}%")
success, result = await downloader.download_file(
url=url,

View File

@@ -155,11 +155,16 @@ class WebSocketManager:
async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Send progress update to specific download client"""
# Store simplified progress data in memory (only progress percentage)
self._download_progress[download_id] = {
progress_entry = {
'progress': data.get('progress', 0),
'timestamp': datetime.now()
'timestamp': datetime.now(),
}
for field in ('bytes_downloaded', 'total_bytes', 'bytes_per_second'):
if field in data:
progress_entry[field] = data[field]
self._download_progress[download_id] = progress_entry
if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}")

View File

@@ -5,6 +5,7 @@ import sys
from pathlib import Path
import types
from typing import Optional
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
sys.modules.setdefault("folder_paths", folder_paths_stub)
@@ -16,6 +17,7 @@ from aiohttp.test_utils import TestClient, TestServer
from py.config import config
from py.routes.base_model_routes import BaseModelRoutes
from py.services import model_file_service
from py.services.downloader import DownloadProgress
from py.services.metadata_sync_service import MetadataSyncService
from py.services.model_file_service import AutoOrganizeResult
from py.services.service_registry import ServiceRegistry
@@ -59,12 +61,21 @@ def download_manager_stub():
self.error = None
self.cancelled = []
self.active_downloads = {}
self.last_progress_snapshot: Optional[DownloadProgress] = None
async def download_from_civitai(self, **kwargs):
self.calls.append(kwargs)
if self.error is not None:
raise self.error
await kwargs["progress_callback"](42)
snapshot = DownloadProgress(
percent_complete=50.0,
bytes_downloaded=5120,
total_bytes=10240,
bytes_per_second=2048.0,
timestamp=0.0,
)
self.last_progress_snapshot = snapshot
await kwargs["progress_callback"](snapshot)
return {"success": True, "path": "/tmp/model.safetensors"}
async def cancel_download(self, download_id):
@@ -332,7 +343,11 @@ def test_download_model_invokes_download_manager(
assert call_args["download_id"] == payload["download_id"]
progress = ws_manager.get_download_progress(payload["download_id"])
assert progress is not None
assert progress["progress"] == 42
expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete)
assert progress["progress"] == expected_progress
assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded
assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes
assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second
assert "timestamp" in progress
progress_response = await client.get(
@@ -341,7 +356,13 @@ def test_download_model_invokes_download_manager(
progress_payload = await progress_response.json()
assert progress_response.status == 200
assert progress_payload == {"success": True, "progress": 42}
assert progress_payload == {
"success": True,
"progress": expected_progress,
"bytes_downloaded": download_manager_stub.last_progress_snapshot.bytes_downloaded,
"total_bytes": download_manager_stub.last_progress_snapshot.total_bytes,
"bytes_per_second": download_manager_stub.last_progress_snapshot.bytes_per_second,
}
ws_manager.cleanup_download_progress(payload["download_id"])
finally:
await client.close()

View File

@@ -7,6 +7,7 @@ from typing import Any, Dict, List
import pytest
from py.services.download_coordinator import DownloadCoordinator
from py.services.downloader import DownloadProgress
from py.services.metadata_sync_service import MetadataSyncService
from py.services.preview_asset_service import PreviewAssetService
from py.services.tag_update_service import TagUpdateService
@@ -191,10 +192,17 @@ def test_download_coordinator_emits_progress() -> None:
class DownloadManagerStub:
def __init__(self) -> None:
self.calls: List[Dict[str, Any]] = []
self.snapshot = DownloadProgress(
percent_complete=25.0,
bytes_downloaded=256,
total_bytes=1024,
bytes_per_second=128.0,
timestamp=0.0,
)
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
self.calls.append(kwargs)
await kwargs["progress_callback"](10)
await kwargs["progress_callback"](self.snapshot)
return {"success": True}
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
@@ -216,6 +224,12 @@ def test_download_coordinator_emits_progress() -> None:
assert result["success"] is True
assert manager_stub.calls
assert ws_stub.progress_events
expected_progress = round(manager_stub.snapshot.percent_complete)
first_event = ws_stub.progress_events[0]
assert first_event["progress"] == expected_progress
assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded
assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes
assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
assert cancel_result["success"] is True

View File

@@ -86,10 +86,29 @@ async def test_broadcast_download_progress_tracks_state(manager):
download_id = "abc"
manager._download_websockets[download_id] = ws
await manager.broadcast_download_progress(download_id, {"progress": 55})
await manager.broadcast_download_progress(
download_id,
{
"progress": 55,
"bytes_downloaded": 512,
"total_bytes": 1024,
"bytes_per_second": 128.0,
},
)
assert ws.messages == [{"progress": 55}]
assert manager.get_download_progress(download_id)["progress"] == 55
assert ws.messages == [
{
"progress": 55,
"bytes_downloaded": 512,
"total_bytes": 1024,
"bytes_per_second": 128.0,
}
]
stored = manager.get_download_progress(download_id)
assert stored["progress"] == 55
assert stored["bytes_downloaded"] == 512
assert stored["total_bytes"] == 1024
assert stored["bytes_per_second"] == 128.0
async def test_broadcast_download_progress_to_multiple_updates(manager):
@@ -107,7 +126,9 @@ async def test_broadcast_download_progress_to_multiple_updates(manager):
async def test_broadcast_download_progress_missing_socket(manager):
await manager.broadcast_download_progress("missing", {"progress": 30})
# Progress should be stored even without a live websocket
assert manager.get_download_progress("missing")["progress"] == 30
missing = manager.get_download_progress("missing")
assert missing["progress"] == 30
assert "bytes_downloaded" not in missing
async def test_auto_organize_progress_helpers(manager):