mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Update metadata registry to remove cache entries when node metadata becomes empty instead of keeping stale data. This prevents accumulation of unused cache entries and ensures cache only contains valid metadata. Added test case to verify cache behavior when LoRA configurations are removed.
278 lines
12 KiB
Python
278 lines
12 KiB
Python
import time
|
|
from nodes import NODE_CLASS_MAPPINGS
|
|
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
|
from .constants import METADATA_CATEGORIES, IMAGES
|
|
|
|
class MetadataRegistry:
|
|
"""A singleton registry to store and retrieve workflow metadata"""
|
|
_instance = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._reset()
|
|
return cls._instance
|
|
|
|
def _reset(self):
|
|
self.current_prompt_id = None
|
|
self.current_prompt = None
|
|
self.metadata = {}
|
|
self.prompt_metadata = {}
|
|
self.executed_nodes = set()
|
|
|
|
# Node-level cache for metadata
|
|
self.node_cache = {}
|
|
|
|
# Limit the number of stored prompts
|
|
self.max_prompt_history = 3
|
|
|
|
# Categories we want to track and retrieve from cache
|
|
self.metadata_categories = METADATA_CATEGORIES
|
|
|
|
def _clean_old_prompts(self):
|
|
"""Clean up old prompt metadata, keeping only recent ones"""
|
|
if len(self.prompt_metadata) <= self.max_prompt_history:
|
|
return
|
|
|
|
# Sort all prompt_ids by timestamp
|
|
sorted_prompts = sorted(
|
|
self.prompt_metadata.keys(),
|
|
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
|
|
)
|
|
|
|
# Remove oldest records
|
|
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
|
|
for pid in prompts_to_remove:
|
|
del self.prompt_metadata[pid]
|
|
|
|
def start_collection(self, prompt_id):
|
|
"""Begin metadata collection for a new prompt"""
|
|
self.current_prompt_id = prompt_id
|
|
self.executed_nodes = set()
|
|
self.prompt_metadata[prompt_id] = {
|
|
category: {} for category in METADATA_CATEGORIES
|
|
}
|
|
# Add additional metadata fields
|
|
self.prompt_metadata[prompt_id].update({
|
|
"execution_order": [],
|
|
"current_prompt": None, # Will store the prompt object
|
|
"timestamp": time.time()
|
|
})
|
|
|
|
# Clean up old prompt data
|
|
self._clean_old_prompts()
|
|
|
|
def set_current_prompt(self, prompt):
|
|
"""Set the current prompt object reference"""
|
|
self.current_prompt = prompt
|
|
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
|
|
# Store the prompt in the metadata for later relationship tracing
|
|
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
|
|
|
|
def get_metadata(self, prompt_id=None):
|
|
"""Get collected metadata for a prompt"""
|
|
key = prompt_id if prompt_id is not None else self.current_prompt_id
|
|
if key not in self.prompt_metadata:
|
|
return {}
|
|
|
|
metadata = self.prompt_metadata[key]
|
|
|
|
# If we have a current prompt object, check for non-executed nodes
|
|
prompt_obj = metadata.get("current_prompt")
|
|
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
|
|
original_prompt = prompt_obj.original_prompt
|
|
|
|
# Fill in missing metadata from cache for nodes that weren't executed
|
|
self._fill_missing_metadata(key, original_prompt)
|
|
|
|
return self.prompt_metadata.get(key, {})
|
|
|
|
def _fill_missing_metadata(self, prompt_id, original_prompt):
|
|
"""Fill missing metadata from cache for non-executed nodes"""
|
|
if not original_prompt:
|
|
return
|
|
|
|
executed_nodes = self.executed_nodes
|
|
metadata = self.prompt_metadata[prompt_id]
|
|
|
|
# Iterate through nodes in the original prompt
|
|
for node_id, node_data in original_prompt.items():
|
|
# Skip if already executed in this run
|
|
if node_id in executed_nodes:
|
|
continue
|
|
|
|
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
|
|
prompt_class_type = node_data.get("class_type")
|
|
if not prompt_class_type:
|
|
continue
|
|
|
|
# Convert to actual class name (which is what we use in our cache)
|
|
class_type = prompt_class_type
|
|
if prompt_class_type in NODE_CLASS_MAPPINGS:
|
|
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
|
|
class_type = class_obj.__name__
|
|
|
|
# Create cache key using the actual class name
|
|
cache_key = f"{node_id}:{class_type}"
|
|
|
|
# Check if this node type is relevant for metadata collection
|
|
if class_type in NODE_EXTRACTORS:
|
|
# Check if we have cached metadata for this node
|
|
if cache_key in self.node_cache:
|
|
cached_data = self.node_cache[cache_key]
|
|
|
|
# Apply cached metadata to the current metadata
|
|
for category in self.metadata_categories:
|
|
if category in cached_data and node_id in cached_data[category]:
|
|
if node_id not in metadata[category]:
|
|
metadata[category][node_id] = cached_data[category][node_id]
|
|
|
|
def record_node_execution(self, node_id, class_type, inputs, outputs):
|
|
"""Record information about a node's execution"""
|
|
if not self.current_prompt_id:
|
|
return
|
|
|
|
# Add to execution order and mark as executed
|
|
if node_id not in self.executed_nodes:
|
|
self.executed_nodes.add(node_id)
|
|
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
|
|
|
|
# Process inputs to simplify working with them
|
|
processed_inputs = {}
|
|
for input_name, input_values in inputs.items():
|
|
if isinstance(input_values, list) and len(input_values) > 0:
|
|
# For single values, just use the first one (most common case)
|
|
processed_inputs[input_name] = input_values[0]
|
|
else:
|
|
processed_inputs[input_name] = input_values
|
|
|
|
# Extract node-specific metadata
|
|
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
|
extractor.extract(
|
|
node_id,
|
|
processed_inputs,
|
|
outputs,
|
|
self.prompt_metadata[self.current_prompt_id]
|
|
)
|
|
|
|
# Cache this node's metadata
|
|
self._cache_node_metadata(node_id, class_type)
|
|
|
|
def update_node_execution(self, node_id, class_type, outputs):
|
|
"""Update node metadata with output information"""
|
|
if not self.current_prompt_id:
|
|
return
|
|
|
|
# Process outputs to make them more usable
|
|
processed_outputs = outputs
|
|
|
|
# Use the same extractor to update with outputs
|
|
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
|
|
if hasattr(extractor, 'update'):
|
|
extractor.update(
|
|
node_id,
|
|
processed_outputs,
|
|
self.prompt_metadata[self.current_prompt_id]
|
|
)
|
|
|
|
# Update the cached metadata for this node
|
|
self._cache_node_metadata(node_id, class_type)
|
|
|
|
def _cache_node_metadata(self, node_id, class_type):
|
|
"""Cache the metadata for a specific node"""
|
|
if not self.current_prompt_id or not node_id or not class_type:
|
|
return
|
|
|
|
# Create a cache key combining node_id and class_type
|
|
cache_key = f"{node_id}:{class_type}"
|
|
|
|
# Create a shallow copy of the node's metadata
|
|
node_metadata = {}
|
|
current_metadata = self.prompt_metadata[self.current_prompt_id]
|
|
|
|
for category in self.metadata_categories:
|
|
if category in current_metadata and node_id in current_metadata[category]:
|
|
if category not in node_metadata:
|
|
node_metadata[category] = {}
|
|
node_metadata[category][node_id] = current_metadata[category][node_id]
|
|
|
|
# Save new metadata or clear stale cache entries when metadata is empty
|
|
if any(node_metadata.values()):
|
|
self.node_cache[cache_key] = node_metadata
|
|
else:
|
|
self.node_cache.pop(cache_key, None)
|
|
|
|
def clear_unused_cache(self):
|
|
"""Clean up node_cache entries that are no longer in use"""
|
|
# Collect all node_ids currently in prompt_metadata
|
|
active_node_ids = set()
|
|
for prompt_data in self.prompt_metadata.values():
|
|
for category in self.metadata_categories:
|
|
if category in prompt_data:
|
|
active_node_ids.update(prompt_data[category].keys())
|
|
|
|
# Find cache keys that are no longer needed
|
|
keys_to_remove = []
|
|
for cache_key in self.node_cache:
|
|
node_id = cache_key.split(':')[0]
|
|
if node_id not in active_node_ids:
|
|
keys_to_remove.append(cache_key)
|
|
|
|
# Remove cache entries that are no longer needed
|
|
for key in keys_to_remove:
|
|
del self.node_cache[key]
|
|
|
|
def clear_metadata(self, prompt_id=None):
|
|
"""Clear metadata for a specific prompt or reset all data"""
|
|
if prompt_id is not None:
|
|
if prompt_id in self.prompt_metadata:
|
|
del self.prompt_metadata[prompt_id]
|
|
# Clean up cache after removing prompt
|
|
self.clear_unused_cache()
|
|
else:
|
|
# Reset all data
|
|
self._reset()
|
|
|
|
def get_first_decoded_image(self, prompt_id=None):
|
|
"""Get the first decoded image result"""
|
|
key = prompt_id if prompt_id is not None else self.current_prompt_id
|
|
if key not in self.prompt_metadata:
|
|
return None
|
|
|
|
metadata = self.prompt_metadata[key]
|
|
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
|
|
image_data = metadata[IMAGES]["first_decode"]["image"]
|
|
|
|
# If it's an image batch or tuple, handle various formats
|
|
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
|
# Return first element of list/tuple
|
|
return image_data[0]
|
|
|
|
# If it's a tensor, return as is for processing in the route handler
|
|
return image_data
|
|
|
|
# If no image is found in the current metadata, try to find it in the cache
|
|
# This handles the case where VAEDecode was cached by ComfyUI and not executed
|
|
prompt_obj = metadata.get("current_prompt")
|
|
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
|
|
original_prompt = prompt_obj.original_prompt
|
|
for node_id, node_data in original_prompt.items():
|
|
class_type = node_data.get("class_type")
|
|
if class_type and class_type in NODE_CLASS_MAPPINGS:
|
|
class_obj = NODE_CLASS_MAPPINGS[class_type]
|
|
class_name = class_obj.__name__
|
|
# Check if this is a VAEDecode node
|
|
if class_name == "VAEDecode":
|
|
# Try to find this node in the cache
|
|
cache_key = f"{node_id}:{class_name}"
|
|
if cache_key in self.node_cache:
|
|
cached_data = self.node_cache[cache_key]
|
|
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
|
|
image_data = cached_data[IMAGES][node_id]["image"]
|
|
# Handle different image formats
|
|
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
|
|
return image_data[0]
|
|
return image_data
|
|
|
|
return None
|