mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-10 04:49:24 -03:00
feat(ui): merge user tags into auto-tag badges and refresh on tag edit (#918)
- Layer 2 fallback: user tags overlapping with auto-tag categories (HIGH/LOW/I2V/T2V/TI2V/Lightning/Turbo) are merged into auto_tags, providing manual override when filename-based detection fails. Matching is case-insensitive so "high"/"High"/"HIGH" all work. - Refresh on tag edit: save_metadata and add_tags handlers now return recalculated auto_tags in the response; the frontend passes them to VirtualScroller.updateSingleItem so badges update immediately without requiring a page reload. - 8 new test cases for Layer 2 fallback and case-insensitive matching.
This commit is contained in:
@@ -788,7 +788,7 @@ class ModelManagementHandler:
|
|||||||
|
|
||||||
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
||||||
|
|
||||||
await self._metadata_sync.save_metadata_updates(
|
updated_metadata = await self._metadata_sync.save_metadata_updates(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
updates=metadata_updates,
|
updates=metadata_updates,
|
||||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
@@ -799,7 +799,12 @@ class ModelManagementHandler:
|
|||||||
cache = await self._service.scanner.get_cached_data()
|
cache = await self._service.scanner.get_cached_data()
|
||||||
await cache.resort()
|
await cache.resort()
|
||||||
|
|
||||||
return web.json_response({"success": True})
|
from ...services.auto_tag_service import extract_auto_tags
|
||||||
|
auto_tags = extract_auto_tags(updated_metadata)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{"success": True, "auto_tags": auto_tags}
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
||||||
return web.Response(text=str(exc), status=500)
|
return web.Response(text=str(exc), status=500)
|
||||||
@@ -816,14 +821,16 @@ class ModelManagementHandler:
|
|||||||
if not isinstance(new_tags, list):
|
if not isinstance(new_tags, list):
|
||||||
return web.Response(text="Tags must be a list", status=400)
|
return web.Response(text="Tags must be a list", status=400)
|
||||||
|
|
||||||
tags = await self._tag_update_service.add_tags(
|
tags, auto_tags = await self._tag_update_service.add_tags(
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
new_tags=new_tags,
|
new_tags=new_tags,
|
||||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
update_cache=self._service.scanner.update_single_model_cache,
|
update_cache=self._service.scanner.update_single_model_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
return web.json_response({"success": True, "tags": tags})
|
return web.json_response(
|
||||||
|
{"success": True, "tags": tags, "auto_tags": auto_tags}
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
||||||
return web.Response(text=str(exc), status=500)
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|||||||
@@ -76,28 +76,32 @@ def _collect_sources(model_data: Dict) -> List[str]:
|
|||||||
def extract_auto_tags(model_data: Dict) -> List[str]:
|
def extract_auto_tags(model_data: Dict) -> List[str]:
|
||||||
"""Extract auto-detected tags from model metadata.
|
"""Extract auto-detected tags from model metadata.
|
||||||
|
|
||||||
Matches predefined patterns against filename, base_model, and
|
Uses a two-layer approach:
|
||||||
CivitAI version name. Returns a sorted, deduplicated list of tag labels.
|
Layer 1 — Regex-based detection against filename, base_model, and
|
||||||
|
CivitAI version name.
|
||||||
|
Layer 2 — Merge in any user-defined tags that overlap with known
|
||||||
|
auto-tag categories. This provides a manual fallback when
|
||||||
|
auto-detection fails (e.g. "I2V HN" or unlabeled models).
|
||||||
|
|
||||||
HIGH/LOW tags are only returned when the base_model indicates a Wan
|
HIGH/LOW tags are only returned when the base_model indicates a Wan
|
||||||
family model — no other model architecture uses this distinction.
|
family model — no other model architecture uses this distinction.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_data: Model metadata dict with keys:
|
model_data: Model metadata dict with keys:
|
||||||
file_name, base_model, civitai (with optional 'name' field).
|
file_name, base_model, civitai (with optional 'name' field),
|
||||||
|
tags (user-defined tag list, used as fallback).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sorted list of unique auto-tag strings (e.g. ["I2V"]).
|
Sorted list of unique auto-tag strings (e.g. ["I2V"]).
|
||||||
"""
|
"""
|
||||||
sources = _collect_sources(model_data)
|
sources = _collect_sources(model_data)
|
||||||
if not sources:
|
|
||||||
return []
|
|
||||||
|
|
||||||
base_model = model_data.get("base_model", "")
|
base_model = model_data.get("base_model", "")
|
||||||
is_wan = "wan" in base_model.lower()
|
is_wan = "wan" in base_model.lower()
|
||||||
|
|
||||||
found: Set[str] = set()
|
found: Set[str] = set()
|
||||||
|
|
||||||
|
# ── Layer 1: regex-based detection ────────────────────────────
|
||||||
|
if sources:
|
||||||
for label, pattern in AUTO_TAG_CATEGORIES.items():
|
for label, pattern in AUTO_TAG_CATEGORIES.items():
|
||||||
# HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise
|
# HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise
|
||||||
if label in ("HIGH", "LOW"):
|
if label in ("HIGH", "LOW"):
|
||||||
@@ -118,4 +122,18 @@ def extract_auto_tags(model_data: Dict) -> List[str]:
|
|||||||
found.add(label)
|
found.add(label)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# ── Layer 2: user-defined tags as manual fallback ─────────────
|
||||||
|
# When auto-detection fails (abbreviated names like "Hi"/"Lo",
|
||||||
|
# "I2V HN", or unlabeled models), users can add canonical tags
|
||||||
|
# (HIGH, LOW, I2V, etc.) to the model's regular tags for correct
|
||||||
|
# badge display and filtering. Matching is case-insensitive so
|
||||||
|
# "high"/"High"/"HIGH" all resolve to the canonical label.
|
||||||
|
user_tags = model_data.get("tags")
|
||||||
|
if user_tags:
|
||||||
|
label_map = {label.lower(): label for label in AUTO_TAG_CATEGORIES}
|
||||||
|
for t in user_tags:
|
||||||
|
canonical = label_map.get(t.lower())
|
||||||
|
if canonical:
|
||||||
|
found.add(canonical)
|
||||||
|
|
||||||
return sorted(found)
|
return sorted(found)
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from typing import Awaitable, Callable, Dict, List, Sequence
|
from typing import Awaitable, Callable, Dict, List, Sequence, Tuple
|
||||||
|
|
||||||
|
from .auto_tag_service import extract_auto_tags
|
||||||
|
|
||||||
|
|
||||||
class TagUpdateService:
|
class TagUpdateService:
|
||||||
@@ -20,9 +22,8 @@ class TagUpdateService:
|
|||||||
new_tags: Sequence[str],
|
new_tags: Sequence[str],
|
||||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
|
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
|
||||||
) -> List[str]:
|
) -> Tuple[List[str], List[str]]:
|
||||||
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
|
"""Add tags to a metadata entry and return updated tags and auto_tags."""
|
||||||
|
|
||||||
base, _ = os.path.splitext(file_path)
|
base, _ = os.path.splitext(file_path)
|
||||||
metadata_path = f"{base}.metadata.json"
|
metadata_path = f"{base}.metadata.json"
|
||||||
metadata = await metadata_loader(metadata_path)
|
metadata = await metadata_loader(metadata_path)
|
||||||
@@ -44,5 +45,6 @@ class TagUpdateService:
|
|||||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||||
await update_cache(file_path, file_path, metadata)
|
await update_cache(file_path, file_path, metadata)
|
||||||
|
|
||||||
return existing_tags
|
auto_tags = extract_auto_tags(metadata)
|
||||||
|
return existing_tags, auto_tags
|
||||||
|
|
||||||
|
|||||||
@@ -422,8 +422,12 @@ export class BaseModelApiClient {
|
|||||||
throw new Error('Failed to save metadata');
|
throw new Error('Failed to save metadata');
|
||||||
}
|
}
|
||||||
|
|
||||||
state.virtualScroller.updateSingleItem(filePath, data);
|
const result = await response.json();
|
||||||
return response.json();
|
state.virtualScroller.updateSingleItem(filePath, {
|
||||||
|
...data,
|
||||||
|
auto_tags: result.auto_tags,
|
||||||
|
});
|
||||||
|
return result;
|
||||||
} finally {
|
} finally {
|
||||||
state.loadingManager.hide();
|
state.loadingManager.hide();
|
||||||
}
|
}
|
||||||
@@ -448,7 +452,10 @@ export class BaseModelApiClient {
|
|||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
|
|
||||||
if (result.success && result.tags) {
|
if (result.success && result.tags) {
|
||||||
state.virtualScroller.updateSingleItem(filePath, { tags: result.tags });
|
state.virtualScroller.updateSingleItem(filePath, {
|
||||||
|
tags: result.tags,
|
||||||
|
auto_tags: result.auto_tags,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
@@ -255,7 +255,7 @@ def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
|||||||
cache_updates.append(metadata)
|
cache_updates.append(metadata)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
tags = asyncio.run(
|
tags, auto_tags = asyncio.run(
|
||||||
service.add_tags(
|
service.add_tags(
|
||||||
file_path=str(tmp_path / "model.safetensors"),
|
file_path=str(tmp_path / "model.safetensors"),
|
||||||
new_tags=["new", "existing"],
|
new_tags=["new", "existing"],
|
||||||
@@ -265,5 +265,6 @@ def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert tags == ["existing", "new"]
|
assert tags == ["existing", "new"]
|
||||||
|
assert auto_tags == []
|
||||||
assert manager.saved
|
assert manager.saved
|
||||||
assert cache_updates
|
assert cache_updates
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ async def test_tag_update_service_handles_case_insensitive_tags(tmp_path: Path)
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Try to add "Test" (different case) - should not be added since "test" already exists
|
# Try to add "Test" (different case) - should not be added since "test" already exists
|
||||||
tags = await service.add_tags(
|
tags, auto_tags = await service.add_tags(
|
||||||
file_path=str(tmp_path / "model.safetensors"),
|
file_path=str(tmp_path / "model.safetensors"),
|
||||||
new_tags=["Test"],
|
new_tags=["Test"],
|
||||||
metadata_loader=loader,
|
metadata_loader=loader,
|
||||||
@@ -52,6 +52,7 @@ async def test_tag_update_service_handles_case_insensitive_tags(tmp_path: Path)
|
|||||||
|
|
||||||
# Should still only have "test" (lowercase) in the tags
|
# Should still only have "test" (lowercase) in the tags
|
||||||
assert tags == ["test"]
|
assert tags == ["test"]
|
||||||
|
assert auto_tags == [] # no file_name/base_model in metadata, so no auto-detection
|
||||||
assert len(manager.saved) == 1
|
assert len(manager.saved) == 1
|
||||||
saved_metadata = manager.saved[0][1]
|
saved_metadata = manager.saved[0][1]
|
||||||
assert saved_metadata["tags"] == ["test"]
|
assert saved_metadata["tags"] == ["test"]
|
||||||
@@ -76,7 +77,7 @@ async def test_tag_update_service_adds_new_tags_in_lowercase(tmp_path: Path) ->
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
# Add new tags with mixed case
|
# Add new tags with mixed case
|
||||||
tags = await service.add_tags(
|
tags, auto_tags = await service.add_tags(
|
||||||
file_path=str(tmp_path / "model.safetensors"),
|
file_path=str(tmp_path / "model.safetensors"),
|
||||||
new_tags=["NewTag", "ANOTHER_TAG"],
|
new_tags=["NewTag", "ANOTHER_TAG"],
|
||||||
metadata_loader=loader,
|
metadata_loader=loader,
|
||||||
@@ -87,6 +88,7 @@ async def test_tag_update_service_adds_new_tags_in_lowercase(tmp_path: Path) ->
|
|||||||
assert "existing" in tags
|
assert "existing" in tags
|
||||||
assert "newtag" in tags
|
assert "newtag" in tags
|
||||||
assert "another_tag" in tags
|
assert "another_tag" in tags
|
||||||
|
assert auto_tags == []
|
||||||
assert len(manager.saved) == 1
|
assert len(manager.saved) == 1
|
||||||
saved_metadata = manager.saved[0][1]
|
saved_metadata = manager.saved[0][1]
|
||||||
assert "newtag" in saved_metadata["tags"]
|
assert "newtag" in saved_metadata["tags"]
|
||||||
|
|||||||
@@ -126,6 +126,80 @@ class TestExtractAutoTags:
|
|||||||
})
|
})
|
||||||
assert set(result) == {"HIGH", "I2V"}
|
assert set(result) == {"HIGH", "I2V"}
|
||||||
|
|
||||||
|
# ── Layer 2: user-defined tags as manual fallback ───────────
|
||||||
|
|
||||||
|
def test_user_tags_fallback_when_detection_fails(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "BOTH-v1.0",
|
||||||
|
"base_model": "Wan 2.2",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["HIGH", "I2V", "T2V"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"HIGH", "I2V", "T2V"}
|
||||||
|
|
||||||
|
def test_user_tags_augment_partial_detection(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "wan_i2v_hn_v2",
|
||||||
|
"base_model": "Wan 2.2 I2V",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["HIGH"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"HIGH", "I2V"}
|
||||||
|
|
||||||
|
def test_user_tags_non_auto_tag_ignored(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "model_v1",
|
||||||
|
"base_model": "Wan 2.2",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["HIGH", "character", "style", "nsfw"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"HIGH"}
|
||||||
|
|
||||||
|
def test_user_tags_overrides_non_wan_gate(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "flux_model_v1",
|
||||||
|
"base_model": "Flux.1 D",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["HIGH", "LOW", "Turbo"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"HIGH", "LOW", "Turbo"}
|
||||||
|
|
||||||
|
def test_user_tags_no_duplication(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "wan_i2v_high_v3",
|
||||||
|
"base_model": "Wan 2.2",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["HIGH", "I2V"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"HIGH", "I2V"}
|
||||||
|
|
||||||
|
def test_user_tags_lightning_turbo_manual(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "sdxl_model_v1",
|
||||||
|
"base_model": "SDXL",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["Lightning"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"Lightning"}
|
||||||
|
|
||||||
|
def test_user_tags_case_insensitive_lowercase(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "wan_masterpieces_v2",
|
||||||
|
"base_model": "Wan Video 14B t2v",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["high"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"HIGH", "T2V"}
|
||||||
|
|
||||||
|
def test_user_tags_case_insensitive_mixed(self):
|
||||||
|
result = extract_auto_tags({
|
||||||
|
"file_name": "model_v1",
|
||||||
|
"base_model": "SDXL",
|
||||||
|
"civitai": {},
|
||||||
|
"tags": ["lightning", "turbo", "i2v"],
|
||||||
|
})
|
||||||
|
assert set(result) == {"Lightning", "Turbo", "I2V"}
|
||||||
|
|
||||||
|
|
||||||
class TestAutoTagCategories:
|
class TestAutoTagCategories:
|
||||||
def test_all_patterns_compile(self):
|
def test_all_patterns_compile(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user