From e128c80eb12b976a6bb272828cccc84338163e68 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Tue, 23 Sep 2025 14:58:35 +0800 Subject: [PATCH] test(services): add async example image download tests --- ...t_example_images_download_manager_async.py | 228 ++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 tests/services/test_example_images_download_manager_async.py diff --git a/tests/services/test_example_images_download_manager_async.py b/tests/services/test_example_images_download_manager_async.py new file mode 100644 index 00000000..7eef56fb --- /dev/null +++ b/tests/services/test_example_images_download_manager_async.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest + +from py.services.settings_manager import settings +from py.utils import example_images_download_manager as download_module + + +class RecordingWebSocketManager: + """Collects broadcast payloads for assertions.""" + + def __init__(self) -> None: + self.payloads: list[dict] = [] + + async def broadcast(self, payload: dict) -> None: + self.payloads.append(payload) + + +class StubScanner: + """Scanner double returning predetermined cache contents.""" + + def __init__(self, models: list[dict]) -> None: + self._cache = SimpleNamespace(raw_data=models) + + async def get_cached_data(self): + return self._cache + + +def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> None: + async def _get_lora_scanner(cls): + return scanner + + monkeypatch.setattr( + download_module.ServiceRegistry, + "get_lora_scanner", + classmethod(_get_lora_scanner), + ) + + +@pytest.mark.usefixtures("tmp_path") +async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + + model = { + "sha256": "abc123", + "model_name": "Example", + "file_path": str(tmp_path / "example.safetensors"), + "file_name": "example.safetensors", + } + _patch_scanner(monkeypatch, StubScanner([model])) + + started = asyncio.Event() + release = asyncio.Event() + + async def fake_process_local_examples(*_args, **_kwargs): + started.set() + await release.wait() + return True + + async def fake_update_metadata(*_args, **_kwargs): + return True + + async def fake_get_downloader(): + return object() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.MetadataUpdater, + "update_metadata_from_local_examples", + staticmethod(fake_update_metadata), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + try: + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + await asyncio.wait_for(started.wait(), timeout=1) + + with pytest.raises(download_module.DownloadInProgressError) as exc: + await manager.start_download({"model_types": ["lora"], "delay": 0}) + + snapshot = exc.value.progress_snapshot + assert snapshot["status"] == "running" + assert snapshot["current_model"] == "Example (abc123)" + + statuses = [payload["status"] for payload in ws_manager.payloads] + assert "running" in statuses + + finally: + release.set() + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + +@pytest.mark.usefixtures("tmp_path") +async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + + models = [ + { + "sha256": "hash-one", + "model_name": "Model One", + "file_path": str(tmp_path / "model-one.safetensors"), + "file_name": "model-one.safetensors", + "civitai": {"images": [{"url": "https://example.com/one.png"}]}, + }, + { + "sha256": "hash-two", + "model_name": "Model Two", + "file_path": str(tmp_path / "model-two.safetensors"), + "file_name": "model-two.safetensors", + "civitai": {"images": [{"url": "https://example.com/two.png"}]}, + }, + ] + _patch_scanner(monkeypatch, StubScanner(models)) + + async def fake_process_local_examples(*_args, **_kwargs): + return False + + async def fake_update_metadata(*_args, **_kwargs): + return True + + first_call_started = asyncio.Event() + first_release = asyncio.Event() + second_call_started = asyncio.Event() + call_order: list[str] = [] + + async def fake_download_model_images(model_hash, *_args, **_kwargs): + call_order.append(model_hash) + if len(call_order) == 1: + first_call_started.set() + await first_release.wait() + else: + second_call_started.set() + return True, False + + async def fake_get_downloader(): + class _Downloader: + async def download_to_memory(self, *_a, **_kw): + return True, b"", {} + + return _Downloader() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.MetadataUpdater, + "update_metadata_from_local_examples", + staticmethod(fake_update_metadata), + ) + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "download_model_images", + staticmethod(fake_download_model_images), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + original_sleep = download_module.asyncio.sleep + pause_gate = asyncio.Event() + resume_gate = asyncio.Event() + + async def fake_sleep(delay: float): + if delay == 1: + pause_gate.set() + await resume_gate.wait() + else: + await original_sleep(delay) + + monkeypatch.setattr(download_module.asyncio, "sleep", fake_sleep) + + try: + await manager.start_download({"model_types": ["lora"], "delay": 0}) + + await asyncio.wait_for(first_call_started.wait(), timeout=1) + + await manager.pause_download({}) + + first_release.set() + + await asyncio.wait_for(pause_gate.wait(), timeout=1) + assert manager._progress["status"] == "paused" + assert not second_call_started.is_set() + + statuses = [payload["status"] for payload in ws_manager.payloads] + paused_index = statuses.index("paused") + + await asyncio.sleep(0) + assert not second_call_started.is_set() + + await manager.resume_download({}) + resume_gate.set() + + await asyncio.wait_for(second_call_started.wait(), timeout=1) + + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + statuses_after = [payload["status"] for payload in ws_manager.payloads] + running_after = next( + i for i, status in enumerate(statuses_after[paused_index + 1 :], start=paused_index + 1) if status == "running" + ) + assert running_after > paused_index + assert "completed" in statuses_after[running_after:] + assert call_order == ["hash-one", "hash-two"] + + finally: + first_release.set() + resume_gate.set() + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)