test(routes): cover snake case model id payload

This commit is contained in:
pixelpaws
2025-10-29 07:33:58 +08:00
parent 7770976513
commit de05b59f29
19 changed files with 484 additions and 36 deletions

View File

@@ -28,13 +28,22 @@ class DummyUpdateService:
self.records = records
self.calls = []
async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False):
async def refresh_for_model_type(
self,
model_type,
scanner,
provider,
*,
force_refresh=False,
target_model_ids=None,
):
self.calls.append(
{
"model_type": model_type,
"scanner": scanner,
"provider": provider,
"force_refresh": force_refresh,
"target_model_ids": target_model_ids,
}
)
return self.records
@@ -152,3 +161,106 @@ async def test_refresh_model_updates_filters_records_without_updates():
assert call["scanner"] is service.scanner
assert call["force_refresh"] is False
assert call["provider"] is not None
assert call["target_model_ids"] is None
@pytest.mark.asyncio
async def test_refresh_model_updates_with_target_ids():
cache = SimpleNamespace(version_index={})
service = DummyService(cache)
record_with_update = ModelUpdateRecord(
model_type="lora",
model_id=1,
versions=[
ModelVersionRecord(
version_id=10,
name="v1",
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
update_service = DummyUpdateService({1: record_with_update})
async def metadata_selector(name):
assert name == "civitai_api"
return object()
handler = ModelUpdateHandler(
service=service,
update_service=update_service,
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {"modelIds": [1, "2", None]}
response = await handler.refresh_model_updates(DummyRequest())
assert response.status == 200
call = update_service.calls[0]
assert call["target_model_ids"] == [1, 2]
@pytest.mark.asyncio
async def test_refresh_model_updates_accepts_snake_case_ids():
cache = SimpleNamespace(version_index={})
service = DummyService(cache)
record_with_update = ModelUpdateRecord(
model_type="lora",
model_id=3,
versions=[
ModelVersionRecord(
version_id=30,
name="v3",
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
update_service = DummyUpdateService({3: record_with_update})
async def metadata_selector(name):
assert name == "civitai_api"
return object()
handler = ModelUpdateHandler(
service=service,
update_service=update_service,
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {"model_ids": [3, "4", "abc", None]}
response = await handler.refresh_model_updates(DummyRequest())
assert response.status == 200
call = update_service.calls[0]
assert call["target_model_ids"] == [3, 4]

View File

@@ -143,6 +143,47 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
assert provider.bulk_calls == [[1]]
@pytest.mark.asyncio
async def test_refresh_filters_to_requested_models(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [
{"civitai": {"modelId": 1, "id": 11}},
{"civitai": {"modelId": 2, "id": 21}},
]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": []})
result = await service.refresh_for_model_type(
"lora",
scanner,
provider,
target_model_ids=[2],
)
assert list(result.keys()) == [2]
assert provider.bulk_calls == [[2]]
@pytest.mark.asyncio
async def test_refresh_returns_empty_when_targets_missing(tmp_path):
db_path = tmp_path / "updates.sqlite"
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
raw_data = [{"civitai": {"modelId": 1, "id": 11}}]
scanner = DummyScanner(raw_data)
provider = DummyProvider({"modelVersions": []})
result = await service.refresh_for_model_type(
"lora",
scanner,
provider,
target_model_ids=[5],
)
assert result == {}
assert provider.bulk_calls == []
@pytest.mark.asyncio
async def test_refresh_respects_ignore_flag(tmp_path):
db_path = tmp_path / "updates.sqlite"