feat: Enhance KSamplerAdvancedExtractor to include additional sampling parameters and update metadata processing

This commit is contained in:
Will Miao
2025-04-18 05:29:36 +08:00
parent c2f599b4ff
commit bccabe40c0
2 changed files with 68 additions and 9 deletions

View File

@@ -85,6 +85,43 @@ class SamplerExtractor(NodeMetadataExtractor):
"node_id": node_id
}
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
@@ -201,14 +238,20 @@ class UNETLoaderExtractor(NodeMetadataExtractor):
# Registry of node-specific extractors
NODE_EXTRACTORS = {
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"CLIPTextEncode": CLIPTextEncodeExtractor,
# Sampling
"KSampler": SamplerExtractor,
"LoraLoader": LoraLoaderExtractor,
"EmptyLatentImage": ImageSizeExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
# Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"LoraLoader": LoraLoaderExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
# Conditioning
"CLIPTextEncode": CLIPTextEncodeExtractor,
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Add other nodes as needed
}