feat: enhance metadata extraction for TSC samplers with vae_decode handling

This commit is contained in:
Will Miao
2025-06-23 10:55:27 +08:00
parent 402318e586
commit e726c4f442
2 changed files with 41 additions and 2 deletions

View File

@@ -1,5 +1,6 @@
import json
import sys
from .constants import IMAGES
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
@@ -18,6 +19,10 @@ class MetadataProcessor:
- metadata: The workflow metadata
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
"""
if downstream_id is None:
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
downstream_id = metadata[IMAGES]["first_decode"]["node_id"]
# If we have a downstream_id and execution_order, use it to narrow down potential samplers
if downstream_id and "execution_order" in metadata:
execution_order = metadata["execution_order"]

View File

@@ -220,8 +220,32 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor):
class TSCSamplerBaseExtractor(NodeMetadataExtractor):
"""Base extractor for handling TSC sampler node outputs"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
# Store vae_decode setting for later use in update
if inputs and "vae_decode" in inputs:
if SAMPLING not in metadata:
metadata[SAMPLING] = {}
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
# Store the vae_decode setting
metadata[SAMPLING][node_id]["vae_decode"] = inputs["vae_decode"]
@staticmethod
def update(node_id, outputs, metadata):
# Check if vae_decode was set to "true"
should_save_image = True
if SAMPLING in metadata and node_id in metadata[SAMPLING]:
vae_decode = metadata[SAMPLING][node_id].get("vae_decode")
if vae_decode is not None:
should_save_image = (vae_decode == "true")
# Skip image saving if vae_decode isn't "true"
if not should_save_image:
return
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
@@ -250,13 +274,23 @@ class TSCSamplerBaseExtractor(NodeMetadataExtractor):
class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
"""Extractor for TSC_KSampler nodes"""
# Extract method is inherited from SamplerExtractor
@staticmethod
def extract(node_id, inputs, outputs, metadata):
# Call parent extract methods
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
# Update method is inherited from TSCSamplerBaseExtractor
class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor):
"""Extractor for TSC_KSamplerAdvanced nodes"""
# Extract method is inherited from KSamplerAdvancedExtractor
@staticmethod
def extract(node_id, inputs, outputs, metadata):
# Call parent extract methods
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
# Update method is inherited from TSCSamplerBaseExtractor
class LoraLoaderExtractor(NodeMetadataExtractor):