diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 93baf58..0000000 --- a/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -__pycache__/__init__.cpython-310.pyc -__pycache__/efficiency_nodes.cpython-310.pyc diff --git a/README.md b/README.md index bb46036..d3cbc52 100644 --- a/README.md +++ b/README.md @@ -1,96 +1,168 @@ -Efficiency Nodes for ComfyUI -======= -### A collection of ComfyUI custom nodes to help streamline workflows and reduce total node count. -## [Direct Download Link](https://github.com/LucianoCirino/efficiency-nodes-comfyui/releases/download/v1.92/efficiency-nodes-comfyui-v192.7z) - -
- Efficient Loader - -- A combination of common initialization nodes. -- Able to load LoRA and Control Net stacks via its 'lora_stack' and 'cnet_stack' inputs. -- Can cache multiple Checkpoint, VAE, and LoRA models. (cache settings found in config file 'node_settings.json') -- Used by the XY Plot node for many of its plot type dependencies. - -

- -

- -
- -
- Ksampler & Ksampler Adv. (Efficient) - -- Modded KSamplers with the ability to live preview generations and/or vae decode images. -- Used for running the XY Plot script. ('sampler_state' = "Script") -- Can be set to re-output their last outputs by force. ('sampler_state' = "Hold") - -

- -         - -

- -
- -
- XY Plot - -- Node that allows users to specify parameters for the Efficient KSampler's to plot on a grid. - -

- -

- -
- -
- Image Overlay - -- Node that allows for flexible image overlaying. Works also with image batches. - -

- -

- -
- -
- SimpleEval Nodes - -- A collection of nodes that allows users to write simple Python expressions for a variety of data types using the "simpleeval" library. - -

- -     - -     - -

- -
- -## **Examples:** - -- HiResFix using the **Efficient Loader**, **Ksampler (Efficient)**, and **HiResFix Script** nodes - -[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/main/workflows/HiRes%20Fix.png) - -- SDXL Refining using the **Eff. SDXL Loader**, and **Ksampler SDXL (Eff.)** - -[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/main/workflows/SDXL%20Base%2BRefine.png) - -- 2D Plotting using the **XY Plot** & **Ksampler (Efficient)** nodes - -[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/main/workflows/XYplot/X-Seeds%20Y-Checkpoints.png) - -[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/main/workflows/XYplot/LoRA%20Plot%20X-ModelStr%20Y-ClipStr.png) - -- Photobashing using the **Image Overlay** node - -[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/main/workflows/Image%20Overlay.png) - -### Dependencies -Dependencies are automatically installed during ComfyUI boot up. - -## **Install:** -To install, drop the "_**efficiency-nodes-comfyui**_" folder into the "_**...\ComfyUI\ComfyUI\custom_nodes**_" directory and restart UI. +Efficiency Nodes for ComfyUI +======= +### A collection of ComfyUI custom nodes to help streamline workflows and reduce total node count. +## [Direct Download Link](https://github.com/LucianoCirino/efficiency-nodes-comfyui/releases/download/v1.92/efficiency-nodes-comfyui-v192.7z) +### Nodes: + +
+ Efficient Loader & Eff. Loader SDXL + + +

+ + +

+
+ +
+ KSampler (Efficient), KSampler Adv. (Efficient), KSampler SDXL (Eff.) + +- Modded KSamplers with the ability to live preview generations and/or vae decode images. +- Feature a special seed box that allows for a clearer management of seeds. (-1 seed to apply the selected seed behavior) +- Can execute a variety of scripts, such as the XY Plot script. To activate the script, simply connect the input connection. + +

+ +       + +       + +

+ +
+ +
+ Script Nodes + +- A group of node's that are used in conjuction with the Efficient KSamplers to execute a variety of 'pre-wired' actions. +- Script nodes can be chained if their input/outputs allow it. Multiple instances of the same Script Node in a chain does nothing. +

+ +

+ +
+ XY Plot + +

+ +

+ +
+ +
+ HighRes-Fix + +

+ +

+ +
+ +
+ Noise Control + +

+ +

+ +
+ +
+ Tiled Upscaler + +

+ +

+ +
+ +
+ AnimateDiff + +

+ +

+ +
+
+ + +
+ Image Overlay + +- Node that allows for flexible image overlaying. Works also with image batches. + +

+ +

+ +
+ +
+ SimpleEval Nodes + +- A collection of nodes that allows users to write simple Python expressions for a variety of data types using the simpleeval library. + +

+ +     + +     + +

+ +
+ +## **Workflow Examples:** + +- HiResFixing with the **HiRes-Fix Script** node + +[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/v2.0/workflows/HiResFix%20Script.png) + +- SDXL Refining using the **Eff. SDXL Loader**, & **Ksampler SDXL (Eff.)** + +[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/v2.0/workflows/SDXL%20Refining%20%26%20Noise%20Control%20Script.png) + +- Comparing LoRA Model & Clip Strenghts via the **XY Plot** node. + +[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/v2.0/workflows/XYPlot%20-%20LoRA%20Model%20vs%20Clip%20Strengths.png) + +- Stacking Scripts: **XY Plot** + **Noise Control** + **HiRes-Fix** + +[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/v2.0/workflows/XYPlot%20-%20Seeds%20vs%20Checkpoints%20%26%20Stacked%20Scripts.png) + +- Stacking Scripts: **AnimateDiff** + **HiRes-Fix** +[](https://github.com/LucianoCirino/efficiency-nodes-comfyui/blob/v2.0/workflows/AnimateDiff%20%26%20HiResFix%20Scripts.gif) + +### Dependencies +Dependencies are automatically installed during ComfyUI boot up. + +## **Install:** +To install, drop the "_**efficiency-nodes-comfyui**_" folder into the "_**...\ComfyUI\ComfyUI\custom_nodes**_" directory and restart UI. diff --git a/efficiency_nodes.py b/efficiency_nodes.py index 0ef9eaf..fb17119 100644 --- a/efficiency_nodes.py +++ b/efficiency_nodes.py @@ -1,5 +1,5 @@ # Efficiency Nodes - A collection of my ComfyUI custom nodes to help streamline workflows and reduce total node count. -# by Luciano Cirino (Discord: TSC#9184) - April 2023 - August 2023 +# by Luciano Cirino (Discord: TSC#9184) - April 2023 - October 2023 # https://github.com/LucianoCirino/efficiency-nodes-comfyui from torch import Tensor @@ -29,8 +29,11 @@ font_path = os.path.join(my_dir, 'arial.ttf') # Append comfy_dir to sys.path & import files sys.path.append(comfy_dir) from nodes import LatentUpscaleBy, KSampler, KSamplerAdvanced, VAEDecode, VAEDecodeTiled, VAEEncode, VAEEncodeTiled, \ - ImageScaleBy, CLIPSetLastLayer, CLIPTextEncode, ControlNetLoader, ControlNetApply, ControlNetApplyAdvanced + ImageScaleBy, CLIPSetLastLayer, CLIPTextEncode, ControlNetLoader, ControlNetApply, ControlNetApplyAdvanced, \ + PreviewImage, MAX_RESOLUTION +from comfy_extras.nodes_upscale_model import UpscaleModelLoader, ImageUpscaleWithModel from comfy_extras.nodes_clip_sdxl import CLIPTextEncodeSDXL, CLIPTextEncodeSDXLRefiner +import comfy.sample import comfy.samplers import comfy.sd import comfy.utils @@ -40,46 +43,46 @@ sys.path.remove(comfy_dir) # Append my_dir to sys.path & import files sys.path.append(my_dir) from tsc_utils import * +from .py import smZ_cfg_denoiser +from .py import smZ_rng_source +from .py import cg_mixed_seed_noise +from .py import city96_latent_upscaler +from .py import ttl_nn_latent_upscaler +from .py import bnk_tiled_samplers +from .py import bnk_adv_encode sys.path.remove(my_dir) # Append custom_nodes_dir to sys.path sys.path.append(custom_nodes_dir) # GLOBALS -MAX_RESOLUTION=8192 REFINER_CFG_OFFSET = 0 #Refiner CFG Offset ######################################################################################################################## # Common function for encoding prompts -def encode_prompts(positive_prompt, negative_prompt, clip, clip_skip, refiner_clip, refiner_clip_skip, ascore, is_sdxl, - empty_latent_width, empty_latent_height, return_type="both"): +def encode_prompts(positive_prompt, negative_prompt, token_normalization, weight_interpretation, clip, clip_skip, + refiner_clip, refiner_clip_skip, ascore, is_sdxl, empty_latent_width, empty_latent_height, + return_type="both"): positive_encoded = negative_encoded = refiner_positive_encoded = refiner_negative_encoded = None # Process base encodings if needed if return_type in ["base", "both"]: - # Base clip skip clip = CLIPSetLastLayer().set_last_layer(clip, clip_skip)[0] - if not is_sdxl: - positive_encoded = CLIPTextEncode().encode(clip, positive_prompt)[0] - negative_encoded = CLIPTextEncode().encode(clip, negative_prompt)[0] - else: - # Encode prompt for base - positive_encoded = CLIPTextEncodeSDXL().encode(clip, empty_latent_width, empty_latent_height, 0, 0, - empty_latent_width, empty_latent_height, positive_prompt, - positive_prompt)[0] - negative_encoded = CLIPTextEncodeSDXL().encode(clip, empty_latent_width, empty_latent_height, 0, 0, - empty_latent_width, empty_latent_height, negative_prompt, - negative_prompt)[0] + + positive_encoded = bnk_adv_encode.AdvancedCLIPTextEncode().encode(clip, positive_prompt, token_normalization, weight_interpretation)[0] + negative_encoded = bnk_adv_encode.AdvancedCLIPTextEncode().encode(clip, negative_prompt, token_normalization, weight_interpretation)[0] + # Process refiner encodings if needed if return_type in ["refiner", "both"] and is_sdxl and refiner_clip and refiner_clip_skip and ascore: - # Refiner clip skip refiner_clip = CLIPSetLastLayer().set_last_layer(refiner_clip, refiner_clip_skip)[0] - # Encode prompt for refiner - refiner_positive_encoded = CLIPTextEncodeSDXLRefiner().encode(refiner_clip, ascore[0], empty_latent_width, - empty_latent_height, positive_prompt)[0] - refiner_negative_encoded = CLIPTextEncodeSDXLRefiner().encode(refiner_clip, ascore[1], empty_latent_width, - empty_latent_height, negative_prompt)[0] + + refiner_positive_encoded = bnk_adv_encode.AdvancedCLIPTextEncode().encode(refiner_clip, positive_prompt, token_normalization, weight_interpretation)[0] + refiner_positive_encoded = bnk_adv_encode.AddCLIPSDXLRParams().encode(refiner_positive_encoded, empty_latent_width, empty_latent_height, ascore[0])[0] + + refiner_negative_encoded = bnk_adv_encode.AdvancedCLIPTextEncode().encode(refiner_clip, negative_prompt, token_normalization, weight_interpretation)[0] + refiner_negative_encoded = bnk_adv_encode.AddCLIPSDXLRParams().encode(refiner_negative_encoded, empty_latent_width, empty_latent_height, ascore[1])[0] + # Return results based on return_type if return_type == "base": return positive_encoded, negative_encoded, clip @@ -100,11 +103,13 @@ class TSC_EfficientLoader: "lora_name": (["None"] + folder_paths.get_filename_list("loras"),), "lora_model_strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), "lora_clip_strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - "positive": ("STRING", {"default": "Positive","multiline": True}), - "negative": ("STRING", {"default": "Negative", "multiline": True}), + "positive": ("STRING", {"default": "CLIP_POSITIVE","multiline": True}), + "negative": ("STRING", {"default": "CLIP_NEGATIVE", "multiline": True}), + "token_normalization": (["none", "mean", "length", "length+mean"],), + "weight_interpretation": (["comfy", "A1111", "compel", "comfy++", "down_weight"],), "empty_latent_width": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), "empty_latent_height": ("INT", {"default": 512, "min": 64, "max": MAX_RESOLUTION, "step": 64}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}, + "batch_size": ("INT", {"default": 1, "min": 1, "max": 262144})}, "optional": {"lora_stack": ("LORA_STACK", ), "cnet_stack": ("CONTROL_NET_STACK",)}, "hidden": { "prompt": "PROMPT", @@ -117,9 +122,9 @@ class TSC_EfficientLoader: CATEGORY = "Efficiency Nodes/Loaders" def efficientloader(self, ckpt_name, vae_name, clip_skip, lora_name, lora_model_strength, lora_clip_strength, - positive, negative, empty_latent_width, empty_latent_height, batch_size, lora_stack=None, - cnet_stack=None, refiner_name="None", ascore=None, prompt=None, my_unique_id=None, - loader_type="regular"): + positive, negative, token_normalization, weight_interpretation, empty_latent_width, + empty_latent_height, batch_size, lora_stack=None, cnet_stack=None, refiner_name="None", + ascore=None, prompt=None, my_unique_id=None, loader_type="regular"): # Clean globally stored objects globals_cleanup(prompt) @@ -164,8 +169,9 @@ class TSC_EfficientLoader: # Encode prompt based on loader_type positive_encoded, negative_encoded, clip, refiner_positive_encoded, refiner_negative_encoded, refiner_clip = \ - encode_prompts(positive, negative, clip, clip_skip, refiner_clip, refiner_clip_skip, ascore, - loader_type == "sdxl", empty_latent_width, empty_latent_height) + encode_prompts(positive, negative, token_normalization, weight_interpretation, clip, clip_skip, + refiner_clip, refiner_clip_skip, ascore, loader_type == "sdxl", + empty_latent_width, empty_latent_height) # Apply ControlNet Stack if given if cnet_stack: @@ -178,7 +184,8 @@ class TSC_EfficientLoader: # Data for XY Plot dependencies = (vae_name, ckpt_name, clip, clip_skip, refiner_name, refiner_clip, refiner_clip_skip, - positive, negative, ascore, empty_latent_width, empty_latent_height, lora_params, cnet_stack) + positive, negative, token_normalization, weight_interpretation, ascore, + empty_latent_width, empty_latent_height, lora_params, cnet_stack) ### Debugging ###print_loaded_objects_entries() @@ -197,14 +204,16 @@ class TSC_EfficientLoaderSDXL(TSC_EfficientLoader): @classmethod def INPUT_TYPES(cls): return {"required": { "base_ckpt_name": (folder_paths.get_filename_list("checkpoints"),), - "base_clip_skip": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), + "base_clip_skip": ("INT", {"default": -2, "min": -24, "max": -1, "step": 1}), "refiner_ckpt_name": (["None"] + folder_paths.get_filename_list("checkpoints"),), - "refiner_clip_skip": ("INT", {"default": -1, "min": -24, "max": -1, "step": 1}), + "refiner_clip_skip": ("INT", {"default": -2, "min": -24, "max": -1, "step": 1}), "positive_ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "negative_ascore": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 1000.0, "step": 0.01}), "vae_name": (["Baked VAE"] + folder_paths.get_filename_list("vae"),), - "positive": ("STRING", {"default": "Positive","multiline": True}), - "negative": ("STRING", {"default": "Negative", "multiline": True}), + "positive": ("STRING", {"default": "CLIP_POSITIVE", "multiline": True}), + "negative": ("STRING", {"default": "CLIP_NEGATIVE", "multiline": True}), + "token_normalization": (["none", "mean", "length", "length+mean"],), + "weight_interpretation": (["comfy", "A1111", "compel", "comfy++", "down_weight"],), "empty_latent_width": ("INT", {"default": 1024, "min": 64, "max": MAX_RESOLUTION, "step": 128}), "empty_latent_height": ("INT", {"default": 1024, "min": 64, "max": MAX_RESOLUTION, "step": 128}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 64})}, @@ -218,15 +227,16 @@ class TSC_EfficientLoaderSDXL(TSC_EfficientLoader): CATEGORY = "Efficiency Nodes/Loaders" def efficientloaderSDXL(self, base_ckpt_name, base_clip_skip, refiner_ckpt_name, refiner_clip_skip, positive_ascore, - negative_ascore, vae_name, positive, negative, empty_latent_width, empty_latent_height, - batch_size, lora_stack=None, cnet_stack=None, prompt=None, my_unique_id=None): + negative_ascore, vae_name, positive, negative, token_normalization, weight_interpretation, + empty_latent_width, empty_latent_height, batch_size, lora_stack=None, cnet_stack=None, + prompt=None, my_unique_id=None): clip_skip = (base_clip_skip, refiner_clip_skip) lora_name = "None" lora_model_strength = lora_clip_strength = 0 return super().efficientloader(base_ckpt_name, vae_name, clip_skip, lora_name, lora_model_strength, lora_clip_strength, - positive, negative, empty_latent_width, empty_latent_height, batch_size, lora_stack=lora_stack, - cnet_stack=cnet_stack, refiner_name=refiner_ckpt_name, ascore=(positive_ascore, negative_ascore), - prompt=prompt, my_unique_id=my_unique_id, loader_type="sdxl") + positive, negative, token_normalization, weight_interpretation, empty_latent_width, empty_latent_height, + batch_size, lora_stack=lora_stack, cnet_stack=cnet_stack, refiner_name=refiner_ckpt_name, + ascore=(positive_ascore, negative_ascore), prompt=prompt, my_unique_id=my_unique_id, loader_type="sdxl") #======================================================================================================================= # TSC Unpack SDXL Tuple @@ -377,21 +387,17 @@ class TSC_Apply_ControlNet_Stack: return (positive, negative, ) + ######################################################################################################################## # TSC KSampler (Efficient) class TSC_KSampler: empty_image = pil2tensor(Image.new('RGBA', (1, 1), (0, 0, 0, 0))) - def __init__(self): - self.output_dir = os.path.join(comfy_dir, 'temp') - self.type = "temp" - @classmethod def INPUT_TYPES(cls): return {"required": - {"sampler_state": (["Sample", "Script", "Hold"], ), - "model": ("MODEL",), + {"model": ("MODEL",), "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0}), @@ -401,8 +407,8 @@ class TSC_KSampler: "negative": ("CONDITIONING",), "latent_image": ("LATENT",), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "preview_method": (["auto", "latent2rgb", "taesd", "none"],), - "vae_decode": (["true", "true (tiled)", "false", "output only", "output only (tiled)"],), + "preview_method": (["auto", "latent2rgb", "taesd", "vae_decoded_only", "none"],), + "vae_decode": (["true", "true (tiled)", "false"],), }, "optional": { "optional_vae": ("VAE",), "script": ("SCRIPT",),}, @@ -415,8 +421,8 @@ class TSC_KSampler: FUNCTION = "sample" CATEGORY = "Efficiency Nodes/Sampling" - def sample(self, sampler_state, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, - latent_image, preview_method, vae_decode, denoise=1.0, prompt=None, extra_pnginfo=None, my_unique_id=None, + def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + preview_method, vae_decode, denoise=1.0, prompt=None, extra_pnginfo=None, my_unique_id=None, optional_vae=(None,), script=None, add_noise=None, start_at_step=None, end_at_step=None, return_with_leftover_noise=None, sampler_type="regular"): @@ -440,93 +446,6 @@ class TSC_KSampler: def keys_exist_in_script(*keys): return any(key in script for key in keys) if script else False - # If no valid script input connected, error out - if not keys_exist_in_script("xyplot", "hiresfix", "tile") and sampler_state == "Script": - print(f"{error('KSampler(Efficient) Error:')} No valid script input detected") - if sampler_type == "sdxl": - result = (sdxl_tuple, latent_image, vae, TSC_KSampler.empty_image,) - else: - result = (model, positive, negative, latent_image, vae, TSC_KSampler.empty_image,) - return {"ui": {"images": list()}, "result": result} - - #--------------------------------------------------------------------------------------------------------------- - def map_filename(filename): - prefix_len = len(os.path.basename(filename_prefix)) - prefix = filename[:prefix_len + 1] - try: - digits = int(filename[prefix_len + 1:].split('_')[0]) - except: - digits = 0 - return (digits, prefix) - - def compute_vars(images,input): - input = input.replace("%width%", str(images[0].shape[1])) - input = input.replace("%height%", str(images[0].shape[0])) - return input - - def preview_image(images, filename_prefix): - - if images == list(): - return list() - - filename_prefix = compute_vars(images,filename_prefix) - - subfolder = os.path.dirname(os.path.normpath(filename_prefix)) - filename = os.path.basename(os.path.normpath(filename_prefix)) - - full_output_folder = os.path.join(self.output_dir, subfolder) - - try: - counter = max(filter(lambda a: a[1][:-1] == filename and a[1][-1] == "_", - map(map_filename, os.listdir(full_output_folder))))[0] + 1 - except ValueError: - counter = 1 - except FileNotFoundError: - os.makedirs(full_output_folder, exist_ok=True) - counter = 1 - - if not os.path.exists(self.output_dir): - os.makedirs(self.output_dir) - - results = list() - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) - file = f"{filename}_{counter:05}_.png" - img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=4) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }); - counter += 1 - return results - - #--------------------------------------------------------------------------------------------------------------- - def get_value_by_id(key: str, my_unique_id): - global last_helds - for value, id_ in last_helds[key]: - if id_ == my_unique_id: - return value - return None - - def update_value_by_id(key: str, my_unique_id, new_value): - global last_helds - - for i, (value, id_) in enumerate(last_helds[key]): - if id_ == my_unique_id: - last_helds[key][i] = (new_value, id_) - return True - - last_helds[key].append((new_value, my_unique_id)) - return True - #--------------------------------------------------------------------------------------------------------------- def vae_decode_latent(vae, samples, vae_decode): return VAEDecodeTiled().decode(vae,samples,512)[0] if "tiled" in vae_decode else VAEDecode().decode(vae,samples)[0] @@ -535,201 +454,271 @@ class TSC_KSampler: return VAEEncodeTiled().encode(vae,pixels,512)[0] if "tiled" in vae_decode else VAEEncode().encode(vae,pixels)[0] # --------------------------------------------------------------------------------------------------------------- - def sample_latent_image(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, + def process_latent_image(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise, sampler_type, add_noise, start_at_step, end_at_step, return_with_leftover_noise, - refiner_model, refiner_positive, refiner_negative, vae, vae_decode, sampler_state): + refiner_model, refiner_positive, refiner_negative, vae, vae_decode, preview_method): - # Sample the latent_image(s) using the Comfy KSampler nodes - if sampler_type == "regular": - samples = KSampler().sample(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, - latent_image, denoise=denoise)[0] + # Store originals + previous_preview_method = global_preview_method() + original_prepare_noise = comfy.sample.prepare_noise + original_KSampler = comfy.samplers.KSampler + original_model_str = str(model) - elif sampler_type == "advanced": - samples = KSamplerAdvanced().sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent_image, start_at_step, end_at_step, - return_with_leftover_noise, denoise=1.0)[0] + # Initialize output variables + samples = images = gifs = preview = cnet_imgs = None - elif sampler_type == "sdxl": - # Disable refiner if refine_at_step is -1 - if end_at_step == -1: - end_at_step = steps + try: + # Change the global preview method (temporarily) + set_preview_method(preview_method) - # Perform base model sampling - add_noise = return_with_leftover_noise = True - samples = KSamplerAdvanced().sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent_image, start_at_step, end_at_step, - return_with_leftover_noise, denoise=1.0)[0] + # ------------------------------------------------------------------------------------------------------ + # Check if "noise" exists in the script before main sampling has taken place + if keys_exist_in_script("noise"): + rng_source, cfg_denoiser, add_seed_noise, m_seed, m_weight = script["noise"] + smZ_rng_source.rng_rand_source(rng_source) # this function monkey patches comfy.sample.prepare_noise + if cfg_denoiser: + comfy.samplers.KSampler = smZ_cfg_denoiser.SDKSampler + if add_seed_noise: + comfy.sample.prepare_noise = cg_mixed_seed_noise.get_mixed_noise_function(comfy.sample.prepare_noise, m_seed, m_weight) + else: + m_seed = m_weight = None + else: + rng_source = cfg_denoiser = add_seed_noise = m_seed = m_weight = None - # Perform refiner model sampling - if refiner_model and end_at_step < steps: - add_noise = return_with_leftover_noise = False - samples = KSamplerAdvanced().sample(refiner_model, add_noise, seed, steps, cfg + REFINER_CFG_OFFSET, - sampler_name, scheduler, refiner_positive, refiner_negative, - samples, end_at_step, steps, + # ------------------------------------------------------------------------------------------------------ + # Check if "anim" exists in the script before main sampling has taken place + if keys_exist_in_script("anim"): + if preview_method != "none": + set_preview_method("none") # disable preview method + print(f"{warning('KSampler(Efficient) Warning:')} Live preview disabled for animatediff generations.") + motion_model, beta_schedule, context_options, frame_rate, loop_count, format, pingpong, save_image = script["anim"] + model = AnimateDiffLoaderWithContext().load_mm_and_inject_params(model, motion_model, beta_schedule, context_options)[0] + + # ------------------------------------------------------------------------------------------------------ + # Store run parameters as strings. Load previous stored samples if all parameters match. + latent_image_hash = tensor_to_hash(latent_image["samples"]) + positive_hash = tensor_to_hash(positive[0][0]) + negative_hash = tensor_to_hash(negative[0][0]) + refiner_positive_hash = tensor_to_hash(refiner_positive[0][0]) if refiner_positive is not None else None + refiner_negative_hash = tensor_to_hash(refiner_negative[0][0]) if refiner_negative is not None else None + + # Include motion_model, beta_schedule, and context_options as unique identifiers if they exist. + model_identifier = [original_model_str, motion_model, beta_schedule, context_options] if keys_exist_in_script("anim")\ + else [original_model_str] + + parameters = [model_identifier] + [seed, steps, cfg, sampler_name, scheduler, positive_hash, negative_hash, + latent_image_hash, denoise, sampler_type, add_noise, start_at_step, + end_at_step, return_with_leftover_noise, refiner_model, refiner_positive_hash, + refiner_negative_hash, rng_source, cfg_denoiser, add_seed_noise, m_seed, m_weight] + + # Convert all elements in parameters to strings, except for the hash variable checks + parameters = [str(item) if not isinstance(item, type(latent_image_hash)) else item for item in parameters] + + # Load previous latent if all parameters match, else returns 'None' + samples = load_ksampler_results("latent", my_unique_id, parameters) + + if samples is None: # clear stored images + store_ksampler_results("image", my_unique_id, None) + store_ksampler_results("cnet_img", my_unique_id, None) + + if samples is not None: # do not re-sample + images = load_ksampler_results("image", my_unique_id) + cnet_imgs = True # "True" will denote that it can be loaded provided the preprocessor matches + + # Sample the latent_image(s) using the Comfy KSampler nodes + elif sampler_type == "regular": + samples = KSampler().sample(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, + latent_image, denoise=denoise)[0] if denoise>0 else latent_image + + elif sampler_type == "advanced": + samples = KSamplerAdvanced().sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0)[0] - if sampler_state == "Script": + elif sampler_type == "sdxl": + # Disable refiner if refine_at_step is -1 + if end_at_step == -1: + end_at_step = steps + # Perform base model sampling + add_noise = return_with_leftover_noise = True + samples = KSamplerAdvanced().sample(model, add_noise, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, start_at_step, end_at_step, + return_with_leftover_noise, denoise=1.0)[0] + + # Perform refiner model sampling + if refiner_model and end_at_step < steps: + add_noise = return_with_leftover_noise = False + samples = KSamplerAdvanced().sample(refiner_model, add_noise, seed, steps, cfg + REFINER_CFG_OFFSET, + sampler_name, scheduler, refiner_positive, refiner_negative, + samples, end_at_step, steps, + return_with_leftover_noise, denoise=1.0)[0] + + # Cache the first pass samples in the 'last_helds' dictionary "latent" if not xyplot + if not any(keys_exist_in_script(key) for key in ["xyplot"]): + store_ksampler_results("latent", my_unique_id, samples, parameters) + + # ------------------------------------------------------------------------------------------------------ # Check if "hiresfix" exists in the script after main sampling has taken place if keys_exist_in_script("hiresfix"): # Unpack the tuple from the script's "hiresfix" key - latent_upscale_method, upscale_by, hires_steps, hires_denoise, iterations, upscale_function = script["hiresfix"] - # Iterate for the given number of iterations - for _ in range(iterations): - upscaled_latent_image = upscale_function().upscale(samples, latent_upscale_method, upscale_by)[0] - samples = KSampler().sample(model, seed, hires_steps, cfg, sampler_name, scheduler, - positive, negative, upscaled_latent_image, denoise=hires_denoise)[0] + upscale_type, latent_upscaler, upscale_by, use_same_seed, hires_seed, hires_steps, hires_denoise,\ + iterations, hires_control_net, hires_cnet_strength, preprocessor, preprocessor_imgs, \ + latent_upscale_function, latent_upscale_model, pixel_upscale_model = script["hiresfix"] + # Define hires_seed + hires_seed = seed if use_same_seed else hires_seed + + # Define latent_upscale_model + if latent_upscale_model is None: + latent_upscale_model = model + elif keys_exist_in_script("anim"): + latent_upscale_model = \ + AnimateDiffLoaderWithContext().load_mm_and_inject_params(latent_upscale_model, motion_model, + beta_schedule, context_options)[0] + + # Generate Preprocessor images and Apply Control Net + if hires_control_net is not None: + # Attempt to load previous "cnet_imgs" if previous images were loaded and preprocessor is same + if cnet_imgs is True: + cnet_imgs = load_ksampler_results("cnet_img", my_unique_id, [preprocessor]) + # If cnet_imgs is None, generate new ones + if cnet_imgs is None: + if images is None: + images = vae_decode_latent(vae, samples, vae_decode) + store_ksampler_results("image", my_unique_id, images) + cnet_imgs = AIO_Preprocessor().execute(preprocessor, images)[0] + store_ksampler_results("cnet_img", my_unique_id, cnet_imgs, [preprocessor]) + positive = ControlNetApply().apply_controlnet(positive, hires_control_net, cnet_imgs, hires_cnet_strength)[0] + + # Iterate for the given number of iterations + if upscale_type == "latent": + for _ in range(iterations): + upscaled_latent_image = latent_upscale_function().upscale(samples, latent_upscaler, upscale_by)[0] + samples = KSampler().sample(latent_upscale_model, hires_seed, hires_steps, cfg, sampler_name, scheduler, + positive, negative, upscaled_latent_image, denoise=hires_denoise)[0] + images = None # set to None when samples is updated + elif upscale_type == "pixel": + if images is None: + images = vae_decode_latent(vae, samples, vae_decode) + store_ksampler_results("image", my_unique_id, images) + images = ImageUpscaleWithModel().upscale(pixel_upscale_model, images)[0] + images = ImageScaleBy().upscale(images, "nearest-exact", upscale_by/4)[0] + + # ------------------------------------------------------------------------------------------------------ # Check if "tile" exists in the script after main sampling has taken place if keys_exist_in_script("tile"): # Unpack the tuple from the script's "tile" key - upscale_by, tile_controlnet, tile_size, tiling_strategy, tiling_steps, tiled_denoise,\ - blenderneko_tiled_nodes = script["tile"] - # VAE Decode samples - image = vae_decode_latent(vae, samples, vae_decode) + upscale_by, tile_size, tiling_strategy, tiling_steps, tile_seed, tiled_denoise,\ + tile_controlnet, strength = script["tile"] + + # Decode image, store if first decode + if images is None: + images = vae_decode_latent(vae, samples, vae_decode) + if not any(keys_exist_in_script(key) for key in ["xyplot", "hiresfix"]): + store_ksampler_results("image", my_unique_id, images) + # Upscale image - upscaled_image = ImageScaleBy().upscale(image, "nearest-exact", upscale_by)[0] + upscaled_image = ImageScaleBy().upscale(images, "nearest-exact", upscale_by)[0] upscaled_latent = vae_encode_image(vae, upscaled_image, vae_decode) - # Apply Control Net using upscaled_image and loaded control_net - positive = ControlNetApply().apply_controlnet(positive, tile_controlnet, upscaled_image, 1)[0] + + # If using Control Net, Apply Control Net using upscaled_image and loaded control_net + if tile_controlnet is not None: + positive = ControlNetApply().apply_controlnet(positive, tile_controlnet, upscaled_image, 1)[0] + # Sample latent - TSampler = blenderneko_tiled_nodes.TiledKSampler - samples = TSampler().sample(model, seed-1, tile_size, tile_size, tiling_strategy, tiling_steps, cfg, + TSampler = bnk_tiled_samplers.TiledKSampler + samples = TSampler().sample(model, tile_seed, tile_size, tile_size, tiling_strategy, tiling_steps, cfg, sampler_name, scheduler, positive, negative, upscaled_latent, denoise=tiled_denoise)[0] - return samples + images = None # set to None when samples is updated + + # ------------------------------------------------------------------------------------------------------ + # Check if "anim" exists in the script after the main sampling has taken place + if keys_exist_in_script("anim"): + if images is None: + images = vae_decode_latent(vae, samples, vae_decode) + if not any(keys_exist_in_script(key) for key in ["xyplot", "hiresfix", "tile"]): + store_ksampler_results("image", my_unique_id, images) + gifs = AnimateDiffCombine().generate_gif(images, frame_rate, loop_count, format=format, + pingpong=pingpong, save_image=save_image, prompt=prompt, extra_pnginfo=extra_pnginfo)["ui"]["gifs"] + + # ------------------------------------------------------------------------------------------------------ + + # Decode image if not yet decoded + if "true" in vae_decode: + if images is None: + images = vae_decode_latent(vae, samples, vae_decode) + # Store decoded image as base image of no script is detected + if all(not keys_exist_in_script(key) for key in ["xyplot", "hiresfix", "tile", "anim"]): + store_ksampler_results("image", my_unique_id, images) + + # Append Control Net Images (if exist) + if cnet_imgs is not None and not True: + if preprocessor_imgs and upscale_type == "latent": + if keys_exist_in_script("xyplot"): + print( + f"{warning('HighRes-Fix Warning:')} Preprocessor images auto-disabled when XY Plotting.") + else: + # Resize cnet_imgs if necessary and stack + if images.shape[1:3] != cnet_imgs.shape[1:3]: # comparing height and width + cnet_imgs = quick_resize(cnet_imgs, images.shape) + images = torch.cat([images, cnet_imgs], dim=0) + + # Define preview images + if keys_exist_in_script("anim"): + preview = {"gifs": gifs, "images": list()} + elif preview_method == "none" or (preview_method == "vae_decoded_only" and vae_decode == "false"): + preview = {"images": list()} + elif images is not None: + preview = PreviewImage().save_images(images, prompt=prompt, extra_pnginfo=extra_pnginfo)["ui"] + + # Define a dummy output image + if images is None and vae_decode == "false": + images = TSC_KSampler.empty_image + + finally: + # Restore global changes + set_preview_method(previous_preview_method) + comfy.samplers.KSampler = original_KSampler + comfy.sample.prepare_noise = original_prepare_noise + + return samples, images, gifs, preview # --------------------------------------------------------------------------------------------------------------- # Clean globally stored objects of non-existant nodes globals_cleanup(prompt) - # Init last_preview_images - if get_value_by_id("preview_images", my_unique_id) is None: - last_preview_images = list() - else: - last_preview_images = get_value_by_id("preview_images", my_unique_id) - - # Init last_latent - if get_value_by_id("latent", my_unique_id) is None: - last_latent = latent_image - else: - last_latent = {"samples": None} - last_latent = get_value_by_id("latent", my_unique_id) - - # Init last_output_images - if get_value_by_id("output_images", my_unique_id) == None: - last_output_images = TSC_KSampler.empty_image - else: - last_output_images = get_value_by_id("output_images", my_unique_id) - - # Define filename_prefix - filename_prefix = "KSeff_{}".format(my_unique_id) - # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - # If the sampler state is "Sample" or "Script" without XY Plot - if sampler_state == "Sample" or (sampler_state == "Script" and not keys_exist_in_script("xyplot")): + # If not XY Plotting + if not keys_exist_in_script("xyplot"): - # Store the global preview method - previous_preview_method = global_preview_method() - - # Change the global preview method temporarily during sampling - set_preview_method(preview_method) - - # Define commands arguments to send to front-end via websocket - if preview_method != "none" and "true" in vae_decode: - send_command_to_frontend(startListening=True, maxCount=steps-1, sendBlob=False) - - samples = sample_latent_image(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, - latent_image, denoise, sampler_type, add_noise, start_at_step, end_at_step, - return_with_leftover_noise, refiner_model, refiner_positive, refiner_negative, - vae, vae_decode, sampler_state) - - # Cache samples in the 'last_helds' dictionary "latent" - update_value_by_id("latent", my_unique_id, samples) - - # Define node output images & next Hold's vae_decode behavior - output_images = node_images = get_latest_image() ### - if vae_decode == "false": - update_value_by_id("vae_decode_flag", my_unique_id, True) - if preview_method == "none" or output_images == list(): - output_images = TSC_KSampler.empty_image - else: - update_value_by_id("vae_decode_flag", my_unique_id, False) - decoded_image = vae_decode_latent(vae, samples, vae_decode) - output_images = node_images = decoded_image - - # Cache output images to global 'last_helds' dictionary "output_images" - update_value_by_id("output_images", my_unique_id, output_images) - - # Generate preview_images (PIL) - preview_images = preview_image(node_images, filename_prefix) - - # Cache node preview images to global 'last_helds' dictionary "preview_images" - update_value_by_id("preview_images", my_unique_id, preview_images) - - # Set xy_plot_flag to 'False' and set the stored (if any) XY Plot image tensor to 'None' - update_value_by_id("xy_plot_flag", my_unique_id, False) - update_value_by_id("xy_plot_image", my_unique_id, None) - - if "output only" in vae_decode: - preview_images = list() - - if preview_method != "none": - # Send message to front-end to revoke the last blob image from browser's memory (fixes preview duplication bug) - send_command_to_frontend(startListening=False) + # Process latent image + samples, images, gifs, preview = process_latent_image(model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent_image, denoise, sampler_type, add_noise, + start_at_step, end_at_step, return_with_leftover_noise, refiner_model, + refiner_positive, refiner_negative, vae, vae_decode, preview_method) if sampler_type == "sdxl": - result = (sdxl_tuple, samples, vae, output_images,) + result = (sdxl_tuple, samples, vae, images,) else: - result = (model, positive, negative, samples, vae, output_images,) - return result if not preview_images and preview_method != "none" else {"ui": {"images": preview_images}, "result": result} + result = (model, positive, negative, samples, vae, images,) + + if preview is None: + return {"result": result} + else: + return {"ui": preview, "result": result} # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - # If the sampler state is "Hold" - elif sampler_state == "Hold": - output_images = last_output_images - preview_images = last_preview_images if "true" in vae_decode else list() - - if get_value_by_id("vae_decode_flag", my_unique_id): - if "true" in vae_decode or "output only" in vae_decode: - output_images = node_images = vae_decode_latent(vae, last_latent, vae_decode) - update_value_by_id("vae_decode_flag", my_unique_id, False) - update_value_by_id("output_images", my_unique_id, output_images) - preview_images = preview_image(node_images, filename_prefix) - update_value_by_id("preview_images", my_unique_id, preview_images) - if "output only" in vae_decode: - preview_images = list() - - # Check if holding an XY Plot image - elif get_value_by_id("xy_plot_flag", my_unique_id): - # Check if XY Plot node is connected - if keys_exist_in_script("xyplot"): - # Extract the 'xyplot_as_output_image' input parameter from the connected xy_plot - _, _, _, _, _, _, _, xyplot_as_output_image, _, _ = script["xyplot"] - if xyplot_as_output_image == True: - output_images = get_value_by_id("xy_plot_image", my_unique_id) - else: - output_images = get_value_by_id("output_images", my_unique_id) - preview_images = last_preview_images - else: - output_images = last_output_images - preview_images = last_preview_images if "true" in vae_decode else list() - - if sampler_type == "sdxl": - result = (sdxl_tuple, last_latent, vae, output_images,) - else: - result = (model, positive, negative, last_latent, vae, output_images,) - return {"ui": {"images": preview_images}, "result": result} - - # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - # If the sampler state is "Script" with XY Plot - elif sampler_state == "Script" and keys_exist_in_script("xyplot"): + # If XY Plot + elif keys_exist_in_script("xyplot"): # If no vae connected, throw errors if vae == (None,): print(f"{error('KSampler(Efficient) Error:')} VAE input must be connected in order to use the XY Plot script.") + return {"ui": {"images": list()}, - "result": (model, positive, negative, last_latent, vae, TSC_KSampler.empty_image,)} + "result": (model, positive, negative, latent_image, vae, TSC_KSampler.empty_image,)} # If vae_decode is not set to true, print message that changing it to true if "true" not in vae_decode: @@ -823,14 +812,14 @@ class TSC_KSampler: f"\nDisallowed XY_types for this KSampler are: {', '.join(disallowed_XY_types)}.") return {"ui": {"images": list()}, - "result": (model, positive, negative, last_latent, vae, TSC_KSampler.empty_image,)} + "result": (model, positive, negative, latent_image, vae, TSC_KSampler.empty_image,)} #_______________________________________________________________________________________________________ # Unpack Effficient Loader dependencies if dependencies is not None: vae_name, ckpt_name, clip, clip_skip, refiner_name, refiner_clip, refiner_clip_skip,\ - positive_prompt, negative_prompt, ascore, empty_latent_width, empty_latent_height,\ - lora_stack, cnet_stack = dependencies + positive_prompt, negative_prompt, token_normalization, weight_interpretation, ascore,\ + empty_latent_width, empty_latent_height, lora_stack, cnet_stack = dependencies #_______________________________________________________________________________________________________ # Printout XY Plot values to be processed @@ -852,7 +841,7 @@ class TSC_KSampler: else (os.path.basename(v[0]), v[1]) if v[2] is None else (os.path.basename(v[0]),) + v[1:] for v in value] - elif (type_ == "LoRA" or type_ == "LoRA Stacks") and isinstance(value, list): + elif type_ == "LoRA" and isinstance(value, list): # Return only the first Tuple of each inner array return [[(os.path.basename(v[0][0]),) + v[0][1:], "..."] if len(v) > 1 else [(os.path.basename(v[0][0]),) + v[0][1:]] for v in value] @@ -953,7 +942,6 @@ class TSC_KSampler: "Checkpoint", "Refiner", "LoRA", - "LoRA Stacks", "VAE", ] conditioners = { @@ -998,9 +986,6 @@ class TSC_KSampler: # Create a list of tuples with types and values type_value_pairs = [(X_type, X_value.copy()), (Y_type, Y_value.copy())] - # Replace "LoRA Stacks" with "LoRA" - type_value_pairs = [('LoRA' if t == 'LoRA Stacks' else t, v) for t, v in type_value_pairs] - # Iterate over type-value pairs for t, v in type_value_pairs: if t in dict_map: @@ -1043,7 +1028,7 @@ class TSC_KSampler: elif X_type == "Refiner": ckpt_dict = [] lora_dict = [] - elif X_type in ("LoRA", "LoRA Stacks"): + elif X_type == "LoRA": ckpt_dict = [] refn_dict = [] @@ -1064,7 +1049,7 @@ class TSC_KSampler: lora_stack, cnet_stack, var_label, num_label): # Define default max label size limit - max_label_len = 36 + max_label_len = 42 # If var_type is "AddNoise", update 'add_noise' with 'var', and generate text label if var_type == "AddNoise": @@ -1206,7 +1191,7 @@ class TSC_KSampler: text = f"RefClipSkip ({refiner_clip_skip[0]})" elif "LoRA" in var_type: - if not lora_stack or var_type == "LoRA Stacks": + if not lora_stack: lora_stack = var.copy() else: # Updating the first tuple of lora_stack @@ -1216,7 +1201,7 @@ class TSC_KSampler: lora_name, lora_model_wt, lora_clip_wt = lora_stack[0] lora_filename = os.path.splitext(os.path.basename(lora_name))[0] - if var_type == "LoRA" or var_type == "LoRA Stacks": + if var_type == "LoRA": if len(lora_stack) == 1: lora_model_wt = format(float(lora_model_wt), ".2f").rstrip('0').rstrip('.') lora_clip_wt = format(float(lora_clip_wt), ".2f").rstrip('0').rstrip('.') @@ -1339,7 +1324,7 @@ class TSC_KSampler: # Note: Index is held at 0 when Y_type == "Nothing" # Load Checkpoint if required. If Y_type is LoRA, required models will be loaded by load_lora func. - if (X_type == "Checkpoint" and index == 0 and Y_type not in ("LoRA", "LoRA Stacks")): + if (X_type == "Checkpoint" and index == 0 and Y_type != "LoRA"): if lora_stack is None: model, clip, _ = load_checkpoint(ckpt_name, xyplot_id, cache=cache[1]) else: # Load Efficient Loader LoRA @@ -1348,11 +1333,11 @@ class TSC_KSampler: encode = True # Load LoRA if required - elif (X_type in ("LoRA", "LoRA Stacks") and index == 0): + elif (X_type == "LoRA" and index == 0): # Don't cache Checkpoints model, clip = load_lora(lora_stack, ckpt_name, xyplot_id, cache=cache[2]) encode = True - elif Y_type in ("LoRA", "LoRA Stacks"): # X_type must be Checkpoint, so cache those as defined + elif Y_type == "LoRA": # X_type must be Checkpoint, so cache those as defined model, clip = load_lora(lora_stack, ckpt_name, xyplot_id, cache=None, ckpt_cache=cache[1]) encode = True @@ -1381,9 +1366,9 @@ class TSC_KSampler: # Encode base prompt if encode == True: positive, negative, clip = \ - encode_prompts(positive_prompt, negative_prompt, clip, clip_skip, refiner_clip, - refiner_clip_skip, ascore, sampler_type == "sdxl", empty_latent_width, - empty_latent_height, return_type="base") + encode_prompts(positive_prompt, negative_prompt, token_normalization, weight_interpretation, + clip, clip_skip, refiner_clip, refiner_clip_skip, ascore, sampler_type == "sdxl", + empty_latent_width, empty_latent_height, return_type="base") # Apply ControlNet Stack if given if cnet_stack: controlnet_conditioning = TSC_Apply_ControlNet_Stack().apply_cnet_stack(positive, negative, cnet_stack) @@ -1391,9 +1376,9 @@ class TSC_KSampler: if encode_refiner == True: refiner_positive, refiner_negative, refiner_clip = \ - encode_prompts(positive_prompt, negative_prompt, clip, clip_skip, refiner_clip, - refiner_clip_skip, ascore, sampler_type == "sdxl", empty_latent_width, - empty_latent_height, return_type="refiner") + encode_prompts(positive_prompt, negative_prompt, token_normalization, weight_interpretation, + clip, clip_skip, refiner_clip, refiner_clip_skip, ascore, sampler_type == "sdxl", + empty_latent_width, empty_latent_height, return_type="refiner") # Load VAE if required if (X_type == "VAE" and index == 0) or Y_type == "VAE": @@ -1421,19 +1406,17 @@ class TSC_KSampler: latent_list.append(latent) if capsule_result is None: - if preview_method != "none": - send_command_to_frontend(startListening=True, maxCount=steps - 1, sendBlob=False) - samples = sample_latent_image(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, - latent_image, denoise, sampler_type, add_noise, start_at_step, end_at_step, - return_with_leftover_noise, refiner_model, refiner_positive, refiner_negative, - vae, vae_decode, sampler_state) + samples, images, _, _ = process_latent_image(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, + latent_image, denoise, sampler_type, add_noise, start_at_step, + end_at_step, return_with_leftover_noise, refiner_model, + refiner_positive, refiner_negative, vae, vae_decode, preview_method) # Add the latent tensor to the tensors list latent_list.append(samples) - # Decode the latent tensor - image = vae_decode_latent(vae, samples, vae_decode) + # Decode the latent tensor if required + image = images if images is not None else vae_decode_latent(vae, samples, vae_decode) if xy_capsule is not None: xy_capsule.set_result(image, samples) @@ -1464,12 +1447,6 @@ class TSC_KSampler: # Store types in a Tuple for easy function passing types = (X_type, Y_type) - # Store the global preview method - previous_preview_method = global_preview_method() - - # Change the global preview method temporarily during this node's execution - set_preview_method(preview_method) - # Clone original model parameters def clone_or_none(*originals): cloned_items = [] @@ -1522,11 +1499,10 @@ class TSC_KSampler: elif X_type != "Nothing" and Y_type != "Nothing": for Y_index, Y in enumerate(Y_value): - if Y_type == "XY_Capsule" or X_type == "XY_Capsule": - model, clip, refiner_model, refiner_clip = \ - clone_or_none(original_model, original_clip, original_refiner_model, original_refiner_clip) if Y_type == "XY_Capsule" and X_type == "XY_Capsule": + model, clip, refiner_model, refiner_clip = \ + clone_or_none(original_model, original_clip, original_refiner_model, original_refiner_clip) Y.set_x_capsule(X) # Define Y parameters and generate labels @@ -1572,7 +1548,7 @@ class TSC_KSampler: clear_cache_by_exception(xyplot_id, lora_dict=[], refn_dict=[]) elif X_type == "Refiner": clear_cache_by_exception(xyplot_id, ckpt_dict=[], lora_dict=[]) - elif X_type in ("LoRA", "LoRA Stacks"): + elif X_type == "LoRA": clear_cache_by_exception(xyplot_id, ckpt_dict=[], refn_dict=[]) # __________________________________________________________________________________________________________ @@ -1580,7 +1556,7 @@ class TSC_KSampler: def print_plot_variables(X_type, Y_type, X_value, Y_value, add_noise, seed, steps, start_at_step, end_at_step, return_with_leftover_noise, cfg, sampler_name, scheduler, denoise, vae_name, ckpt_name, clip_skip, refiner_name, refiner_clip_skip, ascore, lora_stack, cnet_stack, sampler_type, - num_rows, num_cols, latent_height, latent_width): + num_rows, num_cols, i_height, i_width): print("-" * 40) # Print an empty line followed by a separator line print(f"{xyplot_message('XY Plot Results:')}") @@ -1672,7 +1648,7 @@ class TSC_KSampler: lora_name = lora_wt = lora_model_str = lora_clip_str = None # Check for all possible LoRA types - lora_types = ["LoRA", "LoRA Stacks", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr"] + lora_types = ["LoRA", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr"] if X_type not in lora_types and Y_type not in lora_types: if lora_stack: @@ -1685,7 +1661,7 @@ class TSC_KSampler: else: if X_type in lora_types: value = get_lora_sublist_name(X_type, X_value) - if X_type in ("LoRA", "LoRA Stacks"): + if X_type == "LoRA": lora_name = value lora_model_str = None lora_clip_str = None @@ -1707,7 +1683,7 @@ class TSC_KSampler: if Y_type in lora_types: value = get_lora_sublist_name(Y_type, Y_value) - if Y_type in ("LoRA", "LoRA Stacks"): + if Y_type == "LoRA": lora_name = value lora_model_str = None lora_clip_str = None @@ -1730,13 +1706,13 @@ class TSC_KSampler: return lora_name, lora_wt, lora_model_str, lora_clip_str def get_lora_sublist_name(lora_type, lora_value): - if lora_type in ("LoRA", "LoRA Batch", "LoRA Stacks"): + if lora_type == "LoRA" or lora_type == "LoRA Batch": formatted_sublists = [] for sublist in lora_value: formatted_entries = [] for x in sublist: base_name = os.path.splitext(os.path.basename(str(x[0])))[0] - formatted_str = f"{base_name}({round(x[1], 3)},{round(x[2], 3)})" if lora_type in ("LoRA", "LoRA Stacks") else f"{base_name}" + formatted_str = f"{base_name}({round(x[1], 3)},{round(x[2], 3)})" if lora_type == "LoRA" else f"{base_name}" formatted_entries.append(formatted_str) formatted_sublists.append(f"{', '.join(formatted_entries)}") return "\n ".join(formatted_sublists) @@ -1820,7 +1796,7 @@ class TSC_KSampler: print(f"(X) {X_type}") print(f"(Y) {Y_type}") print(f"img_count: {len(X_value)*len(Y_value)}") - print(f"img_dims: {latent_height} x {latent_width}") + print(f"img_dims: {i_height} x {i_width}") print(f"plot_dim: {num_cols} x {num_rows}") print(f"ckpt: {ckpt_name if ckpt_name is not None else ''}") if clip_skip: @@ -1921,13 +1897,13 @@ class TSC_KSampler: print(f"cnet_end%: {', '.join(cnet_end_pct) if isinstance(cnet_end_pct, list) else cnet_end_pct}") # ______________________________________________________________________________________________________ - def adjusted_font_size(text, initial_font_size, latent_width): + def adjusted_font_size(text, initial_font_size, i_width): font = ImageFont.truetype(str(Path(font_path)), initial_font_size) text_width = font.getlength(text) - if text_width > (latent_width * 0.9): + if text_width > (i_width * 0.9): scaling_factor = 0.9 # A value less than 1 to shrink the font size more aggressively - new_font_size = int(initial_font_size * (latent_width / text_width) * scaling_factor) + new_font_size = int(initial_font_size * (i_width / text_width) * scaling_factor) else: new_font_size = initial_font_size @@ -1935,9 +1911,6 @@ class TSC_KSampler: # ______________________________________________________________________________________________________ - # Disable vae decode on next Hold - update_value_by_id("vae_decode_flag", my_unique_id, False) - def rearrange_list_A(arr, num_cols, num_rows): new_list = [] for i in range(num_rows): @@ -1971,13 +1944,13 @@ class TSC_KSampler: latent_list = rearrange_list_A(latent_list, num_cols, num_rows) # Extract final image dimensions - latent_height, latent_width = latent_list[0]['samples'].shape[2] * 8, latent_list[0]['samples'].shape[3] * 8 + i_height, i_width = image_tensor_list[0].shape[1], image_tensor_list[0].shape[2] # Print XY Plot Results print_plot_variables(X_type, Y_type, X_value, Y_value, add_noise, seed, steps, start_at_step, end_at_step, return_with_leftover_noise, cfg, sampler_name, scheduler, denoise, vae_name, ckpt_name, clip_skip, refiner_name, refiner_clip_skip, ascore, lora_stack, cnet_stack, - sampler_type, num_rows, num_cols, latent_height, latent_width) + sampler_type, num_rows, num_cols, i_height, i_width) # Concatenate the 'samples' and 'noise_mask' tensors along the first dimension (dim=0) keys = latent_list[0].keys() @@ -1988,10 +1961,10 @@ class TSC_KSampler: latent_list = result # Store latent_list as last latent - update_value_by_id("latent", my_unique_id, latent_list) + ###update_value_by_id("latent", my_unique_id, latent_list) # Calculate the dimensions of the white background image - border_size_top = latent_width // 15 + border_size_top = i_width // 15 # Longest Y-label length if len(Y_label) > 0: @@ -2005,28 +1978,28 @@ class TSC_KSampler: if Y_label_orientation == "Vertical": border_size_left = border_size_top else: # Assuming Y_label_orientation is "Horizontal" - # border_size_left is now min(latent_width, latent_height) plus 20% of the difference between the two - border_size_left = min(latent_width, latent_height) + int(0.2 * abs(latent_width - latent_height)) + # border_size_left is now min(i_width, i_height) plus 20% of the difference between the two + border_size_left = min(i_width, i_height) + int(0.2 * abs(i_width - i_height)) border_size_left = int(border_size_left * Y_label_scale) # Modify the border size, background width and x_offset initialization based on Y_type and Y_label_orientation if Y_type == "Nothing": - bg_width = num_cols * latent_width + (num_cols - 1) * grid_spacing + bg_width = num_cols * i_width + (num_cols - 1) * grid_spacing x_offset_initial = 0 else: if Y_label_orientation == "Vertical": - bg_width = num_cols * latent_width + (num_cols - 1) * grid_spacing + 3 * border_size_left + bg_width = num_cols * i_width + (num_cols - 1) * grid_spacing + 3 * border_size_left x_offset_initial = border_size_left * 3 else: # Assuming Y_label_orientation is "Horizontal" - bg_width = num_cols * latent_width + (num_cols - 1) * grid_spacing + border_size_left + bg_width = num_cols * i_width + (num_cols - 1) * grid_spacing + border_size_left x_offset_initial = border_size_left # Modify the background height based on X_type if X_type == "Nothing": - bg_height = num_rows * latent_height + (num_rows - 1) * grid_spacing + bg_height = num_rows * i_height + (num_rows - 1) * grid_spacing y_offset = 0 else: - bg_height = num_rows * latent_height + (num_rows - 1) * grid_spacing + 3 * border_size_top + bg_height = num_rows * i_height + (num_rows - 1) * grid_spacing + 3 * border_size_top y_offset = border_size_top * 3 # Create the white background image @@ -2084,8 +2057,8 @@ class TSC_KSampler: # Add the corresponding Y_value as a label to the left of the image if Y_label_orientation == "Vertical": - initial_font_size = int(48 * latent_width / 512) # Adjusting this to be same as X_label size - font_size = adjusted_font_size(text, initial_font_size, latent_width) + initial_font_size = int(48 * i_width / 512) # Adjusting this to be same as X_label size + font_size = adjusted_font_size(text, initial_font_size, i_width) else: # Assuming Y_label_orientation is "Horizontal" initial_font_size = int(48 * (border_size_left/Y_label_scale) / 512) # Adjusting this to be same as X_label size font_size = adjusted_font_size(text, initial_font_size, int(border_size_left/Y_label_scale)) @@ -2132,17 +2105,11 @@ class TSC_KSampler: xy_plot_image = pil2tensor(background) - # Set xy_plot_flag to 'True' and cache the xy_plot_image tensor - update_value_by_id("xy_plot_image", my_unique_id, xy_plot_image) - update_value_by_id("xy_plot_flag", my_unique_id, True) + # Generate the preview_images + preview_images = PreviewImage().save_images(xy_plot_image)["ui"]["images"] - # Generate the preview_images and cache results - preview_images = preview_image(xy_plot_image, filename_prefix) - update_value_by_id("preview_images", my_unique_id, preview_images) - - # Generate output_images and cache results + # Generate output_images output_images = torch.stack([tensor.squeeze() for tensor in image_tensor_list]) - update_value_by_id("output_images", my_unique_id, output_images) # Set the output_image the same as plot image defined by 'xyplot_as_output_image' if xyplot_as_output_image == True: @@ -2154,13 +2121,6 @@ class TSC_KSampler: print("-" * 40) # Print an empty line followed by a separator line - # Set the preview method back to its original state - set_preview_method(previous_preview_method) - - if preview_method != "none": - # Send message to front-end to revoke the last blob image from browser's memory (fixes preview duplication bug) - send_command_to_frontend(startListening=False) - if sampler_type == "sdxl": sdxl_tuple = original_model, original_clip, original_positive, original_negative,\ original_refiner_model, original_refiner_clip, original_refiner_positive, original_refiner_negative @@ -2176,8 +2136,7 @@ class TSC_KSamplerAdvanced(TSC_KSampler): @classmethod def INPUT_TYPES(cls): return {"required": - {"sampler_state": (["Sample", "Hold", "Script"],), - "model": ("MODEL",), + {"model": ("MODEL",), "add_noise": (["enable", "disable"],), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), @@ -2204,11 +2163,11 @@ class TSC_KSamplerAdvanced(TSC_KSampler): FUNCTION = "sample_adv" CATEGORY = "Efficiency Nodes/Sampling" - def sample_adv(self, sampler_state, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, + def sample_adv(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, preview_method, vae_decode, prompt=None, extra_pnginfo=None, my_unique_id=None, optional_vae=(None,), script=None): - return super().sample(sampler_state, model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, + return super().sample(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, preview_method, vae_decode, denoise=1.0, prompt=prompt, extra_pnginfo=extra_pnginfo, my_unique_id=my_unique_id, optional_vae=optional_vae, script=script, add_noise=add_noise, start_at_step=start_at_step,end_at_step=end_at_step, return_with_leftover_noise=return_with_leftover_noise,sampler_type="advanced") @@ -2220,8 +2179,7 @@ class TSC_KSamplerSDXL(TSC_KSampler): @classmethod def INPUT_TYPES(cls): return {"required": - {"sampler_state": (["Sample", "Hold", "Script"],), - "sdxl_tuple": ("SDXL_TUPLE",), + {"sdxl_tuple": ("SDXL_TUPLE",), "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0}), @@ -2233,7 +2191,7 @@ class TSC_KSamplerSDXL(TSC_KSampler): "preview_method": (["auto", "latent2rgb", "taesd", "none"],), "vae_decode": (["true", "true (tiled)", "false", "output only", "output only (tiled)"],), }, - "optional": {"optional_vae": ("VAE",),# "refiner_extras": ("REFINER_EXTRAS",), + "optional": {"optional_vae": ("VAE",), "script": ("SCRIPT",),}, "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO", "my_unique_id": "UNIQUE_ID",}, } @@ -2244,38 +2202,17 @@ class TSC_KSamplerSDXL(TSC_KSampler): FUNCTION = "sample_sdxl" CATEGORY = "Efficiency Nodes/Sampling" - def sample_sdxl(self, sampler_state, sdxl_tuple, noise_seed, steps, cfg, sampler_name, scheduler, latent_image, + def sample_sdxl(self, sdxl_tuple, noise_seed, steps, cfg, sampler_name, scheduler, latent_image, start_at_step, refine_at_step, preview_method, vae_decode, prompt=None, extra_pnginfo=None, my_unique_id=None, optional_vae=(None,), refiner_extras=None, script=None): # sdxl_tuple sent through the 'model' channel - # refine_extras sent through the 'positive' channel negative = None - return super().sample(sampler_state, sdxl_tuple, noise_seed, steps, cfg, sampler_name, scheduler, + return super().sample(sdxl_tuple, noise_seed, steps, cfg, sampler_name, scheduler, refiner_extras, negative, latent_image, preview_method, vae_decode, denoise=1.0, prompt=prompt, extra_pnginfo=extra_pnginfo, my_unique_id=my_unique_id, optional_vae=optional_vae, script=script, add_noise=None, start_at_step=start_at_step, end_at_step=refine_at_step, return_with_leftover_noise=None,sampler_type="sdxl") -#======================================================================================================================= -# TSC KSampler SDXL Refiner Extras (DISABLED) -''' -class TSC_SDXL_Refiner_Extras: - - @classmethod - def INPUT_TYPES(cls): - return {"required": {"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), - "cfg": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0}), - "sampler_name": (comfy.samplers.KSampler.SAMPLERS,), - "scheduler": (comfy.samplers.KSampler.SCHEDULERS,)}} - - RETURN_TYPES = ("REFINER_EXTRAS",) - FUNCTION = "pack_refiner_extras" - CATEGORY = "Efficiency Nodes/Misc" - - def pack_refiner_extras(self, seed, cfg, sampler_name, scheduler): - return ((seed, cfg, sampler_name, scheduler),) -''' - ######################################################################################################################## # Common XY Plot Functions/Variables XYPLOT_LIM = 50 #XY Plot default axis size limit @@ -2379,7 +2316,7 @@ class TSC_XYplot: # Check that dependencies are connected for specific plot types encode_types = { "Checkpoint", "Refiner", - "LoRA", "LoRA Stacks", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr", + "LoRA", "LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr", "Positive Prompt S/R", "Negative Prompt S/R", "AScore+", "AScore-", "Clip Skip", "Clip Skip (Refiner)", @@ -2395,13 +2332,8 @@ class TSC_XYplot: # Check if both X_type and Y_type are special lora_types lora_types = {"LoRA Batch", "LoRA Wt", "LoRA MStr", "LoRA CStr"} if (X_type in lora_types and Y_type not in lora_types) or (Y_type in lora_types and X_type not in lora_types): - print(f"{error('XY Plot Error:')} Both X and Y must be connected to use the 'LoRA Plot' node.") - return (None,) - - # Do not allow LoRA and LoRA Stacks - lora_types = {"LoRA", "LoRA Stacks"} - if (X_type in lora_types and Y_type in lora_types): - print(f"{error('XY Plot Error:')} X and Y input types must be different.") + print( + f"{error('XY Plot Error:')} Both X and Y must be connected to use the 'LoRA Plot' node.") return (None,) # Clean Schedulers from Sampler data (if other type is Scheduler) @@ -3148,7 +3080,7 @@ class TSC_XYplot_LoRA_Stacks: CATEGORY = "Efficiency Nodes/XY Inputs" def xy_value(self, node_state, lora_stack_1=None, lora_stack_2=None, lora_stack_3=None, lora_stack_4=None, lora_stack_5=None): - xy_type = "LoRA Stacks" + xy_type = "LoRA" xy_value = [stack for stack in [lora_stack_1, lora_stack_2, lora_stack_3, lora_stack_4, lora_stack_5] if stack is not None] if not xy_value or not any(xy_value) or node_state == "Disabled": return (None,) @@ -3993,6 +3925,216 @@ class TSC_ImageOverlay: # Return the edited base image return (base_image,) +######################################################################################################################## +# Noise Sources & Seed Variations +# https://github.com/shiimizu/ComfyUI_smZNodes +# https://github.com/chrisgoringe/cg-noise + +# TSC Noise Sources & Variations Script +class TSC_Noise_Control_Script: + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "rng_source": (["cpu", "gpu", "nv"],), + "cfg_denoiser": ("BOOLEAN", {"default": False}), + "add_seed_noise": ("BOOLEAN", {"default": False}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "weight": ("FLOAT", {"default": 0.015, "min": 0, "max": 1, "step": 0.001})}, + "optional": {"script": ("SCRIPT",)} + } + + RETURN_TYPES = ("SCRIPT",) + RETURN_NAMES = ("SCRIPT",) + FUNCTION = "noise_control" + CATEGORY = "Efficiency Nodes/Scripts" + + def noise_control(self, rng_source, cfg_denoiser, add_seed_noise, seed, weight, script=None): + script = script or {} + script["noise"] = (rng_source, cfg_denoiser, add_seed_noise, seed, weight) + return (script,) + +######################################################################################################################## +# Add controlnet options if have controlnet_aux installed (https://github.com/Fannovel16/comfyui_controlnet_aux) +use_controlnet_widget = preprocessor_widget = (["_"],) +if os.path.exists(os.path.join(custom_nodes_dir, "comfyui_controlnet_aux")): + printout = "Attempting to add Control Net options to the 'HiRes-Fix Script' Node (comfyui_controlnet_aux add-on)..." + #print(f"{message('Efficiency Nodes:')} {printout}", end="", flush=True) + + try: + with suppress_output(): + AIO_Preprocessor = getattr(import_module("comfyui_controlnet_aux.__init__"), 'AIO_Preprocessor') + use_controlnet_widget = ("BOOLEAN", {"default": False}) + preprocessor_widget = AIO_Preprocessor.INPUT_TYPES()["optional"]["preprocessor"] + print(f"\r{message('Efficiency Nodes:')} {printout}{success('Success!')}") + except Exception: + print(f"\r{message('Efficiency Nodes:')} {printout}{error('Failed!')}") + +# TSC HighRes-Fix with model latent upscalers (https://github.com/city96/SD-Latent-Upscaler) +class TSC_HighRes_Fix: + + default_latent_upscalers = LatentUpscaleBy.INPUT_TYPES()["required"]["upscale_method"][0] + + city96_upscale_methods =\ + ["city96." + ver for ver in city96_latent_upscaler.LatentUpscaler.INPUT_TYPES()["required"]["latent_ver"][0]] + city96_scalings_raw = city96_latent_upscaler.LatentUpscaler.INPUT_TYPES()["required"]["scale_factor"][0] + city96_scalings_float = [float(scale) for scale in city96_scalings_raw] + + ttl_nn_upscale_methods = \ + ["ttl_nn." + ver for ver in + ttl_nn_latent_upscaler.NNLatentUpscale.INPUT_TYPES()["required"]["version"][0]] + + latent_upscalers = default_latent_upscalers + city96_upscale_methods + ttl_nn_upscale_methods + pixel_upscalers = folder_paths.get_filename_list("upscale_models") + + @classmethod + def INPUT_TYPES(cls): + + return {"required": {"upscale_type": (["latent","pixel"],), + "hires_ckpt_name": (["(use same)"] + folder_paths.get_filename_list("checkpoints"),), + "latent_upscaler": (cls.latent_upscalers,), + "pixel_upscaler": (cls.pixel_upscalers,), + "upscale_by": ("FLOAT", {"default": 1.25, "min": 0.01, "max": 8.0, "step": 0.05}), + "use_same_seed": ("BOOLEAN", {"default": True}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "hires_steps": ("INT", {"default": 12, "min": 1, "max": 10000}), + "denoise": ("FLOAT", {"default": .56, "min": 0.00, "max": 1.00, "step": 0.01}), + "iterations": ("INT", {"default": 1, "min": 0, "max": 5, "step": 1}), + "use_controlnet": use_controlnet_widget, + "control_net_name": (folder_paths.get_filename_list("controlnet"),), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "preprocessor": preprocessor_widget, + "preprocessor_imgs": ("BOOLEAN", {"default": False}) + }, + "optional": {"script": ("SCRIPT",)}, + "hidden": {"my_unique_id": "UNIQUE_ID"} + } + + RETURN_TYPES = ("SCRIPT",) + FUNCTION = "hires_fix_script" + CATEGORY = "Efficiency Nodes/Scripts" + + def hires_fix_script(self, upscale_type, hires_ckpt_name, latent_upscaler, pixel_upscaler, upscale_by, + use_same_seed, seed, hires_steps, denoise, iterations, use_controlnet, control_net_name, + strength, preprocessor, preprocessor_imgs, script=None, my_unique_id=None): + latent_upscale_function = None + latent_upscale_model = None + pixel_upscale_model = None + + def float_to_string(num): + if num == int(num): + return "{:.1f}".format(num) + else: + return str(num) + + if iterations > 0 and upscale_by > 0: + if upscale_type == "latent": + # For latent methods from city96 + if latent_upscaler in self.city96_upscale_methods: + # Remove extra characters added + latent_upscaler = latent_upscaler.replace("city96.", "") + + # Set function to city96_latent_upscaler.LatentUpscaler + latent_upscale_function = city96_latent_upscaler.LatentUpscaler + + # Find the nearest valid scaling in city96_scalings_float + nearest_scaling = min(self.city96_scalings_float, key=lambda x: abs(x - upscale_by)) + + # Retrieve the index of the nearest scaling + nearest_scaling_index = self.city96_scalings_float.index(nearest_scaling) + + # Use the index to get the raw string representation + nearest_scaling_raw = self.city96_scalings_raw[nearest_scaling_index] + + upscale_by = float_to_string(upscale_by) + + # Check if the input upscale_by value was different from the nearest valid value + if upscale_by != nearest_scaling_raw: + print(f"{warning('HighRes-Fix Warning:')} " + f"When using 'city96.{latent_upscaler}', 'upscale_by' must be one of {self.city96_scalings_raw}.\n" + f"Rounding to the nearest valid value ({nearest_scaling_raw}).\033[0m") + upscale_by = nearest_scaling_raw + + # For ttl upscale methods + elif latent_upscaler in self.ttl_nn_upscale_methods: + # Remove extra characters added + latent_upscaler = latent_upscaler.replace("ttl_nn.", "") + + # Bound to min/max limits + upscale_by_clamped = min(max(upscale_by, 1), 2) + if upscale_by != upscale_by_clamped: + print(f"{warning('HighRes-Fix Warning:')} " + f"When using 'ttl_nn.{latent_upscaler}', 'upscale_by' must be between 1 and 2.\n" + f"Rounding to the nearest valid value ({upscale_by_clamped}).\033[0m") + upscale_by = upscale_by_clamped + + latent_upscale_function = ttl_nn_latent_upscaler.NNLatentUpscale + + # For default upscale methods + elif latent_upscaler in self.default_latent_upscalers: + latent_upscale_function = LatentUpscaleBy + + else: # Default + latent_upscale_function = LatentUpscaleBy + latent_upscaler = self.default_latent_upscalers[0] + print(f"{warning('HiResFix Script Warning:')} Chosen latent upscale method not found! " + f"defaulting to '{latent_upscaler}'.\n") + + # Load Checkpoint if defined + if hires_ckpt_name == "(use same)": + clear_cache(my_unique_id, 0, "ckpt") + else: + latent_upscale_model, _, _ = \ + load_checkpoint(hires_ckpt_name, my_unique_id, output_vae=False, cache=1, cache_overwrite=True) + + elif upscale_type == "pixel": + pixel_upscale_model = UpscaleModelLoader().load_model(pixel_upscaler)[0] + + control_net = ControlNetLoader().load_controlnet(control_net_name)[0] if use_controlnet is True else None + + # Construct the script output + script = script or {} + script["hiresfix"] = (upscale_type, latent_upscaler, upscale_by, use_same_seed, seed, hires_steps, + denoise, iterations, control_net, strength, preprocessor, preprocessor_imgs, + latent_upscale_function, latent_upscale_model, pixel_upscale_model) + + return (script,) + +######################################################################################################################## +# TSC Tiled Upscaler (https://github.com/BlenderNeko/ComfyUI_TiledKSampler) +class TSC_Tiled_Upscaler: + @classmethod + def INPUT_TYPES(cls): + # Split the list based on the keyword "tile" + cnet_tile_filenames = [name for name in folder_paths.get_filename_list("controlnet") if "tile" in name] + #cnet_other_filenames = [name for name in folder_paths.get_filename_list("controlnet") if "tile" not in name] + + return {"required": {"upscale_by": ("FLOAT", {"default": 1.25, "min": 0.01, "max": 8.0, "step": 0.05}), + "tile_size": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}), + "tiling_strategy": (["random", "random strict", "padded", 'simple', 'none'],), + "tiling_steps": ("INT", {"default": 30, "min": 1, "max": 10000}), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "denoise": ("FLOAT", {"default": .4, "min": 0.0, "max": 1.0, "step": 0.01}), + "use_controlnet": ("BOOLEAN", {"default": False}), + "tile_controlnet": (cnet_tile_filenames,), + "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + }, + "optional": {"script": ("SCRIPT",)}} + + RETURN_TYPES = ("SCRIPT",) + FUNCTION = "tiled_sampling" + CATEGORY = "Efficiency Nodes/Scripts" + + def tiled_sampling(self, upscale_by, tile_size, tiling_strategy, tiling_steps, seed, denoise, + use_controlnet, tile_controlnet, strength, script=None): + if tiling_strategy != 'none': + script = script or {} + tile_controlnet = ControlNetLoader().load_controlnet(tile_controlnet)[0] if use_controlnet else None + + script["tile"] = (upscale_by, tile_size, tiling_strategy, tiling_steps, seed, denoise, tile_controlnet, strength) + return (script,) + ######################################################################################################################## # NODE MAPPING NODE_CLASS_MAPPINGS = { @@ -4006,7 +4148,6 @@ NODE_CLASS_MAPPINGS = { "Apply ControlNet Stack": TSC_Apply_ControlNet_Stack, "Unpack SDXL Tuple": TSC_Unpack_SDXL_Tuple, "Pack SDXL Tuple": TSC_Pack_SDXL_Tuple, - #"Refiner Extras": TSC_SDXL_Refiner_Extras, # MAYBE FUTURE "XY Plot": TSC_XYplot, "XY Input: Seeds++ Batch": TSC_XYplot_SeedsBatch, "XY Input: Add/Return Noise": TSC_XYplot_AddReturnNoise, @@ -4028,146 +4169,54 @@ NODE_CLASS_MAPPINGS = { "XY Input: Manual XY Entry": TSC_XYplot_Manual_XY_Entry, "Manual XY Entry Info": TSC_XYplot_Manual_XY_Entry_Info, "Join XY Inputs of Same Type": TSC_XYplot_JoinInputs, - "Image Overlay": TSC_ImageOverlay + "Image Overlay": TSC_ImageOverlay, + "Noise Control Script": TSC_Noise_Control_Script, + "HighRes-Fix Script": TSC_HighRes_Fix, + "Tiled Upscaler Script": TSC_Tiled_Upscaler } ######################################################################################################################## -# HirRes Fix Script with model latent upscaler (https://github.com/city96/SD-Latent-Upscaler) -city96_latent_upscaler = None -city96_latent_upscaler_path = os.path.join(custom_nodes_dir, "SD-Latent-Upscaler") -if os.path.exists(city96_latent_upscaler_path): - printout = "Adding City96's 'SD-Latent-Upscaler' selections to the 'HighRes-Fix' node..." +# Add AnimateDiff Script based off Kosinkadink's Nodes (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) +if os.path.exists(os.path.join(custom_nodes_dir, "ComfyUI-AnimateDiff-Evolved")): + printout = "Attempting to add 'AnimatedDiff Script' Node (ComfyUI-AnimateDiff-Evolved add-on)..." print(f"{message('Efficiency Nodes:')} {printout}", end="") try: - city96_latent_upscaler = import_module("SD-Latent-Upscaler.comfy_latent_upscaler") - print(f"\r{message('Efficiency Nodes:')} {printout}{success('Success!')}") - except ImportError: - print(f"\r{message('Efficiency Nodes:')} {printout}{error('Failed!')}") - -# TSC HighRes-Fix -class TSC_HighRes_Fix: - - default_upscale_methods = LatentUpscaleBy.INPUT_TYPES()["required"]["upscale_method"][0] - city96_upscale_methods = list() - - if city96_latent_upscaler: - city96_upscale_methods = ["SD-Latent-Upscaler." + ver for ver in city96_latent_upscaler.LatentUpscaler.INPUT_TYPES()["required"]["latent_ver"][0]] - city96_scalings_raw = city96_latent_upscaler.LatentUpscaler.INPUT_TYPES()["required"]["scale_factor"][0] - city96_scalings_float = [float(scale) for scale in city96_scalings_raw] - upscale_methods = default_upscale_methods + city96_upscale_methods - - @classmethod - def INPUT_TYPES(cls): - return {"required": {"latent_upscale_method": (cls.upscale_methods,), - "upscale_by": ("FLOAT", {"default": 1.25, "min": 0.01, "max": 8.0, "step": 0.25}), - "hires_steps": ("INT", {"default": 12, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": .56, "min": 0.0, "max": 1.0, "step": 0.01}), - "iterations": ("INT", {"default": 1, "min": 0, "max": 5, "step": 1}), - }, - "optional": {"script": ("SCRIPT",)}} - - RETURN_TYPES = ("SCRIPT",) - FUNCTION = "hires_fix_script" - CATEGORY = "Efficiency Nodes/Scripts" - - def hires_fix_script(self, latent_upscale_method, upscale_by, hires_steps, denoise, iterations, script=None): - upscale_function = None - - def float_to_string(num): - if num == int(num): - return "{:.1f}".format(num) - else: - return str(num) - - if iterations > 0: - # For latent methods from SD-Latent-Upscaler - if latent_upscale_method in self.city96_upscale_methods: - # Remove extra characters added - latent_upscale_method = latent_upscale_method.replace("SD-Latent-Upscaler.", "") - - # Set function to city96_latent_upscaler.LatentUpscaler - upscale_function = city96_latent_upscaler.LatentUpscaler - - # Find the nearest valid scaling in city96_scalings_float - nearest_scaling = min(self.city96_scalings_float, key=lambda x: abs(x - upscale_by)) - - # Retrieve the index of the nearest scaling - nearest_scaling_index = self.city96_scalings_float.index(nearest_scaling) - - # Use the index to get the raw string representation - nearest_scaling_raw = self.city96_scalings_raw[nearest_scaling_index] - - upscale_by = float_to_string(upscale_by) - - # Check if the input upscale_by value was different from the nearest valid value - if upscale_by != nearest_scaling_raw: - print(f"{warning('HighRes-Fix Warning:')} " - f"When using 'SD-Latent-Upscaler.{latent_upscale_method}', 'upscale_by' must be one of {self.city96_scalings_raw}.\n" - f"Rounding to the nearest valid value ({nearest_scaling_raw}).\033[0m") - upscale_by = nearest_scaling_raw - - # For default upscale methods - elif latent_upscale_method in self.default_upscale_methods: - upscale_function = LatentUpscaleBy - - else: # Default - upscale_function = LatentUpscaleBy - latent_upscale_method = self.default_upscale_methods[0] - print(f"{warning('HiResFix Script Warning:')} Chosen latent upscale method not found! " - f"defaulting to '{latent_upscale_method}'.\n") - - # Construct the script output - script = script or {} - script["hiresfix"] = (latent_upscale_method, upscale_by, hires_steps, denoise, iterations, upscale_function) - - return (script,) - -NODE_CLASS_MAPPINGS.update({"HighRes-Fix Script": TSC_HighRes_Fix}) - -######################################################################################################################## -''' -# Tiled Sampling KSamplers (https://github.com/BlenderNeko/ComfyUI_TiledKSampler) -blenderneko_tiled_ksampler_path = os.path.join(custom_nodes_dir, "ComfyUI_TiledKSampler") -if os.path.exists(blenderneko_tiled_ksampler_path): - printout = "Importing BlenderNeko's 'ComfyUI_TiledKSampler' to enable the 'Tiled Sampling' node..." - print(f"{message('Efficiency Nodes:')} {printout}", end="") - try: - blenderneko_tiled_nodes = import_module("ComfyUI_TiledKSampler.nodes") + module = import_module("ComfyUI-AnimateDiff-Evolved.animatediff.nodes") + AnimateDiffLoaderWithContext = getattr(module, 'AnimateDiffLoaderWithContext') + AnimateDiffCombine = getattr(module, 'AnimateDiffCombine_Deprecated') print(f"\r{message('Efficiency Nodes:')} {printout}{success('Success!')}") - # TSC Tiled Upscaler - class TSC_Tiled_Upscaler: + # TSC AnimatedDiff Script (https://github.com/BlenderNeko/ComfyUI_TiledKSampler) + class TSC_AnimateDiff_Script: @classmethod def INPUT_TYPES(cls): - # Split the list based on the keyword "tile" - cnet_tile_filenames = [name for name in folder_paths.get_filename_list("controlnet") if "tile" in name] - cnet_other_filenames = [name for name in folder_paths.get_filename_list("controlnet") if "tile" not in name] - - return {"required": {"tile_controlnet": (cnet_tile_filenames + cnet_other_filenames,), - "upscale_by": ("FLOAT", {"default": 1.25, "min": 0.01, "max": 8.0, "step": 0.25}), - "tile_size": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}), - "tiling_strategy": (["random", "random strict", "padded", 'simple', 'none'],), - "tiling_steps": ("INT", {"default": 30, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": .56, "min": 0.0, "max": 1.0, "step": 0.01}), - }, - "optional": {"script": ("SCRIPT",)}} + + return {"required": { + "motion_model": AnimateDiffLoaderWithContext.INPUT_TYPES()["required"]["model_name"], + "beta_schedule": AnimateDiffLoaderWithContext.INPUT_TYPES()["required"]["beta_schedule"], + "frame_rate": AnimateDiffCombine.INPUT_TYPES()["required"]["frame_rate"], + "loop_count": AnimateDiffCombine.INPUT_TYPES()["required"]["loop_count"], + "format": AnimateDiffCombine.INPUT_TYPES()["required"]["format"], + "pingpong": AnimateDiffCombine.INPUT_TYPES()["required"]["pingpong"], + "save_image": AnimateDiffCombine.INPUT_TYPES()["required"]["save_image"]}, + "optional": {"context_options": ("CONTEXT_OPTIONS",)} + } RETURN_TYPES = ("SCRIPT",) - FUNCTION = "tiled_sampling" + FUNCTION = "animatediff" CATEGORY = "Efficiency Nodes/Scripts" - def tiled_sampling(self, upscale_by, tile_controlnet, tile_size, tiling_strategy, tiling_steps, denoise, script=None): - if tiling_strategy != 'none': - script = script or {} - script["tile"] = (upscale_by, ControlNetLoader().load_controlnet(tile_controlnet)[0], - tile_size, tiling_strategy, tiling_steps, denoise, blenderneko_tiled_nodes) + def animatediff(self, motion_model, beta_schedule, frame_rate, loop_count, format, pingpong, save_image, + script=None, context_options=None): + script = script or {} + script["anim"] = (motion_model, beta_schedule, context_options, frame_rate, loop_count, format, pingpong, save_image) return (script,) - NODE_CLASS_MAPPINGS.update({"Tiled Upscaler Script": TSC_Tiled_Upscaler}) + NODE_CLASS_MAPPINGS.update({"AnimateDiff Script": TSC_AnimateDiff_Script}) - except ImportError: + except Exception: print(f"\r{message('Efficiency Nodes:')} {printout}{error('Failed!')}") -''' + ######################################################################################################################## # Simpleeval Nodes (https://github.com/danthedeckie/simpleeval) try: diff --git a/js/appearance.js b/js/appearance.js index 5ea8599..1e87498 100644 --- a/js/appearance.js +++ b/js/appearance.js @@ -44,8 +44,10 @@ const NODE_COLORS = { "Manual XY Entry Info": "cyan", "Join XY Inputs of Same Type": "cyan", "Image Overlay": "random", + "Noise Control Script": "none", "HighRes-Fix Script": "yellow", - "Tiled Sampling Script": "none", + "Tiled Upscaler Script": "red", + "AnimateDiff Script": "random", "Evaluate Integers": "pale_blue", "Evaluate Floats": "pale_blue", "Evaluate Strings": "pale_blue", diff --git a/js/gif_preview.js b/js/gif_preview.js new file mode 100644 index 0000000..e0e34e1 --- /dev/null +++ b/js/gif_preview.js @@ -0,0 +1,144 @@ +import { app } from '../../scripts/app.js' +import { api } from '../../scripts/api.js' + +function offsetDOMWidget( + widget, + ctx, + node, + widgetWidth, + widgetY, + height + ) { + const margin = 10 + const elRect = ctx.canvas.getBoundingClientRect() + const transform = new DOMMatrix() + .scaleSelf( + elRect.width / ctx.canvas.width, + elRect.height / ctx.canvas.height + ) + .multiplySelf(ctx.getTransform()) + .translateSelf(0, widgetY + margin) + + const scale = new DOMMatrix().scaleSelf(transform.a, transform.d) + Object.assign(widget.inputEl.style, { + transformOrigin: '0 0', + transform: scale, + left: `${transform.e}px`, + top: `${transform.d + transform.f}px`, + width: `${widgetWidth}px`, + height: `${(height || widget.parent?.inputHeight || 32) - margin}px`, + position: 'absolute', + background: !node.color ? '' : node.color, + color: !node.color ? '' : 'white', + zIndex: 5, //app.graph._nodes.indexOf(node), + }) + } + + export const hasWidgets = (node) => { + if (!node.widgets || !node.widgets?.[Symbol.iterator]) { + return false + } + return true + } + + export const cleanupNode = (node) => { + if (!hasWidgets(node)) { + return + } + + for (const w of node.widgets) { + if (w.canvas) { + w.canvas.remove() + } + if (w.inputEl) { + w.inputEl.remove() + } + // calls the widget remove callback + w.onRemoved?.() + } + } + +const CreatePreviewElement = (name, val, format) => { + const [type] = format.split('/') + const w = { + name, + type, + value: val, + draw: function (ctx, node, widgetWidth, widgetY, height) { + const [cw, ch] = this.computeSize(widgetWidth) + offsetDOMWidget(this, ctx, node, widgetWidth, widgetY, ch) + }, + computeSize: function (_) { + const ratio = this.inputRatio || 1 + const width = Math.max(220, this.parent.size[0]) + return [width, (width / ratio + 10)] + }, + onRemoved: function () { + if (this.inputEl) { + this.inputEl.remove() + } + }, + } + + w.inputEl = document.createElement(type === 'video' ? 'video' : 'img') + w.inputEl.src = w.value + if (type === 'video') { + w.inputEl.setAttribute('type', 'video/webm'); + w.inputEl.autoplay = true + w.inputEl.loop = true + w.inputEl.controls = false; + } + w.inputEl.onload = function () { + w.inputRatio = w.inputEl.naturalWidth / w.inputEl.naturalHeight + } + document.body.appendChild(w.inputEl) + return w + } + +const gif_preview = { + name: 'efficiency.gif_preview', + async beforeRegisterNodeDef(nodeType, nodeData, app) { + switch (nodeData.name) { + case 'KSampler (Efficient)':{ + const onExecuted = nodeType.prototype.onExecuted + nodeType.prototype.onExecuted = function (message) { + const prefix = 'ad_gif_preview_' + const r = onExecuted ? onExecuted.apply(this, message) : undefined + + if (this.widgets) { + const pos = this.widgets.findIndex((w) => w.name === `${prefix}_0`) + if (pos !== -1) { + for (let i = pos; i < this.widgets.length; i++) { + this.widgets[i].onRemoved?.() + } + this.widgets.length = pos + } + if (message?.gifs) { + message.gifs.forEach((params, i) => { + const previewUrl = api.apiURL( + '/view?' + new URLSearchParams(params).toString() + ) + const w = this.addCustomWidget( + CreatePreviewElement(`${prefix}_${i}`, previewUrl, params.format || 'image/gif') + ) + w.parent = this + }) + } + const onRemoved = this.onRemoved + this.onRemoved = () => { + cleanupNode(this) + return onRemoved?.() + } + } + if (message?.gifs && message.gifs.length > 0) { + this.setSize([this.size[0], this.computeSize([this.size[0], this.size[1]])[1]]); + } + return r + } + break + } + } + } +} + +app.registerExtension(gif_preview) diff --git a/js/node_options/addLinks.js b/js/node_options/addLinks.js new file mode 100644 index 0000000..326dc0b --- /dev/null +++ b/js/node_options/addLinks.js @@ -0,0 +1,180 @@ +import { app } from "../../../scripts/app.js"; +import { addMenuHandler } from "./common/utils.js"; +import { addNode } from "./common/utils.js"; + +function createKSamplerEntry(node, samplerType, subNodeType = null, isSDXL = false) { + const samplerLabelMap = { + "Eff": "KSampler (Efficient)", + "Adv": "KSampler Adv. (Efficient)", + "SDXL": "KSampler SDXL (Eff.)" + }; + + const subNodeLabelMap = { + "XYPlot": "XY Plot", + "NoiseControl": "Noise Control Script", + "HiResFix": "HighRes-Fix Script", + "TiledUpscale": "Tiled Upscaler Script", + "AnimateDiff": "AnimateDiff Script" + }; + + const nicknameMap = { + "KSampler (Efficient)": "KSampler", + "KSampler Adv. (Efficient)": "KSampler(Adv)", + "KSampler SDXL (Eff.)": "KSampler", + "XY Plot": "XY Plot", + "Noise Control Script": "NoiseControl", + "HighRes-Fix Script": "HiResFix", + "Tiled Upscaler Script": "TiledUpscale", + "AnimateDiff Script": "AnimateDiff" + }; + + const kSamplerLabel = samplerLabelMap[samplerType]; + const subNodeLabel = subNodeLabelMap[subNodeType]; + + const kSamplerNickname = nicknameMap[kSamplerLabel]; + const subNodeNickname = nicknameMap[subNodeLabel]; + + const contentLabel = subNodeNickname ? `${kSamplerNickname} + ${subNodeNickname}` : kSamplerNickname; + + return { + content: contentLabel, + callback: function() { + const kSamplerNode = addNode(kSamplerLabel, node, { shiftX: node.size[0] + 50 }); + + // Standard connections for all samplers + node.connect(0, kSamplerNode, 0); // MODEL + node.connect(1, kSamplerNode, 1); // CONDITIONING+ + node.connect(2, kSamplerNode, 2); // CONDITIONING- + + // Additional connections for non-SDXL + if (!isSDXL) { + node.connect(3, kSamplerNode, 3); // LATENT + node.connect(4, kSamplerNode, 4); // VAE + } + + if (subNodeLabel) { + const subNode = addNode(subNodeLabel, node, { shiftX: 50, shiftY: node.size[1] + 50 }); + const dependencyIndex = isSDXL ? 3 : 5; + node.connect(dependencyIndex, subNode, 0); + subNode.connect(0, kSamplerNode, dependencyIndex); + } + }, + }; +} + +function createStackerNode(node, type) { + const stackerLabelMap = { + "LoRA": "LoRA Stacker", + "ControlNet": "Control Net Stacker" + }; + + const contentLabel = stackerLabelMap[type]; + + return { + content: contentLabel, + callback: function() { + const stackerNode = addNode(contentLabel, node); + + // Calculate the left shift based on the width of the new node + const shiftX = -(stackerNode.size[0] + 25); + + stackerNode.pos[0] += shiftX; // Adjust the x position of the new node + + // Introduce a Y offset of 200 for ControlNet Stacker node + if (type === "ControlNet") { + stackerNode.pos[1] += 300; + } + + // Connect outputs to the Efficient Loader based on type + if (type === "LoRA") { + stackerNode.connect(0, node, 0); + } else if (type === "ControlNet") { + stackerNode.connect(0, node, 1); + } + }, + }; +} + +function createXYPlotNode(node, type) { + const contentLabel = "XY Plot"; + + return { + content: contentLabel, + callback: function() { + const xyPlotNode = addNode(contentLabel, node); + + // Center the X coordinate of the XY Plot node + const centerXShift = (node.size[0] - xyPlotNode.size[0]) / 2; + xyPlotNode.pos[0] += centerXShift; + + // Adjust the Y position to place it below the loader node + xyPlotNode.pos[1] += node.size[1] + 60; + + // Depending on the node type, connect the appropriate output to the XY Plot node + if (type === "Efficient") { + node.connect(6, xyPlotNode, 0); + } else if (type === "SDXL") { + node.connect(3, xyPlotNode, 0); + } + }, + }; +} + +function getMenuValues(type, node) { + const subNodeTypes = [null, "XYPlot", "NoiseControl", "HiResFix", "TiledUpscale", "AnimateDiff"]; + const excludedSubNodeTypes = ["NoiseControl", "HiResFix", "TiledUpscale", "AnimateDiff"]; // Nodes to exclude from the menu + + const menuValues = []; + + // Add the new node types to the menu first for the correct order + menuValues.push(createStackerNode(node, "LoRA")); + menuValues.push(createStackerNode(node, "ControlNet")); + + for (const subNodeType of subNodeTypes) { + // Skip adding submenu items that are in the excludedSubNodeTypes array + if (!excludedSubNodeTypes.includes(subNodeType)) { + const menuEntry = createKSamplerEntry(node, type === "Efficient" ? "Eff" : "SDXL", subNodeType, type === "SDXL"); + menuValues.push(menuEntry); + } + } + + // Insert the standalone XY Plot option after the KSampler without any subNodeTypes and before any other KSamplers with subNodeTypes + menuValues.splice(3, 0, createXYPlotNode(node, type)); + + return menuValues; +} + +function showAddLinkMenuCommon(value, options, e, menu, node, type) { + const values = getMenuValues(type, node); + new LiteGraph.ContextMenu(values, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + return false; +} + +// Extension Definition +app.registerExtension({ + name: "efficiency.addLinks", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + const linkTypes = { + "Efficient Loader": "Efficient", + "Eff. Loader SDXL": "SDXL" + }; + + const linkType = linkTypes[nodeData.name]; + + if (linkType) { + addMenuHandler(nodeType, function(insertOption) { + insertOption({ + content: "⛓ Add link...", + has_submenu: true, + callback: (value, options, e, menu, node) => showAddLinkMenuCommon(value, options, e, menu, node, linkType) + }); + }); + } + }, +}); + diff --git a/js/node_options/addScripts.js b/js/node_options/addScripts.js new file mode 100644 index 0000000..fc1de16 --- /dev/null +++ b/js/node_options/addScripts.js @@ -0,0 +1,152 @@ +import { app } from "../../../scripts/app.js"; +import { addMenuHandler } from "./common/utils.js"; +import { addNode } from "./common/utils.js"; + +const connectionMap = { + "KSampler (Efficient)": ["input", 5], + "KSampler Adv. (Efficient)": ["input", 5], + "KSampler SDXL (Eff.)": ["input", 3], + "XY Plot": ["output", 0], + "Noise Control Script": ["input & output", 0], + "HighRes-Fix Script": ["input & output", 0], + "Tiled Upscaler Script": ["input & output", 0], + "AnimateDiff Script": ["output", 0] +}; + + /** + * connect this node output to the input of another node + * @method connect + * @param {number_or_string} slot (could be the number of the slot or the string with the name of the slot) + * @param {LGraphNode} node the target node + * @param {number_or_string} target_slot the input slot of the target node (could be the number of the slot or the string with the name of the slot, or -1 to connect a trigger) + * @return {Object} the link_info is created, otherwise null + LGraphNode.prototype.connect = function(output_slot, target_node, input_slot) + **/ + +function addAndConnectScriptNode(scriptType, selectedNode) { + const selectedNodeType = connectionMap[selectedNode.type]; + const newNodeType = connectionMap[scriptType]; + + // 1. Create the new node without position adjustments + const newNode = addNode(scriptType, selectedNode, { shiftX: 0, shiftY: 0 }); + + // 2. Adjust position of the new node based on conditions + if (newNodeType[0].includes("input") && selectedNodeType[0].includes("output")) { + newNode.pos[0] += selectedNode.size[0] + 50; + } else if (newNodeType[0].includes("output") && selectedNodeType[0].includes("input")) { + newNode.pos[0] -= (newNode.size[0] + 50); + } + + // 3. Logic for connecting the nodes + switch (selectedNodeType[0]) { + case "output": + if (newNodeType[0] === "input & output") { + // For every node that was previously connected to the selectedNode's output + const connectedNodes = selectedNode.getOutputNodes(selectedNodeType[1]); + if (connectedNodes && connectedNodes.length) { + for (let connectedNode of connectedNodes) { + // Disconnect the node from selectedNode's output + selectedNode.disconnectOutput(selectedNodeType[1]); + // Connect the newNode's output to the previously connected node, + // using the appropriate slot based on the type of the connectedNode + const targetSlot = (connectedNode.type in connectionMap) ? connectionMap[connectedNode.type][1] : 0; + newNode.connect(0, connectedNode, targetSlot); + } + } + // Connect selectedNode's output to newNode's input + selectedNode.connect(selectedNodeType[1], newNode, newNodeType[1]); + } + break; + + case "input": + if (newNodeType[0] === "output") { + newNode.connect(0, selectedNode, selectedNodeType[1]); + } else if (newNodeType[0] === "input & output") { + const ogInputNode = selectedNode.getInputNode(selectedNodeType[1]); + if (ogInputNode) { + ogInputNode.connect(0, newNode, 0); + } + newNode.connect(0, selectedNode, selectedNodeType[1]); + } + break; + case "input & output": + if (newNodeType[0] === "output") { + newNode.connect(0, selectedNode, 0); + } else if (newNodeType[0] === "input & output") { + + const connectedNodes = selectedNode.getOutputNodes(0); + if (connectedNodes && connectedNodes.length) { + for (let connectedNode of connectedNodes) { + selectedNode.disconnectOutput(0); + newNode.connect(0, connectedNode, connectedNode.type in connectionMap ? connectionMap[connectedNode.type][1] : 0); + } + } + // Connect selectedNode's output to newNode's input + selectedNode.connect(selectedNodeType[1], newNode, newNodeType[1]); + } + break; + } + + return newNode; +} + +function createScriptEntry(node, scriptType) { + return { + content: scriptType, + callback: function() { + addAndConnectScriptNode(scriptType, node); + }, + }; +} + +function getScriptOptions(nodeType, node) { + const allScriptTypes = [ + "XY Plot", + "Noise Control Script", + "HighRes-Fix Script", + "Tiled Upscaler Script", + "AnimateDiff Script" + ]; + + // Filter script types based on node type + const scriptTypes = allScriptTypes.filter(scriptType => { + const scriptBehavior = connectionMap[scriptType][0]; + + if (connectionMap[nodeType][0] === "output") { + return scriptBehavior.includes("input"); // Includes nodes that are "input" or "input & output" + } else { + return true; + } + }); + + return scriptTypes.map(script => createScriptEntry(node, script)); +} + + +function showAddScriptMenu(_, options, e, menu, node) { + const scriptOptions = getScriptOptions(node.type, node); + new LiteGraph.ContextMenu(scriptOptions, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + return false; +} + +// Extension Definition +app.registerExtension({ + name: "efficiency.addScripts", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (connectionMap[nodeData.name]) { + addMenuHandler(nodeType, function(insertOption) { + insertOption({ + content: "📜 Add script...", + has_submenu: true, + callback: showAddScriptMenu + }); + }); + } + }, +}); + diff --git a/js/node_options/addXYinputs.js b/js/node_options/addXYinputs.js new file mode 100644 index 0000000..98b4d43 --- /dev/null +++ b/js/node_options/addXYinputs.js @@ -0,0 +1,89 @@ +import { app } from "../../../scripts/app.js"; +import { addMenuHandler, addNode } from "./common/utils.js"; + +const nodePxOffsets = 80; + +function getXYInputNodes() { + return [ + "XY Input: Seeds++ Batch", + "XY Input: Add/Return Noise", + "XY Input: Steps", + "XY Input: CFG Scale", + "XY Input: Sampler/Scheduler", + "XY Input: Denoise", + "XY Input: VAE", + "XY Input: Prompt S/R", + "XY Input: Aesthetic Score", + "XY Input: Refiner On/Off", + "XY Input: Checkpoint", + "XY Input: Clip Skip", + "XY Input: LoRA", + "XY Input: LoRA Plot", + "XY Input: LoRA Stacks", + "XY Input: Control Net", + "XY Input: Control Net Plot", + "XY Input: Manual XY Entry" + ]; +} + +function showAddXYInputMenu(type, e, menu, node) { + const specialNodes = [ + "XY Input: LoRA Plot", + "XY Input: Control Net Plot", + "XY Input: Manual XY Entry" + ]; + + const values = getXYInputNodes().map(nodeType => { + return { + content: nodeType, + callback: function() { + const newNode = addNode(nodeType, node); + + // Calculate the left shift based on the width of the new node + const shiftX = -(newNode.size[0] + 35); + newNode.pos[0] += shiftX; + + if (specialNodes.includes(nodeType)) { + newNode.pos[1] += 20; + // Connect both outputs to the XY Plot's 2nd and 3rd input. + newNode.connect(0, node, 1); + newNode.connect(1, node, 2); + } else if (type === 'X') { + newNode.pos[1] += 20; + newNode.connect(0, node, 1); // Connect to 2nd input + } else { + newNode.pos[1] += node.size[1] + 45; + newNode.connect(0, node, 2); // Connect to 3rd input + } + } + }; + }); + + new LiteGraph.ContextMenu(values, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + return false; +} + +app.registerExtension({ + name: "efficiency.addXYinputs", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name === "XY Plot") { + addMenuHandler(nodeType, function(insertOption) { + insertOption({ + content: "✏️ Add 𝚇 input...", + has_submenu: true, + callback: (value, options, e, menu, node) => showAddXYInputMenu('X', e, menu, node) + }); + insertOption({ + content: "✏️ Add 𝚈 input...", + has_submenu: true, + callback: (value, options, e, menu, node) => showAddXYInputMenu('Y', e, menu, node) + }); + }); + } + }, +}); diff --git a/js/node_options/common/modelInfoDialog.css b/js/node_options/common/modelInfoDialog.css new file mode 100644 index 0000000..9e259dd --- /dev/null +++ b/js/node_options/common/modelInfoDialog.css @@ -0,0 +1,104 @@ +.pysssss-model-info { + color: white; + font-family: sans-serif; + max-width: 90vw; +} +.pysssss-model-content { + display: flex; + flex-direction: column; + overflow: hidden; +} +.pysssss-model-info h2 { + text-align: center; + margin: 0 0 10px 0; +} +.pysssss-model-info p { + margin: 5px 0; +} +.pysssss-model-info a { + color: dodgerblue; +} +.pysssss-model-info a:hover { + text-decoration: underline; +} +.pysssss-model-tags-list { + display: flex; + flex-wrap: wrap; + list-style: none; + gap: 10px; + max-height: 200px; + overflow: auto; + margin: 10px 0; + padding: 0; +} +.pysssss-model-tag { + background-color: rgb(128, 213, 247); + color: #000; + display: flex; + align-items: center; + gap: 5px; + border-radius: 5px; + padding: 2px 5px; + cursor: pointer; +} +.pysssss-model-tag--selected span::before { + content: "✅"; + position: absolute; + background-color: dodgerblue; + left: 0; + top: 0; + right: 0; + bottom: 0; + text-align: center; +} +.pysssss-model-tag:hover { + outline: 2px solid dodgerblue; +} +.pysssss-model-tag p { + margin: 0; +} +.pysssss-model-tag span { + text-align: center; + border-radius: 5px; + background-color: dodgerblue; + color: #fff; + padding: 2px; + position: relative; + min-width: 20px; + overflow: hidden; +} + +.pysssss-model-metadata .comfy-modal-content { + max-width: 100%; +} +.pysssss-model-metadata label { + margin-right: 1ch; + color: #ccc; +} + +.pysssss-model-metadata span { + color: dodgerblue; +} + +.pysssss-preview { + max-width: 50%; + margin-left: 10px; + position: relative; +} +.pysssss-preview img { + max-height: 300px; +} +.pysssss-preview button { + position: absolute; + font-size: 12px; + bottom: 10px; + right: 10px; +} +.pysssss-model-notes { + background-color: rgba(0, 0, 0, 0.25); + padding: 5px; + margin-top: 5px; +} +.pysssss-model-notes:empty { + display: none; +} diff --git a/js/node_options/common/modelInfoDialog.js b/js/node_options/common/modelInfoDialog.js new file mode 100644 index 0000000..d8cf027 --- /dev/null +++ b/js/node_options/common/modelInfoDialog.js @@ -0,0 +1,303 @@ +import { $el, ComfyDialog } from "../../../../scripts/ui.js"; +import { api } from "../../../../scripts/api.js"; +import { addStylesheet } from "./utils.js"; + +addStylesheet(import.meta.url); + +class MetadataDialog extends ComfyDialog { + constructor() { + super(); + + this.element.classList.add("pysssss-model-metadata"); + } + show(metadata) { + super.show( + $el( + "div", + Object.keys(metadata).map((k) => + $el("div", [$el("label", { textContent: k }), $el("span", { textContent: metadata[k] })]) + ) + ) + ); + } +} + +export class ModelInfoDialog extends ComfyDialog { + constructor(name) { + super(); + this.name = name; + this.element.classList.add("pysssss-model-info"); + } + + get customNotes() { + return this.metadata["pysssss.notes"]; + } + + set customNotes(v) { + this.metadata["pysssss.notes"] = v; + } + + get hash() { + return this.metadata["pysssss.sha256"]; + } + + async show(type, value) { + this.type = type; + + const req = api.fetchApi("/pysssss/metadata/" + encodeURIComponent(`${type}/${value}`)); + this.info = $el("div", { style: { flex: "auto" } }); + this.img = $el("img", { style: { display: "none" } }); + this.imgWrapper = $el("div.pysssss-preview", [this.img]); + this.main = $el("main", { style: { display: "flex" } }, [this.info, this.imgWrapper]); + this.content = $el("div.pysssss-model-content", [$el("h2", { textContent: this.name }), this.main]); + + const loading = $el("div", { textContent: "ℹ️ Loading...", parent: this.content }); + + super.show(this.content); + + this.metadata = await (await req).json(); + this.viewMetadata.style.cursor = this.viewMetadata.style.opacity = ""; + this.viewMetadata.removeAttribute("disabled"); + + loading.remove(); + this.addInfo(); + } + + createButtons() { + const btns = super.createButtons(); + this.viewMetadata = $el("button", { + type: "button", + textContent: "View raw metadata", + disabled: "disabled", + style: { + opacity: 0.5, + cursor: "not-allowed", + }, + onclick: (e) => { + if (this.metadata) { + new MetadataDialog().show(this.metadata); + } + }, + }); + + btns.unshift(this.viewMetadata); + return btns; + } + + getNoteInfo() { + function parseNote() { + if (!this.customNotes) return []; + + let notes = []; + // Extract links from notes + const r = new RegExp("(\\bhttps?:\\/\\/[^\\s]+)", "g"); + let end = 0; + let m; + do { + m = r.exec(this.customNotes); + let pos; + let fin = 0; + if (m) { + pos = m.index; + fin = m.index + m[0].length; + } else { + pos = this.customNotes.length; + } + + let pre = this.customNotes.substring(end, pos); + if (pre) { + pre = pre.replaceAll("\n", "
"); + notes.push( + $el("span", { + innerHTML: pre, + }) + ); + } + if (m) { + notes.push( + $el("a", { + href: m[0], + textContent: m[0], + target: "_blank", + }) + ); + } + + end = fin; + } while (m); + return notes; + } + + let textarea; + let notesContainer; + const editText = "✏️ Edit"; + const edit = $el("a", { + textContent: editText, + href: "#", + style: { + float: "right", + color: "greenyellow", + textDecoration: "none", + }, + onclick: async (e) => { + e.preventDefault(); + + if (textarea) { + this.customNotes = textarea.value; + + const resp = await api.fetchApi( + "/pysssss/metadata/notes/" + encodeURIComponent(`${this.type}/${this.name}`), + { + method: "POST", + body: this.customNotes, + } + ); + + if (resp.status !== 200) { + console.error(resp); + alert(`Error saving notes (${req.status}) ${req.statusText}`); + return; + } + + e.target.textContent = editText; + textarea.remove(); + textarea = null; + + notesContainer.replaceChildren(...parseNote.call(this)); + } else { + e.target.textContent = "💾 Save"; + textarea = $el("textarea", { + style: { + width: "100%", + minWidth: "200px", + minHeight: "50px", + }, + textContent: this.customNotes, + }); + e.target.after(textarea); + notesContainer.replaceChildren(); + textarea.style.height = Math.min(textarea.scrollHeight, 300) + "px"; + } + }, + }); + + notesContainer = $el("div.pysssss-model-notes", parseNote.call(this)); + return $el( + "div", + { + style: { display: "contents" }, + }, + [edit, notesContainer] + ); + } + + addInfo() { + this.addInfoEntry("Notes", this.getNoteInfo()); + } + + addInfoEntry(name, value) { + return $el( + "p", + { + parent: this.info, + }, + [ + typeof name === "string" ? $el("label", { textContent: name + ": " }) : name, + typeof value === "string" ? $el("span", { textContent: value }) : value, + ] + ); + } + + async getCivitaiDetails() { + const req = await fetch("https://civitai.com/api/v1/model-versions/by-hash/" + this.hash); + if (req.status === 200) { + return await req.json(); + } else if (req.status === 404) { + throw new Error("Model not found"); + } else { + throw new Error(`Error loading info (${req.status}) ${req.statusText}`); + } + } + + addCivitaiInfo() { + const promise = this.getCivitaiDetails(); + const content = $el("span", { textContent: "ℹ️ Loading..." }); + + this.addInfoEntry( + $el("label", [ + $el("img", { + style: { + width: "18px", + position: "relative", + top: "3px", + margin: "0 5px 0 0", + }, + src: "https://civitai.com/favicon.ico", + }), + $el("span", { textContent: "Civitai: " }), + ]), + content + ); + + return promise + .then((info) => { + content.replaceChildren( + $el("a", { + href: "https://civitai.com/models/" + info.modelId, + textContent: "View " + info.model.name, + target: "_blank", + }) + ); + + if (info.images?.length) { + this.img.src = info.images[0].url; + this.img.style.display = ""; + + this.imgSave = $el("button", { + textContent: "Use as preview", + parent: this.imgWrapper, + onclick: async () => { + // Convert the preview to a blob + const blob = await (await fetch(this.img.src)).blob(); + + // Store it in temp + const name = "temp_preview." + new URL(this.img.src).pathname.split(".")[1]; + const body = new FormData(); + body.append("image", new File([blob], name)); + body.append("overwrite", "true"); + body.append("type", "temp"); + + const resp = await api.fetchApi("/upload/image", { + method: "POST", + body, + }); + + if (resp.status !== 200) { + console.error(resp); + alert(`Error saving preview (${req.status}) ${req.statusText}`); + return; + } + + // Use as preview + await api.fetchApi("/pysssss/save/" + encodeURIComponent(`${this.type}/${this.name}`), { + method: "POST", + body: JSON.stringify({ + filename: name, + type: "temp", + }), + headers: { + "content-type": "application/json", + }, + }); + app.refreshComboInNodes(); + }, + }); + } + + return info; + }) + .catch((err) => { + content.textContent = "⚠️ " + err.message; + }); + } +} diff --git a/js/node_options/common/utils.js b/js/node_options/common/utils.js new file mode 100644 index 0000000..93510df --- /dev/null +++ b/js/node_options/common/utils.js @@ -0,0 +1,94 @@ +import { app } from '../../../../scripts/app.js' +import { $el } from "../../../../scripts/ui.js"; + +export function addStylesheet(url) { + if (url.endsWith(".js")) { + url = url.substr(0, url.length - 2) + "css"; + } + $el("link", { + parent: document.head, + rel: "stylesheet", + type: "text/css", + href: url.startsWith("http") ? url : getUrl(url), + }); +} + +export function getUrl(path, baseUrl) { + if (baseUrl) { + return new URL(path, baseUrl).toString(); + } else { + return new URL("../" + path, import.meta.url).toString(); + } +} + +export async function loadImage(url) { + return new Promise((res, rej) => { + const img = new Image(); + img.onload = res; + img.onerror = rej; + img.src = url; + }); +} + +export function addMenuHandler(nodeType, cb) { + + const GROUPED_MENU_ORDER = { + "🔄 Swap with...": 0, + "⛓ Add link...": 1, + "📜 Add script...": 2, + "🔍 View model info...": 3, + "🌱 Seed behavior...": 4, + "📐 Set Resolution...": 5, + "✏️ Add 𝚇 input...": 6, + "✏️ Add 𝚈 input...": 7 + }; + + const originalGetOpts = nodeType.prototype.getExtraMenuOptions; + + nodeType.prototype.getExtraMenuOptions = function () { + let r = originalGetOpts ? originalGetOpts.apply(this, arguments) || [] : []; + + const insertOption = (option) => { + if (GROUPED_MENU_ORDER.hasOwnProperty(option.content)) { + // Find the right position for the option + let targetPos = r.length; // default to the end + + for (let i = 0; i < r.length; i++) { + if (GROUPED_MENU_ORDER.hasOwnProperty(r[i].content) && + GROUPED_MENU_ORDER[option.content] < GROUPED_MENU_ORDER[r[i].content]) { + targetPos = i; + break; + } + } + // Insert the option at the determined position + r.splice(targetPos, 0, option); + } else { + // If the option is not in the GROUPED_MENU_ORDER, simply add it to the end + r.push(option); + } + }; + + cb.call(this, insertOption); + + return r; + }; +} + +export function findWidgetByName(node, widgetName) { + return node.widgets.find(widget => widget.name === widgetName); +} + +// Utility functions +export function addNode(name, nextTo, options) { + options = { select: true, shiftX: 0, shiftY: 0, before: false, ...(options || {}) }; + const node = LiteGraph.createNode(name); + app.graph.add(node); + node.pos = [ + nextTo.pos[0] + options.shiftX, + nextTo.pos[1] + options.shiftY, + ]; + if (options.select) { + app.canvas.selectNode(node, false); + } + return node; +} diff --git a/js/node_options/modelInfo.js b/js/node_options/modelInfo.js new file mode 100644 index 0000000..d8d69a4 --- /dev/null +++ b/js/node_options/modelInfo.js @@ -0,0 +1,336 @@ +import { app } from "../../../scripts/app.js"; +import { $el } from "../../../scripts/ui.js"; +import { ModelInfoDialog } from "./common/modelInfoDialog.js"; +import { addMenuHandler } from "./common/utils.js"; + +const MAX_TAGS = 500; + +class LoraInfoDialog extends ModelInfoDialog { + getTagFrequency() { + if (!this.metadata.ss_tag_frequency) return []; + + const datasets = JSON.parse(this.metadata.ss_tag_frequency); + const tags = {}; + for (const setName in datasets) { + const set = datasets[setName]; + for (const t in set) { + if (t in tags) { + tags[t] += set[t]; + } else { + tags[t] = set[t]; + } + } + } + + return Object.entries(tags).sort((a, b) => b[1] - a[1]); + } + + getResolutions() { + let res = []; + if (this.metadata.ss_bucket_info) { + const parsed = JSON.parse(this.metadata.ss_bucket_info); + if (parsed?.buckets) { + for (const { resolution, count } of Object.values(parsed.buckets)) { + res.push([count, `${resolution.join("x")} * ${count}`]); + } + } + } + res = res.sort((a, b) => b[0] - a[0]).map((a) => a[1]); + let r = this.metadata.ss_resolution; + if (r) { + const s = r.split(","); + const w = s[0].replace("(", ""); + const h = s[1].replace(")", ""); + res.push(`${w.trim()}x${h.trim()} (Base res)`); + } else if ((r = this.metadata["modelspec.resolution"])) { + res.push(r + " (Base res"); + } + if (!res.length) { + res.push("⚠️ Unknown"); + } + return res; + } + + getTagList(tags) { + return tags.map((t) => + $el( + "li.pysssss-model-tag", + { + dataset: { + tag: t[0], + }, + $: (el) => { + el.onclick = () => { + el.classList.toggle("pysssss-model-tag--selected"); + }; + }, + }, + [ + $el("p", { + textContent: t[0], + }), + $el("span", { + textContent: t[1], + }), + ] + ) + ); + } + + addTags() { + let tags = this.getTagFrequency(); + let hasMore; + if (tags?.length) { + const c = tags.length; + let list; + if (c > MAX_TAGS) { + tags = tags.slice(0, MAX_TAGS); + hasMore = $el("p", [ + $el("span", { textContent: `⚠️ Only showing first ${MAX_TAGS} tags ` }), + $el("a", { + href: "#", + textContent: `Show all ${c}`, + onclick: () => { + list.replaceChildren(...this.getTagList(this.getTagFrequency())); + hasMore.remove(); + }, + }), + ]); + } + list = $el("ol.pysssss-model-tags-list", this.getTagList(tags)); + this.tags = $el("div", [list]); + } else { + this.tags = $el("p", { textContent: "⚠️ No tag frequency metadata found" }); + } + + this.content.append(this.tags); + + if (hasMore) { + this.content.append(hasMore); + } + } + + async addInfo() { + this.addInfoEntry("Name", this.metadata.ss_output_name || "⚠️ Unknown"); + this.addInfoEntry("Base Model", this.metadata.ss_sd_model_name || "⚠️ Unknown"); + this.addInfoEntry("Clip Skip", this.metadata.ss_clip_skip || "⚠️ Unknown"); + + this.addInfoEntry( + "Resolution", + $el( + "select", + this.getResolutions().map((r) => $el("option", { textContent: r })) + ) + ); + + super.addInfo(); + const p = this.addCivitaiInfo(); + this.addTags(); + + const info = await p; + if (info) { + $el( + "p", + { + parent: this.content, + textContent: "Trained Words: ", + }, + [ + $el("pre", { + textContent: info.trainedWords.join(", "), + style: { + whiteSpace: "pre-wrap", + margin: "10px 0", + background: "#222", + padding: "5px", + borderRadius: "5px", + maxHeight: "250px", + overflow: "auto", + }, + }), + ] + ); + $el("div", { + parent: this.content, + innerHTML: info.description, + style: { + maxHeight: "250px", + overflow: "auto", + }, + }); + } + } + + createButtons() { + const btns = super.createButtons(); + + function copyTags(e, tags) { + const textarea = $el("textarea", { + parent: document.body, + style: { + position: "fixed", + }, + textContent: tags.map((el) => el.dataset.tag).join(", "), + }); + textarea.select(); + try { + document.execCommand("copy"); + if (!e.target.dataset.text) { + e.target.dataset.text = e.target.textContent; + } + e.target.textContent = "Copied " + tags.length + " tags"; + setTimeout(() => { + e.target.textContent = e.target.dataset.text; + }, 1000); + } catch (ex) { + prompt("Copy to clipboard: Ctrl+C, Enter", text); + } finally { + document.body.removeChild(textarea); + } + } + + btns.unshift( + $el("button", { + type: "button", + textContent: "Copy Selected", + onclick: (e) => { + copyTags(e, [...this.tags.querySelectorAll(".pysssss-model-tag--selected")]); + }, + }), + $el("button", { + type: "button", + textContent: "Copy All", + onclick: (e) => { + copyTags(e, [...this.tags.querySelectorAll(".pysssss-model-tag")]); + }, + }) + ); + + return btns; + } +} + +class CheckpointInfoDialog extends ModelInfoDialog { + async addInfo() { + super.addInfo(); + const info = await this.addCivitaiInfo(); + if (info) { + this.addInfoEntry("Base Model", info.baseModel || "⚠️ Unknown"); + + $el("div", { + parent: this.content, + innerHTML: info.description, + style: { + maxHeight: "250px", + overflow: "auto", + }, + }); + } + } +} + +const generateNames = (prefix, start, end) => { + const result = []; + if (start < end) { + for (let i = start; i <= end; i++) { + result.push(`${prefix}${i}`); + } + } else { + for (let i = start; i >= end; i--) { + result.push(`${prefix}${i}`); + } + } + return result +} + +// NOTE: Orders reversed so they appear in ascending order +const infoHandler = { + "Efficient Loader": { + "loras": ["lora_name"], + "checkpoints": ["ckpt_name"] + }, + "Eff. Loader SDXL": { + "checkpoints": ["refiner_ckpt_name", "base_ckpt_name"] + }, + "LoRA Stacker": { + "loras": generateNames("lora_name_", 50, 1) + }, + "XY Input: LoRA": { + "loras": generateNames("lora_name_", 50, 1) + }, + "HighRes-Fix Script": { + "checkpoints": ["hires_ckpt_name"] + } +}; + +// Utility functions and other parts of your code remain unchanged + +app.registerExtension({ + name: "efficiency.ModelInfo", + beforeRegisterNodeDef(nodeType) { + const types = infoHandler[nodeType.comfyClass]; + + if (types) { + addMenuHandler(nodeType, function (insertOption) { // Here, we are calling addMenuHandler + let submenuItems = []; // to store submenu items + + const addSubMenuOption = (type, widgetNames) => { + widgetNames.forEach(widgetName => { + const widgetValue = this.widgets.find(w => w.name === widgetName)?.value; + + // Check if widgetValue is "None" + if (!widgetValue || widgetValue === "None") { + return; + } + + let value = widgetValue; + if (value.content) { + value = value.content; + } + const cls = type === "loras" ? LoraInfoDialog : CheckpointInfoDialog; + + const label = widgetName; + + // Push to submenuItems + submenuItems.push({ + content: label, + callback: async () => { + new cls(value).show(type, value); + }, + }); + }); + }; + + if (typeof types === 'object') { + Object.keys(types).forEach(type => { + addSubMenuOption(type, types[type]); + }); + } + + // If we have submenu items, use insertOption + if (submenuItems.length) { + insertOption({ // Using insertOption here + content: "🔍 View model info...", + has_submenu: true, + callback: (value, options, e, menu, node) => { + new LiteGraph.ContextMenu(submenuItems, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + + return false; // This ensures the original context menu doesn't proceed + } + }); + } + }); + } + }, +}); + + + + + + diff --git a/js/node_options/setResolution.js b/js/node_options/setResolution.js new file mode 100644 index 0000000..cfe2a9b --- /dev/null +++ b/js/node_options/setResolution.js @@ -0,0 +1,88 @@ +// Additional functions and imports +import { app } from "../../../scripts/app.js"; +import { addMenuHandler, findWidgetByName } from "./common/utils.js"; + +// A mapping for resolutions based on the type of the loader +const RESOLUTIONS = { + "Efficient Loader": [ + {width: 512, height: 512}, + {width: 512, height: 768}, + {width: 512, height: 640}, + {width: 640, height: 512}, + {width: 640, height: 768}, + {width: 640, height: 640}, + {width: 768, height: 512}, + {width: 768, height: 768}, + {width: 768, height: 640}, + ], + "Eff. Loader SDXL": [ + {width: 1024, height: 1024}, + {width: 1152, height: 896}, + {width: 896, height: 1152}, + {width: 1216, height: 832}, + {width: 832, height: 1216}, + {width: 1344, height: 768}, + {width: 768, height: 1344}, + {width: 1536, height: 640}, + {width: 640, height: 1536} + ] +}; + +// Function to set the resolution of a node +function setNodeResolution(node, width, height) { + let widthWidget = findWidgetByName(node, "empty_latent_width"); + let heightWidget = findWidgetByName(node, "empty_latent_height"); + + if (widthWidget) { + widthWidget.value = width; + } + + if (heightWidget) { + heightWidget.value = height; + } +} + +// The callback for the resolution submenu +function resolutionMenuCallback(node, width, height) { + return function() { + setNodeResolution(node, width, height); + }; +} + +// Show the set resolution submenu +function showResolutionMenu(value, options, e, menu, node) { + const resolutions = RESOLUTIONS[node.type]; + if (!resolutions) { + return false; + } + + const resolutionOptions = resolutions.map(res => ({ + content: `${res.width} x ${res.height}`, + callback: resolutionMenuCallback(node, res.width, res.height) + })); + + new LiteGraph.ContextMenu(resolutionOptions, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + + return false; // This ensures the original context menu doesn't proceed +} + +// Extension Definition +app.registerExtension({ + name: "efficiency.SetResolution", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (["Efficient Loader", "Eff. Loader SDXL"].includes(nodeData.name)) { + addMenuHandler(nodeType, function (insertOption) { + insertOption({ + content: "📐 Set Resolution...", + has_submenu: true, + callback: showResolutionMenu + }); + }); + } + }, +}); diff --git a/js/node_options/swapLoaders.js b/js/node_options/swapLoaders.js new file mode 100644 index 0000000..cfcb2e1 --- /dev/null +++ b/js/node_options/swapLoaders.js @@ -0,0 +1,135 @@ +import { app } from "../../../scripts/app.js"; +import { addMenuHandler } from "./common/utils.js"; +import { findWidgetByName } from "./common/utils.js"; + +function replaceNode(oldNode, newNodeName) { + const newNode = LiteGraph.createNode(newNodeName); + if (!newNode) { + return; + } + app.graph.add(newNode); + + newNode.pos = oldNode.pos.slice(); + newNode.size = oldNode.size.slice(); + + // Transfer widget values + const widgetMapping = { + "ckpt_name": "base_ckpt_name", + "vae_name": "vae_name", + "clip_skip": "base_clip_skip", + "positive": "positive", + "negative": "negative", + "prompt_style": "prompt_style", + "empty_latent_width": "empty_latent_width", + "empty_latent_height": "empty_latent_height", + "batch_size": "batch_size" + }; + + let effectiveWidgetMapping = widgetMapping; + + // Invert the mapping when going from "Eff. Loader SDXL" to "Efficient Loader" + if (oldNode.type === "Eff. Loader SDXL" && newNodeName === "Efficient Loader") { + effectiveWidgetMapping = {}; + for (const [key, value] of Object.entries(widgetMapping)) { + effectiveWidgetMapping[value] = key; + } + } + + oldNode.widgets.forEach(widget => { + const newName = effectiveWidgetMapping[widget.name]; + if (newName) { + const newWidget = findWidgetByName(newNode, newName); + if (newWidget) { + newWidget.value = widget.value; + } + } + }); + + // Hardcoded transfer for specific outputs based on the output names from the nodes in the image + const outputMapping = { + "MODEL": null, // Not present in "Eff. Loader SDXL" + "CONDITIONING+": null, // Not present in "Eff. Loader SDXL" + "CONDITIONING-": null, // Not present in "Eff. Loader SDXL" + "LATENT": "LATENT", + "VAE": "VAE", + "CLIP": null, // Not present in "Eff. Loader SDXL" + "DEPENDENCIES": "DEPENDENCIES" + }; + + // Transfer connections from old node outputs to new node outputs based on the outputMapping + oldNode.outputs.forEach((output, index) => { + if (output && output.links && outputMapping[output.name]) { + const newOutputName = outputMapping[output.name]; + + // If the new node does not have this output, skip + if (newOutputName === null) { + return; + } + + const newOutputIndex = newNode.findOutputSlot(newOutputName); + if (newOutputIndex !== -1) { + output.links.forEach(link => { + const targetLinkInfo = oldNode.graph.links[link]; + if (targetLinkInfo) { + const targetNode = oldNode.graph.getNodeById(targetLinkInfo.target_id); + if (targetNode) { + newNode.connect(newOutputIndex, targetNode, targetLinkInfo.target_slot); + } + } + }); + } + } + }); + + // Remove old node + app.graph.remove(oldNode); +} + +function replaceNodeMenuCallback(currentNode, targetNodeName) { + return function() { + replaceNode(currentNode, targetNodeName); + }; +} + +function showSwapMenu(value, options, e, menu, node) { + const swapOptions = []; + + if (node.type !== "Efficient Loader") { + swapOptions.push({ + content: "Efficient Loader", + callback: replaceNodeMenuCallback(node, "Efficient Loader") + }); + } + + if (node.type !== "Eff. Loader SDXL") { + swapOptions.push({ + content: "Eff. Loader SDXL", + callback: replaceNodeMenuCallback(node, "Eff. Loader SDXL") + }); + } + + new LiteGraph.ContextMenu(swapOptions, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + + return false; // This ensures the original context menu doesn't proceed +} + +// Extension Definition +app.registerExtension({ + name: "efficiency.SwapLoaders", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (["Efficient Loader", "Eff. Loader SDXL"].includes(nodeData.name)) { + addMenuHandler(nodeType, function (insertOption) { + insertOption({ + content: "🔄 Swap with...", + has_submenu: true, + callback: showSwapMenu + }); + }); + } + }, +}); diff --git a/js/node_options/swapSamplers.js b/js/node_options/swapSamplers.js new file mode 100644 index 0000000..cf8deaf --- /dev/null +++ b/js/node_options/swapSamplers.js @@ -0,0 +1,191 @@ +import { app } from "../../../scripts/app.js"; +import { addMenuHandler } from "./common/utils.js"; +import { findWidgetByName } from "./common/utils.js"; + +function replaceNode(oldNode, newNodeName) { + // Create new node + const newNode = LiteGraph.createNode(newNodeName); + if (!newNode) { + return; + } + app.graph.add(newNode); + + // Position new node at the same position as the old node + newNode.pos = oldNode.pos.slice(); + + // Define widget mappings + const mappings = { + "KSampler (Efficient) <-> KSampler Adv. (Efficient)": { + seed: "noise_seed", + cfg: "cfg", + sampler_name: "sampler_name", + scheduler: "scheduler", + preview_method: "preview_method", + vae_decode: "vae_decode" + }, + "KSampler (Efficient) <-> KSampler SDXL (Eff.)": { + seed: "noise_seed", + cfg: "cfg", + sampler_name: "sampler_name", + scheduler: "scheduler", + preview_method: "preview_method", + vae_decode: "vae_decode" + }, + "KSampler Adv. (Efficient) <-> KSampler SDXL (Eff.)": { + noise_seed: "noise_seed", + steps: "steps", + cfg: "cfg", + sampler_name: "sampler_name", + scheduler: "scheduler", + start_at_step: "start_at_step", + preview_method: "preview_method", + vae_decode: "vae_decode"} + }; + + const swapKey = `${oldNode.type} <-> ${newNodeName}`; + + let widgetMapping = {}; + + // Check if a reverse mapping is needed + if (!mappings[swapKey]) { + const reverseKey = `${newNodeName} <-> ${oldNode.type}`; + const reverseMapping = mappings[reverseKey]; + if (reverseMapping) { + widgetMapping = Object.entries(reverseMapping).reduce((acc, [key, value]) => { + acc[value] = key; + return acc; + }, {}); + } + } else { + widgetMapping = mappings[swapKey]; + } + + if (oldNode.type === "KSampler (Efficient)" && (newNodeName === "KSampler Adv. (Efficient)" || newNodeName === "KSampler SDXL (Eff.)")) { + const denoise = Math.min(Math.max(findWidgetByName(oldNode, "denoise").value, 0), 1); // Ensure denoise is between 0 and 1 + const steps = Math.min(Math.max(findWidgetByName(oldNode, "steps").value, 0), 10000); // Ensure steps is between 0 and 10000 + + const total_steps = Math.floor(steps / denoise); + const start_at_step = total_steps - steps; + + findWidgetByName(newNode, "steps").value = Math.min(Math.max(total_steps, 0), 10000); // Ensure total_steps is between 0 and 10000 + findWidgetByName(newNode, "start_at_step").value = Math.min(Math.max(start_at_step, 0), 10000); // Ensure start_at_step is between 0 and 10000 + } + else if ((oldNode.type === "KSampler Adv. (Efficient)" || oldNode.type === "KSampler SDXL (Eff.)") && newNodeName === "KSampler (Efficient)") { + const stepsAdv = Math.min(Math.max(findWidgetByName(oldNode, "steps").value, 0), 10000); // Ensure stepsAdv is between 0 and 10000 + const start_at_step = Math.min(Math.max(findWidgetByName(oldNode, "start_at_step").value, 0), 10000); // Ensure start_at_step is between 0 and 10000 + + const denoise = Math.min(Math.max((stepsAdv - start_at_step) / stepsAdv, 0), 1); // Ensure denoise is between 0 and 1 + const stepsTotal = stepsAdv - start_at_step; + + findWidgetByName(newNode, "denoise").value = denoise; + findWidgetByName(newNode, "steps").value = Math.min(Math.max(stepsTotal, 0), 10000); // Ensure stepsTotal is between 0 and 10000 + } + + // Transfer widget values from old node to new node + oldNode.widgets.forEach(widget => { + const newName = widgetMapping[widget.name]; + if (newName) { + const newWidget = findWidgetByName(newNode, newName); + if (newWidget) { + newWidget.value = widget.value; + } + } + }); + + // Determine the starting indices based on the node types + let oldNodeInputStartIndex = 0; + let newNodeInputStartIndex = 0; + let oldNodeOutputStartIndex = 0; + let newNodeOutputStartIndex = 0; + + if (oldNode.type === "KSampler SDXL (Eff.)" || newNodeName === "KSampler SDXL (Eff.)") { + oldNodeInputStartIndex = (oldNode.type === "KSampler SDXL (Eff.)") ? 1 : 3; + newNodeInputStartIndex = (newNodeName === "KSampler SDXL (Eff.)") ? 1 : 3; + oldNodeOutputStartIndex = (oldNode.type === "KSampler SDXL (Eff.)") ? 1 : 3; + newNodeOutputStartIndex = (newNodeName === "KSampler SDXL (Eff.)") ? 1 : 3; + } + + // Transfer connections from old node to new node + oldNode.inputs.slice(oldNodeInputStartIndex).forEach((input, index) => { + if (input && input.link !== null) { + const originLinkInfo = oldNode.graph.links[input.link]; + if (originLinkInfo) { + const originNode = oldNode.graph.getNodeById(originLinkInfo.origin_id); + if (originNode) { + originNode.connect(originLinkInfo.origin_slot, newNode, index + newNodeInputStartIndex); + } + } + } + }); + + oldNode.outputs.slice(oldNodeOutputStartIndex).forEach((output, index) => { + if (output && output.links) { + output.links.forEach(link => { + const targetLinkInfo = oldNode.graph.links[link]; + if (targetLinkInfo) { + const targetNode = oldNode.graph.getNodeById(targetLinkInfo.target_id); + if (targetNode) { + newNode.connect(index + newNodeOutputStartIndex, targetNode, targetLinkInfo.target_slot); + } + } + }); + } + }); + + // Remove old node + app.graph.remove(oldNode); +} + +function replaceNodeMenuCallback(currentNode, targetNodeName) { + return function() { + replaceNode(currentNode, targetNodeName); + }; +} + +function showSwapMenu(value, options, e, menu, node) { + const swapOptions = []; + + if (node.type !== "KSampler (Efficient)") { + swapOptions.push({ + content: "KSampler (Efficient)", + callback: replaceNodeMenuCallback(node, "KSampler (Efficient)") + }); + } + if (node.type !== "KSampler Adv. (Efficient)") { + swapOptions.push({ + content: "KSampler Adv. (Efficient)", + callback: replaceNodeMenuCallback(node, "KSampler Adv. (Efficient)") + }); + } + if (node.type !== "KSampler SDXL (Eff.)") { + swapOptions.push({ + content: "KSampler SDXL (Eff.)", + callback: replaceNodeMenuCallback(node, "KSampler SDXL (Eff.)") + }); + } + + new LiteGraph.ContextMenu(swapOptions, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + + return false; // This ensures the original context menu doesn't proceed +} + +// Extension Definition +app.registerExtension({ + name: "efficiency.SwapSamplers", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (["KSampler (Efficient)", "KSampler Adv. (Efficient)", "KSampler SDXL (Eff.)"].includes(nodeData.name)) { + addMenuHandler(nodeType, function (insertOption) { + insertOption({ + content: "🔄 Swap with...", + has_submenu: true, + callback: showSwapMenu + }); + }); + } + }, +}); diff --git a/js/node_options/swapScripts.js b/js/node_options/swapScripts.js new file mode 100644 index 0000000..e2b4c92 --- /dev/null +++ b/js/node_options/swapScripts.js @@ -0,0 +1,100 @@ +import { app } from "../../../scripts/app.js"; +import { addMenuHandler } from "./common/utils.js"; + +function replaceNode(oldNode, newNodeName) { + const newNode = LiteGraph.createNode(newNodeName); + if (!newNode) { + return; + } + app.graph.add(newNode); + + newNode.pos = oldNode.pos.slice(); + + // Transfer connections from old node to new node + // XY Plot and AnimateDiff have only one output + if(["XY Plot", "AnimateDiff Script"].includes(oldNode.type)) { + if (oldNode.outputs[0] && oldNode.outputs[0].links) { + oldNode.outputs[0].links.forEach(link => { + const targetLinkInfo = oldNode.graph.links[link]; + if (targetLinkInfo) { + const targetNode = oldNode.graph.getNodeById(targetLinkInfo.target_id); + if (targetNode) { + newNode.connect(0, targetNode, targetLinkInfo.target_slot); + } + } + }); + } + } else { + // Noise Control Script, HighRes-Fix Script, and Tiled Upscaler Script have 1 input and 1 output at index 0 + if (oldNode.inputs[0] && oldNode.inputs[0].link !== null) { + const originLinkInfo = oldNode.graph.links[oldNode.inputs[0].link]; + if (originLinkInfo) { + const originNode = oldNode.graph.getNodeById(originLinkInfo.origin_id); + if (originNode) { + originNode.connect(originLinkInfo.origin_slot, newNode, 0); + } + } + } + + if (oldNode.outputs[0] && oldNode.outputs[0].links) { + oldNode.outputs[0].links.forEach(link => { + const targetLinkInfo = oldNode.graph.links[link]; + if (targetLinkInfo) { + const targetNode = oldNode.graph.getNodeById(targetLinkInfo.target_id); + if (targetNode) { + newNode.connect(0, targetNode, targetLinkInfo.target_slot); + } + } + }); + } + } + + // Remove old node + app.graph.remove(oldNode); +} + +function replaceNodeMenuCallback(currentNode, targetNodeName) { + return function() { + replaceNode(currentNode, targetNodeName); + }; +} + +function showSwapMenu(value, options, e, menu, node) { + const scriptNodes = [ + "XY Plot", + "Noise Control Script", + "HighRes-Fix Script", + "Tiled Upscaler Script", + "AnimateDiff Script" + ]; + + const swapOptions = scriptNodes.filter(n => n !== node.type).map(n => ({ + content: n, + callback: replaceNodeMenuCallback(node, n) + })); + + new LiteGraph.ContextMenu(swapOptions, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + + return false; +} + +// Extension Definition +app.registerExtension({ + name: "efficiency.SwapScripts", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (["XY Plot", "Noise Control Script", "HighRes-Fix Script", "Tiled Upscaler Script", "AnimateDiff Script"].includes(nodeData.name)) { + addMenuHandler(nodeType, function (insertOption) { + insertOption({ + content: "🔄 Swap with...", + has_submenu: true, + callback: showSwapMenu + }); + }); + } + }, +}); diff --git a/js/node_options/swapXYinputs.js b/js/node_options/swapXYinputs.js new file mode 100644 index 0000000..cac2fd5 --- /dev/null +++ b/js/node_options/swapXYinputs.js @@ -0,0 +1,98 @@ +import { app } from "../../../scripts/app.js"; +import { addMenuHandler } from "./common/utils.js"; + +function replaceNode(oldNode, newNodeName) { + const newNode = LiteGraph.createNode(newNodeName); + if (!newNode) { + return; + } + app.graph.add(newNode); + + newNode.pos = oldNode.pos.slice(); + + // Handle the special nodes with two outputs + const nodesWithTwoOutputs = ["XY Input: LoRA Plot", "XY Input: Control Net Plot", "XY Input: Manual XY Entry"]; + let outputCount = nodesWithTwoOutputs.includes(oldNode.type) ? 2 : 1; + + // Transfer output connections from old node to new node + oldNode.outputs.slice(0, outputCount).forEach((output, index) => { + if (output && output.links) { + output.links.forEach(link => { + const targetLinkInfo = oldNode.graph.links[link]; + if (targetLinkInfo) { + const targetNode = oldNode.graph.getNodeById(targetLinkInfo.target_id); + if (targetNode) { + newNode.connect(index, targetNode, targetLinkInfo.target_slot); + } + } + }); + } + }); + + // Remove old node + app.graph.remove(oldNode); +} + +function replaceNodeMenuCallback(currentNode, targetNodeName) { + return function() { + replaceNode(currentNode, targetNodeName); + }; +} + +function showSwapMenu(value, options, e, menu, node) { + const swapOptions = []; + const xyInputNodes = [ + "XY Input: Seeds++ Batch", + "XY Input: Add/Return Noise", + "XY Input: Steps", + "XY Input: CFG Scale", + "XY Input: Sampler/Scheduler", + "XY Input: Denoise", + "XY Input: VAE", + "XY Input: Prompt S/R", + "XY Input: Aesthetic Score", + "XY Input: Refiner On/Off", + "XY Input: Checkpoint", + "XY Input: Clip Skip", + "XY Input: LoRA", + "XY Input: LoRA Plot", + "XY Input: LoRA Stacks", + "XY Input: Control Net", + "XY Input: Control Net Plot", + "XY Input: Manual XY Entry" + ]; + + for (const nodeType of xyInputNodes) { + if (node.type !== nodeType) { + swapOptions.push({ + content: nodeType, + callback: replaceNodeMenuCallback(node, nodeType) + }); + } + } + + new LiteGraph.ContextMenu(swapOptions, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + + return false; // This ensures the original context menu doesn't proceed +} + +// Extension Definition +app.registerExtension({ + name: "efficiency.swapXYinputs", + async beforeRegisterNodeDef(nodeType, nodeData, app) { + if (nodeData.name.startsWith("XY Input:")) { + addMenuHandler(nodeType, function (insertOption) { + insertOption({ + content: "🔄 Swap with...", + has_submenu: true, + callback: showSwapMenu + }); + }); + } + }, +}); diff --git a/js/previewfix.js b/js/previewfix.js index bdb221e..5da96b7 100644 --- a/js/previewfix.js +++ b/js/previewfix.js @@ -1,13 +1,10 @@ import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; -const ext = { +app.registerExtension({ name: "efficiency.previewfix", - ws: null, - maxCount: 0, - currentCount: 0, - sendBlob: false, - startProcessing: false, - lastBlobURL: null, + lastExecutedNodeId: null, + blobsToRevoke: [], // Array to accumulate blob URLs for revocation debug: false, log(...args) { @@ -18,89 +15,53 @@ const ext = { if (this.debug) console.error(...args); }, - async sendBlobDataAsDataURL(blobURL) { - const blob = await fetch(blobURL).then(res => res.blob()); - const reader = new FileReader(); - reader.readAsDataURL(blob); - reader.onloadend = () => this.ws.send(reader.result); - }, + shouldRevokeBlobForNode(nodeId) { + const node = app.graph.getNodeById(nodeId); + + const validTitles = [ + "KSampler (Efficient)", + "KSampler Adv. (Efficient)", + "KSampler SDXL (Eff.)" + ]; - handleCommandMessage(data) { - Object.assign(this, { - maxCount: data.maxCount, - sendBlob: data.sendBlob, - startProcessing: data.startProcessing, - currentCount: 0 - }); - - if (!this.startProcessing && this.lastBlobURL) { - this.log("[BlobURLLogger] Revoking last Blob URL:", this.lastBlobURL); - URL.revokeObjectURL(this.lastBlobURL); - this.lastBlobURL = null; + if (!node || !validTitles.includes(node.title)) { + return false; } + + const getValue = name => ((node.widgets || []).find(w => w.name === name) || {}).value; + return getValue("preview_method") !== "none" && getValue("vae_decode").includes("true"); }, - init() { - this.log("[BlobURLLogger] Initializing..."); - - this.ws = new WebSocket('ws://127.0.0.1:8288'); - - this.ws.addEventListener('open', () => this.log('[BlobURLLogger] WebSocket connection opened.')); - this.ws.addEventListener('error', err => this.error('[BlobURLLogger] WebSocket Error:', err)); - this.ws.addEventListener('message', (event) => { - try { - const data = JSON.parse(event.data); - if (data.maxCount !== undefined && data.sendBlob !== undefined && data.startProcessing !== undefined) { - this.handleCommandMessage(data); - } - } catch (err) { - this.error('[BlobURLLogger] Error parsing JSON:', err); - } - }); - + setup() { + // Intercepting blob creation to store and immediately revoke the last blob URL const originalCreateObjectURL = URL.createObjectURL; URL.createObjectURL = (object) => { - const blobURL = originalCreateObjectURL.call(this, object); - if (blobURL.startsWith('blob:') && this.startProcessing) { + const blobURL = originalCreateObjectURL(object); + if (blobURL.startsWith('blob:')) { this.log("[BlobURLLogger] Blob URL created:", blobURL); - this.lastBlobURL = blobURL; - if (this.sendBlob && this.currentCount < this.maxCount) { - this.sendBlobDataAsDataURL(blobURL); + + // If the current node meets the criteria, add the blob URL to the revocation list + if (this.shouldRevokeBlobForNode(this.lastExecutedNodeId)) { + this.blobsToRevoke.push(blobURL); } - this.currentCount++; } return blobURL; }; - this.log("[BlobURLLogger] Hook attached."); - } -}; - -function toggleWidgetVisibility(node, widgetName, isVisible) { - const widget = node.widgets.find(w => w.name === widgetName); - if (widget) { - widget.visible = isVisible; - node.setDirtyCanvas(true); - } -} - -function handleLoraNameChange(node, loraNameWidget) { - const isNone = loraNameWidget.value === "None"; - toggleWidgetVisibility(node, "lora_model_strength", !isNone); - toggleWidgetVisibility(node, "lora_clip_strength", !isNone); -} - -app.registerExtension({ - ...ext, - nodeCreated(node) { - if (node.getTitle() === "Efficient Loader") { - const loraNameWidget = node.widgets.find(w => w.name === "lora_name"); - if (loraNameWidget) { - handleLoraNameChange(node, loraNameWidget); - loraNameWidget.onChange = function() { - handleLoraNameChange(node, this); - }; + // Listen to the start of the node execution to revoke all accumulated blob URLs + api.addEventListener("executing", ({ detail }) => { + if (this.lastExecutedNodeId !== detail || detail === null) { + this.blobsToRevoke.forEach(blob => { + this.log("[BlobURLLogger] Revoking Blob URL:", blob); + URL.revokeObjectURL(blob); + }); + this.blobsToRevoke = []; // Clear the list after revoking all blobs } - } - } -}); + + // Update the last executed node ID + this.lastExecutedNodeId = detail; + }); + + this.log("[BlobURLLogger] Hook attached."); + }, +}); \ No newline at end of file diff --git a/js/seedcontrol.js b/js/seedcontrol.js index 90509c4..fbae407 100644 --- a/js/seedcontrol.js +++ b/js/seedcontrol.js @@ -1,11 +1,18 @@ import { app } from "../../scripts/app.js"; +import { addMenuHandler } from "./node_options/common/utils.js"; const LAST_SEED_BUTTON_LABEL = '🎲 Randomize / ♻️ Last Queued Seed'; +const SEED_BEHAVIOR_RANDOMIZE = 'Randomize'; +const SEED_BEHAVIOR_INCREMENT = 'Increment'; +const SEED_BEHAVIOR_DECREMENT = 'Decrement'; const NODE_WIDGET_MAP = { "KSampler (Efficient)": "seed", "KSampler Adv. (Efficient)": "noise_seed", - "KSampler SDXL (Eff.)": "noise_seed" + "KSampler SDXL (Eff.)": "noise_seed", + "Noise Control Script": "seed", + "HighRes-Fix Script": "seed", + "Tiled Upscaler Script": "seed" }; const SPECIFIC_WIDTH = 325; // Set to desired width @@ -21,11 +28,9 @@ class SeedControl { this.lastSeed = -1; this.serializedCtx = {}; this.node = node; - this.holdFlag = false; // Flag to track if sampler_state was set to "Hold" - this.usedLastSeedOnHoldRelease = false; // To track if we used the lastSeed after releasing hold + this.seedBehavior = 'randomize'; // Default behavior let controlAfterGenerateIndex; - this.samplerStateWidget = this.node.widgets.find(w => w.name === 'sampler_state'); for (const [i, w] of this.node.widgets.entries()) { if (w.name === seedName) { @@ -41,10 +46,24 @@ class SeedControl { } this.lastSeedButton = this.node.addWidget("button", LAST_SEED_BUTTON_LABEL, null, () => { - if (this.seedWidget.value != -1) { + const isValidValue = Number.isInteger(this.seedWidget.value) && this.seedWidget.value >= min && this.seedWidget.value <= max; + + // Special case: if the current label is the default and seed value is -1 + if (this.lastSeedButton.name === LAST_SEED_BUTTON_LABEL && this.seedWidget.value == -1) { + return; // Do nothing and return early + } + + if (isValidValue && this.seedWidget.value != -1) { + this.lastSeed = this.seedWidget.value; this.seedWidget.value = -1; } else if (this.lastSeed !== -1) { this.seedWidget.value = this.lastSeed; + } else { + this.seedWidget.value = -1; // Set to -1 if the label didn't update due to a seed value issue + } + + if (isValidValue) { + this.updateButtonLabel(); // Update the button label to reflect the change } }, { width: 50, serialize: false }); @@ -60,76 +79,167 @@ class SeedControl { const range = (max - min) / (this.seedWidget.options.step / 10); this.seedWidget.serializeValue = async (node, index) => { - const currentSeed = this.seedWidget.value; - this.serializedCtx = { - wasRandom: currentSeed == -1, - }; - - // Check for the state transition and act accordingly. - if (this.samplerStateWidget) { - if (this.samplerStateWidget.value !== "Hold" && this.holdFlag && !this.usedLastSeedOnHoldRelease) { - this.serializedCtx.seedUsed = this.lastSeed; - this.usedLastSeedOnHoldRelease = true; - this.holdFlag = false; // Reset flag for the next cycle - } + // Check if the button is disabled + if (this.lastSeedButton.disabled) { + return this.seedWidget.value; } - if (!this.usedLastSeedOnHoldRelease) { - if (this.serializedCtx.wasRandom) { - this.serializedCtx.seedUsed = Math.floor(Math.random() * range) * (this.seedWidget.options.step / 10) + min; - } else { - this.serializedCtx.seedUsed = this.seedWidget.value; + const currentSeed = this.seedWidget.value; + this.serializedCtx = { + wasSpecial: currentSeed == -1, + }; + + if (this.serializedCtx.wasSpecial) { + switch (this.seedBehavior) { + case 'increment': + this.serializedCtx.seedUsed = this.lastSeed + 1; + break; + case 'decrement': + this.serializedCtx.seedUsed = this.lastSeed - 1; + break; + default: + this.serializedCtx.seedUsed = Math.floor(Math.random() * range) * (this.seedWidget.options.step / 10) + min; + break; } + + // Ensure the seed value is an integer and remains within the accepted range + this.serializedCtx.seedUsed = Number.isInteger(this.serializedCtx.seedUsed) ? Math.min(Math.max(this.serializedCtx.seedUsed, min), max) : this.seedWidget.value; + + } else { + this.serializedCtx.seedUsed = this.seedWidget.value; } if (node && node.widgets_values) { node.widgets_values[index] = this.serializedCtx.seedUsed; - }else{ + } else { // Update the last seed value and the button's label to show the current seed value this.lastSeed = this.serializedCtx.seedUsed; - this.lastSeedButton.name = `🎲 Randomize / ♻️ ${this.lastSeed}`; + this.updateButtonLabel(); } this.seedWidget.value = this.serializedCtx.seedUsed; - if (this.serializedCtx.wasRandom) { + if (this.serializedCtx.wasSpecial) { this.lastSeed = this.serializedCtx.seedUsed; - this.lastSeedButton.name = `🎲 Randomize / ♻️ ${this.lastSeed}`; - if (this.samplerStateWidget.value === "Hold") { - this.holdFlag = true; - } - } - - if (this.usedLastSeedOnHoldRelease && this.samplerStateWidget.value !== "Hold") { - // Reset the flag to ensure default behavior is restored - this.usedLastSeedOnHoldRelease = false; + this.updateButtonLabel(); } return this.serializedCtx.seedUsed; }; this.seedWidget.afterQueued = () => { - if (this.serializedCtx.wasRandom) { + // Check if the button is disabled + if (this.lastSeedButton.disabled) { + return; // Exit the function immediately + } + + if (this.serializedCtx.wasSpecial) { this.seedWidget.value = -1; } - + // Check if seed has changed to a non -1 value, and if so, update lastSeed if (this.seedWidget.value !== -1) { this.lastSeed = this.seedWidget.value; } - // Update the button's label to show the current last seed value - this.lastSeedButton.name = `🎲 Randomize / ♻️ ${this.lastSeed}`; - + this.updateButtonLabel(); this.serializedCtx = {}; }; } + + setBehavior(behavior) { + this.seedBehavior = behavior; + + // Capture the current seed value as lastSeed and then set the seed widget value to -1 + if (this.seedWidget.value != -1) { + this.lastSeed = this.seedWidget.value; + this.seedWidget.value = -1; + } + + this.updateButtonLabel(); + } + + updateButtonLabel() { + + switch (this.seedBehavior) { + case 'increment': + this.lastSeedButton.name = `➕ Increment / ♻️ ${this.lastSeed === -1 ? "Last Queued Seed" : this.lastSeed}`; + break; + case 'decrement': + this.lastSeedButton.name = `➖ Decrement / ♻️ ${this.lastSeed === -1 ? "Last Queued Seed" : this.lastSeed}`; + break; + default: + this.lastSeedButton.name = `🎲 Randomize / ♻️ ${this.lastSeed === -1 ? "Last Queued Seed" : this.lastSeed}`; + break; + } + } + } +function showSeedBehaviorMenu(value, options, e, menu, node) { + const behaviorOptions = [ + { + content: "🎲 Randomize", + callback: () => { + node.seedControl.setBehavior('randomize'); + } + }, + { + content: "➕ Increment", + callback: () => { + node.seedControl.setBehavior('increment'); + } + }, + { + content: "➖ Decrement", + callback: () => { + node.seedControl.setBehavior('decrement'); + } + } + ]; + + new LiteGraph.ContextMenu(behaviorOptions, { + event: e, + callback: null, + parentMenu: menu, + node: node + }); + + return false; // This ensures the original context menu doesn't proceed +} + +// Extension Definition app.registerExtension({ name: "efficiency.seedcontrol", async beforeRegisterNodeDef(nodeType, nodeData, _app) { if (NODE_WIDGET_MAP[nodeData.name]) { + addMenuHandler(nodeType, function (insertOption) { + // Check conditions before showing the seed behavior option + let showSeedOption = true; + + if (nodeData.name === "Noise Control Script") { + // Check for 'add_seed_noise' widget being false + const addSeedNoiseWidget = this.widgets.find(w => w.name === 'add_seed_noise'); + if (addSeedNoiseWidget && !addSeedNoiseWidget.value) { + showSeedOption = false; + } + } else if (nodeData.name === "HighRes-Fix Script") { + // Check for 'use_same_seed' widget being true + const useSameSeedWidget = this.widgets.find(w => w.name === 'use_same_seed'); + if (useSameSeedWidget && useSameSeedWidget.value) { + showSeedOption = false; + } + } + + if (showSeedOption) { + insertOption({ + content: "🌱 Seed behavior...", + has_submenu: true, + callback: showSeedBehaviorMenu + }); + } + }); + const onNodeCreated = nodeType.prototype.onNodeCreated; nodeType.prototype.onNodeCreated = function () { onNodeCreated ? onNodeCreated.apply(this, []) : undefined; @@ -138,4 +248,4 @@ app.registerExtension({ }; } }, -}); \ No newline at end of file +}); diff --git a/js/widgethider.js b/js/widgethider.js index e0b1d26..83b6d6d 100644 --- a/js/widgethider.js +++ b/js/widgethider.js @@ -1,4 +1,4 @@ -import { app } from "/scripts/app.js"; +import { app } from "../../scripts/app.js"; let origProps = {}; let initialized = false; @@ -11,19 +11,43 @@ const doesInputWithNameExist = (node, name) => { return node.inputs ? node.inputs.some((input) => input.name === name) : false; }; -const WIDGET_HEIGHT = 24; +const HIDDEN_TAG = "tschide"; // Toggle Widget + change size function toggleWidget(node, widget, show = false, suffix = "") { if (!widget || doesInputWithNameExist(node, widget.name)) return; + + // Store the original properties of the widget if not already stored + if (!origProps[widget.name]) { + origProps[widget.name] = { origType: widget.type, origComputeSize: widget.computeSize }; + } + + const origSize = node.size; + + // Set the widget type and computeSize based on the show flag + widget.type = show ? origProps[widget.name].origType : HIDDEN_TAG + suffix; + widget.computeSize = show ? origProps[widget.name].origComputeSize : () => [0, -4]; + + // Recursively handle linked widgets if they exist + widget.linkedWidgets?.forEach(w => toggleWidget(node, w, ":" + widget.name, show)); + + // Calculate the new height for the node based on its computeSize method + const newHeight = node.computeSize()[1]; + node.setSize([node.size[0], newHeight]); +} + +const WIDGET_HEIGHT = 24; +// Use for Multiline Widget Nodes (aka Efficient Loaders) +function toggleWidget_2(node, widget, show = false, suffix = "") { + if (!widget || doesInputWithNameExist(node, widget.name)) return; - const isCurrentlyVisible = widget.type !== "tschide" + suffix; + const isCurrentlyVisible = widget.type !== HIDDEN_TAG + suffix; if (isCurrentlyVisible === show) return; // Early exit if widget is already in the desired state if (!origProps[widget.name]) { origProps[widget.name] = { origType: widget.type, origComputeSize: widget.computeSize }; } - widget.type = show ? origProps[widget.name].origType : "tschide" + suffix; + widget.type = show ? origProps[widget.name].origType : HIDDEN_TAG + suffix; widget.computeSize = show ? origProps[widget.name].origComputeSize : () => [0, -4]; if (initialized){ @@ -278,7 +302,18 @@ const nodeWidgetHandlers = { }, "XY Input: Control Net Plot": { 'plot_type': handleXYInputControlNetPlotPlotType - } + }, + "Noise Control Script": { + 'add_seed_noise': handleNoiseControlScript + }, + "HighRes-Fix Script": { + 'upscale_type': handleHiResFixScript, + 'use_same_seed': handleHiResFixScript, + 'use_controlnet':handleHiResFixScript + }, + "Tiled Upscaler Script": { + 'use_controlnet':handleTiledUpscalerScript + }, }; // In the main function where widgetLogic is called @@ -293,24 +328,142 @@ function widgetLogic(node, widget) { // Efficient Loader Handlers function handleEfficientLoaderLoraName(node, widget) { if (widget.value === 'None') { - toggleWidget(node, findWidgetByName(node, 'lora_model_strength')); - toggleWidget(node, findWidgetByName(node, 'lora_clip_strength')); + toggleWidget_2(node, findWidgetByName(node, 'lora_model_strength')); + toggleWidget_2(node, findWidgetByName(node, 'lora_clip_strength')); } else { - toggleWidget(node, findWidgetByName(node, 'lora_model_strength'), true); - toggleWidget(node, findWidgetByName(node, 'lora_clip_strength'), true); + toggleWidget_2(node, findWidgetByName(node, 'lora_model_strength'), true); + toggleWidget_2(node, findWidgetByName(node, 'lora_clip_strength'), true); } } // Eff. Loader SDXL Handlers function handleEffLoaderSDXLRefinerCkptName(node, widget) { if (widget.value === 'None') { - toggleWidget(node, findWidgetByName(node, 'refiner_clip_skip')); - toggleWidget(node, findWidgetByName(node, 'positive_ascore')); - toggleWidget(node, findWidgetByName(node, 'negative_ascore')); + toggleWidget_2(node, findWidgetByName(node, 'refiner_clip_skip')); + toggleWidget_2(node, findWidgetByName(node, 'positive_ascore')); + toggleWidget_2(node, findWidgetByName(node, 'negative_ascore')); } else { - toggleWidget(node, findWidgetByName(node, 'refiner_clip_skip'), true); - toggleWidget(node, findWidgetByName(node, 'positive_ascore'), true); - toggleWidget(node, findWidgetByName(node, 'negative_ascore'), true); + toggleWidget_2(node, findWidgetByName(node, 'refiner_clip_skip'), true); + toggleWidget_2(node, findWidgetByName(node, 'positive_ascore'), true); + toggleWidget_2(node, findWidgetByName(node, 'negative_ascore'), true); + } +} + +// Noise Control Script Seed Handler +function handleNoiseControlScript(node, widget) { + + function ensureSeedControlExists(callback) { + if (node.seedControl && node.seedControl.lastSeedButton) { + callback(); + } else { + setTimeout(() => ensureSeedControlExists(callback), 0); + } + } + + ensureSeedControlExists(() => { + if (widget.value === false) { + toggleWidget(node, findWidgetByName(node, 'seed')); + toggleWidget(node, findWidgetByName(node, 'weight')); + toggleWidget(node, node.seedControl.lastSeedButton); + node.seedControl.lastSeedButton.disabled = true; // Disable the button + } else { + toggleWidget(node, findWidgetByName(node, 'seed'), true); + toggleWidget(node, findWidgetByName(node, 'weight'), true); + node.seedControl.lastSeedButton.disabled = false; // Enable the button + toggleWidget(node, node.seedControl.lastSeedButton, true); + } + }); + +} + +/// HighRes-Fix Script Handlers +function handleHiResFixScript(node, widget) { + + function ensureSeedControlExists(callback) { + if (node.seedControl && node.seedControl.lastSeedButton) { + callback(); + } else { + setTimeout(() => ensureSeedControlExists(callback), 0); + } + } + + if (findWidgetByName(node, 'upscale_type').value === "latent") { + toggleWidget(node, findWidgetByName(node, 'pixel_upscaler')); + + toggleWidget(node, findWidgetByName(node, 'hires_ckpt_name'), true); + toggleWidget(node, findWidgetByName(node, 'latent_upscaler'), true); + toggleWidget(node, findWidgetByName(node, 'use_same_seed'), true); + toggleWidget(node, findWidgetByName(node, 'hires_steps'), true); + toggleWidget(node, findWidgetByName(node, 'denoise'), true); + toggleWidget(node, findWidgetByName(node, 'iterations'), true); + + ensureSeedControlExists(() => { + if (findWidgetByName(node, 'use_same_seed').value == true) { + toggleWidget(node, findWidgetByName(node, 'seed')); + toggleWidget(node, node.seedControl.lastSeedButton); + node.seedControl.lastSeedButton.disabled = true; // Disable the button + } else { + toggleWidget(node, findWidgetByName(node, 'seed'), true); + node.seedControl.lastSeedButton.disabled = false; // Enable the button + toggleWidget(node, node.seedControl.lastSeedButton, true); + } + }); + + if (findWidgetByName(node, 'use_controlnet').value == '_'){ + toggleWidget(node, findWidgetByName(node, 'use_controlnet')); + toggleWidget(node, findWidgetByName(node, 'control_net_name')); + toggleWidget(node, findWidgetByName(node, 'strength')); + toggleWidget(node, findWidgetByName(node, 'preprocessor')); + toggleWidget(node, findWidgetByName(node, 'preprocessor_imgs')); + } + else{ + toggleWidget(node, findWidgetByName(node, 'use_controlnet'), true); + + if (findWidgetByName(node, 'use_controlnet').value == true){ + toggleWidget(node, findWidgetByName(node, 'control_net_name'), true); + toggleWidget(node, findWidgetByName(node, 'strength'), true); + toggleWidget(node, findWidgetByName(node, 'preprocessor'), true); + toggleWidget(node, findWidgetByName(node, 'preprocessor_imgs'), true); + } + else{ + toggleWidget(node, findWidgetByName(node, 'control_net_name')); + toggleWidget(node, findWidgetByName(node, 'strength')); + toggleWidget(node, findWidgetByName(node, 'preprocessor')); + toggleWidget(node, findWidgetByName(node, 'preprocessor_imgs')); + } + } + + } else if (findWidgetByName(node, 'upscale_type').value === "pixel") { + toggleWidget(node, findWidgetByName(node, 'hires_ckpt_name')); + toggleWidget(node, findWidgetByName(node, 'latent_upscaler')); + toggleWidget(node, findWidgetByName(node, 'use_same_seed')); + toggleWidget(node, findWidgetByName(node, 'hires_steps')); + toggleWidget(node, findWidgetByName(node, 'denoise')); + toggleWidget(node, findWidgetByName(node, 'iterations')); + toggleWidget(node, findWidgetByName(node, 'seed')); + ensureSeedControlExists(() => { + toggleWidget(node, node.seedControl.lastSeedButton); + node.seedControl.lastSeedButton.disabled = true; // Disable the button + }); + toggleWidget(node, findWidgetByName(node, 'use_controlnet')); + toggleWidget(node, findWidgetByName(node, 'control_net_name')); + toggleWidget(node, findWidgetByName(node, 'strength')); + toggleWidget(node, findWidgetByName(node, 'preprocessor')); + toggleWidget(node, findWidgetByName(node, 'preprocessor_imgs')); + + toggleWidget(node, findWidgetByName(node, 'pixel_upscaler'), true); + } +} + +/// Tiled Upscaler Script Handler +function handleTiledUpscalerScript(node, widget) { + if (findWidgetByName(node, 'use_controlnet').value == true){ + toggleWidget(node, findWidgetByName(node, 'tile_controlnet'), true); + toggleWidget(node, findWidgetByName(node, 'strength'), true); + } + else{ + toggleWidget(node, findWidgetByName(node, 'tile_controlnet')); + toggleWidget(node, findWidgetByName(node, 'strength')); } } @@ -437,7 +590,7 @@ app.registerExtension({ } }); } - setTimeout(() => {initialized = true;}, 2000); + setTimeout(() => {initialized = true;}, 500); } }); diff --git a/js/workflowfix.js b/js/workflowfix.js new file mode 100644 index 0000000..50b55af --- /dev/null +++ b/js/workflowfix.js @@ -0,0 +1,142 @@ +// Detect and update Efficiency Nodes from v1.92 to v2.00 changes (Final update?) +import { app } from '../../scripts/app.js' +import { addNode } from "./node_options/common/utils.js"; + +const ext = { + name: "efficiency.WorkflowFix", +}; + +function reloadHiResFixNode(originalNode) { + + // Safeguard against missing 'pos' property + const position = originalNode.pos && originalNode.pos.length === 2 ? { x: originalNode.pos[0], y: originalNode.pos[1] } : { x: 0, y: 0 }; + + // Recreate the node + const newNode = addNode("HighRes-Fix Script", originalNode, position); + + // Transfer input connections from old node to new node + originalNode.inputs.forEach((input, index) => { + if (input && input.link !== null) { + const originLinkInfo = originalNode.graph.links[input.link]; + if (originLinkInfo) { + const originNode = originalNode.graph.getNodeById(originLinkInfo.origin_id); + if (originNode) { + originNode.connect(originLinkInfo.origin_slot, newNode, index); + } + } + } + }); + + // Transfer output connections from old node to new node + originalNode.outputs.forEach((output, index) => { + if (output && output.links) { + output.links.forEach(link => { + const targetLinkInfo = originalNode.graph.links[link]; + if (targetLinkInfo) { + const targetNode = originalNode.graph.getNodeById(targetLinkInfo.target_id); + if (targetNode) { + newNode.connect(index, targetNode, targetLinkInfo.target_slot); + } + } + }); + } + }); + + // Remove the original node after all connections are transferred + originalNode.graph.remove(originalNode); + + return newNode; +} + +ext.loadedGraphNode = function(node, app) { + const originalNode = node; // This line ensures that originalNode refers to the provided node + const kSamplerTypes = [ + "KSampler (Efficient)", + "KSampler Adv. (Efficient)", + "KSampler SDXL (Eff.)" + ]; + + // EFFICIENT LOADER & EFF. LOADER SDXL + /* Changes: + Added "token_normalization" & "weight_interpretation" widget below prompt text boxes, + below code fixes the widget values for empty_latent_width, empty_latent_height, and batch_size + by shifting down by 2 widget values starting from the "token_normalization" widget. + Logic triggers when "token_normalization" is a number instead of a string. + */ + if (node.comfyClass === "Efficient Loader" || node.comfyClass === "Eff. Loader SDXL") { + const tokenWidget = node.widgets.find(w => w.name === "token_normalization"); + const weightWidget = node.widgets.find(w => w.name === "weight_interpretation"); + + if (typeof tokenWidget.value === 'number') { + console.log("[EfficiencyUpdate]", `Fixing '${node.comfyClass}' token and weight widgets:`, node); + const index = node.widgets.indexOf(tokenWidget); + if (index !== -1) { + for (let i = node.widgets.length - 1; i > index + 1; i--) { + node.widgets[i].value = node.widgets[i - 2].value; + } + } + tokenWidget.value = "none"; + weightWidget.value = "comfy"; + } + } + + // KSAMPLER (EFFICIENT), KSAMPLER ADV. (EFFICIENT), & KSAMPLER SDXL (EFF.) + /* Changes: + Removed the "sampler_state" widget which cause all widget values to shift down by a factor of 1. + Fix involves moving all widget values by -1. "vae_decode" value is lost in this process, so in + below fix I manually set it to its default value of "true". + */ + else if (kSamplerTypes.includes(node.comfyClass)) { + + const seedWidgetName = (node.comfyClass === "KSampler (Efficient)") ? "seed" : "noise_seed"; + const stepsWidgetName = (node.comfyClass === "KSampler (Efficient)") ? "steps" : "start_at_step"; + + const seedWidget = node.widgets.find(w => w.name === seedWidgetName); + const stepsWidget = node.widgets.find(w => w.name === stepsWidgetName); + + if (isNaN(seedWidget.value) && isNaN(stepsWidget.value)) { + console.log("[EfficiencyUpdate]", `Fixing '${node.comfyClass}' node widgets:`, node); + for (let i = 0; i < node.widgets.length - 1; i++) { + node.widgets[i].value = node.widgets[i + 1].value; + } + node.widgets[node.widgets.length - 1].value = "true"; + } + } + + // HIGHRES-FIX SCRIPT + /* Changes: + Many new changes where added, so in order to properly update, aquired the values of the original + widgets, reload a new node, transffer the known original values, and transffer connection. + This fix is triggered when the upscale_type widget is neither "latent" or "pixel". + */ + // Check if the current node is "HighRes-Fix Script" and if any of the above fixes were applied + else if (node.comfyClass === "HighRes-Fix Script") { + const upscaleTypeWidget = node.widgets.find(w => w.name === "upscale_type"); + + if (upscaleTypeWidget && upscaleTypeWidget.value !== "latent" && upscaleTypeWidget.value !== "pixel") { + console.log("[EfficiencyUpdate]", "Reloading 'HighRes-Fix Script' node:", node); + + // Reload the node and get the new node instance + const newNode = reloadHiResFixNode(node); + + // Update the widgets of the new node + const targetWidgetNames = ["latent_upscaler", "upscale_by", "hires_steps", "denoise", "iterations"]; + + // Extract the first five values of the original node + const originalValues = originalNode.widgets.slice(0, 5).map(w => w.value); + + targetWidgetNames.forEach((name, index) => { + const widget = newNode.widgets.find(w => w.name === name); + if (widget && originalValues[index] !== undefined) { + if (name === "latent_upscaler" && typeof originalValues[index] === 'string') { + widget.value = originalValues[index].replace("SD-Latent-Upscaler", "city96"); + } else { + widget.value = originalValues[index]; + } + } + }); + } + } +} + +app.registerExtension(ext); \ No newline at end of file diff --git a/py/__init__.py b/py/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/py/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/py/bnk_adv_encode.py b/py/bnk_adv_encode.py new file mode 100644 index 0000000..72d4548 --- /dev/null +++ b/py/bnk_adv_encode.py @@ -0,0 +1,317 @@ +import torch +import numpy as np +import itertools +#from math import gcd + +from comfy import model_management +from comfy.sdxl_clip import SDXLClipModel + +def _grouper(n, iterable): + it = iter(iterable) + while True: + chunk = list(itertools.islice(it, n)) + if not chunk: + return + yield chunk + +def _norm_mag(w, n): + d = w - 1 + return 1 + np.sign(d) * np.sqrt(np.abs(d)**2 / n) + #return np.sign(w) * np.sqrt(np.abs(w)**2 / n) + +def divide_length(word_ids, weights): + sums = dict(zip(*np.unique(word_ids, return_counts=True))) + sums[0] = 1 + weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0 + for w, id in zip(x, y)] for x, y in zip(weights, word_ids)] + return weights + +def shift_mean_weight(word_ids, weights): + delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x,y) if id != 0]) + weights = [[w if id == 0 else w+delta + for w, id in zip(x, y)] for x, y in zip(weights, word_ids)] + return weights + +def scale_to_norm(weights, word_ids, w_max): + top = np.max(weights) + w_max = min(top, w_max) + weights = [[w_max if id == 0 else (w/top) * w_max + for w, id in zip(x, y)] for x, y in zip(weights, word_ids)] + return weights + +def from_zero(weights, base_emb): + weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device) + weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape) + return base_emb * weight_tensor + +def mask_word_id(tokens, word_ids, target_id, mask_token): + new_tokens = [[mask_token if wid == target_id else t + for t, wid in zip(x,y)] for x,y in zip(tokens, word_ids)] + mask = np.array(word_ids) == target_id + return (new_tokens, mask) + +def batched_clip_encode(tokens, length, encode_func, num_chunks): + embs = [] + for e in _grouper(32, tokens): + enc, pooled = encode_func(e) + enc = enc.reshape((len(e), length, -1)) + embs.append(enc) + embs = torch.cat(embs) + embs = embs.reshape((len(tokens) // num_chunks, length * num_chunks, -1)) + return embs + +def from_masked(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266): + pooled_base = base_emb[0,length-1:length,:] + wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True) + weight_dict = dict((id,w) + for id,w in zip(wids ,np.array(weights).reshape(-1)[inds]) + if w != 1.0) + + if len(weight_dict) == 0: + return torch.zeros_like(base_emb), base_emb[0,length-1:length,:] + + weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device) + weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape) + + #m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0) + #TODO: find most suitable masking token here + m_token = (m_token, 1.0) + + ws = [] + masked_tokens = [] + masks = [] + + #create prompts + for id, w in weight_dict.items(): + masked, m = mask_word_id(tokens, word_ids, id, m_token) + masked_tokens.extend(masked) + + m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device) + m = m.reshape(1,-1,1).expand(base_emb.shape) + masks.append(m) + + ws.append(w) + + #batch process prompts + embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens)) + masks = torch.cat(masks) + + embs = (base_emb.expand(embs.shape) - embs) + pooled = embs[0,length-1:length,:] + + embs *= masks + embs = embs.sum(axis=0, keepdim=True) + + pooled_start = pooled_base.expand(len(ws), -1) + ws = torch.tensor(ws).reshape(-1,1).expand(pooled_start.shape) + pooled = (pooled - pooled_start) * (ws - 1) + pooled = pooled.mean(axis=0, keepdim=True) + + return ((weight_tensor - 1) * embs), pooled_base + pooled + +def mask_inds(tokens, inds, mask_token): + clip_len = len(tokens[0]) + inds_set = set(inds) + new_tokens = [[mask_token if i*clip_len + j in inds_set else t + for j, t in enumerate(x)] for i, x in enumerate(tokens)] + return new_tokens + +def down_weight(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266): + w, w_inv = np.unique(weights,return_inverse=True) + + if np.sum(w < 1) == 0: + return base_emb, tokens, base_emb[0,length-1:length,:] + #m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0) + #using the comma token as a masking token seems to work better than aos tokens for SD 1.x + m_token = (m_token, 1.0) + + masked_tokens = [] + + masked_current = tokens + for i in range(len(w)): + if w[i] >= 1: + continue + masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token) + masked_tokens.extend(masked_current) + + embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens)) + embs = torch.cat([base_emb, embs]) + w = w[w<=1.0] + w_mix = np.diff([0] + w.tolist()) + w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1,1,1)) + + weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True) + return weighted_emb, masked_current, weighted_emb[0,length-1:length,:] + +def scale_emb_to_mag(base_emb, weighted_emb): + norm_base = torch.linalg.norm(base_emb) + norm_weighted = torch.linalg.norm(weighted_emb) + embeddings_final = (norm_base / norm_weighted) * weighted_emb + return embeddings_final + +def recover_dist(base_emb, weighted_emb): + fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean()) + embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean()) + return embeddings_final + +def A1111_renorm(base_emb, weighted_emb): + embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb + return embeddings_final + +def advanced_encode_from_tokens(tokenized, token_normalization, weight_interpretation, encode_func, m_token=266, length=77, w_max=1.0, return_pooled=False, apply_to_pooled=False): + tokens = [[t for t,_,_ in x] for x in tokenized] + weights = [[w for _,w,_ in x] for x in tokenized] + word_ids = [[wid for _,_,wid in x] for x in tokenized] + + #weight normalization + #==================== + + #distribute down/up weights over word lengths + if token_normalization.startswith("length"): + weights = divide_length(word_ids, weights) + + #make mean of word tokens 1 + if token_normalization.endswith("mean"): + weights = shift_mean_weight(word_ids, weights) + + #weight interpretation + #===================== + pooled = None + + if weight_interpretation == "comfy": + weighted_tokens = [[(t,w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)] + weighted_emb, pooled_base = encode_func(weighted_tokens) + pooled = pooled_base + else: + unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokenized] + base_emb, pooled_base = encode_func(unweighted_tokens) + + if weight_interpretation == "A1111": + weighted_emb = from_zero(weights, base_emb) + weighted_emb = A1111_renorm(base_emb, weighted_emb) + pooled = pooled_base + + if weight_interpretation == "compel": + pos_tokens = [[(t,w) if w >= 1.0 else (t,1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)] + weighted_emb, _ = encode_func(pos_tokens) + weighted_emb, _, pooled = down_weight(pos_tokens, weights, word_ids, weighted_emb, length, encode_func) + + if weight_interpretation == "comfy++": + weighted_emb, tokens_down, _ = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func) + weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights] + #unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokens_down] + embs, pooled = from_masked(unweighted_tokens, weights, word_ids, base_emb, length, encode_func) + weighted_emb += embs + + if weight_interpretation == "down_weight": + weights = scale_to_norm(weights, word_ids, w_max) + weighted_emb, _, pooled = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func) + + if return_pooled: + if apply_to_pooled: + return weighted_emb, pooled + else: + return weighted_emb, pooled_base + return weighted_emb, None + +def encode_token_weights_g(model, token_weight_pairs): + return model.clip_g.encode_token_weights(token_weight_pairs) + +def encode_token_weights_l(model, token_weight_pairs): + l_out, _ = model.clip_l.encode_token_weights(token_weight_pairs) + return l_out, None + +def encode_token_weights(model, token_weight_pairs, encode_func): + if model.layer_idx is not None: + model.cond_stage_model.clip_layer(model.layer_idx) + + model_management.load_model_gpu(model.patcher) + return encode_func(model.cond_stage_model, token_weight_pairs) + +def prepareXL(embs_l, embs_g, pooled, clip_balance): + l_w = 1 - max(0, clip_balance - .5) * 2 + g_w = 1 - max(0, .5 - clip_balance) * 2 + if embs_l is not None: + return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled + else: + return embs_g, pooled + +def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): + tokenized = clip.tokenize(text, return_word_ids=True) + if isinstance(tokenized, dict): + embs_l = None + embs_g = None + pooled = None + if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel): + embs_l, _ = advanced_encode_from_tokens(tokenized['l'], + token_normalization, + weight_interpretation, + lambda x: encode_token_weights(clip, x, encode_token_weights_l), + w_max=w_max, + return_pooled=False) + if 'g' in tokenized: + embs_g, pooled = advanced_encode_from_tokens(tokenized['g'], + token_normalization, + weight_interpretation, + lambda x: encode_token_weights(clip, x, encode_token_weights_g), + w_max=w_max, + return_pooled=True, + apply_to_pooled=apply_to_pooled) + return prepareXL(embs_l, embs_g, pooled, clip_balance) + else: + return advanced_encode_from_tokens(tokenized, + token_normalization, + weight_interpretation, + lambda x: (clip.encode_from_tokens(x), None), + w_max=w_max) + +######################################################################################################################## +from nodes import MAX_RESOLUTION + +class AdvancedCLIPTextEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "text": ("STRING", {"multiline": True}), + "clip": ("CLIP",), + "token_normalization": (["none", "mean", "length", "length+mean"],), + "weight_interpretation": (["comfy", "A1111", "compel", "comfy++", "down_weight"],), + # "affect_pooled": (["disable", "enable"],), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "conditioning/advanced" + + def encode(self, clip, text, token_normalization, weight_interpretation, affect_pooled='disable'): + embeddings_final, pooled = advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, + apply_to_pooled=affect_pooled == 'enable') + return ([[embeddings_final, {"pooled_output": pooled}]],) + + +class AddCLIPSDXLRParams: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning": ("CONDITIONING",), + "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + + CATEGORY = "conditioning/advanced" + + def encode(self, conditioning, width, height, ascore): + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['width'] = width + n[1]['height'] = height + n[1]['aesthetic_score'] = ascore + c.append(n) + return (c,) + diff --git a/py/bnk_tiled_samplers.py b/py/bnk_tiled_samplers.py new file mode 100644 index 0000000..9b48568 --- /dev/null +++ b/py/bnk_tiled_samplers.py @@ -0,0 +1,523 @@ +# https://github.com/BlenderNeko/ComfyUI_TiledKSampler +import sys +import os +import itertools +import numpy as np +from tqdm.auto import tqdm +import torch + +#sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) +import comfy.sd +import comfy.controlnet +import comfy.model_management +import comfy.sample +#from . import tiling +import latent_preview +#import torch +#import itertools +#import numpy as np +MAX_RESOLUTION=8192 + +def grouper(n, iterable): + it = iter(iterable) + while True: + chunk = list(itertools.islice(it, n)) + if not chunk: + return + yield chunk + + +def create_batches(n, iterable): + groups = itertools.groupby(iterable, key=lambda x: (x[1], x[3])) + for _, x in groups: + for y in grouper(n, x): + yield y + + +def get_slice(tensor, h, h_len, w, w_len): + t = tensor.narrow(-2, h, h_len) + t = t.narrow(-1, w, w_len) + return t + + +def set_slice(tensor1, tensor2, h, h_len, w, w_len, mask=None): + if mask is not None: + tensor1[:, :, h:h + h_len, w:w + w_len] = tensor1[:, :, h:h + h_len, w:w + w_len] * (1 - mask) + tensor2 * mask + else: + tensor1[:, :, h:h + h_len, w:w + w_len] = tensor2 + + +def get_tiles_and_masks_simple(steps, latent_shape, tile_height, tile_width): + latent_size_h = latent_shape[-2] + latent_size_w = latent_shape[-1] + tile_size_h = int(tile_height // 8) + tile_size_w = int(tile_width // 8) + + h = np.arange(0, latent_size_h, tile_size_h) + w = np.arange(0, latent_size_w, tile_size_w) + + def create_tile(hs, ws, i, j): + h = int(hs[i]) + w = int(ws[j]) + h_len = min(tile_size_h, latent_size_h - h) + w_len = min(tile_size_w, latent_size_w - w) + return (h, h_len, w, w_len, steps, None) + + passes = [ + [[create_tile(h, w, i, j) for i in range(len(h)) for j in range(len(w))]], + ] + return passes + + +def get_tiles_and_masks_padded(steps, latent_shape, tile_height, tile_width): + batch_size = latent_shape[0] + latent_size_h = latent_shape[-2] + latent_size_w = latent_shape[-1] + + tile_size_h = int(tile_height // 8) + tile_size_h = int((tile_size_h // 4) * 4) + tile_size_w = int(tile_width // 8) + tile_size_w = int((tile_size_w // 4) * 4) + + # masks + mask_h = [0, tile_size_h // 4, tile_size_h - tile_size_h // 4, tile_size_h] + mask_w = [0, tile_size_w // 4, tile_size_w - tile_size_w // 4, tile_size_w] + masks = [[] for _ in range(3)] + for i in range(3): + for j in range(3): + mask = torch.zeros((batch_size, 1, tile_size_h, tile_size_w), dtype=torch.float32, device='cpu') + mask[:, :, mask_h[i]:mask_h[i + 1], mask_w[j]:mask_w[j + 1]] = 1.0 + masks[i].append(mask) + + def create_mask(h_ind, w_ind, h_ind_max, w_ind_max, mask_h, mask_w, h_len, w_len): + mask = masks[1][1] + if not (h_ind == 0 or h_ind == h_ind_max or w_ind == 0 or w_ind == w_ind_max): + return get_slice(mask, 0, h_len, 0, w_len) + mask = mask.clone() + if h_ind == 0 and mask_h: + mask += masks[0][1] + if h_ind == h_ind_max and mask_h: + mask += masks[2][1] + if w_ind == 0 and mask_w: + mask += masks[1][0] + if w_ind == w_ind_max and mask_w: + mask += masks[1][2] + if h_ind == 0 and w_ind == 0 and mask_h and mask_w: + mask += masks[0][0] + if h_ind == 0 and w_ind == w_ind_max and mask_h and mask_w: + mask += masks[0][2] + if h_ind == h_ind_max and w_ind == 0 and mask_h and mask_w: + mask += masks[2][0] + if h_ind == h_ind_max and w_ind == w_ind_max and mask_h and mask_w: + mask += masks[2][2] + return get_slice(mask, 0, h_len, 0, w_len) + + h = np.arange(0, latent_size_h, tile_size_h) + h_shift = np.arange(tile_size_h // 2, latent_size_h - tile_size_h // 2, tile_size_h) + w = np.arange(0, latent_size_w, tile_size_w) + w_shift = np.arange(tile_size_w // 2, latent_size_w - tile_size_h // 2, tile_size_w) + + def create_tile(hs, ws, mask_h, mask_w, i, j): + h = int(hs[i]) + w = int(ws[j]) + h_len = min(tile_size_h, latent_size_h - h) + w_len = min(tile_size_w, latent_size_w - w) + mask = create_mask(i, j, len(hs) - 1, len(ws) - 1, mask_h, mask_w, h_len, w_len) + return (h, h_len, w, w_len, steps, mask) + + passes = [ + [[create_tile(h, w, True, True, i, j) for i in range(len(h)) for j in range(len(w))]], + [[create_tile(h_shift, w, False, True, i, j) for i in range(len(h_shift)) for j in range(len(w))]], + [[create_tile(h, w_shift, True, False, i, j) for i in range(len(h)) for j in range(len(w_shift))]], + [[create_tile(h_shift, w_shift, False, False, i, j) for i in range(len(h_shift)) for j in range(len(w_shift))]], + ] + + return passes + + +def mask_at_boundary(h, h_len, w, w_len, tile_size_h, tile_size_w, latent_size_h, latent_size_w, mask, device='cpu'): + tile_size_h = int(tile_size_h // 8) + tile_size_w = int(tile_size_w // 8) + + if (h_len == tile_size_h or h_len == latent_size_h) and (w_len == tile_size_w or w_len == latent_size_w): + return h, h_len, w, w_len, mask + h_offset = min(0, latent_size_h - (h + tile_size_h)) + w_offset = min(0, latent_size_w - (w + tile_size_w)) + new_mask = torch.zeros((1, 1, tile_size_h, tile_size_w), dtype=torch.float32, device=device) + new_mask[:, :, -h_offset:h_len if h_offset == 0 else tile_size_h, + -w_offset:w_len if w_offset == 0 else tile_size_w] = 1.0 if mask is None else mask + return h + h_offset, tile_size_h, w + w_offset, tile_size_w, new_mask + + +def get_tiles_and_masks_rgrid(steps, latent_shape, tile_height, tile_width, generator): + def calc_coords(latent_size, tile_size, jitter): + tile_coords = int((latent_size + jitter - 1) // tile_size + 1) + tile_coords = [np.clip(tile_size * c - jitter, 0, latent_size) for c in range(tile_coords + 1)] + tile_coords = [(c1, c2 - c1) for c1, c2 in zip(tile_coords, tile_coords[1:])] + return tile_coords + + # calc stuff + batch_size = latent_shape[0] + latent_size_h = latent_shape[-2] + latent_size_w = latent_shape[-1] + tile_size_h = int(tile_height // 8) + tile_size_w = int(tile_width // 8) + + tiles_all = [] + + for s in range(steps): + rands = torch.rand((2,), dtype=torch.float32, generator=generator, device='cpu').numpy() + + jitter_w1 = int(rands[0] * tile_size_w) + jitter_w2 = int(((rands[0] + .5) % 1.0) * tile_size_w) + jitter_h1 = int(rands[1] * tile_size_h) + jitter_h2 = int(((rands[1] + .5) % 1.0) * tile_size_h) + + # calc number of tiles + tiles_h = [ + calc_coords(latent_size_h, tile_size_h, jitter_h1), + calc_coords(latent_size_h, tile_size_h, jitter_h2) + ] + tiles_w = [ + calc_coords(latent_size_w, tile_size_w, jitter_w1), + calc_coords(latent_size_w, tile_size_w, jitter_w2) + ] + + tiles = [] + if s % 2 == 0: + for i, h in enumerate(tiles_h[0]): + for w in tiles_w[i % 2]: + tiles.append((int(h[0]), int(h[1]), int(w[0]), int(w[1]), 1, None)) + else: + for i, w in enumerate(tiles_w[0]): + for h in tiles_h[i % 2]: + tiles.append((int(h[0]), int(h[1]), int(w[0]), int(w[1]), 1, None)) + tiles_all.append(tiles) + return [tiles_all] + +####################### + +def recursion_to_list(obj, attr): + current = obj + yield current + while True: + current = getattr(current, attr, None) + if current is not None: + yield current + else: + return + +def copy_cond(cond): + return [[c1,c2.copy()] for c1,c2 in cond] + +def slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, cond, area): + tile_h_end = tile_h + tile_h_len + tile_w_end = tile_w + tile_w_len + coords = area[0] #h_len, w_len, h, w, + mask = area[1] + if coords is not None: + h_len, w_len, h, w = coords + h_end = h + h_len + w_end = w + w_len + if h < tile_h_end and h_end > tile_h and w < tile_w_end and w_end > tile_w: + new_h = max(0, h - tile_h) + new_w = max(0, w - tile_w) + new_h_end = min(tile_h_end, h_end - tile_h) + new_w_end = min(tile_w_end, w_end - tile_w) + cond[1]['area'] = (new_h_end - new_h, new_w_end - new_w, new_h, new_w) + else: + return (cond, True) + if mask is not None: + new_mask = get_slice(mask, tile_h,tile_h_len,tile_w,tile_w_len) + if new_mask.sum().cpu() == 0.0 and 'mask' in cond[1]: + return (cond, True) + else: + cond[1]['mask'] = new_mask + return (cond, False) + +def slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen): + tile_h_end = tile_h + tile_h_len + tile_w_end = tile_w + tile_w_len + if gligen is None: + return + gligen_type = gligen[0] + gligen_model = gligen[1] + gligen_areas = gligen[2] + + gligen_areas_new = [] + for emb, h_len, w_len, h, w in gligen_areas: + h_end = h + h_len + w_end = w + w_len + if h < tile_h_end and h_end > tile_h and w < tile_w_end and w_end > tile_w: + new_h = max(0, h - tile_h) + new_w = max(0, w - tile_w) + new_h_end = min(tile_h_end, h_end - tile_h) + new_w_end = min(tile_w_end, w_end - tile_w) + gligen_areas_new.append((emb, new_h_end - new_h, new_w_end - new_w, new_h, new_w)) + + if len(gligen_areas_new) == 0: + del cond['gligen'] + else: + cond['gligen'] = (gligen_type, gligen_model, gligen_areas_new) + +def slice_cnet(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img): + if img is None: + img = model.cond_hint_original + model.cond_hint = get_slice(img, h*8, h_len*8, w*8, w_len*8).to(model.control_model.dtype).to(model.device) + +def slices_T2I(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img): + model.control_input = None + if img is None: + img = model.cond_hint_original + model.cond_hint = get_slice(img, h*8, h_len*8, w*8, w_len*8).float().to(model.device) + +# TODO: refactor some of the mess + +from PIL import Image + +def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, preview=False): + end_at_step = min(end_at_step, steps) + device = comfy.model_management.get_torch_device() + samples = latent_image["samples"] + noise_mask = latent_image["noise_mask"] if "noise_mask" in latent_image else None + force_full_denoise = return_with_leftover_noise == "enable" + if add_noise == "disable": + noise = torch.zeros(samples.size(), dtype=samples.dtype, layout=samples.layout, device="cpu") + else: + skip = latent_image["batch_index"] if "batch_index" in latent_image else None + noise = comfy.sample.prepare_noise(samples, noise_seed, skip) + + if noise_mask is not None: + noise_mask = comfy.sample.prepare_mask(noise_mask, noise.shape, device='cpu') + + shape = samples.shape + samples = samples.clone() + + tile_width = min(shape[-1] * 8, tile_width) + tile_height = min(shape[2] * 8, tile_height) + + real_model = None + modelPatches, inference_memory = comfy.sample.get_additional_models(positive, negative, model.model_dtype()) + comfy.model_management.load_models_gpu([model] + modelPatches, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory) + real_model = model.model + + sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) + + if tiling_strategy != 'padded': + if noise_mask is not None: + samples += sampler.sigmas[start_at_step].cpu() * noise_mask * model.model.process_latent_out(noise).cpu() + else: + samples += sampler.sigmas[start_at_step].cpu() * model.model.process_latent_out(noise).cpu() + + #cnets + cnets = comfy.sample.get_models_from_cond(positive, 'control') + comfy.sample.get_models_from_cond(negative, 'control') + cnets = [m for m in cnets if isinstance(m, comfy.controlnet.ControlNet)] + cnets = list(set([x for m in cnets for x in recursion_to_list(m, "previous_controlnet")])) + cnet_imgs = [ + torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu') + if m.cond_hint_original.shape[-2] != shape[-2] * 8 or m.cond_hint_original.shape[-1] != shape[-1] * 8 else None + for m in cnets] + + #T2I + T2Is = comfy.sample.get_models_from_cond(positive, 'control') + comfy.sample.get_models_from_cond(negative, 'control') + T2Is = [m for m in T2Is if isinstance(m, comfy.controlnet.T2IAdapter)] + T2Is = [x for m in T2Is for x in recursion_to_list(m, "previous_controlnet")] + T2I_imgs = [ + torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu') + if m.cond_hint_original.shape[-2] != shape[-2] * 8 or m.cond_hint_original.shape[-1] != shape[-1] * 8 or (m.channels_in == 1 and m.cond_hint_original.shape[1] != 1) else None + for m in T2Is + ] + T2I_imgs = [ + torch.mean(img, 1, keepdim=True) if img is not None and m.channels_in == 1 and m.cond_hint_original.shape[1] else img + for m, img in zip(T2Is, T2I_imgs) + ] + + #cond area and mask + spatial_conds_pos = [ + (c[1]['area'] if 'area' in c[1] else None, + comfy.sample.prepare_mask(c[1]['mask'], shape, device) if 'mask' in c[1] else None) + for c in positive + ] + spatial_conds_neg = [ + (c[1]['area'] if 'area' in c[1] else None, + comfy.sample.prepare_mask(c[1]['mask'], shape, device) if 'mask' in c[1] else None) + for c in negative + ] + + #gligen + gligen_pos = [ + c[1]['gligen'] if 'gligen' in c[1] else None + for c in positive + ] + gligen_neg = [ + c[1]['gligen'] if 'gligen' in c[1] else None + for c in negative + ] + + positive_copy = comfy.sample.broadcast_cond(positive, shape[0], device) + negative_copy = comfy.sample.broadcast_cond(negative, shape[0], device) + + gen = torch.manual_seed(noise_seed) + if tiling_strategy == 'random' or tiling_strategy == 'random strict': + tiles = get_tiles_and_masks_rgrid(end_at_step - start_at_step, samples.shape, tile_height, tile_width, gen) + elif tiling_strategy == 'padded': + tiles = get_tiles_and_masks_padded(end_at_step - start_at_step, samples.shape, tile_height, tile_width) + else: + tiles = get_tiles_and_masks_simple(end_at_step - start_at_step, samples.shape, tile_height, tile_width) + + total_steps = sum([num_steps for img_pass in tiles for steps_list in img_pass for _,_,_,_,num_steps,_ in steps_list]) + current_step = [0] + + preview_format = "JPEG" + if preview_format not in ["JPEG", "PNG"]: + preview_format = "JPEG" + previewer = None + if preview: + previewer = latent_preview.get_previewer(device, model.model.latent_format) + + + with tqdm(total=total_steps) as pbar_tqdm: + pbar = comfy.utils.ProgressBar(total_steps) + + def callback(step, x0, x, total_steps): + current_step[0] += 1 + preview_bytes = None + if previewer: + preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0) + pbar.update_absolute(current_step[0], preview=preview_bytes) + pbar_tqdm.update(1) + + if tiling_strategy == "random strict": + samples_next = samples.clone() + for img_pass in tiles: + for i in range(len(img_pass)): + for tile_h, tile_h_len, tile_w, tile_w_len, tile_steps, tile_mask in img_pass[i]: + tiled_mask = None + if noise_mask is not None: + tiled_mask = get_slice(noise_mask, tile_h, tile_h_len, tile_w, tile_w_len).to(device) + if tile_mask is not None: + if tiled_mask is not None: + tiled_mask *= tile_mask.to(device) + else: + tiled_mask = tile_mask.to(device) + + if tiling_strategy == 'padded' or tiling_strategy == 'random strict': + tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask = mask_at_boundary( tile_h, tile_h_len, tile_w, tile_w_len, + tile_height, tile_width, samples.shape[-2], samples.shape[-1], + tiled_mask, device) + + + if tiled_mask is not None and tiled_mask.sum().cpu() == 0.0: + continue + + tiled_latent = get_slice(samples, tile_h, tile_h_len, tile_w, tile_w_len).to(device) + + if tiling_strategy == 'padded': + tiled_noise = get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) + else: + if tiled_mask is None or noise_mask is None: + tiled_noise = torch.zeros_like(tiled_latent) + else: + tiled_noise = get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) * (1 - tiled_mask) + + #TODO: all other condition based stuff like area sets and GLIGEN should also happen here + + #cnets + for m, img in zip(cnets, cnet_imgs): + slice_cnet(tile_h, tile_h_len, tile_w, tile_w_len, m, img) + + #T2I + for m, img in zip(T2Is, T2I_imgs): + slices_T2I(tile_h, tile_h_len, tile_w, tile_w_len, m, img) + + pos = copy_cond(positive_copy) + neg = copy_cond(negative_copy) + + #cond areas + pos = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(pos, spatial_conds_pos)] + pos = [c for c, ignore in pos if not ignore] + neg = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(neg, spatial_conds_neg)] + neg = [c for c, ignore in neg if not ignore] + + #gligen + for (_, cond), gligen in zip(pos, gligen_pos): + slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen) + for (_, cond), gligen in zip(neg, gligen_neg): + slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen) + + tile_result = sampler.sample(tiled_noise, pos, neg, cfg=cfg, latent_image=tiled_latent, start_step=start_at_step + i * tile_steps, last_step=start_at_step + i*tile_steps + tile_steps, force_full_denoise=force_full_denoise and i+1 == end_at_step - start_at_step, denoise_mask=tiled_mask, callback=callback, disable_pbar=True, seed=noise_seed) + tile_result = tile_result.cpu() + if tiled_mask is not None: + tiled_mask = tiled_mask.cpu() + if tiling_strategy == "random strict": + set_slice(samples_next, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask) + else: + set_slice(samples, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask) + if tiling_strategy == "random strict": + samples = samples_next.clone() + + + comfy.sample.cleanup_additional_models(modelPatches) + + out = latent_image.copy() + out["samples"] = samples.cpu() + return (out, ) + +class TiledKSampler: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}), + "tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}), + "tiling_strategy": (["random", "random strict", "padded", 'simple'], ), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent_image": ("LATENT", ), + "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "sample" + + CATEGORY = "sampling" + + def sample(self, model, seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise): + steps_total = int(steps / denoise) + return sample_common(model, 'enable', seed, tile_width, tile_height, tiling_strategy, steps_total, cfg, sampler_name, scheduler, positive, negative, latent_image, steps_total-steps, steps_total, 'disable', denoise=1.0, preview=True) + +class TiledKSamplerAdvanced: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"model": ("MODEL",), + "add_noise": (["enable", "disable"], ), + "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + "tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}), + "tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}), + "tiling_strategy": (["random", "random strict", "padded", 'simple'], ), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}), + "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), + "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "latent_image": ("LATENT", ), + "start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}), + "end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}), + "return_with_leftover_noise": (["disable", "enable"], ), + "preview": (["disable", "enable"], ), + }} + + RETURN_TYPES = ("LATENT",) + FUNCTION = "sample" + + CATEGORY = "sampling" + + def sample(self, model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, preview, denoise=1.0): + return sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, preview= preview == 'enable') diff --git a/py/cg_mixed_seed_noise.py b/py/cg_mixed_seed_noise.py new file mode 100644 index 0000000..6c4c508 --- /dev/null +++ b/py/cg_mixed_seed_noise.py @@ -0,0 +1,16 @@ +# https://github.com/chrisgoringe/cg-noise +import torch + +def get_mixed_noise_function(original_noise_function, variation_seed, variation_weight): + def prepare_mixed_noise(latent_image:torch.Tensor, seed, batch_inds): + single_image_latent = latent_image[0].unsqueeze_(0) + different_noise = original_noise_function(single_image_latent, variation_seed, batch_inds) + original_noise = original_noise_function(single_image_latent, seed, batch_inds) + if latent_image.shape[0]==1: + mixed_noise = original_noise * (1.0-variation_weight) + different_noise * (variation_weight) + else: + mixed_noise = torch.empty_like(latent_image) + for i in range(latent_image.shape[0]): + mixed_noise[i] = original_noise * (1.0-variation_weight*i) + different_noise * (variation_weight*i) + return mixed_noise + return prepare_mixed_noise diff --git a/py/city96_latent_upscaler.py b/py/city96_latent_upscaler.py new file mode 100644 index 0000000..1d0b9f4 --- /dev/null +++ b/py/city96_latent_upscaler.py @@ -0,0 +1,82 @@ +# https://github.com/city96/SD-Latent-Upscaler +import torch +import torch.nn as nn +from safetensors.torch import load_file +from huggingface_hub import hf_hub_download + +class Upscaler(nn.Module): + """ + Basic NN layout, ported from: + https://github.com/city96/SD-Latent-Upscaler/blob/main/upscaler.py + """ + version = 2.1 # network revision + def head(self): + return [ + nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad), + nn.ReLU(), + nn.Upsample(scale_factor=self.fac, mode="nearest"), + nn.ReLU(), + ] + def core(self): + layers = [] + for _ in range(self.depth): + layers += [ + nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad), + nn.ReLU(), + ] + return layers + def tail(self): + return [ + nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad), + ] + + def __init__(self, fac, depth=16): + super().__init__() + self.size = 64 # Conv2d size + self.chan = 4 # in/out channels + self.depth = depth # no. of layers + self.fac = fac # scale factor + self.krn = 3 # kernel size + self.pad = 1 # padding + + self.sequential = nn.Sequential( + *self.head(), + *self.core(), + *self.tail(), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.sequential(x) + + +class LatentUpscaler: + def __init__(self): + pass + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT", ), + "latent_ver": (["v1", "xl"],), + "scale_factor": (["1.25", "1.5", "2.0"],), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "upscale" + CATEGORY = "latent" + + def upscale(self, samples, latent_ver, scale_factor): + model = Upscaler(scale_factor) + weights = str(hf_hub_download( + repo_id="city96/SD-Latent-Upscaler", + filename=f"latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors") + ) + # weights = f"./latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors" + + model.load_state_dict(load_file(weights)) + lt = samples["samples"] + lt = model(lt) + del model + return ({"samples": lt},) \ No newline at end of file diff --git a/py/sd15_resizer.pt b/py/sd15_resizer.pt new file mode 100644 index 0000000..f5c9192 Binary files /dev/null and b/py/sd15_resizer.pt differ diff --git a/py/sdxl_resizer.pt b/py/sdxl_resizer.pt new file mode 100644 index 0000000..a611e85 Binary files /dev/null and b/py/sdxl_resizer.pt differ diff --git a/py/smZ_cfg_denoiser.py b/py/smZ_cfg_denoiser.py new file mode 100644 index 0000000..e6a4ce1 --- /dev/null +++ b/py/smZ_cfg_denoiser.py @@ -0,0 +1,321 @@ +# https://github.com/shiimizu/ComfyUI_smZNodes +import comfy +import torch +from typing import List +import comfy.sample +from comfy import model_base, model_management +from comfy.samplers import KSampler, CompVisVDenoiser, KSamplerX0Inpaint +from comfy.k_diffusion.external import CompVisDenoiser +import nodes +import inspect +import functools +import importlib +import os +import re +import itertools + +import torch +from comfy import model_management + +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1)).to(device=tensor.device)], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + +class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model): + super().__init__() + self.inner_model = model + self.model_wrap = None + self.mask = None + self.nmask = None + self.init_latent = None + self.steps = None + """number of steps as specified by user in UI""" + + self.total_steps = None + """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler""" + + self.step = 0 + self.image_cfg_scale = None + self.padded_cond_uncond = False + self.sampler = None + self.model_wrap = None + self.p = None + self.mask_before_denoising = False + + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) + + return denoised + + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + + def get_pred_x0(self, x_in, x_out, sigma): + return x_out + + def update_inner_model(self): + self.model_wrap = None + + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + model_management.throw_exception_if_processing_interrupted() + + is_edit_model = False + + conds_list, tensor = cond + assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + + if self.mask_before_denoising and self.mask is not None: + x = self.init_latent * self.mask + self.nmask * x + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + if False: + image_uncond = torch.zeros_like(image_cond) + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm, 'transformer_options': {'from_smZ': True}} # pylint: disable=C3001 + else: + image_uncond = image_cond + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}} + + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) + + skip_uncond = False + + # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it + if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: + skip_uncond = True + x_in = x_in[:-batch_size] + sigma_in = sigma_in[:-batch_size] + + if tensor.shape[1] == uncond.shape[1] or skip_uncond: + if is_edit_model: + cond_in = catenate_conds([tensor, uncond, uncond]) + elif skip_uncond: + cond_in = tensor + else: + cond_in = catenate_conds([tensor, uncond]) + + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + + if not is_edit_model: + c_crossattn = subscript_cond(tensor, a, b) + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(c_crossattn, image_cond_in[a:b])) + + if not skip_uncond: + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], **make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) + + denoised_image_indexes = [x[0][0] for x in conds_list] + if skip_uncond: + fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) + x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be + + if is_edit_model: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) + elif skip_uncond: + denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) + else: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + + if not self.mask_before_denoising and self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + + self.step += 1 + del x_out + return denoised + +# ======================================================================== + +def expand(tensor1, tensor2): + def adjust_tensor_shape(tensor_small, tensor_big): + # Calculate replication factor + # -(-a // b) is ceiling of division without importing math.ceil + replication_factor = -(-tensor_big.size(1) // tensor_small.size(1)) + + # Use repeat to extend tensor_small + tensor_small_extended = tensor_small.repeat(1, replication_factor, 1) + + # Take the rows of the extended tensor_small to match tensor_big + tensor_small_matched = tensor_small_extended[:, :tensor_big.size(1), :] + + return tensor_small_matched + + # Check if their second dimensions are different + if tensor1.size(1) != tensor2.size(1): + # Check which tensor has the smaller second dimension and adjust its shape + if tensor1.size(1) < tensor2.size(1): + tensor1 = adjust_tensor_shape(tensor1, tensor2) + else: + tensor2 = adjust_tensor_shape(tensor2, tensor1) + return (tensor1, tensor2) + +def _find_outer_instance(target, target_type): + import inspect + frame = inspect.currentframe() + while frame: + if target in frame.f_locals: + found = frame.f_locals[target] + if isinstance(found, target_type) and found != 1: # steps == 1 + return found + frame = frame.f_back + return None + +# ======================================================================== +def bounded_modulo(number, modulo_value): + return number if number < modulo_value else modulo_value + +def calc_cond(c, current_step): + """Group by smZ conds that may do prompt-editing / regular conds / comfy conds.""" + _cond = [] + # Group by conds from smZ + fn=lambda x : x[1].get("from_smZ", None) is not None + an_iterator = itertools.groupby(c, fn ) + for key, group in an_iterator: + ls=list(group) + # Group by prompt-editing conds + fn2=lambda x : x[1].get("smZid", None) + an_iterator2 = itertools.groupby(ls, fn2) + for key2, group2 in an_iterator2: + ls2=list(group2) + if key2 is not None: + orig_len = ls2[0][1].get('orig_len', 1) + i = bounded_modulo(current_step, orig_len - 1) + _cond = _cond + [ls2[i]] + else: + _cond = _cond + ls2 + return _cond + +class CFGNoisePredictor(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.ksampler = _find_outer_instance('self', comfy.samplers.KSampler) + self.step = 0 + self.orig = comfy.samplers.CFGNoisePredictor(model) #CFGNoisePredictorOrig(model) + self.inner_model = model + self.inner_model2 = CFGDenoiser(model.apply_model) + self.inner_model2.num_timesteps = model.num_timesteps + self.inner_model2.device = self.ksampler.device if hasattr(self.ksampler, "device") else None + self.s_min_uncond = 0.0 + self.alphas_cumprod = model.alphas_cumprod + self.c_adm = None + self.init_cond = None + self.init_uncond = None + self.is_prompt_editing_u = False + self.is_prompt_editing_c = False + + def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None): + + cc=calc_cond(cond, self.step) + uu=calc_cond(uncond, self.step) + self.step += 1 + + if (any([p[1].get('from_smZ', False) for p in cc]) or + any([p[1].get('from_smZ', False) for p in uu])): + if model_options.get('transformer_options',None) is None: + model_options['transformer_options'] = {} + model_options['transformer_options']['from_smZ'] = True + + # Only supports one cond + for ix in range(len(cc)): + if cc[ix][1].get('from_smZ', False): + cc = [cc[ix]] + break + for ix in range(len(uu)): + if uu[ix][1].get('from_smZ', False): + uu = [uu[ix]] + break + c=cc[0][1] + u=uu[0][1] + _cc = cc[0][0] + _uu = uu[0][0] + if c.get("adm_encoded", None) is not None: + self.c_adm = torch.cat([c['adm_encoded'], u['adm_encoded']]) + # SDXL. Need to pad with repeats + _cc, _uu = expand(_cc, _uu) + _uu, _cc = expand(_uu, _cc) + x.c_adm = self.c_adm + conds_list = c.get('conds_list', [[(0, 1.0)]]) + image_cond = txt2img_image_conditioning(None, x) + out = self.inner_model2(x, timestep, cond=(conds_list, _cc), uncond=_uu, cond_scale=cond_scale, s_min_uncond=self.s_min_uncond, image_cond=image_cond) + return out + +def txt2img_image_conditioning(sd_model, x, width=None, height=None): + return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) + +# ======================================================================================= + +def set_model_k(self: KSampler): + self.model_denoise = CFGNoisePredictor(self.model) # main change + if ((getattr(self.model, "parameterization", "") == "v") or + (getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)): + self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True) + self.model_wrap.parameterization = getattr(self.model, "parameterization", "v") + else: + self.model_wrap = CompVisDenoiser(self.model_denoise, quantize=True) + self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps") + self.model_k = KSamplerX0Inpaint(self.model_wrap) + +class SDKSampler(comfy.samplers.KSampler): + def __init__(self, *args, **kwargs): + super(SDKSampler, self).__init__(*args, **kwargs) + set_model_k(self) \ No newline at end of file diff --git a/py/smZ_rng_source.py b/py/smZ_rng_source.py new file mode 100644 index 0000000..e6105a9 --- /dev/null +++ b/py/smZ_rng_source.py @@ -0,0 +1,140 @@ +# https://github.com/shiimizu/ComfyUI_smZNodes +import numpy as np + +philox_m = [0xD2511F53, 0xCD9E8D57] +philox_w = [0x9E3779B9, 0xBB67AE85] + +two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32) +two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32) + + +def uint32(x): + """Converts (N,) np.uint64 array into (2, N) np.unit32 array.""" + return x.view(np.uint32).reshape(-1, 2).transpose(1, 0) + + +def philox4_round(counter, key): + """A single round of the Philox 4x32 random number generator.""" + + v1 = uint32(counter[0].astype(np.uint64) * philox_m[0]) + v2 = uint32(counter[2].astype(np.uint64) * philox_m[1]) + + counter[0] = v2[1] ^ counter[1] ^ key[0] + counter[1] = v2[0] + counter[2] = v1[1] ^ counter[3] ^ key[1] + counter[3] = v1[0] + + +def philox4_32(counter, key, rounds=10): + """Generates 32-bit random numbers using the Philox 4x32 random number generator. + + Parameters: + counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation). + key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed). + rounds (int): The number of rounds to perform. + + Returns: + numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers. + """ + + for _ in range(rounds - 1): + philox4_round(counter, key) + + key[0] = key[0] + philox_w[0] + key[1] = key[1] + philox_w[1] + + philox4_round(counter, key) + return counter + + +def box_muller(x, y): + """Returns just the first out of two numbers generated by Box–Muller transform algorithm.""" + u = x * two_pow32_inv + two_pow32_inv / 2 + v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2 + + s = np.sqrt(-2.0 * np.log(u)) + + r1 = s * np.sin(v) + return r1.astype(np.float32) + + +class Generator: + """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU""" + + def __init__(self, seed): + self.seed = seed + self.offset = 0 + + def randn(self, shape): + """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform.""" + + n = 1 + for x in shape: + n *= x + + counter = np.zeros((4, n), dtype=np.uint32) + counter[0] = self.offset + counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3] + self.offset += 1 + + key = np.empty(n, dtype=np.uint64) + key.fill(self.seed) + key = uint32(key) + + g = philox4_32(counter, key) + + return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3] + +#======================================================================================================================= +# Monkey Patch "prepare_noise" function +# https://github.com/shiimizu/ComfyUI_smZNodes +import torch +import functools +from comfy.sample import np +import comfy.model_management + +def rng_rand_source(rand_source='cpu'): + device = comfy.model_management.text_encoder_device() + + def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'): + """ + creates random noise given a latent image and a seed. + optional arg skip can be used to skip and discard x number of noise generations for a given seed + """ + generator = torch.Generator(device).manual_seed(seed) + if rand_source == 'nv': + rng = Generator(seed) + if noise_inds is None: + shape = latent_image.size() + if rand_source == 'nv': + return torch.asarray(rng.randn(shape), device=device) + else: + return torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, + device=device) + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1] + 1): + shape = [1] + list(latent_image.size())[1:] + if rand_source == 'nv': + noise = torch.asarray(rng.randn(shape), device=device) + else: + noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, + device=device) + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + noises = torch.cat(noises, axis=0) + return noises + + if rand_source == 'cpu': + if hasattr(comfy.sample, 'prepare_noise_orig'): + comfy.sample.prepare_noise = comfy.sample.prepare_noise_orig + else: + if not hasattr(comfy.sample, 'prepare_noise_orig'): + comfy.sample.prepare_noise_orig = comfy.sample.prepare_noise + _prepare_noise = functools.partial(prepare_noise, device=device) + comfy.sample.prepare_noise = _prepare_noise + + + diff --git a/py/ttl_nn_latent_upscaler.py b/py/ttl_nn_latent_upscaler.py new file mode 100644 index 0000000..b3d737e --- /dev/null +++ b/py/ttl_nn_latent_upscaler.py @@ -0,0 +1,313 @@ +import torch +#from .latent_resizer import LatentResizer +from comfy import model_management +import os + +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +def normalization(channels): + return nn.GroupNorm(32, channels) + + +def zero_module(module): + for p in module.parameters(): + p.detach().zero_() + return module + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalization(in_channels) + self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def attention(self, h_: torch.Tensor) -> torch.Tensor: + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + b, c, h, w = q.shape + q, k, v = map( + lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) + ) + h_ = nn.functional.scaled_dot_product_attention( + q, k, v + ) # scale is dim ** -0.5 per default + + return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + + def forward(self, x, **kwargs): + h_ = x + h_ = self.attention(h_) + h_ = self.proj_out(h_) + return x + h_ + + +def make_attn(in_channels, attn_kwargs=None): + return AttnBlock(in_channels) + + +class ResBlockEmb(nn.Module): + def __init__( + self, + channels, + emb_channels, + dropout=0, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + kernel_size=3, + exchange_temb_dims=False, + skip_t_emb=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.exchange_temb_dims = exchange_temb_dims + + padding = kernel_size // 2 + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, kernel_size, padding=padding), + ) + + self.skip_t_emb = skip_t_emb + self.emb_out_channels = ( + 2 * self.out_channels if use_scale_shift_norm else self.out_channels + ) + if self.skip_t_emb: + print(f"Skipping timestep embedding in {self.__class__.__name__}") + assert not self.use_scale_shift_norm + self.emb_layers = None + self.exchange_temb_dims = False + else: + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + self.emb_out_channels, + ), + ) + + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d( + self.out_channels, + self.out_channels, + kernel_size, + padding=padding, + ) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv2d( + channels, self.out_channels, kernel_size, padding=padding + ) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + def forward(self, x, emb): + h = self.in_layers(x) + + if self.skip_t_emb: + emb_out = torch.zeros_like(h) + else: + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + if self.exchange_temb_dims: + emb_out = rearrange(emb_out, "b t c ... -> b c t ...") + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class LatentResizer(nn.Module): + def __init__(self, in_blocks=10, out_blocks=10, channels=128, dropout=0, attn=True): + super().__init__() + self.conv_in = nn.Conv2d(4, channels, 3, padding=1) + + self.channels = channels + embed_dim = 32 + self.embed = nn.Sequential( + nn.Linear(1, embed_dim), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim), + ) + + self.in_blocks = nn.ModuleList([]) + for b in range(in_blocks): + if (b == 1 or b == in_blocks - 1) and attn: + self.in_blocks.append(make_attn(channels)) + self.in_blocks.append(ResBlockEmb(channels, embed_dim, dropout)) + + self.out_blocks = nn.ModuleList([]) + for b in range(out_blocks): + if (b == 1 or b == out_blocks - 1) and attn: + self.out_blocks.append(make_attn(channels)) + self.out_blocks.append(ResBlockEmb(channels, embed_dim, dropout)) + + self.norm_out = normalization(channels) + self.conv_out = nn.Conv2d(channels, 4, 3, padding=1) + + @classmethod + def load_model(cls, filename, device="cpu", dtype=torch.float32, dropout=0): + if not 'weights_only' in torch.load.__code__.co_varnames: + weights = torch.load(filename, map_location=torch.device("cpu")) + else: + weights = torch.load(filename, map_location=torch.device("cpu"), weights_only=True) + in_blocks = 0 + out_blocks = 0 + in_tfs = 0 + out_tfs = 0 + channels = weights["conv_in.bias"].shape[0] + for k in weights.keys(): + k = k.split(".") + if k[0] == "in_blocks": + in_blocks = max(in_blocks, int(k[1])) + if k[2] == "q" and k[3] == "weight": + in_tfs += 1 + if k[0] == "out_blocks": + out_blocks = max(out_blocks, int(k[1])) + if k[2] == "q" and k[3] == "weight": + out_tfs += 1 + in_blocks = in_blocks + 1 - in_tfs + out_blocks = out_blocks + 1 - out_tfs + resizer = cls( + in_blocks=in_blocks, + out_blocks=out_blocks, + channels=channels, + dropout=dropout, + attn=(out_tfs != 0), + ) + resizer.load_state_dict(weights) + resizer.eval() + resizer.to(device, dtype=dtype) + return resizer + + def forward(self, x, scale=None, size=None): + if scale is None and size is None: + raise ValueError("Either scale or size needs to be not None") + if scale is not None and size is not None: + raise ValueError("Both scale or size can't be not None") + if scale is not None: + size = (x.shape[-2] * scale, x.shape[-1] * scale) + size = tuple([int(round(i)) for i in size]) + else: + scale = size[-1] / x.shape[-1] + + # Output is the same size as input + if size == x.shape[-2:]: + return x + + scale = torch.tensor([scale - 1], dtype=x.dtype).to(x.device).unsqueeze(0) + emb = self.embed(scale) + + x = self.conv_in(x) + + for b in self.in_blocks: + if isinstance(b, ResBlockEmb): + x = b(x, emb) + else: + x = b(x) + x = F.interpolate(x, size=size, mode="bilinear") + for b in self.out_blocks: + if isinstance(b, ResBlockEmb): + x = b(x, emb) + else: + x = b(x) + + x = self.norm_out(x) + x = F.silu(x) + x = self.conv_out(x) + return x + +######################################################## +class NNLatentUpscale: + """ + Upscales SDXL latent using neural network + """ + + def __init__(self): + self.local_dir = os.path.dirname(os.path.realpath(__file__)) + self.scale_factor = 0.13025 + self.dtype = torch.float32 + if model_management.should_use_fp16(): + self.dtype = torch.float16 + self.weight_path = { + "SDXL": os.path.join(self.local_dir, "sdxl_resizer.pt"), + "SD 1.x": os.path.join(self.local_dir, "sd15_resizer.pt"), + } + self.version = "none" + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "latent": ("LATENT",), + "version": (["SDXL", "SD 1.x"],), + "upscale": ( + "FLOAT", + { + "default": 1.5, + "min": 1.0, + "max": 2.0, + "step": 0.01, + "display": "number", + }, + ), + }, + } + + RETURN_TYPES = ("LATENT",) + + FUNCTION = "upscale" + + CATEGORY = "latent" + + def upscale(self, latent, version, upscale): + device = model_management.get_torch_device() + samples = latent["samples"].to(device=device, dtype=self.dtype) + + if version != self.version: + self.model = LatentResizer.load_model(self.weight_path[version], device, self.dtype) + self.version = version + + self.model.to(device=device) + latent_out = (self.model(self.scale_factor * samples, scale=upscale) / self.scale_factor) + + if self.dtype != torch.float32: + latent_out = latent_out.to(dtype=torch.float32) + + latent_out = latent_out.to(device="cpu") + + self.model.to(device=model_management.vae_offload_device()) + return ({"samples": latent_out},) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 4edd327..1717cc3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -simpleeval -websockets \ No newline at end of file +simpleeval \ No newline at end of file diff --git a/tsc_utils.py b/tsc_utils.py index ca138d9..d1eab7b 100644 --- a/tsc_utils.py +++ b/tsc_utils.py @@ -36,16 +36,44 @@ loaded_objects = { "lora": [] # ([(lora_name, strength_model, strength_clip)], ckpt_name, lora_model, clip_lora, [id]) } -# Cache for Ksampler (Efficient) Outputs +# Cache for Efficient Ksamplers last_helds = { - "preview_images": [], # (preview_images, id) # Preview Images, stored as a pil image list - "latent": [], # (latent, id) # Latent outputs, stored as a latent tensor list - "output_images": [], # (output_images, id) # Output Images, stored as an image tensor list - "vae_decode_flag": [], # (vae_decode, id) # Boolean to track wether vae-decode during Holds - "xy_plot_flag": [], # (xy_plot_flag, id) # Boolean to track if held images are xy_plot results - "xy_plot_image": [], # (xy_plot_image, id) # XY Plot image stored as an image tensor + "latent": [], # (latent, [parameters], id) # Base sampling latent results + "image": [], # (image, id) # Base sampling image results + "cnet_img": [] # (cnet_img, [parameters], id) # HiRes-Fix control net preprocessor image results } +def load_ksampler_results(key: str, my_unique_id, parameters_list=None): + global last_helds + for data in last_helds[key]: + id_ = data[-1] # ID is always the last element in the tuple + if id_ == my_unique_id: + if parameters_list is not None: + # Ensure tuple has at least 3 elements and match with parameters_list + if len(data) >= 3 and data[1] == parameters_list: + return data[0] + else: + return data[0] + return None + +def store_ksampler_results(key: str, my_unique_id, value, parameters_list=None): + global last_helds + + for i, data in enumerate(last_helds[key]): + id_ = data[-1] # ID will always be the last in the tuple + if id_ == my_unique_id: + # Check if parameters_list is provided or not + updated_data = (value, parameters_list, id_) if parameters_list is not None else (value, id_) + last_helds[key][i] = updated_data + return True + + # If parameters_list is given + if parameters_list is not None: + last_helds[key].append((value, parameters_list, my_unique_id)) + else: + last_helds[key].append((value, my_unique_id)) + return True + # Tensor to PIL (grabbed from WAS Suite) def tensor2pil(image: torch.Tensor) -> Image.Image: return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) @@ -54,6 +82,20 @@ def tensor2pil(image: torch.Tensor) -> Image.Image: def pil2tensor(image: Image.Image) -> torch.Tensor: return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) +# Convert tensor to PIL, resize it, and convert back to tensor +def quick_resize(source_tensor: torch.Tensor, target_shape: tuple) -> torch.Tensor: + resized_images = [] + for img in source_tensor: + resized_pil = tensor2pil(img.squeeze(0)).resize((target_shape[2], target_shape[1]), Image.ANTIALIAS) + resized_images.append(pil2tensor(resized_pil).squeeze(0)) + return torch.stack(resized_images, dim=0) + +# Create a function to compute the hash of a tensor +import hashlib +def tensor_to_hash(tensor): + byte_repr = tensor.cpu().numpy().tobytes() # Convert tensor to bytes + return hashlib.sha256(byte_repr).hexdigest() # Compute hash + # Color coded messages functions MESSAGE_COLOR = "\033[36m" # Cyan XYPLOT_COLOR = "\033[35m" # Purple @@ -154,9 +196,11 @@ def globals_cleanup(prompt): # Step 1: Clean up last_helds for key in list(last_helds.keys()): original_length = len(last_helds[key]) - last_helds[key] = [(value, id) for value, id in last_helds[key] if str(id) in prompt.keys()] - ###if original_length != len(last_helds[key]): - ###print(f'Updated {key} in last_helds: {last_helds[key]}') + last_helds[key] = [ + (*values, id_) + for *values, id_ in last_helds[key] + if str(id_) in prompt.keys() + ] # Step 2: Clean up loaded_objects for key in list(loaded_objects.keys()): @@ -250,6 +294,7 @@ def load_vae(vae_name, id, cache=None, cache_overwrite=False): vae_path = vae_name else: vae_path = folder_paths.get_full_path("vae", vae_name) + sd = comfy.utils.load_torch_file(vae_path) vae = comfy.sd.VAE(sd=sd) @@ -473,20 +518,11 @@ def global_preview_method(): #----------------------------------------------------------------------------------------------------------------------- # Auto install Efficiency Nodes Python package dependencies import subprocess -# Note: This auto-installer attempts to import packages listed in the requirements.txt. -# If the import fails, indicating the package isn't installed, the installer proceeds to install the package. +# Note: This auto-installer installs packages listed in the requirements.txt. # It first checks if python.exe exists inside the ...\ComfyUI_windows_portable\python_embeded directory. # If python.exe is found in that location, it will use this embedded Python version for the installation. -# Otherwise, it uses the Python interpreter that's currently executing the script (via sys.executable) -# to attempt a general pip install of the packages. If any errors occur during installation, an error message is -# printed with the reason for the failure, and the user is directed to manually install the required packages. - -def is_package_installed(pkg_name): - try: - __import__(pkg_name) - return True - except ImportError: - return False +# Otherwise, it uses the Python interpreter that's currently executing the script (via sys.executable) to attempt a general pip install of the packages. +# If any errors occur during installation, the user is directed to manually install the required packages. def install_packages(my_dir): # Compute path to the target site-packages @@ -500,26 +536,41 @@ def install_packages(my_dir): with open(os.path.join(my_dir, 'requirements.txt'), 'r') as f: required_packages = [line.strip() for line in f if line.strip()] - for pkg in required_packages: - if not is_package_installed(pkg): - printout = f"Installing required package '{pkg}'..." - print(f"{message('Efficiency Nodes:')} {printout}", end='', flush=True) + try: + installed_packages = packages(embedded_python_exe if use_embedded else None, versions=False) + + for pkg in required_packages: + if pkg not in installed_packages: + printout = f"Installing required package '{pkg}'..." + print(f"{message('Efficiency Nodes:')} {printout}", end='', flush=True) - try: if use_embedded: # Targeted installation subprocess.check_call([embedded_python_exe, '-m', 'pip', 'install', pkg, '--target=' + target_dir, '--no-warn-script-location', '--disable-pip-version-check'], - stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=7) + stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) else: # Untargeted installation subprocess.check_call([sys.executable, "-m", "pip", 'install', pkg], - stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, timeout=7) - print(f"\r{message('Efficiency Nodes:')} {printout}{success(' Installed!')}", flush=True) - except Exception as e: - print(f"\r{message('Efficiency Nodes:')} {printout}{error(' Failed!')}", flush=True) - print(f"{warning(str(e))}") - + stdout=subprocess.DEVNULL, stderr=subprocess.PIPE) + + print(f"\r{message('Efficiency Nodes:')} {printout}{success('Installed!')}", flush=True) + + except Exception as e: # This catches all exceptions derived from the base Exception class + print_general_error_message() + +def packages(python_exe=None, versions=False): + try: + if python_exe: + return [(r.decode().split('==')[0] if not versions else r.decode()) for r in + subprocess.check_output([python_exe, '-m', 'pip', 'freeze']).split()] + else: + return [(r.split('==')[0] if not versions else r) for r in + subprocess.getoutput([sys.executable, "-m", "pip", "freeze"]).splitlines()] + except subprocess.CalledProcessError as e: + raise e # re-raise the error to handle it outside + def print_general_error_message(): - print(f"{message('Efficiency Nodes:')} An unexpected error occurred during the package installation process. {error('Failed!')}") + print( + f"\r{message('Efficiency Nodes:')} An unexpected error occurred during the package installation process. {error('Failed!')}") print(warning("Please try manually installing the required packages from the requirements.txt file.")) # Install missing packages @@ -538,85 +589,7 @@ if os.path.exists(destination_dir): shutil.rmtree(destination_dir) #----------------------------------------------------------------------------------------------------------------------- -# Establish a websocket connection to communicate with "efficiency-nodes.js" under: -# ComfyUI\web\extensions\efficiency-nodes-comfyui\ -def handle_websocket_failure(): - global websocket_status - if websocket_status: # Ensures the message is printed only once - websocket_status = False - print(f"\r\033[33mEfficiency Nodes Warning:\033[0m Websocket connection failure." - f"\nEfficient KSampler's live preview images may not clear when vae decoding is set to 'true'.") - -# Initialize websocket related global variables -websocket_status = True -latest_image = list() -connected_client = None - -try: - import websockets - import asyncio - import threading - import base64 - from io import BytesIO - from torchvision import transforms -except ImportError: - handle_websocket_failure() - -async def server_logic(websocket, path): - global latest_image, connected_client, websocket_status - - # If websocket_status is False, set latest_image to an empty list - if not websocket_status: - latest_image = list() - - # Assign the connected client - connected_client = websocket - - try: - async for message in websocket: - # If not a command, treat it as image data - if not message.startswith('{'): - image_data = base64.b64decode(message.split(",")[1]) - image = Image.open(BytesIO(image_data)) - latest_image = pil2tensor(image) - except (websockets.exceptions.ConnectionClosedError, asyncio.exceptions.CancelledError): - handle_websocket_failure() - except Exception: - handle_websocket_failure() - -def run_server(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - start_server = websockets.serve(server_logic, "127.0.0.1", 8288) - loop.run_until_complete(start_server) - loop.run_forever() - except Exception: # Catch all exceptions - handle_websocket_failure() - -def get_latest_image(): - return latest_image - -# Function to send commands to frontend -def send_command_to_frontend(startListening=False, maxCount=0, sendBlob=False): - global connected_client, websocket_status - if connected_client and websocket_status: - try: - asyncio.run(connected_client.send(json.dumps({ - 'startProcessing': startListening, - 'maxCount': maxCount, - 'sendBlob': sendBlob - }))) - except Exception: - handle_websocket_failure() - -# Start the WebSocket server in a separate thread -if websocket_status == True: - server_thread = threading.Thread(target=run_server) - server_thread.daemon = True - server_thread.start() - - +# Other class XY_Capsule: def pre_define_model(self, model, clip, vae): return model, clip, vae @@ -632,3 +605,12 @@ class XY_Capsule: def getLabel(self): return "Unknown" + + + + + + + + + diff --git a/workflows/AnimateDiff & HiResFix Scripts.gif b/workflows/AnimateDiff & HiResFix Scripts.gif new file mode 100644 index 0000000..f382e50 Binary files /dev/null and b/workflows/AnimateDiff & HiResFix Scripts.gif differ diff --git a/workflows/ControlNet (Overview).png b/workflows/ControlNet (Overview).png deleted file mode 100644 index d9cd6ea..0000000 Binary files a/workflows/ControlNet (Overview).png and /dev/null differ diff --git a/workflows/ControlNet.png b/workflows/ControlNet.png deleted file mode 100644 index d619ae4..0000000 Binary files a/workflows/ControlNet.png and /dev/null differ diff --git a/workflows/HiRes Fix (overview).png b/workflows/HiRes Fix (overview).png deleted file mode 100644 index e3748f6..0000000 Binary files a/workflows/HiRes Fix (overview).png and /dev/null differ diff --git a/workflows/HiRes Fix.png b/workflows/HiRes Fix.png deleted file mode 100644 index e27313c..0000000 Binary files a/workflows/HiRes Fix.png and /dev/null differ diff --git a/workflows/HiResFix Script.png b/workflows/HiResFix Script.png new file mode 100644 index 0000000..8c8dbb7 Binary files /dev/null and b/workflows/HiResFix Script.png differ diff --git a/workflows/Image Overlay (overview).png b/workflows/Image Overlay (overview).png deleted file mode 100644 index 5f82e03..0000000 Binary files a/workflows/Image Overlay (overview).png and /dev/null differ diff --git a/workflows/Image Overlay.png b/workflows/Image Overlay.png deleted file mode 100644 index ae2e9c2..0000000 Binary files a/workflows/Image Overlay.png and /dev/null differ diff --git a/workflows/SDXL Base+Refine (Overview).png b/workflows/SDXL Base+Refine (Overview).png deleted file mode 100644 index 0da4f92..0000000 Binary files a/workflows/SDXL Base+Refine (Overview).png and /dev/null differ diff --git a/workflows/SDXL Base+Refine.png b/workflows/SDXL Base+Refine.png deleted file mode 100644 index cdecdce..0000000 Binary files a/workflows/SDXL Base+Refine.png and /dev/null differ diff --git a/workflows/SDXL Refining & Noise Control Script.png b/workflows/SDXL Refining & Noise Control Script.png new file mode 100644 index 0000000..b04b3cb Binary files /dev/null and b/workflows/SDXL Refining & Noise Control Script.png differ diff --git a/workflows/Tiled Upscaler Script.png b/workflows/Tiled Upscaler Script.png new file mode 100644 index 0000000..b6dc3a0 Binary files /dev/null and b/workflows/Tiled Upscaler Script.png differ diff --git a/workflows/XYplot/Manual Entry Notes.txt b/workflows/XY Plot Input Manual Entry Notes.txt similarity index 100% rename from workflows/XYplot/Manual Entry Notes.txt rename to workflows/XY Plot Input Manual Entry Notes.txt diff --git a/workflows/XYPlot - LoRA Model vs Clip Strengths.png b/workflows/XYPlot - LoRA Model vs Clip Strengths.png new file mode 100644 index 0000000..1b9dffd Binary files /dev/null and b/workflows/XYPlot - LoRA Model vs Clip Strengths.png differ diff --git a/workflows/XYPlot - Seeds vs Checkpoints & Stacked Scripts.png b/workflows/XYPlot - Seeds vs Checkpoints & Stacked Scripts.png new file mode 100644 index 0000000..f8cce85 Binary files /dev/null and b/workflows/XYPlot - Seeds vs Checkpoints & Stacked Scripts.png differ diff --git a/workflows/XYplot/LoRA Plot X-ModelStr Y-ClipStr (Overview).png b/workflows/XYplot/LoRA Plot X-ModelStr Y-ClipStr (Overview).png deleted file mode 100644 index 74bcd3e..0000000 Binary files a/workflows/XYplot/LoRA Plot X-ModelStr Y-ClipStr (Overview).png and /dev/null differ diff --git a/workflows/XYplot/LoRA Plot X-ModelStr Y-ClipStr.png b/workflows/XYplot/LoRA Plot X-ModelStr Y-ClipStr.png deleted file mode 100644 index ccfa732..0000000 Binary files a/workflows/XYplot/LoRA Plot X-ModelStr Y-ClipStr.png and /dev/null differ diff --git a/workflows/XYplot/X-Seeds Y-Checkpoints (overview).png b/workflows/XYplot/X-Seeds Y-Checkpoints (overview).png deleted file mode 100644 index 05ac0fc..0000000 Binary files a/workflows/XYplot/X-Seeds Y-Checkpoints (overview).png and /dev/null differ diff --git a/workflows/XYplot/X-Seeds Y-Checkpoints.png b/workflows/XYplot/X-Seeds Y-Checkpoints.png deleted file mode 100644 index a5f7990..0000000 Binary files a/workflows/XYplot/X-Seeds Y-Checkpoints.png and /dev/null differ