feat: improve code formatting and readability in model handlers

- Add blank line after module docstring for better PEP 8 compliance
- Reformat long lines to adhere to 88-character limit using Black-style formatting
- Improve string consistency by using double quotes consistently
- Enhance readability of complex list comprehensions and method calls
- Maintain all existing functionality while improving code structure
This commit is contained in:
Will Miao
2026-01-13 22:56:55 +08:00
parent 0c96e8d328
commit bc08a45214
7 changed files with 918 additions and 338 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -25,9 +25,10 @@ logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from .model_update_service import ModelUpdateService from .model_update_service import ModelUpdateService
class BaseModelService(ABC): class BaseModelService(ABC):
"""Base service class for all model types""" """Base service class for all model types"""
def __init__( def __init__(
self, self,
model_type: str, model_type: str,
@@ -60,13 +61,14 @@ class BaseModelService(ABC):
self.filter_set = filter_set or ModelFilterSet(self.settings) self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy() self.search_strategy = search_strategy or SearchStrategy()
self.update_service = update_service self.update_service = update_service
async def get_paginated_data( async def get_paginated_data(
self, self,
page: int, page: int,
page_size: int, page_size: int,
sort_by: str = 'name', sort_by: str = "name",
folder: str = None, folder: str = None,
folder_exclude: list = None,
search: str = None, search: str = None,
fuzzy_search: bool = False, fuzzy_search: bool = False,
base_models: list = None, base_models: list = None,
@@ -85,7 +87,7 @@ class BaseModelService(ABC):
sort_params = self.cache_repository.parse_sort(sort_by) sort_params = self.cache_repository.parse_sort(sort_by)
t0 = time.perf_counter() t0 = time.perf_counter()
if sort_params.key == 'usage': if sort_params.key == "usage":
sorted_data = await self._fetch_with_usage_sort(sort_params) sorted_data = await self._fetch_with_usage_sort(sort_params)
else: else:
sorted_data = await self.cache_repository.fetch_sorted(sort_params) sorted_data = await self.cache_repository.fetch_sorted(sort_params)
@@ -99,6 +101,7 @@ class BaseModelService(ABC):
filtered_data = await self._apply_common_filters( filtered_data = await self._apply_common_filters(
sorted_data, sorted_data,
folder=folder, folder=folder,
folder_exclude=folder_exclude,
base_models=base_models, base_models=base_models,
model_types=model_types, model_types=model_types,
tags=tags, tags=tags,
@@ -118,10 +121,14 @@ class BaseModelService(ABC):
# Apply license-based filters # Apply license-based filters
if credit_required is not None: if credit_required is not None:
filtered_data = await self._apply_credit_required_filter(filtered_data, credit_required) filtered_data = await self._apply_credit_required_filter(
filtered_data, credit_required
)
if allow_selling_generated_content is not None: if allow_selling_generated_content is not None:
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content) filtered_data = await self._apply_allow_selling_filter(
filtered_data, allow_selling_generated_content
)
filter_duration = time.perf_counter() - t1 filter_duration = time.perf_counter() - t1
post_filter_count = len(filtered_data) post_filter_count = len(filtered_data)
@@ -130,8 +137,7 @@ class BaseModelService(ABC):
if update_available_only: if update_available_only:
annotated_for_filter = await self._annotate_update_flags(filtered_data) annotated_for_filter = await self._annotate_update_flags(filtered_data)
filtered_data = [ filtered_data = [
item for item in annotated_for_filter item for item in annotated_for_filter if item.get("update_available")
if item.get('update_available')
] ]
update_filter_duration = time.perf_counter() - t2 update_filter_duration = time.perf_counter() - t2
final_count = len(filtered_data) final_count = len(filtered_data)
@@ -143,20 +149,27 @@ class BaseModelService(ABC):
t4 = time.perf_counter() t4 = time.perf_counter()
if update_available_only: if update_available_only:
# Items already include update flags thanks to the pre-filter annotation. # Items already include update flags thanks to the pre-filter annotation.
paginated['items'] = list(paginated['items']) paginated["items"] = list(paginated["items"])
else: else:
paginated['items'] = await self._annotate_update_flags( paginated["items"] = await self._annotate_update_flags(
paginated['items'], paginated["items"],
) )
annotate_duration = time.perf_counter() - t4 annotate_duration = time.perf_counter() - t4
overall_duration = time.perf_counter() - overall_start overall_duration = time.perf_counter() - overall_start
logger.debug( logger.debug(
"%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). " "%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). "
"Counts: initial=%d, post_filter=%d, final=%d", "Counts: initial=%d, post_filter=%d, final=%d",
self.__class__.__name__, overall_duration, fetch_duration, filter_duration, self.__class__.__name__,
update_filter_duration, pagination_duration, annotate_duration, overall_duration,
initial_count, post_filter_count, final_count fetch_duration,
filter_duration,
update_filter_duration,
pagination_duration,
annotate_duration,
initial_count,
post_filter_count,
final_count,
) )
return paginated return paginated
@@ -167,11 +180,11 @@ class BaseModelService(ABC):
# Map model type to usage stats bucket # Map model type to usage stats bucket
bucket_map = { bucket_map = {
'lora': 'loras', "lora": "loras",
'checkpoint': 'checkpoints', "checkpoint": "checkpoints",
# 'embedding': 'embeddings', # TODO: Enable when embedding usage tracking is implemented # 'embedding': 'embeddings', # TODO: Enable when embedding usage tracking is implemented
} }
bucket_key = bucket_map.get(self.model_type, '') bucket_key = bucket_map.get(self.model_type, "")
usage_stats = UsageStats() usage_stats = UsageStats()
stats = await usage_stats.get_stats() stats = await usage_stats.get_stats()
@@ -179,45 +192,47 @@ class BaseModelService(ABC):
annotated = [] annotated = []
for item in raw_items: for item in raw_items:
sha = (item.get('sha256') or '').lower() sha = (item.get("sha256") or "").lower()
usage_info = usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {} usage_info = (
usage_count = usage_info.get('total', 0) if isinstance(usage_info, dict) else 0 usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {}
annotated.append({**item, 'usage_count': usage_count}) )
usage_count = (
usage_info.get("total", 0) if isinstance(usage_info, dict) else 0
)
annotated.append({**item, "usage_count": usage_count})
reverse = sort_params.order == 'desc' reverse = sort_params.order == "desc"
annotated.sort( annotated.sort(
key=lambda x: (x.get('usage_count', 0), x.get('model_name', '').lower()), key=lambda x: (x.get("usage_count", 0), x.get("model_name", "").lower()),
reverse=reverse reverse=reverse,
) )
return annotated return annotated
async def _apply_hash_filters(
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: self, data: List[Dict], hash_filters: Dict
) -> List[Dict]:
"""Apply hash-based filtering""" """Apply hash-based filtering"""
single_hash = hash_filters.get('single_hash') single_hash = hash_filters.get("single_hash")
multiple_hashes = hash_filters.get('multiple_hashes') multiple_hashes = hash_filters.get("multiple_hashes")
if single_hash: if single_hash:
# Filter by single hash # Filter by single hash
single_hash = single_hash.lower() single_hash = single_hash.lower()
return [ return [
item for item in data item for item in data if item.get("sha256", "").lower() == single_hash
if item.get('sha256', '').lower() == single_hash
] ]
elif multiple_hashes: elif multiple_hashes:
# Filter by multiple hashes # Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) hash_set = set(hash.lower() for hash in multiple_hashes)
return [ return [item for item in data if item.get("sha256", "").lower() in hash_set]
item for item in data
if item.get('sha256', '').lower() in hash_set
]
return data return data
async def _apply_common_filters( async def _apply_common_filters(
self, self,
data: List[Dict], data: List[Dict],
folder: str = None, folder: str = None,
folder_exclude: list = None,
base_models: list = None, base_models: list = None,
model_types: list = None, model_types: list = None,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
@@ -228,6 +243,7 @@ class BaseModelService(ABC):
normalized_options = self.search_strategy.normalize_options(search_options) normalized_options = self.search_strategy.normalize_options(search_options)
criteria = FilterCriteria( criteria = FilterCriteria(
folder=folder, folder=folder,
folder_exclude=folder_exclude,
base_models=base_models, base_models=base_models,
model_types=model_types, model_types=model_types,
tags=tags, tags=tags,
@@ -235,7 +251,7 @@ class BaseModelService(ABC):
search_options=normalized_options, search_options=normalized_options,
) )
return self.filter_set.apply(data, criteria) return self.filter_set.apply(data, criteria)
async def _apply_search_filters( async def _apply_search_filters(
self, self,
data: List[Dict], data: List[Dict],
@@ -245,28 +261,34 @@ class BaseModelService(ABC):
) -> List[Dict]: ) -> List[Dict]:
"""Apply search filtering""" """Apply search filtering"""
normalized_options = self.search_strategy.normalize_options(search_options) normalized_options = self.search_strategy.normalize_options(search_options)
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search) return self.search_strategy.apply(
data, search, normalized_options, fuzzy_search
)
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed""" """Apply model-specific filters - to be overridden by subclasses if needed"""
return data return data
async def _apply_credit_required_filter(self, data: List[Dict], credit_required: bool) -> List[Dict]: async def _apply_credit_required_filter(
self, data: List[Dict], credit_required: bool
) -> List[Dict]:
"""Apply credit required filtering based on license_flags. """Apply credit required filtering based on license_flags.
Args: Args:
data: List of model data items data: List of model data items
credit_required: credit_required:
- True: Return items where credit is required (allowNoCredit=False) - True: Return items where credit is required (allowNoCredit=False)
- False: Return items where credit is not required (allowNoCredit=True) - False: Return items where credit is not required (allowNoCredit=True)
""" """
filtered_data = [] filtered_data = []
for item in data: for item in data:
license_flags = item.get("license_flags", 127) # Default to all permissions enabled license_flags = item.get(
"license_flags", 127
) # Default to all permissions enabled
# Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required) # Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required)
allow_no_credit = bool(license_flags & (1 << 0)) allow_no_credit = bool(license_flags & (1 << 0))
# If credit_required is True, we want items where allowNoCredit is False (credit required) # If credit_required is True, we want items where allowNoCredit is False (credit required)
# If credit_required is False, we want items where allowNoCredit is True (no credit required) # If credit_required is False, we want items where allowNoCredit is True (no credit required)
if credit_required: if credit_required:
@@ -275,26 +297,30 @@ class BaseModelService(ABC):
else: else:
if allow_no_credit: # Credit is not required if allow_no_credit: # Credit is not required
filtered_data.append(item) filtered_data.append(item)
return filtered_data return filtered_data
async def _apply_allow_selling_filter(self, data: List[Dict], allow_selling: bool) -> List[Dict]: async def _apply_allow_selling_filter(
self, data: List[Dict], allow_selling: bool
) -> List[Dict]:
"""Apply allow selling generated content filtering based on license_flags. """Apply allow selling generated content filtering based on license_flags.
Args: Args:
data: List of model data items data: List of model data items
allow_selling: allow_selling:
- True: Return items where selling generated content is allowed (allowCommercialUse contains Image) - True: Return items where selling generated content is allowed (allowCommercialUse contains Image)
- False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image) - False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image)
""" """
filtered_data = [] filtered_data = []
for item in data: for item in data:
license_flags = item.get("license_flags", 127) # Default to all permissions enabled license_flags = item.get(
"license_flags", 127
) # Default to all permissions enabled
# Bits 1-4 represent commercial use permissions # Bits 1-4 represent commercial use permissions
# Bit 1 specifically represents Image permission (allowCommercialUse contains Image) # Bit 1 specifically represents Image permission (allowCommercialUse contains Image)
has_image_permission = bool(license_flags & (1 << 1)) has_image_permission = bool(license_flags & (1 << 1))
# If allow_selling is True, we want items where Image permission is granted # If allow_selling is True, we want items where Image permission is granted
# If allow_selling is False, we want items where Image permission is not granted # If allow_selling is False, we want items where Image permission is not granted
if allow_selling: if allow_selling:
@@ -303,7 +329,7 @@ class BaseModelService(ABC):
else: else:
if not has_image_permission: # Selling generated content is not allowed if not has_image_permission: # Selling generated content is not allowed
filtered_data.append(item) filtered_data.append(item)
return filtered_data return filtered_data
async def _annotate_update_flags( async def _annotate_update_flags(
@@ -321,7 +347,7 @@ class BaseModelService(ABC):
if self.update_service is None: if self.update_service is None:
for item in annotated: for item in annotated:
item['update_available'] = False item["update_available"] = False
return annotated return annotated
id_to_items: Dict[int, List[Dict]] = {} id_to_items: Dict[int, List[Dict]] = {}
@@ -329,7 +355,7 @@ class BaseModelService(ABC):
for item in annotated: for item in annotated:
model_id = self._extract_model_id(item) model_id = self._extract_model_id(item)
if model_id is None: if model_id is None:
item['update_available'] = False item["update_available"] = False
continue continue
if model_id not in id_to_items: if model_id not in id_to_items:
id_to_items[model_id] = [] id_to_items[model_id] = []
@@ -405,13 +431,19 @@ class BaseModelService(ABC):
default_flag = bool(resolved.get(model_id, False)) if resolved else False default_flag = bool(resolved.get(model_id, False)) if resolved else False
record = records.get(model_id) if records else None record = records.get(model_id) if records else None
base_highest_versions = ( base_highest_versions = (
self._build_highest_local_versions_by_base(record) if same_base_mode and record else {} self._build_highest_local_versions_by_base(record)
if same_base_mode and record
else {}
) )
for item in items_for_id: for item in items_for_id:
if same_base_mode and record is not None: if same_base_mode and record is not None:
base_model = self._extract_base_model(item) base_model = self._extract_base_model(item)
normalized_base = self._normalize_base_model_name(base_model) normalized_base = self._normalize_base_model_name(base_model)
threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None threshold_version = (
base_highest_versions.get(normalized_base)
if normalized_base
else None
)
if threshold_version is None: if threshold_version is None:
threshold_version = self._extract_version_id(item) threshold_version = self._extract_version_id(item)
flag = record.has_update_for_base( flag = record.has_update_for_base(
@@ -420,17 +452,17 @@ class BaseModelService(ABC):
) )
else: else:
flag = default_flag flag = default_flag
item['update_available'] = flag item["update_available"] = flag
return annotated return annotated
@staticmethod @staticmethod
def _extract_model_id(item: Dict) -> Optional[int]: def _extract_model_id(item: Dict) -> Optional[int]:
civitai = item.get('civitai') if isinstance(item, dict) else None civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict): if not isinstance(civitai, dict):
return None return None
try: try:
value = civitai.get('modelId') value = civitai.get("modelId")
if value is None: if value is None:
return None return None
return int(value) return int(value)
@@ -439,10 +471,10 @@ class BaseModelService(ABC):
@staticmethod @staticmethod
def _extract_version_id(item: Dict) -> Optional[int]: def _extract_version_id(item: Dict) -> Optional[int]:
civitai = item.get('civitai') if isinstance(item, dict) else None civitai = item.get("civitai") if isinstance(item, dict) else None
if not isinstance(civitai, dict): if not isinstance(civitai, dict):
return None return None
value = civitai.get('id') value = civitai.get("id")
if value is None: if value is None:
return None return None
try: try:
@@ -452,7 +484,7 @@ class BaseModelService(ABC):
@staticmethod @staticmethod
def _extract_base_model(item: Dict) -> Optional[str]: def _extract_base_model(item: Dict) -> Optional[str]:
value = item.get('base_model') value = item.get("base_model")
if value is None: if value is None:
return None return None
if isinstance(value, str): if isinstance(value, str):
@@ -489,7 +521,9 @@ class BaseModelService(ABC):
for version in getattr(record, "versions", []): for version in getattr(record, "versions", []):
if not getattr(version, "is_in_library", False): if not getattr(version, "is_in_library", False):
continue continue
normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None)) normalized_base = self._normalize_base_model_name(
getattr(version, "base_model", None)
)
if normalized_base is None: if normalized_base is None:
continue continue
version_id = getattr(version, "version_id", None) version_id = getattr(version, "version_id", None)
@@ -506,25 +540,25 @@ class BaseModelService(ABC):
total_items = len(data) total_items = len(data)
start_idx = (page - 1) * page_size start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items) end_idx = min(start_idx + page_size, total_items)
return { return {
'items': data[start_idx:end_idx], "items": data[start_idx:end_idx],
'total': total_items, "total": total_items,
'page': page, "page": page,
'page_size': page_size, "page_size": page_size,
'total_pages': (total_items + page_size - 1) // page_size "total_pages": (total_items + page_size - 1) // page_size,
} }
@abstractmethod @abstractmethod
async def format_response(self, model_data: Dict) -> Dict: async def format_response(self, model_data: Dict) -> Dict:
"""Format model data for API response - must be implemented by subclasses""" """Format model data for API response - must be implemented by subclasses"""
pass pass
# Common service methods that delegate to scanner # Common service methods that delegate to scanner
async def get_top_tags(self, limit: int = 20) -> List[Dict]: async def get_top_tags(self, limit: int = 20) -> List[Dict]:
"""Get top tags sorted by frequency""" """Get top tags sorted by frequency"""
return await self.scanner.get_top_tags(limit) return await self.scanner.get_top_tags(limit)
async def get_base_models(self, limit: int = 20) -> List[Dict]: async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency""" """Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit) return await self.scanner.get_base_models(limit)
@@ -535,62 +569,85 @@ class BaseModelService(ABC):
type_counts: Dict[str, int] = {} type_counts: Dict[str, int] = {}
for entry in cache.raw_data: for entry in cache.raw_data:
normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry)) normalized_type = normalize_civitai_model_type(
resolve_civitai_model_type(entry)
)
if not normalized_type or normalized_type not in VALID_LORA_TYPES: if not normalized_type or normalized_type not in VALID_LORA_TYPES:
continue continue
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1 type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
sorted_types = sorted( sorted_types = sorted(
[{"type": model_type, "count": count} for model_type, count in type_counts.items()], [
{"type": model_type, "count": count}
for model_type, count in type_counts.items()
],
key=lambda value: value["count"], key=lambda value: value["count"],
reverse=True, reverse=True,
) )
return sorted_types[:limit] return sorted_types[:limit]
def has_hash(self, sha256: str) -> bool: def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists""" """Check if a model with given hash exists"""
return self.scanner.has_hash(sha256) return self.scanner.has_hash(sha256)
def get_path_by_hash(self, sha256: str) -> Optional[str]: def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash""" """Get file path for a model by its hash"""
return self.scanner.get_path_by_hash(sha256) return self.scanner.get_path_by_hash(sha256)
def get_hash_by_path(self, file_path: str) -> Optional[str]: def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path""" """Get hash for a model by its file path"""
return self.scanner.get_hash_by_path(file_path) return self.scanner.get_hash_by_path(file_path)
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False): async def scan_models(
self, force_refresh: bool = False, rebuild_cache: bool = False
):
"""Trigger model scanning""" """Trigger model scanning"""
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache) return await self.scanner.get_cached_data(
force_refresh=force_refresh, rebuild_cache=rebuild_cache
)
async def get_model_info_by_name(self, name: str): async def get_model_info_by_name(self, name: str):
"""Get model information by name""" """Get model information by name"""
return await self.scanner.get_model_info_by_name(name) return await self.scanner.get_model_info_by_name(name)
def get_model_roots(self) -> List[str]: def get_model_roots(self) -> List[str]:
"""Get model root directories""" """Get model root directories"""
return self.scanner.get_model_roots() return self.scanner.get_model_roots()
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict: def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
"""Filter relevant fields from CivitAI data""" """Filter relevant fields from CivitAI data"""
if not data: if not data:
return {} return {}
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [ fields = (
"id", "modelId", "name", "createdAt", "updatedAt", ["id", "modelId", "name", "trainedWords"]
"publishedAt", "trainedWords", "baseModel", "description", if minimal
"model", "images", "customImages", "creator" else [
] "id",
"modelId",
"name",
"createdAt",
"updatedAt",
"publishedAt",
"trainedWords",
"baseModel",
"description",
"model",
"images",
"customImages",
"creator",
]
)
return {k: data[k] for k in fields if k in data} return {k: data[k] for k in fields if k in data}
async def get_folder_tree(self, model_root: str) -> Dict: async def get_folder_tree(self, model_root: str) -> Dict:
"""Get hierarchical folder tree for a specific model root""" """Get hierarchical folder tree for a specific model root"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
# Build tree structure from folders # Build tree structure from folders
tree = {} tree = {}
for folder in cache.folders: for folder in cache.folders:
# Check if this folder belongs to the specified model root # Check if this folder belongs to the specified model root
folder_belongs_to_root = False folder_belongs_to_root = False
@@ -598,95 +655,96 @@ class BaseModelService(ABC):
if root == model_root: if root == model_root:
folder_belongs_to_root = True folder_belongs_to_root = True
break break
if not folder_belongs_to_root: if not folder_belongs_to_root:
continue continue
# Split folder path into components # Split folder path into components
parts = folder.split('/') if folder else [] parts = folder.split("/") if folder else []
current_level = tree current_level = tree
for part in parts: for part in parts:
if part not in current_level: if part not in current_level:
current_level[part] = {} current_level[part] = {}
current_level = current_level[part] current_level = current_level[part]
return tree return tree
async def get_unified_folder_tree(self) -> Dict: async def get_unified_folder_tree(self) -> Dict:
"""Get unified folder tree across all model roots""" """Get unified folder tree across all model roots"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
# Build unified tree structure by analyzing all relative paths # Build unified tree structure by analyzing all relative paths
unified_tree = {} unified_tree = {}
# Get all model roots for path normalization # Get all model roots for path normalization
model_roots = self.scanner.get_model_roots() model_roots = self.scanner.get_model_roots()
for folder in cache.folders: for folder in cache.folders:
if not folder: # Skip empty folders if not folder: # Skip empty folders
continue continue
# Find which root this folder belongs to by checking the actual file paths # Find which root this folder belongs to by checking the actual file paths
# This is a simplified approach - we'll use the folder as-is since it should already be relative # This is a simplified approach - we'll use the folder as-is since it should already be relative
relative_path = folder relative_path = folder
# Split folder path into components # Split folder path into components
parts = relative_path.split('/') parts = relative_path.split("/")
current_level = unified_tree current_level = unified_tree
for part in parts: for part in parts:
if part not in current_level: if part not in current_level:
current_level[part] = {} current_level[part] = {}
current_level = current_level[part] current_level = current_level[part]
return unified_tree return unified_tree
async def get_model_notes(self, model_name: str) -> Optional[str]: async def get_model_notes(self, model_name: str) -> Optional[str]:
"""Get notes for a specific model file""" """Get notes for a specific model file"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
for model in cache.raw_data: for model in cache.raw_data:
if model['file_name'] == model_name: if model["file_name"] == model_name:
return model.get('notes', '') return model.get("notes", "")
return None return None
async def get_model_preview_url(self, model_name: str) -> Optional[str]: async def get_model_preview_url(self, model_name: str) -> Optional[str]:
"""Get the static preview URL for a model file""" """Get the static preview URL for a model file"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
for model in cache.raw_data: for model in cache.raw_data:
if model['file_name'] == model_name: if model["file_name"] == model_name:
preview_url = model.get('preview_url') preview_url = model.get("preview_url")
if preview_url: if preview_url:
from ..config import config from ..config import config
return config.get_preview_static_url(preview_url) return config.get_preview_static_url(preview_url)
return '/loras_static/images/no-preview.png' return "/loras_static/images/no-preview.png"
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]: async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a model file""" """Get the Civitai URL for a model file"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
for model in cache.raw_data: for model in cache.raw_data:
if model['file_name'] == model_name: if model["file_name"] == model_name:
civitai_data = model.get('civitai', {}) civitai_data = model.get("civitai", {})
model_id = civitai_data.get('modelId') model_id = civitai_data.get("modelId")
version_id = civitai_data.get('id') version_id = civitai_data.get("id")
if model_id: if model_id:
civitai_url = f"https://civitai.com/models/{model_id}" civitai_url = f"https://civitai.com/models/{model_id}"
if version_id: if version_id:
civitai_url += f"?modelVersionId={version_id}" civitai_url += f"?modelVersionId={version_id}"
return { return {
'civitai_url': civitai_url, "civitai_url": civitai_url,
'model_id': str(model_id), "model_id": str(model_id),
'version_id': str(version_id) if version_id else None "version_id": str(version_id) if version_id else None,
} }
return {'civitai_url': None, 'model_id': None, 'version_id': None} return {"civitai_url": None, "model_id": None, "version_id": None}
async def get_model_metadata(self, file_path: str) -> Optional[Dict]: async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
"""Load full metadata for a single model. """Load full metadata for a single model.
@@ -694,18 +752,21 @@ class BaseModelService(ABC):
Listing/search endpoints return lightweight cache entries; this method performs Listing/search endpoints return lightweight cache entries; this method performs
a lazy read of the on-disk metadata snapshot when callers need full detail. a lazy read of the on-disk metadata snapshot when callers need full detail.
""" """
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class) metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.metadata_class
)
if should_skip or metadata is None: if should_skip or metadata is None:
return None return None
return self.filter_civitai_data(metadata.to_dict().get("civitai", {})) return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
async def get_model_description(self, file_path: str) -> Optional[str]: async def get_model_description(self, file_path: str) -> Optional[str]:
"""Return the stored modelDescription field for a model.""" """Return the stored modelDescription field for a model."""
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class) metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.metadata_class
)
if should_skip or metadata is None: if should_skip or metadata is None:
return None return None
return metadata.modelDescription or '' return metadata.modelDescription or ""
@staticmethod @staticmethod
def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]: def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]:
@@ -743,53 +804,64 @@ class BaseModelService(ABC):
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple: def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
"""Sort paths by how well they satisfy the include tokens.""" """Sort paths by how well they satisfy the include tokens."""
path_lower = relative_path.lower() path_lower = relative_path.lower()
prefix_hits = sum(1 for term in include_terms if term and path_lower.startswith(term)) prefix_hits = sum(
match_positions = [path_lower.find(term) for term in include_terms if term and term in path_lower] 1 for term in include_terms if term and path_lower.startswith(term)
)
match_positions = [
path_lower.find(term)
for term in include_terms
if term and term in path_lower
]
first_match_index = min(match_positions) if match_positions else 0 first_match_index = min(match_positions) if match_positions else 0
return (-prefix_hits, first_match_index, len(relative_path), path_lower) return (-prefix_hits, first_match_index, len(relative_path), path_lower)
async def search_relative_paths(
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]: self, search_term: str, limit: int = 15
) -> List[str]:
"""Search model relative file paths for autocomplete functionality""" """Search model relative file paths for autocomplete functionality"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
include_terms, exclude_terms = self._parse_search_tokens(search_term) include_terms, exclude_terms = self._parse_search_tokens(search_term)
matching_paths = [] matching_paths = []
# Get model roots for path calculation # Get model roots for path calculation
model_roots = self.scanner.get_model_roots() model_roots = self.scanner.get_model_roots()
for model in cache.raw_data: for model in cache.raw_data:
file_path = model.get('file_path', '') file_path = model.get("file_path", "")
if not file_path: if not file_path:
continue continue
# Calculate relative path from model root # Calculate relative path from model root
relative_path = None relative_path = None
for root in model_roots: for root in model_roots:
# Normalize paths for comparison # Normalize paths for comparison
normalized_root = os.path.normpath(root) normalized_root = os.path.normpath(root)
normalized_file = os.path.normpath(file_path) normalized_file = os.path.normpath(file_path)
if normalized_file.startswith(normalized_root): if normalized_file.startswith(normalized_root):
# Remove root and leading separator to get relative path # Remove root and leading separator to get relative path
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep) relative_path = normalized_file[len(normalized_root) :].lstrip(
os.sep
)
break break
if not relative_path: if not relative_path:
continue continue
relative_lower = relative_path.lower() relative_lower = relative_path.lower()
if self._relative_path_matches_tokens(relative_lower, include_terms, exclude_terms): if self._relative_path_matches_tokens(
relative_lower, include_terms, exclude_terms
):
matching_paths.append(relative_path) matching_paths.append(relative_path)
if len(matching_paths) >= limit * 2: # Get more for better sorting if len(matching_paths) >= limit * 2: # Get more for better sorting
break break
# Sort by relevance (prefix and earliest hits first, then by length and alphabetically) # Sort by relevance (prefix and earliest hits first, then by length and alphabetically)
matching_paths.sort( matching_paths.sort(
key=lambda relative: self._relative_path_sort_key(relative, include_terms) key=lambda relative: self._relative_path_sort_key(relative, include_terms)
) )
return matching_paths[:limit] return matching_paths[:limit]

View File

@@ -1,7 +1,18 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable from typing import (
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Protocol,
Callable,
)
from ..utils.constants import NSFW_LEVELS from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match from ..utils.utils import fuzzy_match as default_fuzzy_match
@@ -51,8 +62,7 @@ def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
class SettingsProvider(Protocol): class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers.""" """Protocol describing the SettingsManager contract used by query helpers."""
def get(self, key: str, default: Any = None) -> Any: def get(self, key: str, default: Any = None) -> Any: ...
...
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -68,6 +78,7 @@ class FilterCriteria:
"""Container for model list filtering options.""" """Container for model list filtering options."""
folder: Optional[str] = None folder: Optional[str] = None
folder_exclude: Optional[Sequence[str]] = None
base_models: Optional[Sequence[str]] = None base_models: Optional[Sequence[str]] = None
tags: Optional[Dict[str, str]] = None tags: Optional[Dict[str, str]] = None
favorites_only: bool = False favorites_only: bool = False
@@ -113,11 +124,15 @@ class ModelCacheRepository:
class ModelFilterSet: class ModelFilterSet:
"""Applies common filtering rules to the model collection.""" """Applies common filtering rules to the model collection."""
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None: def __init__(
self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None
) -> None:
self._settings = settings self._settings = settings
self._nsfw_levels = nsfw_levels or NSFW_LEVELS self._nsfw_levels = nsfw_levels or NSFW_LEVELS
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]: def apply(
self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria
) -> List[Dict[str, Any]]:
"""Return items that satisfy the provided criteria.""" """Return items that satisfy the provided criteria."""
overall_start = time.perf_counter() overall_start = time.perf_counter()
items = list(data) items = list(data)
@@ -127,8 +142,10 @@ class ModelFilterSet:
t0 = time.perf_counter() t0 = time.perf_counter()
threshold = self._nsfw_levels.get("R", 0) threshold = self._nsfw_levels.get("R", 0)
items = [ items = [
item for item in items item
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold for item in items
if not item.get("preview_nsfw_level")
or item.get("preview_nsfw_level") < threshold
] ]
sfw_duration = time.perf_counter() - t0 sfw_duration = time.perf_counter() - t0
else: else:
@@ -142,20 +159,44 @@ class ModelFilterSet:
folder_duration = 0 folder_duration = 0
folder = criteria.folder folder = criteria.folder
folder_exclude = criteria.folder_exclude or []
options = criteria.search_options or {} options = criteria.search_options or {}
recursive = bool(options.get("recursive", True)) recursive = bool(options.get("recursive", True))
# Apply folder exclude filters first
if folder_exclude:
t0 = time.perf_counter()
for exclude_folder in folder_exclude:
if exclude_folder:
# Check exact match OR prefix match (for subfolders)
# Normalize exclude_folder for prefix matching
if not exclude_folder.endswith("/"):
exclude_prefix = f"{exclude_folder}/"
else:
exclude_prefix = exclude_folder
items = [
item
for item in items
if item.get("folder") != exclude_folder
and not item.get("folder", "").startswith(exclude_prefix)
]
folder_duration = time.perf_counter() - t0
# Apply folder include filters
if folder is not None: if folder is not None:
t0 = time.perf_counter() t0 = time.perf_counter()
if recursive: if recursive:
if folder: if folder:
folder_with_sep = f"{folder}/" folder_with_sep = f"{folder}/"
items = [ items = [
item for item in items item
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep) for item in items
if item.get("folder") == folder
or item.get("folder", "").startswith(folder_with_sep)
] ]
else: else:
items = [item for item in items if item.get("folder") == folder] items = [item for item in items if item.get("folder") == folder]
folder_duration = time.perf_counter() - t0 folder_duration = time.perf_counter() - t0 + folder_duration
base_models_duration = 0 base_models_duration = 0
base_models = criteria.base_models or [] base_models = criteria.base_models or []
@@ -183,25 +224,23 @@ class ModelFilterSet:
include_tags = {tag for tag in tag_filters if tag} include_tags = {tag for tag in tag_filters if tag}
if include_tags: if include_tags:
def matches_include(item_tags): def matches_include(item_tags):
if not item_tags and "__no_tags__" in include_tags: if not item_tags and "__no_tags__" in include_tags:
return True return True
return any(tag in include_tags for tag in (item_tags or [])) return any(tag in include_tags for tag in (item_tags or []))
items = [ items = [item for item in items if matches_include(item.get("tags"))]
item for item in items
if matches_include(item.get("tags"))
]
if exclude_tags: if exclude_tags:
def matches_exclude(item_tags): def matches_exclude(item_tags):
if not item_tags and "__no_tags__" in exclude_tags: if not item_tags and "__no_tags__" in exclude_tags:
return True return True
return any(tag in exclude_tags for tag in (item_tags or [])) return any(tag in exclude_tags for tag in (item_tags or []))
items = [ items = [
item for item in items item for item in items if not matches_exclude(item.get("tags"))
if not matches_exclude(item.get("tags"))
] ]
tags_duration = time.perf_counter() - t0 tags_duration = time.perf_counter() - t0
@@ -210,26 +249,35 @@ class ModelFilterSet:
if model_types: if model_types:
t0 = time.perf_counter() t0 = time.perf_counter()
normalized_model_types = { normalized_model_types = {
model_type for model_type in ( model_type
for model_type in (
normalize_civitai_model_type(value) for value in model_types normalize_civitai_model_type(value) for value in model_types
) )
if model_type if model_type
} }
if normalized_model_types: if normalized_model_types:
items = [ items = [
item for item in items item
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item))
in normalized_model_types
] ]
model_types_duration = time.perf_counter() - t0 model_types_duration = time.perf_counter() - t0
duration = time.perf_counter() - overall_start duration = time.perf_counter() - overall_start
if duration > 0.1: # Only log if it's potentially slow if duration > 0.1: # Only log if it's potentially slow
logger.debug( logger.debug(
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). " "ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). "
"Count: %d -> %d", "Count: %d -> %d",
duration, sfw_duration, favorites_duration, folder_duration, duration,
base_models_duration, tags_duration, model_types_duration, sfw_duration,
initial_count, len(items) favorites_duration,
folder_duration,
base_models_duration,
tags_duration,
model_types_duration,
initial_count,
len(items),
) )
return items return items
@@ -245,7 +293,9 @@ class SearchStrategy:
"creator": False, "creator": False,
} }
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None: def __init__(
self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None
) -> None:
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]: def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
@@ -284,7 +334,9 @@ class SearchStrategy:
if options.get("tags", False): if options.get("tags", False):
tags = item.get("tags", []) or [] tags = item.get("tags", []) or []
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags): if any(
self._matches(tag, search_term, search_lower, fuzzy) for tag in tags
):
results.append(item) results.append(item)
continue continue
@@ -295,13 +347,17 @@ class SearchStrategy:
creator = civitai.get("creator") creator = civitai.get("creator")
if isinstance(creator, dict): if isinstance(creator, dict):
creator_username = creator.get("username", "") creator_username = creator.get("username", "")
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy): if creator_username and self._matches(
creator_username, search_term, search_lower, fuzzy
):
results.append(item) results.append(item)
continue continue
return results return results
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool: def _matches(
self, candidate: str, search_term: str, search_lower: str, fuzzy: bool
) -> bool:
if not isinstance(candidate, str): if not isinstance(candidate, str):
candidate = "" if candidate is None else str(candidate) candidate = "" if candidate is None else str(candidate)

View File

@@ -3,13 +3,16 @@ from py.services.model_query import ModelFilterSet, FilterCriteria
from py.services.recipe_scanner import RecipeScanner from py.services.recipe_scanner import RecipeScanner
from types import SimpleNamespace from types import SimpleNamespace
# Mock settings # Mock settings
class MockSettings: class MockSettings:
def get(self, key, default=None): def get(self, key, default=None):
return default return default
# --- Model Filtering Tests --- # --- Model Filtering Tests ---
def test_model_filter_set_root_recursive_true(): def test_model_filter_set_root_recursive_true():
filter_set = ModelFilterSet(MockSettings()) filter_set = ModelFilterSet(MockSettings())
items = [ items = [
@@ -17,13 +20,14 @@ def test_model_filter_set_root_recursive_true():
{"model_name": "sub_item", "folder": "sub"}, {"model_name": "sub_item", "folder": "sub"},
] ]
criteria = FilterCriteria(folder="", search_options={"recursive": True}) criteria = FilterCriteria(folder="", search_options={"recursive": True})
result = filter_set.apply(items, criteria) result = filter_set.apply(items, criteria)
assert len(result) == 2 assert len(result) == 2
assert any(i["model_name"] == "root_item" for i in result) assert any(i["model_name"] == "root_item" for i in result)
assert any(i["model_name"] == "sub_item" for i in result) assert any(i["model_name"] == "sub_item" for i in result)
def test_model_filter_set_root_recursive_false(): def test_model_filter_set_root_recursive_false():
filter_set = ModelFilterSet(MockSettings()) filter_set = ModelFilterSet(MockSettings())
items = [ items = [
@@ -31,62 +35,185 @@ def test_model_filter_set_root_recursive_false():
{"model_name": "sub_item", "folder": "sub"}, {"model_name": "sub_item", "folder": "sub"},
] ]
criteria = FilterCriteria(folder="", search_options={"recursive": False}) criteria = FilterCriteria(folder="", search_options={"recursive": False})
result = filter_set.apply(items, criteria) result = filter_set.apply(items, criteria)
assert len(result) == 1 assert len(result) == 1
assert result[0]["model_name"] == "root_item" assert result[0]["model_name"] == "root_item"
def test_model_filter_set_folder_exclude_single():
filter_set = ModelFilterSet(MockSettings())
items = [
{"model_name": "item1", "folder": "characters/"},
{"model_name": "item2", "folder": "styles/"},
{"model_name": "item3", "folder": "characters/anime/"},
{"model_name": "item4", "folder": ""},
]
criteria = FilterCriteria(
folder_exclude=["characters/"], search_options={"recursive": True}
)
result = filter_set.apply(items, criteria)
assert len(result) == 2
model_names = {i["model_name"] for i in result}
assert model_names == {"item2", "item4"}
def test_model_filter_set_folder_exclude_multiple():
filter_set = ModelFilterSet(MockSettings())
items = [
{"model_name": "item1", "folder": "characters/"},
{"model_name": "item2", "folder": "styles/"},
{"model_name": "item3", "folder": "concepts/"},
{"model_name": "item4", "folder": "characters/anime/"},
{"model_name": "item5", "folder": ""},
]
criteria = FilterCriteria(
folder_exclude=["characters/", "styles/"], search_options={"recursive": True}
)
result = filter_set.apply(items, criteria)
assert len(result) == 2
model_names = {i["model_name"] for i in result}
assert model_names == {"item3", "item5"}
def test_model_filter_set_folder_exclude_with_include():
filter_set = ModelFilterSet(MockSettings())
items = [
{"model_name": "item1", "folder": "characters/"},
{"model_name": "item2", "folder": "styles/"},
{"model_name": "item3", "folder": "characters/anime/"},
{"model_name": "item4", "folder": "styles/painting/"},
{"model_name": "item5", "folder": "concepts/"},
]
criteria = FilterCriteria(
folder="characters/",
folder_exclude=["characters/anime/"],
search_options={"recursive": True},
)
result = filter_set.apply(items, criteria)
assert len(result) == 1
assert result[0]["model_name"] == "item1"
# --- Recipe Filtering Tests --- # --- Recipe Filtering Tests ---
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_recipe_scanner_root_recursive_true(): async def test_recipe_scanner_root_recursive_true():
# Mock LoraScanner # Mock LoraScanner
class StubLoraScanner: class StubLoraScanner:
async def get_cached_data(self): async def get_cached_data(self):
return SimpleNamespace(raw_data=[]) return SimpleNamespace(raw_data=[])
scanner = RecipeScanner(lora_scanner=StubLoraScanner()) scanner = RecipeScanner(lora_scanner=StubLoraScanner())
# Manually populate cache for testing get_paginated_data logic # Manually populate cache for testing get_paginated_data logic
scanner._cache = SimpleNamespace( scanner._cache = SimpleNamespace(
raw_data=[ raw_data=[
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []}, {
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []}, "id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
{
"id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
], ],
sorted_by_date=[ sorted_by_date=[
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []}, {
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []}, "id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
{
"id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
], ],
sorted_by_name=[], sorted_by_name=[],
version_index={} version_index={},
) )
result = await scanner.get_paginated_data(page=1, page_size=10, folder="", recursive=True) result = await scanner.get_paginated_data(
page=1, page_size=10, folder="", recursive=True
)
assert len(result["items"]) == 2 assert len(result["items"]) == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_recipe_scanner_root_recursive_false(): async def test_recipe_scanner_root_recursive_false():
# Mock LoraScanner # Mock LoraScanner
class StubLoraScanner: class StubLoraScanner:
async def get_cached_data(self): async def get_cached_data(self):
return SimpleNamespace(raw_data=[]) return SimpleNamespace(raw_data=[])
scanner = RecipeScanner(lora_scanner=StubLoraScanner()) scanner = RecipeScanner(lora_scanner=StubLoraScanner())
scanner._cache = SimpleNamespace( scanner._cache = SimpleNamespace(
raw_data=[ raw_data=[
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []}, {
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []}, "id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
{
"id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
], ],
sorted_by_date=[ sorted_by_date=[
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []}, {
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []}, "id": "r2",
"title": "sub_recipe",
"folder": "sub",
"modified": 2.0,
"created_date": 2.0,
"loras": [],
},
{
"id": "r1",
"title": "root_recipe",
"folder": "",
"modified": 1.0,
"created_date": 1.0,
"loras": [],
},
], ],
sorted_by_name=[], sorted_by_name=[],
version_index={} version_index={},
) )
result = await scanner.get_paginated_data(page=1, page_size=10, folder="", recursive=False) result = await scanner.get_paginated_data(
page=1, page_size=10, folder="", recursive=False
)
assert len(result["items"]) == 1 assert len(result["items"]) == 1
assert result["items"][0]["id"] == "r1" assert result["items"][0]["id"] == "r1"

View File

@@ -77,11 +77,12 @@ export function useLoraPoolApi() {
params.tagsInclude?.forEach(tag => urlParams.append('tag_include', tag)) params.tagsInclude?.forEach(tag => urlParams.append('tag_include', tag))
params.tagsExclude?.forEach(tag => urlParams.append('tag_exclude', tag)) params.tagsExclude?.forEach(tag => urlParams.append('tag_exclude', tag))
// For now, use first include folder (backend currently supports single folder) // Folder filters
if (params.foldersInclude && params.foldersInclude.length > 0) { if (params.foldersInclude && params.foldersInclude.length > 0) {
urlParams.set('folder', params.foldersInclude[0]) urlParams.set('folder', params.foldersInclude[0])
urlParams.set('recursive', 'true') urlParams.set('recursive', 'true')
} }
params.foldersExclude?.forEach(folder => urlParams.append('folder_exclude', folder))
if (params.noCreditRequired !== undefined) { if (params.noCreditRequired !== undefined) {
urlParams.set('credit_required', String(!params.noCreditRequired)) urlParams.set('credit_required', String(!params.noCreditRequired))

View File

@@ -10870,7 +10870,7 @@ function useLoraPoolApi() {
}); });
}; };
const fetchLoras = async (params) => { const fetchLoras = async (params) => {
var _a, _b, _c; var _a, _b, _c, _d;
isLoading.value = true; isLoading.value = true;
try { try {
const urlParams = new URLSearchParams(); const urlParams = new URLSearchParams();
@@ -10883,6 +10883,7 @@ function useLoraPoolApi() {
urlParams.set("folder", params.foldersInclude[0]); urlParams.set("folder", params.foldersInclude[0]);
urlParams.set("recursive", "true"); urlParams.set("recursive", "true");
} }
(_d = params.foldersExclude) == null ? void 0 : _d.forEach((folder) => urlParams.append("folder_exclude", folder));
if (params.noCreditRequired !== void 0) { if (params.noCreditRequired !== void 0) {
urlParams.set("credit_required", String(!params.noCreditRequired)); urlParams.set("credit_required", String(!params.noCreditRequired));
} }

File diff suppressed because one or more lines are too long