mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Add support for SamplerCustomAdvanced node in metadata extraction
This commit is contained in:
@@ -11,7 +11,16 @@ class MetadataProcessor:
|
|||||||
primary_sampler = None
|
primary_sampler = None
|
||||||
primary_sampler_id = None
|
primary_sampler_id = None
|
||||||
|
|
||||||
# First, check for KSamplerAdvanced with add_noise="enable"
|
# First, check for SamplerCustomAdvanced
|
||||||
|
prompt = metadata.get("current_prompt")
|
||||||
|
if prompt and prompt.original_prompt:
|
||||||
|
for node_id, node_info in prompt.original_prompt.items():
|
||||||
|
if node_info.get("class_type") == "SamplerCustomAdvanced":
|
||||||
|
# Found a SamplerCustomAdvanced node
|
||||||
|
if node_id in metadata.get(SAMPLING, {}):
|
||||||
|
return node_id, metadata[SAMPLING][node_id]
|
||||||
|
|
||||||
|
# Next, check for KSamplerAdvanced with add_noise="enable"
|
||||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||||
parameters = sampler_info.get("parameters", {})
|
parameters = sampler_info.get("parameters", {})
|
||||||
add_noise = parameters.get("add_noise")
|
add_noise = parameters.get("add_noise")
|
||||||
@@ -22,7 +31,7 @@ class MetadataProcessor:
|
|||||||
primary_sampler_id = node_id
|
primary_sampler_id = node_id
|
||||||
break
|
break
|
||||||
|
|
||||||
# If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1
|
# If no specialized sampler found, fall back to traditional KSampler with denoise=1
|
||||||
if primary_sampler is None:
|
if primary_sampler is None:
|
||||||
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||||
parameters = sampler_info.get("parameters", {})
|
parameters = sampler_info.get("parameters", {})
|
||||||
@@ -152,22 +161,60 @@ class MetadataProcessor:
|
|||||||
|
|
||||||
# Trace connections from the primary sampler
|
# Trace connections from the primary sampler
|
||||||
if prompt and primary_sampler_id:
|
if prompt and primary_sampler_id:
|
||||||
# Trace positive prompt - look specifically for CLIPTextEncode
|
# Check if this is a SamplerCustomAdvanced node
|
||||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
|
is_custom_advanced = False
|
||||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
|
||||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
||||||
|
|
||||||
# Find any FluxGuidance nodes in the positive conditioning path
|
if is_custom_advanced:
|
||||||
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
|
# For SamplerCustomAdvanced, trace specific inputs
|
||||||
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
|
|
||||||
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
|
# 1. Trace sigmas input to find BasicScheduler
|
||||||
params["guidance"] = flux_params.get("guidance")
|
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
|
||||||
|
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||||
|
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||||
|
params["steps"] = scheduler_params.get("steps")
|
||||||
|
params["scheduler"] = scheduler_params.get("scheduler")
|
||||||
|
|
||||||
|
# 2. Trace sampler input to find KSamplerSelect
|
||||||
|
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||||
|
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||||
|
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||||
|
params["sampler"] = sampler_params.get("sampler_name")
|
||||||
|
|
||||||
|
# 3. Trace guider input for FluxGuidance and CLIPTextEncode
|
||||||
|
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||||
|
if guider_node_id:
|
||||||
|
# Look for FluxGuidance along the guider path
|
||||||
|
flux_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "FluxGuidance", max_depth=5)
|
||||||
|
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
|
||||||
|
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
|
||||||
|
params["guidance"] = flux_params.get("guidance")
|
||||||
|
|
||||||
|
# Find CLIPTextEncode for positive prompt (through conditioning)
|
||||||
|
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "CLIPTextEncode", max_depth=10)
|
||||||
|
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||||
|
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||||
|
|
||||||
# Trace negative prompt - look specifically for CLIPTextEncode
|
else:
|
||||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
|
# Original tracing for standard samplers
|
||||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
# Trace positive prompt - look specifically for CLIPTextEncode
|
||||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||||
|
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||||
|
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||||
|
|
||||||
|
# Find any FluxGuidance nodes in the positive conditioning path
|
||||||
|
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
|
||||||
|
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
|
||||||
|
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
|
||||||
|
params["guidance"] = flux_params.get("guidance")
|
||||||
|
|
||||||
|
# Trace negative prompt - look specifically for CLIPTextEncode
|
||||||
|
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||||
|
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||||
|
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||||
|
|
||||||
|
# Size extraction is same for all sampler types
|
||||||
# Check if the sampler itself has size information (from latent_image)
|
# Check if the sampler itself has size information (from latent_image)
|
||||||
if primary_sampler_id in metadata.get(SIZE, {}):
|
if primary_sampler_id in metadata.get(SIZE, {}):
|
||||||
width = metadata[SIZE][primary_sampler_id].get("width")
|
width = metadata[SIZE][primary_sampler_id].get("width")
|
||||||
|
|||||||
@@ -257,12 +257,85 @@ class VAEDecodeExtractor(NodeMetadataExtractor):
|
|||||||
if "first_decode" not in metadata[IMAGES]:
|
if "first_decode" not in metadata[IMAGES]:
|
||||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||||
|
|
||||||
|
class KSamplerSelectExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs or "sampler_name" not in inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
sampling_params = {}
|
||||||
|
if "sampler_name" in inputs:
|
||||||
|
sampling_params["sampler_name"] = inputs["sampler_name"]
|
||||||
|
|
||||||
|
metadata[SAMPLING][node_id] = {
|
||||||
|
"parameters": sampling_params,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class BasicSchedulerExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
sampling_params = {}
|
||||||
|
for key in ["scheduler", "steps", "denoise"]:
|
||||||
|
if key in inputs:
|
||||||
|
sampling_params[key] = inputs[key]
|
||||||
|
|
||||||
|
metadata[SAMPLING][node_id] = {
|
||||||
|
"parameters": sampling_params,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
|
||||||
|
@staticmethod
|
||||||
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
|
if not inputs:
|
||||||
|
return
|
||||||
|
|
||||||
|
sampling_params = {}
|
||||||
|
|
||||||
|
# Handle noise.seed as seed
|
||||||
|
if "noise" in inputs and inputs["noise"] is not None and hasattr(inputs["noise"], "seed"):
|
||||||
|
noise = inputs["noise"]
|
||||||
|
sampling_params["seed"] = noise.seed
|
||||||
|
|
||||||
|
metadata[SAMPLING][node_id] = {
|
||||||
|
"parameters": sampling_params,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
|
# Extract latent image dimensions if available
|
||||||
|
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||||
|
latent = inputs["latent_image"]
|
||||||
|
if isinstance(latent, dict) and "samples" in latent:
|
||||||
|
# Extract dimensions from latent tensor
|
||||||
|
samples = latent["samples"]
|
||||||
|
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||||
|
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||||
|
# Multiply by 8 to get actual pixel dimensions
|
||||||
|
height = int(samples.shape[2] * 8)
|
||||||
|
width = int(samples.shape[3] * 8)
|
||||||
|
|
||||||
|
if SIZE not in metadata:
|
||||||
|
metadata[SIZE] = {}
|
||||||
|
|
||||||
|
metadata[SIZE][node_id] = {
|
||||||
|
"width": width,
|
||||||
|
"height": height,
|
||||||
|
"node_id": node_id
|
||||||
|
}
|
||||||
|
|
||||||
# Registry of node-specific extractors
|
# Registry of node-specific extractors
|
||||||
NODE_EXTRACTORS = {
|
NODE_EXTRACTORS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"KSampler": SamplerExtractor,
|
"KSampler": SamplerExtractor,
|
||||||
"KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced
|
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
||||||
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
|
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, # Updated to use dedicated extractor
|
||||||
|
# Sampling Selectors
|
||||||
|
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||||
|
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||||
# Loaders
|
# Loaders
|
||||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||||
|
|||||||
Reference in New Issue
Block a user