From 2d00cfdd319fb9d5750db77e7231cef680fc3c05 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Sun, 21 Sep 2025 23:13:30 +0800 Subject: [PATCH] refactor: enhance BaseModelService initialization and improve filtering logic --- py/services/base_model_service.py | 261 +++++++-------------- py/services/model_query.py | 196 ++++++++++++++++ tests/services/test_base_model_service.py | 269 ++++++++++++++++++++++ 3 files changed, 552 insertions(+), 174 deletions(-) create mode 100644 py/services/model_query.py create mode 100644 tests/services/test_base_model_service.py diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index ed1fc930..0b4aaf99 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -5,98 +5,88 @@ import os from ..utils.models import BaseModelMetadata from ..utils.routes_common import ModelRouteUtils -from ..utils.constants import NSFW_LEVELS -from .settings_manager import settings -from ..utils.utils import fuzzy_match +from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider +from .settings_manager import settings as default_settings logger = logging.getLogger(__name__) class BaseModelService(ABC): """Base service class for all model types""" - def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]): - """Initialize the service - + def __init__( + self, + model_type: str, + scanner, + metadata_class: Type[BaseModelMetadata], + *, + cache_repository: Optional[ModelCacheRepository] = None, + filter_set: Optional[ModelFilterSet] = None, + search_strategy: Optional[SearchStrategy] = None, + settings_provider: Optional[SettingsProvider] = None, + ): + """Initialize the service. + Args: - model_type: Type of model (lora, checkpoint, etc.) - scanner: Model scanner instance - metadata_class: Metadata class for this model type + model_type: Type of model (lora, checkpoint, etc.). + scanner: Model scanner instance. + metadata_class: Metadata class for this model type. + cache_repository: Custom repository for cache access (primarily for tests). + filter_set: Filter component controlling folder/tag/favorites logic. + search_strategy: Search component for fuzzy/text matching. + settings_provider: Settings object; defaults to the global settings manager. """ self.model_type = model_type self.scanner = scanner self.metadata_class = metadata_class + self.settings = settings_provider or default_settings + self.cache_repository = cache_repository or ModelCacheRepository(scanner) + self.filter_set = filter_set or ModelFilterSet(self.settings) + self.search_strategy = search_strategy or SearchStrategy() - async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', - folder: str = None, search: str = None, fuzzy_search: bool = False, - base_models: list = None, tags: list = None, - search_options: dict = None, hash_filters: dict = None, - favorites_only: bool = False, **kwargs) -> Dict: - """Get paginated and filtered model data - - Args: - page: Page number (1-based) - page_size: Number of items per page - sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc' - folder: Folder filter - search: Search term - fuzzy_search: Whether to use fuzzy search - base_models: List of base models to filter by - tags: List of tags to filter by - search_options: Search options dict - hash_filters: Hash filtering options - favorites_only: Filter for favorites only - **kwargs: Additional model-specific filters - - Returns: - Dict containing paginated results - """ - cache = await self.scanner.get_cached_data() + async def get_paginated_data( + self, + page: int, + page_size: int, + sort_by: str = 'name', + folder: str = None, + search: str = None, + fuzzy_search: bool = False, + base_models: list = None, + tags: list = None, + search_options: dict = None, + hash_filters: dict = None, + favorites_only: bool = False, + **kwargs, + ) -> Dict: + """Get paginated and filtered model data""" + sort_params = self.cache_repository.parse_sort(sort_by) + sorted_data = await self.cache_repository.fetch_sorted(sort_params) - # Parse sort_by into sort_key and order - if ':' in sort_by: - sort_key, order = sort_by.split(':', 1) - sort_key = sort_key.strip() - order = order.strip().lower() - if order not in ('asc', 'desc'): - order = 'asc' - else: - sort_key = sort_by.strip() - order = 'asc' - - # Get default search options if not provided - if search_options is None: - search_options = { - 'filename': True, - 'modelname': True, - 'tags': False, - 'recursive': True, - } - - # Get the base data set using new sort logic - filtered_data = await cache.get_sorted_data(sort_key, order) - - # Apply hash filtering if provided (highest priority) if hash_filters: - filtered_data = await self._apply_hash_filters(filtered_data, hash_filters) - - # Jump to pagination for hash filters + filtered_data = await self._apply_hash_filters(sorted_data, hash_filters) return self._paginate(filtered_data, page, page_size) - - # Apply common filters + filtered_data = await self._apply_common_filters( - filtered_data, folder, base_models, tags, favorites_only, search_options + sorted_data, + folder=folder, + base_models=base_models, + tags=tags, + favorites_only=favorites_only, + search_options=search_options, ) - - # Apply search filtering + if search: filtered_data = await self._apply_search_filters( - filtered_data, search, fuzzy_search, search_options + filtered_data, + search, + fuzzy_search, + search_options, ) - - # Apply model-specific filters + filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) - + return self._paginate(filtered_data, page, page_size) + async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: """Apply hash-based filtering""" @@ -120,113 +110,36 @@ class BaseModelService(ABC): return data - async def _apply_common_filters(self, data: List[Dict], folder: str = None, - base_models: list = None, tags: list = None, - favorites_only: bool = False, search_options: dict = None) -> List[Dict]: + async def _apply_common_filters( + self, + data: List[Dict], + folder: str = None, + base_models: list = None, + tags: list = None, + favorites_only: bool = False, + search_options: dict = None, + ) -> List[Dict]: """Apply common filters that work across all model types""" - # Apply SFW filtering if enabled in settings - if settings.get('show_only_sfw', False): - data = [ - item for item in data - if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] - ] - - # Apply favorites filtering if enabled - if favorites_only: - data = [ - item for item in data - if item.get('favorite', False) is True - ] - - # Apply folder filtering - if folder is not None: - if search_options and search_options.get('recursive', True): - # Recursive folder filtering - include all subfolders - # Ensure we match exact folder or its subfolders by checking path boundaries - if folder == "": - # Empty folder means root - include all items - pass # Don't filter anything - else: - # Add trailing slash to ensure we match folder boundaries correctly - folder_with_separator = folder + "/" - data = [ - item for item in data - if (item['folder'] == folder or - item['folder'].startswith(folder_with_separator)) - ] - else: - # Exact folder filtering - data = [ - item for item in data - if item['folder'] == folder - ] - - # Apply base model filtering - if base_models and len(base_models) > 0: - data = [ - item for item in data - if item.get('base_model') in base_models - ] - - # Apply tag filtering - if tags and len(tags) > 0: - data = [ - item for item in data - if any(tag in item.get('tags', []) for tag in tags) - ] - - return data + normalized_options = self.search_strategy.normalize_options(search_options) + criteria = FilterCriteria( + folder=folder, + base_models=base_models, + tags=tags, + favorites_only=favorites_only, + search_options=normalized_options, + ) + return self.filter_set.apply(data, criteria) - async def _apply_search_filters(self, data: List[Dict], search: str, - fuzzy_search: bool, search_options: dict) -> List[Dict]: + async def _apply_search_filters( + self, + data: List[Dict], + search: str, + fuzzy_search: bool, + search_options: dict, + ) -> List[Dict]: """Apply search filtering""" - search_results = [] - - for item in data: - # Search by file name - if search_options.get('filename', True): - if fuzzy_search: - if fuzzy_match(item.get('file_name', ''), search): - search_results.append(item) - continue - elif search.lower() in item.get('file_name', '').lower(): - search_results.append(item) - continue - - # Search by model name - if search_options.get('modelname', True): - if fuzzy_search: - if fuzzy_match(item.get('model_name', ''), search): - search_results.append(item) - continue - elif search.lower() in item.get('model_name', '').lower(): - search_results.append(item) - continue - - # Search by tags - if search_options.get('tags', False) and 'tags' in item: - if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) - for tag in item['tags']): - search_results.append(item) - continue - - # Search by creator - civitai = item.get('civitai') - creator_username = '' - if civitai and isinstance(civitai, dict): - creator = civitai.get('creator') - if creator and isinstance(creator, dict): - creator_username = creator.get('username', '') - if search_options.get('creator', False) and creator_username: - if fuzzy_search: - if fuzzy_match(creator_username, search): - search_results.append(item) - continue - elif search.lower() in creator_username.lower(): - search_results.append(item) - continue - - return search_results + normalized_options = self.search_strategy.normalize_options(search_options) + return self.search_strategy.apply(data, search, normalized_options, fuzzy_search) async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: """Apply model-specific filters - to be overridden by subclasses if needed""" diff --git a/py/services/model_query.py b/py/services/model_query.py new file mode 100644 index 00000000..08ca652f --- /dev/null +++ b/py/services/model_query.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable + +from ..utils.constants import NSFW_LEVELS +from ..utils.utils import fuzzy_match as default_fuzzy_match + + +class SettingsProvider(Protocol): + """Protocol describing the SettingsManager contract used by query helpers.""" + + def get(self, key: str, default: Any = None) -> Any: + ... + + +@dataclass(frozen=True) +class SortParams: + """Normalized representation of sorting instructions.""" + + key: str + order: str + + +@dataclass(frozen=True) +class FilterCriteria: + """Container for model list filtering options.""" + + folder: Optional[str] = None + base_models: Optional[Sequence[str]] = None + tags: Optional[Sequence[str]] = None + favorites_only: bool = False + search_options: Optional[Dict[str, Any]] = None + + +class ModelCacheRepository: + """Adapter around scanner cache access and sort normalisation.""" + + def __init__(self, scanner) -> None: + self._scanner = scanner + + async def get_cache(self): + """Return the underlying cache instance from the scanner.""" + return await self._scanner.get_cached_data() + + async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]: + """Fetch cached data pre-sorted according to ``params``.""" + cache = await self.get_cache() + return await cache.get_sorted_data(params.key, params.order) + + @staticmethod + def parse_sort(sort_by: str) -> SortParams: + """Parse an incoming sort string into key/order primitives.""" + if not sort_by: + return SortParams(key="name", order="asc") + + if ":" in sort_by: + raw_key, raw_order = sort_by.split(":", 1) + sort_key = raw_key.strip().lower() or "name" + order = raw_order.strip().lower() + else: + sort_key = sort_by.strip().lower() or "name" + order = "asc" + + if order not in ("asc", "desc"): + order = "asc" + + return SortParams(key=sort_key, order=order) + + +class ModelFilterSet: + """Applies common filtering rules to the model collection.""" + + def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None: + self._settings = settings + self._nsfw_levels = nsfw_levels or NSFW_LEVELS + + def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]: + """Return items that satisfy the provided criteria.""" + items = list(data) + + if self._settings.get("show_only_sfw", False): + threshold = self._nsfw_levels.get("R", 0) + items = [ + item for item in items + if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold + ] + + if criteria.favorites_only: + items = [item for item in items if item.get("favorite", False)] + + folder = criteria.folder + options = criteria.search_options or {} + recursive = bool(options.get("recursive", True)) + if folder is not None: + if recursive: + if folder: + folder_with_sep = f"{folder}/" + items = [ + item for item in items + if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep) + ] + else: + items = [item for item in items if item.get("folder") == folder] + + base_models = criteria.base_models or [] + if base_models: + base_model_set = set(base_models) + items = [item for item in items if item.get("base_model") in base_model_set] + + tags = criteria.tags or [] + if tags: + tag_set = set(tags) + items = [ + item for item in items + if any(tag in tag_set for tag in item.get("tags", [])) + ] + + return items + + +class SearchStrategy: + """Encapsulates text and fuzzy matching behaviour for model queries.""" + + DEFAULT_OPTIONS: Dict[str, Any] = { + "filename": True, + "modelname": True, + "tags": False, + "recursive": True, + "creator": False, + } + + def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None: + self._fuzzy_match = fuzzy_matcher or default_fuzzy_match + + def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Merge provided options with defaults without mutating input.""" + normalized = dict(self.DEFAULT_OPTIONS) + if options: + normalized.update(options) + return normalized + + def apply( + self, + data: Iterable[Dict[str, Any]], + search_term: str, + options: Dict[str, Any], + fuzzy: bool = False, + ) -> List[Dict[str, Any]]: + """Return items matching the search term using the configured strategy.""" + if not search_term: + return list(data) + + search_lower = search_term.lower() + results: List[Dict[str, Any]] = [] + + for item in data: + if options.get("filename", True): + candidate = item.get("file_name", "") + if self._matches(candidate, search_term, search_lower, fuzzy): + results.append(item) + continue + + if options.get("modelname", True): + candidate = item.get("model_name", "") + if self._matches(candidate, search_term, search_lower, fuzzy): + results.append(item) + continue + + if options.get("tags", False): + tags = item.get("tags", []) or [] + if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags): + results.append(item) + continue + + if options.get("creator", False): + creator_username = "" + civitai = item.get("civitai") + if isinstance(civitai, dict): + creator = civitai.get("creator") + if isinstance(creator, dict): + creator_username = creator.get("username", "") + if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy): + results.append(item) + continue + + return results + + def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool: + if not candidate: + return False + + candidate_lower = candidate.lower() + if fuzzy: + return self._fuzzy_match(candidate, search_term) + return search_lower in candidate_lower diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py new file mode 100644 index 00000000..4acfcc49 --- /dev/null +++ b/tests/services/test_base_model_service.py @@ -0,0 +1,269 @@ +import pytest + +from py.services.base_model_service import BaseModelService +from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams +from py.utils.models import BaseModelMetadata + + +class StubSettings: + def __init__(self, values): + self._values = dict(values) + + def get(self, key, default=None): + return self._values.get(key, default) + + +class DummyService(BaseModelService): + async def format_response(self, model_data): + return model_data + + +class StubRepository: + def __init__(self, data): + self._data = list(data) + self.parse_sort_calls = [] + self.fetch_sorted_calls = [] + + def parse_sort(self, sort_by): + params = ModelCacheRepository.parse_sort(sort_by) + self.parse_sort_calls.append(sort_by) + return params + + async def fetch_sorted(self, params): + self.fetch_sorted_calls.append(params) + return list(self._data) + + +class StubFilterSet: + def __init__(self, result): + self.result = list(result) + self.calls = [] + + def apply(self, data, criteria): + self.calls.append((list(data), criteria)) + return list(self.result) + + +class StubSearchStrategy: + def __init__(self, search_result): + self.search_result = list(search_result) + self.normalize_calls = [] + self.apply_calls = [] + + def normalize_options(self, options): + self.normalize_calls.append(options) + normalized = {"recursive": True} + if options: + normalized.update(options) + return normalized + + def apply(self, data, search_term, options, fuzzy): + self.apply_calls.append((list(data), search_term, options, fuzzy)) + return list(self.search_result) + + +@pytest.mark.asyncio +async def test_get_paginated_data_uses_injected_collaborators(): + data = [ + {"model_name": "Alpha", "folder": "root"}, + {"model_name": "Beta", "folder": "root"}, + ] + repository = StubRepository(data) + filter_set = StubFilterSet([{"model_name": "Filtered"}]) + search_strategy = StubSearchStrategy([{"model_name": "SearchResult"}]) + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=1, + page_size=5, + sort_by="name:desc", + folder="root", + search="query", + fuzzy_search=True, + base_models=["base"], + tags=["tag"], + search_options={"recursive": False}, + favorites_only=True, + ) + + assert repository.parse_sort_calls == ["name:desc"] + assert repository.fetch_sorted_calls and isinstance(repository.fetch_sorted_calls[0], SortParams) + sort_params = repository.fetch_sorted_calls[0] + assert sort_params.key == "name" and sort_params.order == "desc" + + assert filter_set.calls, "FilterSet should be invoked" + call_data, criteria = filter_set.calls[0] + assert call_data == data + assert criteria.folder == "root" + assert criteria.base_models == ["base"] + assert criteria.tags == ["tag"] + assert criteria.favorites_only is True + assert criteria.search_options.get("recursive") is False + + assert search_strategy.normalize_calls == [{"recursive": False}, {"recursive": False}] + assert search_strategy.apply_calls == [([{"model_name": "Filtered"}], "query", {"recursive": False}, True)] + + assert response["items"] == search_strategy.search_result + assert response["total"] == len(search_strategy.search_result) + assert response["page"] == 1 + assert response["page_size"] == 5 + + +class FakeCache: + def __init__(self, items): + self.items = list(items) + + async def get_sorted_data(self, sort_key, order): + if sort_key == "name": + data = sorted(self.items, key=lambda x: x["model_name"].lower()) + if order == "desc": + data.reverse() + else: + data = list(self.items) + return data + + +class FakeScanner: + def __init__(self, cache): + self._cache = cache + + async def get_cached_data(self, *_, **__): + return self._cache + + +@pytest.mark.asyncio +async def test_get_paginated_data_filters_and_searches_combination(): + items = [ + { + "model_name": "Alpha", + "file_name": "alpha.safetensors", + "folder": "root/sub", + "tags": ["tag1"], + "base_model": "v1", + "favorite": True, + "preview_nsfw_level": 0, + }, + { + "model_name": "Beta", + "file_name": "beta.safetensors", + "folder": "root", + "tags": ["tag2"], + "base_model": "v2", + "favorite": False, + "preview_nsfw_level": 999, + }, + { + "model_name": "Gamma", + "file_name": "gamma.safetensors", + "folder": "root/sub2", + "tags": ["tag1", "tag3"], + "base_model": "v1", + "favorite": True, + "preview_nsfw_level": 0, + "civitai": {"creator": {"username": "artist"}}, + }, + ] + + cache = FakeCache(items) + scanner = FakeScanner(cache) + settings = StubSettings({"show_only_sfw": True}) + + service = DummyService( + model_type="stub", + scanner=scanner, + metadata_class=BaseModelMetadata, + cache_repository=ModelCacheRepository(scanner), + filter_set=ModelFilterSet(settings), + search_strategy=SearchStrategy(), + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=1, + page_size=1, + sort_by="name:asc", + folder="root", + search="artist", + base_models=["v1"], + tags=["tag1"], + search_options={"creator": True, "tags": True}, + favorites_only=True, + ) + + assert response["items"] == [items[2]] + assert response["total"] == 1 + assert response["page"] == 1 + assert response["page_size"] == 1 + assert response["total_pages"] == 1 + + +class PassThroughFilterSet: + def __init__(self): + self.calls = [] + + def apply(self, data, criteria): + self.calls.append(criteria) + return list(data) + + +class NoSearchStrategy: + def __init__(self): + self.normalize_calls = [] + self.apply_called = False + + def normalize_options(self, options): + self.normalize_calls.append(options) + return {"recursive": True} + + def apply(self, *args, **kwargs): + self.apply_called = True + pytest.fail("Search should not be invoked when no search term is provided") + + +@pytest.mark.asyncio +async def test_get_paginated_data_paginates_without_search(): + items = [ + {"model_name": name, "folder": "root"} + for name in ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"] + ] + + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=2, + page_size=2, + sort_by="name:asc", + ) + + assert repository.parse_sort_calls == ["name:asc"] + assert len(repository.fetch_sorted_calls) == 1 + assert filter_set.calls and filter_set.calls[0].favorites_only is False + assert search_strategy.apply_called is False + assert response["items"] == items[2:4] + assert response["total"] == len(items) + assert response["page"] == 2 + assert response["page_size"] == 2 + assert response["total_pages"] == 3