Files
ComfyUI-Lora-Manager/py/routes/handlers/hf_handlers.py
Will Miao cf898da193 feat(agent): add LLM-powered metadata enrichment system with AgentCLI and PostProcessor
Introduce an agent skill framework for LLM-driven metadata enrichment:

- AgentCLI (py/agent_cli/): in-process wrappers around internal services
  using standard relative imports, eliminating the need for sys.path hacks
- LLMService: centralized BYOK (bring-your-own-key) LLM client supporting
  OpenAI, Ollama, and custom OpenAI-compatible endpoints
- PostProcessor: deterministic engine that applies LLM output via AgentCLI
  (replaces old handler.py + _BASE_MODEL_ALIASES approach)
- SkillRegistry: filesystem-based skill discovery (skill.yaml + prompt.md)
- AgentService: orchestrates skill execution with WebSocket progress
- Frontend AgentManager: WebSocket listeners, skill execution, config UI
- Context menu entries (single + bulk) for "Enrich Metadata (Agent)"
- Settings UI for AI Provider configuration (BYOK)
- Full i18n support across 9 locales

Bug fixes found during review:
- aiohttp.web.json_response: status_code= -> status=
- settings_modal cancelEditApiKey: wrong argument position
- AgentManager.isLlmConfigured: allow Ollama without API key
- PostProcessor._merge_tags: lowercase all tags to match TagUpdateService
2026-07-02 21:27:01 +08:00

418 lines
17 KiB
Python

"""Handlers for Hugging Face model listing and download.
Minimal MVP implementation — uses direct HTTP to the HF API for file
listing and the project's existing aiohttp-based Downloader for
downloading. No huggingface_hub dependency required.
"""
from __future__ import annotations
import json
import logging
import os
import re
from typing import Any
import aiohttp
from aiohttp import web
from ...config import config
from ...services.downloader import (
DownloadProgress,
get_downloader,
)
from ...services.aria2_downloader import Aria2Downloader
from ...services.settings_manager import get_settings_manager
from ...services.service_registry import ServiceRegistry
from ...services.websocket_manager import ws_manager
from ...utils.constants import MODEL_FILE_EXTENSIONS
from ...utils.metadata_manager import MetadataManager
from ...utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
logger = logging.getLogger(__name__)
_DEFAULT_MODEL_CLASS = LoraMetadata
_DEFAULT_SCANNER_GETTER = "get_lora_scanner"
# Shared aiohttp session for HF API calls (created on first use)
_hf_api_session: aiohttp.ClientSession | None = None
async def _get_hf_api_session() -> aiohttp.ClientSession:
"""Get or create the shared aiohttp session for HF API calls."""
global _hf_api_session # needed because we reassign the module-level name
if _hf_api_session is None or _hf_api_session.closed:
_hf_api_session = aiohttp.ClientSession(
headers={"User-Agent": "ComfyUI-LoRA-Manager/1.0"},
timeout=aiohttp.ClientTimeout(total=30),
)
return _hf_api_session
async def close_hf_api_session() -> None:
"""Close the shared HF API session, if it was ever created."""
global _hf_api_session
if _hf_api_session is not None and not _hf_api_session.closed:
await _hf_api_session.close()
_hf_api_session = None
def _infer_model_type(model_root: str) -> tuple[Any, str]:
"""Determine model class and scanner by matching ``model_root`` against the
configured root paths for each model type (from ``Config``).
The ``model_root`` value comes from the frontend's model-root dropdown,
which is populated from the current page's scanner roots. By checking
which scanner's root list it belongs to, we avoid fragile heuristics
like substring-matching path names.
"""
norm = os.path.normpath(model_root).replace(os.sep, "/")
# LoRA roots
for p in (config.loras_roots or []) + (config.extra_loras_roots or []):
if os.path.normpath(p).replace(os.sep, "/") == norm:
return LoraMetadata, "get_lora_scanner"
# Checkpoint / UNet roots
for p in (
(config.checkpoints_roots or [])
+ (config.extra_checkpoints_roots or [])
+ (config.unet_roots or [])
+ (config.extra_unet_roots or [])
):
if os.path.normpath(p).replace(os.sep, "/") == norm:
return CheckpointMetadata, "get_checkpoint_scanner"
# Embedding roots
for p in (config.embeddings_roots or []) + (config.extra_embeddings_roots or []):
if os.path.normpath(p).replace(os.sep, "/") == norm:
return EmbeddingMetadata, "get_embedding_scanner"
# Fallback — should not happen in normal use
logger.warning(
"Could not determine model type for root '%s'; defaulting to LoRA",
model_root,
)
return _DEFAULT_MODEL_CLASS, _DEFAULT_SCANNER_GETTER
async def _save_hf_metadata(dest_path: str, repo: str, model_root: str) -> None:
"""Create a proper .metadata.json and add the model to the scanner cache.
Uses ``MetadataManager.create_default_metadata()`` which computes the
SHA256 hash, extracts safetensors header metadata (base_model), and
produces a fully-populated ``LoraMetadata`` (or ``CheckpointMetadata`` /
``EmbeddingMetadata``) object. We then overlay HF-specific fields and
register the model in the in-memory scanner cache so it appears
immediately without a full filesystem walk.
"""
try:
hf_url = f"https://huggingface.co/{repo}"
model_class, scanner_getter_name = _infer_model_type(model_root)
# 1. Create proper metadata (computes SHA256, reads safetensors headers)
metadata = await MetadataManager.create_default_metadata(
dest_path, model_class=model_class
)
if metadata is None:
logger.warning("create_default_metadata returned None for %s", dest_path)
return
# 2. Overlay HF-specific fields
metadata._unknown_fields["hf_url"] = hf_url
metadata.from_civitai = False # HF models are not from CivitAI
# 3. Save metadata atomically
await MetadataManager.save_metadata(dest_path, metadata)
logger.info("Saved HF metadata (with hf_url) for %s", dest_path)
# 4. Determine relative folder path for cache
# model_root is an absolute path; dest_path is under it
folder = ""
if os.path.isabs(model_root) and dest_path.startswith(model_root):
rel = os.path.relpath(os.path.dirname(dest_path), model_root)
folder = rel.replace(os.sep, "/") if rel != "." else ""
# 5. Add to scanner cache (same as CivitAI's _execute_download does)
scanner_getter = getattr(ServiceRegistry, scanner_getter_name, None)
if scanner_getter is not None:
scanner = await scanner_getter()
if scanner is not None:
metadata_dict = metadata.to_dict()
metadata_dict["hf_url"] = hf_url
await scanner.add_model_to_cache(metadata_dict, folder)
logger.info("Added %s to scanner cache (folder=%s)", dest_path, folder)
except Exception as exc:
logger.warning("Failed to save HF metadata for %s: %s", dest_path, exc)
class HfHandler:
"""Handle Hugging Face model browsing and download."""
async def get_hf_repo_files(self, request: web.Request) -> web.Response:
"""List model-weight files from a HF repo with real file sizes.
Uses the HF tree API endpoint which returns accurate file sizes
(including LFS-tracked files), unlike the model info endpoint.
"""
repo = request.query.get("repo", "").strip()
if not repo or "/" not in repo:
return web.json_response(
{"error": "Missing or invalid 'repo' parameter (expected user/repo)"},
status=400,
)
url = f"https://huggingface.co/api/models/{repo}/tree/main"
try:
session = await _get_hf_api_session()
async with session.get(url) as resp:
if resp.status == 404:
return web.json_response(
{"error": f"Repo '{repo}' not found"}, status=404
)
if resp.status != 200:
text = await resp.text()
return web.json_response(
{"error": f"HF API error {resp.status}: {text[:200]}"},
status=resp.status,
)
tree: list[dict[str, Any]] = await resp.json()
except Exception as exc:
logger.error("Failed to fetch HF repo files: %s", exc)
return web.json_response({"error": str(exc)}, status=502)
files: list[dict[str, Any]] = []
for entry in tree:
path: str = entry.get("path", "")
ext = os.path.splitext(path)[1].lower()
if ext not in MODEL_FILE_EXTENSIONS:
continue
size = entry.get("size", 0) or 0
if size == 0 and "lfs" in entry:
size = entry["lfs"].get("size", 0) or 0
files.append({
"filename": path,
"size": size,
})
files.sort(key=lambda f: f["size"], reverse=True)
return web.json_response(files)
async def download_hf_model(self, request: web.Request) -> web.Response:
"""Download a single file from Hugging Face into the model directory.
POST JSON body::
{
"repo": "dx8152/Flux2-Klein-9B-Consistency",
"filename": "Flux2-Klein-9B-consistency-V2.safetensors",
"revision": "main",
"model_root": "loras",
"relative_path": "",
"use_default_paths": false,
"download_id": "optional-batch-id"
}
If ``download_id`` is provided, real-time progress (bytes, speed,
percentage) is broadcast via the WebSocket progress system, matching
the CivitAI download experience.
Respects the ``download_backend`` setting (``aria2`` or ``default``).
"""
try:
payload: dict[str, Any] = await request.json()
except json.JSONDecodeError:
return web.json_response({"error": "Invalid JSON"}, status=400)
repo = (payload.get("repo") or "").strip()
filename = (payload.get("filename") or "").strip()
revision = (payload.get("revision") or "main").strip()
model_root = (payload.get("model_root") or "").strip()
relative_path = (payload.get("relative_path") or "").strip()
use_default_paths = bool(payload.get("use_default_paths", False))
download_id: str | None = payload.get("download_id")
logger.info(
"download_hf_model: repo=%s file=%s root=%s download_id=%s",
repo, filename, model_root, download_id,
)
if not repo or not filename:
return web.json_response(
{"error": "Missing required fields: 'repo' and 'filename'"}, status=400
)
# Validate repo format — must be user/repo_name
if repo.count("/") != 1 or not re.match(r"^[a-zA-Z0-9_.-]+/[a-zA-Z0-9_.-]+$", repo):
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
author, repo_name = repo.split("/", 1)
if ".." in (author, repo_name) or "." in (author, repo_name):
return web.json_response({"error": f"Invalid repo format: {repo}"}, status=400)
# Validate filename — must not contain path separators or ..
if "/" in filename or "\\" in filename or ".." in filename:
return web.json_response({"error": "Invalid filename"}, status=400)
# Validate relative_path — must not be absolute or escape base directory
if relative_path:
if os.path.isabs(relative_path):
return web.json_response({"error": "relative_path must not be absolute"}, status=400)
if ".." in relative_path.split("/") or "\\" in relative_path:
return web.json_response({"error": "Invalid relative_path"}, status=400)
# Validate model_root — must not contain path traversal
if not os.path.isabs(model_root):
# For relative model_root, check it doesn't escape
resolved_model_root = os.path.realpath(
os.path.join(os.getcwd(), "models", model_root)
)
else:
resolved_model_root = os.path.realpath(model_root)
# Verify model_root is within a configured scanner root
allowed_roots = set()
for root_list in (
config.loras_roots or [],
config.extra_loras_roots or [],
config.checkpoints_roots or [],
config.extra_checkpoints_roots or [],
config.unet_roots or [],
config.extra_unet_roots or [],
config.embeddings_roots or [],
config.extra_embeddings_roots or [],
):
for r in root_list:
allowed_roots.add(os.path.realpath(r))
if not any(resolved_model_root == root or resolved_model_root.startswith(root + os.sep) for root in allowed_roots):
logger.warning("Invalid model_root rejected: %s", model_root)
return web.json_response({"error": f"Invalid model_root: {model_root}"}, status=400)
base_dir = resolved_model_root
if use_default_paths:
target_dir = os.path.join(base_dir, "huggingface", author, repo_name)
elif relative_path:
target_dir = os.path.join(base_dir, relative_path)
else:
target_dir = base_dir
os.makedirs(target_dir, exist_ok=True)
dest_path = os.path.join(target_dir, filename)
# Resolve symlinks and check for path traversal escape
real_dest = os.path.realpath(dest_path)
real_base = os.path.realpath(target_dir)
if not real_dest.startswith(real_base + os.sep):
logger.warning("Path traversal blocked: %s -> %s", dest_path, real_dest)
return web.json_response({"error": "Path traversal detected"}, status=400)
# Check if already exists (simple skip)
if os.path.exists(dest_path) and os.path.getsize(dest_path) > 0:
logger.info("download_hf_model: file already exists, skipping — %s", dest_path)
return web.json_response({
"success": True,
"message": f"File already exists: {dest_path}",
"path": dest_path,
})
# Build HF resolve URL
resolve_url = (
f"https://huggingface.co/{repo}/resolve/{revision}/{filename}"
)
# Set up progress callback if download_id is provided
progress_callback = None
if download_id:
async def _progress_callback(
progress: float | DownloadProgress,
snapshot: DownloadProgress | None = None,
) -> None:
percent = 0.0
metrics = snapshot if isinstance(snapshot, DownloadProgress) else None
if isinstance(progress, DownloadProgress):
percent = progress.percent_complete
metrics = progress
elif isinstance(snapshot, DownloadProgress):
percent = snapshot.percent_complete
else:
percent = float(progress)
broadcast: dict[str, Any] = {
"status": "progress",
"progress": round(percent),
}
if metrics:
broadcast["bytes_downloaded"] = metrics.bytes_downloaded
broadcast["total_bytes"] = metrics.total_bytes
broadcast["bytes_per_second"] = metrics.bytes_per_second
await ws_manager.broadcast_download_progress(download_id, broadcast)
progress_callback = _progress_callback
# Respect download backend setting (aria2 vs default)
download_backend = (
get_settings_manager().get("download_backend", "default")
)
if download_backend == "aria2":
aria2 = await Aria2Downloader.get_instance()
aid = download_id or f"hf_{repo}_{filename}"
try:
hf_success, hf_result = await aria2.download_file(
url=resolve_url,
save_path=dest_path,
download_id=aid,
progress_callback=progress_callback,
)
if hf_success:
await _save_hf_metadata(dest_path, repo, model_root)
return web.json_response({
"success": True,
"message": f"Downloaded to {dest_path}",
"path": dest_path,
})
else:
return web.json_response(
{"success": False, "error": hf_result or "aria2 download failed"},
status=500,
)
except Exception as exc:
logger.error("HF download (aria2) failed: %s", exc)
return web.json_response(
{"success": False, "error": str(exc)}, status=500
)
# Default: use built-in aiohttp Downloader
downloader = await get_downloader()
try:
success, result = await downloader.download_file(
url=resolve_url,
save_path=dest_path,
use_auth=False,
allow_resume=True,
progress_callback=progress_callback,
)
if success:
await _save_hf_metadata(dest_path, repo, model_root)
return web.json_response({
"success": True,
"message": f"Downloaded to {result}",
"path": result,
})
else:
return web.json_response(
{"success": False, "error": result or "Download failed"},
status=500,
)
except Exception as exc:
logger.error("HF download failed: %s", exc)
return web.json_response(
{"success": False, "error": str(exc)}, status=500
)