From df6d56ce667be5d734810f075deb5a5ccc253e38 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Fri, 18 Apr 2025 07:12:43 +0800 Subject: [PATCH] feat: Add IMAGES category to constants and enhance metadata handling in node extractors --- py/metadata_collector/constants.py | 3 +- py/metadata_collector/metadata_registry.py | 67 +++++++++++++++++++++- py/metadata_collector/node_extractors.py | 25 +++++++- 3 files changed, 92 insertions(+), 3 deletions(-) diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py index 0ad55b47..c1109580 100644 --- a/py/metadata_collector/constants.py +++ b/py/metadata_collector/constants.py @@ -6,6 +6,7 @@ PROMPTS = "prompts" SAMPLING = "sampling" LORAS = "loras" SIZE = "size" +IMAGES = "images" # Added new category for image results # Collection of categories for iteration -METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE] +METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories diff --git a/py/metadata_collector/metadata_registry.py b/py/metadata_collector/metadata_registry.py index 434f1eb1..bcf2284a 100644 --- a/py/metadata_collector/metadata_registry.py +++ b/py/metadata_collector/metadata_registry.py @@ -1,7 +1,7 @@ import time from nodes import NODE_CLASS_MAPPINGS from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor -from .constants import METADATA_CATEGORIES +from .constants import METADATA_CATEGORIES, IMAGES class MetadataRegistry: """A singleton registry to store and retrieve workflow metadata""" @@ -23,9 +23,28 @@ class MetadataRegistry: # Node-level cache for metadata self.node_cache = {} + # Limit the number of stored prompts + self.max_prompt_history = 3 + # Categories we want to track and retrieve from cache self.metadata_categories = METADATA_CATEGORIES + def _clean_old_prompts(self): + """Clean up old prompt metadata, keeping only recent ones""" + if len(self.prompt_metadata) <= self.max_prompt_history: + return + + # Sort all prompt_ids by timestamp + sorted_prompts = sorted( + self.prompt_metadata.keys(), + key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0) + ) + + # Remove oldest records + prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history] + for pid in prompts_to_remove: + del self.prompt_metadata[pid] + def start_collection(self, prompt_id): """Begin metadata collection for a new prompt""" self.current_prompt_id = prompt_id @@ -39,6 +58,9 @@ class MetadataRegistry: "current_prompt": None, # Will store the prompt object "timestamp": time.time() }) + + # Clean up old prompt data + self._clean_old_prompts() def set_current_prompt(self, prompt): """Set the current prompt object reference""" @@ -177,3 +199,46 @@ class MetadataRegistry: # Save to cache if we have any metadata for this node if any(node_metadata.values()): self.node_cache[cache_key] = node_metadata + + def clear_unused_cache(self): + """Clean up node_cache entries that are no longer in use""" + # Collect all node_ids currently in prompt_metadata + active_node_ids = set() + for prompt_data in self.prompt_metadata.values(): + for category in self.metadata_categories: + if category in prompt_data: + active_node_ids.update(prompt_data[category].keys()) + + # Find cache keys that are no longer needed + keys_to_remove = [] + for cache_key in self.node_cache: + node_id = cache_key.split(':')[0] + if node_id not in active_node_ids: + keys_to_remove.append(cache_key) + + # Remove cache entries that are no longer needed + for key in keys_to_remove: + del self.node_cache[key] + + def clear_metadata(self, prompt_id=None): + """Clear metadata for a specific prompt or reset all data""" + if prompt_id is not None: + if prompt_id in self.prompt_metadata: + del self.prompt_metadata[prompt_id] + # Clean up cache after removing prompt + self.clear_unused_cache() + else: + # Reset all data + self._reset() + + def get_first_decoded_image(self, prompt_id=None): + """Get the first decoded image result""" + key = prompt_id if prompt_id is not None else self.current_prompt_id + if key not in self.prompt_metadata: + return None + + metadata = self.prompt_metadata[key] + if IMAGES in metadata and "first_decode" in metadata[IMAGES]: + return metadata[IMAGES]["first_decode"]["image"] + + return None diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 210ab29e..cca73095 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -1,6 +1,6 @@ import os -from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE +from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES class NodeMetadataExtractor: @@ -235,7 +235,28 @@ class UNETLoaderExtractor(NodeMetadataExtractor): "type": "checkpoint", "node_id": node_id } + +class VAEDecodeExtractor(NodeMetadataExtractor): + @staticmethod + def extract(node_id, inputs, outputs, metadata): + pass + @staticmethod + def update(node_id, outputs, metadata): + # Check if we already have a first VAEDecode result + if IMAGES in metadata and "first_decode" in metadata[IMAGES]: + return + + # Ensure IMAGES category exists + if IMAGES not in metadata: + metadata[IMAGES] = {} + + # Save reference to the first VAEDecode result + metadata[IMAGES]["first_decode"] = { + "node_id": node_id, + "image": outputs + } + # Registry of node-specific extractors NODE_EXTRACTORS = { # Sampling @@ -253,5 +274,7 @@ NODE_EXTRACTORS = { "EmptyLatentImage": ImageSizeExtractor, # Flux "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance + # Image + "VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor # Add other nodes as needed }