mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-27 13:11:17 -03:00
When refreshing updates with a folder filter, versions already present in other folders were excluded from the is_in_library check, making them appear as available updates. When the user tried to download, the global check found the file already exists and returned 'model already exists'. Fix by also collecting the cross-folder version set when folder_path is provided, and using the union (folder-filtered + cross-folder) for is_in_library in both _build_record_from_remote and _merge_with_local_versions.
624 lines
20 KiB
Python
624 lines
20 KiB
Python
import logging
|
|
import sqlite3
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
|
|
from py.services.errors import ResourceNotFoundError
|
|
from py.services.model_update_service import (
|
|
ModelUpdateRecord,
|
|
ModelUpdateService,
|
|
ModelVersionRecord,
|
|
)
|
|
|
|
|
|
class DummyScanner:
|
|
def __init__(self, raw_data):
|
|
self._cache = SimpleNamespace(raw_data=raw_data, version_index={})
|
|
self._cancelled = False
|
|
|
|
def is_cancelled(self) -> bool:
|
|
return self._cancelled
|
|
|
|
def reset_cancellation(self) -> None:
|
|
self._cancelled = False
|
|
|
|
async def get_cached_data(self, *args, **kwargs):
|
|
return self._cache
|
|
|
|
|
|
class DummyProvider:
|
|
def __init__(self, response, *, support_bulk: bool = True):
|
|
self.response = response
|
|
self.calls: int = 0
|
|
self.bulk_calls: list[list[int]] = []
|
|
self.support_bulk = support_bulk
|
|
|
|
async def get_model_versions(self, model_id):
|
|
self.calls += 1
|
|
return self.response
|
|
|
|
async def get_model_versions_bulk(self, model_ids):
|
|
if not self.support_bulk:
|
|
raise NotImplementedError
|
|
self.bulk_calls.append(list(model_ids))
|
|
return {model_id: self.response for model_id in model_ids}
|
|
|
|
|
|
class NotFoundProvider:
|
|
def __init__(self):
|
|
self.calls = 0
|
|
self.bulk_calls: list[list[int]] = []
|
|
|
|
async def get_model_versions(self, model_id):
|
|
self.calls += 1
|
|
raise ResourceNotFoundError("Resource not found")
|
|
|
|
async def get_model_versions_bulk(self, model_ids):
|
|
self.bulk_calls.append(list(model_ids))
|
|
return {}
|
|
|
|
|
|
def make_version(version_id, *, in_library, base_model=None, should_ignore=False):
|
|
return ModelVersionRecord(
|
|
version_id=version_id,
|
|
name=None,
|
|
base_model=base_model,
|
|
released_at=None,
|
|
size_bytes=None,
|
|
preview_url=None,
|
|
is_in_library=in_library,
|
|
should_ignore=should_ignore,
|
|
)
|
|
|
|
|
|
def make_record(*versions, should_ignore_model=False):
|
|
return ModelUpdateRecord(
|
|
model_type="lora",
|
|
model_id=999,
|
|
versions=list(versions),
|
|
last_checked_at=None,
|
|
should_ignore_model=should_ignore_model,
|
|
)
|
|
|
|
|
|
def test_extract_size_bytes_prefers_primary_model_file(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path))
|
|
|
|
response = {
|
|
"modelVersions": [
|
|
{
|
|
"id": 42,
|
|
"files": [
|
|
{"sizeKB": 2018.0400390625, "type": "Training Data", "primary": False},
|
|
{
|
|
"sizeKB": 1152322.3515625,
|
|
"type": "Model",
|
|
"primary": "True",
|
|
},
|
|
],
|
|
"images": [],
|
|
}
|
|
]
|
|
}
|
|
|
|
versions = service._extract_versions(response)
|
|
assert versions is not None
|
|
assert versions[0].size_bytes == int(1152322.3515625 * 1024)
|
|
|
|
|
|
def test_extract_size_bytes_falls_back_without_primary(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path))
|
|
|
|
response = {
|
|
"modelVersions": [
|
|
{
|
|
"id": 43,
|
|
"files": [
|
|
{
|
|
"sizeKB": 2048,
|
|
"type": "Training Data",
|
|
"primary": True,
|
|
},
|
|
{"sizeKB": 1024, "type": "Archive", "primary": False},
|
|
],
|
|
"images": [],
|
|
}
|
|
]
|
|
}
|
|
|
|
versions = service._extract_versions(response)
|
|
assert versions is not None
|
|
assert versions[0].size_bytes == int(2048 * 1024)
|
|
|
|
|
|
def test_has_update_requires_newer_version_than_library():
|
|
record = make_record(
|
|
make_version(5, in_library=True),
|
|
make_version(4, in_library=False),
|
|
make_version(8, in_library=False, should_ignore=True),
|
|
)
|
|
|
|
assert record.has_update() is False
|
|
|
|
|
|
def test_has_update_detects_newer_remote_version():
|
|
record = make_record(
|
|
make_version(5, in_library=True),
|
|
make_version(7, in_library=False),
|
|
make_version(6, in_library=False, should_ignore=True),
|
|
)
|
|
|
|
assert record.has_update() is True
|
|
|
|
|
|
def test_has_update_for_base_matches_same_base_model():
|
|
record = make_record(
|
|
make_version(5, in_library=True, base_model="Pony"),
|
|
make_version(6, in_library=False, base_model="Pony"),
|
|
make_version(7, in_library=False, base_model="Flux.1"),
|
|
)
|
|
|
|
assert record.has_update_for_base(5, "Pony") is True
|
|
|
|
|
|
def test_has_update_for_base_rejects_other_base_models():
|
|
record = make_record(
|
|
make_version(10, in_library=True, base_model="Flux"),
|
|
make_version(20, in_library=False, base_model="SDXL"),
|
|
)
|
|
|
|
assert record.has_update_for_base(10, "Flux") is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_persists_versions_and_uses_cache(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [
|
|
{"civitai": {"modelId": 1, "id": 11}},
|
|
{"civitai": {"modelId": 1, "id": 15}},
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{
|
|
"id": 11,
|
|
"name": "v1",
|
|
"baseModel": "SD15",
|
|
"publishedAt": "2024-01-01T00:00:00Z",
|
|
"files": [{"sizeKB": 1024}],
|
|
"images": [{"url": "https://example.com/1.png"}],
|
|
},
|
|
{
|
|
"id": 15,
|
|
"name": "v1.5",
|
|
"baseModel": "SD15",
|
|
"publishedAt": "2024-02-01T00:00:00Z",
|
|
"files": [{"sizeKB": 512}],
|
|
"images": [{"url": "https://example.com/2.png"}],
|
|
},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 1)
|
|
|
|
assert provider.calls == 0
|
|
assert provider.bulk_calls == [[1]]
|
|
assert record is not None
|
|
assert record.version_ids == [11, 15]
|
|
assert record.in_library_version_ids == [11, 15]
|
|
assert [version.name for version in record.versions] == ["v1", "v1.5"]
|
|
assert record.should_ignore_model is False
|
|
assert record.has_update() is False
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
assert provider.calls == 0, "provider should not be called again within TTL"
|
|
assert provider.bulk_calls == [[1]]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_filters_to_requested_models(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [
|
|
{"civitai": {"modelId": 1, "id": 11}},
|
|
{"civitai": {"modelId": 2, "id": 21}},
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": []})
|
|
|
|
result = await service.refresh_for_model_type(
|
|
"lora",
|
|
scanner,
|
|
provider,
|
|
target_model_ids=[2],
|
|
)
|
|
|
|
assert list(result.keys()) == [2]
|
|
assert provider.bulk_calls == [[2]]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_returns_empty_when_targets_missing(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 1, "id": 11}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": []})
|
|
|
|
result = await service.refresh_for_model_type(
|
|
"lora",
|
|
scanner,
|
|
provider,
|
|
target_model_ids=[5],
|
|
)
|
|
|
|
assert result == {}
|
|
assert provider.bulk_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_respects_ignore_flag(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 2, "id": 21}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 21, "files": [], "images": []},
|
|
{"id": 22, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
await service.set_should_ignore("lora", 2, True)
|
|
|
|
provider.calls = 0
|
|
provider.bulk_calls = []
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
assert provider.calls == 0
|
|
assert provider.bulk_calls == []
|
|
record = await service.get_record("lora", 2)
|
|
assert record is not None
|
|
assert record.should_ignore_model is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_marks_model_ignored_when_remote_missing(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 5, "id": 51}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = NotFoundProvider()
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 5)
|
|
|
|
assert provider.bulk_calls == [[5]]
|
|
assert provider.calls == 1
|
|
assert record is not None
|
|
assert record.should_ignore_model is True
|
|
assert record.in_library_version_ids == [51]
|
|
assert record.last_checked_at is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_logs_info_for_missing_remote(tmp_path, caplog):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 6, "id": 61}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = NotFoundProvider()
|
|
|
|
with caplog.at_level(logging.INFO, logger="py.services.model_update_service"):
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
|
|
relevant = [
|
|
record for record in caplog.records if "Single lookup for model" in record.message
|
|
]
|
|
assert relevant, "expected single lookup log entry"
|
|
assert all(record.levelno == logging.INFO for record in relevant)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_falls_back_when_bulk_not_supported(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 4, "id": 41}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{"modelVersions": [{"id": 41, "files": [], "images": []}]},
|
|
support_bulk=False,
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 4)
|
|
|
|
assert record is not None
|
|
assert provider.calls == 1
|
|
assert provider.bulk_calls == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_batches_large_collections(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [
|
|
{"civitai": {"modelId": idx, "id": idx * 10}}
|
|
for idx in range(1, 151)
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider({"modelVersions": []})
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
|
|
# Expect two batches: 100 ids and remaining 50 ids
|
|
assert len(provider.bulk_calls) == 2
|
|
assert len(provider.bulk_calls[0]) == 100
|
|
assert len(provider.bulk_calls[1]) == 50
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_in_library_versions_changes_update_state(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
|
raw_data = [{"civitai": {"modelId": 3, "id": 31}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 31, "files": [], "images": []},
|
|
{"id": 35, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
await service.update_in_library_versions("lora", 3, [31])
|
|
record = await service.get_record("lora", 3)
|
|
|
|
assert record is not None
|
|
assert record.has_update() is True
|
|
|
|
await service.update_in_library_versions("lora", 3, [31, 35])
|
|
record = await service.get_record("lora", 3)
|
|
|
|
assert record.has_update() is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_version_ignore_blocks_update_flag(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
|
raw_data = [{"civitai": {"modelId": 5, "id": 51}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 51, "files": [], "images": []},
|
|
{"id": 55, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 5)
|
|
assert record is not None
|
|
assert record.has_update() is True
|
|
|
|
await service.set_version_should_ignore("lora", 5, 55, True)
|
|
record = await service.get_record("lora", 5)
|
|
assert record is not None
|
|
assert record.has_update() is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_has_updates_bulk_returns_mapping(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
raw_data = [{"civitai": {"modelId": 9, "id": 91}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 91, "files": [], "images": []},
|
|
{"id": 92, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
mapping = await service.has_updates_bulk("lora", [9, 9, 42])
|
|
|
|
assert mapping == {9: True, 42: False}
|
|
assert await service.has_update("lora", 9) is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_has_updates_bulk_handles_more_than_sqlite_max_variables(tmp_path):
|
|
"""Bulk query with >999 model IDs must not raise 'too many SQL variables'."""
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
|
|
model_ids = list(range(1, 1201))
|
|
with sqlite3.connect(str(db_path)) as conn:
|
|
conn.execute("INSERT INTO model_update_status (model_id, model_type) VALUES (?, ?)", (1, "lora"))
|
|
conn.execute("INSERT INTO model_update_versions (model_id, version_id, sort_index, name) VALUES (?, ?, ?, ?)", (1, 10, 0, "v1"))
|
|
|
|
mapping = await service.has_updates_bulk("lora", model_ids)
|
|
|
|
assert mapping[1] is True
|
|
assert len(mapping) == len(model_ids)
|
|
assert all(v is False for k, v in mapping.items() if k != 1)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_records_bulk_handles_more_than_sqlite_max_variables(tmp_path):
|
|
"""Bulk record fetch with >999 model IDs must not raise 'too many SQL variables'."""
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=3600)
|
|
|
|
model_ids = list(range(1, 1201))
|
|
with sqlite3.connect(str(db_path)) as conn:
|
|
conn.execute("INSERT INTO model_update_status (model_id, model_type) VALUES (?, ?)", (1, "lora"))
|
|
conn.execute("INSERT INTO model_update_versions (model_id, version_id, sort_index, name) VALUES (?, ?, ?, ?)", (1, 10, 0, "v1"))
|
|
|
|
records = await service.get_records_bulk("lora", model_ids)
|
|
|
|
assert 1 in records
|
|
assert records[1].model_id == 1
|
|
assert len(records) == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_allows_duplicate_version_ids_across_models(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=0)
|
|
raw_data = [
|
|
{"civitai": {"modelId": 1, "id": 42}},
|
|
{"civitai": {"modelId": 2, "id": 42}},
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{
|
|
"id": 42,
|
|
"name": "shared",
|
|
"baseModel": "SD15",
|
|
"publishedAt": "2024-03-01T00:00:00Z",
|
|
"files": [{"sizeKB": 256}],
|
|
"images": [],
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
results = await service.refresh_for_model_type("lora", scanner, provider)
|
|
|
|
assert set(results.keys()) == {1, 2}
|
|
assert results[1].version_ids == [42]
|
|
assert results[2].version_ids == [42]
|
|
|
|
with sqlite3.connect(str(db_path)) as conn:
|
|
count = conn.execute(
|
|
"SELECT COUNT(*) FROM model_update_versions WHERE version_id = 42"
|
|
).fetchone()[0]
|
|
|
|
assert count == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_rewrites_remote_preview_urls(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=1)
|
|
raw_data = [{"civitai": {"modelId": 7, "id": 71}}]
|
|
scanner = DummyScanner(raw_data)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{
|
|
"id": 71,
|
|
"files": [],
|
|
"images": [
|
|
{
|
|
"url": "https://image.civitai.com/high/original=true/sample.png",
|
|
"nsfwLevel": 6,
|
|
"type": "image",
|
|
},
|
|
{
|
|
"url": "https://image.civitai.com/safe/original=true/preview.png",
|
|
"nsfwLevel": 1,
|
|
"type": "image",
|
|
},
|
|
],
|
|
}
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type("lora", scanner, provider)
|
|
record = await service.get_record("lora", 7)
|
|
|
|
assert record is not None
|
|
assert record.versions
|
|
preview_url = record.versions[0].preview_url
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_in_library_versions_populates_metadata(tmp_path):
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path))
|
|
|
|
version_info = {
|
|
"id": 123,
|
|
"name": "v1.0",
|
|
"baseModel": "SD 1.5",
|
|
"publishedAt": "2024-03-01T00:00:00Z",
|
|
"files": [{"sizeKB": 1024, "type": "Model", "primary": True}],
|
|
"images": [{"url": "https://example.com/preview.png"}],
|
|
}
|
|
|
|
await service.update_in_library_versions("lora", 1, [123], version_info=version_info)
|
|
record = await service.get_record("lora", 1)
|
|
|
|
assert record is not None
|
|
assert len(record.versions) == 1
|
|
version = record.versions[0]
|
|
assert version.version_id == 123
|
|
assert version.name == "v1.0"
|
|
assert version.base_model == "SD 1.5"
|
|
assert version.size_bytes == 1024 * 1024
|
|
assert version.preview_url == "https://example.com/preview.png"
|
|
assert version.is_in_library is True
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_refresh_folder_filter_considers_cross_folder_versions(tmp_path):
|
|
"""When refreshing by folder, versions in other folders must still be
|
|
considered in-library so they aren't reported as available updates."""
|
|
db_path = tmp_path / "updates.sqlite"
|
|
service = ModelUpdateService(str(db_path), ttl_seconds=0)
|
|
# Same model (modelId=1) in two folders with different versions
|
|
raw_data = [
|
|
{"civitai": {"modelId": 1, "id": 11}, "folder": "folder_a"},
|
|
{"civitai": {"modelId": 1, "id": 15}, "folder": "folder_b"},
|
|
]
|
|
scanner = DummyScanner(raw_data)
|
|
# Remote offers: 11 (in folder_a), 15 (in folder_b), 20 (truly new)
|
|
provider = DummyProvider(
|
|
{
|
|
"modelVersions": [
|
|
{"id": 11, "files": [], "images": []},
|
|
{"id": 15, "files": [], "images": []},
|
|
{"id": 20, "files": [], "images": []},
|
|
]
|
|
}
|
|
)
|
|
|
|
await service.refresh_for_model_type(
|
|
"lora", scanner, provider, folder_path="folder_a",
|
|
)
|
|
record = await service.get_record("lora", 1)
|
|
|
|
assert record is not None
|
|
|
|
# Version 15 is in folder_b — must be in_library even when filtering by folder_a
|
|
v15 = next(v for v in record.versions if v.version_id == 15)
|
|
assert v15.is_in_library is True
|
|
|
|
# Version 20 is truly new — should not be in_library
|
|
v20 = next(v for v in record.versions if v.version_id == 20)
|
|
assert v20.is_in_library is False
|
|
|
|
# has_update must be True (version 20 > max_in_library=15)
|
|
assert record.has_update() is True
|
|
|