mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-05 17:01:16 -03:00
feat(agent): improve enrich_hf_metadata skill with priority_tags, preview_url fix, civitai.trainedWords
- Add identify_model_type() helper to determine lora/checkpoint/embedding - Pass priority_tags from user settings to LLM prompt for tag relevance - SKILL.md: instruct LLM to exclude technical/generic HF tags, cross-reference against priority_tags; forbid ['None'] placeholder for trigger words - post_processor: fix preview_url not updated after download (now writes local .webp path to metadata); write trigger words to civitai.trainedWords instead of top-level; sanitize ['None']/'null'/'n/a' placeholder values to [] - download_preview() now returns str | None (local path) instead of bool - Update tests for new return type and nested civitai.trainedWords structure
This commit is contained in:
@@ -32,6 +32,15 @@ logger = logging.getLogger(__name__)
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SCANNER_TYPE_MAP: dict[str, str] = {
|
||||
"get_lora_scanner": "lora",
|
||||
"get_checkpoint_scanner": "checkpoint",
|
||||
"get_embedding_scanner": "embedding",
|
||||
}
|
||||
|
||||
SCANNER_GETTER_NAMES = tuple(SCANNER_TYPE_MAP.keys())
|
||||
|
||||
|
||||
async def _find_scanner_for_model(
|
||||
model_path: str,
|
||||
) -> tuple[object, object] | tuple[None, None]:
|
||||
@@ -44,11 +53,7 @@ async def _find_scanner_for_model(
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
normalized = os.path.normpath(model_path)
|
||||
for getter_name in (
|
||||
"get_lora_scanner",
|
||||
"get_checkpoint_scanner",
|
||||
"get_embedding_scanner",
|
||||
):
|
||||
for getter_name in SCANNER_GETTER_NAMES:
|
||||
getter = getattr(ServiceRegistry, getter_name, None)
|
||||
if getter is None:
|
||||
continue
|
||||
@@ -70,6 +75,38 @@ async def _find_scanner_for_model(
|
||||
return None, None
|
||||
|
||||
|
||||
async def identify_model_type(model_path: str) -> str:
|
||||
"""Determine the model type (``\"lora\"``, ``\"checkpoint\"``, or
|
||||
``\"embedding\"``) for *model_path*.
|
||||
|
||||
Iterates all known scanners; the first scanner that claims the path
|
||||
determines the type. Falls back to ``\"lora\"`` when unknown.
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
normalized = os.path.normpath(model_path)
|
||||
for getter_name in SCANNER_GETTER_NAMES:
|
||||
getter = getattr(ServiceRegistry, getter_name, None)
|
||||
if getter is None:
|
||||
continue
|
||||
try:
|
||||
scanner = await getter()
|
||||
if scanner is None:
|
||||
continue
|
||||
cache = await scanner.get_cached_data()
|
||||
for entry in cache.raw_data:
|
||||
if os.path.normpath(entry.get("file_path", "")) == normalized:
|
||||
return SCANNER_TYPE_MAP[getter_name]
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"identify_model_type scanner %s error for %s: %s",
|
||||
getter_name,
|
||||
model_path,
|
||||
exc,
|
||||
)
|
||||
return "lora"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -153,17 +190,17 @@ async def download_preview(
|
||||
*,
|
||||
target_width: int = 480,
|
||||
quality: int = 85,
|
||||
) -> bool:
|
||||
) -> str | None:
|
||||
"""Download a preview image from *url*, optimise to .webp, and save it.
|
||||
|
||||
The output file is placed alongside the model file with a ``.webp``
|
||||
extension. Returns ``True`` on success.
|
||||
extension. Returns the local file path on success, ``None`` on failure.
|
||||
"""
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
|
||||
if not url or not url.strip():
|
||||
return False
|
||||
return None
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
preview_dir = os.path.dirname(model_path)
|
||||
@@ -187,7 +224,7 @@ async def download_preview(
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(optimized_data)
|
||||
logger.info("Preview downloaded and optimised for %s", model_path)
|
||||
return True
|
||||
return output_path
|
||||
except Exception as exc:
|
||||
logger.warning("Preview optimisation failed, saving raw: %s", exc)
|
||||
# Fall through to raw save
|
||||
@@ -197,11 +234,11 @@ async def download_preview(
|
||||
ok, _ = await downloader.download_file(url, output_path, use_auth=False)
|
||||
if ok:
|
||||
logger.info("Preview downloaded (fallback) for %s", model_path)
|
||||
return True
|
||||
return output_path
|
||||
except Exception as exc:
|
||||
logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
|
||||
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
async def refresh_cache(model_path: str) -> bool:
|
||||
|
||||
@@ -334,10 +334,11 @@ class AgentService:
|
||||
"""Gather variables for the skill's prompt template.
|
||||
|
||||
Reads metadata, fetches the HF README (if applicable), lists available
|
||||
base models, and returns a dict that maps to ``{{variable}}``
|
||||
placeholders in ``prompt.md``.
|
||||
base models, loads user priority tags, and returns a dict that maps to
|
||||
``{{variable}}`` placeholders in ``prompt.md``.
|
||||
"""
|
||||
from ...agent_cli import list_base_models
|
||||
from ...agent_cli import identify_model_type, list_base_models
|
||||
from ..settings_manager import SettingsManager
|
||||
|
||||
context: Dict[str, Any] = {
|
||||
"model_path": model_path,
|
||||
@@ -346,6 +347,7 @@ class AgentService:
|
||||
"readme_content": "",
|
||||
"current_metadata": {},
|
||||
"base_models": [],
|
||||
"priority_tags": "",
|
||||
}
|
||||
|
||||
context["current_metadata"] = {
|
||||
@@ -371,6 +373,18 @@ class AgentService:
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to list base models: %s", exc)
|
||||
|
||||
# Determine model type and load the corresponding priority_tags
|
||||
try:
|
||||
model_type = await identify_model_type(model_path)
|
||||
context["model_type"] = model_type
|
||||
settings = SettingsManager()
|
||||
priority_config = settings.get_priority_tag_config()
|
||||
context["priority_tags"] = priority_config.get(model_type, "")
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to load priority tags: %s", exc)
|
||||
context["model_type"] = "lora"
|
||||
context["priority_tags"] = ""
|
||||
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -86,14 +86,16 @@ class PostProcessor:
|
||||
if new_base and self._should_overwrite(current_base, is_hf_model):
|
||||
updates["base_model"] = new_base
|
||||
|
||||
# trainedWords / trigger words
|
||||
new_triggers = llm_output.get("trigger_words", [])
|
||||
if isinstance(new_triggers, list):
|
||||
cleaned = [t.strip() for t in new_triggers if t.strip()]
|
||||
if cleaned:
|
||||
current_triggers = metadata.get("trainedWords") or []
|
||||
if self._should_overwrite_list(current_triggers, is_hf_model):
|
||||
updates["trainedWords"] = cleaned
|
||||
cleaned = [t for t in cleaned if t.lower() not in ("none", "null", "n/a")]
|
||||
current_civitai = metadata.get("civitai") or {}
|
||||
current_triggers = current_civitai.get("trainedWords") or []
|
||||
if self._should_overwrite_list(current_triggers, is_hf_model):
|
||||
civitai_updates = dict(current_civitai)
|
||||
civitai_updates["trainedWords"] = cleaned
|
||||
updates["civitai"] = civitai_updates
|
||||
|
||||
# modelDescription
|
||||
new_desc = (llm_output.get("description") or "").strip()
|
||||
@@ -102,7 +104,7 @@ class PostProcessor:
|
||||
if self._should_overwrite(current_desc, is_hf_model):
|
||||
updates["modelDescription"] = new_desc
|
||||
|
||||
# tags — merge with existing, deduplicate (case-insensitive)
|
||||
# tags
|
||||
new_tags = llm_output.get("tags", [])
|
||||
if isinstance(new_tags, list) and new_tags:
|
||||
existing_tags = metadata.get("tags") or []
|
||||
@@ -114,16 +116,17 @@ class PostProcessor:
|
||||
updates["metadata_source"] = "agent:enrich_hf_metadata"
|
||||
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# -- Persist updates ------------------------------------------------
|
||||
preview_remote_url = (llm_output.get("preview_url") or "").strip()
|
||||
current_preview = metadata.get("preview_url") or ""
|
||||
if preview_remote_url and not (current_preview and os.path.exists(current_preview)):
|
||||
local_path = await download_preview(model_path, preview_remote_url)
|
||||
if local_path:
|
||||
preview_downloaded = True
|
||||
updates["preview_url"] = local_path
|
||||
|
||||
if updates:
|
||||
updated_fields = await apply_metadata_updates(model_path, updates)
|
||||
|
||||
# -- Download preview -----------------------------------------------
|
||||
preview_url = (llm_output.get("preview_url") or "").strip()
|
||||
current_preview = metadata.get("preview_url") or ""
|
||||
if preview_url and not (current_preview and os.path.exists(current_preview)):
|
||||
preview_downloaded = await download_preview(model_path, preview_url)
|
||||
|
||||
# -- Refresh scanner cache ------------------------------------------
|
||||
if updated_fields or preview_downloaded:
|
||||
await refresh_cache(model_path)
|
||||
|
||||
@@ -21,6 +21,16 @@ You are an expert assistant for AI image generation models. Your task is to extr
|
||||
{{current_metadata}}
|
||||
```
|
||||
|
||||
## User Priority Tags Reference
|
||||
|
||||
The user has configured the following list of **meaningful tag categories** for this model type (`{{model_type}}`):
|
||||
|
||||
```
|
||||
{{priority_tags}}
|
||||
```
|
||||
|
||||
These are the subjects, styles, and concepts the user considers useful for categorization. Use this list as a **reference** when evaluating tags (see the **tags** section below).
|
||||
|
||||
## Available Base Models
|
||||
|
||||
The following base models are currently valid in this system:
|
||||
@@ -37,7 +47,7 @@ The following base models are currently valid in this system:
|
||||
Extract the following information from the README content above:
|
||||
|
||||
### base_model
|
||||
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
|
||||
The base model this model was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
|
||||
|
||||
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
|
||||
|
||||
@@ -46,17 +56,27 @@ The trigger words or activation prompts needed to use this LoRA. Look for:
|
||||
- `instance_prompt:` in the YAML frontmatter
|
||||
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
|
||||
- Example prompts at the start (usually the first word or phrase before any description)
|
||||
Return as an array of strings. If none found, return an empty array.
|
||||
Return as an array of strings. If none found, return an empty array `[]`. **Never** return `["None"]` or any placeholder value — a truly empty list means no trigger words exist.
|
||||
|
||||
### description
|
||||
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
|
||||
|
||||
### tags
|
||||
3-8 relevant tags for categorizing this model. Extract from:
|
||||
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
|
||||
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl")
|
||||
- The style/subject (e.g. "anime", "photorealistic", "style", "character")
|
||||
All lowercase, no spaces. Return empty array if none found.
|
||||
3-8 relevant tags for categorizing this model. **Quality over quantity.**
|
||||
|
||||
Sources to consider:
|
||||
- The YAML frontmatter `tags:` list
|
||||
- The subject, style, character, or concept the model represents
|
||||
|
||||
**Critical filtering rules — apply them strictly:**
|
||||
|
||||
1. **Exclude technical/generic tags.** Reject any tag that describes the model's **training methodology, framework, architecture, or modality** rather than its content. Examples to exclude: `text-to-image`, `diffusers`, `lora`, `dreambooth`, `diffusers-training`, `flux`, `sdxl`, `checkpoint`, `pytorch`, `safetensors`, `fine-tuning`, `stable-diffusion`, and any variant of these.
|
||||
|
||||
2. **Cross-reference against the priority_tags reference.** Only include a tag if it meaningfully describes what the model actually creates (subject, style, character type) and is semantically close to one of the priority_tags. If none of the README's tags match meaningful categories, prefer returning a smaller set or an empty array over including low-value tags.
|
||||
|
||||
3. **All lowercase, no spaces, no hyphens** (use single words like `"photorealistic"`, `"anime"`, `"character"`).
|
||||
|
||||
Return empty array if no meaningful content tags remain after filtering.
|
||||
|
||||
### preview_url
|
||||
The URL of the most suitable preview image from the README. Look for image tags (e.g. ``) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
|
||||
@@ -87,3 +107,4 @@ Important:
|
||||
- Only include the JSON object, no other text
|
||||
- If a field cannot be determined, use an empty string or empty array
|
||||
- Do not fabricate information not supported by the README
|
||||
- Never use placeholder values like `"None"` or `"unknown"` for missing data — use empty string or empty array
|
||||
|
||||
@@ -227,11 +227,11 @@ class TestApplyMetadataUpdates:
|
||||
class TestDownloadPreview:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_url_returns_false(self, tmp_path):
|
||||
async def test_empty_url_returns_none(self, tmp_path):
|
||||
mp = tmp_path / "m.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
assert await download_preview(str(mp), "") is False
|
||||
assert await download_preview(str(mp), " ") is False
|
||||
assert await download_preview(str(mp), "") is None
|
||||
assert await download_preview(str(mp), " ") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_download_and_optimise(self, tmp_path):
|
||||
@@ -246,12 +246,12 @@ class TestDownloadPreview:
|
||||
get_dl.return_value = dl
|
||||
exif.optimize_image.return_value = (b"optimized_webp", {})
|
||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||
assert result is True
|
||||
assert result == str(tmp_path / "t.webp")
|
||||
assert (tmp_path / "t.webp").exists()
|
||||
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_returns_false(self, tmp_path):
|
||||
async def test_download_failure_returns_none(self, tmp_path):
|
||||
mp = tmp_path / "t.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
with mock.patch("py.services.downloader.get_downloader") as get_dl:
|
||||
@@ -260,7 +260,7 @@ class TestDownloadPreview:
|
||||
dl.download_file = mock.AsyncMock(return_value=(False, None))
|
||||
get_dl.return_value = dl
|
||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||
assert result is False
|
||||
assert result is None
|
||||
assert not (tmp_path / "t.webp").exists()
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestProcessDispatch:
|
||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||
):
|
||||
mock_apply.return_value = ["metadata_source"]
|
||||
mock_dl.return_value = False
|
||||
mock_dl.return_value = None
|
||||
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
@@ -155,7 +155,7 @@ class TestEnrichHfMetadata:
|
||||
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.download_preview", return_value=None),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
@@ -165,7 +165,7 @@ class TestEnrichHfMetadata:
|
||||
metadata={"trainedWords": []},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["trainedWords"] == ["trigger1", "trigger2"]
|
||||
assert applied["civitai"]["trainedWords"] == ["trigger1", "trigger2"]
|
||||
|
||||
# -- description -----------------------------------------------------
|
||||
|
||||
@@ -237,7 +237,7 @@ class TestEnrichHfMetadata:
|
||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
mock_dl.return_value = True
|
||||
mock_dl.return_value = "/p.webp"
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
@@ -246,6 +246,8 @@ class TestEnrichHfMetadata:
|
||||
)
|
||||
assert result["preview_downloaded"] is True
|
||||
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["preview_url"] == "/p.webp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_skipped_when_exists(self, processor):
|
||||
|
||||
Reference in New Issue
Block a user