mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-26 07:35:45 -03:00
Efficiency Nodes V2.0
This commit is contained in:
210
tsc_utils.py
210
tsc_utils.py
@@ -36,16 +36,44 @@ loaded_objects = {
|
||||
"lora": [] # ([(lora_name, strength_model, strength_clip)], ckpt_name, lora_model, clip_lora, [id])
|
||||
}
|
||||
|
||||
# Cache for Ksampler (Efficient) Outputs
|
||||
# Cache for Efficient Ksamplers
|
||||
last_helds = {
|
||||
"preview_images": [], # (preview_images, id) # Preview Images, stored as a pil image list
|
||||
"latent": [], # (latent, id) # Latent outputs, stored as a latent tensor list
|
||||
"output_images": [], # (output_images, id) # Output Images, stored as an image tensor list
|
||||
"vae_decode_flag": [], # (vae_decode, id) # Boolean to track wether vae-decode during Holds
|
||||
"xy_plot_flag": [], # (xy_plot_flag, id) # Boolean to track if held images are xy_plot results
|
||||
"xy_plot_image": [], # (xy_plot_image, id) # XY Plot image stored as an image tensor
|
||||
"latent": [], # (latent, [parameters], id) # Base sampling latent results
|
||||
"image": [], # (image, id) # Base sampling image results
|
||||
"cnet_img": [] # (cnet_img, [parameters], id) # HiRes-Fix control net preprocessor image results
|
||||
}
|
||||
|
||||
def load_ksampler_results(key: str, my_unique_id, parameters_list=None):
|
||||
global last_helds
|
||||
for data in last_helds[key]:
|
||||
id_ = data[-1] # ID is always the last element in the tuple
|
||||
if id_ == my_unique_id:
|
||||
if parameters_list is not None:
|
||||
# Ensure tuple has at least 3 elements and match with parameters_list
|
||||
if len(data) >= 3 and data[1] == parameters_list:
|
||||
return data[0]
|
||||
else:
|
||||
return data[0]
|
||||
return None
|
||||
|
||||
def store_ksampler_results(key: str, my_unique_id, value, parameters_list=None):
|
||||
global last_helds
|
||||
|
||||
for i, data in enumerate(last_helds[key]):
|
||||
id_ = data[-1] # ID will always be the last in the tuple
|
||||
if id_ == my_unique_id:
|
||||
# Check if parameters_list is provided or not
|
||||
updated_data = (value, parameters_list, id_) if parameters_list is not None else (value, id_)
|
||||
last_helds[key][i] = updated_data
|
||||
return True
|
||||
|
||||
# If parameters_list is given
|
||||
if parameters_list is not None:
|
||||
last_helds[key].append((value, parameters_list, my_unique_id))
|
||||
else:
|
||||
last_helds[key].append((value, my_unique_id))
|
||||
return True
|
||||
|
||||
# Tensor to PIL (grabbed from WAS Suite)
|
||||
def tensor2pil(image: torch.Tensor) -> Image.Image:
|
||||
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
||||
@@ -54,6 +82,20 @@ def tensor2pil(image: torch.Tensor) -> Image.Image:
|
||||
def pil2tensor(image: Image.Image) -> torch.Tensor:
|
||||
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
||||
|
||||
# Convert tensor to PIL, resize it, and convert back to tensor
|
||||
def quick_resize(source_tensor: torch.Tensor, target_shape: tuple) -> torch.Tensor:
|
||||
resized_images = []
|
||||
for img in source_tensor:
|
||||
resized_pil = tensor2pil(img.squeeze(0)).resize((target_shape[2], target_shape[1]), Image.ANTIALIAS)
|
||||
resized_images.append(pil2tensor(resized_pil).squeeze(0))
|
||||
return torch.stack(resized_images, dim=0)
|
||||
|
||||
# Create a function to compute the hash of a tensor
|
||||
import hashlib
|
||||
def tensor_to_hash(tensor):
|
||||
byte_repr = tensor.cpu().numpy().tobytes() # Convert tensor to bytes
|
||||
return hashlib.sha256(byte_repr).hexdigest() # Compute hash
|
||||
|
||||
# Color coded messages functions
|
||||
MESSAGE_COLOR = "\033[36m" # Cyan
|
||||
XYPLOT_COLOR = "\033[35m" # Purple
|
||||
@@ -154,9 +196,11 @@ def globals_cleanup(prompt):
|
||||
# Step 1: Clean up last_helds
|
||||
for key in list(last_helds.keys()):
|
||||
original_length = len(last_helds[key])
|
||||
last_helds[key] = [(value, id) for value, id in last_helds[key] if str(id) in prompt.keys()]
|
||||
###if original_length != len(last_helds[key]):
|
||||
###print(f'Updated {key} in last_helds: {last_helds[key]}')
|
||||
last_helds[key] = [
|
||||
(*values, id_)
|
||||
for *values, id_ in last_helds[key]
|
||||
if str(id_) in prompt.keys()
|
||||
]
|
||||
|
||||
# Step 2: Clean up loaded_objects
|
||||
for key in list(loaded_objects.keys()):
|
||||
@@ -250,6 +294,7 @@ def load_vae(vae_name, id, cache=None, cache_overwrite=False):
|
||||
vae_path = vae_name
|
||||
else:
|
||||
vae_path = folder_paths.get_full_path("vae", vae_name)
|
||||
|
||||
sd = comfy.utils.load_torch_file(vae_path)
|
||||
vae = comfy.sd.VAE(sd=sd)
|
||||
|
||||
@@ -473,20 +518,11 @@ def global_preview_method():
|
||||
#-----------------------------------------------------------------------------------------------------------------------
|
||||
# Auto install Efficiency Nodes Python package dependencies
|
||||
import subprocess
|
||||
# Note: This auto-installer attempts to import packages listed in the requirements.txt.
|
||||
# If the import fails, indicating the package isn't installed, the installer proceeds to install the package.
|
||||
# Note: This auto-installer installs packages listed in the requirements.txt.
|
||||
# It first checks if python.exe exists inside the ...\ComfyUI_windows_portable\python_embeded directory.
|
||||
# If python.exe is found in that location, it will use this embedded Python version for the installation.
|
||||
# Otherwise, it uses the Python interpreter that's currently executing the script (via sys.executable)
|
||||
# to attempt a general pip install of the packages. If any errors occur during installation, an error message is
|
||||
# printed with the reason for the failure, and the user is directed to manually install the required packages.
|
||||
|
||||
def is_package_installed(pkg_name):
|
||||
try:
|
||||
__import__(pkg_name)
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
# Otherwise, it uses the Python interpreter that's currently executing the script (via sys.executable) to attempt a general pip install of the packages.
|
||||
# If any errors occur during installation, the user is directed to manually install the required packages.
|
||||
|
||||
def install_packages(my_dir):
|
||||
# Compute path to the target site-packages
|
||||
@@ -500,26 +536,41 @@ def install_packages(my_dir):
|
||||
with open(os.path.join(my_dir, 'requirements.txt'), 'r') as f:
|
||||
required_packages = [line.strip() for line in f if line.strip()]
|
||||
|
||||
for pkg in required_packages:
|
||||
if not is_package_installed(pkg):
|
||||
printout = f"Installing required package '{pkg}'..."
|
||||
print(f"{message('Efficiency Nodes:')} {printout}", end='', flush=True)
|
||||
try:
|
||||
installed_packages = packages(embedded_python_exe if use_embedded else None, versions=False)
|
||||
|
||||
for pkg in required_packages:
|
||||
if pkg not in installed_packages:
|
||||
printout = f"Installing required package '{pkg}'..."
|
||||
print(f"{message('Efficiency Nodes:')} {printout}", end='', flush=True)
|
||||
|
||||
try:
|
||||
if use_embedded: # Targeted installation
|
||||
subprocess.check_call([embedded_python_exe, '-m', 'pip', 'install', pkg, '--target=' + target_dir,
|
||||
'--no-warn-script-location', '--disable-pip-version-check'],
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=7)
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
else: # Untargeted installation
|
||||
subprocess.check_call([sys.executable, "-m", "pip", 'install', pkg],
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=7)
|
||||
print(f"\r{message('Efficiency Nodes:')} {printout}{success(' Installed!')}", flush=True)
|
||||
except Exception as e:
|
||||
print(f"\r{message('Efficiency Nodes:')} {printout}{error(' Failed!')}", flush=True)
|
||||
print(f"{warning(str(e))}")
|
||||
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.PIPE)
|
||||
|
||||
print(f"\r{message('Efficiency Nodes:')} {printout}{success('Installed!')}", flush=True)
|
||||
|
||||
except Exception as e: # This catches all exceptions derived from the base Exception class
|
||||
print_general_error_message()
|
||||
|
||||
def packages(python_exe=None, versions=False):
|
||||
try:
|
||||
if python_exe:
|
||||
return [(r.decode().split('==')[0] if not versions else r.decode()) for r in
|
||||
subprocess.check_output([python_exe, '-m', 'pip', 'freeze']).split()]
|
||||
else:
|
||||
return [(r.split('==')[0] if not versions else r) for r in
|
||||
subprocess.getoutput([sys.executable, "-m", "pip", "freeze"]).splitlines()]
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise e # re-raise the error to handle it outside
|
||||
|
||||
def print_general_error_message():
|
||||
print(f"{message('Efficiency Nodes:')} An unexpected error occurred during the package installation process. {error('Failed!')}")
|
||||
print(
|
||||
f"\r{message('Efficiency Nodes:')} An unexpected error occurred during the package installation process. {error('Failed!')}")
|
||||
print(warning("Please try manually installing the required packages from the requirements.txt file."))
|
||||
|
||||
# Install missing packages
|
||||
@@ -538,85 +589,7 @@ if os.path.exists(destination_dir):
|
||||
shutil.rmtree(destination_dir)
|
||||
|
||||
#-----------------------------------------------------------------------------------------------------------------------
|
||||
# Establish a websocket connection to communicate with "efficiency-nodes.js" under:
|
||||
# ComfyUI\web\extensions\efficiency-nodes-comfyui\
|
||||
def handle_websocket_failure():
|
||||
global websocket_status
|
||||
if websocket_status: # Ensures the message is printed only once
|
||||
websocket_status = False
|
||||
print(f"\r\033[33mEfficiency Nodes Warning:\033[0m Websocket connection failure."
|
||||
f"\nEfficient KSampler's live preview images may not clear when vae decoding is set to 'true'.")
|
||||
|
||||
# Initialize websocket related global variables
|
||||
websocket_status = True
|
||||
latest_image = list()
|
||||
connected_client = None
|
||||
|
||||
try:
|
||||
import websockets
|
||||
import asyncio
|
||||
import threading
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from torchvision import transforms
|
||||
except ImportError:
|
||||
handle_websocket_failure()
|
||||
|
||||
async def server_logic(websocket, path):
|
||||
global latest_image, connected_client, websocket_status
|
||||
|
||||
# If websocket_status is False, set latest_image to an empty list
|
||||
if not websocket_status:
|
||||
latest_image = list()
|
||||
|
||||
# Assign the connected client
|
||||
connected_client = websocket
|
||||
|
||||
try:
|
||||
async for message in websocket:
|
||||
# If not a command, treat it as image data
|
||||
if not message.startswith('{'):
|
||||
image_data = base64.b64decode(message.split(",")[1])
|
||||
image = Image.open(BytesIO(image_data))
|
||||
latest_image = pil2tensor(image)
|
||||
except (websockets.exceptions.ConnectionClosedError, asyncio.exceptions.CancelledError):
|
||||
handle_websocket_failure()
|
||||
except Exception:
|
||||
handle_websocket_failure()
|
||||
|
||||
def run_server():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
start_server = websockets.serve(server_logic, "127.0.0.1", 8288)
|
||||
loop.run_until_complete(start_server)
|
||||
loop.run_forever()
|
||||
except Exception: # Catch all exceptions
|
||||
handle_websocket_failure()
|
||||
|
||||
def get_latest_image():
|
||||
return latest_image
|
||||
|
||||
# Function to send commands to frontend
|
||||
def send_command_to_frontend(startListening=False, maxCount=0, sendBlob=False):
|
||||
global connected_client, websocket_status
|
||||
if connected_client and websocket_status:
|
||||
try:
|
||||
asyncio.run(connected_client.send(json.dumps({
|
||||
'startProcessing': startListening,
|
||||
'maxCount': maxCount,
|
||||
'sendBlob': sendBlob
|
||||
})))
|
||||
except Exception:
|
||||
handle_websocket_failure()
|
||||
|
||||
# Start the WebSocket server in a separate thread
|
||||
if websocket_status == True:
|
||||
server_thread = threading.Thread(target=run_server)
|
||||
server_thread.daemon = True
|
||||
server_thread.start()
|
||||
|
||||
|
||||
# Other
|
||||
class XY_Capsule:
|
||||
def pre_define_model(self, model, clip, vae):
|
||||
return model, clip, vae
|
||||
@@ -632,3 +605,12 @@ class XY_Capsule:
|
||||
|
||||
def getLabel(self):
|
||||
return "Unknown"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user