mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
Merge pull request #573 from willmiao/codex/add-batch-model-version-retrieval
feat: batch model update refresh using Civitai bulk API
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
|||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Dict, Tuple, List
|
from typing import Optional, Dict, Tuple, List, Sequence
|
||||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
@@ -182,6 +182,59 @@ class CivitaiClient:
|
|||||||
logger.error(f"Error fetching model versions: {e}")
|
logger.error(f"Error fetching model versions: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_model_versions_bulk(
|
||||||
|
self, model_ids: Sequence[int]
|
||||||
|
) -> Optional[Dict[int, Dict]]:
|
||||||
|
"""Fetch model metadata for multiple ids using the batch API."""
|
||||||
|
|
||||||
|
deduped: Dict[int, None] = {}
|
||||||
|
for raw_id in model_ids:
|
||||||
|
try:
|
||||||
|
normalized = int(raw_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
deduped.setdefault(normalized, None)
|
||||||
|
|
||||||
|
normalized_ids = [str(model_id) for model_id in deduped.keys()]
|
||||||
|
if not normalized_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
query = ",".join(normalized_ids)
|
||||||
|
success, result = await self._make_request(
|
||||||
|
'GET',
|
||||||
|
f"{self.base_url}/models",
|
||||||
|
use_auth=True,
|
||||||
|
params={'ids': query},
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
return None
|
||||||
|
|
||||||
|
items = result.get('items') if isinstance(result, dict) else None
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
payload: Dict[int, Dict] = {}
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
model_id = item.get('id')
|
||||||
|
try:
|
||||||
|
normalized_id = int(model_id)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
payload[normalized_id] = {
|
||||||
|
'modelVersions': item.get('modelVersions', []),
|
||||||
|
'type': item.get('type', ''),
|
||||||
|
'name': item.get('name', ''),
|
||||||
|
}
|
||||||
|
return payload
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(f"Error fetching model versions in bulk: {exc}")
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata."""
|
"""Get specific model version with additional metadata."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -54,6 +54,12 @@ class ModelMetadataProvider(ABC):
|
|||||||
"""Get all versions of a model with their details"""
|
"""Get all versions of a model with their details"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def get_model_versions_bulk(
|
||||||
|
self, model_ids: Sequence[int]
|
||||||
|
) -> Optional[Dict[int, Dict]]:
|
||||||
|
"""Fetch model versions for multiple model ids when supported."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata"""
|
"""Get specific model version with additional metadata"""
|
||||||
@@ -81,6 +87,11 @@ class CivitaiModelMetadataProvider(ModelMetadataProvider):
|
|||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
return await self.client.get_model_versions(model_id)
|
return await self.client.get_model_versions(model_id)
|
||||||
|
|
||||||
|
async def get_model_versions_bulk(
|
||||||
|
self, model_ids: Sequence[int]
|
||||||
|
) -> Optional[Dict[int, Dict]]:
|
||||||
|
return await self.client.get_model_versions_bulk(model_ids)
|
||||||
|
|
||||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
return await self.client.get_model_version(model_id, version_id)
|
return await self.client.get_model_version(model_id, version_id)
|
||||||
|
|
||||||
@@ -545,6 +556,18 @@ class ModelMetadataProviderManager:
|
|||||||
provider = self._get_provider(provider_name)
|
provider = self._get_provider(provider_name)
|
||||||
return await provider.get_model_versions(model_id)
|
return await provider.get_model_versions(model_id)
|
||||||
|
|
||||||
|
async def get_model_versions_bulk(
|
||||||
|
self,
|
||||||
|
model_ids: Sequence[int],
|
||||||
|
provider_name: str = None,
|
||||||
|
) -> Optional[Dict[int, Dict]]:
|
||||||
|
"""Fetch model versions for multiple model ids when supported by provider."""
|
||||||
|
provider = self._get_provider(provider_name)
|
||||||
|
try:
|
||||||
|
return await provider.get_model_versions_bulk(model_ids)
|
||||||
|
except NotImplementedError:
|
||||||
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
|
||||||
"""Get specific model version using specified or default provider"""
|
"""Get specific model version using specified or default provider"""
|
||||||
provider = self._get_provider(provider_name)
|
provider = self._get_provider(provider_name)
|
||||||
|
|||||||
@@ -95,6 +95,28 @@ class ModelUpdateService:
|
|||||||
|
|
||||||
local_versions = await self._collect_local_versions(scanner)
|
local_versions = await self._collect_local_versions(scanner)
|
||||||
results: Dict[int, ModelUpdateRecord] = {}
|
results: Dict[int, ModelUpdateRecord] = {}
|
||||||
|
prefetched: Dict[int, Mapping] = {}
|
||||||
|
|
||||||
|
fetch_targets: List[int] = []
|
||||||
|
if metadata_provider and local_versions:
|
||||||
|
now = time.time()
|
||||||
|
async with self._lock:
|
||||||
|
for model_id in local_versions.keys():
|
||||||
|
existing = self._get_record(model_type, model_id)
|
||||||
|
if existing and existing.should_ignore and not force_refresh:
|
||||||
|
continue
|
||||||
|
if force_refresh or not existing or self._is_stale(existing, now):
|
||||||
|
fetch_targets.append(model_id)
|
||||||
|
|
||||||
|
if fetch_targets:
|
||||||
|
try:
|
||||||
|
prefetched = await self._fetch_model_versions_bulk(
|
||||||
|
metadata_provider,
|
||||||
|
fetch_targets,
|
||||||
|
)
|
||||||
|
except NotImplementedError:
|
||||||
|
prefetched = {}
|
||||||
|
|
||||||
for model_id, version_ids in local_versions.items():
|
for model_id, version_ids in local_versions.items():
|
||||||
record = await self._refresh_single_model(
|
record = await self._refresh_single_model(
|
||||||
model_type,
|
model_type,
|
||||||
@@ -102,6 +124,7 @@ class ModelUpdateService:
|
|||||||
version_ids,
|
version_ids,
|
||||||
metadata_provider,
|
metadata_provider,
|
||||||
force_refresh=force_refresh,
|
force_refresh=force_refresh,
|
||||||
|
prefetched_response=prefetched.get(model_id),
|
||||||
)
|
)
|
||||||
if record:
|
if record:
|
||||||
results[model_id] = record
|
results[model_id] = record
|
||||||
@@ -201,6 +224,7 @@ class ModelUpdateService:
|
|||||||
metadata_provider,
|
metadata_provider,
|
||||||
*,
|
*,
|
||||||
force_refresh: bool = False,
|
force_refresh: bool = False,
|
||||||
|
prefetched_response: Optional[Mapping] = None,
|
||||||
) -> Optional[ModelUpdateRecord]:
|
) -> Optional[ModelUpdateRecord]:
|
||||||
normalized_local = self._normalize_sequence(local_versions)
|
normalized_local = self._normalize_sequence(local_versions)
|
||||||
now = time.time()
|
now = time.time()
|
||||||
@@ -223,25 +247,27 @@ class ModelUpdateService:
|
|||||||
# release lock during network request
|
# release lock during network request
|
||||||
fetched_versions: List[int] | None = None
|
fetched_versions: List[int] | None = None
|
||||||
refresh_succeeded = False
|
refresh_succeeded = False
|
||||||
|
response: Optional[Mapping] = None
|
||||||
if metadata_provider and should_fetch:
|
if metadata_provider and should_fetch:
|
||||||
try:
|
response = prefetched_response
|
||||||
response = await metadata_provider.get_model_versions(model_id)
|
if response is None:
|
||||||
except RateLimitError:
|
try:
|
||||||
raise
|
response = await metadata_provider.get_model_versions(model_id)
|
||||||
except Exception as exc: # pragma: no cover - defensive log
|
except RateLimitError:
|
||||||
logger.error(
|
raise
|
||||||
"Failed to fetch versions for model %s (%s): %s",
|
except Exception as exc: # pragma: no cover - defensive log
|
||||||
model_id,
|
logger.error(
|
||||||
model_type,
|
"Failed to fetch versions for model %s (%s): %s",
|
||||||
exc,
|
model_id,
|
||||||
exc_info=True,
|
model_type,
|
||||||
)
|
exc,
|
||||||
else:
|
exc_info=True,
|
||||||
if response is not None:
|
)
|
||||||
extracted = self._extract_version_ids(response)
|
if response is not None:
|
||||||
if extracted is not None:
|
extracted = self._extract_version_ids(response)
|
||||||
fetched_versions = extracted
|
if extracted is not None:
|
||||||
refresh_succeeded = True
|
fetched_versions = extracted
|
||||||
|
refresh_succeeded = True
|
||||||
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
existing = self._get_record(model_type, model_id)
|
existing = self._get_record(model_type, model_id)
|
||||||
@@ -280,6 +306,40 @@ class ModelUpdateService:
|
|||||||
self._upsert_record(record)
|
self._upsert_record(record)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
async def _fetch_model_versions_bulk(
|
||||||
|
self,
|
||||||
|
metadata_provider,
|
||||||
|
model_ids: Sequence[int],
|
||||||
|
) -> Dict[int, Mapping]:
|
||||||
|
"""Fetch model metadata in batches of up to 100 ids."""
|
||||||
|
|
||||||
|
BATCH_SIZE = 100
|
||||||
|
normalized = self._normalize_sequence(model_ids)
|
||||||
|
if not normalized:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
aggregated: Dict[int, Mapping] = {}
|
||||||
|
for index in range(0, len(normalized), BATCH_SIZE):
|
||||||
|
chunk = normalized[index : index + BATCH_SIZE]
|
||||||
|
try:
|
||||||
|
response = await metadata_provider.get_model_versions_bulk(chunk)
|
||||||
|
except RateLimitError:
|
||||||
|
raise
|
||||||
|
if response is None:
|
||||||
|
continue
|
||||||
|
if not isinstance(response, Mapping):
|
||||||
|
logger.debug(
|
||||||
|
"Unexpected bulk response type %s from provider %s", type(response), metadata_provider
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
for key, value in response.items():
|
||||||
|
normalized_key = self._normalize_int(key)
|
||||||
|
if normalized_key is None:
|
||||||
|
continue
|
||||||
|
if isinstance(value, Mapping):
|
||||||
|
aggregated[normalized_key] = value
|
||||||
|
return aggregated
|
||||||
|
|
||||||
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
|
async def _collect_local_versions(self, scanner) -> Dict[int, List[int]]:
|
||||||
cache = await scanner.get_cached_data()
|
cache = await scanner.get_cached_data()
|
||||||
mapping: Dict[int, set[int]] = {}
|
mapping: Dict[int, set[int]] = {}
|
||||||
|
|||||||
@@ -23,8 +23,10 @@ class DummyDownloader:
|
|||||||
self.memory_calls.append({"url": url, "use_auth": use_auth})
|
self.memory_calls.append({"url": url, "use_auth": use_auth})
|
||||||
return True, b"bytes", {"content-type": "image/jpeg"}
|
return True, b"bytes", {"content-type": "image/jpeg"}
|
||||||
|
|
||||||
async def make_request(self, method, url, use_auth=True):
|
async def make_request(self, method, url, use_auth=True, **kwargs):
|
||||||
self.request_calls.append({"method": method, "url": url, "use_auth": use_auth})
|
self.request_calls.append(
|
||||||
|
{"method": method, "url": url, "use_auth": use_auth, "kwargs": kwargs}
|
||||||
|
)
|
||||||
return True, {}
|
return True, {}
|
||||||
|
|
||||||
|
|
||||||
@@ -72,7 +74,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
|
|||||||
}
|
}
|
||||||
model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
|
model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
|
||||||
|
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
if url.endswith("by-hash/hash"):
|
if url.endswith("by-hash/hash"):
|
||||||
return True, version_payload.copy()
|
return True, version_payload.copy()
|
||||||
if url.endswith("/models/123"):
|
if url.endswith("/models/123"):
|
||||||
@@ -94,7 +96,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return False, "not found"
|
return False, "not found"
|
||||||
|
|
||||||
downloader.make_request = fake_make_request
|
downloader.make_request = fake_make_request
|
||||||
@@ -108,7 +110,7 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return False, RateLimitError("limited", retry_after=4)
|
return False, RateLimitError("limited", retry_after=4)
|
||||||
|
|
||||||
downloader.make_request = fake_make_request
|
downloader.make_request = fake_make_request
|
||||||
@@ -148,7 +150,7 @@ async def test_download_preview_image_failure(monkeypatch, downloader):
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_model_versions_success(monkeypatch, downloader):
|
async def test_get_model_versions_success(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
||||||
|
|
||||||
downloader.make_request = fake_make_request
|
downloader.make_request = fake_make_request
|
||||||
@@ -160,8 +162,44 @@ async def test_get_model_versions_success(monkeypatch, downloader):
|
|||||||
assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_model_versions_bulk_success(monkeypatch, downloader):
|
||||||
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
|
assert url.endswith("/models")
|
||||||
|
assert kwargs.get("params") == {"ids": "1,2"}
|
||||||
|
return True, {
|
||||||
|
"items": [
|
||||||
|
{"id": 1, "modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
|
||||||
|
{"id": 2, "modelVersions": [], "type": "Checkpoint", "name": "Two"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
downloader.make_request = fake_make_request
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
|
||||||
|
result = await client.get_model_versions_bulk([1, "2", 2])
|
||||||
|
|
||||||
|
assert result == {
|
||||||
|
1: {"modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
|
||||||
|
2: {"modelVersions": [], "type": "Checkpoint", "name": "Two"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_model_versions_bulk_handles_errors(monkeypatch, downloader):
|
||||||
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
|
return False, "error"
|
||||||
|
|
||||||
|
downloader.make_request = fake_make_request
|
||||||
|
|
||||||
|
client = await CivitaiClient.get_instance()
|
||||||
|
|
||||||
|
result = await client.get_model_versions_bulk([1, 2])
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
async def test_get_model_version_by_version_id(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
if url.endswith("/model-versions/7"):
|
if url.endswith("/model-versions/7"):
|
||||||
return True, {
|
return True, {
|
||||||
"modelId": 321,
|
"modelId": 321,
|
||||||
@@ -219,7 +257,7 @@ async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypa
|
|||||||
"images": [],
|
"images": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
requests.append(url)
|
requests.append(url)
|
||||||
if url.endswith("/models/99"):
|
if url.endswith("/models/99"):
|
||||||
return True, copy.deepcopy(model_payload)
|
return True, copy.deepcopy(model_payload)
|
||||||
@@ -273,7 +311,7 @@ async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, do
|
|||||||
"images": [],
|
"images": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
requests.append(url)
|
requests.append(url)
|
||||||
if url.endswith("/models/99"):
|
if url.endswith("/models/99"):
|
||||||
return True, copy.deepcopy(model_payload)
|
return True, copy.deepcopy(model_payload)
|
||||||
@@ -315,7 +353,7 @@ async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatc
|
|||||||
"poi": False,
|
"poi": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
if url.endswith("/models/99"):
|
if url.endswith("/models/99"):
|
||||||
return True, copy.deepcopy(model_payload)
|
return True, copy.deepcopy(model_payload)
|
||||||
if url.endswith("/model-versions/7"):
|
if url.endswith("/model-versions/7"):
|
||||||
@@ -345,7 +383,7 @@ async def test_get_model_version_requires_identifier(monkeypatch, downloader):
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_model_version_info_handles_not_found(monkeypatch, downloader):
|
async def test_get_model_version_info_handles_not_found(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return False, "not found"
|
return False, "not found"
|
||||||
|
|
||||||
downloader.make_request = fake_make_request
|
downloader.make_request = fake_make_request
|
||||||
@@ -361,7 +399,7 @@ async def test_get_model_version_info_handles_not_found(monkeypatch, downloader)
|
|||||||
async def test_get_model_version_info_success(monkeypatch, downloader):
|
async def test_get_model_version_info_success(monkeypatch, downloader):
|
||||||
expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]}
|
expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]}
|
||||||
|
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return True, expected
|
return True, expected
|
||||||
|
|
||||||
downloader.make_request = fake_make_request
|
downloader.make_request = fake_make_request
|
||||||
@@ -377,7 +415,7 @@ async def test_get_model_version_info_success(monkeypatch, downloader):
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_image_info_returns_first_item(monkeypatch, downloader):
|
async def test_get_image_info_returns_first_item(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return True, {"items": [{"id": 1}, {"id": 2}]}
|
return True, {"items": [{"id": 1}, {"id": 2}]}
|
||||||
|
|
||||||
downloader.make_request = fake_make_request
|
downloader.make_request = fake_make_request
|
||||||
@@ -390,7 +428,7 @@ async def test_get_image_info_returns_first_item(monkeypatch, downloader):
|
|||||||
|
|
||||||
|
|
||||||
async def test_get_image_info_handles_missing(monkeypatch, downloader):
|
async def test_get_image_info_handles_missing(monkeypatch, downloader):
|
||||||
async def fake_make_request(method, url, use_auth=True):
|
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||||
return True, {"items": []}
|
return True, {"items": []}
|
||||||
|
|
||||||
downloader.make_request = fake_make_request
|
downloader.make_request = fake_make_request
|
||||||
|
|||||||
@@ -15,14 +15,22 @@ class DummyScanner:
|
|||||||
|
|
||||||
|
|
||||||
class DummyProvider:
|
class DummyProvider:
|
||||||
def __init__(self, response):
|
def __init__(self, response, *, support_bulk: bool = True):
|
||||||
self.response = response
|
self.response = response
|
||||||
self.calls: int = 0
|
self.calls: int = 0
|
||||||
|
self.bulk_calls: list[list[int]] = []
|
||||||
|
self.support_bulk = support_bulk
|
||||||
|
|
||||||
async def get_model_versions(self, model_id):
|
async def get_model_versions(self, model_id):
|
||||||
self.calls += 1
|
self.calls += 1
|
||||||
return self.response
|
return self.response
|
||||||
|
|
||||||
|
async def get_model_versions_bulk(self, model_ids):
|
||||||
|
if not self.support_bulk:
|
||||||
|
raise NotImplementedError
|
||||||
|
self.bulk_calls.append(list(model_ids))
|
||||||
|
return {model_id: self.response for model_id in model_ids}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||||
@@ -38,14 +46,16 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
|||||||
await service.refresh_for_model_type("lora", scanner, provider)
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
record = await service.get_record("lora", 1)
|
record = await service.get_record("lora", 1)
|
||||||
|
|
||||||
assert provider.calls == 1
|
assert provider.calls == 0
|
||||||
|
assert provider.bulk_calls == [[1]]
|
||||||
assert record is not None
|
assert record is not None
|
||||||
assert record.version_ids == [11, 15]
|
assert record.version_ids == [11, 15]
|
||||||
assert record.in_library_version_ids == [11, 15]
|
assert record.in_library_version_ids == [11, 15]
|
||||||
assert record.has_update() is False
|
assert record.has_update() is False
|
||||||
|
|
||||||
await service.refresh_for_model_type("lora", scanner, provider)
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
assert provider.calls == 1, "provider should not be called again within TTL"
|
assert provider.calls == 0, "provider should not be called again within TTL"
|
||||||
|
assert provider.bulk_calls == [[1]]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -60,8 +70,45 @@ async def test_refresh_respects_ignore_flag(tmp_path):
|
|||||||
await service.set_should_ignore("lora", 2, True)
|
await service.set_should_ignore("lora", 2, True)
|
||||||
|
|
||||||
provider.calls = 0
|
provider.calls = 0
|
||||||
|
provider.bulk_calls = []
|
||||||
await service.refresh_for_model_type("lora", scanner, provider)
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
assert provider.calls == 0
|
assert provider.calls == 0
|
||||||
|
assert provider.bulk_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
|
||||||
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||||
|
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
|
||||||
|
scanner = DummyScanner(raw_data)
|
||||||
|
provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False)
|
||||||
|
|
||||||
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
record = await service.get_record("lora", 4)
|
||||||
|
|
||||||
|
assert record is not None
|
||||||
|
assert provider.calls == 1
|
||||||
|
assert provider.bulk_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_refresh_batches_large_collections(tmp_path):
|
||||||
|
db_path = tmp_path / "updates.sqlite"
|
||||||
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||||
|
raw_data = [
|
||||||
|
{"civitai": {"modelId": idx, "id": idx * 10}}
|
||||||
|
for idx in range(1, 151)
|
||||||
|
]
|
||||||
|
scanner = DummyScanner(raw_data)
|
||||||
|
provider = DummyProvider({"modelVersions": []})
|
||||||
|
|
||||||
|
await service.refresh_for_model_type("lora", scanner, provider)
|
||||||
|
|
||||||
|
# Expect two batches: 100 ids and remaining 50 ids
|
||||||
|
assert len(provider.bulk_calls) == 2
|
||||||
|
assert len(provider.bulk_calls[0]) == 100
|
||||||
|
assert len(provider.bulk_calls[1]) == 50
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user