feat(agent): improve enrich_hf_metadata skill with priority_tags, preview_url fix, civitai.trainedWords

- Add identify_model_type() helper to determine lora/checkpoint/embedding
- Pass priority_tags from user settings to LLM prompt for tag relevance
- SKILL.md: instruct LLM to exclude technical/generic HF tags, cross-reference
  against priority_tags; forbid ['None'] placeholder for trigger words
- post_processor: fix preview_url not updated after download (now writes local
  .webp path to metadata); write trigger words to civitai.trainedWords instead
  of top-level; sanitize ['None']/'null'/'n/a' placeholder values to []
- download_preview() now returns str | None (local path) instead of bool
- Update tests for new return type and nested civitai.trainedWords structure
This commit is contained in:
Will Miao
2026-07-02 22:14:44 +08:00
parent 63785f82b5
commit a8adcaf023
6 changed files with 121 additions and 44 deletions

View File

@@ -32,6 +32,15 @@ logger = logging.getLogger(__name__)
# Helpers # Helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
SCANNER_TYPE_MAP: dict[str, str] = {
"get_lora_scanner": "lora",
"get_checkpoint_scanner": "checkpoint",
"get_embedding_scanner": "embedding",
}
SCANNER_GETTER_NAMES = tuple(SCANNER_TYPE_MAP.keys())
async def _find_scanner_for_model( async def _find_scanner_for_model(
model_path: str, model_path: str,
) -> tuple[object, object] | tuple[None, None]: ) -> tuple[object, object] | tuple[None, None]:
@@ -44,11 +53,7 @@ async def _find_scanner_for_model(
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
normalized = os.path.normpath(model_path) normalized = os.path.normpath(model_path)
for getter_name in ( for getter_name in SCANNER_GETTER_NAMES:
"get_lora_scanner",
"get_checkpoint_scanner",
"get_embedding_scanner",
):
getter = getattr(ServiceRegistry, getter_name, None) getter = getattr(ServiceRegistry, getter_name, None)
if getter is None: if getter is None:
continue continue
@@ -70,6 +75,38 @@ async def _find_scanner_for_model(
return None, None return None, None
async def identify_model_type(model_path: str) -> str:
"""Determine the model type (``\"lora\"``, ``\"checkpoint\"``, or
``\"embedding\"``) for *model_path*.
Iterates all known scanners; the first scanner that claims the path
determines the type. Falls back to ``\"lora\"`` when unknown.
"""
from ..services.service_registry import ServiceRegistry
normalized = os.path.normpath(model_path)
for getter_name in SCANNER_GETTER_NAMES:
getter = getattr(ServiceRegistry, getter_name, None)
if getter is None:
continue
try:
scanner = await getter()
if scanner is None:
continue
cache = await scanner.get_cached_data()
for entry in cache.raw_data:
if os.path.normpath(entry.get("file_path", "")) == normalized:
return SCANNER_TYPE_MAP[getter_name]
except Exception as exc:
logger.debug(
"identify_model_type scanner %s error for %s: %s",
getter_name,
model_path,
exc,
)
return "lora"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Public API # Public API
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -153,17 +190,17 @@ async def download_preview(
*, *,
target_width: int = 480, target_width: int = 480,
quality: int = 85, quality: int = 85,
) -> bool: ) -> str | None:
"""Download a preview image from *url*, optimise to .webp, and save it. """Download a preview image from *url*, optimise to .webp, and save it.
The output file is placed alongside the model file with a ``.webp`` The output file is placed alongside the model file with a ``.webp``
extension. Returns ``True`` on success. extension. Returns the local file path on success, ``None`` on failure.
""" """
from ..services.downloader import get_downloader from ..services.downloader import get_downloader
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
if not url or not url.strip(): if not url or not url.strip():
return False return None
base_name = os.path.splitext(os.path.basename(model_path))[0] base_name = os.path.splitext(os.path.basename(model_path))[0]
preview_dir = os.path.dirname(model_path) preview_dir = os.path.dirname(model_path)
@@ -187,7 +224,7 @@ async def download_preview(
with open(output_path, "wb") as f: with open(output_path, "wb") as f:
f.write(optimized_data) f.write(optimized_data)
logger.info("Preview downloaded and optimised for %s", model_path) logger.info("Preview downloaded and optimised for %s", model_path)
return True return output_path
except Exception as exc: except Exception as exc:
logger.warning("Preview optimisation failed, saving raw: %s", exc) logger.warning("Preview optimisation failed, saving raw: %s", exc)
# Fall through to raw save # Fall through to raw save
@@ -197,11 +234,11 @@ async def download_preview(
ok, _ = await downloader.download_file(url, output_path, use_auth=False) ok, _ = await downloader.download_file(url, output_path, use_auth=False)
if ok: if ok:
logger.info("Preview downloaded (fallback) for %s", model_path) logger.info("Preview downloaded (fallback) for %s", model_path)
return True return output_path
except Exception as exc: except Exception as exc:
logger.warning("Preview fallback download failed for %s: %s", model_path, exc) logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
return False return None
async def refresh_cache(model_path: str) -> bool: async def refresh_cache(model_path: str) -> bool:

View File

@@ -334,10 +334,11 @@ class AgentService:
"""Gather variables for the skill's prompt template. """Gather variables for the skill's prompt template.
Reads metadata, fetches the HF README (if applicable), lists available Reads metadata, fetches the HF README (if applicable), lists available
base models, and returns a dict that maps to ``{{variable}}`` base models, loads user priority tags, and returns a dict that maps to
placeholders in ``prompt.md``. ``{{variable}}`` placeholders in ``prompt.md``.
""" """
from ...agent_cli import list_base_models from ...agent_cli import identify_model_type, list_base_models
from ..settings_manager import SettingsManager
context: Dict[str, Any] = { context: Dict[str, Any] = {
"model_path": model_path, "model_path": model_path,
@@ -346,6 +347,7 @@ class AgentService:
"readme_content": "", "readme_content": "",
"current_metadata": {}, "current_metadata": {},
"base_models": [], "base_models": [],
"priority_tags": "",
} }
context["current_metadata"] = { context["current_metadata"] = {
@@ -371,6 +373,18 @@ class AgentService:
except Exception as exc: except Exception as exc:
logger.debug("Failed to list base models: %s", exc) logger.debug("Failed to list base models: %s", exc)
# Determine model type and load the corresponding priority_tags
try:
model_type = await identify_model_type(model_path)
context["model_type"] = model_type
settings = SettingsManager()
priority_config = settings.get_priority_tag_config()
context["priority_tags"] = priority_config.get(model_type, "")
except Exception as exc:
logger.debug("Failed to load priority tags: %s", exc)
context["model_type"] = "lora"
context["priority_tags"] = ""
return context return context
@staticmethod @staticmethod

View File

@@ -86,14 +86,16 @@ class PostProcessor:
if new_base and self._should_overwrite(current_base, is_hf_model): if new_base and self._should_overwrite(current_base, is_hf_model):
updates["base_model"] = new_base updates["base_model"] = new_base
# trainedWords / trigger words
new_triggers = llm_output.get("trigger_words", []) new_triggers = llm_output.get("trigger_words", [])
if isinstance(new_triggers, list): if isinstance(new_triggers, list):
cleaned = [t.strip() for t in new_triggers if t.strip()] cleaned = [t.strip() for t in new_triggers if t.strip()]
if cleaned: cleaned = [t for t in cleaned if t.lower() not in ("none", "null", "n/a")]
current_triggers = metadata.get("trainedWords") or [] current_civitai = metadata.get("civitai") or {}
if self._should_overwrite_list(current_triggers, is_hf_model): current_triggers = current_civitai.get("trainedWords") or []
updates["trainedWords"] = cleaned if self._should_overwrite_list(current_triggers, is_hf_model):
civitai_updates = dict(current_civitai)
civitai_updates["trainedWords"] = cleaned
updates["civitai"] = civitai_updates
# modelDescription # modelDescription
new_desc = (llm_output.get("description") or "").strip() new_desc = (llm_output.get("description") or "").strip()
@@ -102,7 +104,7 @@ class PostProcessor:
if self._should_overwrite(current_desc, is_hf_model): if self._should_overwrite(current_desc, is_hf_model):
updates["modelDescription"] = new_desc updates["modelDescription"] = new_desc
# tags — merge with existing, deduplicate (case-insensitive) # tags
new_tags = llm_output.get("tags", []) new_tags = llm_output.get("tags", [])
if isinstance(new_tags, list) and new_tags: if isinstance(new_tags, list) and new_tags:
existing_tags = metadata.get("tags") or [] existing_tags = metadata.get("tags") or []
@@ -114,16 +116,17 @@ class PostProcessor:
updates["metadata_source"] = "agent:enrich_hf_metadata" updates["metadata_source"] = "agent:enrich_hf_metadata"
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat() updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
# -- Persist updates ------------------------------------------------ preview_remote_url = (llm_output.get("preview_url") or "").strip()
current_preview = metadata.get("preview_url") or ""
if preview_remote_url and not (current_preview and os.path.exists(current_preview)):
local_path = await download_preview(model_path, preview_remote_url)
if local_path:
preview_downloaded = True
updates["preview_url"] = local_path
if updates: if updates:
updated_fields = await apply_metadata_updates(model_path, updates) updated_fields = await apply_metadata_updates(model_path, updates)
# -- Download preview -----------------------------------------------
preview_url = (llm_output.get("preview_url") or "").strip()
current_preview = metadata.get("preview_url") or ""
if preview_url and not (current_preview and os.path.exists(current_preview)):
preview_downloaded = await download_preview(model_path, preview_url)
# -- Refresh scanner cache ------------------------------------------ # -- Refresh scanner cache ------------------------------------------
if updated_fields or preview_downloaded: if updated_fields or preview_downloaded:
await refresh_cache(model_path) await refresh_cache(model_path)

View File

@@ -21,6 +21,16 @@ You are an expert assistant for AI image generation models. Your task is to extr
{{current_metadata}} {{current_metadata}}
``` ```
## User Priority Tags Reference
The user has configured the following list of **meaningful tag categories** for this model type (`{{model_type}}`):
```
{{priority_tags}}
```
These are the subjects, styles, and concepts the user considers useful for categorization. Use this list as a **reference** when evaluating tags (see the **tags** section below).
## Available Base Models ## Available Base Models
The following base models are currently valid in this system: The following base models are currently valid in this system:
@@ -37,7 +47,7 @@ The following base models are currently valid in this system:
Extract the following information from the README content above: Extract the following information from the README content above:
### base_model ### base_model
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases. The base model this model was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string. Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
@@ -46,17 +56,27 @@ The trigger words or activation prompts needed to use this LoRA. Look for:
- `instance_prompt:` in the YAML frontmatter - `instance_prompt:` in the YAML frontmatter
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:" - Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
- Example prompts at the start (usually the first word or phrase before any description) - Example prompts at the start (usually the first word or phrase before any description)
Return as an array of strings. If none found, return an empty array. Return as an array of strings. If none found, return an empty array `[]`. **Never** return `["None"]` or any placeholder value — a truly empty list means no trigger words exist.
### description ### description
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal. A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
### tags ### tags
3-8 relevant tags for categorizing this model. Extract from: 3-8 relevant tags for categorizing this model. **Quality over quantity.**
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl") Sources to consider:
- The style/subject (e.g. "anime", "photorealistic", "style", "character") - The YAML frontmatter `tags:` list
All lowercase, no spaces. Return empty array if none found. - The subject, style, character, or concept the model represents
**Critical filtering rules — apply them strictly:**
1. **Exclude technical/generic tags.** Reject any tag that describes the model's **training methodology, framework, architecture, or modality** rather than its content. Examples to exclude: `text-to-image`, `diffusers`, `lora`, `dreambooth`, `diffusers-training`, `flux`, `sdxl`, `checkpoint`, `pytorch`, `safetensors`, `fine-tuning`, `stable-diffusion`, and any variant of these.
2. **Cross-reference against the priority_tags reference.** Only include a tag if it meaningfully describes what the model actually creates (subject, style, character type) and is semantically close to one of the priority_tags. If none of the README's tags match meaningful categories, prefer returning a smaller set or an empty array over including low-value tags.
3. **All lowercase, no spaces, no hyphens** (use single words like `"photorealistic"`, `"anime"`, `"character"`).
Return empty array if no meaningful content tags remain after filtering.
### preview_url ### preview_url
The URL of the most suitable preview image from the README. Look for image tags (e.g. `![alt](url)`) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string. The URL of the most suitable preview image from the README. Look for image tags (e.g. `![alt](url)`) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
@@ -87,3 +107,4 @@ Important:
- Only include the JSON object, no other text - Only include the JSON object, no other text
- If a field cannot be determined, use an empty string or empty array - If a field cannot be determined, use an empty string or empty array
- Do not fabricate information not supported by the README - Do not fabricate information not supported by the README
- Never use placeholder values like `"None"` or `"unknown"` for missing data — use empty string or empty array

View File

@@ -227,11 +227,11 @@ class TestApplyMetadataUpdates:
class TestDownloadPreview: class TestDownloadPreview:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_url_returns_false(self, tmp_path): async def test_empty_url_returns_none(self, tmp_path):
mp = tmp_path / "m.safetensors" mp = tmp_path / "m.safetensors"
mp.write_bytes(b"fake") mp.write_bytes(b"fake")
assert await download_preview(str(mp), "") is False assert await download_preview(str(mp), "") is None
assert await download_preview(str(mp), " ") is False assert await download_preview(str(mp), " ") is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_successful_download_and_optimise(self, tmp_path): async def test_successful_download_and_optimise(self, tmp_path):
@@ -246,12 +246,12 @@ class TestDownloadPreview:
get_dl.return_value = dl get_dl.return_value = dl
exif.optimize_image.return_value = (b"optimized_webp", {}) exif.optimize_image.return_value = (b"optimized_webp", {})
result = await download_preview(str(mp), "https://ex.com/i.png") result = await download_preview(str(mp), "https://ex.com/i.png")
assert result is True assert result == str(tmp_path / "t.webp")
assert (tmp_path / "t.webp").exists() assert (tmp_path / "t.webp").exists()
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp" assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_failure_returns_false(self, tmp_path): async def test_download_failure_returns_none(self, tmp_path):
mp = tmp_path / "t.safetensors" mp = tmp_path / "t.safetensors"
mp.write_bytes(b"fake") mp.write_bytes(b"fake")
with mock.patch("py.services.downloader.get_downloader") as get_dl: with mock.patch("py.services.downloader.get_downloader") as get_dl:
@@ -260,7 +260,7 @@ class TestDownloadPreview:
dl.download_file = mock.AsyncMock(return_value=(False, None)) dl.download_file = mock.AsyncMock(return_value=(False, None))
get_dl.return_value = dl get_dl.return_value = dl
result = await download_preview(str(mp), "https://ex.com/i.png") result = await download_preview(str(mp), "https://ex.com/i.png")
assert result is False assert result is None
assert not (tmp_path / "t.webp").exists() assert not (tmp_path / "t.webp").exists()

View File

@@ -44,7 +44,7 @@ class TestProcessDispatch:
mock.patch("py.agent_cli.refresh_cache") as mock_ref, mock.patch("py.agent_cli.refresh_cache") as mock_ref,
): ):
mock_apply.return_value = ["metadata_source"] mock_apply.return_value = ["metadata_source"]
mock_dl.return_value = False mock_dl.return_value = None
result = await processor.process( result = await processor.process(
skill_name="enrich_hf_metadata", skill_name="enrich_hf_metadata",
@@ -155,7 +155,7 @@ class TestEnrichHfMetadata:
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]} llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
with ( with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply, mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False), mock.patch("py.agent_cli.download_preview", return_value=None),
mock.patch("py.agent_cli.refresh_cache"), mock.patch("py.agent_cli.refresh_cache"),
): ):
await processor.process( await processor.process(
@@ -165,7 +165,7 @@ class TestEnrichHfMetadata:
metadata={"trainedWords": []}, metadata={"trainedWords": []},
) )
applied = mock_apply.call_args[0][1] applied = mock_apply.call_args[0][1]
assert applied["trainedWords"] == ["trigger1", "trigger2"] assert applied["civitai"]["trainedWords"] == ["trigger1", "trigger2"]
# -- description ----------------------------------------------------- # -- description -----------------------------------------------------
@@ -237,7 +237,7 @@ class TestEnrichHfMetadata:
mock.patch("py.agent_cli.download_preview") as mock_dl, mock.patch("py.agent_cli.download_preview") as mock_dl,
mock.patch("py.agent_cli.refresh_cache"), mock.patch("py.agent_cli.refresh_cache"),
): ):
mock_dl.return_value = True mock_dl.return_value = "/p.webp"
result = await processor.process( result = await processor.process(
skill_name="enrich_hf_metadata", skill_name="enrich_hf_metadata",
model_path="/p.safetensors", model_path="/p.safetensors",
@@ -246,6 +246,8 @@ class TestEnrichHfMetadata:
) )
assert result["preview_downloaded"] is True assert result["preview_downloaded"] is True
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png") mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
applied = mock_apply.call_args[0][1]
assert applied["preview_url"] == "/p.webp"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_preview_skipped_when_exists(self, processor): async def test_preview_skipped_when_exists(self, processor):