mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
test(downloads): cover pause and resume flows
This commit is contained in:
117
tests/services/test_download_coordinator.py
Normal file
117
tests/services/test_download_coordinator.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from py.services.download_coordinator import DownloadCoordinator
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubWebSocketManager:
|
||||
progress: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
broadcasts: List[Tuple[str, Dict[str, Any]]] = field(default_factory=list)
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
return "generated"
|
||||
|
||||
def get_download_progress(self, download_id: str) -> Dict[str, Any] | None:
|
||||
return self.progress.get(download_id)
|
||||
|
||||
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
||||
self.broadcasts.append((download_id, payload))
|
||||
|
||||
|
||||
async def test_pause_download_broadcasts_cached_state():
|
||||
ws_manager = StubWebSocketManager(
|
||||
progress={
|
||||
"dl": {
|
||||
"progress": 45,
|
||||
"bytes_downloaded": 1024,
|
||||
"total_bytes": 2048,
|
||||
"bytes_per_second": 256.0,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
download_manager = AsyncMock()
|
||||
download_manager.pause_download = AsyncMock(return_value={"success": True})
|
||||
|
||||
async def factory():
|
||||
return download_manager
|
||||
|
||||
coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory)
|
||||
|
||||
result = await coordinator.pause_download("dl")
|
||||
|
||||
assert result == {"success": True}
|
||||
assert ws_manager.broadcasts == [
|
||||
(
|
||||
"dl",
|
||||
{
|
||||
"status": "paused",
|
||||
"progress": 45,
|
||||
"download_id": "dl",
|
||||
"message": "Download paused by user",
|
||||
"bytes_downloaded": 1024,
|
||||
"total_bytes": 2048,
|
||||
"bytes_per_second": 0.0,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def test_resume_download_broadcasts_cached_state():
|
||||
ws_manager = StubWebSocketManager(
|
||||
progress={
|
||||
"dl": {
|
||||
"progress": 75,
|
||||
"bytes_downloaded": 2048,
|
||||
"total_bytes": 4096,
|
||||
"bytes_per_second": 512.0,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
download_manager = AsyncMock()
|
||||
download_manager.resume_download = AsyncMock(return_value={"success": True})
|
||||
|
||||
async def factory():
|
||||
return download_manager
|
||||
|
||||
coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory)
|
||||
|
||||
result = await coordinator.resume_download("dl")
|
||||
|
||||
assert result == {"success": True}
|
||||
assert ws_manager.broadcasts == [
|
||||
(
|
||||
"dl",
|
||||
{
|
||||
"status": "downloading",
|
||||
"progress": 75,
|
||||
"download_id": "dl",
|
||||
"message": "Download resumed by user",
|
||||
"bytes_downloaded": 2048,
|
||||
"total_bytes": 4096,
|
||||
"bytes_per_second": 512.0,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
async def test_pause_download_does_not_broadcast_on_failure():
|
||||
ws_manager = StubWebSocketManager()
|
||||
|
||||
download_manager = AsyncMock()
|
||||
download_manager.pause_download = AsyncMock(return_value={"success": False, "error": "nope"})
|
||||
|
||||
async def factory():
|
||||
return download_manager
|
||||
|
||||
coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory)
|
||||
|
||||
result = await coordinator.pause_download("dl")
|
||||
|
||||
assert result == {"success": False, "error": "nope"}
|
||||
assert ws_manager.broadcasts == []
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -398,6 +399,67 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
assert dummy_scanner.calls # ensure cache updated
|
||||
|
||||
|
||||
async def test_pause_download_updates_state():
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
manager._download_tasks[download_id] = object()
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
manager._pause_events[download_id] = pause_event
|
||||
manager._active_downloads[download_id] = {
|
||||
"status": "downloading",
|
||||
"bytes_per_second": 42.0,
|
||||
}
|
||||
|
||||
result = await manager.pause_download(download_id)
|
||||
|
||||
assert result == {"success": True, "message": "Download paused successfully"}
|
||||
assert download_id in manager._pause_events
|
||||
assert manager._pause_events[download_id].is_set() is False
|
||||
assert manager._active_downloads[download_id]["status"] == "paused"
|
||||
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
|
||||
|
||||
|
||||
async def test_pause_download_rejects_unknown_task():
|
||||
manager = DownloadManager()
|
||||
|
||||
result = await manager.pause_download("missing")
|
||||
|
||||
assert result == {"success": False, "error": "Download task not found"}
|
||||
|
||||
|
||||
async def test_resume_download_sets_event_and_status():
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
pause_event = asyncio.Event()
|
||||
manager._pause_events[download_id] = pause_event
|
||||
manager._active_downloads[download_id] = {
|
||||
"status": "paused",
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
|
||||
result = await manager.resume_download(download_id)
|
||||
|
||||
assert result == {"success": True, "message": "Download resumed successfully"}
|
||||
assert manager._pause_events[download_id].is_set() is True
|
||||
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||
|
||||
|
||||
async def test_resume_download_rejects_when_not_paused():
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
manager._pause_events[download_id] = pause_event
|
||||
|
||||
result = await manager.resume_download(download_id)
|
||||
|
||||
assert result == {"success": False, "error": "Download is not paused"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
|
||||
Reference in New Issue
Block a user