feat: add update service dependency and has_update filter

- Pass ModelUpdateService to CheckpointService, EmbeddingService, and LoraService constructors
- Add has_update query parameter filter to model listing handler
- Update BaseModelService to accept optional update_service parameter

These changes enable model update functionality across different model types and provide filtering capability for models with available updates.
This commit is contained in:
Will Miao
2025-10-15 17:25:16 +08:00
parent 5a6ff444b9
commit a5b2e9b0bf
9 changed files with 209 additions and 15 deletions

View File

@@ -67,6 +67,19 @@ class StubSearchStrategy:
return list(self.search_result)
class StubUpdateService:
def __init__(self, decisions):
self.decisions = dict(decisions)
self.calls = []
async def has_update(self, model_type, model_id):
self.calls.append((model_type, model_id))
result = self.decisions.get(model_id, False)
if isinstance(result, Exception):
raise result
return result
@pytest.mark.asyncio
async def test_get_paginated_data_uses_injected_collaborators():
data = [
@@ -272,3 +285,111 @@ async def test_get_paginated_data_paginates_without_search():
assert response["page"] == 2
assert response["page_size"] == 2
assert response["total_pages"] == 3
@pytest.mark.asyncio
async def test_get_paginated_data_filters_by_update_status():
items = [
{"model_name": "A", "civitai": {"modelId": 1}},
{"model_name": "B", "civitai": {"modelId": 2}},
{"model_name": "C", "civitai": {"modelId": 3}},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
update_service = StubUpdateService({1: True, 2: False, 3: True})
settings = StubSettings({})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
update_service=update_service,
)
response = await service.get_paginated_data(
page=1,
page_size=5,
sort_by="name:asc",
has_update=True,
)
assert update_service.calls == [("stub", 1), ("stub", 2), ("stub", 3)]
assert response["items"] == [items[0], items[2]]
assert response["total"] == 2
assert response["page"] == 1
assert response["page_size"] == 5
assert response["total_pages"] == 1
@pytest.mark.asyncio
async def test_get_paginated_data_has_update_without_service_returns_empty():
items = [
{"model_name": "A", "civitai": {"modelId": 1}},
{"model_name": "B", "civitai": {"modelId": 2}},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
settings = StubSettings({})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
)
response = await service.get_paginated_data(
page=1,
page_size=10,
sort_by="name:asc",
has_update=True,
)
assert response["items"] == []
assert response["total"] == 0
assert response["total_pages"] == 0
@pytest.mark.asyncio
async def test_get_paginated_data_skips_items_when_update_check_fails():
items = [
{"model_name": "A", "civitai": {"modelId": 1}},
{"model_name": "B", "civitai": {"modelId": 2}},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
update_service = StubUpdateService({1: True, 2: RuntimeError("boom")})
settings = StubSettings({})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
update_service=update_service,
)
response = await service.get_paginated_data(
page=1,
page_size=10,
sort_by="name:asc",
has_update=True,
)
assert update_service.calls == [("stub", 1), ("stub", 2)]
assert response["items"] == [items[0]]
assert response["total"] == 1
assert response["total_pages"] == 1