test(routes): cover snake case model id payload

This commit is contained in:
pixelpaws
2025-10-29 07:33:58 +08:00
parent 7770976513
commit de05b59f29
19 changed files with 484 additions and 36 deletions

View File

@@ -28,13 +28,22 @@ class DummyUpdateService:
self.records = records
self.calls = []
async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False):
async def refresh_for_model_type(
self,
model_type,
scanner,
provider,
*,
force_refresh=False,
target_model_ids=None,
):
self.calls.append(
{
"model_type": model_type,
"scanner": scanner,
"provider": provider,
"force_refresh": force_refresh,
"target_model_ids": target_model_ids,
}
)
return self.records
@@ -152,3 +161,106 @@ async def test_refresh_model_updates_filters_records_without_updates():
assert call["scanner"] is service.scanner
assert call["force_refresh"] is False
assert call["provider"] is not None
assert call["target_model_ids"] is None
@pytest.mark.asyncio
async def test_refresh_model_updates_with_target_ids():
cache = SimpleNamespace(version_index={})
service = DummyService(cache)
record_with_update = ModelUpdateRecord(
model_type="lora",
model_id=1,
versions=[
ModelVersionRecord(
version_id=10,
name="v1",
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
update_service = DummyUpdateService({1: record_with_update})
async def metadata_selector(name):
assert name == "civitai_api"
return object()
handler = ModelUpdateHandler(
service=service,
update_service=update_service,
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {"modelIds": [1, "2", None]}
response = await handler.refresh_model_updates(DummyRequest())
assert response.status == 200
call = update_service.calls[0]
assert call["target_model_ids"] == [1, 2]
@pytest.mark.asyncio
async def test_refresh_model_updates_accepts_snake_case_ids():
cache = SimpleNamespace(version_index={})
service = DummyService(cache)
record_with_update = ModelUpdateRecord(
model_type="lora",
model_id=3,
versions=[
ModelVersionRecord(
version_id=30,
name="v3",
base_model=None,
released_at=None,
size_bytes=None,
preview_url=None,
is_in_library=False,
should_ignore=False,
)
],
last_checked_at=None,
should_ignore_model=False,
)
update_service = DummyUpdateService({3: record_with_update})
async def metadata_selector(name):
assert name == "civitai_api"
return object()
handler = ModelUpdateHandler(
service=service,
update_service=update_service,
metadata_provider_selector=metadata_selector,
logger=logging.getLogger(__name__),
)
class DummyRequest:
can_read_body = True
query = {}
async def json(self):
return {"model_ids": [3, "4", "abc", None]}
response = await handler.refresh_model_updates(DummyRequest())
assert response.status == 200
call = update_service.calls[0]
assert call["target_model_ids"] == [3, 4]