Merge pull request #573 from willmiao/codex/add-batch-model-version-retrieval

feat: batch model update refresh using Civitai bulk API
This commit is contained in:
pixelpaws
2025-10-15 20:55:53 +08:00
committed by GitHub
5 changed files with 258 additions and 37 deletions

View File

@@ -2,7 +2,7 @@ import asyncio
import copy import copy
import logging import logging
import os import os
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List, Sequence
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader from .downloader import get_downloader
from .errors import RateLimitError from .errors import RateLimitError
@@ -182,6 +182,59 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
async def get_model_versions_bulk(
self, model_ids: Sequence[int]
) -> Optional[Dict[int, Dict]]:
"""Fetch model metadata for multiple ids using the batch API."""
deduped: Dict[int, None] = {}
for raw_id in model_ids:
try:
normalized = int(raw_id)
except (TypeError, ValueError):
continue
deduped.setdefault(normalized, None)
normalized_ids = [str(model_id) for model_id in deduped.keys()]
if not normalized_ids:
return {}
try:
query = ",".join(normalized_ids)
success, result = await self._make_request(
'GET',
f"{self.base_url}/models",
use_auth=True,
params={'ids': query},
)
if not success:
return None
items = result.get('items') if isinstance(result, dict) else None
if not isinstance(items, list):
return {}
payload: Dict[int, Dict] = {}
for item in items:
if not isinstance(item, dict):
continue
model_id = item.get('id')
try:
normalized_id = int(model_id)
except (TypeError, ValueError):
continue
payload[normalized_id] = {
'modelVersions': item.get('modelVersions', []),
'type': item.get('type', ''),
'name': item.get('name', ''),
}
return payload
except RateLimitError:
raise
except Exception as exc:
logger.error(f"Error fetching model versions in bulk: {exc}")
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata.""" """Get specific model version with additional metadata."""
try: try:

View File

@@ -54,6 +54,12 @@ class ModelMetadataProvider(ABC):
"""Get all versions of a model with their details""" """Get all versions of a model with their details"""
pass pass
async def get_model_versions_bulk(
self, model_ids: Sequence[int]
) -> Optional[Dict[int, Dict]]:
"""Fetch model versions for multiple model ids when supported."""
raise NotImplementedError
@abstractmethod @abstractmethod
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata""" """Get specific model version with additional metadata"""
@@ -81,6 +87,11 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
async def get_model_versions(self, model_id: str) -> Optional[Dict]: async def get_model_versions(self, model_id: str) -> Optional[Dict]:
return await self.client.get_model_versions(model_id) return await self.client.get_model_versions(model_id)
async def get_model_versions_bulk(
self, model_ids: Sequence[int]
) -> Optional[Dict[int, Dict]]:
return await self.client.get_model_versions_bulk(model_ids)
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
return await self.client.get_model_version(model_id, version_id) return await self.client.get_model_version(model_id, version_id)
@@ -545,6 +556,18 @@ class ModelMetadataProviderManager:
provider = self._get_provider(provider_name) provider = self._get_provider(provider_name)
return await provider.get_model_versions(model_id) return await provider.get_model_versions(model_id)
async def get_model_versions_bulk(
self,
model_ids: Sequence[int],
provider_name: str = None,
) -> Optional[Dict[int, Dict]]:
"""Fetch model versions for multiple model ids when supported by provider."""
provider = self._get_provider(provider_name)
try:
return await provider.get_model_versions_bulk(model_ids)
except NotImplementedError:
return None
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]: async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
"""Get specific model version using specified or default provider""" """Get specific model version using specified or default provider"""
provider = self._get_provider(provider_name) provider = self._get_provider(provider_name)

View File

@@ -95,6 +95,28 @@ class ModelUpdateService:
local_versions = await self._collect_local_versions(scanner) local_versions = await self._collect_local_versions(scanner)
results: Dict[int, ModelUpdateRecord] = {} results: Dict[int, ModelUpdateRecord] = {}
prefetched: Dict[int, Mapping] = {}
fetch_targets: List[int] = []
if metadata_provider and local_versions:
now = time.time()
async with self._lock:
for model_id in local_versions.keys():
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
continue
if force_refresh or not existing or self._is_stale(existing, now):
fetch_targets.append(model_id)
if fetch_targets:
try:
prefetched = await self._fetch_model_versions_bulk(
metadata_provider,
fetch_targets,
)
except NotImplementedError:
prefetched = {}
for model_id, version_ids in local_versions.items(): for model_id, version_ids in local_versions.items():
record = await self._refresh_single_model( record = await self._refresh_single_model(
model_type, model_type,
@@ -102,6 +124,7 @@ class ModelUpdateService:
version_ids, version_ids,
metadata_provider, metadata_provider,
force_refresh=force_refresh, force_refresh=force_refresh,
prefetched_response=prefetched.get(model_id),
) )
if record: if record:
results[model_id] = record results[model_id] = record
@@ -201,6 +224,7 @@ class ModelUpdateService:
metadata_provider, metadata_provider,
*, *,
force_refresh: bool = False, force_refresh: bool = False,
prefetched_response: Optional[Mapping] = None,
) -> Optional[ModelUpdateRecord]: ) -> Optional[ModelUpdateRecord]:
normalized_local = self._normalize_sequence(local_versions) normalized_local = self._normalize_sequence(local_versions)
now = time.time() now = time.time()
@@ -223,7 +247,10 @@ class ModelUpdateService:
# release lock during network request # release lock during network request
fetched_versions: List[int] | None = None fetched_versions: List[int] | None = None
refresh_succeeded = False refresh_succeeded = False
response: Optional[Mapping] = None
if metadata_provider and should_fetch: if metadata_provider and should_fetch:
response = prefetched_response
if response is None:
try: try:
response = await metadata_provider.get_model_versions(model_id) response = await metadata_provider.get_model_versions(model_id)
except RateLimitError: except RateLimitError:
@@ -236,7 +263,6 @@ class ModelUpdateService:
exc, exc,
exc_info=True, exc_info=True,
) )
else:
if response is not None: if response is not None:
extracted = self._extract_version_ids(response) extracted = self._extract_version_ids(response)
if extracted is not None: if extracted is not None:
@@ -280,6 +306,40 @@ class ModelUpdateService:
self._upsert_record(record) self._upsert_record(record)
return record return record
async def _fetch_model_versions_bulk(
self,
metadata_provider,
model_ids: Sequence[int],
) -> Dict[int, Mapping]:
"""Fetch model metadata in batches of up to 100 ids."""
BATCH_SIZE = 100
normalized = self._normalize_sequence(model_ids)
if not normalized:
return {}
aggregated: Dict[int, Mapping] = {}
for index in range(0, len(normalized), BATCH_SIZE):
chunk = normalized[index : index + BATCH_SIZE]
try:
response = await metadata_provider.get_model_versions_bulk(chunk)
except RateLimitError:
raise
if response is None:
continue
if not isinstance(response, Mapping):
logger.debug(
"Unexpected bulk response type %s from provider %s", type(response), metadata_provider
)
continue
for key, value in response.items():
normalized_key = self._normalize_int(key)
if normalized_key is None:
continue
if isinstance(value, Mapping):
aggregated[normalized_key] = value
return aggregated
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]: async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
mapping: Dict[int, set[int]] = {} mapping: Dict[int, set[int]] = {}

View File

@@ -23,8 +23,10 @@ class DummyDownloader:
self.memory_calls.append({"url": url, "use_auth": use_auth}) self.memory_calls.append({"url": url, "use_auth": use_auth})
return True, b"bytes", {"content-type": "image/jpeg"} return True, b"bytes", {"content-type": "image/jpeg"}
async def make_request(self, method, url, use_auth=True): async def make_request(self, method, url, use_auth=True, **kwargs):
self.request_calls.append({"method": method, "url": url, "use_auth": use_auth}) self.request_calls.append(
{"method": method, "url": url, "use_auth": use_auth, "kwargs": kwargs}
)
return True, {} return True, {}
@@ -72,7 +74,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
} }
model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}} model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
if url.endswith("by-hash/hash"): if url.endswith("by-hash/hash"):
return True, version_payload.copy() return True, version_payload.copy()
if url.endswith("/models/123"): if url.endswith("/models/123"):
@@ -94,7 +96,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader): async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, "not found" return False, "not found"
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -108,7 +110,7 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader): async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, RateLimitError("limited", retry_after=4) return False, RateLimitError("limited", retry_after=4)
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -148,7 +150,7 @@ async def test_download_preview_image_failure(monkeypatch, downloader):
async def test_get_model_versions_success(monkeypatch, downloader): async def test_get_model_versions_success(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -160,8 +162,44 @@ async def test_get_model_versions_success(monkeypatch, downloader):
assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"} assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
async def test_get_model_versions_bulk_success(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True, **kwargs):
assert url.endswith("/models")
assert kwargs.get("params") == {"ids": "1,2"}
return True, {
"items": [
{"id": 1, "modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
{"id": 2, "modelVersions": [], "type": "Checkpoint", "name": "Two"},
]
}
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_versions_bulk([1, "2", 2])
assert result == {
1: {"modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
2: {"modelVersions": [], "type": "Checkpoint", "name": "Two"},
}
async def test_get_model_versions_bulk_handles_errors(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, "error"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_versions_bulk([1, 2])
assert result is None
async def test_get_model_version_by_version_id(monkeypatch, downloader): async def test_get_model_version_by_version_id(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
if url.endswith("/model-versions/7"): if url.endswith("/model-versions/7"):
return True, { return True, {
"modelId": 321, "modelId": 321,
@@ -219,7 +257,7 @@ async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypa
"images": [], "images": [],
} }
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
requests.append(url) requests.append(url)
if url.endswith("/models/99"): if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload) return True, copy.deepcopy(model_payload)
@@ -273,7 +311,7 @@ async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, do
"images": [], "images": [],
} }
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
requests.append(url) requests.append(url)
if url.endswith("/models/99"): if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload) return True, copy.deepcopy(model_payload)
@@ -315,7 +353,7 @@ async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatc
"poi": False, "poi": False,
} }
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
if url.endswith("/models/99"): if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload) return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"): if url.endswith("/model-versions/7"):
@@ -345,7 +383,7 @@ async def test_get_model_version_requires_identifier(monkeypatch, downloader):
async def test_get_model_version_info_handles_not_found(monkeypatch, downloader): async def test_get_model_version_info_handles_not_found(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, "not found" return False, "not found"
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -361,7 +399,7 @@ async def test_get_model_version_info_handles_not_found(monkeypatch, downloader)
async def test_get_model_version_info_success(monkeypatch, downloader): async def test_get_model_version_info_success(monkeypatch, downloader):
expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]} expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]}
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, expected return True, expected
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -377,7 +415,7 @@ async def test_get_model_version_info_success(monkeypatch, downloader):
async def test_get_image_info_returns_first_item(monkeypatch, downloader): async def test_get_image_info_returns_first_item(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, {"items": [{"id": 1}, {"id": 2}]} return True, {"items": [{"id": 1}, {"id": 2}]}
downloader.make_request = fake_make_request downloader.make_request = fake_make_request
@@ -390,7 +428,7 @@ async def test_get_image_info_returns_first_item(monkeypatch, downloader):
async def test_get_image_info_handles_missing(monkeypatch, downloader): async def test_get_image_info_handles_missing(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True): async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, {"items": []} return True, {"items": []}
downloader.make_request = fake_make_request downloader.make_request = fake_make_request

View File

@@ -15,14 +15,22 @@ class DummyScanner:
class DummyProvider: class DummyProvider:
def __init__(self, response): def __init__(self, response, *, support_bulk: bool = True):
self.response = response self.response = response
self.calls: int = 0 self.calls: int = 0
self.bulk_calls: list[list[int]] = []
self.support_bulk = support_bulk
async def get_model_versions(self, model_id): async def get_model_versions(self, model_id):
self.calls += 1 self.calls += 1
return self.response return self.response
async def get_model_versions_bulk(self, model_ids):
if not self.support_bulk:
raise NotImplementedError
self.bulk_calls.append(list(model_ids))
return {model_id: self.response for model_id in model_ids}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_refresh_persists_versions_and_uses_cache(tmp_path): async def test_refresh_persists_versions_and_uses_cache(tmp_path):
@@ -38,14 +46,16 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
await service.refresh_for_model_type("lora", scanner, provider) await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 1) record = await service.get_record("lora", 1)
assert provider.calls == 1 assert provider.calls == 0
assert provider.bulk_calls == [[1]]
assert record is not None assert record is not None
assert record.version_ids == [11, 15] assert record.version_ids == [11, 15]
assert record.in_library_version_ids == [11, 15] assert record.in_library_version_ids == [11, 15]
assert record.has_update() is False assert record.has_update() is False
await service.refresh_for_model_type("lora", scanner, provider) await service.refresh_for_model_type("lora", scanner, provider)
assert provider.calls == 1, "provider should not be called again within TTL" assert provider.calls == 0, "provider should not be called again within TTL"
assert provider.bulk_calls == [[1]]
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -60,8 +70,45 @@ async def test_refresh_respects_ignore_flag(tmp_path):
await service.set_should_ignore("lora", 2, True) await service.set_should_ignore("lora", 2, True)
provider.calls = 0 provider.calls = 0
provider.bulk_calls = []
await service.refresh_for_model_type("lora", scanner, provider) await service.refresh_for_model_type("lora", scanner, provider)
assert provider.calls == 0 assert provider.calls == 0
assert provider.bulk_calls == []
@pytest.mark.asyncio
async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False)
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 4)
assert record is not None
assert provider.calls == 1
assert provider.bulk_calls == []
@pytest.mark.asyncio
async def test_refresh_batches_large_collections(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [
{"civitai": {"modelId": idx, "id": idx * 10}}
for idx in range(1, 151)
]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": []})
await service.refresh_for_model_type("lora", scanner, provider)
# Expect two batches: 100 ids and remaining 50 ids
assert len(provider.bulk_calls) == 2
assert len(provider.bulk_calls[0]) == 100
assert len(provider.bulk_calls[1]) == 50
@pytest.mark.asyncio @pytest.mark.asyncio