feat(metadata): batch refresh model versions

This commit is contained in:
pixelpaws
2025-10-15 20:47:30 +08:00
parent 0968698804
commit 21a1bc1a01
5 changed files with 258 additions and 37 deletions

View File

@@ -23,8 +23,10 @@ class DummyDownloader:
self.memory_calls.append({"url": url, "use_auth": use_auth})
return True, b"bytes", {"content-type": "image/jpeg"}
async def make_request(self, method, url, use_auth=True):
self.request_calls.append({"method": method, "url": url, "use_auth": use_auth})
async def make_request(self, method, url, use_auth=True, **kwargs):
self.request_calls.append(
{"method": method, "url": url, "use_auth": use_auth, "kwargs": kwargs}
)
return True, {}
@@ -72,7 +74,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
}
model_payload = {"description": "desc", "tags": ["tag"], "creator": {"username": "user"}}
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
if url.endswith("by-hash/hash"):
return True, version_payload.copy()
if url.endswith("/models/123"):
@@ -94,7 +96,7 @@ async def test_get_model_by_hash_enriches_metadata(monkeypatch, downloader):
async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, "not found"
downloader.make_request = fake_make_request
@@ -108,7 +110,7 @@ async def test_get_model_by_hash_handles_not_found(monkeypatch, downloader):
async def test_get_model_by_hash_propagates_rate_limit(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, RateLimitError("limited", retry_after=4)
downloader.make_request = fake_make_request
@@ -148,7 +150,7 @@ async def test_download_preview_image_failure(monkeypatch, downloader):
async def test_get_model_versions_success(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
downloader.make_request = fake_make_request
@@ -160,8 +162,44 @@ async def test_get_model_versions_success(monkeypatch, downloader):
assert result == {"modelVersions": [{"id": 1}], "type": "LORA", "name": "Model"}
async def test_get_model_versions_bulk_success(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True, **kwargs):
assert url.endswith("/models")
assert kwargs.get("params") == {"ids": "1,2"}
return True, {
"items": [
{"id": 1, "modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
{"id": 2, "modelVersions": [], "type": "Checkpoint", "name": "Two"},
]
}
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_versions_bulk([1, "2", 2])
assert result == {
1: {"modelVersions": [{"id": 11}], "type": "LORA", "name": "One"},
2: {"modelVersions": [], "type": "Checkpoint", "name": "Two"},
}
async def test_get_model_versions_bulk_handles_errors(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, "error"
downloader.make_request = fake_make_request
client = await CivitaiClient.get_instance()
result = await client.get_model_versions_bulk([1, 2])
assert result is None
async def test_get_model_version_by_version_id(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
if url.endswith("/model-versions/7"):
return True, {
"modelId": 321,
@@ -219,7 +257,7 @@ async def test_get_model_version_with_model_id_prefers_version_endpoint(monkeypa
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
@@ -273,7 +311,7 @@ async def test_get_model_version_with_model_id_fallbacks_to_hash(monkeypatch, do
"images": [],
}
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
requests.append(url)
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
@@ -315,7 +353,7 @@ async def test_get_model_version_with_model_id_builds_from_model_data(monkeypatc
"poi": False,
}
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
if url.endswith("/models/99"):
return True, copy.deepcopy(model_payload)
if url.endswith("/model-versions/7"):
@@ -345,7 +383,7 @@ async def test_get_model_version_requires_identifier(monkeypatch, downloader):
async def test_get_model_version_info_handles_not_found(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return False, "not found"
downloader.make_request = fake_make_request
@@ -361,7 +399,7 @@ async def test_get_model_version_info_handles_not_found(monkeypatch, downloader)
async def test_get_model_version_info_success(monkeypatch, downloader):
expected = {"id": 55, "images": [{"meta": {"comfy": {"foo": "bar"}, "other": "keep"}}]}
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, expected
downloader.make_request = fake_make_request
@@ -377,7 +415,7 @@ async def test_get_model_version_info_success(monkeypatch, downloader):
async def test_get_image_info_returns_first_item(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, {"items": [{"id": 1}, {"id": 2}]}
downloader.make_request = fake_make_request
@@ -390,7 +428,7 @@ async def test_get_image_info_returns_first_item(monkeypatch, downloader):
async def test_get_image_info_handles_missing(monkeypatch, downloader):
async def fake_make_request(method, url, use_auth=True):
async def fake_make_request(method, url, use_auth=True, **kwargs):
return True, {"items": []}
downloader.make_request = fake_make_request

View File

@@ -15,14 +15,22 @@ class DummyScanner:
class DummyProvider:
def __init__(self, response):
def __init__(self, response, *, support_bulk: bool = True):
self.response = response
self.calls: int = 0
self.bulk_calls: list[list[int]] = []
self.support_bulk = support_bulk
async def get_model_versions(self, model_id):
self.calls += 1
return self.response
async def get_model_versions_bulk(self, model_ids):
if not self.support_bulk:
raise NotImplementedError
self.bulk_calls.append(list(model_ids))
return {model_id: self.response for model_id in model_ids}
@pytest.mark.asyncio
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
@@ -38,14 +46,16 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 1)
assert provider.calls == 1
assert provider.calls == 0
assert provider.bulk_calls == [[1]]
assert record is not None
assert record.version_ids == [11, 15]
assert record.in_library_version_ids == [11, 15]
assert record.has_update() is False
await service.refresh_for_model_type("lora", scanner, provider)
assert provider.calls == 1, "provider should not be called again within TTL"
assert provider.calls == 0, "provider should not be called again within TTL"
assert provider.bulk_calls == [[1]]
@pytest.mark.asyncio
@@ -60,8 +70,45 @@ async def test_refresh_respects_ignore_flag(tmp_path):
await service.set_should_ignore("lora", 2, True)
provider.calls = 0
provider.bulk_calls = []
await service.refresh_for_model_type("lora", scanner, provider)
assert provider.calls == 0
assert provider.bulk_calls == []
@pytest.mark.asyncio
async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False)
await service.refresh_for_model_type("lora", scanner, provider)
record = await service.get_record("lora", 4)
assert record is not None
assert provider.calls == 1
assert provider.bulk_calls == []
@pytest.mark.asyncio
async def test_refresh_batches_large_collections(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [
{"civitai": {"modelId": idx, "id": idx * 10}}
for idx in range(1, 151)
]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": []})
await service.refresh_for_model_type("lora", scanner, provider)
# Expect two batches: 100 ids and remaining 50 ids
assert len(provider.bulk_calls) == 2
assert len(provider.bulk_calls[0]) == 100
assert len(provider.bulk_calls[1]) == 50
@pytest.mark.asyncio