feat: add update service dependency and has_update filter

- Pass ModelUpdateService to CheckpointService, EmbeddingService, and LoraService constructors
- Add has_update query parameter filter to model listing handler
- Update BaseModelService to accept optional update_service parameter

These changes enable model update functionality across different model types and provide filtering capability for models with available updates.
This commit is contained in:
Will Miao
2025-10-15 17:25:16 +08:00
parent 5a6ff444b9
commit a5b2e9b0bf
9 changed files with 209 additions and 15 deletions

View File

@@ -20,8 +20,8 @@ class CheckpointRoutes(BaseModelRoutes):
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
update_service = await ServiceRegistry.get_model_update_service() update_service = await ServiceRegistry.get_model_update_service()
self.service = CheckpointService(checkpoint_scanner, update_service=update_service)
self.set_model_update_service(update_service) self.set_model_update_service(update_service)
# Attach service dependencies # Attach service dependencies

View File

@@ -19,8 +19,8 @@ class EmbeddingRoutes(BaseModelRoutes):
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
embedding_scanner = await ServiceRegistry.get_embedding_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.service = EmbeddingService(embedding_scanner)
update_service = await ServiceRegistry.get_model_update_service() update_service = await ServiceRegistry.get_model_update_service()
self.service = EmbeddingService(embedding_scanner, update_service=update_service)
self.set_model_update_service(update_service) self.set_model_update_service(update_service)
# Attach service dependencies # Attach service dependencies

View File

@@ -6,7 +6,7 @@ import json
import logging import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional from typing import Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
from aiohttp import web from aiohttp import web
import jinja2 import jinja2
@@ -166,6 +166,11 @@ class ModelListingHandler:
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
pass pass
has_update = request.query.get("has_update", "false")
has_update_filter = (
has_update.lower() in {"1", "true", "yes"} if isinstance(has_update, str) else False
)
return { return {
"page": page, "page": page,
"page_size": page_size, "page_size": page_size,
@@ -178,6 +183,7 @@ class ModelListingHandler:
"search_options": search_options, "search_options": search_options,
"hash_filters": hash_filters, "hash_filters": hash_filters,
"favorites_only": favorites_only, "favorites_only": favorites_only,
"has_update": has_update_filter,
**self._parse_specific_params(request), **self._parse_specific_params(request),
} }

View File

@@ -23,8 +23,8 @@ class LoraRoutes(BaseModelRoutes):
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner() lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner)
update_service = await ServiceRegistry.get_model_update_service() update_service = await ServiceRegistry.get_model_update_service()
self.service = LoraService(lora_scanner, update_service=update_service)
self.set_model_update_service(update_service) self.set_model_update_service(update_service)
# Attach service dependencies # Attach service dependencies

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Type import asyncio
from typing import Dict, List, Optional, Type, TYPE_CHECKING
import logging import logging
import os import os
@@ -10,6 +11,9 @@ from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from .model_update_service import ModelUpdateService
class BaseModelService(ABC): class BaseModelService(ABC):
"""Base service class for all model types""" """Base service class for all model types"""
@@ -23,6 +27,7 @@ class BaseModelService(ABC):
filter_set: Optional[ModelFilterSet] = None, filter_set: Optional[ModelFilterSet] = None,
search_strategy: Optional[SearchStrategy] = None, search_strategy: Optional[SearchStrategy] = None,
settings_provider: Optional[SettingsProvider] = None, settings_provider: Optional[SettingsProvider] = None,
update_service: Optional["ModelUpdateService"] = None,
): ):
"""Initialize the service. """Initialize the service.
@@ -34,6 +39,7 @@ class BaseModelService(ABC):
filter_set: Filter component controlling folder/tag/favorites logic. filter_set: Filter component controlling folder/tag/favorites logic.
search_strategy: Search component for fuzzy/text matching. search_strategy: Search component for fuzzy/text matching.
settings_provider: Settings object; defaults to the global settings manager. settings_provider: Settings object; defaults to the global settings manager.
update_service: Service used to determine whether models have remote updates available.
""" """
self.model_type = model_type self.model_type = model_type
self.scanner = scanner self.scanner = scanner
@@ -42,6 +48,7 @@ class BaseModelService(ABC):
self.cache_repository = cache_repository or ModelCacheRepository(scanner) self.cache_repository = cache_repository or ModelCacheRepository(scanner)
self.filter_set = filter_set or ModelFilterSet(self.settings) self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy() self.search_strategy = search_strategy or SearchStrategy()
self.update_service = update_service
async def get_paginated_data( async def get_paginated_data(
self, self,
@@ -56,6 +63,7 @@ class BaseModelService(ABC):
search_options: dict = None, search_options: dict = None,
hash_filters: dict = None, hash_filters: dict = None,
favorites_only: bool = False, favorites_only: bool = False,
has_update: bool = False,
**kwargs, **kwargs,
) -> Dict: ) -> Dict:
"""Get paginated and filtered model data""" """Get paginated and filtered model data"""
@@ -85,6 +93,9 @@ class BaseModelService(ABC):
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
if has_update:
filtered_data = await self._apply_update_filter(filtered_data)
return self._paginate(filtered_data, page, page_size) return self._paginate(filtered_data, page, page_size)
@@ -145,6 +156,59 @@ class BaseModelService(ABC):
"""Apply model-specific filters - to be overridden by subclasses if needed""" """Apply model-specific filters - to be overridden by subclasses if needed"""
return data return data
async def _apply_update_filter(self, data: List[Dict]) -> List[Dict]:
"""Filter models to those with remote updates available when requested."""
if not data:
return []
if self.update_service is None:
logger.warning(
"Requested has_update filter for %s models but update service is unavailable",
self.model_type,
)
return []
candidates: List[tuple[Dict, int]] = []
for item in data:
model_id = self._extract_model_id(item)
if model_id is not None:
candidates.append((item, model_id))
if not candidates:
return []
tasks = [
self.update_service.has_update(self.model_type, model_id)
for _, model_id in candidates
]
results = await asyncio.gather(*tasks, return_exceptions=True)
filtered: List[Dict] = []
for (item, model_id), result in zip(candidates, results):
if isinstance(result, Exception):
logger.error(
"Failed to resolve update status for model %s (%s): %s",
model_id,
self.model_type,
result,
)
continue
if result:
filtered.append(item)
return filtered
@staticmethod
def _extract_model_id(item: Dict) -> Optional[int]:
civitai = item.get('civitai') if isinstance(item, dict) else None
if not isinstance(civitai, dict):
return None
try:
value = civitai.get('modelId')
if value is None:
return None
return int(value)
except (TypeError, ValueError):
return None
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict: def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
"""Apply pagination to filtered data""" """Apply pagination to filtered data"""
total_items = len(data) total_items = len(data)

View File

@@ -11,13 +11,14 @@ logger = logging.getLogger(__name__)
class CheckpointService(BaseModelService): class CheckpointService(BaseModelService):
"""Checkpoint-specific service implementation""" """Checkpoint-specific service implementation"""
def __init__(self, scanner): def __init__(self, scanner, update_service=None):
"""Initialize Checkpoint service """Initialize Checkpoint service
Args: Args:
scanner: Checkpoint scanner instance scanner: Checkpoint scanner instance
update_service: Optional service for remote update tracking.
""" """
super().__init__("checkpoint", scanner, CheckpointMetadata) super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
async def format_response(self, checkpoint_data: Dict) -> Dict: async def format_response(self, checkpoint_data: Dict) -> Dict:
"""Format Checkpoint data for API response""" """Format Checkpoint data for API response"""

View File

@@ -11,13 +11,14 @@ logger = logging.getLogger(__name__)
class EmbeddingService(BaseModelService): class EmbeddingService(BaseModelService):
"""Embedding-specific service implementation""" """Embedding-specific service implementation"""
def __init__(self, scanner): def __init__(self, scanner, update_service=None):
"""Initialize Embedding service """Initialize Embedding service
Args: Args:
scanner: Embedding scanner instance scanner: Embedding scanner instance
update_service: Optional service for remote update tracking.
""" """
super().__init__("embedding", scanner, EmbeddingMetadata) super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service)
async def format_response(self, embedding_data: Dict) -> Dict: async def format_response(self, embedding_data: Dict) -> Dict:
"""Format Embedding data for API response""" """Format Embedding data for API response"""

View File

@@ -11,13 +11,14 @@ logger = logging.getLogger(__name__)
class LoraService(BaseModelService): class LoraService(BaseModelService):
"""LoRA-specific service implementation""" """LoRA-specific service implementation"""
def __init__(self, scanner): def __init__(self, scanner, update_service=None):
"""Initialize LoRA service """Initialize LoRA service
Args: Args:
scanner: LoRA scanner instance scanner: LoRA scanner instance
update_service: Optional service for remote update tracking.
""" """
super().__init__("lora", scanner, LoraMetadata) super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
async def format_response(self, lora_data: Dict) -> Dict: async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response""" """Format LoRA data for API response"""

View File

@@ -67,6 +67,19 @@ class StubSearchStrategy:
return list(self.search_result) return list(self.search_result)
class StubUpdateService:
def __init__(self, decisions):
self.decisions = dict(decisions)
self.calls = []
async def has_update(self, model_type, model_id):
self.calls.append((model_type, model_id))
result = self.decisions.get(model_id, False)
if isinstance(result, Exception):
raise result
return result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_paginated_data_uses_injected_collaborators(): async def test_get_paginated_data_uses_injected_collaborators():
data = [ data = [
@@ -272,3 +285,111 @@ async def test_get_paginated_data_paginates_without_search():
assert response["page"] == 2 assert response["page"] == 2
assert response["page_size"] == 2 assert response["page_size"] == 2
assert response["total_pages"] == 3 assert response["total_pages"] == 3
@pytest.mark.asyncio
async def test_get_paginated_data_filters_by_update_status():
items = [
{"model_name": "A", "civitai": {"modelId": 1}},
{"model_name": "B", "civitai": {"modelId": 2}},
{"model_name": "C", "civitai": {"modelId": 3}},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
update_service = StubUpdateService({1: True, 2: False, 3: True})
settings = StubSettings({})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
update_service=update_service,
)
response = await service.get_paginated_data(
page=1,
page_size=5,
sort_by="name:asc",
has_update=True,
)
assert update_service.calls == [("stub", 1), ("stub", 2), ("stub", 3)]
assert response["items"] == [items[0], items[2]]
assert response["total"] == 2
assert response["page"] == 1
assert response["page_size"] == 5
assert response["total_pages"] == 1
@pytest.mark.asyncio
async def test_get_paginated_data_has_update_without_service_returns_empty():
items = [
{"model_name": "A", "civitai": {"modelId": 1}},
{"model_name": "B", "civitai": {"modelId": 2}},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
settings = StubSettings({})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
)
response = await service.get_paginated_data(
page=1,
page_size=10,
sort_by="name:asc",
has_update=True,
)
assert response["items"] == []
assert response["total"] == 0
assert response["total_pages"] == 0
@pytest.mark.asyncio
async def test_get_paginated_data_skips_items_when_update_check_fails():
items = [
{"model_name": "A", "civitai": {"modelId": 1}},
{"model_name": "B", "civitai": {"modelId": 2}},
]
repository = StubRepository(items)
filter_set = PassThroughFilterSet()
search_strategy = NoSearchStrategy()
update_service = StubUpdateService({1: True, 2: RuntimeError("boom")})
settings = StubSettings({})
service = DummyService(
model_type="stub",
scanner=object(),
metadata_class=BaseModelMetadata,
cache_repository=repository,
filter_set=filter_set,
search_strategy=search_strategy,
settings_provider=settings,
update_service=update_service,
)
response = await service.get_paginated_data(
page=1,
page_size=10,
sort_by="name:asc",
has_update=True,
)
assert update_service.calls == [("stub", 1), ("stub", 2)]
assert response["items"] == [items[0]]
assert response["total"] == 1
assert response["total_pages"] == 1