mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-17 10:37:35 -03:00
feat(ui): auto-detect HIGH/LOW badges and auto-tag filters (#918)
- Backend auto-tag extraction service: detect HIGH/LOW (Wan-only), I2V/T2V/TI2V, Lightning/Turbo from filename, base_model, and CivitAI version name - HIGH/LOW badge in card footer (inline before version name), color-coded: blue for HIGH, teal for LOW; abbreviated to H/L in medium/compact density - Auto-tag filter panel (I2V, T2V, TI2V, Lightning, Turbo) with tri-state include/exclude filtering - Full filter pipeline: FilterCriteria → ModelFilterSet → baseModelApi params - AUTO_TAG_GROUPS exported for frontend use - 19 unit tests for auto-tag extraction edge cases
This commit is contained in:
@@ -301,6 +301,15 @@ class ModelListingHandler:
|
||||
for tag in exclude_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
auto_tag_filters: Dict[str, str] = {}
|
||||
for tag in request.query.getall("auto_tag_include", []):
|
||||
if tag:
|
||||
auto_tag_filters[tag] = "include"
|
||||
for tag in request.query.getall("auto_tag_exclude", []):
|
||||
if tag:
|
||||
auto_tag_filters[tag] = "exclude"
|
||||
|
||||
favorites_only = request.query.get("favorites_only", "false").lower() == "true"
|
||||
|
||||
search_options = {
|
||||
@@ -367,6 +376,7 @@ class ModelListingHandler:
|
||||
"fuzzy_search": fuzzy_search,
|
||||
"base_models": base_models,
|
||||
"tags": tag_filters,
|
||||
"auto_tags": auto_tag_filters,
|
||||
"tag_logic": tag_logic,
|
||||
"search_options": search_options,
|
||||
"hash_filters": hash_filters,
|
||||
|
||||
121
py/services/auto_tag_service.py
Normal file
121
py/services/auto_tag_service.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Auto-tag extraction service for model cards.
|
||||
|
||||
Extracts implicit model attributes (HIGH/LOW, I2V/T2V/TI2V, Lightning, Turbo)
|
||||
from filename, base_model, and CivitAI version name — no manual tagging required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Set
|
||||
|
||||
# ── Tag category definitions ──────────────────────────────────────────
|
||||
# Each category maps a display label to a regex pattern.
|
||||
# Patterns are case-insensitive and matched against filename, base_model,
|
||||
# and civitai version name.
|
||||
|
||||
# Use (?<![a-zA-Z0-9]) and (?![a-zA-Z0-9]) instead of \b because
|
||||
# Python's \b treats underscore as a word character, so \bHIGH\b
|
||||
# won't match '_HIGH_' in filenames.
|
||||
_B = r"(?<![a-zA-Z0-9])" # left boundary
|
||||
_E = r"(?![a-zA-Z0-9])" # right boundary
|
||||
|
||||
AUTO_TAG_CATEGORIES: Dict[str, str] = {
|
||||
"HIGH": _B + r"HIGH" + _E,
|
||||
"LOW": _B + r"(?<!F)LOW" + _E,
|
||||
"I2V": _B + r"I2V" + _E,
|
||||
"T2V": _B + r"T2V" + _E,
|
||||
"TI2V": _B + r"TI2V" + _E,
|
||||
"Lightning": _B + r"Lightning" + _E,
|
||||
"Turbo": _B + r"Turbo" + _E,
|
||||
}
|
||||
|
||||
# Tags that belong to the "mode" group (HIGH/LOW)
|
||||
MODE_TAGS = {"HIGH", "LOW"}
|
||||
|
||||
# Tags that belong to the "video mode" group (I2V/T2V/TI2V)
|
||||
VIDEO_MODE_TAGS = {"I2V", "T2V", "TI2V"}
|
||||
|
||||
# Tags that belong to the "speed/optimization" group
|
||||
SPEED_TAGS = {"Lightning", "Turbo"}
|
||||
|
||||
# ── Display category groups (for settings UI) ─────────────────────────
|
||||
|
||||
AUTO_TAG_GROUPS = {
|
||||
"mode": {"HIGH", "LOW"},
|
||||
"video": {"I2V", "T2V", "TI2V"},
|
||||
"speed": {"Lightning", "Turbo"},
|
||||
}
|
||||
|
||||
# Default enabled categories
|
||||
DEFAULT_ENABLED_GROUPS = {"mode", "video"}
|
||||
|
||||
|
||||
def _collect_sources(model_data: Dict) -> List[str]:
|
||||
"""Collect all text sources from model data for tag matching."""
|
||||
sources: List[str] = []
|
||||
|
||||
file_name = model_data.get("file_name", "")
|
||||
if file_name:
|
||||
sources.append(file_name)
|
||||
|
||||
base_model = model_data.get("base_model", "")
|
||||
if base_model:
|
||||
sources.append(base_model)
|
||||
|
||||
civitai = model_data.get("civitai", {})
|
||||
if isinstance(civitai, dict):
|
||||
version_name = civitai.get("name", "")
|
||||
if version_name:
|
||||
sources.append(version_name)
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
def extract_auto_tags(model_data: Dict) -> List[str]:
|
||||
"""Extract auto-detected tags from model metadata.
|
||||
|
||||
Matches predefined patterns against filename, base_model, and
|
||||
CivitAI version name. Returns a sorted, deduplicated list of tag labels.
|
||||
|
||||
HIGH/LOW tags are only returned when the base_model indicates a Wan
|
||||
family model — no other model architecture uses this distinction.
|
||||
|
||||
Args:
|
||||
model_data: Model metadata dict with keys:
|
||||
file_name, base_model, civitai (with optional 'name' field).
|
||||
|
||||
Returns:
|
||||
Sorted list of unique auto-tag strings (e.g. ["I2V"]).
|
||||
"""
|
||||
sources = _collect_sources(model_data)
|
||||
if not sources:
|
||||
return []
|
||||
|
||||
base_model = model_data.get("base_model", "")
|
||||
is_wan = "wan" in base_model.lower()
|
||||
|
||||
found: Set[str] = set()
|
||||
|
||||
for label, pattern in AUTO_TAG_CATEGORIES.items():
|
||||
# HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise
|
||||
if label in ("HIGH", "LOW"):
|
||||
if not is_wan:
|
||||
continue
|
||||
# Use case-insensitive character class + case-sensitive boundary,
|
||||
# so "HighNoise" (camelCase) matches but "highlight" doesn't.
|
||||
# Boundary: not followed by lowercase letter (= word has ended).
|
||||
ci = "".join(f"[{c.lower()}{c.upper()}]" for c in label)
|
||||
if label == "LOW":
|
||||
regex = re.compile(r"(?<![Ff])" + ci + r"(?![a-z])")
|
||||
else:
|
||||
regex = re.compile(ci + r"(?![a-z])")
|
||||
else:
|
||||
regex = re.compile(pattern, re.IGNORECASE)
|
||||
for source in sources:
|
||||
if regex.search(source):
|
||||
found.add(label)
|
||||
break
|
||||
|
||||
return sorted(found)
|
||||
@@ -77,6 +77,7 @@ class BaseModelService(ABC):
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
auto_tags: Optional[Dict[str, str]] = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
@@ -95,6 +96,11 @@ class BaseModelService(ABC):
|
||||
sorted_data = await self._fetch_with_usage_sort(sort_params)
|
||||
else:
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
# Pre-compute auto_tags for every item — needed for both filtering
|
||||
# and display. Computation is cheap (string regex on 2-3 fields).
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
for item in sorted_data:
|
||||
item["auto_tags"] = extract_auto_tags(item)
|
||||
fetch_duration = time.perf_counter() - t0
|
||||
initial_count = len(sorted_data)
|
||||
|
||||
@@ -110,6 +116,7 @@ class BaseModelService(ABC):
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
auto_tags=auto_tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=search_options,
|
||||
tag_logic=tag_logic,
|
||||
@@ -354,6 +361,7 @@ class BaseModelService(ABC):
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
auto_tags: Optional[Dict[str, str]] = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
tag_logic: str = "any",
|
||||
@@ -367,6 +375,7 @@ class BaseModelService(ABC):
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
auto_tags=auto_tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=normalized_options,
|
||||
tag_logic=tag_logic,
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
|
||||
@@ -45,7 +46,8 @@ class CheckpointService(BaseModelService):
|
||||
"exclude": bool(checkpoint_data.get("exclude", False)),
|
||||
"update_available": bool(checkpoint_data.get("update_available", False)),
|
||||
"skip_metadata_refresh": bool(checkpoint_data.get("skip_metadata_refresh", False)),
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True),
|
||||
"auto_tags": checkpoint_data.get("auto_tags") or extract_auto_tags(checkpoint_data),
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
|
||||
@@ -3,6 +3,7 @@ import logging
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
|
||||
@@ -45,7 +46,8 @@ class EmbeddingService(BaseModelService):
|
||||
"exclude": bool(embedding_data.get("exclude", False)),
|
||||
"update_available": bool(embedding_data.get("update_available", False)),
|
||||
"skip_metadata_refresh": bool(embedding_data.get("skip_metadata_refresh", False)),
|
||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True),
|
||||
"auto_tags": embedding_data.get("auto_tags") or extract_auto_tags(embedding_data),
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from .model_query import resolve_sub_type
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
|
||||
@@ -57,6 +58,7 @@ class LoraService(BaseModelService):
|
||||
"civitai": self.filter_civitai_data(
|
||||
lora_data.get("civitai", {}), minimal=True
|
||||
),
|
||||
"auto_tags": lora_data.get("auto_tags") or extract_auto_tags(lora_data),
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
|
||||
@@ -96,6 +96,7 @@ class FilterCriteria:
|
||||
folder_exclude: Optional[Sequence[str]] = None
|
||||
base_models: Optional[Sequence[str]] = None
|
||||
tags: Optional[Dict[str, str]] = None
|
||||
auto_tags: Optional[Dict[str, str]] = None
|
||||
favorites_only: bool = False
|
||||
search_options: Optional[Dict[str, Any]] = None
|
||||
model_types: Optional[Sequence[str]] = None
|
||||
@@ -359,10 +360,37 @@ class ModelFilterSet:
|
||||
]
|
||||
model_types_duration = time.perf_counter() - t0
|
||||
|
||||
auto_tags_duration = 0
|
||||
auto_tag_filters = criteria.auto_tags or {}
|
||||
if auto_tag_filters:
|
||||
t0 = time.perf_counter()
|
||||
include_at = set()
|
||||
exclude_at = set()
|
||||
for tag, state in auto_tag_filters.items():
|
||||
if not tag:
|
||||
continue
|
||||
if state == "exclude":
|
||||
exclude_at.add(tag)
|
||||
else:
|
||||
include_at.add(tag)
|
||||
|
||||
if include_at:
|
||||
items = [
|
||||
item for item in items
|
||||
if any(tag in include_at for tag in (item.get("auto_tags") or []))
|
||||
]
|
||||
|
||||
if exclude_at:
|
||||
items = [
|
||||
item for item in items
|
||||
if not any(tag in exclude_at for tag in (item.get("auto_tags") or []))
|
||||
]
|
||||
auto_tags_duration = time.perf_counter() - t0
|
||||
|
||||
duration = time.perf_counter() - overall_start
|
||||
if duration > 0.1: # Only log if it's potentially slow
|
||||
logger.debug(
|
||||
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). "
|
||||
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs, auto_tags: %.3fs). "
|
||||
"Count: %d -> %d",
|
||||
duration,
|
||||
sfw_duration,
|
||||
@@ -371,6 +399,7 @@ class ModelFilterSet:
|
||||
base_models_duration,
|
||||
tags_duration,
|
||||
model_types_duration,
|
||||
auto_tags_duration,
|
||||
initial_count,
|
||||
len(items),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user