refactor(example-images): minimize async lock contention by moving I/O outside critical sections

- Extract progress file loading to async methods to run in executor
- Refactor start_download to reduce lock time by pre-loading data before entering lock
- Improve check_pending_models efficiency with single-pass model collection and async loading
- Add type hints to get_status method
- Add tests for download task callback execution and error handling
This commit is contained in:
Will Miao
2026-02-11 09:24:00 +08:00
parent 94edde7744
commit 6b1e3f06ed
2 changed files with 305 additions and 125 deletions

View File

@@ -203,3 +203,150 @@ async def test_pause_or_resume_without_running_download(
with pytest.raises(download_module.DownloadNotRunningError):
await manager.stop_download(object())
async def test_download_task_callback_executes_on_completion(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that _handle_download_task_done callback is executed when download completes."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
callback_executed = asyncio.Event()
original_callback = manager._handle_download_task_done
def tracking_callback(task, output_dir):
original_callback(task, output_dir)
callback_executed.set()
monkeypatch.setattr(
manager, "_handle_download_task_done", tracking_callback
)
async def fake_download(self, *_args):
# Simulate successful completion
async with self._state_lock:
self._progress["status"] = "completed"
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
# Wait for callback to execute
await asyncio.wait_for(callback_executed.wait(), timeout=1)
assert manager._progress["status"] == "completed"
async def test_download_task_callback_handles_errors(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that _handle_download_task_done properly handles task errors and saves progress."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
callback_executed = asyncio.Event()
progress_saved = False
original_save_progress = manager._save_progress
def tracking_save_progress(output_dir):
nonlocal progress_saved
progress_saved = True
return original_save_progress(output_dir)
monkeypatch.setattr(manager, "_save_progress", tracking_save_progress)
original_callback = manager._handle_download_task_done
def tracking_callback(task, output_dir):
original_callback(task, output_dir)
callback_executed.set()
monkeypatch.setattr(
manager, "_handle_download_task_done", tracking_callback
)
async def fake_download_with_error(self, *_args):
raise RuntimeError("Simulated download error")
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download_with_error,
)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
# Wait for callback to execute (it should handle the error)
await asyncio.wait_for(callback_executed.wait(), timeout=1)
# Progress should be saved even on error
assert progress_saved is True
async def test_get_status_returns_correct_state(
monkeypatch: pytest.MonkeyPatch, tmp_path
) -> None:
"""Test that get_status returns the correct download state."""
settings_manager = get_settings_manager()
settings_manager.settings["example_images_path"] = str(tmp_path)
settings_manager.settings["libraries"] = {"default": {}}
settings_manager.settings["active_library"] = "default"
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
# Test idle state
status = await manager.get_status(object())
assert status["success"] is True
assert status["is_downloading"] is False
assert status["status"]["status"] == "idle"
started = asyncio.Event()
release = asyncio.Event()
async def fake_download(self, *_args):
started.set()
await release.wait()
async with self._state_lock:
self._is_downloading = False
self._download_task = None
monkeypatch.setattr(
download_module.DownloadManager,
"_download_all_example_images",
fake_download,
)
# Start download
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
await asyncio.wait_for(started.wait(), timeout=1)
# Test running state
status = await manager.get_status(object())
assert status["success"] is True
assert status["is_downloading"] is True
assert status["status"]["status"] == "running"
# Cleanup
release.set()
if manager._download_task:
await asyncio.wait_for(manager._download_task, timeout=1)