import json import os from .constants import IMAGES # Check if running in standalone mode standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER class MetadataProcessor: """Process and format collected metadata""" @staticmethod def find_primary_sampler(metadata, downstream_id=None): """ Find the primary KSampler node that executed before the given downstream node Parameters: - metadata: The workflow metadata - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler """ if downstream_id is None: if IMAGES in metadata and "first_decode" in metadata[IMAGES]: downstream_id = metadata[IMAGES]["first_decode"]["node_id"] # If we have a downstream_id and execution_order, use it to narrow down potential samplers if downstream_id and "execution_order" in metadata: execution_order = metadata["execution_order"] # Find the index of the downstream node in the execution order if downstream_id in execution_order: downstream_index = execution_order.index(downstream_id) # Extract all sampler nodes that executed before the downstream node candidate_samplers = {} for i in range(downstream_index): node_id = execution_order[i] # Use IS_SAMPLER flag to identify true sampler nodes if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False): candidate_samplers[node_id] = metadata[SAMPLING][node_id] # If we found candidate samplers, apply primary sampler logic to these candidates only if candidate_samplers: # Collect potential primary samplers based on different criteria custom_advanced_samplers = [] advanced_add_noise_samplers = [] high_denoise_samplers = [] max_denoise = -1 high_denoise_id = None # First, check for SamplerCustomAdvanced among candidates prompt = metadata.get("current_prompt") if prompt and prompt.original_prompt: for node_id in candidate_samplers: node_info = prompt.original_prompt.get(node_id, {}) if node_info.get("class_type") == "SamplerCustomAdvanced": custom_advanced_samplers.append(node_id) # Next, check for KSamplerAdvanced with add_noise="enable" among candidates for node_id, sampler_info in candidate_samplers.items(): parameters = sampler_info.get("parameters", {}) add_noise = parameters.get("add_noise") if add_noise == "enable": advanced_add_noise_samplers.append(node_id) # Find the sampler with highest denoise value among candidates for node_id, sampler_info in candidate_samplers.items(): parameters = sampler_info.get("parameters", {}) denoise = parameters.get("denoise") if denoise is not None and denoise > max_denoise: max_denoise = denoise high_denoise_id = node_id if high_denoise_id: high_denoise_samplers.append(high_denoise_id) # Combine all potential primary samplers potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers # Find the most recent potential primary sampler (closest to downstream node) for i in range(downstream_index - 1, -1, -1): node_id = execution_order[i] if node_id in potential_samplers: return node_id, candidate_samplers[node_id] # If no potential sampler found from our criteria, return the most recent sampler if candidate_samplers: for i in range(downstream_index - 1, -1, -1): node_id = execution_order[i] if node_id in candidate_samplers: return node_id, candidate_samplers[node_id] # If no downstream_id provided or no suitable sampler found, fall back to original logic primary_sampler = None primary_sampler_id = None max_denoise = -1 # First, check for SamplerCustomAdvanced prompt = metadata.get("current_prompt") if prompt and prompt.original_prompt: for node_id, node_info in prompt.original_prompt.items(): if node_info.get("class_type") == "SamplerCustomAdvanced": # Check if the node is in SAMPLING and has IS_SAMPLER flag if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False): return node_id, metadata[SAMPLING][node_id] # Next, check for KSamplerAdvanced with add_noise="enable" using IS_SAMPLER flag for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): # Skip if not marked as a sampler if not sampler_info.get(IS_SAMPLER, False): continue parameters = sampler_info.get("parameters", {}) add_noise = parameters.get("add_noise") if add_noise == "enable": primary_sampler = sampler_info primary_sampler_id = node_id break # If no specialized sampler found, find the sampler with highest denoise value if primary_sampler is None: for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): # Skip if not marked as a sampler if not sampler_info.get(IS_SAMPLER, False): continue parameters = sampler_info.get("parameters", {}) denoise = parameters.get("denoise") if denoise is not None and denoise > max_denoise: max_denoise = denoise primary_sampler = sampler_info primary_sampler_id = node_id return primary_sampler_id, primary_sampler @staticmethod def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10): """ Trace an input connection from a node to find the source node Parameters: - prompt: The prompt object containing node connections - node_id: ID of the starting node - input_name: Name of the input to trace - target_class: Optional class name to search for (e.g., "CLIPTextEncode") - max_depth: Maximum depth to follow the node chain to prevent infinite loops Returns: - node_id of the found node, or None if not found """ if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt: return None # For depth tracking current_depth = 0 current_node_id = node_id current_input = input_name # If we're just tracing to origin (no target_class), keep track of the last valid node last_valid_node = None while current_depth < max_depth: if current_node_id not in prompt.original_prompt: return last_valid_node if not target_class else None node_inputs = prompt.original_prompt[current_node_id].get("inputs", {}) if current_input not in node_inputs: # We've reached a node without the specified input - this is our origin node # if we're not looking for a specific target_class return current_node_id if not target_class else None input_value = node_inputs[current_input] # Input connections are formatted as [node_id, output_index] if isinstance(input_value, list) and len(input_value) >= 2: found_node_id = input_value[0] # Connected node_id # If we're looking for a specific node class if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class: return found_node_id # If we're not looking for a specific class, update the last valid node if not target_class: last_valid_node = found_node_id # Continue tracing through intermediate nodes current_node_id = found_node_id # For most conditioning nodes, the input we want to follow is named "conditioning" if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}): current_input = "conditioning" else: # If there's no "conditioning" input, return the current node # if we're not looking for a specific target_class return found_node_id if not target_class else None else: # We've reached a node with no further connections return last_valid_node if not target_class else None current_depth += 1 # If we've reached max depth without finding target_class return last_valid_node if not target_class else None @staticmethod def trace_model_path(metadata, prompt, start_node_id): """ Trace the model connection path upstream to find the checkpoint """ if not prompt or not prompt.original_prompt: return None current_node_id = start_node_id depth = 0 max_depth = 50 while depth < max_depth: # Check if current node is a registered checkpoint in our metadata # This handles cached nodes correctly because metadata contains info for all nodes in the graph if current_node_id in metadata.get(MODELS, {}): if metadata[MODELS][current_node_id].get("type") == "checkpoint": return current_node_id if current_node_id not in prompt.original_prompt: return None node = prompt.original_prompt[current_node_id] inputs = node.get("inputs", {}) class_type = node.get("class_type", "") # Determine which input to follow next next_input_name = "model" # Special handling for initial node if depth == 0: if class_type == "SamplerCustomAdvanced": next_input_name = "guider" # If the specific input doesn't exist, try generic 'model' if next_input_name not in inputs: if "model" in inputs: next_input_name = "model" elif "basic_pipe" in inputs: # Handle pipe nodes like FromBasicPipe by following the pipeline next_input_name = "basic_pipe" else: # Dead end - no model input to follow return None # Get connected node input_val = inputs[next_input_name] if isinstance(input_val, list) and len(input_val) > 0: current_node_id = input_val[0] else: return None depth += 1 return None @staticmethod def find_primary_checkpoint(metadata, downstream_id=None, primary_sampler_id=None): """ Find the primary checkpoint model in the workflow Parameters: - metadata: The workflow metadata - downstream_id: Optional ID of a downstream node to help identify the specific primary sampler - primary_sampler_id: Optional ID of the primary sampler if already known """ if not metadata.get(MODELS): return None # Method 1: Topology-based tracing (More accurate for complex workflows) # First, find the primary sampler if not provided if not primary_sampler_id: primary_sampler_id, _ = MetadataProcessor.find_primary_sampler(metadata, downstream_id) if primary_sampler_id: prompt = metadata.get("current_prompt") if prompt: # Trace back from the sampler to find the checkpoint checkpoint_id = MetadataProcessor.trace_model_path(metadata, prompt, primary_sampler_id) if checkpoint_id and checkpoint_id in metadata.get(MODELS, {}): return metadata[MODELS][checkpoint_id].get("name") # Method 2: Fallback to the first available checkpoint (Original behavior) # In most simple workflows, there's only one checkpoint, so we can just take the first one for node_id, model_info in metadata.get(MODELS, {}).items(): if model_info.get("type") == "checkpoint": return model_info.get("name") return None @staticmethod def match_conditioning_to_prompts(metadata, sampler_id): """ Match conditioning objects from a sampler to prompts in metadata Parameters: - metadata: The workflow metadata - sampler_id: ID of the sampler node to match Returns: - Dictionary with 'prompt' and 'negative_prompt' if found """ result = { "prompt": "", "negative_prompt": "" } # Check if we have stored conditioning objects for this sampler if sampler_id in metadata.get(PROMPTS, {}) and ( "pos_conditioning" in metadata[PROMPTS][sampler_id] or "neg_conditioning" in metadata[PROMPTS][sampler_id]): pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning") neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning") # Helper function to recursively find prompt text for a conditioning object def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True): if conditioning_obj is None: return "" # Try to match conditioning objects with those stored by extractors for prompt_node_id, prompt_data in metadata[PROMPTS].items(): # For nodes with single conditioning output if "conditioning" in prompt_data: if id(prompt_data["conditioning"]) == id(conditioning_obj): return prompt_data.get("text", "") # For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader) if is_positive and "positive_encoded" in prompt_data: if id(prompt_data["positive_encoded"]) == id(conditioning_obj): if "positive_text" in prompt_data: return prompt_data["positive_text"] else: orig_conditioning = prompt_data.get("orig_pos_cond", None) if orig_conditioning is not None: # Recursively find the prompt text for the original conditioning return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True) if not is_positive and "negative_encoded" in prompt_data: if id(prompt_data["negative_encoded"]) == id(conditioning_obj): if "negative_text" in prompt_data: return prompt_data["negative_text"] else: orig_conditioning = prompt_data.get("orig_neg_cond", None) if orig_conditioning is not None: # Recursively find the prompt text for the original conditioning return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False) return "" # Find prompt texts using the helper function result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True) result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False) return result @staticmethod def extract_generation_params(metadata, id=None): """ Extract generation parameters from metadata using node relationships Parameters: - metadata: The workflow metadata - id: Optional ID of a downstream node to help identify the specific primary sampler """ params = { "prompt": "", "negative_prompt": "", "seed": None, "steps": None, "cfg_scale": None, # "guidance": None, # Add guidance parameter "sampler": None, "scheduler": None, "checkpoint": None, "loras": "", "size": None, "clip_skip": None } # Get the prompt object for node relationship tracing prompt = metadata.get("current_prompt") # Find the primary KSampler node primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id) # Directly get checkpoint from metadata instead of tracing # Pass primary_sampler_id to avoid redundant calculation checkpoint = MetadataProcessor.find_primary_checkpoint(metadata, id, primary_sampler_id) if checkpoint: params["checkpoint"] = checkpoint # Check if guidance parameter exists in any sampling node for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): parameters = sampler_info.get("parameters", {}) if "guidance" in parameters and parameters["guidance"] is not None: params["guidance"] = parameters["guidance"] break if primary_sampler: # Extract sampling parameters sampling_params = primary_sampler.get("parameters", {}) # Handle both seed and noise_seed params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed") params["steps"] = sampling_params.get("steps") params["cfg_scale"] = sampling_params.get("cfg") params["sampler"] = sampling_params.get("sampler_name") params["scheduler"] = sampling_params.get("scheduler") if prompt and primary_sampler_id: # Check if this is a SamplerCustomAdvanced node is_custom_advanced = False if prompt.original_prompt and primary_sampler_id in prompt.original_prompt: is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced" if is_custom_advanced: # For SamplerCustomAdvanced, use the new handler method MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params) else: # For standard samplers, match conditioning objects to prompts prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id) params["prompt"] = prompt_results["prompt"] params["negative_prompt"] = prompt_results["negative_prompt"] # If prompts were still not found, fall back to tracing connections if not params["prompt"]: # Original tracing for standard samplers # Trace positive prompt - look specifically for CLIPTextEncode positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10) if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") else: # If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10) if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "") # Trace negative prompt - look specifically for CLIPTextEncode negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10) if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") # For SamplerCustom, handle any additional parameters MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params) # Size extraction is same for all sampler types # Check if the sampler itself has size information (from latent_image) if primary_sampler_id in metadata.get(SIZE, {}): width = metadata[SIZE][primary_sampler_id].get("width") height = metadata[SIZE][primary_sampler_id].get("height") if width and height: params["size"] = f"{width}x{height}" # Extract LoRAs using the standardized format lora_parts = [] for node_id, lora_info in metadata.get(LORAS, {}).items(): # Access the lora_list from the standardized format lora_list = lora_info.get("lora_list", []) for lora in lora_list: name = lora.get("name", "unknown") strength = lora.get("strength", 1.0) lora_parts.append(f"") params["loras"] = " ".join(lora_parts) # Set default clip_skip value params["clip_skip"] = "1" # Common default return params @staticmethod def to_dict(metadata, id=None): """ Convert extracted metadata to the ComfyUI output.json format Parameters: - metadata: The workflow metadata - id: Optional ID of a downstream node to help identify the specific primary sampler """ if standalone_mode: # Return empty dictionary in standalone mode return {} params = MetadataProcessor.extract_generation_params(metadata, id) # Convert all values to strings to match output.json format for key in params: if params[key] is not None: params[key] = str(params[key]) return params @staticmethod def to_json(metadata, id=None): """Convert metadata to JSON string""" params = MetadataProcessor.to_dict(metadata, id) return json.dumps(params, indent=4) @staticmethod def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params): """ Handle parameter extraction for SamplerCustomAdvanced nodes Parameters: - metadata: The workflow metadata - prompt: The prompt object containing node connections - primary_sampler_id: ID of the SamplerCustomAdvanced node - params: Parameters dictionary to update """ if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt: return sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {}) # 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists) if "sigmas" in sampler_inputs: scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5) if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}): scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {}) params["steps"] = scheduler_params.get("steps") params["scheduler"] = scheduler_params.get("scheduler") # 2. Trace sampler input to find KSamplerSelect (only if sampler input exists) if "sampler" in sampler_inputs: sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5) if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}): sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {}) params["sampler"] = sampler_params.get("sampler_name") # 3. Trace guider input for CFGGuider and CLIPTextEncode if "guider" in sampler_inputs: guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5) if guider_node_id and guider_node_id in prompt.original_prompt: # Check if the guider node is a CFGGuider if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider": # Extract cfg value from the CFGGuider if guider_node_id in metadata.get(SAMPLING, {}): cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {}) params["cfg_scale"] = cfg_params.get("cfg") # Find CLIPTextEncode for positive prompt positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10) if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") # Find CLIPTextEncode for negative prompt negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10) if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") else: positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10) if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")