Files
ComfyUI-Lora-Manager/py/services/recipes/analysis_service.py
Will Miao 283730cf38 fix(import): discover LoRA + checkpoint from modelVersionIds when API meta is null
When CivitAI image API returns meta=null and modelVersionIds at root
level, the import flow now:

- Injects modelVersionIds + browsingLevel into a minimal metadata dict
  so the parser can discover LoRAs and checkpoints (both import-from-url
  and analyze-image paths)
- Adds checkpoint dedup + fallback in the parser's modelVersionIds
  handler to avoid duplicate API calls
- Runs EXIF extraction unconditionally in analyze-image path, then
  merges with API metadata (fixes gen params loss)
- Propagates preview_nsfw_level through all three import paths:
  import-from-url, analyze-image (UI Import), and batch-import,
  plus the frontend save flow
2026-06-27 17:05:38 +08:00

520 lines
20 KiB
Python

"""Services responsible for recipe metadata analysis."""
from __future__ import annotations
import asyncio
import base64
import io
import os
import tempfile
from dataclasses import dataclass
from typing import Any, Callable, Optional
import numpy as np
from PIL import Image
from ...utils.utils import calculate_recipe_fingerprint
from ...utils.civitai_utils import extract_civitai_image_id, rewrite_preview_url
from ...recipes.enrichment import RecipeEnricher
from .errors import (
RecipeDownloadError,
RecipeNotFoundError,
RecipeServiceError,
RecipeValidationError,
)
@dataclass(frozen=True)
class AnalysisResult:
"""Return payload from analysis operations."""
payload: dict[str, Any]
status: int = 200
class RecipeAnalysisService:
"""Extract recipe metadata from various image sources."""
def __init__(
self,
*,
exif_utils,
recipe_parser_factory,
downloader_factory: Callable[[], Any],
metadata_collector: Optional[Callable[[], Any]] = None,
metadata_processor_cls: Optional[type] = None,
metadata_registry_cls: Optional[type] = None,
standalone_mode: bool = False,
logger,
) -> None:
self._exif_utils = exif_utils
self._recipe_parser_factory = recipe_parser_factory
self._downloader_factory = downloader_factory
self._metadata_collector = metadata_collector
self._metadata_processor_cls = metadata_processor_cls
self._metadata_registry_cls = metadata_registry_cls
self._standalone_mode = standalone_mode
self._logger = logger
async def analyze_uploaded_image(
self,
*,
image_bytes: bytes | None,
recipe_scanner,
) -> AnalysisResult:
"""Analyze an uploaded image payload."""
if not image_bytes:
raise RecipeValidationError("No image data provided")
temp_path = self._write_temp_file(image_bytes)
try:
metadata = self._exif_utils.extract_image_metadata(temp_path)
if not metadata:
return AnalysisResult(
{"error": "No metadata found in this image", "loras": []}
)
return await self._parse_metadata(
metadata,
recipe_scanner=recipe_scanner,
image_path=None,
include_image_base64=False,
)
finally:
self._safe_cleanup(temp_path)
async def analyze_remote_image(
self,
*,
url: str | None,
recipe_scanner,
civitai_client,
) -> AnalysisResult:
"""Analyze an image accessible via URL, including Civitai integration."""
if not url:
raise RecipeValidationError("No URL provided")
if civitai_client is None:
raise RecipeServiceError("Civitai client unavailable")
temp_path = None
metadata: Optional[dict[str, Any]] = None
is_video = False
extension = ".jpg" # Default
try:
civitai_image_id = extract_civitai_image_id(url)
if civitai_image_id:
image_info = await civitai_client.get_image_info(
civitai_image_id, source_url=url
)
if not image_info:
raise RecipeDownloadError(
"Failed to fetch image information from Civitai"
)
image_url = image_info.get("url")
if not image_url:
raise RecipeDownloadError("No image URL found in Civitai response")
is_video = image_info.get("type") == "video"
# Use optimized preview URLs if possible
rewritten_url, _ = rewrite_preview_url(
image_url, media_type=image_info.get("type")
)
if rewritten_url:
image_url = rewritten_url
if is_video:
# Extract extension from URL
url_path = image_url.split("?")[0].split("#")[0]
extension = os.path.splitext(url_path)[1].lower() or ".mp4"
else:
extension = ".jpg"
temp_path = self._create_temp_path(suffix=extension)
await self._download_image(image_url, temp_path)
metadata = image_info.get("meta") if "meta" in image_info else None
if (
isinstance(metadata, dict)
and "meta" in metadata
and isinstance(metadata["meta"], dict)
):
metadata = metadata["meta"]
# Include modelVersionIds from root level if available.
# CivitAI API returns modelVersionIds at root level, not in meta.
# When meta is null (None), create a minimal dict so downstream
# parsers can still discover LoRAs and checkpoints.
model_version_ids = image_info.get("modelVersionIds")
if model_version_ids:
if isinstance(metadata, dict):
metadata["modelVersionIds"] = model_version_ids
else:
metadata = {"modelVersionIds": model_version_ids}
# Inject browsingLevel (canonical integer) so the recipe's
# preview_nsfw_level can be set, enabling proper NSFW blur
# of the preview image. Fall back to nsfwLevel (string)
# when browsingLevel is absent.
if isinstance(metadata, dict):
browsing_level = image_info.get("browsingLevel")
nsfw_level_str = image_info.get("nsfwLevel")
if isinstance(browsing_level, int) and browsing_level > 0:
metadata["browsingLevel"] = browsing_level
elif (
isinstance(nsfw_level_str, str)
and nsfw_level_str
in (
"PG", "PG13", "R", "X", "XXX", "Blocked",
)
):
from ...utils.constants import NSFW_LEVELS
metadata["browsingLevel"] = NSFW_LEVELS.get(
nsfw_level_str, 0
)
# Validate that metadata contains meaningful recipe fields
# If not, treat as None to trigger EXIF extraction from downloaded image
if isinstance(metadata, dict) and not self._has_recipe_fields(metadata):
self._logger.debug(
"Civitai API metadata lacks recipe fields, will extract from EXIF"
)
metadata = None
else:
# Basic extension detection for non-Civitai URLs
url_path = url.split("?")[0].split("#")[0]
extension = os.path.splitext(url_path)[1].lower()
if extension in [".mp4", ".webm"]:
is_video = True
else:
extension = ".jpg"
temp_path = self._create_temp_path(suffix=extension)
await self._download_image(url, temp_path)
# Always extract EXIF from the downloaded image for generation
# params (prompt, negative prompt, sampler, steps, etc.).
# Previously this was gated on ``metadata is None``, but that
# skipped EXIF entirely when API metadata (modelVersionIds,
# browsingLevel) is present, losing all generation parameters.
exif_metadata = None
if not is_video:
exif_metadata = await asyncio.to_thread(
self._exif_utils.extract_image_metadata, temp_path
)
# Fallback: try the original (non-optimized) image for EXIF data
if not exif_metadata and civitai_image_id and image_info:
original_url = image_info.get("url")
if original_url:
self._logger.debug(
"Optimized image lacks embedded metadata, "
"falling back to original image: %s",
original_url,
)
orig_temp_path = self._create_temp_path(suffix=".png")
try:
await self._download_image(original_url, orig_temp_path)
exif_metadata = await asyncio.to_thread(
self._exif_utils.extract_image_metadata,
orig_temp_path,
)
finally:
self._safe_cleanup(orig_temp_path)
# Parse EXIF data (typically a string like parameters/prompt/workflow)
# and API metadata (dict with modelVersionIds, browsingLevel) separately,
# then merge: API loras/checkpoint override, EXIF gen_params fill in gaps.
# This mirrors the two-pass approach in _do_import_from_url.
exif_parsed_result = None
if isinstance(exif_metadata, str):
exif_parser = self._recipe_parser_factory.create_parser(exif_metadata)
if exif_parser:
exif_data = await exif_parser.parse_metadata(
exif_metadata, recipe_scanner=recipe_scanner,
)
if exif_data and not exif_data.get("error"):
exif_parsed_result = exif_data
# Merge API metadata (dict) with EXIF data (if dict) for the
# CivitaiApiMetadataParser. If EXIF data is a string it was
# parsed above — don't try to merge a string into a dict.
merged = {}
if isinstance(exif_metadata, dict):
merged.update(exif_metadata)
if isinstance(metadata, dict):
merged.update(metadata)
result = await self._parse_metadata(
merged,
recipe_scanner=recipe_scanner,
image_path=temp_path,
include_image_base64=True,
is_video=is_video,
extension=extension,
)
# Merge EXIF string-parsed gen_params into the API result.
# API gen_params take priority (they come later via update).
if exif_parsed_result and not result.payload.get("error"):
exif_gp = exif_parsed_result.get("gen_params") or {}
result_gp = result.payload.get("gen_params") or {}
merged_gp = {**exif_gp, **result_gp}
if merged_gp:
result.payload["gen_params"] = merged_gp
if civitai_image_id and image_info and not result.payload.get("error"):
# Use the metadata dict we built (may contain modelVersionIds
# and browsingLevel from the API root level). Do NOT pass
# image_info.get("meta") — it is null for images whose meta
# lives at the root level only. Also do NOT derive
# model_version_id from modelVersionIds[0] — that array mixes
# checkpoints, LoRAs, and other types without ordering
# guarantees; the parser already resolved them correctly.
recipe_for_enrich = {
"gen_params": result.payload.get("gen_params", {}),
"loras": result.payload.get("loras", []),
"base_model": result.payload.get("base_model", "") or "",
"checkpoint": result.payload.get("checkpoint") or result.payload.get("model"),
"source_path": url,
}
await RecipeEnricher.enrich_recipe(
recipe=recipe_for_enrich,
civitai_client=civitai_client,
request_params=None,
prefetched_civitai_meta_raw=(
metadata if isinstance(metadata, dict) else None
),
prefetched_model_version_id=None,
)
result.payload["gen_params"] = recipe_for_enrich["gen_params"]
if recipe_for_enrich.get("checkpoint"):
result.payload["checkpoint"] = recipe_for_enrich["checkpoint"]
if recipe_for_enrich.get("base_model"):
result.payload["base_model"] = recipe_for_enrich["base_model"]
# Extract browsingLevel from our constructed metadata for NSFW blur
if isinstance(metadata, dict):
bl = metadata.get("browsingLevel")
if isinstance(bl, int) and bl > 0:
result.payload["preview_nsfw_level"] = bl
return result
finally:
if temp_path:
self._safe_cleanup(temp_path)
async def analyze_local_image(
self,
*,
file_path: str | None,
recipe_scanner,
) -> AnalysisResult:
"""Analyze a file already present on disk."""
if not file_path:
raise RecipeValidationError("No file path provided")
normalized_path = os.path.normpath(file_path.strip('"').strip("'"))
if not os.path.isfile(normalized_path):
raise RecipeNotFoundError("File not found")
metadata = await asyncio.to_thread(
self._exif_utils.extract_image_metadata, normalized_path
)
if not metadata:
return self._metadata_not_found_response(normalized_path)
return await self._parse_metadata(
metadata,
recipe_scanner=recipe_scanner,
image_path=normalized_path,
include_image_base64=True,
)
async def analyze_widget_metadata(self, *, recipe_scanner) -> AnalysisResult:
"""Analyse the most recent generation metadata for widget saves."""
if self._metadata_collector is None or self._metadata_processor_cls is None:
raise RecipeValidationError("Metadata collection not available")
raw_metadata = self._metadata_collector()
metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata)
if not metadata_dict:
raise RecipeValidationError("No generation metadata found")
latest_image = None
if not self._standalone_mode and self._metadata_registry_cls is not None:
metadata_registry = self._metadata_registry_cls()
latest_image = metadata_registry.get_first_decoded_image()
if latest_image is None:
raise RecipeValidationError(
"No recent images found to use for recipe. Try generating an image first."
)
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
if image_bytes is None:
raise RecipeValidationError(
"Cannot handle this data shape from metadata registry"
)
return AnalysisResult(
{
"metadata": metadata_dict,
"image_bytes": image_bytes,
}
)
# Internal helpers -------------------------------------------------
def _has_recipe_fields(self, metadata: dict[str, Any]) -> bool:
"""Check if metadata contains meaningful recipe-related fields."""
recipe_fields = {
"prompt",
"negative_prompt",
"resources",
"hashes",
"params",
"generationData",
"Workflow",
"prompt_type",
"positive",
"negative",
# modelVersionIds is injected at the root level by CivitAI's image
# API when meta is null. It carries the version IDs of ALL models
# (checkpoint + LoRAs) used to generate the image.
"modelVersionIds",
}
return any(field in metadata for field in recipe_fields)
async def _parse_metadata(
self,
metadata: dict[str, Any],
*,
recipe_scanner,
image_path: Optional[str],
include_image_base64: bool,
is_video: bool = False,
extension: str = ".jpg",
) -> AnalysisResult:
parser = self._recipe_parser_factory.create_parser(metadata)
if parser is None:
# Provide more specific error message based on metadata source
if not metadata:
error_msg = "This image does not contain any generation metadata (prompt, models, or parameters)"
else:
error_msg = "No parser found for this image"
payload = {"error": error_msg, "loras": []}
if include_image_base64 and image_path:
payload["image_base64"] = self._encode_file(image_path)
payload["is_video"] = is_video
payload["extension"] = extension
return AnalysisResult(payload)
result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner)
if include_image_base64 and image_path:
result["image_base64"] = self._encode_file(image_path)
result["is_video"] = is_video
result["extension"] = extension
if "error" in result and not result.get("loras"):
return AnalysisResult(result)
fingerprint = calculate_recipe_fingerprint(result.get("loras", []))
result["fingerprint"] = fingerprint
matching_recipes: list[str] = []
if fingerprint:
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(
fingerprint
)
result["matching_recipes"] = matching_recipes
return AnalysisResult(result)
async def _download_image(self, url: str, temp_path: str) -> None:
downloader = await self._downloader_factory()
success, result = await downloader.download_file(url, temp_path, use_auth=False)
if not success:
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
payload: dict[str, Any] = {
"error": "No metadata found in this image",
"loras": [],
}
if os.path.exists(path):
payload["image_base64"] = self._encode_file(path)
return AnalysisResult(payload)
def _write_temp_file(self, data: bytes) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
temp_file.write(data)
return temp_file.name
def _create_temp_path(self, suffix: str = ".jpg") -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file:
return temp_file.name
def _safe_cleanup(self, path: Optional[str]) -> None:
if path and os.path.exists(path):
try:
os.unlink(path)
except Exception as exc: # pragma: no cover - defensive logging
self._logger.error("Error deleting temporary file: %s", exc)
def _encode_file(self, path: str) -> str:
with open(path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def _convert_tensor_to_png_bytes(self, latest_image: Any) -> Optional[bytes]:
try:
if isinstance(latest_image, tuple):
tensor_image = latest_image[0] if latest_image else None
if tensor_image is None:
return None
else:
tensor_image = latest_image
if hasattr(tensor_image, "shape"):
self._logger.debug(
"Tensor shape: %s, dtype: %s",
tensor_image.shape,
getattr(tensor_image, "dtype", None),
)
import torch # type: ignore[import-not-found]
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
while len(image_np.shape) > 3:
image_np = image_np[0]
if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format="PNG")
return img_byte_arr.getvalue()
except Exception as exc: # pragma: no cover - defensive logging path
self._logger.error("Error processing image data: %s", exc, exc_info=True)
return None
return None