diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 289e2160..41c27620 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -3,6 +3,18 @@ import os from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES, IS_SAMPLER +def _store_checkpoint_metadata(metadata, node_id, model_name): + """Store checkpoint model information when available.""" + if not model_name: + return + metadata.setdefault(MODELS, {}) + metadata[MODELS][node_id] = { + "name": model_name, + "type": "checkpoint", + "node_id": node_id + } + + class NodeMetadataExtractor: """Base class for node-specific metadata extraction""" @@ -29,12 +41,27 @@ class CheckpointLoaderExtractor(NodeMetadataExtractor): return model_name = inputs.get("ckpt_name") - if model_name: - metadata[MODELS][node_id] = { - "name": model_name, - "type": "checkpoint", - "node_id": node_id - } + _store_checkpoint_metadata(metadata, node_id, model_name) + + +class NunchakuFluxDiTLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "model_path" not in inputs: + return + + model_name = inputs.get("model_path") + _store_checkpoint_metadata(metadata, node_id, model_name) + + +class NunchakuQwenImageDiTLoaderExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + if not inputs or "model_name" not in inputs: + return + + model_name = inputs.get("model_name") + _store_checkpoint_metadata(metadata, node_id, model_name) class TSCCheckpointLoaderExtractor(NodeMetadataExtractor): @staticmethod @@ -43,12 +70,7 @@ class TSCCheckpointLoaderExtractor(NodeMetadataExtractor): return model_name = inputs.get("ckpt_name") - if model_name: - metadata[MODELS][node_id] = { - "name": model_name, - "type": "checkpoint", - "node_id": node_id - } + _store_checkpoint_metadata(metadata, node_id, model_name) # For loader node has lora_stack input, like Efficient Loader from Efficient Nodes active_loras = [] @@ -660,6 +682,8 @@ NODE_EXTRACTORS = { "comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader "CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss "TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes + "NunchakuFluxDiTLoader": NunchakuFluxDiTLoaderExtractor, # ComfyUI-Nunchaku + "NunchakuQwenImageDiTLoader": NunchakuQwenImageDiTLoaderExtractor, # ComfyUI-Nunchaku "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor "LoraLoader": LoraLoaderExtractor,