diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index 9b75351f..434f1eb1 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,4 +1,5 @@ import time +from nodes import NODE_CLASS_MAPPINGS from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .constants import METADATA_CATEGORIES @@ -78,11 +79,18 @@ class MetadataRegistry: if node_id in executed_nodes: continue - class_type = node_data.get("class_type") - if not class_type: + # Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS) + prompt_class_type = node_data.get("class_type") + if not prompt_class_type: continue - # Create cache key + # Convert to actual class name (which is what we use in our cache) + class_type = prompt_class_type + if prompt_class_type in NODE_CLASS_MAPPINGS: + class_obj = NODE_CLASS_MAPPINGS[prompt_class_type] + class_type = class_obj.__name__ + + # Create cache key using the actual class name cache_key = f"{node_id}:{class_type}" # Check if this node type is relevant for metadata collection diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 0d599c93..bdda89ad 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -184,6 +184,20 @@ class FluxGuidanceExtractor(NodeMetadataExtractor): metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id} metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value + +class UNETLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "unet_name" not in inputs: + return + + model_name = inputs.get("unet_name") + if model_name: + metadata[MODELS][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } # Registry of node-specific extractors NODE_EXTRACTORS = { @@ -194,7 +208,7 @@ NODE_EXTRACTORS = { "EmptyLatentImage": ImageSizeExtractor, "LoraManagerLoader": LoraLoaderManagerExtractor, "SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced - "UNETLoader": CheckpointLoaderExtractor, # Add UNETLoader + "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance # Add other nodes as needed }