From 18eb605605b2c373f23764df7549aa54334f5a2e Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Thu, 17 Apr 2025 06:23:31 +0800 Subject: [PATCH] feat: Refactor metadata processing to use constants for category keys and improve structure --- py/metadata_collector/constants.py | 11 ++++++++ py/metadata_collector/metadata_processor.py | 30 +++++++++++---------- py/metadata_collector/metadata_registry.py | 14 +++++----- py/metadata_collector/node_extractors.py | 24 +++++++++-------- 4 files changed, 47 insertions(+), 32 deletions(-) create mode 100644 py/metadata_collector/constants.py diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py new file mode 100644 index 00000000..0ad55b47 --- /dev/null +++ b/py/metadata_collector/constants.py @@ -0,0 +1,11 @@ +"""Constants used by the metadata collector""" + +# Individual category constants +MODELS = "models" +PROMPTS = "prompts" +SAMPLING = "sampling" +LORAS = "loras" +SIZE = "size" + +# Collection of categories for iteration +METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE] diff --git a/py/metadata_collector/metadata_processor.py b/py/metadata_collector/metadata_processor.py index da55c0c7..9c80e68c 100644 --- a/py/metadata_collector/metadata_processor.py +++ b/py/metadata_collector/metadata_processor.py @@ -1,5 +1,7 @@ import json +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE + class MetadataProcessor: """Process and format collected metadata""" @@ -9,7 +11,7 @@ class MetadataProcessor: primary_sampler = None primary_sampler_id = None - for node_id, sampler_info in metadata.get("sampling", {}).items(): + for node_id, sampler_info in metadata.get(SAMPLING, {}).items(): parameters = sampler_info.get("parameters", {}) denoise = parameters.get("denoise") @@ -41,11 +43,11 @@ class MetadataProcessor: @staticmethod def find_primary_checkpoint(metadata): """Find the primary checkpoint model in the workflow""" - if not metadata.get("models"): + if not metadata.get(MODELS): return None # In most workflows, there's only one checkpoint, so we can just take the first one - for node_id, model_info in metadata.get("models", {}).items(): + for node_id, model_info in metadata.get(MODELS, {}).items(): if model_info.get("type") == "checkpoint": return model_info.get("name") @@ -90,18 +92,18 @@ class MetadataProcessor: if prompt and primary_sampler_id: # Trace positive prompt positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive") - if positive_node_id and positive_node_id in metadata.get("prompts", {}): - params["prompt"] = metadata["prompts"][positive_node_id].get("text", "") + if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}): + params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "") # Trace negative prompt negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative") - if negative_node_id and negative_node_id in metadata.get("prompts", {}): - params["negative_prompt"] = metadata["prompts"][negative_node_id].get("text", "") + if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): + params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") # Check if the sampler itself has size information (from latent_image) - if primary_sampler_id in metadata.get("size", {}): - width = metadata["size"][primary_sampler_id].get("width") - height = metadata["size"][primary_sampler_id].get("height") + if primary_sampler_id in metadata.get(SIZE, {}): + width = metadata[SIZE][primary_sampler_id].get("width") + height = metadata[SIZE][primary_sampler_id].get("height") if width and height: params["size"] = f"{width}x{height}" else: @@ -115,9 +117,9 @@ class MetadataProcessor: # Limit depth to avoid infinite loops in complex workflows max_depth = 10 for _ in range(max_depth): - if current_node_id in metadata.get("size", {}): - width = metadata["size"][current_node_id].get("width") - height = metadata["size"][current_node_id].get("height") + if current_node_id in metadata.get(SIZE, {}): + width = metadata[SIZE][current_node_id].get("width") + height = metadata[SIZE][current_node_id].get("height") if width and height: params["size"] = f"{width}x{height}" size_found = True @@ -141,7 +143,7 @@ class MetadataProcessor: # Extract LoRAs using the standardized format lora_parts = [] - for node_id, lora_info in metadata.get("loras", {}).items(): + for node_id, lora_info in metadata.get(LORAS, {}).items(): # Access the lora_list from the standardized format lora_list = lora_info.get("lora_list", []) for lora in lora_list: diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index cae9a1a9..9b75351f 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,5 +1,6 @@ import time from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor +from .constants import METADATA_CATEGORIES class MetadataRegistry: """A singleton registry to store and retrieve workflow metadata""" @@ -22,22 +23,21 @@ class MetadataRegistry: self.node_cache = {} # Categories we want to track and retrieve from cache - self.metadata_categories = ["models", "prompts", "sampling", "loras", "size"] + self.metadata_categories = METADATA_CATEGORIES def start_collection(self, prompt_id): """Begin metadata collection for a new prompt""" self.current_prompt_id = prompt_id self.executed_nodes = set() self.prompt_metadata[prompt_id] = { - "models": {}, - "prompts": {}, - "sampling": {}, - "loras": {}, - "size": {}, + category: {} for category in METADATA_CATEGORIES + } + # Add additional metadata fields + self.prompt_metadata[prompt_id].update({ "execution_order": [], "current_prompt": None, # Will store the prompt object "timestamp": time.time() - } + }) def set_current_prompt(self, prompt): """Set the current prompt object reference""" diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 612059d6..6fb018b5 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -1,5 +1,7 @@ import os +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE + class NodeMetadataExtractor: """Base class for node-specific metadata extraction""" @@ -28,7 +30,7 @@ class CheckpointLoaderExtractor(NodeMetadataExtractor): model_name = inputs.get("ckpt_name") if model_name: - metadata["models"][node_id] = { + metadata[MODELS][node_id] = { "name": model_name, "type": "checkpoint", "node_id": node_id @@ -41,7 +43,7 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor): return text = inputs.get("text", "") - metadata["prompts"][node_id] = { + metadata[PROMPTS][node_id] = { "text": text, "node_id": node_id } @@ -57,7 +59,7 @@ class SamplerExtractor(NodeMetadataExtractor): if key in inputs: sampling_params[key] = inputs[key] - metadata["sampling"][node_id] = { + metadata[SAMPLING][node_id] = { "parameters": sampling_params, "node_id": node_id } @@ -74,10 +76,10 @@ class SamplerExtractor(NodeMetadataExtractor): height = int(samples.shape[2] * 8) width = int(samples.shape[3] * 8) - if "size" not in metadata: - metadata["size"] = {} + if SIZE not in metadata: + metadata[SIZE] = {} - metadata["size"][node_id] = { + metadata[SIZE][node_id] = { "width": width, "height": height, "node_id": node_id @@ -95,7 +97,7 @@ class LoraLoaderExtractor(NodeMetadataExtractor): strength_model = round(float(inputs.get("strength_model", 1.0)), 2) # Use the standardized format with lora_list - metadata["loras"][node_id] = { + metadata[LORAS][node_id] = { "lora_list": [ { "name": lora_name, @@ -114,10 +116,10 @@ class ImageSizeExtractor(NodeMetadataExtractor): width = inputs.get("width", 512) height = inputs.get("height", 512) - if "size" not in metadata: - metadata["size"] = {} + if SIZE not in metadata: + metadata[SIZE] = {} - metadata["size"][node_id] = { + metadata[SIZE][node_id] = { "width": width, "height": height, "node_id": node_id @@ -164,7 +166,7 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor): }) if active_loras: - metadata["loras"][node_id] = { + metadata[LORAS][node_id] = { "lora_list": active_loras, "node_id": node_id }