fix(nodes): save recipes synchronously from save image

This commit is contained in:
Will Miao
2026-04-23 15:46:57 +08:00
parent ebdbb36271
commit df0e5797d0
2 changed files with 432 additions and 0 deletions

View File

@@ -1,12 +1,17 @@
import json import json
import os import os
import re import re
import time
import uuid
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import numpy as np import numpy as np
import folder_paths # type: ignore import folder_paths # type: ignore
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata from ..metadata_collector import get_metadata
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from ..utils.utils import calculate_recipe_fingerprint
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
import piexif import piexif
import logging import logging
@@ -86,6 +91,13 @@ class SaveImageLM:
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.", "tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.",
}, },
), ),
"save_as_recipe": (
"BOOLEAN",
{
"default": False,
"tooltip": "Also saves each generated image as a LoRA Manager recipe.",
},
),
}, },
"hidden": { "hidden": {
"id": "UNIQUE_ID", "id": "UNIQUE_ID",
@@ -346,6 +358,203 @@ class SaveImageLM:
return filename return filename
@staticmethod
def _get_cached_model_by_name(scanner, name):
cache = getattr(scanner, "_cache", None)
if cache is None or not name:
return None
candidates = [
name,
os.path.basename(name),
os.path.splitext(os.path.basename(name))[0],
]
for model in getattr(cache, "raw_data", []):
file_name = model.get("file_name")
if file_name in candidates:
return model
return None
def _build_recipe_loras(self, recipe_scanner, lora_stack):
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack or "")
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
loras_data = []
base_model_counts = {}
for name, strength in lora_matches:
lora_info = self._get_cached_model_by_name(lora_scanner, name)
civitai = (lora_info or {}).get("civitai") or {}
civitai_model = civitai.get("model") or {}
try:
parsed_strength = float(strength)
except (TypeError, ValueError):
parsed_strength = 1.0
loras_data.append(
{
"file_name": name,
"strength": parsed_strength,
"hash": ((lora_info or {}).get("sha256") or "").lower(),
"modelVersionId": civitai.get("id", 0),
"modelName": civitai_model.get("name", name) if lora_info else "",
"modelVersionName": civitai.get("name", "") if lora_info else "",
"isDeleted": False,
"exclude": False,
}
)
base_model = (lora_info or {}).get("base_model")
if base_model:
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
return lora_matches, loras_data, base_model_counts
def _build_recipe_checkpoint(self, recipe_scanner, checkpoint_raw):
if not isinstance(checkpoint_raw, str) or not checkpoint_raw.strip():
return None
checkpoint_name = checkpoint_raw.strip()
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
checkpoint_scanner = getattr(recipe_scanner, "_checkpoint_scanner", None)
checkpoint_info = self._get_cached_model_by_name(
checkpoint_scanner, checkpoint_name
)
if not checkpoint_info:
return {
"type": "checkpoint",
"name": checkpoint_name,
"file_name": file_name,
"hash": self.get_checkpoint_hash(checkpoint_name) or "",
}
civitai = checkpoint_info.get("civitai") or {}
civitai_model = civitai.get("model") or {}
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
cached_file_name = (
checkpoint_info.get("file_name")
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
or file_name
)
return {
"type": "checkpoint",
"modelId": civitai_model.get("id", 0),
"modelVersionId": civitai.get("id", 0),
"name": civitai_model.get("name")
or checkpoint_info.get("model_name")
or checkpoint_name,
"version": civitai.get("name", ""),
"hash": (
checkpoint_info.get("sha256") or checkpoint_info.get("hash") or ""
).lower(),
"file_name": cached_file_name,
"modelName": civitai_model.get("name", ""),
"modelVersionName": civitai.get("name", ""),
"baseModel": checkpoint_info.get("base_model")
or civitai.get("baseModel", ""),
}
@staticmethod
def _derive_recipe_name(lora_matches):
recipe_name_parts = [
f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]
]
return "_".join(recipe_name_parts) or "recipe"
@staticmethod
def _sync_recipe_cache(recipe_scanner, recipe_data, json_path):
cache = getattr(recipe_scanner, "_cache", None)
if cache is not None:
cache.raw_data.append(recipe_data)
cache.sorted_by_name = sorted(
cache.raw_data, key=lambda item: item.get("title", "").lower()
)
cache.sorted_by_date = sorted(
cache.raw_data,
key=lambda item: (
item.get("modified", item.get("created_date", 0)),
item.get("file_path", ""),
),
reverse=True,
)
recipe_scanner._update_folder_metadata(cache)
recipe_scanner._update_fts_index_for_recipe(recipe_data, "add")
recipe_id = str(recipe_data.get("id", ""))
if recipe_id:
recipe_scanner._json_path_map[recipe_id] = json_path
persistent_cache = getattr(recipe_scanner, "_persistent_cache", None)
if persistent_cache:
persistent_cache.update_recipe(recipe_data, json_path)
def _save_image_as_recipe(self, file_path, metadata_dict):
if not metadata_dict:
raise ValueError("No generation metadata found")
recipe_scanner = ServiceRegistry.get_service_sync("recipe_scanner")
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
recipes_dir = recipe_scanner.recipes_dir
if not recipes_dir:
raise RuntimeError("Recipes directory unavailable")
os.makedirs(recipes_dir, exist_ok=True)
recipe_id = str(uuid.uuid4())
optimized_image, extension = ExifUtils.optimize_image(
image_data=file_path,
target_width=CARD_PREVIEW_WIDTH,
format="webp",
quality=85,
preserve_metadata=True,
)
image_path = os.path.normpath(os.path.join(recipes_dir, f"{recipe_id}{extension}"))
with open(image_path, "wb") as file_obj:
file_obj.write(optimized_image)
lora_stack = metadata_dict.get("loras", "")
lora_matches, loras_data, base_model_counts = self._build_recipe_loras(
recipe_scanner, lora_stack
)
checkpoint_entry = self._build_recipe_checkpoint(
recipe_scanner, metadata_dict.get("checkpoint")
)
most_common_base_model = (
max(base_model_counts.items(), key=lambda item: item[1])[0]
if base_model_counts
else ""
)
current_time = time.time()
recipe_data = {
"id": recipe_id,
"file_path": image_path,
"title": self._derive_recipe_name(lora_matches),
"modified": current_time,
"created_date": current_time,
"base_model": most_common_base_model
or (checkpoint_entry or {}).get("baseModel", ""),
"loras": loras_data,
"gen_params": {
key: value
for key, value in metadata_dict.items()
if key not in ["checkpoint", "loras"]
},
"loras_stack": lora_stack,
"fingerprint": calculate_recipe_fingerprint(loras_data),
}
if checkpoint_entry:
recipe_data["checkpoint"] = checkpoint_entry
json_path = os.path.normpath(
os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
)
with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
ExifUtils.append_recipe_metadata(image_path, recipe_data)
self._sync_recipe_cache(recipe_scanner, recipe_data, json_path)
def save_images( def save_images(
self, self,
images, images,
@@ -359,6 +568,7 @@ class SaveImageLM:
embed_workflow=False, embed_workflow=False,
save_with_metadata=True, save_with_metadata=True,
add_counter_to_filename=True, add_counter_to_filename=True,
save_as_recipe=False,
): ):
"""Save images with metadata""" """Save images with metadata"""
results = [] results = []
@@ -477,6 +687,14 @@ class SaveImageLM:
img.save(file_path, format="WEBP", **save_kwargs) img.save(file_path, format="WEBP", **save_kwargs)
if save_as_recipe:
try:
self._save_image_as_recipe(file_path, metadata_dict)
except Exception as e:
logger.warning(
"Failed to save image as recipe: %s", e, exc_info=True
)
results.append( results.append(
{"filename": file, "subfolder": subfolder, "type": self.type} {"filename": file, "subfolder": subfolder, "type": self.type}
) )
@@ -499,6 +717,7 @@ class SaveImageLM:
embed_workflow=False, embed_workflow=False,
save_with_metadata=True, save_with_metadata=True,
add_counter_to_filename=True, add_counter_to_filename=True,
save_as_recipe=False,
): ):
"""Process and save image with metadata""" """Process and save image with metadata"""
# Make sure the output directory exists # Make sure the output directory exists
@@ -527,6 +746,7 @@ class SaveImageLM:
embed_workflow, embed_workflow,
save_with_metadata, save_with_metadata,
add_counter_to_filename, add_counter_to_filename,
save_as_recipe,
) )
return { return {

View File

@@ -1,9 +1,11 @@
import json import json
import os
import numpy as np import numpy as np
import piexif import piexif
from PIL import Image from PIL import Image
from py.services.service_registry import ServiceRegistry
from py.nodes.save_image import SaveImageLM from py.nodes.save_image import SaveImageLM
@@ -151,3 +153,213 @@ def test_process_image_returns_empty_ui_images_when_save_fails(monkeypatch, tmp_
assert result["result"] == (images,) assert result["result"] == (images,)
assert result["ui"] == {"images": []} assert result["ui"] == {"images": []}
def test_save_image_does_not_save_recipe_by_default(monkeypatch, tmp_path):
_configure_save_paths(monkeypatch, tmp_path)
_configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123})
calls = []
monkeypatch.setattr(
SaveImageLM,
"_save_image_as_recipe",
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
)
node = SaveImageLM()
node.save_images([_make_image()], "ComfyUI", "png", id="node-1")
assert calls == []
def test_save_image_saves_recipe_when_enabled(monkeypatch, tmp_path):
_configure_save_paths(monkeypatch, tmp_path)
metadata_dict = {"prompt": "prompt text", "seed": 123}
_configure_metadata(monkeypatch, metadata_dict)
calls = []
monkeypatch.setattr(
SaveImageLM,
"_save_image_as_recipe",
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
)
node = SaveImageLM()
node.save_images(
[_make_image()],
"ComfyUI",
"png",
id="node-1",
save_as_recipe=True,
)
assert calls == [(str(tmp_path / "sample_00001_.png"), metadata_dict)]
def test_save_image_saves_recipe_for_each_successful_batch_image(monkeypatch, tmp_path):
monkeypatch.setattr("folder_paths.get_output_directory", lambda: str(tmp_path), raising=False)
monkeypatch.setattr(
"folder_paths.get_save_image_path",
lambda *_args, **_kwargs: (str(tmp_path), "sample", 7, "", "sample"),
raising=False,
)
metadata_dict = {"prompt": "prompt text", "seed": 123}
_configure_metadata(monkeypatch, metadata_dict)
calls = []
monkeypatch.setattr(
SaveImageLM,
"_save_image_as_recipe",
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
)
node = SaveImageLM()
node.save_images(
[_make_image(), _make_image()],
"ComfyUI",
"png",
id="node-1",
save_as_recipe=True,
)
assert calls == [
(str(tmp_path / "sample_00007_.png"), metadata_dict),
(str(tmp_path / "sample_00008_.png"), metadata_dict),
]
def test_save_image_does_not_save_recipe_when_image_save_fails(monkeypatch, tmp_path):
_configure_save_paths(monkeypatch, tmp_path)
_configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123})
def _raise_save_error(*args, **kwargs):
raise OSError("disk full")
calls = []
monkeypatch.setattr(Image.Image, "save", _raise_save_error)
monkeypatch.setattr(
SaveImageLM,
"_save_image_as_recipe",
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
)
node = SaveImageLM()
node.save_images(
[_make_image()],
"ComfyUI",
"png",
id="node-1",
save_as_recipe=True,
)
assert calls == []
def test_process_image_keeps_image_result_when_recipe_save_fails(monkeypatch, tmp_path):
_configure_save_paths(monkeypatch, tmp_path)
_configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123})
def _raise_recipe_error(*args, **kwargs):
raise RuntimeError("recipe unavailable")
monkeypatch.setattr(SaveImageLM, "_save_image_as_recipe", _raise_recipe_error)
images = [_make_image()]
node = SaveImageLM()
result = node.process_image(images, id="node-1", save_as_recipe=True)
assert result["result"] == (images,)
assert result["ui"] == {
"images": [{"filename": "sample_00001_.png", "subfolder": "", "type": "output"}]
}
def test_save_image_as_recipe_writes_recipe_without_async_scanner_calls(
monkeypatch, tmp_path
):
_configure_save_paths(monkeypatch, tmp_path)
source_image = tmp_path / "source.png"
Image.new("RGB", (16, 16), color=(10, 20, 30)).save(source_image)
recipes_dir = tmp_path / "recipes"
class _Cache:
def __init__(self, raw_data=None):
self.raw_data = raw_data or []
self.sorted_by_name = []
self.sorted_by_date = []
self.folders = []
self.folder_tree = {}
class _ModelScanner:
def __init__(self, raw_data):
self._cache = _Cache(raw_data)
class _PersistentCache:
def __init__(self):
self.updates = []
def update_recipe(self, recipe_data, json_path):
self.updates.append((recipe_data, json_path))
class _RecipeScanner:
def __init__(self):
self.recipes_dir = str(recipes_dir)
self._cache = _Cache([])
self._json_path_map = {}
self._persistent_cache = _PersistentCache()
self._lora_scanner = _ModelScanner(
[
{
"file_name": "foo",
"sha256": "ABC123",
"base_model": "SDXL",
"civitai": {
"id": 456,
"name": "Foo v1",
"model": {"name": "Foo"},
},
}
]
)
self._checkpoint_scanner = _ModelScanner([])
self.fts_updates = []
def _update_folder_metadata(self, cache):
cache.folders = [""]
cache.folder_tree = {}
def _update_fts_index_for_recipe(self, recipe_data, operation):
self.fts_updates.append((recipe_data["id"], operation))
scanner = _RecipeScanner()
monkeypatch.setitem(ServiceRegistry._services, "recipe_scanner", scanner)
node = SaveImageLM()
node._save_image_as_recipe(
str(source_image),
{
"prompt": "prompt text",
"seed": 123,
"checkpoint": "model.safetensors",
"loras": "<lora:foo:0.7>",
},
)
recipe_files = list(recipes_dir.glob("*.recipe.json"))
preview_files = list(recipes_dir.glob("*.webp"))
assert len(recipe_files) == 1
assert len(preview_files) == 1
assert len(scanner._cache.raw_data) == 1
assert len(scanner._persistent_cache.updates) == 1
recipe = json.loads(recipe_files[0].read_text(encoding="utf-8"))
assert recipe["file_path"] == os.path.normpath(str(preview_files[0]))
assert recipe["title"] == "foo-0.70"
assert recipe["base_model"] == "SDXL"
assert recipe["loras"][0]["hash"] == "abc123"
assert recipe["loras"][0]["modelVersionId"] == 456
assert recipe["gen_params"] == {"prompt": "prompt text", "seed": 123}
assert scanner._json_path_map[recipe["id"]] == os.path.normpath(str(recipe_files[0]))
assert scanner.fts_updates == [(recipe["id"], "add")]