feat: improve code formatting and readability in model handlers

- Add blank line after module docstring for better PEP 8 compliance
- Reformat long lines to adhere to 88-character limit using Black-style formatting
- Improve string consistency by using double quotes consistently
- Enhance readability of complex list comprehensions and method calls
- Maintain all existing functionality while improving code structure
This commit is contained in:
Will Miao
2026-01-13 22:56:55 +08:00
parent 0c96e8d328
commit bc08a45214
7 changed files with 918 additions and 338 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -25,9 +25,10 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from .model_update_service import ModelUpdateService
class BaseModelService(ABC):
"""Base service class for all model types"""
def __init__(
self,
model_type: str,
@@ -60,13 +61,14 @@ class BaseModelService(ABC):
self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy()
self.update_service = update_service
async def get_paginated_data(
self,
page: int,
page_size: int,
sort_by: str = 'name',
sort_by: str = "name",
folder: str = None,
folder_exclude: list = None,
search: str = None,
fuzzy_search: bool = False,
base_models: list = None,
@@ -85,7 +87,7 @@ class BaseModelService(ABC):
sort_params = self.cache_repository.parse_sort(sort_by)
t0 = time.perf_counter()
if sort_params.key == 'usage':
if sort_params.key == "usage":
sorted_data = await self._fetch_with_usage_sort(sort_params)
else:
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
@@ -99,6 +101,7 @@ class BaseModelService(ABC):
filtered_data = await self._apply_common_filters(
sorted_data,
folder=folder,
folder_exclude=folder_exclude,
base_models=base_models,
model_types=model_types,
tags=tags,
@@ -118,10 +121,14 @@ class BaseModelService(ABC):
# Apply license-based filters
if credit_required is not None:
filtered_data = await self._apply_credit_required_filter(filtered_data, credit_required)
filtered_data = await self._apply_credit_required_filter(
filtered_data, credit_required
)
if allow_selling_generated_content is not None:
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content)
filtered_data = await self._apply_allow_selling_filter(
filtered_data, allow_selling_generated_content
)
filter_duration = time.perf_counter() - t1
post_filter_count = len(filtered_data)
@@ -130,8 +137,7 @@ class BaseModelService(ABC):
if update_available_only:
annotated_for_filter = await self._annotate_update_flags(filtered_data)
filtered_data = [
item for item in annotated_for_filter
if item.get('update_available')
item for item in annotated_for_filter if item.get("update_available")
]
update_filter_duration = time.perf_counter() - t2
final_count = len(filtered_data)
@@ -143,20 +149,27 @@ class BaseModelService(ABC):
t4 = time.perf_counter()
if update_available_only:
# Items already include update flags thanks to the pre-filter annotation.
paginated['items'] = list(paginated['items'])
paginated["items"] = list(paginated["items"])
else:
paginated['items'] = await self._annotate_update_flags(
paginated['items'],
paginated["items"] = await self._annotate_update_flags(
paginated["items"],
)
annotate_duration = time.perf_counter() - t4
overall_duration = time.perf_counter() - overall_start
logger.debug(
"%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). "
"Counts: initial=%d, post_filter=%d, final=%d",
self.__class__.__name__, overall_duration, fetch_duration, filter_duration,
update_filter_duration, pagination_duration, annotate_duration,
initial_count, post_filter_count, final_count
self.__class__.__name__,
overall_duration,
fetch_duration,
filter_duration,
update_filter_duration,
pagination_duration,
annotate_duration,
initial_count,
post_filter_count,
final_count,
)
return paginated
@@ -167,11 +180,11 @@ class BaseModelService(ABC):
# Map model type to usage stats bucket
bucket_map = {
'lora': 'loras',
'checkpoint': 'checkpoints',
"lora": "loras",
"checkpoint": "checkpoints",
# 'embedding': 'embeddings', # TODO: Enable when embedding usage tracking is implemented
}
bucket_key = bucket_map.get(self.model_type, '')
bucket_key = bucket_map.get(self.model_type, "")
usage_stats = UsageStats()
stats = await usage_stats.get_stats()
@@ -179,45 +192,47 @@ class BaseModelService(ABC):
annotated = []
for item in raw_items:
sha = (item.get('sha256') or '').lower()
usage_info = usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {}
usage_count = usage_info.get('total', 0) if isinstance(usage_info, dict) else 0
annotated.append({**item, 'usage_count': usage_count})
sha = (item.get("sha256") or "").lower()
usage_info = (
usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {}
)
usage_count = (
usage_info.get("total", 0) if isinstance(usage_info, dict) else 0
)
annotated.append({**item, "usage_count": usage_count})
reverse = sort_params.order == 'desc'
reverse = sort_params.order == "desc"
annotated.sort(
key=lambda x: (x.get('usage_count', 0), x.get('model_name', '').lower()),
reverse=reverse
key=lambda x: (x.get("usage_count", 0), x.get("model_name", "").lower()),
reverse=reverse,
)
return annotated
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
async def _apply_hash_filters(
self, data: List[Dict], hash_filters: Dict
) -> List[Dict]:
"""Apply hash-based filtering"""
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
single_hash = hash_filters.get("single_hash")
multiple_hashes = hash_filters.get("multiple_hashes")
if single_hash:
# Filter by single hash
single_hash = single_hash.lower()
return [
item for item in data
if item.get('sha256', '').lower() == single_hash
item for item in data if item.get("sha256", "").lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes)
return [
item for item in data
if item.get('sha256', '').lower() in hash_set
]
return [item for item in data if item.get("sha256", "").lower() in hash_set]
return data
async def _apply_common_filters(
self,
data: List[Dict],
folder: str = None,
folder_exclude: list = None,
base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None,
@@ -228,6 +243,7 @@ class BaseModelService(ABC):
normalized_options = self.search_strategy.normalize_options(search_options)
criteria = FilterCriteria(
folder=folder,
folder_exclude=folder_exclude,
base_models=base_models,
model_types=model_types,
tags=tags,
@@ -235,7 +251,7 @@ class BaseModelService(ABC):
search_options=normalized_options,
)
return self.filter_set.apply(data, criteria)
async def _apply_search_filters(
self,
data: List[Dict],
@@ -245,28 +261,34 @@ class BaseModelService(ABC):
) -> List[Dict]:
"""Apply search filtering"""
normalized_options = self.search_strategy.normalize_options(search_options)
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
return self.search_strategy.apply(
data, search, normalized_options, fuzzy_search
)
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed"""
return data
async def _apply_credit_required_filter(self, data: List[Dict], credit_required: bool) -> List[Dict]:
async def _apply_credit_required_filter(
self, data: List[Dict], credit_required: bool
) -> List[Dict]:
"""Apply credit required filtering based on license_flags.
Args:
data: List of model data items
credit_required:
credit_required:
- True: Return items where credit is required (allowNoCredit=False)
- False: Return items where credit is not required (allowNoCredit=True)
"""
filtered_data = []
for item in data:
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
license_flags = item.get(
"license_flags", 127
) # Default to all permissions enabled
# Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required)
allow_no_credit = bool(license_flags & (1 << 0))
# If credit_required is True, we want items where allowNoCredit is False (credit required)
# If credit_required is False, we want items where allowNoCredit is True (no credit required)
if credit_required:
@@ -275,26 +297,30 @@ class BaseModelService(ABC):
else:
if allow_no_credit: # Credit is not required
filtered_data.append(item)
return filtered_data
async def _apply_allow_selling_filter(self, data: List[Dict], allow_selling: bool) -> List[Dict]:
async def _apply_allow_selling_filter(
self, data: List[Dict], allow_selling: bool
) -> List[Dict]:
"""Apply allow selling generated content filtering based on license_flags.
Args:
data: List of model data items
allow_selling:
allow_selling:
- True: Return items where selling generated content is allowed (allowCommercialUse contains Image)
- False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image)
"""
filtered_data = []
for item in data:
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
license_flags = item.get(
"license_flags", 127
) # Default to all permissions enabled
# Bits 1-4 represent commercial use permissions
# Bit 1 specifically represents Image permission (allowCommercialUse contains Image)
has_image_permission = bool(license_flags & (1 << 1))
# If allow_selling is True, we want items where Image permission is granted
# If allow_selling is False, we want items where Image permission is not granted
if allow_selling:
@@ -303,7 +329,7 @@ class BaseModelService(ABC):
else:
if not has_image_permission: # Selling generated content is not allowed
filtered_data.append(item)
return filtered_data
async def _annotate_update_flags(
@@ -321,7 +347,7 @@ class BaseModelService(ABC):
if self.update_service is None:
for item in annotated:
item['update_available'] = False
item["update_available"] = False
return annotated
id_to_items: Dict[int, List[Dict]] = {}
@@ -329,7 +355,7 @@ class BaseModelService(ABC):
for item in annotated:
model_id = self._extract_model_id(item)
if model_id is None:
item['update_available'] = False
item["update_available"] = False
continue
if model_id not in id_to_items:
id_to_items[model_id] = []
@@ -405,13 +431,19 @@ class BaseModelService(ABC):
default_flag = bool(resolved.get(model_id, False)) if resolved else False
record = records.get(model_id) if records else None
base_highest_versions = (
self._build_highest_local_versions_by_base(record) if same_base_mode and record else {}
self._build_highest_local_versions_by_base(record)
if same_base_mode and record
else {}
)
for item in items_for_id:
if same_base_mode and record is not None:
base_model = self._extract_base_model(item)
normalized_base = self._normalize_base_model_name(base_model)
threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None
threshold_version = (
base_highest_versions.get(normalized_base)
if normalized_base
else None
)
if threshold_version is None:
threshold_version = self._extract_version_id(item)
flag = record.has_update_for_base(
@@ -420,17 +452,17 @@ class BaseModelService(ABC):
)
else:
flag = default_flag
item['update_available'] = flag
item["update_available"] = flag
return annotated
@staticmethod
def _extract_model_id(item: Dict) -> Optional[int]:
civitai = item.get('civitai') if isinstance(item, dict) else None
civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict):
return None
try:
value = civitai.get('modelId')
value = civitai.get("modelId")
if value is None:
return None
return int(value)
@@ -439,10 +471,10 @@ class BaseModelService(ABC):
@staticmethod
def _extract_version_id(item: Dict) -> Optional[int]:
civitai = item.get('civitai') if isinstance(item, dict) else None
civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict):
return None
value = civitai.get('id')
value = civitai.get("id")
if value is None:
return None
try:
@@ -452,7 +484,7 @@ class BaseModelService(ABC):
@staticmethod
def _extract_base_model(item: Dict) -> Optional[str]:
value = item.get('base_model')
value = item.get("base_model")
if value is None:
return None
if isinstance(value, str):
@@ -489,7 +521,9 @@ class BaseModelService(ABC):
for version in getattr(record, "versions", []):
if not getattr(version, "is_in_library", False):
continue
normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None))
normalized_base = self._normalize_base_model_name(
getattr(version, "base_model", None)
)
if normalized_base is None:
continue
version_id = getattr(version, "version_id", None)
@@ -506,25 +540,25 @@ class BaseModelService(ABC):
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
"items": data[start_idx:end_idx],
"total": total_items,
"page": page,
"page_size": page_size,
"total_pages": (total_items + page_size - 1) // page_size,
}
@abstractmethod
async def format_response(self, model_data: Dict) -> Dict:
"""Format model data for API response - must be implemented by subclasses"""
pass
# Common service methods that delegate to scanner
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
"""Get top tags sorted by frequency"""
return await self.scanner.get_top_tags(limit)
async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit)
@@ -535,62 +569,85 @@ class BaseModelService(ABC):
type_counts: Dict[str, int] = {}
for entry in cache.raw_data:
normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry))
normalized_type = normalize_civitai_model_type(
resolve_civitai_model_type(entry)
)
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
continue
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
sorted_types = sorted(
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
[
{"type": model_type, "count": count}
for model_type, count in type_counts.items()
],
key=lambda value: value["count"],
reverse=True,
)
return sorted_types[:limit]
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self.scanner.has_hash(sha256)
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self.scanner.get_path_by_hash(sha256)
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self.scanner.get_hash_by_path(file_path)
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
async def scan_models(
self, force_refresh: bool = False, rebuild_cache: bool = False
):
"""Trigger model scanning"""
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
return await self.scanner.get_cached_data(
force_refresh=force_refresh, rebuild_cache=rebuild_cache
)
async def get_model_info_by_name(self, name: str):
"""Get model information by name"""
return await self.scanner.get_model_info_by_name(name)
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
return self.scanner.get_model_roots()
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images", "customImages", "creator"
]
fields = (
["id", "modelId", "name", "trainedWords"]
if minimal
else [
"id",
"modelId",
"name",
"createdAt",
"updatedAt",
"publishedAt",
"trainedWords",
"baseModel",
"description",
"model",
"images",
"customImages",
"creator",
]
)
return {k: data[k] for k in fields if k in data}
async def get_folder_tree(self, model_root: str) -> Dict:
"""Get hierarchical folder tree for a specific model root"""
cache = await self.scanner.get_cached_data()
# Build tree structure from folders
tree = {}
for folder in cache.folders:
# Check if this folder belongs to the specified model root
folder_belongs_to_root = False
@@ -598,95 +655,96 @@ class BaseModelService(ABC):
if root == model_root:
folder_belongs_to_root = True
break
if not folder_belongs_to_root:
continue
# Split folder path into components
parts = folder.split('/') if folder else []
parts = folder.split("/") if folder else []
current_level = tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return tree
async def get_unified_folder_tree(self) -> Dict:
"""Get unified folder tree across all model roots"""
cache = await self.scanner.get_cached_data()
# Build unified tree structure by analyzing all relative paths
unified_tree = {}
# Get all model roots for path normalization
model_roots = self.scanner.get_model_roots()
for folder in cache.folders:
if not folder: # Skip empty folders
continue
# Find which root this folder belongs to by checking the actual file paths
# This is a simplified approach - we'll use the folder as-is since it should already be relative
relative_path = folder
# Split folder path into components
parts = relative_path.split('/')
parts = relative_path.split("/")
current_level = unified_tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return unified_tree
async def get_model_notes(self, model_name: str) -> Optional[str]:
"""Get notes for a specific model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
return model.get('notes', '')
if model["file_name"] == model_name:
return model.get("notes", "")
return None
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
"""Get the static preview URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
preview_url = model.get('preview_url')
if model["file_name"] == model_name:
preview_url = model.get("preview_url")
if preview_url:
from ..config import config
return config.get_preview_static_url(preview_url)
return '/loras_static/images/no-preview.png'
return "/loras_static/images/no-preview.png"
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
civitai_data = model.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model["file_name"] == model_name:
civitai_data = model.get("civitai", {})
model_id = civitai_data.get("modelId")
version_id = civitai_data.get("id")
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
"civitai_url": civitai_url,
"model_id": str(model_id),
"version_id": str(version_id) if version_id else None,
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
return {"civitai_url": None, "model_id": None, "version_id": None}
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
"""Load full metadata for a single model.
@@ -694,18 +752,21 @@ class BaseModelService(ABC):
Listing/search endpoints return lightweight cache entries; this method performs
a lazy read of the on-disk metadata snapshot when callers need full detail.
"""
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.metadata_class
)
if should_skip or metadata is None:
return None
return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
async def get_model_description(self, file_path: str) -> Optional[str]:
"""Return the stored modelDescription field for a model."""
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.metadata_class
)
if should_skip or metadata is None:
return None
return metadata.modelDescription or ''
return metadata.modelDescription or ""
@staticmethod
def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]:
@@ -743,53 +804,64 @@ class BaseModelService(ABC):
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
"""Sort paths by how well they satisfy the include tokens."""
path_lower = relative_path.lower()
prefix_hits = sum(1 for term in include_terms if term and path_lower.startswith(term))
match_positions = [path_lower.find(term) for term in include_terms if term and term in path_lower]
prefix_hits = sum(
1 for term in include_terms if term and path_lower.startswith(term)
)
match_positions = [
path_lower.find(term)
for term in include_terms
if term and term in path_lower
]
first_match_index = min(match_positions) if match_positions else 0
return (-prefix_hits, first_match_index, len(relative_path), path_lower)
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
async def search_relative_paths(
self, search_term: str, limit: int = 15
) -> List[str]:
"""Search model relative file paths for autocomplete functionality"""
cache = await self.scanner.get_cached_data()
include_terms, exclude_terms = self._parse_search_tokens(search_term)
matching_paths = []
# Get model roots for path calculation
model_roots = self.scanner.get_model_roots()
for model in cache.raw_data:
file_path = model.get('file_path', '')
file_path = model.get("file_path", "")
if not file_path:
continue
# Calculate relative path from model root
relative_path = None
for root in model_roots:
# Normalize paths for comparison
normalized_root = os.path.normpath(root)
normalized_file = os.path.normpath(file_path)
if normalized_file.startswith(normalized_root):
# Remove root and leading separator to get relative path
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
relative_path = normalized_file[len(normalized_root) :].lstrip(
os.sep
)
break
if not relative_path:
continue
relative_lower = relative_path.lower()
if self._relative_path_matches_tokens(relative_lower, include_terms, exclude_terms):
if self._relative_path_matches_tokens(
relative_lower, include_terms, exclude_terms
):
matching_paths.append(relative_path)
if len(matching_paths) >= limit * 2: # Get more for better sorting
break
# Sort by relevance (prefix and earliest hits first, then by length and alphabetically)
matching_paths.sort(
key=lambda relative: self._relative_path_sort_key(relative, include_terms)
)
return matching_paths[:limit]

View File

@@ -1,7 +1,18 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Protocol,
Callable,
)
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match
@@ -51,8 +62,7 @@ def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers."""
def get(self, key: str, default: Any = None) -> Any:
...
def get(self, key: str, default: Any = None) -> Any: ...
@dataclass(frozen=True)
@@ -68,6 +78,7 @@ class FilterCriteria:
"""Container for model list filtering options."""
folder: Optional[str] = None
folder_exclude: Optional[Sequence[str]] = None
base_models: Optional[Sequence[str]] = None
tags: Optional[Dict[str, str]] = None
favorites_only: bool = False
@@ -113,11 +124,15 @@ class ModelCacheRepository:
class ModelFilterSet:
"""Applies common filtering rules to the model collection."""
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
def __init__(
self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None
) -> None:
self._settings = settings
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
def apply(
self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria
) -> List[Dict[str, Any]]:
"""Return items that satisfy the provided criteria."""
overall_start = time.perf_counter()
items = list(data)
@@ -127,8 +142,10 @@ class ModelFilterSet:
t0 = time.perf_counter()
threshold = self._nsfw_levels.get("R", 0)
items = [
item for item in items
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
item
for item in items
if not item.get("preview_nsfw_level")
or item.get("preview_nsfw_level") < threshold
]
sfw_duration = time.perf_counter() - t0
else:
@@ -142,20 +159,44 @@ class ModelFilterSet:
folder_duration = 0
folder = criteria.folder
folder_exclude = criteria.folder_exclude or []
options = criteria.search_options or {}
recursive = bool(options.get("recursive", True))
# Apply folder exclude filters first
if folder_exclude:
t0 = time.perf_counter()
for exclude_folder in folder_exclude:
if exclude_folder:
# Check exact match OR prefix match (for subfolders)
# Normalize exclude_folder for prefix matching
if not exclude_folder.endswith("/"):
exclude_prefix = f"{exclude_folder}/"
else:
exclude_prefix = exclude_folder
items = [
item
for item in items
if item.get("folder") != exclude_folder
and not item.get("folder", "").startswith(exclude_prefix)
]
folder_duration = time.perf_counter() - t0
# Apply folder include filters
if folder is not None:
t0 = time.perf_counter()
if recursive:
if folder:
folder_with_sep = f"{folder}/"
items = [
item for item in items
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
item
for item in items
if item.get("folder") == folder
or item.get("folder", "").startswith(folder_with_sep)
]
else:
items = [item for item in items if item.get("folder") == folder]
folder_duration = time.perf_counter() - t0
folder_duration = time.perf_counter() - t0 + folder_duration
base_models_duration = 0
base_models = criteria.base_models or []
@@ -183,25 +224,23 @@ class ModelFilterSet:
include_tags = {tag for tag in tag_filters if tag}
if include_tags:
def matches_include(item_tags):
if not item_tags and "__no_tags__" in include_tags:
return True
return any(tag in include_tags for tag in (item_tags or []))
items = [
item for item in items
if matches_include(item.get("tags"))
]
items = [item for item in items if matches_include(item.get("tags"))]
if exclude_tags:
def matches_exclude(item_tags):
if not item_tags and "__no_tags__" in exclude_tags:
return True
return any(tag in exclude_tags for tag in (item_tags or []))
items = [
item for item in items
if not matches_exclude(item.get("tags"))
item for item in items if not matches_exclude(item.get("tags"))
]
tags_duration = time.perf_counter() - t0
@@ -210,26 +249,35 @@ class ModelFilterSet:
if model_types:
t0 = time.perf_counter()
normalized_model_types = {
model_type for model_type in (
model_type
for model_type in (
normalize_civitai_model_type(value) for value in model_types
)
if model_type
}
if normalized_model_types:
items = [
item for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
item
for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item))
in normalized_model_types
]
model_types_duration = time.perf_counter() - t0
duration = time.perf_counter() - overall_start
if duration > 0.1: # Only log if it's potentially slow
if duration > 0.1: # Only log if it's potentially slow
logger.debug(
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). "
"Count: %d -> %d",
duration, sfw_duration, favorites_duration, folder_duration,
base_models_duration, tags_duration, model_types_duration,
initial_count, len(items)
duration,
sfw_duration,
favorites_duration,
folder_duration,
base_models_duration,
tags_duration,
model_types_duration,
initial_count,
len(items),
)
return items
@@ -245,7 +293,9 @@ class SearchStrategy:
"creator": False,
}
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
def __init__(
self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None
) -> None:
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
@@ -284,7 +334,9 @@ class SearchStrategy:
if options.get("tags", False):
tags = item.get("tags", []) or []
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
if any(
self._matches(tag, search_term, search_lower, fuzzy) for tag in tags
):
results.append(item)
continue
@@ -295,13 +347,17 @@ class SearchStrategy:
creator = civitai.get("creator")
if isinstance(creator, dict):
creator_username = creator.get("username", "")
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
if creator_username and self._matches(
creator_username, search_term, search_lower, fuzzy
):
results.append(item)
continue
return results
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
def _matches(
self, candidate: str, search_term: str, search_lower: str, fuzzy: bool
) -> bool:
if not isinstance(candidate, str):
candidate = "" if candidate is None else str(candidate)

View File

@@ -3,13 +3,16 @@ from py.services.model_query import ModelFilterSet, FilterCriteria
from py.services.recipe_scanner import RecipeScanner
from types import SimpleNamespace
# Mock settings
class MockSettings:
def get(self, key, default=None):
return default
# --- Model Filtering Tests ---
def test_model_filter_set_root_recursive_true():
filter_set = ModelFilterSet(MockSettings())
items = [
@@ -17,13 +20,14 @@ def test_model_filter_set_root_recursive_true():
{"model_name": "sub_item", "folder": "sub"},
]
criteria = FilterCriteria(folder="", search_options={"recursive": True})
result = filter_set.apply(items, criteria)
assert len(result) == 2
assert any(i["model_name"] == "root_item" for i in result)
assert any(i["model_name"] == "sub_item" for i in result)
def test_model_filter_set_root_recursive_false():
filter_set = ModelFilterSet(MockSettings())
items = [
@@ -31,62 +35,185 @@ def test_model_filter_set_root_recursive_false():
{"model_name": "sub_item", "folder": "sub"},
]
criteria = FilterCriteria(folder="", search_options={"recursive": False})
result = filter_set.apply(items, criteria)
assert len(result) == 1
assert result[0]["model_name"] == "root_item"
def test_model_filter_set_folder_exclude_single():
filter_set = ModelFilterSet(MockSettings())
items = [
{"model_name": "item1", "folder": "characters/"},
{"model_name": "item2", "folder": "styles/"},
{"model_name": "item3", "folder": "characters/anime/"},
{"model_name": "item4", "folder": ""},
]
criteria = FilterCriteria(
folder_exclude=["characters/"], search_options={"recursive": True}
)
result = filter_set.apply(items, criteria)
assert len(result) == 2
model_names = {i["model_name"] for i in result}
assert model_names == {"item2", "item4"}
def test_model_filter_set_folder_exclude_multiple():
filter_set = ModelFilterSet(MockSettings())
items = [
{"model_name": "item1", "folder": "characters/"},
{"model_name": "item2", "folder": "styles/"},
{"model_name": "item3", "folder": "concepts/"},
{"model_name": "item4", "folder": "characters/anime/"},
{"model_name": "item5", "folder": ""},
]
criteria = FilterCriteria(
folder_exclude=["characters/", "styles/"], search_options={"recursive": True}
)
result = filter_set.apply(items, criteria)
assert len(result) == 2
model_names = {i["model_name"] for i in result}
assert model_names == {"item3", "item5"}
def test_model_filter_set_folder_exclude_with_include():
filter_set = ModelFilterSet(MockSettings())
items = [
{"model_name": "item1", "folder": "characters/"},
{"model_name": "item2", "folder": "styles/"},
{"model_name": "item3", "folder": "characters/anime/"},
{"model_name": "item4", "folder": "styles/painting/"},
{"model_name": "item5", "folder": "concepts/"},
]
criteria = FilterCriteria(
folder="characters/",
folder_exclude=["characters/anime/"],
search_options={"recursive": True},
)
result = filter_set.apply(items, criteria)
assert len(result) == 1
assert result[0]["model_name"] == "item1"
# --- Recipe Filtering Tests ---
@pytest.mark.asyncio
async def test_recipe_scanner_root_recursive_true():
# Mock LoraScanner
class StubLoraScanner:
async def get_cached_data(self):
return SimpleNamespace(raw_data=[])
scanner = RecipeScanner(lora_scanner=StubLoraScanner())
# Manually populate cache for testing get_paginated_data logic
scanner._cache = SimpleNamespace(
raw_data=[
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
{
"id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
{
"id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
],
sorted_by_date=[
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
{
"id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
{
"id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
],
sorted_by_name=[],
version_index={}
version_index={},
)
result = await scanner.get_paginated_data(page=1, page_size=10, folder="", recursive=True)
result = await scanner.get_paginated_data(
page=1, page_size=10, folder="", recursive=True
)
assert len(result["items"]) == 2
@pytest.mark.asyncio
async def test_recipe_scanner_root_recursive_false():
# Mock LoraScanner
class StubLoraScanner:
async def get_cached_data(self):
return SimpleNamespace(raw_data=[])
scanner = RecipeScanner(lora_scanner=StubLoraScanner())
scanner._cache = SimpleNamespace(
raw_data=[
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
{
"id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
{
"id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
],
sorted_by_date=[
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
{
"id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
{
"id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
],
sorted_by_name=[],
version_index={}
version_index={},
)
result = await scanner.get_paginated_data(page=1, page_size=10, folder="", recursive=False)
result = await scanner.get_paginated_data(
page=1, page_size=10, folder="", recursive=False
)
assert len(result["items"]) == 1
assert result["items"][0]["id"] == "r1"

View File

@@ -77,11 +77,12 @@ export function useLoraPoolApi() {
params.tagsInclude?.forEach(tag => urlParams.append('tag_include', tag))
params.tagsExclude?.forEach(tag => urlParams.append('tag_exclude', tag))
// For now, use first include folder (backend currently supports single folder)
// Folder filters
if (params.foldersInclude && params.foldersInclude.length > 0) {
urlParams.set('folder', params.foldersInclude[0])
urlParams.set('recursive', 'true')
}
params.foldersExclude?.forEach(folder => urlParams.append('folder_exclude', folder))
if (params.noCreditRequired !== undefined) {
urlParams.set('credit_required', String(!params.noCreditRequired))

View File

@@ -10870,7 +10870,7 @@ function useLoraPoolApi() {
});
};
const fetchLoras = async (params) => {
var _a, _b, _c;
var _a, _b, _c, _d;
isLoading.value = true;
try {
const urlParams = new URLSearchParams();
@@ -10883,6 +10883,7 @@ function useLoraPoolApi() {
urlParams.set("folder", params.foldersInclude[0]);
urlParams.set("recursive", "true");
}
(_d = params.foldersExclude) == null ? void 0 : _d.forEach((folder) => urlParams.append("folder_exclude", folder));
if (params.noCreditRequired !== void 0) {
urlParams.set("credit_required", String(!params.noCreditRequired));
}

File diff suppressed because one or more lines are too long