From df0e5797d0dcba965f8a98c56f69db7950bccbbf Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 23 Apr 2026 15:46:57 +0800 Subject: [PATCH] fix(nodes): save recipes synchronously from save image --- py/nodes/save_image.py | 220 +++++++++++++++++++++++++++++++++ tests/nodes/test_save_image.py | 212 +++++++++++++++++++++++++++++++ 2 files changed, 432 insertions(+) diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index a2cff148..20aec150 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -1,12 +1,17 @@ import json import os import re +import time +import uuid from typing import Any, Dict, Optional import numpy as np import folder_paths # type: ignore from ..services.service_registry import ServiceRegistry from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector import get_metadata +from ..utils.constants import CARD_PREVIEW_WIDTH +from ..utils.exif_utils import ExifUtils +from ..utils.utils import calculate_recipe_fingerprint from PIL import Image, PngImagePlugin import piexif import logging @@ -86,6 +91,13 @@ class SaveImageLM: "tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.", }, ), + "save_as_recipe": ( + "BOOLEAN", + { + "default": False, + "tooltip": "Also saves each generated image as a LoRA Manager recipe.", + }, + ), }, "hidden": { "id": "UNIQUE_ID", @@ -346,6 +358,203 @@ class SaveImageLM: return filename + @staticmethod + def _get_cached_model_by_name(scanner, name): + cache = getattr(scanner, "_cache", None) + if cache is None or not name: + return None + + candidates = [ + name, + os.path.basename(name), + os.path.splitext(os.path.basename(name))[0], + ] + for model in getattr(cache, "raw_data", []): + file_name = model.get("file_name") + if file_name in candidates: + return model + return None + + def _build_recipe_loras(self, recipe_scanner, lora_stack): + lora_matches = re.findall(r"]+)>", lora_stack or "") + lora_scanner = getattr(recipe_scanner, "_lora_scanner", None) + loras_data = [] + base_model_counts = {} + + for name, strength in lora_matches: + lora_info = self._get_cached_model_by_name(lora_scanner, name) + civitai = (lora_info or {}).get("civitai") or {} + civitai_model = civitai.get("model") or {} + try: + parsed_strength = float(strength) + except (TypeError, ValueError): + parsed_strength = 1.0 + + loras_data.append( + { + "file_name": name, + "strength": parsed_strength, + "hash": ((lora_info or {}).get("sha256") or "").lower(), + "modelVersionId": civitai.get("id", 0), + "modelName": civitai_model.get("name", name) if lora_info else "", + "modelVersionName": civitai.get("name", "") if lora_info else "", + "isDeleted": False, + "exclude": False, + } + ) + + base_model = (lora_info or {}).get("base_model") + if base_model: + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + return lora_matches, loras_data, base_model_counts + + def _build_recipe_checkpoint(self, recipe_scanner, checkpoint_raw): + if not isinstance(checkpoint_raw, str) or not checkpoint_raw.strip(): + return None + + checkpoint_name = checkpoint_raw.strip() + file_name = os.path.splitext(os.path.basename(checkpoint_name))[0] + checkpoint_scanner = getattr(recipe_scanner, "_checkpoint_scanner", None) + checkpoint_info = self._get_cached_model_by_name( + checkpoint_scanner, checkpoint_name + ) + + if not checkpoint_info: + return { + "type": "checkpoint", + "name": checkpoint_name, + "file_name": file_name, + "hash": self.get_checkpoint_hash(checkpoint_name) or "", + } + + civitai = checkpoint_info.get("civitai") or {} + civitai_model = civitai.get("model") or {} + file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or "" + cached_file_name = ( + checkpoint_info.get("file_name") + or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "") + or file_name + ) + + return { + "type": "checkpoint", + "modelId": civitai_model.get("id", 0), + "modelVersionId": civitai.get("id", 0), + "name": civitai_model.get("name") + or checkpoint_info.get("model_name") + or checkpoint_name, + "version": civitai.get("name", ""), + "hash": ( + checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "" + ).lower(), + "file_name": cached_file_name, + "modelName": civitai_model.get("name", ""), + "modelVersionName": civitai.get("name", ""), + "baseModel": checkpoint_info.get("base_model") + or civitai.get("baseModel", ""), + } + + @staticmethod + def _derive_recipe_name(lora_matches): + recipe_name_parts = [ + f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3] + ] + return "_".join(recipe_name_parts) or "recipe" + + @staticmethod + def _sync_recipe_cache(recipe_scanner, recipe_data, json_path): + cache = getattr(recipe_scanner, "_cache", None) + if cache is not None: + cache.raw_data.append(recipe_data) + cache.sorted_by_name = sorted( + cache.raw_data, key=lambda item: item.get("title", "").lower() + ) + cache.sorted_by_date = sorted( + cache.raw_data, + key=lambda item: ( + item.get("modified", item.get("created_date", 0)), + item.get("file_path", ""), + ), + reverse=True, + ) + recipe_scanner._update_folder_metadata(cache) + recipe_scanner._update_fts_index_for_recipe(recipe_data, "add") + + recipe_id = str(recipe_data.get("id", "")) + if recipe_id: + recipe_scanner._json_path_map[recipe_id] = json_path + persistent_cache = getattr(recipe_scanner, "_persistent_cache", None) + if persistent_cache: + persistent_cache.update_recipe(recipe_data, json_path) + + def _save_image_as_recipe(self, file_path, metadata_dict): + if not metadata_dict: + raise ValueError("No generation metadata found") + + recipe_scanner = ServiceRegistry.get_service_sync("recipe_scanner") + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipes_dir = recipe_scanner.recipes_dir + if not recipes_dir: + raise RuntimeError("Recipes directory unavailable") + os.makedirs(recipes_dir, exist_ok=True) + + recipe_id = str(uuid.uuid4()) + optimized_image, extension = ExifUtils.optimize_image( + image_data=file_path, + target_width=CARD_PREVIEW_WIDTH, + format="webp", + quality=85, + preserve_metadata=True, + ) + image_path = os.path.normpath(os.path.join(recipes_dir, f"{recipe_id}{extension}")) + with open(image_path, "wb") as file_obj: + file_obj.write(optimized_image) + + lora_stack = metadata_dict.get("loras", "") + lora_matches, loras_data, base_model_counts = self._build_recipe_loras( + recipe_scanner, lora_stack + ) + checkpoint_entry = self._build_recipe_checkpoint( + recipe_scanner, metadata_dict.get("checkpoint") + ) + most_common_base_model = ( + max(base_model_counts.items(), key=lambda item: item[1])[0] + if base_model_counts + else "" + ) + current_time = time.time() + recipe_data = { + "id": recipe_id, + "file_path": image_path, + "title": self._derive_recipe_name(lora_matches), + "modified": current_time, + "created_date": current_time, + "base_model": most_common_base_model + or (checkpoint_entry or {}).get("baseModel", ""), + "loras": loras_data, + "gen_params": { + key: value + for key, value in metadata_dict.items() + if key not in ["checkpoint", "loras"] + }, + "loras_stack": lora_stack, + "fingerprint": calculate_recipe_fingerprint(loras_data), + } + if checkpoint_entry: + recipe_data["checkpoint"] = checkpoint_entry + + json_path = os.path.normpath( + os.path.join(recipes_dir, f"{recipe_id}.recipe.json") + ) + with open(json_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + ExifUtils.append_recipe_metadata(image_path, recipe_data) + self._sync_recipe_cache(recipe_scanner, recipe_data, json_path) + def save_images( self, images, @@ -359,6 +568,7 @@ class SaveImageLM: embed_workflow=False, save_with_metadata=True, add_counter_to_filename=True, + save_as_recipe=False, ): """Save images with metadata""" results = [] @@ -477,6 +687,14 @@ class SaveImageLM: img.save(file_path, format="WEBP", **save_kwargs) + if save_as_recipe: + try: + self._save_image_as_recipe(file_path, metadata_dict) + except Exception as e: + logger.warning( + "Failed to save image as recipe: %s", e, exc_info=True + ) + results.append( {"filename": file, "subfolder": subfolder, "type": self.type} ) @@ -499,6 +717,7 @@ class SaveImageLM: embed_workflow=False, save_with_metadata=True, add_counter_to_filename=True, + save_as_recipe=False, ): """Process and save image with metadata""" # Make sure the output directory exists @@ -527,6 +746,7 @@ class SaveImageLM: embed_workflow, save_with_metadata, add_counter_to_filename, + save_as_recipe, ) return { diff --git a/tests/nodes/test_save_image.py b/tests/nodes/test_save_image.py index 0ab2928b..4b567211 100644 --- a/tests/nodes/test_save_image.py +++ b/tests/nodes/test_save_image.py @@ -1,9 +1,11 @@ import json +import os import numpy as np import piexif from PIL import Image +from py.services.service_registry import ServiceRegistry from py.nodes.save_image import SaveImageLM @@ -151,3 +153,213 @@ def test_process_image_returns_empty_ui_images_when_save_fails(monkeypatch, tmp_ assert result["result"] == (images,) assert result["ui"] == {"images": []} + + +def test_save_image_does_not_save_recipe_by_default(monkeypatch, tmp_path): + _configure_save_paths(monkeypatch, tmp_path) + _configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123}) + + calls = [] + monkeypatch.setattr( + SaveImageLM, + "_save_image_as_recipe", + lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)), + ) + + node = SaveImageLM() + node.save_images([_make_image()], "ComfyUI", "png", id="node-1") + + assert calls == [] + + +def test_save_image_saves_recipe_when_enabled(monkeypatch, tmp_path): + _configure_save_paths(monkeypatch, tmp_path) + metadata_dict = {"prompt": "prompt text", "seed": 123} + _configure_metadata(monkeypatch, metadata_dict) + + calls = [] + monkeypatch.setattr( + SaveImageLM, + "_save_image_as_recipe", + lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)), + ) + + node = SaveImageLM() + node.save_images( + [_make_image()], + "ComfyUI", + "png", + id="node-1", + save_as_recipe=True, + ) + + assert calls == [(str(tmp_path / "sample_00001_.png"), metadata_dict)] + + +def test_save_image_saves_recipe_for_each_successful_batch_image(monkeypatch, tmp_path): + monkeypatch.setattr("folder_paths.get_output_directory", lambda: str(tmp_path), raising=False) + monkeypatch.setattr( + "folder_paths.get_save_image_path", + lambda *_args, **_kwargs: (str(tmp_path), "sample", 7, "", "sample"), + raising=False, + ) + metadata_dict = {"prompt": "prompt text", "seed": 123} + _configure_metadata(monkeypatch, metadata_dict) + + calls = [] + monkeypatch.setattr( + SaveImageLM, + "_save_image_as_recipe", + lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)), + ) + + node = SaveImageLM() + node.save_images( + [_make_image(), _make_image()], + "ComfyUI", + "png", + id="node-1", + save_as_recipe=True, + ) + + assert calls == [ + (str(tmp_path / "sample_00007_.png"), metadata_dict), + (str(tmp_path / "sample_00008_.png"), metadata_dict), + ] + + +def test_save_image_does_not_save_recipe_when_image_save_fails(monkeypatch, tmp_path): + _configure_save_paths(monkeypatch, tmp_path) + _configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123}) + + def _raise_save_error(*args, **kwargs): + raise OSError("disk full") + + calls = [] + monkeypatch.setattr(Image.Image, "save", _raise_save_error) + monkeypatch.setattr( + SaveImageLM, + "_save_image_as_recipe", + lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)), + ) + + node = SaveImageLM() + node.save_images( + [_make_image()], + "ComfyUI", + "png", + id="node-1", + save_as_recipe=True, + ) + + assert calls == [] + + +def test_process_image_keeps_image_result_when_recipe_save_fails(monkeypatch, tmp_path): + _configure_save_paths(monkeypatch, tmp_path) + _configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123}) + + def _raise_recipe_error(*args, **kwargs): + raise RuntimeError("recipe unavailable") + + monkeypatch.setattr(SaveImageLM, "_save_image_as_recipe", _raise_recipe_error) + + images = [_make_image()] + node = SaveImageLM() + + result = node.process_image(images, id="node-1", save_as_recipe=True) + + assert result["result"] == (images,) + assert result["ui"] == { + "images": [{"filename": "sample_00001_.png", "subfolder": "", "type": "output"}] + } + + +def test_save_image_as_recipe_writes_recipe_without_async_scanner_calls( + monkeypatch, tmp_path +): + _configure_save_paths(monkeypatch, tmp_path) + source_image = tmp_path / "source.png" + Image.new("RGB", (16, 16), color=(10, 20, 30)).save(source_image) + recipes_dir = tmp_path / "recipes" + + class _Cache: + def __init__(self, raw_data=None): + self.raw_data = raw_data or [] + self.sorted_by_name = [] + self.sorted_by_date = [] + self.folders = [] + self.folder_tree = {} + + class _ModelScanner: + def __init__(self, raw_data): + self._cache = _Cache(raw_data) + + class _PersistentCache: + def __init__(self): + self.updates = [] + + def update_recipe(self, recipe_data, json_path): + self.updates.append((recipe_data, json_path)) + + class _RecipeScanner: + def __init__(self): + self.recipes_dir = str(recipes_dir) + self._cache = _Cache([]) + self._json_path_map = {} + self._persistent_cache = _PersistentCache() + self._lora_scanner = _ModelScanner( + [ + { + "file_name": "foo", + "sha256": "ABC123", + "base_model": "SDXL", + "civitai": { + "id": 456, + "name": "Foo v1", + "model": {"name": "Foo"}, + }, + } + ] + ) + self._checkpoint_scanner = _ModelScanner([]) + self.fts_updates = [] + + def _update_folder_metadata(self, cache): + cache.folders = [""] + cache.folder_tree = {} + + def _update_fts_index_for_recipe(self, recipe_data, operation): + self.fts_updates.append((recipe_data["id"], operation)) + + scanner = _RecipeScanner() + monkeypatch.setitem(ServiceRegistry._services, "recipe_scanner", scanner) + + node = SaveImageLM() + node._save_image_as_recipe( + str(source_image), + { + "prompt": "prompt text", + "seed": 123, + "checkpoint": "model.safetensors", + "loras": "", + }, + ) + + recipe_files = list(recipes_dir.glob("*.recipe.json")) + preview_files = list(recipes_dir.glob("*.webp")) + + assert len(recipe_files) == 1 + assert len(preview_files) == 1 + assert len(scanner._cache.raw_data) == 1 + assert len(scanner._persistent_cache.updates) == 1 + + recipe = json.loads(recipe_files[0].read_text(encoding="utf-8")) + assert recipe["file_path"] == os.path.normpath(str(preview_files[0])) + assert recipe["title"] == "foo-0.70" + assert recipe["base_model"] == "SDXL" + assert recipe["loras"][0]["hash"] == "abc123" + assert recipe["loras"][0]["modelVersionId"] == 456 + assert recipe["gen_params"] == {"prompt": "prompt text", "seed": 123} + assert scanner._json_path_map[recipe["id"]] == os.path.normpath(str(recipe_files[0])) + assert scanner.fts_updates == [(recipe["id"], "add")]