mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(downloads): expose throughput metrics in progress APIs
This commit is contained in:
@@ -5,6 +5,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
folder_paths_stub = types.SimpleNamespace(get_folder_paths=lambda *_: [])
|
||||
sys.modules.setdefault("folder_paths", folder_paths_stub)
|
||||
@@ -16,6 +17,7 @@ from aiohttp.test_utils import TestClient, TestServer
|
||||
from py.config import config
|
||||
from py.routes.base_model_routes import BaseModelRoutes
|
||||
from py.services import model_file_service
|
||||
from py.services.downloader import DownloadProgress
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
from py.services.model_file_service import AutoOrganizeResult
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
@@ -59,12 +61,21 @@ def download_manager_stub():
|
||||
self.error = None
|
||||
self.cancelled = []
|
||||
self.active_downloads = {}
|
||||
self.last_progress_snapshot: Optional[DownloadProgress] = None
|
||||
|
||||
async def download_from_civitai(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
await kwargs["progress_callback"](42)
|
||||
snapshot = DownloadProgress(
|
||||
percent_complete=50.0,
|
||||
bytes_downloaded=5120,
|
||||
total_bytes=10240,
|
||||
bytes_per_second=2048.0,
|
||||
timestamp=0.0,
|
||||
)
|
||||
self.last_progress_snapshot = snapshot
|
||||
await kwargs["progress_callback"](snapshot)
|
||||
return {"success": True, "path": "/tmp/model.safetensors"}
|
||||
|
||||
async def cancel_download(self, download_id):
|
||||
@@ -332,7 +343,11 @@ def test_download_model_invokes_download_manager(
|
||||
assert call_args["download_id"] == payload["download_id"]
|
||||
progress = ws_manager.get_download_progress(payload["download_id"])
|
||||
assert progress is not None
|
||||
assert progress["progress"] == 42
|
||||
expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete)
|
||||
assert progress["progress"] == expected_progress
|
||||
assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded
|
||||
assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes
|
||||
assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second
|
||||
assert "timestamp" in progress
|
||||
|
||||
progress_response = await client.get(
|
||||
@@ -341,7 +356,13 @@ def test_download_model_invokes_download_manager(
|
||||
progress_payload = await progress_response.json()
|
||||
|
||||
assert progress_response.status == 200
|
||||
assert progress_payload == {"success": True, "progress": 42}
|
||||
assert progress_payload == {
|
||||
"success": True,
|
||||
"progress": expected_progress,
|
||||
"bytes_downloaded": download_manager_stub.last_progress_snapshot.bytes_downloaded,
|
||||
"total_bytes": download_manager_stub.last_progress_snapshot.total_bytes,
|
||||
"bytes_per_second": download_manager_stub.last_progress_snapshot.bytes_per_second,
|
||||
}
|
||||
ws_manager.cleanup_download_progress(payload["download_id"])
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Dict, List
|
||||
import pytest
|
||||
|
||||
from py.services.download_coordinator import DownloadCoordinator
|
||||
from py.services.downloader import DownloadProgress
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
from py.services.preview_asset_service import PreviewAssetService
|
||||
from py.services.tag_update_service import TagUpdateService
|
||||
@@ -191,10 +192,17 @@ def test_download_coordinator_emits_progress() -> None:
|
||||
class DownloadManagerStub:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
self.snapshot = DownloadProgress(
|
||||
percent_complete=25.0,
|
||||
bytes_downloaded=256,
|
||||
total_bytes=1024,
|
||||
bytes_per_second=128.0,
|
||||
timestamp=0.0,
|
||||
)
|
||||
|
||||
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
||||
self.calls.append(kwargs)
|
||||
await kwargs["progress_callback"](10)
|
||||
await kwargs["progress_callback"](self.snapshot)
|
||||
return {"success": True}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
@@ -216,6 +224,12 @@ def test_download_coordinator_emits_progress() -> None:
|
||||
assert result["success"] is True
|
||||
assert manager_stub.calls
|
||||
assert ws_stub.progress_events
|
||||
expected_progress = round(manager_stub.snapshot.percent_complete)
|
||||
first_event = ws_stub.progress_events[0]
|
||||
assert first_event["progress"] == expected_progress
|
||||
assert first_event["bytes_downloaded"] == manager_stub.snapshot.bytes_downloaded
|
||||
assert first_event["total_bytes"] == manager_stub.snapshot.total_bytes
|
||||
assert first_event["bytes_per_second"] == manager_stub.snapshot.bytes_per_second
|
||||
|
||||
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
||||
assert cancel_result["success"] is True
|
||||
|
||||
@@ -86,10 +86,29 @@ async def test_broadcast_download_progress_tracks_state(manager):
|
||||
download_id = "abc"
|
||||
manager._download_websockets[download_id] = ws
|
||||
|
||||
await manager.broadcast_download_progress(download_id, {"progress": 55})
|
||||
await manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"progress": 55,
|
||||
"bytes_downloaded": 512,
|
||||
"total_bytes": 1024,
|
||||
"bytes_per_second": 128.0,
|
||||
},
|
||||
)
|
||||
|
||||
assert ws.messages == [{"progress": 55}]
|
||||
assert manager.get_download_progress(download_id)["progress"] == 55
|
||||
assert ws.messages == [
|
||||
{
|
||||
"progress": 55,
|
||||
"bytes_downloaded": 512,
|
||||
"total_bytes": 1024,
|
||||
"bytes_per_second": 128.0,
|
||||
}
|
||||
]
|
||||
stored = manager.get_download_progress(download_id)
|
||||
assert stored["progress"] == 55
|
||||
assert stored["bytes_downloaded"] == 512
|
||||
assert stored["total_bytes"] == 1024
|
||||
assert stored["bytes_per_second"] == 128.0
|
||||
|
||||
|
||||
async def test_broadcast_download_progress_to_multiple_updates(manager):
|
||||
@@ -107,7 +126,9 @@ async def test_broadcast_download_progress_to_multiple_updates(manager):
|
||||
async def test_broadcast_download_progress_missing_socket(manager):
|
||||
await manager.broadcast_download_progress("missing", {"progress": 30})
|
||||
# Progress should be stored even without a live websocket
|
||||
assert manager.get_download_progress("missing")["progress"] == 30
|
||||
missing = manager.get_download_progress("missing")
|
||||
assert missing["progress"] == 30
|
||||
assert "bytes_downloaded" not in missing
|
||||
|
||||
|
||||
async def test_auto_organize_progress_helpers(manager):
|
||||
|
||||
Reference in New Issue
Block a user