Efficiency Nodes V2.0

This commit is contained in:
TSC
2023-10-20 15:49:32 -05:00
committed by GitHub
parent 749c42b69b
commit 93eb925686
41 changed files with 5013 additions and 888 deletions

View File

@@ -36,16 +36,44 @@ loaded_objects = {
"lora": [] # ([(lora_name, strength_model, strength_clip)], ckpt_name, lora_model, clip_lora, [id])
}
# Cache for Ksampler (Efficient) Outputs
# Cache for Efficient Ksamplers
last_helds = {
"preview_images": [], # (preview_images, id) # Preview Images, stored as a pil image list
"latent": [], # (latent, id) # Latent outputs, stored as a latent tensor list
"output_images": [], # (output_images, id) # Output Images, stored as an image tensor list
"vae_decode_flag": [], # (vae_decode, id) # Boolean to track wether vae-decode during Holds
"xy_plot_flag": [], # (xy_plot_flag, id) # Boolean to track if held images are xy_plot results
"xy_plot_image": [], # (xy_plot_image, id) # XY Plot image stored as an image tensor
"latent": [], # (latent, [parameters], id) # Base sampling latent results
"image": [], # (image, id) # Base sampling image results
"cnet_img": [] # (cnet_img, [parameters], id) # HiRes-Fix control net preprocessor image results
}
def load_ksampler_results(key: str, my_unique_id, parameters_list=None):
global last_helds
for data in last_helds[key]:
id_ = data[-1] # ID is always the last element in the tuple
if id_ == my_unique_id:
if parameters_list is not None:
# Ensure tuple has at least 3 elements and match with parameters_list
if len(data) >= 3 and data[1] == parameters_list:
return data[0]
else:
return data[0]
return None
def store_ksampler_results(key: str, my_unique_id, value, parameters_list=None):
global last_helds
for i, data in enumerate(last_helds[key]):
id_ = data[-1] # ID will always be the last in the tuple
if id_ == my_unique_id:
# Check if parameters_list is provided or not
updated_data = (value, parameters_list, id_) if parameters_list is not None else (value, id_)
last_helds[key][i] = updated_data
return True
# If parameters_list is given
if parameters_list is not None:
last_helds[key].append((value, parameters_list, my_unique_id))
else:
last_helds[key].append((value, my_unique_id))
return True
# Tensor to PIL (grabbed from WAS Suite)
def tensor2pil(image: torch.Tensor) -> Image.Image:
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
@@ -54,6 +82,20 @@ def tensor2pil(image: torch.Tensor) -> Image.Image:
def pil2tensor(image: Image.Image) -> torch.Tensor:
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
# Convert tensor to PIL, resize it, and convert back to tensor
def quick_resize(source_tensor: torch.Tensor, target_shape: tuple) -> torch.Tensor:
resized_images = []
for img in source_tensor:
resized_pil = tensor2pil(img.squeeze(0)).resize((target_shape[2], target_shape[1]), Image.ANTIALIAS)
resized_images.append(pil2tensor(resized_pil).squeeze(0))
return torch.stack(resized_images, dim=0)
# Create a function to compute the hash of a tensor
import hashlib
def tensor_to_hash(tensor):
byte_repr = tensor.cpu().numpy().tobytes() # Convert tensor to bytes
return hashlib.sha256(byte_repr).hexdigest() # Compute hash
# Color coded messages functions
MESSAGE_COLOR = "\033[36m" # Cyan
XYPLOT_COLOR = "\033[35m" # Purple
@@ -154,9 +196,11 @@ def globals_cleanup(prompt):
# Step 1: Clean up last_helds
for key in list(last_helds.keys()):
original_length = len(last_helds[key])
last_helds[key] = [(value, id) for value, id in last_helds[key] if str(id) in prompt.keys()]
###if original_length != len(last_helds[key]):
###print(f'Updated {key} in last_helds: {last_helds[key]}')
last_helds[key] = [
(*values, id_)
for *values, id_ in last_helds[key]
if str(id_) in prompt.keys()
]
# Step 2: Clean up loaded_objects
for key in list(loaded_objects.keys()):
@@ -250,6 +294,7 @@ def load_vae(vae_name, id, cache=None, cache_overwrite=False):
vae_path = vae_name
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
@@ -473,20 +518,11 @@ def global_preview_method():
#-----------------------------------------------------------------------------------------------------------------------
# Auto install Efficiency Nodes Python package dependencies
import subprocess
# Note: This auto-installer attempts to import packages listed in the requirements.txt.
# If the import fails, indicating the package isn't installed, the installer proceeds to install the package.
# Note: This auto-installer installs packages listed in the requirements.txt.
# It first checks if python.exe exists inside the ...\ComfyUI_windows_portable\python_embeded directory.
# If python.exe is found in that location, it will use this embedded Python version for the installation.
# Otherwise, it uses the Python interpreter that's currently executing the script (via sys.executable)
# to attempt a general pip install of the packages. If any errors occur during installation, an error message is
# printed with the reason for the failure, and the user is directed to manually install the required packages.
def is_package_installed(pkg_name):
try:
__import__(pkg_name)
return True
except ImportError:
return False
# Otherwise, it uses the Python interpreter that's currently executing the script (via sys.executable) to attempt a general pip install of the packages.
# If any errors occur during installation, the user is directed to manually install the required packages.
def install_packages(my_dir):
# Compute path to the target site-packages
@@ -500,26 +536,41 @@ def install_packages(my_dir):
with open(os.path.join(my_dir, 'requirements.txt'), 'r') as f:
required_packages = [line.strip() for line in f if line.strip()]
for pkg in required_packages:
if not is_package_installed(pkg):
printout = f"Installing required package '{pkg}'..."
print(f"{message('Efficiency Nodes:')} {printout}", end='', flush=True)
try:
installed_packages = packages(embedded_python_exe if use_embedded else None, versions=False)
for pkg in required_packages:
if pkg not in installed_packages:
printout = f"Installing required package '{pkg}'..."
print(f"{message('Efficiency Nodes:')} {printout}", end='', flush=True)
try:
if use_embedded: # Targeted installation
subprocess.check_call([embedded_python_exe, '-m', 'pip', 'install', pkg, '--target=' + target_dir,
'--no-warn-script-location', '--disable-pip-version-check'],
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=7)
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
else: # Untargeted installation
subprocess.check_call([sys.executable, "-m", "pip", 'install', pkg],
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=7)
print(f"\r{message('Efficiency Nodes:')} {printout}{success(' Installed!')}", flush=True)
except Exception as e:
print(f"\r{message('Efficiency Nodes:')} {printout}{error(' Failed!')}", flush=True)
print(f"{warning(str(e))}")
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
print(f"\r{message('Efficiency Nodes:')} {printout}{success('Installed!')}", flush=True)
except Exception as e: # This catches all exceptions derived from the base Exception class
print_general_error_message()
def packages(python_exe=None, versions=False):
try:
if python_exe:
return [(r.decode().split('==')[0] if not versions else r.decode()) for r in
subprocess.check_output([python_exe, '-m', 'pip', 'freeze']).split()]
else:
return [(r.split('==')[0] if not versions else r) for r in
subprocess.getoutput([sys.executable, "-m", "pip", "freeze"]).splitlines()]
except subprocess.CalledProcessError as e:
raise e # re-raise the error to handle it outside
def print_general_error_message():
print(f"{message('Efficiency Nodes:')} An unexpected error occurred during the package installation process. {error('Failed!')}")
print(
f"\r{message('Efficiency Nodes:')} An unexpected error occurred during the package installation process. {error('Failed!')}")
print(warning("Please try manually installing the required packages from the requirements.txt file."))
# Install missing packages
@@ -538,85 +589,7 @@ if os.path.exists(destination_dir):
shutil.rmtree(destination_dir)
#-----------------------------------------------------------------------------------------------------------------------
# Establish a websocket connection to communicate with "efficiency-nodes.js" under:
# ComfyUI\web\extensions\efficiency-nodes-comfyui\
def handle_websocket_failure():
global websocket_status
if websocket_status: # Ensures the message is printed only once
websocket_status = False
print(f"\r\033[33mEfficiency Nodes Warning:\033[0m Websocket connection failure."
f"\nEfficient KSampler's live preview images may not clear when vae decoding is set to 'true'.")
# Initialize websocket related global variables
websocket_status = True
latest_image = list()
connected_client = None
try:
import websockets
import asyncio
import threading
import base64
from io import BytesIO
from torchvision import transforms
except ImportError:
handle_websocket_failure()
async def server_logic(websocket, path):
global latest_image, connected_client, websocket_status
# If websocket_status is False, set latest_image to an empty list
if not websocket_status:
latest_image = list()
# Assign the connected client
connected_client = websocket
try:
async for message in websocket:
# If not a command, treat it as image data
if not message.startswith('{'):
image_data = base64.b64decode(message.split(",")[1])
image = Image.open(BytesIO(image_data))
latest_image = pil2tensor(image)
except (websockets.exceptions.ConnectionClosedError, asyncio.exceptions.CancelledError):
handle_websocket_failure()
except Exception:
handle_websocket_failure()
def run_server():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
start_server = websockets.serve(server_logic, "127.0.0.1", 8288)
loop.run_until_complete(start_server)
loop.run_forever()
except Exception: # Catch all exceptions
handle_websocket_failure()
def get_latest_image():
return latest_image
# Function to send commands to frontend
def send_command_to_frontend(startListening=False, maxCount=0, sendBlob=False):
global connected_client, websocket_status
if connected_client and websocket_status:
try:
asyncio.run(connected_client.send(json.dumps({
'startProcessing': startListening,
'maxCount': maxCount,
'sendBlob': sendBlob
})))
except Exception:
handle_websocket_failure()
# Start the WebSocket server in a separate thread
if websocket_status == True:
server_thread = threading.Thread(target=run_server)
server_thread.daemon = True
server_thread.start()
# Other
class XY_Capsule:
def pre_define_model(self, model, clip, vae):
return model, clip, vae
@@ -632,3 +605,12 @@ class XY_Capsule:
def getLabel(self):
return "Unknown"