mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add update service dependency and has_update filter
- Pass ModelUpdateService to CheckpointService, EmbeddingService, and LoraService constructors - Add has_update query parameter filter to model listing handler - Update BaseModelService to accept optional update_service parameter These changes enable model update functionality across different model types and provide filtering capability for models with available updates.
This commit is contained in:
@@ -67,6 +67,19 @@ class StubSearchStrategy:
|
||||
return list(self.search_result)
|
||||
|
||||
|
||||
class StubUpdateService:
|
||||
def __init__(self, decisions):
|
||||
self.decisions = dict(decisions)
|
||||
self.calls = []
|
||||
|
||||
async def has_update(self, model_type, model_id):
|
||||
self.calls.append((model_type, model_id))
|
||||
result = self.decisions.get(model_id, False)
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_uses_injected_collaborators():
|
||||
data = [
|
||||
@@ -272,3 +285,111 @@ async def test_get_paginated_data_paginates_without_search():
|
||||
assert response["page"] == 2
|
||||
assert response["page_size"] == 2
|
||||
assert response["total_pages"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_filters_by_update_status():
|
||||
items = [
|
||||
{"model_name": "A", "civitai": {"modelId": 1}},
|
||||
{"model_name": "B", "civitai": {"modelId": 2}},
|
||||
{"model_name": "C", "civitai": {"modelId": 3}},
|
||||
]
|
||||
repository = StubRepository(items)
|
||||
filter_set = PassThroughFilterSet()
|
||||
search_strategy = NoSearchStrategy()
|
||||
update_service = StubUpdateService({1: True, 2: False, 3: True})
|
||||
settings = StubSettings({})
|
||||
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=object(),
|
||||
metadata_class=BaseModelMetadata,
|
||||
cache_repository=repository,
|
||||
filter_set=filter_set,
|
||||
search_strategy=search_strategy,
|
||||
settings_provider=settings,
|
||||
update_service=update_service,
|
||||
)
|
||||
|
||||
response = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=5,
|
||||
sort_by="name:asc",
|
||||
has_update=True,
|
||||
)
|
||||
|
||||
assert update_service.calls == [("stub", 1), ("stub", 2), ("stub", 3)]
|
||||
assert response["items"] == [items[0], items[2]]
|
||||
assert response["total"] == 2
|
||||
assert response["page"] == 1
|
||||
assert response["page_size"] == 5
|
||||
assert response["total_pages"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_has_update_without_service_returns_empty():
|
||||
items = [
|
||||
{"model_name": "A", "civitai": {"modelId": 1}},
|
||||
{"model_name": "B", "civitai": {"modelId": 2}},
|
||||
]
|
||||
repository = StubRepository(items)
|
||||
filter_set = PassThroughFilterSet()
|
||||
search_strategy = NoSearchStrategy()
|
||||
settings = StubSettings({})
|
||||
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=object(),
|
||||
metadata_class=BaseModelMetadata,
|
||||
cache_repository=repository,
|
||||
filter_set=filter_set,
|
||||
search_strategy=search_strategy,
|
||||
settings_provider=settings,
|
||||
)
|
||||
|
||||
response = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_by="name:asc",
|
||||
has_update=True,
|
||||
)
|
||||
|
||||
assert response["items"] == []
|
||||
assert response["total"] == 0
|
||||
assert response["total_pages"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_skips_items_when_update_check_fails():
|
||||
items = [
|
||||
{"model_name": "A", "civitai": {"modelId": 1}},
|
||||
{"model_name": "B", "civitai": {"modelId": 2}},
|
||||
]
|
||||
repository = StubRepository(items)
|
||||
filter_set = PassThroughFilterSet()
|
||||
search_strategy = NoSearchStrategy()
|
||||
update_service = StubUpdateService({1: True, 2: RuntimeError("boom")})
|
||||
settings = StubSettings({})
|
||||
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=object(),
|
||||
metadata_class=BaseModelMetadata,
|
||||
cache_repository=repository,
|
||||
filter_set=filter_set,
|
||||
search_strategy=search_strategy,
|
||||
settings_provider=settings,
|
||||
update_service=update_service,
|
||||
)
|
||||
|
||||
response = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
sort_by="name:asc",
|
||||
has_update=True,
|
||||
)
|
||||
|
||||
assert update_service.calls == [("stub", 1), ("stub", 2)]
|
||||
assert response["items"] == [items[0]]
|
||||
assert response["total"] == 1
|
||||
assert response["total_pages"] == 1
|
||||
|
||||
Reference in New Issue
Block a user