mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat(metadata): batch refresh model versions
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from typing import Optional, Dict, Tuple, List, Sequence
|
||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
from .errors import RateLimitError
|
||||
@@ -181,6 +181,59 @@ class CivitaiClient:
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_versions_bulk(
|
||||
self, model_ids: Sequence[int]
|
||||
) -> Optional[Dict[int, Dict]]:
|
||||
"""Fetch model metadata for multiple ids using the batch API."""
|
||||
|
||||
deduped: Dict[int, None] = {}
|
||||
for raw_id in model_ids:
|
||||
try:
|
||||
normalized = int(raw_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
deduped.setdefault(normalized, None)
|
||||
|
||||
normalized_ids = [str(model_id) for model_id in deduped.keys()]
|
||||
if not normalized_ids:
|
||||
return {}
|
||||
|
||||
try:
|
||||
query = ",".join(normalized_ids)
|
||||
success, result = await self._make_request(
|
||||
'GET',
|
||||
f"{self.base_url}/models",
|
||||
use_auth=True,
|
||||
params={'ids': query},
|
||||
)
|
||||
if not success:
|
||||
return None
|
||||
|
||||
items = result.get('items') if isinstance(result, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return {}
|
||||
|
||||
payload: Dict[int, Dict] = {}
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
model_id = item.get('id')
|
||||
try:
|
||||
normalized_id = int(model_id)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
payload[normalized_id] = {
|
||||
'modelVersions': item.get('modelVersions', []),
|
||||
'type': item.get('type', ''),
|
||||
'name': item.get('name', ''),
|
||||
}
|
||||
return payload
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error fetching model versions in bulk: {exc}")
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata."""
|
||||
|
||||
@@ -53,6 +53,12 @@ class ModelMetadataProvider(ABC):
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
"""Get all versions of a model with their details"""
|
||||
pass
|
||||
|
||||
async def get_model_versions_bulk(
|
||||
self, model_ids: Sequence[int]
|
||||
) -> Optional[Dict[int, Dict]]:
|
||||
"""Fetch model versions for multiple model ids when supported."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
@@ -80,6 +86,11 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
||||
|
||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||
return await self.client.get_model_versions(model_id)
|
||||
|
||||
async def get_model_versions_bulk(
|
||||
self, model_ids: Sequence[int]
|
||||
) -> Optional[Dict[int, Dict]]:
|
||||
return await self.client.get_model_versions_bulk(model_ids)
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
return await self.client.get_model_version(model_id, version_id)
|
||||
@@ -544,7 +555,19 @@ class ModelMetadataProviderManager:
|
||||
"""Get model versions using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
return await provider.get_model_versions(model_id)
|
||||
|
||||
|
||||
async def get_model_versions_bulk(
|
||||
self,
|
||||
model_ids: Sequence[int],
|
||||
provider_name: str = None,
|
||||
) -> Optional[Dict[int, Dict]]:
|
||||
"""Fetch model versions for multiple model ids when supported by provider."""
|
||||
provider = self._get_provider(provider_name)
|
||||
try:
|
||||
return await provider.get_model_versions_bulk(model_ids)
|
||||
except NotImplementedError:
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
|
||||
"""Get specific model version using specified or default provider"""
|
||||
provider = self._get_provider(provider_name)
|
||||
|
||||
@@ -95,6 +95,28 @@ class ModelUpdateService:
|
||||
|
||||
local_versions = await self._collect_local_versions(scanner)
|
||||
results: Dict[int, ModelUpdateRecord] = {}
|
||||
prefetched: Dict[int, Mapping] = {}
|
||||
|
||||
fetch_targets: List[int] = []
|
||||
if metadata_provider and local_versions:
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
for model_id in local_versions.keys():
|
||||
existing = self._get_record(model_type, model_id)
|
||||
if existing and existing.should_ignore and not force_refresh:
|
||||
continue
|
||||
if force_refresh or not existing or self._is_stale(existing, now):
|
||||
fetch_targets.append(model_id)
|
||||
|
||||
if fetch_targets:
|
||||
try:
|
||||
prefetched = await self._fetch_model_versions_bulk(
|
||||
metadata_provider,
|
||||
fetch_targets,
|
||||
)
|
||||
except NotImplementedError:
|
||||
prefetched = {}
|
||||
|
||||
for model_id, version_ids in local_versions.items():
|
||||
record = await self._refresh_single_model(
|
||||
model_type,
|
||||
@@ -102,6 +124,7 @@ class ModelUpdateService:
|
||||
version_ids,
|
||||
metadata_provider,
|
||||
force_refresh=force_refresh,
|
||||
prefetched_response=prefetched.get(model_id),
|
||||
)
|
||||
if record:
|
||||
results[model_id] = record
|
||||
@@ -201,6 +224,7 @@ class ModelUpdateService:
|
||||
metadata_provider,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
prefetched_response: Optional[Mapping] = None,
|
||||
) -> Optional[ModelUpdateRecord]:
|
||||
normalized_local = self._normalize_sequence(local_versions)
|
||||
now = time.time()
|
||||
@@ -223,25 +247,27 @@ class ModelUpdateService:
|
||||
# release lock during network request
|
||||
fetched_versions: List[int] | None = None
|
||||
refresh_succeeded = False
|
||||
response: Optional[Mapping] = None
|
||||
if metadata_provider and should_fetch:
|
||||
try:
|
||||
response = await metadata_provider.get_model_versions(model_id)
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.error(
|
||||
"Failed to fetch versions for model %s (%s): %s",
|
||||
model_id,
|
||||
model_type,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
if response is not None:
|
||||
extracted = self._extract_version_ids(response)
|
||||
if extracted is not None:
|
||||
fetched_versions = extracted
|
||||
refresh_succeeded = True
|
||||
response = prefetched_response
|
||||
if response is None:
|
||||
try:
|
||||
response = await metadata_provider.get_model_versions(model_id)
|
||||
except RateLimitError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive log
|
||||
logger.error(
|
||||
"Failed to fetch versions for model %s (%s): %s",
|
||||
model_id,
|
||||
model_type,
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
if response is not None:
|
||||
extracted = self._extract_version_ids(response)
|
||||
if extracted is not None:
|
||||
fetched_versions = extracted
|
||||
refresh_succeeded = True
|
||||
|
||||
async with self._lock:
|
||||
existing = self._get_record(model_type, model_id)
|
||||
@@ -280,6 +306,40 @@ class ModelUpdateService:
|
||||
self._upsert_record(record)
|
||||
return record
|
||||
|
||||
async def _fetch_model_versions_bulk(
|
||||
self,
|
||||
metadata_provider,
|
||||
model_ids: Sequence[int],
|
||||
) -> Dict[int, Mapping]:
|
||||
"""Fetch model metadata in batches of up to 100 ids."""
|
||||
|
||||
BATCH_SIZE = 100
|
||||
normalized = self._normalize_sequence(model_ids)
|
||||
if not normalized:
|
||||
return {}
|
||||
|
||||
aggregated: Dict[int, Mapping] = {}
|
||||
for index in range(0, len(normalized), BATCH_SIZE):
|
||||
chunk = normalized[index : index + BATCH_SIZE]
|
||||
try:
|
||||
response = await metadata_provider.get_model_versions_bulk(chunk)
|
||||
except RateLimitError:
|
||||
raise
|
||||
if response is None:
|
||||
continue
|
||||
if not isinstance(response, Mapping):
|
||||
logger.debug(
|
||||
"Unexpected bulk response type %s from provider %s", type(response), metadata_provider
|
||||
)
|
||||
continue
|
||||
for key, value in response.items():
|
||||
normalized_key = self._normalize_int(key)
|
||||
if normalized_key is None:
|
||||
continue
|
||||
if isinstance(value, Mapping):
|
||||
aggregated[normalized_key] = value
|
||||
return aggregated
|
||||
|
||||
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
|
||||
cache = await scanner.get_cached_data()
|
||||
mapping: Dict[int, set[int]] = {}
|
||||
|
||||
Reference in New Issue
Block a user