Files
ComfyUI-Lora-Manager/tests/services/test_example_images_download_manager_async.py

607 lines
20 KiB
Python

from __future__ import annotations
import asyncio
import json
from pathlib import Path
from types import SimpleNamespace
import pytest
from py.services.settings_manager import SettingsManager, get_settings_manager
from py.utils import example_images_download_manager as download_module
class RecordingWebSocketManager:
"""Collects broadcast payloads for assertions."""
def __init__(self) -> None:
self.payloads: list[dict] = []
async def broadcast(self, payload: dict) -> None:
self.payloads.append(payload)
class StubScanner:
"""Scanner double returning predetermined cache contents."""
def __init__(self, models: list[dict]) -> None:
self._cache = SimpleNamespace(raw_data=models)
async def get_cached_data(self):
return self._cache
async def update_single_model_cache(self, _old_path, _new_path, metadata):
# Replace the cached entry with the updated metadata for assertions.
for index, model in enumerate(self._cache.raw_data):
if model.get("file_path") == metadata.get("file_path"):
self._cache.raw_data[index] = metadata
break
return True
def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> None:
async def _get_lora_scanner(cls):
return scanner
monkeypatch.setattr(
download_module.ServiceRegistry,
"get_lora_scanner",
classmethod(_get_lora_scanner),
)
@pytest.mark.usefixtures("tmp_path")
async def test_start_download_rejects_parallel_runs(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
model = {
"sha256": "abc123",
"model_name": "Example",
"file_path": str(tmp_path / "example.safetensors"),
"file_name": "example.safetensors",
}
_patch_scanner(monkeypatch, StubScanner([model]))
started = asyncio.Event()
release = asyncio.Event()
async def fake_process_local_examples(*_args, **_kwargs):
started.set()
await release.wait()
return True
async def fake_update_metadata(*_args, **_kwargs):
return True
async def fake_get_downloader():
return object()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.MetadataUpdater,
"update_metadata_from_local_examples",
staticmethod(fake_update_metadata),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
try:
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
await asyncio.wait_for(started.wait(), timeout=1)
with pytest.raises(download_module.DownloadInProgressError) as exc:
await manager.start_download({"model_types": ["lora"], "delay": 0})
snapshot = exc.value.progress_snapshot
assert snapshot["status"] == "running"
assert snapshot["current_model"] == "Example (abc123)"
statuses = [payload["status"] for payload in ws_manager.payloads]
assert "running" in statuses
finally:
release.set()
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
@pytest.mark.usefixtures("tmp_path")
async def test_pause_resume_blocks_processing(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [
{
"sha256": "hash-one",
"model_name": "Model One",
"file_path": str(tmp_path / "model-one.safetensors"),
"file_name": "model-one.safetensors",
"civitai": {"images": [{"url": "https://example.com/one.png"}]},
},
{
"sha256": "hash-two",
"model_name": "Model Two",
"file_path": str(tmp_path / "model-two.safetensors"),
"file_name": "model-two.safetensors",
"civitai": {"images": [{"url": "https://example.com/two.png"}]},
},
]
_patch_scanner(monkeypatch, StubScanner(models))
async def fake_process_local_examples(*_args, **_kwargs):
return False
async def fake_update_metadata(*_args, **_kwargs):
return True
first_call_started = asyncio.Event()
first_release = asyncio.Event()
second_call_started = asyncio.Event()
call_order: list[str] = []
async def fake_download_model_images(model_hash, *_args, **_kwargs):
call_order.append(model_hash)
if len(call_order) == 1:
first_call_started.set()
await first_release.wait()
else:
second_call_started.set()
return True, False, []
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.MetadataUpdater,
"update_metadata_from_local_examples",
staticmethod(fake_update_metadata),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images_with_tracking",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
original_sleep = download_module.asyncio.sleep
pause_gate = asyncio.Event()
resume_gate = asyncio.Event()
async def fake_sleep(delay: float):
if delay == 1:
pause_gate.set()
await resume_gate.wait()
else:
await original_sleep(delay)
monkeypatch.setattr(download_module.asyncio, "sleep", fake_sleep)
try:
await manager.start_download({"model_types": ["lora"], "delay": 0})
await asyncio.wait_for(first_call_started.wait(), timeout=1)
await manager.pause_download({})
first_release.set()
await asyncio.wait_for(pause_gate.wait(), timeout=1)
assert manager._progress["status"] == "paused"
assert not second_call_started.is_set()
statuses = [payload["status"] for payload in ws_manager.payloads]
paused_index = statuses.index("paused")
await asyncio.sleep(0)
assert not second_call_started.is_set()
await manager.resume_download({})
resume_gate.set()
await asyncio.wait_for(second_call_started.wait(), timeout=1)
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
statuses_after = [payload["status"] for payload in ws_manager.payloads]
running_after = next(
i for i, status in enumerate(statuses_after[paused_index + 1 :], start=paused_index + 1) if status == "running"
)
assert running_after > paused_index
assert "completed" in statuses_after[running_after:]
assert call_order == ["hash-one", "hash-two"]
finally:
first_release.set()
resume_gate.set()
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_folder_migrated_and_skipped(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
model_hash = "d" * 64
model_path = tmp_path / "model.safetensors"
model_path.write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Migrated Model",
"file_path": str(model_path),
"file_name": "model.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
process_called = False
download_called = False
async def fake_process_local_examples(*_args, **_kwargs):
nonlocal process_called
process_called = True
return False
async def fake_download_model_images(*_args, **_kwargs):
nonlocal download_called
download_called = True
return True, False, []
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images_with_tracking",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
try:
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
finally:
if manager._download_task is not None and not manager._download_task.done():
await asyncio.wait_for(manager._download_task, timeout=1)
library_root = tmp_path / "extra"
migrated_folder = library_root / model_hash
assert migrated_folder.exists()
assert not legacy_folder.exists()
assert not process_called
assert not download_called
assert model_hash in manager._progress["processed_models"]
@pytest.mark.usefixtures("tmp_path")
async def test_legacy_progress_file_migrates(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
model_hash = "e" * 64
model_path = tmp_path / "model-two.safetensors"
model_path.write_text("data", encoding="utf-8")
legacy_progress = tmp_path / ".download_progress.json"
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
legacy_folder = tmp_path / model_hash
legacy_folder.mkdir()
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Legacy Progress Model",
"file_path": str(model_path),
"file_name": "model-two.safetensors",
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
}
_patch_scanner(monkeypatch, StubScanner([model]))
async def fake_process_local_examples(*_args, **_kwargs):
return False
async def fake_download_model_images(*_args, **_kwargs):
raise AssertionError("Remote download should not be attempted when progress is migrated")
async def fake_get_downloader():
class _Downloader:
async def download_to_memory(self, *_a, **_kw):
return True, b"", {}
return _Downloader()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"download_model_images_with_tracking",
staticmethod(fake_download_model_images),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
new_progress = (tmp_path / "extra") / ".download_progress.json"
assert model_hash in manager._progress["processed_models"]
assert not legacy_progress.exists()
assert new_progress.exists()
contents = json.loads(new_progress.read_text(encoding="utf-8"))
assert model_hash in contents.get("processed_models", [])
@pytest.mark.usefixtures("tmp_path")
async def test_download_remains_in_initial_library(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
monkeypatch.setitem(settings_manager.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
monkeypatch.setitem(settings_manager.settings, "active_library", "LibraryA")
state = {"active": "LibraryA"}
def fake_get_active_library_name(self):
return state["active"]
monkeypatch.setattr(SettingsManager, "get_active_library_name", fake_get_active_library_name)
model_hash = "f" * 64
model_path = tmp_path / "example-model.safetensors"
model_path.write_text("data", encoding="utf-8")
model = {
"sha256": model_hash,
"model_name": "Library Switch Model",
"file_path": str(model_path),
"file_name": "example-model.safetensors",
}
_patch_scanner(monkeypatch, StubScanner([model]))
async def fake_process_local_examples(
_file_path,
_file_name,
_model_name,
model_dir,
_optimize,
):
Path(model_dir).mkdir(parents=True, exist_ok=True)
(Path(model_dir) / "local.txt").write_text("data", encoding="utf-8")
state["active"] = "LibraryB"
return True
async def fake_update_metadata(*_args, **_kwargs):
return True
async def fake_get_downloader():
return object()
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.MetadataUpdater,
"update_metadata_from_local_examples",
staticmethod(fake_update_metadata),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
library_a_root = tmp_path / "LibraryA"
library_b_root = tmp_path / "LibraryB"
progress_file = library_a_root / ".download_progress.json"
model_dir = library_a_root / model_hash
assert progress_file.exists()
assert (model_dir / "local.txt").exists()
assert not (library_b_root / ".download_progress.json").exists()
assert not (library_b_root / model_hash).exists()
@pytest.mark.usefixtures("tmp_path")
async def test_not_found_example_images_are_cleaned(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
images_root = tmp_path / "examples"
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(images_root))
model_hash = "f" * 64
model_path = tmp_path / "missing-model.safetensors"
model_path.write_text("data", encoding="utf-8")
missing_url = "https://example.com/missing.png"
valid_url = "https://example.com/valid.png"
model_metadata = {
"sha256": model_hash,
"model_name": "Missing Example",
"file_path": str(model_path),
"file_name": "missing-model.safetensors",
"civitai": {
"images": [
{"url": missing_url},
{"url": valid_url},
]
},
}
scanner = StubScanner([model_metadata.copy()])
_patch_scanner(monkeypatch, scanner)
model_dir = images_root / model_hash
model_dir.mkdir(parents=True, exist_ok=True)
(model_dir / "image_0.png").write_bytes(b"first")
(model_dir / "image_1.png").write_bytes(b"second")
async def fake_process_local_examples(*_args, **_kwargs):
return False
refresh_calls: list[str] = []
async def fake_refresh(model_hash_arg, *_args, **_kwargs):
refresh_calls.append(model_hash_arg)
return True
async def fake_get_updated_model(model_hash_arg, _scanner):
assert model_hash_arg == model_hash
return model_metadata
async def fake_save_metadata(_path, metadata):
model_metadata.update(metadata)
return True
class DownloaderStub:
def __init__(self):
self.calls: list[str] = []
async def download_to_memory(self, url, *_args, **_kwargs):
self.calls.append(url)
if url == missing_url:
return False, "File not found", None
return True, b"\x89PNG\r\n\x1a\n", {"content-type": "image/png"}
downloader = DownloaderStub()
async def fake_get_downloader():
return downloader
monkeypatch.setattr(
download_module.ExampleImagesProcessor,
"process_local_examples",
staticmethod(fake_process_local_examples),
)
monkeypatch.setattr(
download_module.MetadataUpdater,
"refresh_model_metadata",
staticmethod(fake_refresh),
)
monkeypatch.setattr(
download_module.MetadataUpdater,
"get_updated_model",
staticmethod(fake_get_updated_model),
)
monkeypatch.setattr(
download_module.MetadataManager,
"save_metadata",
staticmethod(fake_save_metadata),
)
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
monkeypatch.setattr(download_module, "_model_directory_has_files", lambda _path: False)
result = await manager.start_download({"model_types": ["lora"], "delay": 0, "optimize": False})
assert result["success"] is True
if manager._download_task is not None:
await asyncio.wait_for(manager._download_task, timeout=1)
assert refresh_calls == [model_hash]
assert missing_url in downloader.calls
assert manager._progress["failed_models"] == {model_hash}
assert model_hash in manager._progress["processed_models"]
remaining_images = model_metadata["civitai"]["images"]
assert remaining_images == [
{"url": missing_url, "downloadFailed": True, "downloadError": "not_found"},
{"url": valid_url},
]
files = sorted(p.name for p in model_dir.iterdir())
assert files == ["image_0.png", "image_1.png"]
assert (model_dir / "image_0.png").read_bytes() == b"first"
assert (model_dir / "image_1.png").read_bytes() == b"second"
@pytest.fixture
def settings_manager():
return get_settings_manager()