mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Add set_version_update_ignore endpoint to toggle ignore status for specific versions - Add get_model_versions endpoint to retrieve version details with optional refresh - Update serialization to include version-specific data and preview overrides - Modify database schema to support version-level ignore tracking - Improve error handling for rate limiting and missing models These changes enable granular control over version updates and provide better visibility into model version status.
242 lines
7.8 KiB
Python
242 lines
7.8 KiB
Python
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from py.services.model_update_service import ModelUpdateService
|
|
|
|
|
|
class DummyScanner:
|
|
def __init__(self, raw_data):
|
|
self._cache = SimpleNamespace(raw_data=raw_data, version_index={})
|
|
|
|
async def get_cached_data(self, *args, **kwargs):
|
|
return self._cache
|
|
|
|
|
|
class DummyProvider:
|
|
def __init__(self, response, *, support_bulk: bool = True):
|
|
self.response = response
|
|
self.calls: int = 0
|
|
self.bulk_calls: list[list[int]] = []
|
|
self.support_bulk = support_bulk
|
|
|
|
async def get_model_versions(self, model_id):
|
|
self.calls += 1
|
|
return self.response
|
|
|
|
async def get_model_versions_bulk(self, model_ids):
|
|
if not self.support_bulk:
|
|
raise NotImplementedError
|
|
self.bulk_calls.append(list(model_ids))
|
|
return {model_id: self.response for model_id in model_ids}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [
|
|
{"civitai": {"modelId": 1, "id": 11}},
|
|
{"civitai": {"modelId": 1, "id": 15}},
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{
|
|
"id": 11,
|
|
"name": "v1",
|
|
"baseModel": "SD15",
|
|
"publishedAt": "2024-01-01T00:00:00Z",
|
|
"files": [{"sizeKB": 1024}],
|
|
"images": [{"url": "https://example.com/1.png"}],
|
|
},
|
|
{
|
|
"id": 15,
|
|
"name": "v1.5",
|
|
"baseModel": "SD15",
|
|
"publishedAt": "2024-02-01T00:00:00Z",
|
|
"files": [{"sizeKB": 512}],
|
|
"images": [{"url": "https://example.com/2.png"}],
|
|
},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 1)
|
|
|
|
assert provider.calls == 0
|
|
assert provider.bulk_calls == [[1]]
|
|
assert record is not None
|
|
assert record.version_ids == [11, 15]
|
|
assert record.in_library_version_ids == [11, 15]
|
|
assert [version.name for version in record.versions] == ["v1", "v1.5"]
|
|
assert record.should_ignore_model is False
|
|
assert record.has_update() is False
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
assert provider.calls == 0, "provider should not be called again within TTL"
|
|
assert provider.bulk_calls == [[1]]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_respects_ignore_flag(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 21, "files": [], "images": []},
|
|
{"id": 22, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
await service.set_should_ignore("lora", 2, True)
|
|
|
|
provider.calls = 0
|
|
provider.bulk_calls = []
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
assert provider.calls == 0
|
|
assert provider.bulk_calls == []
|
|
record = await service.get_record("lora", 2)
|
|
assert record is not None
|
|
assert record.should_ignore_model is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{"modelVersions": [{"id": 41, "files": [], "images": []}]},
|
|
support_bulk=False,
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 4)
|
|
|
|
assert record is not None
|
|
assert provider.calls == 1
|
|
assert provider.bulk_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_batches_large_collections(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [
|
|
{"civitai": {"modelId": idx, "id": idx * 10}}
|
|
for idx in range(1, 151)
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": []})
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
|
|
# Expect two batches: 100 ids and remaining 50 ids
|
|
assert len(provider.bulk_calls) == 2
|
|
assert len(provider.bulk_calls[0]) == 100
|
|
assert len(provider.bulk_calls[1]) == 50
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_in_library_versions_changes_update_state(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
|
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 31, "files": [], "images": []},
|
|
{"id": 35, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
await service.update_in_library_versions("lora", 3, [31])
|
|
record = await service.get_record("lora", 3)
|
|
|
|
assert record is not None
|
|
assert record.has_update() is True
|
|
|
|
await service.update_in_library_versions("lora", 3, [31, 35])
|
|
record = await service.get_record("lora", 3)
|
|
|
|
assert record.has_update() is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_version_ignore_blocks_update_flag(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
|
raw_data = [{"civitai": {"modelId": 5, "id": 51}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 51, "files": [], "images": []},
|
|
{"id": 55, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 5)
|
|
assert record is not None
|
|
assert record.has_update() is True
|
|
|
|
await service.set_version_should_ignore("lora", 5, 55, True)
|
|
record = await service.get_record("lora", 5)
|
|
assert record is not None
|
|
assert record.has_update() is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
|
raw_data = [{"civitai": {"modelId": 7, "id": 71}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{
|
|
"id": 71,
|
|
"files": [],
|
|
"images": [
|
|
{
|
|
"url": "https://image.civitai.com/high/original=true/sample.png",
|
|
"nsfwLevel": 6,
|
|
"type": "image",
|
|
},
|
|
{
|
|
"url": "https://image.civitai.com/safe/original=true/preview.png",
|
|
"nsfwLevel": 1,
|
|
"type": "image",
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 7)
|
|
|
|
assert record is not None
|
|
assert record.versions
|
|
preview_url = record.versions[0].preview_url
|
|
assert (
|
|
preview_url
|
|
== "https://image.civitai.com/safe/width=450,optimized=true/preview.png"
|
|
)
|