mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
refactor(example-images): minimize async lock contention by moving I/O outside critical sections
- Extract progress file loading to async methods to run in executor - Refactor start_download to reduce lock time by pre-loading data before entering lock - Improve check_pending_models efficiency with single-pass model collection and async loading - Add type hints to get_status method - Add tests for download task callback execution and error handling
This commit is contained in:
@@ -203,3 +203,150 @@ async def test_pause_or_resume_without_running_download(
|
||||
|
||||
with pytest.raises(download_module.DownloadNotRunningError):
|
||||
await manager.stop_download(object())
|
||||
|
||||
|
||||
async def test_download_task_callback_executes_on_completion(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||
) -> None:
|
||||
"""Test that _handle_download_task_done callback is executed when download completes."""
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager.settings["libraries"] = {"default": {}}
|
||||
settings_manager.settings["active_library"] = "default"
|
||||
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
callback_executed = asyncio.Event()
|
||||
original_callback = manager._handle_download_task_done
|
||||
|
||||
def tracking_callback(task, output_dir):
|
||||
original_callback(task, output_dir)
|
||||
callback_executed.set()
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager, "_handle_download_task_done", tracking_callback
|
||||
)
|
||||
|
||||
async def fake_download(self, *_args):
|
||||
# Simulate successful completion
|
||||
async with self._state_lock:
|
||||
self._progress["status"] = "completed"
|
||||
self._is_downloading = False
|
||||
self._download_task = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.DownloadManager,
|
||||
"_download_all_example_images",
|
||||
fake_download,
|
||||
)
|
||||
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
# Wait for callback to execute
|
||||
await asyncio.wait_for(callback_executed.wait(), timeout=1)
|
||||
assert manager._progress["status"] == "completed"
|
||||
|
||||
|
||||
async def test_download_task_callback_handles_errors(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||
) -> None:
|
||||
"""Test that _handle_download_task_done properly handles task errors and saves progress."""
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager.settings["libraries"] = {"default": {}}
|
||||
settings_manager.settings["active_library"] = "default"
|
||||
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
callback_executed = asyncio.Event()
|
||||
progress_saved = False
|
||||
original_save_progress = manager._save_progress
|
||||
|
||||
def tracking_save_progress(output_dir):
|
||||
nonlocal progress_saved
|
||||
progress_saved = True
|
||||
return original_save_progress(output_dir)
|
||||
|
||||
monkeypatch.setattr(manager, "_save_progress", tracking_save_progress)
|
||||
|
||||
original_callback = manager._handle_download_task_done
|
||||
|
||||
def tracking_callback(task, output_dir):
|
||||
original_callback(task, output_dir)
|
||||
callback_executed.set()
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager, "_handle_download_task_done", tracking_callback
|
||||
)
|
||||
|
||||
async def fake_download_with_error(self, *_args):
|
||||
raise RuntimeError("Simulated download error")
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.DownloadManager,
|
||||
"_download_all_example_images",
|
||||
fake_download_with_error,
|
||||
)
|
||||
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
# Wait for callback to execute (it should handle the error)
|
||||
await asyncio.wait_for(callback_executed.wait(), timeout=1)
|
||||
# Progress should be saved even on error
|
||||
assert progress_saved is True
|
||||
|
||||
|
||||
async def test_get_status_returns_correct_state(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path
|
||||
) -> None:
|
||||
"""Test that get_status returns the correct download state."""
|
||||
settings_manager = get_settings_manager()
|
||||
settings_manager.settings["example_images_path"] = str(tmp_path)
|
||||
settings_manager.settings["libraries"] = {"default": {}}
|
||||
settings_manager.settings["active_library"] = "default"
|
||||
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
# Test idle state
|
||||
status = await manager.get_status(object())
|
||||
assert status["success"] is True
|
||||
assert status["is_downloading"] is False
|
||||
assert status["status"]["status"] == "idle"
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def fake_download(self, *_args):
|
||||
started.set()
|
||||
await release.wait()
|
||||
async with self._state_lock:
|
||||
self._is_downloading = False
|
||||
self._download_task = None
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.DownloadManager,
|
||||
"_download_all_example_images",
|
||||
fake_download,
|
||||
)
|
||||
|
||||
# Start download
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
await asyncio.wait_for(started.wait(), timeout=1)
|
||||
|
||||
# Test running state
|
||||
status = await manager.get_status(object())
|
||||
assert status["success"] is True
|
||||
assert status["is_downloading"] is True
|
||||
assert status["status"]["status"] == "running"
|
||||
|
||||
# Cleanup
|
||||
release.set()
|
||||
if manager._download_task:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
|
||||
Reference in New Issue
Block a user