feat(ui): merge user tags into auto-tag badges and refresh on tag edit (#918)

- Layer 2 fallback: user tags overlapping with auto-tag categories
  (HIGH/LOW/I2V/T2V/TI2V/Lightning/Turbo) are merged into auto_tags,
  providing manual override when filename-based detection fails.
  Matching is case-insensitive so "high"/"High"/"HIGH" all work.
- Refresh on tag edit: save_metadata and add_tags handlers now return
  recalculated auto_tags in the response; the frontend passes them to
  VirtualScroller.updateSingleItem so badges update immediately without
  requiring a page reload.
- 8 new test cases for Layer 2 fallback and case-insensitive matching.
This commit is contained in:
Will Miao
2026-05-20 22:48:44 +08:00
parent 9ce56dd40c
commit 78303b2a5e
7 changed files with 150 additions and 39 deletions

View File

@@ -788,7 +788,7 @@ class ModelManagementHandler:
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
await self._metadata_sync.save_metadata_updates(
updated_metadata = await self._metadata_sync.save_metadata_updates(
file_path=file_path,
updates=metadata_updates,
metadata_loader=self._metadata_sync.load_local_metadata,
@@ -799,7 +799,12 @@ class ModelManagementHandler:
cache = await self._service.scanner.get_cached_data()
await cache.resort()
return web.json_response({"success": True})
from ...services.auto_tag_service import extract_auto_tags
auto_tags = extract_auto_tags(updated_metadata)
return web.json_response(
{"success": True, "auto_tags": auto_tags}
)
except Exception as exc:
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
@@ -816,14 +821,16 @@ class ModelManagementHandler:
if not isinstance(new_tags, list):
return web.Response(text="Tags must be a list", status=400)
tags = await self._tag_update_service.add_tags(
tags, auto_tags = await self._tag_update_service.add_tags(
file_path=file_path,
new_tags=new_tags,
metadata_loader=self._metadata_sync.load_local_metadata,
update_cache=self._service.scanner.update_single_model_cache,
)
return web.json_response({"success": True, "tags": tags})
return web.json_response(
{"success": True, "tags": tags, "auto_tags": auto_tags}
)
except Exception as exc:
self._logger.error("Error adding tags: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)

View File

@@ -76,46 +76,64 @@ def _collect_sources(model_data: Dict) -> List[str]:
def extract_auto_tags(model_data: Dict) -> List[str]:
"""Extract auto-detected tags from model metadata.
Matches predefined patterns against filename, base_model, and
CivitAI version name. Returns a sorted, deduplicated list of tag labels.
Uses a two-layer approach:
Layer 1 — Regex-based detection against filename, base_model, and
CivitAI version name.
Layer 2 — Merge in any user-defined tags that overlap with known
auto-tag categories. This provides a manual fallback when
auto-detection fails (e.g. "I2V HN" or unlabeled models).
HIGH/LOW tags are only returned when the base_model indicates a Wan
family model — no other model architecture uses this distinction.
Args:
model_data: Model metadata dict with keys:
file_name, base_model, civitai (with optional 'name' field).
file_name, base_model, civitai (with optional 'name' field),
tags (user-defined tag list, used as fallback).
Returns:
Sorted list of unique auto-tag strings (e.g. ["I2V"]).
"""
sources = _collect_sources(model_data)
if not sources:
return []
base_model = model_data.get("base_model", "")
is_wan = "wan" in base_model.lower()
found: Set[str] = set()
for label, pattern in AUTO_TAG_CATEGORIES.items():
# HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise
if label in ("HIGH", "LOW"):
if not is_wan:
continue
# Use case-insensitive character class + case-sensitive boundary,
# so "HighNoise" (camelCase) matches but "highlight" doesn't.
# Boundary: not followed by lowercase letter (= word has ended).
ci = "".join(f"[{c.lower()}{c.upper()}]" for c in label)
if label == "LOW":
regex = re.compile(r"(?<![Ff])" + ci + r"(?![a-z])")
# ── Layer 1: regex-based detection ────────────────────────────
if sources:
for label, pattern in AUTO_TAG_CATEGORIES.items():
# HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise
if label in ("HIGH", "LOW"):
if not is_wan:
continue
# Use case-insensitive character class + case-sensitive boundary,
# so "HighNoise" (camelCase) matches but "highlight" doesn't.
# Boundary: not followed by lowercase letter (= word has ended).
ci = "".join(f"[{c.lower()}{c.upper()}]" for c in label)
if label == "LOW":
regex = re.compile(r"(?<![Ff])" + ci + r"(?![a-z])")
else:
regex = re.compile(ci + r"(?![a-z])")
else:
regex = re.compile(ci + r"(?![a-z])")
else:
regex = re.compile(pattern, re.IGNORECASE)
for source in sources:
if regex.search(source):
found.add(label)
break
regex = re.compile(pattern, re.IGNORECASE)
for source in sources:
if regex.search(source):
found.add(label)
break
# ── Layer 2: user-defined tags as manual fallback ─────────────
# When auto-detection fails (abbreviated names like "Hi"/"Lo",
# "I2V HN", or unlabeled models), users can add canonical tags
# (HIGH, LOW, I2V, etc.) to the model's regular tags for correct
# badge display and filtering. Matching is case-insensitive so
# "high"/"High"/"HIGH" all resolve to the canonical label.
user_tags = model_data.get("tags")
if user_tags:
label_map = {label.lower(): label for label in AUTO_TAG_CATEGORIES}
for t in user_tags:
canonical = label_map.get(t.lower())
if canonical:
found.add(canonical)
return sorted(found)

View File

@@ -4,7 +4,9 @@ from __future__ import annotations
import os
from typing import Awaitable, Callable, Dict, List, Sequence
from typing import Awaitable, Callable, Dict, List, Sequence, Tuple
from .auto_tag_service import extract_auto_tags
class TagUpdateService:
@@ -20,9 +22,8 @@ class TagUpdateService:
new_tags: Sequence[str],
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
) -> List[str]:
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
) -> Tuple[List[str], List[str]]:
"""Add tags to a metadata entry and return updated tags and auto_tags."""
base, _ = os.path.splitext(file_path)
metadata_path = f"{base}.metadata.json"
metadata = await metadata_loader(metadata_path)
@@ -44,5 +45,6 @@ class TagUpdateService:
await self._metadata_manager.save_metadata(file_path, metadata)
await update_cache(file_path, file_path, metadata)
return existing_tags
auto_tags = extract_auto_tags(metadata)
return existing_tags, auto_tags

View File

@@ -422,8 +422,12 @@ export class BaseModelApiClient {
throw new Error('Failed to save metadata');
}
state.virtualScroller.updateSingleItem(filePath, data);
return response.json();
const result = await response.json();
state.virtualScroller.updateSingleItem(filePath, {
...data,
auto_tags: result.auto_tags,
});
return result;
} finally {
state.loadingManager.hide();
}
@@ -448,7 +452,10 @@ export class BaseModelApiClient {
const result = await response.json();
if (result.success && result.tags) {
state.virtualScroller.updateSingleItem(filePath, { tags: result.tags });
state.virtualScroller.updateSingleItem(filePath, {
tags: result.tags,
auto_tags: result.auto_tags,
});
}
return result;

View File

@@ -255,7 +255,7 @@ def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
cache_updates.append(metadata)
return True
tags = asyncio.run(
tags, auto_tags = asyncio.run(
service.add_tags(
file_path=str(tmp_path / "model.safetensors"),
new_tags=["new", "existing"],
@@ -265,5 +265,6 @@ def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
)
assert tags == ["existing", "new"]
assert auto_tags == []
assert manager.saved
assert cache_updates

View File

@@ -43,7 +43,7 @@ async def test_tag_update_service_handles_case_insensitive_tags(tmp_path: Path)
return True
# Try to add "Test" (different case) - should not be added since "test" already exists
tags = await service.add_tags(
tags, auto_tags = await service.add_tags(
file_path=str(tmp_path / "model.safetensors"),
new_tags=["Test"],
metadata_loader=loader,
@@ -52,6 +52,7 @@ async def test_tag_update_service_handles_case_insensitive_tags(tmp_path: Path)
# Should still only have "test" (lowercase) in the tags
assert tags == ["test"]
assert auto_tags == [] # no file_name/base_model in metadata, so no auto-detection
assert len(manager.saved) == 1
saved_metadata = manager.saved[0][1]
assert saved_metadata["tags"] == ["test"]
@@ -76,7 +77,7 @@ async def test_tag_update_service_adds_new_tags_in_lowercase(tmp_path: Path) ->
return True
# Add new tags with mixed case
tags = await service.add_tags(
tags, auto_tags = await service.add_tags(
file_path=str(tmp_path / "model.safetensors"),
new_tags=["NewTag", "ANOTHER_TAG"],
metadata_loader=loader,
@@ -87,6 +88,7 @@ async def test_tag_update_service_adds_new_tags_in_lowercase(tmp_path: Path) ->
assert "existing" in tags
assert "newtag" in tags
assert "another_tag" in tags
assert auto_tags == []
assert len(manager.saved) == 1
saved_metadata = manager.saved[0][1]
assert "newtag" in saved_metadata["tags"]

View File

@@ -126,6 +126,80 @@ class TestExtractAutoTags:
})
assert set(result) == {"HIGH", "I2V"}
# ── Layer 2: user-defined tags as manual fallback ───────────
def test_user_tags_fallback_when_detection_fails(self):
result = extract_auto_tags({
"file_name": "BOTH-v1.0",
"base_model": "Wan 2.2",
"civitai": {},
"tags": ["HIGH", "I2V", "T2V"],
})
assert set(result) == {"HIGH", "I2V", "T2V"}
def test_user_tags_augment_partial_detection(self):
result = extract_auto_tags({
"file_name": "wan_i2v_hn_v2",
"base_model": "Wan 2.2 I2V",
"civitai": {},
"tags": ["HIGH"],
})
assert set(result) == {"HIGH", "I2V"}
def test_user_tags_non_auto_tag_ignored(self):
result = extract_auto_tags({
"file_name": "model_v1",
"base_model": "Wan 2.2",
"civitai": {},
"tags": ["HIGH", "character", "style", "nsfw"],
})
assert set(result) == {"HIGH"}
def test_user_tags_overrides_non_wan_gate(self):
result = extract_auto_tags({
"file_name": "flux_model_v1",
"base_model": "Flux.1 D",
"civitai": {},
"tags": ["HIGH", "LOW", "Turbo"],
})
assert set(result) == {"HIGH", "LOW", "Turbo"}
def test_user_tags_no_duplication(self):
result = extract_auto_tags({
"file_name": "wan_i2v_high_v3",
"base_model": "Wan 2.2",
"civitai": {},
"tags": ["HIGH", "I2V"],
})
assert set(result) == {"HIGH", "I2V"}
def test_user_tags_lightning_turbo_manual(self):
result = extract_auto_tags({
"file_name": "sdxl_model_v1",
"base_model": "SDXL",
"civitai": {},
"tags": ["Lightning"],
})
assert set(result) == {"Lightning"}
def test_user_tags_case_insensitive_lowercase(self):
result = extract_auto_tags({
"file_name": "wan_masterpieces_v2",
"base_model": "Wan Video 14B t2v",
"civitai": {},
"tags": ["high"],
})
assert set(result) == {"HIGH", "T2V"}
def test_user_tags_case_insensitive_mixed(self):
result = extract_auto_tags({
"file_name": "model_v1",
"base_model": "SDXL",
"civitai": {},
"tags": ["lightning", "turbo", "i2v"],
})
assert set(result) == {"Lightning", "Turbo", "I2V"}
class TestAutoTagCategories:
def test_all_patterns_compile(self):