mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
133 lines
4.4 KiB
Python
133 lines
4.4 KiB
Python
import asyncio
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from py.services.model_update_service import ModelUpdateService
|
|
|
|
|
|
class DummyScanner:
|
|
def __init__(self, raw_data):
|
|
self._cache = SimpleNamespace(raw_data=raw_data)
|
|
|
|
async def get_cached_data(self, *args, **kwargs):
|
|
return self._cache
|
|
|
|
|
|
class DummyProvider:
|
|
def __init__(self, response, *, support_bulk: bool = True):
|
|
self.response = response
|
|
self.calls: int = 0
|
|
self.bulk_calls: list[list[int]] = []
|
|
self.support_bulk = support_bulk
|
|
|
|
async def get_model_versions(self, model_id):
|
|
self.calls += 1
|
|
return self.response
|
|
|
|
async def get_model_versions_bulk(self, model_ids):
|
|
if not self.support_bulk:
|
|
raise NotImplementedError
|
|
self.bulk_calls.append(list(model_ids))
|
|
return {model_id: self.response for model_id in model_ids}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [
|
|
{"civitai": {"modelId": 1, "id": 11}},
|
|
{"civitai": {"modelId": 1, "id": 15}},
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": [{"id": 11}, {"id": 15}]})
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 1)
|
|
|
|
assert provider.calls == 0
|
|
assert provider.bulk_calls == [[1]]
|
|
assert record is not None
|
|
assert record.version_ids == [11, 15]
|
|
assert record.in_library_version_ids == [11, 15]
|
|
assert record.has_update() is False
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
assert provider.calls == 0, "provider should not be called again within TTL"
|
|
assert provider.bulk_calls == [[1]]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_respects_ignore_flag(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": [{"id": 21}, {"id": 22}]})
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
await service.set_should_ignore("lora", 2, True)
|
|
|
|
provider.calls = 0
|
|
provider.bulk_calls = []
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
assert provider.calls == 0
|
|
assert provider.bulk_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": [{"id": 41}]}, support_bulk=False)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 4)
|
|
|
|
assert record is not None
|
|
assert provider.calls == 1
|
|
assert provider.bulk_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_batches_large_collections(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [
|
|
{"civitai": {"modelId": idx, "id": idx * 10}}
|
|
for idx in range(1, 151)
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": []})
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
|
|
# Expect two batches: 100 ids and remaining 50 ids
|
|
assert len(provider.bulk_calls) == 2
|
|
assert len(provider.bulk_calls[0]) == 100
|
|
assert len(provider.bulk_calls[1]) == 50
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_in_library_versions_changes_update_state(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
|
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": [{"id": 31}, {"id": 35}]})
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
await service.update_in_library_versions("lora", 3, [31])
|
|
record = await service.get_record("lora", 3)
|
|
|
|
assert record is not None
|
|
assert record.has_update() is True
|
|
|
|
await service.update_in_library_versions("lora", 3, [31, 35])
|
|
record = await service.get_record("lora", 3)
|
|
|
|
assert record.has_update() is False
|