Files
ComfyUI-Lora-Manager/py/services/recipes/persistence_service.py
Will Miao 3f646aa0c9 feat: add recipe root directory and move recipe endpoints
- Add GET /api/lm/recipes/roots endpoint to retrieve recipe root directories
- Add POST /api/lm/recipe/move endpoint to move recipes between directories
- Register new endpoints in route definitions
- Implement error handling for both new endpoints with proper status codes
- Enable recipe management operations for better file organization
2025-12-23 09:13:57 +08:00

517 lines
20 KiB
Python

"""Services encapsulating recipe persistence workflows."""
from __future__ import annotations
import base64
import json
import os
import re
import shutil
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Optional
from ...config import config
from ...utils.utils import calculate_recipe_fingerprint
from .errors import RecipeNotFoundError, RecipeValidationError
@dataclass(frozen=True)
class PersistenceResult:
"""Return payload from persistence operations."""
payload: dict[str, Any]
status: int = 200
class RecipePersistenceService:
"""Coordinate recipe persistence tasks across storage and caches."""
def __init__(
self,
*,
exif_utils,
card_preview_width: int,
logger,
) -> None:
self._exif_utils = exif_utils
self._card_preview_width = card_preview_width
self._logger = logger
async def save_recipe(
self,
*,
recipe_scanner,
image_bytes: bytes | None,
image_base64: str | None,
name: str | None,
tags: Iterable[str],
metadata: Optional[dict[str, Any]],
extension: str | None = None,
) -> PersistenceResult:
"""Persist a user uploaded recipe."""
missing_fields = []
if not name:
missing_fields.append("name")
if metadata is None:
missing_fields.append("metadata")
if missing_fields:
raise RecipeValidationError(
f"Missing required fields: {', '.join(missing_fields)}"
)
resolved_image_bytes = self._resolve_image_bytes(image_bytes, image_base64)
recipes_dir = recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)
recipe_id = str(uuid.uuid4())
# Handle video formats by bypassing optimization and metadata embedding
is_video = extension in [".mp4", ".webm"]
if is_video:
optimized_image = resolved_image_bytes
# extension is already set
else:
optimized_image, extension = self._exif_utils.optimize_image(
image_data=resolved_image_bytes,
target_width=self._card_preview_width,
format="webp",
quality=85,
preserve_metadata=True,
)
image_filename = f"{recipe_id}{extension}"
image_path = os.path.join(recipes_dir, image_filename)
normalized_image_path = os.path.normpath(image_path)
with open(normalized_image_path, "wb") as file_obj:
file_obj.write(optimized_image)
current_time = time.time()
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
gen_params = metadata.get("gen_params") or {}
if not gen_params and "raw_metadata" in metadata:
raw_metadata = metadata.get("raw_metadata", {})
gen_params = {
"prompt": raw_metadata.get("prompt", ""),
"negative_prompt": raw_metadata.get("negative_prompt", ""),
"steps": raw_metadata.get("steps", ""),
"sampler": raw_metadata.get("sampler", ""),
"cfg_scale": raw_metadata.get("cfg_scale", ""),
"seed": raw_metadata.get("seed", ""),
"size": raw_metadata.get("size", ""),
"clip_skip": raw_metadata.get("clip_skip", ""),
}
# Drop checkpoint duplication from generation parameters to store it only at top level
gen_params.pop("checkpoint", None)
fingerprint = calculate_recipe_fingerprint(loras_data)
recipe_data: Dict[str, Any] = {
"id": recipe_id,
"file_path": normalized_image_path,
"title": name,
"modified": current_time,
"created_date": current_time,
"base_model": metadata.get("base_model", ""),
"loras": loras_data,
"gen_params": gen_params,
"fingerprint": fingerprint,
}
if checkpoint_entry:
recipe_data["checkpoint"] = checkpoint_entry
tags_list = list(tags)
if tags_list:
recipe_data["tags"] = tags_list
if metadata.get("source_path"):
recipe_data["source_path"] = metadata.get("source_path")
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
json_path = os.path.normpath(json_path)
with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
if not is_video:
self._exif_utils.append_recipe_metadata(normalized_image_path, recipe_data)
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
await recipe_scanner.add_recipe(recipe_data)
return PersistenceResult(
{
"success": True,
"recipe_id": recipe_id,
"image_path": normalized_image_path,
"json_path": json_path,
"matching_recipes": matching_recipes,
}
)
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
"""Delete an existing recipe."""
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
raise RecipeNotFoundError("Recipe not found")
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
recipe_data = json.load(file_obj)
image_path = recipe_data.get("file_path")
os.remove(recipe_json_path)
if image_path and os.path.exists(image_path):
os.remove(image_path)
await recipe_scanner.remove_recipe(recipe_id)
return PersistenceResult({"success": True, "message": "Recipe deleted successfully"})
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
"""Update persisted metadata for a recipe."""
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")):
raise RecipeValidationError(
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
)
success = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
if not success:
raise RecipeNotFoundError("Recipe not found or update failed")
return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates})
async def move_recipe(self, *, recipe_scanner, recipe_id: str, target_path: str) -> PersistenceResult:
"""Move a recipe's assets into a new folder under the recipes root."""
if not target_path:
raise RecipeValidationError("Target path is required")
recipes_root = recipe_scanner.recipes_dir
if not recipes_root:
raise RecipeNotFoundError("Recipes directory not found")
normalized_target = os.path.normpath(target_path)
recipes_root = os.path.normpath(recipes_root)
if not os.path.isabs(normalized_target):
normalized_target = os.path.normpath(os.path.join(recipes_root, normalized_target))
try:
common_root = os.path.commonpath([normalized_target, recipes_root])
except ValueError as exc:
raise RecipeValidationError("Invalid target path") from exc
if common_root != recipes_root:
raise RecipeValidationError("Target path must be inside the recipes directory")
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
raise RecipeNotFoundError("Recipe not found")
recipe_data = await recipe_scanner.get_recipe_by_id(recipe_id)
if not recipe_data:
raise RecipeNotFoundError("Recipe not found")
current_json_dir = os.path.dirname(recipe_json_path)
normalized_image_path = os.path.normpath(recipe_data.get("file_path") or "") if recipe_data.get("file_path") else None
os.makedirs(normalized_target, exist_ok=True)
if os.path.normpath(current_json_dir) == normalized_target:
return PersistenceResult(
{
"success": True,
"message": "Recipe is already in the target folder",
"recipe_id": recipe_id,
"original_file_path": recipe_data.get("file_path"),
"new_file_path": recipe_data.get("file_path"),
}
)
new_json_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(recipe_json_path)))
shutil.move(recipe_json_path, new_json_path)
new_image_path = normalized_image_path
if normalized_image_path:
target_image_path = os.path.normpath(os.path.join(normalized_target, os.path.basename(normalized_image_path)))
if os.path.exists(normalized_image_path) and normalized_image_path != target_image_path:
shutil.move(normalized_image_path, target_image_path)
new_image_path = target_image_path
relative_folder = os.path.relpath(normalized_target, recipes_root)
if relative_folder in (".", ""):
relative_folder = ""
updates = {"file_path": new_image_path or recipe_data.get("file_path"), "folder": relative_folder.replace(os.path.sep, "/")}
updated = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
if not updated:
raise RecipeNotFoundError("Recipe not found after move")
return PersistenceResult(
{
"success": True,
"recipe_id": recipe_id,
"original_file_path": recipe_data.get("file_path"),
"new_file_path": updates["file_path"],
"json_path": new_json_path,
"folder": updates["folder"],
}
)
async def reconnect_lora(
self,
*,
recipe_scanner,
recipe_id: str,
lora_index: int,
target_name: str,
) -> PersistenceResult:
"""Reconnect a LoRA entry within an existing recipe."""
recipe_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_path or not os.path.exists(recipe_path):
raise RecipeNotFoundError("Recipe not found")
target_lora = await recipe_scanner.get_local_lora(target_name)
if not target_lora:
raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}")
recipe_data, updated_lora = await recipe_scanner.update_lora_entry(
recipe_id,
lora_index,
target_name=target_name,
target_lora=target_lora,
)
image_path = recipe_data.get("file_path")
if image_path and os.path.exists(image_path):
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
matching_recipes = []
if "fingerprint" in recipe_data:
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"])
if recipe_id in matching_recipes:
matching_recipes.remove(recipe_id)
return PersistenceResult(
{
"success": True,
"recipe_id": recipe_id,
"updated_lora": updated_lora,
"matching_recipes": matching_recipes,
}
)
async def bulk_delete(
self,
*,
recipe_scanner,
recipe_ids: Iterable[str],
) -> PersistenceResult:
"""Delete multiple recipes in a single request."""
recipe_ids = list(recipe_ids)
if not recipe_ids:
raise RecipeValidationError("No recipe IDs provided")
deleted_recipes: list[str] = []
failed_recipes: list[dict[str, Any]] = []
for recipe_id in recipe_ids:
recipe_json_path = await recipe_scanner.get_recipe_json_path(recipe_id)
if not recipe_json_path or not os.path.exists(recipe_json_path):
failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"})
continue
try:
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
recipe_data = json.load(file_obj)
image_path = recipe_data.get("file_path")
os.remove(recipe_json_path)
if image_path and os.path.exists(image_path):
os.remove(image_path)
deleted_recipes.append(recipe_id)
except Exception as exc:
failed_recipes.append({"id": recipe_id, "reason": str(exc)})
if deleted_recipes:
await recipe_scanner.bulk_remove(deleted_recipes)
return PersistenceResult(
{
"success": True,
"deleted": deleted_recipes,
"failed": failed_recipes,
"total_deleted": len(deleted_recipes),
"total_failed": len(failed_recipes),
}
)
async def save_recipe_from_widget(
self,
*,
recipe_scanner,
metadata: dict[str, Any],
image_bytes: bytes,
) -> PersistenceResult:
"""Save a recipe constructed from widget metadata."""
if not metadata:
raise RecipeValidationError("No generation metadata found")
recipes_dir = recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)
recipe_id = str(uuid.uuid4())
optimized_image, extension = self._exif_utils.optimize_image(
image_data=image_bytes,
target_width=self._card_preview_width,
format="webp",
quality=85,
preserve_metadata=True,
)
image_filename = f"{recipe_id}{extension}"
image_path = os.path.join(recipes_dir, image_filename)
with open(image_path, "wb") as file_obj:
file_obj.write(optimized_image)
lora_stack = metadata.get("loras", "")
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack)
loras_data = []
base_model_counts: Dict[str, int] = {}
for name, strength in lora_matches:
lora_info = await recipe_scanner.get_local_lora(name)
lora_data = {
"file_name": name,
"strength": float(strength),
"hash": (lora_info.get("sha256") or "").lower() if lora_info else "",
"modelVersionId": (lora_info.get("civitai") or {}).get("id", 0) if lora_info else 0,
"modelName": ((lora_info.get("civitai") or {}).get("model") or {}).get("name", name) if lora_info else "",
"modelVersionName": (lora_info.get("civitai") or {}).get("name", "") if lora_info else "",
"isDeleted": False,
"exclude": False,
}
loras_data.append(lora_data)
if lora_info and "base_model" in lora_info:
base_model = lora_info["base_model"]
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
recipe_name = self._derive_recipe_name(lora_matches)
most_common_base_model = (
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
)
recipe_data = {
"id": recipe_id,
"file_path": image_path,
"title": recipe_name,
"modified": time.time(),
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
"gen_params": {
key: value
for key, value in metadata.items()
if key not in ["checkpoint", "loras"]
},
"loras_stack": lora_stack,
}
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
await recipe_scanner.add_recipe(recipe_data)
return PersistenceResult(
{
"success": True,
"recipe_id": recipe_id,
"image_path": image_path,
"json_path": json_path,
"recipe_name": recipe_name,
}
)
# Helper methods ---------------------------------------------------
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
"""Pull a checkpoint entry from various metadata locations."""
checkpoint_entry = metadata.get("checkpoint") or metadata.get("model")
if not checkpoint_entry:
gen_params = metadata.get("gen_params") or {}
checkpoint_entry = gen_params.get("checkpoint")
return checkpoint_entry if isinstance(checkpoint_entry, dict) else None
def _sanitize_checkpoint_entry(self, checkpoint_entry: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
"""Remove transient/local-only fields from checkpoint metadata."""
if not checkpoint_entry:
return None
if not isinstance(checkpoint_entry, dict):
return checkpoint_entry
pruned = dict(checkpoint_entry)
for key in ("existsLocally", "localPath", "thumbnailUrl", "size", "downloadUrl"):
pruned.pop(key, None)
return pruned
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
if image_bytes is not None:
return image_bytes
if image_base64:
try:
payload = image_base64.split(",", 1)[1] if "," in image_base64 else image_base64
return base64.b64decode(payload)
except Exception as exc: # pragma: no cover - validation guard
raise RecipeValidationError(f"Invalid base64 image data: {exc}") from exc
raise RecipeValidationError("No image data provided")
def _normalise_lora_entry(self, lora: dict[str, Any]) -> dict[str, Any]:
return {
"file_name": lora.get("file_name", "")
or (
os.path.splitext(os.path.basename(lora.get("localPath", "")))[0]
if lora.get("localPath")
else ""
),
"hash": (lora.get("hash") or "").lower(),
"strength": float(lora.get("weight", 1.0)),
"modelVersionId": lora.get("id", 0),
"modelName": lora.get("name", ""),
"modelVersionName": lora.get("version", ""),
"isDeleted": lora.get("isDeleted", False),
"exclude": lora.get("exclude", False),
}
async def _find_matching_recipes(
self,
recipe_scanner,
fingerprint: str | None,
*,
exclude_id: Optional[str] = None,
) -> list[str]:
if not fingerprint:
return []
matches = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
if exclude_id and exclude_id in matches:
matches.remove(exclude_id)
return matches
def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str:
recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]]
recipe_name = "_".join(recipe_name_parts)
return recipe_name or "recipe"