mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 06:02:11 -03:00
test(routes): cover snake case model id payload
This commit is contained in:
@@ -28,13 +28,22 @@ class DummyUpdateService:
|
||||
self.records = records
|
||||
self.calls = []
|
||||
|
||||
async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False):
|
||||
async def refresh_for_model_type(
|
||||
self,
|
||||
model_type,
|
||||
scanner,
|
||||
provider,
|
||||
*,
|
||||
force_refresh=False,
|
||||
target_model_ids=None,
|
||||
):
|
||||
self.calls.append(
|
||||
{
|
||||
"model_type": model_type,
|
||||
"scanner": scanner,
|
||||
"provider": provider,
|
||||
"force_refresh": force_refresh,
|
||||
"target_model_ids": target_model_ids,
|
||||
}
|
||||
)
|
||||
return self.records
|
||||
@@ -152,3 +161,106 @@ async def test_refresh_model_updates_filters_records_without_updates():
|
||||
assert call["scanner"] is service.scanner
|
||||
assert call["force_refresh"] is False
|
||||
assert call["provider"] is not None
|
||||
assert call["target_model_ids"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_model_updates_with_target_ids():
|
||||
cache = SimpleNamespace(version_index={})
|
||||
service = DummyService(cache)
|
||||
|
||||
record_with_update = ModelUpdateRecord(
|
||||
model_type="lora",
|
||||
model_id=1,
|
||||
versions=[
|
||||
ModelVersionRecord(
|
||||
version_id=10,
|
||||
name="v1",
|
||||
base_model=None,
|
||||
released_at=None,
|
||||
size_bytes=None,
|
||||
preview_url=None,
|
||||
is_in_library=False,
|
||||
should_ignore=False,
|
||||
)
|
||||
],
|
||||
last_checked_at=None,
|
||||
should_ignore_model=False,
|
||||
)
|
||||
|
||||
update_service = DummyUpdateService({1: record_with_update})
|
||||
|
||||
async def metadata_selector(name):
|
||||
assert name == "civitai_api"
|
||||
return object()
|
||||
|
||||
handler = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=update_service,
|
||||
metadata_provider_selector=metadata_selector,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
class DummyRequest:
|
||||
can_read_body = True
|
||||
query = {}
|
||||
|
||||
async def json(self):
|
||||
return {"modelIds": [1, "2", None]}
|
||||
|
||||
response = await handler.refresh_model_updates(DummyRequest())
|
||||
assert response.status == 200
|
||||
|
||||
call = update_service.calls[0]
|
||||
assert call["target_model_ids"] == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_model_updates_accepts_snake_case_ids():
|
||||
cache = SimpleNamespace(version_index={})
|
||||
service = DummyService(cache)
|
||||
|
||||
record_with_update = ModelUpdateRecord(
|
||||
model_type="lora",
|
||||
model_id=3,
|
||||
versions=[
|
||||
ModelVersionRecord(
|
||||
version_id=30,
|
||||
name="v3",
|
||||
base_model=None,
|
||||
released_at=None,
|
||||
size_bytes=None,
|
||||
preview_url=None,
|
||||
is_in_library=False,
|
||||
should_ignore=False,
|
||||
)
|
||||
],
|
||||
last_checked_at=None,
|
||||
should_ignore_model=False,
|
||||
)
|
||||
|
||||
update_service = DummyUpdateService({3: record_with_update})
|
||||
|
||||
async def metadata_selector(name):
|
||||
assert name == "civitai_api"
|
||||
return object()
|
||||
|
||||
handler = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=update_service,
|
||||
metadata_provider_selector=metadata_selector,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
class DummyRequest:
|
||||
can_read_body = True
|
||||
query = {}
|
||||
|
||||
async def json(self):
|
||||
return {"model_ids": [3, "4", "abc", None]}
|
||||
|
||||
response = await handler.refresh_model_updates(DummyRequest())
|
||||
assert response.status == 200
|
||||
|
||||
call = update_service.calls[0]
|
||||
assert call["target_model_ids"] == [3, 4]
|
||||
|
||||
Reference in New Issue
Block a user