feat(agent): add LLM-powered metadata enrichment system with AgentCLI and PostProcessor

Introduce an agent skill framework for LLM-driven metadata enrichment:

- AgentCLI (py/agent_cli/): in-process wrappers around internal services
  using standard relative imports, eliminating the need for sys.path hacks
- LLMService: centralized BYOK (bring-your-own-key) LLM client supporting
  OpenAI, Ollama, and custom OpenAI-compatible endpoints
- PostProcessor: deterministic engine that applies LLM output via AgentCLI
  (replaces old handler.py + _BASE_MODEL_ALIASES approach)
- SkillRegistry: filesystem-based skill discovery (skill.yaml + prompt.md)
- AgentService: orchestrates skill execution with WebSocket progress
- Frontend AgentManager: WebSocket listeners, skill execution, config UI
- Context menu entries (single + bulk) for "Enrich Metadata (Agent)"
- Settings UI for AI Provider configuration (BYOK)
- Full i18n support across 9 locales

Bug fixes found during review:
- aiohttp.web.json_response: status_code= -> status=
- settings_modal cancelEditApiKey: wrong argument position
- AgentManager.isLlmConfigured: allow Ollama without API key
- PostProcessor._merge_tags: lowercase all tags to match TagUpdateService
This commit is contained in:
Will Miao
2026-07-02 20:51:11 +08:00
parent fe90f7f9b1
commit cf898da193
44 changed files with 5937 additions and 2180 deletions

208
docs/agent_skills.md Normal file
View File

@@ -0,0 +1,208 @@
# Agent Skills System
The LoRA Manager agent skills system enables LLM-powered metadata enrichment and other AI-driven tasks. Users configure their own LLM provider (BYOK), and skills are executed through right-click context menu actions.
## Architecture
```
┌──────────────────────────────────────────────┐
│ LoRA Manager Backend │
│ │
│ ┌──────────────┐ ┌────────────────┐ │
│ │ LLMService │───▶│ LLM Provider │ │
│ │ (BYOK config, │◀───│ (OpenAI/Ollama │ │
│ │ API calls) │ │ /custom) │ │
│ └───────┬───────┘ └────────────────┘ │
│ │ │
│ ┌───────▼───────────────────────┐ │
│ │ AgentService │ │
│ │ (orchestration: validate │ │
│ │ → LLM call → post-process │ │
│ │ → WebSocket broadcast) │ │
│ └───────┬───────────────────────┘ │
│ │ │
│ ┌───────▼───────────────────────┐ │
│ │ SkillRegistry │ │
│ │ ┌─────────────────────────┐ │ │
│ │ │ enrich_hf_metadata: │ │ │
│ │ │ - skill.yaml │ │ │
│ │ │ - prompt.md │ │ │
│ │ │ - handler.py │ │ │
│ │ └─────────────────────────┘ │ │
│ └───────────────────────────────┘ │
└──────────────────────────────────────────────┘
```
### Key Design Principle
**Skills define *what* to do (prompt + post-processing). The AgentService handles *how* (LLM calls, validation, progress).**
Skills never call the LLM directly. This keeps BYOK configuration centralized and provider-agnostic.
## BYOK Configuration
Users configure their LLM provider in **Settings → AI Provider**:
| Setting | Description | Example |
|---|---|---|
| `llm_provider` | Provider type | `openai`, `ollama`, or `custom` |
| `llm_api_key` | API key (not needed for local Ollama) | `sk-...` |
| `llm_api_base` | Custom API base URL (empty = provider default) | `https://api.openai.com/v1` |
| `llm_model` | Model name | `gpt-4o-mini` |
Environment variable overrides: `LLM_API_KEY`, `LLM_MODEL`, `LLM_API_BASE`, `LLM_PROVIDER`.
### Supported Providers
- **OpenAI**: Uses `https://api.openai.com/v1` by default
- **Ollama** (local): Uses `http://localhost:11434/v1`, no API key required
- **Custom**: Any OpenAI-compatible endpoint (vLLM, LM Studio, etc.) — set `llm_api_base` explicitly
## Available Skills
### enrich_hf_metadata
Enriches HuggingFace-downloaded models with metadata extracted by an LLM from the HF model card.
**Entry point**: Right-click context menu → "Enrich Metadata (Agent)"
**What it does**:
1. Reads the model's `.metadata.json` to get the `hf_url`
2. Fetches the README.md from the HuggingFace repository
3. Sends the README + local metadata to the LLM for structured extraction
4. Writes extracted fields to `.metadata.json`:
- `base_model` — only if current value is empty
- `trainedWords` — trigger words (LoRA only, if none exist)
- `modelDescription` — concise summary (if none exists)
- `tags` — merged with existing tags, deduplicated
- `metadata_source` — audit trail: `agent:enrich_hf_metadata`
- `llm_enriched_at` — ISO timestamp
5. Downloads and optimizes preview image (if LLM found one in the README)
6. Updates the scanner cache
7. Broadcasts WebSocket progress events
**Model types**: LoRA, Checkpoint, Embedding
## Adding a New Skill
### 1. Create the skill directory
```
py/services/agent/skills/<skill_name>/
├── skill.yaml # Skill metadata and schemas
├── prompt.md # LLM prompt template
└── handler.py # Pre-processing and post-processing
```
### 2. Write skill.yaml
```yaml
name: my_skill
title: "My Skill"
description: "What this skill does"
llm_required: true
model_type_filter: ["lora"] # or null for all types
input_schema:
type: object
properties:
model_paths:
type: array
items:
type: string
required:
- model_paths
output_schema:
type: object
properties:
# ... JSON schema for LLM output
permissions:
write_metadata: true
write_previews: false
network_domains:
- "example.com"
```
### 3. Write prompt.md
Use `{{variable}}` placeholders that will be replaced with data from the `prepare` function:
```markdown
You are an expert assistant...
Model URL: {{hf_url}}
README content:
{{readme_content}}
Current metadata:
{{current_metadata}}
```
### 4. Write handler.py
```python
async def prepare(model_path: str, input_data: dict) -> dict:
"""Gather context for the LLM prompt. Returns variables for template rendering."""
return {
"model_path": model_path,
# ... other variables used in prompt.md
}
async def post_process(context) -> dict:
"""Apply the LLM-extracted data to the model."""
llm_response = context.llm_response
# ... write metadata, download previews, update cache
return {
"success": True,
"updated_fields": ["base_model", "tags"],
"errors": [],
}
```
**Important**: Use absolute imports (`from py.utils.metadata_manager import MetadataManager`) because skills are loaded via `importlib.util.spec_from_file_location`, which doesn't support relative imports.
### 5. Test
The skill is automatically discovered by `SkillRegistry` on startup. Test with:
```python
pytest tests/services/test_agent_service.py
```
## API Endpoints
| Method | Path | Description |
|---|---|---|
| GET | `/api/lm/agent/skills` | List available skills |
| POST | `/api/lm/agent/execute/{skill_name}` | Execute a skill (body: `{"model_paths": [...]}`) |
| POST | `/api/lm/agent/cancel` | Cancel running skill (stub) |
## WebSocket Events
| Type | When | Key fields |
|---|---|---|
| `agent_progress` | Skill started/processing | `skill`, `status`, `total`, `processed`, `success`, `current_path` |
| `agent_progress` | Skill completed | `skill`, `status`, `updated_models`, `errors`, `summary` |
| `agent_progress` | Skill error | `skill`, `status`, `error` |
## Security Model
Skills declare permissions in `skill.yaml`:
- `write_metadata` — can write `.metadata.json` files
- `write_previews` — can download/replace preview images
- `network_domains` — allowed domains for HTTP requests
These are declarative constraints checked by `AgentService`. They are defense-in-depth, not a sandbox — the Python process can technically do anything, but the contract is clear and auditable.
## File Locations
| Component | Path |
|---|---|
| LLMService | `py/services/llm_service.py` |
| AgentService | `py/services/agent/agent_service.py` |
| SkillRegistry | `py/services/agent/skill_registry.py` |
| SkillDefinition | `py/services/agent/skill_definition.py` |
| Skills directory | `py/services/agent/skills/` |
| Route handlers | `py/routes/handlers/agent_handlers.py` |
| Frontend manager | `static/js/managers/AgentManager.js` |
| Settings UI | `templates/components/modals/settings_modal.html` |
| Context menu | `templates/components/context_menu.html` |

View File

@@ -657,6 +657,23 @@
"proxyPassword": "Passwort (optional)",
"proxyPasswordPlaceholder": "passwort",
"proxyPasswordHelp": "Passwort für die Proxy-Authentifizierung (falls erforderlich)"
},
"aiProvider": {
"title": "KI-Anbieter",
"provider": "Anbieter",
"providerHelp": "Wählen Sie Ihren LLM-Anbieter. OpenAI und Ollama verwenden voreingestellte API-Endpunkte. Mit \"Benutzerdefiniert\" können Sie jeden OpenAI-kompatiblen Endpunkt angeben.",
"custom": "Benutzerdefiniert (OpenAI-kompatibel)",
"apiBase": "API-Basis-URL",
"apiBaseHelp": "Die Basis-URL für die LLM-API (z.B. https://api.openai.com/v1). Leer lassen, um die Anbietervoreinstellung zu verwenden.",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "API-Schlüssel",
"apiKeyHelp": "Ihr LLM-API-Schlüssel. Wird lokal gespeichert und niemals an einen anderen Server außer Ihrem gewählten LLM-Anbieter gesendet.",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "Nicht festgelegt",
"apiKeyConfigured": "Konfiguriert",
"apiKeySet": "Einrichten",
"model": "Modell",
"modelHelp": "Der zu verwendende Modellname (z.B. deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Prüfen Sie Ihren Anbieter auf verfügbare Modelle."
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "Abgeschlossen: {success} verschoben, {skipped} übersprungen, {failures} fehlgeschlagen",
"complete": "Automatische Organisation abgeschlossen",
"error": "Fehler: {error}"
}
},
"enrichHfAgent": "Metadaten mit KI anreichern"
},
"contextMenu": {
"refreshMetadata": "Civitai-Daten aktualisieren",
@@ -778,7 +796,8 @@
"shareRecipe": "Rezept teilen",
"viewAllLoras": "Alle LoRAs anzeigen",
"downloadMissingLoras": "Fehlende LoRAs herunterladen",
"deleteRecipe": "Rezept löschen"
"deleteRecipe": "Rezept löschen",
"enrichHfAgent": "Metadaten mit KI anreichern"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "In die Zwischenablage kopiert",
"downloadStarted": "Download gestartet"
},
"agent": {
"llmNotConfigured": "KI-Anbieter nicht konfiguriert. Aktivieren Sie ihn unter Einstellungen → KI-Anbieter.",
"enrichStarted": "Metadaten werden mit KI angereichert...",
"enrichComplete": "Metadatenanreicherung abgeschlossen: {{summary}}",
"enrichFailed": "Metadatenanreicherung fehlgeschlagen: {{error}}"
}
},
"doctor": {

File diff suppressed because it is too large Load Diff

View File

@@ -657,6 +657,23 @@
"proxyPassword": "Contraseña (opcional)",
"proxyPasswordPlaceholder": "contraseña",
"proxyPasswordHelp": "Contraseña para autenticación de proxy (si es necesario)"
},
"aiProvider": {
"title": "Proveedor de IA",
"provider": "Proveedor",
"providerHelp": "Elija su proveedor de LLM. OpenAI y Ollama usan endpoints predefinidos. Personalizado le permite especificar cualquier endpoint compatible con OpenAI.",
"custom": "Personalizado (compatible con OpenAI)",
"apiBase": "URL base de la API",
"apiBaseHelp": "La URL base para la API LLM (p.ej. https://api.openai.com/v1). Déjelo vacío para usar el valor predeterminado del proveedor.",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "Clave de API",
"apiKeyHelp": "Su clave de API del proveedor LLM. Se almacena localmente y nunca se envía a ningún servidor excepto a su proveedor LLM elegido.",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "No configurada",
"apiKeyConfigured": "Configurada",
"apiKeySet": "Configurar",
"model": "Modelo",
"modelHelp": "El nombre del modelo a usar (p.ej. deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Consulte a su proveedor para ver los modelos disponibles."
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "Completado: {success} movidos, {skipped} omitidos, {failures} fallidos",
"complete": "Auto-organización completada",
"error": "Error: {error}"
}
},
"enrichHfAgent": "Enriquecer metadatos (IA)"
},
"contextMenu": {
"refreshMetadata": "Actualizar datos de Civitai",
@@ -778,7 +796,8 @@
"shareRecipe": "Compartir receta",
"viewAllLoras": "Ver todos los LoRAs",
"downloadMissingLoras": "Descargar LoRAs faltantes",
"deleteRecipe": "Eliminar receta"
"deleteRecipe": "Eliminar receta",
"enrichHfAgent": "Enriquecer metadatos (IA)"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "Copiado al portapapeles",
"downloadStarted": "Descarga iniciada"
},
"agent": {
"llmNotConfigured": "Proveedor de IA no configurado. Actívelo en Configuración → Proveedor de IA.",
"enrichStarted": "Enriqueciendo metadatos con IA...",
"enrichComplete": "Enriquecimiento de metadatos completado: {{summary}}",
"enrichFailed": "Enriquecimiento de metadatos fallido: {{error}}"
}
},
"doctor": {

View File

@@ -657,6 +657,23 @@
"proxyPassword": "Mot de passe (optionnel)",
"proxyPasswordPlaceholder": "mot_de_passe",
"proxyPasswordHelp": "Mot de passe pour l'authentification proxy (si nécessaire)"
},
"aiProvider": {
"title": "Fournisseur d'IA",
"provider": "Fournisseur",
"providerHelp": "Choisissez votre fournisseur LLM. OpenAI et Ollama utilisent des endpoints prédéfinis. Personnalisé vous permet de spécifier n'importe quel endpoint compatible OpenAI.",
"custom": "Personnalisé (compatible OpenAI)",
"apiBase": "URL de base de l'API",
"apiBaseHelp": "L'URL de base pour l'API LLM (ex. https://api.openai.com/v1). Laissez vide pour utiliser le fournisseur par défaut.",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "Clé API",
"apiKeyHelp": "Votre clé API du fournisseur LLM. Stockée localement, jamais envoyée à un serveur autre que votre fournisseur LLM choisi.",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "Non définie",
"apiKeyConfigured": "Configurée",
"apiKeySet": "Configurer",
"model": "Modèle",
"modelHelp": "Le nom du modèle à utiliser (ex. deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Consultez votre fournisseur pour les modèles disponibles."
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "Terminé : {success} déplacés, {skipped} ignorés, {failures} échecs",
"complete": "Auto-organisation terminée",
"error": "Erreur : {error}"
}
},
"enrichHfAgent": "Enrichir les métadonnées (IA)"
},
"contextMenu": {
"refreshMetadata": "Actualiser les données Civitai",
@@ -778,7 +796,8 @@
"shareRecipe": "Partager la recipe",
"viewAllLoras": "Voir tous les LoRAs",
"downloadMissingLoras": "Télécharger les LoRAs manquants",
"deleteRecipe": "Supprimer la recipe"
"deleteRecipe": "Supprimer la recipe",
"enrichHfAgent": "Enrichir les métadonnées (IA)"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "Copié dans le presse-papiers",
"downloadStarted": "Téléchargement démarré"
},
"agent": {
"llmNotConfigured": "Fournisseur d'IA non configuré. Activez-le dans Paramètres → Fournisseur d'IA.",
"enrichStarted": "Enrichissement des métadonnées par IA...",
"enrichComplete": "Enrichissement des métadonnées terminé : {{summary}}",
"enrichFailed": "Échec de l'enrichissement des métadonnées : {{error}}"
}
},
"doctor": {

View File

@@ -657,6 +657,23 @@
"proxyPassword": "סיסמה (אופציונלי)",
"proxyPasswordPlaceholder": "password",
"proxyPasswordHelp": "סיסמה לאימות מול הפרוקסי (אם נדרש)"
},
"aiProvider": {
"title": "ספק AI",
"provider": "ספק",
"providerHelp": "בחר את ספק ה-LLM שלך. OpenAI ו-Ollama משתמשים בנקודות קצה מוגדרות מראש. מותאם אישית מאפשר לך לציין כל נקודת קצה תואמת OpenAI.",
"custom": "מותאם אישית (תואם OpenAI)",
"apiBase": "כתובת בסיס API",
"apiBaseHelp": "כתובת ה-URL הבסיסית ל-API של LLM (לדוגמה https://api.openai.com/v1). השאר ריק לשימוש בברירת המחדל של הספק.",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "מפתח API",
"apiKeyHelp": "מפתח ה-API של ספק ה-LLM שלך. נשמר מקומית, לעולם לא נשלח לשרת כלשהו מלבד ספק ה-LLM שבחרת.",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "לא הוגדר",
"apiKeyConfigured": "הוגדר",
"apiKeySet": "הגדר",
"model": "מודל",
"modelHelp": "שם המודל לשימוש (לדוגמה deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). בדוק אצל הספק שלך אילו מודלים זמינים."
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "הושלם: {success} הועברו, {skipped} דולגו, {failures} נכשלו",
"complete": "ארגון אוטומטי הושלם",
"error": "שגיאה: {error}"
}
},
"enrichHfAgent": "העשרת מטא-דאטה (AI)"
},
"contextMenu": {
"refreshMetadata": "רענן נתוני Civitai",
@@ -778,7 +796,8 @@
"shareRecipe": "שתף מתכון",
"viewAllLoras": "הצג את כל ה-LoRAs",
"downloadMissingLoras": "הורד LoRAs חסרים",
"deleteRecipe": "מחק מתכון"
"deleteRecipe": "מחק מתכון",
"enrichHfAgent": "העשרת מטא-דאטה (AI)"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "הועתק ללוח",
"downloadStarted": "ההורדה החלה"
},
"agent": {
"llmNotConfigured": "ספק AI לא הוגדר. הפעל אותו בהגדרות → ספק AI.",
"enrichStarted": "מעשיר מטא-דאטה באמצעות AI...",
"enrichComplete": "העשרת מטא-דאטה הושלמה: {{summary}}",
"enrichFailed": "העשרת מטא-דאטה נכשלה: {{error}}"
}
},
"doctor": {

View File

@@ -657,6 +657,23 @@
"proxyPassword": "パスワード(任意)",
"proxyPasswordPlaceholder": "パスワード",
"proxyPasswordHelp": "プロキシ認証用のパスワード(必要な場合)"
},
"aiProvider": {
"title": "AIプロバイダー",
"provider": "プロバイダー",
"providerHelp": "LLMプロバイダーを選択してください。OpenAIとOllamaはプリセットのAPIエンドポイントを使用します。カスタムでは任意のOpenAI互換エンドポイントを指定できます。",
"custom": "カスタムOpenAI互換",
"apiBase": "APIベースURL",
"apiBaseHelp": "LLM APIのベースURLhttps://api.openai.com/v1。空の場合はプロバイダーのデフォルトが使用されます。",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "APIキー",
"apiKeyHelp": "LLMプロバイダーのAPIキー。ローカルに保存され、選択したLLMプロバイダー以外のサーバーに送信されることはありません。",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "未設定",
"apiKeyConfigured": "設定済み",
"apiKeySet": "設定",
"model": "モデル",
"modelHelp": "使用するモデル名deepseek-v4-flash, gemini-2.5-flash, gemma4:12b。プロバイダーで利用可能なモデルをご確認ください。"
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "完了:{success} 移動、{skipped} スキップ、{failures} 失敗",
"complete": "自動整理が完了しました",
"error": "エラー:{error}"
}
},
"enrichHfAgent": "メタデータをAIで補完"
},
"contextMenu": {
"refreshMetadata": "Civitaiデータを更新",
@@ -778,7 +796,8 @@
"shareRecipe": "レシピを共有",
"viewAllLoras": "すべてのLoRAを表示",
"downloadMissingLoras": "不足しているLoRAをダウンロード",
"deleteRecipe": "レシピを削除"
"deleteRecipe": "レシピを削除",
"enrichHfAgent": "メタデータをAIで補完"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "クリップボードにコピーしました",
"downloadStarted": "ダウンロードを開始しました"
},
"agent": {
"llmNotConfigured": "AIプロバイダーが設定されていません。設定 → AIプロバイダーで有効にしてください。",
"enrichStarted": "AIでメタデータを補完中...",
"enrichComplete": "メタデータの補完が完了しました:{{summary}}",
"enrichFailed": "メタデータの補完に失敗しました:{{error}}"
}
},
"doctor": {

View File

@@ -657,6 +657,23 @@
"proxyPassword": "비밀번호 (선택사항)",
"proxyPasswordPlaceholder": "password",
"proxyPasswordHelp": "프록시 인증에 필요한 비밀번호 (필요한 경우)"
},
"aiProvider": {
"title": "AI 제공자",
"provider": "제공자",
"providerHelp": "LLM 제공자를 선택하세요. OpenAI와 Ollama는 사전 설정된 API 엔드포인트를 사용합니다. 사용자 정의를 선택하면 모든 OpenAI 호환 엔드포인트를 지정할 수 있습니다.",
"custom": "사용자 정의 (OpenAI 호환)",
"apiBase": "API 기본 URL",
"apiBaseHelp": "LLM API의 기본 URL입니다 (예: https://api.openai.com/v1). 비워두면 제공자 기본값이 사용됩니다.",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "API 키",
"apiKeyHelp": "LLM 제공자의 API 키입니다. 로컬에 저장되며 선택한 LLM 제공자 외의 서버로 전송되지 않습니다.",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "설정되지 않음",
"apiKeyConfigured": "설정됨",
"apiKeySet": "설정",
"model": "모델",
"modelHelp": "사용할 모델 이름 (예: deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). 제공자에서 사용 가능한 모델을 확인하세요."
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "완료: {success}개 이동, {skipped}개 건너뜀, {failures}개 실패",
"complete": "자동 정리 완료",
"error": "오류: {error}"
}
},
"enrichHfAgent": "AI로 메타데이터 보강"
},
"contextMenu": {
"refreshMetadata": "Civitai 데이터 새로고침",
@@ -778,7 +796,8 @@
"shareRecipe": "레시피 공유",
"viewAllLoras": "모든 LoRA 보기",
"downloadMissingLoras": "누락된 LoRA 다운로드",
"deleteRecipe": "레시피 삭제"
"deleteRecipe": "레시피 삭제",
"enrichHfAgent": "AI로 메타데이터 보강"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "클립보드에 복사됨",
"downloadStarted": "다운로드 시작됨"
},
"agent": {
"llmNotConfigured": "AI 제공자가 설정되지 않았습니다. 설정 → AI 제공자에서 활성화하세요.",
"enrichStarted": "AI로 메타데이터 보강 중...",
"enrichComplete": "메타데이터 보강 완료: {{summary}}",
"enrichFailed": "메타데이터 보강 실패: {{error}}"
}
},
"doctor": {

View File

@@ -657,6 +657,23 @@
"proxyPassword": "Пароль (необязательно)",
"proxyPasswordPlaceholder": "пароль",
"proxyPasswordHelp": "Пароль для аутентификации на прокси (если требуется)"
},
"aiProvider": {
"title": "Поставщик ИИ",
"provider": "Поставщик",
"providerHelp": "Выберите поставщика LLM. OpenAI и Ollama используют предустановленные API-эндпоинты. Пользовательский позволяет указать любой совместимый с OpenAI эндпоинт.",
"custom": "Пользовательский (совместимый с OpenAI)",
"apiBase": "Базовый URL API",
"apiBaseHelp": "Базовый URL для LLM API (например, https://api.openai.com/v1). Оставьте пустым, чтобы использовать значение по умолчанию.",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "API-ключ",
"apiKeyHelp": "Ваш API-ключ поставщика LLM. Хранится локально и никогда не отправляется на другие серверы, кроме выбранного поставщика LLM.",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "Не задан",
"apiKeyConfigured": "Настроен",
"apiKeySet": "Настроить",
"model": "Модель",
"modelHelp": "Имя модели для использования (например, deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Проверьте доступные модели у вашего поставщика."
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "Завершено: {success} перемещено, {skipped} пропущено, {failures} не удалось",
"complete": "Автоматическая организация завершена",
"error": "Ошибка: {error}"
}
},
"enrichHfAgent": "Обогатить метаданные (ИИ)"
},
"contextMenu": {
"refreshMetadata": "Обновить данные Civitai",
@@ -778,7 +796,8 @@
"shareRecipe": "Поделиться рецептом",
"viewAllLoras": "Посмотреть все LoRAs",
"downloadMissingLoras": "Загрузить отсутствующие LoRAs",
"deleteRecipe": "Удалить рецепт"
"deleteRecipe": "Удалить рецепт",
"enrichHfAgent": "Обогатить метаданные (ИИ)"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "Скопировано в буфер обмена",
"downloadStarted": "Загрузка начата"
},
"agent": {
"llmNotConfigured": "Поставщик ИИ не настроен. Включите его в Настройки → Поставщик ИИ.",
"enrichStarted": "Обогащение метаданных с помощью ИИ...",
"enrichComplete": "Обогащение метаданных завершено: {{summary}}",
"enrichFailed": "Ошибка обогащения метаданных: {{error}}"
}
},
"doctor": {

View File

@@ -657,6 +657,23 @@
"proxyPassword": "密码 (可选)",
"proxyPasswordPlaceholder": "密码",
"proxyPasswordHelp": "代理认证的密码 (如果需要)"
},
"aiProvider": {
"title": "AI 提供商",
"provider": "提供商",
"providerHelp": "选择您的 LLM 提供商。OpenAI 和 Ollama 使用预设的 API 端点。自定义允许您指定任何兼容 OpenAI 的端点。",
"custom": "自定义(兼容 OpenAI",
"apiBase": "API 基础地址",
"apiBaseHelp": "LLM API 的基础 URL例如 https://api.openai.com/v1。留空则使用提供商默认地址。",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "API 密钥",
"apiKeyHelp": "您的 LLM 提供商 API 密钥。仅本地存储,不会发送到您选择的 LLM 提供商之外的任何服务器。",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "未设置",
"apiKeyConfigured": "已配置",
"apiKeySet": "设置",
"model": "模型",
"modelHelp": "要使用的模型名称(例如 deepseek-v4-flash, gemini-2.5-flash, gemma4:12b。请查看您的提供商支持的可用模型列表。"
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "完成:已移动 {success} 个,跳过 {skipped} 个,失败 {failures} 个",
"complete": "自动整理已完成",
"error": "错误:{error}"
}
},
"enrichHfAgent": "AI 元数据增强"
},
"contextMenu": {
"refreshMetadata": "刷新 Civitai 数据",
@@ -778,7 +796,8 @@
"shareRecipe": "分享配方",
"viewAllLoras": "查看所有 LoRA",
"downloadMissingLoras": "下载缺失的 LoRA",
"deleteRecipe": "删除配方"
"deleteRecipe": "删除配方",
"enrichHfAgent": "AI 元数据增强"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "已复制到剪贴板",
"downloadStarted": "下载已开始"
},
"agent": {
"llmNotConfigured": "AI 提供商未配置。请在 设置 → AI 提供商 中进行配置。",
"enrichStarted": "正在使用 AI 增强元数据...",
"enrichComplete": "元数据增强完成:{{summary}}",
"enrichFailed": "元数据增强失败:{{error}}"
}
},
"doctor": {

View File

@@ -657,6 +657,23 @@
"proxyPassword": "密碼(選填)",
"proxyPasswordPlaceholder": "password",
"proxyPasswordHelp": "代理驗證所需的密碼(如有需要)"
},
"aiProvider": {
"title": "AI 提供者",
"provider": "提供者",
"providerHelp": "選擇您的 LLM 提供者。OpenAI 和 Ollama 使用預設 API 端點。自訂允許您指定任何相容 OpenAI 的端點。",
"custom": "自訂(相容 OpenAI",
"apiBase": "API 基礎位址",
"apiBaseHelp": "LLM API 的基礎 URL例如 https://api.openai.com/v1。留空則使用提供者預設位址。",
"apiBasePlaceholder": "https://api.openai.com/v1",
"apiKey": "API 金鑰",
"apiKeyHelp": "您的 LLM 提供者 API 金鑰。僅儲存在本地,除了您選擇的 LLM 提供者外,不會發送到任何伺服器。",
"apiKeyPlaceholder": "sk-...",
"apiKeyNotSet": "未設定",
"apiKeyConfigured": "已設定",
"apiKeySet": "設定",
"model": "模型",
"modelHelp": "要使用的模型名稱(例如 deepseek-v4-flash, gemini-2.5-flash, gemma4:12b。請查看您的提供者支援的可用模型列表。"
}
},
"loras": {
@@ -754,7 +771,8 @@
"completed": "完成:已移動 {success},已略過 {skipped},失敗 {failures}",
"complete": "自動整理完成",
"error": "錯誤:{error}"
}
},
"enrichHfAgent": "AI 中繼資料增強"
},
"contextMenu": {
"refreshMetadata": "刷新 Civitai 資料",
@@ -778,7 +796,8 @@
"shareRecipe": "分享配方",
"viewAllLoras": "檢視全部 LoRA",
"downloadMissingLoras": "下載缺少的 LoRA",
"deleteRecipe": "刪除配方"
"deleteRecipe": "刪除配方",
"enrichHfAgent": "AI 中繼資料增強"
}
},
"recipes": {
@@ -2081,6 +2100,12 @@
"moveFailed": "Failed to move item: {message}",
"copiedToClipboard": "已複製到剪貼簿",
"downloadStarted": "下載已開始"
},
"agent": {
"llmNotConfigured": "AI 提供者尚未設定。請在 設定 → AI 提供者 中進行設定。",
"enrichStarted": "正在使用 AI 增強中繼資料...",
"enrichComplete": "中繼資料增強完成:{{summary}}",
"enrichFailed": "中繼資料增強失敗:{{error}}"
}
},
"doctor": {

225
py/agent_cli/__init__.py Normal file
View File

@@ -0,0 +1,225 @@
"""Agent CLI — thin in-process wrappers around LoRA Manager internal services.
All functions are simple Python async functions that delegate to the
appropriate internal service. They use **relative imports** within the
``py`` package, so ``sys.modules`` caching works normally and there is no
risk of double import or circular dependencies.
Usage (in-process, primary)::
from py.agent_cli import list_base_models, read_metadata
models = await list_base_models()
meta = await read_metadata("/path/to/model.safetensors")
Usage (subprocess, debugging / external)::
python -m py.agent_cli base-models list
python -m py.agent_cli metadata read /path/to/model.safetensors
"""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _find_scanner_for_model(
model_path: str,
) -> tuple[object, object] | tuple[None, None]:
"""Find the (scanner, cache_entry) responsible for *model_path*.
Iterates all known scanner types and returns the first one whose cache
contains the given path. Returns ``(None, None)`` when no scanner
claims the model.
"""
from ..services.service_registry import ServiceRegistry
normalized = os.path.normpath(model_path)
for getter_name in (
"get_lora_scanner",
"get_checkpoint_scanner",
"get_embedding_scanner",
):
getter = getattr(ServiceRegistry, getter_name, None)
if getter is None:
continue
try:
scanner = await getter()
if scanner is None:
continue
cache = await scanner.get_cached_data()
for entry in cache.raw_data:
if os.path.normpath(entry.get("file_path", "")) == normalized:
return scanner, entry
except Exception as exc:
logger.debug(
"Scanner %s check failed for %s: %s",
getter_name,
model_path,
exc,
)
return None, None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def list_base_models(limit: int = 0) -> List[str]:
"""Return deduplicated base model names from all model caches.
The result is ordered by frequency (most common first). Pass
*limit* = 0 (default) for all models.
"""
from ..services.service_registry import ServiceRegistry
counts: Dict[str, int] = {}
for getter_name in (
"get_lora_scanner",
"get_checkpoint_scanner",
"get_embedding_scanner",
):
getter = getattr(ServiceRegistry, getter_name, None)
if getter is None:
continue
try:
scanner = await getter()
if scanner is None:
continue
cache = await scanner.get_cached_data()
for entry in cache.raw_data:
bm = entry.get("base_model")
if bm:
counts[bm] = counts.get(bm, 0) + 1
except Exception as exc:
logger.debug("list_base_models scanner %s error: %s", getter_name, exc)
sorted_names = [name for name, _ in sorted(counts.items(), key=lambda x: -x[1])]
if limit > 0:
return sorted_names[:limit]
return sorted_names
async def read_metadata(model_path: str) -> Dict[str, Any]:
"""Load the full metadata payload for *model_path* from disk.
Returns an empty dict when the metadata file does not exist or cannot
be parsed — never raises.
"""
from ..utils.metadata_manager import MetadataManager
try:
return await MetadataManager.load_metadata_payload(model_path) or {}
except Exception as exc:
logger.warning("read_metadata failed for %s: %s", model_path, exc)
return {}
async def apply_metadata_updates(
model_path: str,
updates: Dict[str, Any],
) -> List[str]:
"""Merge *updates* into the model's on-disk metadata and persist.
Returns the list of field names that actually changed.
"""
from ..utils.metadata_manager import MetadataManager
metadata = await read_metadata(model_path)
updated_fields: List[str] = []
for key, value in updates.items():
old = metadata.get(key)
if old != value:
metadata[key] = value
updated_fields.append(key)
if updated_fields:
await MetadataManager.save_metadata(model_path, metadata)
return updated_fields
async def download_preview(
model_path: str,
url: str,
*,
target_width: int = 480,
quality: int = 85,
) -> bool:
"""Download a preview image from *url*, optimise to .webp, and save it.
The output file is placed alongside the model file with a ``.webp``
extension. Returns ``True`` on success.
"""
from ..services.downloader import get_downloader
from ..utils.exif_utils import ExifUtils
if not url or not url.strip():
return False
base_name = os.path.splitext(os.path.basename(model_path))[0]
preview_dir = os.path.dirname(model_path)
output_path = os.path.join(preview_dir, base_name + ".webp")
downloader = await get_downloader()
# Try in-memory download + optimise first
success, content, _headers = await downloader.download_to_memory(
url, use_auth=False,
)
if success and content:
try:
optimized_data, _ = ExifUtils.optimize_image(
image_data=content,
target_width=target_width,
format="webp",
quality=quality,
preserve_metadata=False,
)
with open(output_path, "wb") as f:
f.write(optimized_data)
logger.info("Preview downloaded and optimised for %s", model_path)
return True
except Exception as exc:
logger.warning("Preview optimisation failed, saving raw: %s", exc)
# Fall through to raw save
# Fallback: download directly to file
try:
ok, _ = await downloader.download_file(url, output_path, use_auth=False)
if ok:
logger.info("Preview downloaded (fallback) for %s", model_path)
return True
except Exception as exc:
logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
return False
async def refresh_cache(model_path: str) -> bool:
"""Invalidate and reload the scanner cache entry for *model_path*.
Returns ``True`` when the model was found and the cache was refreshed.
"""
scanner, entry = await _find_scanner_for_model(model_path)
if scanner is None:
logger.warning("refresh_cache: no scanner found for %s", model_path)
return False
try:
metadata = await read_metadata(model_path)
if not metadata:
logger.warning("refresh_cache: no metadata for %s", model_path)
return False
await scanner.update_single_model_cache(model_path, model_path, metadata)
return True
except Exception as exc:
logger.warning("refresh_cache failed for %s: %s", model_path, exc)
return False

118
py/agent_cli/__main__.py Normal file
View File

@@ -0,0 +1,118 @@
"""Subprocess entry point for AgentCLI (debugging / external use).
Usage::
python -m py.agent_cli base-models list [--limit N]
python -m py.agent_cli metadata read <path>
python -m py.agent_cli metadata update <path> --json '{...}'
python -m py.agent_cli preview download <path> --url <url>
python -m py.agent_cli cache refresh <path>
NOTE: This is an **optional** convenience wrapper. The primary consumer of
AgentCLI is the :mod:`AgentService` (in-process). This entry point exists
for manual debugging and future integration with subprocess-based agent
frameworks.
"""
from __future__ import annotations
import argparse
import asyncio
import json
import sys
from typing import Any, Dict, List
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="lmcli", description="LoRA Manager Agent CLI")
sub = parser.add_subparsers(dest="command", required=True)
# base-models list
base_models = sub.add_parser("base-models", aliases=["bm"])
base_models_cmds = base_models.add_subparsers(dest="subcommand", required=True)
base_models_list = base_models_cmds.add_parser("list")
base_models_list.add_argument(
"--limit", type=int, default=0, help="Max number of models (0 = all)"
)
# metadata read
meta = sub.add_parser("metadata", aliases=["md"])
meta_cmds = meta.add_subparsers(dest="subcommand", required=True)
meta_read = meta_cmds.add_parser("read")
meta_read.add_argument("path", type=str, help="Model file path")
# metadata update
meta_update = meta_cmds.add_parser("update")
meta_update.add_argument("path", type=str, help="Model file path")
meta_update.add_argument(
"--json",
type=str,
required=True,
help='JSON object of fields to update, e.g. \'{"base_model": "SDXL 1.0"}\'',
)
# preview download
prev = sub.add_parser("preview", aliases=["pv"])
prev_cmds = prev.add_subparsers(dest="subcommand", required=True)
prev_dl = prev_cmds.add_parser("download")
prev_dl.add_argument("path", type=str, help="Model file path")
prev_dl.add_argument("--url", type=str, required=True, help="Preview image URL")
# cache refresh
cache = sub.add_parser("cache")
cache_cmds = cache.add_subparsers(dest="subcommand", required=True)
cache_refresh = cache_cmds.add_parser("refresh")
cache_refresh.add_argument("path", type=str, help="Model file path")
return parser
async def _run(args: argparse.Namespace) -> Any:
from . import ( # lazy import so startup is fast
list_base_models,
read_metadata,
apply_metadata_updates,
download_preview,
refresh_cache,
)
cmd = args.command
sub = args.subcommand
if cmd in ("base-models", "bm") and sub == "list":
return await list_base_models(limit=args.limit)
if cmd in ("metadata", "md") and sub == "read":
return await read_metadata(args.path)
if cmd in ("metadata", "md") and sub == "update":
updates: Dict[str, Any] = json.loads(args.json)
return await apply_metadata_updates(args.path, updates)
if cmd in ("preview", "pv") and sub == "download":
return await download_preview(args.path, args.url)
if cmd == "cache" and sub == "refresh":
return await refresh_cache(args.path)
raise ValueError(f"Unknown command: {cmd} {sub}")
def main() -> None:
parser = _build_parser()
args = parser.parse_args()
result = asyncio.run(_run(args))
# Always print as JSON so callers can parse reliably
if isinstance(result, list):
for item in result:
print(item)
elif isinstance(result, dict):
json.dump(result, sys.stdout, ensure_ascii=False, indent=2)
print()
else:
print(json.dumps(result))
if __name__ == "__main__":
main()

View File

@@ -8,6 +8,8 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
import logging
import json
import urllib.parse
import sys as _sys
import types as _types
import time
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
@@ -1380,4 +1382,20 @@ class Config:
# Global config instance
config = Config()
# NOTE: Guard against re-import. When ServiceRegistry.get_lora_scanner() triggers
# a fresh import of lora_scanner → config, we must NOT re-execute Config.__init__()
# (which re-scans all roots, re-registers libraries, etc.).
#
# Strategy: store the config instance in a dedicated sentinel module
# ('_lm_config_cache') that is NEVER removed from sys.modules (its key does
# NOT start with 'py.'), so it survives re-imports of py.* modules.
_CONFIG_SENTINEL = "_lm_config_cache"
if _CONFIG_SENTINEL in _sys.modules:
# Re-import: reuse the existing singleton from the sentinel.
config: Config = _sys.modules[_CONFIG_SENTINEL].config # type: ignore[valid-type]
else:
config: Config = Config()
# Register the sentinel so re-imports of py.config find us.
_sentinel_mod = _types.ModuleType(_CONFIG_SENTINEL)
_sentinel_mod.config = config
_sys.modules[_CONFIG_SENTINEL] = _sentinel_mod

View File

@@ -445,5 +445,12 @@ class LoraManager:
scanner.cancel_task()
logger.debug("LoRA Manager: Cancelled %s", name)
# Close shared aiohttp sessions to avoid "Unclosed client session" warnings
try:
from py.routes.handlers.hf_handlers import close_hf_api_session
await close_hf_api_session()
except Exception as exc:
logger.debug("Error closing HF API session: %s", exc)
except Exception as e:
logger.error(f"Error during cleanup: {e}", exc_info=True)

View File

@@ -0,0 +1,167 @@
"""HTTP route handlers for agent skill endpoints.
These handlers expose the :class:`AgentService` via HTTP, allowing the
frontend to list available skills and execute them on selected models.
Progress is reported via WebSocket broadcast.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Dict
from aiohttp import web
from ...services.agent import AgentService, AgentProgressReporter
from ...services.llm_service import LLMNotConfiguredError
logger = logging.getLogger(__name__)
class AgentHandler:
"""HTTP handler for agent skill operations."""
def __init__(self, agent_service: AgentService | None = None) -> None:
self._agent_service = agent_service
async def _ensure_service(self) -> AgentService:
if self._agent_service is None:
self._agent_service = await AgentService.get_instance()
return self._agent_service
# ------------------------------------------------------------------
# GET /api/lm/agent/skills
# ------------------------------------------------------------------
async def get_agent_skills(self, request: web.Request) -> web.Response:
"""Return a list of available agent skills."""
service = await self._ensure_service()
skills = await service.list_skills()
return web.json_response({"skills": skills})
# ------------------------------------------------------------------
# POST /api/lm/agent/execute/{skill_name}
# ------------------------------------------------------------------
async def execute_agent_skill(self, request: web.Request) -> web.Response:
"""Execute an agent skill on the provided model paths.
Request body::
{"model_paths": ["/path/to/model1.safetensors", ...], "options": {}}
Returns immediately with a task ID. Execution runs in the
background; progress and completion are pushed via WebSocket
events of type ``agent_progress``.
"""
skill_name = request.match_info.get("skill_name", "")
if not skill_name:
return web.json_response(
{"error": "Skill name is required"}, status_code=400
)
try:
body = await request.json()
except Exception:
return web.json_response(
{"error": "Invalid JSON body"}, status_code=400
)
model_paths = body.get("model_paths", [])
if not model_paths or not isinstance(model_paths, list):
return web.json_response(
{"error": "model_paths must be a non-empty array"},
status_code=400,
)
service = await self._ensure_service()
# Validate LLM configuration early for skills that need it
# (fail fast rather than after starting background work)
try:
from ...services.llm_service import LLMService
llm = await LLMService.get_instance()
if not llm.is_configured():
return web.json_response(
{
"error": "LLM provider is not configured. "
"Enable it in Settings → AI Provider.",
},
status=400,
)
except Exception as exc:
logger.error("Failed to check LLM configuration: %s", exc)
# Launch execution in the background
progress_reporter = AgentProgressReporter()
logger.info(
"Agent skill '%s' starting for %d model(s) in background task",
skill_name, len(model_paths),
)
async def _run() -> None:
logger.info("_run background task started for skill '%s'", skill_name)
try:
result = await service.execute_skill(
skill_name=skill_name,
input_data={"model_paths": model_paths},
progress_callback=progress_reporter,
)
logger.info(
"Agent skill '%s' finished: success=%s, summary='%s', errors=%s",
skill_name, result.success, result.summary, result.errors,
)
except LLMNotConfiguredError as exc:
logger.warning("Agent skill '%s' not configured: %s", skill_name, exc)
await progress_reporter.on_progress(
{
"type": "agent_progress",
"skill": skill_name,
"status": "error",
"error": str(exc),
}
)
except Exception as exc:
logger.error("Agent skill '%s' failed: %s", skill_name, exc, exc_info=True)
await progress_reporter.on_progress(
{
"type": "agent_progress",
"skill": skill_name,
"status": "error",
"error": str(exc),
}
)
# Fire and forget — progress comes via WebSocket
task = asyncio.create_task(_run())
logger.info("Agent skill '%s' background task created (id=%s)", skill_name, task)
return web.json_response(
{
"status": "started",
"skill": skill_name,
"model_count": len(model_paths),
}
)
# ------------------------------------------------------------------
# POST /api/lm/agent/cancel
# ------------------------------------------------------------------
async def cancel_agent_skill(self, request: web.Request) -> web.Response:
"""Cancel a running agent skill.
NOTE: Cancellation is a stub for now — the AgentService processes
models sequentially and does not yet support mid-execution
cancellation. This endpoint exists for API completeness.
"""
# TODO: implement cooperative cancellation in AgentService
return web.json_response(
{"status": "acknowledged", "note": "Cancellation not yet implemented"},
status_code=200,
)

View File

@@ -49,6 +49,14 @@ async def _get_hf_api_session() -> aiohttp.ClientSession:
return _hf_api_session
async def close_hf_api_session() -> None:
"""Close the shared HF API session, if it was ever created."""
global _hf_api_session
if _hf_api_session is not None and not _hf_api_session.closed:
await _hf_api_session.close()
_hf_api_session = None
def _infer_model_type(model_root: str) -> tuple[Any, str]:
"""Determine model class and scanner by matching ``model_root`` against the
configured root paths for each model type (from ``Config``).

View File

@@ -49,6 +49,7 @@ from ...utils.constants import (
VALID_LORA_TYPES,
)
from .hf_handlers import HfHandler
from .agent_handlers import AgentHandler
from ...utils.civitai_utils import rewrite_preview_url
from ...utils.example_images_paths import (
find_non_compliant_items_in_example_images_root,
@@ -3317,6 +3318,7 @@ class MiscHandlerSet:
example_workflows: ExampleWorkflowsHandler,
base_model: BaseModelHandlerSet,
hf_handler: HfHandler | None = None,
agent_handler: AgentHandler | None = None,
) -> None:
self.health = health
self.settings = settings
@@ -3336,6 +3338,7 @@ class MiscHandlerSet:
self.example_workflows = example_workflows
self.base_model = base_model
self.hf_handler = hf_handler
self.agent_handler = agent_handler
def to_route_mapping(
self,
@@ -3384,6 +3387,10 @@ class MiscHandlerSet:
# Hugging Face handlers
"get_hf_repo_files": self.hf_handler.get_hf_repo_files,
"download_hf_model": self.hf_handler.download_hf_model,
# Agent skill handlers
"get_agent_skills": self.agent_handler.get_agent_skills,
"execute_agent_skill": self.agent_handler.execute_agent_skill,
"cancel_agent_skill": self.agent_handler.cancel_agent_skill,
# Base model handlers
"get_base_models": self.base_model.get_base_models,
"refresh_base_models": self.base_model.refresh_base_models,

View File

@@ -101,6 +101,16 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition(
"POST", "/api/lm/download-hf-model", "download_hf_model"
),
# Agent skill endpoints
RouteDefinition(
"GET", "/api/lm/agent/skills", "get_agent_skills"
),
RouteDefinition(
"POST", "/api/lm/agent/execute/{skill_name}", "execute_agent_skill"
),
RouteDefinition(
"POST", "/api/lm/agent/cancel", "cancel_agent_skill"
),
)

View File

@@ -40,6 +40,7 @@ from .handlers.misc_handlers import (
)
from .handlers.base_model_handlers import BaseModelHandlerSet
from .handlers.hf_handlers import HfHandler
from .handlers.agent_handlers import AgentHandler
from .misc_route_registrar import MiscRouteRegistrar
logger = logging.getLogger(__name__)
@@ -138,6 +139,7 @@ class MiscRoutes:
example_workflows = ExampleWorkflowsHandler()
base_model = BaseModelHandlerSet()
hf_handler = HfHandler()
agent_handler = AgentHandler()
return self._handler_set_factory(
health=health,
@@ -158,6 +160,7 @@ class MiscRoutes:
example_workflows=example_workflows,
base_model=base_model,
hf_handler=hf_handler,
agent_handler=agent_handler,
)

View File

@@ -0,0 +1,23 @@
"""Agent-powered skill system for LoRA Manager.
This package provides the orchestration layer for LLM/agent-powered features.
Skills define *what* to do (prompt template). The :class:`AgentService`
handles *how* (LLM calls, context gathering, validation, progress).
"""
from __future__ import annotations
from .skill_definition import SkillDefinition, SkillPermissions
from .skill_registry import SkillRegistry
from .agent_service import AgentService, AgentProgressReporter, SkillResult
from .post_processor import PostProcessor
__all__ = [
"AgentProgressReporter",
"AgentService",
"PostProcessor",
"SkillDefinition",
"SkillPermissions",
"SkillRegistry",
"SkillResult",
]

View File

@@ -0,0 +1,413 @@
"""Agent orchestration service.
The :class:`AgentService` coordinates skill execution:
1. Look up the skill in :class:`SkillRegistry`
2. Validate input against the skill's ``input_schema``
3. Prepare context via :mod:`~py.agent_cli` (read metadata, list base models, fetch HF README)
4. If ``llm_required``: call :class:`LLMService` with the rendered prompt
5. Post-process via :class:`PostProcessor` (delegates I/O to :mod:`~py.agent_cli`)
6. Broadcast progress and completion via :class:`WebSocketManager`
Skills define *what* to do (prompt template). The AgentService handles *how*
(LLM calls, context gathering, validation, progress).
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
import aiohttp
from ..llm_service import LLMService
from ..websocket_manager import ws_manager
from .post_processor import PostProcessor
from .skill_registry import SkillRegistry
logger = logging.getLogger(__name__)
class AgentProgressReporter:
"""Protocol-compatible progress reporter backed by WebSocket broadcast."""
async def on_progress(self, payload: Dict[str, Any]) -> None:
await ws_manager.broadcast(payload)
@dataclass
class SkillResult:
"""Outcome of a skill execution."""
success: bool
updated_models: List[Dict[str, Any]] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
summary: str = ""
def _validate_schema(data: Any, schema: Dict[str, Any], path: str = "") -> List[str]:
"""Minimal JSON schema validator.
Supports a subset of JSON Schema: ``type``, ``properties``, ``required``,
``items``, ``enum``. Returns a list of error messages (empty = valid).
"""
errors: List[str] = []
if not schema:
return errors
expected_type = schema.get("type")
if expected_type:
type_map = {
"string": str,
"number": (int, float),
"integer": int,
"boolean": bool,
"array": list,
"object": dict,
"null": type(None),
}
expected_py = type_map.get(expected_type)
if expected_py is not None and not isinstance(data, expected_py):
errors.append(f"{path or 'root'}: expected {expected_type}, got {type(data).__name__}")
return errors
if expected_type == "object" and isinstance(data, dict):
properties = schema.get("properties", {})
required = schema.get("required", [])
for req_key in required:
if req_key not in data:
errors.append(f"{path or 'root'}: missing required property '{req_key}'")
for key, value in data.items():
if key in properties:
errors.extend(_validate_schema(value, properties[key], f"{path}.{key}"))
if expected_type == "array" and isinstance(data, list):
items_schema = schema.get("items")
if items_schema:
for i, item in enumerate(data):
errors.extend(_validate_schema(item, items_schema, f"{path}[{i}]"))
if "enum" in schema and data not in schema["enum"]:
errors.append(f"{path or 'root'}: value '{data}' not in enum {schema['enum']}")
return errors
# ------------------------------------------------------------------
# Prompt template rendering
# ------------------------------------------------------------------
def _render_prompt(template: str, variables: Dict[str, Any]) -> str:
"""Render a prompt template with ``{{variable}}`` placeholders.
Uses simple regex substitution — no Jinja2 dependency needed.
"""
def replace(match: re.Match) -> str:
key = match.group(1).strip()
value = variables.get(key, "")
if isinstance(value, (dict, list)):
return json.dumps(value, ensure_ascii=False, indent=2)
return str(value)
return re.sub(r"\{\{(\w+)\}\}", replace, template)
class AgentService:
"""Orchestrate agent skill execution.
Usage::
service = await AgentService.get_instance()
result = await service.execute_skill(
skill_name="enrich_hf_metadata",
input_data={"model_paths": ["/path/to/model.safetensors"]},
progress_callback=AgentProgressReporter(),
)
"""
_instance: Optional["AgentService"] = None
_lock: asyncio.Lock = asyncio.Lock()
def __init__(
self,
*,
skill_registry: Optional[SkillRegistry] = None,
llm_service: Optional[LLMService] = None,
) -> None:
self._registry = skill_registry
self._llm_service = llm_service
@classmethod
async def get_instance(cls) -> "AgentService":
"""Return the lazily-initialised global ``AgentService``."""
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
cls._instance = cls(
skill_registry=await SkillRegistry.get_instance(),
llm_service=await LLMService.get_instance(),
)
return cls._instance
@classmethod
def reset_instance(cls) -> None:
"""Reset the cached singleton — primarily for tests."""
cls._instance = None
async def _ensure_registry(self) -> SkillRegistry:
if self._registry is None:
self._registry = await SkillRegistry.get_instance()
return self._registry
async def _ensure_llm(self) -> LLMService:
if self._llm_service is None:
self._llm_service = await LLMService.get_instance()
return self._llm_service
async def list_skills(self) -> List[Dict[str, Any]]:
"""Return a JSON-serialisable list of available skills."""
registry = await self._ensure_registry()
return [
{
"name": s.name,
"title": s.title,
"description": s.description,
"llm_required": s.llm_required,
"model_type_filter": s.model_type_filter,
}
for s in registry.list_skills()
]
async def execute_skill(
self,
*,
skill_name: str,
input_data: Dict[str, Any],
progress_callback: Optional[AgentProgressReporter] = None,
) -> SkillResult:
"""Execute an agent skill.
Args:
skill_name: Name of the skill to execute
input_data: Input validated against the skill's ``input_schema``
progress_callback: Optional WebSocket progress reporter
Returns:
:class:`SkillResult` with success status and updated model info
"""
registry = await self._ensure_registry()
logger.info("execute_skill '%s': looking up skill", skill_name)
skill = registry.get_skill(skill_name)
if skill is None:
return SkillResult(
success=False,
errors=[f"Skill not found: {skill_name}"],
summary=f"Skill '{skill_name}' does not exist",
)
input_errors = _validate_schema(input_data, skill.input_schema)
if input_errors:
return SkillResult(
success=False,
errors=input_errors,
summary=f"Invalid input: {'; '.join(input_errors)}",
)
model_paths = input_data.get("model_paths", [])
if not model_paths:
return SkillResult(
success=False,
errors=["No model_paths provided"],
summary="No models to process",
)
total = len(model_paths)
processed = 0
success_count = 0
updated_models: List[Dict[str, Any]] = []
errors: List[str] = []
post_processor = PostProcessor()
logger.info("execute_skill '%s': starting with %d model(s)", skill_name, total)
await self._emit_progress(
progress_callback, skill_name, status="started",
total=total, processed=0, success=0,
)
llm = await self._ensure_llm()
llm_configured = llm.is_configured() if skill.llm_required else True
for model_path in model_paths:
logger.info(
"execute_skill '%s': processing model %d/%d: %s",
skill_name, processed + 1, total, model_path,
)
try:
from ...agent_cli import read_metadata
metadata = await read_metadata(model_path)
prompt_vars: Dict[str, Any] = {"model_path": model_path}
if skill.llm_required and llm_configured:
prompt_vars = await self._build_prompt_context(
skill_name, model_path, metadata, registry, llm,
)
llm_response: Optional[Dict[str, Any]] = None
if skill.llm_required and llm_configured:
prompt_template = registry.load_prompt(skill_name)
rendered = _render_prompt(prompt_template, prompt_vars)
logger.info(
"execute_skill '%s': LLM call for %s (prompt=%d chars)",
skill_name, model_path, len(rendered),
)
llm_response = await llm.chat_completion_json(
system_prompt=prompt_vars.get(
"system_prompt",
"You are a helpful assistant that extracts structured metadata.",
),
user_prompt=rendered,
)
model_result = await post_processor.process(
skill_name=skill_name,
model_path=model_path,
llm_output=llm_response or {},
metadata=metadata,
)
if model_result.get("success", True):
success_count += 1
uf = model_result.get("updated_fields", [])
if uf:
updated_models.append({"path": model_path, "updated_fields": uf})
else:
errors.extend(
model_result.get("errors", [model_result.get("error", "Unknown error")])
)
except Exception as exc:
logger.error("Skill %s failed for %s: %s", skill_name, model_path, exc)
errors.append(f"{model_path}: {exc}")
processed += 1
await self._emit_progress(
progress_callback, skill_name, status="processing",
total=total, processed=processed, success=success_count,
current_path=model_path,
)
result = SkillResult(
success=success_count > 0,
updated_models=updated_models,
errors=errors,
summary=f"Processed {processed}/{total} models, {success_count} succeeded",
)
logger.info("execute_skill '%s': done — %s", skill_name, result.summary)
await self._emit_progress(
progress_callback, skill_name, status="completed",
total=total, processed=processed, success=success_count,
updated_models=updated_models, errors=errors, summary=result.summary,
)
return result
async def _build_prompt_context(
self,
skill_name: str,
model_path: str,
metadata: Dict[str, Any],
registry: SkillRegistry,
llm: Any,
) -> Dict[str, Any]:
"""Gather variables for the skill's prompt template.
Reads metadata, fetches the HF README (if applicable), lists available
base models, and returns a dict that maps to ``{{variable}}``
placeholders in ``prompt.md``.
"""
from ...agent_cli import list_base_models
context: Dict[str, Any] = {
"model_path": model_path,
"hf_url": "",
"repo": "",
"readme_content": "",
"current_metadata": {},
"base_models": [],
}
context["current_metadata"] = {
"file_name": metadata.get("file_name", ""),
"base_model": metadata.get("base_model", ""),
"tags": metadata.get("tags", []),
"modelDescription": metadata.get("modelDescription", ""),
"trainedWords": metadata.get("trainedWords", []),
"sha256": (metadata.get("sha256") or "")[:16] + "..." if metadata.get("sha256") else "",
"size": metadata.get("size", 0),
}
hf_url = metadata.get("hf_url", "")
context["hf_url"] = hf_url
repo = self._extract_repo_from_url(hf_url) if hf_url else ""
context["repo"] = repo or ""
if repo:
readme = await self._fetch_readme(repo)
context["readme_content"] = readme[:8000] if readme else "(README not available)"
try:
context["base_models"] = await list_base_models()
except Exception as exc:
logger.debug("Failed to list base models: %s", exc)
return context
@staticmethod
def _extract_repo_from_url(hf_url: str) -> Optional[str]:
"""Extract ``user/repo`` from a HuggingFace URL."""
if not hf_url:
return None
m = re.match(r"https?://huggingface\.co/([^/]+/[^/]+)", hf_url)
return m.group(1) if m else None
@staticmethod
async def _fetch_readme(repo: str) -> str:
"""Fetch README.md from HuggingFace (tries ``main``, then ``master``)."""
async with aiohttp.ClientSession(
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
timeout=aiohttp.ClientTimeout(total=30),
) as session:
for branch in ("main", "master"):
url = f"https://huggingface.co/{repo}/raw/{branch}/README.md"
try:
async with session.get(url) as resp:
if resp.status == 200:
return await resp.text()
except Exception as exc:
logger.debug("Failed to fetch README from %s: %s", url, exc)
return ""
async def _emit_progress(
self,
callback: Optional[AgentProgressReporter],
skill_name: str,
*,
status: str,
**extra: Any,
) -> None:
"""Send a progress update via WebSocket (if callback is set)."""
payload: Dict[str, Any] = {"type": "agent_progress", "skill": skill_name, "status": status}
payload.update(extra)
if callback is not None:
await callback.on_progress(payload)

View File

@@ -0,0 +1,168 @@
"""Post-processing engine for agent skill outputs.
The :class:`PostProcessor` takes the LLM's structured JSON output and applies
it to a model's on-disk metadata via the :mod:`~py.agent_cli` functions.
It handles all the skill-specific business logic — conditions, transformations,
and orchestration of multiple side-effects (write metadata, download preview,
refresh cache). All actual I/O is delegated to :mod:`~py.agent_cli`.
"""
from __future__ import annotations
import logging
import os
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
class PostProcessor:
"""Deterministic post-processor for agent skill outputs.
Usage (called by :class:`~py.services.agent.agent_service.AgentService`)::
processor = PostProcessor()
result = await processor.process(
skill_name="enrich_hf_metadata",
model_path="/path/to/model.safetensors",
llm_output={...},
metadata={...}, # from agent_cli.read_metadata()
)
"""
async def process(
self,
*,
skill_name: str,
model_path: str,
llm_output: Dict[str, Any],
metadata: Dict[str, Any],
) -> Dict[str, Any]:
"""Route *llm_output* to the correct skill post-processor.
Returns a dict with keys ``success`` (bool), ``updated_fields`` (list),
``preview_downloaded`` (bool), and ``errors`` (list).
"""
if skill_name == "enrich_hf_metadata":
return await self._process_enrich_hf_metadata(
model_path, llm_output, metadata,
)
return {
"success": False,
"updated_fields": [],
"errors": [f"No post-processor registered for skill: {skill_name}"],
}
# ------------------------------------------------------------------
# enrich_hf_metadata
# ------------------------------------------------------------------
async def _process_enrich_hf_metadata(
self,
model_path: str,
llm_output: Dict[str, Any],
metadata: Dict[str, Any],
) -> Dict[str, Any]:
from ...agent_cli import (
apply_metadata_updates,
download_preview,
refresh_cache,
)
updated_fields: List[str] = []
preview_downloaded = False
# -- Determine whether this is an HF-sourced model -----------------
is_hf_model = not metadata.get("from_civitai", True)
# -- Collect updates -----------------------------------------------
updates: Dict[str, Any] = {}
# base_model
new_base = (llm_output.get("base_model") or "").strip()
current_base = metadata.get("base_model", "") or ""
if new_base and self._should_overwrite(current_base, is_hf_model):
updates["base_model"] = new_base
# trainedWords / trigger words
new_triggers = llm_output.get("trigger_words", [])
if isinstance(new_triggers, list):
cleaned = [t.strip() for t in new_triggers if t.strip()]
if cleaned:
current_triggers = metadata.get("trainedWords") or []
if self._should_overwrite_list(current_triggers, is_hf_model):
updates["trainedWords"] = cleaned
# modelDescription
new_desc = (llm_output.get("description") or "").strip()
if new_desc:
current_desc = metadata.get("modelDescription", "") or ""
if self._should_overwrite(current_desc, is_hf_model):
updates["modelDescription"] = new_desc
# tags — merge with existing, deduplicate (case-insensitive)
new_tags = llm_output.get("tags", [])
if isinstance(new_tags, list) and new_tags:
existing_tags = metadata.get("tags") or []
merged = self._merge_tags(existing_tags, new_tags)
if len(merged) > len(existing_tags) or is_hf_model:
updates["tags"] = merged
# metadata_source & llm_enriched_at (always set)
updates["metadata_source"] = "agent:enrich_hf_metadata"
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
# -- Persist updates ------------------------------------------------
if updates:
updated_fields = await apply_metadata_updates(model_path, updates)
# -- Download preview -----------------------------------------------
preview_url = (llm_output.get("preview_url") or "").strip()
current_preview = metadata.get("preview_url") or ""
if preview_url and not (current_preview and os.path.exists(current_preview)):
preview_downloaded = await download_preview(model_path, preview_url)
# -- Refresh scanner cache ------------------------------------------
if updated_fields or preview_downloaded:
await refresh_cache(model_path)
return {
"success": True,
"updated_fields": updated_fields,
"preview_downloaded": preview_downloaded,
"errors": [],
}
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
@staticmethod
def _should_overwrite(current_value: str, is_hf_model: bool) -> bool:
"""Return ``True`` when a scalar field should be overwritten."""
return is_hf_model or not current_value or current_value.lower() in (
"", "unknown",
)
@staticmethod
def _should_overwrite_list(current_list: List[str], is_hf_model: bool) -> bool:
"""Return ``True`` when a list field should be overwritten."""
return is_hf_model or not current_list
@staticmethod
def _merge_tags(existing: List[str], new: List[str]) -> List[str]:
"""Merge *new* tags into *existing*, all lowercased.
This matches the behaviour of :class:`TagUpdateService` which
normalises every tag to lowercase for case-insensitive dedup.
"""
merged: List[str] = []
seen: set = set()
for tag in list(existing) + list(new):
t = tag.strip().lower()
if t and t not in seen:
merged.append(t)
seen.add(t)
return merged

View File

@@ -0,0 +1,45 @@
"""Skill definition data structures.
Each skill is described by a :class:`SkillDefinition` that declares its
input/output schemas, whether it needs an LLM call, and what permissions
its post-processor has.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
@dataclass(frozen=True)
class SkillPermissions:
"""Declarative permission scope for a skill's post-processor.
These are auditable constraints — the :class:`AgentService` checks them
before invoking the handler. They are defense-in-depth, not a sandbox.
"""
write_metadata: bool = True
write_previews: bool = True
network_domains: Tuple[str, ...] = ()
@dataclass(frozen=True)
class SkillDefinition:
"""Immutable description of an agent skill."""
name: str
title: str
description: str
llm_required: bool
input_schema: Dict[str, Any] = field(default_factory=dict)
output_schema: Dict[str, Any] = field(default_factory=dict)
model_type_filter: Optional[List[str]] = None
permissions: SkillPermissions = field(default_factory=SkillPermissions)
def applies_to_model_type(self, model_type: str) -> bool:
"""Return ``True`` if this skill can run on the given model type."""
if self.model_type_filter is None:
return True
return model_type in self.model_type_filter

View File

@@ -0,0 +1,184 @@
"""Discovery and loading of agent skills.
Skills live in ``py/services/agent/skills/<name>/`` directories. Each
directory must contain:
- ``skill.yaml`` — metadata (name, title, description, schemas, permissions)
- ``prompt.md`` — LLM system prompt template (Jinja2-style ``{{variable}}`` placeholders)
- ``handler.py`` — async ``prepare`` and ``post_process`` functions
The registry scans the skills directory on first access and caches results.
"""
from __future__ import annotations
import asyncio
import importlib
import importlib.util
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import yaml
from .skill_definition import SkillDefinition, SkillPermissions
logger = logging.getLogger(__name__)
# Directory where built-in skills are stored
_SKILLS_DIR = Path(__file__).parent / "skills"
class SkillRegistry:
"""Discover and load agent skills from the filesystem."""
_instance: Optional["SkillRegistry"] = None
_lock: asyncio.Lock = asyncio.Lock()
def __init__(self, skills_dir: Path = _SKILLS_DIR) -> None:
self._skills_dir = skills_dir
self._skills: Dict[str, SkillDefinition] = {}
self._loaded: bool = False
# ------------------------------------------------------------------
# Singleton access
# ------------------------------------------------------------------
@classmethod
async def get_instance(cls) -> "SkillRegistry":
"""Return the lazily-initialised global ``SkillRegistry``."""
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
registry = cls()
registry._discover()
cls._instance = registry
return cls._instance
@classmethod
def reset_instance(cls) -> None:
"""Reset the cached singleton — primarily for tests."""
cls._instance = None
# ------------------------------------------------------------------
# Discovery
# ------------------------------------------------------------------
def _discover(self) -> None:
"""Scan the skills directory and load all valid skill definitions."""
self._skills.clear()
if not self._skills_dir.is_dir():
logger.warning("Skills directory does not exist: %s", self._skills_dir)
self._loaded = True
return
for entry in sorted(self._skills_dir.iterdir()):
if not entry.is_dir():
continue
skill_yaml = entry / "skill.yaml"
if not skill_yaml.exists():
continue
try:
definition = self._load_skill_yaml(skill_yaml)
if definition is not None:
self._skills[definition.name] = definition
logger.debug("Loaded skill: %s", definition.name)
except Exception as exc:
logger.warning("Failed to load skill from %s: %s", skill_yaml, exc)
self._loaded = True
logger.info("Discovered %d agent skills", len(self._skills))
def _load_skill_yaml(self, path: Path) -> Optional[SkillDefinition]:
"""Parse a skill.yaml file into a :class:`SkillDefinition`."""
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not data or "name" not in data:
logger.warning("skill.yaml missing required 'name' field: %s", path)
return None
# Parse permissions
perm_data = data.get("permissions", {})
permissions = SkillPermissions(
write_metadata=perm_data.get("write_metadata", True),
write_previews=perm_data.get("write_previews", True),
network_domains=tuple(perm_data.get("network_domains", [])),
)
return SkillDefinition(
name=data["name"],
title=data.get("title", data["name"]),
description=data.get("description", ""),
llm_required=data.get("llm_required", False),
input_schema=data.get("input_schema", {}),
output_schema=data.get("output_schema", {}),
model_type_filter=data.get("model_type_filter"),
permissions=permissions,
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def list_skills(self) -> List[SkillDefinition]:
"""Return all discovered skill definitions."""
if not self._loaded:
self._discover()
return list(self._skills.values())
def get_skill(self, name: str) -> Optional[SkillDefinition]:
"""Return the skill definition for ``name``, or ``None`` if not found."""
if not self._loaded:
self._discover()
return self._skills.get(name)
def load_prompt(self, name: str) -> str:
"""Load and return the prompt template for a skill."""
skill_dir = self._skills_dir / name
prompt_path = skill_dir / "prompt.md"
if not prompt_path.exists():
raise FileNotFoundError(f"Prompt template not found: {prompt_path}")
with open(prompt_path, "r", encoding="utf-8") as f:
return f.read()
def load_handler(self, name: str) -> Dict[str, Callable]:
"""Dynamically import a skill's handler module and return its functions.
Returns a dict with ``prepare`` and ``post_process`` callables.
``prepare`` may be absent (the skill doesn't need pre-LLM data gathering).
"""
skill_dir = self._skills_dir / name
handler_path = skill_dir / "handler.py"
if not handler_path.exists():
raise FileNotFoundError(f"Handler not found: {handler_path}")
# Use importlib to load the module by file path
# Important: use a fully-qualified module name so that absolute imports
# (e.g. ``from py.utils.metadata_manager import MetadataManager``) resolve correctly.
module_name = f"py.services.agent.skills.{name}.handler"
spec = importlib.util.spec_from_file_location(module_name, handler_path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load handler module from {handler_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
result: Dict[str, Callable] = {}
if hasattr(module, "prepare"):
result["prepare"] = module.prepare
if hasattr(module, "post_process"):
result["post_process"] = module.post_process
else:
raise AttributeError(
f"Skill handler {name} is missing required 'post_process' function"
)
return result

View File

@@ -0,0 +1 @@
# Agent skills package — each subdirectory is a skill.

View File

@@ -0,0 +1,77 @@
You are an expert assistant for AI image generation models. Your task is to extract structured metadata from a HuggingFace model card (README.md).
## Model Information
- **Repository**: {{hf_url}}
- **Model file path**: {{model_path}}
- **Repository ID**: {{repo}}
## Current Metadata (may be incomplete)
```json
{{current_metadata}}
```
## HuggingFace README Content
```
{{readme_content}}
```
## Extraction Instructions
Extract the following information from the README content above:
### base_model
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list below. Do not invent new names or use aliases.
Available Base Models:
{{base_models}}
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
### trigger_words
The trigger words or activation prompts needed to use this LoRA. Look for:
- `instance_prompt:` in the YAML frontmatter
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
- Example prompts at the start (usually the first word or phrase before any description)
Return as an array of strings. If none found, return an empty array.
### description
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
### tags
3-8 relevant tags for categorizing this model. Extract from:
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl")
- The style/subject (e.g. "anime", "photorealistic", "style", "character")
All lowercase, no spaces. Return empty array if none found.
### preview_url
The URL of the most suitable preview image from the README. Look for image tags (e.g. `![alt](url)`) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
### confidence
Your confidence level in the extracted data:
- "high" — most fields were explicitly stated in the README
- "medium" — some fields were inferred from context
- "low" — most fields are guesses based on limited information
## Output Format
Return ONLY a JSON object with exactly these fields (no markdown fences, no extra text):
{
"model_path": "{{model_path}}",
"base_model": "<canonical name or empty string>",
"trigger_words": ["<word1>", "<word2>"],
"description": "<1-2 sentence summary>",
"tags": ["<tag1>", "<tag2>"],
"preview_url": "<image URL or empty string>",
"confidence": "<high|medium|low>"
}
Important:
- Only include the JSON object, no other text
- If a field cannot be determined, use an empty string or empty array
- Do not fabricate information not supported by the README
- For base_model, the YAML frontmatter often has `base_model:` with a HuggingFace repo name like "black-forest-labs/FLUX.1-dev" — map this to "Flux.1 D"

View File

@@ -0,0 +1,47 @@
name: enrich_hf_metadata
title: "Enrich Metadata from HuggingFace"
description: >
Parse the HuggingFace model card via LLM to extract description, trigger
words, base model, tags, and preview image URL. Updates .metadata.json
and downloads the preview thumbnail.
llm_required: true
model_type_filter: ["lora", "checkpoint", "embedding"]
input_schema:
type: object
properties:
model_paths:
type: array
items:
type: string
required:
- model_paths
output_schema:
type: object
properties:
model_path:
type: string
base_model:
type: string
trigger_words:
type: array
items:
type: string
description:
type: string
tags:
type: array
items:
type: string
preview_url:
type: string
confidence:
type: string
enum: ["high", "medium", "low"]
required:
- model_path
- confidence
permissions:
write_metadata: true
write_previews: true
network_domains:
- "huggingface.co"

View File

@@ -25,3 +25,21 @@ class ResourceNotFoundError(RuntimeError):
pass
class LLMNotConfiguredError(RuntimeError):
"""Raised when an LLM-dependent operation is attempted but no provider is configured."""
pass
class LLMRateLimitError(RateLimitError):
"""Raised when the LLM provider rejects a request due to rate limiting."""
pass
class LLMResponseError(RuntimeError):
"""Raised when the LLM returns an unparseable or schema-invalid response."""
pass

321
py/services/llm_service.py Normal file
View File

@@ -0,0 +1,321 @@
"""Centralized LLM API client with BYOK (bring-your-own-key) provider support.
Reads provider configuration from :class:`SettingsManager` and makes
OpenAI-compatible ``/chat/completions`` calls. Supports any provider that
implements the OpenAI Chat Completions API surface area (OpenAI, Ollama,
vLLM, LM Studio, etc.).
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any, Dict, List, Optional
import aiohttp
from .errors import LLMNotConfiguredError, LLMRateLimitError, LLMResponseError
logger = logging.getLogger(__name__)
# Default API base URLs per provider
_PROVIDER_DEFAULTS: Dict[str, str] = {
"openai": "https://api.openai.com/v1",
"ollama": "http://localhost:11434/v1",
# "custom" requires an explicit llm_api_base from the user
}
# Request timeout for LLM calls (seconds)
_LLM_TIMEOUT = aiohttp.ClientTimeout(total=120)
class LLMService:
"""Centralized LLM API client.
All agent skills call LLMs through this service so that BYOK config,
retry logic, and error handling live in one place.
"""
_instance: Optional["LLMService"] = None
_lock: asyncio.Lock = asyncio.Lock()
def __init__(self, settings_service) -> None:
self._settings = settings_service
# ------------------------------------------------------------------
# Singleton access
# ------------------------------------------------------------------
@classmethod
async def get_instance(cls) -> "LLMService":
"""Return the lazily-initialised global ``LLMService`` instance."""
if cls._instance is None:
async with cls._lock:
if cls._instance is None:
from .settings_manager import get_settings_manager
cls._instance = cls(get_settings_manager())
return cls._instance
@classmethod
def reset_instance(cls) -> None:
"""Reset the cached singleton — primarily for tests."""
cls._instance = None
# ------------------------------------------------------------------
# Configuration helpers
# ------------------------------------------------------------------
def _get_config(self) -> Dict[str, Any]:
"""Read the current LLM configuration from settings."""
return {
"provider": self._settings.get("llm_provider", "openai"),
"api_key": self._settings.get("llm_api_key", ""),
"api_base": self._settings.get("llm_api_base", ""),
"model": self._settings.get("llm_model", ""),
}
def is_configured(self) -> bool:
"""Return ``True`` when the LLM provider is minimally configured.
A provider is considered configured when ``llm_model`` is set and
(for non-Ollama) an API key is configured.
"""
cfg = self._get_config()
has_model = bool(cfg["model"])
has_key = bool(cfg["api_key"]) or cfg["provider"] == "ollama"
return has_model and has_key
def _resolve_api_base(self, provider: str, api_base: str) -> str:
"""Resolve the API base URL for the given provider."""
if api_base:
return api_base.rstrip("/")
return _PROVIDER_DEFAULTS.get(provider, "").rstrip("/")
def _build_headers(self, api_key: str) -> Dict[str, str]:
"""Build HTTP headers for the LLM API request."""
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def _ensure_configured(self) -> Dict[str, Any]:
"""Validate configuration and return it, or raise.
A provider is considered configured when ``llm_model`` is set and
(for non-Ollama) an API key is configured.
"""
cfg = self._get_config()
has_model = bool(cfg["model"])
has_key = bool(cfg["api_key"]) or cfg["provider"] == "ollama"
if not (has_model and has_key):
parts = []
if not has_model:
parts.append("No LLM model specified")
if not has_key and cfg["provider"] != "ollama":
parts.append("No LLM API key configured")
detail = "; ".join(parts) if parts else "LLM provider is not configured"
raise LLMNotConfiguredError(
f"{detail}. Configure it in Settings → AI Provider."
)
return cfg
# ------------------------------------------------------------------
# Core API call
# ------------------------------------------------------------------
async def chat_completion(
self,
*,
messages: List[Dict[str, str]],
model: Optional[str] = None,
temperature: float = 0.3,
response_format: Optional[Dict[str, Any]] = None,
max_tokens: Optional[int] = None,
retry_on_rate_limit: bool = True,
) -> Dict[str, Any]:
"""Call the configured LLM provider's ``/chat/completions`` endpoint.
Args:
messages: OpenAI-format message list
model: Override the configured model name
temperature: Sampling temperature
response_format: Optional ``{"type": "json_object"}`` for structured output
max_tokens: Optional max output tokens
retry_on_rate_limit: Retry once after a 429 with backoff
Returns:
Dict with ``content`` (str), ``usage`` (dict), ``model`` (str)
Raises:
LLMNotConfiguredError: Provider not enabled / missing config
LLMRateLimitError: Rate limited and retry exhausted
LLMResponseError: Non-200 response or parse failure
"""
cfg = self._ensure_configured()
api_base = self._resolve_api_base(cfg["provider"], cfg["api_base"])
url = f"{api_base}/chat/completions"
model_name = model or cfg["model"]
payload: Dict[str, Any] = {
"model": model_name,
"messages": messages,
"temperature": temperature,
}
if response_format is not None:
payload["response_format"] = response_format
if max_tokens is not None:
payload["max_tokens"] = max_tokens
headers = self._build_headers(cfg["api_key"])
attempt = 0
max_attempts = 2 if retry_on_rate_limit else 1
while attempt < max_attempts:
attempt += 1
try:
async with aiohttp.ClientSession(timeout=_LLM_TIMEOUT) as session:
async with session.post(
url, json=payload, headers=headers
) as resp:
if resp.status == 429:
if attempt < max_attempts:
retry_after = float(
resp.headers.get("Retry-After", "5")
)
logger.warning(
"LLM rate limited, retrying after %.1fs",
retry_after,
)
await asyncio.sleep(retry_after)
continue
raise LLMRateLimitError(
f"LLM provider rate limited (HTTP 429)",
provider=cfg["provider"],
)
if resp.status != 200:
body = await resp.text()
raise LLMResponseError(
f"LLM API returned HTTP {resp.status}: "
f"{body[:500]}"
)
data = await resp.json()
except aiohttp.ClientError as exc:
raise LLMResponseError(f"Network error calling LLM API: {exc}") from exc
# Parse response
try:
content = data["choices"][0]["message"]["content"]
usage = data.get("usage", {})
return {
"content": content,
"usage": usage,
"model": data.get("model", model_name),
}
except (KeyError, IndexError) as exc:
raise LLMResponseError(
f"Unexpected LLM response structure: {json.dumps(data)[:500]}"
) from exc
# Should not reach here, but satisfy type checker
raise LLMRateLimitError("Rate limit retry exhausted", provider=cfg["provider"])
# ------------------------------------------------------------------
# Structured output convenience
# ------------------------------------------------------------------
async def chat_completion_json(
self,
*,
system_prompt: str,
user_prompt: str,
model: Optional[str] = None,
temperature: float = 0.3,
max_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""Call the LLM and return parsed JSON.
Sends ``response_format: {"type": "json_object"}`` when the provider
supports it, and parses the response content as JSON. If parsing
fails, retries once with a clarifying system message.
Args:
system_prompt: System-level instructions
user_prompt: User-level query
model: Override the configured model name
temperature: Sampling temperature
max_tokens: Optional max output tokens
Returns:
Parsed JSON dict from the LLM response
Raises:
LLMNotConfiguredError: Provider not configured
LLMRateLimitError: Rate limited
LLMResponseError: JSON parse failure after retry
"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
# First attempt with JSON mode
result = await self.chat_completion(
messages=messages,
model=model,
temperature=temperature,
response_format={"type": "json_object"},
max_tokens=max_tokens,
)
try:
return json.loads(result["content"])
except (json.JSONDecodeError, TypeError) as exc:
logger.warning(
"LLM JSON parse failed on first attempt: %s. Retrying.", exc
)
# Retry with explicit instruction to return valid JSON
retry_messages = messages + [
{
"role": "assistant",
"content": result["content"],
},
{
"role": "user",
"content": (
"The previous response could not be parsed as JSON. "
"Please respond with ONLY a valid JSON object, no "
"markdown fences or extra text."
),
},
]
result = await self.chat_completion(
messages=retry_messages,
model=model,
temperature=0.0, # More deterministic for retry
response_format={"type": "json_object"},
max_tokens=max_tokens,
)
try:
return json.loads(result["content"])
except (json.JSONDecodeError, TypeError) as exc:
raise LLMResponseError(
f"LLM response could not be parsed as JSON after retry: {exc}\n"
f"Raw content: {result['content'][:500]}"
) from exc

View File

@@ -107,6 +107,11 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
"backup_retention_count": 5,
"use_new_license_icons": True,
"group_by_model": False,
# AI / LLM provider configuration (BYOK)
"llm_provider": "openai", # "openai" | "ollama" | "custom"
"llm_api_key": "",
"llm_api_base": "", # empty = provider default
"llm_model": "", # e.g. "gpt-4o-mini"
}
@@ -873,6 +878,23 @@ class SettingsManager:
self.settings["civitai_api_key"] = env_api_key
self._save_settings()
# LLM provider overrides
llm_env_map = {
"LLM_API_KEY": "llm_api_key",
"LLM_MODEL": "llm_model",
"LLM_API_BASE": "llm_api_base",
"LLM_PROVIDER": "llm_provider",
}
llm_changed = False
for env_var, settings_key in llm_env_map.items():
env_val = os.environ.get(env_var)
if env_val:
logger.info("Found %s environment variable", env_var)
self.settings[settings_key] = env_val
llm_changed = True
if llm_changed:
self._save_settings()
def _default_settings_actions(self) -> List[Dict[str, Any]]:
return [
{

View File

@@ -35,6 +35,9 @@ class BaseModelMetadata:
metadata_source: Optional[str] = None # Last provider that supplied metadata
last_checked_at: float = 0 # Last checked timestamp
hash_status: str = "completed" # Hash calculation status: pending | calculating | completed | failed
trainedWords: List[str] = field(
default_factory=list
) # Trigger words / activation prompts (source-agnostic)
_unknown_fields: Dict[str, Any] = field(
default_factory=dict, repr=False, compare=False
) # Store unknown fields
@@ -47,6 +50,9 @@ class BaseModelMetadata:
if self.tags is None:
self.tags = []
if self.trainedWords is None:
self.trainedWords = []
@classmethod
def from_dict(cls, data: Dict) -> "BaseModelMetadata":
"""Create instance from dictionary"""

View File

@@ -274,6 +274,9 @@ export class BulkContextMenu extends BaseContextMenu {
case 'resume-metadata-refresh':
bulkManager.setSkipMetadataRefresh(false);
break;
case 'enrich-hf-agent-bulk':
this.enrichBulkWithAgent();
break;
case 'delete-all':
bulkManager.showBulkDeleteModal();
break;
@@ -363,4 +366,66 @@ export class BulkContextMenu extends BaseContextMenu {
console.error('Bulk download example images failed:', error);
}
}
/**
* Enrich metadata for selected models via LLM agent skill.
*/
async enrichBulkWithAgent() {
if (state.selectedModels.size === 0) {
return;
}
const { agentManager } = await import('../../managers/AgentManager.js');
// Check if LLM is configured
const configured = await agentManager.isLlmConfigured();
if (!configured) {
showToast('toast.agent.llmNotConfigured', {}, 'warning');
return;
}
const modelPaths = [...state.selectedModels];
// Connect WebSocket for progress
agentManager.connect();
// Set up one-time completion handler
const onComplete = (data) => {
const idx = agentManager.completeCallbacks.indexOf(onComplete);
if (idx >= 0) agentManager.completeCallbacks.splice(idx, 1);
if (data.status === 'completed') {
showToast(
'toast.agent.enrichComplete',
{ summary: data.summary || 'Done' },
'success'
);
// Soft reload to reflect updated metadata
window.location.reload();
} else if (data.status === 'error') {
showToast(
'toast.agent.enrichFailed',
{ error: data.error || 'Unknown error' },
'error'
);
}
};
agentManager.onComplete(onComplete);
showToast(
'toast.agent.enrichStarted',
{ count: modelPaths.length },
'info'
);
try {
await agentManager.executeSkill('enrich_hf_metadata', modelPaths);
} catch (error) {
showToast(
'toast.agent.enrichFailed',
{ error: error.message },
'error'
);
}
}
}

View File

@@ -1,7 +1,7 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
import { getModelApiClient, resetAndReload } from '../../api/modelApiFactory.js';
import { copyLoraSyntax, sendLoraToWorkflow, buildLoraSyntax } from '../../utils/uiHelpers.js';
import { copyLoraSyntax, sendLoraToWorkflow, buildLoraSyntax, showToast } from '../../utils/uiHelpers.js';
import { showExcludeModal, showDeleteModal } from '../../utils/modalUtils.js';
import { moveManager } from '../../managers/MoveManager.js';
@@ -63,6 +63,9 @@ export class LoraContextMenu extends BaseContextMenu {
case 'refresh-metadata':
getModelApiClient().refreshSingleModelMetadata(this.currentCard.dataset.filepath);
break;
case 'enrich-hf-agent':
this.enrichWithAgent(this.currentCard.dataset.filepath);
break;
case 'exclude':
showExcludeModal(this.currentCard.dataset.filepath);
break;
@@ -72,6 +75,46 @@ export class LoraContextMenu extends BaseContextMenu {
}
}
async enrichWithAgent(filePath) {
const { agentManager } = await import('../../managers/AgentManager.js');
// Check if LLM is configured
const configured = await agentManager.isLlmConfigured();
if (!configured) {
showToast('toast.agent.llmNotConfigured', {}, 'warning');
return;
}
// Connect WebSocket for progress
agentManager.connect();
// Set up one-time completion handler
const onComplete = (data) => {
const idx = agentManager.completeCallbacks.indexOf(onComplete);
if (idx >= 0) agentManager.completeCallbacks.splice(idx, 1);
if (data.status === 'completed') {
showToast('toast.agent.enrichComplete', { summary: data.summary || 'Done' }, 'success');
// Soft reload to reflect updated metadata
if (typeof resetAndReload === 'function') {
resetAndReload();
}
} else if (data.status === 'error') {
showToast('toast.agent.enrichFailed', { error: data.error || 'Unknown error' }, 'error');
}
};
agentManager.onComplete(onComplete);
// Show progress toast
showToast('toast.agent.enrichStarted', {}, 'info');
try {
await agentManager.executeSkill('enrich_hf_metadata', [filePath]);
} catch (error) {
showToast('toast.agent.enrichFailed', { error: error.message }, 'error');
}
}
sendLoraToWorkflow(replaceMode) {
const card = this.currentCard;
const usageTips = JSON.parse(card.dataset.usage_tips || '{}');

View File

@@ -0,0 +1,196 @@
/**
* AgentManager — WebSocket listener for agent skill progress events.
*
* Connects to the generic WebSocket endpoint and filters for
* `type: "agent_progress"` messages. Dispatches progress and completion
* events to registered callbacks.
*/
class AgentManager {
constructor() {
this.websocket = null;
this.progressCallbacks = [];
this.completeCallbacks = [];
this.errorCallbacks = [];
this.connected = false;
}
/**
* Connect to the WebSocket endpoint for agent progress events.
* Safe to call multiple times — won't reconnect if already connected.
*/
connect() {
if (this.connected && this.websocket?.readyState === WebSocket.OPEN) {
return;
}
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
try {
this.websocket = new WebSocket(
`${wsProtocol}${window.location.host}/ws/fetch-progress`
);
} catch (e) {
console.error('AgentManager: Failed to create WebSocket:', e);
return;
}
this.websocket.onopen = () => {
this.connected = true;
console.debug('AgentManager: WebSocket connected');
};
this.websocket.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
if (data.type !== 'agent_progress') return;
this._dispatch(data);
} catch (e) {
// Not JSON or wrong format — ignore
}
};
this.websocket.onerror = (error) => {
console.error('AgentManager: WebSocket error:', error);
this.connected = false;
};
this.websocket.onclose = () => {
this.connected = false;
console.debug('AgentManager: WebSocket closed');
};
}
/**
* Dispatch a parsed agent event to the appropriate callbacks.
* @param {Object} data - The parsed WebSocket message
*/
_dispatch(data) {
const { status, skill } = data;
if (status === 'error') {
this.errorCallbacks.forEach((cb) => {
try {
cb(data);
} catch (e) {
console.error('AgentManager error callback failed:', e);
}
});
return;
}
if (status === 'completed') {
this.completeCallbacks.forEach((cb) => {
try {
cb(data);
} catch (e) {
console.error('AgentManager complete callback failed:', e);
}
});
return;
}
// started, processing — general progress
this.progressCallbacks.forEach((cb) => {
try {
cb(data);
} catch (e) {
console.error('AgentManager progress callback failed:', e);
}
});
}
/**
* Register a callback for progress events (started, processing).
* @param {Function} callback - Receives the event data
*/
onProgress(callback) {
this.progressCallbacks.push(callback);
}
/**
* Register a callback for completion events.
* @param {Function} callback - Receives the event data
*/
onComplete(callback) {
this.completeCallbacks.push(callback);
}
/**
* Register a callback for error events.
* @param {Function} callback - Receives the event data
*/
onError(callback) {
this.errorCallbacks.push(callback);
}
/**
* Clear all registered callbacks.
*/
clearCallbacks() {
this.progressCallbacks = [];
this.completeCallbacks = [];
this.errorCallbacks = [];
}
/**
* Execute an agent skill on the provided model paths.
*
* @param {string} skillName - The skill to execute
* @param {string[]} modelPaths - Model file paths to process
* @returns {Promise<Object>} The response JSON
*/
async executeSkill(skillName, modelPaths) {
const response = await fetch(
`/api/lm/agent/execute/${encodeURIComponent(skillName)}`,
{
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ model_paths: modelPaths }),
}
);
if (!response.ok) {
const errorData = await response.json().catch(() => ({}));
throw new Error(
errorData.error || `HTTP ${response.status}: ${response.statusText}`
);
}
return response.json();
}
/**
* Check if the LLM provider is configured.
*
* Returns true when both an API key and a model name are set.
*
* @returns {Promise<boolean>}
*/
async isLlmConfigured() {
try {
const response = await fetch('/api/lm/settings');
if (!response.ok) return false;
const data = await response.json();
const provider = data.settings?.llm_provider;
const hasModel = !!data.settings?.llm_model;
const hasKey = !!data.settings?.llm_api_key;
return hasModel && (hasKey || provider === 'ollama');
} catch {
return false;
}
}
/**
* Get the list of available agent skills.
*
* @returns {Promise<Array>}
*/
async listSkills() {
const response = await fetch('/api/lm/agent/skills');
if (!response.ok) return [];
const data = await response.json();
return data.skills || [];
}
}
// Export as singleton
export const agentManager = new AgentManager();

View File

@@ -827,6 +827,23 @@ export class SettingsManager {
// Update API key status display (do NOT pre-fill the input)
this.updateApiKeyStatus();
this.updateLlmApiKeyStatus();
// AI Provider settings
const llmProviderSelect = document.getElementById('llmProvider');
if (llmProviderSelect) {
llmProviderSelect.value = state.global.settings.llm_provider || 'openai';
}
const llmApiBaseInput = document.getElementById('llmApiBase');
if (llmApiBaseInput) {
llmApiBaseInput.value = state.global.settings.llm_api_base || '';
}
const llmModelInput = document.getElementById('llmModel');
if (llmModelInput) {
llmModelInput.value = state.global.settings.llm_model || '';
}
const civitaiHostSelect = document.getElementById('civitaiHost');
if (civitaiHostSelect) {
@@ -2931,42 +2948,70 @@ export class SettingsManager {
}
}
editApiKey() {
const statusEl = document.getElementById('civitaiApiKeyStatus');
updateLlmApiKeyStatus() {
const hasKey = !!state.global.settings.llm_api_key;
const statusText = document.getElementById('llmApiKeyStatusText');
const actionBtn = document.getElementById('llmApiKeyActionBtn');
if (!statusText || !actionBtn) return;
if (hasKey) {
statusText.classList.remove('api-key-status--unconfigured');
statusText.classList.add('api-key-status--configured');
statusText.innerHTML = '<i class="fas fa-check-circle text-success"></i> '
+ translate('settings.aiProvider.apiKeyConfigured', {}, 'Configured');
actionBtn.textContent = translate('common.actions.change', {}, 'Change');
} else {
statusText.classList.remove('api-key-status--configured');
statusText.classList.add('api-key-status--unconfigured');
statusText.innerHTML = '<i class="fas fa-times-circle text-error"></i> '
+ translate('settings.aiProvider.apiKeyNotSet', {}, 'Not set');
actionBtn.textContent = translate('settings.aiProvider.apiKeySet', {}, 'Set up');
}
}
editApiKey(settingsKey = 'civitai_api_key', inputId = 'civitaiApiKey') {
const statusId = inputId + 'Status';
const editId = inputId + 'Edit';
const statusEl = document.getElementById(statusId);
if (statusEl) statusEl.classList.add('is-hidden');
const editContainer = document.getElementById('civitaiApiKeyEdit');
const editContainer = document.getElementById(editId);
if (editContainer) editContainer.classList.remove('is-hidden');
// Focus the input
const input = document.getElementById('civitaiApiKey');
const input = document.getElementById(inputId);
if (input) {
input.value = ''; // Never pre-fill the secret
setTimeout(() => input.focus(), 50);
}
}
cancelEditApiKey(silent) {
const editContainer = document.getElementById('civitaiApiKeyEdit');
cancelEditApiKey(silent, inputId = 'civitaiApiKey') {
const editId = inputId + 'Edit';
const statusId = inputId + 'Status';
const editContainer = document.getElementById(editId);
if (editContainer) editContainer.classList.add('is-hidden');
const statusContainer = document.getElementById('civitaiApiKeyStatus');
const statusContainer = document.getElementById(statusId);
if (statusContainer) statusContainer.classList.remove('is-hidden');
// Clear any typed value
const input = document.getElementById('civitaiApiKey');
const input = document.getElementById(inputId);
if (input) input.value = '';
if (!silent) {
this.updateApiKeyStatus();
if (inputId === 'civitaiApiKey') {
this.updateApiKeyStatus();
}
}
}
async saveApiKey() {
const input = document.getElementById('civitaiApiKey');
async saveApiKey(settingsKey = 'civitai_api_key', inputId = 'civitaiApiKey') {
const input = document.getElementById(inputId);
if (!input) return;
const value = input.value.trim();
try {
await this.saveSetting('civitai_api_key', value);
await this.saveSetting(settingsKey, value);
const labelName = settingsKey === 'civitai_api_key' ? 'CivitAI API Key' : 'LLM API Key';
showToast('toast.settings.settingsUpdated',
{ setting: 'CivitAI API Key' }, 'success');
{ setting: labelName }, 'success');
} catch (error) {
showToast('toast.settings.settingSaveFailed',
{ message: error.message }, 'error');
@@ -2974,9 +3019,13 @@ export class SettingsManager {
}
// Update the in-memory flag so the UI reflects the change
state.global.settings.civitai_api_key_set = !!value;
this.cancelEditApiKey(true);
this.updateApiKeyStatus();
if (settingsKey === 'civitai_api_key') {
state.global.settings.civitai_api_key_set = !!value;
}
this.cancelEditApiKey(true, inputId);
if (inputId === 'civitaiApiKey') {
this.updateApiKeyStatus();
}
}
toggleInputVisibility(button) {

View File

@@ -55,6 +55,10 @@ const DEFAULT_SETTINGS_BASE = Object.freeze({
strip_lora_on_copy: false,
use_new_license_icons: true,
group_by_model: false,
llm_provider: 'openai',
llm_api_key: '',
llm_api_base: '',
llm_model: '',
});
export function createDefaultSettings() {

View File

@@ -12,6 +12,9 @@
<div class="context-menu-item" data-action="check-updates">
<i class="fas fa-bell"></i> <span>{{ t('loras.contextMenu.checkUpdates') }}</span>
</div>
<div class="context-menu-item" data-action="enrich-hf-agent">
<i class="fas fa-wand-magic-sparkles"></i> <span>{{ t('loras.contextMenu.enrichHfAgent') }}</span>
</div>
<div class="context-menu-item" data-action="relink-civitai">
<i class="fas fa-link"></i> <span>{{ t('loras.contextMenu.relinkCivitai') }}</span>
</div>
@@ -83,6 +86,9 @@
<div class="context-menu-item" data-action="resume-metadata-refresh">
<i class="fas fa-redo"></i> <span>{{ t('loras.bulkOperations.resumeMetadataRefresh') }}</span>
</div>
<div class="context-menu-item" data-action="enrich-hf-agent-bulk">
<i class="fas fa-wand-magic-sparkles"></i> <span>{{ t('loras.bulkOperations.enrichHfAgent') }}</span>
</div>
</div>
<div class="context-menu-section" data-section="workflow">
<div class="context-menu-section-header">{{ t('loras.bulkOperations.sections.workflow') }}</div>

View File

@@ -144,6 +144,96 @@
</div>
</div>
<!-- AI Provider Configuration (BYOK) -->
<div class="settings-subsection">
<div class="settings-subsection-header">
<h4>{{ t('settings.aiProvider.title') }}</h4>
</div>
<div class="setting-item">
<div class="setting-row">
<div class="setting-info">
<label for="llmProvider">{{ t('settings.aiProvider.provider') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.providerHelp') }}"></i>
</div>
<div class="setting-control select-control">
<select id="llmProvider" onchange="settingsManager.saveSelectSetting('llmProvider', 'llm_provider')">
<option value="openai">OpenAI</option>
<option value="ollama">Ollama (local)</option>
<option value="custom">{{ t('settings.aiProvider.custom') }}</option>
</select>
</div>
</div>
</div>
<div class="setting-item">
<div class="setting-row">
<div class="setting-info">
<label for="llmApiBase">{{ t('settings.aiProvider.apiBase') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.apiBaseHelp') }}"></i>
</div>
<div class="setting-control">
<div class="text-input-wrapper">
<input type="text" id="llmApiBase"
value="{{ settings.get('llm_api_base', '') }}"
placeholder="{{ t('settings.aiProvider.apiBasePlaceholder') }}"
onblur="settingsManager.saveInputSetting('llmApiBase', 'llm_api_base')"
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
</div>
</div>
</div>
</div>
<div class="setting-item api-key-item">
<div class="setting-row">
<div class="setting-info">
<label>{{ t('settings.aiProvider.apiKey') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.apiKeyHelp') }}"></i>
</div>
<div class="setting-control">
<div id="llmApiKeyStatus" class="api-key-status">
<span id="llmApiKeyStatusText" class="api-key-status-text api-key-status--unconfigured">
<i class="fas fa-times-circle text-error"></i>
{{ t('settings.aiProvider.apiKeyNotSet') }}
</span>
<button type="button" class="secondary-btn" id="llmApiKeyActionBtn" onclick="settingsManager.editApiKey('llm_api_key', 'llmApiKey')">
{{ t('settings.aiProvider.apiKeySet') }}
</button>
</div>
<div id="llmApiKeyEdit" class="api-key-edit is-hidden">
<div class="api-key-input">
<input type="text"
id="llmApiKey"
class="api-key-masked"
placeholder="{{ t('settings.aiProvider.apiKeyPlaceholder') }}"
autocomplete="off"
data-mask="css" />
<button type="button" class="toggle-visibility">
<i class="fas fa-eye"></i>
</button>
</div>
<button type="button" class="primary-btn" onclick="settingsManager.saveApiKey('llm_api_key', 'llmApiKey')">{{ t('common.actions.save') }}</button>
<button type="button" class="secondary-btn" onclick="settingsManager.cancelEditApiKey(true, 'llmApiKey')">{{ t('common.actions.cancel') }}</button>
</div>
</div>
</div>
</div>
<div class="setting-item">
<div class="setting-row">
<div class="setting-info">
<label for="llmModel">{{ t('settings.aiProvider.model') }}</label>
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.modelHelp') }}"></i>
</div>
<div class="setting-control">
<div class="text-input-wrapper">
<input type="text" id="llmModel"
value="{{ settings.get('llm_model', '') }}"
placeholder="e.g. gpt-4o-mini"
onblur="settingsManager.saveInputSetting('llmModel', 'llm_model')"
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
</div>
</div>
</div>
</div>
</div>
<div class="settings-subsection">
<div class="settings-subsection-header">
<h4>{{ t('settings.sections.downloads') }}</h4>

View File

View File

@@ -0,0 +1,317 @@
"""Tests for the AgentCLI module (py/agent_cli/).
All tests mock the underlying services (scanner, MetadataManager, downloader)
since the AgentCLI is a thin delegation layer.
Mock targets must match where imports are resolved inside each function
(lazy imports via ``from X import Y`` inside function body).
"""
from __future__ import annotations
from unittest import mock
import pytest
from py.agent_cli import (
list_base_models,
read_metadata,
apply_metadata_updates,
download_preview,
refresh_cache,
)
# ======================================================================
# Helpers
# ======================================================================
class MockCache:
def __init__(self, raw_data: list[dict] | None = None):
self.raw_data = raw_data or []
class MockScanner:
"""Simulates a ModelScanner for testing."""
def __init__(self, raw_data: list[dict] | None = None):
self._raw_data = raw_data or []
self.update_single_model_cache = mock.AsyncMock(return_value=True)
async def get_cached_data(self):
return MockCache(self._raw_data)
# ======================================================================
# list_base_models -- imports ServiceRegistry internally
# ======================================================================
class TestListBaseModels:
@pytest.mark.asyncio
async def test_empty_cache(self):
scanner = MockScanner([])
with mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=scanner),
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
get_embedding_scanner=mock.AsyncMock(return_value=None),
):
result = await list_base_models()
assert result == []
@pytest.mark.asyncio
async def test_merges_all_scanners(self):
lora_scanner = MockScanner([
{"base_model": "SDXL 1.0"},
{"base_model": "Flux.1 D"},
{"base_model": "SDXL 1.0"},
])
ckpt_scanner = MockScanner([
{"base_model": "SDXL 1.0"},
{"base_model": "SD 1.5"},
])
with mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=lora_scanner),
get_checkpoint_scanner=mock.AsyncMock(return_value=ckpt_scanner),
get_embedding_scanner=mock.AsyncMock(return_value=None),
):
result = await list_base_models()
assert result == ["SDXL 1.0", "Flux.1 D", "SD 1.5"]
@pytest.mark.asyncio
async def test_limit(self):
scanner = MockScanner([
{"base_model": "A"}, {"base_model": "B"}, {"base_model": "C"},
])
with mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=scanner),
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
get_embedding_scanner=mock.AsyncMock(return_value=None),
):
result = await list_base_models(limit=2)
assert result == ["A", "B"]
@pytest.mark.asyncio
async def test_all_scanners_return_none(self):
with mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=None),
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
get_embedding_scanner=mock.AsyncMock(return_value=None),
):
result = await list_base_models()
assert result == []
@pytest.mark.asyncio
async def test_skips_empty_or_missing_base_model(self):
scanner = MockScanner([
{"base_model": "SDXL 1.0"},
{"file_name": "foo.safetensors"}, # no base_model key
{"base_model": ""}, # empty
])
with mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=scanner),
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
get_embedding_scanner=mock.AsyncMock(return_value=None),
):
result = await list_base_models()
assert result == ["SDXL 1.0"]
# ======================================================================
# read_metadata -- imports MetadataManager from py.utils.metadata_manager
# ======================================================================
class TestReadMetadata:
@pytest.mark.asyncio
async def test_delegates_to_metadata_manager(self):
fake = {"file_name": "test", "base_model": "SDXL 1.0"}
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
mm.load_metadata_payload = mock.AsyncMock(return_value=fake)
result = await read_metadata("/p.safetensors")
assert result == fake
@pytest.mark.asyncio
async def test_exception_returns_empty_dict(self):
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
mm.load_metadata_payload = mock.AsyncMock(side_effect=ValueError("x"))
result = await read_metadata("/p.safetensors")
assert result == {}
@pytest.mark.asyncio
async def test_none_coerces_to_empty_dict(self):
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
mm.load_metadata_payload = mock.AsyncMock(return_value=None)
result = await read_metadata("/p.safetensors")
assert result == {}
# ======================================================================
# apply_metadata_updates -- uses read_metadata + MetadataManager.save_metadata
# ======================================================================
class TestApplyMetadataUpdates:
@pytest.mark.asyncio
async def test_updates_field(self):
with (
mock.patch("py.agent_cli.read_metadata") as mock_read,
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
):
mock_read.return_value = {"base_model": "", "tags": []}
mm.save_metadata = mock.AsyncMock(return_value=True)
updated = await apply_metadata_updates(
"/p.safetensors", {"base_model": "Flux.1 D"}
)
assert updated == ["base_model"]
mm.save_metadata.assert_awaited_once_with(
"/p.safetensors", {"base_model": "Flux.1 D", "tags": []},
)
@pytest.mark.asyncio
async def test_noop_when_value_unchanged(self):
with (
mock.patch("py.agent_cli.read_metadata") as mock_read,
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
):
mock_read.return_value = {"base_model": "Flux.1 D"}
updated = await apply_metadata_updates(
"/p.safetensors", {"base_model": "Flux.1 D"}
)
assert updated == []
mm.save_metadata.assert_not_called()
@pytest.mark.asyncio
async def test_multiple_fields(self):
with (
mock.patch("py.agent_cli.read_metadata") as mock_read,
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
):
mm.save_metadata = mock.AsyncMock(return_value=True)
mock_read.return_value = {
"base_model": "", "modelDescription": "", "tags": [],
}
updated = await apply_metadata_updates(
"/p.safetensors",
{"base_model": "SDXL 1.0", "modelDescription": "A", "tags": ["flux"]},
)
assert sorted(updated) == sorted(["base_model", "modelDescription", "tags"])
saved = mm.save_metadata.call_args[0][1]
assert saved["base_model"] == "SDXL 1.0"
@pytest.mark.asyncio
async def test_empty_updates_noop(self):
with (
mock.patch("py.agent_cli.read_metadata"),
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
):
updated = await apply_metadata_updates("/p.safetensors", {})
assert updated == []
mm.save_metadata.assert_not_called()
# ======================================================================
# download_preview -- imports get_downloader + ExifUtils
# ======================================================================
class TestDownloadPreview:
@pytest.mark.asyncio
async def test_empty_url_returns_false(self, tmp_path):
mp = tmp_path / "m.safetensors"
mp.write_bytes(b"fake")
assert await download_preview(str(mp), "") is False
assert await download_preview(str(mp), " ") is False
@pytest.mark.asyncio
async def test_successful_download_and_optimise(self, tmp_path):
mp = tmp_path / "t.safetensors"
mp.write_bytes(b"fake")
with (
mock.patch("py.services.downloader.get_downloader") as get_dl,
mock.patch("py.utils.exif_utils.ExifUtils") as exif,
):
dl = mock.AsyncMock()
dl.download_to_memory = mock.AsyncMock(return_value=(True, b"raw", {}))
get_dl.return_value = dl
exif.optimize_image.return_value = (b"optimized_webp", {})
result = await download_preview(str(mp), "https://ex.com/i.png")
assert result is True
assert (tmp_path / "t.webp").exists()
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
@pytest.mark.asyncio
async def test_download_failure_returns_false(self, tmp_path):
mp = tmp_path / "t.safetensors"
mp.write_bytes(b"fake")
with mock.patch("py.services.downloader.get_downloader") as get_dl:
dl = mock.AsyncMock()
dl.download_to_memory = mock.AsyncMock(return_value=(False, None, {}))
dl.download_file = mock.AsyncMock(return_value=(False, None))
get_dl.return_value = dl
result = await download_preview(str(mp), "https://ex.com/i.png")
assert result is False
assert not (tmp_path / "t.webp").exists()
# ======================================================================
# refresh_cache -- uses _find_scanner_for_model (ServiceRegistry)
# ======================================================================
class TestRefreshCache:
@pytest.mark.asyncio
async def test_found_and_refreshed(self):
scanner = MockScanner([{"file_path": "/some/path.safetensors"}])
with (
mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=scanner),
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
get_embedding_scanner=mock.AsyncMock(return_value=None),
),
mock.patch("py.agent_cli.read_metadata") as mock_read,
):
mock_read.return_value = {"base_model": "SDXL 1.0"}
result = await refresh_cache("/some/path.safetensors")
assert result is True
scanner.update_single_model_cache.assert_awaited_once()
@pytest.mark.asyncio
async def test_not_found_in_any_scanner(self):
scanner = MockScanner([])
with mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=scanner),
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
get_embedding_scanner=mock.AsyncMock(return_value=None),
):
result = await refresh_cache("/nonexistent/path.safetensors")
assert result is False
@pytest.mark.asyncio
async def test_no_metadata_returns_false(self):
scanner = MockScanner([{"file_path": "/some/path.safetensors"}])
with (
mock.patch(
"py.services.service_registry.ServiceRegistry",
get_lora_scanner=mock.AsyncMock(return_value=scanner),
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
get_embedding_scanner=mock.AsyncMock(return_value=None),
),
mock.patch("py.agent_cli.read_metadata") as mock_read,
):
mock_read.return_value = {}
result = await refresh_cache("/some/path.safetensors")
assert result is False

View File

@@ -0,0 +1,237 @@
"""Tests for the LLMService."""
from __future__ import annotations
import asyncio
import json
from unittest import mock
import pytest
from py.services.llm_service import LLMService
from py.services.errors import LLMNotConfiguredError, LLMRateLimitError, LLMResponseError
class MockSettings:
"""Minimal settings mock for LLMService tests."""
def __init__(self, **kwargs):
self._data = {
"llm_enabled": False,
"llm_provider": "openai",
"llm_api_key": "",
"llm_api_base": "",
"llm_model": "",
}
self._data.update(kwargs)
def get(self, key, default=None):
return self._data.get(key, default)
class MockResponse:
"""Mock aiohttp response."""
def __init__(self, status, json_data=None, text_data="", headers=None):
self.status = status
self._json_data = json_data
self._text_data = text_data
self.headers = headers or {}
async def json(self):
return self._json_data
async def text(self):
return self._text_data
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
class MockSession:
"""Mock aiohttp ClientSession."""
def __init__(self, response):
self._response = response
self.closed = False
def post(self, url, json=None, headers=None):
self.last_url = url
self.last_json = json
self.last_headers = headers
return self._response
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
@pytest.fixture
def llm_service():
"""Create an LLMService with mock settings."""
LLMService.reset_instance()
settings = MockSettings(
llm_enabled=True,
llm_provider="openai",
llm_api_key="sk-test-key",
llm_api_base="",
llm_model="gpt-4o-mini",
)
return LLMService(settings)
class TestLLMServiceConfiguration:
def test_is_configured_when_enabled_with_key_and_model(self, llm_service):
assert llm_service.is_configured() is True
def test_not_configured_when_disabled(self):
settings = MockSettings(
llm_enabled=False, llm_api_key="sk-test", llm_model="gpt-4o"
)
service = LLMService(settings)
# Lenient: model + API key is treated as configured even without
# the toggle, because the user clearly intends to use the feature.
assert service.is_configured() is True
def test_not_configured_without_model(self):
settings = MockSettings(llm_enabled=True, llm_api_key="sk-test", llm_model="")
service = LLMService(settings)
assert service.is_configured() is False
def test_not_configured_without_api_key_for_openai(self):
settings = MockSettings(llm_enabled=True, llm_api_key="", llm_model="gpt-4o")
service = LLMService(settings)
assert service.is_configured() is False
def test_ollama_configured_without_api_key(self):
settings = MockSettings(
llm_enabled=True, llm_provider="ollama", llm_api_key="", llm_model="llama3"
)
service = LLMService(settings)
assert service.is_configured() is True
def test_resolve_api_base_openai_default(self, llm_service):
assert llm_service._resolve_api_base("openai", "") == "https://api.openai.com/v1"
def test_resolve_api_base_ollama_default(self, llm_service):
assert llm_service._resolve_api_base("ollama", "") == "http://localhost:11434/v1"
def test_resolve_api_base_custom_override(self, llm_service):
assert llm_service._resolve_api_base("custom", "https://my.api.com/v1/") == "https://my.api.com/v1"
def test_ensure_configured_raises_when_disabled(self):
settings = MockSettings(llm_enabled=False)
service = LLMService(settings)
with pytest.raises(LLMNotConfiguredError):
service._ensure_configured()
def test_ensure_configured_raises_without_model(self):
settings = MockSettings(llm_enabled=True, llm_api_key="sk-test", llm_model="")
service = LLMService(settings)
with pytest.raises(LLMNotConfiguredError):
service._ensure_configured()
class TestLLMServiceChatCompletion:
@pytest.mark.asyncio
async def test_chat_completion_success(self, llm_service):
mock_response = MockResponse(
200,
json_data={
"choices": [{"message": {"content": "Hello!"}}],
"usage": {"total_tokens": 10},
"model": "gpt-4o-mini",
},
)
mock_session = MockSession(mock_response)
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
result = await llm_service.chat_completion(
messages=[{"role": "user", "content": "Hi"}],
)
assert result["content"] == "Hello!"
assert result["usage"]["total_tokens"] == 10
assert result["model"] == "gpt-4o-mini"
@pytest.mark.asyncio
async def test_chat_completion_raises_on_not_configured(self):
settings = MockSettings(llm_enabled=False)
service = LLMService(settings)
with pytest.raises(LLMNotConfiguredError):
await service.chat_completion(messages=[])
@pytest.mark.asyncio
async def test_chat_completion_raises_on_http_error(self, llm_service):
mock_response = MockResponse(500, text_data="Internal Server Error")
mock_session = MockSession(mock_response)
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(LLMResponseError, match="HTTP 500"):
await llm_service.chat_completion(messages=[])
@pytest.mark.asyncio
async def test_chat_completion_raises_on_rate_limit(self, llm_service):
mock_response = MockResponse(429, text_data="Rate limited", headers={"Retry-After": "0"})
mock_session = MockSession(mock_response)
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(LLMRateLimitError):
await llm_service.chat_completion(
messages=[], retry_on_rate_limit=False
)
@pytest.mark.asyncio
async def test_chat_completion_raises_on_bad_response_structure(self, llm_service):
mock_response = MockResponse(200, json_data={"unexpected": "data"})
mock_session = MockSession(mock_response)
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(LLMResponseError, match="Unexpected LLM response"):
await llm_service.chat_completion(messages=[])
class TestLLMServiceChatCompletionJson:
@pytest.mark.asyncio
async def test_chat_completion_json_parses_json(self, llm_service):
mock_response = MockResponse(
200,
json_data={
"choices": [{"message": {"content": '{"key": "value"}'}}],
"usage": {},
"model": "gpt-4o-mini",
},
)
mock_session = MockSession(mock_response)
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
result = await llm_service.chat_completion_json(
system_prompt="You are helpful.",
user_prompt="Return JSON.",
)
assert result == {"key": "value"}
@pytest.mark.asyncio
async def test_chat_completion_json_raises_on_non_json(self, llm_service):
# First attempt: non-JSON; second attempt (retry): also non-JSON
mock_response = MockResponse(
200,
json_data={
"choices": [{"message": {"content": "not json at all"}}],
"usage": {},
},
)
mock_session = MockSession(mock_response)
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
with pytest.raises(LLMResponseError, match="could not be parsed as JSON"):
await llm_service.chat_completion_json(
system_prompt="test",
user_prompt="test",
)

View File

@@ -0,0 +1,313 @@
"""Tests for the PostProcessor (py/services/agent/post_processor.py).
PostProcessor delegates all I/O to AgentCLI — these tests mock AgentCLI
functions and verify the business logic (conditions, merges, dispatch).
"""
from __future__ import annotations
from datetime import datetime, timezone
from unittest import mock
import pytest
from py.services.agent.post_processor import PostProcessor
@pytest.fixture
def processor():
return PostProcessor()
# ======================================================================
# process() — routing
# ======================================================================
class TestProcessDispatch:
@pytest.mark.asyncio
async def test_unknown_skill_returns_error(self, processor):
result = await processor.process(
skill_name="nonexistent",
model_path="/p.safetensors",
llm_output={},
metadata={},
)
assert result["success"] is False
assert "nonexistent" in result["errors"][0]
@pytest.mark.asyncio
async def test_enrich_hf_metadata_routes_correctly(self, processor):
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview") as mock_dl,
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
):
mock_apply.return_value = ["metadata_source"]
mock_dl.return_value = False
result = await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output={},
metadata={"from_civitai": True},
)
assert result["success"] is True
# ======================================================================
# enrich_hf_metadata — field-level logic
# ======================================================================
class TestEnrichHfMetadata:
"""Business logic tests for the enrich_hf_metadata post-processor."""
MIN_LLM_OUTPUT = {
"base_model": "",
"trigger_words": [],
"description": "",
"tags": [],
"preview_url": "",
"confidence": "low",
}
# -- base_model ------------------------------------------------------
@pytest.mark.asyncio
async def test_base_model_overwrites_empty(self, processor):
"""Empty current base_model → new value is applied."""
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"base_model": ""},
)
applied = mock_apply.call_args[0][1]
assert applied["base_model"] == "Flux.1 D"
@pytest.mark.asyncio
async def test_base_model_does_not_overwrite_existing_civitai(self, processor):
"""Existing base_model from CivitAI → not overwritten."""
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"base_model": "SDXL 1.0", "from_civitai": True},
)
# apply IS called (metadata_source, llm_enriched_at) but base_model not in it
applied = mock_apply.call_args[0][1]
assert "base_model" not in applied
@pytest.mark.asyncio
async def test_base_model_overwrites_existing_hf_model(self, processor):
"""Existing base_model from HF → overwritten (LLM is more reliable)."""
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"base_model": "SD 1.5", "from_civitai": False},
)
applied = mock_apply.call_args[0][1]
assert applied["base_model"] == "Flux.1 D"
@pytest.mark.asyncio
async def test_base_model_skipped_when_llm_empty(self, processor):
"""LLM returns empty base_model → nothing written."""
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=self.MIN_LLM_OUTPUT,
metadata={"base_model": ""},
)
applied = mock_apply.call_args[0][1]
assert "base_model" not in applied
# -- trigger_words ---------------------------------------------------
@pytest.mark.asyncio
async def test_trigger_words_merged(self, processor):
"""New trigger words written when current list is empty."""
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"trainedWords": []},
)
applied = mock_apply.call_args[0][1]
assert applied["trainedWords"] == ["trigger1", "trigger2"]
# -- description -----------------------------------------------------
@pytest.mark.asyncio
async def test_description_set_when_empty(self, processor):
llm = {**self.MIN_LLM_OUTPUT, "description": "A model description"}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"modelDescription": ""},
)
assert "modelDescription" in mock_apply.call_args[0][1]
# -- tags ------------------------------------------------------------
@pytest.mark.asyncio
async def test_tags_merged_and_deduplicated(self, processor):
llm = {**self.MIN_LLM_OUTPUT, "tags": ["flux", "lora", "STYLE"]}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"tags": ["anime"], "from_civitai": False},
)
merged = mock_apply.call_args[0][1]["tags"]
assert "anime" in merged
assert "flux" in merged
assert "style" in merged # lowercased
# "lora" and "STYLE" → "lora" and "style"
assert len(merged) == 4 # anime, flux, lora, style
# -- metadata_source & llm_enriched_at --------------------------------
@pytest.mark.asyncio
async def test_audit_fields_always_set(self, processor):
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache"),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=self.MIN_LLM_OUTPUT,
metadata={},
)
applied = mock_apply.call_args[0][1]
assert applied["metadata_source"] == "agent:enrich_hf_metadata"
assert "llm_enriched_at" in applied
# -- preview download ------------------------------------------------
@pytest.mark.asyncio
async def test_preview_downloaded_when_url_provided(self, processor):
llm = {**self.MIN_LLM_OUTPUT, "preview_url": "https://ex.com/img.png"}
with (
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
mock.patch("py.agent_cli.download_preview") as mock_dl,
mock.patch("py.agent_cli.refresh_cache"),
):
mock_dl.return_value = True
result = await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={},
)
assert result["preview_downloaded"] is True
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
@pytest.mark.asyncio
async def test_preview_skipped_when_exists(self, processor):
"""If current_preview file exists on disk, skip download."""
llm = {**self.MIN_LLM_OUTPUT, "preview_url": "https://ex.com/img.png"}
with (
mock.patch("py.agent_cli.apply_metadata_updates"),
mock.patch("py.agent_cli.download_preview") as mock_dl,
mock.patch("py.agent_cli.refresh_cache"),
mock.patch("os.path.exists", return_value=True),
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"preview_url": "/existing/preview.webp"},
)
mock_dl.assert_not_called()
# -- cache refresh ---------------------------------------------------
@pytest.mark.asyncio
async def test_cache_refreshed_when_updates_applied(self, processor):
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
with (
mock.patch("py.agent_cli.apply_metadata_updates", return_value=["base_model"]),
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=llm,
metadata={"base_model": ""},
)
mock_ref.assert_awaited_once_with("/p.safetensors")
@pytest.mark.asyncio
async def test_cache_not_refreshed_when_nothing_changed(self, processor):
with (
mock.patch("py.agent_cli.apply_metadata_updates", return_value=[]),
mock.patch("py.agent_cli.download_preview", return_value=False),
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
):
await processor.process(
skill_name="enrich_hf_metadata",
model_path="/p.safetensors",
llm_output=self.MIN_LLM_OUTPUT,
metadata={"base_model": ""},
)
mock_ref.assert_not_called()
# ======================================================================
# Unit: _merge_tags
# ======================================================================
class TestMergeTags:
def test_deduplicates_case_insensitive(self):
existing = ["anime", "Flux"]
new = ["flux", "LORA", "anime"]
result = PostProcessor._merge_tags(existing, new)
# All tags are lowercased (matching TagUpdateService behaviour)
assert result == ["anime", "flux", "lora"]

View File

@@ -0,0 +1,91 @@
"""Tests for the SkillRegistry."""
from __future__ import annotations
from pathlib import Path
import pytest
from py.services.agent.skill_registry import SkillRegistry
from py.services.agent.skill_definition import SkillDefinition, SkillPermissions
@pytest.fixture
def registry():
"""Create a SkillRegistry with the real skills directory."""
SkillRegistry.reset_instance()
reg = SkillRegistry()
reg._discover()
return reg
class TestSkillRegistryDiscovery:
def test_discovers_enrich_hf_metadata_skill(self, registry):
skills = registry.list_skills()
assert len(skills) >= 1
skill = registry.get_skill("enrich_hf_metadata")
assert skill is not None
assert skill.name == "enrich_hf_metadata"
assert skill.llm_required is True
def test_skill_has_correct_model_type_filter(self, registry):
skill = registry.get_skill("enrich_hf_metadata")
assert skill.model_type_filter == ["lora", "checkpoint", "embedding"]
def test_skill_has_permissions(self, registry):
skill = registry.get_skill("enrich_hf_metadata")
assert skill.permissions.write_metadata is True
assert skill.permissions.write_previews is True
assert "huggingface.co" in skill.permissions.network_domains
def test_get_skill_returns_none_for_unknown(self, registry):
assert registry.get_skill("nonexistent_skill") is None
class TestSkillRegistryLoading:
def test_load_prompt_returns_content(self, registry):
prompt = registry.load_prompt("enrich_hf_metadata")
assert isinstance(prompt, str)
assert len(prompt) > 100
assert "base_model" in prompt
assert "trigger_words" in prompt
def test_load_prompt_raises_for_unknown_skill(self, registry):
with pytest.raises(FileNotFoundError):
registry.load_prompt("nonexistent")
def test_load_handler_raises_when_handler_missing(self, registry):
with pytest.raises(FileNotFoundError):
registry.load_handler("enrich_hf_metadata")
class TestSkillDefinition:
def test_applies_to_model_type_with_filter(self):
sd = SkillDefinition(
name="test",
title="Test",
description="",
llm_required=False,
model_type_filter=["lora"],
)
assert sd.applies_to_model_type("lora") is True
assert sd.applies_to_model_type("checkpoint") is False
def test_applies_to_model_type_without_filter(self):
sd = SkillDefinition(
name="test",
title="Test",
description="",
llm_required=False,
model_type_filter=None,
)
assert sd.applies_to_model_type("lora") is True
assert sd.applies_to_model_type("checkpoint") is True
class TestSkillPermissions:
def test_defaults(self):
sp = SkillPermissions()
assert sp.write_metadata is True
assert sp.write_previews is True
assert sp.network_domains == ()