feat(metadata): batch refresh model versions

This commit is contained in:
pixelpaws
2025-10-15 20:47:30 +08:00
parent 0968698804
commit 21a1bc1a01
5 changed files with 258 additions and 37 deletions

View File

@@ -95,6 +95,28 @@ class ModelUpdateService:
local_versions = await self._collect_local_versions(scanner)
results: Dict[int, ModelUpdateRecord] = {}
prefetched: Dict[int, Mapping] = {}
fetch_targets: List[int] = []
if metadata_provider and local_versions:
now = time.time()
async with self._lock:
for model_id in local_versions.keys():
existing = self._get_record(model_type, model_id)
if existing and existing.should_ignore and not force_refresh:
continue
if force_refresh or not existing or self._is_stale(existing, now):
fetch_targets.append(model_id)
if fetch_targets:
try:
prefetched = await self._fetch_model_versions_bulk(
metadata_provider,
fetch_targets,
)
except NotImplementedError:
prefetched = {}
for model_id, version_ids in local_versions.items():
record = await self._refresh_single_model(
model_type,
@@ -102,6 +124,7 @@ class ModelUpdateService:
version_ids,
metadata_provider,
force_refresh=force_refresh,
prefetched_response=prefetched.get(model_id),
)
if record:
results[model_id] = record
@@ -201,6 +224,7 @@ class ModelUpdateService:
metadata_provider,
*,
force_refresh: bool = False,
prefetched_response: Optional[Mapping] = None,
) -> Optional[ModelUpdateRecord]:
normalized_local = self._normalize_sequence(local_versions)
now = time.time()
@@ -223,25 +247,27 @@ class ModelUpdateService:
# release lock during network request
fetched_versions: List[int] | None = None
refresh_succeeded = False
response: Optional[Mapping] = None
if metadata_provider and should_fetch:
try:
response = await metadata_provider.get_model_versions(model_id)
except RateLimitError:
raise
except Exception as exc: # pragma: no cover - defensive log
logger.error(
"Failed to fetch versions for model %s (%s): %s",
model_id,
model_type,
exc,
exc_info=True,
)
else:
if response is not None:
extracted = self._extract_version_ids(response)
if extracted is not None:
fetched_versions = extracted
refresh_succeeded = True
response = prefetched_response
if response is None:
try:
response = await metadata_provider.get_model_versions(model_id)
except RateLimitError:
raise
except Exception as exc: # pragma: no cover - defensive log
logger.error(
"Failed to fetch versions for model %s (%s): %s",
model_id,
model_type,
exc,
exc_info=True,
)
if response is not None:
extracted = self._extract_version_ids(response)
if extracted is not None:
fetched_versions = extracted
refresh_succeeded = True
async with self._lock:
existing = self._get_record(model_type, model_id)
@@ -280,6 +306,40 @@ class ModelUpdateService:
self._upsert_record(record)
return record
async def _fetch_model_versions_bulk(
self,
metadata_provider,
model_ids: Sequence[int],
) -> Dict[int, Mapping]:
"""Fetch model metadata in batches of up to 100 ids."""
BATCH_SIZE = 100
normalized = self._normalize_sequence(model_ids)
if not normalized:
return {}
aggregated: Dict[int, Mapping] = {}
for index in range(0, len(normalized), BATCH_SIZE):
chunk = normalized[index : index + BATCH_SIZE]
try:
response = await metadata_provider.get_model_versions_bulk(chunk)
except RateLimitError:
raise
if response is None:
continue
if not isinstance(response, Mapping):
logger.debug(
"Unexpected bulk response type %s from provider %s", type(response), metadata_provider
)
continue
for key, value in response.items():
normalized_key = self._normalize_int(key)
if normalized_key is None:
continue
if isinstance(value, Mapping):
aggregated[normalized_key] = value
return aggregated
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
cache = await scanner.get_cached_data()
mapping: Dict[int, set[int]] = {}