mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-02 15:31:17 -03:00
Introduce an agent skill framework for LLM-driven metadata enrichment: - AgentCLI (py/agent_cli/): in-process wrappers around internal services using standard relative imports, eliminating the need for sys.path hacks - LLMService: centralized BYOK (bring-your-own-key) LLM client supporting OpenAI, Ollama, and custom OpenAI-compatible endpoints - PostProcessor: deterministic engine that applies LLM output via AgentCLI (replaces old handler.py + _BASE_MODEL_ALIASES approach) - SkillRegistry: filesystem-based skill discovery (skill.yaml + prompt.md) - AgentService: orchestrates skill execution with WebSocket progress - Frontend AgentManager: WebSocket listeners, skill execution, config UI - Context menu entries (single + bulk) for "Enrich Metadata (Agent)" - Settings UI for AI Provider configuration (BYOK) - Full i18n support across 9 locales Bug fixes found during review: - aiohttp.web.json_response: status_code= -> status= - settings_modal cancelEditApiKey: wrong argument position - AgentManager.isLlmConfigured: allow Ollama without API key - PostProcessor._merge_tags: lowercase all tags to match TagUpdateService
414 lines
14 KiB
Python
414 lines
14 KiB
Python
"""Agent orchestration service.
|
|
|
|
The :class:`AgentService` coordinates skill execution:
|
|
|
|
1. Look up the skill in :class:`SkillRegistry`
|
|
2. Validate input against the skill's ``input_schema``
|
|
3. Prepare context via :mod:`~py.agent_cli` (read metadata, list base models, fetch HF README)
|
|
4. If ``llm_required``: call :class:`LLMService` with the rendered prompt
|
|
5. Post-process via :class:`PostProcessor` (delegates I/O to :mod:`~py.agent_cli`)
|
|
6. Broadcast progress and completion via :class:`WebSocketManager`
|
|
|
|
Skills define *what* to do (prompt template). The AgentService handles *how*
|
|
(LLM calls, context gathering, validation, progress).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import aiohttp
|
|
|
|
from ..llm_service import LLMService
|
|
from ..websocket_manager import ws_manager
|
|
from .post_processor import PostProcessor
|
|
from .skill_registry import SkillRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AgentProgressReporter:
|
|
"""Protocol-compatible progress reporter backed by WebSocket broadcast."""
|
|
|
|
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
|
await ws_manager.broadcast(payload)
|
|
|
|
|
|
@dataclass
|
|
class SkillResult:
|
|
"""Outcome of a skill execution."""
|
|
|
|
success: bool
|
|
updated_models: List[Dict[str, Any]] = field(default_factory=list)
|
|
errors: List[str] = field(default_factory=list)
|
|
summary: str = ""
|
|
|
|
|
|
def _validate_schema(data: Any, schema: Dict[str, Any], path: str = "") -> List[str]:
|
|
"""Minimal JSON schema validator.
|
|
|
|
Supports a subset of JSON Schema: ``type``, ``properties``, ``required``,
|
|
``items``, ``enum``. Returns a list of error messages (empty = valid).
|
|
"""
|
|
|
|
errors: List[str] = []
|
|
if not schema:
|
|
return errors
|
|
|
|
expected_type = schema.get("type")
|
|
if expected_type:
|
|
type_map = {
|
|
"string": str,
|
|
"number": (int, float),
|
|
"integer": int,
|
|
"boolean": bool,
|
|
"array": list,
|
|
"object": dict,
|
|
"null": type(None),
|
|
}
|
|
expected_py = type_map.get(expected_type)
|
|
if expected_py is not None and not isinstance(data, expected_py):
|
|
errors.append(f"{path or 'root'}: expected {expected_type}, got {type(data).__name__}")
|
|
return errors
|
|
|
|
if expected_type == "object" and isinstance(data, dict):
|
|
properties = schema.get("properties", {})
|
|
required = schema.get("required", [])
|
|
for req_key in required:
|
|
if req_key not in data:
|
|
errors.append(f"{path or 'root'}: missing required property '{req_key}'")
|
|
for key, value in data.items():
|
|
if key in properties:
|
|
errors.extend(_validate_schema(value, properties[key], f"{path}.{key}"))
|
|
|
|
if expected_type == "array" and isinstance(data, list):
|
|
items_schema = schema.get("items")
|
|
if items_schema:
|
|
for i, item in enumerate(data):
|
|
errors.extend(_validate_schema(item, items_schema, f"{path}[{i}]"))
|
|
|
|
if "enum" in schema and data not in schema["enum"]:
|
|
errors.append(f"{path or 'root'}: value '{data}' not in enum {schema['enum']}")
|
|
|
|
return errors
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
# Prompt template rendering
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
def _render_prompt(template: str, variables: Dict[str, Any]) -> str:
|
|
"""Render a prompt template with ``{{variable}}`` placeholders.
|
|
|
|
Uses simple regex substitution — no Jinja2 dependency needed.
|
|
"""
|
|
|
|
def replace(match: re.Match) -> str:
|
|
key = match.group(1).strip()
|
|
value = variables.get(key, "")
|
|
if isinstance(value, (dict, list)):
|
|
return json.dumps(value, ensure_ascii=False, indent=2)
|
|
return str(value)
|
|
|
|
return re.sub(r"\{\{(\w+)\}\}", replace, template)
|
|
|
|
|
|
class AgentService:
|
|
"""Orchestrate agent skill execution.
|
|
|
|
Usage::
|
|
|
|
service = await AgentService.get_instance()
|
|
result = await service.execute_skill(
|
|
skill_name="enrich_hf_metadata",
|
|
input_data={"model_paths": ["/path/to/model.safetensors"]},
|
|
progress_callback=AgentProgressReporter(),
|
|
)
|
|
"""
|
|
|
|
_instance: Optional["AgentService"] = None
|
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
skill_registry: Optional[SkillRegistry] = None,
|
|
llm_service: Optional[LLMService] = None,
|
|
) -> None:
|
|
self._registry = skill_registry
|
|
self._llm_service = llm_service
|
|
|
|
@classmethod
|
|
async def get_instance(cls) -> "AgentService":
|
|
"""Return the lazily-initialised global ``AgentService``."""
|
|
|
|
if cls._instance is None:
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls(
|
|
skill_registry=await SkillRegistry.get_instance(),
|
|
llm_service=await LLMService.get_instance(),
|
|
)
|
|
return cls._instance
|
|
|
|
@classmethod
|
|
def reset_instance(cls) -> None:
|
|
"""Reset the cached singleton — primarily for tests."""
|
|
|
|
cls._instance = None
|
|
|
|
async def _ensure_registry(self) -> SkillRegistry:
|
|
if self._registry is None:
|
|
self._registry = await SkillRegistry.get_instance()
|
|
return self._registry
|
|
|
|
async def _ensure_llm(self) -> LLMService:
|
|
if self._llm_service is None:
|
|
self._llm_service = await LLMService.get_instance()
|
|
return self._llm_service
|
|
|
|
async def list_skills(self) -> List[Dict[str, Any]]:
|
|
"""Return a JSON-serialisable list of available skills."""
|
|
|
|
registry = await self._ensure_registry()
|
|
return [
|
|
{
|
|
"name": s.name,
|
|
"title": s.title,
|
|
"description": s.description,
|
|
"llm_required": s.llm_required,
|
|
"model_type_filter": s.model_type_filter,
|
|
}
|
|
for s in registry.list_skills()
|
|
]
|
|
|
|
async def execute_skill(
|
|
self,
|
|
*,
|
|
skill_name: str,
|
|
input_data: Dict[str, Any],
|
|
progress_callback: Optional[AgentProgressReporter] = None,
|
|
) -> SkillResult:
|
|
"""Execute an agent skill.
|
|
|
|
Args:
|
|
skill_name: Name of the skill to execute
|
|
input_data: Input validated against the skill's ``input_schema``
|
|
progress_callback: Optional WebSocket progress reporter
|
|
|
|
Returns:
|
|
:class:`SkillResult` with success status and updated model info
|
|
"""
|
|
|
|
registry = await self._ensure_registry()
|
|
logger.info("execute_skill '%s': looking up skill", skill_name)
|
|
skill = registry.get_skill(skill_name)
|
|
if skill is None:
|
|
return SkillResult(
|
|
success=False,
|
|
errors=[f"Skill not found: {skill_name}"],
|
|
summary=f"Skill '{skill_name}' does not exist",
|
|
)
|
|
|
|
input_errors = _validate_schema(input_data, skill.input_schema)
|
|
if input_errors:
|
|
return SkillResult(
|
|
success=False,
|
|
errors=input_errors,
|
|
summary=f"Invalid input: {'; '.join(input_errors)}",
|
|
)
|
|
|
|
model_paths = input_data.get("model_paths", [])
|
|
if not model_paths:
|
|
return SkillResult(
|
|
success=False,
|
|
errors=["No model_paths provided"],
|
|
summary="No models to process",
|
|
)
|
|
|
|
total = len(model_paths)
|
|
processed = 0
|
|
success_count = 0
|
|
updated_models: List[Dict[str, Any]] = []
|
|
errors: List[str] = []
|
|
post_processor = PostProcessor()
|
|
|
|
logger.info("execute_skill '%s': starting with %d model(s)", skill_name, total)
|
|
await self._emit_progress(
|
|
progress_callback, skill_name, status="started",
|
|
total=total, processed=0, success=0,
|
|
)
|
|
|
|
llm = await self._ensure_llm()
|
|
llm_configured = llm.is_configured() if skill.llm_required else True
|
|
|
|
for model_path in model_paths:
|
|
logger.info(
|
|
"execute_skill '%s': processing model %d/%d: %s",
|
|
skill_name, processed + 1, total, model_path,
|
|
)
|
|
try:
|
|
from ...agent_cli import read_metadata
|
|
metadata = await read_metadata(model_path)
|
|
|
|
prompt_vars: Dict[str, Any] = {"model_path": model_path}
|
|
if skill.llm_required and llm_configured:
|
|
prompt_vars = await self._build_prompt_context(
|
|
skill_name, model_path, metadata, registry, llm,
|
|
)
|
|
|
|
llm_response: Optional[Dict[str, Any]] = None
|
|
if skill.llm_required and llm_configured:
|
|
prompt_template = registry.load_prompt(skill_name)
|
|
rendered = _render_prompt(prompt_template, prompt_vars)
|
|
logger.info(
|
|
"execute_skill '%s': LLM call for %s (prompt=%d chars)",
|
|
skill_name, model_path, len(rendered),
|
|
)
|
|
llm_response = await llm.chat_completion_json(
|
|
system_prompt=prompt_vars.get(
|
|
"system_prompt",
|
|
"You are a helpful assistant that extracts structured metadata.",
|
|
),
|
|
user_prompt=rendered,
|
|
)
|
|
|
|
model_result = await post_processor.process(
|
|
skill_name=skill_name,
|
|
model_path=model_path,
|
|
llm_output=llm_response or {},
|
|
metadata=metadata,
|
|
)
|
|
|
|
if model_result.get("success", True):
|
|
success_count += 1
|
|
uf = model_result.get("updated_fields", [])
|
|
if uf:
|
|
updated_models.append({"path": model_path, "updated_fields": uf})
|
|
else:
|
|
errors.extend(
|
|
model_result.get("errors", [model_result.get("error", "Unknown error")])
|
|
)
|
|
|
|
except Exception as exc:
|
|
logger.error("Skill %s failed for %s: %s", skill_name, model_path, exc)
|
|
errors.append(f"{model_path}: {exc}")
|
|
|
|
processed += 1
|
|
await self._emit_progress(
|
|
progress_callback, skill_name, status="processing",
|
|
total=total, processed=processed, success=success_count,
|
|
current_path=model_path,
|
|
)
|
|
|
|
result = SkillResult(
|
|
success=success_count > 0,
|
|
updated_models=updated_models,
|
|
errors=errors,
|
|
summary=f"Processed {processed}/{total} models, {success_count} succeeded",
|
|
)
|
|
|
|
logger.info("execute_skill '%s': done — %s", skill_name, result.summary)
|
|
await self._emit_progress(
|
|
progress_callback, skill_name, status="completed",
|
|
total=total, processed=processed, success=success_count,
|
|
updated_models=updated_models, errors=errors, summary=result.summary,
|
|
)
|
|
|
|
return result
|
|
|
|
async def _build_prompt_context(
|
|
self,
|
|
skill_name: str,
|
|
model_path: str,
|
|
metadata: Dict[str, Any],
|
|
registry: SkillRegistry,
|
|
llm: Any,
|
|
) -> Dict[str, Any]:
|
|
"""Gather variables for the skill's prompt template.
|
|
|
|
Reads metadata, fetches the HF README (if applicable), lists available
|
|
base models, and returns a dict that maps to ``{{variable}}``
|
|
placeholders in ``prompt.md``.
|
|
"""
|
|
from ...agent_cli import list_base_models
|
|
|
|
context: Dict[str, Any] = {
|
|
"model_path": model_path,
|
|
"hf_url": "",
|
|
"repo": "",
|
|
"readme_content": "",
|
|
"current_metadata": {},
|
|
"base_models": [],
|
|
}
|
|
|
|
context["current_metadata"] = {
|
|
"file_name": metadata.get("file_name", ""),
|
|
"base_model": metadata.get("base_model", ""),
|
|
"tags": metadata.get("tags", []),
|
|
"modelDescription": metadata.get("modelDescription", ""),
|
|
"trainedWords": metadata.get("trainedWords", []),
|
|
"sha256": (metadata.get("sha256") or "")[:16] + "..." if metadata.get("sha256") else "",
|
|
"size": metadata.get("size", 0),
|
|
}
|
|
|
|
hf_url = metadata.get("hf_url", "")
|
|
context["hf_url"] = hf_url
|
|
repo = self._extract_repo_from_url(hf_url) if hf_url else ""
|
|
context["repo"] = repo or ""
|
|
if repo:
|
|
readme = await self._fetch_readme(repo)
|
|
context["readme_content"] = readme[:8000] if readme else "(README not available)"
|
|
|
|
try:
|
|
context["base_models"] = await list_base_models()
|
|
except Exception as exc:
|
|
logger.debug("Failed to list base models: %s", exc)
|
|
|
|
return context
|
|
|
|
@staticmethod
|
|
def _extract_repo_from_url(hf_url: str) -> Optional[str]:
|
|
"""Extract ``user/repo`` from a HuggingFace URL."""
|
|
if not hf_url:
|
|
return None
|
|
m = re.match(r"https?://huggingface\.co/([^/]+/[^/]+)", hf_url)
|
|
return m.group(1) if m else None
|
|
|
|
@staticmethod
|
|
async def _fetch_readme(repo: str) -> str:
|
|
"""Fetch README.md from HuggingFace (tries ``main``, then ``master``)."""
|
|
async with aiohttp.ClientSession(
|
|
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
|
|
timeout=aiohttp.ClientTimeout(total=30),
|
|
) as session:
|
|
for branch in ("main", "master"):
|
|
url = f"https://huggingface.co/{repo}/raw/{branch}/README.md"
|
|
try:
|
|
async with session.get(url) as resp:
|
|
if resp.status == 200:
|
|
return await resp.text()
|
|
except Exception as exc:
|
|
logger.debug("Failed to fetch README from %s: %s", url, exc)
|
|
return ""
|
|
|
|
async def _emit_progress(
|
|
self,
|
|
callback: Optional[AgentProgressReporter],
|
|
skill_name: str,
|
|
*,
|
|
status: str,
|
|
**extra: Any,
|
|
) -> None:
|
|
"""Send a progress update via WebSocket (if callback is set)."""
|
|
payload: Dict[str, Any] = {"type": "agent_progress", "skill": skill_name, "status": status}
|
|
payload.update(extra)
|
|
if callback is not None:
|
|
await callback.on_progress(payload)
|