Files
ComfyUI-Lora-Manager/tests/utils/test_example_images_download_manager_unit.py
Will Miao 6b1e3f06ed refactor(example-images): minimize async lock contention by moving I/O outside critical sections
- Extract progress file loading to async methods to run in executor
- Refactor start_download to reduce lock time by pre-loading data before entering lock
- Improve check_pending_models efficiency with single-pass model collection and async loading
- Add type hints to get_status method
- Add tests for download task callback execution and error handling
2026-02-11 09:24:00 +08:00

353 lines
12 KiB
Python

from __future__ import annotations
import asyncio
import time
from typing import Any, Dict
import pytest
from py.services.settings_manager import get_settings_manager
from py.utils import example_images_download_manager as download_module
class RecordingWebSocketManager:
def __init__(self) -> None:
self.payloads: list[Dict[str, Any]] = []
async def broadcast(self, payload: Dict[str, Any]) -> None:
self.payloads.append(payload)
@pytest.fixture(autouse=True)
def restore_settings() -> None:
manager = get_settings_manager()
original = manager.settings.copy()
try:
yield
finally:
manager.settings.clear()
manager.settings.update(original)
async def test_start_download_requires_configured_path(
monkeypatch: pytest.MonkeyPatch,
) -> None:
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
# Ensure example_images_path is not configured
settings_manager = get_settings_manager()
settings_manager.settings.pop("example_images_path", None)
with pytest.raises(download_module.DownloadConfigurationError) as exc_info:
await manager.start_download({})
assert "not configured" in str(exc_info.value)
result = await manager.start_download({"auto_mode": True})
assert result["success"] is True
assert "skipping auto download" in result["message"]
async def test_start_download_bootstraps_progress_and_task(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(
self, output_dir, optimize, model_types, delay, library_name, force=False
):
started.set()
await release.wait()
async with self._state_lock:
self._is_downloading = False
self._download_task = None
self._progress["status"] = "completed"
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
assert manager._is_downloading is True
await asyncio.wait_for(started.wait(), timeout=1)
assert ws_manager.payloads[0]["status"] == "running"
task = manager._download_task
assert task is not None
release.set()
await asyncio.wait_for(task, timeout=1)
assert manager._is_downloading is False
assert manager._progress["status"] == "completed"
async def test_pause_and_resume_flow(monkeypatch: pytest.MonkeyPatch, tmp_path) -> None:
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(self, *_args):
started.set()
await release.wait()
async with self._state_lock:
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
await manager.start_download({})
await asyncio.wait_for(started.wait(), timeout=1)
pause_response = await manager.pause_download(object())
assert pause_response == {"success": True, "message": "Download paused"}
assert manager._progress["status"] == "paused"
resume_response = await manager.resume_download(object())
assert resume_response == {"success": True, "message": "Download resumed"}
assert manager._progress["status"] == "running"
task = manager._download_task
assert task is not None
release.set()
await asyncio.wait_for(task, timeout=1)
async def test_stop_download_transitions_to_stopped(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(self, *_args):
started.set()
await release.wait()
async with self._state_lock:
if self._stop_requested and self._progress["status"] == "stopping":
self._progress["status"] = "stopped"
else:
self._progress["status"] = "completed"
self._progress["end_time"] = time.time()
self._stop_requested = False
await self._broadcast_progress(status=self._progress["status"])
async with self._state_lock:
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
await manager.start_download({})
await asyncio.wait_for(started.wait(), timeout=1)
stop_response = await manager.stop_download(object())
assert stop_response == {"success": True, "message": "Download stopping"}
assert manager._progress["status"] == "stopping"
task = manager._download_task
assert task is not None
release.set()
await asyncio.wait_for(task, timeout=1)
assert manager._progress["status"] == "stopped"
assert manager._is_downloading is False
assert manager._stop_requested is False
statuses = [payload["status"] for payload in ws_manager.payloads]
assert "stopping" in statuses
assert "stopped" in statuses
async def test_pause_or_resume_without_running_download(
monkeypatch: pytest.MonkeyPatch,
) -> None:
manager = download_module.DownloadManager(ws_manager=RecordingWebSocketManager())
with pytest.raises(download_module.DownloadNotRunningError):
await manager.pause_download(object())
with pytest.raises(download_module.DownloadNotRunningError):
await manager.resume_download(object())
with pytest.raises(download_module.DownloadNotRunningError):
await manager.stop_download(object())
async def test_download_task_callback_executes_on_completion(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that _handle_download_task_done callback is executed when download completes."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
callback_executed = asyncio.Event()
original_callback = manager._handle_download_task_done
def tracking_callback(task, output_dir):
original_callback(task, output_dir)
callback_executed.set()
monkeypatch.setattr(
manager, "_handle_download_task_done", tracking_callback
)
async def fake_download(self, *_args):
# Simulate successful completion
async with self._state_lock:
self._progress["status"] = "completed"
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
# Wait for callback to execute
await asyncio.wait_for(callback_executed.wait(), timeout=1)
assert manager._progress["status"] == "completed"
async def test_download_task_callback_handles_errors(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that _handle_download_task_done properly handles task errors and saves progress."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
callback_executed = asyncio.Event()
progress_saved = False
original_save_progress = manager._save_progress
def tracking_save_progress(output_dir):
nonlocal progress_saved
progress_saved = True
return original_save_progress(output_dir)
monkeypatch.setattr(manager, "_save_progress", tracking_save_progress)
original_callback = manager._handle_download_task_done
def tracking_callback(task, output_dir):
original_callback(task, output_dir)
callback_executed.set()
monkeypatch.setattr(
manager, "_handle_download_task_done", tracking_callback
)
async def fake_download_with_error(self, *_args):
raise RuntimeError("Simulated download error")
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download_with_error,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
# Wait for callback to execute (it should handle the error)
await asyncio.wait_for(callback_executed.wait(), timeout=1)
# Progress should be saved even on error
assert progress_saved is True
async def test_get_status_returns_correct_state(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that get_status returns the correct download state."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
# Test idle state
status = await manager.get_status(object())
assert status["success"] is True
assert status["is_downloading"] is False
assert status["status"]["status"] == "idle"
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(self, *_args):
started.set()
await release.wait()
async with self._state_lock:
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
# Start download
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
await asyncio.wait_for(started.wait(), timeout=1)
# Test running state
status = await manager.get_status(object())
assert status["success"] is True
assert status["is_downloading"] is True
assert status["status"]["status"] == "running"
# Cleanup
release.set()
if manager._download_task:
await asyncio.wait_for(manager._download_task, timeout=1)