feat(agent): improve enrich_hf_metadata skill with priority_tags, preview_url fix, civitai.trainedWords

- Add identify_model_type() helper to determine lora/checkpoint/embedding
- Pass priority_tags from user settings to LLM prompt for tag relevance
- SKILL.md: instruct LLM to exclude technical/generic HF tags, cross-reference
  against priority_tags; forbid ['None'] placeholder for trigger words
- post_processor: fix preview_url not updated after download (now writes local
  .webp path to metadata); write trigger words to civitai.trainedWords instead
  of top-level; sanitize ['None']/'null'/'n/a' placeholder values to []
- download_preview() now returns str | None (local path) instead of bool
- Update tests for new return type and nested civitai.trainedWords structure
This commit is contained in:
Will Miao
2026-07-02 22:14:44 +08:00
parent 63785f82b5
commit a8adcaf023
6 changed files with 121 additions and 44 deletions

View File

@@ -32,6 +32,15 @@ logger = logging.getLogger(__name__)
# Helpers
# ---------------------------------------------------------------------------
SCANNER_TYPE_MAP: dict[str, str] = {
"get_lora_scanner": "lora",
"get_checkpoint_scanner": "checkpoint",
"get_embedding_scanner": "embedding",
}
SCANNER_GETTER_NAMES = tuple(SCANNER_TYPE_MAP.keys())
async def _find_scanner_for_model(
model_path: str,
) -> tuple[object, object] | tuple[None, None]:
@@ -44,11 +53,7 @@ async def _find_scanner_for_model(
from ..services.service_registry import ServiceRegistry
normalized = os.path.normpath(model_path)
for getter_name in (
"get_lora_scanner",
"get_checkpoint_scanner",
"get_embedding_scanner",
):
for getter_name in SCANNER_GETTER_NAMES:
getter = getattr(ServiceRegistry, getter_name, None)
if getter is None:
continue
@@ -70,6 +75,38 @@ async def _find_scanner_for_model(
return None, None
async def identify_model_type(model_path: str) -> str:
"""Determine the model type (``\"lora\"``, ``\"checkpoint\"``, or
``\"embedding\"``) for *model_path*.
Iterates all known scanners; the first scanner that claims the path
determines the type. Falls back to ``\"lora\"`` when unknown.
"""
from ..services.service_registry import ServiceRegistry
normalized = os.path.normpath(model_path)
for getter_name in SCANNER_GETTER_NAMES:
getter = getattr(ServiceRegistry, getter_name, None)
if getter is None:
continue
try:
scanner = await getter()
if scanner is None:
continue
cache = await scanner.get_cached_data()
for entry in cache.raw_data:
if os.path.normpath(entry.get("file_path", "")) == normalized:
return SCANNER_TYPE_MAP[getter_name]
except Exception as exc:
logger.debug(
"identify_model_type scanner %s error for %s: %s",
getter_name,
model_path,
exc,
)
return "lora"
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
@@ -153,17 +190,17 @@ async def download_preview(
*,
target_width: int = 480,
quality: int = 85,
) -> bool:
) -> str | None:
"""Download a preview image from *url*, optimise to .webp, and save it.
The output file is placed alongside the model file with a ``.webp``
extension. Returns ``True`` on success.
extension. Returns the local file path on success, ``None`` on failure.
"""
from ..services.downloader import get_downloader
from ..utils.exif_utils import ExifUtils
if not url or not url.strip():
return False
return None
base_name = os.path.splitext(os.path.basename(model_path))[0]
preview_dir = os.path.dirname(model_path)
@@ -187,7 +224,7 @@ async def download_preview(
with open(output_path, "wb") as f:
f.write(optimized_data)
logger.info("Preview downloaded and optimised for %s", model_path)
return True
return output_path
except Exception as exc:
logger.warning("Preview optimisation failed, saving raw: %s", exc)
# Fall through to raw save
@@ -197,11 +234,11 @@ async def download_preview(
ok, _ = await downloader.download_file(url, output_path, use_auth=False)
if ok:
logger.info("Preview downloaded (fallback) for %s", model_path)
return True
return output_path
except Exception as exc:
logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
return False
return None
async def refresh_cache(model_path: str) -> bool:

View File

@@ -334,10 +334,11 @@ class AgentService:
"""Gather variables for the skill's prompt template.
Reads metadata, fetches the HF README (if applicable), lists available
base models, and returns a dict that maps to ``{{variable}}``
placeholders in ``prompt.md``.
base models, loads user priority tags, and returns a dict that maps to
``{{variable}}`` placeholders in ``prompt.md``.
"""
from ...agent_cli import list_base_models
from ...agent_cli import identify_model_type, list_base_models
from ..settings_manager import SettingsManager
context: Dict[str, Any] = {
"model_path": model_path,
@@ -346,6 +347,7 @@ class AgentService:
"readme_content": "",
"current_metadata": {},
"base_models": [],
"priority_tags": "",
}
context["current_metadata"] = {
@@ -371,6 +373,18 @@ class AgentService:
except Exception as exc:
logger.debug("Failed to list base models: %s", exc)
# Determine model type and load the corresponding priority_tags
try:
model_type = await identify_model_type(model_path)
context["model_type"] = model_type
settings = SettingsManager()
priority_config = settings.get_priority_tag_config()
context["priority_tags"] = priority_config.get(model_type, "")
except Exception as exc:
logger.debug("Failed to load priority tags: %s", exc)
context["model_type"] = "lora"
context["priority_tags"] = ""
return context
@staticmethod

View File

@@ -86,14 +86,16 @@ class PostProcessor:
if new_base and self._should_overwrite(current_base, is_hf_model):
updates["base_model"] = new_base
# trainedWords / trigger words
new_triggers = llm_output.get("trigger_words", [])
if isinstance(new_triggers, list):
cleaned = [t.strip() for t in new_triggers if t.strip()]
if cleaned:
current_triggers = metadata.get("trainedWords") or []
if self._should_overwrite_list(current_triggers, is_hf_model):
updates["trainedWords"] = cleaned
cleaned = [t for t in cleaned if t.lower() not in ("none", "null", "n/a")]
current_civitai = metadata.get("civitai") or {}
current_triggers = current_civitai.get("trainedWords") or []
if self._should_overwrite_list(current_triggers, is_hf_model):
civitai_updates = dict(current_civitai)
civitai_updates["trainedWords"] = cleaned
updates["civitai"] = civitai_updates
# modelDescription
new_desc = (llm_output.get("description") or "").strip()
@@ -102,7 +104,7 @@ class PostProcessor:
if self._should_overwrite(current_desc, is_hf_model):
updates["modelDescription"] = new_desc
# tags — merge with existing, deduplicate (case-insensitive)
# tags
new_tags = llm_output.get("tags", [])
if isinstance(new_tags, list) and new_tags:
existing_tags = metadata.get("tags") or []
@@ -114,16 +116,17 @@ class PostProcessor:
updates["metadata_source"] = "agent:enrich_hf_metadata"
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
# -- Persist updates ------------------------------------------------
preview_remote_url = (llm_output.get("preview_url") or "").strip()
current_preview = metadata.get("preview_url") or ""
if preview_remote_url and not (current_preview and os.path.exists(current_preview)):
local_path = await download_preview(model_path, preview_remote_url)
if local_path:
preview_downloaded = True
updates["preview_url"] = local_path
if updates:
updated_fields = await apply_metadata_updates(model_path, updates)
# -- Download preview -----------------------------------------------
preview_url = (llm_output.get("preview_url") or "").strip()
current_preview = metadata.get("preview_url") or ""
if preview_url and not (current_preview and os.path.exists(current_preview)):
preview_downloaded = await download_preview(model_path, preview_url)
# -- Refresh scanner cache ------------------------------------------
if updated_fields or preview_downloaded:
await refresh_cache(model_path)

View File

@@ -21,6 +21,16 @@ You are an expert assistant for AI image generation models. Your task is to extr
{{current_metadata}}
```
## User Priority Tags Reference
The user has configured the following list of **meaningful tag categories** for this model type (`{{model_type}}`):
```
{{priority_tags}}
```
These are the subjects, styles, and concepts the user considers useful for categorization. Use this list as a **reference** when evaluating tags (see the **tags** section below).
## Available Base Models
The following base models are currently valid in this system:
@@ -37,7 +47,7 @@ The following base models are currently valid in this system:
Extract the following information from the README content above:
### base_model
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
The base model this model was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
@@ -46,17 +56,27 @@ The trigger words or activation prompts needed to use this LoRA. Look for:
- `instance_prompt:` in the YAML frontmatter
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
- Example prompts at the start (usually the first word or phrase before any description)
Return as an array of strings. If none found, return an empty array.
Return as an array of strings. If none found, return an empty array `[]`. **Never** return `["None"]` or any placeholder value — a truly empty list means no trigger words exist.
### description
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
### tags
3-8 relevant tags for categorizing this model. Extract from:
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl")
- The style/subject (e.g. "anime", "photorealistic", "style", "character")
All lowercase, no spaces. Return empty array if none found.
3-8 relevant tags for categorizing this model. **Quality over quantity.**
Sources to consider:
- The YAML frontmatter `tags:` list
- The subject, style, character, or concept the model represents
**Critical filtering rules — apply them strictly:**
1. **Exclude technical/generic tags.** Reject any tag that describes the model's **training methodology, framework, architecture, or modality** rather than its content. Examples to exclude: `text-to-image`, `diffusers`, `lora`, `dreambooth`, `diffusers-training`, `flux`, `sdxl`, `checkpoint`, `pytorch`, `safetensors`, `fine-tuning`, `stable-diffusion`, and any variant of these.
2. **Cross-reference against the priority_tags reference.** Only include a tag if it meaningfully describes what the model actually creates (subject, style, character type) and is semantically close to one of the priority_tags. If none of the README's tags match meaningful categories, prefer returning a smaller set or an empty array over including low-value tags.
3. **All lowercase, no spaces, no hyphens** (use single words like `"photorealistic"`, `"anime"`, `"character"`).
Return empty array if no meaningful content tags remain after filtering.
### preview_url
The URL of the most suitable preview image from the README. Look for image tags (e.g. `![alt](url)`) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
@@ -87,3 +107,4 @@ Important:
- Only include the JSON object, no other text
- If a field cannot be determined, use an empty string or empty array
- Do not fabricate information not supported by the README
- Never use placeholder values like `"None"` or `"unknown"` for missing data — use empty string or empty array

View File

@@ -227,11 +227,11 @@ class TestApplyMetadataUpdates:
class TestDownloadPreview:
@pytest.mark.asyncio
async def test_empty_url_returns_false(self, tmp_path):
async def test_empty_url_returns_none(self, tmp_path):
mp = tmp_path / "m.safetensors"
mp.write_bytes(b"fake")
assert await download_preview(str(mp), "") is False
assert await download_preview(str(mp), " ") is False
assert await download_preview(str(mp), "") is None
assert await download_preview(str(mp), " ") is None
@pytest.mark.asyncio
async def test_successful_download_and_optimise(self, tmp_path):
@@ -246,12 +246,12 @@ class TestDownloadPreview:
get_dl.return_value = dl
exif.optimize_image.return_value = (b"optimized_webp", {})
result = await download_preview(str(mp), "https://ex.com/i.png")
assert result is True
assert result == str(tmp_path / "t.webp")
assert (tmp_path / "t.webp").exists()
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
@pytest.mark.asyncio
async def test_download_failure_returns_false(self, tmp_path):
async def test_download_failure_returns_none(self, tmp_path):
mp = tmp_path / "t.safetensors"
mp.write_bytes(b"fake")
with mock.patch("py.services.downloader.get_downloader") as get_dl:
@@ -260,7 +260,7 @@ class TestDownloadPreview:
dl.download_file = mock.AsyncMock(return_value=(False, None))
get_dl.return_value = dl
result = await download_preview(str(mp), "https://ex.com/i.png")
assert result is False
assert result is None
assert not (tmp_path / "t.webp").exists()

View File

@@ -44,7 +44,7 @@ class TestProcessDispatch:
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
):
mock_apply.return_value = ["metadata_source"]
mock_dl.return_value = False
mock_dl.return_value = None
result = await processor.process(
skill_name="enrich_hf_metadata",
@@ -155,7 +155,7 @@ class TestEnrichHfMetadata:
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.download_preview", return_value=None),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
@@ -165,7 +165,7 @@ class TestEnrichHfMetadata:
metadata={"trainedWords": []},
)
applied = mock_apply.call_args[0][1]
assert applied["trainedWords"] == ["trigger1", "trigger2"]
assert applied["civitai"]["trainedWords"] == ["trigger1", "trigger2"]
# -- description -----------------------------------------------------
@@ -237,7 +237,7 @@ class TestEnrichHfMetadata:
mock.patch("py.agent_cli.download_preview") as mock_dl,
mock.patch("py.agent_cli.refresh_cache"),
):
mock_dl.return_value = True
mock_dl.return_value = "/p.webp"
result = await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
@@ -246,6 +246,8 @@ class TestEnrichHfMetadata:
)
assert result["preview_downloaded"] is True
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
applied = mock_apply.call_args[0][1]
assert applied["preview_url"] == "/p.webp"
@pytest.mark.asyncio
async def test_preview_skipped_when_exists(self, processor):