From 60575b65466657c607dec2a136baf8816a622cb7 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 1 Apr 2025 08:38:49 +0800 Subject: [PATCH 01/18] checkpoint --- py/workflow/mappers.py | 398 +++++++++++++++--------------------- py/workflow/parser.py | 7 +- refs/output.json | 18 +- refs/prompt.json | 48 +++-- web/comfyui/loras_widget.js | 2 +- 5 files changed, 209 insertions(+), 264 deletions(-) diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 2f661521..383b1b4a 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -5,72 +5,102 @@ import logging import os import importlib.util import inspect -from typing import Dict, List, Any, Optional, Union, Type, Callable +from typing import Dict, List, Any, Optional, Union, Type, Callable, Tuple logger = logging.getLogger(__name__) # Global mapper registry -_MAPPER_REGISTRY: Dict[str, 'NodeMapper'] = {} +_MAPPER_REGISTRY: Dict[str, Dict] = {} -class NodeMapper: - """Base class for node mappers that define how to extract information from a specific node type""" +# ============================================================================= +# Mapper Definition Functions +# ============================================================================= + +def create_mapper( + node_type: str, + inputs_to_track: List[str], + transform_func: Callable[[Dict], Any] = None +) -> Dict: + """Create a mapper definition for a node type""" + mapper = { + "node_type": node_type, + "inputs_to_track": inputs_to_track, + "transform": transform_func or (lambda inputs: inputs) + } + return mapper + +def register_mapper(mapper: Dict) -> None: + """Register a node mapper in the global registry""" + _MAPPER_REGISTRY[mapper["node_type"]] = mapper + logger.debug(f"Registered mapper for node type: {mapper['node_type']}") + +def get_mapper(node_type: str) -> Optional[Dict]: + """Get a mapper for the specified node type""" + return _MAPPER_REGISTRY.get(node_type) + +def get_all_mappers() -> Dict[str, Dict]: + """Get all registered mappers""" + return _MAPPER_REGISTRY.copy() + +# ============================================================================= +# Node Processing Function +# ============================================================================= + +def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: + """Process a node using its mapper and extract relevant information""" + node_type = node_data.get("class_type") + mapper = get_mapper(node_type) - def __init__(self, node_type: str, inputs_to_track: List[str]): - self.node_type = node_type - self.inputs_to_track = inputs_to_track - - def process(self, node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore - """Process the node and extract relevant information""" - result = {} - for input_name in self.inputs_to_track: - if input_name in node_data.get("inputs", {}): - input_value = node_data["inputs"][input_name] - # Check if input is a reference to another node's output - if isinstance(input_value, list) and len(input_value) == 2: - # Format is [node_id, output_slot] - try: - ref_node_id, output_slot = input_value - # Convert node_id to string if it's an integer - if isinstance(ref_node_id, int): - ref_node_id = str(ref_node_id) - - # Recursively process the referenced node - ref_value = parser.process_node(ref_node_id, workflow) - - # Store the processed value - if ref_value is not None: - result[input_name] = ref_value - else: - # If we couldn't get a value from the reference, store the raw value - result[input_name] = input_value - except Exception as e: - logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}") - # If we couldn't process the reference, store the raw value - result[input_name] = input_value - else: - # Direct value - result[input_name] = input_value + if not mapper: + return None - # Apply any transformations - return self.transform(result) + result = {} - def transform(self, inputs: Dict) -> Any: - """Transform the extracted inputs - override in subclasses""" - return inputs + # Extract inputs based on the mapper's tracked inputs + for input_name in mapper["inputs_to_track"]: + if input_name in node_data.get("inputs", {}): + input_value = node_data["inputs"][input_name] + + # Check if input is a reference to another node's output + if isinstance(input_value, list) and len(input_value) == 2: + try: + # Format is [node_id, output_slot] + ref_node_id, output_slot = input_value + # Convert node_id to string if it's an integer + if isinstance(ref_node_id, int): + ref_node_id = str(ref_node_id) + + # Recursively process the referenced node + ref_value = parser.process_node(ref_node_id, workflow) + + if ref_value is not None: + result[input_name] = ref_value + else: + # If we couldn't get a value from the reference, store the raw value + result[input_name] = input_value + except Exception as e: + logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}") + result[input_name] = input_value + else: + # Direct value + result[input_name] = input_value + + # Apply the transform function + try: + return mapper["transform"](result) + except Exception as e: + logger.error(f"Error in transform function for node {node_id} of type {node_type}: {e}") + return result +# ============================================================================= +# Default Mapper Definitions +# ============================================================================= -class KSamplerMapper(NodeMapper): - """Mapper for KSampler nodes""" +def register_default_mappers() -> None: + """Register all default mappers""" - def __init__(self): - super().__init__( - node_type="KSampler", - inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler", - "denoise", "positive", "negative", "latent_image", - "model", "clip_skip"] - ) - - def transform(self, inputs: Dict) -> Dict: + # KSampler mapper + def transform_ksampler(inputs: Dict) -> Dict: result = { "seed": str(inputs.get("seed", "")), "steps": str(inputs.get("steps", "")), @@ -99,70 +129,52 @@ class KSamplerMapper(NodeMapper): result["clip_skip"] = str(inputs.get("clip_skip", "")) return result - - -class EmptyLatentImageMapper(NodeMapper): - """Mapper for EmptyLatentImage nodes""" - def __init__(self): - super().__init__( - node_type="EmptyLatentImage", - inputs_to_track=["width", "height", "batch_size"] - ) + register_mapper(create_mapper( + node_type="KSampler", + inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler", + "denoise", "positive", "negative", "latent_image", + "model", "clip_skip"], + transform_func=transform_ksampler + )) - def transform(self, inputs: Dict) -> Dict: + # EmptyLatentImage mapper + def transform_empty_latent(inputs: Dict) -> Dict: width = inputs.get("width", 0) height = inputs.get("height", 0) return {"width": width, "height": height, "size": f"{width}x{height}"} - - -class EmptySD3LatentImageMapper(NodeMapper): - """Mapper for EmptySD3LatentImage nodes""" - def __init__(self): - super().__init__( - node_type="EmptySD3LatentImage", - inputs_to_track=["width", "height", "batch_size"] - ) + register_mapper(create_mapper( + node_type="EmptyLatentImage", + inputs_to_track=["width", "height", "batch_size"], + transform_func=transform_empty_latent + )) - def transform(self, inputs: Dict) -> Dict: - width = inputs.get("width", 0) - height = inputs.get("height", 0) - return {"width": width, "height": height, "size": f"{width}x{height}"} - - -class CLIPTextEncodeMapper(NodeMapper): - """Mapper for CLIPTextEncode nodes""" + # SD3LatentImage mapper - reuses same transform function as EmptyLatentImage + register_mapper(create_mapper( + node_type="EmptySD3LatentImage", + inputs_to_track=["width", "height", "batch_size"], + transform_func=transform_empty_latent + )) - def __init__(self): - super().__init__( - node_type="CLIPTextEncode", - inputs_to_track=["text", "clip"] - ) - - def transform(self, inputs: Dict) -> Any: - # Simply return the text + # CLIPTextEncode mapper + def transform_clip_text(inputs: Dict) -> Any: return inputs.get("text", "") - - -class LoraLoaderMapper(NodeMapper): - """Mapper for LoraLoader nodes""" - def __init__(self): - super().__init__( - node_type="Lora Loader (LoraManager)", - inputs_to_track=["loras", "lora_stack"] - ) + register_mapper(create_mapper( + node_type="CLIPTextEncode", + inputs_to_track=["text", "clip"], + transform_func=transform_clip_text + )) - def transform(self, inputs: Dict) -> Dict: - # Fallback to loras array if text field doesn't exist or is invalid + # LoraLoader mapper + def transform_lora_loader(inputs: Dict) -> Dict: loras_data = inputs.get("loras", []) lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) - # Process loras array - filter active entries lora_texts = [] - # Check if loras_data is a list or a dict with __value__ key (new format) + # Process loras array if isinstance(loras_data, dict) and "__value__" in loras_data: loras_list = loras_data["__value__"] elif isinstance(loras_data, list): @@ -172,42 +184,29 @@ class LoraLoaderMapper(NodeMapper): # Process each active lora entry for lora in loras_list: - logger.info(f"Lora: {lora}, active: {lora.get('active')}") if isinstance(lora, dict) and lora.get("active", False): lora_name = lora.get("name", "") strength = lora.get("strength", 1.0) lora_texts.append(f"") - # Process lora_stack if it exists and is a valid format (list of tuples) + # Process lora_stack if valid if lora_stack and isinstance(lora_stack, list): - # If lora_stack is a reference to another node ([node_id, output_slot]), - # we don't process it here as it's already been processed recursively - if len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int): - # This is a reference to another node, already processed - pass - else: - # Format each entry from the stack (assuming it's a list of tuples) + if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)): for stack_entry in lora_stack: lora_name = stack_entry[0] strength = stack_entry[1] lora_texts.append(f"") - # Join with spaces - combined_text = " ".join(lora_texts) - - return {"loras": combined_text} - - -class LoraStackerMapper(NodeMapper): - """Mapper for LoraStacker nodes""" + return {"loras": " ".join(lora_texts)} - def __init__(self): - super().__init__( - node_type="Lora Stacker (LoraManager)", - inputs_to_track=["loras", "lora_stack"] - ) + register_mapper(create_mapper( + node_type="Lora Loader (LoraManager)", + inputs_to_track=["loras", "lora_stack"], + transform_func=transform_lora_loader + )) - def transform(self, inputs: Dict) -> Dict: + # LoraStacker mapper + def transform_lora_stacker(inputs: Dict) -> Dict: loras_data = inputs.get("loras", []) result_stack = [] @@ -215,25 +214,18 @@ class LoraStackerMapper(NodeMapper): existing_stack = [] lora_stack_input = inputs.get("lora_stack", []) - # Handle different formats of lora_stack if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input: - # Format from another LoraStacker node existing_stack = lora_stack_input["lora_stack"] elif isinstance(lora_stack_input, list): - # Direct list format or reference format [node_id, output_slot] - if len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and isinstance(lora_stack_input[1], int): - # This is likely a reference that was already processed - pass - else: - # Regular list of tuples/entries + if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and + isinstance(lora_stack_input[1], int)): existing_stack = lora_stack_input - # Add existing entries first + # Add existing entries if existing_stack: result_stack.extend(existing_stack) - # Process loras array - filter active entries - # Check if loras_data is a list or a dict with __value__ key (new format) + # Process new loras if isinstance(loras_data, dict) and "__value__" in loras_data: loras_list = loras_data["__value__"] elif isinstance(loras_data, list): @@ -241,7 +233,6 @@ class LoraStackerMapper(NodeMapper): else: loras_list = [] - # Process each active lora entry for lora in loras_list: if isinstance(lora, dict) and lora.get("active", False): lora_name = lora.get("name", "") @@ -249,50 +240,40 @@ class LoraStackerMapper(NodeMapper): result_stack.append((lora_name, strength)) return {"lora_stack": result_stack} - - -class JoinStringsMapper(NodeMapper): - """Mapper for JoinStrings nodes""" - def __init__(self): - super().__init__( - node_type="JoinStrings", - inputs_to_track=["string1", "string2", "delimiter"] - ) + register_mapper(create_mapper( + node_type="Lora Stacker (LoraManager)", + inputs_to_track=["loras", "lora_stack"], + transform_func=transform_lora_stacker + )) - def transform(self, inputs: Dict) -> str: + # JoinStrings mapper + def transform_join_strings(inputs: Dict) -> str: string1 = inputs.get("string1", "") string2 = inputs.get("string2", "") delimiter = inputs.get("delimiter", "") return f"{string1}{delimiter}{string2}" - - -class StringConstantMapper(NodeMapper): - """Mapper for StringConstant and StringConstantMultiline nodes""" - def __init__(self): - super().__init__( - node_type="StringConstantMultiline", - inputs_to_track=["string"] - ) + register_mapper(create_mapper( + node_type="JoinStrings", + inputs_to_track=["string1", "string2", "delimiter"], + transform_func=transform_join_strings + )) - def transform(self, inputs: Dict) -> str: + # StringConstant mapper + def transform_string_constant(inputs: Dict) -> str: return inputs.get("string", "") - - -class TriggerWordToggleMapper(NodeMapper): - """Mapper for TriggerWordToggle nodes""" - def __init__(self): - super().__init__( - node_type="TriggerWord Toggle (LoraManager)", - inputs_to_track=["toggle_trigger_words"] - ) + register_mapper(create_mapper( + node_type="StringConstantMultiline", + inputs_to_track=["string"], + transform_func=transform_string_constant + )) - def transform(self, inputs: Dict) -> str: + # TriggerWordToggle mapper + def transform_trigger_word_toggle(inputs: Dict) -> str: toggle_data = inputs.get("toggle_trigger_words", []) - # check if toggle_words is a list or a dict with __value__ key (new format) if isinstance(toggle_data, dict) and "__value__" in toggle_data: toggle_words = toggle_data["__value__"] elif isinstance(toggle_data, list): @@ -308,28 +289,21 @@ class TriggerWordToggleMapper(NodeMapper): if word and not word.startswith("__dummy"): active_words.append(word) - # Join with commas - result = ", ".join(active_words) - return result - - -class FluxGuidanceMapper(NodeMapper): - """Mapper for FluxGuidance nodes""" + return ", ".join(active_words) - def __init__(self): - super().__init__( - node_type="FluxGuidance", - inputs_to_track=["guidance", "conditioning"] - ) + register_mapper(create_mapper( + node_type="TriggerWord Toggle (LoraManager)", + inputs_to_track=["toggle_trigger_words"], + transform_func=transform_trigger_word_toggle + )) - def transform(self, inputs: Dict) -> Dict: + # FluxGuidance mapper + def transform_flux_guidance(inputs: Dict) -> Dict: result = {} - # Handle guidance parameter if "guidance" in inputs: result["guidance"] = inputs["guidance"] - # Handle conditioning (the prompt text) if "conditioning" in inputs: conditioning = inputs["conditioning"] if isinstance(conditioning, str): @@ -338,42 +312,12 @@ class FluxGuidanceMapper(NodeMapper): result["prompt"] = "Unknown prompt" return result - - -# ============================================================================= -# Mapper Registry Functions -# ============================================================================= - -def register_mapper(mapper: NodeMapper) -> None: - """Register a node mapper in the global registry""" - _MAPPER_REGISTRY[mapper.node_type] = mapper - logger.debug(f"Registered mapper for node type: {mapper.node_type}") - -def get_mapper(node_type: str) -> Optional[NodeMapper]: - """Get a mapper for the specified node type""" - return _MAPPER_REGISTRY.get(node_type) - -def get_all_mappers() -> Dict[str, NodeMapper]: - """Get all registered mappers""" - return _MAPPER_REGISTRY.copy() - -def register_default_mappers() -> None: - """Register all default mappers""" - default_mappers = [ - KSamplerMapper(), - EmptyLatentImageMapper(), - EmptySD3LatentImageMapper(), - CLIPTextEncodeMapper(), - LoraLoaderMapper(), - LoraStackerMapper(), - JoinStringsMapper(), - StringConstantMapper(), - TriggerWordToggleMapper(), - FluxGuidanceMapper() - ] - for mapper in default_mappers: - register_mapper(mapper) + register_mapper(create_mapper( + node_type="FluxGuidance", + inputs_to_track=["guidance", "conditioning"], + transform_func=transform_flux_guidance + )) # ============================================================================= # Extension Loading @@ -383,8 +327,8 @@ def load_extensions(ext_dir: str = None) -> None: """ Load mapper extensions from the specified directory - Each Python file in the directory will be loaded, and any NodeMapper subclasses - defined in those files will be automatically registered. + Extension files should define mappers using the create_mapper function + and then call register_mapper to add them to the registry. """ # Use default path if none provided if ext_dir is None: @@ -410,19 +354,9 @@ def load_extensions(ext_dir: str = None) -> None: if spec and spec.loader: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - - # Find all NodeMapper subclasses in the module - for name, obj in inspect.getmembers(module): - if (inspect.isclass(obj) and issubclass(obj, NodeMapper) - and obj != NodeMapper and hasattr(obj, 'node_type')): - # Instantiate and register the mapper - mapper = obj() - register_mapper(mapper) - logger.info(f"Loaded extension mapper: {mapper.node_type} from {filename}") - + logger.info(f"Loaded extension module: {filename}") except Exception as e: logger.warning(f"Error loading extension {filename}: {e}") - # Initialize the registry with default mappers register_default_mappers() \ No newline at end of file diff --git a/py/workflow/parser.py b/py/workflow/parser.py index 70b40edd..a289deac 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -4,7 +4,7 @@ Main workflow parser implementation for ComfyUI import json import logging from typing import Dict, List, Any, Optional, Union, Set -from .mappers import get_mapper, get_all_mappers, load_extensions +from .mappers import get_mapper, get_all_mappers, load_extensions, process_node from .utils import ( load_workflow, save_output, find_node_by_type, trace_model_path @@ -45,10 +45,9 @@ class WorkflowParser: node_type = node_data.get("class_type") result = None - mapper = get_mapper(node_type) - if mapper: + if get_mapper(node_type): try: - result = mapper.process(node_id, node_data, workflow, self) + result = process_node(node_id, node_data, workflow, self) # Cache the result self.node_results_cache[node_id] = result except Exception as e: diff --git a/refs/output.json b/refs/output.json index 9bfc59f5..775870c5 100644 --- a/refs/output.json +++ b/refs/output.json @@ -1,13 +1,11 @@ { "loras": " ", - "gen_params": { - "prompt": "in the style of ck-rw, aorun, scales, makeup, bare shoulders, pointy ears, dress, claws, in the style of cksc, artist:moriimee, in the style of cknc, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing", - "negative_prompt": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw", - "steps": "20", - "sampler": "euler_ancestral", - "cfg_scale": "8", - "seed": "241", - "size": "832x1216", - "clip_skip": "2" - } + "prompt": "in the style of ck-rw, aorun, scales, makeup, bare shoulders, pointy ears, dress, claws, in the style of cksc, artist:moriimee, in the style of cknc, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing", + "negative_prompt": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw", + "steps": "20", + "sampler": "euler_ancestral", + "cfg_scale": "8", + "seed": "241", + "size": "832x1216", + "clip_skip": "2" } \ No newline at end of file diff --git a/refs/prompt.json b/refs/prompt.json index 531b34ca..db03d459 100644 --- a/refs/prompt.json +++ b/refs/prompt.json @@ -1,7 +1,7 @@ { "3": { "inputs": { - "seed": 241, + "seed": 42, "steps": 20, "cfg": 8, "sampler_name": "euler_ancestral", @@ -121,7 +121,7 @@ }, "21": { "inputs": { - "string": "masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing", + "string": "masterpiece, best quality, good quality, very awa, newest, highres, absurdres, 1girl, solo, dress, standing, flower, outdoors, water, white flower, pink flower, scenery, reflection, rain, dark, ripples, yellow flower, puddle, colorful, abstract, standing on liquidi¼Œ\nvery Wide Shot, limited palette,", "strip_newlines": false }, "class_type": "StringConstantMultiline", @@ -151,15 +151,19 @@ "group_mode": true, "toggle_trigger_words": [ { - "text": "in the style of ck-rw", + "text": "xxx667_illu", "active": true }, { - "text": "in the style of cksc", + "text": "glowing", "active": true }, { - "text": "artist:moriimee", + "text": "glitch", + "active": true + }, + { + "text": "15546+456868", "active": true }, { @@ -173,7 +177,7 @@ "_isDummy": true } ], - "orinalMessage": "in the style of ck-rw,, in the style of cksc,, artist:moriimee", + "orinalMessage": "xxx667_illu,, glowing,, glitch,, 15546+456868", "trigger_words": [ "56", 2 @@ -186,22 +190,32 @@ }, "56": { "inputs": { - "text": " ", + "text": " ", "loras": [ { - "name": "ck-shadow-circuit-IL-000012", - "strength": 0.78, + "name": "ponyv6_noobE11_2_adamW-000017", + "strength": 0.3, "active": true }, { - "name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1", - "strength": 0.45, - "active": true - }, - { - "name": "ck-nc-cyberpunk-IL-000011", + "name": "XXX667", "strength": 0.4, - "active": false + "active": true + }, + { + "name": "114558v4df2fsdf5", + "strength": 0.6, + "active": true + }, + { + "name": "illustriousXL_stabilizer_v1.23", + "strength": 0.3, + "active": true + }, + { + "name": "mon_monmon2133", + "strength": 0.5, + "active": true }, { "name": "__dummy_item1__", @@ -273,7 +287,7 @@ { "name": "ck-neon-retrowave-IL-000012", "strength": 0.8, - "active": true + "active": false }, { "name": "__dummy_item1__", diff --git a/web/comfyui/loras_widget.js b/web/comfyui/loras_widget.js index d8ea4d29..0e6c0129 100644 --- a/web/comfyui/loras_widget.js +++ b/web/comfyui/loras_widget.js @@ -824,7 +824,7 @@ async function saveRecipeDirectly(widget) { try { // Get the workflow data from the ComfyUI app const prompt = await app.graphToPrompt(); - console.log('Prompt:', prompt.output); + console.log('Prompt:', prompt); // Show loading toast if (app && app.extensionManager && app.extensionManager.toast) { From 195866b00d2c002fb9c7e578e2e6d798300f06e6 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 1 Apr 2025 16:22:57 +0800 Subject: [PATCH 02/18] Implement KJNodes extension with new mappers and transform functions - Added KJNodes mappers for JoinStrings, StringConstantMultiline, and EmptyLatentImagePresets. - Introduced transform functions to handle string joining, string constants, and dimension extraction with optional inversion. - Registered new mappers and logged successful registration for better traceability. --- py/workflow/ext/kjnodes.py | 81 +++++++ py/workflow/mappers.py | 417 ++++++++++++++++++------------------- 2 files changed, 280 insertions(+), 218 deletions(-) create mode 100644 py/workflow/ext/kjnodes.py diff --git a/py/workflow/ext/kjnodes.py b/py/workflow/ext/kjnodes.py new file mode 100644 index 00000000..76e45edf --- /dev/null +++ b/py/workflow/ext/kjnodes.py @@ -0,0 +1,81 @@ +""" +KJNodes mappers extension for ComfyUI workflow parsing +""" +import logging +import re +from typing import Dict, Any + +logger = logging.getLogger(__name__) + +# Import the mapper registration functions from the parent module +from workflow.mappers import create_mapper, register_mapper + +# ============================================================================= +# Transform Functions +# ============================================================================= + +def transform_join_strings(inputs: Dict) -> str: + """Transform function for JoinStrings nodes""" + string1 = inputs.get("string1", "") + string2 = inputs.get("string2", "") + delimiter = inputs.get("delimiter", "") + return f"{string1}{delimiter}{string2}" + +def transform_string_constant(inputs: Dict) -> str: + """Transform function for StringConstant nodes""" + return inputs.get("string", "") + +def transform_empty_latent_presets(inputs: Dict) -> Dict: + """Transform function for EmptyLatentImagePresets nodes""" + dimensions = inputs.get("dimensions", "") + invert = inputs.get("invert", False) + + # Extract width and height from dimensions string + # Expected format: "width x height (ratio)" or similar + width = 0 + height = 0 + + if dimensions: + # Try to extract dimensions using regex + match = re.search(r'(\d+)\s*x\s*(\d+)', dimensions) + if match: + width = int(match.group(1)) + height = int(match.group(2)) + + # If invert is True, swap width and height + if invert and width and height: + width, height = height, width + + return {"width": width, "height": height, "size": f"{width}x{height}"} + +# ============================================================================= +# Register Mappers +# ============================================================================= + +# Define the mappers for KJNodes +KJNODES_MAPPERS = { + "JoinStrings": { + "inputs_to_track": ["string1", "string2", "delimiter"], + "transform_func": transform_join_strings + }, + "StringConstantMultiline": { + "inputs_to_track": ["string"], + "transform_func": transform_string_constant + }, + "EmptyLatentImagePresets": { + "inputs_to_track": ["dimensions", "invert", "batch_size"], + "transform_func": transform_empty_latent_presets + } +} + +# Register all KJNodes mappers +for node_type, config in KJNODES_MAPPERS.items(): + mapper = create_mapper( + node_type=node_type, + inputs_to_track=config["inputs_to_track"], + transform_func=config["transform_func"] + ) + register_mapper(mapper) + logger.info(f"Registered KJNodes mapper for node type: {node_type}") + +logger.info(f"Loaded KJNodes extension with {len(KJNODES_MAPPERS)} mappers") \ No newline at end of file diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 383b1b4a..96fb02a2 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -52,6 +52,7 @@ def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'Workflo mapper = get_mapper(node_type) if not mapper: + logger.warning(f"No mapper found for node type: {node_type}") return None result = {} @@ -93,231 +94,211 @@ def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'Workflo return result # ============================================================================= -# Default Mapper Definitions +# Transform Functions # ============================================================================= -def register_default_mappers() -> None: - """Register all default mappers""" +def transform_ksampler(inputs: Dict) -> Dict: + """Transform function for KSampler nodes""" + result = { + "seed": str(inputs.get("seed", "")), + "steps": str(inputs.get("steps", "")), + "cfg": str(inputs.get("cfg", "")), + "sampler": inputs.get("sampler_name", ""), + "scheduler": inputs.get("scheduler", ""), + } - # KSampler mapper - def transform_ksampler(inputs: Dict) -> Dict: - result = { - "seed": str(inputs.get("seed", "")), - "steps": str(inputs.get("steps", "")), - "cfg": str(inputs.get("cfg", "")), - "sampler": inputs.get("sampler_name", ""), - "scheduler": inputs.get("scheduler", ""), - } + # Process positive prompt + if "positive" in inputs: + result["prompt"] = inputs["positive"] + + # Process negative prompt + if "negative" in inputs: + result["negative_prompt"] = inputs["negative"] + + # Get dimensions from latent image + if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): + width = inputs["latent_image"].get("width", 0) + height = inputs["latent_image"].get("height", 0) + if width and height: + result["size"] = f"{width}x{height}" + + # Add clip_skip if present + if "clip_skip" in inputs: + result["clip_skip"] = str(inputs.get("clip_skip", "")) - # Process positive prompt - if "positive" in inputs: - result["prompt"] = inputs["positive"] - - # Process negative prompt - if "negative" in inputs: - result["negative_prompt"] = inputs["negative"] - - # Get dimensions from latent image - if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): - width = inputs["latent_image"].get("width", 0) - height = inputs["latent_image"].get("height", 0) - if width and height: - result["size"] = f"{width}x{height}" - - # Add clip_skip if present - if "clip_skip" in inputs: - result["clip_skip"] = str(inputs.get("clip_skip", "")) - - return result - - register_mapper(create_mapper( - node_type="KSampler", - inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler", - "denoise", "positive", "negative", "latent_image", - "model", "clip_skip"], - transform_func=transform_ksampler - )) - - # EmptyLatentImage mapper - def transform_empty_latent(inputs: Dict) -> Dict: - width = inputs.get("width", 0) - height = inputs.get("height", 0) - return {"width": width, "height": height, "size": f"{width}x{height}"} - - register_mapper(create_mapper( - node_type="EmptyLatentImage", - inputs_to_track=["width", "height", "batch_size"], - transform_func=transform_empty_latent - )) - - # SD3LatentImage mapper - reuses same transform function as EmptyLatentImage - register_mapper(create_mapper( - node_type="EmptySD3LatentImage", - inputs_to_track=["width", "height", "batch_size"], - transform_func=transform_empty_latent - )) - - # CLIPTextEncode mapper - def transform_clip_text(inputs: Dict) -> Any: - return inputs.get("text", "") - - register_mapper(create_mapper( - node_type="CLIPTextEncode", - inputs_to_track=["text", "clip"], - transform_func=transform_clip_text - )) - - # LoraLoader mapper - def transform_lora_loader(inputs: Dict) -> Dict: - loras_data = inputs.get("loras", []) - lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) - - lora_texts = [] - - # Process loras array - if isinstance(loras_data, dict) and "__value__" in loras_data: - loras_list = loras_data["__value__"] - elif isinstance(loras_data, list): - loras_list = loras_data - else: - loras_list = [] - - # Process each active lora entry - for lora in loras_list: - if isinstance(lora, dict) and lora.get("active", False): - lora_name = lora.get("name", "") - strength = lora.get("strength", 1.0) - lora_texts.append(f"") - - # Process lora_stack if valid - if lora_stack and isinstance(lora_stack, list): - if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)): - for stack_entry in lora_stack: - lora_name = stack_entry[0] - strength = stack_entry[1] - lora_texts.append(f"") - - return {"loras": " ".join(lora_texts)} - - register_mapper(create_mapper( - node_type="Lora Loader (LoraManager)", - inputs_to_track=["loras", "lora_stack"], - transform_func=transform_lora_loader - )) - - # LoraStacker mapper - def transform_lora_stacker(inputs: Dict) -> Dict: - loras_data = inputs.get("loras", []) - result_stack = [] - - # Handle existing stack entries - existing_stack = [] - lora_stack_input = inputs.get("lora_stack", []) - - if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input: - existing_stack = lora_stack_input["lora_stack"] - elif isinstance(lora_stack_input, list): - if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and - isinstance(lora_stack_input[1], int)): - existing_stack = lora_stack_input - - # Add existing entries - if existing_stack: - result_stack.extend(existing_stack) - - # Process new loras - if isinstance(loras_data, dict) and "__value__" in loras_data: - loras_list = loras_data["__value__"] - elif isinstance(loras_data, list): - loras_list = loras_data - else: - loras_list = [] - - for lora in loras_list: - if isinstance(lora, dict) and lora.get("active", False): - lora_name = lora.get("name", "") - strength = float(lora.get("strength", 1.0)) - result_stack.append((lora_name, strength)) - - return {"lora_stack": result_stack} - - register_mapper(create_mapper( - node_type="Lora Stacker (LoraManager)", - inputs_to_track=["loras", "lora_stack"], - transform_func=transform_lora_stacker - )) - - # JoinStrings mapper - def transform_join_strings(inputs: Dict) -> str: - string1 = inputs.get("string1", "") - string2 = inputs.get("string2", "") - delimiter = inputs.get("delimiter", "") - return f"{string1}{delimiter}{string2}" - - register_mapper(create_mapper( - node_type="JoinStrings", - inputs_to_track=["string1", "string2", "delimiter"], - transform_func=transform_join_strings - )) - - # StringConstant mapper - def transform_string_constant(inputs: Dict) -> str: - return inputs.get("string", "") - - register_mapper(create_mapper( - node_type="StringConstantMultiline", - inputs_to_track=["string"], - transform_func=transform_string_constant - )) - - # TriggerWordToggle mapper - def transform_trigger_word_toggle(inputs: Dict) -> str: - toggle_data = inputs.get("toggle_trigger_words", []) + return result - if isinstance(toggle_data, dict) and "__value__" in toggle_data: - toggle_words = toggle_data["__value__"] - elif isinstance(toggle_data, list): - toggle_words = toggle_data +def transform_empty_latent(inputs: Dict) -> Dict: + """Transform function for EmptyLatentImage nodes""" + width = inputs.get("width", 0) + height = inputs.get("height", 0) + return {"width": width, "height": height, "size": f"{width}x{height}"} + +def transform_clip_text(inputs: Dict) -> Any: + """Transform function for CLIPTextEncode nodes""" + return inputs.get("text", "") + +def transform_lora_loader(inputs: Dict) -> Dict: + """Transform function for LoraLoader nodes""" + loras_data = inputs.get("loras", []) + lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) + + lora_texts = [] + + # Process loras array + if isinstance(loras_data, dict) and "__value__" in loras_data: + loras_list = loras_data["__value__"] + elif isinstance(loras_data, list): + loras_list = loras_data + else: + loras_list = [] + + # Process each active lora entry + for lora in loras_list: + if isinstance(lora, dict) and lora.get("active", False): + lora_name = lora.get("name", "") + strength = lora.get("strength", 1.0) + lora_texts.append(f"") + + # Process lora_stack if valid + if lora_stack and isinstance(lora_stack, list): + if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)): + for stack_entry in lora_stack: + lora_name = stack_entry[0] + strength = stack_entry[1] + lora_texts.append(f"") + + return {"loras": " ".join(lora_texts)} + +def transform_lora_stacker(inputs: Dict) -> Dict: + """Transform function for LoraStacker nodes""" + loras_data = inputs.get("loras", []) + result_stack = [] + + # Handle existing stack entries + existing_stack = [] + lora_stack_input = inputs.get("lora_stack", []) + + if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input: + existing_stack = lora_stack_input["lora_stack"] + elif isinstance(lora_stack_input, list): + if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and + isinstance(lora_stack_input[1], int)): + existing_stack = lora_stack_input + + # Add existing entries + if existing_stack: + result_stack.extend(existing_stack) + + # Process new loras + if isinstance(loras_data, dict) and "__value__" in loras_data: + loras_list = loras_data["__value__"] + elif isinstance(loras_data, list): + loras_list = loras_data + else: + loras_list = [] + + for lora in loras_list: + if isinstance(lora, dict) and lora.get("active", False): + lora_name = lora.get("name", "") + strength = float(lora.get("strength", 1.0)) + result_stack.append((lora_name, strength)) + + return {"lora_stack": result_stack} + +def transform_trigger_word_toggle(inputs: Dict) -> str: + """Transform function for TriggerWordToggle nodes""" + toggle_data = inputs.get("toggle_trigger_words", []) + + if isinstance(toggle_data, dict) and "__value__" in toggle_data: + toggle_words = toggle_data["__value__"] + elif isinstance(toggle_data, list): + toggle_words = toggle_data + else: + toggle_words = [] + + # Filter active trigger words + active_words = [] + for item in toggle_words: + if isinstance(item, dict) and item.get("active", False): + word = item.get("text", "") + if word and not word.startswith("__dummy"): + active_words.append(word) + + return ", ".join(active_words) + +def transform_flux_guidance(inputs: Dict) -> Dict: + """Transform function for FluxGuidance nodes""" + result = {} + + if "guidance" in inputs: + result["guidance"] = inputs["guidance"] + + if "conditioning" in inputs: + conditioning = inputs["conditioning"] + if isinstance(conditioning, str): + result["prompt"] = conditioning else: - toggle_words = [] - - # Filter active trigger words - active_words = [] - for item in toggle_words: - if isinstance(item, dict) and item.get("active", False): - word = item.get("text", "") - if word and not word.startswith("__dummy"): - active_words.append(word) - - return ", ".join(active_words) + result["prompt"] = "Unknown prompt" - register_mapper(create_mapper( - node_type="TriggerWord Toggle (LoraManager)", - inputs_to_track=["toggle_trigger_words"], - transform_func=transform_trigger_word_toggle - )) - - # FluxGuidance mapper - def transform_flux_guidance(inputs: Dict) -> Dict: - result = {} - - if "guidance" in inputs: - result["guidance"] = inputs["guidance"] - - if "conditioning" in inputs: - conditioning = inputs["conditioning"] - if isinstance(conditioning, str): - result["prompt"] = conditioning - else: - result["prompt"] = "Unknown prompt" - - return result - - register_mapper(create_mapper( - node_type="FluxGuidance", - inputs_to_track=["guidance", "conditioning"], - transform_func=transform_flux_guidance - )) + return result + +# ============================================================================= +# Node Mapper Definitions +# ============================================================================= + +# Central definition of all supported node types and their configurations +NODE_MAPPERS = { + # ComfyUI core nodes + "KSampler": { + "inputs_to_track": [ + "seed", "steps", "cfg", "sampler_name", "scheduler", + "denoise", "positive", "negative", "latent_image", + "model", "clip_skip" + ], + "transform_func": transform_ksampler + }, + "EmptyLatentImage": { + "inputs_to_track": ["width", "height", "batch_size"], + "transform_func": transform_empty_latent + }, + "EmptySD3LatentImage": { + "inputs_to_track": ["width", "height", "batch_size"], + "transform_func": transform_empty_latent + }, + "CLIPTextEncode": { + "inputs_to_track": ["text", "clip"], + "transform_func": transform_clip_text + }, + "FluxGuidance": { + "inputs_to_track": ["guidance", "conditioning"], + "transform_func": transform_flux_guidance + }, + # LoraManager nodes + "Lora Loader (LoraManager)": { + "inputs_to_track": ["loras", "lora_stack"], + "transform_func": transform_lora_loader + }, + "Lora Stacker (LoraManager)": { + "inputs_to_track": ["loras", "lora_stack"], + "transform_func": transform_lora_stacker + }, + "TriggerWord Toggle (LoraManager)": { + "inputs_to_track": ["toggle_trigger_words"], + "transform_func": transform_trigger_word_toggle + } +} + +def register_default_mappers() -> None: + """Register all default mappers from the NODE_MAPPERS dictionary""" + for node_type, config in NODE_MAPPERS.items(): + mapper = create_mapper( + node_type=node_type, + inputs_to_track=config["inputs_to_track"], + transform_func=config["transform_func"] + ) + register_mapper(mapper) + logger.info(f"Registered {len(NODE_MAPPERS)} default node mappers") # ============================================================================= # Extension Loading From 27db60ce681a19314426f0fb79395dcb7ad41f33 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 1 Apr 2025 19:17:43 +0800 Subject: [PATCH 03/18] checkpoint --- .vscode/launch.json | 15 + py/workflow/ext/comfyui_core.py | 169 ++++++++++ py/workflow/ext/example_mapper.py | 54 --- py/workflow/ext/kjnodes.py | 8 + py/workflow/parser.py | 89 ++++- refs/prompt.json | 536 +++++++++++++++++++----------- 6 files changed, 601 insertions(+), 270 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 py/workflow/ext/comfyui_core.py delete mode 100644 py/workflow/ext/example_mapper.py diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..6b76b4fa --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} \ No newline at end of file diff --git a/py/workflow/ext/comfyui_core.py b/py/workflow/ext/comfyui_core.py new file mode 100644 index 00000000..59713ee8 --- /dev/null +++ b/py/workflow/ext/comfyui_core.py @@ -0,0 +1,169 @@ +""" +ComfyUI Core nodes mappers extension for workflow parsing +""" +import logging +from typing import Dict, Any, List + +logger = logging.getLogger(__name__) + +# Import the mapper registration functions from the parent module +from workflow.mappers import create_mapper, register_mapper + +# ============================================================================= +# Transform Functions +# ============================================================================= + +def transform_random_noise(inputs: Dict) -> Dict: + """Transform function for RandomNoise node""" + return {"seed": str(inputs.get("noise_seed", ""))} + +def transform_ksampler_select(inputs: Dict) -> Dict: + """Transform function for KSamplerSelect node""" + return {"sampler": inputs.get("sampler_name", "")} + +def transform_basic_scheduler(inputs: Dict) -> Dict: + """Transform function for BasicScheduler node""" + result = { + "scheduler": inputs.get("scheduler", ""), + "denoise": str(inputs.get("denoise", "1.0")) + } + + # Get steps from inputs or steps input + if "steps" in inputs: + if isinstance(inputs["steps"], str): + result["steps"] = inputs["steps"] + elif isinstance(inputs["steps"], dict) and "value" in inputs["steps"]: + result["steps"] = str(inputs["steps"]["value"]) + else: + result["steps"] = str(inputs["steps"]) + + return result + +def transform_basic_guider(inputs: Dict) -> Dict: + """Transform function for BasicGuider node""" + result = {} + + # Process conditioning + if "conditioning" in inputs: + if isinstance(inputs["conditioning"], str): + result["prompt"] = inputs["conditioning"] + elif isinstance(inputs["conditioning"], dict): + result["conditioning"] = inputs["conditioning"] + + # Get model information if needed + if "model" in inputs and isinstance(inputs["model"], dict): + if "loras" in inputs["model"]: + result["loras"] = inputs["model"]["loras"] + + return result + +def transform_model_sampling_flux(inputs: Dict) -> Dict: + """Transform function for ModelSamplingFlux - mostly a pass-through node""" + # This node is primarily used for routing, so we mostly pass through values + result = {} + + # Extract any dimensions if present + width = inputs.get("width", 0) + height = inputs.get("height", 0) + if width and height: + result["width"] = width + result["height"] = height + result["size"] = f"{width}x{height}" + + # Pass through model information + if "model" in inputs and isinstance(inputs["model"], dict): + for key, value in inputs["model"].items(): + result[key] = value + + return result + +def transform_sampler_custom_advanced(inputs: Dict) -> Dict: + """Transform function for SamplerCustomAdvanced node""" + result = {} + + # Extract seed from noise + if "noise" in inputs and isinstance(inputs["noise"], dict): + result["seed"] = str(inputs["noise"].get("seed", "")) + + # Extract sampler info + if "sampler" in inputs and isinstance(inputs["sampler"], dict): + sampler = inputs["sampler"].get("sampler", "") + if sampler: + result["sampler"] = sampler + + # Extract scheduler, steps, denoise from sigmas + if "sigmas" in inputs and isinstance(inputs["sigmas"], dict): + sigmas = inputs["sigmas"] + result["scheduler"] = sigmas.get("scheduler", "") + result["steps"] = str(sigmas.get("steps", "")) + result["denoise"] = str(sigmas.get("denoise", "1.0")) + + # Extract prompt and guidance from guider + if "guider" in inputs and isinstance(inputs["guider"], dict): + guider = inputs["guider"] + + # Get prompt from conditioning + if "conditioning" in guider and isinstance(guider["conditioning"], str): + result["prompt"] = guider["conditioning"] + elif "conditioning" in guider and isinstance(guider["conditioning"], dict): + result["guidance"] = guider["conditioning"].get("guidance", "") + result["prompt"] = guider["conditioning"].get("prompt", "") + + if "model" in guider and isinstance(guider["model"], dict): + result["loras"] = guider["model"].get("loras", "") + + # Extract dimensions from latent_image + if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): + latent = inputs["latent_image"] + width = latent.get("width", 0) + height = latent.get("height", 0) + if width and height: + result["width"] = width + result["height"] = height + result["size"] = f"{width}x{height}" + + return result + +# ============================================================================= +# Register Mappers +# ============================================================================= + +# Define the mappers for ComfyUI core nodes not in main mapper +COMFYUI_CORE_MAPPERS = { + "RandomNoise": { + "inputs_to_track": ["noise_seed"], + "transform_func": transform_random_noise + }, + "KSamplerSelect": { + "inputs_to_track": ["sampler_name"], + "transform_func": transform_ksampler_select + }, + "BasicScheduler": { + "inputs_to_track": ["scheduler", "steps", "denoise", "model"], + "transform_func": transform_basic_scheduler + }, + "BasicGuider": { + "inputs_to_track": ["model", "conditioning"], + "transform_func": transform_basic_guider + }, + "ModelSamplingFlux": { + "inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"], + "transform_func": transform_model_sampling_flux + }, + "SamplerCustomAdvanced": { + "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], + "transform_func": transform_sampler_custom_advanced + } +} + +# Register all ComfyUI core mappers +for node_type, config in COMFYUI_CORE_MAPPERS.items(): + mapper = create_mapper( + node_type=node_type, + inputs_to_track=config["inputs_to_track"], + transform_func=config["transform_func"] + ) + register_mapper(mapper) + logger.info(f"Registered ComfyUI core mapper for node type: {node_type}") + +logger.info(f"Loaded ComfyUI core extension with {len(COMFYUI_CORE_MAPPERS)} mappers") \ No newline at end of file diff --git a/py/workflow/ext/example_mapper.py b/py/workflow/ext/example_mapper.py deleted file mode 100644 index 652be09e..00000000 --- a/py/workflow/ext/example_mapper.py +++ /dev/null @@ -1,54 +0,0 @@ -""" -Example extension mapper for demonstrating the extension system -""" -from typing import Dict, Any -from ..mappers import NodeMapper - -class ExampleNodeMapper(NodeMapper): - """Example mapper for custom nodes""" - - def __init__(self): - super().__init__( - node_type="ExampleCustomNode", - inputs_to_track=["param1", "param2", "image"] - ) - - def transform(self, inputs: Dict) -> Dict: - """Transform extracted inputs into the desired output format""" - result = {} - - # Extract interesting parameters - if "param1" in inputs: - result["example_param1"] = inputs["param1"] - - if "param2" in inputs: - result["example_param2"] = inputs["param2"] - - # You can process the data in any way needed - return result - - -class VAEMapperExtension(NodeMapper): - """Extension mapper for VAE nodes""" - - def __init__(self): - super().__init__( - node_type="VAELoader", - inputs_to_track=["vae_name"] - ) - - def transform(self, inputs: Dict) -> Dict: - """Extract VAE information""" - vae_name = inputs.get("vae_name", "") - - # Remove path prefix if present - if "/" in vae_name or "\\" in vae_name: - # Get just the filename without path or extension - vae_name = vae_name.replace("\\", "/").split("/")[-1] - vae_name = vae_name.split(".")[0] # Remove extension - - return {"vae": vae_name} - - -# Note: No need to register manually - extensions are automatically registered -# when the extension system loads this file \ No newline at end of file diff --git a/py/workflow/ext/kjnodes.py b/py/workflow/ext/kjnodes.py index 76e45edf..ecab4cfd 100644 --- a/py/workflow/ext/kjnodes.py +++ b/py/workflow/ext/kjnodes.py @@ -48,6 +48,10 @@ def transform_empty_latent_presets(inputs: Dict) -> Dict: return {"width": width, "height": height, "size": f"{width}x{height}"} +def transform_int_constant(inputs: Dict) -> int: + """Transform function for INTConstant nodes""" + return inputs.get("value", 0) + # ============================================================================= # Register Mappers # ============================================================================= @@ -65,6 +69,10 @@ KJNODES_MAPPERS = { "EmptyLatentImagePresets": { "inputs_to_track": ["dimensions", "invert", "batch_size"], "transform_func": transform_empty_latent_presets + }, + "INTConstant": { + "inputs_to_track": ["value"], + "transform_func": transform_int_constant } } diff --git a/py/workflow/parser.py b/py/workflow/parser.py index a289deac..e57d8ce4 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -59,6 +59,59 @@ class WorkflowParser: self.processed_nodes.remove(node_id) return result + def find_primary_sampler_node(self, workflow: Dict) -> Optional[str]: + """ + Find the primary sampler node in the workflow. + + Priority: + 1. First try to find a SamplerCustomAdvanced node + 2. If not found, look for KSampler nodes with denoise=1.0 + 3. If still not found, use the first KSampler node + + Args: + workflow: The workflow data as a dictionary + + Returns: + The node ID of the primary sampler node, or None if not found + """ + # First check for SamplerCustomAdvanced nodes + sampler_advanced_nodes = [] + ksampler_nodes = [] + + # Scan workflow for sampler nodes + for node_id, node_data in workflow.items(): + node_type = node_data.get("class_type") + + if node_type == "SamplerCustomAdvanced": + sampler_advanced_nodes.append(node_id) + elif node_type == "KSampler": + ksampler_nodes.append(node_id) + + # If we found SamplerCustomAdvanced nodes, return the first one + if sampler_advanced_nodes: + logger.info(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}") + return sampler_advanced_nodes[0] + + # If we have KSampler nodes, look for one with denoise=1.0 + if ksampler_nodes: + for node_id in ksampler_nodes: + node_data = workflow[node_id] + inputs = node_data.get("inputs", {}) + denoise = inputs.get("denoise", 0) + + # Check if denoise is 1.0 (allowing for small floating point differences) + if abs(float(denoise) - 1.0) < 0.001: + logger.info(f"Found KSampler node with denoise=1.0: {node_id}") + return node_id + + # If no KSampler with denoise=1.0 found, use the first one + logger.info(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}") + return ksampler_nodes[0] + + # No sampler nodes found + logger.warning("No sampler nodes found in workflow") + return None + def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str: """Collect loras information from the model node chain""" if not isinstance(model_input, list) or len(model_input) != 2: @@ -107,23 +160,23 @@ class WorkflowParser: self.processed_nodes = set() self.node_results_cache = {} - # Find the KSampler node - ksampler_node_id = find_node_by_type(workflow, "KSampler") - if not ksampler_node_id: - logger.warning("No KSampler node found in workflow") + # Find the primary sampler node + sampler_node_id = self.find_primary_sampler_node(workflow) + if not sampler_node_id: + logger.warning("No suitable sampler node found in workflow") return {} - # Start parsing from the KSampler node + # Start parsing from the sampler node result = { "gen_params": {}, "loras": "" } - # Process KSampler node to extract parameters - ksampler_result = self.process_node(ksampler_node_id, workflow) - if ksampler_result: + # Process sampler node to extract parameters + sampler_result = self.process_node(sampler_node_id, workflow) + if sampler_result: # Process the result - for key, value in ksampler_result.items(): + for key, value in sampler_result.items(): # Special handling for the positive prompt from FluxGuidance if key == "positive" and isinstance(value, dict): # Extract guidance value @@ -138,8 +191,8 @@ class WorkflowParser: result["gen_params"][key] = value # Process the positive prompt node if it exists and we don't have a prompt yet - if "prompt" not in result["gen_params"] and "positive" in ksampler_result: - positive_value = ksampler_result.get("positive") + if "prompt" not in result["gen_params"] and "positive" in sampler_result: + positive_value = sampler_result.get("positive") if isinstance(positive_value, str): result["gen_params"]["prompt"] = positive_value @@ -152,11 +205,11 @@ class WorkflowParser: if "guidance" in node_inputs: result["gen_params"]["guidance"] = node_inputs["guidance"] - # Extract loras from the model input of KSampler - ksampler_node = workflow.get(ksampler_node_id, {}) - ksampler_inputs = ksampler_node.get("inputs", {}) - if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list): - loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow) + # Extract loras from the model input of sampler + sampler_node = workflow.get(sampler_node_id, {}) + sampler_inputs = sampler_node.get("inputs", {}) + if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list): + loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow) if loras_text: result["loras"] = loras_text @@ -164,9 +217,9 @@ class WorkflowParser: if "cfg" in result["gen_params"]: result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg") - # Add clip_skip = 2 to match reference output if not already present + # Add clip_skip = 1 to match reference output if not already present if "clip_skip" not in result["gen_params"]: - result["gen_params"]["clip_skip"] = "2" + result["gen_params"]["clip_skip"] = "1" # Ensure the prompt is a string and not a nested dictionary if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict): diff --git a/refs/prompt.json b/refs/prompt.json index db03d459..535ddec5 100644 --- a/refs/prompt.json +++ b/refs/prompt.json @@ -1,75 +1,12 @@ { - "3": { - "inputs": { - "seed": 42, - "steps": 20, - "cfg": 8, - "sampler_name": "euler_ancestral", - "scheduler": "karras", - "denoise": 1, - "model": [ - "56", - 0 - ], - "positive": [ - "6", - 0 - ], - "negative": [ - "7", - 0 - ], - "latent_image": [ - "5", - 0 - ] - }, - "class_type": "KSampler", - "_meta": { - "title": "KSampler" - } - }, - "4": { - "inputs": { - "ckpt_name": "il\\waiNSFWIllustrious_v110.safetensors" - }, - "class_type": "CheckpointLoaderSimple", - "_meta": { - "title": "Load Checkpoint" - } - }, - "5": { - "inputs": { - "width": 832, - "height": 1216, - "batch_size": 1 - }, - "class_type": "EmptyLatentImage", - "_meta": { - "title": "Empty Latent Image" - } - }, "6": { "inputs": { "text": [ - "22", + "301", 0 ], "clip": [ - "56", - 1 - ] - }, - "class_type": "CLIPTextEncode", - "_meta": { - "title": "CLIP Text Encode (Prompt)" - } - }, - "7": { - "inputs": { - "text": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw", - "clip": [ - "56", + "299", 1 ] }, @@ -81,12 +18,12 @@ "8": { "inputs": { "samples": [ - "3", - 0 + "13", + 1 ], "vae": [ - "4", - 2 + "10", + 0 ] }, "class_type": "VAEDecode", @@ -94,7 +31,275 @@ "title": "VAE Decode" } }, - "14": { + "10": { + "inputs": { + "vae_name": "flux1\\ae.safetensors" + }, + "class_type": "VAELoader", + "_meta": { + "title": "Load VAE" + } + }, + "11": { + "inputs": { + "clip_name1": "t5xxl_fp8_e4m3fn.safetensors", + "clip_name2": "ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors", + "type": "flux", + "device": "default" + }, + "class_type": "DualCLIPLoader", + "_meta": { + "title": "DualCLIPLoader" + } + }, + "13": { + "inputs": { + "noise": [ + "147", + 0 + ], + "guider": [ + "22", + 0 + ], + "sampler": [ + "16", + 0 + ], + "sigmas": [ + "17", + 0 + ], + "latent_image": [ + "48", + 0 + ] + }, + "class_type": "SamplerCustomAdvanced", + "_meta": { + "title": "SamplerCustomAdvanced" + } + }, + "16": { + "inputs": { + "sampler_name": "dpmpp_2m" + }, + "class_type": "KSamplerSelect", + "_meta": { + "title": "KSamplerSelect" + } + }, + "17": { + "inputs": { + "scheduler": "beta", + "steps": [ + "246", + 0 + ], + "denoise": 1, + "model": [ + "28", + 0 + ] + }, + "class_type": "BasicScheduler", + "_meta": { + "title": "BasicScheduler" + } + }, + "22": { + "inputs": { + "model": [ + "28", + 0 + ], + "conditioning": [ + "29", + 0 + ] + }, + "class_type": "BasicGuider", + "_meta": { + "title": "BasicGuider" + } + }, + "28": { + "inputs": { + "max_shift": 1.1500000000000001, + "base_shift": 0.5, + "width": [ + "48", + 1 + ], + "height": [ + "48", + 2 + ], + "model": [ + "299", + 0 + ] + }, + "class_type": "ModelSamplingFlux", + "_meta": { + "title": "ModelSamplingFlux" + } + }, + "29": { + "inputs": { + "guidance": 3.5, + "conditioning": [ + "6", + 0 + ] + }, + "class_type": "FluxGuidance", + "_meta": { + "title": "FluxGuidance" + } + }, + "48": { + "inputs": { + "resolution": "832x1216 (0.68)", + "batch_size": 1, + "width_override": 0, + "height_override": 0 + }, + "class_type": "SDXLEmptyLatentSizePicker+", + "_meta": { + "title": "🔧 SDXL Empty Latent Size Picker" + } + }, + "65": { + "inputs": { + "unet_name": "flux\\flux1-dev-fp8-e4m3fn.safetensors", + "weight_dtype": "fp8_e4m3fn_fast" + }, + "class_type": "UNETLoader", + "_meta": { + "title": "Load Diffusion Model" + } + }, + "147": { + "inputs": { + "noise_seed": 651532572596956 + }, + "class_type": "RandomNoise", + "_meta": { + "title": "RandomNoise" + } + }, + "148": { + "inputs": { + "wildcard_text": "__some-prompts__", + "populated_text": "A surreal digital artwork showcases a forward-thinking inventor captivated by his intricate mechanical creation through a large magnifying glass. Viewed from an unconventional perspective, the scene reveals an eccentric assembly of gears, springs, and brass instruments within his workshop. Soft, ethereal light radiates from the invention, casting enigmatic shadows on the walls as time appears to bend around its metallic form, invoking a sense of curiosity, wonder, and exhilaration in discovery.", + "mode": "fixed", + "seed": 553084268162351, + "Select to add Wildcard": "Select the Wildcard to add to the text" + }, + "class_type": "ImpactWildcardProcessor", + "_meta": { + "title": "ImpactWildcardProcessor" + } + }, + "151": { + "inputs": { + "text": "A hyper-realistic close-up portrait of a young woman with shoulder-length black hair styled in edgy, futuristic layers, adorned with glowing tips. She wears mecha eyewear with a neon green visor that transitions into iridescent shades of teal and gold. The frame is sleek, with angular edges and fine mechanical detailing. Her expression is fierce and confident, with flawless skin highlighted by the neon reflections. She wears a high-tech bodysuit with integrated LED lines and metallic panels. The background depicts a hazy rendition of The Great Wave off Kanagawa by Hokusai, its powerful waves blending seamlessly with the neon tones, amplifying her intense, defiant aura." + }, + "class_type": "Text Multiline", + "_meta": { + "title": "Text Multiline" + } + }, + "191": { + "inputs": { + "text": "A cinematic, oil painting masterpiece captures the essence of impressionistic surrealism, inspired by Claude Monet. A mysterious woman in a flowing crimson dress stands at the edge of a tranquil lake, where lily pads shimmer under an ethereal, golden twilight. The water’s surface reflects a dreamlike sky, its swirling hues of violet and sapphire melting together like liquid light. The thick, expressive brushstrokes lend depth to the scene, evoking a sense of nostalgia and quiet longing, as if the world itself is caught between reality and a fleeting dream. \nA mesmerizing oil painting masterpiece inspired by Salvador Dalí, blending surrealism with post-impressionist texture. A lone violinist plays atop a melting clock tower, his form distorted by the passage of time. The sky is a cascade of swirling, liquid oranges and deep blues, where floating staircases spiral endlessly into the horizon. The impasto technique gives depth and movement to the surreal elements, making time itself feel fluid, as if the world is dissolving into a dream. \nA stunning impressionistic oil painting evokes the spirit of Edvard Munch, capturing a solitary figure standing on a rain-soaked street, illuminated by the glow of flickering gas lamps. The swirling, chaotic strokes of deep blues and fiery reds reflect the turbulence of emotion, while the blurred reflections in the wet cobblestone suggest a merging of past and present. The faceless figure, draped in a dark overcoat, seems lost in thought, embodying the ephemeral nature of memory and time. \nA breathtaking oil painting masterpiece, inspired by Gustav Klimt, presents a celestial ballroom where faceless dancers swirl in an eternal waltz beneath a gilded, star-speckled sky. Their golden garments shimmer with intricate patterns, blending into the opulent mosaic floor that seems to stretch into infinity. The dreamlike composition, rich in warm amber and deep sapphire hues, captures an otherworldly elegance, as if the dancers are suspended in a moment that transcends time. \nA visionary oil painting inspired by Marc Chagall depicts a dreamlike cityscape where gravity ceases to exist. A couple floats above a crimson-tinted town, their forms dissolving into the swirling strokes of a vast, cerulean sky. The buildings below twist and bend in rhythmic motion, their windows glowing like tiny stars. The thick, textured brushwork conveys a sense of weightlessness and wonder, as if love itself has defied the laws of the universe. \nAn impressionistic oil painting in the style of J.M.W. Turner, depicting a ghostly ship sailing through a sea of swirling golden mist. The waves crash and dissolve into abstract, fiery strokes of orange and deep indigo, blurring the line between ocean and sky. The ship appears almost ethereal, as if drifting between worlds, lost in the ever-changing tides of memory and myth. The dynamic brushstrokes capture the relentless power of nature and the fleeting essence of time. \nA captivating oil painting masterpiece, infused with surrealist impressionism, portrays a grand library where books float midair, their pages unraveling into ribbons of light. The towering shelves twist into the heavens, vanishing into an infinite, starry void. A lone scholar, illuminated by the glow of a suspended lantern, reaches for a book that seems to pulse with life. The scene pulses with mystery, where the impasto textures bring depth to the interplay between knowledge and dreams. \nA luminous impressionistic oil painting captures the melancholic beauty of an abandoned carnival, its faded carousel horses frozen mid-gallop beneath a sky of swirling lavender and gold. The wind carries fragments of forgotten laughter through the empty fairground, where scattered ticket stubs and crumbling banners whisper tales of joy long past. The thick, textured brushstrokes blend nostalgia with an eerie dreamlike quality, as if the carnival exists only in the echoes of memory. \nA surreal oil painting in the spirit of René Magritte, featuring a towering lighthouse that emits not light, but cascading waterfalls from its peak. The swirling sky, painted in deep midnight blues, is punctuated by glowing, crescent moons that defy gravity. A lone figure stands at the water’s edge, gazing up in quiet contemplation, as if caught between wonder and the unknown. The painting’s rich textures and luminous colors create an enigmatic, dreamlike landscape. \nA striking impressionistic oil painting, reminiscent of Van Gogh, portrays a lone traveler on a winding cobblestone path, their silhouette bathed in the golden glow of lantern-lit cherry blossoms. The petals swirl through the night air like glowing embers, blending with the deep, rhythmic strokes of a star-filled indigo sky. The scene captures a feeling of wistful solitude, as if the traveler is walking not only through the city, but through the fleeting nature of time itself." + }, + "class_type": "Text Multiline", + "_meta": { + "title": "Text Multiline" + } + }, + "203": { + "inputs": { + "string1": [ + "289", + 0 + ], + "string2": [ + "293", + 0 + ], + "delimiter": ", " + }, + "class_type": "JoinStrings", + "_meta": { + "title": "Join Strings" + } + }, + "208": { + "inputs": { + "file_path": "", + "dictionary_name": "[filename]", + "label": "TextBatch", + "mode": "automatic", + "index": 0, + "multiline_text": [ + "191", + 0 + ] + }, + "class_type": "Text Load Line From File", + "_meta": { + "title": "Text Load Line From File" + } + }, + "223": { + "inputs": { + "filename": "%time_%seed", + "path": "%date", + "extension": "jpeg", + "steps": [ + "246", + 0 + ], + "cfg": 3.5, + "modelname": "flux_dev", + "sampler_name": "dpmpp_2m", + "scheduler": "beta", + "positive": [ + "203", + 0 + ], + "negative": "", + "width": [ + "48", + 1 + ], + "height": [ + "48", + 2 + ], + "lossless_webp": true, + "quality_jpeg_or_webp": 100, + "optimize_png": false, + "counter": 0, + "denoise": 1, + "clip_skip": 1, + "time_format": "%Y-%m-%d-%H%M%S", + "save_workflow_as_json": false, + "embed_workflow_in_png": false, + "images": [ + "8", + 0 + ] + }, + "class_type": "Image Saver", + "_meta": { + "title": "Image Saver" + } + }, + "226": { "inputs": { "images": [ "8", @@ -106,64 +311,25 @@ "title": "Preview Image" } }, - "19": { + "246": { "inputs": { - "stop_at_clip_layer": -2, - "clip": [ - "4", - 1 - ] + "value": 25 }, - "class_type": "CLIPSetLastLayer", + "class_type": "INTConstant", "_meta": { - "title": "CLIP Set Last Layer" + "title": "Steps" } }, - "21": { - "inputs": { - "string": "masterpiece, best quality, good quality, very awa, newest, highres, absurdres, 1girl, solo, dress, standing, flower, outdoors, water, white flower, pink flower, scenery, reflection, rain, dark, ripples, yellow flower, puddle, colorful, abstract, standing on liquidi¼Œ\nvery Wide Shot, limited palette,", - "strip_newlines": false - }, - "class_type": "StringConstantMultiline", - "_meta": { - "title": "positive" - } - }, - "22": { - "inputs": { - "string1": [ - "55", - 0 - ], - "string2": [ - "21", - 0 - ], - "delimiter": ", " - }, - "class_type": "JoinStrings", - "_meta": { - "title": "Join Strings" - } - }, - "55": { + "289": { "inputs": { "group_mode": true, "toggle_trigger_words": [ { - "text": "xxx667_illu", + "text": "perfection style", "active": true }, { - "text": "glowing", - "active": true - }, - { - "text": "glitch", - "active": true - }, - { - "text": "15546+456868", + "text": "mythp0rt", "active": true }, { @@ -177,9 +343,9 @@ "_isDummy": true } ], - "orinalMessage": "xxx667_illu,, glowing,, glitch,, 15546+456868", + "orinalMessage": "perfection style,, mythp0rt", "trigger_words": [ - "56", + "299", 2 ] }, @@ -188,33 +354,57 @@ "title": "TriggerWord Toggle (LoraManager)" } }, - "56": { + "293": { "inputs": { - "text": " ", + "input": 1, + "text1": [ + "208", + 0 + ], + "text2": [ + "151", + 0 + ] + }, + "class_type": "easy textSwitch", + "_meta": { + "title": "Text Switch" + } + }, + "297": { + "inputs": { + "text": "" + }, + "class_type": "Lora Stacker (LoraManager)", + "_meta": { + "title": "Lora Stacker (LoraManager)" + } + }, + "298": { + "inputs": { + "text": "flux1/testing/matmillerartFLUX.safetensors,0.2,0.2", + "anything": [ + "297", + 0 + ] + }, + "class_type": "easy showAnything", + "_meta": { + "title": "Show Any" + } + }, + "299": { + "inputs": { + "text": " ", "loras": [ { - "name": "ponyv6_noobE11_2_adamW-000017", - "strength": 0.3, + "name": "boFLUX Double Exposure Magic v2", + "strength": 0.8, "active": true }, { - "name": "XXX667", - "strength": 0.4, - "active": true - }, - { - "name": "114558v4df2fsdf5", - "strength": 0.6, - "active": true - }, - { - "name": "illustriousXL_stabilizer_v1.23", - "strength": 0.3, - "active": true - }, - { - "name": "mon_monmon2133", - "strength": 0.5, + "name": "FluxDFaeTasticDetails", + "strength": 0.65, "active": true }, { @@ -231,15 +421,15 @@ } ], "model": [ - "4", + "65", 0 ], "clip": [ - "4", - 1 + "11", + 0 ], "lora_stack": [ - "57", + "297", 0 ] }, @@ -248,64 +438,14 @@ "title": "Lora Loader (LoraManager)" } }, - "57": { + "301": { "inputs": { - "text": "", - "loras": [ - { - "name": "aorunIllstrious", - "strength": "0.90", - "active": false - }, - { - "name": "__dummy_item1__", - "strength": 0, - "active": false, - "_isDummy": true - }, - { - "name": "__dummy_item2__", - "strength": 0, - "active": false, - "_isDummy": true - } - ], - "lora_stack": [ - "59", - 0 - ] + "string": "A hyper-realistic close-up portrait of a young woman with shoulder-length black hair styled in edgy, futuristic layers, adorned with glowing tips. She wears mecha eyewear with a neon green visor that transitions into iridescent shades of teal and gold. The frame is sleek, with angular edges and fine mechanical detailing. Her expression is fierce and confident, with flawless skin highlighted by the neon reflections. She wears a high-tech bodysuit with integrated LED lines and metallic panels. The background depicts a hazy rendition of The Great Wave off Kanagawa by Hokusai, its powerful waves blending seamlessly with the neon tones, amplifying her intense, defiant aura.", + "strip_newlines": true }, - "class_type": "Lora Stacker (LoraManager)", + "class_type": "StringConstantMultiline", "_meta": { - "title": "Lora Stacker (LoraManager)" - } - }, - "59": { - "inputs": { - "text": "", - "loras": [ - { - "name": "ck-neon-retrowave-IL-000012", - "strength": 0.8, - "active": false - }, - { - "name": "__dummy_item1__", - "strength": 0, - "active": false, - "_isDummy": true - }, - { - "name": "__dummy_item2__", - "strength": 0, - "active": false, - "_isDummy": true - } - ] - }, - "class_type": "Lora Stacker (LoraManager)", - "_meta": { - "title": "Lora Stacker (LoraManager)" + "title": "String Constant Multiline" } } } \ No newline at end of file From a8ec5af037534bfd826defa4c25768965e08e3ac Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 06:05:24 +0800 Subject: [PATCH 04/18] checkpoint --- py/routes/recipe_routes.py | 8 ++-- py/workflow/ext/comfyui_core.py | 49 ++++++++++++++---------- py/workflow/mappers.py | 18 ++++++++- py/workflow/parser.py | 68 ++++++++------------------------- refs/prompt.json | 54 +------------------------- 5 files changed, 67 insertions(+), 130 deletions(-) diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 037f5c19..a6e3ef91 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -783,8 +783,8 @@ class RecipeRoutes: # Parse the workflow to extract generation parameters and loras parsed_workflow = self.parser.parse_workflow(workflow_json) - if not parsed_workflow or not parsed_workflow.get("gen_params"): - return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400) + if not parsed_workflow: + return web.json_response({"error": "Could not extract parameters from workflow"}, status=400) # Get the lora stack from the parsed workflow lora_stack = parsed_workflow.get("loras", "") @@ -880,7 +880,9 @@ class RecipeRoutes: "created_date": time.time(), "base_model": most_common_base_model, "loras": loras_data, - "gen_params": parsed_workflow.get("gen_params", {}), # Use the parsed workflow parameters + "checkpoint": parsed_workflow.get("checkpoint", ""), + "gen_params": {key: value for key, value in parsed_workflow.items() + if key not in ['checkpoint', 'loras']}, "loras_stack": lora_stack # Include the original lora stack string } diff --git a/py/workflow/ext/comfyui_core.py b/py/workflow/ext/comfyui_core.py index 59713ee8..73b56f41 100644 --- a/py/workflow/ext/comfyui_core.py +++ b/py/workflow/ext/comfyui_core.py @@ -52,30 +52,15 @@ def transform_basic_guider(inputs: Dict) -> Dict: # Get model information if needed if "model" in inputs and isinstance(inputs["model"], dict): - if "loras" in inputs["model"]: - result["loras"] = inputs["model"]["loras"] + result["model"] = inputs["model"] return result def transform_model_sampling_flux(inputs: Dict) -> Dict: """Transform function for ModelSamplingFlux - mostly a pass-through node""" # This node is primarily used for routing, so we mostly pass through values - result = {} - # Extract any dimensions if present - width = inputs.get("width", 0) - height = inputs.get("height", 0) - if width and height: - result["width"] = width - result["height"] = height - result["size"] = f"{width}x{height}" - - # Pass through model information - if "model" in inputs and isinstance(inputs["model"], dict): - for key, value in inputs["model"].items(): - result[key] = value - - return result + return inputs["model"] def transform_sampler_custom_advanced(inputs: Dict) -> Dict: """Transform function for SamplerCustomAdvanced node""" @@ -110,6 +95,7 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict: result["prompt"] = guider["conditioning"].get("prompt", "") if "model" in guider and isinstance(guider["model"], dict): + result["checkpoint"] = guider["model"].get("checkpoint", "") result["loras"] = guider["model"].get("loras", "") # Extract dimensions from latent_image @@ -124,12 +110,27 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict: return result +def transform_unet_loader(inputs: Dict) -> Dict: + """Transform function for UNETLoader node""" + unet_name = inputs.get("unet_name", "") + return {"checkpoint": unet_name} if unet_name else {} + +def transform_checkpoint_loader(inputs: Dict) -> Dict: + """Transform function for CheckpointLoaderSimple node""" + ckpt_name = inputs.get("ckpt_name", "") + return {"checkpoint": ckpt_name} if ckpt_name else {} + # ============================================================================= # Register Mappers # ============================================================================= # Define the mappers for ComfyUI core nodes not in main mapper COMFYUI_CORE_MAPPERS = { + # KSamplers + "SamplerCustomAdvanced": { + "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], + "transform_func": transform_sampler_custom_advanced + }, "RandomNoise": { "inputs_to_track": ["noise_seed"], "transform_func": transform_random_noise @@ -150,9 +151,17 @@ COMFYUI_CORE_MAPPERS = { "inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"], "transform_func": transform_model_sampling_flux }, - "SamplerCustomAdvanced": { - "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], - "transform_func": transform_sampler_custom_advanced + "UNETLoader": { + "inputs_to_track": ["unet_name"], + "transform_func": transform_unet_loader + }, + "CheckpointLoaderSimple": { + "inputs_to_track": ["ckpt_name"], + "transform_func": transform_checkpoint_loader + }, + "CheckpointLoader": { + "inputs_to_track": ["ckpt_name"], + "transform_func": transform_checkpoint_loader } } diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 96fb02a2..f954cf54 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -125,6 +125,15 @@ def transform_ksampler(inputs: Dict) -> Dict: # Add clip_skip if present if "clip_skip" in inputs: result["clip_skip"] = str(inputs.get("clip_skip", "")) + + # Add guidance if present + if "guidance" in inputs: + result["guidance"] = str(inputs.get("guidance", "")) + + # Add model if present + if "model" in inputs: + result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "") + result["loras"] = inputs.get("model", {}).get("loras", "") return result @@ -167,8 +176,13 @@ def transform_lora_loader(inputs: Dict) -> Dict: lora_name = stack_entry[0] strength = stack_entry[1] lora_texts.append(f"") + + result = { + "checkpoint": inputs.get("model", {}).get("checkpoint", ""), + "loras": " ".join(lora_texts) + } - return {"loras": " ".join(lora_texts)} + return result def transform_lora_stacker(inputs: Dict) -> Dict: """Transform function for LoraStacker nodes""" @@ -276,7 +290,7 @@ NODE_MAPPERS = { }, # LoraManager nodes "Lora Loader (LoraManager)": { - "inputs_to_track": ["loras", "lora_stack"], + "inputs_to_track": ["model", "loras", "lora_stack"], "transform_func": transform_lora_loader }, "Lora Stacker (LoraManager)": { diff --git a/py/workflow/parser.py b/py/workflow/parser.py index e57d8ce4..2c913173 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -166,71 +166,33 @@ class WorkflowParser: logger.warning("No suitable sampler node found in workflow") return {} - # Start parsing from the sampler node - result = { - "gen_params": {}, - "loras": "" - } - # Process sampler node to extract parameters sampler_result = self.process_node(sampler_node_id, workflow) - if sampler_result: - # Process the result - for key, value in sampler_result.items(): - # Special handling for the positive prompt from FluxGuidance - if key == "positive" and isinstance(value, dict): - # Extract guidance value - if "guidance" in value: - result["gen_params"]["guidance"] = value["guidance"] - - # Extract prompt - if "prompt" in value: - result["gen_params"]["prompt"] = value["prompt"] - else: - # Normal handling for other values - result["gen_params"][key] = value + logger.info(f"Sampler result: {sampler_result}") + if not sampler_result: + return {} - # Process the positive prompt node if it exists and we don't have a prompt yet - if "prompt" not in result["gen_params"] and "positive" in sampler_result: - positive_value = sampler_result.get("positive") - if isinstance(positive_value, str): - result["gen_params"]["prompt"] = positive_value - - # Manually check for FluxGuidance if we don't have guidance value - if "guidance" not in result["gen_params"]: - flux_node_id = find_node_by_type(workflow, "FluxGuidance") - if flux_node_id: - # Get the direct input from the node - node_inputs = workflow[flux_node_id].get("inputs", {}) - if "guidance" in node_inputs: - result["gen_params"]["guidance"] = node_inputs["guidance"] - - # Extract loras from the model input of sampler - sampler_node = workflow.get(sampler_node_id, {}) - sampler_inputs = sampler_node.get("inputs", {}) - if "model" in sampler_inputs and isinstance(sampler_inputs["model"], list): - loras_text = self.collect_loras_from_model(sampler_inputs["model"], workflow) - if loras_text: - result["loras"] = loras_text + # Return the sampler result directly - it's already in the format we need + # This simplifies the structure and makes it easier to use in recipe_routes.py # Handle standard ComfyUI names vs our output format - if "cfg" in result["gen_params"]: - result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg") + if "cfg" in sampler_result: + sampler_result["cfg_scale"] = sampler_result.pop("cfg") # Add clip_skip = 1 to match reference output if not already present - if "clip_skip" not in result["gen_params"]: - result["gen_params"]["clip_skip"] = "1" + if "clip_skip" not in sampler_result: + sampler_result["clip_skip"] = "1" # Ensure the prompt is a string and not a nested dictionary - if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict): - if "prompt" in result["gen_params"]["prompt"]: - result["gen_params"]["prompt"] = result["gen_params"]["prompt"]["prompt"] + if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict): + if "prompt" in sampler_result["prompt"]: + sampler_result["prompt"] = sampler_result["prompt"]["prompt"] # Save the result if requested if output_path: - save_output(result, output_path) + save_output(sampler_result, output_path) - return result + return sampler_result def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict: @@ -245,4 +207,4 @@ def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dic Dictionary containing extracted parameters """ parser = WorkflowParser() - return parser.parse_workflow(workflow_path, output_path) \ No newline at end of file + return parser.parse_workflow(workflow_path, output_path) \ No newline at end of file diff --git a/refs/prompt.json b/refs/prompt.json index 535ddec5..96f62b0a 100644 --- a/refs/prompt.json +++ b/refs/prompt.json @@ -254,51 +254,6 @@ "title": "Text Load Line From File" } }, - "223": { - "inputs": { - "filename": "%time_%seed", - "path": "%date", - "extension": "jpeg", - "steps": [ - "246", - 0 - ], - "cfg": 3.5, - "modelname": "flux_dev", - "sampler_name": "dpmpp_2m", - "scheduler": "beta", - "positive": [ - "203", - 0 - ], - "negative": "", - "width": [ - "48", - 1 - ], - "height": [ - "48", - 2 - ], - "lossless_webp": true, - "quality_jpeg_or_webp": 100, - "optimize_png": false, - "counter": 0, - "denoise": 1, - "clip_skip": 1, - "time_format": "%Y-%m-%d-%H%M%S", - "save_workflow_as_json": false, - "embed_workflow_in_png": false, - "images": [ - "8", - 0 - ] - }, - "class_type": "Image Saver", - "_meta": { - "title": "Image Saver" - } - }, "226": { "inputs": { "images": [ @@ -325,11 +280,7 @@ "group_mode": true, "toggle_trigger_words": [ { - "text": "perfection style", - "active": true - }, - { - "text": "mythp0rt", + "text": "bo-exposure", "active": true }, { @@ -343,7 +294,7 @@ "_isDummy": true } ], - "orinalMessage": "perfection style,, mythp0rt", + "orinalMessage": "bo-exposure", "trigger_words": [ "299", 2 @@ -382,7 +333,6 @@ }, "298": { "inputs": { - "text": "flux1/testing/matmillerartFLUX.safetensors,0.2,0.2", "anything": [ "297", 0 From 5a93c40b7996f2aea82151c6129f95e47262e2d9 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 10:29:31 +0800 Subject: [PATCH 05/18] Refactor logging levels and improve mapper registration - Changed warning logs to debug logs in CivitaiClient and RecipeScanner for better log granularity. - Updated the mapper registration function name for clarity and adjusted related logging messages. - Enhanced extension loading process to automatically register mappers from NODE_MAPPERS_EXT, improving modularity and maintainability. --- py/services/civitai_client.py | 3 --- py/services/recipe_scanner.py | 8 ++++---- py/workflow/ext/comfyui_core.py | 21 +++------------------ py/workflow/ext/kjnodes.py | 21 +++------------------ py/workflow/mappers.py | 24 +++++++++++++++++------- py/workflow/parser.py | 14 ++++++-------- 6 files changed, 33 insertions(+), 58 deletions(-) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 88dafaaf..fbd77739 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -234,11 +234,9 @@ class CivitaiClient: if not self._session: return None - logger.info(f"Fetching model version info from Civitai for ID: {model_version_id}") version_info = await self._session.get(f"{self.base_url}/model-versions/{model_version_id}") if not version_info or not version_info.json().get('files'): - logger.warning(f"No files found in version info for ID: {model_version_id}") return None # Get hash from the first file @@ -248,7 +246,6 @@ class CivitaiClient: hash_value = file_info['hashes']['SHA256'].lower() return hash_value - logger.warning(f"No SHA256 hash found in version info for ID: {model_version_id}") return None except Exception as e: logger.error(f"Error getting hash from Civitai: {e}") diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 3db5f884..51590a3e 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -211,7 +211,7 @@ class RecipeScanner: lora['hash'] = hash_from_civitai metadata_updated = True else: - logger.warning(f"Could not get hash for modelVersionId {model_version_id}") + logger.debug(f"Could not get hash for modelVersionId {model_version_id}") # If has hash but no file_name, look up in lora library if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']): @@ -261,7 +261,7 @@ class RecipeScanner: version_info = await self._civitai_client.get_model_version_info(model_version_id) if not version_info or not version_info.get('files'): - logger.warning(f"No files found in version info for ID: {model_version_id}") + logger.debug(f"No files found in version info for ID: {model_version_id}") return None # Get hash from the first file @@ -269,7 +269,7 @@ class RecipeScanner: if file_info.get('hashes', {}).get('SHA256'): return file_info['hashes']['SHA256'] - logger.warning(f"No SHA256 hash found in version info for ID: {model_version_id}") + logger.debug(f"No SHA256 hash found in version info for ID: {model_version_id}") return None except Exception as e: logger.error(f"Error getting hash from Civitai: {e}") @@ -286,7 +286,7 @@ class RecipeScanner: if version_info and 'name' in version_info: return version_info['name'] - logger.warning(f"No version name found for modelVersionId {model_version_id}") + logger.debug(f"No version name found for modelVersionId {model_version_id}") return None except Exception as e: logger.error(f"Error getting model version name from Civitai: {e}") diff --git a/py/workflow/ext/comfyui_core.py b/py/workflow/ext/comfyui_core.py index 73b56f41..125c29f0 100644 --- a/py/workflow/ext/comfyui_core.py +++ b/py/workflow/ext/comfyui_core.py @@ -6,9 +6,6 @@ from typing import Dict, Any, List logger = logging.getLogger(__name__) -# Import the mapper registration functions from the parent module -from workflow.mappers import create_mapper, register_mapper - # ============================================================================= # Transform Functions # ============================================================================= @@ -121,11 +118,11 @@ def transform_checkpoint_loader(inputs: Dict) -> Dict: return {"checkpoint": ckpt_name} if ckpt_name else {} # ============================================================================= -# Register Mappers +# Node Mapper Definitions # ============================================================================= # Define the mappers for ComfyUI core nodes not in main mapper -COMFYUI_CORE_MAPPERS = { +NODE_MAPPERS_EXT = { # KSamplers "SamplerCustomAdvanced": { "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], @@ -163,16 +160,4 @@ COMFYUI_CORE_MAPPERS = { "inputs_to_track": ["ckpt_name"], "transform_func": transform_checkpoint_loader } -} - -# Register all ComfyUI core mappers -for node_type, config in COMFYUI_CORE_MAPPERS.items(): - mapper = create_mapper( - node_type=node_type, - inputs_to_track=config["inputs_to_track"], - transform_func=config["transform_func"] - ) - register_mapper(mapper) - logger.info(f"Registered ComfyUI core mapper for node type: {node_type}") - -logger.info(f"Loaded ComfyUI core extension with {len(COMFYUI_CORE_MAPPERS)} mappers") \ No newline at end of file +} \ No newline at end of file diff --git a/py/workflow/ext/kjnodes.py b/py/workflow/ext/kjnodes.py index ecab4cfd..8ea99d2c 100644 --- a/py/workflow/ext/kjnodes.py +++ b/py/workflow/ext/kjnodes.py @@ -7,9 +7,6 @@ from typing import Dict, Any logger = logging.getLogger(__name__) -# Import the mapper registration functions from the parent module -from workflow.mappers import create_mapper, register_mapper - # ============================================================================= # Transform Functions # ============================================================================= @@ -53,11 +50,11 @@ def transform_int_constant(inputs: Dict) -> int: return inputs.get("value", 0) # ============================================================================= -# Register Mappers +# Node Mapper Definitions # ============================================================================= # Define the mappers for KJNodes -KJNODES_MAPPERS = { +NODE_MAPPERS_EXT = { "JoinStrings": { "inputs_to_track": ["string1", "string2", "delimiter"], "transform_func": transform_join_strings @@ -74,16 +71,4 @@ KJNODES_MAPPERS = { "inputs_to_track": ["value"], "transform_func": transform_int_constant } -} - -# Register all KJNodes mappers -for node_type, config in KJNODES_MAPPERS.items(): - mapper = create_mapper( - node_type=node_type, - inputs_to_track=config["inputs_to_track"], - transform_func=config["transform_func"] - ) - register_mapper(mapper) - logger.info(f"Registered KJNodes mapper for node type: {node_type}") - -logger.info(f"Loaded KJNodes extension with {len(KJNODES_MAPPERS)} mappers") \ No newline at end of file +} \ No newline at end of file diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index f954cf54..156afb27 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -303,8 +303,8 @@ NODE_MAPPERS = { } } -def register_default_mappers() -> None: - """Register all default mappers from the NODE_MAPPERS dictionary""" +def register_all_mappers() -> None: + """Register all mappers from the NODE_MAPPERS dictionary""" for node_type, config in NODE_MAPPERS.items(): mapper = create_mapper( node_type=node_type, @@ -312,7 +312,7 @@ def register_default_mappers() -> None: transform_func=config["transform_func"] ) register_mapper(mapper) - logger.info(f"Registered {len(NODE_MAPPERS)} default node mappers") + logger.info(f"Registered {len(NODE_MAPPERS)} node mappers") # ============================================================================= # Extension Loading @@ -322,8 +322,8 @@ def load_extensions(ext_dir: str = None) -> None: """ Load mapper extensions from the specified directory - Extension files should define mappers using the create_mapper function - and then call register_mapper to add them to the registry. + Extension files should define a NODE_MAPPERS_EXT dictionary containing mapper configurations. + These will be added to the global NODE_MAPPERS dictionary and registered automatically. """ # Use default path if none provided if ext_dir is None: @@ -349,9 +349,19 @@ def load_extensions(ext_dir: str = None) -> None: if spec and spec.loader: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - logger.info(f"Loaded extension module: {filename}") + + # Check if the module defines NODE_MAPPERS_EXT + if hasattr(module, 'NODE_MAPPERS_EXT'): + # Add the extension mappers to the global NODE_MAPPERS dictionary + NODE_MAPPERS.update(module.NODE_MAPPERS_EXT) + logger.info(f"Added {len(module.NODE_MAPPERS_EXT)} mappers from extension: {filename}") + else: + logger.warning(f"Extension {filename} does not define NODE_MAPPERS_EXT dictionary") except Exception as e: logger.warning(f"Error loading extension {filename}: {e}") + + # Re-register all mappers after loading extensions + register_all_mappers() # Initialize the registry with default mappers -register_default_mappers() \ No newline at end of file +# register_default_mappers() \ No newline at end of file diff --git a/py/workflow/parser.py b/py/workflow/parser.py index 2c913173..bfae55a2 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -15,14 +15,13 @@ logger = logging.getLogger(__name__) class WorkflowParser: """Parser for ComfyUI workflows""" - def __init__(self, load_extensions_on_init: bool = True): + def __init__(self): """Initialize the parser with mappers""" self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results - # Load extensions if requested - if load_extensions_on_init: - load_extensions() + # Load extensions + load_extensions() def process_node(self, node_id: str, workflow: Dict) -> Any: """Process a single node and extract relevant information""" @@ -89,7 +88,7 @@ class WorkflowParser: # If we found SamplerCustomAdvanced nodes, return the first one if sampler_advanced_nodes: - logger.info(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}") + logger.debug(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}") return sampler_advanced_nodes[0] # If we have KSampler nodes, look for one with denoise=1.0 @@ -101,11 +100,11 @@ class WorkflowParser: # Check if denoise is 1.0 (allowing for small floating point differences) if abs(float(denoise) - 1.0) < 0.001: - logger.info(f"Found KSampler node with denoise=1.0: {node_id}") + logger.debug(f"Found KSampler node with denoise=1.0: {node_id}") return node_id # If no KSampler with denoise=1.0 found, use the first one - logger.info(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}") + logger.debug(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}") return ksampler_nodes[0] # No sampler nodes found @@ -168,7 +167,6 @@ class WorkflowParser: # Process sampler node to extract parameters sampler_result = self.process_node(sampler_node_id, workflow) - logger.info(f"Sampler result: {sampler_result}") if not sampler_result: return {} From 4933dbfb87cf7a231a7a8be641ee104d015a0604 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 11:14:05 +0800 Subject: [PATCH 06/18] Refactor ExifUtils by removing unused methods and imports - Removed the extract_user_comment and update_user_comment methods to streamline the ExifUtils class. - Cleaned up unnecessary imports and reduced code complexity, focusing on essential functionality for image metadata extraction. --- py/utils/exif_utils.py | 292 +---------------------------------------- 1 file changed, 2 insertions(+), 290 deletions(-) diff --git a/py/utils/exif_utils.py b/py/utils/exif_utils.py index 0a9550f5..c0de350e 100644 --- a/py/utils/exif_utils.py +++ b/py/utils/exif_utils.py @@ -1,51 +1,16 @@ import piexif import json import logging -from typing import Dict, Optional, Any +from typing import Optional from io import BytesIO import os from PIL import Image -import re logger = logging.getLogger(__name__) class ExifUtils: """Utility functions for working with EXIF data in images""" - @staticmethod - def extract_user_comment(image_path: str) -> Optional[str]: - """Extract UserComment field from image EXIF data""" - try: - # First try to open as image to check format - with Image.open(image_path) as img: - if img.format not in ['JPEG', 'TIFF', 'WEBP']: - # For non-JPEG/TIFF/WEBP images, try to get EXIF through PIL - exif = img._getexif() - if exif and piexif.ExifIFD.UserComment in exif: - user_comment = exif[piexif.ExifIFD.UserComment] - if isinstance(user_comment, bytes): - if user_comment.startswith(b'UNICODE\0'): - return user_comment[8:].decode('utf-16be') - return user_comment.decode('utf-8', errors='ignore') - return user_comment - return None - - # For JPEG/TIFF/WEBP, use piexif - exif_dict = piexif.load(image_path) - - if piexif.ExifIFD.UserComment in exif_dict.get('Exif', {}): - user_comment = exif_dict['Exif'][piexif.ExifIFD.UserComment] - if isinstance(user_comment, bytes): - if user_comment.startswith(b'UNICODE\0'): - user_comment = user_comment[8:].decode('utf-16be') - else: - user_comment = user_comment.decode('utf-8', errors='ignore') - return user_comment - return None - - except Exception as e: - return None - @staticmethod def extract_image_metadata(image_path: str) -> Optional[str]: """Extract metadata from image including UserComment or parameters field @@ -103,53 +68,6 @@ class ExifUtils: logger.error(f"Error extracting image metadata: {e}", exc_info=True) return None - @staticmethod - def update_user_comment(image_path: str, user_comment: str) -> str: - """Update UserComment field in image EXIF data""" - try: - # Load the image and its EXIF data - with Image.open(image_path) as img: - # Get original format - img_format = img.format - - # For WebP format, we need a different approach - if img_format == 'WEBP': - # WebP doesn't support standard EXIF through piexif - # We'll use PIL's exif parameter directly - exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + user_comment.encode('utf-16be')}} - exif_bytes = piexif.dump(exif_dict) - - # Save with the exif data - img.save(image_path, format='WEBP', exif=exif_bytes, quality=85) - return image_path - - # For other formats, use the standard approach - try: - exif_dict = piexif.load(img.info.get('exif', b'')) - except: - exif_dict = {'0th':{}, 'Exif':{}, 'GPS':{}, 'Interop':{}, '1st':{}} - - # If no Exif dictionary exists, create one - if 'Exif' not in exif_dict: - exif_dict['Exif'] = {} - - # Update the UserComment field - use UNICODE format - unicode_bytes = user_comment.encode('utf-16be') - user_comment_bytes = b'UNICODE\0' + unicode_bytes - - exif_dict['Exif'][piexif.ExifIFD.UserComment] = user_comment_bytes - - # Convert EXIF dict back to bytes - exif_bytes = piexif.dump(exif_dict) - - # Save the image with updated EXIF data - img.save(image_path, exif=exif_bytes) - - return image_path - except Exception as e: - logger.error(f"Error updating EXIF data in {image_path}: {e}") - return image_path - @staticmethod def update_image_metadata(image_path: str, metadata: str) -> str: """Update metadata in image's EXIF data or parameters fields @@ -394,210 +312,4 @@ class ExifUtils: if isinstance(image_data, str) and os.path.exists(image_data): with open(image_data, 'rb') as f: return f.read(), os.path.splitext(image_data)[1] - return image_data, '.jpg' - - @staticmethod - def _parse_comfyui_workflow(workflow_data: Any) -> Dict[str, Any]: - """ - Parse ComfyUI workflow data and extract relevant generation parameters - - Args: - workflow_data: Raw workflow data (string or dict) - - Returns: - Formatted generation parameters dictionary - """ - try: - # If workflow_data is a string, try to parse it as JSON - if isinstance(workflow_data, str): - try: - workflow_data = json.loads(workflow_data) - except json.JSONDecodeError: - logger.error("Failed to parse workflow data as JSON") - return {} - - # Now workflow_data should be a dictionary - if not isinstance(workflow_data, dict): - logger.error(f"Workflow data is not a dictionary: {type(workflow_data)}") - return {} - - # Initialize parameters dictionary with only the required fields - gen_params = { - "prompt": "", - "negative_prompt": "", - "steps": "", - "sampler": "", - "cfg_scale": "", - "seed": "", - "size": "", - "clip_skip": "" - } - - # First pass: find the KSampler node to get basic parameters and node references - # Store node references to follow for prompts - positive_ref = None - negative_ref = None - - for node_id, node_data in workflow_data.items(): - if not isinstance(node_data, dict): - continue - - # Extract node inputs if available - inputs = node_data.get("inputs", {}) - if not inputs: - continue - - # KSampler nodes contain most generation parameters and references to prompt nodes - if "KSampler" in node_data.get("class_type", ""): - # Extract basic sampling parameters - gen_params["steps"] = inputs.get("steps", "") - gen_params["cfg_scale"] = inputs.get("cfg", "") - gen_params["sampler"] = inputs.get("sampler_name", "") - gen_params["seed"] = inputs.get("seed", "") - if isinstance(gen_params["seed"], list) and len(gen_params["seed"]) > 1: - gen_params["seed"] = gen_params["seed"][1] # Use the actual value if it's a list - - # Get references to positive and negative prompt nodes - positive_ref = inputs.get("positive", "") - negative_ref = inputs.get("negative", "") - - # CLIPSetLastLayer contains clip_skip information - elif "CLIPSetLastLayer" in node_data.get("class_type", ""): - gen_params["clip_skip"] = inputs.get("stop_at_clip_layer", "") - if isinstance(gen_params["clip_skip"], int) and gen_params["clip_skip"] < 0: - # Convert negative layer index to positive clip skip value - gen_params["clip_skip"] = abs(gen_params["clip_skip"]) - - # Look for resolution information - elif "LatentImage" in node_data.get("class_type", "") or "Empty" in node_data.get("class_type", ""): - width = inputs.get("width", 0) - height = inputs.get("height", 0) - if width and height: - gen_params["size"] = f"{width}x{height}" - - # Some nodes have resolution as a string like "832x1216 (0.68)" - resolution = inputs.get("resolution", "") - if isinstance(resolution, str) and "x" in resolution: - gen_params["size"] = resolution.split(" ")[0] # Extract just the dimensions - - # Helper function to follow node references and extract text content - def get_text_from_node_ref(node_ref, workflow_data): - if not node_ref or not isinstance(node_ref, list) or len(node_ref) < 2: - return "" - - node_id, slot_idx = node_ref - - # If we can't find the node, return empty string - if node_id not in workflow_data: - return "" - - node = workflow_data[node_id] - inputs = node.get("inputs", {}) - - # Direct text input in CLIP Text Encode nodes - if "CLIPTextEncode" in node.get("class_type", ""): - text = inputs.get("text", "") - if isinstance(text, str): - return text - elif isinstance(text, list) and len(text) >= 2: - # If text is a reference to another node, follow it - return get_text_from_node_ref(text, workflow_data) - - # Other nodes might have text input with different field names - for field_name, field_value in inputs.items(): - if field_name == "text" and isinstance(field_value, str): - return field_value - elif isinstance(field_value, list) and len(field_value) >= 2 and field_name in ["text"]: - # If it's a reference to another node, follow it - return get_text_from_node_ref(field_value, workflow_data) - - return "" - - # Extract prompts by following references from KSampler node - if positive_ref: - gen_params["prompt"] = get_text_from_node_ref(positive_ref, workflow_data) - - if negative_ref: - gen_params["negative_prompt"] = get_text_from_node_ref(negative_ref, workflow_data) - - # Fallback: if we couldn't extract prompts via references, use the traditional method - if not gen_params["prompt"] or not gen_params["negative_prompt"]: - for node_id, node_data in workflow_data.items(): - if not isinstance(node_data, dict): - continue - - inputs = node_data.get("inputs", {}) - if not inputs: - continue - - if "CLIPTextEncode" in node_data.get("class_type", ""): - # Check for negative prompt nodes - title = node_data.get("_meta", {}).get("title", "").lower() - prompt_text = inputs.get("text", "") - - if isinstance(prompt_text, str): - if "negative" in title and not gen_params["negative_prompt"]: - gen_params["negative_prompt"] = prompt_text - elif prompt_text and not "negative" in title and not gen_params["prompt"]: - gen_params["prompt"] = prompt_text - - return gen_params - - except Exception as e: - logger.error(f"Error parsing ComfyUI workflow: {e}", exc_info=True) - return {} - - @staticmethod - def extract_comfyui_gen_params(image_path: str) -> Dict[str, Any]: - """ - Extract ComfyUI workflow data from PNG images and format for recipe data - Only extracts the specific generation parameters needed for recipes. - - Args: - image_path: Path to the ComfyUI-generated PNG image - - Returns: - Dictionary containing formatted generation parameters - """ - try: - # Check if the file exists and is accessible - if not os.path.exists(image_path): - logger.error(f"Image file not found: {image_path}") - return {} - - # Open the image to extract embedded workflow data - with Image.open(image_path) as img: - workflow_data = None - - # For PNG images, look for the ComfyUI workflow data in PNG chunks - if img.format == 'PNG': - # Check standard metadata fields that might contain workflow - if 'parameters' in img.info: - workflow_data = img.info['parameters'] - elif 'prompt' in img.info: - workflow_data = img.info['prompt'] - else: - # Look for other potential field names that might contain workflow data - for key in img.info: - if isinstance(key, str) and ('workflow' in key.lower() or 'comfy' in key.lower()): - workflow_data = img.info[key] - break - - # If no workflow data found in PNG chunks, try extract_image_metadata as fallback - if not workflow_data: - metadata = ExifUtils.extract_image_metadata(image_path) - if metadata and '{' in metadata and '}' in metadata: - # Try to extract JSON part - json_start = metadata.find('{') - json_end = metadata.rfind('}') + 1 - workflow_data = metadata[json_start:json_end] - - # Parse workflow data if found - if workflow_data: - return ExifUtils._parse_comfyui_workflow(workflow_data) - - return {} - - except Exception as e: - logger.error(f"Error extracting ComfyUI gen params from {image_path}: {e}", exc_info=True) - return {} \ No newline at end of file + return image_data, '.jpg' \ No newline at end of file From 435628ea592d702a944c6da362eceb65f2307048 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 14:13:24 +0800 Subject: [PATCH 07/18] Refactor WorkflowParser by removing unused methods --- py/workflow/parser.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/py/workflow/parser.py b/py/workflow/parser.py index bfae55a2..0a5a02ef 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -111,33 +111,6 @@ class WorkflowParser: logger.warning("No sampler nodes found in workflow") return None - def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str: - """Collect loras information from the model node chain""" - if not isinstance(model_input, list) or len(model_input) != 2: - return "" - - model_node_id, _ = model_input - # Convert node_id to string if it's an integer - if isinstance(model_node_id, int): - model_node_id = str(model_node_id) - - # Process the model node - model_result = self.process_node(model_node_id, workflow) - - # If this is a Lora Loader node, return the loras text - if model_result and isinstance(model_result, dict) and "loras" in model_result: - return model_result["loras"] - - # If not a lora loader, check the node's inputs for a model connection - node_data = workflow.get(model_node_id, {}) - inputs = node_data.get("inputs", {}) - - # If this node has a model input, follow that path - if "model" in inputs and isinstance(inputs["model"], list): - return self.collect_loras_from_model(inputs["model"], workflow) - - return "" - def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict: """ Parse the workflow and extract generation parameters From b508f51fcf5736244ce2a82f8f3964d8610b0090 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 14:13:53 +0800 Subject: [PATCH 08/18] checkpoint --- __init__.py | 4 +- py/nodes/save_image.py | 287 +++++++++++++++++++-- refs/jpeg_civitai_exif_userComment_example | 7 +- 3 files changed, 275 insertions(+), 23 deletions(-) diff --git a/__init__.py b/__init__.py index c697e12f..008219c8 100644 --- a/__init__.py +++ b/__init__.py @@ -2,13 +2,13 @@ from .py.lora_manager import LoraManager from .py.nodes.lora_loader import LoraManagerLoader from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.lora_stacker import LoraStacker -# from .py.nodes.save_image import SaveImage +from .py.nodes.save_image import SaveImage NODE_CLASS_MAPPINGS = { LoraManagerLoader.NAME: LoraManagerLoader, TriggerWordToggle.NAME: TriggerWordToggle, LoraStacker.NAME: LoraStacker, - # SaveImage.NAME: SaveImage + SaveImage.NAME: SaveImage } WEB_DIRECTORY = "./web/comfyui" diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 09fec5e4..63c7b925 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -1,16 +1,43 @@ import json +import os +import asyncio +import re +import numpy as np +import time from server import PromptServer # type: ignore +import folder_paths # type: ignore +from ..services.lora_scanner import LoraScanner +from ..config import config +from ..workflow.parser import WorkflowParser +from PIL import Image, PngImagePlugin +import piexif +from io import BytesIO class SaveImage: NAME = "Save Image (LoraManager)" CATEGORY = "Lora Manager/utils" - DESCRIPTION = "Experimental node to display image preview and print prompt and extra_pnginfo" + DESCRIPTION = "Save images with embedded generation metadata in compatible format" + + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + self.type = "output" + self.prefix_append = "" + self.compress_level = 4 + self.counter = 0 @classmethod def INPUT_TYPES(cls): return { "required": { "image": ("IMAGE",), + "filename_prefix": ("STRING", {"default": "ComfyUI"}), + "file_format": (["png", "jpeg", "webp"],), + }, + "optional": { + "lossless_webp": ("BOOLEAN", {"default": True}), + "quality": ("INT", {"default": 100, "min": 1, "max": 100}), + "save_workflow_json": ("BOOLEAN", {"default": False}), + "add_counter_to_filename": ("BOOLEAN", {"default": True}), }, "hidden": { "prompt": "PROMPT", @@ -18,24 +45,252 @@ class SaveImage: }, } - RETURN_TYPES = ("IMAGE",) - RETURN_NAMES = ("image",) + RETURN_TYPES = ("IMAGE", "STRING") + RETURN_NAMES = ("image", "filename") FUNCTION = "process_image" + OUTPUT_NODE = True - def process_image(self, image, prompt=None, extra_pnginfo=None): - # Print the prompt information - print("SaveImage Node - Prompt:") + async def get_lora_hash(self, lora_name): + """Get the lora hash from cache""" + scanner = await LoraScanner.get_instance() + cache = await scanner.get_cached_data() + + for item in cache.raw_data: + if item.get('file_name') == lora_name: + return item.get('sha256') + return None + + async def format_metadata(self, parsed_workflow): + """Format metadata in the requested format similar to userComment example""" + if not parsed_workflow: + return "" + + # Extract the prompt and negative prompt + prompt = parsed_workflow.get('prompt', '') + negative_prompt = parsed_workflow.get('negative_prompt', '') + + # Extract loras from the prompt if present + loras_text = parsed_workflow.get('loras', '') + lora_hashes = {} + + # If loras are found, add them on a new line after the prompt + if loras_text: + prompt_with_loras = f"{prompt}\n{loras_text}" + + # Extract lora names from the format + lora_matches = re.findall(r']+)>', loras_text) + + # Get hash for each lora + for lora_name, strength in lora_matches: + hash_value = await self.get_lora_hash(lora_name) + if hash_value: + lora_hashes[lora_name] = hash_value + else: + prompt_with_loras = prompt + + # Format the first part (prompt and loras) + metadata_parts = [prompt_with_loras] + + # Add negative prompt + if negative_prompt: + metadata_parts.append(f"Negative prompt: {negative_prompt}") + + # Format the second part (generation parameters) + params = [] + + # Add standard parameters in the correct order + if 'steps' in parsed_workflow: + params.append(f"Steps: {parsed_workflow.get('steps')}") + + if 'sampler' in parsed_workflow: + sampler = parsed_workflow.get('sampler') + # Convert ComfyUI sampler names to user-friendly names + sampler_mapping = { + 'euler': 'Euler', + 'euler_ancestral': 'Euler a', + 'dpm_2': 'DPM2', + 'dpm_2_ancestral': 'DPM2 a', + 'heun': 'Heun', + 'dpm_fast': 'DPM fast', + 'dpm_adaptive': 'DPM adaptive', + 'lms': 'LMS', + 'dpmpp_2s_ancestral': 'DPM++ 2S a', + 'dpmpp_sde': 'DPM++ SDE', + 'dpmpp_sde_gpu': 'DPM++ SDE', + 'dpmpp_2m': 'DPM++ 2M', + 'dpmpp_2m_sde': 'DPM++ 2M SDE', + 'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE', + 'ddim': 'DDIM' + } + sampler_name = sampler_mapping.get(sampler, sampler) + params.append(f"Sampler: {sampler_name}") + + if 'scheduler' in parsed_workflow: + scheduler = parsed_workflow.get('scheduler') + scheduler_mapping = { + 'normal': 'Simple', + 'karras': 'Karras', + 'exponential': 'Exponential', + 'sgm_uniform': 'SGM Uniform', + 'sgm_quadratic': 'SGM Quadratic' + } + scheduler_name = scheduler_mapping.get(scheduler, scheduler) + params.append(f"Schedule type: {scheduler_name}") + + # CFG scale (cfg in parsed_workflow) + if 'cfg_scale' in parsed_workflow: + params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}") + elif 'cfg' in parsed_workflow: + params.append(f"CFG scale: {parsed_workflow.get('cfg')}") + + # Seed + if 'seed' in parsed_workflow: + params.append(f"Seed: {parsed_workflow.get('seed')}") + + # Size + if 'size' in parsed_workflow: + params.append(f"Size: {parsed_workflow.get('size')}") + + # Model info + if 'checkpoint' in parsed_workflow: + # Extract basename without path + checkpoint = os.path.basename(parsed_workflow.get('checkpoint', '')) + # Remove extension if present + checkpoint = os.path.splitext(checkpoint)[0] + params.append(f"Model: {checkpoint}") + + # Add LoRA hashes if available + if lora_hashes: + lora_hash_parts = [] + for lora_name, hash_value in lora_hashes.items(): + lora_hash_parts.append(f"{lora_name}: {hash_value}") + + if lora_hash_parts: + params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"") + + # Combine all parameters with commas + metadata_parts.append(", ".join(params)) + + # Join all parts with a new line + return "\n".join(metadata_parts) + + def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None, + lossless_webp=True, quality=100, save_workflow_json=False, add_counter_to_filename=True): + """Save images with metadata""" + results = [] + + # Parse the workflow using the WorkflowParser + parser = WorkflowParser() if prompt: - print(json.dumps(prompt, indent=2)) + parsed_workflow = parser.parse_workflow(prompt) else: - print("No prompt information available") + parsed_workflow = {} + + # Get or create metadata asynchronously + metadata = asyncio.run(self.format_metadata(parsed_workflow)) - # Print the extra_pnginfo - print("\nSaveImage Node - Extra PNG Info:") - if extra_pnginfo: - print(json.dumps(extra_pnginfo, indent=2)) - else: - print("No extra PNG info available") + # Process each image + for i, image in enumerate(images): + # Convert the tensor image to numpy array + img = 255. * image.cpu().numpy() + img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) + + # Generate filename with counter if needed + if add_counter_to_filename: + filename = f"{filename_prefix}_{self.counter:05d}" + self.counter += 1 + else: + filename = f"{filename_prefix}" + + # Set file extension and prepare saving parameters + if file_format == "png": + filename += ".png" + file_extension = ".png" + save_kwargs = {"optimize": True, "compress_level": self.compress_level} + pnginfo = PngImagePlugin.PngInfo() + elif file_format == "jpeg": + filename += ".jpg" + file_extension = ".jpg" + save_kwargs = {"quality": quality, "optimize": True} + elif file_format == "webp": + filename += ".webp" + file_extension = ".webp" + save_kwargs = {"quality": quality, "lossless": lossless_webp} + + # Full save path + file_path = os.path.join(self.output_dir, filename) + + # Save the image with metadata + try: + if file_format == "png": + if metadata: + pnginfo.add_text("parameters", metadata) + if save_workflow_json and extra_pnginfo is not None: + workflow_json = json.dumps(extra_pnginfo) + pnginfo.add_text("workflow", workflow_json) + save_kwargs["pnginfo"] = pnginfo + img.save(file_path, format="PNG", **save_kwargs) + elif file_format == "jpeg": + # For JPEG, use piexif + if metadata: + try: + exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}} + exif_bytes = piexif.dump(exif_dict) + save_kwargs["exif"] = exif_bytes + except Exception as e: + print(f"Error adding EXIF data: {e}") + img.save(file_path, format="JPEG", **save_kwargs) + elif file_format == "webp": + # For WebP, also use piexif for metadata + if metadata: + try: + exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}} + exif_bytes = piexif.dump(exif_dict) + save_kwargs["exif"] = exif_bytes + except Exception as e: + print(f"Error adding EXIF data: {e}") + img.save(file_path, format="WEBP", **save_kwargs) + + results.append({ + "filename": filename, + "subfolder": "", + "type": self.type + }) + + # Notify UI about saved image + PromptServer.instance.send_sync("image", { + "filename": filename, + "subfolder": "", + "type": self.type, + }) + + except Exception as e: + print(f"Error saving image: {e}") - # Return the image unchanged - return (image,) + return results + + def process_image(self, image, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, + lossless_webp=True, quality=100, save_workflow_json=False, add_counter_to_filename=True): + """Process and save image with metadata""" + # Make sure the output directory exists + os.makedirs(self.output_dir, exist_ok=True) + + # Convert single image to list for consistent processing + images = [image[0]] if len(image.shape) == 3 else [img for img in image] + + # Save all images + results = self.save_images( + images, + filename_prefix, + file_format, + prompt, + extra_pnginfo, + lossless_webp, + quality, + save_workflow_json, + add_counter_to_filename + ) + + # Return the first saved filename and the original image + filename = results[0]["filename"] if results else "" + return (image, filename) \ No newline at end of file diff --git a/refs/jpeg_civitai_exif_userComment_example b/refs/jpeg_civitai_exif_userComment_example index b9fc40ab..bf5935fa 100644 --- a/refs/jpeg_civitai_exif_userComment_example +++ b/refs/jpeg_civitai_exif_userComment_example @@ -2,13 +2,10 @@ a dynamic and dramatic digital artwork featuring a stylized anthropomorphic whit Negative prompt: Steps: 30, Sampler: Undefined, CFG scale: 3.5, Seed: 90300501, Size: 832x1216, Clip skip: 2, Created Date: 2025-03-05T13:51:18.1770234Z, Civitai resources: [{"type":"checkpoint","modelVersionId":691639,"modelName":"FLUX","modelVersionName":"Dev"},{"type":"lora","weight":0.4,"modelVersionId":1202162,"modelName":"Velvet\u0027s Mythic Fantasy Styles | Flux \u002B Pony \u002B illustrious","modelVersionName":"Flux Gothic Lines"},{"type":"lora","weight":0.8,"modelVersionId":1470588,"modelName":"Velvet\u0027s Mythic Fantasy Styles | Flux \u002B Pony \u002B illustrious","modelVersionName":"Flux Retro"},{"type":"lora","weight":0.75,"modelVersionId":746484,"modelName":"Elden Ring - Yoshitaka Amano","modelVersionName":"V1"},{"type":"lora","weight":0.2,"modelVersionId":914935,"modelName":"Ink-style","modelVersionName":"ink-dynamic"},{"type":"lora","weight":0.2,"modelVersionId":1189379,"modelName":"Painterly Fantasy by ChronoKnight - [FLUX \u0026 IL]","modelVersionName":"FLUX"},{"type":"lora","weight":0.2,"modelVersionId":757030,"modelName":"Mezzotint Artstyle for Flux - by Ethanar","modelVersionName":"V1"}], Civitai metadata: {} -, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, dynamic angle, dutch angle, from below, epic half body portrait, gritty, wabi sabi, looking at viewer, woman is a geisha, parted lips, -holographic skin, holofoil glitter, faint, glowing, ethereal, neon hair, glowing hair, otherworldly glow, she is dangerous, - - - +holographic skin, holofoil glitter, faint, glowing, ethereal, neon hair, glowing hair, otherworldly glow, she is dangerous +, , , Negative prompt: score_6, score_5, score_4, bad quality, worst quality, worst detail, sketch, censorship, furry, window, headphones, Steps: 30, Sampler: Euler a, Schedule type: Simple, CFG scale: 7, Seed: 1405717592, Size: 832x1216, Model hash: 1ad6ca7f70, Model: waiNSFWIllustrious_v100, Denoising strength: 0.35, Hires CFG Scale: 5, Hires upscale: 1.3, Hires steps: 20, Hires upscaler: 4x-AnimeSharp, Lora hashes: "ck-shadow-circuit-IL: 88e247aa8c3d, ck-nc-cyberpunk-IL-000011: 935e6755554c, ck-neon-retrowave-IL: edafb9df7da1, ck-yoneyama-mai-IL-000014: 1b9305692a2e", Version: f2.0.1v1.10.1-1.10.1, Diffusion in Low Bits: Automatic (fp16 LoRA) From aec218ba00674096723ba1df2f975056b783aa24 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 15:08:36 +0800 Subject: [PATCH 09/18] Enhance SaveImage class with filename formatting and multiple image support - Updated the INPUT_TYPES to accept multiple images and modified the corresponding processing methods. - Introduced a new format_filename method to handle dynamic filename generation using metadata patterns. - Replaced save_workflow_json with embed_workflow for better clarity in saving workflow metadata. - Improved directory handling and filename generation logic to ensure proper file saving. --- py/nodes/save_image.py | 136 ++++++++++++++++++++++++++++++----------- 1 file changed, 101 insertions(+), 35 deletions(-) diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 63c7b925..b9e75a5a 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -3,11 +3,8 @@ import os import asyncio import re import numpy as np -import time -from server import PromptServer # type: ignore import folder_paths # type: ignore from ..services.lora_scanner import LoraScanner -from ..config import config from ..workflow.parser import WorkflowParser from PIL import Image, PngImagePlugin import piexif @@ -25,18 +22,21 @@ class SaveImage: self.compress_level = 4 self.counter = 0 + # Add pattern format regex for filename substitution + pattern_format = re.compile(r"(%[^%]+%)") + @classmethod def INPUT_TYPES(cls): return { "required": { - "image": ("IMAGE",), + "images": ("IMAGE",), "filename_prefix": ("STRING", {"default": "ComfyUI"}), "file_format": (["png", "jpeg", "webp"],), }, "optional": { "lossless_webp": ("BOOLEAN", {"default": True}), "quality": ("INT", {"default": 100, "min": 1, "max": 100}), - "save_workflow_json": ("BOOLEAN", {"default": False}), + "embed_workflow": ("BOOLEAN", {"default": False}), "add_counter_to_filename": ("BOOLEAN", {"default": True}), }, "hidden": { @@ -45,8 +45,8 @@ class SaveImage: }, } - RETURN_TYPES = ("IMAGE", "STRING") - RETURN_NAMES = ("image", "filename") + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) FUNCTION = "process_image" OUTPUT_NODE = True @@ -174,8 +174,73 @@ class SaveImage: # Join all parts with a new line return "\n".join(metadata_parts) + # credit to nkchocoai + # Add format_filename method to handle pattern substitution + def format_filename(self, filename, parsed_workflow): + """Format filename with metadata values""" + if not parsed_workflow: + return filename + + result = re.findall(self.pattern_format, filename) + for segment in result: + parts = segment.replace("%", "").split(":") + key = parts[0] + + if key == "seed" and 'seed' in parsed_workflow: + filename = filename.replace(segment, str(parsed_workflow.get('seed', ''))) + elif key == "width" and 'size' in parsed_workflow: + size = parsed_workflow.get('size', 'x') + w = size.split('x')[0] if isinstance(size, str) else size[0] + filename = filename.replace(segment, str(w)) + elif key == "height" and 'size' in parsed_workflow: + size = parsed_workflow.get('size', 'x') + h = size.split('x')[1] if isinstance(size, str) else size[1] + filename = filename.replace(segment, str(h)) + elif key == "pprompt" and 'prompt' in parsed_workflow: + prompt = parsed_workflow.get('prompt', '').replace("\n", " ") + if len(parts) >= 2: + length = int(parts[1]) + prompt = prompt[:length] + filename = filename.replace(segment, prompt.strip()) + elif key == "nprompt" and 'negative_prompt' in parsed_workflow: + prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ") + if len(parts) >= 2: + length = int(parts[1]) + prompt = prompt[:length] + filename = filename.replace(segment, prompt.strip()) + elif key == "model" and 'checkpoint' in parsed_workflow: + model = parsed_workflow.get('checkpoint', '') + model = os.path.splitext(os.path.basename(model))[0] + if len(parts) >= 2: + length = int(parts[1]) + model = model[:length] + filename = filename.replace(segment, model) + elif key == "date": + from datetime import datetime + now = datetime.now() + date_table = { + "yyyy": str(now.year), + "MM": str(now.month).zfill(2), + "dd": str(now.day).zfill(2), + "hh": str(now.hour).zfill(2), + "mm": str(now.minute).zfill(2), + "ss": str(now.second).zfill(2), + } + if len(parts) >= 2: + date_format = parts[1] + for k, v in date_table.items(): + date_format = date_format.replace(k, v) + filename = filename.replace(segment, date_format) + else: + date_format = "yyyyMMddhhmmss" + for k, v in date_table.items(): + date_format = date_format.replace(k, v) + filename = filename.replace(segment, date_format) + + return filename + def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, save_workflow_json=False, add_counter_to_filename=True): + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Save images with metadata""" results = [] @@ -189,44 +254,54 @@ class SaveImage: # Get or create metadata asynchronously metadata = asyncio.run(self.format_metadata(parsed_workflow)) + # Process filename_prefix with pattern substitution + filename_prefix = self.format_filename(filename_prefix, parsed_workflow) + # Process each image for i, image in enumerate(images): # Convert the tensor image to numpy array img = 255. * image.cpu().numpy() img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) + # Create directory if filename_prefix contains path separators + output_path = os.path.join(self.output_dir, filename_prefix) + if not os.path.exists(os.path.dirname(output_path)): + os.makedirs(os.path.dirname(output_path), exist_ok=True) + + # Use folder_paths.get_save_image_path for better counter handling + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( + filename_prefix, self.output_dir, img.width, img.height + ) + # Generate filename with counter if needed if add_counter_to_filename: - filename = f"{filename_prefix}_{self.counter:05d}" - self.counter += 1 - else: - filename = f"{filename_prefix}" - + filename += f"_{counter:05}" + # Set file extension and prepare saving parameters if file_format == "png": - filename += ".png" + file = filename + ".png" file_extension = ".png" save_kwargs = {"optimize": True, "compress_level": self.compress_level} pnginfo = PngImagePlugin.PngInfo() elif file_format == "jpeg": - filename += ".jpg" + file = filename + ".jpg" file_extension = ".jpg" save_kwargs = {"quality": quality, "optimize": True} elif file_format == "webp": - filename += ".webp" + file = filename + ".webp" file_extension = ".webp" save_kwargs = {"quality": quality, "lossless": lossless_webp} # Full save path - file_path = os.path.join(self.output_dir, filename) + file_path = os.path.join(full_output_folder, file) # Save the image with metadata try: if file_format == "png": if metadata: pnginfo.add_text("parameters", metadata) - if save_workflow_json and extra_pnginfo is not None: - workflow_json = json.dumps(extra_pnginfo) + if embed_workflow and extra_pnginfo is not None: + workflow_json = json.dumps(extra_pnginfo["workflow"]) pnginfo.add_text("workflow", workflow_json) save_kwargs["pnginfo"] = pnginfo img.save(file_path, format="PNG", **save_kwargs) @@ -252,31 +327,24 @@ class SaveImage: img.save(file_path, format="WEBP", **save_kwargs) results.append({ - "filename": filename, - "subfolder": "", + "filename": file, + "subfolder": subfolder, "type": self.type }) - # Notify UI about saved image - PromptServer.instance.send_sync("image", { - "filename": filename, - "subfolder": "", - "type": self.type, - }) - except Exception as e: print(f"Error saving image: {e}") return results - def process_image(self, image, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, - lossless_webp=True, quality=100, save_workflow_json=False, add_counter_to_filename=True): + def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None, + lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True): """Process and save image with metadata""" # Make sure the output directory exists os.makedirs(self.output_dir, exist_ok=True) # Convert single image to list for consistent processing - images = [image[0]] if len(image.shape) == 3 else [img for img in image] + images = [images[0]] if len(images.shape) == 3 else [img for img in images] # Save all images results = self.save_images( @@ -287,10 +355,8 @@ class SaveImage: extra_pnginfo, lossless_webp, quality, - save_workflow_json, + embed_workflow, add_counter_to_filename ) - # Return the first saved filename and the original image - filename = results[0]["filename"] if results else "" - return (image, filename) \ No newline at end of file + return (images,) \ No newline at end of file From 234c942f347626c59ec42fe9812ca22c74736fb6 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 17:01:10 +0800 Subject: [PATCH 10/18] Refactor transform functions and update node mappers - Moved and redefined transform functions for KSampler, EmptyLatentImage, CLIPTextEncode, and FluxGuidance to improve organization and maintainability. - Updated NODE_MAPPERS to include new input tracking for clip_skip in KSampler and added new transform functions for LatentUpscale and CLIPSetLastLayer. - Enhanced the transform_sampler_custom_advanced function to handle clip_skip extraction from model inputs. --- py/workflow/ext/comfyui_core.py | 128 +++++++++++++++++++++++++++++++- py/workflow/mappers.py | 95 ++---------------------- 2 files changed, 130 insertions(+), 93 deletions(-) diff --git a/py/workflow/ext/comfyui_core.py b/py/workflow/ext/comfyui_core.py index 125c29f0..5a116d59 100644 --- a/py/workflow/ext/comfyui_core.py +++ b/py/workflow/ext/comfyui_core.py @@ -94,6 +94,7 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict: if "model" in guider and isinstance(guider["model"], dict): result["checkpoint"] = guider["model"].get("checkpoint", "") result["loras"] = guider["model"].get("loras", "") + result["clip_skip"] = str(int(guider["model"].get("clip_skip", "-1")) * -1) # Extract dimensions from latent_image if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): @@ -107,6 +108,73 @@ def transform_sampler_custom_advanced(inputs: Dict) -> Dict: return result +def transform_ksampler(inputs: Dict) -> Dict: + """Transform function for KSampler nodes""" + result = { + "seed": str(inputs.get("seed", "")), + "steps": str(inputs.get("steps", "")), + "cfg": str(inputs.get("cfg", "")), + "sampler": inputs.get("sampler_name", ""), + "scheduler": inputs.get("scheduler", ""), + } + + # Process positive prompt + if "positive" in inputs: + result["prompt"] = inputs["positive"] + + # Process negative prompt + if "negative" in inputs: + result["negative_prompt"] = inputs["negative"] + + # Get dimensions from latent image + if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): + width = inputs["latent_image"].get("width", 0) + height = inputs["latent_image"].get("height", 0) + if width and height: + result["size"] = f"{width}x{height}" + + # Add clip_skip if present + if "clip_skip" in inputs: + result["clip_skip"] = str(inputs.get("clip_skip", "")) + + # Add guidance if present + if "guidance" in inputs: + result["guidance"] = str(inputs.get("guidance", "")) + + # Add model if present + if "model" in inputs: + result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "") + result["loras"] = inputs.get("model", {}).get("loras", "") + result["clip_skip"] = str(inputs.get("model", {}).get("clip_skip", -1) * -1) + + return result + +def transform_empty_latent(inputs: Dict) -> Dict: + """Transform function for EmptyLatentImage nodes""" + width = inputs.get("width", 0) + height = inputs.get("height", 0) + return {"width": width, "height": height, "size": f"{width}x{height}"} + +def transform_clip_text(inputs: Dict) -> Any: + """Transform function for CLIPTextEncode nodes""" + return inputs.get("text", "") + +def transform_flux_guidance(inputs: Dict) -> Dict: + """Transform function for FluxGuidance nodes""" + result = {} + + if "guidance" in inputs: + result["guidance"] = inputs["guidance"] + + if "conditioning" in inputs: + conditioning = inputs["conditioning"] + if isinstance(conditioning, str): + result["prompt"] = conditioning + else: + result["prompt"] = "Unknown prompt" + + return result + def transform_unet_loader(inputs: Dict) -> Dict: """Transform function for UNETLoader node""" unet_name = inputs.get("unet_name", "") @@ -117,6 +185,27 @@ def transform_checkpoint_loader(inputs: Dict) -> Dict: ckpt_name = inputs.get("ckpt_name", "") return {"checkpoint": ckpt_name} if ckpt_name else {} +def transform_latent_upscale_by(inputs: Dict) -> Dict: + """Transform function for LatentUpscaleBy node""" + result = {} + + width = inputs["samples"].get("width", 0) * inputs["scale_by"] + height = inputs["samples"].get("height", 0) * inputs["scale_by"] + result["width"] = width + result["height"] = height + result["size"] = f"{width}x{height}" + + return result + +def transform_clip_set_last_layer(inputs: Dict) -> Dict: + """Transform function for CLIPSetLastLayer node""" + result = {} + + if "stop_at_clip_layer" in inputs: + result["clip_skip"] = inputs["stop_at_clip_layer"] + + return result + # ============================================================================= # Node Mapper Definitions # ============================================================================= @@ -128,6 +217,31 @@ NODE_MAPPERS_EXT = { "inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"], "transform_func": transform_sampler_custom_advanced }, + "KSampler": { + "inputs_to_track": [ + "seed", "steps", "cfg", "sampler_name", "scheduler", + "denoise", "positive", "negative", "latent_image", + "model", "clip_skip" + ], + "transform_func": transform_ksampler + }, + # ComfyUI core nodes + "EmptyLatentImage": { + "inputs_to_track": ["width", "height", "batch_size"], + "transform_func": transform_empty_latent + }, + "EmptySD3LatentImage": { + "inputs_to_track": ["width", "height", "batch_size"], + "transform_func": transform_empty_latent + }, + "CLIPTextEncode": { + "inputs_to_track": ["text", "clip"], + "transform_func": transform_clip_text + }, + "FluxGuidance": { + "inputs_to_track": ["guidance", "conditioning"], + "transform_func": transform_flux_guidance + }, "RandomNoise": { "inputs_to_track": ["noise_seed"], "transform_func": transform_random_noise @@ -156,8 +270,16 @@ NODE_MAPPERS_EXT = { "inputs_to_track": ["ckpt_name"], "transform_func": transform_checkpoint_loader }, - "CheckpointLoader": { - "inputs_to_track": ["ckpt_name"], - "transform_func": transform_checkpoint_loader + "LatentUpscale": { + "inputs_to_track": ["width", "height"], + "transform_func": transform_empty_latent + }, + "LatentUpscaleBy": { + "inputs_to_track": ["samples", "scale_by"], + "transform_func": transform_latent_upscale_by + }, + "CLIPSetLastLayer": { + "inputs_to_track": ["clip", "stop_at_clip_layer"], + "transform_func": transform_clip_set_last_layer } } \ No newline at end of file diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 156afb27..22811ec6 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -97,55 +97,7 @@ def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'Workflo # Transform Functions # ============================================================================= -def transform_ksampler(inputs: Dict) -> Dict: - """Transform function for KSampler nodes""" - result = { - "seed": str(inputs.get("seed", "")), - "steps": str(inputs.get("steps", "")), - "cfg": str(inputs.get("cfg", "")), - "sampler": inputs.get("sampler_name", ""), - "scheduler": inputs.get("scheduler", ""), - } - - # Process positive prompt - if "positive" in inputs: - result["prompt"] = inputs["positive"] - - # Process negative prompt - if "negative" in inputs: - result["negative_prompt"] = inputs["negative"] - - # Get dimensions from latent image - if "latent_image" in inputs and isinstance(inputs["latent_image"], dict): - width = inputs["latent_image"].get("width", 0) - height = inputs["latent_image"].get("height", 0) - if width and height: - result["size"] = f"{width}x{height}" - - # Add clip_skip if present - if "clip_skip" in inputs: - result["clip_skip"] = str(inputs.get("clip_skip", "")) - # Add guidance if present - if "guidance" in inputs: - result["guidance"] = str(inputs.get("guidance", "")) - - # Add model if present - if "model" in inputs: - result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "") - result["loras"] = inputs.get("model", {}).get("loras", "") - - return result - -def transform_empty_latent(inputs: Dict) -> Dict: - """Transform function for EmptyLatentImage nodes""" - width = inputs.get("width", 0) - height = inputs.get("height", 0) - return {"width": width, "height": height, "size": f"{width}x{height}"} - -def transform_clip_text(inputs: Dict) -> Any: - """Transform function for CLIPTextEncode nodes""" - return inputs.get("text", "") def transform_lora_loader(inputs: Dict) -> Dict: """Transform function for LoraLoader nodes""" @@ -181,6 +133,9 @@ def transform_lora_loader(inputs: Dict) -> Dict: "checkpoint": inputs.get("model", {}).get("checkpoint", ""), "loras": " ".join(lora_texts) } + + if "clip" in inputs: + result["clip_skip"] = inputs["clip"].get("clip_skip", "-1") return result @@ -241,56 +196,16 @@ def transform_trigger_word_toggle(inputs: Dict) -> str: return ", ".join(active_words) -def transform_flux_guidance(inputs: Dict) -> Dict: - """Transform function for FluxGuidance nodes""" - result = {} - - if "guidance" in inputs: - result["guidance"] = inputs["guidance"] - - if "conditioning" in inputs: - conditioning = inputs["conditioning"] - if isinstance(conditioning, str): - result["prompt"] = conditioning - else: - result["prompt"] = "Unknown prompt" - - return result - # ============================================================================= # Node Mapper Definitions # ============================================================================= # Central definition of all supported node types and their configurations NODE_MAPPERS = { - # ComfyUI core nodes - "KSampler": { - "inputs_to_track": [ - "seed", "steps", "cfg", "sampler_name", "scheduler", - "denoise", "positive", "negative", "latent_image", - "model", "clip_skip" - ], - "transform_func": transform_ksampler - }, - "EmptyLatentImage": { - "inputs_to_track": ["width", "height", "batch_size"], - "transform_func": transform_empty_latent - }, - "EmptySD3LatentImage": { - "inputs_to_track": ["width", "height", "batch_size"], - "transform_func": transform_empty_latent - }, - "CLIPTextEncode": { - "inputs_to_track": ["text", "clip"], - "transform_func": transform_clip_text - }, - "FluxGuidance": { - "inputs_to_track": ["guidance", "conditioning"], - "transform_func": transform_flux_guidance - }, + # LoraManager nodes "Lora Loader (LoraManager)": { - "inputs_to_track": ["model", "loras", "lora_stack"], + "inputs_to_track": ["model", "clip", "loras", "lora_stack"], "transform_func": transform_lora_loader }, "Lora Stacker (LoraManager)": { From 0499ca13005d617ef3aa5136bde586b49829eeaf Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 17:02:11 +0800 Subject: [PATCH 11/18] Update process_node function to ignore type checking - Added a type: ignore comment to the process_node function to suppress type checking errors. - Removed the README.md file as it is no longer needed. --- py/workflow/README.md | 149 ----------------------------------------- py/workflow/mappers.py | 2 +- 2 files changed, 1 insertion(+), 150 deletions(-) delete mode 100644 py/workflow/README.md diff --git a/py/workflow/README.md b/py/workflow/README.md deleted file mode 100644 index 0cb78d54..00000000 --- a/py/workflow/README.md +++ /dev/null @@ -1,149 +0,0 @@ -# ComfyUI Workflow Parser - -本模块提供了一个灵活的解析系统,可以从ComfyUI工作流中提取生成参数和LoRA信息。 - -## 设计理念 - -工作流解析器基于以下设计原则: - -1. **模块化**: 每种节点类型由独立的mapper处理 -2. **可扩展性**: 通过扩展系统轻松添加新的节点类型支持 -3. **回溯**: 通过工作流图的模型输入路径跟踪LoRA节点 -4. **灵活性**: 适应不同的ComfyUI工作流结构 - -## 主要组件 - -### 1. NodeMapper - -`NodeMapper`是所有节点映射器的基类,定义了如何从工作流中提取节点信息: - -```python -class NodeMapper: - def __init__(self, node_type: str, inputs_to_track: List[str]): - self.node_type = node_type - self.inputs_to_track = inputs_to_track - - def process(self, node_id: str, node_data: Dict, workflow: Dict, parser) -> Any: - # 处理节点的通用逻辑 - ... - - def transform(self, inputs: Dict) -> Any: - # 由子类覆盖以提供特定转换 - return inputs -``` - -### 2. WorkflowParser - -主要解析类,通过跟踪工作流图来提取参数: - -```python -parser = WorkflowParser() -result = parser.parse_workflow("workflow.json") -``` - -### 3. 扩展系统 - -允许通过添加新的自定义mapper来扩展支持的节点类型: - -```python -# 在py/workflow/ext/中添加自定义mapper模块 -load_extensions() # 自动加载所有扩展 -``` - -## 使用方法 - -### 基本用法 - -```python -from workflow.parser import parse_workflow - -# 解析工作流并保存结果 -result = parse_workflow("workflow.json", "output.json") -``` - -### 自定义解析 - -```python -from workflow.parser import WorkflowParser -from workflow.mappers import register_mapper, load_extensions - -# 加载扩展 -load_extensions() - -# 创建解析器 -parser = WorkflowParser(load_extensions_on_init=False) # 不自动加载扩展 - -# 解析工作流 -result = parser.parse_workflow(workflow_data) -``` - -## 扩展系统 - -### 添加新的节点映射器 - -在`py/workflow/ext/`目录中创建Python文件,定义从`NodeMapper`继承的类: - -```python -# example_mapper.py -from ..mappers import NodeMapper - -class MyCustomNodeMapper(NodeMapper): - def __init__(self): - super().__init__( - node_type="MyCustomNode", # 节点的class_type - inputs_to_track=["param1", "param2"] # 要提取的参数 - ) - - def transform(self, inputs: Dict) -> Any: - # 处理提取的参数 - return { - "custom_param": inputs.get("param1", "default") - } -``` - -扩展系统会自动加载和注册这些映射器。 - -### LoraManager节点说明 - -LoraManager相关节点的处理方式: - -1. **Lora Loader**: 处理`loras`数组,过滤出`active=true`的条目,和`lora_stack`输入 -2. **Lora Stacker**: 处理`loras`数组和已有的`lora_stack`,构建叠加的LoRA -3. **TriggerWord Toggle**: 从`toggle_trigger_words`中提取`active=true`的条目 - -## 输出格式 - -解析器生成的输出格式如下: - -```json -{ - "gen_params": { - "prompt": "...", - "negative_prompt": "", - "steps": "25", - "sampler": "dpmpp_2m", - "scheduler": "beta", - "cfg": "1", - "seed": "48", - "guidance": 3.5, - "size": "896x1152", - "clip_skip": "2" - }, - "loras": " " -} -``` - -## 高级用法 - -### 直接注册映射器 - -```python -from workflow.mappers import register_mapper -from workflow.mappers import NodeMapper - -# 创建自定义映射器 -class CustomMapper(NodeMapper): - # ...实现映射器 - -# 注册映射器 -register_mapper(CustomMapper()) \ No newline at end of file diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 22811ec6..33bc9e9b 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -46,7 +46,7 @@ def get_all_mappers() -> Dict[str, Dict]: # Node Processing Function # ============================================================================= -def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: +def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore """Process a node using its mapper and extract relevant information""" node_type = node_data.get("class_type") mapper = get_mapper(node_type) From 73686d4146ecbb64ac8eee90581d244db2caf5d5 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Wed, 2 Apr 2025 17:37:16 +0800 Subject: [PATCH 12/18] Enhance modal and settings functionality with default LoRA root selection - Updated modal styles for improved layout and added select control for default LoRA root. - Modified DownloadManager, ImportManager, MoveManager, and SettingsManager to retrieve and set the default LoRA root from storage. - Introduced asynchronous loading of LoRA roots in SettingsManager to dynamically populate the select options. - Improved user experience by allowing users to set a default LoRA root for downloads, imports, and moves. --- static/css/components/modal.css | 40 +++++++++++++++++++--- static/js/managers/DownloadManager.js | 8 ++++- static/js/managers/ImportManager.js | 7 ++++ static/js/managers/MoveManager.js | 7 ++++ static/js/managers/SettingsManager.js | 48 +++++++++++++++++++++++++-- templates/components/modals.html | 20 +++++++++++ 6 files changed, 123 insertions(+), 7 deletions(-) diff --git a/static/css/components/modal.css b/static/css/components/modal.css index 141713ef..b23264b1 100644 --- a/static/css/components/modal.css +++ b/static/css/components/modal.css @@ -341,8 +341,7 @@ body.modal-open { .setting-item { display: flex; - justify-content: space-between; - align-items: flex-start; + flex-direction: column; margin-bottom: var(--space-2); padding: var(--space-1); border-radius: var(--border-radius-xs); @@ -357,7 +356,8 @@ body.modal-open { } .setting-info { - flex: 1; + margin-bottom: var(--space-1); + width: 100%; } .setting-info label { @@ -367,7 +367,39 @@ body.modal-open { } .setting-control { - padding-left: var(--space-2); + width: 100%; + margin-bottom: var(--space-1); +} + +/* Select Control Styles */ +.select-control { + width: 100%; +} + +.select-control select { + width: 100%; + padding: 6px 10px; + border-radius: var(--border-radius-xs); + border: 1px solid var(--border-color); + background-color: var(--lora-surface); + color: var(--text-color); + font-size: 0.95em; +} + +/* Fix dark theme select dropdown text color */ +[data-theme="dark"] .select-control select { + background-color: rgba(30, 30, 30, 0.9); + color: var(--text-color); +} + +[data-theme="dark"] .select-control select option { + background-color: #2d2d2d; + color: var(--text-color); +} + +.select-control select:focus { + border-color: var(--lora-accent); + outline: none; } /* Toggle Switch */ diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index b6ccf89b..c0891bdf 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -3,7 +3,7 @@ import { showToast } from '../utils/uiHelpers.js'; import { LoadingManager } from './LoadingManager.js'; import { state } from '../state/index.js'; import { resetAndReload } from '../api/loraApi.js'; - +import { getStorageItem } from '../utils/storageHelpers.js'; export class DownloadManager { constructor() { this.currentVersion = null; @@ -246,6 +246,12 @@ export class DownloadManager { `` ).join(''); + // Set default lora root if available + const defaultRoot = getStorageItem('settings', {}).default_loras_root; + if (defaultRoot && data.roots.includes(defaultRoot)) { + loraRoot.value = defaultRoot; + } + // Initialize folder browser after loading roots this.initializeFolderBrowser(); } catch (error) { diff --git a/static/js/managers/ImportManager.js b/static/js/managers/ImportManager.js index b62e76f1..8a97880f 100644 --- a/static/js/managers/ImportManager.js +++ b/static/js/managers/ImportManager.js @@ -1,6 +1,7 @@ import { modalManager } from './ModalManager.js'; import { showToast } from '../utils/uiHelpers.js'; import { LoadingManager } from './LoadingManager.js'; +import { getStorageItem } from '../utils/storageHelpers.js'; export class ImportManager { constructor() { @@ -779,6 +780,12 @@ export class ImportManager { loraRoot.innerHTML = rootsData.roots.map(root => `` ).join(''); + + // Set default lora root if available + const defaultRoot = getStorageItem('settings', {}).default_loras_root; + if (defaultRoot && rootsData.roots.includes(defaultRoot)) { + loraRoot.value = defaultRoot; + } } // Fetch folders diff --git a/static/js/managers/MoveManager.js b/static/js/managers/MoveManager.js index 532066f5..b98eaa59 100644 --- a/static/js/managers/MoveManager.js +++ b/static/js/managers/MoveManager.js @@ -2,6 +2,7 @@ import { showToast } from '../utils/uiHelpers.js'; import { state } from '../state/index.js'; import { resetAndReload } from '../api/loraApi.js'; import { modalManager } from './ModalManager.js'; +import { getStorageItem } from '../utils/storageHelpers.js'; class MoveManager { constructor() { @@ -87,6 +88,12 @@ class MoveManager { `` ).join(''); + // Set default lora root if available + const defaultRoot = getStorageItem('settings', {}).default_loras_root; + if (defaultRoot && data.roots.includes(defaultRoot)) { + this.loraRootSelect.value = defaultRoot; + } + this.updatePathPreview(); modalManager.showModal('moveModal'); diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index fd14fe54..50617f52 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -53,7 +53,7 @@ export class SettingsManager { this.initialized = true; } - loadSettingsToUI() { + async loadSettingsToUI() { // Set frontend settings from state const blurMatureContentCheckbox = document.getElementById('blurMatureContent'); if (blurMatureContentCheckbox) { @@ -65,10 +65,52 @@ export class SettingsManager { // Sync with state (backend will set this via template) state.global.settings.show_only_sfw = showOnlySFWCheckbox.checked; } + + // Load default lora root + await this.loadLoraRoots(); // Backend settings are loaded from the template directly } + async loadLoraRoots() { + try { + const defaultLoraRootSelect = document.getElementById('defaultLoraRoot'); + if (!defaultLoraRootSelect) return; + + // Fetch lora roots + const response = await fetch('/api/lora-roots'); + if (!response.ok) { + throw new Error('Failed to fetch LoRA roots'); + } + + const data = await response.json(); + if (!data.roots || data.roots.length === 0) { + throw new Error('No LoRA roots found'); + } + + // Clear existing options except the first one (No Default) + const noDefaultOption = defaultLoraRootSelect.querySelector('option[value=""]'); + defaultLoraRootSelect.innerHTML = ''; + defaultLoraRootSelect.appendChild(noDefaultOption); + + // Add options for each root + data.roots.forEach(root => { + const option = document.createElement('option'); + option.value = root; + option.textContent = root; + defaultLoraRootSelect.appendChild(option); + }); + + // Set selected value from settings + const defaultRoot = state.global.settings.default_loras_root || ''; + defaultLoraRootSelect.value = defaultRoot; + + } catch (error) { + console.error('Error loading LoRA roots:', error); + showToast('Failed to load LoRA roots: ' + error.message, 'error'); + } + } + toggleSettings() { if (this.isOpen) { modalManager.closeModal('settingsModal'); @@ -81,14 +123,16 @@ export class SettingsManager { async saveSettings() { // Get frontend settings from UI const blurMatureContent = document.getElementById('blurMatureContent').checked; + const showOnlySFW = document.getElementById('showOnlySFW').checked; + const defaultLoraRoot = document.getElementById('defaultLoraRoot').value; // Get backend settings const apiKey = document.getElementById('civitaiApiKey').value; - const showOnlySFW = document.getElementById('showOnlySFW').checked; // Update frontend state and save to localStorage state.global.settings.blurMatureContent = blurMatureContent; state.global.settings.show_only_sfw = showOnlySFW; + state.global.settings.default_loras_root = defaultLoraRoot; // Save settings to localStorage setStorageItem('settings', state.global.settings); diff --git a/templates/components/modals.html b/templates/components/modals.html index 3fdf148d..7b82f4bc 100644 --- a/templates/components/modals.html +++ b/templates/components/modals.html @@ -66,6 +66,26 @@ + + +
+

Folder Settings

+ +
+
+ +
+
+ +
+
+ Set the default LoRA root directory for downloads, imports and moves +
+
+