diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 6e931b2d..5438dd29 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -788,7 +788,7 @@ class ModelManagementHandler: metadata_updates = {k: v for k, v in data.items() if k != "file_path"} - await self._metadata_sync.save_metadata_updates( + updated_metadata = await self._metadata_sync.save_metadata_updates( file_path=file_path, updates=metadata_updates, metadata_loader=self._metadata_sync.load_local_metadata, @@ -799,7 +799,12 @@ class ModelManagementHandler: cache = await self._service.scanner.get_cached_data() await cache.resort() - return web.json_response({"success": True}) + from ...services.auto_tag_service import extract_auto_tags + auto_tags = extract_auto_tags(updated_metadata) + + return web.json_response( + {"success": True, "auto_tags": auto_tags} + ) except Exception as exc: self._logger.error("Error saving metadata: %s", exc, exc_info=True) return web.Response(text=str(exc), status=500) @@ -816,14 +821,16 @@ class ModelManagementHandler: if not isinstance(new_tags, list): return web.Response(text="Tags must be a list", status=400) - tags = await self._tag_update_service.add_tags( + tags, auto_tags = await self._tag_update_service.add_tags( file_path=file_path, new_tags=new_tags, metadata_loader=self._metadata_sync.load_local_metadata, update_cache=self._service.scanner.update_single_model_cache, ) - return web.json_response({"success": True, "tags": tags}) + return web.json_response( + {"success": True, "tags": tags, "auto_tags": auto_tags} + ) except Exception as exc: self._logger.error("Error adding tags: %s", exc, exc_info=True) return web.Response(text=str(exc), status=500) diff --git a/py/services/auto_tag_service.py b/py/services/auto_tag_service.py index 545cf52e..89c0f966 100644 --- a/py/services/auto_tag_service.py +++ b/py/services/auto_tag_service.py @@ -76,46 +76,64 @@ def _collect_sources(model_data: Dict) -> List[str]: def extract_auto_tags(model_data: Dict) -> List[str]: """Extract auto-detected tags from model metadata. - Matches predefined patterns against filename, base_model, and - CivitAI version name. Returns a sorted, deduplicated list of tag labels. + Uses a two-layer approach: + Layer 1 — Regex-based detection against filename, base_model, and + CivitAI version name. + Layer 2 — Merge in any user-defined tags that overlap with known + auto-tag categories. This provides a manual fallback when + auto-detection fails (e.g. "I2V HN" or unlabeled models). HIGH/LOW tags are only returned when the base_model indicates a Wan family model — no other model architecture uses this distinction. Args: model_data: Model metadata dict with keys: - file_name, base_model, civitai (with optional 'name' field). + file_name, base_model, civitai (with optional 'name' field), + tags (user-defined tag list, used as fallback). Returns: Sorted list of unique auto-tag strings (e.g. ["I2V"]). """ sources = _collect_sources(model_data) - if not sources: - return [] - base_model = model_data.get("base_model", "") is_wan = "wan" in base_model.lower() found: Set[str] = set() - for label, pattern in AUTO_TAG_CATEGORIES.items(): - # HIGH/LOW are Wan-specific — skip for non-Wan to avoid noise - if label in ("HIGH", "LOW"): - if not is_wan: - continue - # Use case-insensitive character class + case-sensitive boundary, - # so "HighNoise" (camelCase) matches but "highlight" doesn't. - # Boundary: not followed by lowercase letter (= word has ended). - ci = "".join(f"[{c.lower()}{c.upper()}]" for c in label) - if label == "LOW": - regex = re.compile(r"(? List[str]: - """Add tags to a metadata entry while keeping case-insensitive uniqueness.""" - + ) -> Tuple[List[str], List[str]]: + """Add tags to a metadata entry and return updated tags and auto_tags.""" base, _ = os.path.splitext(file_path) metadata_path = f"{base}.metadata.json" metadata = await metadata_loader(metadata_path) @@ -44,5 +45,6 @@ class TagUpdateService: await self._metadata_manager.save_metadata(file_path, metadata) await update_cache(file_path, file_path, metadata) - return existing_tags + auto_tags = extract_auto_tags(metadata) + return existing_tags, auto_tags diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index bc5ec074..750c2f21 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -422,8 +422,12 @@ export class BaseModelApiClient { throw new Error('Failed to save metadata'); } - state.virtualScroller.updateSingleItem(filePath, data); - return response.json(); + const result = await response.json(); + state.virtualScroller.updateSingleItem(filePath, { + ...data, + auto_tags: result.auto_tags, + }); + return result; } finally { state.loadingManager.hide(); } @@ -448,7 +452,10 @@ export class BaseModelApiClient { const result = await response.json(); if (result.success && result.tags) { - state.virtualScroller.updateSingleItem(filePath, { tags: result.tags }); + state.virtualScroller.updateSingleItem(filePath, { + tags: result.tags, + auto_tags: result.auto_tags, + }); } return result; diff --git a/tests/services/test_route_support_services.py b/tests/services/test_route_support_services.py index 9f95b3f3..39929832 100644 --- a/tests/services/test_route_support_services.py +++ b/tests/services/test_route_support_services.py @@ -255,7 +255,7 @@ def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None: cache_updates.append(metadata) return True - tags = asyncio.run( + tags, auto_tags = asyncio.run( service.add_tags( file_path=str(tmp_path / "model.safetensors"), new_tags=["new", "existing"], @@ -265,5 +265,6 @@ def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None: ) assert tags == ["existing", "new"] + assert auto_tags == [] assert manager.saved assert cache_updates diff --git a/tests/services/test_tag_case_sensitivity.py b/tests/services/test_tag_case_sensitivity.py index 3396d6b3..6f021589 100644 --- a/tests/services/test_tag_case_sensitivity.py +++ b/tests/services/test_tag_case_sensitivity.py @@ -43,7 +43,7 @@ async def test_tag_update_service_handles_case_insensitive_tags(tmp_path: Path) return True # Try to add "Test" (different case) - should not be added since "test" already exists - tags = await service.add_tags( + tags, auto_tags = await service.add_tags( file_path=str(tmp_path / "model.safetensors"), new_tags=["Test"], metadata_loader=loader, @@ -52,6 +52,7 @@ async def test_tag_update_service_handles_case_insensitive_tags(tmp_path: Path) # Should still only have "test" (lowercase) in the tags assert tags == ["test"] + assert auto_tags == [] # no file_name/base_model in metadata, so no auto-detection assert len(manager.saved) == 1 saved_metadata = manager.saved[0][1] assert saved_metadata["tags"] == ["test"] @@ -76,7 +77,7 @@ async def test_tag_update_service_adds_new_tags_in_lowercase(tmp_path: Path) -> return True # Add new tags with mixed case - tags = await service.add_tags( + tags, auto_tags = await service.add_tags( file_path=str(tmp_path / "model.safetensors"), new_tags=["NewTag", "ANOTHER_TAG"], metadata_loader=loader, @@ -87,6 +88,7 @@ async def test_tag_update_service_adds_new_tags_in_lowercase(tmp_path: Path) -> assert "existing" in tags assert "newtag" in tags assert "another_tag" in tags + assert auto_tags == [] assert len(manager.saved) == 1 saved_metadata = manager.saved[0][1] assert "newtag" in saved_metadata["tags"] diff --git a/tests/test_auto_tag_service.py b/tests/test_auto_tag_service.py index e7c3fb45..84df141d 100644 --- a/tests/test_auto_tag_service.py +++ b/tests/test_auto_tag_service.py @@ -126,6 +126,80 @@ class TestExtractAutoTags: }) assert set(result) == {"HIGH", "I2V"} + # ── Layer 2: user-defined tags as manual fallback ─────────── + + def test_user_tags_fallback_when_detection_fails(self): + result = extract_auto_tags({ + "file_name": "BOTH-v1.0", + "base_model": "Wan 2.2", + "civitai": {}, + "tags": ["HIGH", "I2V", "T2V"], + }) + assert set(result) == {"HIGH", "I2V", "T2V"} + + def test_user_tags_augment_partial_detection(self): + result = extract_auto_tags({ + "file_name": "wan_i2v_hn_v2", + "base_model": "Wan 2.2 I2V", + "civitai": {}, + "tags": ["HIGH"], + }) + assert set(result) == {"HIGH", "I2V"} + + def test_user_tags_non_auto_tag_ignored(self): + result = extract_auto_tags({ + "file_name": "model_v1", + "base_model": "Wan 2.2", + "civitai": {}, + "tags": ["HIGH", "character", "style", "nsfw"], + }) + assert set(result) == {"HIGH"} + + def test_user_tags_overrides_non_wan_gate(self): + result = extract_auto_tags({ + "file_name": "flux_model_v1", + "base_model": "Flux.1 D", + "civitai": {}, + "tags": ["HIGH", "LOW", "Turbo"], + }) + assert set(result) == {"HIGH", "LOW", "Turbo"} + + def test_user_tags_no_duplication(self): + result = extract_auto_tags({ + "file_name": "wan_i2v_high_v3", + "base_model": "Wan 2.2", + "civitai": {}, + "tags": ["HIGH", "I2V"], + }) + assert set(result) == {"HIGH", "I2V"} + + def test_user_tags_lightning_turbo_manual(self): + result = extract_auto_tags({ + "file_name": "sdxl_model_v1", + "base_model": "SDXL", + "civitai": {}, + "tags": ["Lightning"], + }) + assert set(result) == {"Lightning"} + + def test_user_tags_case_insensitive_lowercase(self): + result = extract_auto_tags({ + "file_name": "wan_masterpieces_v2", + "base_model": "Wan Video 14B t2v", + "civitai": {}, + "tags": ["high"], + }) + assert set(result) == {"HIGH", "T2V"} + + def test_user_tags_case_insensitive_mixed(self): + result = extract_auto_tags({ + "file_name": "model_v1", + "base_model": "SDXL", + "civitai": {}, + "tags": ["lightning", "turbo", "i2v"], + }) + assert set(result) == {"Lightning", "Turbo", "I2V"} + class TestAutoTagCategories: def test_all_patterns_compile(self):