mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Add model_types parameter to ModelListingHandler to support filtering by model type
- Implement get_model_types endpoint in ModelQueryHandler to retrieve available model types
- Register new /api/lm/{prefix}/model-types route for model type queries
- Extend BaseModelService to handle model type filtering in queries
- Support both model_type and civitai_model_type query parameters for backward compatibility
This enables users to filter models by specific types, improving model discovery and organization capabilities.
269 lines
9.0 KiB
Python
269 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
|
|
|
|
from ..utils.constants import NSFW_LEVELS
|
|
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
|
|
|
|
|
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
|
|
|
|
|
|
def _coerce_to_str(value: Any) -> Optional[str]:
|
|
if value is None:
|
|
return None
|
|
|
|
candidate = str(value).strip()
|
|
return candidate if candidate else None
|
|
|
|
|
|
def normalize_civitai_model_type(value: Any) -> Optional[str]:
|
|
"""Return a lowercase string suitable for comparisons."""
|
|
candidate = _coerce_to_str(value)
|
|
return candidate.lower() if candidate else None
|
|
|
|
|
|
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
|
|
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
|
|
if not isinstance(entry, Mapping):
|
|
return DEFAULT_CIVITAI_MODEL_TYPE
|
|
|
|
civitai = entry.get("civitai")
|
|
if isinstance(civitai, Mapping):
|
|
civitai_model = civitai.get("model")
|
|
if isinstance(civitai_model, Mapping):
|
|
model_type = _coerce_to_str(civitai_model.get("type"))
|
|
if model_type:
|
|
return model_type
|
|
|
|
model_type = _coerce_to_str(entry.get("model_type"))
|
|
if model_type:
|
|
return model_type
|
|
|
|
return DEFAULT_CIVITAI_MODEL_TYPE
|
|
|
|
|
|
class SettingsProvider(Protocol):
|
|
"""Protocol describing the SettingsManager contract used by query helpers."""
|
|
|
|
def get(self, key: str, default: Any = None) -> Any:
|
|
...
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SortParams:
|
|
"""Normalized representation of sorting instructions."""
|
|
|
|
key: str
|
|
order: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class FilterCriteria:
|
|
"""Container for model list filtering options."""
|
|
|
|
folder: Optional[str] = None
|
|
base_models: Optional[Sequence[str]] = None
|
|
tags: Optional[Dict[str, str]] = None
|
|
favorites_only: bool = False
|
|
search_options: Optional[Dict[str, Any]] = None
|
|
model_types: Optional[Sequence[str]] = None
|
|
|
|
|
|
class ModelCacheRepository:
|
|
"""Adapter around scanner cache access and sort normalisation."""
|
|
|
|
def __init__(self, scanner) -> None:
|
|
self._scanner = scanner
|
|
|
|
async def get_cache(self):
|
|
"""Return the underlying cache instance from the scanner."""
|
|
return await self._scanner.get_cached_data()
|
|
|
|
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
|
|
"""Fetch cached data pre-sorted according to ``params``."""
|
|
cache = await self.get_cache()
|
|
return await cache.get_sorted_data(params.key, params.order)
|
|
|
|
@staticmethod
|
|
def parse_sort(sort_by: str) -> SortParams:
|
|
"""Parse an incoming sort string into key/order primitives."""
|
|
if not sort_by:
|
|
return SortParams(key="name", order="asc")
|
|
|
|
if ":" in sort_by:
|
|
raw_key, raw_order = sort_by.split(":", 1)
|
|
sort_key = raw_key.strip().lower() or "name"
|
|
order = raw_order.strip().lower()
|
|
else:
|
|
sort_key = sort_by.strip().lower() or "name"
|
|
order = "asc"
|
|
|
|
if order not in ("asc", "desc"):
|
|
order = "asc"
|
|
|
|
return SortParams(key=sort_key, order=order)
|
|
|
|
|
|
class ModelFilterSet:
|
|
"""Applies common filtering rules to the model collection."""
|
|
|
|
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
|
|
self._settings = settings
|
|
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
|
|
|
|
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
|
"""Return items that satisfy the provided criteria."""
|
|
items = list(data)
|
|
|
|
if self._settings.get("show_only_sfw", False):
|
|
threshold = self._nsfw_levels.get("R", 0)
|
|
items = [
|
|
item for item in items
|
|
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
|
]
|
|
|
|
if criteria.favorites_only:
|
|
items = [item for item in items if item.get("favorite", False)]
|
|
|
|
folder = criteria.folder
|
|
options = criteria.search_options or {}
|
|
recursive = bool(options.get("recursive", True))
|
|
if folder is not None:
|
|
if recursive:
|
|
if folder:
|
|
folder_with_sep = f"{folder}/"
|
|
items = [
|
|
item for item in items
|
|
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
|
|
]
|
|
else:
|
|
items = [item for item in items if item.get("folder") == folder]
|
|
|
|
base_models = criteria.base_models or []
|
|
if base_models:
|
|
base_model_set = set(base_models)
|
|
items = [item for item in items if item.get("base_model") in base_model_set]
|
|
|
|
tag_filters = criteria.tags or {}
|
|
include_tags = set()
|
|
exclude_tags = set()
|
|
if isinstance(tag_filters, dict):
|
|
for tag, state in tag_filters.items():
|
|
if not tag:
|
|
continue
|
|
if state == "exclude":
|
|
exclude_tags.add(tag)
|
|
else:
|
|
include_tags.add(tag)
|
|
else:
|
|
include_tags = {tag for tag in tag_filters if tag}
|
|
|
|
if include_tags:
|
|
items = [
|
|
item for item in items
|
|
if any(tag in include_tags for tag in (item.get("tags", []) or []))
|
|
]
|
|
|
|
if exclude_tags:
|
|
items = [
|
|
item for item in items
|
|
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
|
|
]
|
|
|
|
model_types = criteria.model_types or []
|
|
normalized_model_types = {
|
|
model_type for model_type in (
|
|
normalize_civitai_model_type(value) for value in model_types
|
|
)
|
|
if model_type
|
|
}
|
|
if normalized_model_types:
|
|
items = [
|
|
item for item in items
|
|
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
|
|
]
|
|
|
|
return items
|
|
|
|
|
|
class SearchStrategy:
|
|
"""Encapsulates text and fuzzy matching behaviour for model queries."""
|
|
|
|
DEFAULT_OPTIONS: Dict[str, Any] = {
|
|
"filename": True,
|
|
"modelname": True,
|
|
"tags": False,
|
|
"recursive": True,
|
|
"creator": False,
|
|
}
|
|
|
|
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
|
|
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
|
|
|
|
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
"""Merge provided options with defaults without mutating input."""
|
|
normalized = dict(self.DEFAULT_OPTIONS)
|
|
if options:
|
|
normalized.update(options)
|
|
return normalized
|
|
|
|
def apply(
|
|
self,
|
|
data: Iterable[Dict[str, Any]],
|
|
search_term: str,
|
|
options: Dict[str, Any],
|
|
fuzzy: bool = False,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Return items matching the search term using the configured strategy."""
|
|
if not search_term:
|
|
return list(data)
|
|
|
|
search_lower = search_term.lower()
|
|
results: List[Dict[str, Any]] = []
|
|
|
|
for item in data:
|
|
if options.get("filename", True):
|
|
candidate = item.get("file_name", "")
|
|
if self._matches(candidate, search_term, search_lower, fuzzy):
|
|
results.append(item)
|
|
continue
|
|
|
|
if options.get("modelname", True):
|
|
candidate = item.get("model_name", "")
|
|
if self._matches(candidate, search_term, search_lower, fuzzy):
|
|
results.append(item)
|
|
continue
|
|
|
|
if options.get("tags", False):
|
|
tags = item.get("tags", []) or []
|
|
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
|
|
results.append(item)
|
|
continue
|
|
|
|
if options.get("creator", False):
|
|
creator_username = ""
|
|
civitai = item.get("civitai")
|
|
if isinstance(civitai, dict):
|
|
creator = civitai.get("creator")
|
|
if isinstance(creator, dict):
|
|
creator_username = creator.get("username", "")
|
|
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
|
|
results.append(item)
|
|
continue
|
|
|
|
return results
|
|
|
|
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
|
|
if not isinstance(candidate, str):
|
|
candidate = "" if candidate is None else str(candidate)
|
|
|
|
if not candidate:
|
|
return False
|
|
|
|
candidate_lower = candidate.lower()
|
|
if fuzzy:
|
|
return self._fuzzy_match(candidate, search_term)
|
|
return search_lower in candidate_lower
|