From a5b2e9b0bf753e3f9d6a922750fa1f0bb3764e2e Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 15 Oct 2025 17:25:16 +0800 Subject: [PATCH] feat: add update service dependency and has_update filter - Pass ModelUpdateService to CheckpointService, EmbeddingService, and LoraService constructors - Add has_update query parameter filter to model listing handler - Update BaseModelService to accept optional update_service parameter These changes enable model update functionality across different model types and provide filtering capability for models with available updates. --- py/routes/checkpoint_routes.py | 4 +- py/routes/embedding_routes.py | 2 +- py/routes/handlers/model_handlers.py | 8 +- py/routes/lora_routes.py | 2 +- py/services/base_model_service.py | 68 +++++++++++- py/services/checkpoint_service.py | 7 +- py/services/embedding_service.py | 5 +- py/services/lora_service.py | 7 +- tests/services/test_base_model_service.py | 121 ++++++++++++++++++++++ 9 files changed, 209 insertions(+), 15 deletions(-) diff --git a/py/routes/checkpoint_routes.py b/py/routes/checkpoint_routes.py index 5a17d79a..16ebd338 100644 --- a/py/routes/checkpoint_routes.py +++ b/py/routes/checkpoint_routes.py @@ -20,8 +20,8 @@ class CheckpointRoutes(BaseModelRoutes): async def initialize_services(self): """Initialize services from ServiceRegistry""" checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() - self.service = CheckpointService(checkpoint_scanner) update_service = await ServiceRegistry.get_model_update_service() + self.service = CheckpointService(checkpoint_scanner, update_service=update_service) self.set_model_update_service(update_service) # Attach service dependencies @@ -95,4 +95,4 @@ class CheckpointRoutes(BaseModelRoutes): return web.json_response({ "success": False, "error": str(e) - }, status=500) \ No newline at end of file + }, status=500) diff --git a/py/routes/embedding_routes.py b/py/routes/embedding_routes.py index 80b15525..5268dc4a 100644 --- a/py/routes/embedding_routes.py +++ b/py/routes/embedding_routes.py @@ -19,8 +19,8 @@ class EmbeddingRoutes(BaseModelRoutes): async def initialize_services(self): """Initialize services from ServiceRegistry""" embedding_scanner = await ServiceRegistry.get_embedding_scanner() - self.service = EmbeddingService(embedding_scanner) update_service = await ServiceRegistry.get_model_update_service() + self.service = EmbeddingService(embedding_scanner, update_service=update_service) self.set_model_update_service(update_service) # Attach service dependencies diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 4df9941a..b1254e0b 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -6,7 +6,7 @@ import json import logging import os from dataclasses import dataclass -from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional +from typing import Awaitable, Callable, Dict, Iterable, List, Mapping, Optional from aiohttp import web import jinja2 @@ -166,6 +166,11 @@ class ModelListingHandler: except (json.JSONDecodeError, TypeError): pass + has_update = request.query.get("has_update", "false") + has_update_filter = ( + has_update.lower() in {"1", "true", "yes"} if isinstance(has_update, str) else False + ) + return { "page": page, "page_size": page_size, @@ -178,6 +183,7 @@ class ModelListingHandler: "search_options": search_options, "hash_filters": hash_filters, "favorites_only": favorites_only, + "has_update": has_update_filter, **self._parse_specific_params(request), } diff --git a/py/routes/lora_routes.py b/py/routes/lora_routes.py index c966cd5f..c53e88a4 100644 --- a/py/routes/lora_routes.py +++ b/py/routes/lora_routes.py @@ -23,8 +23,8 @@ class LoraRoutes(BaseModelRoutes): async def initialize_services(self): """Initialize services from ServiceRegistry""" lora_scanner = await ServiceRegistry.get_lora_scanner() - self.service = LoraService(lora_scanner) update_service = await ServiceRegistry.get_model_update_service() + self.service = LoraService(lora_scanner, update_service=update_service) self.set_model_update_service(update_service) # Attach service dependencies diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 728b789d..e17a4dab 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Type +import asyncio +from typing import Dict, List, Optional, Type, TYPE_CHECKING import logging import os @@ -10,6 +11,9 @@ from .settings_manager import get_settings_manager logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from .model_update_service import ModelUpdateService + class BaseModelService(ABC): """Base service class for all model types""" @@ -23,6 +27,7 @@ class BaseModelService(ABC): filter_set: Optional[ModelFilterSet] = None, search_strategy: Optional[SearchStrategy] = None, settings_provider: Optional[SettingsProvider] = None, + update_service: Optional["ModelUpdateService"] = None, ): """Initialize the service. @@ -34,6 +39,7 @@ class BaseModelService(ABC): filter_set: Filter component controlling folder/tag/favorites logic. search_strategy: Search component for fuzzy/text matching. settings_provider: Settings object; defaults to the global settings manager. + update_service: Service used to determine whether models have remote updates available. """ self.model_type = model_type self.scanner = scanner @@ -42,6 +48,7 @@ class BaseModelService(ABC): self.cache_repository = cache_repository or ModelCacheRepository(scanner) self.filter_set = filter_set or ModelFilterSet(self.settings) self.search_strategy = search_strategy or SearchStrategy() + self.update_service = update_service async def get_paginated_data( self, @@ -56,6 +63,7 @@ class BaseModelService(ABC): search_options: dict = None, hash_filters: dict = None, favorites_only: bool = False, + has_update: bool = False, **kwargs, ) -> Dict: """Get paginated and filtered model data""" @@ -85,6 +93,9 @@ class BaseModelService(ABC): filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) + if has_update: + filtered_data = await self._apply_update_filter(filtered_data) + return self._paginate(filtered_data, page, page_size) @@ -144,6 +155,59 @@ class BaseModelService(ABC): async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: """Apply model-specific filters - to be overridden by subclasses if needed""" return data + + async def _apply_update_filter(self, data: List[Dict]) -> List[Dict]: + """Filter models to those with remote updates available when requested.""" + if not data: + return [] + if self.update_service is None: + logger.warning( + "Requested has_update filter for %s models but update service is unavailable", + self.model_type, + ) + return [] + + candidates: List[tuple[Dict, int]] = [] + for item in data: + model_id = self._extract_model_id(item) + if model_id is not None: + candidates.append((item, model_id)) + + if not candidates: + return [] + + tasks = [ + self.update_service.has_update(self.model_type, model_id) + for _, model_id in candidates + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + + filtered: List[Dict] = [] + for (item, model_id), result in zip(candidates, results): + if isinstance(result, Exception): + logger.error( + "Failed to resolve update status for model %s (%s): %s", + model_id, + self.model_type, + result, + ) + continue + if result: + filtered.append(item) + return filtered + + @staticmethod + def _extract_model_id(item: Dict) -> Optional[int]: + civitai = item.get('civitai') if isinstance(item, dict) else None + if not isinstance(civitai, dict): + return None + try: + value = civitai.get('modelId') + if value is None: + return None + return int(value) + except (TypeError, ValueError): + return None def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict: """Apply pagination to filtered data""" @@ -373,4 +437,4 @@ class BaseModelService(ABC): x.lower() # Then alphabetically )) - return matching_paths[:limit] \ No newline at end of file + return matching_paths[:limit] diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py index 2f7b8a96..55cb13cd 100644 --- a/py/services/checkpoint_service.py +++ b/py/services/checkpoint_service.py @@ -11,13 +11,14 @@ logger = logging.getLogger(__name__) class CheckpointService(BaseModelService): """Checkpoint-specific service implementation""" - def __init__(self, scanner): + def __init__(self, scanner, update_service=None): """Initialize Checkpoint service Args: scanner: Checkpoint scanner instance + update_service: Optional service for remote update tracking. """ - super().__init__("checkpoint", scanner, CheckpointMetadata) + super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service) async def format_response(self, checkpoint_data: Dict) -> Dict: """Format Checkpoint data for API response""" @@ -46,4 +47,4 @@ class CheckpointService(BaseModelService): def find_duplicate_filenames(self) -> Dict: """Find Checkpoints with conflicting filenames""" - return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file + return self.scanner._hash_index.get_duplicate_filenames() diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py index 46396fc5..30a742df 100644 --- a/py/services/embedding_service.py +++ b/py/services/embedding_service.py @@ -11,13 +11,14 @@ logger = logging.getLogger(__name__) class EmbeddingService(BaseModelService): """Embedding-specific service implementation""" - def __init__(self, scanner): + def __init__(self, scanner, update_service=None): """Initialize Embedding service Args: scanner: Embedding scanner instance + update_service: Optional service for remote update tracking. """ - super().__init__("embedding", scanner, EmbeddingMetadata) + super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service) async def format_response(self, embedding_data: Dict) -> Dict: """Format Embedding data for API response""" diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 551c4d3c..5a0e27f6 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -11,13 +11,14 @@ logger = logging.getLogger(__name__) class LoraService(BaseModelService): """LoRA-specific service implementation""" - def __init__(self, scanner): + def __init__(self, scanner, update_service=None): """Initialize LoRA service Args: scanner: LoRA scanner instance + update_service: Optional service for remote update tracking. """ - super().__init__("lora", scanner, LoraMetadata) + super().__init__("lora", scanner, LoraMetadata, update_service=update_service) async def format_response(self, lora_data: Dict) -> Dict: """Format LoRA data for API response""" @@ -178,4 +179,4 @@ class LoraService(BaseModelService): def find_duplicate_filenames(self) -> Dict: """Find LoRAs with conflicting filenames""" - return self.scanner._hash_index.get_duplicate_filenames() \ No newline at end of file + return self.scanner._hash_index.get_duplicate_filenames() diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index ea6a126b..23b0b1d7 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -67,6 +67,19 @@ class StubSearchStrategy: return list(self.search_result) +class StubUpdateService: + def __init__(self, decisions): + self.decisions = dict(decisions) + self.calls = [] + + async def has_update(self, model_type, model_id): + self.calls.append((model_type, model_id)) + result = self.decisions.get(model_id, False) + if isinstance(result, Exception): + raise result + return result + + @pytest.mark.asyncio async def test_get_paginated_data_uses_injected_collaborators(): data = [ @@ -272,3 +285,111 @@ async def test_get_paginated_data_paginates_without_search(): assert response["page"] == 2 assert response["page_size"] == 2 assert response["total_pages"] == 3 + + +@pytest.mark.asyncio +async def test_get_paginated_data_filters_by_update_status(): + items = [ + {"model_name": "A", "civitai": {"modelId": 1}}, + {"model_name": "B", "civitai": {"modelId": 2}}, + {"model_name": "C", "civitai": {"modelId": 3}}, + ] + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + update_service = StubUpdateService({1: True, 2: False, 3: True}) + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + update_service=update_service, + ) + + response = await service.get_paginated_data( + page=1, + page_size=5, + sort_by="name:asc", + has_update=True, + ) + + assert update_service.calls == [("stub", 1), ("stub", 2), ("stub", 3)] + assert response["items"] == [items[0], items[2]] + assert response["total"] == 2 + assert response["page"] == 1 + assert response["page_size"] == 5 + assert response["total_pages"] == 1 + + +@pytest.mark.asyncio +async def test_get_paginated_data_has_update_without_service_returns_empty(): + items = [ + {"model_name": "A", "civitai": {"modelId": 1}}, + {"model_name": "B", "civitai": {"modelId": 2}}, + ] + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=1, + page_size=10, + sort_by="name:asc", + has_update=True, + ) + + assert response["items"] == [] + assert response["total"] == 0 + assert response["total_pages"] == 0 + + +@pytest.mark.asyncio +async def test_get_paginated_data_skips_items_when_update_check_fails(): + items = [ + {"model_name": "A", "civitai": {"modelId": 1}}, + {"model_name": "B", "civitai": {"modelId": 2}}, + ] + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + update_service = StubUpdateService({1: True, 2: RuntimeError("boom")}) + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + update_service=update_service, + ) + + response = await service.get_paginated_data( + page=1, + page_size=10, + sort_by="name:asc", + has_update=True, + ) + + assert update_service.calls == [("stub", 1), ("stub", 2)] + assert response["items"] == [items[0]] + assert response["total"] == 1 + assert response["total_pages"] == 1