From e726c4f44247d6d0c5c4ecf785e97cd345dff4cc Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Mon, 23 Jun 2025 10:55:27 +0800 Subject: [PATCH] feat: enhance metadata extraction for TSC samplers with vae_decode handling --- py/metadata_collector/metadata_processor.py | 5 +++ py/metadata_collector/node_extractors.py | 38 +++++++++++++++++++-- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index 654d9aaa..17d0a482 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -1,5 +1,6 @@ import json import sys +from .constants import IMAGES # Check if running in standalone mode standalone_mode = 'nodes' not in sys.modules @@ -18,6 +19,10 @@ class MetadataProcessor: - metadata: The workflow metadata - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler """ + if downstream_id is None: + if IMAGES in metadata and "first_decode" in metadata[IMAGES]: + downstream_id = metadata[IMAGES]["first_decode"]["node_id"] + # If we have a downstream_id and execution_order, use it to narrow down potential samplers if downstream_id and "execution_order" in metadata: execution_order = metadata["execution_order"] diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index e3c78877..5973de37 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -220,8 +220,32 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor): class TSCSamplerBaseExtractor(NodeMetadataExtractor): """Base extractor for handling TSC sampler node outputs""" + @staticmethod + def extract(node_id, inputs, outputs, metadata): + # Store vae_decode setting for later use in update + if inputs and "vae_decode" in inputs: + if SAMPLING not in metadata: + metadata[SAMPLING] = {} + + if node_id not in metadata[SAMPLING]: + metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id} + + # Store the vae_decode setting + metadata[SAMPLING][node_id]["vae_decode"] = inputs["vae_decode"] + @staticmethod def update(node_id, outputs, metadata): + # Check if vae_decode was set to "true" + should_save_image = True + if SAMPLING in metadata and node_id in metadata[SAMPLING]: + vae_decode = metadata[SAMPLING][node_id].get("vae_decode") + if vae_decode is not None: + should_save_image = (vae_decode == "true") + + # Skip image saving if vae_decode isn't "true" + if not should_save_image: + return + # Ensure IMAGES category exists if IMAGES not in metadata: metadata[IMAGES] = {} @@ -250,13 +274,23 @@ class TSCSamplerBaseExtractor(NodeMetadataExtractor): class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor): """Extractor for TSC_KSampler nodes""" - # Extract method is inherited from SamplerExtractor + @staticmethod + def extract(node_id, inputs, outputs, metadata): + # Call parent extract methods + SamplerExtractor.extract(node_id, inputs, outputs, metadata) + TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata) + # Update method is inherited from TSCSamplerBaseExtractor class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor): """Extractor for TSC_KSamplerAdvanced nodes""" - # Extract method is inherited from KSamplerAdvancedExtractor + @staticmethod + def extract(node_id, inputs, outputs, metadata): + # Call parent extract methods + SamplerExtractor.extract(node_id, inputs, outputs, metadata) + TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata) + # Update method is inherited from TSCSamplerBaseExtractor class LoraLoaderExtractor(NodeMetadataExtractor):