mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat: Refactor metadata processing to use constants for category keys and improve structure
This commit is contained in:
11
py/metadata_collector/constants.py
Normal file
11
py/metadata_collector/constants.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Constants used by the metadata collector"""
|
||||||
|
|
||||||
|
# Individual category constants
|
||||||
|
MODELS = "models"
|
||||||
|
PROMPTS = "prompts"
|
||||||
|
SAMPLING = "sampling"
|
||||||
|
LORAS = "loras"
|
||||||
|
SIZE = "size"
|
||||||
|
|
||||||
|
# Collection of categories for iteration
|
||||||
|
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE]
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
|
||||||
|
|
||||||
class MetadataProcessor:
|
class MetadataProcessor:
|
||||||
"""Process and format collected metadata"""
|
"""Process and format collected metadata"""
|
||||||
|
|
||||||
@@ -9,7 +11,7 @@ class MetadataProcessor:
|
|||||||
primary_sampler = None
|
primary_sampler = None
|
||||||
primary_sampler_id = None
|
primary_sampler_id = None
|
||||||
|
|
||||||
for node_id, sampler_info in metadata.get("sampling", {}).items():
|
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
|
||||||
parameters = sampler_info.get("parameters", {})
|
parameters = sampler_info.get("parameters", {})
|
||||||
denoise = parameters.get("denoise")
|
denoise = parameters.get("denoise")
|
||||||
|
|
||||||
@@ -41,11 +43,11 @@ class MetadataProcessor:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def find_primary_checkpoint(metadata):
|
def find_primary_checkpoint(metadata):
|
||||||
"""Find the primary checkpoint model in the workflow"""
|
"""Find the primary checkpoint model in the workflow"""
|
||||||
if not metadata.get("models"):
|
if not metadata.get(MODELS):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# In most workflows, there's only one checkpoint, so we can just take the first one
|
# In most workflows, there's only one checkpoint, so we can just take the first one
|
||||||
for node_id, model_info in metadata.get("models", {}).items():
|
for node_id, model_info in metadata.get(MODELS, {}).items():
|
||||||
if model_info.get("type") == "checkpoint":
|
if model_info.get("type") == "checkpoint":
|
||||||
return model_info.get("name")
|
return model_info.get("name")
|
||||||
|
|
||||||
@@ -90,18 +92,18 @@ class MetadataProcessor:
|
|||||||
if prompt and primary_sampler_id:
|
if prompt and primary_sampler_id:
|
||||||
# Trace positive prompt
|
# Trace positive prompt
|
||||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive")
|
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive")
|
||||||
if positive_node_id and positive_node_id in metadata.get("prompts", {}):
|
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||||
params["prompt"] = metadata["prompts"][positive_node_id].get("text", "")
|
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||||
|
|
||||||
# Trace negative prompt
|
# Trace negative prompt
|
||||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative")
|
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative")
|
||||||
if negative_node_id and negative_node_id in metadata.get("prompts", {}):
|
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||||
params["negative_prompt"] = metadata["prompts"][negative_node_id].get("text", "")
|
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||||
|
|
||||||
# Check if the sampler itself has size information (from latent_image)
|
# Check if the sampler itself has size information (from latent_image)
|
||||||
if primary_sampler_id in metadata.get("size", {}):
|
if primary_sampler_id in metadata.get(SIZE, {}):
|
||||||
width = metadata["size"][primary_sampler_id].get("width")
|
width = metadata[SIZE][primary_sampler_id].get("width")
|
||||||
height = metadata["size"][primary_sampler_id].get("height")
|
height = metadata[SIZE][primary_sampler_id].get("height")
|
||||||
if width and height:
|
if width and height:
|
||||||
params["size"] = f"{width}x{height}"
|
params["size"] = f"{width}x{height}"
|
||||||
else:
|
else:
|
||||||
@@ -115,9 +117,9 @@ class MetadataProcessor:
|
|||||||
# Limit depth to avoid infinite loops in complex workflows
|
# Limit depth to avoid infinite loops in complex workflows
|
||||||
max_depth = 10
|
max_depth = 10
|
||||||
for _ in range(max_depth):
|
for _ in range(max_depth):
|
||||||
if current_node_id in metadata.get("size", {}):
|
if current_node_id in metadata.get(SIZE, {}):
|
||||||
width = metadata["size"][current_node_id].get("width")
|
width = metadata[SIZE][current_node_id].get("width")
|
||||||
height = metadata["size"][current_node_id].get("height")
|
height = metadata[SIZE][current_node_id].get("height")
|
||||||
if width and height:
|
if width and height:
|
||||||
params["size"] = f"{width}x{height}"
|
params["size"] = f"{width}x{height}"
|
||||||
size_found = True
|
size_found = True
|
||||||
@@ -141,7 +143,7 @@ class MetadataProcessor:
|
|||||||
|
|
||||||
# Extract LoRAs using the standardized format
|
# Extract LoRAs using the standardized format
|
||||||
lora_parts = []
|
lora_parts = []
|
||||||
for node_id, lora_info in metadata.get("loras", {}).items():
|
for node_id, lora_info in metadata.get(LORAS, {}).items():
|
||||||
# Access the lora_list from the standardized format
|
# Access the lora_list from the standardized format
|
||||||
lora_list = lora_info.get("lora_list", [])
|
lora_list = lora_info.get("lora_list", [])
|
||||||
for lora in lora_list:
|
for lora in lora_list:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import time
|
import time
|
||||||
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
|
||||||
|
from .constants import METADATA_CATEGORIES
|
||||||
|
|
||||||
class MetadataRegistry:
|
class MetadataRegistry:
|
||||||
"""A singleton registry to store and retrieve workflow metadata"""
|
"""A singleton registry to store and retrieve workflow metadata"""
|
||||||
@@ -22,22 +23,21 @@ class MetadataRegistry:
|
|||||||
self.node_cache = {}
|
self.node_cache = {}
|
||||||
|
|
||||||
# Categories we want to track and retrieve from cache
|
# Categories we want to track and retrieve from cache
|
||||||
self.metadata_categories = ["models", "prompts", "sampling", "loras", "size"]
|
self.metadata_categories = METADATA_CATEGORIES
|
||||||
|
|
||||||
def start_collection(self, prompt_id):
|
def start_collection(self, prompt_id):
|
||||||
"""Begin metadata collection for a new prompt"""
|
"""Begin metadata collection for a new prompt"""
|
||||||
self.current_prompt_id = prompt_id
|
self.current_prompt_id = prompt_id
|
||||||
self.executed_nodes = set()
|
self.executed_nodes = set()
|
||||||
self.prompt_metadata[prompt_id] = {
|
self.prompt_metadata[prompt_id] = {
|
||||||
"models": {},
|
category: {} for category in METADATA_CATEGORIES
|
||||||
"prompts": {},
|
}
|
||||||
"sampling": {},
|
# Add additional metadata fields
|
||||||
"loras": {},
|
self.prompt_metadata[prompt_id].update({
|
||||||
"size": {},
|
|
||||||
"execution_order": [],
|
"execution_order": [],
|
||||||
"current_prompt": None, # Will store the prompt object
|
"current_prompt": None, # Will store the prompt object
|
||||||
"timestamp": time.time()
|
"timestamp": time.time()
|
||||||
}
|
})
|
||||||
|
|
||||||
def set_current_prompt(self, prompt):
|
def set_current_prompt(self, prompt):
|
||||||
"""Set the current prompt object reference"""
|
"""Set the current prompt object reference"""
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
|
||||||
|
|
||||||
|
|
||||||
class NodeMetadataExtractor:
|
class NodeMetadataExtractor:
|
||||||
"""Base class for node-specific metadata extraction"""
|
"""Base class for node-specific metadata extraction"""
|
||||||
@@ -28,7 +30,7 @@ class CheckpointLoaderExtractor(NodeMetadataExtractor):
|
|||||||
|
|
||||||
model_name = inputs.get("ckpt_name")
|
model_name = inputs.get("ckpt_name")
|
||||||
if model_name:
|
if model_name:
|
||||||
metadata["models"][node_id] = {
|
metadata[MODELS][node_id] = {
|
||||||
"name": model_name,
|
"name": model_name,
|
||||||
"type": "checkpoint",
|
"type": "checkpoint",
|
||||||
"node_id": node_id
|
"node_id": node_id
|
||||||
@@ -41,7 +43,7 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
|||||||
return
|
return
|
||||||
|
|
||||||
text = inputs.get("text", "")
|
text = inputs.get("text", "")
|
||||||
metadata["prompts"][node_id] = {
|
metadata[PROMPTS][node_id] = {
|
||||||
"text": text,
|
"text": text,
|
||||||
"node_id": node_id
|
"node_id": node_id
|
||||||
}
|
}
|
||||||
@@ -57,7 +59,7 @@ class SamplerExtractor(NodeMetadataExtractor):
|
|||||||
if key in inputs:
|
if key in inputs:
|
||||||
sampling_params[key] = inputs[key]
|
sampling_params[key] = inputs[key]
|
||||||
|
|
||||||
metadata["sampling"][node_id] = {
|
metadata[SAMPLING][node_id] = {
|
||||||
"parameters": sampling_params,
|
"parameters": sampling_params,
|
||||||
"node_id": node_id
|
"node_id": node_id
|
||||||
}
|
}
|
||||||
@@ -74,10 +76,10 @@ class SamplerExtractor(NodeMetadataExtractor):
|
|||||||
height = int(samples.shape[2] * 8)
|
height = int(samples.shape[2] * 8)
|
||||||
width = int(samples.shape[3] * 8)
|
width = int(samples.shape[3] * 8)
|
||||||
|
|
||||||
if "size" not in metadata:
|
if SIZE not in metadata:
|
||||||
metadata["size"] = {}
|
metadata[SIZE] = {}
|
||||||
|
|
||||||
metadata["size"][node_id] = {
|
metadata[SIZE][node_id] = {
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
"node_id": node_id
|
"node_id": node_id
|
||||||
@@ -95,7 +97,7 @@ class LoraLoaderExtractor(NodeMetadataExtractor):
|
|||||||
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
|
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
|
||||||
|
|
||||||
# Use the standardized format with lora_list
|
# Use the standardized format with lora_list
|
||||||
metadata["loras"][node_id] = {
|
metadata[LORAS][node_id] = {
|
||||||
"lora_list": [
|
"lora_list": [
|
||||||
{
|
{
|
||||||
"name": lora_name,
|
"name": lora_name,
|
||||||
@@ -114,10 +116,10 @@ class ImageSizeExtractor(NodeMetadataExtractor):
|
|||||||
width = inputs.get("width", 512)
|
width = inputs.get("width", 512)
|
||||||
height = inputs.get("height", 512)
|
height = inputs.get("height", 512)
|
||||||
|
|
||||||
if "size" not in metadata:
|
if SIZE not in metadata:
|
||||||
metadata["size"] = {}
|
metadata[SIZE] = {}
|
||||||
|
|
||||||
metadata["size"][node_id] = {
|
metadata[SIZE][node_id] = {
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
"node_id": node_id
|
"node_id": node_id
|
||||||
@@ -164,7 +166,7 @@ class LoraLoaderManagerExtractor(NodeMetadataExtractor):
|
|||||||
})
|
})
|
||||||
|
|
||||||
if active_loras:
|
if active_loras:
|
||||||
metadata["loras"][node_id] = {
|
metadata[LORAS][node_id] = {
|
||||||
"lora_list": active_loras,
|
"lora_list": active_loras,
|
||||||
"node_id": node_id
|
"node_id": node_id
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user