mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add file_path as a tie-breaker for all sort modes in ModelCache, BaseModelService, LoraService, and RecipeCache to ensure deterministic ordering when primary keys are identical. Resolves issue #859.
932 lines
34 KiB
Python
932 lines
34 KiB
Python
from abc import ABC, abstractmethod
|
|
import asyncio
|
|
import re
|
|
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
|
import logging
|
|
import os
|
|
import time
|
|
|
|
from ..utils.constants import VALID_LORA_SUB_TYPES, VALID_CHECKPOINT_SUB_TYPES
|
|
from ..utils.models import BaseModelMetadata
|
|
from ..utils.metadata_manager import MetadataManager
|
|
from ..utils.usage_stats import UsageStats
|
|
from .model_query import (
|
|
FilterCriteria,
|
|
ModelCacheRepository,
|
|
ModelFilterSet,
|
|
SearchStrategy,
|
|
SettingsProvider,
|
|
normalize_sub_type,
|
|
resolve_sub_type,
|
|
)
|
|
from .settings_manager import get_settings_manager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from .model_update_service import ModelUpdateService
|
|
|
|
|
|
class BaseModelService(ABC):
|
|
"""Base service class for all model types"""
|
|
|
|
def __init__(
|
|
self,
|
|
model_type: str,
|
|
scanner,
|
|
metadata_class: Type[BaseModelMetadata],
|
|
*,
|
|
cache_repository: Optional[ModelCacheRepository] = None,
|
|
filter_set: Optional[ModelFilterSet] = None,
|
|
search_strategy: Optional[SearchStrategy] = None,
|
|
settings_provider: Optional[SettingsProvider] = None,
|
|
update_service: Optional["ModelUpdateService"] = None,
|
|
):
|
|
"""Initialize the service.
|
|
|
|
Args:
|
|
model_type: Type of model (lora, checkpoint, etc.).
|
|
scanner: Model scanner instance.
|
|
metadata_class: Metadata class for this model type.
|
|
cache_repository: Custom repository for cache access (primarily for tests).
|
|
filter_set: Filter component controlling folder/tag/favorites logic.
|
|
search_strategy: Search component for fuzzy/text matching.
|
|
settings_provider: Settings object; defaults to the global settings manager.
|
|
update_service: Service used to determine whether models have remote updates available.
|
|
"""
|
|
self.model_type = model_type
|
|
self.scanner = scanner
|
|
self.metadata_class = metadata_class
|
|
self.settings = settings_provider or get_settings_manager()
|
|
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
|
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
|
self.search_strategy = search_strategy or SearchStrategy()
|
|
self.update_service = update_service
|
|
|
|
async def get_paginated_data(
|
|
self,
|
|
page: int,
|
|
page_size: int,
|
|
sort_by: str = "name",
|
|
folder: str = None,
|
|
folder_include: list = None,
|
|
folder_exclude: list = None,
|
|
search: str = None,
|
|
fuzzy_search: bool = False,
|
|
base_models: list = None,
|
|
model_types: list = None,
|
|
tags: Optional[Dict[str, str]] = None,
|
|
search_options: dict = None,
|
|
hash_filters: dict = None,
|
|
favorites_only: bool = False,
|
|
update_available_only: bool = False,
|
|
credit_required: Optional[bool] = None,
|
|
allow_selling_generated_content: Optional[bool] = None,
|
|
tag_logic: str = "any",
|
|
**kwargs,
|
|
) -> Dict:
|
|
"""Get paginated and filtered model data"""
|
|
overall_start = time.perf_counter()
|
|
|
|
sort_params = self.cache_repository.parse_sort(sort_by)
|
|
t0 = time.perf_counter()
|
|
if sort_params.key == "usage":
|
|
sorted_data = await self._fetch_with_usage_sort(sort_params)
|
|
else:
|
|
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
|
fetch_duration = time.perf_counter() - t0
|
|
initial_count = len(sorted_data)
|
|
|
|
t1 = time.perf_counter()
|
|
if hash_filters:
|
|
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
|
else:
|
|
filtered_data = await self._apply_common_filters(
|
|
sorted_data,
|
|
folder=folder,
|
|
folder_include=folder_include,
|
|
folder_exclude=folder_exclude,
|
|
base_models=base_models,
|
|
model_types=model_types,
|
|
tags=tags,
|
|
favorites_only=favorites_only,
|
|
search_options=search_options,
|
|
tag_logic=tag_logic,
|
|
)
|
|
|
|
if search:
|
|
filtered_data = await self._apply_search_filters(
|
|
filtered_data,
|
|
search,
|
|
fuzzy_search,
|
|
search_options,
|
|
)
|
|
|
|
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
|
|
|
# Apply license-based filters
|
|
if credit_required is not None:
|
|
filtered_data = await self._apply_credit_required_filter(
|
|
filtered_data, credit_required
|
|
)
|
|
|
|
if allow_selling_generated_content is not None:
|
|
filtered_data = await self._apply_allow_selling_filter(
|
|
filtered_data, allow_selling_generated_content
|
|
)
|
|
filter_duration = time.perf_counter() - t1
|
|
post_filter_count = len(filtered_data)
|
|
|
|
annotated_for_filter: Optional[List[Dict]] = None
|
|
t2 = time.perf_counter()
|
|
if update_available_only:
|
|
annotated_for_filter = await self._annotate_update_flags(filtered_data)
|
|
filtered_data = [
|
|
item for item in annotated_for_filter if item.get("update_available")
|
|
]
|
|
update_filter_duration = time.perf_counter() - t2
|
|
final_count = len(filtered_data)
|
|
|
|
t3 = time.perf_counter()
|
|
paginated = self._paginate(filtered_data, page, page_size)
|
|
pagination_duration = time.perf_counter() - t3
|
|
|
|
t4 = time.perf_counter()
|
|
if update_available_only:
|
|
# Items already include update flags thanks to the pre-filter annotation.
|
|
paginated["items"] = list(paginated["items"])
|
|
else:
|
|
paginated["items"] = await self._annotate_update_flags(
|
|
paginated["items"],
|
|
)
|
|
annotate_duration = time.perf_counter() - t4
|
|
|
|
overall_duration = time.perf_counter() - overall_start
|
|
logger.debug(
|
|
"%s.get_paginated_data took %.3fs (fetch: %.3fs, filter: %.3fs, update_filter: %.3fs, pagination: %.3fs, annotate: %.3fs). "
|
|
"Counts: initial=%d, post_filter=%d, final=%d",
|
|
self.__class__.__name__,
|
|
overall_duration,
|
|
fetch_duration,
|
|
filter_duration,
|
|
update_filter_duration,
|
|
pagination_duration,
|
|
annotate_duration,
|
|
initial_count,
|
|
post_filter_count,
|
|
final_count,
|
|
)
|
|
return paginated
|
|
|
|
async def _fetch_with_usage_sort(self, sort_params):
|
|
"""Fetch data sorted by usage count (desc/asc)."""
|
|
cache = await self.cache_repository.get_cache()
|
|
raw_items = cache.raw_data or []
|
|
|
|
# Map model type to usage stats bucket
|
|
bucket_map = {
|
|
"lora": "loras",
|
|
"checkpoint": "checkpoints",
|
|
# 'embedding': 'embeddings', # TODO: Enable when embedding usage tracking is implemented
|
|
}
|
|
bucket_key = bucket_map.get(self.model_type, "")
|
|
|
|
usage_stats = UsageStats()
|
|
stats = await usage_stats.get_stats()
|
|
usage_bucket = stats.get(bucket_key, {}) if bucket_key else {}
|
|
|
|
annotated = []
|
|
for item in raw_items:
|
|
sha = (item.get("sha256") or "").lower()
|
|
usage_info = (
|
|
usage_bucket.get(sha, {}) if isinstance(usage_bucket, dict) else {}
|
|
)
|
|
usage_count = (
|
|
usage_info.get("total", 0) if isinstance(usage_info, dict) else 0
|
|
)
|
|
annotated.append({**item, "usage_count": usage_count})
|
|
|
|
reverse = sort_params.order == "desc"
|
|
annotated.sort(
|
|
key=lambda x: (
|
|
x.get("usage_count", 0),
|
|
x.get("model_name", "").lower(),
|
|
x.get("file_path", "").lower()
|
|
),
|
|
reverse=reverse,
|
|
)
|
|
return annotated
|
|
|
|
async def _apply_hash_filters(
|
|
self, data: List[Dict], hash_filters: Dict
|
|
) -> List[Dict]:
|
|
"""Apply hash-based filtering"""
|
|
single_hash = hash_filters.get("single_hash")
|
|
multiple_hashes = hash_filters.get("multiple_hashes")
|
|
|
|
if single_hash:
|
|
# Filter by single hash
|
|
single_hash = single_hash.lower()
|
|
return [
|
|
item for item in data if item.get("sha256", "").lower() == single_hash
|
|
]
|
|
elif multiple_hashes:
|
|
# Filter by multiple hashes
|
|
hash_set = set(hash.lower() for hash in multiple_hashes)
|
|
return [item for item in data if item.get("sha256", "").lower() in hash_set]
|
|
|
|
return data
|
|
|
|
async def _apply_common_filters(
|
|
self,
|
|
data: List[Dict],
|
|
folder: str = None,
|
|
folder_include: list = None,
|
|
folder_exclude: list = None,
|
|
base_models: list = None,
|
|
model_types: list = None,
|
|
tags: Optional[Dict[str, str]] = None,
|
|
favorites_only: bool = False,
|
|
search_options: dict = None,
|
|
tag_logic: str = "any",
|
|
) -> List[Dict]:
|
|
"""Apply common filters that work across all model types"""
|
|
normalized_options = self.search_strategy.normalize_options(search_options)
|
|
criteria = FilterCriteria(
|
|
folder=folder,
|
|
folder_include=folder_include,
|
|
folder_exclude=folder_exclude,
|
|
base_models=base_models,
|
|
model_types=model_types,
|
|
tags=tags,
|
|
favorites_only=favorites_only,
|
|
search_options=normalized_options,
|
|
tag_logic=tag_logic,
|
|
)
|
|
return self.filter_set.apply(data, criteria)
|
|
|
|
async def _apply_search_filters(
|
|
self,
|
|
data: List[Dict],
|
|
search: str,
|
|
fuzzy_search: bool,
|
|
search_options: dict,
|
|
) -> List[Dict]:
|
|
"""Apply search filtering"""
|
|
normalized_options = self.search_strategy.normalize_options(search_options)
|
|
return self.search_strategy.apply(
|
|
data, search, normalized_options, fuzzy_search
|
|
)
|
|
|
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
|
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
|
return data
|
|
|
|
async def _apply_credit_required_filter(
|
|
self, data: List[Dict], credit_required: bool
|
|
) -> List[Dict]:
|
|
"""Apply credit required filtering based on license_flags.
|
|
|
|
Args:
|
|
data: List of model data items
|
|
credit_required:
|
|
- True: Return items where credit is required (allowNoCredit=False)
|
|
- False: Return items where credit is not required (allowNoCredit=True)
|
|
"""
|
|
filtered_data = []
|
|
for item in data:
|
|
license_flags = item.get(
|
|
"license_flags", 127
|
|
) # Default to all permissions enabled
|
|
|
|
# Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required)
|
|
allow_no_credit = bool(license_flags & (1 << 0))
|
|
|
|
# If credit_required is True, we want items where allowNoCredit is False (credit required)
|
|
# If credit_required is False, we want items where allowNoCredit is True (no credit required)
|
|
if credit_required:
|
|
if not allow_no_credit: # Credit is required
|
|
filtered_data.append(item)
|
|
else:
|
|
if allow_no_credit: # Credit is not required
|
|
filtered_data.append(item)
|
|
|
|
return filtered_data
|
|
|
|
async def _apply_allow_selling_filter(
|
|
self, data: List[Dict], allow_selling: bool
|
|
) -> List[Dict]:
|
|
"""Apply allow selling generated content filtering based on license_flags.
|
|
|
|
Args:
|
|
data: List of model data items
|
|
allow_selling:
|
|
- True: Return items where selling generated content is allowed (allowCommercialUse contains Image)
|
|
- False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image)
|
|
"""
|
|
filtered_data = []
|
|
for item in data:
|
|
license_flags = item.get(
|
|
"license_flags", 127
|
|
) # Default to all permissions enabled
|
|
|
|
# Bits 1-4 represent commercial use permissions
|
|
# Bit 1 specifically represents Image permission (allowCommercialUse contains Image)
|
|
has_image_permission = bool(license_flags & (1 << 1))
|
|
|
|
# If allow_selling is True, we want items where Image permission is granted
|
|
# If allow_selling is False, we want items where Image permission is not granted
|
|
if allow_selling:
|
|
if has_image_permission: # Selling generated content is allowed
|
|
filtered_data.append(item)
|
|
else:
|
|
if not has_image_permission: # Selling generated content is not allowed
|
|
filtered_data.append(item)
|
|
|
|
return filtered_data
|
|
|
|
async def _annotate_update_flags(
|
|
self,
|
|
items: List[Dict],
|
|
) -> List[Dict]:
|
|
"""Attach an update_available flag to each response item.
|
|
|
|
Items without a civitai model id default to False.
|
|
"""
|
|
if not items:
|
|
return []
|
|
|
|
annotated = [dict(item) for item in items]
|
|
|
|
if self.update_service is None:
|
|
for item in annotated:
|
|
item["update_available"] = False
|
|
return annotated
|
|
|
|
id_to_items: Dict[int, List[Dict]] = {}
|
|
ordered_ids: List[int] = []
|
|
for item in annotated:
|
|
model_id = self._extract_model_id(item)
|
|
if model_id is None:
|
|
item["update_available"] = False
|
|
continue
|
|
if model_id not in id_to_items:
|
|
id_to_items[model_id] = []
|
|
ordered_ids.append(model_id)
|
|
id_to_items[model_id].append(item)
|
|
|
|
if not ordered_ids:
|
|
return annotated
|
|
|
|
strategy_value = self.settings.get("update_flag_strategy")
|
|
if isinstance(strategy_value, str) and strategy_value.strip():
|
|
strategy = strategy_value.strip().lower()
|
|
else:
|
|
strategy = "same_base"
|
|
same_base_mode = strategy == "same_base"
|
|
|
|
# Check user setting for hiding early access updates
|
|
hide_early_access = False
|
|
try:
|
|
hide_early_access = bool(
|
|
self.settings.get("hide_early_access_updates", False)
|
|
)
|
|
except Exception:
|
|
hide_early_access = False
|
|
|
|
records = None
|
|
resolved: Optional[Dict[int, bool]] = None
|
|
if same_base_mode:
|
|
record_method = getattr(self.update_service, "get_records_bulk", None)
|
|
if callable(record_method):
|
|
try:
|
|
records = await record_method(self.model_type, ordered_ids)
|
|
resolved = {
|
|
model_id: record.has_update(hide_early_access=hide_early_access)
|
|
for model_id, record in records.items()
|
|
}
|
|
except Exception as exc:
|
|
logger.error(
|
|
"Failed to resolve update records in bulk for %s models (%s): %s",
|
|
self.model_type,
|
|
ordered_ids,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
records = None
|
|
resolved = None
|
|
|
|
if resolved is None:
|
|
bulk_method = getattr(self.update_service, "has_updates_bulk", None)
|
|
if callable(bulk_method):
|
|
try:
|
|
resolved = await bulk_method(
|
|
self.model_type,
|
|
ordered_ids,
|
|
hide_early_access=hide_early_access,
|
|
)
|
|
except Exception as exc:
|
|
logger.error(
|
|
"Failed to resolve update status in bulk for %s models (%s): %s",
|
|
self.model_type,
|
|
ordered_ids,
|
|
exc,
|
|
exc_info=True,
|
|
)
|
|
resolved = None
|
|
|
|
if resolved is None:
|
|
tasks = [
|
|
self.update_service.has_update(
|
|
self.model_type, model_id, hide_early_access=hide_early_access
|
|
)
|
|
for model_id in ordered_ids
|
|
]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
resolved = {}
|
|
for model_id, result in zip(ordered_ids, results):
|
|
if isinstance(result, Exception):
|
|
logger.error(
|
|
"Failed to resolve update status for model %s (%s): %s",
|
|
model_id,
|
|
self.model_type,
|
|
result,
|
|
)
|
|
continue
|
|
resolved[model_id] = bool(result)
|
|
|
|
for model_id, items_for_id in id_to_items.items():
|
|
default_flag = bool(resolved.get(model_id, False)) if resolved else False
|
|
record = records.get(model_id) if records else None
|
|
base_highest_versions = (
|
|
self._build_highest_local_versions_by_base(record)
|
|
if same_base_mode and record
|
|
else {}
|
|
)
|
|
for item in items_for_id:
|
|
if same_base_mode and record is not None:
|
|
base_model = self._extract_base_model(item)
|
|
normalized_base = self._normalize_base_model_name(base_model)
|
|
threshold_version = (
|
|
base_highest_versions.get(normalized_base)
|
|
if normalized_base
|
|
else None
|
|
)
|
|
if threshold_version is None:
|
|
threshold_version = self._extract_version_id(item)
|
|
flag = record.has_update_for_base(
|
|
threshold_version,
|
|
base_model,
|
|
hide_early_access=hide_early_access,
|
|
)
|
|
else:
|
|
flag = default_flag
|
|
item["update_available"] = flag
|
|
|
|
return annotated
|
|
|
|
@staticmethod
|
|
def _extract_model_id(item: Dict) -> Optional[int]:
|
|
civitai = item.get("civitai") if isinstance(item, dict) else None
|
|
if not isinstance(civitai, dict):
|
|
return None
|
|
try:
|
|
value = civitai.get("modelId")
|
|
if value is None:
|
|
return None
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_version_id(item: Dict) -> Optional[int]:
|
|
civitai = item.get("civitai") if isinstance(item, dict) else None
|
|
if not isinstance(civitai, dict):
|
|
return None
|
|
value = civitai.get("id")
|
|
if value is None:
|
|
return None
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_base_model(item: Dict) -> Optional[str]:
|
|
value = item.get("base_model")
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
candidate = value.strip()
|
|
else:
|
|
try:
|
|
candidate = str(value).strip()
|
|
except Exception:
|
|
return None
|
|
return candidate if candidate else None
|
|
|
|
@staticmethod
|
|
def _normalize_base_model_name(value: Optional[str]) -> Optional[str]:
|
|
"""Return a lowercased, trimmed base model name for comparison."""
|
|
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
candidate = value.strip()
|
|
else:
|
|
try:
|
|
candidate = str(value).strip()
|
|
except Exception:
|
|
return None
|
|
return candidate.lower() if candidate else None
|
|
|
|
def _build_highest_local_versions_by_base(self, record) -> Dict[str, int]:
|
|
"""Return the highest local version id known for each normalized base model."""
|
|
|
|
if record is None:
|
|
return {}
|
|
|
|
highest_by_base: Dict[str, int] = {}
|
|
for version in getattr(record, "versions", []):
|
|
if not getattr(version, "is_in_library", False):
|
|
continue
|
|
normalized_base = self._normalize_base_model_name(
|
|
getattr(version, "base_model", None)
|
|
)
|
|
if normalized_base is None:
|
|
continue
|
|
version_id = getattr(version, "version_id", None)
|
|
if version_id is None:
|
|
continue
|
|
current_max = highest_by_base.get(normalized_base)
|
|
if current_max is None or version_id > current_max:
|
|
highest_by_base[normalized_base] = version_id
|
|
|
|
return highest_by_base
|
|
|
|
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
|
"""Apply pagination to filtered data"""
|
|
total_items = len(data)
|
|
start_idx = (page - 1) * page_size
|
|
end_idx = min(start_idx + page_size, total_items)
|
|
|
|
return {
|
|
"items": data[start_idx:end_idx],
|
|
"total": total_items,
|
|
"page": page,
|
|
"page_size": page_size,
|
|
"total_pages": (total_items + page_size - 1) // page_size,
|
|
}
|
|
|
|
@abstractmethod
|
|
async def format_response(self, model_data: Dict) -> Dict:
|
|
"""Format model data for API response - must be implemented by subclasses"""
|
|
pass
|
|
|
|
# Common service methods that delegate to scanner
|
|
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
|
"""Get top tags sorted by frequency"""
|
|
return await self.scanner.get_top_tags(limit)
|
|
|
|
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
|
"""Get base models sorted by frequency"""
|
|
return await self.scanner.get_base_models(limit)
|
|
|
|
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
|
|
"""Get counts of sub-types present in the cache."""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
type_counts: Dict[str, int] = {}
|
|
for entry in cache.raw_data:
|
|
normalized_type = normalize_sub_type(resolve_sub_type(entry))
|
|
if not normalized_type:
|
|
continue
|
|
|
|
# Filter by valid sub-types based on scanner type
|
|
if (
|
|
self.model_type == "lora"
|
|
and normalized_type not in VALID_LORA_SUB_TYPES
|
|
):
|
|
continue
|
|
if (
|
|
self.model_type == "checkpoint"
|
|
and normalized_type not in VALID_CHECKPOINT_SUB_TYPES
|
|
):
|
|
continue
|
|
|
|
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
|
|
|
|
sorted_types = sorted(
|
|
[
|
|
{"type": model_type, "count": count}
|
|
for model_type, count in type_counts.items()
|
|
],
|
|
key=lambda value: value["count"],
|
|
reverse=True,
|
|
)
|
|
|
|
return sorted_types[:limit]
|
|
|
|
def has_hash(self, sha256: str) -> bool:
|
|
"""Check if a model with given hash exists"""
|
|
return self.scanner.has_hash(sha256)
|
|
|
|
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
|
"""Get file path for a model by its hash"""
|
|
return self.scanner.get_path_by_hash(sha256)
|
|
|
|
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
|
"""Get hash for a model by its file path"""
|
|
return self.scanner.get_hash_by_path(file_path)
|
|
|
|
async def scan_models(
|
|
self, force_refresh: bool = False, rebuild_cache: bool = False
|
|
):
|
|
"""Trigger model scanning"""
|
|
return await self.scanner.get_cached_data(
|
|
force_refresh=force_refresh, rebuild_cache=rebuild_cache
|
|
)
|
|
|
|
async def get_model_info_by_name(self, name: str):
|
|
"""Get model information by name"""
|
|
return await self.scanner.get_model_info_by_name(name)
|
|
|
|
def get_model_roots(self) -> List[str]:
|
|
"""Get model root directories"""
|
|
return self.scanner.get_model_roots()
|
|
|
|
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
|
"""Filter relevant fields from CivitAI data"""
|
|
if not data:
|
|
return {}
|
|
|
|
fields = (
|
|
["id", "modelId", "name", "trainedWords"]
|
|
if minimal
|
|
else [
|
|
"id",
|
|
"modelId",
|
|
"name",
|
|
"createdAt",
|
|
"updatedAt",
|
|
"publishedAt",
|
|
"trainedWords",
|
|
"baseModel",
|
|
"description",
|
|
"model",
|
|
"images",
|
|
"customImages",
|
|
"creator",
|
|
]
|
|
)
|
|
return {k: data[k] for k in fields if k in data}
|
|
|
|
async def get_folder_tree(self, model_root: str) -> Dict:
|
|
"""Get hierarchical folder tree for a specific model root"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
# Build tree structure from folders
|
|
tree = {}
|
|
|
|
for folder in cache.folders:
|
|
# Check if this folder belongs to the specified model root
|
|
folder_belongs_to_root = False
|
|
for root in self.scanner.get_model_roots():
|
|
if root == model_root:
|
|
folder_belongs_to_root = True
|
|
break
|
|
|
|
if not folder_belongs_to_root:
|
|
continue
|
|
|
|
# Split folder path into components
|
|
parts = folder.split("/") if folder else []
|
|
current_level = tree
|
|
|
|
for part in parts:
|
|
if part not in current_level:
|
|
current_level[part] = {}
|
|
current_level = current_level[part]
|
|
|
|
return tree
|
|
|
|
async def get_unified_folder_tree(self) -> Dict:
|
|
"""Get unified folder tree across all model roots"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
# Build unified tree structure by analyzing all relative paths
|
|
unified_tree = {}
|
|
|
|
# Get all model roots for path normalization
|
|
model_roots = self.scanner.get_model_roots()
|
|
|
|
for folder in cache.folders:
|
|
if not folder: # Skip empty folders
|
|
continue
|
|
|
|
# Find which root this folder belongs to by checking the actual file paths
|
|
# This is a simplified approach - we'll use the folder as-is since it should already be relative
|
|
relative_path = folder
|
|
|
|
# Split folder path into components
|
|
parts = relative_path.split("/")
|
|
current_level = unified_tree
|
|
|
|
for part in parts:
|
|
if part not in current_level:
|
|
current_level[part] = {}
|
|
current_level = current_level[part]
|
|
|
|
return unified_tree
|
|
|
|
async def get_model_notes(self, model_name: str) -> Optional[str]:
|
|
"""Get notes for a specific model file"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
for model in cache.raw_data:
|
|
if model["file_name"] == model_name:
|
|
return model.get("notes", "")
|
|
|
|
return None
|
|
|
|
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
|
|
"""Get the static preview URL for a model file"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
for model in cache.raw_data:
|
|
if model["file_name"] == model_name:
|
|
preview_url = model.get("preview_url")
|
|
if preview_url:
|
|
from ..config import config
|
|
|
|
return config.get_preview_static_url(preview_url)
|
|
|
|
return "/loras_static/images/no-preview.png"
|
|
|
|
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
|
"""Get the Civitai URL for a model file"""
|
|
cache = await self.scanner.get_cached_data()
|
|
|
|
for model in cache.raw_data:
|
|
if model["file_name"] == model_name:
|
|
civitai_data = model.get("civitai", {})
|
|
model_id = civitai_data.get("modelId")
|
|
version_id = civitai_data.get("id")
|
|
|
|
if model_id:
|
|
civitai_url = f"https://civitai.com/models/{model_id}"
|
|
if version_id:
|
|
civitai_url += f"?modelVersionId={version_id}"
|
|
|
|
return {
|
|
"civitai_url": civitai_url,
|
|
"model_id": str(model_id),
|
|
"version_id": str(version_id) if version_id else None,
|
|
}
|
|
|
|
return {"civitai_url": None, "model_id": None, "version_id": None}
|
|
|
|
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
|
"""Load full metadata for a single model.
|
|
|
|
Listing/search endpoints return lightweight cache entries; this method performs
|
|
a lazy read of the on-disk metadata snapshot when callers need full detail.
|
|
"""
|
|
metadata, should_skip = await MetadataManager.load_metadata(
|
|
file_path, self.metadata_class
|
|
)
|
|
if should_skip or metadata is None:
|
|
return None
|
|
return self.filter_civitai_data(metadata.to_dict().get("civitai", {}))
|
|
|
|
async def get_model_description(self, file_path: str) -> Optional[str]:
|
|
"""Return the stored modelDescription field for a model."""
|
|
metadata, should_skip = await MetadataManager.load_metadata(
|
|
file_path, self.metadata_class
|
|
)
|
|
if should_skip or metadata is None:
|
|
return None
|
|
return metadata.modelDescription or ""
|
|
|
|
@staticmethod
|
|
def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]:
|
|
"""Split a search string into include and exclude tokens."""
|
|
include_terms: List[str] = []
|
|
exclude_terms: List[str] = []
|
|
|
|
for raw_term in search_term.split():
|
|
term = raw_term.strip()
|
|
if not term:
|
|
continue
|
|
|
|
if term.startswith("-") and len(term) > 1:
|
|
exclude_terms.append(term[1:].lower())
|
|
else:
|
|
include_terms.append(term.lower())
|
|
|
|
return include_terms, exclude_terms
|
|
|
|
@staticmethod
|
|
def _remove_model_extension(path: str) -> str:
|
|
"""Remove model file extension (.safetensors, .ckpt, .pt, .bin) for cleaner matching."""
|
|
return re.sub(r"\.(safetensors|ckpt|pt|bin)$", "", path, flags=re.IGNORECASE)
|
|
|
|
@staticmethod
|
|
def _relative_path_matches_tokens(
|
|
path_lower: str, include_terms: List[str], exclude_terms: List[str]
|
|
) -> bool:
|
|
"""Determine whether a relative path string satisfies include/exclude tokens.
|
|
|
|
Matches against the path without extension to avoid matching .safetensors
|
|
when searching for 's'.
|
|
"""
|
|
# Use path without extension for matching
|
|
path_for_matching = BaseModelService._remove_model_extension(path_lower)
|
|
|
|
if any(term and term in path_for_matching for term in exclude_terms):
|
|
return False
|
|
|
|
for term in include_terms:
|
|
if term and term not in path_for_matching:
|
|
return False
|
|
|
|
return True
|
|
|
|
@staticmethod
|
|
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
|
|
"""Sort paths by how well they satisfy the include tokens.
|
|
|
|
Sorts based on path without extension for consistent ordering.
|
|
"""
|
|
# Use path without extension for sorting
|
|
path_for_sorting = BaseModelService._remove_model_extension(
|
|
relative_path.lower()
|
|
)
|
|
prefix_hits = sum(
|
|
1 for term in include_terms if term and path_for_sorting.startswith(term)
|
|
)
|
|
match_positions = [
|
|
path_for_sorting.find(term)
|
|
for term in include_terms
|
|
if term and term in path_for_sorting
|
|
]
|
|
first_match_index = min(match_positions) if match_positions else 0
|
|
|
|
return (
|
|
-prefix_hits,
|
|
first_match_index,
|
|
len(path_for_sorting),
|
|
path_for_sorting,
|
|
)
|
|
|
|
async def search_relative_paths(
|
|
self, search_term: str, limit: int = 15, offset: int = 0
|
|
) -> List[str]:
|
|
"""Search model relative file paths for autocomplete functionality"""
|
|
cache = await self.scanner.get_cached_data()
|
|
include_terms, exclude_terms = self._parse_search_tokens(search_term)
|
|
|
|
matching_paths = []
|
|
|
|
# Get model roots for path calculation
|
|
model_roots = self.scanner.get_model_roots()
|
|
|
|
# Collect all matching paths first (needed for proper sorting and offset)
|
|
for model in cache.raw_data:
|
|
file_path = model.get("file_path", "")
|
|
if not file_path:
|
|
continue
|
|
|
|
# Calculate relative path from model root
|
|
relative_path = None
|
|
for root in model_roots:
|
|
# Normalize paths for comparison
|
|
normalized_root = os.path.normpath(root)
|
|
normalized_file = os.path.normpath(file_path)
|
|
|
|
if normalized_file.startswith(normalized_root):
|
|
# Remove root and leading separator to get relative path
|
|
relative_path = normalized_file[len(normalized_root) :].lstrip(
|
|
os.sep
|
|
)
|
|
break
|
|
|
|
if not relative_path:
|
|
continue
|
|
|
|
relative_lower = relative_path.lower()
|
|
if self._relative_path_matches_tokens(
|
|
relative_lower, include_terms, exclude_terms
|
|
):
|
|
matching_paths.append(relative_path)
|
|
|
|
# Sort by relevance (prefix and earliest hits first, then by length and alphabetically)
|
|
matching_paths.sort(
|
|
key=lambda relative: self._relative_path_sort_key(relative, include_terms)
|
|
)
|
|
|
|
# Apply offset and limit
|
|
start = min(offset, len(matching_paths))
|
|
end = min(start + limit, len(matching_paths))
|
|
return matching_paths[start:end]
|