mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-03 07:51:16 -03:00
feat(agent): add LLM-powered metadata enrichment system with AgentCLI and PostProcessor
Introduce an agent skill framework for LLM-driven metadata enrichment: - AgentCLI (py/agent_cli/): in-process wrappers around internal services using standard relative imports, eliminating the need for sys.path hacks - LLMService: centralized BYOK (bring-your-own-key) LLM client supporting OpenAI, Ollama, and custom OpenAI-compatible endpoints - PostProcessor: deterministic engine that applies LLM output via AgentCLI (replaces old handler.py + _BASE_MODEL_ALIASES approach) - SkillRegistry: filesystem-based skill discovery (skill.yaml + prompt.md) - AgentService: orchestrates skill execution with WebSocket progress - Frontend AgentManager: WebSocket listeners, skill execution, config UI - Context menu entries (single + bulk) for "Enrich Metadata (Agent)" - Settings UI for AI Provider configuration (BYOK) - Full i18n support across 9 locales Bug fixes found during review: - aiohttp.web.json_response: status_code= -> status= - settings_modal cancelEditApiKey: wrong argument position - AgentManager.isLlmConfigured: allow Ollama without API key - PostProcessor._merge_tags: lowercase all tags to match TagUpdateService
This commit is contained in:
23
py/services/agent/__init__.py
Normal file
23
py/services/agent/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Agent-powered skill system for LoRA Manager.
|
||||
|
||||
This package provides the orchestration layer for LLM/agent-powered features.
|
||||
Skills define *what* to do (prompt template). The :class:`AgentService`
|
||||
handles *how* (LLM calls, context gathering, validation, progress).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .skill_definition import SkillDefinition, SkillPermissions
|
||||
from .skill_registry import SkillRegistry
|
||||
from .agent_service import AgentService, AgentProgressReporter, SkillResult
|
||||
from .post_processor import PostProcessor
|
||||
|
||||
__all__ = [
|
||||
"AgentProgressReporter",
|
||||
"AgentService",
|
||||
"PostProcessor",
|
||||
"SkillDefinition",
|
||||
"SkillPermissions",
|
||||
"SkillRegistry",
|
||||
"SkillResult",
|
||||
]
|
||||
413
py/services/agent/agent_service.py
Normal file
413
py/services/agent/agent_service.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""Agent orchestration service.
|
||||
|
||||
The :class:`AgentService` coordinates skill execution:
|
||||
|
||||
1. Look up the skill in :class:`SkillRegistry`
|
||||
2. Validate input against the skill's ``input_schema``
|
||||
3. Prepare context via :mod:`~py.agent_cli` (read metadata, list base models, fetch HF README)
|
||||
4. If ``llm_required``: call :class:`LLMService` with the rendered prompt
|
||||
5. Post-process via :class:`PostProcessor` (delegates I/O to :mod:`~py.agent_cli`)
|
||||
6. Broadcast progress and completion via :class:`WebSocketManager`
|
||||
|
||||
Skills define *what* to do (prompt template). The AgentService handles *how*
|
||||
(LLM calls, context gathering, validation, progress).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..llm_service import LLMService
|
||||
from ..websocket_manager import ws_manager
|
||||
from .post_processor import PostProcessor
|
||||
from .skill_registry import SkillRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentProgressReporter:
|
||||
"""Protocol-compatible progress reporter backed by WebSocket broadcast."""
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
await ws_manager.broadcast(payload)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillResult:
|
||||
"""Outcome of a skill execution."""
|
||||
|
||||
success: bool
|
||||
updated_models: List[Dict[str, Any]] = field(default_factory=list)
|
||||
errors: List[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
|
||||
|
||||
def _validate_schema(data: Any, schema: Dict[str, Any], path: str = "") -> List[str]:
|
||||
"""Minimal JSON schema validator.
|
||||
|
||||
Supports a subset of JSON Schema: ``type``, ``properties``, ``required``,
|
||||
``items``, ``enum``. Returns a list of error messages (empty = valid).
|
||||
"""
|
||||
|
||||
errors: List[str] = []
|
||||
if not schema:
|
||||
return errors
|
||||
|
||||
expected_type = schema.get("type")
|
||||
if expected_type:
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": (int, float),
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
expected_py = type_map.get(expected_type)
|
||||
if expected_py is not None and not isinstance(data, expected_py):
|
||||
errors.append(f"{path or 'root'}: expected {expected_type}, got {type(data).__name__}")
|
||||
return errors
|
||||
|
||||
if expected_type == "object" and isinstance(data, dict):
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
for req_key in required:
|
||||
if req_key not in data:
|
||||
errors.append(f"{path or 'root'}: missing required property '{req_key}'")
|
||||
for key, value in data.items():
|
||||
if key in properties:
|
||||
errors.extend(_validate_schema(value, properties[key], f"{path}.{key}"))
|
||||
|
||||
if expected_type == "array" and isinstance(data, list):
|
||||
items_schema = schema.get("items")
|
||||
if items_schema:
|
||||
for i, item in enumerate(data):
|
||||
errors.extend(_validate_schema(item, items_schema, f"{path}[{i}]"))
|
||||
|
||||
if "enum" in schema and data not in schema["enum"]:
|
||||
errors.append(f"{path or 'root'}: value '{data}' not in enum {schema['enum']}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt template rendering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_prompt(template: str, variables: Dict[str, Any]) -> str:
|
||||
"""Render a prompt template with ``{{variable}}`` placeholders.
|
||||
|
||||
Uses simple regex substitution — no Jinja2 dependency needed.
|
||||
"""
|
||||
|
||||
def replace(match: re.Match) -> str:
|
||||
key = match.group(1).strip()
|
||||
value = variables.get(key, "")
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False, indent=2)
|
||||
return str(value)
|
||||
|
||||
return re.sub(r"\{\{(\w+)\}\}", replace, template)
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""Orchestrate agent skill execution.
|
||||
|
||||
Usage::
|
||||
|
||||
service = await AgentService.get_instance()
|
||||
result = await service.execute_skill(
|
||||
skill_name="enrich_hf_metadata",
|
||||
input_data={"model_paths": ["/path/to/model.safetensors"]},
|
||||
progress_callback=AgentProgressReporter(),
|
||||
)
|
||||
"""
|
||||
|
||||
_instance: Optional["AgentService"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
skill_registry: Optional[SkillRegistry] = None,
|
||||
llm_service: Optional[LLMService] = None,
|
||||
) -> None:
|
||||
self._registry = skill_registry
|
||||
self._llm_service = llm_service
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "AgentService":
|
||||
"""Return the lazily-initialised global ``AgentService``."""
|
||||
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(
|
||||
skill_registry=await SkillRegistry.get_instance(),
|
||||
llm_service=await LLMService.get_instance(),
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the cached singleton — primarily for tests."""
|
||||
|
||||
cls._instance = None
|
||||
|
||||
async def _ensure_registry(self) -> SkillRegistry:
|
||||
if self._registry is None:
|
||||
self._registry = await SkillRegistry.get_instance()
|
||||
return self._registry
|
||||
|
||||
async def _ensure_llm(self) -> LLMService:
|
||||
if self._llm_service is None:
|
||||
self._llm_service = await LLMService.get_instance()
|
||||
return self._llm_service
|
||||
|
||||
async def list_skills(self) -> List[Dict[str, Any]]:
|
||||
"""Return a JSON-serialisable list of available skills."""
|
||||
|
||||
registry = await self._ensure_registry()
|
||||
return [
|
||||
{
|
||||
"name": s.name,
|
||||
"title": s.title,
|
||||
"description": s.description,
|
||||
"llm_required": s.llm_required,
|
||||
"model_type_filter": s.model_type_filter,
|
||||
}
|
||||
for s in registry.list_skills()
|
||||
]
|
||||
|
||||
async def execute_skill(
|
||||
self,
|
||||
*,
|
||||
skill_name: str,
|
||||
input_data: Dict[str, Any],
|
||||
progress_callback: Optional[AgentProgressReporter] = None,
|
||||
) -> SkillResult:
|
||||
"""Execute an agent skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill to execute
|
||||
input_data: Input validated against the skill's ``input_schema``
|
||||
progress_callback: Optional WebSocket progress reporter
|
||||
|
||||
Returns:
|
||||
:class:`SkillResult` with success status and updated model info
|
||||
"""
|
||||
|
||||
registry = await self._ensure_registry()
|
||||
logger.info("execute_skill '%s': looking up skill", skill_name)
|
||||
skill = registry.get_skill(skill_name)
|
||||
if skill is None:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=[f"Skill not found: {skill_name}"],
|
||||
summary=f"Skill '{skill_name}' does not exist",
|
||||
)
|
||||
|
||||
input_errors = _validate_schema(input_data, skill.input_schema)
|
||||
if input_errors:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=input_errors,
|
||||
summary=f"Invalid input: {'; '.join(input_errors)}",
|
||||
)
|
||||
|
||||
model_paths = input_data.get("model_paths", [])
|
||||
if not model_paths:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=["No model_paths provided"],
|
||||
summary="No models to process",
|
||||
)
|
||||
|
||||
total = len(model_paths)
|
||||
processed = 0
|
||||
success_count = 0
|
||||
updated_models: List[Dict[str, Any]] = []
|
||||
errors: List[str] = []
|
||||
post_processor = PostProcessor()
|
||||
|
||||
logger.info("execute_skill '%s': starting with %d model(s)", skill_name, total)
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="started",
|
||||
total=total, processed=0, success=0,
|
||||
)
|
||||
|
||||
llm = await self._ensure_llm()
|
||||
llm_configured = llm.is_configured() if skill.llm_required else True
|
||||
|
||||
for model_path in model_paths:
|
||||
logger.info(
|
||||
"execute_skill '%s': processing model %d/%d: %s",
|
||||
skill_name, processed + 1, total, model_path,
|
||||
)
|
||||
try:
|
||||
from ...agent_cli import read_metadata
|
||||
metadata = await read_metadata(model_path)
|
||||
|
||||
prompt_vars: Dict[str, Any] = {"model_path": model_path}
|
||||
if skill.llm_required and llm_configured:
|
||||
prompt_vars = await self._build_prompt_context(
|
||||
skill_name, model_path, metadata, registry, llm,
|
||||
)
|
||||
|
||||
llm_response: Optional[Dict[str, Any]] = None
|
||||
if skill.llm_required and llm_configured:
|
||||
prompt_template = registry.load_prompt(skill_name)
|
||||
rendered = _render_prompt(prompt_template, prompt_vars)
|
||||
logger.info(
|
||||
"execute_skill '%s': LLM call for %s (prompt=%d chars)",
|
||||
skill_name, model_path, len(rendered),
|
||||
)
|
||||
llm_response = await llm.chat_completion_json(
|
||||
system_prompt=prompt_vars.get(
|
||||
"system_prompt",
|
||||
"You are a helpful assistant that extracts structured metadata.",
|
||||
),
|
||||
user_prompt=rendered,
|
||||
)
|
||||
|
||||
model_result = await post_processor.process(
|
||||
skill_name=skill_name,
|
||||
model_path=model_path,
|
||||
llm_output=llm_response or {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if model_result.get("success", True):
|
||||
success_count += 1
|
||||
uf = model_result.get("updated_fields", [])
|
||||
if uf:
|
||||
updated_models.append({"path": model_path, "updated_fields": uf})
|
||||
else:
|
||||
errors.extend(
|
||||
model_result.get("errors", [model_result.get("error", "Unknown error")])
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Skill %s failed for %s: %s", skill_name, model_path, exc)
|
||||
errors.append(f"{model_path}: {exc}")
|
||||
|
||||
processed += 1
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="processing",
|
||||
total=total, processed=processed, success=success_count,
|
||||
current_path=model_path,
|
||||
)
|
||||
|
||||
result = SkillResult(
|
||||
success=success_count > 0,
|
||||
updated_models=updated_models,
|
||||
errors=errors,
|
||||
summary=f"Processed {processed}/{total} models, {success_count} succeeded",
|
||||
)
|
||||
|
||||
logger.info("execute_skill '%s': done — %s", skill_name, result.summary)
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="completed",
|
||||
total=total, processed=processed, success=success_count,
|
||||
updated_models=updated_models, errors=errors, summary=result.summary,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _build_prompt_context(
|
||||
self,
|
||||
skill_name: str,
|
||||
model_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
registry: SkillRegistry,
|
||||
llm: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gather variables for the skill's prompt template.
|
||||
|
||||
Reads metadata, fetches the HF README (if applicable), lists available
|
||||
base models, and returns a dict that maps to ``{{variable}}``
|
||||
placeholders in ``prompt.md``.
|
||||
"""
|
||||
from ...agent_cli import list_base_models
|
||||
|
||||
context: Dict[str, Any] = {
|
||||
"model_path": model_path,
|
||||
"hf_url": "",
|
||||
"repo": "",
|
||||
"readme_content": "",
|
||||
"current_metadata": {},
|
||||
"base_models": [],
|
||||
}
|
||||
|
||||
context["current_metadata"] = {
|
||||
"file_name": metadata.get("file_name", ""),
|
||||
"base_model": metadata.get("base_model", ""),
|
||||
"tags": metadata.get("tags", []),
|
||||
"modelDescription": metadata.get("modelDescription", ""),
|
||||
"trainedWords": metadata.get("trainedWords", []),
|
||||
"sha256": (metadata.get("sha256") or "")[:16] + "..." if metadata.get("sha256") else "",
|
||||
"size": metadata.get("size", 0),
|
||||
}
|
||||
|
||||
hf_url = metadata.get("hf_url", "")
|
||||
context["hf_url"] = hf_url
|
||||
repo = self._extract_repo_from_url(hf_url) if hf_url else ""
|
||||
context["repo"] = repo or ""
|
||||
if repo:
|
||||
readme = await self._fetch_readme(repo)
|
||||
context["readme_content"] = readme[:8000] if readme else "(README not available)"
|
||||
|
||||
try:
|
||||
context["base_models"] = await list_base_models()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to list base models: %s", exc)
|
||||
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _extract_repo_from_url(hf_url: str) -> Optional[str]:
|
||||
"""Extract ``user/repo`` from a HuggingFace URL."""
|
||||
if not hf_url:
|
||||
return None
|
||||
m = re.match(r"https?://huggingface\.co/([^/]+/[^/]+)", hf_url)
|
||||
return m.group(1) if m else None
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_readme(repo: str) -> str:
|
||||
"""Fetch README.md from HuggingFace (tries ``main``, then ``master``)."""
|
||||
async with aiohttp.ClientSession(
|
||||
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as session:
|
||||
for branch in ("main", "master"):
|
||||
url = f"https://huggingface.co/{repo}/raw/{branch}/README.md"
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.text()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to fetch README from %s: %s", url, exc)
|
||||
return ""
|
||||
|
||||
async def _emit_progress(
|
||||
self,
|
||||
callback: Optional[AgentProgressReporter],
|
||||
skill_name: str,
|
||||
*,
|
||||
status: str,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
"""Send a progress update via WebSocket (if callback is set)."""
|
||||
payload: Dict[str, Any] = {"type": "agent_progress", "skill": skill_name, "status": status}
|
||||
payload.update(extra)
|
||||
if callback is not None:
|
||||
await callback.on_progress(payload)
|
||||
168
py/services/agent/post_processor.py
Normal file
168
py/services/agent/post_processor.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Post-processing engine for agent skill outputs.
|
||||
|
||||
The :class:`PostProcessor` takes the LLM's structured JSON output and applies
|
||||
it to a model's on-disk metadata via the :mod:`~py.agent_cli` functions.
|
||||
|
||||
It handles all the skill-specific business logic — conditions, transformations,
|
||||
and orchestration of multiple side-effects (write metadata, download preview,
|
||||
refresh cache). All actual I/O is delegated to :mod:`~py.agent_cli`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostProcessor:
|
||||
"""Deterministic post-processor for agent skill outputs.
|
||||
|
||||
Usage (called by :class:`~py.services.agent.agent_service.AgentService`)::
|
||||
|
||||
processor = PostProcessor()
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/path/to/model.safetensors",
|
||||
llm_output={...},
|
||||
metadata={...}, # from agent_cli.read_metadata()
|
||||
)
|
||||
"""
|
||||
|
||||
async def process(
|
||||
self,
|
||||
*,
|
||||
skill_name: str,
|
||||
model_path: str,
|
||||
llm_output: Dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Route *llm_output* to the correct skill post-processor.
|
||||
|
||||
Returns a dict with keys ``success`` (bool), ``updated_fields`` (list),
|
||||
``preview_downloaded`` (bool), and ``errors`` (list).
|
||||
"""
|
||||
if skill_name == "enrich_hf_metadata":
|
||||
return await self._process_enrich_hf_metadata(
|
||||
model_path, llm_output, metadata,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"updated_fields": [],
|
||||
"errors": [f"No post-processor registered for skill: {skill_name}"],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# enrich_hf_metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _process_enrich_hf_metadata(
|
||||
self,
|
||||
model_path: str,
|
||||
llm_output: Dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
from ...agent_cli import (
|
||||
apply_metadata_updates,
|
||||
download_preview,
|
||||
refresh_cache,
|
||||
)
|
||||
|
||||
updated_fields: List[str] = []
|
||||
preview_downloaded = False
|
||||
|
||||
# -- Determine whether this is an HF-sourced model -----------------
|
||||
is_hf_model = not metadata.get("from_civitai", True)
|
||||
|
||||
# -- Collect updates -----------------------------------------------
|
||||
updates: Dict[str, Any] = {}
|
||||
|
||||
# base_model
|
||||
new_base = (llm_output.get("base_model") or "").strip()
|
||||
current_base = metadata.get("base_model", "") or ""
|
||||
if new_base and self._should_overwrite(current_base, is_hf_model):
|
||||
updates["base_model"] = new_base
|
||||
|
||||
# trainedWords / trigger words
|
||||
new_triggers = llm_output.get("trigger_words", [])
|
||||
if isinstance(new_triggers, list):
|
||||
cleaned = [t.strip() for t in new_triggers if t.strip()]
|
||||
if cleaned:
|
||||
current_triggers = metadata.get("trainedWords") or []
|
||||
if self._should_overwrite_list(current_triggers, is_hf_model):
|
||||
updates["trainedWords"] = cleaned
|
||||
|
||||
# modelDescription
|
||||
new_desc = (llm_output.get("description") or "").strip()
|
||||
if new_desc:
|
||||
current_desc = metadata.get("modelDescription", "") or ""
|
||||
if self._should_overwrite(current_desc, is_hf_model):
|
||||
updates["modelDescription"] = new_desc
|
||||
|
||||
# tags — merge with existing, deduplicate (case-insensitive)
|
||||
new_tags = llm_output.get("tags", [])
|
||||
if isinstance(new_tags, list) and new_tags:
|
||||
existing_tags = metadata.get("tags") or []
|
||||
merged = self._merge_tags(existing_tags, new_tags)
|
||||
if len(merged) > len(existing_tags) or is_hf_model:
|
||||
updates["tags"] = merged
|
||||
|
||||
# metadata_source & llm_enriched_at (always set)
|
||||
updates["metadata_source"] = "agent:enrich_hf_metadata"
|
||||
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# -- Persist updates ------------------------------------------------
|
||||
if updates:
|
||||
updated_fields = await apply_metadata_updates(model_path, updates)
|
||||
|
||||
# -- Download preview -----------------------------------------------
|
||||
preview_url = (llm_output.get("preview_url") or "").strip()
|
||||
current_preview = metadata.get("preview_url") or ""
|
||||
if preview_url and not (current_preview and os.path.exists(current_preview)):
|
||||
preview_downloaded = await download_preview(model_path, preview_url)
|
||||
|
||||
# -- Refresh scanner cache ------------------------------------------
|
||||
if updated_fields or preview_downloaded:
|
||||
await refresh_cache(model_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"updated_fields": updated_fields,
|
||||
"preview_downloaded": preview_downloaded,
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _should_overwrite(current_value: str, is_hf_model: bool) -> bool:
|
||||
"""Return ``True`` when a scalar field should be overwritten."""
|
||||
return is_hf_model or not current_value or current_value.lower() in (
|
||||
"", "unknown",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_overwrite_list(current_list: List[str], is_hf_model: bool) -> bool:
|
||||
"""Return ``True`` when a list field should be overwritten."""
|
||||
return is_hf_model or not current_list
|
||||
|
||||
@staticmethod
|
||||
def _merge_tags(existing: List[str], new: List[str]) -> List[str]:
|
||||
"""Merge *new* tags into *existing*, all lowercased.
|
||||
|
||||
This matches the behaviour of :class:`TagUpdateService` which
|
||||
normalises every tag to lowercase for case-insensitive dedup.
|
||||
"""
|
||||
merged: List[str] = []
|
||||
seen: set = set()
|
||||
for tag in list(existing) + list(new):
|
||||
t = tag.strip().lower()
|
||||
if t and t not in seen:
|
||||
merged.append(t)
|
||||
seen.add(t)
|
||||
return merged
|
||||
45
py/services/agent/skill_definition.py
Normal file
45
py/services/agent/skill_definition.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Skill definition data structures.
|
||||
|
||||
Each skill is described by a :class:`SkillDefinition` that declares its
|
||||
input/output schemas, whether it needs an LLM call, and what permissions
|
||||
its post-processor has.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillPermissions:
|
||||
"""Declarative permission scope for a skill's post-processor.
|
||||
|
||||
These are auditable constraints — the :class:`AgentService` checks them
|
||||
before invoking the handler. They are defense-in-depth, not a sandbox.
|
||||
"""
|
||||
|
||||
write_metadata: bool = True
|
||||
write_previews: bool = True
|
||||
network_domains: Tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillDefinition:
|
||||
"""Immutable description of an agent skill."""
|
||||
|
||||
name: str
|
||||
title: str
|
||||
description: str
|
||||
llm_required: bool
|
||||
input_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
output_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
model_type_filter: Optional[List[str]] = None
|
||||
permissions: SkillPermissions = field(default_factory=SkillPermissions)
|
||||
|
||||
def applies_to_model_type(self, model_type: str) -> bool:
|
||||
"""Return ``True`` if this skill can run on the given model type."""
|
||||
|
||||
if self.model_type_filter is None:
|
||||
return True
|
||||
return model_type in self.model_type_filter
|
||||
184
py/services/agent/skill_registry.py
Normal file
184
py/services/agent/skill_registry.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Discovery and loading of agent skills.
|
||||
|
||||
Skills live in ``py/services/agent/skills/<name>/`` directories. Each
|
||||
directory must contain:
|
||||
|
||||
- ``skill.yaml`` — metadata (name, title, description, schemas, permissions)
|
||||
- ``prompt.md`` — LLM system prompt template (Jinja2-style ``{{variable}}`` placeholders)
|
||||
- ``handler.py`` — async ``prepare`` and ``post_process`` functions
|
||||
|
||||
The registry scans the skills directory on first access and caches results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from .skill_definition import SkillDefinition, SkillPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Directory where built-in skills are stored
|
||||
_SKILLS_DIR = Path(__file__).parent / "skills"
|
||||
|
||||
|
||||
class SkillRegistry:
|
||||
"""Discover and load agent skills from the filesystem."""
|
||||
|
||||
_instance: Optional["SkillRegistry"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, skills_dir: Path = _SKILLS_DIR) -> None:
|
||||
self._skills_dir = skills_dir
|
||||
self._skills: Dict[str, SkillDefinition] = {}
|
||||
self._loaded: bool = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Singleton access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "SkillRegistry":
|
||||
"""Return the lazily-initialised global ``SkillRegistry``."""
|
||||
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
registry = cls()
|
||||
registry._discover()
|
||||
cls._instance = registry
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the cached singleton — primarily for tests."""
|
||||
|
||||
cls._instance = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Discovery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover(self) -> None:
|
||||
"""Scan the skills directory and load all valid skill definitions."""
|
||||
|
||||
self._skills.clear()
|
||||
if not self._skills_dir.is_dir():
|
||||
logger.warning("Skills directory does not exist: %s", self._skills_dir)
|
||||
self._loaded = True
|
||||
return
|
||||
|
||||
for entry in sorted(self._skills_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
skill_yaml = entry / "skill.yaml"
|
||||
if not skill_yaml.exists():
|
||||
continue
|
||||
try:
|
||||
definition = self._load_skill_yaml(skill_yaml)
|
||||
if definition is not None:
|
||||
self._skills[definition.name] = definition
|
||||
logger.debug("Loaded skill: %s", definition.name)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load skill from %s: %s", skill_yaml, exc)
|
||||
|
||||
self._loaded = True
|
||||
logger.info("Discovered %d agent skills", len(self._skills))
|
||||
|
||||
def _load_skill_yaml(self, path: Path) -> Optional[SkillDefinition]:
|
||||
"""Parse a skill.yaml file into a :class:`SkillDefinition`."""
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
if not data or "name" not in data:
|
||||
logger.warning("skill.yaml missing required 'name' field: %s", path)
|
||||
return None
|
||||
|
||||
# Parse permissions
|
||||
perm_data = data.get("permissions", {})
|
||||
permissions = SkillPermissions(
|
||||
write_metadata=perm_data.get("write_metadata", True),
|
||||
write_previews=perm_data.get("write_previews", True),
|
||||
network_domains=tuple(perm_data.get("network_domains", [])),
|
||||
)
|
||||
|
||||
return SkillDefinition(
|
||||
name=data["name"],
|
||||
title=data.get("title", data["name"]),
|
||||
description=data.get("description", ""),
|
||||
llm_required=data.get("llm_required", False),
|
||||
input_schema=data.get("input_schema", {}),
|
||||
output_schema=data.get("output_schema", {}),
|
||||
model_type_filter=data.get("model_type_filter"),
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_skills(self) -> List[SkillDefinition]:
|
||||
"""Return all discovered skill definitions."""
|
||||
|
||||
if not self._loaded:
|
||||
self._discover()
|
||||
return list(self._skills.values())
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillDefinition]:
|
||||
"""Return the skill definition for ``name``, or ``None`` if not found."""
|
||||
|
||||
if not self._loaded:
|
||||
self._discover()
|
||||
return self._skills.get(name)
|
||||
|
||||
def load_prompt(self, name: str) -> str:
|
||||
"""Load and return the prompt template for a skill."""
|
||||
|
||||
skill_dir = self._skills_dir / name
|
||||
prompt_path = skill_dir / "prompt.md"
|
||||
if not prompt_path.exists():
|
||||
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def load_handler(self, name: str) -> Dict[str, Callable]:
|
||||
"""Dynamically import a skill's handler module and return its functions.
|
||||
|
||||
Returns a dict with ``prepare`` and ``post_process`` callables.
|
||||
``prepare`` may be absent (the skill doesn't need pre-LLM data gathering).
|
||||
"""
|
||||
|
||||
skill_dir = self._skills_dir / name
|
||||
handler_path = skill_dir / "handler.py"
|
||||
if not handler_path.exists():
|
||||
raise FileNotFoundError(f"Handler not found: {handler_path}")
|
||||
|
||||
# Use importlib to load the module by file path
|
||||
# Important: use a fully-qualified module name so that absolute imports
|
||||
# (e.g. ``from py.utils.metadata_manager import MetadataManager``) resolve correctly.
|
||||
module_name = f"py.services.agent.skills.{name}.handler"
|
||||
spec = importlib.util.spec_from_file_location(module_name, handler_path)
|
||||
if spec is None or spec.loader is None:
|
||||
raise ImportError(f"Cannot load handler module from {handler_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
result: Dict[str, Callable] = {}
|
||||
if hasattr(module, "prepare"):
|
||||
result["prepare"] = module.prepare
|
||||
if hasattr(module, "post_process"):
|
||||
result["post_process"] = module.post_process
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Skill handler {name} is missing required 'post_process' function"
|
||||
)
|
||||
return result
|
||||
1
py/services/agent/skills/__init__.py
Normal file
1
py/services/agent/skills/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Agent skills package — each subdirectory is a skill.
|
||||
77
py/services/agent/skills/enrich_hf_metadata/prompt.md
Normal file
77
py/services/agent/skills/enrich_hf_metadata/prompt.md
Normal file
@@ -0,0 +1,77 @@
|
||||
You are an expert assistant for AI image generation models. Your task is to extract structured metadata from a HuggingFace model card (README.md).
|
||||
|
||||
## Model Information
|
||||
|
||||
- **Repository**: {{hf_url}}
|
||||
- **Model file path**: {{model_path}}
|
||||
- **Repository ID**: {{repo}}
|
||||
|
||||
## Current Metadata (may be incomplete)
|
||||
|
||||
```json
|
||||
{{current_metadata}}
|
||||
```
|
||||
|
||||
## HuggingFace README Content
|
||||
|
||||
```
|
||||
{{readme_content}}
|
||||
```
|
||||
|
||||
## Extraction Instructions
|
||||
|
||||
Extract the following information from the README content above:
|
||||
|
||||
### base_model
|
||||
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list below. Do not invent new names or use aliases.
|
||||
|
||||
Available Base Models:
|
||||
{{base_models}}
|
||||
|
||||
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
|
||||
|
||||
### trigger_words
|
||||
The trigger words or activation prompts needed to use this LoRA. Look for:
|
||||
- `instance_prompt:` in the YAML frontmatter
|
||||
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
|
||||
- Example prompts at the start (usually the first word or phrase before any description)
|
||||
Return as an array of strings. If none found, return an empty array.
|
||||
|
||||
### description
|
||||
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
|
||||
|
||||
### tags
|
||||
3-8 relevant tags for categorizing this model. Extract from:
|
||||
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
|
||||
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl")
|
||||
- The style/subject (e.g. "anime", "photorealistic", "style", "character")
|
||||
All lowercase, no spaces. Return empty array if none found.
|
||||
|
||||
### preview_url
|
||||
The URL of the most suitable preview image from the README. Look for image tags (e.g. ``) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
|
||||
|
||||
### confidence
|
||||
Your confidence level in the extracted data:
|
||||
- "high" — most fields were explicitly stated in the README
|
||||
- "medium" — some fields were inferred from context
|
||||
- "low" — most fields are guesses based on limited information
|
||||
|
||||
## Output Format
|
||||
|
||||
Return ONLY a JSON object with exactly these fields (no markdown fences, no extra text):
|
||||
|
||||
{
|
||||
"model_path": "{{model_path}}",
|
||||
"base_model": "<canonical name or empty string>",
|
||||
"trigger_words": ["<word1>", "<word2>"],
|
||||
"description": "<1-2 sentence summary>",
|
||||
"tags": ["<tag1>", "<tag2>"],
|
||||
"preview_url": "<image URL or empty string>",
|
||||
"confidence": "<high|medium|low>"
|
||||
}
|
||||
|
||||
Important:
|
||||
- Only include the JSON object, no other text
|
||||
- If a field cannot be determined, use an empty string or empty array
|
||||
- Do not fabricate information not supported by the README
|
||||
- For base_model, the YAML frontmatter often has `base_model:` with a HuggingFace repo name like "black-forest-labs/FLUX.1-dev" — map this to "Flux.1 D"
|
||||
47
py/services/agent/skills/enrich_hf_metadata/skill.yaml
Normal file
47
py/services/agent/skills/enrich_hf_metadata/skill.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
name: enrich_hf_metadata
|
||||
title: "Enrich Metadata from HuggingFace"
|
||||
description: >
|
||||
Parse the HuggingFace model card via LLM to extract description, trigger
|
||||
words, base model, tags, and preview image URL. Updates .metadata.json
|
||||
and downloads the preview thumbnail.
|
||||
llm_required: true
|
||||
model_type_filter: ["lora", "checkpoint", "embedding"]
|
||||
input_schema:
|
||||
type: object
|
||||
properties:
|
||||
model_paths:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- model_paths
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
model_path:
|
||||
type: string
|
||||
base_model:
|
||||
type: string
|
||||
trigger_words:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description:
|
||||
type: string
|
||||
tags:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
preview_url:
|
||||
type: string
|
||||
confidence:
|
||||
type: string
|
||||
enum: ["high", "medium", "low"]
|
||||
required:
|
||||
- model_path
|
||||
- confidence
|
||||
permissions:
|
||||
write_metadata: true
|
||||
write_previews: true
|
||||
network_domains:
|
||||
- "huggingface.co"
|
||||
Reference in New Issue
Block a user