mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
refactor: enhance BaseModelService initialization and improve filtering logic
This commit is contained in:
@@ -5,99 +5,89 @@ import os
|
|||||||
|
|
||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||||
from .settings_manager import settings
|
from .settings_manager import settings as default_settings
|
||||||
from ..utils.utils import fuzzy_match
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BaseModelService(ABC):
|
class BaseModelService(ABC):
|
||||||
"""Base service class for all model types"""
|
"""Base service class for all model types"""
|
||||||
|
|
||||||
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
def __init__(
|
||||||
"""Initialize the service
|
self,
|
||||||
|
model_type: str,
|
||||||
|
scanner,
|
||||||
|
metadata_class: Type[BaseModelMetadata],
|
||||||
|
*,
|
||||||
|
cache_repository: Optional[ModelCacheRepository] = None,
|
||||||
|
filter_set: Optional[ModelFilterSet] = None,
|
||||||
|
search_strategy: Optional[SearchStrategy] = None,
|
||||||
|
settings_provider: Optional[SettingsProvider] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type: Type of model (lora, checkpoint, etc.)
|
model_type: Type of model (lora, checkpoint, etc.).
|
||||||
scanner: Model scanner instance
|
scanner: Model scanner instance.
|
||||||
metadata_class: Metadata class for this model type
|
metadata_class: Metadata class for this model type.
|
||||||
|
cache_repository: Custom repository for cache access (primarily for tests).
|
||||||
|
filter_set: Filter component controlling folder/tag/favorites logic.
|
||||||
|
search_strategy: Search component for fuzzy/text matching.
|
||||||
|
settings_provider: Settings object; defaults to the global settings manager.
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.scanner = scanner
|
self.scanner = scanner
|
||||||
self.metadata_class = metadata_class
|
self.metadata_class = metadata_class
|
||||||
|
self.settings = settings_provider or default_settings
|
||||||
|
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||||
|
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||||
|
self.search_strategy = search_strategy or SearchStrategy()
|
||||||
|
|
||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
async def get_paginated_data(
|
||||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
self,
|
||||||
base_models: list = None, tags: list = None,
|
page: int,
|
||||||
search_options: dict = None, hash_filters: dict = None,
|
page_size: int,
|
||||||
favorites_only: bool = False, **kwargs) -> Dict:
|
sort_by: str = 'name',
|
||||||
"""Get paginated and filtered model data
|
folder: str = None,
|
||||||
|
search: str = None,
|
||||||
|
fuzzy_search: bool = False,
|
||||||
|
base_models: list = None,
|
||||||
|
tags: list = None,
|
||||||
|
search_options: dict = None,
|
||||||
|
hash_filters: dict = None,
|
||||||
|
favorites_only: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict:
|
||||||
|
"""Get paginated and filtered model data"""
|
||||||
|
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||||
|
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||||
|
|
||||||
Args:
|
|
||||||
page: Page number (1-based)
|
|
||||||
page_size: Number of items per page
|
|
||||||
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
|
|
||||||
folder: Folder filter
|
|
||||||
search: Search term
|
|
||||||
fuzzy_search: Whether to use fuzzy search
|
|
||||||
base_models: List of base models to filter by
|
|
||||||
tags: List of tags to filter by
|
|
||||||
search_options: Search options dict
|
|
||||||
hash_filters: Hash filtering options
|
|
||||||
favorites_only: Filter for favorites only
|
|
||||||
**kwargs: Additional model-specific filters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing paginated results
|
|
||||||
"""
|
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
|
|
||||||
# Parse sort_by into sort_key and order
|
|
||||||
if ':' in sort_by:
|
|
||||||
sort_key, order = sort_by.split(':', 1)
|
|
||||||
sort_key = sort_key.strip()
|
|
||||||
order = order.strip().lower()
|
|
||||||
if order not in ('asc', 'desc'):
|
|
||||||
order = 'asc'
|
|
||||||
else:
|
|
||||||
sort_key = sort_by.strip()
|
|
||||||
order = 'asc'
|
|
||||||
|
|
||||||
# Get default search options if not provided
|
|
||||||
if search_options is None:
|
|
||||||
search_options = {
|
|
||||||
'filename': True,
|
|
||||||
'modelname': True,
|
|
||||||
'tags': False,
|
|
||||||
'recursive': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the base data set using new sort logic
|
|
||||||
filtered_data = await cache.get_sorted_data(sort_key, order)
|
|
||||||
|
|
||||||
# Apply hash filtering if provided (highest priority)
|
|
||||||
if hash_filters:
|
if hash_filters:
|
||||||
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||||
|
|
||||||
# Jump to pagination for hash filters
|
|
||||||
return self._paginate(filtered_data, page, page_size)
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
# Apply common filters
|
|
||||||
filtered_data = await self._apply_common_filters(
|
filtered_data = await self._apply_common_filters(
|
||||||
filtered_data, folder, base_models, tags, favorites_only, search_options
|
sorted_data,
|
||||||
|
folder=folder,
|
||||||
|
base_models=base_models,
|
||||||
|
tags=tags,
|
||||||
|
favorites_only=favorites_only,
|
||||||
|
search_options=search_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply search filtering
|
|
||||||
if search:
|
if search:
|
||||||
filtered_data = await self._apply_search_filters(
|
filtered_data = await self._apply_search_filters(
|
||||||
filtered_data, search, fuzzy_search, search_options
|
filtered_data,
|
||||||
|
search,
|
||||||
|
fuzzy_search,
|
||||||
|
search_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply model-specific filters
|
|
||||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||||
|
|
||||||
return self._paginate(filtered_data, page, page_size)
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
|
|
||||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||||
"""Apply hash-based filtering"""
|
"""Apply hash-based filtering"""
|
||||||
single_hash = hash_filters.get('single_hash')
|
single_hash = hash_filters.get('single_hash')
|
||||||
@@ -120,113 +110,36 @@ class BaseModelService(ABC):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
async def _apply_common_filters(
|
||||||
base_models: list = None, tags: list = None,
|
self,
|
||||||
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
data: List[Dict],
|
||||||
|
folder: str = None,
|
||||||
|
base_models: list = None,
|
||||||
|
tags: list = None,
|
||||||
|
favorites_only: bool = False,
|
||||||
|
search_options: dict = None,
|
||||||
|
) -> List[Dict]:
|
||||||
"""Apply common filters that work across all model types"""
|
"""Apply common filters that work across all model types"""
|
||||||
# Apply SFW filtering if enabled in settings
|
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||||
if settings.get('show_only_sfw', False):
|
criteria = FilterCriteria(
|
||||||
data = [
|
folder=folder,
|
||||||
item for item in data
|
base_models=base_models,
|
||||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
tags=tags,
|
||||||
]
|
favorites_only=favorites_only,
|
||||||
|
search_options=normalized_options,
|
||||||
|
)
|
||||||
|
return self.filter_set.apply(data, criteria)
|
||||||
|
|
||||||
# Apply favorites filtering if enabled
|
async def _apply_search_filters(
|
||||||
if favorites_only:
|
self,
|
||||||
data = [
|
data: List[Dict],
|
||||||
item for item in data
|
search: str,
|
||||||
if item.get('favorite', False) is True
|
fuzzy_search: bool,
|
||||||
]
|
search_options: dict,
|
||||||
|
) -> List[Dict]:
|
||||||
# Apply folder filtering
|
|
||||||
if folder is not None:
|
|
||||||
if search_options and search_options.get('recursive', True):
|
|
||||||
# Recursive folder filtering - include all subfolders
|
|
||||||
# Ensure we match exact folder or its subfolders by checking path boundaries
|
|
||||||
if folder == "":
|
|
||||||
# Empty folder means root - include all items
|
|
||||||
pass # Don't filter anything
|
|
||||||
else:
|
|
||||||
# Add trailing slash to ensure we match folder boundaries correctly
|
|
||||||
folder_with_separator = folder + "/"
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if (item['folder'] == folder or
|
|
||||||
item['folder'].startswith(folder_with_separator))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Exact folder filtering
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if item['folder'] == folder
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply base model filtering
|
|
||||||
if base_models and len(base_models) > 0:
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if item.get('base_model') in base_models
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply tag filtering
|
|
||||||
if tags and len(tags) > 0:
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if any(tag in item.get('tags', []) for tag in tags)
|
|
||||||
]
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
async def _apply_search_filters(self, data: List[Dict], search: str,
|
|
||||||
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
|
||||||
"""Apply search filtering"""
|
"""Apply search filtering"""
|
||||||
search_results = []
|
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||||
|
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
|
||||||
for item in data:
|
|
||||||
# Search by file name
|
|
||||||
if search_options.get('filename', True):
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(item.get('file_name', ''), search):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
elif search.lower() in item.get('file_name', '').lower():
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by model name
|
|
||||||
if search_options.get('modelname', True):
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(item.get('model_name', ''), search):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
elif search.lower() in item.get('model_name', '').lower():
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by tags
|
|
||||||
if search_options.get('tags', False) and 'tags' in item:
|
|
||||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
|
||||||
for tag in item['tags']):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by creator
|
|
||||||
civitai = item.get('civitai')
|
|
||||||
creator_username = ''
|
|
||||||
if civitai and isinstance(civitai, dict):
|
|
||||||
creator = civitai.get('creator')
|
|
||||||
if creator and isinstance(creator, dict):
|
|
||||||
creator_username = creator.get('username', '')
|
|
||||||
if search_options.get('creator', False) and creator_username:
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(creator_username, search):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
elif search.lower() in creator_username.lower():
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||||
|
|||||||
196
py/services/model_query.py
Normal file
196
py/services/model_query.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
|
||||||
|
|
||||||
|
from ..utils.constants import NSFW_LEVELS
|
||||||
|
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsProvider(Protocol):
|
||||||
|
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SortParams:
|
||||||
|
"""Normalized representation of sorting instructions."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
order: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FilterCriteria:
|
||||||
|
"""Container for model list filtering options."""
|
||||||
|
|
||||||
|
folder: Optional[str] = None
|
||||||
|
base_models: Optional[Sequence[str]] = None
|
||||||
|
tags: Optional[Sequence[str]] = None
|
||||||
|
favorites_only: bool = False
|
||||||
|
search_options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCacheRepository:
|
||||||
|
"""Adapter around scanner cache access and sort normalisation."""
|
||||||
|
|
||||||
|
def __init__(self, scanner) -> None:
|
||||||
|
self._scanner = scanner
|
||||||
|
|
||||||
|
async def get_cache(self):
|
||||||
|
"""Return the underlying cache instance from the scanner."""
|
||||||
|
return await self._scanner.get_cached_data()
|
||||||
|
|
||||||
|
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
|
||||||
|
"""Fetch cached data pre-sorted according to ``params``."""
|
||||||
|
cache = await self.get_cache()
|
||||||
|
return await cache.get_sorted_data(params.key, params.order)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_sort(sort_by: str) -> SortParams:
|
||||||
|
"""Parse an incoming sort string into key/order primitives."""
|
||||||
|
if not sort_by:
|
||||||
|
return SortParams(key="name", order="asc")
|
||||||
|
|
||||||
|
if ":" in sort_by:
|
||||||
|
raw_key, raw_order = sort_by.split(":", 1)
|
||||||
|
sort_key = raw_key.strip().lower() or "name"
|
||||||
|
order = raw_order.strip().lower()
|
||||||
|
else:
|
||||||
|
sort_key = sort_by.strip().lower() or "name"
|
||||||
|
order = "asc"
|
||||||
|
|
||||||
|
if order not in ("asc", "desc"):
|
||||||
|
order = "asc"
|
||||||
|
|
||||||
|
return SortParams(key=sort_key, order=order)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFilterSet:
|
||||||
|
"""Applies common filtering rules to the model collection."""
|
||||||
|
|
||||||
|
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
|
||||||
|
self._settings = settings
|
||||||
|
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
|
||||||
|
|
||||||
|
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
||||||
|
"""Return items that satisfy the provided criteria."""
|
||||||
|
items = list(data)
|
||||||
|
|
||||||
|
if self._settings.get("show_only_sfw", False):
|
||||||
|
threshold = self._nsfw_levels.get("R", 0)
|
||||||
|
items = [
|
||||||
|
item for item in items
|
||||||
|
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
if criteria.favorites_only:
|
||||||
|
items = [item for item in items if item.get("favorite", False)]
|
||||||
|
|
||||||
|
folder = criteria.folder
|
||||||
|
options = criteria.search_options or {}
|
||||||
|
recursive = bool(options.get("recursive", True))
|
||||||
|
if folder is not None:
|
||||||
|
if recursive:
|
||||||
|
if folder:
|
||||||
|
folder_with_sep = f"{folder}/"
|
||||||
|
items = [
|
||||||
|
item for item in items
|
||||||
|
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
items = [item for item in items if item.get("folder") == folder]
|
||||||
|
|
||||||
|
base_models = criteria.base_models or []
|
||||||
|
if base_models:
|
||||||
|
base_model_set = set(base_models)
|
||||||
|
items = [item for item in items if item.get("base_model") in base_model_set]
|
||||||
|
|
||||||
|
tags = criteria.tags or []
|
||||||
|
if tags:
|
||||||
|
tag_set = set(tags)
|
||||||
|
items = [
|
||||||
|
item for item in items
|
||||||
|
if any(tag in tag_set for tag in item.get("tags", []))
|
||||||
|
]
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
class SearchStrategy:
|
||||||
|
"""Encapsulates text and fuzzy matching behaviour for model queries."""
|
||||||
|
|
||||||
|
DEFAULT_OPTIONS: Dict[str, Any] = {
|
||||||
|
"filename": True,
|
||||||
|
"modelname": True,
|
||||||
|
"tags": False,
|
||||||
|
"recursive": True,
|
||||||
|
"creator": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
|
||||||
|
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
|
||||||
|
|
||||||
|
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""Merge provided options with defaults without mutating input."""
|
||||||
|
normalized = dict(self.DEFAULT_OPTIONS)
|
||||||
|
if options:
|
||||||
|
normalized.update(options)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
data: Iterable[Dict[str, Any]],
|
||||||
|
search_term: str,
|
||||||
|
options: Dict[str, Any],
|
||||||
|
fuzzy: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Return items matching the search term using the configured strategy."""
|
||||||
|
if not search_term:
|
||||||
|
return list(data)
|
||||||
|
|
||||||
|
search_lower = search_term.lower()
|
||||||
|
results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
if options.get("filename", True):
|
||||||
|
candidate = item.get("file_name", "")
|
||||||
|
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if options.get("modelname", True):
|
||||||
|
candidate = item.get("model_name", "")
|
||||||
|
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if options.get("tags", False):
|
||||||
|
tags = item.get("tags", []) or []
|
||||||
|
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if options.get("creator", False):
|
||||||
|
creator_username = ""
|
||||||
|
civitai = item.get("civitai")
|
||||||
|
if isinstance(civitai, dict):
|
||||||
|
creator = civitai.get("creator")
|
||||||
|
if isinstance(creator, dict):
|
||||||
|
creator_username = creator.get("username", "")
|
||||||
|
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
|
||||||
|
if not candidate:
|
||||||
|
return False
|
||||||
|
|
||||||
|
candidate_lower = candidate.lower()
|
||||||
|
if fuzzy:
|
||||||
|
return self._fuzzy_match(candidate, search_term)
|
||||||
|
return search_lower in candidate_lower
|
||||||
269
tests/services/test_base_model_service.py
Normal file
269
tests/services/test_base_model_service.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.base_model_service import BaseModelService
|
||||||
|
from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams
|
||||||
|
from py.utils.models import BaseModelMetadata
|
||||||
|
|
||||||
|
|
||||||
|
class StubSettings:
|
||||||
|
def __init__(self, values):
|
||||||
|
self._values = dict(values)
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
return self._values.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class DummyService(BaseModelService):
|
||||||
|
async def format_response(self, model_data):
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
|
||||||
|
class StubRepository:
|
||||||
|
def __init__(self, data):
|
||||||
|
self._data = list(data)
|
||||||
|
self.parse_sort_calls = []
|
||||||
|
self.fetch_sorted_calls = []
|
||||||
|
|
||||||
|
def parse_sort(self, sort_by):
|
||||||
|
params = ModelCacheRepository.parse_sort(sort_by)
|
||||||
|
self.parse_sort_calls.append(sort_by)
|
||||||
|
return params
|
||||||
|
|
||||||
|
async def fetch_sorted(self, params):
|
||||||
|
self.fetch_sorted_calls.append(params)
|
||||||
|
return list(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
class StubFilterSet:
|
||||||
|
def __init__(self, result):
|
||||||
|
self.result = list(result)
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def apply(self, data, criteria):
|
||||||
|
self.calls.append((list(data), criteria))
|
||||||
|
return list(self.result)
|
||||||
|
|
||||||
|
|
||||||
|
class StubSearchStrategy:
|
||||||
|
def __init__(self, search_result):
|
||||||
|
self.search_result = list(search_result)
|
||||||
|
self.normalize_calls = []
|
||||||
|
self.apply_calls = []
|
||||||
|
|
||||||
|
def normalize_options(self, options):
|
||||||
|
self.normalize_calls.append(options)
|
||||||
|
normalized = {"recursive": True}
|
||||||
|
if options:
|
||||||
|
normalized.update(options)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def apply(self, data, search_term, options, fuzzy):
|
||||||
|
self.apply_calls.append((list(data), search_term, options, fuzzy))
|
||||||
|
return list(self.search_result)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_paginated_data_uses_injected_collaborators():
|
||||||
|
data = [
|
||||||
|
{"model_name": "Alpha", "folder": "root"},
|
||||||
|
{"model_name": "Beta", "folder": "root"},
|
||||||
|
]
|
||||||
|
repository = StubRepository(data)
|
||||||
|
filter_set = StubFilterSet([{"model_name": "Filtered"}])
|
||||||
|
search_strategy = StubSearchStrategy([{"model_name": "SearchResult"}])
|
||||||
|
settings = StubSettings({})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=object(),
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=repository,
|
||||||
|
filter_set=filter_set,
|
||||||
|
search_strategy=search_strategy,
|
||||||
|
settings_provider=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=5,
|
||||||
|
sort_by="name:desc",
|
||||||
|
folder="root",
|
||||||
|
search="query",
|
||||||
|
fuzzy_search=True,
|
||||||
|
base_models=["base"],
|
||||||
|
tags=["tag"],
|
||||||
|
search_options={"recursive": False},
|
||||||
|
favorites_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert repository.parse_sort_calls == ["name:desc"]
|
||||||
|
assert repository.fetch_sorted_calls and isinstance(repository.fetch_sorted_calls[0], SortParams)
|
||||||
|
sort_params = repository.fetch_sorted_calls[0]
|
||||||
|
assert sort_params.key == "name" and sort_params.order == "desc"
|
||||||
|
|
||||||
|
assert filter_set.calls, "FilterSet should be invoked"
|
||||||
|
call_data, criteria = filter_set.calls[0]
|
||||||
|
assert call_data == data
|
||||||
|
assert criteria.folder == "root"
|
||||||
|
assert criteria.base_models == ["base"]
|
||||||
|
assert criteria.tags == ["tag"]
|
||||||
|
assert criteria.favorites_only is True
|
||||||
|
assert criteria.search_options.get("recursive") is False
|
||||||
|
|
||||||
|
assert search_strategy.normalize_calls == [{"recursive": False}, {"recursive": False}]
|
||||||
|
assert search_strategy.apply_calls == [([{"model_name": "Filtered"}], "query", {"recursive": False}, True)]
|
||||||
|
|
||||||
|
assert response["items"] == search_strategy.search_result
|
||||||
|
assert response["total"] == len(search_strategy.search_result)
|
||||||
|
assert response["page"] == 1
|
||||||
|
assert response["page_size"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
class FakeCache:
|
||||||
|
def __init__(self, items):
|
||||||
|
self.items = list(items)
|
||||||
|
|
||||||
|
async def get_sorted_data(self, sort_key, order):
|
||||||
|
if sort_key == "name":
|
||||||
|
data = sorted(self.items, key=lambda x: x["model_name"].lower())
|
||||||
|
if order == "desc":
|
||||||
|
data.reverse()
|
||||||
|
else:
|
||||||
|
data = list(self.items)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class FakeScanner:
|
||||||
|
def __init__(self, cache):
|
||||||
|
self._cache = cache
|
||||||
|
|
||||||
|
async def get_cached_data(self, *_, **__):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_paginated_data_filters_and_searches_combination():
|
||||||
|
items = [
|
||||||
|
{
|
||||||
|
"model_name": "Alpha",
|
||||||
|
"file_name": "alpha.safetensors",
|
||||||
|
"folder": "root/sub",
|
||||||
|
"tags": ["tag1"],
|
||||||
|
"base_model": "v1",
|
||||||
|
"favorite": True,
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "Beta",
|
||||||
|
"file_name": "beta.safetensors",
|
||||||
|
"folder": "root",
|
||||||
|
"tags": ["tag2"],
|
||||||
|
"base_model": "v2",
|
||||||
|
"favorite": False,
|
||||||
|
"preview_nsfw_level": 999,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "Gamma",
|
||||||
|
"file_name": "gamma.safetensors",
|
||||||
|
"folder": "root/sub2",
|
||||||
|
"tags": ["tag1", "tag3"],
|
||||||
|
"base_model": "v1",
|
||||||
|
"favorite": True,
|
||||||
|
"preview_nsfw_level": 0,
|
||||||
|
"civitai": {"creator": {"username": "artist"}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
cache = FakeCache(items)
|
||||||
|
scanner = FakeScanner(cache)
|
||||||
|
settings = StubSettings({"show_only_sfw": True})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=scanner,
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=ModelCacheRepository(scanner),
|
||||||
|
filter_set=ModelFilterSet(settings),
|
||||||
|
search_strategy=SearchStrategy(),
|
||||||
|
settings_provider=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=1,
|
||||||
|
sort_by="name:asc",
|
||||||
|
folder="root",
|
||||||
|
search="artist",
|
||||||
|
base_models=["v1"],
|
||||||
|
tags=["tag1"],
|
||||||
|
search_options={"creator": True, "tags": True},
|
||||||
|
favorites_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response["items"] == [items[2]]
|
||||||
|
assert response["total"] == 1
|
||||||
|
assert response["page"] == 1
|
||||||
|
assert response["page_size"] == 1
|
||||||
|
assert response["total_pages"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class PassThroughFilterSet:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def apply(self, data, criteria):
|
||||||
|
self.calls.append(criteria)
|
||||||
|
return list(data)
|
||||||
|
|
||||||
|
|
||||||
|
class NoSearchStrategy:
|
||||||
|
def __init__(self):
|
||||||
|
self.normalize_calls = []
|
||||||
|
self.apply_called = False
|
||||||
|
|
||||||
|
def normalize_options(self, options):
|
||||||
|
self.normalize_calls.append(options)
|
||||||
|
return {"recursive": True}
|
||||||
|
|
||||||
|
def apply(self, *args, **kwargs):
|
||||||
|
self.apply_called = True
|
||||||
|
pytest.fail("Search should not be invoked when no search term is provided")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_paginated_data_paginates_without_search():
|
||||||
|
items = [
|
||||||
|
{"model_name": name, "folder": "root"}
|
||||||
|
for name in ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"]
|
||||||
|
]
|
||||||
|
|
||||||
|
repository = StubRepository(items)
|
||||||
|
filter_set = PassThroughFilterSet()
|
||||||
|
search_strategy = NoSearchStrategy()
|
||||||
|
settings = StubSettings({})
|
||||||
|
|
||||||
|
service = DummyService(
|
||||||
|
model_type="stub",
|
||||||
|
scanner=object(),
|
||||||
|
metadata_class=BaseModelMetadata,
|
||||||
|
cache_repository=repository,
|
||||||
|
filter_set=filter_set,
|
||||||
|
search_strategy=search_strategy,
|
||||||
|
settings_provider=settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await service.get_paginated_data(
|
||||||
|
page=2,
|
||||||
|
page_size=2,
|
||||||
|
sort_by="name:asc",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert repository.parse_sort_calls == ["name:asc"]
|
||||||
|
assert len(repository.fetch_sorted_calls) == 1
|
||||||
|
assert filter_set.calls and filter_set.calls[0].favorites_only is False
|
||||||
|
assert search_strategy.apply_called is False
|
||||||
|
assert response["items"] == items[2:4]
|
||||||
|
assert response["total"] == len(items)
|
||||||
|
assert response["page"] == 2
|
||||||
|
assert response["page_size"] == 2
|
||||||
|
assert response["total_pages"] == 3
|
||||||
Reference in New Issue
Block a user