diff --git a/py/services/metadata_sync_service.py b/py/services/metadata_sync_service.py index 0a228079..5ce79e3d 100644 --- a/py/services/metadata_sync_service.py +++ b/py/services/metadata_sync_service.py @@ -167,41 +167,101 @@ class MetadataSyncService: metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" enable_archive = self._settings.get("enable_metadata_archive_db", False) + previous_source = model_data.get("metadata_source") or (model_data.get("civitai") or {}).get("source") try: + provider_attempts: list[tuple[Optional[str], MetadataProviderProtocol]] = [] + sqlite_attempted = False + if model_data.get("civitai_deleted") is True: - if not enable_archive or model_data.get("db_checked") is True: + if previous_source in (None, "civarchive"): + try: + provider_attempts.append(("civarchive_api", await self._get_provider("civarchive_api"))) + except Exception as exc: # pragma: no cover - provider resolution fault + logger.debug("Unable to resolve civarchive provider: %s", exc) + + if enable_archive and model_data.get("db_checked") is not True: + try: + provider_attempts.append(("sqlite", await self._get_provider("sqlite"))) + except Exception as exc: # pragma: no cover - provider resolution fault + logger.debug("Unable to resolve sqlite provider: %s", exc) + + if not provider_attempts: if not enable_archive: error_msg = "CivitAI model is deleted and metadata archive DB is not enabled" - else: + elif model_data.get("db_checked") is True: error_msg = "CivitAI model is deleted and not found in metadata archive DB" - return (False, error_msg) - metadata_provider = await self._get_provider("sqlite") + else: + error_msg = "CivitAI model is deleted and no archive provider is available" + return False, error_msg else: - metadata_provider = await self._get_default_provider() + provider_attempts.append((None, await self._get_default_provider())) - civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256) + civitai_metadata: Optional[Dict[str, Any]] = None + metadata_provider: Optional[MetadataProviderProtocol] = None + provider_used: Optional[str] = None + last_error: Optional[str] = None - if not civitai_metadata: - if error == "Model not found": + for provider_name, provider in provider_attempts: + try: + civitai_metadata_candidate, error = await provider.get_model_by_hash(sha256) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Provider %s failed for hash %s: %s", provider_name, sha256, exc) + civitai_metadata_candidate, error = None, str(exc) + + if provider_name == "sqlite": + sqlite_attempted = True + + if civitai_metadata_candidate: + civitai_metadata = civitai_metadata_candidate + metadata_provider = provider + provider_used = provider_name + break + + last_error = error or last_error + + if civitai_metadata is None or metadata_provider is None: + if sqlite_attempted: + model_data["db_checked"] = True + + if last_error == "Model not found": model_data["from_civitai"] = False model_data["civitai_deleted"] = True - model_data["db_checked"] = enable_archive + model_data["db_checked"] = sqlite_attempted or (enable_archive and model_data.get("db_checked", False)) model_data["last_checked_at"] = datetime.now().timestamp() data_to_save = model_data.copy() data_to_save.pop("folder", None) await self._metadata_manager.save_metadata(file_path, data_to_save) + default_error = ( + "CivitAI model is deleted and metadata archive DB is not enabled" + if model_data.get("civitai_deleted") and not enable_archive + else "CivitAI model is deleted and not found in metadata archive DB" + if model_data.get("civitai_deleted") and (model_data.get("db_checked") is True or sqlite_attempted) + else "No provider returned metadata" + ) + error_msg = ( - f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})" + f"Error fetching metadata: {last_error or default_error} " + f"(model_name={model_data.get('model_name', '')})" ) logger.error(error_msg) return False, error_msg model_data["from_civitai"] = True model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db" or civitai_metadata.get("source") == "civarchive" - model_data["db_checked"] = enable_archive and civitai_metadata.get("source") == "archive_db" + model_data["db_checked"] = enable_archive and ( + civitai_metadata.get("source") == "archive_db" or sqlite_attempted + ) + source = civitai_metadata.get("source") or "civitai_api" + if source == "api": + source = "civitai_api" + elif provider_used == "civarchive_api" and source != "civarchive": + source = "civarchive" + elif provider_used == "sqlite": + source = "archive_db" + model_data["metadata_source"] = source model_data["last_checked_at"] = datetime.now().timestamp() local_metadata = model_data.copy() diff --git a/py/utils/models.py b/py/utils/models.py index 159146d5..4caffa3e 100644 --- a/py/utils/models.py +++ b/py/utils/models.py @@ -25,6 +25,7 @@ class BaseModelMetadata: favorite: bool = False # Whether the model is a favorite exclude: bool = False # Whether to exclude this model from the cache db_checked: bool = False # Whether checked in archive DB + metadata_source: Optional[str] = None # Last provider that supplied metadata last_checked_at: float = 0 # Last checked timestamp _unknown_fields: Dict[str, Any] = field(default_factory=dict, repr=False, compare=False) # Store unknown fields diff --git a/tests/services/test_metadata_sync_service.py b/tests/services/test_metadata_sync_service.py index 470259f6..cd3ade58 100644 --- a/tests/services/test_metadata_sync_service.py +++ b/tests/services/test_metadata_sync_service.py @@ -32,6 +32,8 @@ def build_service( get_model_by_hash=AsyncMock(), get_model_version=AsyncMock(), ) + if default_provider is None: + provider.get_model_by_hash.return_value = (None, None) default_provider_factory = AsyncMock(return_value=provider) provider_selector = provider_selector or AsyncMock(return_value=provider) @@ -138,6 +140,7 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path): assert model_data["from_civitai"] is True assert model_data["civitai_deleted"] is False assert "civitai" in model_data + assert model_data["metadata_source"] == "civitai_api" helpers.metadata_manager.hydrate_model_data.assert_not_awaited() assert model_data["hydrated"] is True @@ -219,6 +222,124 @@ async def test_fetch_and_update_model_respects_deleted_without_archive(): update_cache.assert_not_awaited() +@pytest.mark.asyncio +async def test_fetch_and_update_model_prefers_civarchive_for_deleted_models(tmp_path): + default_provider = SimpleNamespace( + get_model_by_hash=AsyncMock(), + get_model_version=AsyncMock(), + ) + civarchive_provider = SimpleNamespace( + get_model_by_hash=AsyncMock( + return_value=( + { + "source": "civarchive", + "model": {"name": "Recovered", "description": "", "tags": []}, + "images": [], + "baseModel": "sdxl", + }, + None, + ) + ), + get_model_version=AsyncMock(), + ) + + async def select_provider(name: str): + return civarchive_provider if name == "civarchive_api" else default_provider + + provider_selector = AsyncMock(side_effect=select_provider) + helpers = build_service( + settings_values={"enable_metadata_archive_db": False}, + default_provider=default_provider, + provider_selector=provider_selector, + ) + + model_path = tmp_path / "model.safetensors" + model_data = { + "civitai_deleted": True, + "metadata_source": "civarchive", + "civitai": {"source": "civarchive"}, + "file_path": str(model_path), + } + update_cache = AsyncMock() + + ok, error = await helpers.service.fetch_and_update_model( + sha256="deadbeef", + file_path=str(model_path), + model_data=model_data, + update_cache_func=update_cache, + ) + + assert ok + assert error is None + provider_selector.assert_awaited_with("civarchive_api") + helpers.default_provider_factory.assert_not_awaited() + civarchive_provider.get_model_by_hash.assert_awaited_once_with("deadbeef") + update_cache.assert_awaited() + assert model_data["metadata_source"] == "civarchive" + helpers.metadata_manager.save_metadata.assert_awaited() + + +@pytest.mark.asyncio +async def test_fetch_and_update_model_falls_back_to_sqlite_after_civarchive_failure(tmp_path): + default_provider = SimpleNamespace( + get_model_by_hash=AsyncMock(), + get_model_version=AsyncMock(), + ) + civarchive_provider = SimpleNamespace( + get_model_by_hash=AsyncMock(return_value=(None, "Model not found")), + get_model_version=AsyncMock(), + ) + sqlite_payload = { + "source": "archive_db", + "model": {"name": "Recovered", "description": "", "tags": []}, + "images": [], + "baseModel": "sdxl", + } + sqlite_provider = SimpleNamespace( + get_model_by_hash=AsyncMock(return_value=(sqlite_payload, None)), + get_model_version=AsyncMock(), + ) + + async def select_provider(name: str): + if name == "civarchive_api": + return civarchive_provider + if name == "sqlite": + return sqlite_provider + return default_provider + + provider_selector = AsyncMock(side_effect=select_provider) + helpers = build_service( + settings_values={"enable_metadata_archive_db": True}, + default_provider=default_provider, + provider_selector=provider_selector, + ) + + model_path = tmp_path / "model.safetensors" + model_data = { + "civitai_deleted": True, + "db_checked": False, + "file_path": str(model_path), + } + update_cache = AsyncMock() + + ok, error = await helpers.service.fetch_and_update_model( + sha256="cafe", + file_path=str(model_path), + model_data=model_data, + update_cache_func=update_cache, + ) + + assert ok and error is None + assert civarchive_provider.get_model_by_hash.await_count == 1 + assert sqlite_provider.get_model_by_hash.await_count == 1 + assert model_data["metadata_source"] == "archive_db" + assert model_data["db_checked"] is True + assert provider_selector.await_args_list[0].args == ("civarchive_api",) + assert provider_selector.await_args_list[1].args == ("sqlite",) + update_cache.assert_awaited() + helpers.metadata_manager.save_metadata.assert_awaited() + + @pytest.mark.asyncio async def test_relink_metadata_fetches_version_and_updates_sha(tmp_path): provider = SimpleNamespace(