mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
607 lines
20 KiB
Python
607 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from py.services.settings_manager import SettingsManager, get_settings_manager
|
|
from py.utils import example_images_download_manager as download_module
|
|
|
|
|
|
class RecordingWebSocketManager:
|
|
"""Collects broadcast payloads for assertions."""
|
|
|
|
def __init__(self) -> None:
|
|
self.payloads: list[dict] = []
|
|
|
|
async def broadcast(self, payload: dict) -> None:
|
|
self.payloads.append(payload)
|
|
|
|
|
|
class StubScanner:
|
|
"""Scanner double returning predetermined cache contents."""
|
|
|
|
def __init__(self, models: list[dict]) -> None:
|
|
self._cache = SimpleNamespace(raw_data=models)
|
|
|
|
async def get_cached_data(self):
|
|
return self._cache
|
|
|
|
async def update_single_model_cache(self, _old_path, _new_path, metadata):
|
|
# Replace the cached entry with the updated metadata for assertions.
|
|
for index, model in enumerate(self._cache.raw_data):
|
|
if model.get("file_path") == metadata.get("file_path"):
|
|
self._cache.raw_data[index] = metadata
|
|
break
|
|
return True
|
|
|
|
|
|
def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> None:
|
|
async def _get_lora_scanner(cls):
|
|
return scanner
|
|
|
|
monkeypatch.setattr(
|
|
download_module.ServiceRegistry,
|
|
"get_lora_scanner",
|
|
classmethod(_get_lora_scanner),
|
|
)
|
|
|
|
|
|
@pytest.mark.usefixtures("tmp_path")
|
|
async def test_start_download_rejects_parallel_runs(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path,
|
|
settings_manager,
|
|
):
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
|
|
|
model = {
|
|
"sha256": "abc123",
|
|
"model_name": "Example",
|
|
"file_path": str(tmp_path / "example.safetensors"),
|
|
"file_name": "example.safetensors",
|
|
}
|
|
_patch_scanner(monkeypatch, StubScanner([model]))
|
|
|
|
started = asyncio.Event()
|
|
release = asyncio.Event()
|
|
|
|
async def fake_process_local_examples(*_args, **_kwargs):
|
|
started.set()
|
|
await release.wait()
|
|
return True
|
|
|
|
async def fake_update_metadata(*_args, **_kwargs):
|
|
return True
|
|
|
|
async def fake_get_downloader():
|
|
return object()
|
|
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"process_local_examples",
|
|
staticmethod(fake_process_local_examples),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.MetadataUpdater,
|
|
"update_metadata_from_local_examples",
|
|
staticmethod(fake_update_metadata),
|
|
)
|
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
|
|
|
try:
|
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
|
assert result["success"] is True
|
|
|
|
await asyncio.wait_for(started.wait(), timeout=1)
|
|
|
|
with pytest.raises(download_module.DownloadInProgressError) as exc:
|
|
await manager.start_download({"model_types": ["lora"], "delay": 0})
|
|
|
|
snapshot = exc.value.progress_snapshot
|
|
assert snapshot["status"] == "running"
|
|
assert snapshot["current_model"] == "Example (abc123)"
|
|
|
|
statuses = [payload["status"] for payload in ws_manager.payloads]
|
|
assert "running" in statuses
|
|
|
|
finally:
|
|
release.set()
|
|
if manager._download_task is not None:
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
|
|
|
|
@pytest.mark.usefixtures("tmp_path")
|
|
async def test_pause_resume_blocks_processing(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path,
|
|
settings_manager,
|
|
):
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
|
|
|
models = [
|
|
{
|
|
"sha256": "hash-one",
|
|
"model_name": "Model One",
|
|
"file_path": str(tmp_path / "model-one.safetensors"),
|
|
"file_name": "model-one.safetensors",
|
|
"civitai": {"images": [{"url": "https://example.com/one.png"}]},
|
|
},
|
|
{
|
|
"sha256": "hash-two",
|
|
"model_name": "Model Two",
|
|
"file_path": str(tmp_path / "model-two.safetensors"),
|
|
"file_name": "model-two.safetensors",
|
|
"civitai": {"images": [{"url": "https://example.com/two.png"}]},
|
|
},
|
|
]
|
|
_patch_scanner(monkeypatch, StubScanner(models))
|
|
|
|
async def fake_process_local_examples(*_args, **_kwargs):
|
|
return False
|
|
|
|
async def fake_update_metadata(*_args, **_kwargs):
|
|
return True
|
|
|
|
first_call_started = asyncio.Event()
|
|
first_release = asyncio.Event()
|
|
second_call_started = asyncio.Event()
|
|
call_order: list[str] = []
|
|
|
|
async def fake_download_model_images(model_hash, *_args, **_kwargs):
|
|
call_order.append(model_hash)
|
|
if len(call_order) == 1:
|
|
first_call_started.set()
|
|
await first_release.wait()
|
|
else:
|
|
second_call_started.set()
|
|
return True, False, []
|
|
|
|
async def fake_get_downloader():
|
|
class _Downloader:
|
|
async def download_to_memory(self, *_a, **_kw):
|
|
return True, b"", {}
|
|
|
|
return _Downloader()
|
|
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"process_local_examples",
|
|
staticmethod(fake_process_local_examples),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.MetadataUpdater,
|
|
"update_metadata_from_local_examples",
|
|
staticmethod(fake_update_metadata),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"download_model_images_with_tracking",
|
|
staticmethod(fake_download_model_images),
|
|
)
|
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
|
|
|
original_sleep = download_module.asyncio.sleep
|
|
pause_gate = asyncio.Event()
|
|
resume_gate = asyncio.Event()
|
|
|
|
async def fake_sleep(delay: float):
|
|
if delay == 1:
|
|
pause_gate.set()
|
|
await resume_gate.wait()
|
|
else:
|
|
await original_sleep(delay)
|
|
|
|
monkeypatch.setattr(download_module.asyncio, "sleep", fake_sleep)
|
|
|
|
try:
|
|
await manager.start_download({"model_types": ["lora"], "delay": 0})
|
|
|
|
await asyncio.wait_for(first_call_started.wait(), timeout=1)
|
|
|
|
await manager.pause_download({})
|
|
|
|
first_release.set()
|
|
|
|
await asyncio.wait_for(pause_gate.wait(), timeout=1)
|
|
assert manager._progress["status"] == "paused"
|
|
assert not second_call_started.is_set()
|
|
|
|
statuses = [payload["status"] for payload in ws_manager.payloads]
|
|
paused_index = statuses.index("paused")
|
|
|
|
await asyncio.sleep(0)
|
|
assert not second_call_started.is_set()
|
|
|
|
await manager.resume_download({})
|
|
resume_gate.set()
|
|
|
|
await asyncio.wait_for(second_call_started.wait(), timeout=1)
|
|
|
|
if manager._download_task is not None:
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
|
|
statuses_after = [payload["status"] for payload in ws_manager.payloads]
|
|
running_after = next(
|
|
i for i, status in enumerate(statuses_after[paused_index + 1 :], start=paused_index + 1) if status == "running"
|
|
)
|
|
assert running_after > paused_index
|
|
assert "completed" in statuses_after[running_after:]
|
|
assert call_order == ["hash-one", "hash-two"]
|
|
|
|
finally:
|
|
first_release.set()
|
|
resume_gate.set()
|
|
if manager._download_task is not None:
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
|
|
|
|
|
|
@pytest.mark.usefixtures("tmp_path")
|
|
async def test_legacy_folder_migrated_and_skipped(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path,
|
|
settings_manager,
|
|
):
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
|
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
|
|
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
|
|
|
|
model_hash = "d" * 64
|
|
model_path = tmp_path / "model.safetensors"
|
|
model_path.write_text("data", encoding="utf-8")
|
|
|
|
model = {
|
|
"sha256": model_hash,
|
|
"model_name": "Migrated Model",
|
|
"file_path": str(model_path),
|
|
"file_name": "model.safetensors",
|
|
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
|
}
|
|
|
|
_patch_scanner(monkeypatch, StubScanner([model]))
|
|
|
|
legacy_folder = tmp_path / model_hash
|
|
legacy_folder.mkdir()
|
|
(legacy_folder / "image_0.png").write_text("data", encoding="utf-8")
|
|
|
|
process_called = False
|
|
download_called = False
|
|
|
|
async def fake_process_local_examples(*_args, **_kwargs):
|
|
nonlocal process_called
|
|
process_called = True
|
|
return False
|
|
|
|
async def fake_download_model_images(*_args, **_kwargs):
|
|
nonlocal download_called
|
|
download_called = True
|
|
return True, False, []
|
|
|
|
async def fake_get_downloader():
|
|
class _Downloader:
|
|
async def download_to_memory(self, *_a, **_kw):
|
|
return True, b"", {}
|
|
|
|
return _Downloader()
|
|
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"process_local_examples",
|
|
staticmethod(fake_process_local_examples),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"download_model_images_with_tracking",
|
|
staticmethod(fake_download_model_images),
|
|
)
|
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
|
|
|
try:
|
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
|
assert result["success"] is True
|
|
|
|
if manager._download_task is not None:
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
finally:
|
|
if manager._download_task is not None and not manager._download_task.done():
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
|
|
library_root = tmp_path / "extra"
|
|
migrated_folder = library_root / model_hash
|
|
|
|
assert migrated_folder.exists()
|
|
assert not legacy_folder.exists()
|
|
assert not process_called
|
|
assert not download_called
|
|
assert model_hash in manager._progress["processed_models"]
|
|
|
|
|
|
@pytest.mark.usefixtures("tmp_path")
|
|
async def test_legacy_progress_file_migrates(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path,
|
|
settings_manager,
|
|
):
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
|
monkeypatch.setitem(settings_manager.settings, "libraries", {"default": {}, "extra": {}})
|
|
monkeypatch.setitem(settings_manager.settings, "active_library", "extra")
|
|
|
|
model_hash = "e" * 64
|
|
model_path = tmp_path / "model-two.safetensors"
|
|
model_path.write_text("data", encoding="utf-8")
|
|
|
|
legacy_progress = tmp_path / ".download_progress.json"
|
|
legacy_progress.write_text(json.dumps({"processed_models": [model_hash], "failed_models": []}), encoding="utf-8")
|
|
|
|
legacy_folder = tmp_path / model_hash
|
|
legacy_folder.mkdir()
|
|
(legacy_folder / "existing.png").write_text("data", encoding="utf-8")
|
|
|
|
model = {
|
|
"sha256": model_hash,
|
|
"model_name": "Legacy Progress Model",
|
|
"file_path": str(model_path),
|
|
"file_name": "model-two.safetensors",
|
|
"civitai": {"images": [{"url": "https://example.com/image.png"}]},
|
|
}
|
|
|
|
_patch_scanner(monkeypatch, StubScanner([model]))
|
|
|
|
async def fake_process_local_examples(*_args, **_kwargs):
|
|
return False
|
|
|
|
async def fake_download_model_images(*_args, **_kwargs):
|
|
raise AssertionError("Remote download should not be attempted when progress is migrated")
|
|
|
|
async def fake_get_downloader():
|
|
class _Downloader:
|
|
async def download_to_memory(self, *_a, **_kw):
|
|
return True, b"", {}
|
|
|
|
return _Downloader()
|
|
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"process_local_examples",
|
|
staticmethod(fake_process_local_examples),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"download_model_images_with_tracking",
|
|
staticmethod(fake_download_model_images),
|
|
)
|
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
|
|
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
|
assert result["success"] is True
|
|
|
|
if manager._download_task is not None:
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
|
|
new_progress = (tmp_path / "extra") / ".download_progress.json"
|
|
|
|
assert model_hash in manager._progress["processed_models"]
|
|
assert not legacy_progress.exists()
|
|
assert new_progress.exists()
|
|
contents = json.loads(new_progress.read_text(encoding="utf-8"))
|
|
assert model_hash in contents.get("processed_models", [])
|
|
|
|
|
|
@pytest.mark.usefixtures("tmp_path")
|
|
async def test_download_remains_in_initial_library(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path,
|
|
settings_manager,
|
|
):
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
|
monkeypatch.setitem(settings_manager.settings, "libraries", {"LibraryA": {}, "LibraryB": {}})
|
|
monkeypatch.setitem(settings_manager.settings, "active_library", "LibraryA")
|
|
|
|
state = {"active": "LibraryA"}
|
|
|
|
def fake_get_active_library_name(self):
|
|
return state["active"]
|
|
|
|
monkeypatch.setattr(SettingsManager, "get_active_library_name", fake_get_active_library_name)
|
|
|
|
model_hash = "f" * 64
|
|
model_path = tmp_path / "example-model.safetensors"
|
|
model_path.write_text("data", encoding="utf-8")
|
|
|
|
model = {
|
|
"sha256": model_hash,
|
|
"model_name": "Library Switch Model",
|
|
"file_path": str(model_path),
|
|
"file_name": "example-model.safetensors",
|
|
}
|
|
|
|
_patch_scanner(monkeypatch, StubScanner([model]))
|
|
|
|
async def fake_process_local_examples(
|
|
_file_path,
|
|
_file_name,
|
|
_model_name,
|
|
model_dir,
|
|
_optimize,
|
|
):
|
|
Path(model_dir).mkdir(parents=True, exist_ok=True)
|
|
(Path(model_dir) / "local.txt").write_text("data", encoding="utf-8")
|
|
state["active"] = "LibraryB"
|
|
return True
|
|
|
|
async def fake_update_metadata(*_args, **_kwargs):
|
|
return True
|
|
|
|
async def fake_get_downloader():
|
|
return object()
|
|
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"process_local_examples",
|
|
staticmethod(fake_process_local_examples),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.MetadataUpdater,
|
|
"update_metadata_from_local_examples",
|
|
staticmethod(fake_update_metadata),
|
|
)
|
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
|
|
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
|
assert result["success"] is True
|
|
|
|
if manager._download_task is not None:
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
|
|
library_a_root = tmp_path / "LibraryA"
|
|
library_b_root = tmp_path / "LibraryB"
|
|
|
|
progress_file = library_a_root / ".download_progress.json"
|
|
model_dir = library_a_root / model_hash
|
|
|
|
assert progress_file.exists()
|
|
assert (model_dir / "local.txt").exists()
|
|
assert not (library_b_root / ".download_progress.json").exists()
|
|
assert not (library_b_root / model_hash).exists()
|
|
|
|
@pytest.mark.usefixtures("tmp_path")
|
|
async def test_not_found_example_images_are_cleaned(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tmp_path,
|
|
settings_manager,
|
|
):
|
|
ws_manager = RecordingWebSocketManager()
|
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
|
|
|
images_root = tmp_path / "examples"
|
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(images_root))
|
|
|
|
model_hash = "f" * 64
|
|
model_path = tmp_path / "missing-model.safetensors"
|
|
model_path.write_text("data", encoding="utf-8")
|
|
|
|
missing_url = "https://example.com/missing.png"
|
|
valid_url = "https://example.com/valid.png"
|
|
|
|
model_metadata = {
|
|
"sha256": model_hash,
|
|
"model_name": "Missing Example",
|
|
"file_path": str(model_path),
|
|
"file_name": "missing-model.safetensors",
|
|
"civitai": {
|
|
"images": [
|
|
{"url": missing_url},
|
|
{"url": valid_url},
|
|
]
|
|
},
|
|
}
|
|
|
|
scanner = StubScanner([model_metadata.copy()])
|
|
_patch_scanner(monkeypatch, scanner)
|
|
|
|
model_dir = images_root / model_hash
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
(model_dir / "image_0.png").write_bytes(b"first")
|
|
(model_dir / "image_1.png").write_bytes(b"second")
|
|
|
|
async def fake_process_local_examples(*_args, **_kwargs):
|
|
return False
|
|
|
|
refresh_calls: list[str] = []
|
|
|
|
async def fake_refresh(model_hash_arg, *_args, **_kwargs):
|
|
refresh_calls.append(model_hash_arg)
|
|
return True
|
|
|
|
async def fake_get_updated_model(model_hash_arg, _scanner):
|
|
assert model_hash_arg == model_hash
|
|
return model_metadata
|
|
|
|
async def fake_save_metadata(_path, metadata):
|
|
model_metadata.update(metadata)
|
|
return True
|
|
|
|
class DownloaderStub:
|
|
def __init__(self):
|
|
self.calls: list[str] = []
|
|
|
|
async def download_to_memory(self, url, *_args, **_kwargs):
|
|
self.calls.append(url)
|
|
if url == missing_url:
|
|
return False, "File not found", None
|
|
return True, b"\x89PNG\r\n\x1a\n", {"content-type": "image/png"}
|
|
|
|
downloader = DownloaderStub()
|
|
|
|
async def fake_get_downloader():
|
|
return downloader
|
|
|
|
monkeypatch.setattr(
|
|
download_module.ExampleImagesProcessor,
|
|
"process_local_examples",
|
|
staticmethod(fake_process_local_examples),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.MetadataUpdater,
|
|
"refresh_model_metadata",
|
|
staticmethod(fake_refresh),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.MetadataUpdater,
|
|
"get_updated_model",
|
|
staticmethod(fake_get_updated_model),
|
|
)
|
|
monkeypatch.setattr(
|
|
download_module.MetadataManager,
|
|
"save_metadata",
|
|
staticmethod(fake_save_metadata),
|
|
)
|
|
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
|
monkeypatch.setattr(download_module, "_model_directory_has_files", lambda _path: False)
|
|
|
|
result = await manager.start_download({"model_types": ["lora"], "delay": 0, "optimize": False})
|
|
assert result["success"] is True
|
|
|
|
if manager._download_task is not None:
|
|
await asyncio.wait_for(manager._download_task, timeout=1)
|
|
|
|
assert refresh_calls == [model_hash]
|
|
assert missing_url in downloader.calls
|
|
assert manager._progress["failed_models"] == {model_hash}
|
|
assert model_hash in manager._progress["processed_models"]
|
|
|
|
remaining_images = model_metadata["civitai"]["images"]
|
|
assert remaining_images == [
|
|
{"url": missing_url, "downloadFailed": True, "downloadError": "not_found"},
|
|
{"url": valid_url},
|
|
]
|
|
|
|
files = sorted(p.name for p in model_dir.iterdir())
|
|
assert files == ["image_0.png", "image_1.png"]
|
|
assert (model_dir / "image_0.png").read_bytes() == b"first"
|
|
assert (model_dir / "image_1.png").read_bytes() == b"second"
|
|
|
|
|
|
@pytest.fixture
|
|
def settings_manager():
|
|
return get_settings_manager()
|