mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #525 from willmiao/codex/develop-tests-for-metadatasyncservice-and-modelscanner
Add coverage for metadata sync service and scanner reconciliation
This commit is contained in:
227
tests/services/test_metadata_sync_service.py
Normal file
227
tests/services/test_metadata_sync_service.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.metadata_sync_service import MetadataSyncService
|
||||
|
||||
|
||||
class DummySettings:
|
||||
def __init__(self, values: dict | None = None) -> None:
|
||||
self._values = values or {}
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
def build_service(
|
||||
*,
|
||||
settings_values: dict | None = None,
|
||||
default_provider: SimpleNamespace | None = None,
|
||||
provider_selector: AsyncMock | None = None,
|
||||
):
|
||||
metadata_manager = SimpleNamespace(save_metadata=AsyncMock())
|
||||
preview_service = SimpleNamespace(ensure_preview_for_metadata=AsyncMock())
|
||||
settings = DummySettings(settings_values)
|
||||
|
||||
provider = default_provider or SimpleNamespace(
|
||||
get_model_by_hash=AsyncMock(),
|
||||
get_model_version=AsyncMock(),
|
||||
)
|
||||
|
||||
default_provider_factory = AsyncMock(return_value=provider)
|
||||
provider_selector = provider_selector or AsyncMock(return_value=provider)
|
||||
|
||||
service = MetadataSyncService(
|
||||
metadata_manager=metadata_manager,
|
||||
preview_service=preview_service,
|
||||
settings=settings,
|
||||
default_metadata_provider_factory=default_provider_factory,
|
||||
metadata_provider_selector=provider_selector,
|
||||
)
|
||||
|
||||
return SimpleNamespace(
|
||||
service=service,
|
||||
metadata_manager=metadata_manager,
|
||||
preview_service=preview_service,
|
||||
default_provider=provider,
|
||||
default_provider_factory=default_provider_factory,
|
||||
provider_selector=provider_selector,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_model_metadata_merges_and_persists():
|
||||
helpers = build_service()
|
||||
|
||||
local = {
|
||||
"civitai": {"trainedWords": ["alpha"], "creator": {"id": 1}},
|
||||
"modelDescription": "",
|
||||
"tags": [],
|
||||
"model_name": "Local",
|
||||
}
|
||||
remote = {
|
||||
"source": "api",
|
||||
"trainedWords": ["beta"],
|
||||
"model": {
|
||||
"name": "Remote Model",
|
||||
"description": "desc",
|
||||
"tags": ["style"],
|
||||
"creator": {"id": 2},
|
||||
},
|
||||
"baseModel": "sdxl",
|
||||
"images": ["img"],
|
||||
}
|
||||
|
||||
result = await helpers.service.update_model_metadata(
|
||||
"path/to/model.metadata.json",
|
||||
local,
|
||||
remote,
|
||||
helpers.default_provider,
|
||||
)
|
||||
|
||||
assert set(result["civitai"]["trainedWords"]) == {"alpha", "beta"}
|
||||
assert result["model_name"] == "Remote Model"
|
||||
assert result["modelDescription"] == "desc"
|
||||
assert result["tags"] == ["style"]
|
||||
assert result["base_model"] == "SDXL 1.0"
|
||||
|
||||
helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once()
|
||||
helpers.metadata_manager.save_metadata.assert_awaited_once_with(
|
||||
"path/to/model.metadata.json",
|
||||
result,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_update_model_success_updates_cache(tmp_path):
|
||||
helpers = build_service()
|
||||
|
||||
civitai_payload = {
|
||||
"source": "api",
|
||||
"model": {"name": "Remote", "description": "", "tags": ["tag"]},
|
||||
"images": [],
|
||||
"baseModel": "sdxl",
|
||||
}
|
||||
helpers.default_provider.get_model_by_hash.return_value = (civitai_payload, None)
|
||||
|
||||
model_data = {"model_name": "Local", "folder": "root"}
|
||||
update_cache = AsyncMock(return_value=True)
|
||||
|
||||
ok, error = await helpers.service.fetch_and_update_model(
|
||||
sha256="abc",
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
|
||||
assert ok and error is None
|
||||
assert model_data["from_civitai"] is True
|
||||
assert model_data["civitai_deleted"] is False
|
||||
assert "civitai" in model_data
|
||||
|
||||
metadata_path = str(tmp_path / "model.metadata.json")
|
||||
await_args = helpers.metadata_manager.save_metadata.await_args_list
|
||||
assert await_args, "expected metadata to be persisted"
|
||||
assert await_args[-1][0][0] == metadata_path
|
||||
update_cache.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_update_model_handles_missing_remote_metadata(tmp_path):
|
||||
helpers = build_service()
|
||||
helpers.default_provider.get_model_by_hash.return_value = (None, "Model not found")
|
||||
|
||||
model_data = {
|
||||
"model_name": "Local",
|
||||
"folder": "sub",
|
||||
}
|
||||
|
||||
ok, error = await helpers.service.fetch_and_update_model(
|
||||
sha256="missing",
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
model_data=model_data,
|
||||
update_cache_func=AsyncMock(),
|
||||
)
|
||||
|
||||
assert not ok
|
||||
assert "Model not found" in error
|
||||
assert model_data["from_civitai"] is False
|
||||
assert model_data["civitai_deleted"] is True
|
||||
|
||||
helpers.metadata_manager.save_metadata.assert_awaited_once()
|
||||
args, _ = helpers.metadata_manager.save_metadata.await_args
|
||||
assert args[0].endswith("model.safetensors")
|
||||
assert "folder" not in args[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_and_update_model_respects_deleted_without_archive():
|
||||
helpers = build_service(settings_values={"enable_metadata_archive_db": False})
|
||||
|
||||
model_data = {"civitai_deleted": True}
|
||||
update_cache = AsyncMock()
|
||||
|
||||
ok, error = await helpers.service.fetch_and_update_model(
|
||||
sha256="abc",
|
||||
file_path="/tmp/model.safetensors",
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
|
||||
assert not ok
|
||||
assert "metadata archive DB is not enabled" in error
|
||||
helpers.default_provider_factory.assert_not_awaited()
|
||||
update_cache.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relink_metadata_fetches_version_and_updates_sha(tmp_path):
|
||||
provider = SimpleNamespace(
|
||||
get_model_by_hash=AsyncMock(),
|
||||
get_model_version=AsyncMock(
|
||||
return_value={
|
||||
"files": [
|
||||
{
|
||||
"primary": True,
|
||||
"type": "Model",
|
||||
"hashes": {"SHA256": "ABCDEF"},
|
||||
}
|
||||
],
|
||||
"model": {"name": "Remote"},
|
||||
"images": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
helpers = build_service(default_provider=provider)
|
||||
|
||||
metadata = {"model_name": "Local"}
|
||||
result = await helpers.service.relink_metadata(
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
metadata=metadata,
|
||||
model_id=1,
|
||||
model_version_id=2,
|
||||
)
|
||||
|
||||
assert result["sha256"] == "abcdef"
|
||||
helpers.metadata_manager.save_metadata.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relink_metadata_raises_when_version_missing():
|
||||
provider = SimpleNamespace(
|
||||
get_model_by_hash=AsyncMock(),
|
||||
get_model_version=AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
helpers = build_service(default_provider=provider)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await helpers.service.relink_metadata(
|
||||
file_path="/tmp/model.safetensors",
|
||||
metadata={},
|
||||
model_id=9,
|
||||
model_version_id=None,
|
||||
)
|
||||
|
||||
@@ -379,3 +379,50 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch):
|
||||
).fetchone()[0]
|
||||
|
||||
assert remaining == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path):
|
||||
first, _, _ = _create_files(tmp_path)
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
await scanner._initialize_cache()
|
||||
await scanner.get_cached_data()
|
||||
|
||||
new_file = tmp_path / "three.txt"
|
||||
new_file.write_text("three", encoding="utf-8")
|
||||
(tmp_path / "nested" / "two.txt").unlink()
|
||||
|
||||
await scanner._reconcile_cache()
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
cached_paths = {item["file_path"] for item in cache.raw_data}
|
||||
|
||||
assert cached_paths == {
|
||||
_normalize_path(first),
|
||||
_normalize_path(new_file),
|
||||
}
|
||||
assert scanner._hash_index.get_path("hash-three") == _normalize_path(new_file)
|
||||
assert scanner._hash_index.get_path("hash-two") is None
|
||||
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
||||
assert cache.folders == [""]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_model_files_handles_symlink_loops(tmp_path: Path):
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
root_file = tmp_path / "root.txt"
|
||||
root_file.write_text("root", encoding="utf-8")
|
||||
|
||||
subdir = tmp_path / "sub"
|
||||
subdir.mkdir()
|
||||
nested_file = subdir / "nested.txt"
|
||||
nested_file.write_text("nested", encoding="utf-8")
|
||||
|
||||
loop_link = subdir / "loop"
|
||||
loop_link.symlink_to(tmp_path)
|
||||
|
||||
count = scanner._count_model_files()
|
||||
|
||||
assert count == 2
|
||||
|
||||
Reference in New Issue
Block a user