feat: Add IMAGES category to constants and enhance metadata handling in node extractors

This commit is contained in:
Will Miao
2025-04-18 07:12:43 +08:00
parent f0203c96ab
commit df6d56ce66
3 changed files with 92 additions and 3 deletions

View File

@@ -6,6 +6,7 @@ PROMPTS = "prompts"
SAMPLING = "sampling" SAMPLING = "sampling"
LORAS = "loras" LORAS = "loras"
SIZE = "size" SIZE = "size"
IMAGES = "images" # Added new category for image results
# Collection of categories for iteration # Collection of categories for iteration
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE] METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories

View File

@@ -1,7 +1,7 @@
import time import time
from nodes import NODE_CLASS_MAPPINGS from nodes import NODE_CLASS_MAPPINGS
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry: class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata""" """A singleton registry to store and retrieve workflow metadata"""
@@ -23,9 +23,28 @@ class MetadataRegistry:
# Node-level cache for metadata # Node-level cache for metadata
self.node_cache = {} self.node_cache = {}
# Limit the number of stored prompts
self.max_prompt_history = 3
# Categories we want to track and retrieve from cache # Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history:
return
# Sort all prompt_ids by timestamp
sorted_prompts = sorted(
self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
)
# Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
for pid in prompts_to_remove:
del self.prompt_metadata[pid]
def start_collection(self, prompt_id): def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt""" """Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id self.current_prompt_id = prompt_id
@@ -39,6 +58,9 @@ class MetadataRegistry:
"current_prompt": None, # Will store the prompt object "current_prompt": None, # Will store the prompt object
"timestamp": time.time() "timestamp": time.time()
}) })
# Clean up old prompt data
self._clean_old_prompts()
def set_current_prompt(self, prompt): def set_current_prompt(self, prompt):
"""Set the current prompt object reference""" """Set the current prompt object reference"""
@@ -177,3 +199,46 @@ class MetadataRegistry:
# Save to cache if we have any metadata for this node # Save to cache if we have any metadata for this node
if any(node_metadata.values()): if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata self.node_cache[cache_key] = node_metadata
def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata
active_node_ids = set()
for prompt_data in self.prompt_metadata.values():
for category in self.metadata_categories:
if category in prompt_data:
active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed
keys_to_remove = []
for cache_key in self.node_cache:
node_id = cache_key.split(':')[0]
if node_id not in active_node_ids:
keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed
for key in keys_to_remove:
del self.node_cache[key]
def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None:
if prompt_id in self.prompt_metadata:
del self.prompt_metadata[prompt_id]
# Clean up cache after removing prompt
self.clear_unused_cache()
else:
# Reset all data
self._reset()
def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return None
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
return metadata[IMAGES]["first_decode"]["image"]
return None

View File

@@ -1,6 +1,6 @@
import os import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES
class NodeMetadataExtractor: class NodeMetadataExtractor:
@@ -235,7 +235,28 @@ class UNETLoaderExtractor(NodeMetadataExtractor):
"type": "checkpoint", "type": "checkpoint",
"node_id": node_id "node_id": node_id
} }
class VAEDecodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
@staticmethod
def update(node_id, outputs, metadata):
# Check if we already have a first VAEDecode result
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
return
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
# Save reference to the first VAEDecode result
metadata[IMAGES]["first_decode"] = {
"node_id": node_id,
"image": outputs
}
# Registry of node-specific extractors # Registry of node-specific extractors
NODE_EXTRACTORS = { NODE_EXTRACTORS = {
# Sampling # Sampling
@@ -253,5 +274,7 @@ NODE_EXTRACTORS = {
"EmptyLatentImage": ImageSizeExtractor, "EmptyLatentImage": ImageSizeExtractor,
# Flux # Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance "FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Image
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
# Add other nodes as needed # Add other nodes as needed
} }