feat: add update service dependency and has_update filter

- Pass ModelUpdateService to CheckpointService, EmbeddingService, and LoraService constructors
- Add has_update query parameter filter to model listing handler
- Update BaseModelService to accept optional update_service parameter

These changes enable model update functionality across different model types and provide filtering capability for models with available updates.
This commit is contained in:
Will Miao
2025-10-15 17:25:16 +08:00
parent 5a6ff444b9
commit a5b2e9b0bf
9 changed files with 209 additions and 15 deletions

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Type
import asyncio
from typing import Dict, List, Optional, Type, TYPE_CHECKING
import logging
import os
@@ -10,6 +11,9 @@ from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from .model_update_service import ModelUpdateService
class BaseModelService(ABC):
"""Base service class for all model types"""
@@ -23,6 +27,7 @@ class BaseModelService(ABC):
filter_set: Optional[ModelFilterSet] = None,
search_strategy: Optional[SearchStrategy] = None,
settings_provider: Optional[SettingsProvider] = None,
update_service: Optional["ModelUpdateService"] = None,
):
"""Initialize the service.
@@ -34,6 +39,7 @@ class BaseModelService(ABC):
filter_set: Filter component controlling folder/tag/favorites logic.
search_strategy: Search component for fuzzy/text matching.
settings_provider: Settings object; defaults to the global settings manager.
update_service: Service used to determine whether models have remote updates available.
"""
self.model_type = model_type
self.scanner = scanner
@@ -42,6 +48,7 @@ class BaseModelService(ABC):
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy()
self.update_service = update_service
async def get_paginated_data(
self,
@@ -56,6 +63,7 @@ class BaseModelService(ABC):
search_options: dict = None,
hash_filters: dict = None,
favorites_only: bool = False,
has_update: bool = False,
**kwargs,
) -> Dict:
"""Get paginated and filtered model data"""
@@ -85,6 +93,9 @@ class BaseModelService(ABC):
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
if has_update:
filtered_data = await self._apply_update_filter(filtered_data)
return self._paginate(filtered_data, page, page_size)
@@ -144,6 +155,59 @@ class BaseModelService(ABC):
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed"""
return data
async def _apply_update_filter(self, data: List[Dict]) -> List[Dict]:
"""Filter models to those with remote updates available when requested."""
if not data:
return []
if self.update_service is None:
logger.warning(
"Requested has_update filter for %s models but update service is unavailable",
self.model_type,
)
return []
candidates: List[tuple[Dict, int]] = []
for item in data:
model_id = self._extract_model_id(item)
if model_id is not None:
candidates.append((item, model_id))
if not candidates:
return []
tasks = [
self.update_service.has_update(self.model_type, model_id)
for _, model_id in candidates
]
results = await asyncio.gather(*tasks, return_exceptions=True)
filtered: List[Dict] = []
for (item, model_id), result in zip(candidates, results):
if isinstance(result, Exception):
logger.error(
"Failed to resolve update status for model %s (%s): %s",
model_id,
self.model_type,
result,
)
continue
if result:
filtered.append(item)
return filtered
@staticmethod
def _extract_model_id(item: Dict) -> Optional[int]:
civitai = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai, dict):
return None
try:
value = civitai.get('modelId')
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
"""Apply pagination to filtered data"""
@@ -373,4 +437,4 @@ class BaseModelService(ABC):
x.lower() # Then alphabetically
))
return matching_paths[:limit]
return matching_paths[:limit]