diff --git a/tests/services/test_metadata_sync_service.py b/tests/services/test_metadata_sync_service.py new file mode 100644 index 00000000..236f0ac4 --- /dev/null +++ b/tests/services/test_metadata_sync_service.py @@ -0,0 +1,227 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from py.services.metadata_sync_service import MetadataSyncService + + +class DummySettings: + def __init__(self, values: dict | None = None) -> None: + self._values = values or {} + + def get(self, key: str, default=None): + return self._values.get(key, default) + + +def build_service( + *, + settings_values: dict | None = None, + default_provider: SimpleNamespace | None = None, + provider_selector: AsyncMock | None = None, +): + metadata_manager = SimpleNamespace(save_metadata=AsyncMock()) + preview_service = SimpleNamespace(ensure_preview_for_metadata=AsyncMock()) + settings = DummySettings(settings_values) + + provider = default_provider or SimpleNamespace( + get_model_by_hash=AsyncMock(), + get_model_version=AsyncMock(), + ) + + default_provider_factory = AsyncMock(return_value=provider) + provider_selector = provider_selector or AsyncMock(return_value=provider) + + service = MetadataSyncService( + metadata_manager=metadata_manager, + preview_service=preview_service, + settings=settings, + default_metadata_provider_factory=default_provider_factory, + metadata_provider_selector=provider_selector, + ) + + return SimpleNamespace( + service=service, + metadata_manager=metadata_manager, + preview_service=preview_service, + default_provider=provider, + default_provider_factory=default_provider_factory, + provider_selector=provider_selector, + ) + + +@pytest.mark.asyncio +async def test_update_model_metadata_merges_and_persists(): + helpers = build_service() + + local = { + "civitai": {"trainedWords": ["alpha"], "creator": {"id": 1}}, + "modelDescription": "", + "tags": [], + "model_name": "Local", + } + remote = { + "source": "api", + "trainedWords": ["beta"], + "model": { + "name": "Remote Model", + "description": "desc", + "tags": ["style"], + "creator": {"id": 2}, + }, + "baseModel": "sdxl", + "images": ["img"], + } + + result = await helpers.service.update_model_metadata( + "path/to/model.metadata.json", + local, + remote, + helpers.default_provider, + ) + + assert set(result["civitai"]["trainedWords"]) == {"alpha", "beta"} + assert result["model_name"] == "Remote Model" + assert result["modelDescription"] == "desc" + assert result["tags"] == ["style"] + assert result["base_model"] == "SDXL 1.0" + + helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once() + helpers.metadata_manager.save_metadata.assert_awaited_once_with( + "path/to/model.metadata.json", + result, + ) + + +@pytest.mark.asyncio +async def test_fetch_and_update_model_success_updates_cache(tmp_path): + helpers = build_service() + + civitai_payload = { + "source": "api", + "model": {"name": "Remote", "description": "", "tags": ["tag"]}, + "images": [], + "baseModel": "sdxl", + } + helpers.default_provider.get_model_by_hash.return_value = (civitai_payload, None) + + model_data = {"model_name": "Local", "folder": "root"} + update_cache = AsyncMock(return_value=True) + + ok, error = await helpers.service.fetch_and_update_model( + sha256="abc", + file_path=str(tmp_path / "model.safetensors"), + model_data=model_data, + update_cache_func=update_cache, + ) + + assert ok and error is None + assert model_data["from_civitai"] is True + assert model_data["civitai_deleted"] is False + assert "civitai" in model_data + + metadata_path = str(tmp_path / "model.metadata.json") + await_args = helpers.metadata_manager.save_metadata.await_args_list + assert await_args, "expected metadata to be persisted" + assert await_args[-1][0][0] == metadata_path + update_cache.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_fetch_and_update_model_handles_missing_remote_metadata(tmp_path): + helpers = build_service() + helpers.default_provider.get_model_by_hash.return_value = (None, "Model not found") + + model_data = { + "model_name": "Local", + "folder": "sub", + } + + ok, error = await helpers.service.fetch_and_update_model( + sha256="missing", + file_path=str(tmp_path / "model.safetensors"), + model_data=model_data, + update_cache_func=AsyncMock(), + ) + + assert not ok + assert "Model not found" in error + assert model_data["from_civitai"] is False + assert model_data["civitai_deleted"] is True + + helpers.metadata_manager.save_metadata.assert_awaited_once() + args, _ = helpers.metadata_manager.save_metadata.await_args + assert args[0].endswith("model.safetensors") + assert "folder" not in args[1] + + +@pytest.mark.asyncio +async def test_fetch_and_update_model_respects_deleted_without_archive(): + helpers = build_service(settings_values={"enable_metadata_archive_db": False}) + + model_data = {"civitai_deleted": True} + update_cache = AsyncMock() + + ok, error = await helpers.service.fetch_and_update_model( + sha256="abc", + file_path="/tmp/model.safetensors", + model_data=model_data, + update_cache_func=update_cache, + ) + + assert not ok + assert "metadata archive DB is not enabled" in error + helpers.default_provider_factory.assert_not_awaited() + update_cache.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_relink_metadata_fetches_version_and_updates_sha(tmp_path): + provider = SimpleNamespace( + get_model_by_hash=AsyncMock(), + get_model_version=AsyncMock( + return_value={ + "files": [ + { + "primary": True, + "type": "Model", + "hashes": {"SHA256": "ABCDEF"}, + } + ], + "model": {"name": "Remote"}, + "images": [], + } + ), + ) + + helpers = build_service(default_provider=provider) + + metadata = {"model_name": "Local"} + result = await helpers.service.relink_metadata( + file_path=str(tmp_path / "model.safetensors"), + metadata=metadata, + model_id=1, + model_version_id=2, + ) + + assert result["sha256"] == "abcdef" + helpers.metadata_manager.save_metadata.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_relink_metadata_raises_when_version_missing(): + provider = SimpleNamespace( + get_model_by_hash=AsyncMock(), + get_model_version=AsyncMock(return_value=None), + ) + + helpers = build_service(default_provider=provider) + + with pytest.raises(ValueError): + await helpers.service.relink_metadata( + file_path="/tmp/model.safetensors", + metadata={}, + model_id=9, + model_version_id=None, + ) + diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py index dba82c41..a077caeb 100644 --- a/tests/services/test_model_scanner.py +++ b/tests/services/test_model_scanner.py @@ -379,3 +379,50 @@ async def test_batch_delete_persists_removal(tmp_path: Path, monkeypatch): ).fetchone()[0] assert remaining == 0 + + +@pytest.mark.asyncio +async def test_reconcile_cache_adds_new_files_and_updates_hash_index(tmp_path: Path): + first, _, _ = _create_files(tmp_path) + scanner = DummyScanner(tmp_path) + + await scanner._initialize_cache() + await scanner.get_cached_data() + + new_file = tmp_path / "three.txt" + new_file.write_text("three", encoding="utf-8") + (tmp_path / "nested" / "two.txt").unlink() + + await scanner._reconcile_cache() + + cache = await scanner.get_cached_data() + cached_paths = {item["file_path"] for item in cache.raw_data} + + assert cached_paths == { + _normalize_path(first), + _normalize_path(new_file), + } + assert scanner._hash_index.get_path("hash-three") == _normalize_path(new_file) + assert scanner._hash_index.get_path("hash-two") is None + assert scanner._tags_count == {"alpha": 1, "beta": 1} + assert cache.folders == [""] + + +@pytest.mark.asyncio +async def test_count_model_files_handles_symlink_loops(tmp_path: Path): + scanner = DummyScanner(tmp_path) + + root_file = tmp_path / "root.txt" + root_file.write_text("root", encoding="utf-8") + + subdir = tmp_path / "sub" + subdir.mkdir() + nested_file = subdir / "nested.txt" + nested_file.write_text("nested", encoding="utf-8") + + loop_link = subdir / "loop" + loop_link.symlink_to(tmp_path) + + count = scanner._count_model_files() + + assert count == 2