mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-21 21:22:13 -03:00
Merge pull request #182 from LucianoCirino/XY-Input-LoRA-Stacks-Fix
XY Input LoRA Stacks Fix
This commit is contained in:
@@ -852,7 +852,7 @@ class TSC_KSampler:
|
|||||||
else (os.path.basename(v[0]), v[1]) if v[2] is None
|
else (os.path.basename(v[0]), v[1]) if v[2] is None
|
||||||
else (os.path.basename(v[0]),) + v[1:] for v in value]
|
else (os.path.basename(v[0]),) + v[1:] for v in value]
|
||||||
|
|
||||||
elif type_ == "LoRA" and isinstance(value, list):
|
elif (type_ == "LoRA" or type_ == "LoRA Stacks") and isinstance(value, list):
|
||||||
# Return only the first Tuple of each inner array
|
# Return only the first Tuple of each inner array
|
||||||
return [[(os.path.basename(v[0][0]),) + v[0][1:], "..."] if len(v) > 1
|
return [[(os.path.basename(v[0][0]),) + v[0][1:], "..."] if len(v) > 1
|
||||||
else [(os.path.basename(v[0][0]),) + v[0][1:]] for v in value]
|
else [(os.path.basename(v[0][0]),) + v[0][1:]] for v in value]
|
||||||
@@ -953,6 +953,7 @@ class TSC_KSampler:
|
|||||||
"Checkpoint",
|
"Checkpoint",
|
||||||
"Refiner",
|
"Refiner",
|
||||||
"LoRA",
|
"LoRA",
|
||||||
|
"LoRA Stacks",
|
||||||
"VAE",
|
"VAE",
|
||||||
]
|
]
|
||||||
conditioners = {
|
conditioners = {
|
||||||
@@ -997,6 +998,9 @@ class TSC_KSampler:
|
|||||||
# Create a list of tuples with types and values
|
# Create a list of tuples with types and values
|
||||||
type_value_pairs = [(X_type, X_value.copy()), (Y_type, Y_value.copy())]
|
type_value_pairs = [(X_type, X_value.copy()), (Y_type, Y_value.copy())]
|
||||||
|
|
||||||
|
# Replace "LoRA Stacks" with "LoRA"
|
||||||
|
type_value_pairs = [('LoRA' if t == 'LoRA Stacks' else t, v) for t, v in type_value_pairs]
|
||||||
|
|
||||||
# Iterate over type-value pairs
|
# Iterate over type-value pairs
|
||||||
for t, v in type_value_pairs:
|
for t, v in type_value_pairs:
|
||||||
if t in dict_map:
|
if t in dict_map:
|
||||||
@@ -1039,7 +1043,7 @@ class TSC_KSampler:
|
|||||||
elif X_type == "Refiner":
|
elif X_type == "Refiner":
|
||||||
ckpt_dict = []
|
ckpt_dict = []
|
||||||
lora_dict = []
|
lora_dict = []
|
||||||
elif X_type == "LoRA":
|
elif X_type in ("LoRA", "LoRA Stacks"):
|
||||||
ckpt_dict = []
|
ckpt_dict = []
|
||||||
refn_dict = []
|
refn_dict = []
|
||||||
|
|
||||||
@@ -1202,7 +1206,7 @@ class TSC_KSampler:
|
|||||||
text = f"RefClipSkip ({refiner_clip_skip[0]})"
|
text = f"RefClipSkip ({refiner_clip_skip[0]})"
|
||||||
|
|
||||||
elif "LoRA" in var_type:
|
elif "LoRA" in var_type:
|
||||||
if not lora_stack:
|
if not lora_stack or var_type == "LoRA Stacks":
|
||||||
lora_stack = var.copy()
|
lora_stack = var.copy()
|
||||||
else:
|
else:
|
||||||
# Updating the first tuple of lora_stack
|
# Updating the first tuple of lora_stack
|
||||||
@@ -1212,7 +1216,7 @@ class TSC_KSampler:
|
|||||||
lora_name, lora_model_wt, lora_clip_wt = lora_stack[0]
|
lora_name, lora_model_wt, lora_clip_wt = lora_stack[0]
|
||||||
lora_filename = os.path.splitext(os.path.basename(lora_name))[0]
|
lora_filename = os.path.splitext(os.path.basename(lora_name))[0]
|
||||||
|
|
||||||
if var_type == "LoRA":
|
if var_type == "LoRA" or var_type == "LoRA Stacks":
|
||||||
if len(lora_stack) == 1:
|
if len(lora_stack) == 1:
|
||||||
lora_model_wt = format(float(lora_model_wt), ".2f").rstrip('0').rstrip('.')
|
lora_model_wt = format(float(lora_model_wt), ".2f").rstrip('0').rstrip('.')
|
||||||
lora_clip_wt = format(float(lora_clip_wt), ".2f").rstrip('0').rstrip('.')
|
lora_clip_wt = format(float(lora_clip_wt), ".2f").rstrip('0').rstrip('.')
|
||||||
@@ -1335,7 +1339,7 @@ class TSC_KSampler:
|
|||||||
# Note: Index is held at 0 when Y_type == "Nothing"
|
# Note: Index is held at 0 when Y_type == "Nothing"
|
||||||
|
|
||||||
# Load Checkpoint if required. If Y_type is LoRA, required models will be loaded by load_lora func.
|
# Load Checkpoint if required. If Y_type is LoRA, required models will be loaded by load_lora func.
|
||||||
if (X_type == "Checkpoint" and index == 0 and Y_type != "LoRA"):
|
if (X_type == "Checkpoint" and index == 0 and Y_type not in ("LoRA", "LoRA Stacks")):
|
||||||
if lora_stack is None:
|
if lora_stack is None:
|
||||||
model, clip, _ = load_checkpoint(ckpt_name, xyplot_id, cache=cache[1])
|
model, clip, _ = load_checkpoint(ckpt_name, xyplot_id, cache=cache[1])
|
||||||
else: # Load Efficient Loader LoRA
|
else: # Load Efficient Loader LoRA
|
||||||
@@ -1344,11 +1348,11 @@ class TSC_KSampler:
|
|||||||
encode = True
|
encode = True
|
||||||
|
|
||||||
# Load LoRA if required
|
# Load LoRA if required
|
||||||
elif (X_type == "LoRA" and index == 0):
|
elif (X_type in ("LoRA", "LoRA Stacks") and index == 0):
|
||||||
# Don't cache Checkpoints
|
# Don't cache Checkpoints
|
||||||
model, clip = load_lora(lora_stack, ckpt_name, xyplot_id, cache=cache[2])
|
model, clip = load_lora(lora_stack, ckpt_name, xyplot_id, cache=cache[2])
|
||||||
encode = True
|
encode = True
|
||||||
elif Y_type == "LoRA": # X_type must be Checkpoint, so cache those as defined
|
elif Y_type in ("LoRA", "LoRA Stacks"): # X_type must be Checkpoint, so cache those as defined
|
||||||
model, clip = load_lora(lora_stack, ckpt_name, xyplot_id,
|
model, clip = load_lora(lora_stack, ckpt_name, xyplot_id,
|
||||||
cache=None, ckpt_cache=cache[1])
|
cache=None, ckpt_cache=cache[1])
|
||||||
encode = True
|
encode = True
|
||||||
@@ -1568,7 +1572,7 @@ class TSC_KSampler:
|
|||||||
clear_cache_by_exception(xyplot_id, lora_dict=[], refn_dict=[])
|
clear_cache_by_exception(xyplot_id, lora_dict=[], refn_dict=[])
|
||||||
elif X_type == "Refiner":
|
elif X_type == "Refiner":
|
||||||
clear_cache_by_exception(xyplot_id, ckpt_dict=[], lora_dict=[])
|
clear_cache_by_exception(xyplot_id, ckpt_dict=[], lora_dict=[])
|
||||||
elif X_type == "LoRA":
|
elif X_type in ("LoRA", "LoRA Stacks"):
|
||||||
clear_cache_by_exception(xyplot_id, ckpt_dict=[], refn_dict=[])
|
clear_cache_by_exception(xyplot_id, ckpt_dict=[], refn_dict=[])
|
||||||
|
|
||||||
# __________________________________________________________________________________________________________
|
# __________________________________________________________________________________________________________
|
||||||
@@ -1668,7 +1672,7 @@ class TSC_KSampler:
|
|||||||
lora_name = lora_wt = lora_model_str = lora_clip_str = None
|
lora_name = lora_wt = lora_model_str = lora_clip_str = None
|
||||||
|
|
||||||
# Check for all possible LoRA types
|
# Check for all possible LoRA types
|
||||||
lora_types = ["LoRA", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr"]
|
lora_types = ["LoRA", "LoRA Stacks", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr"]
|
||||||
|
|
||||||
if X_type not in lora_types and Y_type not in lora_types:
|
if X_type not in lora_types and Y_type not in lora_types:
|
||||||
if lora_stack:
|
if lora_stack:
|
||||||
@@ -1681,7 +1685,7 @@ class TSC_KSampler:
|
|||||||
else:
|
else:
|
||||||
if X_type in lora_types:
|
if X_type in lora_types:
|
||||||
value = get_lora_sublist_name(X_type, X_value)
|
value = get_lora_sublist_name(X_type, X_value)
|
||||||
if X_type == "LoRA":
|
if X_type in ("LoRA", "LoRA Stacks"):
|
||||||
lora_name = value
|
lora_name = value
|
||||||
lora_model_str = None
|
lora_model_str = None
|
||||||
lora_clip_str = None
|
lora_clip_str = None
|
||||||
@@ -1703,7 +1707,7 @@ class TSC_KSampler:
|
|||||||
|
|
||||||
if Y_type in lora_types:
|
if Y_type in lora_types:
|
||||||
value = get_lora_sublist_name(Y_type, Y_value)
|
value = get_lora_sublist_name(Y_type, Y_value)
|
||||||
if Y_type == "LoRA":
|
if Y_type in ("LoRA", "LoRA Stacks"):
|
||||||
lora_name = value
|
lora_name = value
|
||||||
lora_model_str = None
|
lora_model_str = None
|
||||||
lora_clip_str = None
|
lora_clip_str = None
|
||||||
@@ -1726,13 +1730,13 @@ class TSC_KSampler:
|
|||||||
return lora_name, lora_wt, lora_model_str, lora_clip_str
|
return lora_name, lora_wt, lora_model_str, lora_clip_str
|
||||||
|
|
||||||
def get_lora_sublist_name(lora_type, lora_value):
|
def get_lora_sublist_name(lora_type, lora_value):
|
||||||
if lora_type == "LoRA" or lora_type == "LoRA Batch":
|
if lora_type in ("LoRA", "LoRA Batch", "LoRA Stacks"):
|
||||||
formatted_sublists = []
|
formatted_sublists = []
|
||||||
for sublist in lora_value:
|
for sublist in lora_value:
|
||||||
formatted_entries = []
|
formatted_entries = []
|
||||||
for x in sublist:
|
for x in sublist:
|
||||||
base_name = os.path.splitext(os.path.basename(str(x[0])))[0]
|
base_name = os.path.splitext(os.path.basename(str(x[0])))[0]
|
||||||
formatted_str = f"{base_name}({round(x[1], 3)},{round(x[2], 3)})" if lora_type == "LoRA" else f"{base_name}"
|
formatted_str = f"{base_name}({round(x[1], 3)},{round(x[2], 3)})" if lora_type in ("LoRA", "LoRA Stacks") else f"{base_name}"
|
||||||
formatted_entries.append(formatted_str)
|
formatted_entries.append(formatted_str)
|
||||||
formatted_sublists.append(f"{', '.join(formatted_entries)}")
|
formatted_sublists.append(f"{', '.join(formatted_entries)}")
|
||||||
return "\n ".join(formatted_sublists)
|
return "\n ".join(formatted_sublists)
|
||||||
@@ -2375,7 +2379,7 @@ class TSC_XYplot:
|
|||||||
# Check that dependencies are connected for specific plot types
|
# Check that dependencies are connected for specific plot types
|
||||||
encode_types = {
|
encode_types = {
|
||||||
"Checkpoint", "Refiner",
|
"Checkpoint", "Refiner",
|
||||||
"LoRA", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr",
|
"LoRA", "LoRA Stacks", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr",
|
||||||
"Positive Prompt S/R", "Negative Prompt S/R",
|
"Positive Prompt S/R", "Negative Prompt S/R",
|
||||||
"AScore+", "AScore-",
|
"AScore+", "AScore-",
|
||||||
"Clip Skip", "Clip Skip (Refiner)",
|
"Clip Skip", "Clip Skip (Refiner)",
|
||||||
@@ -2391,8 +2395,13 @@ class TSC_XYplot:
|
|||||||
# Check if both X_type and Y_type are special lora_types
|
# Check if both X_type and Y_type are special lora_types
|
||||||
lora_types = {"LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr"}
|
lora_types = {"LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr"}
|
||||||
if (X_type in lora_types and Y_type not in lora_types) or (Y_type in lora_types and X_type not in lora_types):
|
if (X_type in lora_types and Y_type not in lora_types) or (Y_type in lora_types and X_type not in lora_types):
|
||||||
print(
|
print(f"{error('XY Plot Error:')} Both X and Y must be connected to use the 'LoRA Plot' node.")
|
||||||
f"{error('XY Plot Error:')} Both X and Y must be connected to use the 'LoRA Plot' node.")
|
return (None,)
|
||||||
|
|
||||||
|
# Do not allow LoRA and LoRA Stacks
|
||||||
|
lora_types = {"LoRA", "LoRA Stacks"}
|
||||||
|
if (X_type in lora_types and Y_type in lora_types):
|
||||||
|
print(f"{error('XY Plot Error:')} X and Y input types must be different.")
|
||||||
return (None,)
|
return (None,)
|
||||||
|
|
||||||
# Clean Schedulers from Sampler data (if other type is Scheduler)
|
# Clean Schedulers from Sampler data (if other type is Scheduler)
|
||||||
@@ -3139,7 +3148,7 @@ class TSC_XYplot_LoRA_Stacks:
|
|||||||
CATEGORY = "Efficiency Nodes/XY Inputs"
|
CATEGORY = "Efficiency Nodes/XY Inputs"
|
||||||
|
|
||||||
def xy_value(self, node_state, lora_stack_1=None, lora_stack_2=None, lora_stack_3=None, lora_stack_4=None, lora_stack_5=None):
|
def xy_value(self, node_state, lora_stack_1=None, lora_stack_2=None, lora_stack_3=None, lora_stack_4=None, lora_stack_5=None):
|
||||||
xy_type = "LoRA"
|
xy_type = "LoRA Stacks"
|
||||||
xy_value = [stack for stack in [lora_stack_1, lora_stack_2, lora_stack_3, lora_stack_4, lora_stack_5] if stack is not None]
|
xy_value = [stack for stack in [lora_stack_1, lora_stack_2, lora_stack_3, lora_stack_4, lora_stack_5] if stack is not None]
|
||||||
if not xy_value or not any(xy_value) or node_state == "Disabled":
|
if not xy_value or not any(xy_value) or node_state == "Disabled":
|
||||||
return (None,)
|
return (None,)
|
||||||
|
|||||||
Reference in New Issue
Block a user