mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-07-03 07:51:16 -03:00
Compare commits
14 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63785f82b5 | ||
|
|
cf898da193 | ||
|
|
fe90f7f9b1 | ||
|
|
8b344ea39f | ||
|
|
8348a0cef8 | ||
|
|
7cf785b72f | ||
|
|
e8913f4481 | ||
|
|
f9c3d8dc97 | ||
|
|
09ca91fc0e | ||
|
|
16f5222efd | ||
|
|
28e7c04b37 | ||
|
|
28f99c46d3 | ||
|
|
205194f4e6 | ||
|
|
402d8b07cf |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,6 +7,10 @@ py/run_test.py
|
||||
.vscode/
|
||||
cache/
|
||||
civitai/
|
||||
stats/
|
||||
wildcards/
|
||||
backups/
|
||||
logs/
|
||||
node_modules/
|
||||
coverage/
|
||||
.coverage
|
||||
|
||||
208
docs/agent_skills.md
Normal file
208
docs/agent_skills.md
Normal file
@@ -0,0 +1,208 @@
|
||||
# Agent Skills System
|
||||
|
||||
The LoRA Manager agent skills system enables LLM-powered metadata enrichment and other AI-driven tasks. Users configure their own LLM provider (BYOK), and skills are executed through right-click context menu actions.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────┐
|
||||
│ LoRA Manager Backend │
|
||||
│ │
|
||||
│ ┌──────────────┐ ┌────────────────┐ │
|
||||
│ │ LLMService │───▶│ LLM Provider │ │
|
||||
│ │ (BYOK config, │◀───│ (OpenAI/Ollama │ │
|
||||
│ │ API calls) │ │ /custom) │ │
|
||||
│ └───────┬───────┘ └────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────▼───────────────────────┐ │
|
||||
│ │ AgentService │ │
|
||||
│ │ (orchestration: validate │ │
|
||||
│ │ → LLM call → post-process │ │
|
||||
│ │ → WebSocket broadcast) │ │
|
||||
│ └───────┬───────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────▼───────────────────────┐ │
|
||||
│ │ SkillRegistry │ │
|
||||
│ │ ┌─────────────────────────┐ │ │
|
||||
│ │ │ enrich_hf_metadata: │ │ │
|
||||
│ │ │ - skill.yaml │ │ │
|
||||
│ │ │ - prompt.md │ │ │
|
||||
│ │ │ - handler.py │ │ │
|
||||
│ │ └─────────────────────────┘ │ │
|
||||
│ └───────────────────────────────┘ │
|
||||
└──────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Key Design Principle
|
||||
|
||||
**Skills define *what* to do (prompt + post-processing). The AgentService handles *how* (LLM calls, validation, progress).**
|
||||
|
||||
Skills never call the LLM directly. This keeps BYOK configuration centralized and provider-agnostic.
|
||||
|
||||
## BYOK Configuration
|
||||
|
||||
Users configure their LLM provider in **Settings → AI Provider**:
|
||||
|
||||
| Setting | Description | Example |
|
||||
|---|---|---|
|
||||
| `llm_provider` | Provider type | `openai`, `ollama`, or `custom` |
|
||||
| `llm_api_key` | API key (not needed for local Ollama) | `sk-...` |
|
||||
| `llm_api_base` | Custom API base URL (empty = provider default) | `https://api.openai.com/v1` |
|
||||
| `llm_model` | Model name | `gpt-4o-mini` |
|
||||
|
||||
Environment variable overrides: `LLM_API_KEY`, `LLM_MODEL`, `LLM_API_BASE`, `LLM_PROVIDER`.
|
||||
|
||||
### Supported Providers
|
||||
|
||||
- **OpenAI**: Uses `https://api.openai.com/v1` by default
|
||||
- **Ollama** (local): Uses `http://localhost:11434/v1`, no API key required
|
||||
- **Custom**: Any OpenAI-compatible endpoint (vLLM, LM Studio, etc.) — set `llm_api_base` explicitly
|
||||
|
||||
## Available Skills
|
||||
|
||||
### enrich_hf_metadata
|
||||
|
||||
Enriches HuggingFace-downloaded models with metadata extracted by an LLM from the HF model card.
|
||||
|
||||
**Entry point**: Right-click context menu → "Enrich Metadata (Agent)"
|
||||
|
||||
**What it does**:
|
||||
1. Reads the model's `.metadata.json` to get the `hf_url`
|
||||
2. Fetches the README.md from the HuggingFace repository
|
||||
3. Sends the README + local metadata to the LLM for structured extraction
|
||||
4. Writes extracted fields to `.metadata.json`:
|
||||
- `base_model` — only if current value is empty
|
||||
- `trainedWords` — trigger words (LoRA only, if none exist)
|
||||
- `modelDescription` — concise summary (if none exists)
|
||||
- `tags` — merged with existing tags, deduplicated
|
||||
- `metadata_source` — audit trail: `agent:enrich_hf_metadata`
|
||||
- `llm_enriched_at` — ISO timestamp
|
||||
5. Downloads and optimizes preview image (if LLM found one in the README)
|
||||
6. Updates the scanner cache
|
||||
7. Broadcasts WebSocket progress events
|
||||
|
||||
**Model types**: LoRA, Checkpoint, Embedding
|
||||
|
||||
## Adding a New Skill
|
||||
|
||||
### 1. Create the skill directory
|
||||
|
||||
```
|
||||
py/services/agent/skills/<skill_name>/
|
||||
├── skill.yaml # Skill metadata and schemas
|
||||
├── prompt.md # LLM prompt template
|
||||
└── handler.py # Pre-processing and post-processing
|
||||
```
|
||||
|
||||
### 2. Write skill.yaml
|
||||
|
||||
```yaml
|
||||
name: my_skill
|
||||
title: "My Skill"
|
||||
description: "What this skill does"
|
||||
llm_required: true
|
||||
model_type_filter: ["lora"] # or null for all types
|
||||
input_schema:
|
||||
type: object
|
||||
properties:
|
||||
model_paths:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
required:
|
||||
- model_paths
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
# ... JSON schema for LLM output
|
||||
permissions:
|
||||
write_metadata: true
|
||||
write_previews: false
|
||||
network_domains:
|
||||
- "example.com"
|
||||
```
|
||||
|
||||
### 3. Write prompt.md
|
||||
|
||||
Use `{{variable}}` placeholders that will be replaced with data from the `prepare` function:
|
||||
|
||||
```markdown
|
||||
You are an expert assistant...
|
||||
|
||||
Model URL: {{hf_url}}
|
||||
README content:
|
||||
{{readme_content}}
|
||||
|
||||
Current metadata:
|
||||
{{current_metadata}}
|
||||
```
|
||||
|
||||
### 4. Write handler.py
|
||||
|
||||
```python
|
||||
async def prepare(model_path: str, input_data: dict) -> dict:
|
||||
"""Gather context for the LLM prompt. Returns variables for template rendering."""
|
||||
return {
|
||||
"model_path": model_path,
|
||||
# ... other variables used in prompt.md
|
||||
}
|
||||
|
||||
async def post_process(context) -> dict:
|
||||
"""Apply the LLM-extracted data to the model."""
|
||||
llm_response = context.llm_response
|
||||
# ... write metadata, download previews, update cache
|
||||
return {
|
||||
"success": True,
|
||||
"updated_fields": ["base_model", "tags"],
|
||||
"errors": [],
|
||||
}
|
||||
```
|
||||
|
||||
**Important**: Use absolute imports (`from py.utils.metadata_manager import MetadataManager`) because skills are loaded via `importlib.util.spec_from_file_location`, which doesn't support relative imports.
|
||||
|
||||
### 5. Test
|
||||
|
||||
The skill is automatically discovered by `SkillRegistry` on startup. Test with:
|
||||
|
||||
```python
|
||||
pytest tests/services/test_agent_service.py
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|---|---|---|
|
||||
| GET | `/api/lm/agent/skills` | List available skills |
|
||||
| POST | `/api/lm/agent/execute/{skill_name}` | Execute a skill (body: `{"model_paths": [...]}`) |
|
||||
| POST | `/api/lm/agent/cancel` | Cancel running skill (stub) |
|
||||
|
||||
## WebSocket Events
|
||||
|
||||
| Type | When | Key fields |
|
||||
|---|---|---|
|
||||
| `agent_progress` | Skill started/processing | `skill`, `status`, `total`, `processed`, `success`, `current_path` |
|
||||
| `agent_progress` | Skill completed | `skill`, `status`, `updated_models`, `errors`, `summary` |
|
||||
| `agent_progress` | Skill error | `skill`, `status`, `error` |
|
||||
|
||||
## Security Model
|
||||
|
||||
Skills declare permissions in `skill.yaml`:
|
||||
- `write_metadata` — can write `.metadata.json` files
|
||||
- `write_previews` — can download/replace preview images
|
||||
- `network_domains` — allowed domains for HTTP requests
|
||||
|
||||
These are declarative constraints checked by `AgentService`. They are defense-in-depth, not a sandbox — the Python process can technically do anything, but the contract is clear and auditable.
|
||||
|
||||
## File Locations
|
||||
|
||||
| Component | Path |
|
||||
|---|---|
|
||||
| LLMService | `py/services/llm_service.py` |
|
||||
| AgentService | `py/services/agent/agent_service.py` |
|
||||
| SkillRegistry | `py/services/agent/skill_registry.py` |
|
||||
| SkillDefinition | `py/services/agent/skill_definition.py` |
|
||||
| Skills directory | `py/services/agent/skills/` |
|
||||
| Route handlers | `py/routes/handlers/agent_handlers.py` |
|
||||
| Frontend manager | `static/js/managers/AgentManager.js` |
|
||||
| Settings UI | `templates/components/modals/settings_modal.html` |
|
||||
| Context menu | `templates/components/context_menu.html` |
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "Aus Favoriten entfernen",
|
||||
"viewOnCivitai": "Auf Civitai anzeigen",
|
||||
"notAvailableFromCivitai": "Nicht auf Civitai verfügbar",
|
||||
"viewOnHuggingFace": "Auf Hugging Face ansehen",
|
||||
"sendToWorkflow": "An ComfyUI senden (Klick: Anhängen, Shift+Klick: Ersetzen)",
|
||||
"copyLoRASyntax": "LoRA-Syntax kopieren",
|
||||
"checkpointNameCopied": "Checkpoint-Name kopiert",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "Passwort (optional)",
|
||||
"proxyPasswordPlaceholder": "passwort",
|
||||
"proxyPasswordHelp": "Passwort für die Proxy-Authentifizierung (falls erforderlich)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "KI-Anbieter",
|
||||
"provider": "Anbieter",
|
||||
"providerHelp": "Wählen Sie Ihren LLM-Anbieter. OpenAI und Ollama verwenden voreingestellte API-Endpunkte. Mit \"Benutzerdefiniert\" können Sie jeden OpenAI-kompatiblen Endpunkt angeben.",
|
||||
"custom": "Benutzerdefiniert (OpenAI-kompatibel)",
|
||||
"apiBase": "API-Basis-URL",
|
||||
"apiBaseHelp": "Die Basis-URL für die LLM-API (z.B. https://api.openai.com/v1). Leer lassen, um die Anbietervoreinstellung zu verwenden.",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "API-Schlüssel",
|
||||
"apiKeyHelp": "Ihr LLM-API-Schlüssel. Wird lokal gespeichert und niemals an einen anderen Server außer Ihrem gewählten LLM-Anbieter gesendet.",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "Nicht festgelegt",
|
||||
"apiKeyConfigured": "Konfiguriert",
|
||||
"apiKeySet": "Einrichten",
|
||||
"model": "Modell",
|
||||
"modelHelp": "Der zu verwendende Modellname (z.B. deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Prüfen Sie Ihren Anbieter auf verfügbare Modelle."
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "Abgeschlossen: {success} verschoben, {skipped} übersprungen, {failures} fehlgeschlagen",
|
||||
"complete": "Automatische Organisation abgeschlossen",
|
||||
"error": "Fehler: {error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "Metadaten mit KI anreichern"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "Civitai-Daten aktualisieren",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "Rezept teilen",
|
||||
"viewAllLoras": "Alle LoRAs anzeigen",
|
||||
"downloadMissingLoras": "Fehlende LoRAs herunterladen",
|
||||
"deleteRecipe": "Rezept löschen"
|
||||
"deleteRecipe": "Rezept löschen",
|
||||
"enrichHfAgent": "Metadaten mit KI anreichern"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "{type} von URL herunterladen",
|
||||
"civitaiUrl": "Civitai URL:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "Geben Sie eine CivitAI- oder CivArchive-URL pro Zeile ein. Unterstützt mehrere URLs für den Batch-Download.",
|
||||
"urlHint": "Geben Sie eine CivitAI-, CivArchive- oder Hugging Face-URL pro Zeile ein. Unterstützt mehrere URLs für den Batch-Download.",
|
||||
"selectHfFiles": "Datei(en) zum Herunterladen aus diesem Repository auswählen:",
|
||||
"selectAll": "Alle auswählen",
|
||||
"fetchingRepoFiles": "Repository-Dateien werden abgerufen...",
|
||||
"locationPreview": "Download-Speicherort Vorschau",
|
||||
"useDefaultPath": "Standardpfad verwenden",
|
||||
"useDefaultPathTooltip": "Wenn aktiviert, werden Dateien automatisch mit konfigurierten Pfadvorlagen organisiert",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "Ungültiges Civitai URL-Format",
|
||||
"noVersions": "Keine Versionen für dieses Modell verfügbar"
|
||||
"noVersions": "Keine Versionen für dieses Modell verfügbar",
|
||||
"mixedSources": "CivitAI- und Hugging Face-URLs können nicht in derselben Charge gemischt werden.",
|
||||
"noModelFiles": "In diesem Repository wurden keine Modelldateien gefunden."
|
||||
},
|
||||
"status": {
|
||||
"preparing": "Download wird vorbereitet...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "Versionsname bearbeiten",
|
||||
"viewOnCivitai": "Auf Civitai anzeigen",
|
||||
"viewOnCivitaiText": "Auf Civitai anzeigen",
|
||||
"viewOnHuggingFace": "Auf Hugging Face ansehen",
|
||||
"viewOnHuggingFaceText": "Auf Hugging Face ansehen",
|
||||
"viewCreatorProfile": "Ersteller-Profil anzeigen",
|
||||
"openFileLocation": "Dateispeicherort öffnen",
|
||||
"sendToWorkflow": "An ComfyUI senden",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "Zusätzliche Notizen",
|
||||
"notesHint": "Enter zum Speichern, Shift+Enter für neue Zeile",
|
||||
"addNotesPlaceholder": "Fügen Sie hier Ihre Notizen hinzu...",
|
||||
"aboutThisVersion": "Über diese Version"
|
||||
"aboutThisVersion": "Über diese Version",
|
||||
"baseModelSearchPlaceholder": "Basismodell suchen…",
|
||||
"baseModelSuggested": "Vorschlag",
|
||||
"baseModelNoMatch": "Keine passenden Basismodelle"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "Notizen erfolgreich gespeichert",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "In die Zwischenablage kopiert",
|
||||
"downloadStarted": "Download gestartet"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "KI-Anbieter nicht konfiguriert. Aktivieren Sie ihn unter Einstellungen → KI-Anbieter.",
|
||||
"enrichStarted": "Metadaten werden mit KI angereichert...",
|
||||
"enrichComplete": "Metadatenanreicherung abgeschlossen: {{summary}}",
|
||||
"enrichFailed": "Metadatenanreicherung fehlgeschlagen: {{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
4302
locales/en.json
4302
locales/en.json
File diff suppressed because it is too large
Load Diff
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "Eliminar de favoritos",
|
||||
"viewOnCivitai": "Ver en Civitai",
|
||||
"notAvailableFromCivitai": "No disponible en Civitai",
|
||||
"viewOnHuggingFace": "Ver en Hugging Face",
|
||||
"sendToWorkflow": "Enviar a ComfyUI (Clic: Añadir, Shift+Clic: Reemplazar)",
|
||||
"copyLoRASyntax": "Copiar sintaxis de LoRA",
|
||||
"checkpointNameCopied": "Nombre del checkpoint copiado",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "Contraseña (opcional)",
|
||||
"proxyPasswordPlaceholder": "contraseña",
|
||||
"proxyPasswordHelp": "Contraseña para autenticación de proxy (si es necesario)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "Proveedor de IA",
|
||||
"provider": "Proveedor",
|
||||
"providerHelp": "Elija su proveedor de LLM. OpenAI y Ollama usan endpoints predefinidos. Personalizado le permite especificar cualquier endpoint compatible con OpenAI.",
|
||||
"custom": "Personalizado (compatible con OpenAI)",
|
||||
"apiBase": "URL base de la API",
|
||||
"apiBaseHelp": "La URL base para la API LLM (p.ej. https://api.openai.com/v1). Déjelo vacío para usar el valor predeterminado del proveedor.",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "Clave de API",
|
||||
"apiKeyHelp": "Su clave de API del proveedor LLM. Se almacena localmente y nunca se envía a ningún servidor excepto a su proveedor LLM elegido.",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "No configurada",
|
||||
"apiKeyConfigured": "Configurada",
|
||||
"apiKeySet": "Configurar",
|
||||
"model": "Modelo",
|
||||
"modelHelp": "El nombre del modelo a usar (p.ej. deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Consulte a su proveedor para ver los modelos disponibles."
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "Completado: {success} movidos, {skipped} omitidos, {failures} fallidos",
|
||||
"complete": "Auto-organización completada",
|
||||
"error": "Error: {error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "Enriquecer metadatos (IA)"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "Actualizar datos de Civitai",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "Compartir receta",
|
||||
"viewAllLoras": "Ver todos los LoRAs",
|
||||
"downloadMissingLoras": "Descargar LoRAs faltantes",
|
||||
"deleteRecipe": "Eliminar receta"
|
||||
"deleteRecipe": "Eliminar receta",
|
||||
"enrichHfAgent": "Enriquecer metadatos (IA)"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "Descargar {type} desde URL",
|
||||
"civitaiUrl": "URL de Civitai:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "Ingrese una URL de CivitAI o CivArchive por línea. Admite múltiples URLs para descarga por lotes.",
|
||||
"urlHint": "Ingrese una URL de CivitAI, CivArchive o Hugging Face por línea. Admite múltiples URLs para descarga por lotes.",
|
||||
"selectHfFiles": "Seleccione el/los archivo(s) para descargar de este repositorio:",
|
||||
"selectAll": "Seleccionar todo",
|
||||
"fetchingRepoFiles": "Obteniendo archivos del repositorio...",
|
||||
"locationPreview": "Vista previa de ubicación de descarga",
|
||||
"useDefaultPath": "Usar ruta predeterminada",
|
||||
"useDefaultPathTooltip": "Cuando está habilitado, los archivos se organizan automáticamente usando plantillas de rutas configuradas",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "Formato de URL de Civitai inválido",
|
||||
"noVersions": "No hay versiones disponibles para este modelo"
|
||||
"noVersions": "No hay versiones disponibles para este modelo",
|
||||
"mixedSources": "No se pueden mezclar URL de CivitAI y Hugging Face en el mismo lote.",
|
||||
"noModelFiles": "No se encontraron archivos de modelo en este repositorio."
|
||||
},
|
||||
"status": {
|
||||
"preparing": "Preparando descarga...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "Editar nombre de versión",
|
||||
"viewOnCivitai": "Ver en Civitai",
|
||||
"viewOnCivitaiText": "Ver en Civitai",
|
||||
"viewOnHuggingFace": "Ver en Hugging Face",
|
||||
"viewOnHuggingFaceText": "Ver en Hugging Face",
|
||||
"viewCreatorProfile": "Ver perfil del creador",
|
||||
"openFileLocation": "Abrir ubicación del archivo",
|
||||
"sendToWorkflow": "Enviar a ComfyUI",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "Notas adicionales",
|
||||
"notesHint": "Presiona Enter para guardar, Shift+Enter para nueva línea",
|
||||
"addNotesPlaceholder": "Añade tus notas aquí...",
|
||||
"aboutThisVersion": "Acerca de esta versión"
|
||||
"aboutThisVersion": "Acerca de esta versión",
|
||||
"baseModelSearchPlaceholder": "Buscar modelo base…",
|
||||
"baseModelSuggested": "Sugerido",
|
||||
"baseModelNoMatch": "No hay modelos base que coincidan"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "Notas guardadas exitosamente",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "Copiado al portapapeles",
|
||||
"downloadStarted": "Descarga iniciada"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "Proveedor de IA no configurado. Actívelo en Configuración → Proveedor de IA.",
|
||||
"enrichStarted": "Enriqueciendo metadatos con IA...",
|
||||
"enrichComplete": "Enriquecimiento de metadatos completado: {{summary}}",
|
||||
"enrichFailed": "Enriquecimiento de metadatos fallido: {{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "Retirer des favoris",
|
||||
"viewOnCivitai": "Voir sur Civitai",
|
||||
"notAvailableFromCivitai": "Non disponible sur Civitai",
|
||||
"viewOnHuggingFace": "Voir sur Hugging Face",
|
||||
"sendToWorkflow": "Envoyer vers ComfyUI (Clic: Ajouter, Maj+Clic: Remplacer)",
|
||||
"copyLoRASyntax": "Copier la syntaxe LoRA",
|
||||
"checkpointNameCopied": "Nom du checkpoint copié",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "Mot de passe (optionnel)",
|
||||
"proxyPasswordPlaceholder": "mot_de_passe",
|
||||
"proxyPasswordHelp": "Mot de passe pour l'authentification proxy (si nécessaire)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "Fournisseur d'IA",
|
||||
"provider": "Fournisseur",
|
||||
"providerHelp": "Choisissez votre fournisseur LLM. OpenAI et Ollama utilisent des endpoints prédéfinis. Personnalisé vous permet de spécifier n'importe quel endpoint compatible OpenAI.",
|
||||
"custom": "Personnalisé (compatible OpenAI)",
|
||||
"apiBase": "URL de base de l'API",
|
||||
"apiBaseHelp": "L'URL de base pour l'API LLM (ex. https://api.openai.com/v1). Laissez vide pour utiliser le fournisseur par défaut.",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "Clé API",
|
||||
"apiKeyHelp": "Votre clé API du fournisseur LLM. Stockée localement, jamais envoyée à un serveur autre que votre fournisseur LLM choisi.",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "Non définie",
|
||||
"apiKeyConfigured": "Configurée",
|
||||
"apiKeySet": "Configurer",
|
||||
"model": "Modèle",
|
||||
"modelHelp": "Le nom du modèle à utiliser (ex. deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Consultez votre fournisseur pour les modèles disponibles."
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "Terminé : {success} déplacés, {skipped} ignorés, {failures} échecs",
|
||||
"complete": "Auto-organisation terminée",
|
||||
"error": "Erreur : {error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "Enrichir les métadonnées (IA)"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "Actualiser les données Civitai",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "Partager la recipe",
|
||||
"viewAllLoras": "Voir tous les LoRAs",
|
||||
"downloadMissingLoras": "Télécharger les LoRAs manquants",
|
||||
"deleteRecipe": "Supprimer la recipe"
|
||||
"deleteRecipe": "Supprimer la recipe",
|
||||
"enrichHfAgent": "Enrichir les métadonnées (IA)"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "Télécharger {type} depuis une URL",
|
||||
"civitaiUrl": "URL Civitai :",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "Entrez une URL CivitAI ou CivArchive par ligne. Prend en charge plusieurs URLs pour le téléchargement par lot.",
|
||||
"urlHint": "Entrez une URL CivitAI, CivArchive ou Hugging Face par ligne. Prend en charge plusieurs URL pour le téléchargement par lot.",
|
||||
"selectHfFiles": "Sélectionnez le(s) fichier(s) à télécharger depuis ce dépôt :",
|
||||
"selectAll": "Tout sélectionner",
|
||||
"fetchingRepoFiles": "Récupération des fichiers du dépôt...",
|
||||
"locationPreview": "Aperçu de l'emplacement de téléchargement",
|
||||
"useDefaultPath": "Utiliser le chemin par défaut",
|
||||
"useDefaultPathTooltip": "Lorsque activé, les fichiers sont automatiquement organisés selon les modèles de chemin configurés",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "Format d'URL Civitai invalide",
|
||||
"noVersions": "Aucune version disponible pour ce modèle"
|
||||
"noVersions": "Aucune version disponible pour ce modèle",
|
||||
"mixedSources": "Impossible de mélanger les URL CivitAI et Hugging Face dans le même lot.",
|
||||
"noModelFiles": "Aucun fichier de modèle trouvé dans ce dépôt."
|
||||
},
|
||||
"status": {
|
||||
"preparing": "Préparation du téléchargement...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "Modifier le nom de la version",
|
||||
"viewOnCivitai": "Voir sur Civitai",
|
||||
"viewOnCivitaiText": "Voir sur Civitai",
|
||||
"viewOnHuggingFace": "Voir sur Hugging Face",
|
||||
"viewOnHuggingFaceText": "Voir sur Hugging Face",
|
||||
"viewCreatorProfile": "Voir le profil du créateur",
|
||||
"openFileLocation": "Ouvrir l'emplacement du fichier",
|
||||
"sendToWorkflow": "Envoyer vers ComfyUI",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "Notes supplémentaires",
|
||||
"notesHint": "Appuyez sur Entrée pour sauvegarder, Maj+Entrée pour nouvelle ligne",
|
||||
"addNotesPlaceholder": "Ajoutez vos notes ici...",
|
||||
"aboutThisVersion": "À propos de cette version"
|
||||
"aboutThisVersion": "À propos de cette version",
|
||||
"baseModelSearchPlaceholder": "Rechercher un modèle de base…",
|
||||
"baseModelSuggested": "Suggéré",
|
||||
"baseModelNoMatch": "Aucun modèle de base correspondant"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "Notes sauvegardées avec succès",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "Copié dans le presse-papiers",
|
||||
"downloadStarted": "Téléchargement démarré"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "Fournisseur d'IA non configuré. Activez-le dans Paramètres → Fournisseur d'IA.",
|
||||
"enrichStarted": "Enrichissement des métadonnées par IA...",
|
||||
"enrichComplete": "Enrichissement des métadonnées terminé : {{summary}}",
|
||||
"enrichFailed": "Échec de l'enrichissement des métadonnées : {{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "הסר מהמועדפים",
|
||||
"viewOnCivitai": "הצג ב-Civitai",
|
||||
"notAvailableFromCivitai": "לא זמין מ-Civitai",
|
||||
"viewOnHuggingFace": "צפייה ב-Hugging Face",
|
||||
"sendToWorkflow": "שלח ל-ComfyUI (לחיצה: הוסף, Shift+לחיצה: החלף)",
|
||||
"copyLoRASyntax": "העתק תחביר LoRA",
|
||||
"checkpointNameCopied": "שם Checkpoint הועתק",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "סיסמה (אופציונלי)",
|
||||
"proxyPasswordPlaceholder": "password",
|
||||
"proxyPasswordHelp": "סיסמה לאימות מול הפרוקסי (אם נדרש)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "ספק AI",
|
||||
"provider": "ספק",
|
||||
"providerHelp": "בחר את ספק ה-LLM שלך. OpenAI ו-Ollama משתמשים בנקודות קצה מוגדרות מראש. מותאם אישית מאפשר לך לציין כל נקודת קצה תואמת OpenAI.",
|
||||
"custom": "מותאם אישית (תואם OpenAI)",
|
||||
"apiBase": "כתובת בסיס API",
|
||||
"apiBaseHelp": "כתובת ה-URL הבסיסית ל-API של LLM (לדוגמה https://api.openai.com/v1). השאר ריק לשימוש בברירת המחדל של הספק.",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "מפתח API",
|
||||
"apiKeyHelp": "מפתח ה-API של ספק ה-LLM שלך. נשמר מקומית, לעולם לא נשלח לשרת כלשהו מלבד ספק ה-LLM שבחרת.",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "לא הוגדר",
|
||||
"apiKeyConfigured": "הוגדר",
|
||||
"apiKeySet": "הגדר",
|
||||
"model": "מודל",
|
||||
"modelHelp": "שם המודל לשימוש (לדוגמה deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). בדוק אצל הספק שלך אילו מודלים זמינים."
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "הושלם: {success} הועברו, {skipped} דולגו, {failures} נכשלו",
|
||||
"complete": "ארגון אוטומטי הושלם",
|
||||
"error": "שגיאה: {error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "העשרת מטא-דאטה (AI)"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "רענן נתוני Civitai",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "שתף מתכון",
|
||||
"viewAllLoras": "הצג את כל ה-LoRAs",
|
||||
"downloadMissingLoras": "הורד LoRAs חסרים",
|
||||
"deleteRecipe": "מחק מתכון"
|
||||
"deleteRecipe": "מחק מתכון",
|
||||
"enrichHfAgent": "העשרת מטא-דאטה (AI)"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "הורד {type} מכתובת URL",
|
||||
"civitaiUrl": "כתובת URL של Civitai:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "יש להזין כתובת URL אחת של CivitAI או CivArchive בכל שורה. תומך במספר כתובות URL להורדה בבת אחת.",
|
||||
"urlHint": "יש להזין כתובת URL אחת של CivitAI, CivArchive או Hugging Face בכל שורה. תומך במספר כתובות URL להורדה בקבוצה.",
|
||||
"selectHfFiles": "בחר קבצים להורדה ממאגר זה:",
|
||||
"selectAll": "בחר הכל",
|
||||
"fetchingRepoFiles": "מביא קבצים מהמאגר...",
|
||||
"locationPreview": "תצוגה מקדימה של מיקום ההורדה",
|
||||
"useDefaultPath": "השתמש בנתיב ברירת מחדל",
|
||||
"useDefaultPathTooltip": "כאשר מופעל, קבצים מאורגנים אוטומטית באמצעות תבניות נתיב מוגדרות",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "פורמט URL של Civitai לא חוקי",
|
||||
"noVersions": "אין גרסאות זמינות למודל זה"
|
||||
"noVersions": "אין גרסאות זמינות למודל זה",
|
||||
"mixedSources": "לא ניתן לערבב כתובות URL של CivitAI ו-Hugging Face באותה קבוצה.",
|
||||
"noModelFiles": "לא נמצאו קבצי מודל במאגר זה."
|
||||
},
|
||||
"status": {
|
||||
"preparing": "מכין הורדה...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "ערוך שם גרסה",
|
||||
"viewOnCivitai": "הצג ב-Civitai",
|
||||
"viewOnCivitaiText": "הצג ב-Civitai",
|
||||
"viewOnHuggingFace": "צפייה ב-Hugging Face",
|
||||
"viewOnHuggingFaceText": "צפייה ב-Hugging Face",
|
||||
"viewCreatorProfile": "הצג פרופיל יוצר",
|
||||
"openFileLocation": "פתח מיקום קובץ",
|
||||
"sendToWorkflow": "שלח ל-ComfyUI",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "הערות נוספות",
|
||||
"notesHint": "לחץ Enter לשמירה, Shift+Enter לשורה חדשה",
|
||||
"addNotesPlaceholder": "הוסף את ההערות שלך כאן...",
|
||||
"aboutThisVersion": "אודות גרסה זו"
|
||||
"aboutThisVersion": "אודות גרסה זו",
|
||||
"baseModelSearchPlaceholder": "חפש מודל בסיס…",
|
||||
"baseModelSuggested": "מוצע",
|
||||
"baseModelNoMatch": "אין מודלי בסיס תואמים"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "הערות נשמרו בהצלחה",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "הועתק ללוח",
|
||||
"downloadStarted": "ההורדה החלה"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "ספק AI לא הוגדר. הפעל אותו בהגדרות → ספק AI.",
|
||||
"enrichStarted": "מעשיר מטא-דאטה באמצעות AI...",
|
||||
"enrichComplete": "העשרת מטא-דאטה הושלמה: {{summary}}",
|
||||
"enrichFailed": "העשרת מטא-דאטה נכשלה: {{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "お気に入りから削除",
|
||||
"viewOnCivitai": "Civitaiで表示",
|
||||
"notAvailableFromCivitai": "Civitaiでは利用できません",
|
||||
"viewOnHuggingFace": "Hugging Face で見る",
|
||||
"sendToWorkflow": "ComfyUIに送信(クリック:追加、Shift+クリック:置換)",
|
||||
"copyLoRASyntax": "LoRA構文をコピー",
|
||||
"checkpointNameCopied": "checkpointの名前をコピーしました",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "パスワード(任意)",
|
||||
"proxyPasswordPlaceholder": "パスワード",
|
||||
"proxyPasswordHelp": "プロキシ認証用のパスワード(必要な場合)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "AIプロバイダー",
|
||||
"provider": "プロバイダー",
|
||||
"providerHelp": "LLMプロバイダーを選択してください。OpenAIとOllamaはプリセットのAPIエンドポイントを使用します。カスタムでは任意のOpenAI互換エンドポイントを指定できます。",
|
||||
"custom": "カスタム(OpenAI互換)",
|
||||
"apiBase": "APIベースURL",
|
||||
"apiBaseHelp": "LLM APIのベースURL(例:https://api.openai.com/v1)。空の場合はプロバイダーのデフォルトが使用されます。",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "APIキー",
|
||||
"apiKeyHelp": "LLMプロバイダーのAPIキー。ローカルに保存され、選択したLLMプロバイダー以外のサーバーに送信されることはありません。",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "未設定",
|
||||
"apiKeyConfigured": "設定済み",
|
||||
"apiKeySet": "設定",
|
||||
"model": "モデル",
|
||||
"modelHelp": "使用するモデル名(例:deepseek-v4-flash, gemini-2.5-flash, gemma4:12b)。プロバイダーで利用可能なモデルをご確認ください。"
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "完了:{success} 移動、{skipped} スキップ、{failures} 失敗",
|
||||
"complete": "自動整理が完了しました",
|
||||
"error": "エラー:{error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "メタデータをAIで補完"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "Civitaiデータを更新",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "レシピを共有",
|
||||
"viewAllLoras": "すべてのLoRAを表示",
|
||||
"downloadMissingLoras": "不足しているLoRAをダウンロード",
|
||||
"deleteRecipe": "レシピを削除"
|
||||
"deleteRecipe": "レシピを削除",
|
||||
"enrichHfAgent": "メタデータをAIで補完"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "URLから{type}をダウンロード",
|
||||
"civitaiUrl": "Civitai URL:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "1行に1つのCivitAIまたはCivArchive URLを入力してください。複数のURLを一括ダウンロードできます。",
|
||||
"urlHint": "1行に1つのCivitAI、CivArchive、またはHugging Face URLを入力してください。複数のURLを一括ダウンロードできます。",
|
||||
"selectHfFiles": "このリポジトリからダウンロードするファイルを選択してください:",
|
||||
"selectAll": "すべて選択",
|
||||
"fetchingRepoFiles": "リポジトリのファイルを取得中...",
|
||||
"locationPreview": "ダウンロード場所プレビュー",
|
||||
"useDefaultPath": "デフォルトパスを使用",
|
||||
"useDefaultPathTooltip": "有効にすると、設定されたパステンプレートを使用してファイルが自動的に整理されます",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "無効なCivitai URL形式",
|
||||
"noVersions": "このモデルの利用可能なバージョンがありません"
|
||||
"noVersions": "このモデルの利用可能なバージョンがありません",
|
||||
"mixedSources": "同じバッチ内でCivitAIとHugging FaceのURLを混在させることはできません。",
|
||||
"noModelFiles": "このリポジトリにモデルファイルが見つかりませんでした。"
|
||||
},
|
||||
"status": {
|
||||
"preparing": "ダウンロードを準備中...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "バージョン名を編集",
|
||||
"viewOnCivitai": "Civitaiで表示",
|
||||
"viewOnCivitaiText": "Civitaiで表示",
|
||||
"viewOnHuggingFace": "Hugging Face で見る",
|
||||
"viewOnHuggingFaceText": "Hugging Face で見る",
|
||||
"viewCreatorProfile": "作成者プロフィールを表示",
|
||||
"openFileLocation": "ファイルの場所を開く",
|
||||
"sendToWorkflow": "ComfyUI に送信",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "追加メモ",
|
||||
"notesHint": "Enterで保存、Shift+Enterで改行",
|
||||
"addNotesPlaceholder": "メモをここに追加...",
|
||||
"aboutThisVersion": "このバージョンについて"
|
||||
"aboutThisVersion": "このバージョンについて",
|
||||
"baseModelSearchPlaceholder": "ベースモデルを検索…",
|
||||
"baseModelSuggested": "おすすめ",
|
||||
"baseModelNoMatch": "該当するベースモデルがありません"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "メモが正常に保存されました",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "クリップボードにコピーしました",
|
||||
"downloadStarted": "ダウンロードを開始しました"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "AIプロバイダーが設定されていません。設定 → AIプロバイダーで有効にしてください。",
|
||||
"enrichStarted": "AIでメタデータを補完中...",
|
||||
"enrichComplete": "メタデータの補完が完了しました:{{summary}}",
|
||||
"enrichFailed": "メタデータの補完に失敗しました:{{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "즐겨찾기에서 제거",
|
||||
"viewOnCivitai": "Civitai에서 보기",
|
||||
"notAvailableFromCivitai": "Civitai에서 사용할 수 없음",
|
||||
"viewOnHuggingFace": "Hugging Face에서 보기",
|
||||
"sendToWorkflow": "ComfyUI로 전송 (클릭: 추가, Shift+클릭: 교체)",
|
||||
"copyLoRASyntax": "LoRA 문법 복사",
|
||||
"checkpointNameCopied": "Checkpoint 이름 복사됨",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "비밀번호 (선택사항)",
|
||||
"proxyPasswordPlaceholder": "password",
|
||||
"proxyPasswordHelp": "프록시 인증에 필요한 비밀번호 (필요한 경우)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "AI 제공자",
|
||||
"provider": "제공자",
|
||||
"providerHelp": "LLM 제공자를 선택하세요. OpenAI와 Ollama는 사전 설정된 API 엔드포인트를 사용합니다. 사용자 정의를 선택하면 모든 OpenAI 호환 엔드포인트를 지정할 수 있습니다.",
|
||||
"custom": "사용자 정의 (OpenAI 호환)",
|
||||
"apiBase": "API 기본 URL",
|
||||
"apiBaseHelp": "LLM API의 기본 URL입니다 (예: https://api.openai.com/v1). 비워두면 제공자 기본값이 사용됩니다.",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "API 키",
|
||||
"apiKeyHelp": "LLM 제공자의 API 키입니다. 로컬에 저장되며 선택한 LLM 제공자 외의 서버로 전송되지 않습니다.",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "설정되지 않음",
|
||||
"apiKeyConfigured": "설정됨",
|
||||
"apiKeySet": "설정",
|
||||
"model": "모델",
|
||||
"modelHelp": "사용할 모델 이름 (예: deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). 제공자에서 사용 가능한 모델을 확인하세요."
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "완료: {success}개 이동, {skipped}개 건너뜀, {failures}개 실패",
|
||||
"complete": "자동 정리 완료",
|
||||
"error": "오류: {error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "AI로 메타데이터 보강"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "Civitai 데이터 새로고침",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "레시피 공유",
|
||||
"viewAllLoras": "모든 LoRA 보기",
|
||||
"downloadMissingLoras": "누락된 LoRA 다운로드",
|
||||
"deleteRecipe": "레시피 삭제"
|
||||
"deleteRecipe": "레시피 삭제",
|
||||
"enrichHfAgent": "AI로 메타데이터 보강"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "URL에서 {type} 다운로드",
|
||||
"civitaiUrl": "Civitai URL:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "한 줄에 하나의 CivitAI 또는 CivArchive URL을 입력하세요. 여러 URL을 일괄 다운로드할 수 있습니다.",
|
||||
"urlHint": "한 줄에 하나의 CivitAI, CivArchive 또는 Hugging Face URL을 입력하세요. 여러 URL을 일괄 다운로드할 수 있습니다.",
|
||||
"selectHfFiles": "이 저장소에서 다운로드할 파일을 선택하세요:",
|
||||
"selectAll": "모두 선택",
|
||||
"fetchingRepoFiles": "저장소 파일을 가져오는 중...",
|
||||
"locationPreview": "다운로드 위치 미리보기",
|
||||
"useDefaultPath": "기본 경로 사용",
|
||||
"useDefaultPathTooltip": "활성화하면 구성된 경로 템플릿을 사용하여 파일이 자동으로 정리됩니다",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "잘못된 Civitai URL 형식",
|
||||
"noVersions": "이 모델에 사용 가능한 버전이 없습니다"
|
||||
"noVersions": "이 모델에 사용 가능한 버전이 없습니다",
|
||||
"mixedSources": "동일한 배치에서 CivitAI와 Hugging Face URL을 혼합할 수 없습니다.",
|
||||
"noModelFiles": "이 저장소에서 모델 파일을 찾을 수 없습니다."
|
||||
},
|
||||
"status": {
|
||||
"preparing": "다운로드 준비 중...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "버전명 편집",
|
||||
"viewOnCivitai": "Civitai에서 보기",
|
||||
"viewOnCivitaiText": "Civitai에서 보기",
|
||||
"viewOnHuggingFace": "Hugging Face에서 보기",
|
||||
"viewOnHuggingFaceText": "Hugging Face에서 보기",
|
||||
"viewCreatorProfile": "제작자 프로필 보기",
|
||||
"openFileLocation": "파일 위치 열기",
|
||||
"sendToWorkflow": "ComfyUI로 보내기",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "추가 메모",
|
||||
"notesHint": "Enter로 저장, Shift+Enter로 줄바꿈",
|
||||
"addNotesPlaceholder": "메모를 여기에 추가하세요...",
|
||||
"aboutThisVersion": "이 버전에 대해"
|
||||
"aboutThisVersion": "이 버전에 대해",
|
||||
"baseModelSearchPlaceholder": "베이스 모델 검색…",
|
||||
"baseModelSuggested": "추천",
|
||||
"baseModelNoMatch": "일치하는 베이스 모델 없음"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "메모가 성공적으로 저장됨",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "클립보드에 복사됨",
|
||||
"downloadStarted": "다운로드 시작됨"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "AI 제공자가 설정되지 않았습니다. 설정 → AI 제공자에서 활성화하세요.",
|
||||
"enrichStarted": "AI로 메타데이터 보강 중...",
|
||||
"enrichComplete": "메타데이터 보강 완료: {{summary}}",
|
||||
"enrichFailed": "메타데이터 보강 실패: {{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "Удалить из избранного",
|
||||
"viewOnCivitai": "Посмотреть на Civitai",
|
||||
"notAvailableFromCivitai": "Недоступно на Civitai",
|
||||
"viewOnHuggingFace": "Открыть Hugging Face",
|
||||
"sendToWorkflow": "Отправить в ComfyUI (Клик: Добавить, Shift+Клик: Заменить)",
|
||||
"copyLoRASyntax": "Копировать синтаксис LoRA",
|
||||
"checkpointNameCopied": "Имя checkpoint скопировано",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "Пароль (необязательно)",
|
||||
"proxyPasswordPlaceholder": "пароль",
|
||||
"proxyPasswordHelp": "Пароль для аутентификации на прокси (если требуется)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "Поставщик ИИ",
|
||||
"provider": "Поставщик",
|
||||
"providerHelp": "Выберите поставщика LLM. OpenAI и Ollama используют предустановленные API-эндпоинты. Пользовательский позволяет указать любой совместимый с OpenAI эндпоинт.",
|
||||
"custom": "Пользовательский (совместимый с OpenAI)",
|
||||
"apiBase": "Базовый URL API",
|
||||
"apiBaseHelp": "Базовый URL для LLM API (например, https://api.openai.com/v1). Оставьте пустым, чтобы использовать значение по умолчанию.",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "API-ключ",
|
||||
"apiKeyHelp": "Ваш API-ключ поставщика LLM. Хранится локально и никогда не отправляется на другие серверы, кроме выбранного поставщика LLM.",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "Не задан",
|
||||
"apiKeyConfigured": "Настроен",
|
||||
"apiKeySet": "Настроить",
|
||||
"model": "Модель",
|
||||
"modelHelp": "Имя модели для использования (например, deepseek-v4-flash, gemini-2.5-flash, gemma4:12b). Проверьте доступные модели у вашего поставщика."
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "Завершено: {success} перемещено, {skipped} пропущено, {failures} не удалось",
|
||||
"complete": "Автоматическая организация завершена",
|
||||
"error": "Ошибка: {error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "Обогатить метаданные (ИИ)"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "Обновить данные Civitai",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "Поделиться рецептом",
|
||||
"viewAllLoras": "Посмотреть все LoRAs",
|
||||
"downloadMissingLoras": "Загрузить отсутствующие LoRAs",
|
||||
"deleteRecipe": "Удалить рецепт"
|
||||
"deleteRecipe": "Удалить рецепт",
|
||||
"enrichHfAgent": "Обогатить метаданные (ИИ)"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "Скачать {type} по URL",
|
||||
"civitaiUrl": "Civitai URL:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "Введите один URL CivitAI или CivArchive в каждой строке. Поддерживается пакетная загрузка нескольких URL.",
|
||||
"urlHint": "Введите один URL CivitAI, CivArchive или Hugging Face в каждой строке. Поддерживает несколько URL для пакетной загрузки.",
|
||||
"selectHfFiles": "Выберите файл(ы) для загрузки из этого репозитория:",
|
||||
"selectAll": "Выбрать все",
|
||||
"fetchingRepoFiles": "Получение файлов репозитория...",
|
||||
"locationPreview": "Предпросмотр места загрузки",
|
||||
"useDefaultPath": "Использовать путь по умолчанию",
|
||||
"useDefaultPathTooltip": "При включении файлы автоматически организуются с использованием настроенных шаблонов путей",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "Неверный формат URL Civitai",
|
||||
"noVersions": "Нет доступных версий для этой модели"
|
||||
"noVersions": "Нет доступных версий для этой модели",
|
||||
"mixedSources": "Нельзя смешивать URL-адреса CivitAI и Hugging Face в одном пакете.",
|
||||
"noModelFiles": "В этом репозитории не найдено файлов моделей."
|
||||
},
|
||||
"status": {
|
||||
"preparing": "Подготовка загрузки...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "Редактировать название версии",
|
||||
"viewOnCivitai": "Посмотреть на Civitai",
|
||||
"viewOnCivitaiText": "Посмотреть на Civitai",
|
||||
"viewOnHuggingFace": "Открыть Hugging Face",
|
||||
"viewOnHuggingFaceText": "Открыть Hugging Face",
|
||||
"viewCreatorProfile": "Посмотреть профиль создателя",
|
||||
"openFileLocation": "Открыть расположение файла",
|
||||
"sendToWorkflow": "Отправить в ComfyUI",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "Дополнительные заметки",
|
||||
"notesHint": "Нажмите Enter для сохранения, Shift+Enter для новой строки",
|
||||
"addNotesPlaceholder": "Добавьте ваши заметки здесь...",
|
||||
"aboutThisVersion": "Об этой версии"
|
||||
"aboutThisVersion": "Об этой версии",
|
||||
"baseModelSearchPlaceholder": "Поиск базовой модели…",
|
||||
"baseModelSuggested": "Предполагаемые",
|
||||
"baseModelNoMatch": "Нет подходящих базовых моделей"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "Заметки успешно сохранены",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "Скопировано в буфер обмена",
|
||||
"downloadStarted": "Загрузка начата"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "Поставщик ИИ не настроен. Включите его в Настройки → Поставщик ИИ.",
|
||||
"enrichStarted": "Обогащение метаданных с помощью ИИ...",
|
||||
"enrichComplete": "Обогащение метаданных завершено: {{summary}}",
|
||||
"enrichFailed": "Ошибка обогащения метаданных: {{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "从收藏移除",
|
||||
"viewOnCivitai": "在 Civitai 查看",
|
||||
"notAvailableFromCivitai": "Civitai 上不可用",
|
||||
"viewOnHuggingFace": "在 Hugging Face 查看",
|
||||
"sendToWorkflow": "发送到 ComfyUI(点击:追加,Shift+点击:替换)",
|
||||
"copyLoRASyntax": "复制 LoRA 语法",
|
||||
"checkpointNameCopied": "检查点名称已复制",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "密码 (可选)",
|
||||
"proxyPasswordPlaceholder": "密码",
|
||||
"proxyPasswordHelp": "代理认证的密码 (如果需要)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "AI 提供商",
|
||||
"provider": "提供商",
|
||||
"providerHelp": "选择您的 LLM 提供商。OpenAI 和 Ollama 使用预设的 API 端点。自定义允许您指定任何兼容 OpenAI 的端点。",
|
||||
"custom": "自定义(兼容 OpenAI)",
|
||||
"apiBase": "API 基础地址",
|
||||
"apiBaseHelp": "LLM API 的基础 URL(例如 https://api.openai.com/v1)。留空则使用提供商默认地址。",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "API 密钥",
|
||||
"apiKeyHelp": "您的 LLM 提供商 API 密钥。仅本地存储,不会发送到您选择的 LLM 提供商之外的任何服务器。",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "未设置",
|
||||
"apiKeyConfigured": "已配置",
|
||||
"apiKeySet": "设置",
|
||||
"model": "模型",
|
||||
"modelHelp": "要使用的模型名称(例如 deepseek-v4-flash, gemini-2.5-flash, gemma4:12b)。请查看您的提供商支持的可用模型列表。"
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "完成:已移动 {success} 个,跳过 {skipped} 个,失败 {failures} 个",
|
||||
"complete": "自动整理已完成",
|
||||
"error": "错误:{error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "AI 元数据增强"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "刷新 Civitai 数据",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "分享配方",
|
||||
"viewAllLoras": "查看所有 LoRA",
|
||||
"downloadMissingLoras": "下载缺失的 LoRA",
|
||||
"deleteRecipe": "删除配方"
|
||||
"deleteRecipe": "删除配方",
|
||||
"enrichHfAgent": "AI 元数据增强"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "从 URL 下载 {type}",
|
||||
"civitaiUrl": "Civitai URL:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "每行输入一个 CivitAI 或 CivArchive URL。支持批量下载多个 URL。",
|
||||
"urlHint": "每行输入一个 CivitAI、CivArchive 或 Hugging Face URL。支持批量下载多个 URL。",
|
||||
"selectHfFiles": "选择从此仓库下载的文件:",
|
||||
"selectAll": "全选",
|
||||
"fetchingRepoFiles": "正在获取仓库文件...",
|
||||
"locationPreview": "下载位置预览",
|
||||
"useDefaultPath": "使用默认路径",
|
||||
"useDefaultPathTooltip": "启用后,文件将自动按配置的路径模板进行整理",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "无效的 Civitai URL 格式",
|
||||
"noVersions": "此模型没有可用版本"
|
||||
"noVersions": "此模型没有可用版本",
|
||||
"mixedSources": "无法在同一批次中混合使用 CivitAI 和 Hugging Face URL。",
|
||||
"noModelFiles": "在此仓库中未找到模型文件。"
|
||||
},
|
||||
"status": {
|
||||
"preparing": "正在准备下载...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "编辑版本名称",
|
||||
"viewOnCivitai": "在 Civitai 查看",
|
||||
"viewOnCivitaiText": "在 Civitai 查看",
|
||||
"viewOnHuggingFace": "在 Hugging Face 查看",
|
||||
"viewOnHuggingFaceText": "在 Hugging Face 查看",
|
||||
"viewCreatorProfile": "查看创作者主页",
|
||||
"openFileLocation": "打开文件位置",
|
||||
"sendToWorkflow": "发送到 ComfyUI",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "附加备注",
|
||||
"notesHint": "回车保存,Shift+回车换行",
|
||||
"addNotesPlaceholder": "在此添加你的备注...",
|
||||
"aboutThisVersion": "关于此版本"
|
||||
"aboutThisVersion": "关于此版本",
|
||||
"baseModelSearchPlaceholder": "搜索基础模型…",
|
||||
"baseModelSuggested": "推荐",
|
||||
"baseModelNoMatch": "没有匹配的基础模型"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "备注保存成功",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "已复制到剪贴板",
|
||||
"downloadStarted": "下载已开始"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "AI 提供商未配置。请在 设置 → AI 提供商 中进行配置。",
|
||||
"enrichStarted": "正在使用 AI 增强元数据...",
|
||||
"enrichComplete": "元数据增强完成:{{summary}}",
|
||||
"enrichFailed": "元数据增强失败:{{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
@@ -105,6 +105,7 @@
|
||||
"removeFromFavorites": "移除收藏",
|
||||
"viewOnCivitai": "在 Civitai 查看",
|
||||
"notAvailableFromCivitai": "Civitai 不提供",
|
||||
"viewOnHuggingFace": "在 Hugging Face 查看",
|
||||
"sendToWorkflow": "傳送到 ComfyUI(點擊:附加,Shift+點擊:取代)",
|
||||
"copyLoRASyntax": "複製 LoRA 語法",
|
||||
"checkpointNameCopied": "Checkpoint 名稱已複製",
|
||||
@@ -656,6 +657,23 @@
|
||||
"proxyPassword": "密碼(選填)",
|
||||
"proxyPasswordPlaceholder": "password",
|
||||
"proxyPasswordHelp": "代理驗證所需的密碼(如有需要)"
|
||||
},
|
||||
"aiProvider": {
|
||||
"title": "AI 提供者",
|
||||
"provider": "提供者",
|
||||
"providerHelp": "選擇您的 LLM 提供者。OpenAI 和 Ollama 使用預設 API 端點。自訂允許您指定任何相容 OpenAI 的端點。",
|
||||
"custom": "自訂(相容 OpenAI)",
|
||||
"apiBase": "API 基礎位址",
|
||||
"apiBaseHelp": "LLM API 的基礎 URL(例如 https://api.openai.com/v1)。留空則使用提供者預設位址。",
|
||||
"apiBasePlaceholder": "https://api.openai.com/v1",
|
||||
"apiKey": "API 金鑰",
|
||||
"apiKeyHelp": "您的 LLM 提供者 API 金鑰。僅儲存在本地,除了您選擇的 LLM 提供者外,不會發送到任何伺服器。",
|
||||
"apiKeyPlaceholder": "sk-...",
|
||||
"apiKeyNotSet": "未設定",
|
||||
"apiKeyConfigured": "已設定",
|
||||
"apiKeySet": "設定",
|
||||
"model": "模型",
|
||||
"modelHelp": "要使用的模型名稱(例如 deepseek-v4-flash, gemini-2.5-flash, gemma4:12b)。請查看您的提供者支援的可用模型列表。"
|
||||
}
|
||||
},
|
||||
"loras": {
|
||||
@@ -753,7 +771,8 @@
|
||||
"completed": "完成:已移動 {success},已略過 {skipped},失敗 {failures}",
|
||||
"complete": "自動整理完成",
|
||||
"error": "錯誤:{error}"
|
||||
}
|
||||
},
|
||||
"enrichHfAgent": "AI 中繼資料增強"
|
||||
},
|
||||
"contextMenu": {
|
||||
"refreshMetadata": "刷新 Civitai 資料",
|
||||
@@ -777,7 +796,8 @@
|
||||
"shareRecipe": "分享配方",
|
||||
"viewAllLoras": "檢視全部 LoRA",
|
||||
"downloadMissingLoras": "下載缺少的 LoRA",
|
||||
"deleteRecipe": "刪除配方"
|
||||
"deleteRecipe": "刪除配方",
|
||||
"enrichHfAgent": "AI 中繼資料增強"
|
||||
}
|
||||
},
|
||||
"recipes": {
|
||||
@@ -1134,7 +1154,10 @@
|
||||
"titleWithType": "從網址下載 {type}",
|
||||
"civitaiUrl": "Civitai 網址:",
|
||||
"placeholder": "https://civitai.com/models/...",
|
||||
"urlHint": "每行輸入一個 CivitAI 或 CivArchive URL。支援批量下載多個 URL。",
|
||||
"urlHint": "每行輸入一個 CivitAI、CivArchive 或 Hugging Face URL。支援批量下載多個 URL。",
|
||||
"selectHfFiles": "選擇從此倉庫下載的檔案:",
|
||||
"selectAll": "全選",
|
||||
"fetchingRepoFiles": "正在獲取倉庫檔案...",
|
||||
"locationPreview": "下載位置預覽",
|
||||
"useDefaultPath": "使用預設路徑",
|
||||
"useDefaultPathTooltip": "啟用後,檔案將依照設定的路徑範本自動整理",
|
||||
@@ -1163,7 +1186,9 @@
|
||||
},
|
||||
"errors": {
|
||||
"invalidUrl": "Civitai 網址格式無效",
|
||||
"noVersions": "此模型無可用版本"
|
||||
"noVersions": "此模型無可用版本",
|
||||
"mixedSources": "無法在同一批次中混合使用 CivitAI 和 Hugging Face URL。",
|
||||
"noModelFiles": "在此倉庫中未找到模型檔案。"
|
||||
},
|
||||
"status": {
|
||||
"preparing": "準備下載中...",
|
||||
@@ -1314,6 +1339,8 @@
|
||||
"editVersionName": "編輯版本名稱",
|
||||
"viewOnCivitai": "在 Civitai 查看",
|
||||
"viewOnCivitaiText": "在 Civitai 查看",
|
||||
"viewOnHuggingFace": "在 Hugging Face 查看",
|
||||
"viewOnHuggingFaceText": "在 Hugging Face 查看",
|
||||
"viewCreatorProfile": "查看創作者個人檔案",
|
||||
"openFileLocation": "開啟檔案位置",
|
||||
"sendToWorkflow": "傳送到 ComfyUI",
|
||||
@@ -1339,7 +1366,10 @@
|
||||
"additionalNotes": "附加備註",
|
||||
"notesHint": "按 Enter 儲存,Shift+Enter 換行",
|
||||
"addNotesPlaceholder": "在此新增備註...",
|
||||
"aboutThisVersion": "關於此版本"
|
||||
"aboutThisVersion": "關於此版本",
|
||||
"baseModelSearchPlaceholder": "搜尋基礎模型…",
|
||||
"baseModelSuggested": "推薦",
|
||||
"baseModelNoMatch": "沒有符合的基礎模型"
|
||||
},
|
||||
"notes": {
|
||||
"saved": "備註已儲存",
|
||||
@@ -2070,6 +2100,12 @@
|
||||
"moveFailed": "Failed to move item: {message}",
|
||||
"copiedToClipboard": "已複製到剪貼簿",
|
||||
"downloadStarted": "下載已開始"
|
||||
},
|
||||
"agent": {
|
||||
"llmNotConfigured": "AI 提供者尚未設定。請在 設定 → AI 提供者 中進行設定。",
|
||||
"enrichStarted": "正在使用 AI 增強中繼資料...",
|
||||
"enrichComplete": "中繼資料增強完成:{{summary}}",
|
||||
"enrichFailed": "中繼資料增強失敗:{{error}}"
|
||||
}
|
||||
},
|
||||
"doctor": {
|
||||
|
||||
225
py/agent_cli/__init__.py
Normal file
225
py/agent_cli/__init__.py
Normal file
@@ -0,0 +1,225 @@
|
||||
"""Agent CLI — thin in-process wrappers around LoRA Manager internal services.
|
||||
|
||||
All functions are simple Python async functions that delegate to the
|
||||
appropriate internal service. They use **relative imports** within the
|
||||
``py`` package, so ``sys.modules`` caching works normally and there is no
|
||||
risk of double import or circular dependencies.
|
||||
|
||||
Usage (in-process, primary)::
|
||||
|
||||
from py.agent_cli import list_base_models, read_metadata
|
||||
|
||||
models = await list_base_models()
|
||||
meta = await read_metadata("/path/to/model.safetensors")
|
||||
|
||||
Usage (subprocess, debugging / external)::
|
||||
|
||||
python -m py.agent_cli base-models list
|
||||
python -m py.agent_cli metadata read /path/to/model.safetensors
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _find_scanner_for_model(
|
||||
model_path: str,
|
||||
) -> tuple[object, object] | tuple[None, None]:
|
||||
"""Find the (scanner, cache_entry) responsible for *model_path*.
|
||||
|
||||
Iterates all known scanner types and returns the first one whose cache
|
||||
contains the given path. Returns ``(None, None)`` when no scanner
|
||||
claims the model.
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
normalized = os.path.normpath(model_path)
|
||||
for getter_name in (
|
||||
"get_lora_scanner",
|
||||
"get_checkpoint_scanner",
|
||||
"get_embedding_scanner",
|
||||
):
|
||||
getter = getattr(ServiceRegistry, getter_name, None)
|
||||
if getter is None:
|
||||
continue
|
||||
try:
|
||||
scanner = await getter()
|
||||
if scanner is None:
|
||||
continue
|
||||
cache = await scanner.get_cached_data()
|
||||
for entry in cache.raw_data:
|
||||
if os.path.normpath(entry.get("file_path", "")) == normalized:
|
||||
return scanner, entry
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Scanner %s check failed for %s: %s",
|
||||
getter_name,
|
||||
model_path,
|
||||
exc,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def list_base_models(limit: int = 0) -> List[str]:
|
||||
"""Return deduplicated base model names from all model caches.
|
||||
|
||||
The result is ordered by frequency (most common first). Pass
|
||||
*limit* = 0 (default) for all models.
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
counts: Dict[str, int] = {}
|
||||
for getter_name in (
|
||||
"get_lora_scanner",
|
||||
"get_checkpoint_scanner",
|
||||
"get_embedding_scanner",
|
||||
):
|
||||
getter = getattr(ServiceRegistry, getter_name, None)
|
||||
if getter is None:
|
||||
continue
|
||||
try:
|
||||
scanner = await getter()
|
||||
if scanner is None:
|
||||
continue
|
||||
cache = await scanner.get_cached_data()
|
||||
for entry in cache.raw_data:
|
||||
bm = entry.get("base_model")
|
||||
if bm:
|
||||
counts[bm] = counts.get(bm, 0) + 1
|
||||
except Exception as exc:
|
||||
logger.debug("list_base_models scanner %s error: %s", getter_name, exc)
|
||||
|
||||
sorted_names = [name for name, _ in sorted(counts.items(), key=lambda x: -x[1])]
|
||||
if limit > 0:
|
||||
return sorted_names[:limit]
|
||||
return sorted_names
|
||||
|
||||
|
||||
async def read_metadata(model_path: str) -> Dict[str, Any]:
|
||||
"""Load the full metadata payload for *model_path* from disk.
|
||||
|
||||
Returns an empty dict when the metadata file does not exist or cannot
|
||||
be parsed — never raises.
|
||||
"""
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
try:
|
||||
return await MetadataManager.load_metadata_payload(model_path) or {}
|
||||
except Exception as exc:
|
||||
logger.warning("read_metadata failed for %s: %s", model_path, exc)
|
||||
return {}
|
||||
|
||||
|
||||
async def apply_metadata_updates(
|
||||
model_path: str,
|
||||
updates: Dict[str, Any],
|
||||
) -> List[str]:
|
||||
"""Merge *updates* into the model's on-disk metadata and persist.
|
||||
|
||||
Returns the list of field names that actually changed.
|
||||
"""
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
metadata = await read_metadata(model_path)
|
||||
updated_fields: List[str] = []
|
||||
for key, value in updates.items():
|
||||
old = metadata.get(key)
|
||||
if old != value:
|
||||
metadata[key] = value
|
||||
updated_fields.append(key)
|
||||
if updated_fields:
|
||||
await MetadataManager.save_metadata(model_path, metadata)
|
||||
return updated_fields
|
||||
|
||||
|
||||
async def download_preview(
|
||||
model_path: str,
|
||||
url: str,
|
||||
*,
|
||||
target_width: int = 480,
|
||||
quality: int = 85,
|
||||
) -> bool:
|
||||
"""Download a preview image from *url*, optimise to .webp, and save it.
|
||||
|
||||
The output file is placed alongside the model file with a ``.webp``
|
||||
extension. Returns ``True`` on success.
|
||||
"""
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
|
||||
if not url or not url.strip():
|
||||
return False
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
preview_dir = os.path.dirname(model_path)
|
||||
output_path = os.path.join(preview_dir, base_name + ".webp")
|
||||
|
||||
downloader = await get_downloader()
|
||||
|
||||
# Try in-memory download + optimise first
|
||||
success, content, _headers = await downloader.download_to_memory(
|
||||
url, use_auth=False,
|
||||
)
|
||||
if success and content:
|
||||
try:
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=content,
|
||||
target_width=target_width,
|
||||
format="webp",
|
||||
quality=quality,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(optimized_data)
|
||||
logger.info("Preview downloaded and optimised for %s", model_path)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Preview optimisation failed, saving raw: %s", exc)
|
||||
# Fall through to raw save
|
||||
|
||||
# Fallback: download directly to file
|
||||
try:
|
||||
ok, _ = await downloader.download_file(url, output_path, use_auth=False)
|
||||
if ok:
|
||||
logger.info("Preview downloaded (fallback) for %s", model_path)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("Preview fallback download failed for %s: %s", model_path, exc)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def refresh_cache(model_path: str) -> bool:
|
||||
"""Invalidate and reload the scanner cache entry for *model_path*.
|
||||
|
||||
Returns ``True`` when the model was found and the cache was refreshed.
|
||||
"""
|
||||
scanner, entry = await _find_scanner_for_model(model_path)
|
||||
if scanner is None:
|
||||
logger.warning("refresh_cache: no scanner found for %s", model_path)
|
||||
return False
|
||||
try:
|
||||
metadata = await read_metadata(model_path)
|
||||
if not metadata:
|
||||
logger.warning("refresh_cache: no metadata for %s", model_path)
|
||||
return False
|
||||
await scanner.update_single_model_cache(model_path, model_path, metadata)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("refresh_cache failed for %s: %s", model_path, exc)
|
||||
return False
|
||||
118
py/agent_cli/__main__.py
Normal file
118
py/agent_cli/__main__.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Subprocess entry point for AgentCLI (debugging / external use).
|
||||
|
||||
Usage::
|
||||
|
||||
python -m py.agent_cli base-models list [--limit N]
|
||||
python -m py.agent_cli metadata read <path>
|
||||
python -m py.agent_cli metadata update <path> --json '{...}'
|
||||
python -m py.agent_cli preview download <path> --url <url>
|
||||
python -m py.agent_cli cache refresh <path>
|
||||
|
||||
NOTE: This is an **optional** convenience wrapper. The primary consumer of
|
||||
AgentCLI is the :mod:`AgentService` (in-process). This entry point exists
|
||||
for manual debugging and future integration with subprocess-based agent
|
||||
frameworks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="lmcli", description="LoRA Manager Agent CLI")
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# base-models list
|
||||
base_models = sub.add_parser("base-models", aliases=["bm"])
|
||||
base_models_cmds = base_models.add_subparsers(dest="subcommand", required=True)
|
||||
base_models_list = base_models_cmds.add_parser("list")
|
||||
base_models_list.add_argument(
|
||||
"--limit", type=int, default=0, help="Max number of models (0 = all)"
|
||||
)
|
||||
|
||||
# metadata read
|
||||
meta = sub.add_parser("metadata", aliases=["md"])
|
||||
meta_cmds = meta.add_subparsers(dest="subcommand", required=True)
|
||||
meta_read = meta_cmds.add_parser("read")
|
||||
meta_read.add_argument("path", type=str, help="Model file path")
|
||||
|
||||
# metadata update
|
||||
meta_update = meta_cmds.add_parser("update")
|
||||
meta_update.add_argument("path", type=str, help="Model file path")
|
||||
meta_update.add_argument(
|
||||
"--json",
|
||||
type=str,
|
||||
required=True,
|
||||
help='JSON object of fields to update, e.g. \'{"base_model": "SDXL 1.0"}\'',
|
||||
)
|
||||
|
||||
# preview download
|
||||
prev = sub.add_parser("preview", aliases=["pv"])
|
||||
prev_cmds = prev.add_subparsers(dest="subcommand", required=True)
|
||||
prev_dl = prev_cmds.add_parser("download")
|
||||
prev_dl.add_argument("path", type=str, help="Model file path")
|
||||
prev_dl.add_argument("--url", type=str, required=True, help="Preview image URL")
|
||||
|
||||
# cache refresh
|
||||
cache = sub.add_parser("cache")
|
||||
cache_cmds = cache.add_subparsers(dest="subcommand", required=True)
|
||||
cache_refresh = cache_cmds.add_parser("refresh")
|
||||
cache_refresh.add_argument("path", type=str, help="Model file path")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> Any:
|
||||
from . import ( # lazy import so startup is fast
|
||||
list_base_models,
|
||||
read_metadata,
|
||||
apply_metadata_updates,
|
||||
download_preview,
|
||||
refresh_cache,
|
||||
)
|
||||
|
||||
cmd = args.command
|
||||
sub = args.subcommand
|
||||
|
||||
if cmd in ("base-models", "bm") and sub == "list":
|
||||
return await list_base_models(limit=args.limit)
|
||||
|
||||
if cmd in ("metadata", "md") and sub == "read":
|
||||
return await read_metadata(args.path)
|
||||
|
||||
if cmd in ("metadata", "md") and sub == "update":
|
||||
updates: Dict[str, Any] = json.loads(args.json)
|
||||
return await apply_metadata_updates(args.path, updates)
|
||||
|
||||
if cmd in ("preview", "pv") and sub == "download":
|
||||
return await download_preview(args.path, args.url)
|
||||
|
||||
if cmd == "cache" and sub == "refresh":
|
||||
return await refresh_cache(args.path)
|
||||
|
||||
raise ValueError(f"Unknown command: {cmd} {sub}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(_run(args))
|
||||
# Always print as JSON so callers can parse reliably
|
||||
if isinstance(result, list):
|
||||
for item in result:
|
||||
print(item)
|
||||
elif isinstance(result, dict):
|
||||
json.dump(result, sys.stdout, ensure_ascii=False, indent=2)
|
||||
print()
|
||||
else:
|
||||
print(json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
20
py/config.py
20
py/config.py
@@ -8,6 +8,8 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Set, Tuple
|
||||
import logging
|
||||
import json
|
||||
import urllib.parse
|
||||
import sys as _sys
|
||||
import types as _types
|
||||
import time
|
||||
|
||||
from .utils.cache_paths import CacheType, get_cache_file_path, get_legacy_cache_paths
|
||||
@@ -1380,4 +1382,20 @@ class Config:
|
||||
|
||||
|
||||
# Global config instance
|
||||
config = Config()
|
||||
# NOTE: Guard against re-import. When ServiceRegistry.get_lora_scanner() triggers
|
||||
# a fresh import of lora_scanner → config, we must NOT re-execute Config.__init__()
|
||||
# (which re-scans all roots, re-registers libraries, etc.).
|
||||
#
|
||||
# Strategy: store the config instance in a dedicated sentinel module
|
||||
# ('_lm_config_cache') that is NEVER removed from sys.modules (its key does
|
||||
# NOT start with 'py.'), so it survives re-imports of py.* modules.
|
||||
_CONFIG_SENTINEL = "_lm_config_cache"
|
||||
if _CONFIG_SENTINEL in _sys.modules:
|
||||
# Re-import: reuse the existing singleton from the sentinel.
|
||||
config: Config = _sys.modules[_CONFIG_SENTINEL].config # type: ignore[valid-type]
|
||||
else:
|
||||
config: Config = Config()
|
||||
# Register the sentinel so re-imports of py.config find us.
|
||||
_sentinel_mod = _types.ModuleType(_CONFIG_SENTINEL)
|
||||
_sentinel_mod.config = config
|
||||
_sys.modules[_CONFIG_SENTINEL] = _sentinel_mod
|
||||
|
||||
@@ -445,5 +445,12 @@ class LoraManager:
|
||||
scanner.cancel_task()
|
||||
logger.debug("LoRA Manager: Cancelled %s", name)
|
||||
|
||||
# Close shared aiohttp sessions to avoid "Unclosed client session" warnings
|
||||
try:
|
||||
from py.routes.handlers.hf_handlers import close_hf_api_session
|
||||
await close_hf_api_session()
|
||||
except Exception as exc:
|
||||
logger.debug("Error closing HF API session: %s", exc)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during cleanup: {e}", exc_info=True)
|
||||
|
||||
167
py/routes/handlers/agent_handlers.py
Normal file
167
py/routes/handlers/agent_handlers.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""HTTP route handlers for agent skill endpoints.
|
||||
|
||||
These handlers expose the :class:`AgentService` via HTTP, allowing the
|
||||
frontend to list available skills and execute them on selected models.
|
||||
Progress is reported via WebSocket broadcast.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.agent import AgentService, AgentProgressReporter
|
||||
from ...services.llm_service import LLMNotConfiguredError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentHandler:
|
||||
"""HTTP handler for agent skill operations."""
|
||||
|
||||
def __init__(self, agent_service: AgentService | None = None) -> None:
|
||||
self._agent_service = agent_service
|
||||
|
||||
async def _ensure_service(self) -> AgentService:
|
||||
if self._agent_service is None:
|
||||
self._agent_service = await AgentService.get_instance()
|
||||
return self._agent_service
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# GET /api/lm/agent/skills
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def get_agent_skills(self, request: web.Request) -> web.Response:
|
||||
"""Return a list of available agent skills."""
|
||||
|
||||
service = await self._ensure_service()
|
||||
skills = await service.list_skills()
|
||||
return web.json_response({"skills": skills})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# POST /api/lm/agent/execute/{skill_name}
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def execute_agent_skill(self, request: web.Request) -> web.Response:
|
||||
"""Execute an agent skill on the provided model paths.
|
||||
|
||||
Request body::
|
||||
|
||||
{"model_paths": ["/path/to/model1.safetensors", ...], "options": {}}
|
||||
|
||||
Returns immediately with a task ID. Execution runs in the
|
||||
background; progress and completion are pushed via WebSocket
|
||||
events of type ``agent_progress``.
|
||||
"""
|
||||
|
||||
skill_name = request.match_info.get("skill_name", "")
|
||||
if not skill_name:
|
||||
return web.json_response(
|
||||
{"error": "Skill name is required"}, status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return web.json_response(
|
||||
{"error": "Invalid JSON body"}, status_code=400
|
||||
)
|
||||
|
||||
model_paths = body.get("model_paths", [])
|
||||
if not model_paths or not isinstance(model_paths, list):
|
||||
return web.json_response(
|
||||
{"error": "model_paths must be a non-empty array"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
service = await self._ensure_service()
|
||||
|
||||
# Validate LLM configuration early for skills that need it
|
||||
# (fail fast rather than after starting background work)
|
||||
try:
|
||||
from ...services.llm_service import LLMService
|
||||
|
||||
llm = await LLMService.get_instance()
|
||||
if not llm.is_configured():
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "LLM provider is not configured. "
|
||||
"Enable it in Settings → AI Provider.",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to check LLM configuration: %s", exc)
|
||||
|
||||
# Launch execution in the background
|
||||
progress_reporter = AgentProgressReporter()
|
||||
logger.info(
|
||||
"Agent skill '%s' starting for %d model(s) in background task",
|
||||
skill_name, len(model_paths),
|
||||
)
|
||||
|
||||
async def _run() -> None:
|
||||
logger.info("_run background task started for skill '%s'", skill_name)
|
||||
try:
|
||||
result = await service.execute_skill(
|
||||
skill_name=skill_name,
|
||||
input_data={"model_paths": model_paths},
|
||||
progress_callback=progress_reporter,
|
||||
)
|
||||
logger.info(
|
||||
"Agent skill '%s' finished: success=%s, summary='%s', errors=%s",
|
||||
skill_name, result.success, result.summary, result.errors,
|
||||
)
|
||||
except LLMNotConfiguredError as exc:
|
||||
logger.warning("Agent skill '%s' not configured: %s", skill_name, exc)
|
||||
await progress_reporter.on_progress(
|
||||
{
|
||||
"type": "agent_progress",
|
||||
"skill": skill_name,
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Agent skill '%s' failed: %s", skill_name, exc, exc_info=True)
|
||||
await progress_reporter.on_progress(
|
||||
{
|
||||
"type": "agent_progress",
|
||||
"skill": skill_name,
|
||||
"status": "error",
|
||||
"error": str(exc),
|
||||
}
|
||||
)
|
||||
|
||||
# Fire and forget — progress comes via WebSocket
|
||||
task = asyncio.create_task(_run())
|
||||
logger.info("Agent skill '%s' background task created (id=%s)", skill_name, task)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"status": "started",
|
||||
"skill": skill_name,
|
||||
"model_count": len(model_paths),
|
||||
}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# POST /api/lm/agent/cancel
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def cancel_agent_skill(self, request: web.Request) -> web.Response:
|
||||
"""Cancel a running agent skill.
|
||||
|
||||
NOTE: Cancellation is a stub for now — the AgentService processes
|
||||
models sequentially and does not yet support mid-execution
|
||||
cancellation. This endpoint exists for API completeness.
|
||||
"""
|
||||
|
||||
# TODO: implement cooperative cancellation in AgentService
|
||||
return web.json_response(
|
||||
{"status": "acknowledged", "note": "Cancellation not yet implemented"},
|
||||
status_code=200,
|
||||
)
|
||||
417
py/routes/handlers/hf_handlers.py
Normal file
417
py/routes/handlers/hf_handlers.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""Handlers for Hugging Face model listing and download.
|
||||
|
||||
Minimal MVP implementation — uses direct HTTP to the HF API for file
|
||||
listing and the project's existing aiohttp-based Downloader for
|
||||
downloading. No huggingface_hub dependency required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config
|
||||
from ...services.downloader import (
|
||||
DownloadProgress,
|
||||
get_downloader,
|
||||
)
|
||||
from ...services.aria2_downloader import Aria2Downloader
|
||||
from ...services.settings_manager import get_settings_manager
|
||||
from ...services.service_registry import ServiceRegistry
|
||||
from ...services.websocket_manager import ws_manager
|
||||
from ...utils.constants import MODEL_FILE_EXTENSIONS
|
||||
from ...utils.metadata_manager import MetadataManager
|
||||
from ...utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_MODEL_CLASS = LoraMetadata
|
||||
_DEFAULT_SCANNER_GETTER = "get_lora_scanner"
|
||||
|
||||
# Shared aiohttp session for HF API calls (created on first use)
|
||||
_hf_api_session: aiohttp.ClientSession | None = None
|
||||
|
||||
|
||||
async def _get_hf_api_session() -> aiohttp.ClientSession:
|
||||
"""Get or create the shared aiohttp session for HF API calls."""
|
||||
global _hf_api_session # needed because we reassign the module-level name
|
||||
if _hf_api_session is None or _hf_api_session.closed:
|
||||
_hf_api_session = aiohttp.ClientSession(
|
||||
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
)
|
||||
return _hf_api_session
|
||||
|
||||
|
||||
async def close_hf_api_session() -> None:
|
||||
"""Close the shared HF API session, if it was ever created."""
|
||||
global _hf_api_session
|
||||
if _hf_api_session is not None and not _hf_api_session.closed:
|
||||
await _hf_api_session.close()
|
||||
_hf_api_session = None
|
||||
|
||||
|
||||
def _infer_model_type(model_root: str) -> tuple[Any, str]:
|
||||
"""Determine model class and scanner by matching ``model_root`` against the
|
||||
configured root paths for each model type (from ``Config``).
|
||||
|
||||
The ``model_root`` value comes from the frontend's model-root dropdown,
|
||||
which is populated from the current page's scanner roots. By checking
|
||||
which scanner's root list it belongs to, we avoid fragile heuristics
|
||||
like substring-matching path names.
|
||||
"""
|
||||
norm = os.path.normpath(model_root).replace(os.sep, "/")
|
||||
|
||||
# LoRA roots
|
||||
for p in (config.loras_roots or []) + (config.extra_loras_roots or []):
|
||||
if os.path.normpath(p).replace(os.sep, "/") == norm:
|
||||
return LoraMetadata, "get_lora_scanner"
|
||||
|
||||
# Checkpoint / UNet roots
|
||||
for p in (
|
||||
(config.checkpoints_roots or [])
|
||||
+ (config.extra_checkpoints_roots or [])
|
||||
+ (config.unet_roots or [])
|
||||
+ (config.extra_unet_roots or [])
|
||||
):
|
||||
if os.path.normpath(p).replace(os.sep, "/") == norm:
|
||||
return CheckpointMetadata, "get_checkpoint_scanner"
|
||||
|
||||
# Embedding roots
|
||||
for p in (config.embeddings_roots or []) + (config.extra_embeddings_roots or []):
|
||||
if os.path.normpath(p).replace(os.sep, "/") == norm:
|
||||
return EmbeddingMetadata, "get_embedding_scanner"
|
||||
|
||||
# Fallback — should not happen in normal use
|
||||
logger.warning(
|
||||
"Could not determine model type for root '%s'; defaulting to LoRA",
|
||||
model_root,
|
||||
)
|
||||
return _DEFAULT_MODEL_CLASS, _DEFAULT_SCANNER_GETTER
|
||||
|
||||
|
||||
async def _save_hf_metadata(dest_path: str, repo: str, model_root: str) -> None:
|
||||
"""Create a proper .metadata.json and add the model to the scanner cache.
|
||||
|
||||
Uses ``MetadataManager.create_default_metadata()`` which computes the
|
||||
SHA256 hash, extracts safetensors header metadata (base_model), and
|
||||
produces a fully-populated ``LoraMetadata`` (or ``CheckpointMetadata`` /
|
||||
``EmbeddingMetadata``) object. We then overlay HF-specific fields and
|
||||
register the model in the in-memory scanner cache so it appears
|
||||
immediately without a full filesystem walk.
|
||||
"""
|
||||
try:
|
||||
hf_url = f"https://huggingface.co/{repo}"
|
||||
model_class, scanner_getter_name = _infer_model_type(model_root)
|
||||
|
||||
# 1. Create proper metadata (computes SHA256, reads safetensors headers)
|
||||
metadata = await MetadataManager.create_default_metadata(
|
||||
dest_path, model_class=model_class
|
||||
)
|
||||
if metadata is None:
|
||||
logger.warning("create_default_metadata returned None for %s", dest_path)
|
||||
return
|
||||
|
||||
# 2. Overlay HF-specific fields
|
||||
metadata._unknown_fields["hf_url"] = hf_url
|
||||
metadata.from_civitai = False # HF models are not from CivitAI
|
||||
|
||||
# 3. Save metadata atomically
|
||||
await MetadataManager.save_metadata(dest_path, metadata)
|
||||
logger.info("Saved HF metadata (with hf_url) for %s", dest_path)
|
||||
|
||||
# 4. Determine relative folder path for cache
|
||||
# model_root is an absolute path; dest_path is under it
|
||||
folder = ""
|
||||
if os.path.isabs(model_root) and dest_path.startswith(model_root):
|
||||
rel = os.path.relpath(os.path.dirname(dest_path), model_root)
|
||||
folder = rel.replace(os.sep, "/") if rel != "." else ""
|
||||
|
||||
# 5. Add to scanner cache (same as CivitAI's _execute_download does)
|
||||
scanner_getter = getattr(ServiceRegistry, scanner_getter_name, None)
|
||||
if scanner_getter is not None:
|
||||
scanner = await scanner_getter()
|
||||
if scanner is not None:
|
||||
metadata_dict = metadata.to_dict()
|
||||
metadata_dict["hf_url"] = hf_url
|
||||
await scanner.add_model_to_cache(metadata_dict, folder)
|
||||
logger.info("Added %s to scanner cache (folder=%s)", dest_path, folder)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to save HF metadata for %s: %s", dest_path, exc)
|
||||
|
||||
|
||||
class HfHandler:
|
||||
"""Handle Hugging Face model browsing and download."""
|
||||
|
||||
async def get_hf_repo_files(self, request: web.Request) -> web.Response:
|
||||
"""List model-weight files from a HF repo with real file sizes.
|
||||
|
||||
Uses the HF tree API endpoint which returns accurate file sizes
|
||||
(including LFS-tracked files), unlike the model info endpoint.
|
||||
"""
|
||||
repo = request.query.get("repo", "").strip()
|
||||
if not repo or "/" not in repo:
|
||||
return web.json_response(
|
||||
{"error": "Missing or invalid 'repo' parameter (expected user/repo)"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
url = f"https://huggingface.co/api/models/{repo}/tree/main"
|
||||
|
||||
try:
|
||||
session = await _get_hf_api_session()
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 404:
|
||||
return web.json_response(
|
||||
{"error": f"Repo '{repo}' not found"}, status=404
|
||||
)
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
return web.json_response(
|
||||
{"error": f"HF API error {resp.status}: {text[:200]}"},
|
||||
status=resp.status,
|
||||
)
|
||||
tree: list[dict[str, Any]] = await resp.json()
|
||||
except Exception as exc:
|
||||
logger.error("Failed to fetch HF repo files: %s", exc)
|
||||
return web.json_response({"error": str(exc)}, status=502)
|
||||
|
||||
files: list[dict[str, Any]] = []
|
||||
for entry in tree:
|
||||
path: str = entry.get("path", "")
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
if ext not in MODEL_FILE_EXTENSIONS:
|
||||
continue
|
||||
size = entry.get("size", 0) or 0
|
||||
if size == 0 and "lfs" in entry:
|
||||
size = entry["lfs"].get("size", 0) or 0
|
||||
files.append({
|
||||
"filename": path,
|
||||
"size": size,
|
||||
})
|
||||
|
||||
files.sort(key=lambda f: f["size"], reverse=True)
|
||||
return web.json_response(files)
|
||||
|
||||
async def download_hf_model(self, request: web.Request) -> web.Response:
|
||||
"""Download a single file from Hugging Face into the model directory.
|
||||
|
||||
POST JSON body::
|
||||
|
||||
{
|
||||
"repo": "dx8152/Flux2-Klein-9B-Consistency",
|
||||
"filename": "Flux2-Klein-9B-consistency-V2.safetensors",
|
||||
"revision": "main",
|
||||
"model_root": "loras",
|
||||
"relative_path": "",
|
||||
"use_default_paths": false,
|
||||
"download_id": "optional-batch-id"
|
||||
}
|
||||
|
||||
If ``download_id`` is provided, real-time progress (bytes, speed,
|
||||
percentage) is broadcast via the WebSocket progress system, matching
|
||||
the CivitAI download experience.
|
||||
|
||||
Respects the ``download_backend`` setting (``aria2`` or ``default``).
|
||||
"""
|
||||
try:
|
||||
payload: dict[str, Any] = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
return web.json_response({"error": "Invalid JSON"}, status=400)
|
||||
|
||||
repo = (payload.get("repo") or "").strip()
|
||||
filename = (payload.get("filename") or "").strip()
|
||||
revision = (payload.get("revision") or "main").strip()
|
||||
model_root = (payload.get("model_root") or "").strip()
|
||||
relative_path = (payload.get("relative_path") or "").strip()
|
||||
use_default_paths = bool(payload.get("use_default_paths", False))
|
||||
download_id: str | None = payload.get("download_id")
|
||||
|
||||
logger.info(
|
||||
"download_hf_model: repo=%s file=%s root=%s download_id=%s",
|
||||
repo, filename, model_root, download_id,
|
||||
)
|
||||
|
||||
if not repo or not filename:
|
||||
return web.json_response(
|
||||
{"error": "Missing required fields: 'repo' and 'filename'"}, status=400
|
||||
)
|
||||
|
||||
# Validate repo format — must be user/repo_name
|
||||
if repo.count("/") != 1 or not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo):
|
||||
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
|
||||
author, repo_name = repo.split("/", 1)
|
||||
if ".." in (author, repo_name) or "." in (author, repo_name):
|
||||
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
|
||||
|
||||
# Validate filename — must not contain path separators or ..
|
||||
if "/" in filename or "\\" in filename or ".." in filename:
|
||||
return web.json_response({"error": "Invalid filename"}, status=400)
|
||||
|
||||
# Validate relative_path — must not be absolute or escape base directory
|
||||
if relative_path:
|
||||
if os.path.isabs(relative_path):
|
||||
return web.json_response({"error": "relative_path must not be absolute"}, status=400)
|
||||
if ".." in relative_path.split("/") or "\\" in relative_path:
|
||||
return web.json_response({"error": "Invalid relative_path"}, status=400)
|
||||
|
||||
# Validate model_root — must not contain path traversal
|
||||
if not os.path.isabs(model_root):
|
||||
# For relative model_root, check it doesn't escape
|
||||
resolved_model_root = os.path.realpath(
|
||||
os.path.join(os.getcwd(), "models", model_root)
|
||||
)
|
||||
else:
|
||||
resolved_model_root = os.path.realpath(model_root)
|
||||
|
||||
# Verify model_root is within a configured scanner root
|
||||
allowed_roots = set()
|
||||
for root_list in (
|
||||
config.loras_roots or [],
|
||||
config.extra_loras_roots or [],
|
||||
config.checkpoints_roots or [],
|
||||
config.extra_checkpoints_roots or [],
|
||||
config.unet_roots or [],
|
||||
config.extra_unet_roots or [],
|
||||
config.embeddings_roots or [],
|
||||
config.extra_embeddings_roots or [],
|
||||
):
|
||||
for r in root_list:
|
||||
allowed_roots.add(os.path.realpath(r))
|
||||
|
||||
if not any(resolved_model_root == root or resolved_model_root.startswith(root + os.sep) for root in allowed_roots):
|
||||
logger.warning("Invalid model_root rejected: %s", model_root)
|
||||
return web.json_response({"error": f"Invalid model_root: {model_root}"}, status=400)
|
||||
|
||||
base_dir = resolved_model_root
|
||||
|
||||
if use_default_paths:
|
||||
target_dir = os.path.join(base_dir, "huggingface", author, repo_name)
|
||||
elif relative_path:
|
||||
target_dir = os.path.join(base_dir, relative_path)
|
||||
else:
|
||||
target_dir = base_dir
|
||||
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
dest_path = os.path.join(target_dir, filename)
|
||||
|
||||
# Resolve symlinks and check for path traversal escape
|
||||
real_dest = os.path.realpath(dest_path)
|
||||
real_base = os.path.realpath(target_dir)
|
||||
if not real_dest.startswith(real_base + os.sep):
|
||||
logger.warning("Path traversal blocked: %s -> %s", dest_path, real_dest)
|
||||
return web.json_response({"error": "Path traversal detected"}, status=400)
|
||||
|
||||
# Check if already exists (simple skip)
|
||||
if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0:
|
||||
logger.info("download_hf_model: file already exists, skipping — %s", dest_path)
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"File already exists: {dest_path}",
|
||||
"path": dest_path,
|
||||
})
|
||||
|
||||
# Build HF resolve URL
|
||||
resolve_url = (
|
||||
f"https://huggingface.co/{repo}/resolve/{revision}/{filename}"
|
||||
)
|
||||
|
||||
# Set up progress callback if download_id is provided
|
||||
progress_callback = None
|
||||
if download_id:
|
||||
|
||||
async def _progress_callback(
|
||||
progress: float | DownloadProgress,
|
||||
snapshot: DownloadProgress | None = None,
|
||||
) -> None:
|
||||
percent = 0.0
|
||||
metrics = snapshot if isinstance(snapshot, DownloadProgress) else None
|
||||
|
||||
if isinstance(progress, DownloadProgress):
|
||||
percent = progress.percent_complete
|
||||
metrics = progress
|
||||
elif isinstance(snapshot, DownloadProgress):
|
||||
percent = snapshot.percent_complete
|
||||
else:
|
||||
percent = float(progress)
|
||||
|
||||
broadcast: dict[str, Any] = {
|
||||
"status": "progress",
|
||||
"progress": round(percent),
|
||||
}
|
||||
if metrics:
|
||||
broadcast["bytes_downloaded"] = metrics.bytes_downloaded
|
||||
broadcast["total_bytes"] = metrics.total_bytes
|
||||
broadcast["bytes_per_second"] = metrics.bytes_per_second
|
||||
|
||||
await ws_manager.broadcast_download_progress(download_id, broadcast)
|
||||
|
||||
progress_callback = _progress_callback
|
||||
|
||||
# Respect download backend setting (aria2 vs default)
|
||||
download_backend = (
|
||||
get_settings_manager().get("download_backend", "default")
|
||||
)
|
||||
|
||||
if download_backend == "aria2":
|
||||
aria2 = await Aria2Downloader.get_instance()
|
||||
aid = download_id or f"hf_{repo}_{filename}"
|
||||
try:
|
||||
hf_success, hf_result = await aria2.download_file(
|
||||
url=resolve_url,
|
||||
save_path=dest_path,
|
||||
download_id=aid,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if hf_success:
|
||||
await _save_hf_metadata(dest_path, repo, model_root)
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Downloaded to {dest_path}",
|
||||
"path": dest_path,
|
||||
})
|
||||
else:
|
||||
return web.json_response(
|
||||
{"success": False, "error": hf_result or "aria2 download failed"},
|
||||
status=500,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("HF download (aria2) failed: %s", exc)
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc)}, status=500
|
||||
)
|
||||
|
||||
# Default: use built-in aiohttp Downloader
|
||||
downloader = await get_downloader()
|
||||
try:
|
||||
success, result = await downloader.download_file(
|
||||
url=resolve_url,
|
||||
save_path=dest_path,
|
||||
use_auth=False,
|
||||
allow_resume=True,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if success:
|
||||
await _save_hf_metadata(dest_path, repo, model_root)
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Downloaded to {result}",
|
||||
"path": result,
|
||||
})
|
||||
else:
|
||||
return web.json_response(
|
||||
{"success": False, "error": result or "Download failed"},
|
||||
status=500,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("HF download failed: %s", exc)
|
||||
return web.json_response(
|
||||
{"success": False, "error": str(exc)}, status=500
|
||||
)
|
||||
@@ -48,6 +48,8 @@ from ...utils.constants import (
|
||||
SUPPORTED_MEDIA_EXTENSIONS,
|
||||
VALID_LORA_TYPES,
|
||||
)
|
||||
from .hf_handlers import HfHandler
|
||||
from .agent_handlers import AgentHandler
|
||||
from ...utils.civitai_utils import rewrite_preview_url
|
||||
from ...utils.example_images_paths import (
|
||||
find_non_compliant_items_in_example_images_root,
|
||||
@@ -3315,6 +3317,8 @@ class MiscHandlerSet:
|
||||
doctor: DoctorHandler,
|
||||
example_workflows: ExampleWorkflowsHandler,
|
||||
base_model: BaseModelHandlerSet,
|
||||
hf_handler: HfHandler | None = None,
|
||||
agent_handler: AgentHandler | None = None,
|
||||
) -> None:
|
||||
self.health = health
|
||||
self.settings = settings
|
||||
@@ -3333,6 +3337,8 @@ class MiscHandlerSet:
|
||||
self.doctor = doctor
|
||||
self.example_workflows = example_workflows
|
||||
self.base_model = base_model
|
||||
self.hf_handler = hf_handler
|
||||
self.agent_handler = agent_handler
|
||||
|
||||
def to_route_mapping(
|
||||
self,
|
||||
@@ -3378,6 +3384,13 @@ class MiscHandlerSet:
|
||||
"get_supporters": self.supporters.get_supporters,
|
||||
"get_example_workflows": self.example_workflows.get_example_workflows,
|
||||
"get_example_workflow": self.example_workflows.get_example_workflow,
|
||||
# Hugging Face handlers
|
||||
"get_hf_repo_files": self.hf_handler.get_hf_repo_files,
|
||||
"download_hf_model": self.hf_handler.download_hf_model,
|
||||
# Agent skill handlers
|
||||
"get_agent_skills": self.agent_handler.get_agent_skills,
|
||||
"execute_agent_skill": self.agent_handler.execute_agent_skill,
|
||||
"cancel_agent_skill": self.agent_handler.cancel_agent_skill,
|
||||
# Base model handlers
|
||||
"get_base_models": self.base_model.get_base_models,
|
||||
"refresh_base_models": self.base_model.refresh_base_models,
|
||||
|
||||
@@ -203,11 +203,17 @@ class ModelListingHandler:
|
||||
result = await self._service.get_paginated_data(**params)
|
||||
|
||||
format_start = time.perf_counter()
|
||||
formatted_raw = [
|
||||
await self._service.format_response(entry)
|
||||
for entry in result["items"]
|
||||
]
|
||||
# Filter out None entries returned for corrupted cache rows (issue #730).
|
||||
# Note: "total" intentionally remains the pre-filter count to reflect
|
||||
# the true number of models in the cache; corrupted entries are rare
|
||||
# and adjusting total would cause pagination drift on every page.
|
||||
formatted_items = [item for item in formatted_raw if item is not None]
|
||||
formatted_result = {
|
||||
"items": [
|
||||
await self._service.format_response(item)
|
||||
for item in result["items"]
|
||||
],
|
||||
"items": formatted_items,
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
@@ -238,11 +244,15 @@ class ModelListingHandler:
|
||||
result = await self._service.get_excluded_paginated_data(**params)
|
||||
|
||||
format_start = time.perf_counter()
|
||||
formatted_raw = [
|
||||
await self._service.format_response(entry)
|
||||
for entry in result["items"]
|
||||
]
|
||||
# Filter out None entries returned for corrupted cache rows (issue #730).
|
||||
# "total" stays at the pre-filter count; see get_models for rationale.
|
||||
formatted_items = [item for item in formatted_raw if item is not None]
|
||||
formatted_result = {
|
||||
"items": [
|
||||
await self._service.format_response(item)
|
||||
for item in result["items"]
|
||||
],
|
||||
"items": formatted_items,
|
||||
"total": result["total"],
|
||||
"page": result["page"],
|
||||
"page_size": result["page_size"],
|
||||
@@ -533,8 +543,13 @@ class ModelManagementHandler:
|
||||
if not success:
|
||||
return web.json_response({"success": False, "error": error})
|
||||
|
||||
formatted_metadata = await self._service.format_response(model_data)
|
||||
return web.json_response({"success": True, "metadata": formatted_metadata})
|
||||
formatted = await self._service.format_response(model_data)
|
||||
if formatted is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Model entry is corrupted (missing file_path)"},
|
||||
status=500,
|
||||
)
|
||||
return web.json_response({"success": True, "metadata": formatted})
|
||||
except Exception as exc:
|
||||
if is_expected_offline_error(str(exc)):
|
||||
return web.json_response(
|
||||
@@ -1091,10 +1106,12 @@ class ModelQueryHandler:
|
||||
# Sort: originals first, copies last
|
||||
sorted_models = self._sort_duplicate_group(filtered)
|
||||
|
||||
# Format response
|
||||
# Format response, filtering out corrupted entries (issue #730)
|
||||
group = {"hash": sha256, "models": []}
|
||||
for model in sorted_models:
|
||||
group["models"].append(await self._service.format_response(model))
|
||||
formatted = await self._service.format_response(model)
|
||||
if formatted is not None:
|
||||
group["models"].append(formatted)
|
||||
|
||||
# Only include groups with 2+ models after filtering
|
||||
if len(group["models"]) > 1:
|
||||
@@ -1211,9 +1228,9 @@ class ModelQueryHandler:
|
||||
(m for m in cache.raw_data if m["file_path"] == path), None
|
||||
)
|
||||
if model:
|
||||
group["models"].append(
|
||||
await self._service.format_response(model)
|
||||
)
|
||||
formatted = await self._service.format_response(model)
|
||||
if formatted is not None:
|
||||
group["models"].append(formatted)
|
||||
hash_val = self._service.scanner.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self._service.get_path_by_hash(hash_val)
|
||||
@@ -1223,9 +1240,9 @@ class ModelQueryHandler:
|
||||
None,
|
||||
)
|
||||
if main_model:
|
||||
group["models"].insert(
|
||||
0, await self._service.format_response(main_model)
|
||||
)
|
||||
formatted = await self._service.format_response(main_model)
|
||||
if formatted is not None:
|
||||
group["models"].insert(0, formatted)
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
return web.json_response(
|
||||
|
||||
@@ -94,6 +94,23 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/delete-model-version", "delete_model_version"
|
||||
),
|
||||
# Hugging Face model endpoints
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/hf-repo-files", "get_hf_repo_files"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/download-hf-model", "download_hf_model"
|
||||
),
|
||||
# Agent skill endpoints
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/agent/skills", "get_agent_skills"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/agent/execute/{skill_name}", "execute_agent_skill"
|
||||
),
|
||||
RouteDefinition(
|
||||
"POST", "/api/lm/agent/cancel", "cancel_agent_skill"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ from .handlers.misc_handlers import (
|
||||
build_service_registry_adapter,
|
||||
)
|
||||
from .handlers.base_model_handlers import BaseModelHandlerSet
|
||||
from .handlers.hf_handlers import HfHandler
|
||||
from .handlers.agent_handlers import AgentHandler
|
||||
from .misc_route_registrar import MiscRouteRegistrar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -136,6 +138,8 @@ class MiscRoutes:
|
||||
doctor = DoctorHandler(settings_service=self._settings)
|
||||
example_workflows = ExampleWorkflowsHandler()
|
||||
base_model = BaseModelHandlerSet()
|
||||
hf_handler = HfHandler()
|
||||
agent_handler = AgentHandler()
|
||||
|
||||
return self._handler_set_factory(
|
||||
health=health,
|
||||
@@ -155,6 +159,8 @@ class MiscRoutes:
|
||||
doctor=doctor,
|
||||
example_workflows=example_workflows,
|
||||
base_model=base_model,
|
||||
hf_handler=hf_handler,
|
||||
agent_handler=agent_handler,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -16,6 +16,27 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
NETWORK_EXCEPTIONS = (ClientError, OSError, asyncio.TimeoutError)
|
||||
|
||||
# User-managed directories that live inside the plugin folder (portable
|
||||
# mode) and must survive a Git-based update. ``git clean -fd`` would
|
||||
# otherwise delete them because they are untracked and, in released tags,
|
||||
# not listed in ``.gitignore``. ``-e`` excludes a path from cleaning
|
||||
# regardless of whether it is ignored.
|
||||
_PRESERVE_DIRS = ('settings.json', 'civitai', 'wildcards', 'backups', 'stats', 'logs', 'cache', 'model_cache')
|
||||
|
||||
|
||||
def _clean_excludes() -> List[str]:
|
||||
"""Build the ``-e`` arguments for ``git clean`` from :data:`_PRESERVE_DIRS`."""
|
||||
excludes: List[str] = []
|
||||
for name in _PRESERVE_DIRS:
|
||||
excludes.append('-e')
|
||||
excludes.append(name)
|
||||
# For directories, also exclude nested matches explicitly
|
||||
# (``-e dir`` alone matches the dir entry; ``-e dir/**`` guards
|
||||
# contents under all git versions as defense-in-depth).
|
||||
excludes.append('-e')
|
||||
excludes.append(f'{name}/**')
|
||||
return excludes
|
||||
|
||||
|
||||
class UpdateRoutes:
|
||||
"""Routes for handling plugin update checks"""
|
||||
@@ -365,6 +386,8 @@ class UpdateRoutes:
|
||||
)
|
||||
return False, ""
|
||||
|
||||
clean_excludes = _clean_excludes()
|
||||
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
@@ -376,8 +399,9 @@ class UpdateRoutes:
|
||||
if nightly:
|
||||
# Reset to discard any local changes
|
||||
repo.git.reset('--hard')
|
||||
# Clean untracked files
|
||||
repo.git.clean('-fd')
|
||||
# Clean untracked files, but preserve user-managed directories
|
||||
# (wildcards, backups, stats, civitai, caches, settings.json).
|
||||
repo.git.clean('-fd', *clean_excludes)
|
||||
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
@@ -394,8 +418,9 @@ class UpdateRoutes:
|
||||
else:
|
||||
# Reset to discard any local changes
|
||||
repo.git.reset('--hard')
|
||||
# Clean untracked files
|
||||
repo.git.clean('-fd')
|
||||
# Clean untracked files, but preserve user-managed directories
|
||||
# (wildcards, backups, stats, civitai, caches, settings.json).
|
||||
repo.git.clean('-fd', *clean_excludes)
|
||||
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
|
||||
23
py/services/agent/__init__.py
Normal file
23
py/services/agent/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Agent-powered skill system for LoRA Manager.
|
||||
|
||||
This package provides the orchestration layer for LLM/agent-powered features.
|
||||
Skills define *what* to do (prompt template). The :class:`AgentService`
|
||||
handles *how* (LLM calls, context gathering, validation, progress).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .skill_definition import SkillDefinition, SkillPermissions
|
||||
from .skill_registry import SkillRegistry
|
||||
from .agent_service import AgentService, AgentProgressReporter, SkillResult
|
||||
from .post_processor import PostProcessor
|
||||
|
||||
__all__ = [
|
||||
"AgentProgressReporter",
|
||||
"AgentService",
|
||||
"PostProcessor",
|
||||
"SkillDefinition",
|
||||
"SkillPermissions",
|
||||
"SkillRegistry",
|
||||
"SkillResult",
|
||||
]
|
||||
413
py/services/agent/agent_service.py
Normal file
413
py/services/agent/agent_service.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""Agent orchestration service.
|
||||
|
||||
The :class:`AgentService` coordinates skill execution:
|
||||
|
||||
1. Look up the skill in :class:`SkillRegistry`
|
||||
2. Validate input against the skill's ``input_schema``
|
||||
3. Prepare context via :mod:`~py.agent_cli` (read metadata, list base models, fetch HF README)
|
||||
4. If ``llm_required``: call :class:`LLMService` with the rendered prompt
|
||||
5. Post-process via :class:`PostProcessor` (delegates I/O to :mod:`~py.agent_cli`)
|
||||
6. Broadcast progress and completion via :class:`WebSocketManager`
|
||||
|
||||
Skills define *what* to do (prompt template). The AgentService handles *how*
|
||||
(LLM calls, context gathering, validation, progress).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from ..llm_service import LLMService
|
||||
from ..websocket_manager import ws_manager
|
||||
from .post_processor import PostProcessor
|
||||
from .skill_registry import SkillRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentProgressReporter:
|
||||
"""Protocol-compatible progress reporter backed by WebSocket broadcast."""
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
await ws_manager.broadcast(payload)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillResult:
|
||||
"""Outcome of a skill execution."""
|
||||
|
||||
success: bool
|
||||
updated_models: List[Dict[str, Any]] = field(default_factory=list)
|
||||
errors: List[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
|
||||
|
||||
def _validate_schema(data: Any, schema: Dict[str, Any], path: str = "") -> List[str]:
|
||||
"""Minimal JSON schema validator.
|
||||
|
||||
Supports a subset of JSON Schema: ``type``, ``properties``, ``required``,
|
||||
``items``, ``enum``. Returns a list of error messages (empty = valid).
|
||||
"""
|
||||
|
||||
errors: List[str] = []
|
||||
if not schema:
|
||||
return errors
|
||||
|
||||
expected_type = schema.get("type")
|
||||
if expected_type:
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": (int, float),
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
"null": type(None),
|
||||
}
|
||||
expected_py = type_map.get(expected_type)
|
||||
if expected_py is not None and not isinstance(data, expected_py):
|
||||
errors.append(f"{path or 'root'}: expected {expected_type}, got {type(data).__name__}")
|
||||
return errors
|
||||
|
||||
if expected_type == "object" and isinstance(data, dict):
|
||||
properties = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
for req_key in required:
|
||||
if req_key not in data:
|
||||
errors.append(f"{path or 'root'}: missing required property '{req_key}'")
|
||||
for key, value in data.items():
|
||||
if key in properties:
|
||||
errors.extend(_validate_schema(value, properties[key], f"{path}.{key}"))
|
||||
|
||||
if expected_type == "array" and isinstance(data, list):
|
||||
items_schema = schema.get("items")
|
||||
if items_schema:
|
||||
for i, item in enumerate(data):
|
||||
errors.extend(_validate_schema(item, items_schema, f"{path}[{i}]"))
|
||||
|
||||
if "enum" in schema and data not in schema["enum"]:
|
||||
errors.append(f"{path or 'root'}: value '{data}' not in enum {schema['enum']}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt template rendering
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _render_prompt(template: str, variables: Dict[str, Any]) -> str:
|
||||
"""Render a prompt template with ``{{variable}}`` placeholders.
|
||||
|
||||
Uses simple regex substitution — no Jinja2 dependency needed.
|
||||
"""
|
||||
|
||||
def replace(match: re.Match) -> str:
|
||||
key = match.group(1).strip()
|
||||
value = variables.get(key, "")
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False, indent=2)
|
||||
return str(value)
|
||||
|
||||
return re.sub(r"\{\{(\w+)\}\}", replace, template)
|
||||
|
||||
|
||||
class AgentService:
|
||||
"""Orchestrate agent skill execution.
|
||||
|
||||
Usage::
|
||||
|
||||
service = await AgentService.get_instance()
|
||||
result = await service.execute_skill(
|
||||
skill_name="enrich_hf_metadata",
|
||||
input_data={"model_paths": ["/path/to/model.safetensors"]},
|
||||
progress_callback=AgentProgressReporter(),
|
||||
)
|
||||
"""
|
||||
|
||||
_instance: Optional["AgentService"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
skill_registry: Optional[SkillRegistry] = None,
|
||||
llm_service: Optional[LLMService] = None,
|
||||
) -> None:
|
||||
self._registry = skill_registry
|
||||
self._llm_service = llm_service
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "AgentService":
|
||||
"""Return the lazily-initialised global ``AgentService``."""
|
||||
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(
|
||||
skill_registry=await SkillRegistry.get_instance(),
|
||||
llm_service=await LLMService.get_instance(),
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the cached singleton — primarily for tests."""
|
||||
|
||||
cls._instance = None
|
||||
|
||||
async def _ensure_registry(self) -> SkillRegistry:
|
||||
if self._registry is None:
|
||||
self._registry = await SkillRegistry.get_instance()
|
||||
return self._registry
|
||||
|
||||
async def _ensure_llm(self) -> LLMService:
|
||||
if self._llm_service is None:
|
||||
self._llm_service = await LLMService.get_instance()
|
||||
return self._llm_service
|
||||
|
||||
async def list_skills(self) -> List[Dict[str, Any]]:
|
||||
"""Return a JSON-serialisable list of available skills."""
|
||||
|
||||
registry = await self._ensure_registry()
|
||||
return [
|
||||
{
|
||||
"name": s.name,
|
||||
"title": s.title,
|
||||
"description": s.description,
|
||||
"llm_required": s.llm_required,
|
||||
"model_type_filter": s.model_type_filter,
|
||||
}
|
||||
for s in registry.list_skills()
|
||||
]
|
||||
|
||||
async def execute_skill(
|
||||
self,
|
||||
*,
|
||||
skill_name: str,
|
||||
input_data: Dict[str, Any],
|
||||
progress_callback: Optional[AgentProgressReporter] = None,
|
||||
) -> SkillResult:
|
||||
"""Execute an agent skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill to execute
|
||||
input_data: Input validated against the skill's ``input_schema``
|
||||
progress_callback: Optional WebSocket progress reporter
|
||||
|
||||
Returns:
|
||||
:class:`SkillResult` with success status and updated model info
|
||||
"""
|
||||
|
||||
registry = await self._ensure_registry()
|
||||
logger.info("execute_skill '%s': looking up skill", skill_name)
|
||||
skill = registry.get_skill(skill_name)
|
||||
if skill is None:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=[f"Skill not found: {skill_name}"],
|
||||
summary=f"Skill '{skill_name}' does not exist",
|
||||
)
|
||||
|
||||
input_errors = _validate_schema(input_data, skill.input_schema)
|
||||
if input_errors:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=input_errors,
|
||||
summary=f"Invalid input: {'; '.join(input_errors)}",
|
||||
)
|
||||
|
||||
model_paths = input_data.get("model_paths", [])
|
||||
if not model_paths:
|
||||
return SkillResult(
|
||||
success=False,
|
||||
errors=["No model_paths provided"],
|
||||
summary="No models to process",
|
||||
)
|
||||
|
||||
total = len(model_paths)
|
||||
processed = 0
|
||||
success_count = 0
|
||||
updated_models: List[Dict[str, Any]] = []
|
||||
errors: List[str] = []
|
||||
post_processor = PostProcessor()
|
||||
|
||||
logger.info("execute_skill '%s': starting with %d model(s)", skill_name, total)
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="started",
|
||||
total=total, processed=0, success=0,
|
||||
)
|
||||
|
||||
llm = await self._ensure_llm()
|
||||
llm_configured = llm.is_configured() if skill.llm_required else True
|
||||
|
||||
for model_path in model_paths:
|
||||
logger.info(
|
||||
"execute_skill '%s': processing model %d/%d: %s",
|
||||
skill_name, processed + 1, total, model_path,
|
||||
)
|
||||
try:
|
||||
from ...agent_cli import read_metadata
|
||||
metadata = await read_metadata(model_path)
|
||||
|
||||
prompt_vars: Dict[str, Any] = {"model_path": model_path}
|
||||
if skill.llm_required and llm_configured:
|
||||
prompt_vars = await self._build_prompt_context(
|
||||
skill_name, model_path, metadata, registry, llm,
|
||||
)
|
||||
|
||||
llm_response: Optional[Dict[str, Any]] = None
|
||||
if skill.llm_required and llm_configured:
|
||||
prompt_template = registry.load_prompt(skill_name)
|
||||
rendered = _render_prompt(prompt_template, prompt_vars)
|
||||
logger.info(
|
||||
"execute_skill '%s': LLM call for %s (prompt=%d chars)",
|
||||
skill_name, model_path, len(rendered),
|
||||
)
|
||||
llm_response = await llm.chat_completion_json(
|
||||
system_prompt=prompt_vars.get(
|
||||
"system_prompt",
|
||||
"You are a helpful assistant that extracts structured metadata.",
|
||||
),
|
||||
user_prompt=rendered,
|
||||
)
|
||||
|
||||
model_result = await post_processor.process(
|
||||
skill_name=skill_name,
|
||||
model_path=model_path,
|
||||
llm_output=llm_response or {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if model_result.get("success", True):
|
||||
success_count += 1
|
||||
uf = model_result.get("updated_fields", [])
|
||||
if uf:
|
||||
updated_models.append({"path": model_path, "updated_fields": uf})
|
||||
else:
|
||||
errors.extend(
|
||||
model_result.get("errors", [model_result.get("error", "Unknown error")])
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Skill %s failed for %s: %s", skill_name, model_path, exc)
|
||||
errors.append(f"{model_path}: {exc}")
|
||||
|
||||
processed += 1
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="processing",
|
||||
total=total, processed=processed, success=success_count,
|
||||
current_path=model_path,
|
||||
)
|
||||
|
||||
result = SkillResult(
|
||||
success=success_count > 0,
|
||||
updated_models=updated_models,
|
||||
errors=errors,
|
||||
summary=f"Processed {processed}/{total} models, {success_count} succeeded",
|
||||
)
|
||||
|
||||
logger.info("execute_skill '%s': done — %s", skill_name, result.summary)
|
||||
await self._emit_progress(
|
||||
progress_callback, skill_name, status="completed",
|
||||
total=total, processed=processed, success=success_count,
|
||||
updated_models=updated_models, errors=errors, summary=result.summary,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _build_prompt_context(
|
||||
self,
|
||||
skill_name: str,
|
||||
model_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
registry: SkillRegistry,
|
||||
llm: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""Gather variables for the skill's prompt template.
|
||||
|
||||
Reads metadata, fetches the HF README (if applicable), lists available
|
||||
base models, and returns a dict that maps to ``{{variable}}``
|
||||
placeholders in ``prompt.md``.
|
||||
"""
|
||||
from ...agent_cli import list_base_models
|
||||
|
||||
context: Dict[str, Any] = {
|
||||
"model_path": model_path,
|
||||
"hf_url": "",
|
||||
"repo": "",
|
||||
"readme_content": "",
|
||||
"current_metadata": {},
|
||||
"base_models": [],
|
||||
}
|
||||
|
||||
context["current_metadata"] = {
|
||||
"file_name": metadata.get("file_name", ""),
|
||||
"base_model": metadata.get("base_model", ""),
|
||||
"tags": metadata.get("tags", []),
|
||||
"modelDescription": metadata.get("modelDescription", ""),
|
||||
"trainedWords": metadata.get("trainedWords", []),
|
||||
"sha256": (metadata.get("sha256") or "")[:16] + "..." if metadata.get("sha256") else "",
|
||||
"size": metadata.get("size", 0),
|
||||
}
|
||||
|
||||
hf_url = metadata.get("hf_url", "")
|
||||
context["hf_url"] = hf_url
|
||||
repo = self._extract_repo_from_url(hf_url) if hf_url else ""
|
||||
context["repo"] = repo or ""
|
||||
if repo:
|
||||
readme = await self._fetch_readme(repo)
|
||||
context["readme_content"] = readme[:8000] if readme else "(README not available)"
|
||||
|
||||
try:
|
||||
context["base_models"] = await list_base_models()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to list base models: %s", exc)
|
||||
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _extract_repo_from_url(hf_url: str) -> Optional[str]:
|
||||
"""Extract ``user/repo`` from a HuggingFace URL."""
|
||||
if not hf_url:
|
||||
return None
|
||||
m = re.match(r"https?://huggingface\.co/([^/]+/[^/]+)", hf_url)
|
||||
return m.group(1) if m else None
|
||||
|
||||
@staticmethod
|
||||
async def _fetch_readme(repo: str) -> str:
|
||||
"""Fetch README.md from HuggingFace (tries ``main``, then ``master``)."""
|
||||
async with aiohttp.ClientSession(
|
||||
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
) as session:
|
||||
for branch in ("main", "master"):
|
||||
url = f"https://huggingface.co/{repo}/raw/{branch}/README.md"
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
return await resp.text()
|
||||
except Exception as exc:
|
||||
logger.debug("Failed to fetch README from %s: %s", url, exc)
|
||||
return ""
|
||||
|
||||
async def _emit_progress(
|
||||
self,
|
||||
callback: Optional[AgentProgressReporter],
|
||||
skill_name: str,
|
||||
*,
|
||||
status: str,
|
||||
**extra: Any,
|
||||
) -> None:
|
||||
"""Send a progress update via WebSocket (if callback is set)."""
|
||||
payload: Dict[str, Any] = {"type": "agent_progress", "skill": skill_name, "status": status}
|
||||
payload.update(extra)
|
||||
if callback is not None:
|
||||
await callback.on_progress(payload)
|
||||
168
py/services/agent/post_processor.py
Normal file
168
py/services/agent/post_processor.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Post-processing engine for agent skill outputs.
|
||||
|
||||
The :class:`PostProcessor` takes the LLM's structured JSON output and applies
|
||||
it to a model's on-disk metadata via the :mod:`~py.agent_cli` functions.
|
||||
|
||||
It handles all the skill-specific business logic — conditions, transformations,
|
||||
and orchestration of multiple side-effects (write metadata, download preview,
|
||||
refresh cache). All actual I/O is delegated to :mod:`~py.agent_cli`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PostProcessor:
|
||||
"""Deterministic post-processor for agent skill outputs.
|
||||
|
||||
Usage (called by :class:`~py.services.agent.agent_service.AgentService`)::
|
||||
|
||||
processor = PostProcessor()
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/path/to/model.safetensors",
|
||||
llm_output={...},
|
||||
metadata={...}, # from agent_cli.read_metadata()
|
||||
)
|
||||
"""
|
||||
|
||||
async def process(
|
||||
self,
|
||||
*,
|
||||
skill_name: str,
|
||||
model_path: str,
|
||||
llm_output: Dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Route *llm_output* to the correct skill post-processor.
|
||||
|
||||
Returns a dict with keys ``success`` (bool), ``updated_fields`` (list),
|
||||
``preview_downloaded`` (bool), and ``errors`` (list).
|
||||
"""
|
||||
if skill_name == "enrich_hf_metadata":
|
||||
return await self._process_enrich_hf_metadata(
|
||||
model_path, llm_output, metadata,
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"updated_fields": [],
|
||||
"errors": [f"No post-processor registered for skill: {skill_name}"],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# enrich_hf_metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _process_enrich_hf_metadata(
|
||||
self,
|
||||
model_path: str,
|
||||
llm_output: Dict[str, Any],
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
from ...agent_cli import (
|
||||
apply_metadata_updates,
|
||||
download_preview,
|
||||
refresh_cache,
|
||||
)
|
||||
|
||||
updated_fields: List[str] = []
|
||||
preview_downloaded = False
|
||||
|
||||
# -- Determine whether this is an HF-sourced model -----------------
|
||||
is_hf_model = not metadata.get("from_civitai", True)
|
||||
|
||||
# -- Collect updates -----------------------------------------------
|
||||
updates: Dict[str, Any] = {}
|
||||
|
||||
# base_model
|
||||
new_base = (llm_output.get("base_model") or "").strip()
|
||||
current_base = metadata.get("base_model", "") or ""
|
||||
if new_base and self._should_overwrite(current_base, is_hf_model):
|
||||
updates["base_model"] = new_base
|
||||
|
||||
# trainedWords / trigger words
|
||||
new_triggers = llm_output.get("trigger_words", [])
|
||||
if isinstance(new_triggers, list):
|
||||
cleaned = [t.strip() for t in new_triggers if t.strip()]
|
||||
if cleaned:
|
||||
current_triggers = metadata.get("trainedWords") or []
|
||||
if self._should_overwrite_list(current_triggers, is_hf_model):
|
||||
updates["trainedWords"] = cleaned
|
||||
|
||||
# modelDescription
|
||||
new_desc = (llm_output.get("description") or "").strip()
|
||||
if new_desc:
|
||||
current_desc = metadata.get("modelDescription", "") or ""
|
||||
if self._should_overwrite(current_desc, is_hf_model):
|
||||
updates["modelDescription"] = new_desc
|
||||
|
||||
# tags — merge with existing, deduplicate (case-insensitive)
|
||||
new_tags = llm_output.get("tags", [])
|
||||
if isinstance(new_tags, list) and new_tags:
|
||||
existing_tags = metadata.get("tags") or []
|
||||
merged = self._merge_tags(existing_tags, new_tags)
|
||||
if len(merged) > len(existing_tags) or is_hf_model:
|
||||
updates["tags"] = merged
|
||||
|
||||
# metadata_source & llm_enriched_at (always set)
|
||||
updates["metadata_source"] = "agent:enrich_hf_metadata"
|
||||
updates["llm_enriched_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
# -- Persist updates ------------------------------------------------
|
||||
if updates:
|
||||
updated_fields = await apply_metadata_updates(model_path, updates)
|
||||
|
||||
# -- Download preview -----------------------------------------------
|
||||
preview_url = (llm_output.get("preview_url") or "").strip()
|
||||
current_preview = metadata.get("preview_url") or ""
|
||||
if preview_url and not (current_preview and os.path.exists(current_preview)):
|
||||
preview_downloaded = await download_preview(model_path, preview_url)
|
||||
|
||||
# -- Refresh scanner cache ------------------------------------------
|
||||
if updated_fields or preview_downloaded:
|
||||
await refresh_cache(model_path)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"updated_fields": updated_fields,
|
||||
"preview_downloaded": preview_downloaded,
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _should_overwrite(current_value: str, is_hf_model: bool) -> bool:
|
||||
"""Return ``True`` when a scalar field should be overwritten."""
|
||||
return is_hf_model or not current_value or current_value.lower() in (
|
||||
"", "unknown",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_overwrite_list(current_list: List[str], is_hf_model: bool) -> bool:
|
||||
"""Return ``True`` when a list field should be overwritten."""
|
||||
return is_hf_model or not current_list
|
||||
|
||||
@staticmethod
|
||||
def _merge_tags(existing: List[str], new: List[str]) -> List[str]:
|
||||
"""Merge *new* tags into *existing*, all lowercased.
|
||||
|
||||
This matches the behaviour of :class:`TagUpdateService` which
|
||||
normalises every tag to lowercase for case-insensitive dedup.
|
||||
"""
|
||||
merged: List[str] = []
|
||||
seen: set = set()
|
||||
for tag in list(existing) + list(new):
|
||||
t = tag.strip().lower()
|
||||
if t and t not in seen:
|
||||
merged.append(t)
|
||||
seen.add(t)
|
||||
return merged
|
||||
45
py/services/agent/skill_definition.py
Normal file
45
py/services/agent/skill_definition.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Skill definition data structures.
|
||||
|
||||
Each skill is described by a :class:`SkillDefinition` that declares its
|
||||
input/output schemas, whether it needs an LLM call, and what permissions
|
||||
its post-processor has.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillPermissions:
|
||||
"""Declarative permission scope for a skill's post-processor.
|
||||
|
||||
These are auditable constraints — the :class:`AgentService` checks them
|
||||
before invoking the handler. They are defense-in-depth, not a sandbox.
|
||||
"""
|
||||
|
||||
write_metadata: bool = True
|
||||
write_previews: bool = True
|
||||
network_domains: Tuple[str, ...] = ()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillDefinition:
|
||||
"""Immutable description of an agent skill."""
|
||||
|
||||
name: str
|
||||
title: str
|
||||
description: str
|
||||
llm_required: bool
|
||||
input_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
output_schema: Dict[str, Any] = field(default_factory=dict)
|
||||
model_type_filter: Optional[List[str]] = None
|
||||
permissions: SkillPermissions = field(default_factory=SkillPermissions)
|
||||
|
||||
def applies_to_model_type(self, model_type: str) -> bool:
|
||||
"""Return ``True`` if this skill can run on the given model type."""
|
||||
|
||||
if self.model_type_filter is None:
|
||||
return True
|
||||
return model_type in self.model_type_filter
|
||||
184
py/services/agent/skill_registry.py
Normal file
184
py/services/agent/skill_registry.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Discovery and loading of agent skills.
|
||||
|
||||
Skills live in ``py/services/agent/skills/<name>/`` directories. Each
|
||||
directory must contain a ``SKILL.md`` file with YAML frontmatter::
|
||||
|
||||
---
|
||||
name: my_skill
|
||||
title: "My Skill"
|
||||
description: "What this skill does"
|
||||
llm_required: true
|
||||
---
|
||||
|
||||
Prompt template with ``{{variable}}`` placeholders.
|
||||
|
||||
The registry scans the skills directory on first access and caches results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from .skill_definition import SkillDefinition, SkillPermissions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Directory where built-in skills are stored
|
||||
_SKILLS_DIR = Path(__file__).parent / "skills"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frontmatter parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FRONTMATTER_RE = re.compile(
|
||||
r"^---\s*\n(.*?\n)---\s*\n?(.*)", re.DOTALL
|
||||
)
|
||||
|
||||
|
||||
def _parse_skill_file(path: Path) -> tuple[dict, str]:
|
||||
"""Read a ``SKILL.md`` file and return (frontmatter_dict, body_text).
|
||||
|
||||
Raises ``ValueError`` if the file lacks valid YAML frontmatter.
|
||||
"""
|
||||
text = path.read_text(encoding="utf-8")
|
||||
m = _FRONTMATTER_RE.match(text)
|
||||
if not m:
|
||||
raise ValueError(f"Missing or invalid YAML frontmatter in {path}")
|
||||
frontmatter = yaml.safe_load(m.group(1))
|
||||
if not isinstance(frontmatter, dict):
|
||||
raise ValueError(f"Frontmatter in {path} is not a mapping")
|
||||
body = m.group(2).strip()
|
||||
return frontmatter, body
|
||||
|
||||
|
||||
class SkillRegistry:
|
||||
"""Discover and load agent skills from the filesystem."""
|
||||
|
||||
_instance: Optional["SkillRegistry"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, skills_dir: Path = _SKILLS_DIR) -> None:
|
||||
self._skills_dir = skills_dir
|
||||
self._skills: Dict[str, SkillDefinition] = {}
|
||||
self._loaded: bool = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Singleton access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "SkillRegistry":
|
||||
"""Return the lazily-initialised global ``SkillRegistry``."""
|
||||
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
registry = cls()
|
||||
registry._discover()
|
||||
cls._instance = registry
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the cached singleton — primarily for tests."""
|
||||
|
||||
cls._instance = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Discovery
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _discover(self) -> None:
|
||||
"""Scan the skills directory and load all valid skill definitions."""
|
||||
|
||||
self._skills.clear()
|
||||
if not self._skills_dir.is_dir():
|
||||
logger.warning("Skills directory does not exist: %s", self._skills_dir)
|
||||
self._loaded = True
|
||||
return
|
||||
|
||||
for entry in sorted(self._skills_dir.iterdir()):
|
||||
if not entry.is_dir():
|
||||
continue
|
||||
skill_md = entry / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
continue
|
||||
try:
|
||||
definition = self._load_skill_definition(skill_md)
|
||||
if definition is not None:
|
||||
self._skills[definition.name] = definition
|
||||
logger.debug("Loaded skill: %s", definition.name)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load skill from %s: %s", skill_md, exc)
|
||||
|
||||
self._loaded = True
|
||||
logger.info("Discovered %d agent skills", len(self._skills))
|
||||
|
||||
def _load_skill_definition(self, path: Path) -> Optional[SkillDefinition]:
|
||||
"""Parse a ``SKILL.md`` frontmatter into a :class:`SkillDefinition`."""
|
||||
|
||||
try:
|
||||
data, _body = _parse_skill_file(path)
|
||||
except (ValueError, yaml.YAMLError) as exc:
|
||||
logger.warning("Failed to parse SKILL.md %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
if "name" not in data:
|
||||
logger.warning("SKILL.md missing required 'name' field: %s", path)
|
||||
return None
|
||||
|
||||
perm_data = data.get("permissions", {})
|
||||
permissions = SkillPermissions(
|
||||
write_metadata=perm_data.get("write_metadata", True),
|
||||
write_previews=perm_data.get("write_previews", True),
|
||||
network_domains=tuple(perm_data.get("network_domains", [])),
|
||||
)
|
||||
|
||||
return SkillDefinition(
|
||||
name=data["name"],
|
||||
title=data.get("title", data["name"]),
|
||||
description=data.get("description", ""),
|
||||
llm_required=data.get("llm_required", False),
|
||||
input_schema=data.get("input_schema", {}),
|
||||
output_schema=data.get("output_schema", {}),
|
||||
model_type_filter=data.get("model_type_filter"),
|
||||
permissions=permissions,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def list_skills(self) -> List[SkillDefinition]:
|
||||
"""Return all discovered skill definitions."""
|
||||
|
||||
if not self._loaded:
|
||||
self._discover()
|
||||
return list(self._skills.values())
|
||||
|
||||
def get_skill(self, name: str) -> Optional[SkillDefinition]:
|
||||
"""Return the skill definition for ``name``, or ``None`` if not found."""
|
||||
|
||||
if not self._loaded:
|
||||
self._discover()
|
||||
return self._skills.get(name)
|
||||
|
||||
def load_prompt(self, name: str) -> str:
|
||||
"""Load and return the prompt template body from a skill's ``SKILL.md``."""
|
||||
|
||||
skill_dir = self._skills_dir / name
|
||||
skill_path = skill_dir / "SKILL.md"
|
||||
if not skill_path.exists():
|
||||
raise FileNotFoundError(f"SKILL.md not found: {skill_path}")
|
||||
try:
|
||||
_frontmatter, body = _parse_skill_file(skill_path)
|
||||
return body
|
||||
except (ValueError, yaml.YAMLError) as exc:
|
||||
raise ValueError(f"Failed to parse prompt from {skill_path}: {exc}") from exc
|
||||
89
py/services/agent/skills/enrich_hf_metadata/SKILL.md
Normal file
89
py/services/agent/skills/enrich_hf_metadata/SKILL.md
Normal file
@@ -0,0 +1,89 @@
|
||||
---
|
||||
name: enrich_hf_metadata
|
||||
title: "Enrich Metadata from HuggingFace"
|
||||
description: >
|
||||
Parse the HuggingFace model card via LLM to extract description, trigger
|
||||
words, base model, tags, and preview image URL.
|
||||
llm_required: true
|
||||
---
|
||||
|
||||
You are an expert assistant for AI image generation models. Your task is to extract structured metadata from a HuggingFace model card (README.md).
|
||||
|
||||
## Model Information
|
||||
|
||||
- **Repository**: {{hf_url}}
|
||||
- **Model file path**: {{model_path}}
|
||||
- **Repository ID**: {{repo}}
|
||||
|
||||
## Current Metadata (may be incomplete)
|
||||
|
||||
```json
|
||||
{{current_metadata}}
|
||||
```
|
||||
|
||||
## Available Base Models
|
||||
|
||||
The following base models are currently valid in this system:
|
||||
{{base_models}}
|
||||
|
||||
## HuggingFace README Content
|
||||
|
||||
```
|
||||
{{readme_content}}
|
||||
```
|
||||
|
||||
## Extraction Instructions
|
||||
|
||||
Extract the following information from the README content above:
|
||||
|
||||
### base_model
|
||||
The base model this LoRA/checkpoint was trained on. Use EXACTLY one of the names from the **Available Base Models** list above. Do not invent new names or use aliases.
|
||||
|
||||
Check the YAML frontmatter (between --- markers) for `base_model:` first, then look at the description text and safetensors metadata. If you cannot determine it, return an empty string.
|
||||
|
||||
### trigger_words
|
||||
The trigger words or activation prompts needed to use this LoRA. Look for:
|
||||
- `instance_prompt:` in the YAML frontmatter
|
||||
- Phrases like "trigger word:", "trigger:", "use this prompt:", "activation prompt:"
|
||||
- Example prompts at the start (usually the first word or phrase before any description)
|
||||
Return as an array of strings. If none found, return an empty array.
|
||||
|
||||
### description
|
||||
A concise 1-2 sentence summary of what this model does. Extract from the "Model description" section or the first paragraph. Return empty string if the README is too minimal.
|
||||
|
||||
### tags
|
||||
3-8 relevant tags for categorizing this model. Extract from:
|
||||
- The YAML frontmatter `tags:` list (often contains excellent categorization tags)
|
||||
- The model type (e.g. "lora", "checkpoint", "flux", "sdxl")
|
||||
- The style/subject (e.g. "anime", "photorealistic", "style", "character")
|
||||
All lowercase, no spaces. Return empty array if none found.
|
||||
|
||||
### preview_url
|
||||
The URL of the most suitable preview image from the README. Look for image tags (e.g. ``) and the YAML frontmatter `widget:` section (which often has `output.url` fields). Choose the first image that appears to be a generation example (not a logo or diagram). Construct the absolute URL as `https://huggingface.co/{{repo}}/resolve/main/{filename}`. If no suitable image is found, return an empty string.
|
||||
|
||||
### confidence
|
||||
Your confidence level in the extracted data:
|
||||
- "high" — most fields were explicitly stated in the README
|
||||
- "medium" — some fields were inferred from context
|
||||
- "low" — most fields are guesses based on limited information
|
||||
|
||||
## Output Format
|
||||
|
||||
Return ONLY a JSON object with exactly these fields (no markdown fences, no extra text):
|
||||
|
||||
```json
|
||||
{
|
||||
"model_path": "{{model_path}}",
|
||||
"base_model": "<canonical name or empty string>",
|
||||
"trigger_words": ["<word1>", "<word2>"],
|
||||
"description": "<1-2 sentence summary>",
|
||||
"tags": ["<tag1>", "<tag2>"],
|
||||
"preview_url": "<image URL or empty string>",
|
||||
"confidence": "<high|medium|low>"
|
||||
}
|
||||
```
|
||||
|
||||
Important:
|
||||
- Only include the JSON object, no other text
|
||||
- If a field cannot be determined, use an empty string or empty array
|
||||
- Do not fabricate information not supported by the README
|
||||
@@ -791,8 +791,12 @@ class BaseModelService(ABC):
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
async def format_response(self, model_data: Dict) -> Optional[Dict]:
|
||||
"""Format model data for API response - must be implemented by subclasses.
|
||||
|
||||
Subclasses should return None for corrupted entries so the handler
|
||||
layer can filter them out. See issue #730.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
@@ -21,20 +21,37 @@ class CheckpointService(BaseModelService):
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata, update_service=update_service)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
async def format_response(self, checkpoint_data: Dict) -> Optional[Dict]:
|
||||
"""Format Checkpoint data for API response.
|
||||
|
||||
Returns None when the entry is missing critical fields (corrupted cache
|
||||
row), so the handler layer can filter it out. See issue #730.
|
||||
"""
|
||||
# Guard against corrupted cache entries missing critical fields
|
||||
file_path = checkpoint_data.get("file_path")
|
||||
if not file_path or not isinstance(file_path, str):
|
||||
logger.warning(
|
||||
"Skipping corrupted checkpoint entry (missing file_path): %s",
|
||||
checkpoint_data.get("file_name", "<unknown>"),
|
||||
)
|
||||
return None
|
||||
|
||||
# Get sub_type from cache entry (new canonical field)
|
||||
sub_type = checkpoint_data.get("sub_type", "checkpoint")
|
||||
|
||||
|
||||
file_name = checkpoint_data.get("file_name") or ""
|
||||
model_name = checkpoint_data.get("model_name") or file_name
|
||||
folder = checkpoint_data.get("folder") or ""
|
||||
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"model_name": model_name,
|
||||
"file_name": file_name,
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"folder": folder,
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_path": file_path.replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
@@ -49,6 +66,7 @@ class CheckpointService(BaseModelService):
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True),
|
||||
"auto_tags": checkpoint_data.get("auto_tags") or extract_auto_tags(checkpoint_data),
|
||||
"version_count": checkpoint_data.get("version_count"),
|
||||
"hf_url": checkpoint_data.get("hf_url", ""),
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
|
||||
@@ -327,7 +327,7 @@ class CivArchiveClient:
|
||||
if resolved:
|
||||
return resolved, None
|
||||
|
||||
logger.error("Error fetching version of CivArchive model by hash %s", model_hash[:10])
|
||||
logger.debug("Error fetching version of CivArchive model by hash %s", model_hash[:10])
|
||||
return None, "No version data found"
|
||||
|
||||
except RateLimitError:
|
||||
|
||||
@@ -196,6 +196,7 @@ class CivitaiBaseModelService:
|
||||
"ernie": "ERNI",
|
||||
"ernie turbo": "ETRB",
|
||||
"nucleus": "NUCL",
|
||||
"krea 2": "KR2",
|
||||
"svd": "SVD",
|
||||
"ltxv": "LTXV",
|
||||
"ltxv2": "LTV2",
|
||||
@@ -424,6 +425,7 @@ class CivitaiBaseModelService:
|
||||
"Ernie",
|
||||
"Ernie Turbo",
|
||||
"Nucleus",
|
||||
"Krea 2",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from .auto_tag_service import extract_auto_tags
|
||||
@@ -21,20 +21,37 @@ class EmbeddingService(BaseModelService):
|
||||
"""
|
||||
super().__init__("embedding", scanner, EmbeddingMetadata, update_service=update_service)
|
||||
|
||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||
"""Format Embedding data for API response"""
|
||||
async def format_response(self, embedding_data: Dict) -> Optional[Dict]:
|
||||
"""Format Embedding data for API response.
|
||||
|
||||
Returns None when the entry is missing critical fields (corrupted cache
|
||||
row), so the handler layer can filter it out. See issue #730.
|
||||
"""
|
||||
# Guard against corrupted cache entries missing critical fields
|
||||
file_path = embedding_data.get("file_path")
|
||||
if not file_path or not isinstance(file_path, str):
|
||||
logger.warning(
|
||||
"Skipping corrupted embedding entry (missing file_path): %s",
|
||||
embedding_data.get("file_name", "<unknown>"),
|
||||
)
|
||||
return None
|
||||
|
||||
# Get sub_type from cache entry (new canonical field)
|
||||
sub_type = embedding_data.get("sub_type", "embedding")
|
||||
|
||||
|
||||
file_name = embedding_data.get("file_name") or ""
|
||||
model_name = embedding_data.get("model_name") or file_name
|
||||
folder = embedding_data.get("folder") or ""
|
||||
|
||||
return {
|
||||
"model_name": embedding_data["model_name"],
|
||||
"file_name": embedding_data["file_name"],
|
||||
"model_name": model_name,
|
||||
"file_name": file_name,
|
||||
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
|
||||
"base_model": embedding_data.get("base_model", ""),
|
||||
"folder": embedding_data["folder"],
|
||||
"folder": folder,
|
||||
"sha256": embedding_data.get("sha256", ""),
|
||||
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
|
||||
"file_path": file_path.replace(os.sep, "/"),
|
||||
"file_size": embedding_data.get("size", 0),
|
||||
"modified": embedding_data.get("modified", ""),
|
||||
"tags": embedding_data.get("tags", []),
|
||||
@@ -49,6 +66,7 @@ class EmbeddingService(BaseModelService):
|
||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True),
|
||||
"auto_tags": embedding_data.get("auto_tags") or extract_auto_tags(embedding_data),
|
||||
"version_count": embedding_data.get("version_count"),
|
||||
"hf_url": embedding_data.get("hf_url", ""),
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
|
||||
@@ -25,3 +25,21 @@ class ResourceNotFoundError(RuntimeError):
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LLMNotConfiguredError(RuntimeError):
|
||||
"""Raised when an LLM-dependent operation is attempted but no provider is configured."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LLMRateLimitError(RateLimitError):
|
||||
"""Raised when the LLM provider rejects a request due to rate limiting."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LLMResponseError(RuntimeError):
|
||||
"""Raised when the LLM returns an unparseable or schema-invalid response."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
321
py/services/llm_service.py
Normal file
321
py/services/llm_service.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Centralized LLM API client with BYOK (bring-your-own-key) provider support.
|
||||
|
||||
Reads provider configuration from :class:`SettingsManager` and makes
|
||||
OpenAI-compatible ``/chat/completions`` calls. Supports any provider that
|
||||
implements the OpenAI Chat Completions API surface area (OpenAI, Ollama,
|
||||
vLLM, LM Studio, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .errors import LLMNotConfiguredError, LLMRateLimitError, LLMResponseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default API base URLs per provider
|
||||
_PROVIDER_DEFAULTS: Dict[str, str] = {
|
||||
"openai": "https://api.openai.com/v1",
|
||||
"ollama": "http://localhost:11434/v1",
|
||||
# "custom" requires an explicit llm_api_base from the user
|
||||
}
|
||||
|
||||
# Request timeout for LLM calls (seconds)
|
||||
_LLM_TIMEOUT = aiohttp.ClientTimeout(total=120)
|
||||
|
||||
|
||||
class LLMService:
|
||||
"""Centralized LLM API client.
|
||||
|
||||
All agent skills call LLMs through this service so that BYOK config,
|
||||
retry logic, and error handling live in one place.
|
||||
"""
|
||||
|
||||
_instance: Optional["LLMService"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def __init__(self, settings_service) -> None:
|
||||
self._settings = settings_service
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Singleton access
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls) -> "LLMService":
|
||||
"""Return the lazily-initialised global ``LLMService`` instance."""
|
||||
|
||||
if cls._instance is None:
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
cls._instance = cls(get_settings_manager())
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""Reset the cached singleton — primarily for tests."""
|
||||
|
||||
cls._instance = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Configuration helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_config(self) -> Dict[str, Any]:
|
||||
"""Read the current LLM configuration from settings."""
|
||||
|
||||
return {
|
||||
"provider": self._settings.get("llm_provider", "openai"),
|
||||
"api_key": self._settings.get("llm_api_key", ""),
|
||||
"api_base": self._settings.get("llm_api_base", ""),
|
||||
"model": self._settings.get("llm_model", ""),
|
||||
}
|
||||
|
||||
def is_configured(self) -> bool:
|
||||
"""Return ``True`` when the LLM provider is minimally configured.
|
||||
|
||||
A provider is considered configured when ``llm_model`` is set and
|
||||
(for non-Ollama) an API key is configured.
|
||||
"""
|
||||
|
||||
cfg = self._get_config()
|
||||
has_model = bool(cfg["model"])
|
||||
has_key = bool(cfg["api_key"]) or cfg["provider"] == "ollama"
|
||||
return has_model and has_key
|
||||
|
||||
def _resolve_api_base(self, provider: str, api_base: str) -> str:
|
||||
"""Resolve the API base URL for the given provider."""
|
||||
|
||||
if api_base:
|
||||
return api_base.rstrip("/")
|
||||
return _PROVIDER_DEFAULTS.get(provider, "").rstrip("/")
|
||||
|
||||
def _build_headers(self, api_key: str) -> Dict[str, str]:
|
||||
"""Build HTTP headers for the LLM API request."""
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def _ensure_configured(self) -> Dict[str, Any]:
|
||||
"""Validate configuration and return it, or raise.
|
||||
|
||||
A provider is considered configured when ``llm_model`` is set and
|
||||
(for non-Ollama) an API key is configured.
|
||||
"""
|
||||
|
||||
cfg = self._get_config()
|
||||
has_model = bool(cfg["model"])
|
||||
has_key = bool(cfg["api_key"]) or cfg["provider"] == "ollama"
|
||||
if not (has_model and has_key):
|
||||
parts = []
|
||||
if not has_model:
|
||||
parts.append("No LLM model specified")
|
||||
if not has_key and cfg["provider"] != "ollama":
|
||||
parts.append("No LLM API key configured")
|
||||
detail = "; ".join(parts) if parts else "LLM provider is not configured"
|
||||
raise LLMNotConfiguredError(
|
||||
f"{detail}. Configure it in Settings → AI Provider."
|
||||
)
|
||||
return cfg
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core API call
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
*,
|
||||
messages: List[Dict[str, str]],
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
response_format: Optional[Dict[str, Any]] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
retry_on_rate_limit: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Call the configured LLM provider's ``/chat/completions`` endpoint.
|
||||
|
||||
Args:
|
||||
messages: OpenAI-format message list
|
||||
model: Override the configured model name
|
||||
temperature: Sampling temperature
|
||||
response_format: Optional ``{"type": "json_object"}`` for structured output
|
||||
max_tokens: Optional max output tokens
|
||||
retry_on_rate_limit: Retry once after a 429 with backoff
|
||||
|
||||
Returns:
|
||||
Dict with ``content`` (str), ``usage`` (dict), ``model`` (str)
|
||||
|
||||
Raises:
|
||||
LLMNotConfiguredError: Provider not enabled / missing config
|
||||
LLMRateLimitError: Rate limited and retry exhausted
|
||||
LLMResponseError: Non-200 response or parse failure
|
||||
"""
|
||||
|
||||
cfg = self._ensure_configured()
|
||||
api_base = self._resolve_api_base(cfg["provider"], cfg["api_base"])
|
||||
url = f"{api_base}/chat/completions"
|
||||
model_name = model or cfg["model"]
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if response_format is not None:
|
||||
payload["response_format"] = response_format
|
||||
if max_tokens is not None:
|
||||
payload["max_tokens"] = max_tokens
|
||||
|
||||
headers = self._build_headers(cfg["api_key"])
|
||||
|
||||
attempt = 0
|
||||
max_attempts = 2 if retry_on_rate_limit else 1
|
||||
while attempt < max_attempts:
|
||||
attempt += 1
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=_LLM_TIMEOUT) as session:
|
||||
async with session.post(
|
||||
url, json=payload, headers=headers
|
||||
) as resp:
|
||||
if resp.status == 429:
|
||||
if attempt < max_attempts:
|
||||
retry_after = float(
|
||||
resp.headers.get("Retry-After", "5")
|
||||
)
|
||||
logger.warning(
|
||||
"LLM rate limited, retrying after %.1fs",
|
||||
retry_after,
|
||||
)
|
||||
await asyncio.sleep(retry_after)
|
||||
continue
|
||||
raise LLMRateLimitError(
|
||||
f"LLM provider rate limited (HTTP 429)",
|
||||
provider=cfg["provider"],
|
||||
)
|
||||
|
||||
if resp.status != 200:
|
||||
body = await resp.text()
|
||||
raise LLMResponseError(
|
||||
f"LLM API returned HTTP {resp.status}: "
|
||||
f"{body[:500]}"
|
||||
)
|
||||
|
||||
data = await resp.json()
|
||||
|
||||
except aiohttp.ClientError as exc:
|
||||
raise LLMResponseError(f"Network error calling LLM API: {exc}") from exc
|
||||
|
||||
# Parse response
|
||||
try:
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
usage = data.get("usage", {})
|
||||
return {
|
||||
"content": content,
|
||||
"usage": usage,
|
||||
"model": data.get("model", model_name),
|
||||
}
|
||||
except (KeyError, IndexError) as exc:
|
||||
raise LLMResponseError(
|
||||
f"Unexpected LLM response structure: {json.dumps(data)[:500]}"
|
||||
) from exc
|
||||
|
||||
# Should not reach here, but satisfy type checker
|
||||
raise LLMRateLimitError("Rate limit retry exhausted", provider=cfg["provider"])
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Structured output convenience
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat_completion_json(
|
||||
self,
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Call the LLM and return parsed JSON.
|
||||
|
||||
Sends ``response_format: {"type": "json_object"}`` when the provider
|
||||
supports it, and parses the response content as JSON. If parsing
|
||||
fails, retries once with a clarifying system message.
|
||||
|
||||
Args:
|
||||
system_prompt: System-level instructions
|
||||
user_prompt: User-level query
|
||||
model: Override the configured model name
|
||||
temperature: Sampling temperature
|
||||
max_tokens: Optional max output tokens
|
||||
|
||||
Returns:
|
||||
Parsed JSON dict from the LLM response
|
||||
|
||||
Raises:
|
||||
LLMNotConfiguredError: Provider not configured
|
||||
LLMRateLimitError: Rate limited
|
||||
LLMResponseError: JSON parse failure after retry
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# First attempt with JSON mode
|
||||
result = await self.chat_completion(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
response_format={"type": "json_object"},
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(result["content"])
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
logger.warning(
|
||||
"LLM JSON parse failed on first attempt: %s. Retrying.", exc
|
||||
)
|
||||
|
||||
# Retry with explicit instruction to return valid JSON
|
||||
retry_messages = messages + [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": result["content"],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"The previous response could not be parsed as JSON. "
|
||||
"Please respond with ONLY a valid JSON object, no "
|
||||
"markdown fences or extra text."
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
result = await self.chat_completion(
|
||||
messages=retry_messages,
|
||||
model=model,
|
||||
temperature=0.0, # More deterministic for retry
|
||||
response_format={"type": "json_object"},
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(result["content"])
|
||||
except (json.JSONDecodeError, TypeError) as exc:
|
||||
raise LLMResponseError(
|
||||
f"LLM response could not be parsed as JSON after retry: {exc}\n"
|
||||
f"Raw content: {result['content'][:500]}"
|
||||
) from exc
|
||||
@@ -24,23 +24,41 @@ class LoraService(BaseModelService):
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata, update_service=update_service)
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
async def format_response(self, lora_data: Dict) -> Optional[Dict]:
|
||||
"""Format LoRA data for API response.
|
||||
|
||||
Returns None when the entry is missing critical fields (corrupted cache
|
||||
row), so the handler layer can filter it out instead of crashing the
|
||||
whole listing request. See issue #730.
|
||||
"""
|
||||
# Guard against corrupted cache entries missing critical fields
|
||||
file_path = lora_data.get("file_path")
|
||||
if not file_path or not isinstance(file_path, str):
|
||||
logger.warning(
|
||||
"Skipping corrupted LoRA entry (missing file_path): %s",
|
||||
lora_data.get("file_name", "<unknown>"),
|
||||
)
|
||||
return None
|
||||
|
||||
# Resolve sub_type using priority: sub_type > model_type > civitai.model.type > default
|
||||
# Normalize to lowercase for consistent API responses
|
||||
sub_type = resolve_sub_type(lora_data).lower()
|
||||
|
||||
file_name = lora_data.get("file_name") or ""
|
||||
model_name = lora_data.get("model_name") or file_name
|
||||
folder = lora_data.get("folder") or ""
|
||||
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"model_name": model_name,
|
||||
"file_name": file_name,
|
||||
"preview_url": config.get_preview_static_url(
|
||||
lora_data.get("preview_url", "")
|
||||
),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
"folder": folder,
|
||||
"sha256": lora_data.get("sha256", ""),
|
||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||
"file_path": file_path.replace(os.sep, "/"),
|
||||
"file_size": lora_data.get("size", 0),
|
||||
"modified": lora_data.get("modified", ""),
|
||||
"tags": lora_data.get("tags", []),
|
||||
@@ -60,6 +78,7 @@ class LoraService(BaseModelService):
|
||||
),
|
||||
"auto_tags": lora_data.get("auto_tags") or extract_auto_tags(lora_data),
|
||||
"version_count": lora_data.get("version_count"),
|
||||
"hf_url": lora_data.get("hf_url", ""),
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
|
||||
@@ -248,6 +248,7 @@ class ModelScanner:
|
||||
'civitai': civitai_slim,
|
||||
'civitai_deleted': bool(get_value('civitai_deleted', False)),
|
||||
'skip_metadata_refresh': bool(get_value('skip_metadata_refresh', False)),
|
||||
'hf_url': get_value('hf_url', '') or '',
|
||||
}
|
||||
|
||||
license_source: Dict[str, Any] = {}
|
||||
@@ -476,11 +477,20 @@ class ModelScanner:
|
||||
for tag in adjusted_item.get('tags') or []:
|
||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||
|
||||
# Validate cache entries and check health
|
||||
# Validate cache entries and check health.
|
||||
# Always use the validated/repaired entries — even when there are no
|
||||
# invalid entries, auto_repair may have filled in missing optional
|
||||
# fields (model_name, file_name, folder) with safe defaults on a copied
|
||||
# working_entry. Without this unconditional replacement the repaired
|
||||
# copies are discarded and None values propagate to format_response.
|
||||
# See issue #730.
|
||||
valid_entries, invalid_entries = CacheEntryValidator.validate_batch(
|
||||
adjusted_raw_data, auto_repair=True
|
||||
)
|
||||
|
||||
# Always use the validated entries (repaired copies)
|
||||
adjusted_raw_data = valid_entries
|
||||
|
||||
if invalid_entries:
|
||||
monitor = CacheHealthMonitor()
|
||||
report = monitor.check_health(adjusted_raw_data, auto_repair=True)
|
||||
|
||||
@@ -57,6 +57,7 @@ class PersistentModelCache:
|
||||
"db_checked",
|
||||
"last_checked_at",
|
||||
"hash_status",
|
||||
"hf_url",
|
||||
)
|
||||
_MODEL_UPDATE_COLUMNS: Tuple[str, ...] = _MODEL_COLUMNS[2:]
|
||||
_instances: Dict[str, "PersistentModelCache"] = {}
|
||||
@@ -165,8 +166,8 @@ class PersistentModelCache:
|
||||
|
||||
item = {
|
||||
"file_path": file_path,
|
||||
"file_name": row["file_name"],
|
||||
"model_name": row["model_name"],
|
||||
"file_name": row["file_name"] or "",
|
||||
"model_name": row["model_name"] or "",
|
||||
"folder": row["folder"] or "",
|
||||
"size": row["size"] or 0,
|
||||
"modified": row["modified"] or 0.0,
|
||||
@@ -188,6 +189,7 @@ class PersistentModelCache:
|
||||
"skip_metadata_refresh": bool(row["skip_metadata_refresh"]),
|
||||
"license_flags": int(license_value),
|
||||
"hash_status": row["hash_status"] or "completed",
|
||||
"hf_url": row["hf_url"] or "",
|
||||
}
|
||||
raw_data.append(item)
|
||||
|
||||
@@ -452,6 +454,7 @@ class PersistentModelCache:
|
||||
db_checked INTEGER,
|
||||
last_checked_at REAL,
|
||||
hash_status TEXT,
|
||||
hf_url TEXT DEFAULT '',
|
||||
PRIMARY KEY (model_type, file_path)
|
||||
);
|
||||
|
||||
@@ -500,6 +503,7 @@ class PersistentModelCache:
|
||||
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
|
||||
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
|
||||
"hash_status": "TEXT DEFAULT 'completed'",
|
||||
"hf_url": "TEXT DEFAULT ''",
|
||||
}
|
||||
|
||||
for column, definition in required_columns.items():
|
||||
@@ -548,19 +552,19 @@ class PersistentModelCache:
|
||||
return (
|
||||
model_type,
|
||||
item.get("file_path"),
|
||||
item.get("file_name"),
|
||||
item.get("model_name"),
|
||||
item.get("folder"),
|
||||
item.get("file_name") or "",
|
||||
item.get("model_name") or "",
|
||||
item.get("folder") or "",
|
||||
int(item.get("size") or 0),
|
||||
float(item.get("modified") or 0.0),
|
||||
(item.get("sha256") or "").lower() or None,
|
||||
item.get("base_model"),
|
||||
item.get("preview_url"),
|
||||
item.get("base_model") or "",
|
||||
item.get("preview_url") or "",
|
||||
int(item.get("preview_nsfw_level") or 0),
|
||||
1 if item.get("from_civitai", True) else 0,
|
||||
1 if item.get("favorite") else 0,
|
||||
item.get("notes"),
|
||||
item.get("usage_tips"),
|
||||
item.get("notes") or "",
|
||||
item.get("usage_tips") or "",
|
||||
metadata_source,
|
||||
civitai.get("id"),
|
||||
civitai.get("modelId"),
|
||||
@@ -575,6 +579,7 @@ class PersistentModelCache:
|
||||
1 if item.get("db_checked") else 0,
|
||||
float(item.get("last_checked_at") or 0.0),
|
||||
item.get("hash_status", "completed"),
|
||||
item.get("hf_url") or "",
|
||||
)
|
||||
|
||||
def _insert_model_sql(self) -> str:
|
||||
|
||||
@@ -107,6 +107,11 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
|
||||
"backup_retention_count": 5,
|
||||
"use_new_license_icons": True,
|
||||
"group_by_model": False,
|
||||
# AI / LLM provider configuration (BYOK)
|
||||
"llm_provider": "openai", # "openai" | "ollama" | "custom"
|
||||
"llm_api_key": "",
|
||||
"llm_api_base": "", # empty = provider default
|
||||
"llm_model": "", # e.g. "gpt-4o-mini"
|
||||
}
|
||||
|
||||
|
||||
@@ -873,6 +878,23 @@ class SettingsManager:
|
||||
self.settings["civitai_api_key"] = env_api_key
|
||||
self._save_settings()
|
||||
|
||||
# LLM provider overrides
|
||||
llm_env_map = {
|
||||
"LLM_API_KEY": "llm_api_key",
|
||||
"LLM_MODEL": "llm_model",
|
||||
"LLM_API_BASE": "llm_api_base",
|
||||
"LLM_PROVIDER": "llm_provider",
|
||||
}
|
||||
llm_changed = False
|
||||
for env_var, settings_key in llm_env_map.items():
|
||||
env_val = os.environ.get(env_var)
|
||||
if env_val:
|
||||
logger.info("Found %s environment variable", env_var)
|
||||
self.settings[settings_key] = env_val
|
||||
llm_changed = True
|
||||
if llm_changed:
|
||||
self._save_settings()
|
||||
|
||||
def _default_settings_actions(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
@@ -1568,7 +1590,7 @@ class SettingsManager:
|
||||
previous_dir = os.path.dirname(previous_path) or target_dir
|
||||
|
||||
if os.path.abspath(previous_path) != os.path.abspath(target_path):
|
||||
self._copy_model_cache_directory(previous_dir, target_dir)
|
||||
self._migrate_settings_directory_content(previous_dir, target_dir)
|
||||
logger.info("Switching settings file to: %s", target_path)
|
||||
|
||||
self._pending_portable_switch = {"other_path": other_path}
|
||||
@@ -1603,46 +1625,52 @@ class SettingsManager:
|
||||
finally:
|
||||
self._pending_portable_switch = None
|
||||
|
||||
def _copy_model_cache_directory(self, source_dir: str, target_dir: str) -> None:
|
||||
"""Copy model_cache artifacts when switching storage locations."""
|
||||
def _migrate_settings_directory_content(
|
||||
self, source_dir: str, target_dir: str
|
||||
) -> None:
|
||||
"""Migrate settings directory subdirectories when switching storage locations.
|
||||
|
||||
Copies the canonical subdirectories (cache, backups, logs, stats, wildcards)
|
||||
from the old settings directory to the new one. Legacy cache artifacts
|
||||
(model_cache, recipe_cache, etc.) are migrated lazily by
|
||||
``resolve_cache_path_with_migration`` on first access.
|
||||
|
||||
Args:
|
||||
source_dir: The previous settings directory path.
|
||||
target_dir: The new settings directory path.
|
||||
"""
|
||||
|
||||
if not source_dir or not target_dir:
|
||||
return
|
||||
|
||||
source_cache_dir = os.path.join(source_dir, "model_cache")
|
||||
target_cache_dir = os.path.join(target_dir, "model_cache")
|
||||
if os.path.isdir(source_cache_dir) and os.path.abspath(
|
||||
source_cache_dir
|
||||
) != os.path.abspath(target_cache_dir):
|
||||
try:
|
||||
shutil.copytree(
|
||||
source_cache_dir,
|
||||
target_cache_dir,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("*.sqlite-shm", "*.sqlite-wal"),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to copy model_cache directory from %s to %s: %s",
|
||||
source_cache_dir,
|
||||
target_cache_dir,
|
||||
exc,
|
||||
)
|
||||
def _copy_dir(name: str) -> None:
|
||||
source = os.path.join(source_dir, name)
|
||||
target = os.path.join(target_dir, name)
|
||||
if os.path.isdir(source) and os.path.abspath(source) != os.path.abspath(
|
||||
target
|
||||
):
|
||||
try:
|
||||
shutil.copytree(
|
||||
source,
|
||||
target,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("*.sqlite-shm", "*.sqlite-wal"),
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to copy directory %s from %s to %s: %s",
|
||||
name,
|
||||
source,
|
||||
target,
|
||||
exc,
|
||||
)
|
||||
|
||||
source_cache_file = os.path.join(source_dir, "model_cache.sqlite")
|
||||
target_cache_file = os.path.join(target_dir, "model_cache.sqlite")
|
||||
if os.path.isfile(source_cache_file) and os.path.abspath(
|
||||
source_cache_file
|
||||
) != os.path.abspath(target_cache_file):
|
||||
try:
|
||||
shutil.copy2(source_cache_file, target_cache_file)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to copy model_cache.sqlite from %s to %s: %s",
|
||||
source_cache_file,
|
||||
target_cache_file,
|
||||
exc,
|
||||
)
|
||||
# Managed subdirectories under settings_dir
|
||||
_copy_dir("cache")
|
||||
_copy_dir("backups")
|
||||
_copy_dir("logs")
|
||||
_copy_dir("stats")
|
||||
_copy_dir("wildcards")
|
||||
|
||||
def _get_user_config_directory(self) -> str:
|
||||
"""Return the user configuration directory, falling back to ~/.config."""
|
||||
|
||||
@@ -47,6 +47,20 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
||||
"videos": [".mp4", ".webm"],
|
||||
}
|
||||
|
||||
# Model weight file extensions recognised by scanners.
|
||||
# This is the union of all scanner extensions (lora, checkpoint, embedding).
|
||||
MODEL_FILE_EXTENSIONS = {
|
||||
".safetensors",
|
||||
".ckpt",
|
||||
".pt",
|
||||
".pt2",
|
||||
".bin",
|
||||
".pth",
|
||||
".pkl",
|
||||
".sft",
|
||||
".gguf",
|
||||
}
|
||||
|
||||
# Valid sub-types for each scanner type
|
||||
VALID_LORA_SUB_TYPES = ["lora", "locon", "dora"]
|
||||
VALID_CHECKPOINT_SUB_TYPES = ["checkpoint", "diffusion_model"]
|
||||
@@ -215,5 +229,6 @@ SUPPORTED_DOWNLOAD_SKIP_BASE_MODELS = frozenset(
|
||||
"Ernie",
|
||||
"Ernie Turbo",
|
||||
"Nucleus",
|
||||
"Krea 2",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -35,6 +35,9 @@ class BaseModelMetadata:
|
||||
metadata_source: Optional[str] = None # Last provider that supplied metadata
|
||||
last_checked_at: float = 0 # Last checked timestamp
|
||||
hash_status: str = "completed" # Hash calculation status: pending | calculating | completed | failed
|
||||
trainedWords: List[str] = field(
|
||||
default_factory=list
|
||||
) # Trigger words / activation prompts (source-agnostic)
|
||||
_unknown_fields: Dict[str, Any] = field(
|
||||
default_factory=dict, repr=False, compare=False
|
||||
) # Store unknown fields
|
||||
@@ -47,6 +50,9 @@ class BaseModelMetadata:
|
||||
if self.tags is None:
|
||||
self.tags = []
|
||||
|
||||
if self.trainedWords is None:
|
||||
self.trainedWords = []
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "BaseModelMetadata":
|
||||
"""Create instance from dictionary"""
|
||||
|
||||
@@ -444,16 +444,161 @@
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.base-model-selector {
|
||||
width: 100%;
|
||||
padding: 3px 5px;
|
||||
/* ── Base Model Search Dropdown ─────────────────────────────────────────── */
|
||||
|
||||
.base-model-search-wrapper {
|
||||
position: relative;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
z-index: 100;
|
||||
}
|
||||
|
||||
.base-model-search-input-wrapper {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
background: var(--bg-color);
|
||||
border: 1px solid var(--lora-accent);
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: 0 6px;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.base-model-search-input-wrapper .search-icon {
|
||||
color: var(--text-color);
|
||||
opacity: 0.45;
|
||||
font-size: 12px;
|
||||
flex-shrink: 0;
|
||||
pointer-events: none;
|
||||
/* Reset global .search-icon rules from search-filter.css */
|
||||
position: static;
|
||||
right: auto;
|
||||
top: auto;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.base-model-search-input {
|
||||
flex: 1;
|
||||
background: transparent;
|
||||
border: none;
|
||||
outline: none;
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
outline: none;
|
||||
margin-right: var(--space-1);
|
||||
padding: 3px 0;
|
||||
width: 100%;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.base-model-search-input::placeholder {
|
||||
color: var(--text-color);
|
||||
opacity: 0.35;
|
||||
}
|
||||
|
||||
.base-model-dropdown {
|
||||
position: absolute;
|
||||
top: 100%;
|
||||
left: -1px;
|
||||
right: -1px;
|
||||
max-height: 270px;
|
||||
overflow-y: auto;
|
||||
background: var(--bg-color);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-top: none;
|
||||
border-radius: 0 0 var(--border-radius-xs) var(--border-radius-xs);
|
||||
box-shadow: 0 8px 24px rgba(0, 0, 0, 0.22);
|
||||
z-index: 101;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .base-model-dropdown {
|
||||
box-shadow: 0 8px 28px rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
/* Dropdown scrollbar styling */
|
||||
.base-model-dropdown::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.base-model-dropdown::-webkit-scrollbar-thumb {
|
||||
background: var(--lora-border);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.base-model-dropdown::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
/* Section */
|
||||
.base-model-dropdown-section {
|
||||
border-bottom: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
.base-model-dropdown-section:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
/* Section header */
|
||||
.base-model-dropdown-header {
|
||||
padding: 5px 10px;
|
||||
font-size: 0.72em;
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.08em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.5;
|
||||
background: var(--surface-subtle);
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.base-model-dropdown-header.suggested-header {
|
||||
color: var(--lora-accent);
|
||||
opacity: 1;
|
||||
background: oklch(from var(--lora-accent) l c h / 0.08);
|
||||
}
|
||||
|
||||
.base-model-dropdown-header.suggested-header i {
|
||||
margin-right: 4px;
|
||||
font-size: 0.85em;
|
||||
}
|
||||
|
||||
/* Dropdown items */
|
||||
.base-model-dropdown-item {
|
||||
padding: 5px 12px;
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
color: var(--text-color);
|
||||
transition: background 0.1s;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.base-model-dropdown-item:hover {
|
||||
background: oklch(from var(--lora-accent) l c h / 0.1);
|
||||
}
|
||||
|
||||
.base-model-dropdown-item.active {
|
||||
background: oklch(from var(--lora-accent) l c h / 0.16);
|
||||
}
|
||||
|
||||
.base-model-dropdown-item.selected {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.base-model-dropdown-item.selected::after {
|
||||
content: '✓';
|
||||
float: right;
|
||||
color: var(--lora-accent);
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
/* Empty state */
|
||||
.base-model-dropdown-empty {
|
||||
padding: 18px 12px;
|
||||
text-align: center;
|
||||
color: var(--text-color);
|
||||
opacity: 0.4;
|
||||
font-size: 0.88em;
|
||||
}
|
||||
|
||||
.size-wrapper {
|
||||
|
||||
@@ -821,4 +821,66 @@
|
||||
|
||||
[data-theme="dark"] .batch-preview-item {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
}
|
||||
|
||||
.hf-badge {
|
||||
display: inline-block;
|
||||
padding: 1px 6px;
|
||||
border-radius: 8px;
|
||||
background: oklch(0.55 0.12 250 / 0.15);
|
||||
color: oklch(0.7 0.12 250);
|
||||
font-size: 0.75em;
|
||||
font-weight: 600;
|
||||
margin-left: 4px;
|
||||
}
|
||||
|
||||
|
||||
/* Checkbox inside HF batch preview items */
|
||||
.batch-preview-checkbox {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
cursor: pointer;
|
||||
accent-color: var(--lora-accent);
|
||||
flex-shrink: 0;
|
||||
padding: 0;
|
||||
border: none;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* Select All toolbar in batch preview */
|
||||
.batch-preview-select-all {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 8px 12px;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
background: var(--lora-surface);
|
||||
cursor: pointer;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.batch-preview-select-all input[type="checkbox"] {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
cursor: pointer;
|
||||
accent-color: var(--lora-accent);
|
||||
flex-shrink: 0;
|
||||
padding: 0;
|
||||
border: none;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.batch-preview-select-all label {
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
color: var(--text-color);
|
||||
font-weight: 500;
|
||||
margin: 0;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .batch-preview-select-all {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
@@ -190,6 +190,12 @@ export const DOWNLOAD_ENDPOINTS = {
|
||||
exampleImages: '/api/lm/force-download-example-images' // New endpoint for downloading example images
|
||||
};
|
||||
|
||||
// Hugging Face API endpoints
|
||||
export const HF_ENDPOINTS = {
|
||||
repoFiles: '/api/lm/hf-repo-files',
|
||||
download: '/api/lm/download-hf-model',
|
||||
};
|
||||
|
||||
// WebSocket endpoints
|
||||
export const WS_ENDPOINTS = {
|
||||
fetchProgress: '/ws/fetch-progress'
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
getCurrentModelType,
|
||||
isValidModelType,
|
||||
DOWNLOAD_ENDPOINTS,
|
||||
HF_ENDPOINTS,
|
||||
WS_ENDPOINTS
|
||||
} from './apiConfig.js';
|
||||
import { resetAndReload } from './modelApiFactory.js';
|
||||
@@ -1243,6 +1244,48 @@ export class BaseModelApiClient {
|
||||
}
|
||||
}
|
||||
|
||||
async fetchHfRepoFiles(repo, revision = 'main') {
|
||||
try {
|
||||
const params = new URLSearchParams({ repo, revision });
|
||||
const response = await fetch(`${HF_ENDPOINTS.repoFiles}?${params}`);
|
||||
if (!response.ok) {
|
||||
const err = await response.json().catch(() => ({}));
|
||||
throw new Error(err.error || 'Failed to fetch HF repo files');
|
||||
}
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('Error fetching HF repo files:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async downloadHfModel({ repo, filename, revision, modelRoot, relativePath, useDefaultPaths, download_id }) {
|
||||
try {
|
||||
const response = await fetch(HF_ENDPOINTS.download, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
repo,
|
||||
filename,
|
||||
revision: revision || 'main',
|
||||
model_root: modelRoot,
|
||||
relative_path: relativePath || '',
|
||||
use_default_paths: useDefaultPaths || false,
|
||||
...(download_id ? { download_id } : {}),
|
||||
})
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
|
||||
return await response.json();
|
||||
} catch (error) {
|
||||
console.error('Error downloading HF model:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
_buildQueryParams(baseParams, pageState) {
|
||||
const params = new URLSearchParams(baseParams);
|
||||
const isExcludedView = pageState.viewMode === 'excluded';
|
||||
|
||||
@@ -274,6 +274,9 @@ export class BulkContextMenu extends BaseContextMenu {
|
||||
case 'resume-metadata-refresh':
|
||||
bulkManager.setSkipMetadataRefresh(false);
|
||||
break;
|
||||
case 'enrich-hf-agent-bulk':
|
||||
this.enrichBulkWithAgent();
|
||||
break;
|
||||
case 'delete-all':
|
||||
bulkManager.showBulkDeleteModal();
|
||||
break;
|
||||
@@ -363,4 +366,66 @@ export class BulkContextMenu extends BaseContextMenu {
|
||||
console.error('Bulk download example images failed:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enrich metadata for selected models via LLM agent skill.
|
||||
*/
|
||||
async enrichBulkWithAgent() {
|
||||
if (state.selectedModels.size === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { agentManager } = await import('../../managers/AgentManager.js');
|
||||
|
||||
// Check if LLM is configured
|
||||
const configured = await agentManager.isLlmConfigured();
|
||||
if (!configured) {
|
||||
showToast('toast.agent.llmNotConfigured', {}, 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
const modelPaths = [...state.selectedModels];
|
||||
|
||||
// Connect WebSocket for progress
|
||||
agentManager.connect();
|
||||
|
||||
// Set up one-time completion handler
|
||||
const onComplete = (data) => {
|
||||
const idx = agentManager.completeCallbacks.indexOf(onComplete);
|
||||
if (idx >= 0) agentManager.completeCallbacks.splice(idx, 1);
|
||||
|
||||
if (data.status === 'completed') {
|
||||
showToast(
|
||||
'toast.agent.enrichComplete',
|
||||
{ summary: data.summary || 'Done' },
|
||||
'success'
|
||||
);
|
||||
// Soft reload to reflect updated metadata
|
||||
window.location.reload();
|
||||
} else if (data.status === 'error') {
|
||||
showToast(
|
||||
'toast.agent.enrichFailed',
|
||||
{ error: data.error || 'Unknown error' },
|
||||
'error'
|
||||
);
|
||||
}
|
||||
};
|
||||
agentManager.onComplete(onComplete);
|
||||
|
||||
showToast(
|
||||
'toast.agent.enrichStarted',
|
||||
{ count: modelPaths.length },
|
||||
'info'
|
||||
);
|
||||
|
||||
try {
|
||||
await agentManager.executeSkill('enrich_hf_metadata', modelPaths);
|
||||
} catch (error) {
|
||||
showToast(
|
||||
'toast.agent.enrichFailed',
|
||||
{ error: error.message },
|
||||
'error'
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||
import { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
||||
import { getModelApiClient, resetAndReload } from '../../api/modelApiFactory.js';
|
||||
import { copyLoraSyntax, sendLoraToWorkflow, buildLoraSyntax } from '../../utils/uiHelpers.js';
|
||||
import { copyLoraSyntax, sendLoraToWorkflow, buildLoraSyntax, showToast } from '../../utils/uiHelpers.js';
|
||||
import { showExcludeModal, showDeleteModal } from '../../utils/modalUtils.js';
|
||||
import { moveManager } from '../../managers/MoveManager.js';
|
||||
|
||||
@@ -63,6 +63,9 @@ export class LoraContextMenu extends BaseContextMenu {
|
||||
case 'refresh-metadata':
|
||||
getModelApiClient().refreshSingleModelMetadata(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'enrich-hf-agent':
|
||||
this.enrichWithAgent(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
case 'exclude':
|
||||
showExcludeModal(this.currentCard.dataset.filepath);
|
||||
break;
|
||||
@@ -72,6 +75,46 @@ export class LoraContextMenu extends BaseContextMenu {
|
||||
}
|
||||
}
|
||||
|
||||
async enrichWithAgent(filePath) {
|
||||
const { agentManager } = await import('../../managers/AgentManager.js');
|
||||
|
||||
// Check if LLM is configured
|
||||
const configured = await agentManager.isLlmConfigured();
|
||||
if (!configured) {
|
||||
showToast('toast.agent.llmNotConfigured', {}, 'warning');
|
||||
return;
|
||||
}
|
||||
|
||||
// Connect WebSocket for progress
|
||||
agentManager.connect();
|
||||
|
||||
// Set up one-time completion handler
|
||||
const onComplete = (data) => {
|
||||
const idx = agentManager.completeCallbacks.indexOf(onComplete);
|
||||
if (idx >= 0) agentManager.completeCallbacks.splice(idx, 1);
|
||||
|
||||
if (data.status === 'completed') {
|
||||
showToast('toast.agent.enrichComplete', { summary: data.summary || 'Done' }, 'success');
|
||||
// Soft reload to reflect updated metadata
|
||||
if (typeof resetAndReload === 'function') {
|
||||
resetAndReload();
|
||||
}
|
||||
} else if (data.status === 'error') {
|
||||
showToast('toast.agent.enrichFailed', { error: data.error || 'Unknown error' }, 'error');
|
||||
}
|
||||
};
|
||||
agentManager.onComplete(onComplete);
|
||||
|
||||
// Show progress toast
|
||||
showToast('toast.agent.enrichStarted', {}, 'info');
|
||||
|
||||
try {
|
||||
await agentManager.executeSkill('enrich_hf_metadata', [filePath]);
|
||||
} catch (error) {
|
||||
showToast('toast.agent.enrichFailed', { error: error.message }, 'error');
|
||||
}
|
||||
}
|
||||
|
||||
sendLoraToWorkflow(replaceMode) {
|
||||
const card = this.currentCard;
|
||||
const usageTips = JSON.parse(card.dataset.usage_tips || '{}');
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { showToast, openCivitai, copyToClipboard, copyLoraSyntax, sendLoraToWorkflow, sendEmbeddingToWorkflow, openExampleImagesFolder, buildLoraSyntax, sendModelPathToWorkflow } from '../../utils/uiHelpers.js';
|
||||
import { showToast, openCivitai, openHuggingFace, copyToClipboard, copyLoraSyntax, sendLoraToWorkflow, sendEmbeddingToWorkflow, openExampleImagesFolder, buildLoraSyntax, sendModelPathToWorkflow } from '../../utils/uiHelpers.js';
|
||||
import { state, getCurrentPageState } from '../../state/index.js';
|
||||
import { showModelModal } from './ModelModal.js';
|
||||
import { toggleShowcase } from './showcase/ShowcaseView.js';
|
||||
@@ -66,6 +66,8 @@ function handleModelCardEvent_internal(event, modelType) {
|
||||
event.stopPropagation();
|
||||
if (card.dataset.from_civitai === 'true') {
|
||||
openCivitai(card.dataset.filepath);
|
||||
} else if (card.dataset.hf_url) {
|
||||
openHuggingFace(card.dataset.hf_url);
|
||||
}
|
||||
return true; // Stop propagation
|
||||
}
|
||||
@@ -313,6 +315,7 @@ async function showModelModalFromCard(card, modelType) {
|
||||
modified: card.dataset.modified,
|
||||
file_size: parseInt(card.dataset.file_size || '0'),
|
||||
from_civitai: card.dataset.from_civitai === 'true',
|
||||
hf_url: card.dataset.hf_url || '',
|
||||
base_model: card.dataset.base_model,
|
||||
notes: card.dataset.notes || '',
|
||||
favorite: card.dataset.favorite === 'true',
|
||||
@@ -401,6 +404,7 @@ function showExampleAccessModal(card, modelType) {
|
||||
modified: card.dataset.modified,
|
||||
file_size: card.dataset.file_size,
|
||||
from_civitai: card.dataset.from_civitai === 'true',
|
||||
hf_url: card.dataset.hf_url || '',
|
||||
base_model: card.dataset.base_model,
|
||||
notes: card.dataset.notes,
|
||||
favorite: card.dataset.favorite === 'true',
|
||||
@@ -467,6 +471,7 @@ export function createModelCard(model, modelType) {
|
||||
card.dataset.base_model = model.base_model || 'Unknown';
|
||||
card.dataset.favorite = model.favorite ? 'true' : 'false';
|
||||
card.dataset.exclude = model.exclude ? 'true' : 'false';
|
||||
card.dataset.hf_url = model.hf_url || '';
|
||||
const hasUpdateAvailable = Boolean(model.update_available);
|
||||
card.dataset.update_available = hasUpdateAvailable ? 'true' : 'false';
|
||||
card.dataset.skip_metadata_refresh = model.skip_metadata_refresh ? 'true' : 'false';
|
||||
@@ -578,7 +583,10 @@ export function createModelCard(model, modelType) {
|
||||
translate('modelCard.actions.addToFavorites', {}, 'Add to favorites');
|
||||
const globeTitle = model.from_civitai ?
|
||||
translate('modelCard.actions.viewOnCivitai', {}, 'View on Civitai') :
|
||||
translate('modelCard.actions.notAvailableFromCivitai', {}, 'Not available from Civitai');
|
||||
model.hf_url ?
|
||||
translate('modelCard.actions.viewOnHuggingFace', {}, 'View on Hugging Face') :
|
||||
translate('modelCard.actions.notAvailableFromCivitai', {}, 'Not available from Civitai');
|
||||
const globeEnabled = model.from_civitai || !!model.hf_url;
|
||||
let sendTitle;
|
||||
let copyTitle;
|
||||
if (modelType === MODEL_TYPES.LORA) {
|
||||
@@ -603,7 +611,7 @@ export function createModelCard(model, modelType) {
|
||||
</i>
|
||||
<i class="fas fa-globe"
|
||||
title="${globeTitle}"
|
||||
${!model.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
|
||||
${!globeEnabled ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
|
||||
</i>
|
||||
<i class="fas fa-paper-plane"
|
||||
title="${sendTitle}">
|
||||
|
||||
@@ -3,9 +3,75 @@
|
||||
* Handles model metadata editing functionality - General version
|
||||
*/
|
||||
|
||||
import { BASE_MODEL_CATEGORIES } from '../../utils/constants.js';
|
||||
import { BASE_MODEL_CATEGORIES, getMergedBaseModels } from '../../utils/constants.js';
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { getModelApiClient } from '../../api/modelApiFactory.js';
|
||||
import { translate } from '../../utils/i18nHelpers.js';
|
||||
|
||||
// ── Filename-based base model inference ──────────────────────────────────────
|
||||
// Rules are ordered by specificity — first match wins for dedup.
|
||||
// Each rule checks the filename (lowercased) for a regex pattern and suggests
|
||||
// the associated base model values.
|
||||
|
||||
const BASE_MODEL_FILENAME_RULES = [
|
||||
{ pattern: /flux\.?\s*2\s*klein/i, models: ['Flux.2 Klein 9B', 'Flux.2 Klein 9B-base', 'Flux.2 Klein 4B', 'Flux.2 Klein 4B-base'] },
|
||||
{ pattern: /flux\.?\s*2/i, models: ['Flux.2 D', 'Flux.2 Klein 9B', 'Flux.2 Klein 4B'] },
|
||||
{ pattern: /flux\.?\s*1\s*(dev|d)\b/i, models: ['Flux.1 D'] },
|
||||
{ pattern: /flux\.?\s*1\s*(schnell|s)\b/i, models: ['Flux.1 S'] },
|
||||
{ pattern: /flux/i, models: ['Flux.1 D', 'Flux.1 S', 'Flux.2 D'] },
|
||||
{ pattern: /sdxl/i, models: ['SDXL 1.0', 'SDXL Lightning', 'SDXL Hyper'] },
|
||||
{ pattern: /sd\s*1[._-\s]?5/i, models: ['SD 1.5'] },
|
||||
{ pattern: /sd\s*1[._-\s]?4/i, models: ['SD 1.4'] },
|
||||
{ pattern: /sd\s*1/i, models: ['SD 1.5', 'SD 1.4', 'SD 1.5 LCM', 'SD 1.5 Hyper'] },
|
||||
{ pattern: /sd\s*3[._-\s]?5/i, models: ['SD 3.5', 'SD 3.5 Medium', 'SD 3.5 Large', 'SD 3.5 Large Turbo'] },
|
||||
{ pattern: /sd\s*3/i, models: ['SD 3', 'SD 3.5'] },
|
||||
{ pattern: /wan\s*\.?\s*video/i, models: ['Wan Video', 'Wan Video 1.3B t2v', 'Wan Video 14B t2v', 'Wan Video 14B i2v 480p', 'Wan Video 14B i2v 720p'] },
|
||||
{ pattern: /hunyuan\s*\.?\s*video/i, models: ['Hunyuan Video'] },
|
||||
{ pattern: /ltxv/i, models: ['LTXV', 'LTXV2', 'LTXV 2.3'] },
|
||||
{ pattern: /cogvideo/i, models: ['CogVideoX'] },
|
||||
{ pattern: /pony/i, models: ['Pony', 'Pony V7'] },
|
||||
{ pattern: /illustrious/i, models: ['Illustrious'] },
|
||||
{ pattern: /noobai/i, models: ['NoobAI'] },
|
||||
{ pattern: /pixart/i, models: ['PixArt a', 'PixArt E'] },
|
||||
{ pattern: /aura\s*\.?\s*flow/i, models: ['AuraFlow'] },
|
||||
{ pattern: /kolors/i, models: ['Kolors'] },
|
||||
{ pattern: /hunyuan\s*1/i, models: ['Hunyuan 1'] },
|
||||
{ pattern: /lumina/i, models: ['Lumina'] },
|
||||
{ pattern: /hidream/i, models: ['HiDream'] },
|
||||
{ pattern: /qwen/i, models: ['Qwen'] },
|
||||
{ pattern: /chroma/i, models: ['Chroma'] },
|
||||
{ pattern: /anima/i, models: ['Anima'] },
|
||||
{ pattern: /sd\s*2[._-\s]?[01]/i, models: ['SD 2.0', 'SD 2.1'] },
|
||||
{ pattern: /mochi/i, models: ['Mochi'] },
|
||||
{ pattern: /svd/i, models: ['SVD'] },
|
||||
{ pattern: /zimage/i, models: ['ZImageTurbo', 'ZImageBase'] },
|
||||
{ pattern: /nucleus/i, models: ['Nucleus'] },
|
||||
{ pattern: /krea/i, models: ['Flux.1 Krea', 'Krea 2'] },
|
||||
{ pattern: /ernie/i, models: ['Ernie', 'Ernie Turbo'] },
|
||||
];
|
||||
|
||||
/**
|
||||
* Infer likely base model(s) from a filename + model name string.
|
||||
* Returns a deduplicated array in match-priority order.
|
||||
* @param {string} filename
|
||||
* @returns {string[]}
|
||||
*/
|
||||
function inferBaseModelsFromFilename(filename) {
|
||||
if (!filename || typeof filename !== 'string') return [];
|
||||
const seen = new Set();
|
||||
const results = [];
|
||||
for (const rule of BASE_MODEL_FILENAME_RULES) {
|
||||
if (rule.pattern.test(filename)) {
|
||||
for (const model of rule.models) {
|
||||
if (!seen.has(model)) {
|
||||
seen.add(model);
|
||||
results.push(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the active file path for the currently open model modal.
|
||||
@@ -226,7 +292,9 @@ export function setupModelNameEditing(filePath) {
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up base model editing functionality
|
||||
* Set up base model editing functionality with searchable dropdown
|
||||
* Shows filename-inferred suggestions at the top, supports keyboard navigation,
|
||||
* and allows typing custom values.
|
||||
* @param {string} filePath - File path
|
||||
*/
|
||||
export function setupBaseModelEditing(filePath) {
|
||||
@@ -257,98 +325,251 @@ export function setupBaseModelEditing(filePath) {
|
||||
// Store the original value to check for changes later
|
||||
const originalValue = baseModelContent.textContent.trim();
|
||||
|
||||
// Create dropdown selector to replace the base model content
|
||||
const currentValue = originalValue;
|
||||
const dropdown = document.createElement('select');
|
||||
dropdown.className = 'base-model-selector';
|
||||
// ── Build the full option list ────────────────────────────────────────
|
||||
const allModels = []; // { value, label, category }
|
||||
const categorizedModels = new Set();
|
||||
|
||||
// Flag to track if a change was made
|
||||
let valueChanged = false;
|
||||
|
||||
// Add options from BASE_MODEL_CATEGORIES constants
|
||||
const baseModelCategories = BASE_MODEL_CATEGORIES;
|
||||
|
||||
// Create option groups for better organization
|
||||
Object.entries(baseModelCategories).forEach(([category, models]) => {
|
||||
const group = document.createElement('optgroup');
|
||||
group.label = category;
|
||||
|
||||
Object.entries(BASE_MODEL_CATEGORIES).forEach(([category, models]) => {
|
||||
models.forEach(model => {
|
||||
const option = document.createElement('option');
|
||||
option.value = model;
|
||||
option.textContent = model;
|
||||
option.selected = model === currentValue;
|
||||
group.appendChild(option);
|
||||
allModels.push({ value: model, label: model, category });
|
||||
categorizedModels.add(model);
|
||||
});
|
||||
});
|
||||
|
||||
const mergedModels = getMergedBaseModels();
|
||||
const uncategorizedModels = mergedModels.filter(model => !categorizedModels.has(model));
|
||||
if (uncategorizedModels.length > 0) {
|
||||
uncategorizedModels.forEach(model => {
|
||||
allModels.push({ value: model, label: model, category: 'Other (API)' });
|
||||
});
|
||||
}
|
||||
|
||||
// ── Filename-based inference ──────────────────────────────────────────
|
||||
const fileName = (document.querySelector('.file-name-content')?.textContent || '') + ' ' +
|
||||
(document.querySelector('.model-name-content')?.textContent || '');
|
||||
const inferredModels = inferBaseModelsFromFilename(fileName);
|
||||
const inferredSet = new Set(inferredModels);
|
||||
|
||||
// ── Build search widget DOM ───────────────────────────────────────────
|
||||
const wrapper = document.createElement('div');
|
||||
wrapper.className = 'base-model-search-wrapper';
|
||||
|
||||
// Search input row
|
||||
const inputWrapper = document.createElement('div');
|
||||
inputWrapper.className = 'base-model-search-input-wrapper';
|
||||
const searchIcon = document.createElement('i');
|
||||
searchIcon.className = 'fas fa-search search-icon';
|
||||
searchIcon.setAttribute('aria-hidden', 'true');
|
||||
inputWrapper.appendChild(searchIcon);
|
||||
const searchInput = document.createElement('input');
|
||||
searchInput.type = 'text';
|
||||
searchInput.className = 'base-model-search-input';
|
||||
searchInput.placeholder = translate('modals.model.metadata.baseModelSearchPlaceholder', {}, 'Search base model…');
|
||||
searchInput.autocomplete = 'off';
|
||||
searchInput.spellcheck = false;
|
||||
inputWrapper.appendChild(searchInput);
|
||||
wrapper.appendChild(inputWrapper);
|
||||
|
||||
// Dropdown list
|
||||
const dropdown = document.createElement('div');
|
||||
dropdown.className = 'base-model-dropdown';
|
||||
wrapper.appendChild(dropdown);
|
||||
|
||||
// ── Render ────────────────────────────────────────────────────────────
|
||||
function renderDropdown(filterText) {
|
||||
const lowerFilter = (filterText || '').toLowerCase().trim();
|
||||
dropdown.innerHTML = '';
|
||||
let hasVisibleItems = false;
|
||||
const fragment = document.createDocumentFragment();
|
||||
|
||||
// 1. Suggested section (filename-inferred, filtered by search)
|
||||
let suggestedToShow = inferredModels;
|
||||
if (lowerFilter) {
|
||||
suggestedToShow = inferredModels.filter(m =>
|
||||
m.toLowerCase().includes(lowerFilter)
|
||||
);
|
||||
}
|
||||
|
||||
if (suggestedToShow.length > 0) {
|
||||
const section = document.createElement('div');
|
||||
section.className = 'base-model-dropdown-section';
|
||||
|
||||
const header = document.createElement('div');
|
||||
header.className = 'base-model-dropdown-header suggested-header';
|
||||
header.innerHTML = '<i class="fas fa-star" aria-hidden="true"></i> ' +
|
||||
translate('modals.model.metadata.baseModelSuggested', {}, 'Suggested');
|
||||
section.appendChild(header);
|
||||
|
||||
suggestedToShow.forEach(model => {
|
||||
const item = document.createElement('div');
|
||||
item.className = 'base-model-dropdown-item';
|
||||
if (model === originalValue) item.classList.add('selected');
|
||||
item.dataset.value = model;
|
||||
item.textContent = model;
|
||||
section.appendChild(item);
|
||||
hasVisibleItems = true;
|
||||
});
|
||||
|
||||
fragment.appendChild(section);
|
||||
}
|
||||
|
||||
// 2. Categorized options (deduplicated against suggestions)
|
||||
const categoryMap = {};
|
||||
allModels.forEach(m => {
|
||||
if (inferredSet.has(m.value)) return; // already shown in Suggested
|
||||
if (lowerFilter && !m.label.toLowerCase().includes(lowerFilter)) return;
|
||||
if (!categoryMap[m.category]) categoryMap[m.category] = [];
|
||||
categoryMap[m.category].push(m);
|
||||
});
|
||||
|
||||
dropdown.appendChild(group);
|
||||
Object.entries(categoryMap).forEach(([category, items]) => {
|
||||
if (items.length === 0) return;
|
||||
const section = document.createElement('div');
|
||||
section.className = 'base-model-dropdown-section';
|
||||
|
||||
const header = document.createElement('div');
|
||||
header.className = 'base-model-dropdown-header';
|
||||
header.textContent = category;
|
||||
section.appendChild(header);
|
||||
|
||||
items.forEach(m => {
|
||||
const item = document.createElement('div');
|
||||
item.className = 'base-model-dropdown-item';
|
||||
if (m.value === originalValue) item.classList.add('selected');
|
||||
item.dataset.value = m.value;
|
||||
item.textContent = m.label;
|
||||
section.appendChild(item);
|
||||
hasVisibleItems = true;
|
||||
});
|
||||
|
||||
fragment.appendChild(section);
|
||||
});
|
||||
|
||||
// 3. Empty state
|
||||
if (!hasVisibleItems) {
|
||||
const empty = document.createElement('div');
|
||||
empty.className = 'base-model-dropdown-empty';
|
||||
empty.textContent = translate('modals.model.metadata.baseModelNoMatch', {}, 'No matching base models');
|
||||
fragment.appendChild(empty);
|
||||
}
|
||||
|
||||
dropdown.appendChild(fragment);
|
||||
|
||||
// Scroll the selected item into view
|
||||
const selected = dropdown.querySelector('.base-model-dropdown-item.selected');
|
||||
if (selected) {
|
||||
selected.scrollIntoView({ block: 'nearest' });
|
||||
}
|
||||
}
|
||||
|
||||
// Initial render — show everything
|
||||
renderDropdown('');
|
||||
|
||||
// ── Events ────────────────────────────────────────────────────────────
|
||||
let filterTimeout;
|
||||
searchInput.addEventListener('input', () => {
|
||||
clearTimeout(filterTimeout);
|
||||
filterTimeout = setTimeout(() => renderDropdown(searchInput.value), 50);
|
||||
});
|
||||
|
||||
// Replace content with dropdown
|
||||
// Click to select
|
||||
dropdown.addEventListener('click', (e) => {
|
||||
const item = e.target.closest('.base-model-dropdown-item');
|
||||
if (!item) return;
|
||||
baseModelContent.textContent = item.dataset.value;
|
||||
cleanup();
|
||||
const finalValue = baseModelContent.textContent.trim();
|
||||
if (finalValue !== originalValue) {
|
||||
saveBaseModel(
|
||||
getActiveModalFilePath(baseModelContent.dataset.filePath),
|
||||
originalValue
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
// Replace content with search widget
|
||||
baseModelContent.style.display = 'none';
|
||||
baseModelDisplay.insertBefore(dropdown, editBtn);
|
||||
|
||||
// Hide edit button during editing
|
||||
editBtn.style.display = 'none';
|
||||
baseModelDisplay.insertBefore(wrapper, editBtn);
|
||||
searchInput.focus();
|
||||
|
||||
// Focus the dropdown
|
||||
dropdown.focus();
|
||||
|
||||
// Handle dropdown change
|
||||
dropdown.addEventListener('change', function() {
|
||||
const selectedModel = this.value;
|
||||
baseModelContent.textContent = selectedModel;
|
||||
|
||||
// Mark that a change was made if the value differs from original
|
||||
if (selectedModel !== originalValue) {
|
||||
valueChanged = true;
|
||||
} else {
|
||||
valueChanged = false;
|
||||
// ── Cleanup ───────────────────────────────────────────────────────────
|
||||
function cleanup() {
|
||||
if (wrapper.parentNode === baseModelDisplay) {
|
||||
baseModelDisplay.removeChild(wrapper);
|
||||
}
|
||||
});
|
||||
|
||||
// Function to save changes and exit edit mode
|
||||
const saveAndExit = function() {
|
||||
// Check if dropdown still exists and remove it
|
||||
if (dropdown && dropdown.parentNode === baseModelDisplay) {
|
||||
baseModelDisplay.removeChild(dropdown);
|
||||
}
|
||||
|
||||
// Show the content and edit button
|
||||
baseModelContent.style.display = '';
|
||||
editBtn.style.display = '';
|
||||
|
||||
// Remove editing class
|
||||
baseModelDisplay.classList.remove('editing');
|
||||
|
||||
// Only save if the value has actually changed
|
||||
if (valueChanged || baseModelContent.textContent.trim() !== originalValue) {
|
||||
const resolvedPath = getActiveModalFilePath(baseModelContent.dataset.filePath);
|
||||
saveBaseModel(resolvedPath, originalValue);
|
||||
}
|
||||
|
||||
// Remove this event listener
|
||||
document.removeEventListener('click', outsideClickHandler);
|
||||
};
|
||||
}
|
||||
|
||||
// Handle outside clicks to save and exit
|
||||
// Outside click → save typed/custom value if any
|
||||
const outsideClickHandler = function(e) {
|
||||
// If click is outside the dropdown and base model display
|
||||
if (!baseModelDisplay.contains(e.target)) {
|
||||
saveAndExit();
|
||||
if (wrapper.contains(e.target)) return;
|
||||
|
||||
// If user typed a custom value (not just empty), apply it
|
||||
const typedValue = searchInput.value.trim();
|
||||
if (typedValue) {
|
||||
baseModelContent.textContent = typedValue;
|
||||
}
|
||||
cleanup();
|
||||
const finalValue = baseModelContent.textContent.trim();
|
||||
if (finalValue !== originalValue) {
|
||||
saveBaseModel(
|
||||
getActiveModalFilePath(baseModelContent.dataset.filePath),
|
||||
originalValue
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Add delayed event listener for outside clicks
|
||||
// Defer listener to avoid the opening click itself
|
||||
setTimeout(() => {
|
||||
document.addEventListener('click', outsideClickHandler);
|
||||
}, 0);
|
||||
|
||||
// Also handle dropdown blur event
|
||||
dropdown.addEventListener('blur', function(e) {
|
||||
// Only save if the related target is not the edit button or inside the baseModelDisplay
|
||||
if (!baseModelDisplay.contains(e.relatedTarget)) {
|
||||
saveAndExit();
|
||||
// Keyboard navigation
|
||||
searchInput.addEventListener('keydown', function onKeydown(e) {
|
||||
const items = Array.from(dropdown.querySelectorAll('.base-model-dropdown-item'));
|
||||
const activeIdx = items.findIndex(el => el.classList.contains('active'));
|
||||
|
||||
if (e.key === 'ArrowDown') {
|
||||
e.preventDefault();
|
||||
items.forEach(el => el.classList.remove('active'));
|
||||
const next = Math.min(activeIdx + 1, items.length - 1);
|
||||
if (items[next]) {
|
||||
items[next].classList.add('active');
|
||||
items[next].scrollIntoView({ block: 'nearest' });
|
||||
}
|
||||
} else if (e.key === 'ArrowUp') {
|
||||
e.preventDefault();
|
||||
items.forEach(el => el.classList.remove('active'));
|
||||
const prev = Math.max(activeIdx - 1, 0);
|
||||
if (items[prev]) {
|
||||
items[prev].classList.add('active');
|
||||
items[prev].scrollIntoView({ block: 'nearest' });
|
||||
}
|
||||
} else if (e.key === 'Enter') {
|
||||
e.preventDefault();
|
||||
const activeItem = items.find(el => el.classList.contains('active'));
|
||||
if (activeItem) {
|
||||
activeItem.click();
|
||||
} else if (searchInput.value.trim()) {
|
||||
// Custom value typed
|
||||
baseModelContent.textContent = searchInput.value.trim();
|
||||
cleanup();
|
||||
const finalValue = baseModelContent.textContent.trim();
|
||||
if (finalValue !== originalValue) {
|
||||
saveBaseModel(
|
||||
getActiveModalFilePath(baseModelContent.dataset.filePath),
|
||||
originalValue
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if (e.key === 'Escape') {
|
||||
e.preventDefault();
|
||||
baseModelContent.textContent = originalValue;
|
||||
cleanup();
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -360,6 +360,11 @@ export async function showModelModal(model, modelType) {
|
||||
const viewOnCivitaiAction = modelWithFullData.from_civitai ? `
|
||||
<div class="civitai-view" title="${translate('modals.model.actions.viewOnCivitai', {}, 'View on Civitai')}" data-action="view-civitai" data-filepath="${escapedFilePathAttr}">
|
||||
<i class="fas fa-globe"></i> ${translate('modals.model.actions.viewOnCivitaiText', {}, 'View on Civitai')}
|
||||
</div>`.trim() : '';
|
||||
const escapedHfUrl = modelWithFullData.hf_url ? escapeAttribute(modelWithFullData.hf_url) : '';
|
||||
const viewOnHuggingFaceAction = escapedHfUrl ? `
|
||||
<div class="civitai-view" title="${translate('modals.model.actions.viewOnHuggingFace', {}, 'View on Hugging Face')}" data-action="view-huggingface" data-hf-url="${escapedHfUrl}">
|
||||
<i class="fas fa-globe"></i> ${translate('modals.model.actions.viewOnHuggingFaceText', {}, 'View on Hugging Face')}
|
||||
</div>`.trim() : '';
|
||||
const creatorInfoAction = modelWithFullData.civitai?.creator ? `
|
||||
<div class="creator-info" data-username="${modelWithFullData.civitai.creator.username}" data-action="view-creator" title="${translate('modals.model.actions.viewCreatorProfile', {}, 'View Creator Profile')}">
|
||||
@@ -377,6 +382,9 @@ export async function showModelModal(model, modelType) {
|
||||
if (viewOnCivitaiAction) {
|
||||
creatorActionItems.push(indentMarkup(viewOnCivitaiAction, 24));
|
||||
}
|
||||
if (viewOnHuggingFaceAction) {
|
||||
creatorActionItems.push(indentMarkup(viewOnHuggingFaceAction, 24));
|
||||
}
|
||||
if (creatorInfoAction) {
|
||||
creatorActionItems.push(indentMarkup(creatorInfoAction, 24));
|
||||
}
|
||||
@@ -869,6 +877,11 @@ function setupEventHandlers(filePath, modelType) {
|
||||
case 'view-civitai':
|
||||
openCivitai(target.dataset.filepath);
|
||||
break;
|
||||
case 'view-huggingface':
|
||||
if (target.dataset.hfUrl) {
|
||||
window.open(target.dataset.hfUrl, '_blank', 'noopener,noreferrer');
|
||||
}
|
||||
break;
|
||||
case 'view-creator':
|
||||
const username = target.dataset.username;
|
||||
if (username) {
|
||||
|
||||
196
static/js/managers/AgentManager.js
Normal file
196
static/js/managers/AgentManager.js
Normal file
@@ -0,0 +1,196 @@
|
||||
/**
|
||||
* AgentManager — WebSocket listener for agent skill progress events.
|
||||
*
|
||||
* Connects to the generic WebSocket endpoint and filters for
|
||||
* `type: "agent_progress"` messages. Dispatches progress and completion
|
||||
* events to registered callbacks.
|
||||
*/
|
||||
class AgentManager {
|
||||
constructor() {
|
||||
this.websocket = null;
|
||||
this.progressCallbacks = [];
|
||||
this.completeCallbacks = [];
|
||||
this.errorCallbacks = [];
|
||||
this.connected = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to the WebSocket endpoint for agent progress events.
|
||||
* Safe to call multiple times — won't reconnect if already connected.
|
||||
*/
|
||||
connect() {
|
||||
if (this.connected && this.websocket?.readyState === WebSocket.OPEN) {
|
||||
return;
|
||||
}
|
||||
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||
try {
|
||||
this.websocket = new WebSocket(
|
||||
`${wsProtocol}${window.location.host}/ws/fetch-progress`
|
||||
);
|
||||
} catch (e) {
|
||||
console.error('AgentManager: Failed to create WebSocket:', e);
|
||||
return;
|
||||
}
|
||||
|
||||
this.websocket.onopen = () => {
|
||||
this.connected = true;
|
||||
console.debug('AgentManager: WebSocket connected');
|
||||
};
|
||||
|
||||
this.websocket.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.type !== 'agent_progress') return;
|
||||
this._dispatch(data);
|
||||
} catch (e) {
|
||||
// Not JSON or wrong format — ignore
|
||||
}
|
||||
};
|
||||
|
||||
this.websocket.onerror = (error) => {
|
||||
console.error('AgentManager: WebSocket error:', error);
|
||||
this.connected = false;
|
||||
};
|
||||
|
||||
this.websocket.onclose = () => {
|
||||
this.connected = false;
|
||||
console.debug('AgentManager: WebSocket closed');
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispatch a parsed agent event to the appropriate callbacks.
|
||||
* @param {Object} data - The parsed WebSocket message
|
||||
*/
|
||||
_dispatch(data) {
|
||||
const { status, skill } = data;
|
||||
|
||||
if (status === 'error') {
|
||||
this.errorCallbacks.forEach((cb) => {
|
||||
try {
|
||||
cb(data);
|
||||
} catch (e) {
|
||||
console.error('AgentManager error callback failed:', e);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
if (status === 'completed') {
|
||||
this.completeCallbacks.forEach((cb) => {
|
||||
try {
|
||||
cb(data);
|
||||
} catch (e) {
|
||||
console.error('AgentManager complete callback failed:', e);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// started, processing — general progress
|
||||
this.progressCallbacks.forEach((cb) => {
|
||||
try {
|
||||
cb(data);
|
||||
} catch (e) {
|
||||
console.error('AgentManager progress callback failed:', e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a callback for progress events (started, processing).
|
||||
* @param {Function} callback - Receives the event data
|
||||
*/
|
||||
onProgress(callback) {
|
||||
this.progressCallbacks.push(callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a callback for completion events.
|
||||
* @param {Function} callback - Receives the event data
|
||||
*/
|
||||
onComplete(callback) {
|
||||
this.completeCallbacks.push(callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a callback for error events.
|
||||
* @param {Function} callback - Receives the event data
|
||||
*/
|
||||
onError(callback) {
|
||||
this.errorCallbacks.push(callback);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear all registered callbacks.
|
||||
*/
|
||||
clearCallbacks() {
|
||||
this.progressCallbacks = [];
|
||||
this.completeCallbacks = [];
|
||||
this.errorCallbacks = [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute an agent skill on the provided model paths.
|
||||
*
|
||||
* @param {string} skillName - The skill to execute
|
||||
* @param {string[]} modelPaths - Model file paths to process
|
||||
* @returns {Promise<Object>} The response JSON
|
||||
*/
|
||||
async executeSkill(skillName, modelPaths) {
|
||||
const response = await fetch(
|
||||
`/api/lm/agent/execute/${encodeURIComponent(skillName)}`,
|
||||
{
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({ model_paths: modelPaths }),
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
throw new Error(
|
||||
errorData.error || `HTTP ${response.status}: ${response.statusText}`
|
||||
);
|
||||
}
|
||||
|
||||
return response.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the LLM provider is configured.
|
||||
*
|
||||
* Returns true when both an API key and a model name are set.
|
||||
*
|
||||
* @returns {Promise<boolean>}
|
||||
*/
|
||||
async isLlmConfigured() {
|
||||
try {
|
||||
const response = await fetch('/api/lm/settings');
|
||||
if (!response.ok) return false;
|
||||
const data = await response.json();
|
||||
const provider = data.settings?.llm_provider;
|
||||
const hasModel = !!data.settings?.llm_model;
|
||||
const hasKey = !!data.settings?.llm_api_key;
|
||||
return hasModel && (hasKey || provider === 'ollama');
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the list of available agent skills.
|
||||
*
|
||||
* @returns {Promise<Array>}
|
||||
*/
|
||||
async listSkills() {
|
||||
const response = await fetch('/api/lm/agent/skills');
|
||||
if (!response.ok) return [];
|
||||
const data = await response.json();
|
||||
return data.skills || [];
|
||||
}
|
||||
}
|
||||
|
||||
// Export as singleton
|
||||
export const agentManager = new AgentManager();
|
||||
@@ -7,6 +7,7 @@ import { getStorageItem, setStorageItem } from '../utils/storageHelpers.js';
|
||||
import { FolderTreeManager } from '../components/FolderTreeManager.js';
|
||||
import { translate } from '../utils/i18nHelpers.js';
|
||||
import { extractCivitaiModelUrlParts } from '../utils/civitaiUtils.js';
|
||||
import { formatFileSize } from '../utils/formatters.js';
|
||||
|
||||
export class DownloadManager {
|
||||
constructor() {
|
||||
@@ -27,6 +28,10 @@ export class DownloadManager {
|
||||
this.isBatchMode = false;
|
||||
this.editingBatchIndex = -1;
|
||||
|
||||
// HF download state
|
||||
this.hfRepoId = null;
|
||||
this.hfSelectedFiles = [];
|
||||
|
||||
this.loadingManager = new LoadingManager();
|
||||
this.folderTreeManager = new FolderTreeManager();
|
||||
this.folderClickHandler = null;
|
||||
@@ -44,6 +49,8 @@ export class DownloadManager {
|
||||
this.handleToggleDefaultPath = this.toggleDefaultPath.bind(this);
|
||||
this.handleBackToUrlFromBatch = this.backToUrlFromBatch.bind(this);
|
||||
this.handleNextFromBatch = this.nextFromBatch.bind(this);
|
||||
|
||||
|
||||
}
|
||||
|
||||
showDownloadModal() {
|
||||
@@ -99,6 +106,8 @@ export class DownloadManager {
|
||||
|
||||
// Default path toggle handler
|
||||
document.getElementById('useDefaultPath').addEventListener('change', this.handleToggleDefaultPath);
|
||||
|
||||
|
||||
}
|
||||
|
||||
updateModalLabels() {
|
||||
@@ -160,6 +169,10 @@ export class DownloadManager {
|
||||
|
||||
// Reset default path toggle
|
||||
this.loadDefaultPathSetting();
|
||||
|
||||
// Reset HF state
|
||||
this.hfRepoId = null;
|
||||
this.hfSelectedFiles = [];
|
||||
}
|
||||
|
||||
async retrieveVersionsForModel(modelId, source = null) {
|
||||
@@ -180,6 +193,29 @@ export class DownloadManager {
|
||||
return;
|
||||
}
|
||||
|
||||
// Detect URL types — all URLs must share the same source type
|
||||
const urlTypes = urls.map(u => DownloadManager.detectUrlType(u));
|
||||
const isHf = urlTypes.every(t => t && (t.type === 'hf-resolve' || t.type === 'hf-repo'));
|
||||
const isCivitai = urlTypes.every(t => t && t.type === 'civitai');
|
||||
|
||||
if (!isHf && !isCivitai) {
|
||||
const allValid = urlTypes.every(t => t !== null);
|
||||
if (!allValid) {
|
||||
errorElement.textContent = translate('modals.download.errors.invalidUrl');
|
||||
return;
|
||||
}
|
||||
// Mixed sources not supported in one batch
|
||||
if (urls.length > 1) {
|
||||
errorElement.textContent = translate('modals.download.errors.mixedSources');
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (isHf) {
|
||||
return this._validateAndFetchHf(urls, errorElement);
|
||||
}
|
||||
|
||||
// --- Original CivitAI flow below ---
|
||||
if (urls.length === 1) {
|
||||
this.isBatchMode = false;
|
||||
try {
|
||||
@@ -271,6 +307,112 @@ export class DownloadManager {
|
||||
this.showBatchPreviewStep();
|
||||
}
|
||||
|
||||
// ---- Hugging Face download flow ----
|
||||
|
||||
async _validateAndFetchHf(urls, errorElement) {
|
||||
if (urls.length === 1) {
|
||||
const info = DownloadManager.detectUrlType(urls[0]);
|
||||
// Direct file resolve URL → skip file selection, go to location
|
||||
if (info.type === 'hf-resolve') {
|
||||
this.isBatchMode = false;
|
||||
this.hfRepoId = info.repo;
|
||||
this.hfSelectedFiles = [info.filename];
|
||||
this.source = 'huggingface';
|
||||
this.proceedToLocation();
|
||||
return;
|
||||
}
|
||||
// Repo URL → fetch file list and convert to batch items
|
||||
try {
|
||||
this.loadingManager.showSimpleLoading(translate('modals.download.fetchingRepoFiles'));
|
||||
const files = await this.apiClient.fetchHfRepoFiles(info.repo);
|
||||
if (!files || files.length === 0) {
|
||||
throw new Error(translate('modals.download.errors.noModelFiles'));
|
||||
}
|
||||
this.isBatchMode = true;
|
||||
this.batchModels = [];
|
||||
this.source = 'huggingface';
|
||||
for (const file of files) {
|
||||
this.batchModels.push({
|
||||
url: urls[0],
|
||||
source: 'huggingface',
|
||||
repo: info.repo,
|
||||
filename: file.filename,
|
||||
revision: 'main',
|
||||
displayName: file.filename,
|
||||
fileSizeBytes: file.size,
|
||||
selectedVersion: true,
|
||||
versions: [],
|
||||
checked: false,
|
||||
error: null,
|
||||
});
|
||||
}
|
||||
this.showBatchPreviewStep();
|
||||
} catch (err) {
|
||||
errorElement.textContent = err.message;
|
||||
} finally {
|
||||
this.loadingManager.hide();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Multiple HF URLs → batch mode: flatten all files from all repos
|
||||
this.isBatchMode = true;
|
||||
this.batchModels = [];
|
||||
this.source = 'huggingface';
|
||||
this.loadingManager.showSimpleLoading(translate('modals.download.fetchingRepoFiles'));
|
||||
|
||||
for (const url of urls) {
|
||||
const info = DownloadManager.detectUrlType(url);
|
||||
if (!info) {
|
||||
this.batchModels.push({ url, error: 'Invalid URL', versions: [], selectedVersion: null });
|
||||
continue;
|
||||
}
|
||||
if (info.type === 'hf-resolve') {
|
||||
this.batchModels.push({
|
||||
url,
|
||||
source: 'huggingface',
|
||||
repo: info.repo,
|
||||
filename: info.filename,
|
||||
revision: info.revision || 'main',
|
||||
displayName: info.filename,
|
||||
selectedVersion: true,
|
||||
versions: [],
|
||||
checked: false,
|
||||
error: null,
|
||||
});
|
||||
} else if (info.type === 'hf-repo') {
|
||||
try {
|
||||
const files = await this.apiClient.fetchHfRepoFiles(info.repo);
|
||||
if (!files || files.length === 0) {
|
||||
this.batchModels.push({ url, error: 'No model files found', versions: [], selectedVersion: null });
|
||||
continue;
|
||||
}
|
||||
// Flatten: create one batch item per file, all checked by default
|
||||
for (const file of files) {
|
||||
this.batchModels.push({
|
||||
url,
|
||||
source: 'huggingface',
|
||||
repo: info.repo,
|
||||
filename: file.filename,
|
||||
revision: 'main',
|
||||
displayName: file.filename,
|
||||
fileSizeBytes: file.size,
|
||||
selectedVersion: true,
|
||||
versions: [],
|
||||
checked: false,
|
||||
error: null,
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
this.batchModels.push({ url, error: err.message, versions: [], selectedVersion: null });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.loadingManager.hide();
|
||||
this.showBatchPreviewStep();
|
||||
}
|
||||
|
||||
async fetchVersionsForCurrentModel() {
|
||||
const errorElement = document.getElementById('urlError');
|
||||
if (errorElement) {
|
||||
@@ -311,6 +453,60 @@ export class DownloadManager {
|
||||
return { modelId: null, modelVersionId: null, source: null };
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect the source type of a download URL.
|
||||
* @param {string} url
|
||||
* @returns {{ type: string, repo?: string, filename?: string, revision?: string } | null}
|
||||
* type: 'civitai' | 'civarchive' | 'hf-resolve' | 'hf-repo' | 'direct-http'
|
||||
*/
|
||||
static detectUrlType(url) {
|
||||
const trimmed = url.trim();
|
||||
if (!trimmed) return null;
|
||||
|
||||
// CivitAI
|
||||
if (/civitai\.com\/models\//i.test(trimmed) || /civitaiarchive|civarchive/i.test(trimmed)) {
|
||||
// Will be parsed by existing CivitAI logic
|
||||
return { type: 'civitai' };
|
||||
}
|
||||
|
||||
// Hugging Face resolve URL → direct file
|
||||
const hfResolveMatch = trimmed.match(/huggingface\.co\/([^/\s]+\/[^/\s]+)\/resolve\/([^/\s]+)\/(.+)/i);
|
||||
if (hfResolveMatch) {
|
||||
return {
|
||||
type: 'hf-resolve',
|
||||
repo: hfResolveMatch[1],
|
||||
revision: hfResolveMatch[2],
|
||||
filename: hfResolveMatch[3],
|
||||
};
|
||||
}
|
||||
|
||||
// Hugging Face repo URL (huggingface.co/user/repo or bare user/repo path)
|
||||
// Require huggingface.co prefix for full URLs; bare user/repo only without ://
|
||||
const hfRepoMatch = trimmed.match(
|
||||
trimmed.includes('://')
|
||||
? /^https?:\/\/huggingface\.co\/([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)(?:\/?$|$)/
|
||||
: /^([a-zA-Z0-9_.-]+\/[a-zA-Z0-9_.-]+)$/
|
||||
);
|
||||
if (hfRepoMatch) {
|
||||
// Reject path-traversal patterns like "../.." or "user/.."
|
||||
const parts = hfRepoMatch[1].split('/');
|
||||
if (parts.some(p => p === '.' || p === '..')) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
type: 'hf-repo',
|
||||
repo: hfRepoMatch[1],
|
||||
};
|
||||
}
|
||||
|
||||
// Direct HTTP(S) URL (non-HF)
|
||||
if (/^https?:\/\//i.test(trimmed)) {
|
||||
return { type: 'direct-http' };
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
extractModelId(url) {
|
||||
const result = DownloadManager.parseModelUrl(url);
|
||||
this.modelVersionId = result.modelVersionId;
|
||||
@@ -559,8 +755,8 @@ export class DownloadManager {
|
||||
return;
|
||||
}
|
||||
|
||||
// In single-URL mode, validate version selection
|
||||
if (!this.isBatchMode) {
|
||||
// In single-URL mode, validate version selection (skip for HF)
|
||||
if (!this.isBatchMode && this.source !== 'huggingface') {
|
||||
if (!this.currentVersion) {
|
||||
showToast('toast.loras.pleaseSelectVersion', {}, 'error');
|
||||
return;
|
||||
@@ -784,6 +980,77 @@ export class DownloadManager {
|
||||
}
|
||||
}
|
||||
|
||||
async _downloadHfSingle({ modelRoot, targetFolder, useDefaultPaths }) {
|
||||
modalManager.closeModal('downloadModal');
|
||||
this.loadingManager.restoreProgressBar();
|
||||
const totalFiles = this.hfSelectedFiles.length;
|
||||
const updateProgress = this.loadingManager.showDownloadProgress(totalFiles);
|
||||
|
||||
try {
|
||||
let completedDownloads = 0;
|
||||
for (let i = 0; i < totalFiles; i++) {
|
||||
const filename = this.hfSelectedFiles[i];
|
||||
updateProgress(0, completedDownloads, filename);
|
||||
this.loadingManager.setStatus(`Downloading ${filename}...`);
|
||||
|
||||
const downloadId = Date.now().toString() + '_' + i;
|
||||
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
|
||||
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||
|
||||
try {
|
||||
await new Promise((resolve, reject) => {
|
||||
ws.onopen = resolve;
|
||||
ws.onerror = reject;
|
||||
});
|
||||
|
||||
// Capture completed count at WS creation time so progress
|
||||
// updates arriving after completedDownloads increments still
|
||||
// show the correct "N / total" position.
|
||||
const snapshotCompleted = completedDownloads;
|
||||
ws.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.status === 'progress') {
|
||||
const metrics = {
|
||||
bytesDownloaded: data.bytes_downloaded,
|
||||
totalBytes: data.total_bytes,
|
||||
bytesPerSecond: data.bytes_per_second,
|
||||
};
|
||||
updateProgress(data.progress, snapshotCompleted, filename, metrics);
|
||||
}
|
||||
};
|
||||
|
||||
const response = await this.apiClient.downloadHfModel({
|
||||
repo: this.hfRepoId,
|
||||
filename,
|
||||
revision: 'main',
|
||||
modelRoot,
|
||||
relativePath: targetFolder,
|
||||
useDefaultPaths,
|
||||
download_id: downloadId,
|
||||
});
|
||||
|
||||
if (response?.success) {
|
||||
completedDownloads++;
|
||||
updateProgress(100, completedDownloads, filename);
|
||||
}
|
||||
} finally {
|
||||
ws.close();
|
||||
}
|
||||
}
|
||||
|
||||
showToast('toast.loras.downloadCompleted', {}, 'success');
|
||||
// Reload page data — model is already in scanner cache via backend
|
||||
await resetAndReload(true);
|
||||
return true;
|
||||
} catch (error) {
|
||||
console.error('Failed to download HF model:', error);
|
||||
showToast('toast.downloads.downloadError', { message: error?.message }, 'error');
|
||||
return false;
|
||||
} finally {
|
||||
this.loadingManager.hide();
|
||||
}
|
||||
}
|
||||
|
||||
updatePathSelectionUI() {
|
||||
const manualSelection = document.getElementById('manualPathSelection');
|
||||
|
||||
@@ -812,13 +1079,19 @@ export class DownloadManager {
|
||||
document.querySelectorAll('.download-step').forEach(step => step.style.display = 'none');
|
||||
document.getElementById('batchPreviewStep').style.display = 'block';
|
||||
|
||||
const validCount = this.batchModels.filter(m => !m.error && m.selectedVersion).length;
|
||||
const validCount = this.batchModels.filter(m => {
|
||||
if (m.error) return false;
|
||||
if (m.source === 'huggingface') return m.checked !== false;
|
||||
return m.selectedVersion;
|
||||
}).length;
|
||||
document.getElementById('downloadModalTitle').textContent =
|
||||
translate('modals.download.titleWithType', { type: this.apiClient.apiConfig.config.displayName }) +
|
||||
` (${validCount})`;
|
||||
|
||||
const list = document.getElementById('batchPreviewList');
|
||||
list.innerHTML = this.batchModels.map((item, index) => {
|
||||
const hasHfItems = this.batchModels.some(m => m.source === 'huggingface' && !m.error);
|
||||
|
||||
let itemsHtml = this.batchModels.map((item, index) => {
|
||||
if (item.error) {
|
||||
return `
|
||||
<div class="batch-preview-item batch-preview-error" data-index="${index}">
|
||||
@@ -837,6 +1110,30 @@ export class DownloadManager {
|
||||
}
|
||||
|
||||
const ver = item.selectedVersion;
|
||||
|
||||
// HF batch item rendering with checkbox
|
||||
if (item.source === 'huggingface') {
|
||||
const hfSize = item.fileSizeBytes
|
||||
? formatFileSize(item.fileSizeBytes)
|
||||
: '?';
|
||||
return `
|
||||
<div class="batch-preview-item" data-index="${index}">
|
||||
<input type="checkbox" class="batch-preview-checkbox"
|
||||
data-index="${index}" ${item.checked !== false ? 'checked' : ''} />
|
||||
<div class="batch-preview-info">
|
||||
<div class="batch-preview-name">${item.displayName || item.filename || `HF #${index}`} <span class="hf-badge">HF</span></div>
|
||||
<div class="batch-preview-meta">
|
||||
<span>${hfSize}</span>
|
||||
<span>${item.repo || ''}</span>
|
||||
</div>
|
||||
</div>
|
||||
<button class="batch-preview-remove" data-index="${index}" title="${translate('common.actions.remove', {}, 'Remove')}">
|
||||
<i class="fas fa-times"></i>
|
||||
</button>
|
||||
</div>
|
||||
`;
|
||||
}
|
||||
|
||||
const firstImage = ver?.images?.find(img => !img.url.endsWith('.mp4'));
|
||||
const thumbnailUrl = firstImage ? firstImage.url : '/loras_static/images/no-preview.png';
|
||||
const fileSize = ver?.modelSizeKB
|
||||
@@ -866,6 +1163,21 @@ export class DownloadManager {
|
||||
`;
|
||||
}).join('');
|
||||
|
||||
// Prepend select-all toolbar if there are HF items with checkboxes
|
||||
if (hasHfItems) {
|
||||
const allChecked = this.batchModels
|
||||
.filter(m => m.source === 'huggingface' && !m.error)
|
||||
.every(m => m.checked !== false);
|
||||
itemsHtml = `
|
||||
<div class="batch-preview-select-all">
|
||||
<input type="checkbox" id="batchSelectAll" ${allChecked ? 'checked' : ''} />
|
||||
<label for="batchSelectAll">${translate('modals.download.selectAll', {}, 'Select All')}</label>
|
||||
</div>
|
||||
` + itemsHtml;
|
||||
}
|
||||
|
||||
list.innerHTML = itemsHtml;
|
||||
|
||||
list.onclick = (e) => {
|
||||
const removeBtn = e.target.closest('.batch-preview-remove');
|
||||
if (removeBtn) {
|
||||
@@ -881,6 +1193,59 @@ export class DownloadManager {
|
||||
}
|
||||
};
|
||||
|
||||
// Checkbox handler for HF batch items
|
||||
const checkboxes = list.querySelectorAll('.batch-preview-checkbox');
|
||||
checkboxes.forEach(cb => {
|
||||
cb.addEventListener('change', (e) => {
|
||||
const idx = parseInt(e.target.dataset.index);
|
||||
if (this.batchModels[idx]) {
|
||||
this.batchModels[idx].checked = e.target.checked;
|
||||
}
|
||||
// Update valid count in title and Next button
|
||||
const checkedCount = this.batchModels.filter(
|
||||
m => !m.error && m.checked !== false
|
||||
).length;
|
||||
document.getElementById('downloadModalTitle').textContent =
|
||||
translate('modals.download.titleWithType', { type: this.apiClient.apiConfig.config.displayName }) +
|
||||
` (${checkedCount})`;
|
||||
const nextBtn = document.getElementById('nextFromBatchBtn');
|
||||
nextBtn.disabled = checkedCount === 0;
|
||||
nextBtn.classList.toggle('disabled', checkedCount === 0);
|
||||
// Update select-all checkbox state
|
||||
const selectAll = document.getElementById('batchSelectAll');
|
||||
if (selectAll) {
|
||||
const hfItems = this.batchModels.filter(m => m.source === 'huggingface' && !m.error);
|
||||
selectAll.checked = hfItems.length > 0 && hfItems.every(m => m.checked !== false);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Select-all handler
|
||||
const selectAll = document.getElementById('batchSelectAll');
|
||||
if (selectAll) {
|
||||
selectAll.addEventListener('change', (e) => {
|
||||
const checked = e.target.checked;
|
||||
const hfCheckboxes = list.querySelectorAll('.batch-preview-checkbox');
|
||||
hfCheckboxes.forEach(cb => {
|
||||
cb.checked = checked;
|
||||
const idx = parseInt(cb.dataset.index);
|
||||
if (this.batchModels[idx]) {
|
||||
this.batchModels[idx].checked = checked;
|
||||
}
|
||||
});
|
||||
// Update valid count in title and Next button
|
||||
const checkedCount = this.batchModels.filter(
|
||||
m => !m.error && m.checked !== false
|
||||
).length;
|
||||
document.getElementById('downloadModalTitle').textContent =
|
||||
translate('modals.download.titleWithType', { type: this.apiClient.apiConfig.config.displayName }) +
|
||||
` (${checkedCount})`;
|
||||
const nextBtn = document.getElementById('nextFromBatchBtn');
|
||||
nextBtn.disabled = checkedCount === 0;
|
||||
nextBtn.classList.toggle('disabled', checkedCount === 0);
|
||||
});
|
||||
}
|
||||
|
||||
const nextBtn = document.getElementById('nextFromBatchBtn');
|
||||
nextBtn.disabled = validCount === 0;
|
||||
nextBtn.classList.toggle('disabled', validCount === 0);
|
||||
@@ -903,7 +1268,12 @@ export class DownloadManager {
|
||||
}
|
||||
|
||||
nextFromBatch() {
|
||||
const validModels = this.batchModels.filter(m => !m.error && m.selectedVersion);
|
||||
// For HF items, respect the checked flag; for CivitAI items, use selectedVersion
|
||||
const validModels = this.batchModels.filter(m => {
|
||||
if (m.error) return false;
|
||||
if (m.source === 'huggingface') return m.checked !== false;
|
||||
return m.selectedVersion;
|
||||
});
|
||||
if (validModels.length === 0) return;
|
||||
this.proceedToLocation();
|
||||
}
|
||||
@@ -953,6 +1323,15 @@ export class DownloadManager {
|
||||
targetFolder = this.folderTreeManager.getSelectedPath();
|
||||
}
|
||||
if (!this.isBatchMode) {
|
||||
// Single-item download
|
||||
if (this.source === 'huggingface') {
|
||||
return this._downloadHfSingle({
|
||||
modelRoot,
|
||||
targetFolder,
|
||||
useDefaultPaths,
|
||||
});
|
||||
}
|
||||
|
||||
const fileParams = this.selectedFile ? {
|
||||
type: this.selectedFile.type || 'Model',
|
||||
format: this.selectedFile.metadata?.format || 'SafeTensor',
|
||||
@@ -974,7 +1353,13 @@ export class DownloadManager {
|
||||
}
|
||||
|
||||
// Batch download mode
|
||||
const downloadItems = this.batchModels.filter(m => !m.error && m.selectedVersion && !m.selectedVersion.existsLocally);
|
||||
const downloadItems = this.batchModels.filter(m => {
|
||||
if (m.error) return false;
|
||||
if (!m.selectedVersion) return false;
|
||||
// HF items have selectedVersion as a boolean marker + checked flag
|
||||
if (m.source === 'huggingface') return m.checked !== false;
|
||||
return !m.selectedVersion.existsLocally;
|
||||
});
|
||||
if (downloadItems.length === 0) {
|
||||
showToast('toast.loras.downloadCompleted', {}, 'info');
|
||||
modalManager.closeModal('downloadModal');
|
||||
@@ -999,7 +1384,7 @@ export class DownloadManager {
|
||||
|
||||
if (data.status === 'progress' && data.download_id?.startsWith(batchDownloadId)) {
|
||||
const current = downloadItems[completedDownloads + failedDownloads];
|
||||
const name = current?.selectedVersion?.name || `#${completedDownloads + failedDownloads + 1}`;
|
||||
const name = current?.selectedVersion?.name || current?.displayName || current?.filename || `#${completedDownloads + failedDownloads + 1}`;
|
||||
const metrics = {
|
||||
bytesDownloaded: data.bytes_downloaded,
|
||||
totalBytes: data.total_bytes,
|
||||
@@ -1016,22 +1401,59 @@ export class DownloadManager {
|
||||
|
||||
for (let i = 0; i < downloadItems.length; i++) {
|
||||
const item = downloadItems[i];
|
||||
const ver = item.selectedVersion;
|
||||
const name = ver?.name || `Model #${item.modelId}`;
|
||||
const name = item.displayName || item.filename || (item.selectedVersion?.name || `Model #${item.modelId}`);
|
||||
const isHf = item.source === 'huggingface';
|
||||
|
||||
updateProgress(0, completedDownloads, name);
|
||||
loadingManager.setStatus(`${i + 1}/${downloadItems.length}: ${name}`);
|
||||
|
||||
try {
|
||||
const response = await this.apiClient.downloadModel(
|
||||
item.modelId,
|
||||
ver.id,
|
||||
modelRoot,
|
||||
targetFolder,
|
||||
useDefaultPaths,
|
||||
batchDownloadId,
|
||||
item.source
|
||||
);
|
||||
let response;
|
||||
if (isHf) {
|
||||
// Per-file WebSocket for real-time progress
|
||||
const downloadId = Date.now().toString() + '_hf_' + i;
|
||||
const wsHf = new WebSocket(`${wsProtocol}${window.location.host}/ws/download-progress?id=${downloadId}`);
|
||||
try {
|
||||
await new Promise((resolve, reject) => {
|
||||
wsHf.onopen = resolve;
|
||||
wsHf.onerror = reject;
|
||||
});
|
||||
const snapshotCompleted = completedDownloads;
|
||||
wsHf.onmessage = (event) => {
|
||||
const data = JSON.parse(event.data);
|
||||
if (data.status === 'progress') {
|
||||
const metrics = {
|
||||
bytesDownloaded: data.bytes_downloaded,
|
||||
totalBytes: data.total_bytes,
|
||||
bytesPerSecond: data.bytes_per_second,
|
||||
};
|
||||
updateProgress(data.progress, snapshotCompleted, name, metrics);
|
||||
}
|
||||
};
|
||||
|
||||
response = await this.apiClient.downloadHfModel({
|
||||
repo: item.repo,
|
||||
filename: item.filename,
|
||||
revision: item.revision || 'main',
|
||||
modelRoot,
|
||||
relativePath: targetFolder,
|
||||
useDefaultPaths,
|
||||
download_id: downloadId,
|
||||
});
|
||||
} finally {
|
||||
wsHf.close();
|
||||
}
|
||||
} else {
|
||||
response = await this.apiClient.downloadModel(
|
||||
item.modelId,
|
||||
item.selectedVersion.id,
|
||||
modelRoot,
|
||||
targetFolder,
|
||||
useDefaultPaths,
|
||||
batchDownloadId,
|
||||
item.source
|
||||
);
|
||||
}
|
||||
|
||||
if (!response.success) {
|
||||
failedDownloads++;
|
||||
|
||||
@@ -827,6 +827,23 @@ export class SettingsManager {
|
||||
|
||||
// Update API key status display (do NOT pre-fill the input)
|
||||
this.updateApiKeyStatus();
|
||||
this.updateLlmApiKeyStatus();
|
||||
|
||||
// AI Provider settings
|
||||
const llmProviderSelect = document.getElementById('llmProvider');
|
||||
if (llmProviderSelect) {
|
||||
llmProviderSelect.value = state.global.settings.llm_provider || 'openai';
|
||||
}
|
||||
|
||||
const llmApiBaseInput = document.getElementById('llmApiBase');
|
||||
if (llmApiBaseInput) {
|
||||
llmApiBaseInput.value = state.global.settings.llm_api_base || '';
|
||||
}
|
||||
|
||||
const llmModelInput = document.getElementById('llmModel');
|
||||
if (llmModelInput) {
|
||||
llmModelInput.value = state.global.settings.llm_model || '';
|
||||
}
|
||||
|
||||
const civitaiHostSelect = document.getElementById('civitaiHost');
|
||||
if (civitaiHostSelect) {
|
||||
@@ -2931,42 +2948,70 @@ export class SettingsManager {
|
||||
}
|
||||
}
|
||||
|
||||
editApiKey() {
|
||||
const statusEl = document.getElementById('civitaiApiKeyStatus');
|
||||
updateLlmApiKeyStatus() {
|
||||
const hasKey = !!state.global.settings.llm_api_key;
|
||||
const statusText = document.getElementById('llmApiKeyStatusText');
|
||||
const actionBtn = document.getElementById('llmApiKeyActionBtn');
|
||||
if (!statusText || !actionBtn) return;
|
||||
|
||||
if (hasKey) {
|
||||
statusText.classList.remove('api-key-status--unconfigured');
|
||||
statusText.classList.add('api-key-status--configured');
|
||||
statusText.innerHTML = '<i class="fas fa-check-circle text-success"></i> '
|
||||
+ translate('settings.aiProvider.apiKeyConfigured', {}, 'Configured');
|
||||
actionBtn.textContent = translate('common.actions.change', {}, 'Change');
|
||||
} else {
|
||||
statusText.classList.remove('api-key-status--configured');
|
||||
statusText.classList.add('api-key-status--unconfigured');
|
||||
statusText.innerHTML = '<i class="fas fa-times-circle text-error"></i> '
|
||||
+ translate('settings.aiProvider.apiKeyNotSet', {}, 'Not set');
|
||||
actionBtn.textContent = translate('settings.aiProvider.apiKeySet', {}, 'Set up');
|
||||
}
|
||||
}
|
||||
|
||||
editApiKey(settingsKey = 'civitai_api_key', inputId = 'civitaiApiKey') {
|
||||
const statusId = inputId + 'Status';
|
||||
const editId = inputId + 'Edit';
|
||||
const statusEl = document.getElementById(statusId);
|
||||
if (statusEl) statusEl.classList.add('is-hidden');
|
||||
const editContainer = document.getElementById('civitaiApiKeyEdit');
|
||||
const editContainer = document.getElementById(editId);
|
||||
if (editContainer) editContainer.classList.remove('is-hidden');
|
||||
// Focus the input
|
||||
const input = document.getElementById('civitaiApiKey');
|
||||
const input = document.getElementById(inputId);
|
||||
if (input) {
|
||||
input.value = ''; // Never pre-fill the secret
|
||||
setTimeout(() => input.focus(), 50);
|
||||
}
|
||||
}
|
||||
|
||||
cancelEditApiKey(silent) {
|
||||
const editContainer = document.getElementById('civitaiApiKeyEdit');
|
||||
cancelEditApiKey(silent, inputId = 'civitaiApiKey') {
|
||||
const editId = inputId + 'Edit';
|
||||
const statusId = inputId + 'Status';
|
||||
const editContainer = document.getElementById(editId);
|
||||
if (editContainer) editContainer.classList.add('is-hidden');
|
||||
const statusContainer = document.getElementById('civitaiApiKeyStatus');
|
||||
const statusContainer = document.getElementById(statusId);
|
||||
if (statusContainer) statusContainer.classList.remove('is-hidden');
|
||||
// Clear any typed value
|
||||
const input = document.getElementById('civitaiApiKey');
|
||||
const input = document.getElementById(inputId);
|
||||
if (input) input.value = '';
|
||||
if (!silent) {
|
||||
this.updateApiKeyStatus();
|
||||
if (inputId === 'civitaiApiKey') {
|
||||
this.updateApiKeyStatus();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async saveApiKey() {
|
||||
const input = document.getElementById('civitaiApiKey');
|
||||
async saveApiKey(settingsKey = 'civitai_api_key', inputId = 'civitaiApiKey') {
|
||||
const input = document.getElementById(inputId);
|
||||
if (!input) return;
|
||||
|
||||
const value = input.value.trim();
|
||||
|
||||
try {
|
||||
await this.saveSetting('civitai_api_key', value);
|
||||
await this.saveSetting(settingsKey, value);
|
||||
const labelName = settingsKey === 'civitai_api_key' ? 'CivitAI API Key' : 'LLM API Key';
|
||||
showToast('toast.settings.settingsUpdated',
|
||||
{ setting: 'CivitAI API Key' }, 'success');
|
||||
{ setting: labelName }, 'success');
|
||||
} catch (error) {
|
||||
showToast('toast.settings.settingSaveFailed',
|
||||
{ message: error.message }, 'error');
|
||||
@@ -2974,9 +3019,13 @@ export class SettingsManager {
|
||||
}
|
||||
|
||||
// Update the in-memory flag so the UI reflects the change
|
||||
state.global.settings.civitai_api_key_set = !!value;
|
||||
this.cancelEditApiKey(true);
|
||||
this.updateApiKeyStatus();
|
||||
if (settingsKey === 'civitai_api_key') {
|
||||
state.global.settings.civitai_api_key_set = !!value;
|
||||
}
|
||||
this.cancelEditApiKey(true, inputId);
|
||||
if (inputId === 'civitaiApiKey') {
|
||||
this.updateApiKeyStatus();
|
||||
}
|
||||
}
|
||||
|
||||
toggleInputVisibility(button) {
|
||||
|
||||
@@ -55,6 +55,10 @@ const DEFAULT_SETTINGS_BASE = Object.freeze({
|
||||
strip_lora_on_copy: false,
|
||||
use_new_license_icons: true,
|
||||
group_by_model: false,
|
||||
llm_provider: 'openai',
|
||||
llm_api_key: '',
|
||||
llm_api_base: '',
|
||||
llm_model: '',
|
||||
});
|
||||
|
||||
export function createDefaultSettings() {
|
||||
|
||||
@@ -70,6 +70,7 @@ export const BASE_MODELS = {
|
||||
ERNIE_TURBO: "Ernie Turbo",
|
||||
NUCLEUS: "Nucleus",
|
||||
PONY_V7: "Pony V7",
|
||||
KREA_2: "Krea 2",
|
||||
// Default
|
||||
UNKNOWN: "Other"
|
||||
};
|
||||
@@ -197,6 +198,7 @@ export const BASE_MODEL_ABBREVIATIONS = {
|
||||
[BASE_MODELS.ERNIE]: 'ERNI',
|
||||
[BASE_MODELS.ERNIE_TURBO]: 'ETRB',
|
||||
[BASE_MODELS.NUCLEUS]: 'NUCL',
|
||||
[BASE_MODELS.KREA_2]: 'KR2',
|
||||
|
||||
// Default
|
||||
[BASE_MODELS.UNKNOWN]: 'OTH'
|
||||
@@ -401,6 +403,7 @@ export const BASE_MODEL_CATEGORIES = {
|
||||
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
||||
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.ANIMA,
|
||||
BASE_MODELS.ERNIE, BASE_MODELS.ERNIE_TURBO, BASE_MODELS.NUCLEUS,
|
||||
BASE_MODELS.KREA_2,
|
||||
BASE_MODELS.UNKNOWN
|
||||
]
|
||||
};
|
||||
|
||||
@@ -319,6 +319,15 @@ export function openCivitai(filePath) {
|
||||
openCivitaiByMetadata(civitaiId, versionId, modelName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Open a Hugging Face model page in a new tab
|
||||
* @param {string} hfUrl - The Hugging Face URL
|
||||
*/
|
||||
export function openHuggingFace(hfUrl) {
|
||||
if (!hfUrl) return;
|
||||
window.open(hfUrl, '_blank', 'noopener,noreferrer');
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamically positions the search options panel and filter panel
|
||||
* based on the current layout and folder tags container height
|
||||
|
||||
@@ -12,6 +12,9 @@
|
||||
<div class="context-menu-item" data-action="check-updates">
|
||||
<i class="fas fa-bell"></i> <span>{{ t('loras.contextMenu.checkUpdates') }}</span>
|
||||
</div>
|
||||
<div class="context-menu-item" data-action="enrich-hf-agent">
|
||||
<i class="fas fa-wand-magic-sparkles"></i> <span>{{ t('loras.contextMenu.enrichHfAgent') }}</span>
|
||||
</div>
|
||||
<div class="context-menu-item" data-action="relink-civitai">
|
||||
<i class="fas fa-link"></i> <span>{{ t('loras.contextMenu.relinkCivitai') }}</span>
|
||||
</div>
|
||||
@@ -83,6 +86,9 @@
|
||||
<div class="context-menu-item" data-action="resume-metadata-refresh">
|
||||
<i class="fas fa-redo"></i> <span>{{ t('loras.bulkOperations.resumeMetadataRefresh') }}</span>
|
||||
</div>
|
||||
<div class="context-menu-item" data-action="enrich-hf-agent-bulk">
|
||||
<i class="fas fa-wand-magic-sparkles"></i> <span>{{ t('loras.bulkOperations.enrichHfAgent') }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="context-menu-section" data-section="workflow">
|
||||
<div class="context-menu-section-header">{{ t('loras.bulkOperations.sections.workflow') }}</div>
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
<div class="error-message" id="urlError"></div>
|
||||
<div class="input-hint">
|
||||
<i class="fas fa-info-circle"></i>
|
||||
<span>{{ t('modals.download.urlHint') }}</span>
|
||||
<span id="urlHint">{{ t('modals.download.urlHint') }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div class="modal-actions">
|
||||
|
||||
@@ -144,6 +144,96 @@
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- AI Provider Configuration (BYOK) -->
|
||||
<div class="settings-subsection">
|
||||
<div class="settings-subsection-header">
|
||||
<h4>{{ t('settings.aiProvider.title') }}</h4>
|
||||
</div>
|
||||
<div class="setting-item">
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label for="llmProvider">{{ t('settings.aiProvider.provider') }}</label>
|
||||
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.providerHelp') }}"></i>
|
||||
</div>
|
||||
<div class="setting-control select-control">
|
||||
<select id="llmProvider" onchange="settingsManager.saveSelectSetting('llmProvider', 'llm_provider')">
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="ollama">Ollama (local)</option>
|
||||
<option value="custom">{{ t('settings.aiProvider.custom') }}</option>
|
||||
</select>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="setting-item">
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label for="llmApiBase">{{ t('settings.aiProvider.apiBase') }}</label>
|
||||
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.apiBaseHelp') }}"></i>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<div class="text-input-wrapper">
|
||||
<input type="text" id="llmApiBase"
|
||||
value="{{ settings.get('llm_api_base', '') }}"
|
||||
placeholder="{{ t('settings.aiProvider.apiBasePlaceholder') }}"
|
||||
onblur="settingsManager.saveInputSetting('llmApiBase', 'llm_api_base')"
|
||||
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="setting-item api-key-item">
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label>{{ t('settings.aiProvider.apiKey') }}</label>
|
||||
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.apiKeyHelp') }}"></i>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<div id="llmApiKeyStatus" class="api-key-status">
|
||||
<span id="llmApiKeyStatusText" class="api-key-status-text api-key-status--unconfigured">
|
||||
<i class="fas fa-times-circle text-error"></i>
|
||||
{{ t('settings.aiProvider.apiKeyNotSet') }}
|
||||
</span>
|
||||
<button type="button" class="secondary-btn" id="llmApiKeyActionBtn" onclick="settingsManager.editApiKey('llm_api_key', 'llmApiKey')">
|
||||
{{ t('settings.aiProvider.apiKeySet') }}
|
||||
</button>
|
||||
</div>
|
||||
<div id="llmApiKeyEdit" class="api-key-edit is-hidden">
|
||||
<div class="api-key-input">
|
||||
<input type="text"
|
||||
id="llmApiKey"
|
||||
class="api-key-masked"
|
||||
placeholder="{{ t('settings.aiProvider.apiKeyPlaceholder') }}"
|
||||
autocomplete="off"
|
||||
data-mask="css" />
|
||||
<button type="button" class="toggle-visibility">
|
||||
<i class="fas fa-eye"></i>
|
||||
</button>
|
||||
</div>
|
||||
<button type="button" class="primary-btn" onclick="settingsManager.saveApiKey('llm_api_key', 'llmApiKey')">{{ t('common.actions.save') }}</button>
|
||||
<button type="button" class="secondary-btn" onclick="settingsManager.cancelEditApiKey(true, 'llmApiKey')">{{ t('common.actions.cancel') }}</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="setting-item">
|
||||
<div class="setting-row">
|
||||
<div class="setting-info">
|
||||
<label for="llmModel">{{ t('settings.aiProvider.model') }}</label>
|
||||
<i class="fas fa-info-circle info-icon" data-tooltip="{{ t('settings.aiProvider.modelHelp') }}"></i>
|
||||
</div>
|
||||
<div class="setting-control">
|
||||
<div class="text-input-wrapper">
|
||||
<input type="text" id="llmModel"
|
||||
value="{{ settings.get('llm_model', '') }}"
|
||||
placeholder="e.g. gpt-4o-mini"
|
||||
onblur="settingsManager.saveInputSetting('llmModel', 'llm_model')"
|
||||
onkeydown="if(event.key === 'Enter') { this.blur(); }" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="settings-subsection">
|
||||
<div class="settings-subsection-header">
|
||||
<h4>{{ t('settings.sections.downloads') }}</h4>
|
||||
|
||||
0
tests/agent_cli/__init__.py
Normal file
0
tests/agent_cli/__init__.py
Normal file
317
tests/agent_cli/test_agent_cli.py
Normal file
317
tests/agent_cli/test_agent_cli.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""Tests for the AgentCLI module (py/agent_cli/).
|
||||
|
||||
All tests mock the underlying services (scanner, MetadataManager, downloader)
|
||||
since the AgentCLI is a thin delegation layer.
|
||||
|
||||
Mock targets must match where imports are resolved inside each function
|
||||
(lazy imports via ``from X import Y`` inside function body).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.agent_cli import (
|
||||
list_base_models,
|
||||
read_metadata,
|
||||
apply_metadata_updates,
|
||||
download_preview,
|
||||
refresh_cache,
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Helpers
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class MockCache:
|
||||
def __init__(self, raw_data: list[dict] | None = None):
|
||||
self.raw_data = raw_data or []
|
||||
|
||||
|
||||
class MockScanner:
|
||||
"""Simulates a ModelScanner for testing."""
|
||||
|
||||
def __init__(self, raw_data: list[dict] | None = None):
|
||||
self._raw_data = raw_data or []
|
||||
self.update_single_model_cache = mock.AsyncMock(return_value=True)
|
||||
|
||||
async def get_cached_data(self):
|
||||
return MockCache(self._raw_data)
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# list_base_models -- imports ServiceRegistry internally
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestListBaseModels:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_cache(self):
|
||||
scanner = MockScanner([])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_merges_all_scanners(self):
|
||||
lora_scanner = MockScanner([
|
||||
{"base_model": "SDXL 1.0"},
|
||||
{"base_model": "Flux.1 D"},
|
||||
{"base_model": "SDXL 1.0"},
|
||||
])
|
||||
ckpt_scanner = MockScanner([
|
||||
{"base_model": "SDXL 1.0"},
|
||||
{"base_model": "SD 1.5"},
|
||||
])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=lora_scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=ckpt_scanner),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == ["SDXL 1.0", "Flux.1 D", "SD 1.5"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limit(self):
|
||||
scanner = MockScanner([
|
||||
{"base_model": "A"}, {"base_model": "B"}, {"base_model": "C"},
|
||||
])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models(limit=2)
|
||||
assert result == ["A", "B"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_scanners_return_none(self):
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=None),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_empty_or_missing_base_model(self):
|
||||
scanner = MockScanner([
|
||||
{"base_model": "SDXL 1.0"},
|
||||
{"file_name": "foo.safetensors"}, # no base_model key
|
||||
{"base_model": ""}, # empty
|
||||
])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await list_base_models()
|
||||
assert result == ["SDXL 1.0"]
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# read_metadata -- imports MetadataManager from py.utils.metadata_manager
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestReadMetadata:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegates_to_metadata_manager(self):
|
||||
fake = {"file_name": "test", "base_model": "SDXL 1.0"}
|
||||
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
|
||||
mm.load_metadata_payload = mock.AsyncMock(return_value=fake)
|
||||
result = await read_metadata("/p.safetensors")
|
||||
assert result == fake
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_returns_empty_dict(self):
|
||||
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
|
||||
mm.load_metadata_payload = mock.AsyncMock(side_effect=ValueError("x"))
|
||||
result = await read_metadata("/p.safetensors")
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_coerces_to_empty_dict(self):
|
||||
with mock.patch("py.utils.metadata_manager.MetadataManager") as mm:
|
||||
mm.load_metadata_payload = mock.AsyncMock(return_value=None)
|
||||
result = await read_metadata("/p.safetensors")
|
||||
assert result == {}
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# apply_metadata_updates -- uses read_metadata + MetadataManager.save_metadata
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestApplyMetadataUpdates:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_updates_field(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
mock_read.return_value = {"base_model": "", "tags": []}
|
||||
mm.save_metadata = mock.AsyncMock(return_value=True)
|
||||
updated = await apply_metadata_updates(
|
||||
"/p.safetensors", {"base_model": "Flux.1 D"}
|
||||
)
|
||||
assert updated == ["base_model"]
|
||||
mm.save_metadata.assert_awaited_once_with(
|
||||
"/p.safetensors", {"base_model": "Flux.1 D", "tags": []},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noop_when_value_unchanged(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
mock_read.return_value = {"base_model": "Flux.1 D"}
|
||||
updated = await apply_metadata_updates(
|
||||
"/p.safetensors", {"base_model": "Flux.1 D"}
|
||||
)
|
||||
assert updated == []
|
||||
mm.save_metadata.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_fields(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
mm.save_metadata = mock.AsyncMock(return_value=True)
|
||||
mock_read.return_value = {
|
||||
"base_model": "", "modelDescription": "", "tags": [],
|
||||
}
|
||||
updated = await apply_metadata_updates(
|
||||
"/p.safetensors",
|
||||
{"base_model": "SDXL 1.0", "modelDescription": "A", "tags": ["flux"]},
|
||||
)
|
||||
assert sorted(updated) == sorted(["base_model", "modelDescription", "tags"])
|
||||
saved = mm.save_metadata.call_args[0][1]
|
||||
assert saved["base_model"] == "SDXL 1.0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_updates_noop(self):
|
||||
with (
|
||||
mock.patch("py.agent_cli.read_metadata"),
|
||||
mock.patch("py.utils.metadata_manager.MetadataManager") as mm,
|
||||
):
|
||||
updated = await apply_metadata_updates("/p.safetensors", {})
|
||||
assert updated == []
|
||||
mm.save_metadata.assert_not_called()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# download_preview -- imports get_downloader + ExifUtils
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestDownloadPreview:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_url_returns_false(self, tmp_path):
|
||||
mp = tmp_path / "m.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
assert await download_preview(str(mp), "") is False
|
||||
assert await download_preview(str(mp), " ") is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_download_and_optimise(self, tmp_path):
|
||||
mp = tmp_path / "t.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
with (
|
||||
mock.patch("py.services.downloader.get_downloader") as get_dl,
|
||||
mock.patch("py.utils.exif_utils.ExifUtils") as exif,
|
||||
):
|
||||
dl = mock.AsyncMock()
|
||||
dl.download_to_memory = mock.AsyncMock(return_value=(True, b"raw", {}))
|
||||
get_dl.return_value = dl
|
||||
exif.optimize_image.return_value = (b"optimized_webp", {})
|
||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||
assert result is True
|
||||
assert (tmp_path / "t.webp").exists()
|
||||
assert (tmp_path / "t.webp").read_bytes() == b"optimized_webp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_failure_returns_false(self, tmp_path):
|
||||
mp = tmp_path / "t.safetensors"
|
||||
mp.write_bytes(b"fake")
|
||||
with mock.patch("py.services.downloader.get_downloader") as get_dl:
|
||||
dl = mock.AsyncMock()
|
||||
dl.download_to_memory = mock.AsyncMock(return_value=(False, None, {}))
|
||||
dl.download_file = mock.AsyncMock(return_value=(False, None))
|
||||
get_dl.return_value = dl
|
||||
result = await download_preview(str(mp), "https://ex.com/i.png")
|
||||
assert result is False
|
||||
assert not (tmp_path / "t.webp").exists()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# refresh_cache -- uses _find_scanner_for_model (ServiceRegistry)
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestRefreshCache:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_found_and_refreshed(self):
|
||||
scanner = MockScanner([{"file_path": "/some/path.safetensors"}])
|
||||
with (
|
||||
mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
),
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
):
|
||||
mock_read.return_value = {"base_model": "SDXL 1.0"}
|
||||
result = await refresh_cache("/some/path.safetensors")
|
||||
assert result is True
|
||||
scanner.update_single_model_cache.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_in_any_scanner(self):
|
||||
scanner = MockScanner([])
|
||||
with mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
):
|
||||
result = await refresh_cache("/nonexistent/path.safetensors")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_metadata_returns_false(self):
|
||||
scanner = MockScanner([{"file_path": "/some/path.safetensors"}])
|
||||
with (
|
||||
mock.patch(
|
||||
"py.services.service_registry.ServiceRegistry",
|
||||
get_lora_scanner=mock.AsyncMock(return_value=scanner),
|
||||
get_checkpoint_scanner=mock.AsyncMock(return_value=None),
|
||||
get_embedding_scanner=mock.AsyncMock(return_value=None),
|
||||
),
|
||||
mock.patch("py.agent_cli.read_metadata") as mock_read,
|
||||
):
|
||||
mock_read.return_value = {}
|
||||
result = await refresh_cache("/some/path.safetensors")
|
||||
assert result is False
|
||||
103
tests/frontend/utils/hfUrlDetection.test.js
Normal file
103
tests/frontend/utils/hfUrlDetection.test.js
Normal file
@@ -0,0 +1,103 @@
|
||||
import { describe, it, expect } from 'vitest';
|
||||
import { DownloadManager } from '../../../static/js/managers/DownloadManager.js';
|
||||
|
||||
describe('DownloadManager.detectUrlType — HF URL detection', () => {
|
||||
|
||||
it('detects HF resolve URL with file', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://huggingface.co/dx8152/Flux2-Klein-9B-Consistency/resolve/main/Flux2-Klein-9B-consistency-V2.safetensors'
|
||||
);
|
||||
expect(result).toEqual({
|
||||
type: 'hf-resolve',
|
||||
repo: 'dx8152/Flux2-Klein-9B-Consistency',
|
||||
revision: 'main',
|
||||
filename: 'Flux2-Klein-9B-consistency-V2.safetensors',
|
||||
});
|
||||
});
|
||||
|
||||
it('detects HF resolve URL with subdirectory file', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://huggingface.co/user/repo/resolve/main/subdir/model.safetensors'
|
||||
);
|
||||
expect(result).toEqual({
|
||||
type: 'hf-resolve',
|
||||
repo: 'user/repo',
|
||||
revision: 'main',
|
||||
filename: 'subdir/model.safetensors',
|
||||
});
|
||||
});
|
||||
|
||||
it('detects HF repo URL (full URL)', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://huggingface.co/dx8152/Flux2-Klein-9B-Consistency'
|
||||
);
|
||||
expect(result).toEqual({
|
||||
type: 'hf-repo',
|
||||
repo: 'dx8152/Flux2-Klein-9B-Consistency',
|
||||
});
|
||||
});
|
||||
|
||||
it('detects HF repo URL (bare user/repo)', () => {
|
||||
const result = DownloadManager.detectUrlType('dx8152/Flux2-Klein-9B-Consistency');
|
||||
expect(result).toEqual({
|
||||
type: 'hf-repo',
|
||||
repo: 'dx8152/Flux2-Klein-9B-Consistency',
|
||||
});
|
||||
});
|
||||
|
||||
it('detects HF repo URL with trailing slash', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://huggingface.co/user/repo/'
|
||||
);
|
||||
expect(result).toEqual({
|
||||
type: 'hf-repo',
|
||||
repo: 'user/repo',
|
||||
});
|
||||
});
|
||||
|
||||
it('detects CivitAI URL', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://civitai.com/models/123/some-model'
|
||||
);
|
||||
expect(result).toEqual({ type: 'civitai' });
|
||||
});
|
||||
|
||||
it('detects CivArchive URL', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://civarchive.com/models/456'
|
||||
);
|
||||
expect(result).toEqual({ type: 'civitai' });
|
||||
});
|
||||
|
||||
it('detects direct HTTP URL', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://example.com/file.zip'
|
||||
);
|
||||
expect(result).toEqual({ type: 'direct-http' });
|
||||
});
|
||||
|
||||
it('returns null for invalid input', () => {
|
||||
expect(DownloadManager.detectUrlType('')).toBeNull();
|
||||
expect(DownloadManager.detectUrlType(' ')).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null for unrecognized path', () => {
|
||||
expect(DownloadManager.detectUrlType('justrandomtext')).toBeNull();
|
||||
});
|
||||
|
||||
it('prefers HF resolve over repo when both match', () => {
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://huggingface.co/user/repo/resolve/main/file.safetensors'
|
||||
);
|
||||
expect(result?.type).toBe('hf-resolve');
|
||||
});
|
||||
|
||||
it('prefers CivitAI over HF when both match', () => {
|
||||
// CivitAI check comes first in detectUrlType
|
||||
// This URL should be detected as CivitAI, not HF
|
||||
const result = DownloadManager.detectUrlType(
|
||||
'https://civitai.com/models/123?huggingface.co/test/repo'
|
||||
);
|
||||
expect(result?.type).toBe('civitai');
|
||||
});
|
||||
});
|
||||
@@ -201,6 +201,45 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_list_models_filters_out_corrupted_entries(mock_service, mock_scanner):
|
||||
"""Corrupted cache entries (format_response returns None) must not appear
|
||||
in the response items nor cause a 500. See issue #730.
|
||||
"""
|
||||
mock_service.paginated_items = [
|
||||
{"file_path": "/tmp/good.safetensors", "name": "Good"},
|
||||
{"file_path": None, "name": "Corrupted"}, # triggers None from format_response
|
||||
{"file_path": "/tmp/also_good.safetensors", "name": "AlsoGood"},
|
||||
]
|
||||
|
||||
# Override format_response to return None for corrupted entries
|
||||
original_format = mock_service.format_response
|
||||
|
||||
async def conditional_format(item):
|
||||
if item.get("file_path") is None:
|
||||
return None
|
||||
return await original_format(item)
|
||||
|
||||
mock_service.format_response = conditional_format
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.get("/api/lm/test-models/list")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
# Only the 2 non-corrupted entries should appear
|
||||
assert len(payload["items"]) == 2
|
||||
assert payload["items"][0]["name"] == "Good"
|
||||
assert payload["items"][1]["name"] == "AlsoGood"
|
||||
# None should never appear in the items list
|
||||
assert None not in payload["items"]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
|
||||
mock_service.model_types = [
|
||||
{"type": "LoRa", "count": 3},
|
||||
|
||||
@@ -59,3 +59,180 @@ async def test_get_nightly_version_network_error_logs_warning(monkeypatch, caplo
|
||||
assert changelog == []
|
||||
assert "Unable to reach GitHub for nightly version" in caplog.text
|
||||
assert "Traceback" not in caplog.text
|
||||
|
||||
|
||||
def test_clean_excludes_covers_user_data_dirs():
|
||||
"""git clean must receive -e excludes for every user-managed dir."""
|
||||
excludes = update_routes._clean_excludes()
|
||||
assert "-e" in excludes # at least one exclude flag present
|
||||
for name in update_routes._PRESERVE_DIRS:
|
||||
assert name in excludes
|
||||
assert f"{name}/**" in excludes
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_git_update_preserves_user_dirs(monkeypatch, tmp_path):
|
||||
"""``git clean`` must be called with -e excludes for user data dirs.
|
||||
|
||||
Regression test for portable-mode updates wiping wildcards/, stats/,
|
||||
backups/, etc. because ``git clean -fd`` removed untracked, non-ignored
|
||||
directories.
|
||||
"""
|
||||
calls = []
|
||||
|
||||
class FakeGit:
|
||||
def reset(self, *args, **kwargs):
|
||||
calls.append(("reset", args))
|
||||
|
||||
def clean(self, *args, **kwargs):
|
||||
calls.append(("clean", args))
|
||||
|
||||
def checkout(self, *args, **kwargs):
|
||||
calls.append(("checkout", args))
|
||||
|
||||
class FakeRemote:
|
||||
def fetch(self):
|
||||
calls.append(("fetch", ()))
|
||||
|
||||
def pull(self, *args, **kwargs):
|
||||
calls.append(("pull", args))
|
||||
|
||||
class FakeRemotes:
|
||||
origin = FakeRemote()
|
||||
|
||||
class FakeCommit:
|
||||
hexsha = "abcdef123456"
|
||||
|
||||
class FakeHeads:
|
||||
def __getitem__(self, name):
|
||||
class Head:
|
||||
def checkout(self_inner):
|
||||
calls.append(("head-checkout", (name,)))
|
||||
return Head()
|
||||
|
||||
class FakeBranches:
|
||||
names = ["main"]
|
||||
|
||||
def __iter__(self):
|
||||
class B:
|
||||
name = "main"
|
||||
return iter([B()])
|
||||
|
||||
class FakeRepo:
|
||||
def __init__(self, path):
|
||||
calls.append(("repo", (path,)))
|
||||
|
||||
git = FakeGit()
|
||||
remotes = FakeRemotes()
|
||||
head = type("H", (), {"commit": FakeCommit()})()
|
||||
branches = FakeBranches()
|
||||
heads = FakeHeads()
|
||||
|
||||
def create_head(self, name, ref):
|
||||
calls.append(("create_head", (name, ref)))
|
||||
|
||||
class FakeGitModule:
|
||||
class Repo:
|
||||
def __new__(cls, path):
|
||||
return FakeRepo(path)
|
||||
|
||||
class exc:
|
||||
class GitError(Exception):
|
||||
pass
|
||||
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "git":
|
||||
return FakeGitModule
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
success, version = await update_routes.UpdateRoutes._perform_git_update(
|
||||
str(tmp_path), nightly=True
|
||||
)
|
||||
|
||||
assert success is True
|
||||
clean_calls = [c for c in calls if c[0] == "clean"]
|
||||
assert len(clean_calls) == 1
|
||||
clean_args = clean_calls[0][1]
|
||||
# Every preserved dir must be excluded via -e
|
||||
for name in update_routes._PRESERVE_DIRS:
|
||||
assert name in clean_args, f"{name} missing from git clean excludes"
|
||||
assert f"{name}/**" in clean_args, f"{name}/** missing from git clean excludes"
|
||||
# Ensure there's an -e before each name occurrence
|
||||
idx = clean_args.index(name)
|
||||
assert clean_args[idx - 1] == "-e"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_git_update_stable_preserves_user_dirs(monkeypatch, tmp_path):
|
||||
"""Stable (tag) update path must also pass -e excludes to git clean."""
|
||||
calls = []
|
||||
|
||||
class FakeGit:
|
||||
def reset(self, *args, **kwargs):
|
||||
calls.append(("reset", args))
|
||||
|
||||
def clean(self, *args, **kwargs):
|
||||
calls.append(("clean", args))
|
||||
|
||||
def checkout(self, *args, **kwargs):
|
||||
calls.append(("checkout", args))
|
||||
|
||||
class FakeRemote:
|
||||
def fetch(self):
|
||||
calls.append(("fetch", ()))
|
||||
|
||||
class FakeRemotes:
|
||||
origin = FakeRemote()
|
||||
|
||||
class FakeCommit:
|
||||
committed_datetime = "2026-01-01"
|
||||
|
||||
class FakeTag:
|
||||
name = "v9.9.9"
|
||||
commit = FakeCommit()
|
||||
|
||||
class FakeRepo:
|
||||
def __init__(self, path):
|
||||
calls.append(("repo", (path,)))
|
||||
|
||||
git = FakeGit()
|
||||
remotes = FakeRemotes()
|
||||
tags = [FakeTag()]
|
||||
|
||||
class FakeGitModule:
|
||||
class Repo:
|
||||
def __new__(cls, path):
|
||||
return FakeRepo(path)
|
||||
|
||||
class exc:
|
||||
class GitError(Exception):
|
||||
pass
|
||||
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
def fake_import(name, *args, **kwargs):
|
||||
if name == "git":
|
||||
return FakeGitModule
|
||||
return real_import(name, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
success, version = await update_routes.UpdateRoutes._perform_git_update(
|
||||
str(tmp_path), nightly=False
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert version == "v9.9.9"
|
||||
clean_calls = [c for c in calls if c[0] == "clean"]
|
||||
assert len(clean_calls) == 1
|
||||
clean_args = clean_calls[0][1]
|
||||
for name in update_routes._PRESERVE_DIRS:
|
||||
assert name in clean_args, f"{name} missing from git clean excludes (stable)"
|
||||
|
||||
237
tests/services/test_llm_service.py
Normal file
237
tests/services/test_llm_service.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for the LLMService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.llm_service import LLMService
|
||||
from py.services.errors import LLMNotConfiguredError, LLMRateLimitError, LLMResponseError
|
||||
|
||||
|
||||
class MockSettings:
|
||||
"""Minimal settings mock for LLMService tests."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._data = {
|
||||
"llm_enabled": False,
|
||||
"llm_provider": "openai",
|
||||
"llm_api_key": "",
|
||||
"llm_api_base": "",
|
||||
"llm_model": "",
|
||||
}
|
||||
self._data.update(kwargs)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._data.get(key, default)
|
||||
|
||||
|
||||
class MockResponse:
|
||||
"""Mock aiohttp response."""
|
||||
|
||||
def __init__(self, status, json_data=None, text_data="", headers=None):
|
||||
self.status = status
|
||||
self._json_data = json_data
|
||||
self._text_data = text_data
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
async def text(self):
|
||||
return self._text_data
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
class MockSession:
|
||||
"""Mock aiohttp ClientSession."""
|
||||
|
||||
def __init__(self, response):
|
||||
self._response = response
|
||||
self.closed = False
|
||||
|
||||
def post(self, url, json=None, headers=None):
|
||||
self.last_url = url
|
||||
self.last_json = json
|
||||
self.last_headers = headers
|
||||
return self._response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_service():
|
||||
"""Create an LLMService with mock settings."""
|
||||
LLMService.reset_instance()
|
||||
settings = MockSettings(
|
||||
llm_enabled=True,
|
||||
llm_provider="openai",
|
||||
llm_api_key="sk-test-key",
|
||||
llm_api_base="",
|
||||
llm_model="gpt-4o-mini",
|
||||
)
|
||||
return LLMService(settings)
|
||||
|
||||
|
||||
class TestLLMServiceConfiguration:
|
||||
def test_is_configured_when_enabled_with_key_and_model(self, llm_service):
|
||||
assert llm_service.is_configured() is True
|
||||
|
||||
def test_not_configured_when_disabled(self):
|
||||
settings = MockSettings(
|
||||
llm_enabled=False, llm_api_key="sk-test", llm_model="gpt-4o"
|
||||
)
|
||||
service = LLMService(settings)
|
||||
# Lenient: model + API key is treated as configured even without
|
||||
# the toggle, because the user clearly intends to use the feature.
|
||||
assert service.is_configured() is True
|
||||
|
||||
def test_not_configured_without_model(self):
|
||||
settings = MockSettings(llm_enabled=True, llm_api_key="sk-test", llm_model="")
|
||||
service = LLMService(settings)
|
||||
assert service.is_configured() is False
|
||||
|
||||
def test_not_configured_without_api_key_for_openai(self):
|
||||
settings = MockSettings(llm_enabled=True, llm_api_key="", llm_model="gpt-4o")
|
||||
service = LLMService(settings)
|
||||
assert service.is_configured() is False
|
||||
|
||||
def test_ollama_configured_without_api_key(self):
|
||||
settings = MockSettings(
|
||||
llm_enabled=True, llm_provider="ollama", llm_api_key="", llm_model="llama3"
|
||||
)
|
||||
service = LLMService(settings)
|
||||
assert service.is_configured() is True
|
||||
|
||||
def test_resolve_api_base_openai_default(self, llm_service):
|
||||
assert llm_service._resolve_api_base("openai", "") == "https://api.openai.com/v1"
|
||||
|
||||
def test_resolve_api_base_ollama_default(self, llm_service):
|
||||
assert llm_service._resolve_api_base("ollama", "") == "http://localhost:11434/v1"
|
||||
|
||||
def test_resolve_api_base_custom_override(self, llm_service):
|
||||
assert llm_service._resolve_api_base("custom", "https://my.api.com/v1/") == "https://my.api.com/v1"
|
||||
|
||||
def test_ensure_configured_raises_when_disabled(self):
|
||||
settings = MockSettings(llm_enabled=False)
|
||||
service = LLMService(settings)
|
||||
with pytest.raises(LLMNotConfiguredError):
|
||||
service._ensure_configured()
|
||||
|
||||
def test_ensure_configured_raises_without_model(self):
|
||||
settings = MockSettings(llm_enabled=True, llm_api_key="sk-test", llm_model="")
|
||||
service = LLMService(settings)
|
||||
with pytest.raises(LLMNotConfiguredError):
|
||||
service._ensure_configured()
|
||||
|
||||
|
||||
class TestLLMServiceChatCompletion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_success(self, llm_service):
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
json_data={
|
||||
"choices": [{"message": {"content": "Hello!"}}],
|
||||
"usage": {"total_tokens": 10},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await llm_service.chat_completion(
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
)
|
||||
|
||||
assert result["content"] == "Hello!"
|
||||
assert result["usage"]["total_tokens"] == 10
|
||||
assert result["model"] == "gpt-4o-mini"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_not_configured(self):
|
||||
settings = MockSettings(llm_enabled=False)
|
||||
service = LLMService(settings)
|
||||
with pytest.raises(LLMNotConfiguredError):
|
||||
await service.chat_completion(messages=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_http_error(self, llm_service):
|
||||
mock_response = MockResponse(500, text_data="Internal Server Error")
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMResponseError, match="HTTP 500"):
|
||||
await llm_service.chat_completion(messages=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_rate_limit(self, llm_service):
|
||||
mock_response = MockResponse(429, text_data="Rate limited", headers={"Retry-After": "0"})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMRateLimitError):
|
||||
await llm_service.chat_completion(
|
||||
messages=[], retry_on_rate_limit=False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_raises_on_bad_response_structure(self, llm_service):
|
||||
mock_response = MockResponse(200, json_data={"unexpected": "data"})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMResponseError, match="Unexpected LLM response"):
|
||||
await llm_service.chat_completion(messages=[])
|
||||
|
||||
|
||||
class TestLLMServiceChatCompletionJson:
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_json_parses_json(self, llm_service):
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
json_data={
|
||||
"choices": [{"message": {"content": '{"key": "value"}'}}],
|
||||
"usage": {},
|
||||
"model": "gpt-4o-mini",
|
||||
},
|
||||
)
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await llm_service.chat_completion_json(
|
||||
system_prompt="You are helpful.",
|
||||
user_prompt="Return JSON.",
|
||||
)
|
||||
|
||||
assert result == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_json_raises_on_non_json(self, llm_service):
|
||||
# First attempt: non-JSON; second attempt (retry): also non-JSON
|
||||
mock_response = MockResponse(
|
||||
200,
|
||||
json_data={
|
||||
"choices": [{"message": {"content": "not json at all"}}],
|
||||
"usage": {},
|
||||
},
|
||||
)
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with mock.patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(LLMResponseError, match="could not be parsed as JSON"):
|
||||
await llm_service.chat_completion_json(
|
||||
system_prompt="test",
|
||||
user_prompt="test",
|
||||
)
|
||||
313
tests/services/test_post_processor.py
Normal file
313
tests/services/test_post_processor.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Tests for the PostProcessor (py/services/agent/post_processor.py).
|
||||
|
||||
PostProcessor delegates all I/O to AgentCLI — these tests mock AgentCLI
|
||||
functions and verify the business logic (conditions, merges, dispatch).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.agent.post_processor import PostProcessor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def processor():
|
||||
return PostProcessor()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# process() — routing
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestProcessDispatch:
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_skill_returns_error(self, processor):
|
||||
result = await processor.process(
|
||||
skill_name="nonexistent",
|
||||
model_path="/p.safetensors",
|
||||
llm_output={},
|
||||
metadata={},
|
||||
)
|
||||
assert result["success"] is False
|
||||
assert "nonexistent" in result["errors"][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enrich_hf_metadata_routes_correctly(self, processor):
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||
):
|
||||
mock_apply.return_value = ["metadata_source"]
|
||||
mock_dl.return_value = False
|
||||
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output={},
|
||||
metadata={"from_civitai": True},
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# enrich_hf_metadata — field-level logic
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestEnrichHfMetadata:
|
||||
"""Business logic tests for the enrich_hf_metadata post-processor."""
|
||||
|
||||
MIN_LLM_OUTPUT = {
|
||||
"base_model": "",
|
||||
"trigger_words": [],
|
||||
"description": "",
|
||||
"tags": [],
|
||||
"preview_url": "",
|
||||
"confidence": "low",
|
||||
}
|
||||
|
||||
# -- base_model ------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_overwrites_empty(self, processor):
|
||||
"""Empty current base_model → new value is applied."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["base_model"] == "Flux.1 D"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_does_not_overwrite_existing_civitai(self, processor):
|
||||
"""Existing base_model from CivitAI → not overwritten."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": "SDXL 1.0", "from_civitai": True},
|
||||
)
|
||||
# apply IS called (metadata_source, llm_enriched_at) but base_model not in it
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert "base_model" not in applied
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_overwrites_existing_hf_model(self, processor):
|
||||
"""Existing base_model from HF → overwritten (LLM is more reliable)."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": "SD 1.5", "from_civitai": False},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["base_model"] == "Flux.1 D"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_base_model_skipped_when_llm_empty(self, processor):
|
||||
"""LLM returns empty base_model → nothing written."""
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=self.MIN_LLM_OUTPUT,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert "base_model" not in applied
|
||||
|
||||
# -- trigger_words ---------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_words_merged(self, processor):
|
||||
"""New trigger words written when current list is empty."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "trigger_words": ["trigger1", "trigger2"]}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"trainedWords": []},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["trainedWords"] == ["trigger1", "trigger2"]
|
||||
|
||||
# -- description -----------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_description_set_when_empty(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "description": "A model description"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"modelDescription": ""},
|
||||
)
|
||||
assert "modelDescription" in mock_apply.call_args[0][1]
|
||||
|
||||
# -- tags ------------------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_merged_and_deduplicated(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "tags": ["flux", "lora", "STYLE"]}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"tags": ["anime"], "from_civitai": False},
|
||||
)
|
||||
merged = mock_apply.call_args[0][1]["tags"]
|
||||
assert "anime" in merged
|
||||
assert "flux" in merged
|
||||
assert "style" in merged # lowercased
|
||||
# "lora" and "STYLE" → "lora" and "style"
|
||||
assert len(merged) == 4 # anime, flux, lora, style
|
||||
|
||||
# -- metadata_source & llm_enriched_at --------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_fields_always_set(self, processor):
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=self.MIN_LLM_OUTPUT,
|
||||
metadata={},
|
||||
)
|
||||
applied = mock_apply.call_args[0][1]
|
||||
assert applied["metadata_source"] == "agent:enrich_hf_metadata"
|
||||
assert "llm_enriched_at" in applied
|
||||
|
||||
# -- preview download ------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_downloaded_when_url_provided(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "preview_url": "https://ex.com/img.png"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates") as mock_apply,
|
||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
):
|
||||
mock_dl.return_value = True
|
||||
result = await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={},
|
||||
)
|
||||
assert result["preview_downloaded"] is True
|
||||
mock_dl.assert_awaited_once_with("/p.safetensors", "https://ex.com/img.png")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_skipped_when_exists(self, processor):
|
||||
"""If current_preview file exists on disk, skip download."""
|
||||
llm = {**self.MIN_LLM_OUTPUT, "preview_url": "https://ex.com/img.png"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates"),
|
||||
mock.patch("py.agent_cli.download_preview") as mock_dl,
|
||||
mock.patch("py.agent_cli.refresh_cache"),
|
||||
mock.patch("os.path.exists", return_value=True),
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"preview_url": "/existing/preview.webp"},
|
||||
)
|
||||
mock_dl.assert_not_called()
|
||||
|
||||
# -- cache refresh ---------------------------------------------------
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_refreshed_when_updates_applied(self, processor):
|
||||
llm = {**self.MIN_LLM_OUTPUT, "base_model": "Flux.1 D"}
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates", return_value=["base_model"]),
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=llm,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
mock_ref.assert_awaited_once_with("/p.safetensors")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_not_refreshed_when_nothing_changed(self, processor):
|
||||
with (
|
||||
mock.patch("py.agent_cli.apply_metadata_updates", return_value=[]),
|
||||
mock.patch("py.agent_cli.download_preview", return_value=False),
|
||||
mock.patch("py.agent_cli.refresh_cache") as mock_ref,
|
||||
):
|
||||
await processor.process(
|
||||
skill_name="enrich_hf_metadata",
|
||||
model_path="/p.safetensors",
|
||||
llm_output=self.MIN_LLM_OUTPUT,
|
||||
metadata={"base_model": ""},
|
||||
)
|
||||
mock_ref.assert_not_called()
|
||||
|
||||
|
||||
# ======================================================================
|
||||
# Unit: _merge_tags
|
||||
# ======================================================================
|
||||
|
||||
|
||||
class TestMergeTags:
|
||||
def test_deduplicates_case_insensitive(self):
|
||||
existing = ["anime", "Flux"]
|
||||
new = ["flux", "LORA", "anime"]
|
||||
result = PostProcessor._merge_tags(existing, new)
|
||||
# All tags are lowercased (matching TagUpdateService behaviour)
|
||||
assert result == ["anime", "flux", "lora"]
|
||||
@@ -199,8 +199,107 @@ class TestEmbeddingServiceFormatResponse:
|
||||
"from_civitai": True,
|
||||
"civitai": {},
|
||||
}
|
||||
|
||||
|
||||
result = await embedding_service.format_response(embedding_data)
|
||||
|
||||
|
||||
assert result["sub_type"] == "embedding"
|
||||
assert "model_type" not in result # Removed in refactoring
|
||||
|
||||
|
||||
class TestFormatResponseCorruptedEntries:
|
||||
"""Test format_response handles corrupted cache entries gracefully (issue #730).
|
||||
|
||||
When cache rows have None/missing critical fields (e.g. from a partially
|
||||
written or legacy DB), format_response must NOT raise KeyError/AttributeError.
|
||||
Instead it returns None so the handler layer can filter the bad entry out
|
||||
instead of failing the entire listing request.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scanner(self):
|
||||
scanner = MagicMock()
|
||||
scanner._hash_index = MagicMock()
|
||||
return scanner
|
||||
|
||||
@pytest.fixture
|
||||
def lora_service(self, mock_scanner):
|
||||
return LoraService(mock_scanner)
|
||||
|
||||
@pytest.fixture
|
||||
def checkpoint_service(self, mock_scanner):
|
||||
return CheckpointService(mock_scanner)
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_service(self, mock_scanner):
|
||||
return EmbeddingService(mock_scanner)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lora_returns_none_on_missing_file_path(self, lora_service):
|
||||
"""format_response returns None when file_path is missing (corrupted row)."""
|
||||
lora_data = {
|
||||
"model_name": "Test LoRA",
|
||||
"file_name": "test_lora",
|
||||
"file_path": None, # corrupted: missing file_path
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"civitai": {},
|
||||
}
|
||||
result = await lora_service.format_response(lora_data)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lora_handles_none_model_name_gracefully(self, lora_service):
|
||||
"""format_response should not crash when model_name is None (legacy DB row)."""
|
||||
lora_data = {
|
||||
"model_name": None, # NULL from old DB row
|
||||
"file_name": "test_lora",
|
||||
"file_path": "/models/test_lora.safetensors",
|
||||
"folder": "",
|
||||
"sha256": "abc123",
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"civitai": {},
|
||||
}
|
||||
result = await lora_service.format_response(lora_data)
|
||||
# Should not raise; model_name falls back to file_name
|
||||
assert result is not None
|
||||
assert result["model_name"] == "test_lora"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint_returns_none_on_missing_file_path(self, checkpoint_service):
|
||||
"""format_response returns None when file_path is missing (corrupted row)."""
|
||||
checkpoint_data = {
|
||||
"model_name": "Test",
|
||||
"file_name": "test",
|
||||
"file_path": "", # empty string == corrupted
|
||||
"folder": "",
|
||||
"sha256": "abc",
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"civitai": {},
|
||||
"sub_type": "checkpoint",
|
||||
}
|
||||
result = await checkpoint_service.format_response(checkpoint_data)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_handles_none_fields_gracefully(self, embedding_service):
|
||||
"""format_response should not crash when optional fields are None."""
|
||||
embedding_data = {
|
||||
"model_name": None,
|
||||
"file_name": None,
|
||||
"file_path": "/models/test.pt",
|
||||
"folder": None,
|
||||
"sha256": "abc",
|
||||
"tags": [],
|
||||
"from_civitai": True,
|
||||
"civitai": {},
|
||||
"sub_type": "embedding",
|
||||
}
|
||||
result = await embedding_service.format_response(embedding_data)
|
||||
assert result is not None
|
||||
assert result["file_path"] == "/models/test.pt"
|
||||
# model_name falls back to file_name which falls back to ""
|
||||
assert result["model_name"] == ""
|
||||
|
||||
@@ -200,52 +200,97 @@ def _setup_storage_paths(tmp_path, monkeypatch):
|
||||
return project_root, user_dir, user_settings_path
|
||||
|
||||
|
||||
def _populate_cache(root_dir, marker_name, db_text):
|
||||
cache_dir = root_dir / "model_cache"
|
||||
cache_dir.mkdir(exist_ok=True)
|
||||
marker_file = cache_dir / marker_name
|
||||
marker_file.write_text(marker_name, encoding="utf-8")
|
||||
(root_dir / "model_cache.sqlite").write_text(db_text, encoding="utf-8")
|
||||
def _populate_settings_dir(root_dir):
|
||||
"""Create test data for all managed subdirectories under a settings directory."""
|
||||
(root_dir / "cache" / "symlink").mkdir(parents=True, exist_ok=True)
|
||||
(root_dir / "cache" / "symlink" / "symlink_map.json").write_text(
|
||||
'{"migrated": true}', encoding="utf-8"
|
||||
)
|
||||
(root_dir / "backups").mkdir(parents=True, exist_ok=True)
|
||||
(root_dir / "backups" / "backup_test.zip").write_text(
|
||||
"backup", encoding="utf-8"
|
||||
)
|
||||
(root_dir / "logs").mkdir(parents=True, exist_ok=True)
|
||||
(root_dir / "logs" / "session.log").write_text("log", encoding="utf-8")
|
||||
(root_dir / "stats").mkdir(parents=True, exist_ok=True)
|
||||
(root_dir / "stats" / "stats.json").write_text(
|
||||
'{"stats": true}', encoding="utf-8"
|
||||
)
|
||||
(root_dir / "wildcards").mkdir(parents=True, exist_ok=True)
|
||||
(root_dir / "wildcards" / "test.txt").write_text("wildcard", encoding="utf-8")
|
||||
|
||||
|
||||
def test_switch_to_portable_mode_copies_cache(tmp_path, monkeypatch):
|
||||
def test_switch_to_portable_mode_copies_subdirectories(tmp_path, monkeypatch):
|
||||
project_root, user_dir, user_settings = _setup_storage_paths(tmp_path, monkeypatch)
|
||||
_populate_cache(user_dir, "user_marker.txt", "user_db")
|
||||
_populate_settings_dir(user_dir)
|
||||
|
||||
manager = SettingsManager()
|
||||
|
||||
manager.set("use_portable_settings", True)
|
||||
|
||||
assert manager.settings_file == str(project_root / "settings.json")
|
||||
marker_copy = project_root / "model_cache" / "user_marker.txt"
|
||||
assert marker_copy.read_text(encoding="utf-8") == "user_marker.txt"
|
||||
assert (project_root / "model_cache.sqlite").read_text(
|
||||
# Managed subdirectories should all be migrated
|
||||
assert (
|
||||
project_root / "cache" / "symlink" / "symlink_map.json"
|
||||
).read_text(encoding="utf-8") == '{"migrated": true}'
|
||||
assert (
|
||||
project_root / "backups" / "backup_test.zip"
|
||||
).read_text(encoding="utf-8") == "backup"
|
||||
assert (project_root / "logs" / "session.log").read_text(
|
||||
encoding="utf-8"
|
||||
) == "user_db"
|
||||
) == "log"
|
||||
assert (project_root / "stats" / "stats.json").read_text(
|
||||
encoding="utf-8"
|
||||
) == '{"stats": true}'
|
||||
assert (project_root / "wildcards" / "test.txt").read_text(
|
||||
encoding="utf-8"
|
||||
) == "wildcard"
|
||||
assert user_settings.exists()
|
||||
|
||||
|
||||
def test_switching_back_to_user_config_moves_cache(tmp_path, monkeypatch):
|
||||
def test_switching_back_to_user_config_moves_subdirectories(tmp_path, monkeypatch):
|
||||
project_root, user_dir, user_settings = _setup_storage_paths(tmp_path, monkeypatch)
|
||||
_populate_cache(user_dir, "user_marker.txt", "user_db")
|
||||
_populate_settings_dir(user_dir)
|
||||
|
||||
manager = SettingsManager()
|
||||
manager.set("use_portable_settings", True)
|
||||
|
||||
project_cache_dir = project_root / "model_cache"
|
||||
project_cache_dir.mkdir(exist_ok=True)
|
||||
(project_cache_dir / "project_marker.txt").write_text(
|
||||
"project_marker", encoding="utf-8"
|
||||
# Populate project-root managed subdirectories
|
||||
(project_root / "cache" / "model").mkdir(parents=True, exist_ok=True)
|
||||
(project_root / "cache" / "model" / "default.sqlite").write_text(
|
||||
"project_db", encoding="utf-8"
|
||||
)
|
||||
(project_root / "backups" / "project_backup.zip").write_text(
|
||||
"project_backup", encoding="utf-8"
|
||||
)
|
||||
(project_root / "logs" / "project.log").write_text(
|
||||
"project_log", encoding="utf-8"
|
||||
)
|
||||
(project_root / "stats" / "project_stats.json").write_text(
|
||||
'{"project": true}', encoding="utf-8"
|
||||
)
|
||||
(project_root / "wildcards" / "project.txt").write_text(
|
||||
"project_wildcard", encoding="utf-8"
|
||||
)
|
||||
(project_root / "model_cache.sqlite").write_text("project_db", encoding="utf-8")
|
||||
|
||||
manager.set("use_portable_settings", False)
|
||||
|
||||
assert manager.settings_file == str(user_settings)
|
||||
assert (user_dir / "model_cache" / "project_marker.txt").read_text(
|
||||
assert (user_dir / "cache" / "model" / "default.sqlite").read_text(
|
||||
encoding="utf-8"
|
||||
) == "project_marker"
|
||||
assert (user_dir / "model_cache.sqlite").read_text(encoding="utf-8") == "project_db"
|
||||
) == "project_db"
|
||||
assert (user_dir / "backups" / "project_backup.zip").read_text(
|
||||
encoding="utf-8"
|
||||
) == "project_backup"
|
||||
assert (user_dir / "logs" / "project.log").read_text(
|
||||
encoding="utf-8"
|
||||
) == "project_log"
|
||||
assert (user_dir / "stats" / "project_stats.json").read_text(
|
||||
encoding="utf-8"
|
||||
) == '{"project": true}'
|
||||
assert (user_dir / "wildcards" / "project.txt").read_text(
|
||||
encoding="utf-8"
|
||||
) == "project_wildcard"
|
||||
|
||||
|
||||
def test_download_path_template_parses_json_string(manager):
|
||||
|
||||
88
tests/services/test_skill_registry.py
Normal file
88
tests/services/test_skill_registry.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Tests for the SkillRegistry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.agent.skill_registry import SkillRegistry
|
||||
from py.services.agent.skill_definition import SkillDefinition, SkillPermissions
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def registry():
|
||||
"""Create a SkillRegistry with the real skills directory."""
|
||||
SkillRegistry.reset_instance()
|
||||
reg = SkillRegistry()
|
||||
reg._discover()
|
||||
return reg
|
||||
|
||||
|
||||
class TestSkillRegistryDiscovery:
|
||||
def test_discovers_enrich_hf_metadata_skill(self, registry):
|
||||
skills = registry.list_skills()
|
||||
assert len(skills) >= 1
|
||||
skill = registry.get_skill("enrich_hf_metadata")
|
||||
assert skill is not None
|
||||
assert skill.name == "enrich_hf_metadata"
|
||||
assert skill.llm_required is True
|
||||
|
||||
def test_skill_has_correct_model_type_filter(self, registry):
|
||||
skill = registry.get_skill("enrich_hf_metadata")
|
||||
# model_type_filter was removed from SKILL.md — defaults to None (all types)
|
||||
assert skill.model_type_filter is None
|
||||
|
||||
def test_skill_has_permissions(self, registry):
|
||||
skill = registry.get_skill("enrich_hf_metadata")
|
||||
assert skill.permissions.write_metadata is True
|
||||
assert skill.permissions.write_previews is True
|
||||
# network_domains defaults to () since permissions block was removed
|
||||
|
||||
def test_get_skill_returns_none_for_unknown(self, registry):
|
||||
assert registry.get_skill("nonexistent_skill") is None
|
||||
|
||||
|
||||
class TestSkillRegistryLoading:
|
||||
def test_load_prompt_returns_content(self, registry):
|
||||
prompt = registry.load_prompt("enrich_hf_metadata")
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 100
|
||||
assert "base_model" in prompt
|
||||
assert "trigger_words" in prompt
|
||||
|
||||
def test_load_prompt_raises_for_unknown_skill(self, registry):
|
||||
with pytest.raises((FileNotFoundError, ValueError)):
|
||||
registry.load_prompt("nonexistent")
|
||||
|
||||
|
||||
class TestSkillDefinition:
|
||||
def test_applies_to_model_type_with_filter(self):
|
||||
sd = SkillDefinition(
|
||||
name="test",
|
||||
title="Test",
|
||||
description="",
|
||||
llm_required=False,
|
||||
model_type_filter=["lora"],
|
||||
)
|
||||
assert sd.applies_to_model_type("lora") is True
|
||||
assert sd.applies_to_model_type("checkpoint") is False
|
||||
|
||||
def test_applies_to_model_type_without_filter(self):
|
||||
sd = SkillDefinition(
|
||||
name="test",
|
||||
title="Test",
|
||||
description="",
|
||||
llm_required=False,
|
||||
model_type_filter=None,
|
||||
)
|
||||
assert sd.applies_to_model_type("lora") is True
|
||||
assert sd.applies_to_model_type("checkpoint") is True
|
||||
|
||||
|
||||
class TestSkillPermissions:
|
||||
def test_defaults(self):
|
||||
sp = SkillPermissions()
|
||||
assert sp.write_metadata is True
|
||||
assert sp.write_previews is True
|
||||
assert sp.network_domains == ()
|
||||
Reference in New Issue
Block a user