From c2f599b4ff2aa634f628ed9e7d53e2cfbc4fa1d9 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Apr 2025 22:05:40 +0800 Subject: [PATCH] feat: Update node extractors to include UNETLoaderExtractor and enhance metadata handling for guidance parameters --- py/metadata_collector/metadata_registry.py | 14 +++++++++++--- py/metadata_collector/node_extractors.py | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index 9b75351f..434f1eb1 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,4 +1,5 @@ import time +from nodes import NODE_CLASS_MAPPINGS from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .constants import METADATA_CATEGORIES @@ -78,11 +79,18 @@ class MetadataRegistry: if node_id in executed_nodes: continue - class_type = node_data.get("class_type") - if not class_type: + # Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS) + prompt_class_type = node_data.get("class_type") + if not prompt_class_type: continue - # Create cache key + # Convert to actual class name (which is what we use in our cache) + class_type = prompt_class_type + if prompt_class_type in NODE_CLASS_MAPPINGS: + class_obj = NODE_CLASS_MAPPINGS[prompt_class_type] + class_type = class_obj.__name__ + + # Create cache key using the actual class name cache_key = f"{node_id}:{class_type}" # Check if this node type is relevant for metadata collection diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 0d599c93..bdda89ad 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -184,6 +184,20 @@ class FluxGuidanceExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id} metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value + +class UNETLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "unet_name" not in inputs: + return + + model_name = inputs.get("unet_name") + if model_name: + metadata[MODELS][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } # Registry of node-specific extractors NODE_EXTRACTORS = { @@ -194,7 +208,7 @@ NODE_EXTRACTORS = { "EmptyLatentImage": ImageSizeExtractor, "LoraManagerLoader": LoraLoaderManagerExtractor, "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced - "UNETLoader": CheckpointLoaderExtractor, # Add UNETLoader + "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance # Add other nodes as needed }