mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
Compare commits
4 Commits
2eef629821
...
b31fae4e51
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b31fae4e51 | ||
|
|
c6e5467907 | ||
|
|
df0e5797d0 | ||
|
|
ebdbb36271 |
@@ -353,49 +353,100 @@ class MetadataProcessor:
|
||||
# Check if we have stored conditioning objects for this sampler
|
||||
if sampler_id in metadata.get(PROMPTS, {}) and (
|
||||
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
|
||||
|
||||
"neg_conditioning" in metadata[PROMPTS][sampler_id]
|
||||
):
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
def extend_unique(target, values):
|
||||
for value in values:
|
||||
if value and value not in target:
|
||||
target.append(value)
|
||||
|
||||
# Helper function to recursively find prompt texts for a conditioning object.
|
||||
# Transform nodes can map one output conditioning to multiple source conditionings.
|
||||
def find_prompt_texts_for_conditioning(
|
||||
conditioning_obj, is_positive=True, visited=None
|
||||
):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
return []
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
conditioning_id = id(conditioning_obj)
|
||||
if conditioning_id in visited:
|
||||
return []
|
||||
visited.add(conditioning_id)
|
||||
|
||||
prompt_texts = []
|
||||
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
if not isinstance(prompt_data, dict):
|
||||
continue
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
# For CLIP text nodes with a single conditioning output.
|
||||
if id(prompt_data.get("conditioning")) == conditioning_id:
|
||||
text = prompt_data.get("text", "")
|
||||
if text:
|
||||
extend_unique(prompt_texts, [text])
|
||||
|
||||
# Generic provenance for passthrough/transform/combine nodes.
|
||||
for source in prompt_data.get("conditioning_sources", []):
|
||||
if id(source.get("output")) != conditioning_id:
|
||||
continue
|
||||
for input_conditioning in source.get("inputs", []):
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
input_conditioning, is_positive, visited
|
||||
),
|
||||
)
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs
|
||||
# like TSC_EfficientLoader and existing ControlNet-style metadata.
|
||||
if (
|
||||
is_positive
|
||||
and id(prompt_data.get("positive_encoded")) == conditioning_id
|
||||
):
|
||||
if prompt_data.get("positive_text"):
|
||||
extend_unique(prompt_texts, [prompt_data["positive_text"]])
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
prompt_data.get("orig_pos_cond"),
|
||||
is_positive=True,
|
||||
visited=visited,
|
||||
),
|
||||
)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
if (
|
||||
not is_positive
|
||||
and id(prompt_data.get("negative_encoded")) == conditioning_id
|
||||
):
|
||||
if prompt_data.get("negative_text"):
|
||||
extend_unique(prompt_texts, [prompt_data["negative_text"]])
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
extend_unique(
|
||||
prompt_texts,
|
||||
find_prompt_texts_for_conditioning(
|
||||
prompt_data.get("orig_neg_cond"),
|
||||
is_positive=False,
|
||||
visited=visited,
|
||||
),
|
||||
)
|
||||
|
||||
return ""
|
||||
return prompt_texts
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
result["prompt"] = ", ".join(
|
||||
find_prompt_texts_for_conditioning(pos_conditioning, is_positive=True)
|
||||
)
|
||||
result["negative_prompt"] = ", ".join(
|
||||
find_prompt_texts_for_conditioning(neg_conditioning, is_positive=False)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -163,6 +163,251 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
|
||||
class MyOriginalWaifuTextExtractor(NodeMetadataExtractor):
|
||||
"""Extractor for ComfyUI-MyOriginalWaifu TextProvider nodes."""
|
||||
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"positive_text": positive_text,
|
||||
"negative_text": negative_text,
|
||||
"node_id": node_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 2:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["positive_text"] = output_tuple[0]
|
||||
prompt_metadata["negative_text"] = output_tuple[1]
|
||||
|
||||
|
||||
class MyOriginalWaifuClipExtractor(NodeMetadataExtractor):
|
||||
"""Extractor for ComfyUI-MyOriginalWaifu ClipProvider nodes."""
|
||||
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
positive_text = inputs.get("positive", "")
|
||||
negative_text = inputs.get("negative", "")
|
||||
|
||||
if positive_text or negative_text:
|
||||
metadata[PROMPTS][node_id] = {
|
||||
"positive_text": positive_text,
|
||||
"negative_text": negative_text,
|
||||
"node_id": node_id,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 2:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["positive_encoded"] = output_tuple[0]
|
||||
prompt_metadata["negative_encoded"] = output_tuple[1]
|
||||
|
||||
|
||||
def _ensure_prompt_metadata(metadata, node_id):
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
return metadata[PROMPTS][node_id]
|
||||
|
||||
|
||||
def _first_output_tuple(outputs):
|
||||
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return None
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple):
|
||||
return first_output
|
||||
return None
|
||||
|
||||
|
||||
def _record_conditioning_source(
|
||||
metadata, node_id, output_conditioning, input_conditionings
|
||||
):
|
||||
if output_conditioning is None:
|
||||
return
|
||||
|
||||
sources = [
|
||||
conditioning for conditioning in input_conditionings if conditioning is not None
|
||||
]
|
||||
if not sources:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata.setdefault("conditioning_sources", []).append(
|
||||
{
|
||||
"output": output_conditioning,
|
||||
"inputs": sources,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_variable_name(inputs):
|
||||
for key in ("key", "name", "variable_name", "tag", "text"):
|
||||
value = inputs.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _get_node_variable_name(metadata, node_id, inputs):
|
||||
variable_name = _get_variable_name(inputs)
|
||||
if variable_name:
|
||||
return variable_name
|
||||
|
||||
prompt = metadata.get("current_prompt")
|
||||
original_prompt = getattr(prompt, "original_prompt", None)
|
||||
if not original_prompt or node_id not in original_prompt:
|
||||
return None
|
||||
|
||||
node_data = original_prompt[node_id]
|
||||
variable_name = _get_variable_name(node_data.get("inputs", {}))
|
||||
if variable_name:
|
||||
return variable_name
|
||||
|
||||
widgets_values = node_data.get("widgets_values", [])
|
||||
if widgets_values and isinstance(widgets_values[0], str):
|
||||
return widgets_values[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ControlNetApplyAdvancedExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
if inputs.get("positive") is not None:
|
||||
prompt_metadata["orig_pos_cond"] = inputs["positive"]
|
||||
if inputs.get("negative") is not None:
|
||||
prompt_metadata["orig_neg_cond"] = inputs["negative"]
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
positive_input = prompt_metadata.get("orig_pos_cond")
|
||||
negative_input = prompt_metadata.get("orig_neg_cond")
|
||||
|
||||
if len(output_tuple) >= 1:
|
||||
prompt_metadata["positive_encoded"] = output_tuple[0]
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_tuple[0], [positive_input]
|
||||
)
|
||||
if len(output_tuple) >= 2:
|
||||
prompt_metadata["negative_encoded"] = output_tuple[1]
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_tuple[1], [negative_input]
|
||||
)
|
||||
|
||||
|
||||
class ConditioningCombineExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
input_conditionings = []
|
||||
for input_name in inputs:
|
||||
if (
|
||||
input_name.startswith("conditioning")
|
||||
and inputs[input_name] is not None
|
||||
):
|
||||
input_conditionings.append(inputs[input_name])
|
||||
|
||||
if input_conditionings:
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["orig_conditionings"] = input_conditionings
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 1:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
output_conditioning = output_tuple[0]
|
||||
prompt_metadata["conditioning"] = output_conditioning
|
||||
_record_conditioning_source(
|
||||
metadata,
|
||||
node_id,
|
||||
output_conditioning,
|
||||
prompt_metadata.get("orig_conditionings", []),
|
||||
)
|
||||
|
||||
|
||||
class SetNodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
variable_name = _get_node_variable_name(metadata, node_id, inputs)
|
||||
conditioning = inputs.get("CONDITIONING")
|
||||
if conditioning is None:
|
||||
conditioning = inputs.get("conditioning")
|
||||
if conditioning is None:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["conditioning"] = conditioning
|
||||
if variable_name:
|
||||
prompt_metadata["variable_name"] = variable_name
|
||||
metadata[PROMPTS].setdefault("__conditioning_variables__", {})[
|
||||
variable_name
|
||||
] = conditioning
|
||||
|
||||
|
||||
class GetNodeExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
variable_name = _get_node_variable_name(metadata, node_id, inputs or {})
|
||||
if variable_name:
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
prompt_metadata["variable_name"] = variable_name
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
output_tuple = _first_output_tuple(outputs)
|
||||
if not output_tuple or len(output_tuple) < 1:
|
||||
return
|
||||
|
||||
prompt_metadata = _ensure_prompt_metadata(metadata, node_id)
|
||||
output_conditioning = output_tuple[0]
|
||||
prompt_metadata["conditioning"] = output_conditioning
|
||||
|
||||
variable_name = prompt_metadata.get("variable_name")
|
||||
if not variable_name:
|
||||
return
|
||||
|
||||
input_conditioning = metadata[PROMPTS].get("__conditioning_variables__", {}).get(
|
||||
variable_name
|
||||
)
|
||||
_record_conditioning_source(
|
||||
metadata, node_id, output_conditioning, [input_conditioning]
|
||||
)
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@@ -798,6 +1043,12 @@ NODE_EXTRACTORS = {
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
|
||||
"TextProvider": MyOriginalWaifuTextExtractor, # ComfyUI-MyOriginalWaifu
|
||||
"ClipProvider": MyOriginalWaifuClipExtractor, # ComfyUI-MyOriginalWaifu
|
||||
"ControlNetApplyAdvanced": ControlNetApplyAdvancedExtractor,
|
||||
"ConditioningCombine": ConditioningCombineExtractor,
|
||||
"SetNode": SetNodeExtractor,
|
||||
"GetNode": GetNodeExtractor,
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
|
||||
@@ -1,12 +1,17 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.utils import calculate_recipe_fingerprint
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
import logging
|
||||
@@ -86,6 +91,13 @@ class SaveImageLM:
|
||||
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images.",
|
||||
},
|
||||
),
|
||||
"save_as_recipe": (
|
||||
"BOOLEAN",
|
||||
{
|
||||
"default": False,
|
||||
"tooltip": "Also saves each generated image as a LoRA Manager recipe.",
|
||||
},
|
||||
),
|
||||
},
|
||||
"hidden": {
|
||||
"id": "UNIQUE_ID",
|
||||
@@ -346,6 +358,203 @@ class SaveImageLM:
|
||||
|
||||
return filename
|
||||
|
||||
@staticmethod
|
||||
def _get_cached_model_by_name(scanner, name):
|
||||
cache = getattr(scanner, "_cache", None)
|
||||
if cache is None or not name:
|
||||
return None
|
||||
|
||||
candidates = [
|
||||
name,
|
||||
os.path.basename(name),
|
||||
os.path.splitext(os.path.basename(name))[0],
|
||||
]
|
||||
for model in getattr(cache, "raw_data", []):
|
||||
file_name = model.get("file_name")
|
||||
if file_name in candidates:
|
||||
return model
|
||||
return None
|
||||
|
||||
def _build_recipe_loras(self, recipe_scanner, lora_stack):
|
||||
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack or "")
|
||||
lora_scanner = getattr(recipe_scanner, "_lora_scanner", None)
|
||||
loras_data = []
|
||||
base_model_counts = {}
|
||||
|
||||
for name, strength in lora_matches:
|
||||
lora_info = self._get_cached_model_by_name(lora_scanner, name)
|
||||
civitai = (lora_info or {}).get("civitai") or {}
|
||||
civitai_model = civitai.get("model") or {}
|
||||
try:
|
||||
parsed_strength = float(strength)
|
||||
except (TypeError, ValueError):
|
||||
parsed_strength = 1.0
|
||||
|
||||
loras_data.append(
|
||||
{
|
||||
"file_name": name,
|
||||
"strength": parsed_strength,
|
||||
"hash": ((lora_info or {}).get("sha256") or "").lower(),
|
||||
"modelVersionId": civitai.get("id", 0),
|
||||
"modelName": civitai_model.get("name", name) if lora_info else "",
|
||||
"modelVersionName": civitai.get("name", "") if lora_info else "",
|
||||
"isDeleted": False,
|
||||
"exclude": False,
|
||||
}
|
||||
)
|
||||
|
||||
base_model = (lora_info or {}).get("base_model")
|
||||
if base_model:
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
return lora_matches, loras_data, base_model_counts
|
||||
|
||||
def _build_recipe_checkpoint(self, recipe_scanner, checkpoint_raw):
|
||||
if not isinstance(checkpoint_raw, str) or not checkpoint_raw.strip():
|
||||
return None
|
||||
|
||||
checkpoint_name = checkpoint_raw.strip()
|
||||
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
|
||||
checkpoint_scanner = getattr(recipe_scanner, "_checkpoint_scanner", None)
|
||||
checkpoint_info = self._get_cached_model_by_name(
|
||||
checkpoint_scanner, checkpoint_name
|
||||
)
|
||||
|
||||
if not checkpoint_info:
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"name": checkpoint_name,
|
||||
"file_name": file_name,
|
||||
"hash": self.get_checkpoint_hash(checkpoint_name) or "",
|
||||
}
|
||||
|
||||
civitai = checkpoint_info.get("civitai") or {}
|
||||
civitai_model = civitai.get("model") or {}
|
||||
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
|
||||
cached_file_name = (
|
||||
checkpoint_info.get("file_name")
|
||||
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
|
||||
or file_name
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "checkpoint",
|
||||
"modelId": civitai_model.get("id", 0),
|
||||
"modelVersionId": civitai.get("id", 0),
|
||||
"name": civitai_model.get("name")
|
||||
or checkpoint_info.get("model_name")
|
||||
or checkpoint_name,
|
||||
"version": civitai.get("name", ""),
|
||||
"hash": (
|
||||
checkpoint_info.get("sha256") or checkpoint_info.get("hash") or ""
|
||||
).lower(),
|
||||
"file_name": cached_file_name,
|
||||
"modelName": civitai_model.get("name", ""),
|
||||
"modelVersionName": civitai.get("name", ""),
|
||||
"baseModel": checkpoint_info.get("base_model")
|
||||
or civitai.get("baseModel", ""),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _derive_recipe_name(lora_matches):
|
||||
recipe_name_parts = [
|
||||
f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]
|
||||
]
|
||||
return "_".join(recipe_name_parts) or "recipe"
|
||||
|
||||
@staticmethod
|
||||
def _sync_recipe_cache(recipe_scanner, recipe_data, json_path):
|
||||
cache = getattr(recipe_scanner, "_cache", None)
|
||||
if cache is not None:
|
||||
cache.raw_data.append(recipe_data)
|
||||
cache.sorted_by_name = sorted(
|
||||
cache.raw_data, key=lambda item: item.get("title", "").lower()
|
||||
)
|
||||
cache.sorted_by_date = sorted(
|
||||
cache.raw_data,
|
||||
key=lambda item: (
|
||||
item.get("modified", item.get("created_date", 0)),
|
||||
item.get("file_path", ""),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
recipe_scanner._update_folder_metadata(cache)
|
||||
recipe_scanner._update_fts_index_for_recipe(recipe_data, "add")
|
||||
|
||||
recipe_id = str(recipe_data.get("id", ""))
|
||||
if recipe_id:
|
||||
recipe_scanner._json_path_map[recipe_id] = json_path
|
||||
persistent_cache = getattr(recipe_scanner, "_persistent_cache", None)
|
||||
if persistent_cache:
|
||||
persistent_cache.update_recipe(recipe_data, json_path)
|
||||
|
||||
def _save_image_as_recipe(self, file_path, metadata_dict):
|
||||
if not metadata_dict:
|
||||
raise ValueError("No generation metadata found")
|
||||
|
||||
recipe_scanner = ServiceRegistry.get_service_sync("recipe_scanner")
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir:
|
||||
raise RuntimeError("Recipes directory unavailable")
|
||||
os.makedirs(recipes_dir, exist_ok=True)
|
||||
|
||||
recipe_id = str(uuid.uuid4())
|
||||
optimized_image, extension = ExifUtils.optimize_image(
|
||||
image_data=file_path,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=True,
|
||||
)
|
||||
image_path = os.path.normpath(os.path.join(recipes_dir, f"{recipe_id}{extension}"))
|
||||
with open(image_path, "wb") as file_obj:
|
||||
file_obj.write(optimized_image)
|
||||
|
||||
lora_stack = metadata_dict.get("loras", "")
|
||||
lora_matches, loras_data, base_model_counts = self._build_recipe_loras(
|
||||
recipe_scanner, lora_stack
|
||||
)
|
||||
checkpoint_entry = self._build_recipe_checkpoint(
|
||||
recipe_scanner, metadata_dict.get("checkpoint")
|
||||
)
|
||||
most_common_base_model = (
|
||||
max(base_model_counts.items(), key=lambda item: item[1])[0]
|
||||
if base_model_counts
|
||||
else ""
|
||||
)
|
||||
current_time = time.time()
|
||||
recipe_data = {
|
||||
"id": recipe_id,
|
||||
"file_path": image_path,
|
||||
"title": self._derive_recipe_name(lora_matches),
|
||||
"modified": current_time,
|
||||
"created_date": current_time,
|
||||
"base_model": most_common_base_model
|
||||
or (checkpoint_entry or {}).get("baseModel", ""),
|
||||
"loras": loras_data,
|
||||
"gen_params": {
|
||||
key: value
|
||||
for key, value in metadata_dict.items()
|
||||
if key not in ["checkpoint", "loras"]
|
||||
},
|
||||
"loras_stack": lora_stack,
|
||||
"fingerprint": calculate_recipe_fingerprint(loras_data),
|
||||
}
|
||||
if checkpoint_entry:
|
||||
recipe_data["checkpoint"] = checkpoint_entry
|
||||
|
||||
json_path = os.path.normpath(
|
||||
os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
)
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
||||
self._sync_recipe_cache(recipe_scanner, recipe_data, json_path)
|
||||
|
||||
def save_images(
|
||||
self,
|
||||
images,
|
||||
@@ -359,6 +568,7 @@ class SaveImageLM:
|
||||
embed_workflow=False,
|
||||
save_with_metadata=True,
|
||||
add_counter_to_filename=True,
|
||||
save_as_recipe=False,
|
||||
):
|
||||
"""Save images with metadata"""
|
||||
results = []
|
||||
@@ -477,6 +687,14 @@ class SaveImageLM:
|
||||
|
||||
img.save(file_path, format="WEBP", **save_kwargs)
|
||||
|
||||
if save_as_recipe:
|
||||
try:
|
||||
self._save_image_as_recipe(file_path, metadata_dict)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to save image as recipe: %s", e, exc_info=True
|
||||
)
|
||||
|
||||
results.append(
|
||||
{"filename": file, "subfolder": subfolder, "type": self.type}
|
||||
)
|
||||
@@ -499,6 +717,7 @@ class SaveImageLM:
|
||||
embed_workflow=False,
|
||||
save_with_metadata=True,
|
||||
add_counter_to_filename=True,
|
||||
save_as_recipe=False,
|
||||
):
|
||||
"""Process and save image with metadata"""
|
||||
# Make sure the output directory exists
|
||||
@@ -527,6 +746,7 @@ class SaveImageLM:
|
||||
embed_workflow,
|
||||
save_with_metadata,
|
||||
add_counter_to_filename,
|
||||
save_as_recipe,
|
||||
)
|
||||
|
||||
return {
|
||||
|
||||
@@ -177,6 +177,383 @@ def test_attention_bias_clip_text_encode_prompts_are_collected(metadata_registry
|
||||
assert prompt_results["negative_prompt"] == "low quality"
|
||||
|
||||
|
||||
def test_myoriginalwaifu_text_provider_uses_processed_prompt_outputs(
|
||||
metadata_registry, monkeypatch
|
||||
):
|
||||
prompt_graph = {
|
||||
"text_provider": {
|
||||
"class_type": "TextProvider",
|
||||
"inputs": {
|
||||
"positive": "raw positive",
|
||||
"negative": "raw negative",
|
||||
},
|
||||
},
|
||||
"encode_pos": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": ["text_provider", 0], "clip": ["clip", 0]},
|
||||
},
|
||||
"encode_neg": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": ["text_provider", 1], "clip": ["clip", 0]},
|
||||
},
|
||||
"sampler": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": ["encode_pos", 0],
|
||||
"negative": ["encode_neg", 0],
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt = SimpleNamespace(original_prompt=prompt_graph)
|
||||
|
||||
pos_conditioning = object()
|
||||
neg_conditioning = object()
|
||||
|
||||
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
|
||||
|
||||
metadata_registry.start_collection("prompt-myoriginalwaifu-text")
|
||||
metadata_registry.set_current_prompt(prompt)
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"text_provider",
|
||||
"TextProvider",
|
||||
{"positive": "raw positive", "negative": "raw negative"},
|
||||
None,
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"text_provider",
|
||||
"TextProvider",
|
||||
[("processed positive", "processed negative")],
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_pos", "CLIPTextEncode", {"text": "processed positive"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_pos", "CLIPTextEncode", [(pos_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_neg", "CLIPTextEncode", {"text": "processed negative"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_neg", "CLIPTextEncode", [(neg_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"sampler",
|
||||
"KSampler",
|
||||
{
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": pos_conditioning,
|
||||
"negative": neg_conditioning,
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt-myoriginalwaifu-text")
|
||||
params = MetadataProcessor.extract_generation_params(metadata)
|
||||
|
||||
assert metadata[PROMPTS]["text_provider"]["positive_text"] == "processed positive"
|
||||
assert metadata[PROMPTS]["text_provider"]["negative_text"] == "processed negative"
|
||||
assert params["prompt"] == "processed positive"
|
||||
assert params["negative_prompt"] == "processed negative"
|
||||
|
||||
|
||||
def test_myoriginalwaifu_clip_provider_prompts_are_collected_without_clip_text_encode(
|
||||
metadata_registry, monkeypatch
|
||||
):
|
||||
prompt_graph = {
|
||||
"clip_provider": {
|
||||
"class_type": "ClipProvider",
|
||||
"inputs": {
|
||||
"positive": "direct positive",
|
||||
"negative": "direct negative",
|
||||
"clip": ["clip", 0],
|
||||
},
|
||||
},
|
||||
"sampler": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": ["clip_provider", 0],
|
||||
"negative": ["clip_provider", 1],
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt = SimpleNamespace(original_prompt=prompt_graph)
|
||||
|
||||
pos_conditioning = object()
|
||||
neg_conditioning = object()
|
||||
|
||||
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
|
||||
|
||||
metadata_registry.start_collection("prompt-myoriginalwaifu-clip")
|
||||
metadata_registry.set_current_prompt(prompt)
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"clip_provider",
|
||||
"ClipProvider",
|
||||
{"positive": "direct positive", "negative": "direct negative"},
|
||||
None,
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"clip_provider", "ClipProvider", [(pos_conditioning, neg_conditioning)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"sampler",
|
||||
"KSampler",
|
||||
{
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": pos_conditioning,
|
||||
"negative": neg_conditioning,
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt-myoriginalwaifu-clip")
|
||||
params = MetadataProcessor.extract_generation_params(metadata)
|
||||
|
||||
assert params["prompt"] == "direct positive"
|
||||
assert params["negative_prompt"] == "direct negative"
|
||||
|
||||
|
||||
def test_conditioning_provenance_recovers_combined_controlnet_prompts(
|
||||
metadata_registry, monkeypatch
|
||||
):
|
||||
import types
|
||||
|
||||
prompt_graph = {
|
||||
"encode_wd": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "wd14 tags", "clip": ["clip", 0]},
|
||||
},
|
||||
"encode_manual": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "manual tags", "clip": ["clip", 0]},
|
||||
},
|
||||
"encode_neg": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "low quality", "clip": ["clip", 0]},
|
||||
},
|
||||
"combine": {
|
||||
"class_type": "ConditioningCombine",
|
||||
"inputs": {
|
||||
"conditioning_1": ["encode_wd", 0],
|
||||
"conditioning_2": ["encode_manual", 0],
|
||||
},
|
||||
},
|
||||
"controlnet": {
|
||||
"class_type": "ControlNetApplyAdvanced",
|
||||
"inputs": {
|
||||
"positive": ["combine", 0],
|
||||
"negative": ["encode_neg", 0],
|
||||
},
|
||||
},
|
||||
"sampler": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": ["controlnet", 0],
|
||||
"negative": ["controlnet", 1],
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt = SimpleNamespace(original_prompt=prompt_graph)
|
||||
|
||||
wd_conditioning = object()
|
||||
manual_conditioning = object()
|
||||
negative_conditioning = object()
|
||||
combined_conditioning = object()
|
||||
controlnet_positive = object()
|
||||
controlnet_negative = object()
|
||||
|
||||
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
|
||||
|
||||
metadata_registry.start_collection("prompt-provenance")
|
||||
metadata_registry.set_current_prompt(prompt)
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_wd", "CLIPTextEncode", {"text": "wd14 tags"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_wd", "CLIPTextEncode", [(wd_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_manual", "CLIPTextEncode", {"text": "manual tags"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_manual", "CLIPTextEncode", [(manual_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_neg", "CLIPTextEncode", {"text": "low quality"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_neg", "CLIPTextEncode", [(negative_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"combine",
|
||||
"ConditioningCombine",
|
||||
{
|
||||
"conditioning_1": wd_conditioning,
|
||||
"conditioning_2": manual_conditioning,
|
||||
},
|
||||
None,
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"combine", "ConditioningCombine", [(combined_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"controlnet",
|
||||
"ControlNetApplyAdvanced",
|
||||
{
|
||||
"positive": combined_conditioning,
|
||||
"negative": negative_conditioning,
|
||||
},
|
||||
None,
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"controlnet",
|
||||
"ControlNetApplyAdvanced",
|
||||
[(controlnet_positive, controlnet_negative)],
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"sampler",
|
||||
"KSampler",
|
||||
{
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": controlnet_positive,
|
||||
"negative": controlnet_negative,
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt-provenance")
|
||||
params = MetadataProcessor.extract_generation_params(metadata)
|
||||
|
||||
assert params["prompt"] == "wd14 tags, manual tags"
|
||||
assert params["negative_prompt"] == "low quality"
|
||||
|
||||
|
||||
def test_conditioning_provenance_recovers_kj_set_get_prompts(
|
||||
metadata_registry, monkeypatch
|
||||
):
|
||||
import types
|
||||
|
||||
prompt_graph = {
|
||||
"encode_pos": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {"text": "from set node", "clip": ["clip", 0]},
|
||||
},
|
||||
"set_positive": {
|
||||
"class_type": "SetNode",
|
||||
"inputs": {"CONDITIONING": ["encode_pos", 0], "name": "positive"},
|
||||
},
|
||||
"get_positive": {
|
||||
"class_type": "GetNode",
|
||||
"inputs": {"name": "positive"},
|
||||
},
|
||||
"sampler": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": ["get_positive", 0],
|
||||
"negative": ["encode_pos", 0],
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
},
|
||||
}
|
||||
prompt = SimpleNamespace(original_prompt=prompt_graph)
|
||||
|
||||
original_conditioning = object()
|
||||
get_conditioning = object()
|
||||
|
||||
monkeypatch.setattr(metadata_processor, "standalone_mode", False)
|
||||
|
||||
metadata_registry.start_collection("prompt-kj-get")
|
||||
metadata_registry.set_current_prompt(prompt)
|
||||
|
||||
metadata_registry.record_node_execution(
|
||||
"encode_pos", "CLIPTextEncode", {"text": "from set node"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"encode_pos", "CLIPTextEncode", [(original_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"set_positive",
|
||||
"SetNode",
|
||||
{"CONDITIONING": original_conditioning, "name": "positive"},
|
||||
None,
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"get_positive", "GetNode", {"name": "positive"}, None
|
||||
)
|
||||
metadata_registry.update_node_execution(
|
||||
"get_positive", "GetNode", [(get_conditioning,)]
|
||||
)
|
||||
metadata_registry.record_node_execution(
|
||||
"sampler",
|
||||
"KSampler",
|
||||
{
|
||||
"seed": 123,
|
||||
"steps": 20,
|
||||
"cfg": 7.0,
|
||||
"sampler_name": "Euler",
|
||||
"scheduler": "karras",
|
||||
"denoise": 1.0,
|
||||
"positive": get_conditioning,
|
||||
"negative": original_conditioning,
|
||||
"latent_image": {"samples": types.SimpleNamespace(shape=(1, 4, 16, 16))},
|
||||
},
|
||||
None,
|
||||
)
|
||||
|
||||
metadata = metadata_registry.get_metadata("prompt-kj-get")
|
||||
params = MetadataProcessor.extract_generation_params(metadata)
|
||||
|
||||
assert params["prompt"] == "from set node"
|
||||
assert params["negative_prompt"] == "from set node"
|
||||
|
||||
|
||||
def test_sampler_custom_advanced_recovers_prompt_text_through_guidance_nodes(metadata_registry, monkeypatch):
|
||||
import types
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import piexif
|
||||
from PIL import Image
|
||||
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.nodes.save_image import SaveImageLM
|
||||
|
||||
|
||||
@@ -151,3 +153,213 @@ def test_process_image_returns_empty_ui_images_when_save_fails(monkeypatch, tmp_
|
||||
|
||||
assert result["result"] == (images,)
|
||||
assert result["ui"] == {"images": []}
|
||||
|
||||
|
||||
def test_save_image_does_not_save_recipe_by_default(monkeypatch, tmp_path):
|
||||
_configure_save_paths(monkeypatch, tmp_path)
|
||||
_configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123})
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
SaveImageLM,
|
||||
"_save_image_as_recipe",
|
||||
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
|
||||
)
|
||||
|
||||
node = SaveImageLM()
|
||||
node.save_images([_make_image()], "ComfyUI", "png", id="node-1")
|
||||
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_save_image_saves_recipe_when_enabled(monkeypatch, tmp_path):
|
||||
_configure_save_paths(monkeypatch, tmp_path)
|
||||
metadata_dict = {"prompt": "prompt text", "seed": 123}
|
||||
_configure_metadata(monkeypatch, metadata_dict)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
SaveImageLM,
|
||||
"_save_image_as_recipe",
|
||||
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
|
||||
)
|
||||
|
||||
node = SaveImageLM()
|
||||
node.save_images(
|
||||
[_make_image()],
|
||||
"ComfyUI",
|
||||
"png",
|
||||
id="node-1",
|
||||
save_as_recipe=True,
|
||||
)
|
||||
|
||||
assert calls == [(str(tmp_path / "sample_00001_.png"), metadata_dict)]
|
||||
|
||||
|
||||
def test_save_image_saves_recipe_for_each_successful_batch_image(monkeypatch, tmp_path):
|
||||
monkeypatch.setattr("folder_paths.get_output_directory", lambda: str(tmp_path), raising=False)
|
||||
monkeypatch.setattr(
|
||||
"folder_paths.get_save_image_path",
|
||||
lambda *_args, **_kwargs: (str(tmp_path), "sample", 7, "", "sample"),
|
||||
raising=False,
|
||||
)
|
||||
metadata_dict = {"prompt": "prompt text", "seed": 123}
|
||||
_configure_metadata(monkeypatch, metadata_dict)
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(
|
||||
SaveImageLM,
|
||||
"_save_image_as_recipe",
|
||||
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
|
||||
)
|
||||
|
||||
node = SaveImageLM()
|
||||
node.save_images(
|
||||
[_make_image(), _make_image()],
|
||||
"ComfyUI",
|
||||
"png",
|
||||
id="node-1",
|
||||
save_as_recipe=True,
|
||||
)
|
||||
|
||||
assert calls == [
|
||||
(str(tmp_path / "sample_00007_.png"), metadata_dict),
|
||||
(str(tmp_path / "sample_00008_.png"), metadata_dict),
|
||||
]
|
||||
|
||||
|
||||
def test_save_image_does_not_save_recipe_when_image_save_fails(monkeypatch, tmp_path):
|
||||
_configure_save_paths(monkeypatch, tmp_path)
|
||||
_configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123})
|
||||
|
||||
def _raise_save_error(*args, **kwargs):
|
||||
raise OSError("disk full")
|
||||
|
||||
calls = []
|
||||
monkeypatch.setattr(Image.Image, "save", _raise_save_error)
|
||||
monkeypatch.setattr(
|
||||
SaveImageLM,
|
||||
"_save_image_as_recipe",
|
||||
lambda self, file_path, metadata_dict: calls.append((file_path, metadata_dict)),
|
||||
)
|
||||
|
||||
node = SaveImageLM()
|
||||
node.save_images(
|
||||
[_make_image()],
|
||||
"ComfyUI",
|
||||
"png",
|
||||
id="node-1",
|
||||
save_as_recipe=True,
|
||||
)
|
||||
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_process_image_keeps_image_result_when_recipe_save_fails(monkeypatch, tmp_path):
|
||||
_configure_save_paths(monkeypatch, tmp_path)
|
||||
_configure_metadata(monkeypatch, {"prompt": "prompt text", "seed": 123})
|
||||
|
||||
def _raise_recipe_error(*args, **kwargs):
|
||||
raise RuntimeError("recipe unavailable")
|
||||
|
||||
monkeypatch.setattr(SaveImageLM, "_save_image_as_recipe", _raise_recipe_error)
|
||||
|
||||
images = [_make_image()]
|
||||
node = SaveImageLM()
|
||||
|
||||
result = node.process_image(images, id="node-1", save_as_recipe=True)
|
||||
|
||||
assert result["result"] == (images,)
|
||||
assert result["ui"] == {
|
||||
"images": [{"filename": "sample_00001_.png", "subfolder": "", "type": "output"}]
|
||||
}
|
||||
|
||||
|
||||
def test_save_image_as_recipe_writes_recipe_without_async_scanner_calls(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
_configure_save_paths(monkeypatch, tmp_path)
|
||||
source_image = tmp_path / "source.png"
|
||||
Image.new("RGB", (16, 16), color=(10, 20, 30)).save(source_image)
|
||||
recipes_dir = tmp_path / "recipes"
|
||||
|
||||
class _Cache:
|
||||
def __init__(self, raw_data=None):
|
||||
self.raw_data = raw_data or []
|
||||
self.sorted_by_name = []
|
||||
self.sorted_by_date = []
|
||||
self.folders = []
|
||||
self.folder_tree = {}
|
||||
|
||||
class _ModelScanner:
|
||||
def __init__(self, raw_data):
|
||||
self._cache = _Cache(raw_data)
|
||||
|
||||
class _PersistentCache:
|
||||
def __init__(self):
|
||||
self.updates = []
|
||||
|
||||
def update_recipe(self, recipe_data, json_path):
|
||||
self.updates.append((recipe_data, json_path))
|
||||
|
||||
class _RecipeScanner:
|
||||
def __init__(self):
|
||||
self.recipes_dir = str(recipes_dir)
|
||||
self._cache = _Cache([])
|
||||
self._json_path_map = {}
|
||||
self._persistent_cache = _PersistentCache()
|
||||
self._lora_scanner = _ModelScanner(
|
||||
[
|
||||
{
|
||||
"file_name": "foo",
|
||||
"sha256": "ABC123",
|
||||
"base_model": "SDXL",
|
||||
"civitai": {
|
||||
"id": 456,
|
||||
"name": "Foo v1",
|
||||
"model": {"name": "Foo"},
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
self._checkpoint_scanner = _ModelScanner([])
|
||||
self.fts_updates = []
|
||||
|
||||
def _update_folder_metadata(self, cache):
|
||||
cache.folders = [""]
|
||||
cache.folder_tree = {}
|
||||
|
||||
def _update_fts_index_for_recipe(self, recipe_data, operation):
|
||||
self.fts_updates.append((recipe_data["id"], operation))
|
||||
|
||||
scanner = _RecipeScanner()
|
||||
monkeypatch.setitem(ServiceRegistry._services, "recipe_scanner", scanner)
|
||||
|
||||
node = SaveImageLM()
|
||||
node._save_image_as_recipe(
|
||||
str(source_image),
|
||||
{
|
||||
"prompt": "prompt text",
|
||||
"seed": 123,
|
||||
"checkpoint": "model.safetensors",
|
||||
"loras": "<lora:foo:0.7>",
|
||||
},
|
||||
)
|
||||
|
||||
recipe_files = list(recipes_dir.glob("*.recipe.json"))
|
||||
preview_files = list(recipes_dir.glob("*.webp"))
|
||||
|
||||
assert len(recipe_files) == 1
|
||||
assert len(preview_files) == 1
|
||||
assert len(scanner._cache.raw_data) == 1
|
||||
assert len(scanner._persistent_cache.updates) == 1
|
||||
|
||||
recipe = json.loads(recipe_files[0].read_text(encoding="utf-8"))
|
||||
assert recipe["file_path"] == os.path.normpath(str(preview_files[0]))
|
||||
assert recipe["title"] == "foo-0.70"
|
||||
assert recipe["base_model"] == "SDXL"
|
||||
assert recipe["loras"][0]["hash"] == "abc123"
|
||||
assert recipe["loras"][0]["modelVersionId"] == 456
|
||||
assert recipe["gen_params"] == {"prompt": "prompt text", "seed": 123}
|
||||
assert scanner._json_path_map[recipe["id"]] == os.path.normpath(str(recipe_files[0]))
|
||||
assert scanner.fts_updates == [(recipe["id"], "add")]
|
||||
|
||||
@@ -5,6 +5,7 @@ import LoraRandomizerWidget from '@/components/LoraRandomizerWidget.vue'
|
||||
import LoraCyclerWidget from '@/components/LoraCyclerWidget.vue'
|
||||
import JsonDisplayWidget from '@/components/JsonDisplayWidget.vue'
|
||||
import AutocompleteTextWidget from '@/components/AutocompleteTextWidget.vue'
|
||||
import { createVueWidgetCleanup } from './vue-widget-cleanup'
|
||||
import type { LoraPoolConfig, RandomizerConfig, CyclerConfig } from './composables/types'
|
||||
import {
|
||||
setupModeChangeHandler,
|
||||
@@ -66,6 +67,12 @@ function forwardMiddleMouseToCanvas(container: HTMLElement) {
|
||||
}
|
||||
|
||||
const vueApps = new Map<number, VueApp>()
|
||||
let autocompleteTextWidgetInstanceId = 0
|
||||
|
||||
export function createAutocompleteTextWidgetInstanceId() {
|
||||
autocompleteTextWidgetInstanceId += 1
|
||||
return autocompleteTextWidgetInstanceId
|
||||
}
|
||||
|
||||
// Cache for dynamically loaded addLorasWidget module
|
||||
let addLorasWidgetCache: any = null
|
||||
@@ -562,8 +569,9 @@ function createAutocompleteTextWidgetFactory(
|
||||
inputOptions: { placeholder?: string } = {}
|
||||
) {
|
||||
const metadataWidgetName = `__lm_autocomplete_meta_${widgetName}`
|
||||
const instanceId = createAutocompleteTextWidgetInstanceId()
|
||||
const container = document.createElement('div')
|
||||
container.id = `autocomplete-text-widget-${node.id}-${widgetName}`
|
||||
container.id = `autocomplete-text-widget-${instanceId}`
|
||||
container.style.width = '100%'
|
||||
container.style.height = '100%'
|
||||
container.style.display = 'flex'
|
||||
@@ -644,17 +652,12 @@ function createAutocompleteTextWidgetFactory(
|
||||
})
|
||||
|
||||
vueApp.mount(container)
|
||||
// Use a unique key combining node.id and widget name to avoid collisions
|
||||
const appKey = node.id * 100000 + widgetName.charCodeAt(0)
|
||||
const appKey = instanceId
|
||||
vueApps.set(appKey, vueApp)
|
||||
|
||||
widget.onRemove = () => {
|
||||
const vueApp = vueApps.get(appKey)
|
||||
if (vueApp) {
|
||||
vueApp.unmount()
|
||||
widget.onRemove = createVueWidgetCleanup(vueApp, () => {
|
||||
vueApps.delete(appKey)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return { widget }
|
||||
}
|
||||
|
||||
15
vue-widgets/src/vue-widget-cleanup.ts
Normal file
15
vue-widgets/src/vue-widget-cleanup.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import type { App as VueApp } from 'vue'
|
||||
|
||||
export function createVueWidgetCleanup(vueApp: VueApp, onCleanup?: () => void) {
|
||||
let didUnmount = false
|
||||
|
||||
return () => {
|
||||
if (didUnmount) {
|
||||
return
|
||||
}
|
||||
|
||||
vueApp.unmount()
|
||||
didUnmount = true
|
||||
onCleanup?.()
|
||||
}
|
||||
}
|
||||
38
vue-widgets/tests/unit/vueWidgetCleanup.test.ts
Normal file
38
vue-widgets/tests/unit/vueWidgetCleanup.test.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
import { createVueWidgetCleanup } from '@/vue-widget-cleanup'
|
||||
|
||||
describe('createVueWidgetCleanup', () => {
|
||||
it('cleans up only the Vue app bound to the widget remove handler', () => {
|
||||
const firstCleanup = vi.fn()
|
||||
const secondCleanup = vi.fn()
|
||||
const firstApp = { unmount: vi.fn() }
|
||||
const secondApp = { unmount: vi.fn() }
|
||||
|
||||
const removeFirst = createVueWidgetCleanup(firstApp as any, firstCleanup)
|
||||
const removeSecond = createVueWidgetCleanup(secondApp as any, secondCleanup)
|
||||
|
||||
removeFirst()
|
||||
|
||||
expect(firstApp.unmount).toHaveBeenCalledTimes(1)
|
||||
expect(firstCleanup).toHaveBeenCalledTimes(1)
|
||||
expect(secondApp.unmount).not.toHaveBeenCalled()
|
||||
expect(secondCleanup).not.toHaveBeenCalled()
|
||||
|
||||
removeSecond()
|
||||
|
||||
expect(secondApp.unmount).toHaveBeenCalledTimes(1)
|
||||
expect(secondCleanup).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('is idempotent when ComfyUI calls the remove handler more than once', () => {
|
||||
const cleanup = vi.fn()
|
||||
const app = { unmount: vi.fn() }
|
||||
const remove = createVueWidgetCleanup(app as any, cleanup)
|
||||
|
||||
remove()
|
||||
remove()
|
||||
|
||||
expect(app.unmount).toHaveBeenCalledTimes(1)
|
||||
expect(cleanup).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
@@ -14933,6 +14933,17 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
}
|
||||
});
|
||||
const AutocompleteTextWidget = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-76ce0f19"]]);
|
||||
function createVueWidgetCleanup(vueApp, onCleanup) {
|
||||
let didUnmount = false;
|
||||
return () => {
|
||||
if (didUnmount) {
|
||||
return;
|
||||
}
|
||||
vueApp.unmount();
|
||||
didUnmount = true;
|
||||
onCleanup == null ? void 0 : onCleanup();
|
||||
};
|
||||
}
|
||||
const LORA_PROVIDER_NODE_TYPES$1 = [
|
||||
"Lora Stacker (LoraManager)",
|
||||
"Lora Randomizer (LoraManager)",
|
||||
@@ -15274,6 +15285,11 @@ function forwardMiddleMouseToCanvas(container) {
|
||||
});
|
||||
}
|
||||
const vueApps = /* @__PURE__ */ new Map();
|
||||
let autocompleteTextWidgetInstanceId = 0;
|
||||
function createAutocompleteTextWidgetInstanceId() {
|
||||
autocompleteTextWidgetInstanceId += 1;
|
||||
return autocompleteTextWidgetInstanceId;
|
||||
}
|
||||
let addLorasWidgetCache = null;
|
||||
function createLoraPoolWidget(node) {
|
||||
const container = document.createElement("div");
|
||||
@@ -15653,8 +15669,9 @@ if ((_a = app$1.ui) == null ? void 0 : _a.settings) {
|
||||
function createAutocompleteTextWidgetFactory(node, widgetName, modelType, inputOptions = {}) {
|
||||
var _a2, _b, _c;
|
||||
const metadataWidgetName = `__lm_autocomplete_meta_${widgetName}`;
|
||||
const instanceId = createAutocompleteTextWidgetInstanceId();
|
||||
const container = document.createElement("div");
|
||||
container.id = `autocomplete-text-widget-${node.id}-${widgetName}`;
|
||||
container.id = `autocomplete-text-widget-${instanceId}`;
|
||||
container.style.width = "100%";
|
||||
container.style.height = "100%";
|
||||
container.style.display = "flex";
|
||||
@@ -15721,15 +15738,11 @@ function createAutocompleteTextWidgetFactory(node, widgetName, modelType, inputO
|
||||
ripple: false
|
||||
});
|
||||
vueApp.mount(container);
|
||||
const appKey = node.id * 1e5 + widgetName.charCodeAt(0);
|
||||
const appKey = instanceId;
|
||||
vueApps.set(appKey, vueApp);
|
||||
widget.onRemove = () => {
|
||||
const vueApp2 = vueApps.get(appKey);
|
||||
if (vueApp2) {
|
||||
vueApp2.unmount();
|
||||
widget.onRemove = createVueWidgetCleanup(vueApp, () => {
|
||||
vueApps.delete(appKey);
|
||||
}
|
||||
};
|
||||
});
|
||||
return { widget };
|
||||
}
|
||||
app$1.registerExtension({
|
||||
@@ -15834,4 +15847,7 @@ app$1.registerExtension({
|
||||
}
|
||||
}
|
||||
});
|
||||
export {
|
||||
createAutocompleteTextWidgetInstanceId
|
||||
};
|
||||
//# sourceMappingURL=lora-manager-widgets.js.map
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user