diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 5912dc07..5c692be2 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -2,7 +2,7 @@ import asyncio import copy import logging import os -from typing import Optional, Dict, Tuple, List +from typing import Optional, Dict, Tuple, List, Sequence from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .downloader import get_downloader from .errors import RateLimitError @@ -181,6 +181,59 @@ class CivitaiClient: except Exception as e: logger.error(f"Error fetching model versions: {e}") return None + + async def get_model_versions_bulk( + self, model_ids: Sequence[int] + ) -> Optional[Dict[int, Dict]]: + """Fetch model metadata for multiple ids using the batch API.""" + + deduped: Dict[int, None] = {} + for raw_id in model_ids: + try: + normalized = int(raw_id) + except (TypeError, ValueError): + continue + deduped.setdefault(normalized, None) + + normalized_ids = [str(model_id) for model_id in deduped.keys()] + if not normalized_ids: + return {} + + try: + query = ",".join(normalized_ids) + success, result = await self._make_request( + 'GET', + f"{self.base_url}/models", + use_auth=True, + params={'ids': query}, + ) + if not success: + return None + + items = result.get('items') if isinstance(result, dict) else None + if not isinstance(items, list): + return {} + + payload: Dict[int, Dict] = {} + for item in items: + if not isinstance(item, dict): + continue + model_id = item.get('id') + try: + normalized_id = int(model_id) + except (TypeError, ValueError): + continue + payload[normalized_id] = { + 'modelVersions': item.get('modelVersions', []), + 'type': item.get('type', ''), + 'name': item.get('name', ''), + } + return payload + except RateLimitError: + raise + except Exception as exc: + logger.error(f"Error fetching model versions in bulk: {exc}") + return None async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: """Get specific model version with additional metadata.""" diff --git a/py/services/model_metadata_provider.py b/py/services/model_metadata_provider.py index 2c580c59..9cc06300 100644 --- a/py/services/model_metadata_provider.py +++ b/py/services/model_metadata_provider.py @@ -53,6 +53,12 @@ class ModelMetadataProvider(ABC): async def get_model_versions(self, model_id: str) -> Optional[Dict]: """Get all versions of a model with their details""" pass + + async def get_model_versions_bulk( + self, model_ids: Sequence[int] + ) -> Optional[Dict[int, Dict]]: + """Fetch model versions for multiple model ids when supported.""" + raise NotImplementedError @abstractmethod async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: @@ -80,6 +86,11 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider): async def get_model_versions(self, model_id: str) -> Optional[Dict]: return await self.client.get_model_versions(model_id) + + async def get_model_versions_bulk( + self, model_ids: Sequence[int] + ) -> Optional[Dict[int, Dict]]: + return await self.client.get_model_versions_bulk(model_ids) async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: return await self.client.get_model_version(model_id, version_id) @@ -544,7 +555,19 @@ class ModelMetadataProviderManager: """Get model versions using specified or default provider""" provider = self._get_provider(provider_name) return await provider.get_model_versions(model_id) - + + async def get_model_versions_bulk( + self, + model_ids: Sequence[int], + provider_name: str = None, + ) -> Optional[Dict[int, Dict]]: + """Fetch model versions for multiple model ids when supported by provider.""" + provider = self._get_provider(provider_name) + try: + return await provider.get_model_versions_bulk(model_ids) + except NotImplementedError: + return None + async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]: """Get specific model version using specified or default provider""" provider = self._get_provider(provider_name) diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index 3973ce85..4f89e6e9 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -95,6 +95,28 @@ class ModelUpdateService: local_versions = await self._collect_local_versions(scanner) results: Dict[int, ModelUpdateRecord] = {} + prefetched: Dict[int, Mapping] = {} + + fetch_targets: List[int] = [] + if metadata_provider and local_versions: + now = time.time() + async with self._lock: + for model_id in local_versions.keys(): + existing = self._get_record(model_type, model_id) + if existing and existing.should_ignore and not force_refresh: + continue + if force_refresh or not existing or self._is_stale(existing, now): + fetch_targets.append(model_id) + + if fetch_targets: + try: + prefetched = await self._fetch_model_versions_bulk( + metadata_provider, + fetch_targets, + ) + except NotImplementedError: + prefetched = {} + for model_id, version_ids in local_versions.items(): record = await self._refresh_single_model( model_type, @@ -102,6 +124,7 @@ class ModelUpdateService: version_ids, metadata_provider, force_refresh=force_refresh, + prefetched_response=prefetched.get(model_id), ) if record: results[model_id] = record @@ -201,6 +224,7 @@ class ModelUpdateService: metadata_provider, *, force_refresh: bool = False, + prefetched_response: Optional[Mapping] = None, ) -> Optional[ModelUpdateRecord]: normalized_local = self._normalize_sequence(local_versions) now = time.time() @@ -223,25 +247,27 @@ class ModelUpdateService: # release lock during network request fetched_versions: List[int] | None = None refresh_succeeded = False + response: Optional[Mapping] = None if metadata_provider and should_fetch: - try: - response = await metadata_provider.get_model_versions(model_id) - except RateLimitError: - raise - except Exception as exc: # pragma: no cover - defensive log - logger.error( - "Failed to fetch versions for model %s (%s): %s", - model_id, - model_type, - exc, - exc_info=True, - ) - else: - if response is not None: - extracted = self._extract_version_ids(response) - if extracted is not None: - fetched_versions = extracted - refresh_succeeded = True + response = prefetched_response + if response is None: + try: + response = await metadata_provider.get_model_versions(model_id) + except RateLimitError: + raise + except Exception as exc: # pragma: no cover - defensive log + logger.error( + "Failed to fetch versions for model %s (%s): %s", + model_id, + model_type, + exc, + exc_info=True, + ) + if response is not None: + extracted = self._extract_version_ids(response) + if extracted is not None: + fetched_versions = extracted + refresh_succeeded = True async with self._lock: existing = self._get_record(model_type, model_id) @@ -280,6 +306,40 @@ class ModelUpdateService: self._upsert_record(record) return record + async def _fetch_model_versions_bulk( + self, + metadata_provider, + model_ids: Sequence[int], + ) -> Dict[int, Mapping]: + """Fetch model metadata in batches of up to 100 ids.""" + + BATCH_SIZE = 100 + normalized = self._normalize_sequence(model_ids) + if not normalized: + return {} + + aggregated: Dict[int, Mapping] = {} + for index in range(0, len(normalized), BATCH_SIZE): + chunk = normalized[index : index + BATCH_SIZE] + try: + response = await metadata_provider.get_model_versions_bulk(chunk) + except RateLimitError: + raise + if response is None: + continue + if not isinstance(response, Mapping): + logger.debug( + "Unexpected bulk response type %s from provider %s", type(response), metadata_provider + ) + continue + for key, value in response.items(): + normalized_key = self._normalize_int(key) + if normalized_key is None: + continue + if isinstance(value, Mapping): + aggregated[normalized_key] = value + return aggregated + async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]: cache = await scanner.get_cached_data() mapping: Dict[int, set[int]] = {} diff --git a/tests/services/test_civitai_client.py b/tests/services/test_civitai_client.py index 2421a07d..e657fe8f 100644 --- a/tests/services/test_civitai_client.py +++ b/tests/services/test_civitai_client.py @@ -23,8 +23,10 @@ class DummyDownloader: self.memory_calls.append({"url": url, "use_auth": use_auth}) return True, b"bytes", {"content-type": "image/jpeg"} - async def make_request(self, method, url, use_auth=True): - self.request_calls.append({"method": method, "url": url, "use_auth": use_auth}) + async def make_request(self, method, url, use_auth=True, **kwargs): + self.request_calls.append( + {"method": method, "url": url, "use_auth": use_auth, "kwargs": kwargs} + ) return True, {} @@ -72,7 +74,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader): } model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): if url.endswith("by-hash/hash"): return True, version_payload.copy() if url.endswith("/models/123"): @@ -94,7 +96,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader): async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): return False, "not found" downloader.make_request = fake_make_request @@ -108,7 +110,7 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader): - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): return False, RateLimitError("limited", retry_after=4) downloader.make_request = fake_make_request @@ -148,7 +150,7 @@ async def test_download_preview_image_failure(monkeypatch, downloader): async def test_get_model_versions_success(monkeypatch, downloader): - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} downloader.make_request = fake_make_request @@ -160,8 +162,44 @@ async def test_get_model_versions_success(monkeypatch, downloader): assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} +async def test_get_model_versions_bulk_success(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True, **kwargs): + assert url.endswith("/models") + assert kwargs.get("params") == {"ids": "1,2"} + return True, { + "items": [ + {"id": 1, "modelVersions": [{"id": 11}], "type": "LORA", "name": "One"}, + {"id": 2, "modelVersions": [], "type": "Checkpoint", "name": "Two"}, + ] + } + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_versions_bulk([1, "2", 2]) + + assert result == { + 1: {"modelVersions": [{"id": 11}], "type": "LORA", "name": "One"}, + 2: {"modelVersions": [], "type": "Checkpoint", "name": "Two"}, + } + + +async def test_get_model_versions_bulk_handles_errors(monkeypatch, downloader): + async def fake_make_request(method, url, use_auth=True, **kwargs): + return False, "error" + + downloader.make_request = fake_make_request + + client = await CivitaiClient.get_instance() + + result = await client.get_model_versions_bulk([1, 2]) + + assert result is None + + async def test_get_model_version_by_version_id(monkeypatch, downloader): - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): if url.endswith("/model-versions/7"): return True, { "modelId": 321, @@ -219,7 +257,7 @@ async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypa "images": [], } - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): requests.append(url) if url.endswith("/models/99"): return True, copy.deepcopy(model_payload) @@ -273,7 +311,7 @@ async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, do "images": [], } - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): requests.append(url) if url.endswith("/models/99"): return True, copy.deepcopy(model_payload) @@ -315,7 +353,7 @@ async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatc "poi": False, } - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): if url.endswith("/models/99"): return True, copy.deepcopy(model_payload) if url.endswith("/model-versions/7"): @@ -345,7 +383,7 @@ async def test_get_model_version_requires_identifier(monkeypatch, downloader): async def test_get_model_version_info_handles_not_found(monkeypatch, downloader): - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): return False, "not found" downloader.make_request = fake_make_request @@ -361,7 +399,7 @@ async def test_get_model_version_info_handles_not_found(monkeypatch, downloader) async def test_get_model_version_info_success(monkeypatch, downloader): expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]} - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): return True, expected downloader.make_request = fake_make_request @@ -377,7 +415,7 @@ async def test_get_model_version_info_success(monkeypatch, downloader): async def test_get_image_info_returns_first_item(monkeypatch, downloader): - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): return True, {"items": [{"id": 1}, {"id": 2}]} downloader.make_request = fake_make_request @@ -390,7 +428,7 @@ async def test_get_image_info_returns_first_item(monkeypatch, downloader): async def test_get_image_info_handles_missing(monkeypatch, downloader): - async def fake_make_request(method, url, use_auth=True): + async def fake_make_request(method, url, use_auth=True, **kwargs): return True, {"items": []} downloader.make_request = fake_make_request diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py index 3cabcd38..b9e59e65 100644 --- a/tests/services/test_model_update_service.py +++ b/tests/services/test_model_update_service.py @@ -15,14 +15,22 @@ class DummyScanner: class DummyProvider: - def __init__(self, response): + def __init__(self, response, *, support_bulk: bool = True): self.response = response self.calls: int = 0 + self.bulk_calls: list[list[int]] = [] + self.support_bulk = support_bulk async def get_model_versions(self, model_id): self.calls += 1 return self.response + async def get_model_versions_bulk(self, model_ids): + if not self.support_bulk: + raise NotImplementedError + self.bulk_calls.append(list(model_ids)) + return {model_id: self.response for model_id in model_ids} + @pytest.mark.asyncio async def test_refresh_persists_versions_and_uses_cache(tmp_path): @@ -38,14 +46,16 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path): await service.refresh_for_model_type("lora", scanner, provider) record = await service.get_record("lora", 1) - assert provider.calls == 1 + assert provider.calls == 0 + assert provider.bulk_calls == [[1]] assert record is not None assert record.version_ids == [11, 15] assert record.in_library_version_ids == [11, 15] assert record.has_update() is False await service.refresh_for_model_type("lora", scanner, provider) - assert provider.calls == 1, "provider should not be called again within TTL" + assert provider.calls == 0, "provider should not be called again within TTL" + assert provider.bulk_calls == [[1]] @pytest.mark.asyncio @@ -60,8 +70,45 @@ async def test_refresh_respects_ignore_flag(tmp_path): await service.set_should_ignore("lora", 2, True) provider.calls = 0 + provider.bulk_calls = [] await service.refresh_for_model_type("lora", scanner, provider) assert provider.calls == 0 + assert provider.bulk_calls == [] + + +@pytest.mark.asyncio +async def test_refresh_falls_back_when_bulk_not_supported(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=3600) + raw_data = [{"civitai": {"modelId": 4, "id": 41}}] + scanner = DummyScanner(raw_data) + provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False) + + await service.refresh_for_model_type("lora", scanner, provider) + record = await service.get_record("lora", 4) + + assert record is not None + assert provider.calls == 1 + assert provider.bulk_calls == [] + + +@pytest.mark.asyncio +async def test_refresh_batches_large_collections(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path), ttl_seconds=3600) + raw_data = [ + {"civitai": {"modelId": idx, "id": idx * 10}} + for idx in range(1, 151) + ] + scanner = DummyScanner(raw_data) + provider = DummyProvider({"modelVersions": []}) + + await service.refresh_for_model_type("lora", scanner, provider) + + # Expect two batches: 100 ids and remaining 50 ids + assert len(provider.bulk_calls) == 2 + assert len(provider.bulk_calls[0]) == 100 + assert len(provider.bulk_calls[1]) == 50 @pytest.mark.asyncio