mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-05 17:01:16 -03:00
feat(agent): improve enrich_hf_metadata skill with priority_tags, preview_url fix, civitai.trainedWords
- Add identify_model_type() helper to determine lora/checkpoint/embedding - Pass priority_tags from user settings to LLM prompt for tag relevance - SKILL.md: instruct LLM to exclude technical/generic HF tags, cross-reference against priority_tags; forbid ['None'] placeholder for trigger words - post_processor: fix preview_url not updated after download (now writes local .webp path to metadata); write trigger words to civitai.trainedWords instead of top-level; sanitize ['None']/'null'/'n/a' placeholder values to [] - download_preview() now returns str | None (local path) instead of bool - Update tests for new return type and nested civitai.trainedWords structure
This commit is contained in:
@@ -32,6 +32,15 @@ logger = logging.getLogger(__name__)
|
|||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
SCANNER_TYPE_MAP: dict[str, str] = {
|
||||||
|
"get_lora_scanner": "lora",
|
||||||
|
"get_checkpoint_scanner": "checkpoint",
|
||||||
|
"get_embedding_scanner": "embedding",
|
||||||
|
}
|
||||||
|
|
||||||
|
SCANNER_GETTER_NAMES = tuple(SCANNER_TYPE_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
async def _find_scanner_for_model(
|
async def _find_scanner_for_model(
|
||||||
model_path: str,
|
model_path: str,
|
||||||
) -> tuple[object, object] | tuple[None, None]:
|
) -> tuple[object, object] | tuple[None, None]:
|
||||||
@@ -44,11 +53,7 @@ async def _find_scanner_for_model(
|
|||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
normalized = os.path.normpath(model_path)
|
normalized = os.path.normpath(model_path)
|
||||||
for getter_name in (
|
for getter_name in SCANNER_GETTER_NAMES:
|
||||||
"get_lora_scanner",
|
|
||||||
"get_checkpoint_scanner",
|
|
||||||
"get_embedding_scanner",
|
|
||||||
):
|
|
||||||
getter = getattr(ServiceRegistry, getter_name, None)
|
getter = getattr(ServiceRegistry, getter_name, None)
|
||||||
if getter is None:
|
if getter is None:
|
||||||
continue
|
continue
|
||||||
@@ -70,6 +75,38 @@ async def _find_scanner_for_model(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
async def identify_model_type(model_path: str) -> str:
|
||||||
|
"""Determine the model type (``\"lora\"``, ``\"checkpoint\"``, or
|
||||||
|
``\"embedding\"``) for *model_path*.
|
||||||
|
|
||||||
|
Iterates all known scanners; the first scanner that claims the path
|
||||||
|
determines the type. Falls back to ``\"lora\"`` when unknown.
|
||||||
|
"""
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
|
normalized = os.path.normpath(model_path)
|
||||||
|
for getter_name in SCANNER_GETTER_NAMES:
|
||||||
|
getter = getattr(ServiceRegistry, getter_name, None)
|
||||||
|
if getter is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
scanner = await getter()
|
||||||
|
if scanner is None:
|
||||||
|
continue
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
for entry in cache.raw_data:
|
||||||
|
if os.path.normpath(entry.get("file_path", "")) == normalized:
|
||||||
|
return SCANNER_TYPE_MAP[getter_name]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"identify_model_type scanner %s error for %s: %s",
|
||||||
|
getter_name,
|
||||||
|
model_path,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return "lora"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Public API
|
# Public API
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -153,17 +190,17 @@ async def download_preview(
|
|||||||
*,
|
*,
|
||||||
target_width: int = 480,
|
target_width: int = 480,
|
||||||
quality: int = 85,
|
quality: int = 85,
|
||||||
) -> bool:
|
) -> str | None:
|
||||||
"""Download a preview image from *url*, optimise to .webp, and save it.
|
"""Download a preview image from *url*, optimise to .webp, and save it.
|
||||||
|
|
||||||
The output file is placed alongside the model file with a ``.webp``
|
The output file is placed alongside the model file with a ``.webp``
|
||||||
extension. Returns ``True`` on success.
|
extension. Returns the local file path on success, ``None`` on failure.
|
||||||
"""
|
"""
|
||||||
from ..services.downloader import get_downloader
|
from ..services.downloader import get_downloader
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
|
|
||||||
if not url or not url.strip():
|
if not url or not url.strip():
|
||||||
return False
|
return None
|
||||||
|
|
||||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||||
preview_dir = os.path.dirname(model_path)
|
preview_dir = os.path.dirname(model_path)
|
||||||
@@ -187,7 +224,7 @@ async def download_preview(
|
|||||||
with open(output_path, "wb") as f:
|
with open(output_path, "wb") as f:
|
||||||
f.write(optimized_data)
|
f.write(optimized_data)
|
||||||
logger.info("Preview downloaded and optimised for %s", model_path)
|
logger.info("Preview downloaded and optimised for %s", model_path)
|
||||||
return True
|
return output_path
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Preview optimisation failed, saving raw: %s", exc)
|
logger.warning("Preview optimisation failed, saving raw: %s", exc)
|
||||||
# Fall through to raw save
|
# Fall through to raw save
|
||||||
@@ -197,11 +234,11 @@ async def download_preview(
|
|||||||
ok, _ = await downloader.download_file(url, output_path, use_auth=False)
|
ok, _ = await downloader.download_file(url, output_path, use_auth=False)
|
||||||
if ok:
|
if ok:
|
||||||
logger.info("Preview downloaded (fallback) for %s", model_path)
|
logger.info("Preview downloaded (fallback) for %s", model_path)
|
||||||
return True
|
return output_path
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
|
logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
|
||||||
|
|
||||||
return False
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def refresh_cache(model_path: str) -> bool:
|
async def refresh_cache(model_path: str) -> bool:
|
||||||
|
|||||||
@@ -334,10 +334,11 @@ class AgentService:
|
|||||||
"""Gather variables for the skill's prompt template.
|
"""Gather variables for the skill's prompt template.
|
||||||
|
|
||||||
Reads metadata, fetches the HF README (if applicable), lists available
|
Reads metadata, fetches the HF README (if applicable), lists available
|
||||||
base models, and returns a dict that maps to ``{{variable}}``
|
base models, loads user priority tags, and returns a dict that maps to
|
||||||
placeholders in ``prompt.md``.
|
``{{variable}}`` placeholders in ``prompt.md``.
|
||||||
"""
|
"""
|
||||||
from ...agent_cli import list_base_models
|
from ...agent_cli import identify_model_type, list_base_models
|
||||||
|
from ..settings_manager import SettingsManager
|
||||||
|
|
||||||
context: Dict[str, Any] = {
|
context: Dict[str, Any] = {
|
||||||
"model_path": model_path,
|
"model_path": model_path,
|
||||||
@@ -346,6 +347,7 @@ class AgentService:
|
|||||||
"readme_content": "",
|
"readme_content": "",
|
||||||
"current_metadata": {},
|
"current_metadata": {},
|
||||||
"base_models": [],
|
"base_models": [],
|
||||||
|
"priority_tags": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
context["current_metadata"] = {
|
context["current_metadata"] = {
|
||||||
@@ -371,6 +373,18 @@ class AgentService:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.debug("Failed to list base models: %s", exc)
|
logger.debug("Failed to list base models: %s", exc)
|
||||||
|
|
||||||
|
# Determine model type and load the corresponding priority_tags
|
||||||
|
try:
|
||||||
|
model_type = await identify_model_type(model_path)
|
||||||
|
context["model_type"] = model_type
|
||||||
|
settings = SettingsManager()
|
||||||
|
priority_config = settings.get_priority_tag_config()
|
||||||
|
context["priority_tags"] = priority_config.get(model_type, "")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Failed to load priority tags: %s", exc)
|
||||||
|
context["model_type"] = "lora"
|
||||||
|
context["priority_tags"] = ""
|
||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -86,14 +86,16 @@ class PostProcessor:
|
|||||||
if new_base and self._should_overwrite(current_base, is_hf_model):
|
if new_base and self._should_overwrite(current_base, is_hf_model):
|
||||||
updates["base_model"] = new_base
|
updates["base_model"] = new_base
|
||||||
|
|
||||||
# trainedWords / trigger words
|
|
||||||
new_triggers = llm_output.get("trigger_words", [])
|
new_triggers = llm_output.get("trigger_words", [])
|
||||||
if isinstance(new_triggers, list):
|
if isinstance(new_triggers, list):
|
||||||
cleaned = [t.strip() for t in new_triggers if t.strip()]
|
cleaned = [t.strip() for t in new_triggers if t.strip()]
|
||||||
if cleaned:
|
cleaned = [t for t in cleaned if t.lower() not in ("none", "null", "n/a")]
|
||||||
current_triggers = metadata.get("trainedWords") or []
|
current_civitai = metadata.get("civitai") or {}
|
||||||
if self._should_overwrite_list(current_triggers, is_hf_model):
|
current_triggers = current_civitai.get("trainedWords") or []
|
||||||
updates["trainedWords"] = cleaned
|
if self._should_overwrite_list(current_triggers, is_hf_model):
|
||||||
|
civitai_updates = dict(current_civitai)
|
||||||
|
civitai_updates["trainedWords"] = cleaned
|
||||||
|
updates["civitai"] = civitai_updates
|
||||||
|
|
||||||
# modelDescription
|
# modelDescription
|
||||||
new_desc = (llm_output.get("description") or "").strip()
|
new_desc = (llm_output.get("description") or "").strip()
|
||||||
@@ -102,7 +104,7 @@ class PostProcessor:
|
|||||||
if self._should_overwrite(current_desc, is_hf_model):
|
if self._should_overwrite(current_desc, is_hf_model):
|
||||||
updates["modelDescription"] = new_desc
|
updates["modelDescription"] = new_desc
|
||||||
|
|
||||||
# tags — merge with existing, deduplicate (case-insensitive)
|
# tags
|
||||||
new_tags = llm_output.get("tags", [])
|
new_tags = llm_output.get("tags", [])
|
||||||
if isinstance(new_tags, list) and new_tags:
|
if isinstance(new_tags, list) and new_tags:
|
||||||
existing_tags = metadata.get("tags") or []
|
existing_tags = metadata.get("tags") or []
|
||||||
@@ -114,16 +116,17 @@ class PostProcessor:
|
|||||||
updates["metadata_source"] = "agent:enrich_hf_metadata"
|
updates["metadata_source"] = "agent:enrich_hf_metadata"
|
||||||
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
|
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
# -- Persist updates ------------------------------------------------
|
preview_remote_url = (llm_output.get("preview_url") or "").strip()
|
||||||
|
current_preview = metadata.get("preview_url") or ""
|
||||||
|
if preview_remote_url and not (current_preview and os.path.exists(current_preview)):
|
||||||
|
local_path = await download_preview(model_path, preview_remote_url)
|
||||||
|
if local_path:
|
||||||
|
preview_downloaded = True
|
||||||
|
updates["preview_url"] = local_path
|
||||||
|
|
||||||
if updates:
|
if updates:
|
||||||
updated_fields = await apply_metadata_updates(model_path, updates)
|
updated_fields = await apply_metadata_updates(model_path, updates)
|
||||||
|
|
||||||
# -- Download preview -----------------------------------------------
|
|
||||||
preview_url = (llm_output.get("preview_url") or "").strip()
|
|
||||||
current_preview = metadata.get("preview_url") or ""
|
|
||||||
if preview_url and not (current_preview and os.path.exists(current_preview)):
|
|
||||||
preview_downloaded = await download_preview(model_path, preview_url)
|
|
||||||
|
|
||||||
# -- Refresh scanner cache ------------------------------------------
|
# -- Refresh scanner cache ------------------------------------------
|
||||||
if updated_fields or preview_downloaded:
|
if updated_fields or preview_downloaded:
|
||||||
await refresh_cache(model_path)
|
await refresh_cache(model_path)
|
||||||
|
|||||||
@@ -21,6 +21,16 @@ You are an expert assistant for AI image generation models. Your task is to extr
|
|||||||
{{current_metadata}}
|
{{current_metadata}}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## User Priority Tags Reference
|
||||||
|
|
||||||
|
The user has configured the following list of **meaningful tag categories** for this model type (`{{model_type}}`):
|
||||||
|
|
||||||
|
```
|
||||||
|
{{priority_tags}}
|
||||||
|
```
|
||||||
|
|
||||||
|
These are the subjects, styles, and concepts the user considers useful for categorization. Use this list as a **reference** when evaluating tags (see the **tags** section below).
|
||||||
|
|
||||||
## Available Base Models
|
## Available Base Models
|
||||||
|
|
||||||
The following base models are currently valid in this system:
|
The following base models are currently valid in this system:
|
||||||
@@ -37,7 +47,7 @@ The following base models are currently valid in this system:
|
|||||||
Extract the following information from the README content above:
|
Extract the following information from the README content above:
|
||||||
|
|
||||||
### base_model
|
### base_model
|
||||||
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
|
The base model this model was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
|
||||||
|
|
||||||
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
|
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
|
||||||
|
|
||||||
@@ -46,17 +56,27 @@ The trigger words or activation prompts needed to use this LoRA. Look for:
|
|||||||
- `instance_prompt:` in the YAML frontmatter
|
- `instance_prompt:` in the YAML frontmatter
|
||||||
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
|
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
|
||||||
- Example prompts at the start (usually the first word or phrase before any description)
|
- Example prompts at the start (usually the first word or phrase before any description)
|
||||||
Return as an array of strings. If none found, return an empty array.
|
Return as an array of strings. If none found, return an empty array `[]`. **Never** return `["None"]` or any placeholder value — a truly empty list means no trigger words exist.
|
||||||
|
|
||||||
### description
|
### description
|
||||||
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
|
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
|
||||||
|
|
||||||
### tags
|
### tags
|
||||||
3-8 relevant tags for categorizing this model. Extract from:
|
3-8 relevant tags for categorizing this model. **Quality over quantity.**
|
||||||
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
|
|
||||||
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl")
|
Sources to consider:
|
||||||
- The style/subject (e.g. "anime", "photorealistic", "style", "character")
|
- The YAML frontmatter `tags:` list
|
||||||
All lowercase, no spaces. Return empty array if none found.
|
- The subject, style, character, or concept the model represents
|
||||||
|
|
||||||
|
**Critical filtering rules — apply them strictly:**
|
||||||
|
|
||||||
|
1. **Exclude technical/generic tags.** Reject any tag that describes the model's **training methodology, framework, architecture, or modality** rather than its content. Examples to exclude: `text-to-image`, `diffusers`, `lora`, `dreambooth`, `diffusers-training`, `flux`, `sdxl`, `checkpoint`, `pytorch`, `safetensors`, `fine-tuning`, `stable-diffusion`, and any variant of these.
|
||||||
|
|
||||||
|
2. **Cross-reference against the priority_tags reference.** Only include a tag if it meaningfully describes what the model actually creates (subject, style, character type) and is semantically close to one of the priority_tags. If none of the README's tags match meaningful categories, prefer returning a smaller set or an empty array over including low-value tags.
|
||||||
|
|
||||||
|
3. **All lowercase, no spaces, no hyphens** (use single words like `"photorealistic"`, `"anime"`, `"character"`).
|
||||||
|
|
||||||
|
Return empty array if no meaningful content tags remain after filtering.
|
||||||
|
|
||||||
### preview_url
|
### preview_url
|
||||||
The URL of the most suitable preview image from the README. Look for image tags (e.g. ``) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
|
The URL of the most suitable preview image from the README. Look for image tags (e.g. ``) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
|
||||||
@@ -87,3 +107,4 @@ Important:
|
|||||||
- Only include the JSON object, no other text
|
- Only include the JSON object, no other text
|
||||||
- If a field cannot be determined, use an empty string or empty array
|
- If a field cannot be determined, use an empty string or empty array
|
||||||
- Do not fabricate information not supported by the README
|
- Do not fabricate information not supported by the README
|
||||||
|
- Never use placeholder values like `"None"` or `"unknown"` for missing data — use empty string or empty array
|
||||||
|
|||||||
@@ -227,11 +227,11 @@ class TestApplyMetadataUpdates:
|
|||||||
class TestDownloadPreview:
|
class TestDownloadPreview:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_url_returns_false(self, tmp_path):
|
async def test_empty_url_returns_none(self, tmp_path):
|
||||||
mp = tmp_path / "m.safetensors"
|
mp = tmp_path / "m.safetensors"
|
||||||
mp.write_bytes(b"fake")
|
mp.write_bytes(b"fake")
|
||||||
assert await download_preview(str(mp), "") is False
|
assert await download_preview(str(mp), "") is None
|
||||||
assert await download_preview(str(mp), " ") is False
|
assert await download_preview(str(mp), " ") is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_successful_download_and_optimise(self, tmp_path):
|
async def test_successful_download_and_optimise(self, tmp_path):
|
||||||
@@ -246,12 +246,12 @@ class TestDownloadPreview:
|
|||||||
get_dl.return_value = dl
|
get_dl.return_value = dl
|
||||||
exif.optimize_image.return_value = (b"optimized_webp", {})
|
exif.optimize_image.return_value = (b"optimized_webp", {})
|
||||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||||
assert result is True
|
assert result == str(tmp_path / "t.webp")
|
||||||
assert (tmp_path / "t.webp").exists()
|
assert (tmp_path / "t.webp").exists()
|
||||||
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
|
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_download_failure_returns_false(self, tmp_path):
|
async def test_download_failure_returns_none(self, tmp_path):
|
||||||
mp = tmp_path / "t.safetensors"
|
mp = tmp_path / "t.safetensors"
|
||||||
mp.write_bytes(b"fake")
|
mp.write_bytes(b"fake")
|
||||||
with mock.patch("py.services.downloader.get_downloader") as get_dl:
|
with mock.patch("py.services.downloader.get_downloader") as get_dl:
|
||||||
@@ -260,7 +260,7 @@ class TestDownloadPreview:
|
|||||||
dl.download_file = mock.AsyncMock(return_value=(False, None))
|
dl.download_file = mock.AsyncMock(return_value=(False, None))
|
||||||
get_dl.return_value = dl
|
get_dl.return_value = dl
|
||||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||||
assert result is False
|
assert result is None
|
||||||
assert not (tmp_path / "t.webp").exists()
|
assert not (tmp_path / "t.webp").exists()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ class TestProcessDispatch:
|
|||||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||||
):
|
):
|
||||||
mock_apply.return_value = ["metadata_source"]
|
mock_apply.return_value = ["metadata_source"]
|
||||||
mock_dl.return_value = False
|
mock_dl.return_value = None
|
||||||
|
|
||||||
result = await processor.process(
|
result = await processor.process(
|
||||||
skill_name="enrich_hf_metadata",
|
skill_name="enrich_hf_metadata",
|
||||||
@@ -155,7 +155,7 @@ class TestEnrichHfMetadata:
|
|||||||
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
|
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
|
||||||
with (
|
with (
|
||||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
mock.patch("py.agent_cli.download_preview", return_value=None),
|
||||||
mock.patch("py.agent_cli.refresh_cache"),
|
mock.patch("py.agent_cli.refresh_cache"),
|
||||||
):
|
):
|
||||||
await processor.process(
|
await processor.process(
|
||||||
@@ -165,7 +165,7 @@ class TestEnrichHfMetadata:
|
|||||||
metadata={"trainedWords": []},
|
metadata={"trainedWords": []},
|
||||||
)
|
)
|
||||||
applied = mock_apply.call_args[0][1]
|
applied = mock_apply.call_args[0][1]
|
||||||
assert applied["trainedWords"] == ["trigger1", "trigger2"]
|
assert applied["civitai"]["trainedWords"] == ["trigger1", "trigger2"]
|
||||||
|
|
||||||
# -- description -----------------------------------------------------
|
# -- description -----------------------------------------------------
|
||||||
|
|
||||||
@@ -237,7 +237,7 @@ class TestEnrichHfMetadata:
|
|||||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||||
mock.patch("py.agent_cli.refresh_cache"),
|
mock.patch("py.agent_cli.refresh_cache"),
|
||||||
):
|
):
|
||||||
mock_dl.return_value = True
|
mock_dl.return_value = "/p.webp"
|
||||||
result = await processor.process(
|
result = await processor.process(
|
||||||
skill_name="enrich_hf_metadata",
|
skill_name="enrich_hf_metadata",
|
||||||
model_path="/p.safetensors",
|
model_path="/p.safetensors",
|
||||||
@@ -246,6 +246,8 @@ class TestEnrichHfMetadata:
|
|||||||
)
|
)
|
||||||
assert result["preview_downloaded"] is True
|
assert result["preview_downloaded"] is True
|
||||||
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
|
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
|
||||||
|
applied = mock_apply.call_args[0][1]
|
||||||
|
assert applied["preview_url"] == "/p.webp"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_preview_skipped_when_exists(self, processor):
|
async def test_preview_skipped_when_exists(self, processor):
|
||||||
|
|||||||
Reference in New Issue
Block a user