refactor: enhance BaseModelService initialization and improve filtering logic

This commit is contained in:
Will Miao
2025-09-21 23:13:30 +08:00
parent 49e03d658b
commit 2d00cfdd31
3 changed files with 552 additions and 174 deletions

View File

@@ -5,98 +5,88 @@ import os
from ..utils.models import BaseModelMetadata
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS
from .settings_manager import settings
from ..utils.utils import fuzzy_match
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
from .settings_manager import settings as default_settings
logger = logging.getLogger(__name__)
class BaseModelService(ABC):
"""Base service class for all model types"""
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
"""Initialize the service
def __init__(
self,
model_type: str,
scanner,
metadata_class: Type[BaseModelMetadata],
*,
cache_repository: Optional[ModelCacheRepository] = None,
filter_set: Optional[ModelFilterSet] = None,
search_strategy: Optional[SearchStrategy] = None,
settings_provider: Optional[SettingsProvider] = None,
):
"""Initialize the service.
Args:
model_type: Type of model (lora, checkpoint, etc.)
scanner: Model scanner instance
metadata_class: Metadata class for this model type
model_type: Type of model (lora, checkpoint, etc.).
scanner: Model scanner instance.
metadata_class: Metadata class for this model type.
cache_repository: Custom repository for cache access (primarily for tests).
filter_set: Filter component controlling folder/tag/favorites logic.
search_strategy: Search component for fuzzy/text matching.
settings_provider: Settings object; defaults to the global settings manager.
"""
self.model_type = model_type
self.scanner = scanner
self.metadata_class = metadata_class
self.settings = settings_provider or default_settings
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy()
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None, hash_filters: dict = None,
favorites_only: bool = False, **kwargs) -> Dict:
"""Get paginated and filtered model data
Args:
page: Page number (1-based)
page_size: Number of items per page
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
folder: Folder filter
search: Search term
fuzzy_search: Whether to use fuzzy search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Search options dict
hash_filters: Hash filtering options
favorites_only: Filter for favorites only
**kwargs: Additional model-specific filters
Returns:
Dict containing paginated results
"""
cache = await self.scanner.get_cached_data()
async def get_paginated_data(
self,
page: int,
page_size: int,
sort_by: str = 'name',
folder: str = None,
search: str = None,
fuzzy_search: bool = False,
base_models: list = None,
tags: list = None,
search_options: dict = None,
hash_filters: dict = None,
favorites_only: bool = False,
**kwargs,
) -> Dict:
"""Get paginated and filtered model data"""
sort_params = self.cache_repository.parse_sort(sort_by)
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
# Parse sort_by into sort_key and order
if ':' in sort_by:
sort_key, order = sort_by.split(':', 1)
sort_key = sort_key.strip()
order = order.strip().lower()
if order not in ('asc', 'desc'):
order = 'asc'
else:
sort_key = sort_by.strip()
order = 'asc'
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': True,
}
# Get the base data set using new sort logic
filtered_data = await cache.get_sorted_data(sort_key, order)
# Apply hash filtering if provided (highest priority)
if hash_filters:
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
# Jump to pagination for hash filters
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
return self._paginate(filtered_data, page, page_size)
# Apply common filters
filtered_data = await self._apply_common_filters(
filtered_data, folder, base_models, tags, favorites_only, search_options
sorted_data,
folder=folder,
base_models=base_models,
tags=tags,
favorites_only=favorites_only,
search_options=search_options,
)
# Apply search filtering
if search:
filtered_data = await self._apply_search_filters(
filtered_data, search, fuzzy_search, search_options
filtered_data,
search,
fuzzy_search,
search_options,
)
# Apply model-specific filters
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
return self._paginate(filtered_data, page, page_size)
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
"""Apply hash-based filtering"""
@@ -120,113 +110,36 @@ class BaseModelService(ABC):
return data
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
base_models: list = None, tags: list = None,
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
async def _apply_common_filters(
self,
data: List[Dict],
folder: str = None,
base_models: list = None,
tags: list = None,
favorites_only: bool = False,
search_options: dict = None,
) -> List[Dict]:
"""Apply common filters that work across all model types"""
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
data = [
item for item in data
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
data = [
item for item in data
if item.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options and search_options.get('recursive', True):
# Recursive folder filtering - include all subfolders
# Ensure we match exact folder or its subfolders by checking path boundaries
if folder == "":
# Empty folder means root - include all items
pass # Don't filter anything
else:
# Add trailing slash to ensure we match folder boundaries correctly
folder_with_separator = folder + "/"
data = [
item for item in data
if (item['folder'] == folder or
item['folder'].startswith(folder_with_separator))
]
else:
# Exact folder filtering
data = [
item for item in data
if item['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
data = [
item for item in data
if item.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
data = [
item for item in data
if any(tag in item.get('tags', []) for tag in tags)
]
return data
normalized_options = self.search_strategy.normalize_options(search_options)
criteria = FilterCriteria(
folder=folder,
base_models=base_models,
tags=tags,
favorites_only=favorites_only,
search_options=normalized_options,
)
return self.filter_set.apply(data, criteria)
async def _apply_search_filters(self, data: List[Dict], search: str,
fuzzy_search: bool, search_options: dict) -> List[Dict]:
async def _apply_search_filters(
self,
data: List[Dict],
search: str,
fuzzy_search: bool,
search_options: dict,
) -> List[Dict]:
"""Apply search filtering"""
search_results = []
for item in data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(item.get('file_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('file_name', '').lower():
search_results.append(item)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(item.get('model_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('model_name', '').lower():
search_results.append(item)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in item:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
for tag in item['tags']):
search_results.append(item)
continue
# Search by creator
civitai = item.get('civitai')
creator_username = ''
if civitai and isinstance(civitai, dict):
creator = civitai.get('creator')
if creator and isinstance(creator, dict):
creator_username = creator.get('username', '')
if search_options.get('creator', False) and creator_username:
if fuzzy_search:
if fuzzy_match(creator_username, search):
search_results.append(item)
continue
elif search.lower() in creator_username.lower():
search_results.append(item)
continue
return search_results
normalized_options = self.search_strategy.normalize_options(search_options)
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed"""

196
py/services/model_query.py Normal file
View File

@@ -0,0 +1,196 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match
class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers."""
def get(self, key: str, default: Any = None) -> Any:
...
@dataclass(frozen=True)
class SortParams:
"""Normalized representation of sorting instructions."""
key: str
order: str
@dataclass(frozen=True)
class FilterCriteria:
"""Container for model list filtering options."""
folder: Optional[str] = None
base_models: Optional[Sequence[str]] = None
tags: Optional[Sequence[str]] = None
favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None
class ModelCacheRepository:
"""Adapter around scanner cache access and sort normalisation."""
def __init__(self, scanner) -> None:
self._scanner = scanner
async def get_cache(self):
"""Return the underlying cache instance from the scanner."""
return await self._scanner.get_cached_data()
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
"""Fetch cached data pre-sorted according to ``params``."""
cache = await self.get_cache()
return await cache.get_sorted_data(params.key, params.order)
@staticmethod
def parse_sort(sort_by: str) -> SortParams:
"""Parse an incoming sort string into key/order primitives."""
if not sort_by:
return SortParams(key="name", order="asc")
if ":" in sort_by:
raw_key, raw_order = sort_by.split(":", 1)
sort_key = raw_key.strip().lower() or "name"
order = raw_order.strip().lower()
else:
sort_key = sort_by.strip().lower() or "name"
order = "asc"
if order not in ("asc", "desc"):
order = "asc"
return SortParams(key=sort_key, order=order)
class ModelFilterSet:
"""Applies common filtering rules to the model collection."""
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
self._settings = settings
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
"""Return items that satisfy the provided criteria."""
items = list(data)
if self._settings.get("show_only_sfw", False):
threshold = self._nsfw_levels.get("R", 0)
items = [
item for item in items
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
]
if criteria.favorites_only:
items = [item for item in items if item.get("favorite", False)]
folder = criteria.folder
options = criteria.search_options or {}
recursive = bool(options.get("recursive", True))
if folder is not None:
if recursive:
if folder:
folder_with_sep = f"{folder}/"
items = [
item for item in items
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
]
else:
items = [item for item in items if item.get("folder") == folder]
base_models = criteria.base_models or []
if base_models:
base_model_set = set(base_models)
items = [item for item in items if item.get("base_model") in base_model_set]
tags = criteria.tags or []
if tags:
tag_set = set(tags)
items = [
item for item in items
if any(tag in tag_set for tag in item.get("tags", []))
]
return items
class SearchStrategy:
"""Encapsulates text and fuzzy matching behaviour for model queries."""
DEFAULT_OPTIONS: Dict[str, Any] = {
"filename": True,
"modelname": True,
"tags": False,
"recursive": True,
"creator": False,
}
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Merge provided options with defaults without mutating input."""
normalized = dict(self.DEFAULT_OPTIONS)
if options:
normalized.update(options)
return normalized
def apply(
self,
data: Iterable[Dict[str, Any]],
search_term: str,
options: Dict[str, Any],
fuzzy: bool = False,
) -> List[Dict[str, Any]]:
"""Return items matching the search term using the configured strategy."""
if not search_term:
return list(data)
search_lower = search_term.lower()
results: List[Dict[str, Any]] = []
for item in data:
if options.get("filename", True):
candidate = item.get("file_name", "")
if self._matches(candidate, search_term, search_lower, fuzzy):
results.append(item)
continue
if options.get("modelname", True):
candidate = item.get("model_name", "")
if self._matches(candidate, search_term, search_lower, fuzzy):
results.append(item)
continue
if options.get("tags", False):
tags = item.get("tags", []) or []
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
results.append(item)
continue
if options.get("creator", False):
creator_username = ""
civitai = item.get("civitai")
if isinstance(civitai, dict):
creator = civitai.get("creator")
if isinstance(creator, dict):
creator_username = creator.get("username", "")
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
results.append(item)
continue
return results
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
if not candidate:
return False
candidate_lower = candidate.lower()
if fuzzy:
return self._fuzzy_match(candidate, search_term)
return search_lower in candidate_lower