feat(ui): auto-detect HIGH/LOW badges and auto-tag filters (#918)

- Backend auto-tag extraction service: detect HIGH/LOW (Wan-only), I2V/T2V/TI2V,
  Lightning/Turbo from filename, base_model, and CivitAI version name
- HIGH/LOW badge in card footer (inline before version name), color-coded:
  blue for HIGH, teal for LOW; abbreviated to H/L in medium/compact density
- Auto-tag filter panel (I2V, T2V, TI2V, Lightning, Turbo) with tri-state
  include/exclude filtering
- Full filter pipeline: FilterCriteria → ModelFilterSet → baseModelApi params
- AUTO_TAG_GROUPS exported for frontend use
- 19 unit tests for auto-tag extraction edge cases
This commit is contained in:
Will Miao
2026-05-17 17:45:12 +08:00
parent a74cbe7aa2
commit cc20d3b992
23 changed files with 524 additions and 8 deletions

View File

@@ -301,6 +301,15 @@ class ModelListingHandler:
for tag in exclude_tags:
if tag:
tag_filters[tag] = "exclude"
auto_tag_filters: Dict[str, str] = {}
for tag in request.query.getall("auto_tag_include", []):
if tag:
auto_tag_filters[tag] = "include"
for tag in request.query.getall("auto_tag_exclude", []):
if tag:
auto_tag_filters[tag] = "exclude"
favorites_only = request.query.get("favorites_only", "false").lower() == "true"
search_options = {
@@ -367,6 +376,7 @@ class ModelListingHandler:
"fuzzy_search": fuzzy_search,
"base_models": base_models,
"tags": tag_filters,
"auto_tags": auto_tag_filters,
"tag_logic": tag_logic,
"search_options": search_options,
"hash_filters": hash_filters,

View File

@@ -0,0 +1,121 @@
"""
Auto-tag extraction service for model cards.
Extracts implicit model attributes (HIGH/LOW, I2V/T2V/TI2V, Lightning, Turbo)
from filename, base_model, and CivitAI version name — no manual tagging required.
"""
from __future__ import annotations
import re
from typing import Dict, List, Set
# ── Tag category definitions ──────────────────────────────────────────
# Each category maps a display label to a regex pattern.
# Patterns are case-insensitive and matched against filename, base_model,
# and civitai version name.
# Use (?<![a-zA-Z0-9]) and (?![a-zA-Z0-9]) instead of \b because
# Python's \b treats underscore as a word character, so \bHIGH\b
# won't match '_HIGH_' in filenames.
_B = r"(?<![a-zA-Z0-9])" # left boundary
_E = r"(?![a-zA-Z0-9])" # right boundary
AUTO_TAG_CATEGORIES: Dict[str, str] = {
"HIGH": _B + r"HIGH" + _E,
"LOW": _B + r"(?<!F)LOW" + _E,
"I2V": _B + r"I2V" + _E,
"T2V": _B + r"T2V" + _E,
"TI2V": _B + r"TI2V" + _E,
"Lightning": _B + r"Lightning" + _E,
"Turbo": _B + r"Turbo" + _E,
}
# Tags that belong to the "mode" group (HIGH/LOW)
MODE_TAGS = {"HIGH", "LOW"}
# Tags that belong to the "video mode" group (I2V/T2V/TI2V)
VIDEO_MODE_TAGS = {"I2V", "T2V", "TI2V"}
# Tags that belong to the "speed/optimization" group
SPEED_TAGS = {"Lightning", "Turbo"}
# ── Display category groups (for settings UI) ─────────────────────────
AUTO_TAG_GROUPS = {
"mode": {"HIGH", "LOW"},
"video": {"I2V", "T2V", "TI2V"},
"speed": {"Lightning", "Turbo"},
}
# Default enabled categories
DEFAULT_ENABLED_GROUPS = {"mode", "video"}
def _collect_sources(model_data: Dict) -> List[str]:
"""Collect all text sources from model data for tag matching."""
sources: List[str] = []
file_name = model_data.get("file_name", "")
if file_name:
sources.append(file_name)
base_model = model_data.get("base_model", "")
if base_model:
sources.append(base_model)
civitai = model_data.get("civitai", {})
if isinstance(civitai, dict):
version_name = civitai.get("name", "")
if version_name:
sources.append(version_name)
return sources
def extract_auto_tags(model_data: Dict) -> List[str]:
"""Extract auto-detected tags from model metadata.
Matches predefined patterns against filename, base_model, and
CivitAI version name. Returns a sorted, deduplicated list of tag labels.
HIGH/LOW tags are only returned when the base_model indicates a Wan
family model — no other model architecture uses this distinction.
Args:
model_data: Model metadata dict with keys:
file_name, base_model, civitai (with optional 'name' field).
Returns:
Sorted list of unique auto-tag strings (e.g. ["I2V"]).
"""
sources = _collect_sources(model_data)
if not sources:
return []
base_model = model_data.get("base_model", "")
is_wan = "wan" in base_model.lower()
found: Set[str] = set()
for label, pattern in AUTO_TAG_CATEGORIES.items():
# HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise
if label in ("HIGH", "LOW"):
if not is_wan:
continue
# Use case-insensitive character class + case-sensitive boundary,
# so "HighNoise" (camelCase) matches but "highlight" doesn't.
# Boundary: not followed by lowercase letter (= word has ended).
ci = "".join(f"[{c.lower()}{c.upper()}]" for c in label)
if label == "LOW":
regex = re.compile(r"(?<![Ff])" + ci + r"(?![a-z])")
else:
regex = re.compile(ci + r"(?![a-z])")
else:
regex = re.compile(pattern, re.IGNORECASE)
for source in sources:
if regex.search(source):
found.add(label)
break
return sorted(found)

View File

@@ -77,6 +77,7 @@ class BaseModelService(ABC):
base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None,
auto_tags: Optional[Dict[str, str]] = None,
search_options: dict = None,
hash_filters: dict = None,
favorites_only: bool = False,
@@ -95,6 +96,11 @@ class BaseModelService(ABC):
sorted_data = await self._fetch_with_usage_sort(sort_params)
else:
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
# Pre-compute auto_tags for every item — needed for both filtering
# and display. Computation is cheap (string regex on 2-3 fields).
from .auto_tag_service import extract_auto_tags
for item in sorted_data:
item["auto_tags"] = extract_auto_tags(item)
fetch_duration = time.perf_counter() - t0
initial_count = len(sorted_data)
@@ -110,6 +116,7 @@ class BaseModelService(ABC):
base_models=base_models,
model_types=model_types,
tags=tags,
auto_tags=auto_tags,
favorites_only=favorites_only,
search_options=search_options,
tag_logic=tag_logic,
@@ -354,6 +361,7 @@ class BaseModelService(ABC):
base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None,
auto_tags: Optional[Dict[str, str]] = None,
favorites_only: bool = False,
search_options: dict = None,
tag_logic: str = "any",
@@ -367,6 +375,7 @@ class BaseModelService(ABC):
base_models=base_models,
model_types=model_types,
tags=tags,
auto_tags=auto_tags,
favorites_only=favorites_only,
search_options=normalized_options,
tag_logic=tag_logic,

View File

@@ -3,6 +3,7 @@ import logging
from typing import Dict
from .base_model_service import BaseModelService
from .auto_tag_service import extract_auto_tags
from ..utils.models import CheckpointMetadata
from ..config import config
@@ -45,7 +46,8 @@ class CheckpointService(BaseModelService):
"exclude": bool(checkpoint_data.get("exclude", False)),
"update_available": bool(checkpoint_data.get("update_available", False)),
"skip_metadata_refresh": bool(checkpoint_data.get("skip_metadata_refresh", False)),
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True),
"auto_tags": checkpoint_data.get("auto_tags") or extract_auto_tags(checkpoint_data),
}
def find_duplicate_hashes(self) -> Dict:

View File

@@ -3,6 +3,7 @@ import logging
from typing import Dict
from .base_model_service import BaseModelService
from .auto_tag_service import extract_auto_tags
from ..utils.models import EmbeddingMetadata
from ..config import config
@@ -45,7 +46,8 @@ class EmbeddingService(BaseModelService):
"exclude": bool(embedding_data.get("exclude", False)),
"update_available": bool(embedding_data.get("update_available", False)),
"skip_metadata_refresh": bool(embedding_data.get("skip_metadata_refresh", False)),
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True),
"auto_tags": embedding_data.get("auto_tags") or extract_auto_tags(embedding_data),
}
def find_duplicate_hashes(self) -> Dict:

View File

@@ -5,6 +5,7 @@ from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from .model_query import resolve_sub_type
from .auto_tag_service import extract_auto_tags
from ..utils.models import LoraMetadata
from ..config import config
@@ -57,6 +58,7 @@ class LoraService(BaseModelService):
"civitai": self.filter_civitai_data(
lora_data.get("civitai", {}), minimal=True
),
"auto_tags": lora_data.get("auto_tags") or extract_auto_tags(lora_data),
}
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:

View File

@@ -96,6 +96,7 @@ class FilterCriteria:
folder_exclude: Optional[Sequence[str]] = None
base_models: Optional[Sequence[str]] = None
tags: Optional[Dict[str, str]] = None
auto_tags: Optional[Dict[str, str]] = None
favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None
model_types: Optional[Sequence[str]] = None
@@ -359,10 +360,37 @@ class ModelFilterSet:
]
model_types_duration = time.perf_counter() - t0
auto_tags_duration = 0
auto_tag_filters = criteria.auto_tags or {}
if auto_tag_filters:
t0 = time.perf_counter()
include_at = set()
exclude_at = set()
for tag, state in auto_tag_filters.items():
if not tag:
continue
if state == "exclude":
exclude_at.add(tag)
else:
include_at.add(tag)
if include_at:
items = [
item for item in items
if any(tag in include_at for tag in (item.get("auto_tags") or []))
]
if exclude_at:
items = [
item for item in items
if not any(tag in exclude_at for tag in (item.get("auto_tags") or []))
]
auto_tags_duration = time.perf_counter() - t0
duration = time.perf_counter() - overall_start
if duration > 0.1: # Only log if it's potentially slow
logger.debug(
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs). "
"ModelFilterSet.apply took %.3fs (sfw: %.3fs, fav: %.3fs, folder: %.3fs, base: %.3fs, tags: %.3fs, types: %.3fs, auto_tags: %.3fs). "
"Count: %d -> %d",
duration,
sfw_duration,
@@ -371,6 +399,7 @@ class ModelFilterSet:
base_models_duration,
tags_duration,
model_types_duration,
auto_tags_duration,
initial_count,
len(items),
)