feat: Refactor metadata processing to use constants for category keys and improve structure

This commit is contained in:
Will Miao
2025-04-17 06:23:31 +08:00
parent 4fdc88e9e1
commit 18eb605605
4 changed files with 47 additions and 32 deletions

View File

@@ -0,0 +1,11 @@
"""Constants used by the metadata collector"""
# Individual category constants
MODELS = "models"
PROMPTS = "prompts"
SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
# Collection of categories for iteration
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE]

View File

@@ -1,5 +1,7 @@
import json import json
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
class MetadataProcessor: class MetadataProcessor:
"""Process and format collected metadata""" """Process and format collected metadata"""
@@ -9,7 +11,7 @@ class MetadataProcessor:
primary_sampler = None primary_sampler = None
primary_sampler_id = None primary_sampler_id = None
for node_id, sampler_info in metadata.get("sampling", {}).items(): for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {}) parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise") denoise = parameters.get("denoise")
@@ -41,11 +43,11 @@ class MetadataProcessor:
@staticmethod @staticmethod
def find_primary_checkpoint(metadata): def find_primary_checkpoint(metadata):
"""Find the primary checkpoint model in the workflow""" """Find the primary checkpoint model in the workflow"""
if not metadata.get("models"): if not metadata.get(MODELS):
return None return None
# In most workflows, there's only one checkpoint, so we can just take the first one # In most workflows, there's only one checkpoint, so we can just take the first one
for node_id, model_info in metadata.get("models", {}).items(): for node_id, model_info in metadata.get(MODELS, {}).items():
if model_info.get("type") == "checkpoint": if model_info.get("type") == "checkpoint":
return model_info.get("name") return model_info.get("name")
@@ -90,18 +92,18 @@ class MetadataProcessor:
if prompt and primary_sampler_id: if prompt and primary_sampler_id:
# Trace positive prompt # Trace positive prompt
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive") positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive")
if positive_node_id and positive_node_id in metadata.get("prompts", {}): if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata["prompts"][positive_node_id].get("text", "") params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Trace negative prompt # Trace negative prompt
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative") negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative")
if negative_node_id and negative_node_id in metadata.get("prompts", {}): if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata["prompts"][negative_node_id].get("text", "") params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Check if the sampler itself has size information (from latent_image) # Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get("size", {}): if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata["size"][primary_sampler_id].get("width") width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata["size"][primary_sampler_id].get("height") height = metadata[SIZE][primary_sampler_id].get("height")
if width and height: if width and height:
params["size"] = f"{width}x{height}" params["size"] = f"{width}x{height}"
else: else:
@@ -115,9 +117,9 @@ class MetadataProcessor:
# Limit depth to avoid infinite loops in complex workflows # Limit depth to avoid infinite loops in complex workflows
max_depth = 10 max_depth = 10
for _ in range(max_depth): for _ in range(max_depth):
if current_node_id in metadata.get("size", {}): if current_node_id in metadata.get(SIZE, {}):
width = metadata["size"][current_node_id].get("width") width = metadata[SIZE][current_node_id].get("width")
height = metadata["size"][current_node_id].get("height") height = metadata[SIZE][current_node_id].get("height")
if width and height: if width and height:
params["size"] = f"{width}x{height}" params["size"] = f"{width}x{height}"
size_found = True size_found = True
@@ -141,7 +143,7 @@ class MetadataProcessor:
# Extract LoRAs using the standardized format # Extract LoRAs using the standardized format
lora_parts = [] lora_parts = []
for node_id, lora_info in metadata.get("loras", {}).items(): for node_id, lora_info in metadata.get(LORAS, {}).items():
# Access the lora_list from the standardized format # Access the lora_list from the standardized format
lora_list = lora_info.get("lora_list", []) lora_list = lora_info.get("lora_list", [])
for lora in lora_list: for lora in lora_list:

View File

@@ -1,5 +1,6 @@
import time import time
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES
class MetadataRegistry: class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata""" """A singleton registry to store and retrieve workflow metadata"""
@@ -22,22 +23,21 @@ class MetadataRegistry:
self.node_cache = {} self.node_cache = {}
# Categories we want to track and retrieve from cache # Categories we want to track and retrieve from cache
self.metadata_categories = ["models", "prompts", "sampling", "loras", "size"] self.metadata_categories = METADATA_CATEGORIES
def start_collection(self, prompt_id): def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt""" """Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id self.current_prompt_id = prompt_id
self.executed_nodes = set() self.executed_nodes = set()
self.prompt_metadata[prompt_id] = { self.prompt_metadata[prompt_id] = {
"models": {}, category: {} for category in METADATA_CATEGORIES
"prompts": {}, }
"sampling": {}, # Add additional metadata fields
"loras": {}, self.prompt_metadata[prompt_id].update({
"size": {},
"execution_order": [], "execution_order": [],
"current_prompt": None, # Will store the prompt object "current_prompt": None, # Will store the prompt object
"timestamp": time.time() "timestamp": time.time()
} })
def set_current_prompt(self, prompt): def set_current_prompt(self, prompt):
"""Set the current prompt object reference""" """Set the current prompt object reference"""

View File

@@ -1,5 +1,7 @@
import os import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
class NodeMetadataExtractor: class NodeMetadataExtractor:
"""Base class for node-specific metadata extraction""" """Base class for node-specific metadata extraction"""
@@ -28,7 +30,7 @@ class CheckpointLoaderExtractor(NodeMetadataExtractor):
model_name = inputs.get("ckpt_name") model_name = inputs.get("ckpt_name")
if model_name: if model_name:
metadata["models"][node_id] = { metadata[MODELS][node_id] = {
"name": model_name, "name": model_name,
"type": "checkpoint", "type": "checkpoint",
"node_id": node_id "node_id": node_id
@@ -41,7 +43,7 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
return return
text = inputs.get("text", "") text = inputs.get("text", "")
metadata["prompts"][node_id] = { metadata[PROMPTS][node_id] = {
"text": text, "text": text,
"node_id": node_id "node_id": node_id
} }
@@ -57,7 +59,7 @@ class SamplerExtractor(NodeMetadataExtractor):
if key in inputs: if key in inputs:
sampling_params[key] = inputs[key] sampling_params[key] = inputs[key]
metadata["sampling"][node_id] = { metadata[SAMPLING][node_id] = {
"parameters": sampling_params, "parameters": sampling_params,
"node_id": node_id "node_id": node_id
} }
@@ -74,10 +76,10 @@ class SamplerExtractor(NodeMetadataExtractor):
height = int(samples.shape[2] * 8) height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8) width = int(samples.shape[3] * 8)
if "size" not in metadata: if SIZE not in metadata:
metadata["size"] = {} metadata[SIZE] = {}
metadata["size"][node_id] = { metadata[SIZE][node_id] = {
"width": width, "width": width,
"height": height, "height": height,
"node_id": node_id "node_id": node_id
@@ -95,7 +97,7 @@ class LoraLoaderExtractor(NodeMetadataExtractor):
strength_model = round(float(inputs.get("strength_model", 1.0)), 2) strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
# Use the standardized format with lora_list # Use the standardized format with lora_list
metadata["loras"][node_id] = { metadata[LORAS][node_id] = {
"lora_list": [ "lora_list": [
{ {
"name": lora_name, "name": lora_name,
@@ -114,10 +116,10 @@ class ImageSizeExtractor(NodeMetadataExtractor):
width = inputs.get("width", 512) width = inputs.get("width", 512)
height = inputs.get("height", 512) height = inputs.get("height", 512)
if "size" not in metadata: if SIZE not in metadata:
metadata["size"] = {} metadata[SIZE] = {}
metadata["size"][node_id] = { metadata[SIZE][node_id] = {
"width": width, "width": width,
"height": height, "height": height,
"node_id": node_id "node_id": node_id
@@ -164,7 +166,7 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor):
}) })
if active_loras: if active_loras:
metadata["loras"][node_id] = { metadata[LORAS][node_id] = {
"lora_list": active_loras, "lora_list": active_loras,
"node_id": node_id "node_id": node_id
} }