mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: improve code formatting and readability in model handlers
- Add blank line after module docstring for better PEP 8 compliance - Reformat long lines to adhere to 88-character limit using Black-style formatting - Improve string consistency by using double quotes consistently - Enhance readability of complex list comprehensions and method calls - Maintain all existing functionality while improving code structure
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -25,9 +25,10 @@ logger = logging.getLogger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from .model_update_service import ModelUpdateService
|
||||
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
@@ -60,13 +61,14 @@ class BaseModelService(ABC):
|
||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||
self.search_strategy = search_strategy or SearchStrategy()
|
||||
self.update_service = update_service
|
||||
|
||||
|
||||
async def get_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
page_size: int,
|
||||
sort_by: str = 'name',
|
||||
sort_by: str = "name",
|
||||
folder: str = None,
|
||||
folder_exclude: list = None,
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
base_models: list = None,
|
||||
@@ -85,7 +87,7 @@ class BaseModelService(ABC):
|
||||
|
||||
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||
t0 = time.perf_counter()
|
||||
if sort_params.key == 'usage':
|
||||
if sort_params.key == "usage":
|
||||
sorted_data = await self._fetch_with_usage_sort(sort_params)
|
||||
else:
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
@@ -99,6 +101,7 @@ class BaseModelService(ABC):
|
||||
filtered_data = await self._apply_common_filters(
|
||||
sorted_data,
|
||||
folder=folder,
|
||||
folder_exclude=folder_exclude,
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
@@ -118,10 +121,14 @@ class BaseModelService(ABC):
|
||||
|
||||
# Apply license-based filters
|
||||
if credit_required is not None:
|
||||
filtered_data = await self._apply_credit_required_filter(filtered_data, credit_required)
|
||||
|
||||
filtered_data = await self._apply_credit_required_filter(
|
||||
filtered_data, credit_required
|
||||
)
|
||||
|
||||
if allow_selling_generated_content is not None:
|
||||
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content)
|
||||
filtered_data = await self._apply_allow_selling_filter(
|
||||
filtered_data, allow_selling_generated_content
|
||||
)
|
||||
filter_duration = time.perf_counter() - t1
|
||||
post_filter_count = len(filtered_data)
|
||||
|
||||
@@ -130,8 +137,7 @@ class BaseModelService(ABC):
|
||||
if update_available_only:
|
||||
annotated_for_filter = await self._annotate_update_flags(filtered_data)
|
||||
filtered_data = [
|
||||
item for item in annotated_for_filter
|
||||
if item.get('update_available')
|
||||
item for item in annotated_for_filter if item.get("update_available")
|
||||
]
|
||||
update_filter_duration = time.perf_counter() - t2
|
||||
final_count = len(filtered_data)
|
||||
@@ -143,20 +149,27 @@ class BaseModelService(ABC):
|
||||
t4 = time.perf_counter()
|
||||
if update_available_only:
|
||||
# Items already include update flags thanks to the pre-filter annotation.
|
||||
paginated['items'] = list(paginated['items'])
|
||||
paginated["items"] = list(paginated["items"])
|
||||
else:
|
||||
paginated['items'] = await self._annotate_update_flags(
|
||||
paginated['items'],
|
||||
paginated["items"] = await self._annotate_update_flags(
|
||||
paginated["items"],
|
||||
)
|
||||
annotate_duration = time.perf_counter() - t4
|
||||
|
||||
|
||||
overall_duration = time.perf_counter() - overall_start
|
||||
logger.debug(
|
||||
"%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). "
|
||||
"Counts: initial=%d, post_filter=%d, final=%d",
|
||||
self.__class__.__name__, overall_duration, fetch_duration, filter_duration,
|
||||
update_filter_duration, pagination_duration, annotate_duration,
|
||||
initial_count, post_filter_count, final_count
|
||||
self.__class__.__name__,
|
||||
overall_duration,
|
||||
fetch_duration,
|
||||
filter_duration,
|
||||
update_filter_duration,
|
||||
pagination_duration,
|
||||
annotate_duration,
|
||||
initial_count,
|
||||
post_filter_count,
|
||||
final_count,
|
||||
)
|
||||
return paginated
|
||||
|
||||
@@ -167,11 +180,11 @@ class BaseModelService(ABC):
|
||||
|
||||
# Map model type to usage stats bucket
|
||||
bucket_map = {
|
||||
'lora': 'loras',
|
||||
'checkpoint': 'checkpoints',
|
||||
"lora": "loras",
|
||||
"checkpoint": "checkpoints",
|
||||
# 'embedding': 'embeddings', # TODO: Enable when embedding usage tracking is implemented
|
||||
}
|
||||
bucket_key = bucket_map.get(self.model_type, '')
|
||||
bucket_key = bucket_map.get(self.model_type, "")
|
||||
|
||||
usage_stats = UsageStats()
|
||||
stats = await usage_stats.get_stats()
|
||||
@@ -179,45 +192,47 @@ class BaseModelService(ABC):
|
||||
|
||||
annotated = []
|
||||
for item in raw_items:
|
||||
sha = (item.get('sha256') or '').lower()
|
||||
usage_info = usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {}
|
||||
usage_count = usage_info.get('total', 0) if isinstance(usage_info, dict) else 0
|
||||
annotated.append({**item, 'usage_count': usage_count})
|
||||
sha = (item.get("sha256") or "").lower()
|
||||
usage_info = (
|
||||
usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {}
|
||||
)
|
||||
usage_count = (
|
||||
usage_info.get("total", 0) if isinstance(usage_info, dict) else 0
|
||||
)
|
||||
annotated.append({**item, "usage_count": usage_count})
|
||||
|
||||
reverse = sort_params.order == 'desc'
|
||||
reverse = sort_params.order == "desc"
|
||||
annotated.sort(
|
||||
key=lambda x: (x.get('usage_count', 0), x.get('model_name', '').lower()),
|
||||
reverse=reverse
|
||||
key=lambda x: (x.get("usage_count", 0), x.get("model_name", "").lower()),
|
||||
reverse=reverse,
|
||||
)
|
||||
return annotated
|
||||
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
async def _apply_hash_filters(
|
||||
self, data: List[Dict], hash_filters: Dict
|
||||
) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
single_hash = hash_filters.get("single_hash")
|
||||
multiple_hashes = hash_filters.get("multiple_hashes")
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
item for item in data if item.get("sha256", "").lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return [item for item in data if item.get("sha256", "").lower() in hash_set]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def _apply_common_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
folder: str = None,
|
||||
folder_exclude: list = None,
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
@@ -228,6 +243,7 @@ class BaseModelService(ABC):
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
criteria = FilterCriteria(
|
||||
folder=folder,
|
||||
folder_exclude=folder_exclude,
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
@@ -235,7 +251,7 @@ class BaseModelService(ABC):
|
||||
search_options=normalized_options,
|
||||
)
|
||||
return self.filter_set.apply(data, criteria)
|
||||
|
||||
|
||||
async def _apply_search_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
@@ -245,28 +261,34 @@ class BaseModelService(ABC):
|
||||
) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
|
||||
|
||||
return self.search_strategy.apply(
|
||||
data, search, normalized_options, fuzzy_search
|
||||
)
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
async def _apply_credit_required_filter(self, data: List[Dict], credit_required: bool) -> List[Dict]:
|
||||
async def _apply_credit_required_filter(
|
||||
self, data: List[Dict], credit_required: bool
|
||||
) -> List[Dict]:
|
||||
"""Apply credit required filtering based on license_flags.
|
||||
|
||||
|
||||
Args:
|
||||
data: List of model data items
|
||||
credit_required:
|
||||
credit_required:
|
||||
- True: Return items where credit is required (allowNoCredit=False)
|
||||
- False: Return items where credit is not required (allowNoCredit=True)
|
||||
"""
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||
|
||||
license_flags = item.get(
|
||||
"license_flags", 127
|
||||
) # Default to all permissions enabled
|
||||
|
||||
# Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required)
|
||||
allow_no_credit = bool(license_flags & (1 << 0))
|
||||
|
||||
|
||||
# If credit_required is True, we want items where allowNoCredit is False (credit required)
|
||||
# If credit_required is False, we want items where allowNoCredit is True (no credit required)
|
||||
if credit_required:
|
||||
@@ -275,26 +297,30 @@ class BaseModelService(ABC):
|
||||
else:
|
||||
if allow_no_credit: # Credit is not required
|
||||
filtered_data.append(item)
|
||||
|
||||
|
||||
return filtered_data
|
||||
|
||||
async def _apply_allow_selling_filter(self, data: List[Dict], allow_selling: bool) -> List[Dict]:
|
||||
async def _apply_allow_selling_filter(
|
||||
self, data: List[Dict], allow_selling: bool
|
||||
) -> List[Dict]:
|
||||
"""Apply allow selling generated content filtering based on license_flags.
|
||||
|
||||
|
||||
Args:
|
||||
data: List of model data items
|
||||
allow_selling:
|
||||
allow_selling:
|
||||
- True: Return items where selling generated content is allowed (allowCommercialUse contains Image)
|
||||
- False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image)
|
||||
"""
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||
|
||||
license_flags = item.get(
|
||||
"license_flags", 127
|
||||
) # Default to all permissions enabled
|
||||
|
||||
# Bits 1-4 represent commercial use permissions
|
||||
# Bit 1 specifically represents Image permission (allowCommercialUse contains Image)
|
||||
has_image_permission = bool(license_flags & (1 << 1))
|
||||
|
||||
|
||||
# If allow_selling is True, we want items where Image permission is granted
|
||||
# If allow_selling is False, we want items where Image permission is not granted
|
||||
if allow_selling:
|
||||
@@ -303,7 +329,7 @@ class BaseModelService(ABC):
|
||||
else:
|
||||
if not has_image_permission: # Selling generated content is not allowed
|
||||
filtered_data.append(item)
|
||||
|
||||
|
||||
return filtered_data
|
||||
|
||||
async def _annotate_update_flags(
|
||||
@@ -321,7 +347,7 @@ class BaseModelService(ABC):
|
||||
|
||||
if self.update_service is None:
|
||||
for item in annotated:
|
||||
item['update_available'] = False
|
||||
item["update_available"] = False
|
||||
return annotated
|
||||
|
||||
id_to_items: Dict[int, List[Dict]] = {}
|
||||
@@ -329,7 +355,7 @@ class BaseModelService(ABC):
|
||||
for item in annotated:
|
||||
model_id = self._extract_model_id(item)
|
||||
if model_id is None:
|
||||
item['update_available'] = False
|
||||
item["update_available"] = False
|
||||
continue
|
||||
if model_id not in id_to_items:
|
||||
id_to_items[model_id] = []
|
||||
@@ -405,13 +431,19 @@ class BaseModelService(ABC):
|
||||
default_flag = bool(resolved.get(model_id, False)) if resolved else False
|
||||
record = records.get(model_id) if records else None
|
||||
base_highest_versions = (
|
||||
self._build_highest_local_versions_by_base(record) if same_base_mode and record else {}
|
||||
self._build_highest_local_versions_by_base(record)
|
||||
if same_base_mode and record
|
||||
else {}
|
||||
)
|
||||
for item in items_for_id:
|
||||
if same_base_mode and record is not None:
|
||||
base_model = self._extract_base_model(item)
|
||||
normalized_base = self._normalize_base_model_name(base_model)
|
||||
threshold_version = base_highest_versions.get(normalized_base) if normalized_base else None
|
||||
threshold_version = (
|
||||
base_highest_versions.get(normalized_base)
|
||||
if normalized_base
|
||||
else None
|
||||
)
|
||||
if threshold_version is None:
|
||||
threshold_version = self._extract_version_id(item)
|
||||
flag = record.has_update_for_base(
|
||||
@@ -420,17 +452,17 @@ class BaseModelService(ABC):
|
||||
)
|
||||
else:
|
||||
flag = default_flag
|
||||
item['update_available'] = flag
|
||||
item["update_available"] = flag
|
||||
|
||||
return annotated
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_id(item: Dict) -> Optional[int]:
|
||||
civitai = item.get('civitai') if isinstance(item, dict) else None
|
||||
civitai = item.get("civitai") if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
return None
|
||||
try:
|
||||
value = civitai.get('modelId')
|
||||
value = civitai.get("modelId")
|
||||
if value is None:
|
||||
return None
|
||||
return int(value)
|
||||
@@ -439,10 +471,10 @@ class BaseModelService(ABC):
|
||||
|
||||
@staticmethod
|
||||
def _extract_version_id(item: Dict) -> Optional[int]:
|
||||
civitai = item.get('civitai') if isinstance(item, dict) else None
|
||||
civitai = item.get("civitai") if isinstance(item, dict) else None
|
||||
if not isinstance(civitai, dict):
|
||||
return None
|
||||
value = civitai.get('id')
|
||||
value = civitai.get("id")
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
@@ -452,7 +484,7 @@ class BaseModelService(ABC):
|
||||
|
||||
@staticmethod
|
||||
def _extract_base_model(item: Dict) -> Optional[str]:
|
||||
value = item.get('base_model')
|
||||
value = item.get("base_model")
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
@@ -489,7 +521,9 @@ class BaseModelService(ABC):
|
||||
for version in getattr(record, "versions", []):
|
||||
if not getattr(version, "is_in_library", False):
|
||||
continue
|
||||
normalized_base = self._normalize_base_model_name(getattr(version, "base_model", None))
|
||||
normalized_base = self._normalize_base_model_name(
|
||||
getattr(version, "base_model", None)
|
||||
)
|
||||
if normalized_base is None:
|
||||
continue
|
||||
version_id = getattr(version, "version_id", None)
|
||||
@@ -506,25 +540,25 @@ class BaseModelService(ABC):
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
"items": data[start_idx:end_idx],
|
||||
"total": total_items,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": (total_items + page_size - 1) // page_size,
|
||||
}
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
@@ -535,62 +569,85 @@ class BaseModelService(ABC):
|
||||
|
||||
type_counts: Dict[str, int] = {}
|
||||
for entry in cache.raw_data:
|
||||
normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry))
|
||||
normalized_type = normalize_civitai_model_type(
|
||||
resolve_civitai_model_type(entry)
|
||||
)
|
||||
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
|
||||
continue
|
||||
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
|
||||
|
||||
sorted_types = sorted(
|
||||
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
|
||||
[
|
||||
{"type": model_type, "count": count}
|
||||
for model_type, count in type_counts.items()
|
||||
],
|
||||
key=lambda value: value["count"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_types[:limit]
|
||||
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
|
||||
async def scan_models(
|
||||
self, force_refresh: bool = False, rebuild_cache: bool = False
|
||||
):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
return await self.scanner.get_cached_data(
|
||||
force_refresh=force_refresh, rebuild_cache=rebuild_cache
|
||||
)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
|
||||
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images", "customImages", "creator"
|
||||
]
|
||||
fields = (
|
||||
["id", "modelId", "name", "trainedWords"]
|
||||
if minimal
|
||||
else [
|
||||
"id",
|
||||
"modelId",
|
||||
"name",
|
||||
"createdAt",
|
||||
"updatedAt",
|
||||
"publishedAt",
|
||||
"trainedWords",
|
||||
"baseModel",
|
||||
"description",
|
||||
"model",
|
||||
"images",
|
||||
"customImages",
|
||||
"creator",
|
||||
]
|
||||
)
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
|
||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||
"""Get hierarchical folder tree for a specific model root"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
|
||||
# Build tree structure from folders
|
||||
tree = {}
|
||||
|
||||
|
||||
for folder in cache.folders:
|
||||
# Check if this folder belongs to the specified model root
|
||||
folder_belongs_to_root = False
|
||||
@@ -598,95 +655,96 @@ class BaseModelService(ABC):
|
||||
if root == model_root:
|
||||
folder_belongs_to_root = True
|
||||
break
|
||||
|
||||
|
||||
if not folder_belongs_to_root:
|
||||
continue
|
||||
|
||||
|
||||
# Split folder path into components
|
||||
parts = folder.split('/') if folder else []
|
||||
parts = folder.split("/") if folder else []
|
||||
current_level = tree
|
||||
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
|
||||
return tree
|
||||
|
||||
|
||||
async def get_unified_folder_tree(self) -> Dict:
|
||||
"""Get unified folder tree across all model roots"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
|
||||
# Build unified tree structure by analyzing all relative paths
|
||||
unified_tree = {}
|
||||
|
||||
|
||||
# Get all model roots for path normalization
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
|
||||
for folder in cache.folders:
|
||||
if not folder: # Skip empty folders
|
||||
continue
|
||||
|
||||
|
||||
# Find which root this folder belongs to by checking the actual file paths
|
||||
# This is a simplified approach - we'll use the folder as-is since it should already be relative
|
||||
relative_path = folder
|
||||
|
||||
|
||||
# Split folder path into components
|
||||
parts = relative_path.split('/')
|
||||
parts = relative_path.split("/")
|
||||
current_level = unified_tree
|
||||
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
|
||||
return unified_tree
|
||||
|
||||
async def get_model_notes(self, model_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
return model.get('notes', '')
|
||||
|
||||
if model["file_name"] == model_name:
|
||||
return model.get("notes", "")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
preview_url = model.get('preview_url')
|
||||
if model["file_name"] == model_name:
|
||||
preview_url = model.get("preview_url")
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return '/loras_static/images/no-preview.png'
|
||||
|
||||
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
civitai_data = model.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model["file_name"] == model_name:
|
||||
civitai_data = model.get("civitai", {})
|
||||
model_id = civitai_data.get("modelId")
|
||||
version_id = civitai_data.get("id")
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
"civitai_url": civitai_url,
|
||||
"model_id": str(model_id),
|
||||
"version_id": str(version_id) if version_id else None,
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
return {"civitai_url": None, "model_id": None, "version_id": None}
|
||||
|
||||
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
||||
"""Load full metadata for a single model.
|
||||
@@ -694,18 +752,21 @@ class BaseModelService(ABC):
|
||||
Listing/search endpoints return lightweight cache entries; this method performs
|
||||
a lazy read of the on-disk metadata snapshot when callers need full detail.
|
||||
"""
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
|
||||
metadata, should_skip = await MetadataManager.load_metadata(
|
||||
file_path, self.metadata_class
|
||||
)
|
||||
if should_skip or metadata is None:
|
||||
return None
|
||||
return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
|
||||
|
||||
|
||||
async def get_model_description(self, file_path: str) -> Optional[str]:
|
||||
"""Return the stored modelDescription field for a model."""
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.metadata_class)
|
||||
metadata, should_skip = await MetadataManager.load_metadata(
|
||||
file_path, self.metadata_class
|
||||
)
|
||||
if should_skip or metadata is None:
|
||||
return None
|
||||
return metadata.modelDescription or ''
|
||||
return metadata.modelDescription or ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]:
|
||||
@@ -743,53 +804,64 @@ class BaseModelService(ABC):
|
||||
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
|
||||
"""Sort paths by how well they satisfy the include tokens."""
|
||||
path_lower = relative_path.lower()
|
||||
prefix_hits = sum(1 for term in include_terms if term and path_lower.startswith(term))
|
||||
match_positions = [path_lower.find(term) for term in include_terms if term and term in path_lower]
|
||||
prefix_hits = sum(
|
||||
1 for term in include_terms if term and path_lower.startswith(term)
|
||||
)
|
||||
match_positions = [
|
||||
path_lower.find(term)
|
||||
for term in include_terms
|
||||
if term and term in path_lower
|
||||
]
|
||||
first_match_index = min(match_positions) if match_positions else 0
|
||||
|
||||
return (-prefix_hits, first_match_index, len(relative_path), path_lower)
|
||||
|
||||
|
||||
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
||||
async def search_relative_paths(
|
||||
self, search_term: str, limit: int = 15
|
||||
) -> List[str]:
|
||||
"""Search model relative file paths for autocomplete functionality"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
include_terms, exclude_terms = self._parse_search_tokens(search_term)
|
||||
|
||||
|
||||
matching_paths = []
|
||||
|
||||
|
||||
# Get model roots for path calculation
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
|
||||
for model in cache.raw_data:
|
||||
file_path = model.get('file_path', '')
|
||||
file_path = model.get("file_path", "")
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
|
||||
# Calculate relative path from model root
|
||||
relative_path = None
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root)
|
||||
normalized_file = os.path.normpath(file_path)
|
||||
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
# Remove root and leading separator to get relative path
|
||||
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
|
||||
relative_path = normalized_file[len(normalized_root) :].lstrip(
|
||||
os.sep
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
if not relative_path:
|
||||
continue
|
||||
|
||||
relative_lower = relative_path.lower()
|
||||
if self._relative_path_matches_tokens(relative_lower, include_terms, exclude_terms):
|
||||
if self._relative_path_matches_tokens(
|
||||
relative_lower, include_terms, exclude_terms
|
||||
):
|
||||
matching_paths.append(relative_path)
|
||||
|
||||
|
||||
if len(matching_paths) >= limit * 2: # Get more for better sorting
|
||||
break
|
||||
|
||||
|
||||
# Sort by relevance (prefix and earliest hits first, then by length and alphabetically)
|
||||
matching_paths.sort(
|
||||
key=lambda relative: self._relative_path_sort_key(relative, include_terms)
|
||||
)
|
||||
|
||||
|
||||
return matching_paths[:limit]
|
||||
|
||||
@@ -1,7 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Protocol,
|
||||
Callable,
|
||||
)
|
||||
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||
@@ -51,8 +62,7 @@ def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
|
||||
class SettingsProvider(Protocol):
|
||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
...
|
||||
def get(self, key: str, default: Any = None) -> Any: ...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -68,6 +78,7 @@ class FilterCriteria:
|
||||
"""Container for model list filtering options."""
|
||||
|
||||
folder: Optional[str] = None
|
||||
folder_exclude: Optional[Sequence[str]] = None
|
||||
base_models: Optional[Sequence[str]] = None
|
||||
tags: Optional[Dict[str, str]] = None
|
||||
favorites_only: bool = False
|
||||
@@ -113,11 +124,15 @@ class ModelCacheRepository:
|
||||
class ModelFilterSet:
|
||||
"""Applies common filtering rules to the model collection."""
|
||||
|
||||
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
|
||||
def __init__(
|
||||
self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None
|
||||
) -> None:
|
||||
self._settings = settings
|
||||
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
|
||||
|
||||
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
||||
def apply(
|
||||
self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return items that satisfy the provided criteria."""
|
||||
overall_start = time.perf_counter()
|
||||
items = list(data)
|
||||
@@ -127,8 +142,10 @@ class ModelFilterSet:
|
||||
t0 = time.perf_counter()
|
||||
threshold = self._nsfw_levels.get("R", 0)
|
||||
items = [
|
||||
item for item in items
|
||||
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
||||
item
|
||||
for item in items
|
||||
if not item.get("preview_nsfw_level")
|
||||
or item.get("preview_nsfw_level") < threshold
|
||||
]
|
||||
sfw_duration = time.perf_counter() - t0
|
||||
else:
|
||||
@@ -142,20 +159,44 @@ class ModelFilterSet:
|
||||
|
||||
folder_duration = 0
|
||||
folder = criteria.folder
|
||||
folder_exclude = criteria.folder_exclude or []
|
||||
options = criteria.search_options or {}
|
||||
recursive = bool(options.get("recursive", True))
|
||||
|
||||
# Apply folder exclude filters first
|
||||
if folder_exclude:
|
||||
t0 = time.perf_counter()
|
||||
for exclude_folder in folder_exclude:
|
||||
if exclude_folder:
|
||||
# Check exact match OR prefix match (for subfolders)
|
||||
# Normalize exclude_folder for prefix matching
|
||||
if not exclude_folder.endswith("/"):
|
||||
exclude_prefix = f"{exclude_folder}/"
|
||||
else:
|
||||
exclude_prefix = exclude_folder
|
||||
items = [
|
||||
item
|
||||
for item in items
|
||||
if item.get("folder") != exclude_folder
|
||||
and not item.get("folder", "").startswith(exclude_prefix)
|
||||
]
|
||||
folder_duration = time.perf_counter() - t0
|
||||
|
||||
# Apply folder include filters
|
||||
if folder is not None:
|
||||
t0 = time.perf_counter()
|
||||
if recursive:
|
||||
if folder:
|
||||
folder_with_sep = f"{folder}/"
|
||||
items = [
|
||||
item for item in items
|
||||
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
|
||||
item
|
||||
for item in items
|
||||
if item.get("folder") == folder
|
||||
or item.get("folder", "").startswith(folder_with_sep)
|
||||
]
|
||||
else:
|
||||
items = [item for item in items if item.get("folder") == folder]
|
||||
folder_duration = time.perf_counter() - t0
|
||||
folder_duration = time.perf_counter() - t0 + folder_duration
|
||||
|
||||
base_models_duration = 0
|
||||
base_models = criteria.base_models or []
|
||||
@@ -183,25 +224,23 @@ class ModelFilterSet:
|
||||
include_tags = {tag for tag in tag_filters if tag}
|
||||
|
||||
if include_tags:
|
||||
|
||||
def matches_include(item_tags):
|
||||
if not item_tags and "__no_tags__" in include_tags:
|
||||
return True
|
||||
return any(tag in include_tags for tag in (item_tags or []))
|
||||
|
||||
items = [
|
||||
item for item in items
|
||||
if matches_include(item.get("tags"))
|
||||
]
|
||||
items = [item for item in items if matches_include(item.get("tags"))]
|
||||
|
||||
if exclude_tags:
|
||||
|
||||
def matches_exclude(item_tags):
|
||||
if not item_tags and "__no_tags__" in exclude_tags:
|
||||
return True
|
||||
return any(tag in exclude_tags for tag in (item_tags or []))
|
||||
|
||||
items = [
|
||||
item for item in items
|
||||
if not matches_exclude(item.get("tags"))
|
||||
item for item in items if not matches_exclude(item.get("tags"))
|
||||
]
|
||||
tags_duration = time.perf_counter() - t0
|
||||
|
||||
@@ -210,26 +249,35 @@ class ModelFilterSet:
|
||||
if model_types:
|
||||
t0 = time.perf_counter()
|
||||
normalized_model_types = {
|
||||
model_type for model_type in (
|
||||
model_type
|
||||
for model_type in (
|
||||
normalize_civitai_model_type(value) for value in model_types
|
||||
)
|
||||
if model_type
|
||||
}
|
||||
if normalized_model_types:
|
||||
items = [
|
||||
item for item in items
|
||||
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
|
||||
item
|
||||
for item in items
|
||||
if normalize_civitai_model_type(resolve_civitai_model_type(item))
|
||||
in normalized_model_types
|
||||
]
|
||||
model_types_duration = time.perf_counter() - t0
|
||||
|
||||
duration = time.perf_counter() - overall_start
|
||||
if duration > 0.1: # Only log if it's potentially slow
|
||||
if duration > 0.1: # Only log if it's potentially slow
|
||||
logger.debug(
|
||||
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). "
|
||||
"Count: %d -> %d",
|
||||
duration, sfw_duration, favorites_duration, folder_duration,
|
||||
base_models_duration, tags_duration, model_types_duration,
|
||||
initial_count, len(items)
|
||||
duration,
|
||||
sfw_duration,
|
||||
favorites_duration,
|
||||
folder_duration,
|
||||
base_models_duration,
|
||||
tags_duration,
|
||||
model_types_duration,
|
||||
initial_count,
|
||||
len(items),
|
||||
)
|
||||
return items
|
||||
|
||||
@@ -245,7 +293,9 @@ class SearchStrategy:
|
||||
"creator": False,
|
||||
}
|
||||
|
||||
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
|
||||
def __init__(
|
||||
self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None
|
||||
) -> None:
|
||||
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
|
||||
|
||||
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
@@ -284,7 +334,9 @@ class SearchStrategy:
|
||||
|
||||
if options.get("tags", False):
|
||||
tags = item.get("tags", []) or []
|
||||
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
|
||||
if any(
|
||||
self._matches(tag, search_term, search_lower, fuzzy) for tag in tags
|
||||
):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
@@ -295,13 +347,17 @@ class SearchStrategy:
|
||||
creator = civitai.get("creator")
|
||||
if isinstance(creator, dict):
|
||||
creator_username = creator.get("username", "")
|
||||
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
|
||||
if creator_username and self._matches(
|
||||
creator_username, search_term, search_lower, fuzzy
|
||||
):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
|
||||
def _matches(
|
||||
self, candidate: str, search_term: str, search_lower: str, fuzzy: bool
|
||||
) -> bool:
|
||||
if not isinstance(candidate, str):
|
||||
candidate = "" if candidate is None else str(candidate)
|
||||
|
||||
|
||||
@@ -3,13 +3,16 @@ from py.services.model_query import ModelFilterSet, FilterCriteria
|
||||
from py.services.recipe_scanner import RecipeScanner
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
# Mock settings
|
||||
class MockSettings:
|
||||
def get(self, key, default=None):
|
||||
return default
|
||||
|
||||
|
||||
# --- Model Filtering Tests ---
|
||||
|
||||
|
||||
def test_model_filter_set_root_recursive_true():
|
||||
filter_set = ModelFilterSet(MockSettings())
|
||||
items = [
|
||||
@@ -17,13 +20,14 @@ def test_model_filter_set_root_recursive_true():
|
||||
{"model_name": "sub_item", "folder": "sub"},
|
||||
]
|
||||
criteria = FilterCriteria(folder="", search_options={"recursive": True})
|
||||
|
||||
|
||||
result = filter_set.apply(items, criteria)
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert any(i["model_name"] == "root_item" for i in result)
|
||||
assert any(i["model_name"] == "sub_item" for i in result)
|
||||
|
||||
|
||||
def test_model_filter_set_root_recursive_false():
|
||||
filter_set = ModelFilterSet(MockSettings())
|
||||
items = [
|
||||
@@ -31,62 +35,185 @@ def test_model_filter_set_root_recursive_false():
|
||||
{"model_name": "sub_item", "folder": "sub"},
|
||||
]
|
||||
criteria = FilterCriteria(folder="", search_options={"recursive": False})
|
||||
|
||||
|
||||
result = filter_set.apply(items, criteria)
|
||||
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["model_name"] == "root_item"
|
||||
|
||||
|
||||
def test_model_filter_set_folder_exclude_single():
|
||||
filter_set = ModelFilterSet(MockSettings())
|
||||
items = [
|
||||
{"model_name": "item1", "folder": "characters/"},
|
||||
{"model_name": "item2", "folder": "styles/"},
|
||||
{"model_name": "item3", "folder": "characters/anime/"},
|
||||
{"model_name": "item4", "folder": ""},
|
||||
]
|
||||
criteria = FilterCriteria(
|
||||
folder_exclude=["characters/"], search_options={"recursive": True}
|
||||
)
|
||||
|
||||
result = filter_set.apply(items, criteria)
|
||||
|
||||
assert len(result) == 2
|
||||
model_names = {i["model_name"] for i in result}
|
||||
assert model_names == {"item2", "item4"}
|
||||
|
||||
|
||||
def test_model_filter_set_folder_exclude_multiple():
|
||||
filter_set = ModelFilterSet(MockSettings())
|
||||
items = [
|
||||
{"model_name": "item1", "folder": "characters/"},
|
||||
{"model_name": "item2", "folder": "styles/"},
|
||||
{"model_name": "item3", "folder": "concepts/"},
|
||||
{"model_name": "item4", "folder": "characters/anime/"},
|
||||
{"model_name": "item5", "folder": ""},
|
||||
]
|
||||
criteria = FilterCriteria(
|
||||
folder_exclude=["characters/", "styles/"], search_options={"recursive": True}
|
||||
)
|
||||
|
||||
result = filter_set.apply(items, criteria)
|
||||
|
||||
assert len(result) == 2
|
||||
model_names = {i["model_name"] for i in result}
|
||||
assert model_names == {"item3", "item5"}
|
||||
|
||||
|
||||
def test_model_filter_set_folder_exclude_with_include():
|
||||
filter_set = ModelFilterSet(MockSettings())
|
||||
items = [
|
||||
{"model_name": "item1", "folder": "characters/"},
|
||||
{"model_name": "item2", "folder": "styles/"},
|
||||
{"model_name": "item3", "folder": "characters/anime/"},
|
||||
{"model_name": "item4", "folder": "styles/painting/"},
|
||||
{"model_name": "item5", "folder": "concepts/"},
|
||||
]
|
||||
criteria = FilterCriteria(
|
||||
folder="characters/",
|
||||
folder_exclude=["characters/anime/"],
|
||||
search_options={"recursive": True},
|
||||
)
|
||||
|
||||
result = filter_set.apply(items, criteria)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["model_name"] == "item1"
|
||||
|
||||
|
||||
# --- Recipe Filtering Tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recipe_scanner_root_recursive_true():
|
||||
# Mock LoraScanner
|
||||
class StubLoraScanner:
|
||||
async def get_cached_data(self):
|
||||
return SimpleNamespace(raw_data=[])
|
||||
|
||||
|
||||
scanner = RecipeScanner(lora_scanner=StubLoraScanner())
|
||||
# Manually populate cache for testing get_paginated_data logic
|
||||
scanner._cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
|
||||
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
|
||||
{
|
||||
"id": "r1",
|
||||
"title": "root_recipe",
|
||||
"folder": "",
|
||||
"modified": 1.0,
|
||||
"created_date": 1.0,
|
||||
"loras": [],
|
||||
},
|
||||
{
|
||||
"id": "r2",
|
||||
"title": "sub_recipe",
|
||||
"folder": "sub",
|
||||
"modified": 2.0,
|
||||
"created_date": 2.0,
|
||||
"loras": [],
|
||||
},
|
||||
],
|
||||
sorted_by_date=[
|
||||
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
|
||||
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
|
||||
{
|
||||
"id": "r2",
|
||||
"title": "sub_recipe",
|
||||
"folder": "sub",
|
||||
"modified": 2.0,
|
||||
"created_date": 2.0,
|
||||
"loras": [],
|
||||
},
|
||||
{
|
||||
"id": "r1",
|
||||
"title": "root_recipe",
|
||||
"folder": "",
|
||||
"modified": 1.0,
|
||||
"created_date": 1.0,
|
||||
"loras": [],
|
||||
},
|
||||
],
|
||||
sorted_by_name=[],
|
||||
version_index={}
|
||||
version_index={},
|
||||
)
|
||||
|
||||
result = await scanner.get_paginated_data(page=1, page_size=10, folder="", recursive=True)
|
||||
|
||||
|
||||
result = await scanner.get_paginated_data(
|
||||
page=1, page_size=10, folder="", recursive=True
|
||||
)
|
||||
|
||||
assert len(result["items"]) == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recipe_scanner_root_recursive_false():
|
||||
# Mock LoraScanner
|
||||
class StubLoraScanner:
|
||||
async def get_cached_data(self):
|
||||
return SimpleNamespace(raw_data=[])
|
||||
|
||||
|
||||
scanner = RecipeScanner(lora_scanner=StubLoraScanner())
|
||||
scanner._cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
|
||||
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
|
||||
{
|
||||
"id": "r1",
|
||||
"title": "root_recipe",
|
||||
"folder": "",
|
||||
"modified": 1.0,
|
||||
"created_date": 1.0,
|
||||
"loras": [],
|
||||
},
|
||||
{
|
||||
"id": "r2",
|
||||
"title": "sub_recipe",
|
||||
"folder": "sub",
|
||||
"modified": 2.0,
|
||||
"created_date": 2.0,
|
||||
"loras": [],
|
||||
},
|
||||
],
|
||||
sorted_by_date=[
|
||||
{"id": "r2", "title": "sub_recipe", "folder": "sub", "modified": 2.0, "created_date": 2.0, "loras": []},
|
||||
{"id": "r1", "title": "root_recipe", "folder": "", "modified": 1.0, "created_date": 1.0, "loras": []},
|
||||
{
|
||||
"id": "r2",
|
||||
"title": "sub_recipe",
|
||||
"folder": "sub",
|
||||
"modified": 2.0,
|
||||
"created_date": 2.0,
|
||||
"loras": [],
|
||||
},
|
||||
{
|
||||
"id": "r1",
|
||||
"title": "root_recipe",
|
||||
"folder": "",
|
||||
"modified": 1.0,
|
||||
"created_date": 1.0,
|
||||
"loras": [],
|
||||
},
|
||||
],
|
||||
sorted_by_name=[],
|
||||
version_index={}
|
||||
version_index={},
|
||||
)
|
||||
|
||||
result = await scanner.get_paginated_data(page=1, page_size=10, folder="", recursive=False)
|
||||
|
||||
|
||||
result = await scanner.get_paginated_data(
|
||||
page=1, page_size=10, folder="", recursive=False
|
||||
)
|
||||
|
||||
assert len(result["items"]) == 1
|
||||
assert result["items"][0]["id"] == "r1"
|
||||
|
||||
@@ -77,11 +77,12 @@ export function useLoraPoolApi() {
|
||||
params.tagsInclude?.forEach(tag => urlParams.append('tag_include', tag))
|
||||
params.tagsExclude?.forEach(tag => urlParams.append('tag_exclude', tag))
|
||||
|
||||
// For now, use first include folder (backend currently supports single folder)
|
||||
// Folder filters
|
||||
if (params.foldersInclude && params.foldersInclude.length > 0) {
|
||||
urlParams.set('folder', params.foldersInclude[0])
|
||||
urlParams.set('recursive', 'true')
|
||||
}
|
||||
params.foldersExclude?.forEach(folder => urlParams.append('folder_exclude', folder))
|
||||
|
||||
if (params.noCreditRequired !== undefined) {
|
||||
urlParams.set('credit_required', String(!params.noCreditRequired))
|
||||
|
||||
@@ -10870,7 +10870,7 @@ function useLoraPoolApi() {
|
||||
});
|
||||
};
|
||||
const fetchLoras = async (params) => {
|
||||
var _a, _b, _c;
|
||||
var _a, _b, _c, _d;
|
||||
isLoading.value = true;
|
||||
try {
|
||||
const urlParams = new URLSearchParams();
|
||||
@@ -10883,6 +10883,7 @@ function useLoraPoolApi() {
|
||||
urlParams.set("folder", params.foldersInclude[0]);
|
||||
urlParams.set("recursive", "true");
|
||||
}
|
||||
(_d = params.foldersExclude) == null ? void 0 : _d.forEach((folder) => urlParams.append("folder_exclude", folder));
|
||||
if (params.noCreditRequired !== void 0) {
|
||||
urlParams.set("credit_required", String(!params.noCreditRequired));
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user