mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
- Pass ModelUpdateService to CheckpointService, EmbeddingService, and LoraService constructors - Add has_update query parameter filter to model listing handler - Update BaseModelService to accept optional update_service parameter These changes enable model update functionality across different model types and provide filtering capability for models with available updates.
396 lines
11 KiB
Python
396 lines
11 KiB
Python
import pytest
|
|
|
|
from py.services.base_model_service import BaseModelService
|
|
from py.services.model_query import (
|
|
ModelCacheRepository,
|
|
ModelFilterSet,
|
|
SearchStrategy,
|
|
SortParams,
|
|
)
|
|
from py.utils.models import BaseModelMetadata
|
|
|
|
|
|
class StubSettings:
|
|
def __init__(self, values):
|
|
self._values = dict(values)
|
|
|
|
def get(self, key, default=None):
|
|
return self._values.get(key, default)
|
|
|
|
|
|
class DummyService(BaseModelService):
|
|
async def format_response(self, model_data):
|
|
return model_data
|
|
|
|
|
|
class StubRepository:
|
|
def __init__(self, data):
|
|
self._data = list(data)
|
|
self.parse_sort_calls = []
|
|
self.fetch_sorted_calls = []
|
|
|
|
def parse_sort(self, sort_by):
|
|
params = ModelCacheRepository.parse_sort(sort_by)
|
|
self.parse_sort_calls.append(sort_by)
|
|
return params
|
|
|
|
async def fetch_sorted(self, params):
|
|
self.fetch_sorted_calls.append(params)
|
|
return list(self._data)
|
|
|
|
|
|
class StubFilterSet:
|
|
def __init__(self, result):
|
|
self.result = list(result)
|
|
self.calls = []
|
|
|
|
def apply(self, data, criteria):
|
|
self.calls.append((list(data), criteria))
|
|
return list(self.result)
|
|
|
|
|
|
class StubSearchStrategy:
|
|
def __init__(self, search_result):
|
|
self.search_result = list(search_result)
|
|
self.normalize_calls = []
|
|
self.apply_calls = []
|
|
|
|
def normalize_options(self, options):
|
|
self.normalize_calls.append(options)
|
|
normalized = {"recursive": True}
|
|
if options:
|
|
normalized.update(options)
|
|
return normalized
|
|
|
|
def apply(self, data, search_term, options, fuzzy):
|
|
self.apply_calls.append((list(data), search_term, options, fuzzy))
|
|
return list(self.search_result)
|
|
|
|
|
|
class StubUpdateService:
|
|
def __init__(self, decisions):
|
|
self.decisions = dict(decisions)
|
|
self.calls = []
|
|
|
|
async def has_update(self, model_type, model_id):
|
|
self.calls.append((model_type, model_id))
|
|
result = self.decisions.get(model_id, False)
|
|
if isinstance(result, Exception):
|
|
raise result
|
|
return result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_paginated_data_uses_injected_collaborators():
|
|
data = [
|
|
{"model_name": "Alpha", "folder": "root"},
|
|
{"model_name": "Beta", "folder": "root"},
|
|
]
|
|
repository = StubRepository(data)
|
|
filter_set = StubFilterSet([{"model_name": "Filtered"}])
|
|
search_strategy = StubSearchStrategy([{"model_name": "SearchResult"}])
|
|
settings = StubSettings({})
|
|
|
|
service = DummyService(
|
|
model_type="stub",
|
|
scanner=object(),
|
|
metadata_class=BaseModelMetadata,
|
|
cache_repository=repository,
|
|
filter_set=filter_set,
|
|
search_strategy=search_strategy,
|
|
settings_provider=settings,
|
|
)
|
|
|
|
response = await service.get_paginated_data(
|
|
page=1,
|
|
page_size=5,
|
|
sort_by="name:desc",
|
|
folder="root",
|
|
search="query",
|
|
fuzzy_search=True,
|
|
base_models=["base"],
|
|
tags=["tag"],
|
|
search_options={"recursive": False},
|
|
favorites_only=True,
|
|
)
|
|
|
|
assert repository.parse_sort_calls == ["name:desc"]
|
|
assert repository.fetch_sorted_calls and isinstance(repository.fetch_sorted_calls[0], SortParams)
|
|
sort_params = repository.fetch_sorted_calls[0]
|
|
assert sort_params.key == "name" and sort_params.order == "desc"
|
|
|
|
assert filter_set.calls, "FilterSet should be invoked"
|
|
call_data, criteria = filter_set.calls[0]
|
|
assert call_data == data
|
|
assert criteria.folder == "root"
|
|
assert criteria.base_models == ["base"]
|
|
assert criteria.tags == ["tag"]
|
|
assert criteria.favorites_only is True
|
|
assert criteria.search_options.get("recursive") is False
|
|
|
|
assert search_strategy.normalize_calls == [{"recursive": False}, {"recursive": False}]
|
|
assert search_strategy.apply_calls == [([{"model_name": "Filtered"}], "query", {"recursive": False}, True)]
|
|
|
|
assert response["items"] == search_strategy.search_result
|
|
assert response["total"] == len(search_strategy.search_result)
|
|
assert response["page"] == 1
|
|
assert response["page_size"] == 5
|
|
|
|
|
|
class FakeCache:
|
|
def __init__(self, items):
|
|
self.items = list(items)
|
|
|
|
async def get_sorted_data(self, sort_key, order):
|
|
if sort_key == "name":
|
|
data = sorted(self.items, key=lambda x: x["model_name"].lower())
|
|
if order == "desc":
|
|
data.reverse()
|
|
else:
|
|
data = list(self.items)
|
|
return data
|
|
|
|
|
|
class FakeScanner:
|
|
def __init__(self, cache):
|
|
self._cache = cache
|
|
|
|
async def get_cached_data(self, *_, **__):
|
|
return self._cache
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_paginated_data_filters_and_searches_combination():
|
|
items = [
|
|
{
|
|
"model_name": "Alpha",
|
|
"file_name": "alpha.safetensors",
|
|
"folder": "root/sub",
|
|
"tags": ["tag1"],
|
|
"base_model": "v1",
|
|
"favorite": True,
|
|
"preview_nsfw_level": 0,
|
|
},
|
|
{
|
|
"model_name": "Beta",
|
|
"file_name": "beta.safetensors",
|
|
"folder": "root",
|
|
"tags": ["tag2"],
|
|
"base_model": "v2",
|
|
"favorite": False,
|
|
"preview_nsfw_level": 999,
|
|
},
|
|
{
|
|
"model_name": "Gamma",
|
|
"file_name": "gamma.safetensors",
|
|
"folder": "root/sub2",
|
|
"tags": ["tag1", "tag3"],
|
|
"base_model": "v1",
|
|
"favorite": True,
|
|
"preview_nsfw_level": 0,
|
|
"civitai": {"creator": {"username": "artist"}},
|
|
},
|
|
]
|
|
|
|
cache = FakeCache(items)
|
|
scanner = FakeScanner(cache)
|
|
settings = StubSettings({"show_only_sfw": True})
|
|
|
|
service = DummyService(
|
|
model_type="stub",
|
|
scanner=scanner,
|
|
metadata_class=BaseModelMetadata,
|
|
cache_repository=ModelCacheRepository(scanner),
|
|
filter_set=ModelFilterSet(settings),
|
|
search_strategy=SearchStrategy(),
|
|
settings_provider=settings,
|
|
)
|
|
|
|
response = await service.get_paginated_data(
|
|
page=1,
|
|
page_size=1,
|
|
sort_by="name:asc",
|
|
folder="root",
|
|
search="artist",
|
|
base_models=["v1"],
|
|
tags=["tag1"],
|
|
search_options={"creator": True, "tags": True},
|
|
favorites_only=True,
|
|
)
|
|
|
|
assert response["items"] == [items[2]]
|
|
assert response["total"] == 1
|
|
assert response["page"] == 1
|
|
assert response["page_size"] == 1
|
|
assert response["total_pages"] == 1
|
|
|
|
|
|
class PassThroughFilterSet:
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def apply(self, data, criteria):
|
|
self.calls.append(criteria)
|
|
return list(data)
|
|
|
|
|
|
class NoSearchStrategy:
|
|
def __init__(self):
|
|
self.normalize_calls = []
|
|
self.apply_called = False
|
|
|
|
def normalize_options(self, options):
|
|
self.normalize_calls.append(options)
|
|
return {"recursive": True}
|
|
|
|
def apply(self, *args, **kwargs):
|
|
self.apply_called = True
|
|
pytest.fail("Search should not be invoked when no search term is provided")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_paginated_data_paginates_without_search():
|
|
items = [
|
|
{"model_name": name, "folder": "root"}
|
|
for name in ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"]
|
|
]
|
|
|
|
repository = StubRepository(items)
|
|
filter_set = PassThroughFilterSet()
|
|
search_strategy = NoSearchStrategy()
|
|
settings = StubSettings({})
|
|
|
|
service = DummyService(
|
|
model_type="stub",
|
|
scanner=object(),
|
|
metadata_class=BaseModelMetadata,
|
|
cache_repository=repository,
|
|
filter_set=filter_set,
|
|
search_strategy=search_strategy,
|
|
settings_provider=settings,
|
|
)
|
|
|
|
response = await service.get_paginated_data(
|
|
page=2,
|
|
page_size=2,
|
|
sort_by="name:asc",
|
|
)
|
|
|
|
assert repository.parse_sort_calls == ["name:asc"]
|
|
assert len(repository.fetch_sorted_calls) == 1
|
|
assert filter_set.calls and filter_set.calls[0].favorites_only is False
|
|
assert search_strategy.apply_called is False
|
|
assert response["items"] == items[2:4]
|
|
assert response["total"] == len(items)
|
|
assert response["page"] == 2
|
|
assert response["page_size"] == 2
|
|
assert response["total_pages"] == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_paginated_data_filters_by_update_status():
|
|
items = [
|
|
{"model_name": "A", "civitai": {"modelId": 1}},
|
|
{"model_name": "B", "civitai": {"modelId": 2}},
|
|
{"model_name": "C", "civitai": {"modelId": 3}},
|
|
]
|
|
repository = StubRepository(items)
|
|
filter_set = PassThroughFilterSet()
|
|
search_strategy = NoSearchStrategy()
|
|
update_service = StubUpdateService({1: True, 2: False, 3: True})
|
|
settings = StubSettings({})
|
|
|
|
service = DummyService(
|
|
model_type="stub",
|
|
scanner=object(),
|
|
metadata_class=BaseModelMetadata,
|
|
cache_repository=repository,
|
|
filter_set=filter_set,
|
|
search_strategy=search_strategy,
|
|
settings_provider=settings,
|
|
update_service=update_service,
|
|
)
|
|
|
|
response = await service.get_paginated_data(
|
|
page=1,
|
|
page_size=5,
|
|
sort_by="name:asc",
|
|
has_update=True,
|
|
)
|
|
|
|
assert update_service.calls == [("stub", 1), ("stub", 2), ("stub", 3)]
|
|
assert response["items"] == [items[0], items[2]]
|
|
assert response["total"] == 2
|
|
assert response["page"] == 1
|
|
assert response["page_size"] == 5
|
|
assert response["total_pages"] == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_paginated_data_has_update_without_service_returns_empty():
|
|
items = [
|
|
{"model_name": "A", "civitai": {"modelId": 1}},
|
|
{"model_name": "B", "civitai": {"modelId": 2}},
|
|
]
|
|
repository = StubRepository(items)
|
|
filter_set = PassThroughFilterSet()
|
|
search_strategy = NoSearchStrategy()
|
|
settings = StubSettings({})
|
|
|
|
service = DummyService(
|
|
model_type="stub",
|
|
scanner=object(),
|
|
metadata_class=BaseModelMetadata,
|
|
cache_repository=repository,
|
|
filter_set=filter_set,
|
|
search_strategy=search_strategy,
|
|
settings_provider=settings,
|
|
)
|
|
|
|
response = await service.get_paginated_data(
|
|
page=1,
|
|
page_size=10,
|
|
sort_by="name:asc",
|
|
has_update=True,
|
|
)
|
|
|
|
assert response["items"] == []
|
|
assert response["total"] == 0
|
|
assert response["total_pages"] == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_paginated_data_skips_items_when_update_check_fails():
|
|
items = [
|
|
{"model_name": "A", "civitai": {"modelId": 1}},
|
|
{"model_name": "B", "civitai": {"modelId": 2}},
|
|
]
|
|
repository = StubRepository(items)
|
|
filter_set = PassThroughFilterSet()
|
|
search_strategy = NoSearchStrategy()
|
|
update_service = StubUpdateService({1: True, 2: RuntimeError("boom")})
|
|
settings = StubSettings({})
|
|
|
|
service = DummyService(
|
|
model_type="stub",
|
|
scanner=object(),
|
|
metadata_class=BaseModelMetadata,
|
|
cache_repository=repository,
|
|
filter_set=filter_set,
|
|
search_strategy=search_strategy,
|
|
settings_provider=settings,
|
|
update_service=update_service,
|
|
)
|
|
|
|
response = await service.get_paginated_data(
|
|
page=1,
|
|
page_size=10,
|
|
sort_by="name:asc",
|
|
has_update=True,
|
|
)
|
|
|
|
assert update_service.calls == [("stub", 1), ("stub", 2)]
|
|
assert response["items"] == [items[0]]
|
|
assert response["total"] == 1
|
|
assert response["total_pages"] == 1
|