mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
test(routes): cover snake case model id payload
This commit is contained in:
@@ -28,13 +28,22 @@ class DummyUpdateService:
|
||||
self.records = records
|
||||
self.calls = []
|
||||
|
||||
async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False):
|
||||
async def refresh_for_model_type(
|
||||
self,
|
||||
model_type,
|
||||
scanner,
|
||||
provider,
|
||||
*,
|
||||
force_refresh=False,
|
||||
target_model_ids=None,
|
||||
):
|
||||
self.calls.append(
|
||||
{
|
||||
"model_type": model_type,
|
||||
"scanner": scanner,
|
||||
"provider": provider,
|
||||
"force_refresh": force_refresh,
|
||||
"target_model_ids": target_model_ids,
|
||||
}
|
||||
)
|
||||
return self.records
|
||||
@@ -152,3 +161,106 @@ async def test_refresh_model_updates_filters_records_without_updates():
|
||||
assert call["scanner"] is service.scanner
|
||||
assert call["force_refresh"] is False
|
||||
assert call["provider"] is not None
|
||||
assert call["target_model_ids"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_model_updates_with_target_ids():
|
||||
cache = SimpleNamespace(version_index={})
|
||||
service = DummyService(cache)
|
||||
|
||||
record_with_update = ModelUpdateRecord(
|
||||
model_type="lora",
|
||||
model_id=1,
|
||||
versions=[
|
||||
ModelVersionRecord(
|
||||
version_id=10,
|
||||
name="v1",
|
||||
base_model=None,
|
||||
released_at=None,
|
||||
size_bytes=None,
|
||||
preview_url=None,
|
||||
is_in_library=False,
|
||||
should_ignore=False,
|
||||
)
|
||||
],
|
||||
last_checked_at=None,
|
||||
should_ignore_model=False,
|
||||
)
|
||||
|
||||
update_service = DummyUpdateService({1: record_with_update})
|
||||
|
||||
async def metadata_selector(name):
|
||||
assert name == "civitai_api"
|
||||
return object()
|
||||
|
||||
handler = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=update_service,
|
||||
metadata_provider_selector=metadata_selector,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
class DummyRequest:
|
||||
can_read_body = True
|
||||
query = {}
|
||||
|
||||
async def json(self):
|
||||
return {"modelIds": [1, "2", None]}
|
||||
|
||||
response = await handler.refresh_model_updates(DummyRequest())
|
||||
assert response.status == 200
|
||||
|
||||
call = update_service.calls[0]
|
||||
assert call["target_model_ids"] == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_model_updates_accepts_snake_case_ids():
|
||||
cache = SimpleNamespace(version_index={})
|
||||
service = DummyService(cache)
|
||||
|
||||
record_with_update = ModelUpdateRecord(
|
||||
model_type="lora",
|
||||
model_id=3,
|
||||
versions=[
|
||||
ModelVersionRecord(
|
||||
version_id=30,
|
||||
name="v3",
|
||||
base_model=None,
|
||||
released_at=None,
|
||||
size_bytes=None,
|
||||
preview_url=None,
|
||||
is_in_library=False,
|
||||
should_ignore=False,
|
||||
)
|
||||
],
|
||||
last_checked_at=None,
|
||||
should_ignore_model=False,
|
||||
)
|
||||
|
||||
update_service = DummyUpdateService({3: record_with_update})
|
||||
|
||||
async def metadata_selector(name):
|
||||
assert name == "civitai_api"
|
||||
return object()
|
||||
|
||||
handler = ModelUpdateHandler(
|
||||
service=service,
|
||||
update_service=update_service,
|
||||
metadata_provider_selector=metadata_selector,
|
||||
logger=logging.getLogger(__name__),
|
||||
)
|
||||
|
||||
class DummyRequest:
|
||||
can_read_body = True
|
||||
query = {}
|
||||
|
||||
async def json(self):
|
||||
return {"model_ids": [3, "4", "abc", None]}
|
||||
|
||||
response = await handler.refresh_model_updates(DummyRequest())
|
||||
assert response.status == 200
|
||||
|
||||
call = update_service.calls[0]
|
||||
assert call["target_model_ids"] == [3, 4]
|
||||
|
||||
@@ -143,6 +143,47 @@ async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
||||
assert provider.bulk_calls == [[1]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_filters_to_requested_models(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [
|
||||
{"civitai": {"modelId": 1, "id": 11}},
|
||||
{"civitai": {"modelId": 2, "id": 21}},
|
||||
]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": []})
|
||||
|
||||
result = await service.refresh_for_model_type(
|
||||
"lora",
|
||||
scanner,
|
||||
provider,
|
||||
target_model_ids=[2],
|
||||
)
|
||||
|
||||
assert list(result.keys()) == [2]
|
||||
assert provider.bulk_calls == [[2]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_returns_empty_when_targets_missing(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
||||
raw_data = [{"civitai": {"modelId": 1, "id": 11}}]
|
||||
scanner = DummyScanner(raw_data)
|
||||
provider = DummyProvider({"modelVersions": []})
|
||||
|
||||
result = await service.refresh_for_model_type(
|
||||
"lora",
|
||||
scanner,
|
||||
provider,
|
||||
target_model_ids=[5],
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
assert provider.bulk_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_respects_ignore_flag(tmp_path):
|
||||
db_path = tmp_path / "updates.sqlite"
|
||||
|
||||
Reference in New Issue
Block a user