Merge branch 'Main' into feat/xy_capsule

This commit is contained in:
Dr.Lt.Data
2023-08-31 12:59:32 +09:00
31 changed files with 3569 additions and 1594 deletions

View File

@@ -31,12 +31,13 @@ from comfy.cli_args import args
# Cache for Efficiency Node models
loaded_objects = {
"ckpt": [], # (ckpt_name, ckpt_model, clip, bvae, [id])
"refn": [], # (ckpt_name, ckpt_model, clip, bvae, [id])
"vae": [], # (vae_name, vae, [id])
"lora": [] # ([(lora_name, strength_model, strength_clip)], ckpt_name, lora_model, clip_lora, [id])
}
# Cache for Ksampler (Efficient) Outputs
last_helds: dict[str, list] = {
last_helds = {
"preview_images": [], # (preview_images, id) # Preview Images, stored as a pil image list
"latent": [], # (latent, id) # Latent outputs, stored as a latent tensor list
"output_images": [], # (output_images, id) # Output Images, stored as an image tensor list
@@ -53,6 +54,29 @@ def tensor2pil(image: torch.Tensor) -> Image.Image:
def pil2tensor(image: Image.Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
# Color coded messages functions
MESSAGE_COLOR = "\033[36m" # Cyan
XYPLOT_COLOR = "\033[35m" # Purple
SUCCESS_COLOR = "\033[92m" # Green
WARNING_COLOR = "\033[93m" # Yellow
ERROR_COLOR = "\033[91m" # Red
INFO_COLOR = "\033[90m" # Gray
def format_message(text, color_code):
RESET_COLOR = "\033[0m"
return f"{color_code}{text}{RESET_COLOR}"
def message(text):
return format_message(text, MESSAGE_COLOR)
def warning(text):
return format_message(text, WARNING_COLOR)
def error(text):
return format_message(text, ERROR_COLOR)
def success(text):
return format_message(text, SUCCESS_COLOR)
def xyplot_message(text):
return format_message(text, XYPLOT_COLOR)
def info(text):
return format_message(text, INFO_COLOR)
def extract_node_info(prompt, id, indirect_key=None):
# Convert ID to string
id = str(id)
@@ -95,7 +119,7 @@ def print_loaded_objects_entries(id=None, prompt=None, show_id=False):
else:
print(f"\033[36mModels Cache: \nnode_id:{int(id)}\033[0m")
entries_found = False
for key in ["ckpt", "vae", "lora"]:
for key in ["ckpt", "refn", "vae", "lora"]:
entries_with_id = loaded_objects[key] if id is None else [entry for entry in loaded_objects[key] if id in entry[-1]]
if not entries_with_id: # If no entries with the chosen ID, print None and skip this key
continue
@@ -103,13 +127,15 @@ def print_loaded_objects_entries(id=None, prompt=None, show_id=False):
print(f"{key.capitalize()}:")
for i, entry in enumerate(entries_with_id, 1): # Start numbering from 1
if key == "lora":
lora_models_info = ', '.join(f"{os.path.splitext(os.path.basename(name))[0]}({round(strength_model, 2)},{round(strength_clip, 2)})" for name, strength_model, strength_clip in entry[0])
base_ckpt_name = os.path.splitext(os.path.basename(entry[1]))[0] # Split logic for base_ckpt
if id is None:
associated_ids = ', '.join(map(str, entry[-1])) # Gather all associated ids
print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info} (ids: {associated_ids})")
print(f" [{i}] base_ckpt: {base_ckpt_name} (ids: {associated_ids})")
else:
print(f" [{i}] base_ckpt: {base_ckpt_name}, lora(mod,clip): {lora_models_info}")
print(f" [{i}] base_ckpt: {base_ckpt_name}")
for name, strength_model, strength_clip in entry[0]:
lora_model_info = f"{os.path.splitext(os.path.basename(name))[0]}({round(strength_model, 2)},{round(strength_clip, 2)})"
print(f" lora(mod,clip): {lora_model_info}")
else:
name_without_ext = os.path.splitext(os.path.basename(entry[0]))[0]
if id is None:
@@ -146,59 +172,54 @@ def globals_cleanup(prompt):
loaded_objects[key].remove(tup)
###print(f'Deleted tuple at index {i} in {key} in loaded_objects because its id array became empty.')
def load_checkpoint(ckpt_name, id, output_vae=True, cache=None, cache_overwrite=False):
"""
Searches for tuple index that contains ckpt_name in "ckpt" array of loaded_objects.
If found, extracts the model, clip, and vae from the loaded_objects.
If not found, loads the checkpoint, extracts the model, clip, and vae.
The id parameter represents the node ID and is used for caching models for the XY Plot node.
If the cache limit is reached for a specific id, clears the cache and returns the loaded model, clip, and vae without adding a new entry.
If there is cache space, adds the id to the ids list if it's not already there.
If there is cache space and the checkpoint was not found in loaded_objects, adds a new entry to loaded_objects.
Parameters:
- ckpt_name: name of the checkpoint to load.
- id: an identifier for caching models for specific nodes.
- output_vae: boolean, if True loads the VAE too.
- cache (optional): an integer that specifies how many checkpoint entries with a given id can exist in loaded_objects. Defaults to None.
"""
def load_checkpoint(ckpt_name, id, output_vae=True, cache=None, cache_overwrite=False, ckpt_type="ckpt"):
global loaded_objects
for entry in loaded_objects["ckpt"]:
# Create copies of the arguments right at the start
ckpt_name = ckpt_name.copy() if isinstance(ckpt_name, (list, dict, set)) else ckpt_name
# Check if the type is valid
if ckpt_type not in ["ckpt", "refn"]:
raise ValueError(f"Invalid checkpoint type: {ckpt_type}")
for entry in loaded_objects[ckpt_type]:
if entry[0] == ckpt_name:
_, model, clip, vae, ids = entry
cache_full = cache and len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) >= cache
cache_full = cache and len([entry for entry in loaded_objects[ckpt_type] if id in entry[-1]]) >= cache
if cache_full:
clear_cache(id, cache, "ckpt")
clear_cache(id, cache, ckpt_type)
elif id not in ids:
ids.append(id)
return model, clip, vae
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
if os.path.isabs(ckpt_name):
ckpt_path = ckpt_name
else:
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
with suppress_output():
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae, output_clip=True,
embedding_directory=folder_paths.get_folder_paths("embeddings"))
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
model = out[0]
clip = out[1]
vae = out[2] # bvae
vae = out[2] if output_vae else None # Load VAE from the checkpoint path only if output_vae is True
if cache:
if len([entry for entry in loaded_objects["ckpt"] if id in entry[-1]]) < cache:
loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id]))
cache_list = [entry for entry in loaded_objects[ckpt_type] if id in entry[-1]]
if len(cache_list) < cache:
loaded_objects[ckpt_type].append((ckpt_name, model, clip, vae, [id]))
else:
clear_cache(id, cache, "ckpt")
clear_cache(id, cache, ckpt_type)
if cache_overwrite:
# Find the first entry with the id, remove the id from the entry's id list
for e in loaded_objects["ckpt"]:
for e in loaded_objects[ckpt_type]:
if id in e[-1]:
e[-1].remove(id)
# If the id list becomes empty, remove the entry from the "ckpt" list
# If the id list becomes empty, remove the entry from the ckpt_type list
if not e[-1]:
loaded_objects["ckpt"].remove(e)
loaded_objects[ckpt_type].remove(e)
break
loaded_objects["ckpt"].append((ckpt_name, model, clip, vae, [id]))
loaded_objects[ckpt_type].append((ckpt_name, model, clip, vae, [id]))
return model, clip, vae
@@ -209,21 +230,11 @@ def get_bvae_by_ckpt_name(ckpt_name):
return None # return None if no match is found
def load_vae(vae_name, id, cache=None, cache_overwrite=False):
"""
Extracts the vae with a given name from the "vae" array in loaded_objects.
If the vae is not found, creates a new VAE object with the given name and adds it to the "vae" array.
Also stores the id parameter, which is used for caching models specifically for nodes with the given ID.
If the cache limit is reached for a specific id, returns the loaded vae without adding id or making a new entry in loaded_objects.
If there is cache space, and the id is not in the ids list, adds the id to the ids list.
If there is cache space, and the vae was not found in loaded_objects, adds a new entry to the loaded_objects.
Parameters:
- vae_name: name of the VAE to load.
- id (optional): an identifier for caching models for specific nodes. Defaults to None.
- cache (optional): an integer that specifies how many vae entries with a given id can exist in loaded_objects. Defaults to None.
"""
global loaded_objects
# Create copies of the argument right at the start
vae_name = vae_name.copy() if isinstance(vae_name, (list, dict, set)) else vae_name
for i, entry in enumerate(loaded_objects["vae"]):
if entry[0] == vae_name:
vae, ids = entry[1], entry[2]
@@ -235,7 +246,10 @@ def load_vae(vae_name, id, cache=None, cache_overwrite=False):
clear_cache(id, cache, "vae")
return vae
vae_path = folder_paths.get_full_path("vae", vae_name)
if os.path.isabs(vae_name):
vae_path = vae_name
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
vae = comfy.sd.VAE(ckpt_path=vae_path)
if cache:
@@ -257,23 +271,14 @@ def load_vae(vae_name, id, cache=None, cache_overwrite=False):
return vae
def load_lora(lora_params, ckpt_name, id, cache=None, ckpt_cache=None, cache_overwrite=False):
"""
Extracts the Lora model with a given name from the "lora" array in loaded_objects.
If the Lora model is not found or strength values changed or model changed, creates a new Lora object with the given name and adds it to the "lora" array.
Also stores the id parameter, which is used for caching models specifically for nodes with the given ID.
If the cache limit is reached for a specific id, clears the cache and returns the loaded Lora model and clip without adding a new entry.
If there is cache space, adds the id to the ids list if it's not already there.
If there is cache space and the Lora model was not found in loaded_objects, adds a new entry to loaded_objects.
Parameters:
- lora_params: A list of tuples, where each tuple contains lora_name, strength_model, strength_clip.
- ckpt_name: name of the checkpoint from which the Lora model is created.
- id: an identifier for caching models for specific nodes.
- cache (optional): an integer that specifies how many Lora entries with a given id can exist in loaded_objects. Defaults to None.
"""
global loaded_objects
# Create copies of the arguments right at the start
lora_params = lora_params.copy() if isinstance(lora_params, (list, dict, set)) else lora_params
ckpt_name = ckpt_name.copy() if isinstance(ckpt_name, (list, dict, set)) else ckpt_name
for entry in loaded_objects["lora"]:
# Convert to sets and compare
if set(entry[0]) == set(lora_params) and entry[1] == ckpt_name:
@@ -304,7 +309,11 @@ def load_lora(lora_params, ckpt_name, id, cache=None, ckpt_cache=None, cache_ove
return ckpt, clip
lora_name, strength_model, strength_clip = lora_params[0]
lora_path = folder_paths.get_full_path("loras", lora_name)
if os.path.isabs(lora_name):
lora_path = lora_name
else:
lora_path = folder_paths.get_full_path("loras", lora_name)
lora_model, lora_clip = comfy.sd.load_lora_for_models(ckpt, clip, comfy.utils.load_torch_file(lora_path), strength_model, strength_clip)
# Call the function again with the new lora_model and lora_clip and the remaining tuples
@@ -336,7 +345,7 @@ def load_lora(lora_params, ckpt_name, id, cache=None, ckpt_cache=None, cache_ove
def clear_cache(id, cache, dict_name):
"""
Clear the cache for a specific id in a specific dictionary (either "ckpt" or "vae").
Clear the cache for a specific id in a specific dictionary.
If the cache limit is reached for a specific id, deletes the id from the oldest entry.
If the id array of the entry becomes empty, deletes the entry.
"""
@@ -353,16 +362,18 @@ def clear_cache(id, cache, dict_name):
# Update the id_associated_entries
id_associated_entries = [entry for entry in loaded_objects[dict_name] if id in entry[-1]]
def clear_cache_by_exception(node_id, vae_dict=None, ckpt_dict=None, lora_dict=None):
def clear_cache_by_exception(node_id, vae_dict=None, ckpt_dict=None, lora_dict=None, refn_dict=None):
global loaded_objects
dict_mapping = {
"vae_dict": "vae",
"ckpt_dict": "ckpt",
"lora_dict": "lora"
"lora_dict": "lora",
"refn_dict": "refn"
}
for arg_name, arg_val in {"vae_dict": vae_dict, "ckpt_dict": ckpt_dict, "lora_dict": lora_dict}.items():
for arg_name, arg_val in {"vae_dict": vae_dict, "ckpt_dict": ckpt_dict, "lora_dict": lora_dict, "refn_dict": refn_dict}.items():
if arg_val is None:
continue
@@ -401,7 +412,8 @@ def get_cache_numbers(node_name):
vae_cache = int(model_cache_settings.get('vae', 1))
ckpt_cache = int(model_cache_settings.get('ckpt', 1))
lora_cache = int(model_cache_settings.get('lora', 1))
return vae_cache, ckpt_cache, lora_cache
refn_cache = int(model_cache_settings.get('ckpt', 1))
return vae_cache, ckpt_cache, lora_cache, refn_cache,
def print_last_helds(id=None):
print("\n" + "-" * 40) # Print an empty line followed by a separator line
@@ -509,19 +521,33 @@ def packages(python_exe=None, versions=False):
install_packages(my_dir)
#-----------------------------------------------------------------------------------------------------------------------
# Auto install efficiency nodes web extension '\js\efficiency_nodes.js' to 'ComfyUI\web\extensions'
# Auto install efficiency nodes web extensions '\js\' to 'ComfyUI\web\extensions'
import shutil
# Source and destination paths
source_path = os.path.join(my_dir, 'js', 'efficiency_nodes.js')
# Source and destination directories
source_dir = os.path.join(my_dir, 'js')
destination_dir = os.path.join(comfy_dir, 'web', 'extensions', 'efficiency-nodes-comfyui')
destination_path = os.path.join(destination_dir, 'efficiency_nodes.js')
# Create the destination directory if it doesn't exist
os.makedirs(destination_dir, exist_ok=True)
# Copy the file
shutil.copy2(source_path, destination_path)
# Get a list of all .js files in the source directory
source_files = [f for f in os.listdir(source_dir) if f.endswith('.js')]
# Clear files in the destination directory that aren't in the source directory
for file_name in os.listdir(destination_dir):
if file_name not in source_files and file_name.endswith('.js'):
file_path = os.path.join(destination_dir, file_name)
os.unlink(file_path)
# Iterate over all files in the source directory for copying
for file_name in source_files:
# Full paths for source and destination
source_path = os.path.join(source_dir, file_name)
destination_path = os.path.join(destination_dir, file_name)
# Directly copy the file (this will overwrite if the file already exists)
shutil.copy2(source_path, destination_path)
#-----------------------------------------------------------------------------------------------------------------------
# Establish a websocket connection to communicate with "efficiency-nodes.js" under: