mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat: Add conditioning matching to prompts and update metadata handling in node extractors. See #235
This commit is contained in:
@@ -209,6 +209,44 @@ class MetadataProcessor:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def match_conditioning_to_prompts(metadata, sampler_id):
|
||||||
|
"""
|
||||||
|
Match conditioning objects from a sampler to prompts in metadata
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- metadata: The workflow metadata
|
||||||
|
- sampler_id: ID of the sampler node to match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Dictionary with 'prompt' and 'negative_prompt' if found
|
||||||
|
"""
|
||||||
|
result = {
|
||||||
|
"prompt": "",
|
||||||
|
"negative_prompt": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# Check if we have stored conditioning objects for this sampler
|
||||||
|
if sampler_id in metadata.get(PROMPTS, {}) and (
|
||||||
|
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
|
||||||
|
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
|
||||||
|
|
||||||
|
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||||
|
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||||
|
|
||||||
|
# Try to match conditioning objects with those stored by CLIPTextEncodeExtractor
|
||||||
|
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||||
|
if "conditioning" not in prompt_data:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if pos_conditioning is not None and id(prompt_data["conditioning"]) == id(pos_conditioning):
|
||||||
|
result["prompt"] = prompt_data.get("text", "")
|
||||||
|
|
||||||
|
if neg_conditioning is not None and id(prompt_data["conditioning"]) == id(neg_conditioning):
|
||||||
|
result["negative_prompt"] = prompt_data.get("text", "")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_generation_params(metadata, id=None):
|
def extract_generation_params(metadata, id=None):
|
||||||
"""
|
"""
|
||||||
@@ -261,8 +299,14 @@ class MetadataProcessor:
|
|||||||
params["sampler"] = sampling_params.get("sampler_name")
|
params["sampler"] = sampling_params.get("sampler_name")
|
||||||
params["scheduler"] = sampling_params.get("scheduler")
|
params["scheduler"] = sampling_params.get("scheduler")
|
||||||
|
|
||||||
# Trace connections from the primary sampler
|
# First try to match conditioning objects to prompts (new method)
|
||||||
if prompt and primary_sampler_id:
|
if primary_sampler_id:
|
||||||
|
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
|
||||||
|
params["prompt"] = prompt_results["prompt"]
|
||||||
|
params["negative_prompt"] = prompt_results["negative_prompt"]
|
||||||
|
|
||||||
|
# If prompts were not found by object matching, fall back to tracing connections
|
||||||
|
if not params["prompt"] and prompt and primary_sampler_id:
|
||||||
# Check if this is a SamplerCustomAdvanced node
|
# Check if this is a SamplerCustomAdvanced node
|
||||||
is_custom_advanced = False
|
is_custom_advanced = False
|
||||||
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
|
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
|
||||||
|
|||||||
@@ -47,6 +47,13 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
|||||||
"text": text,
|
"text": text,
|
||||||
"node_id": node_id
|
"node_id": node_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update(node_id, outputs, metadata):
|
||||||
|
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||||
|
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||||
|
conditioning = outputs[0][0]
|
||||||
|
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||||
|
|
||||||
class SamplerExtractor(NodeMetadataExtractor):
|
class SamplerExtractor(NodeMetadataExtractor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -64,6 +71,18 @@ class SamplerExtractor(NodeMetadataExtractor):
|
|||||||
"node_id": node_id,
|
"node_id": node_id,
|
||||||
IS_SAMPLER: True # Add sampler flag
|
IS_SAMPLER: True # Add sampler flag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Store the conditioning objects directly in metadata for later matching
|
||||||
|
pos_conditioning = inputs.get("positive", None)
|
||||||
|
neg_conditioning = inputs.get("negative", None)
|
||||||
|
|
||||||
|
# Save conditioning objects in metadata for later matching
|
||||||
|
if pos_conditioning is not None or neg_conditioning is not None:
|
||||||
|
if node_id not in metadata[PROMPTS]:
|
||||||
|
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||||
|
|
||||||
|
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||||
|
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||||
|
|
||||||
# Extract latent image dimensions if available
|
# Extract latent image dimensions if available
|
||||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||||
@@ -103,6 +122,18 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor):
|
|||||||
IS_SAMPLER: True # Add sampler flag
|
IS_SAMPLER: True # Add sampler flag
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Store the conditioning objects directly in metadata for later matching
|
||||||
|
pos_conditioning = inputs.get("positive", None)
|
||||||
|
neg_conditioning = inputs.get("negative", None)
|
||||||
|
|
||||||
|
# Save conditioning objects in metadata for later matching
|
||||||
|
if pos_conditioning is not None or neg_conditioning is not None:
|
||||||
|
if node_id not in metadata[PROMPTS]:
|
||||||
|
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||||
|
|
||||||
|
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||||
|
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||||
|
|
||||||
# Extract latent image dimensions if available
|
# Extract latent image dimensions if available
|
||||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||||
latent = inputs["latent_image"]
|
latent = inputs["latent_image"]
|
||||||
@@ -376,6 +407,13 @@ class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
|
|||||||
|
|
||||||
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
|
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update(node_id, outputs, metadata):
|
||||||
|
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||||
|
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||||
|
conditioning = outputs[0][0]
|
||||||
|
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||||
|
|
||||||
class CFGGuiderExtractor(NodeMetadataExtractor):
|
class CFGGuiderExtractor(NodeMetadataExtractor):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract(node_id, inputs, outputs, metadata):
|
def extract(node_id, inputs, outputs, metadata):
|
||||||
@@ -404,6 +442,7 @@ NODE_EXTRACTORS = {
|
|||||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||||
# Loaders
|
# Loaders
|
||||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||||
|
"comfyLoader": CheckpointLoaderExtractor, # eeasy comfyLoader
|
||||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||||
"LoraLoader": LoraLoaderExtractor,
|
"LoraLoader": LoraLoaderExtractor,
|
||||||
|
|||||||
Reference in New Issue
Block a user