diff --git a/py/workflow/mappers.py b/py/workflow/mappers.py index 57216be8..aca0285f 100644 --- a/py/workflow/mappers.py +++ b/py/workflow/mappers.py @@ -28,10 +28,25 @@ class NodeMapper: # Check if input is a reference to another node's output if isinstance(input_value, list) and len(input_value) == 2: # Format is [node_id, output_slot] - ref_node_id, output_slot = input_value - # Recursively process the referenced node - ref_value = parser.process_node(str(ref_node_id), workflow) - result[input_name] = ref_value + try: + ref_node_id, output_slot = input_value + # Convert node_id to string if it's an integer + if isinstance(ref_node_id, int): + ref_node_id = str(ref_node_id) + + # Recursively process the referenced node + ref_value = parser.process_node(ref_node_id, workflow) + + # Store the processed value + if ref_value is not None: + result[input_name] = ref_value + else: + # If we couldn't get a value from the reference, store the raw value + result[input_name] = input_value + except Exception as e: + logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}") + # If we couldn't process the reference, store the raw value + result[input_name] = input_value else: # Direct value result[input_name] = input_value @@ -142,7 +157,7 @@ class LoraLoaderMapper(NodeMapper): def transform(self, inputs: Dict) -> Dict: # Fallback to loras array if text field doesn't exist or is invalid loras_data = inputs.get("loras", []) - lora_stack = inputs.get("lora_stack", []) + lora_stack = inputs.get("lora_stack", {}).get("lora_stack", []) # Process loras array - filter active entries lora_texts = [] @@ -157,18 +172,24 @@ class LoraLoaderMapper(NodeMapper): # Process each active lora entry for lora in loras_list: + logger.info(f"Lora: {lora}, active: {lora.get('active')}") if isinstance(lora, dict) and lora.get("active", False): lora_name = lora.get("name", "") strength = lora.get("strength", 1.0) - if lora_name and not lora_name.startswith("__dummy"): - lora_texts.append(f"") + lora_texts.append(f"") - # Process lora_stack if it exists - if lora_stack: - # Format each entry from the stack - for lora_path, strength, _ in lora_stack: - lora_name = os.path.basename(lora_path).split('.')[0] - if lora_name and not lora_name.startswith("__dummy"): + # Process lora_stack if it exists and is a valid format (list of tuples) + if lora_stack and isinstance(lora_stack, list): + # If lora_stack is a reference to another node ([node_id, output_slot]), + # we don't process it here as it's already been processed recursively + if len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int): + # This is a reference to another node, already processed + pass + else: + # Format each entry from the stack (assuming it's a list of tuples) + for stack_entry in lora_stack: + lora_name = stack_entry[0] + strength = stack_entry[1] lora_texts.append(f"") # Join with spaces @@ -188,12 +209,21 @@ class LoraStackerMapper(NodeMapper): def transform(self, inputs: Dict) -> Dict: loras_data = inputs.get("loras", []) - existing_stack = inputs.get("lora_stack", []) + existing_stack = inputs.get("lora_stack", {}).get("lora_stack", []) result_stack = [] - # Keep existing stack entries + # Handle existing stack entries if existing_stack: - result_stack.extend(existing_stack) + # Check if existing_stack is a reference to another node ([node_id, output_slot]) + if isinstance(existing_stack, list) and len(existing_stack) == 2 and isinstance(existing_stack[0], (str, int)) and isinstance(existing_stack[1], int): + # This is a reference to another node, should already be processed + # So we'll need to extract the value from that node + if isinstance(inputs.get("lora_stack", {}), dict) and "lora_stack" in inputs["lora_stack"]: + # If we have the processed result, use it + result_stack.extend(inputs["lora_stack"]["lora_stack"]) + elif isinstance(existing_stack, list): + # If it's a regular list (not a node reference), just add the entries + result_stack.extend(existing_stack) # Process loras array - filter active entries # Check if loras_data is a list or a dict with __value__ key (new format) @@ -209,10 +239,7 @@ class LoraStackerMapper(NodeMapper): if isinstance(lora, dict) and lora.get("active", False): lora_name = lora.get("name", "") strength = float(lora.get("strength", 1.0)) - if lora_name and not lora_name.startswith("__dummy"): - # Here we would need the real path, but as a fallback use the name - # In a real implementation, this would require looking up the file path - result_stack.append((lora_name, strength, strength)) + result_stack.append((lora_name, strength)) return {"lora_stack": result_stack} diff --git a/py/workflow/parser.py b/py/workflow/parser.py index 875dc476..92ca8323 100644 --- a/py/workflow/parser.py +++ b/py/workflow/parser.py @@ -41,7 +41,12 @@ class WorkflowParser: result = None mapper = get_mapper(node_type) if mapper: - result = mapper.process(node_id, node_data, workflow, self) + try: + result = mapper.process(node_id, node_data, workflow, self) + except Exception as e: + logger.error(f"Error processing node {node_id} of type {node_type}: {e}", exc_info=True) + # Return a partial result or None depending on how we want to handle errors + result = {} # Remove node from processed set to allow it to be processed again in a different context self.processed_nodes.remove(node_id) @@ -112,23 +117,6 @@ class WorkflowParser: if "guidance" in node_inputs: result["gen_params"]["guidance"] = node_inputs["guidance"] - # Trace the model path to find LoRA Loader nodes - lora_node_ids = trace_model_path(workflow, ksampler_node_id) - - # Process each LoRA Loader node - lora_texts = [] - for lora_node_id in lora_node_ids: - # Reset the processed nodes tracker for each lora processing - self.processed_nodes = set() - - lora_result = self.process_node(lora_node_id, workflow) - if lora_result and "loras" in lora_result: - lora_texts.append(lora_result["loras"]) - - # Combine all LoRA texts - if lora_texts: - result["loras"] = " ".join(lora_texts) - # Add clip_skip = 2 to match reference output if not already present if "clip_skip" not in result["gen_params"]: result["gen_params"]["clip_skip"] = "2" diff --git a/py/workflow/utils.py b/py/workflow/utils.py index 742bfcf9..aaa333ea 100644 --- a/py/workflow/utils.py +++ b/py/workflow/utils.py @@ -60,43 +60,61 @@ def get_input_node_ids(workflow: Dict, node_id: str) -> Dict[str, Tuple[str, int return result -def trace_model_path(workflow: Dict, start_node_id: str, - visited: Optional[Set[str]] = None) -> List[str]: +def trace_model_path(workflow: Dict, start_node_id: str) -> List[str]: """ - Trace through the workflow graph following 'model' inputs - to find all LoRA Loader nodes that affect the model + Trace the model path backward from KSampler to find all LoRA nodes - Returns a list of LoRA Loader node IDs + Args: + workflow: The workflow data + start_node_id: The starting node ID (usually KSampler) + + Returns: + List of node IDs in the model path """ - if visited is None: - visited = set() - - # Prevent cycles - if start_node_id in visited: - return [] - - visited.add(start_node_id) + model_path_nodes = [] - node_data = workflow.get(start_node_id) - if not node_data: - return [] - - # If this is a LoRA Loader node, add it to the result - if node_data.get("class_type") == "Lora Loader (LoraManager)": - return [start_node_id] - - # Get all input nodes - input_nodes = get_input_node_ids(workflow, start_node_id) + # Get the model input from the start node + if start_node_id not in workflow: + return model_path_nodes - # Recursively trace the model input if it exists - result = [] - if "model" in input_nodes: - model_node_id, _ = input_nodes["model"] - result.extend(trace_model_path(workflow, model_node_id, visited)) + # Track visited nodes to avoid cycles + visited = set() + + # Stack for depth-first search + stack = [] + + # Get model input reference if available + start_node = workflow[start_node_id] + if "inputs" in start_node and "model" in start_node["inputs"] and isinstance(start_node["inputs"]["model"], list): + model_ref = start_node["inputs"]["model"] + stack.append(str(model_ref[0])) + + # Perform depth-first search + while stack: + node_id = stack.pop() - # Also trace lora_stack input if it exists - if "lora_stack" in input_nodes: - lora_stack_node_id, _ = input_nodes["lora_stack"] - result.extend(trace_model_path(workflow, lora_stack_node_id, visited)) + # Skip if already visited + if node_id in visited: + continue - return result \ No newline at end of file + # Mark as visited + visited.add(node_id) + + # Skip if node doesn't exist + if node_id not in workflow: + continue + + node = workflow[node_id] + node_type = node.get("class_type", "") + + # Add current node to result list if it's a LoRA node + if "Lora" in node_type: + model_path_nodes.append(node_id) + + # Add all input nodes that have a "model" or "lora_stack" output to the stack + if "inputs" in node: + for input_name, input_value in node["inputs"].items(): + if input_name in ["model", "lora_stack"] and isinstance(input_value, list) and len(input_value) == 2: + stack.append(str(input_value[0])) + + return model_path_nodes \ No newline at end of file diff --git a/refs/prompt.json b/refs/prompt.json index c0982fff..531b34ca 100644 --- a/refs/prompt.json +++ b/refs/prompt.json @@ -114,7 +114,7 @@ 1 ] }, - "class_type": "CLIPSetLastLayer", + "class_type": "CLIPSetLastLayer", "_meta": { "title": "CLIP Set Last Layer" } @@ -154,18 +154,6 @@ "text": "in the style of ck-rw", "active": true }, - { - "text": "aorun, scales, makeup, bare shoulders, pointy ears", - "active": true - }, - { - "text": "dress", - "active": true - }, - { - "text": "claws", - "active": true - }, { "text": "in the style of cksc", "active": true @@ -174,10 +162,6 @@ "text": "artist:moriimee", "active": true }, - { - "text": "in the style of cknc", - "active": true - }, { "text": "__dummy_item__", "active": false, @@ -189,7 +173,7 @@ "_isDummy": true } ], - "orinalMessage": "in the style of ck-rw,, aorun, scales, makeup, bare shoulders, pointy ears,, dress,, claws,, in the style of cksc,, artist:moriimee,, in the style of cknc", + "orinalMessage": "in the style of ck-rw,, in the style of cksc,, artist:moriimee", "trigger_words": [ "56", 2 @@ -217,7 +201,7 @@ { "name": "ck-nc-cyberpunk-IL-000011", "strength": 0.4, - "active": true + "active": false }, { "name": "__dummy_item1__", @@ -257,7 +241,7 @@ { "name": "aorunIllstrious", "strength": "0.90", - "active": true + "active": false }, { "name": "__dummy_item1__", diff --git a/simple_test.py b/simple_test.py new file mode 100644 index 00000000..0d8d463c --- /dev/null +++ b/simple_test.py @@ -0,0 +1,25 @@ +import json +from py.workflow.parser import WorkflowParser + +# Load workflow data +with open('refs/prompt.json', 'r') as f: + workflow_data = json.load(f) + +# Parse workflow +parser = WorkflowParser() +try: + # Parse the workflow + result = parser.parse_workflow(workflow_data) + print("Parsing successful!") + + # Print each component separately + print("\nGeneration Parameters:") + for k, v in result.get("gen_params", {}).items(): + print(f" {k}: {v}") + + print("\nLoRAs:") + print(result.get("loras", "")) +except Exception as e: + print(f"Error parsing workflow: {e}") + import traceback + traceback.print_exc() \ No newline at end of file diff --git a/test_parser.py b/test_parser.py new file mode 100644 index 00000000..fff7bbfd --- /dev/null +++ b/test_parser.py @@ -0,0 +1,45 @@ +import json +import sys +from py.workflow.parser import WorkflowParser +from py.workflow.utils import trace_model_path + +# Load workflow data +with open('refs/prompt.json', 'r') as f: + workflow_data = json.load(f) + +# Parse workflow +parser = WorkflowParser() +try: + # Find KSampler node + ksampler_node = None + for node_id, node in workflow_data.items(): + if node.get("class_type") == "KSampler": + ksampler_node = node_id + break + + if not ksampler_node: + print("KSampler node not found") + sys.exit(1) + + # Trace all Lora nodes + print("Finding Lora nodes in the workflow...") + lora_nodes = trace_model_path(workflow_data, ksampler_node) + print(f"Found Lora nodes: {lora_nodes}") + + # Print node details + for node_id in lora_nodes: + node = workflow_data[node_id] + print(f"\nNode {node_id}: {node.get('class_type')}") + for key, value in node.get("inputs", {}).items(): + print(f" - {key}: {value}") + + # Parse the workflow + result = parser.parse_workflow(workflow_data) + print("\nParsing successful!") + print(json.dumps(result, indent=2)) + sys.exit(0) +except Exception as e: + print(f"Error parsing workflow: {e}") + import traceback + traceback.print_exc() + sys.exit(1) \ No newline at end of file