diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ed8ebf5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 455aac9..c07619d 100644 --- a/README.md +++ b/README.md @@ -1 +1,54 @@ # ComfyUI-Lora-Auto-Trigger-Words + +This project is a fork of https://github.com/Extraltodeus/LoadLoraWithTags +The aim of these custom nodes is to get an _easy_ access to the tags used to trigger a lora. +This project is compatible with Stacked Loras from https://github.com/LucianoCirino/efficiency-nodes-comfyui/releases + +## Install +Some of this project nodes depends on https://github.com/pythongosssss/ComfyUI-Custom-Scripts : +- LoraLoaderAdvanced +- LoraLoaderStackedAdvanced +They get their vanilla equivalents. + +Overall, Custom-Scripts is recommended to be able to know the content of the tag lists with the node `showText` + +## Features +### Main nodes +#### Vanilla vs Advanced +Vanilla refers to nodes that have no lora preview from the menu, nor the lora list. But the features provided are the same. +![image](./images/main.png) +#### Nodes +- LoraLoader (Vanilla or Advanced) +- LoraLoaderStacked (Vanilla or Avanced). The stacked lora input is optional. +Allow to load a lora, either the normal way, or the efficiency-nodes way. +These loaders have two custom outputs: +- civitai_tags_list: a python list of the tags related to this lora on civitai +- meta_tags_list: a python list of the tags used for training the lora embeded in it (if any) +This outputs needs to be filtered by two othere nodes: +- TagsFormater: list in a comprehensible way the available tags +- tagsSelector: allow to filter tags and apply a weight to it. +#### Filtering +The format is simple. It's the same as python list index, but can select multiple index or ranges of indexes separated by comas. +`Ex: 0, 3, 5:8, -8:` +Select a specific list of indexes: `0, 2, 3, 15`... +Select range of indexes: `2:5, 10:15`... +Select a range from the begining to a specific index: `:5` +Select a range from a specific index to the end: `5:` +You can use negative indexes. Like `-1` to select the last tag +By default `:` selects everything + +#### Example of normal workflow +![image](./images/loaderAdvanced.png) + +#### Example of Stacked workflow +![image](./images/loaderStacked.png) + +#### Chaining Selectors and Stacked +Tags selectors can be chained to select differents tags with differents weights `(tags1:0.8), tag2, (tag3:1.1)`. +Lora Stack can also be chained together to load multiple loras into an efficient loaders. +![image](./images/stackingLoras.png) + +### Side nodes I made and kept here +- FusionText: takes two text input and join them together +- Randomizer: takes two couples text+lorastack and return randomly one them +- TextInputBasic: just a text input with two additional input for text chaining diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..fe1e2af --- /dev/null +++ b/__init__.py @@ -0,0 +1,8 @@ +#from .nodes_autotrigger import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as na_NCM, na_NDNM +#from .nodes_utils import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS as nu_NCM, nu_NDNM +from .nodes_autotrigger import NODE_CLASS_MAPPINGS as na_NCM +from .nodes_utils import NODE_CLASS_MAPPINGS as nu_NCM + +NODE_CLASS_MAPPINGS = dict(na_NCM, **nu_NCM) +#NODE_DISPLAY_NAME_MAPPINGS = dict(na_NDNM, **nu_NDNM) +__all__ = ["NODE_CLASS_MAPPINGS"]#, "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/images/loaderAdvanced.png b/images/loaderAdvanced.png new file mode 100644 index 0000000..13f983a Binary files /dev/null and b/images/loaderAdvanced.png differ diff --git a/images/loaderStacked.png b/images/loaderStacked.png new file mode 100644 index 0000000..b10df80 Binary files /dev/null and b/images/loaderStacked.png differ diff --git a/images/main.png b/images/main.png new file mode 100644 index 0000000..215aaf3 Binary files /dev/null and b/images/main.png differ diff --git a/images/stackingLoras.png b/images/stackingLoras.png new file mode 100644 index 0000000..f9b8c10 Binary files /dev/null and b/images/stackingLoras.png differ diff --git a/nodes_autotrigger.py b/nodes_autotrigger.py new file mode 100644 index 0000000..872b785 --- /dev/null +++ b/nodes_autotrigger.py @@ -0,0 +1,203 @@ +from comfy.sd import load_lora_for_models +from comfy.utils import load_torch_file +import folder_paths + +from .utils import * + +class LoraLoaderVanilla: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + LORA_LIST = sorted(folder_paths.get_filename_list("loras"), key=str.lower) + return { + "required": { + "model": ("MODEL",), + "clip": ("CLIP", ), + "lora_name": (LORA_LIST, ), + "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1}), + "force_fetch": ("BOOLEAN", {"default": False}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP", "LIST", "LIST") + RETURN_NAMES = ("MODEL", "CLIP", "civitai_tags_list", "meta_tags_list") + FUNCTION = "load_lora" + CATEGORY = "autotrigger" + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip, force_fetch): + meta_tags_list = sort_tags_by_frequency(get_metadata(lora_name, "loras")) + output_tags_list = load_and_save_tags(lora_name, force_fetch) + lora_path = folder_paths.get_full_path("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + temp = self.loaded_lora + self.loaded_lora = None + del temp + + if lora is None: + lora = load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = load_lora_for_models(model, clip, lora, strength_model, strength_clip) + + return (model_lora, clip_lora, output_tags_list, meta_tags_list) + +class LoraLoaderStackedVanilla: + @classmethod + def INPUT_TYPES(s): + LORA_LIST = folder_paths.get_filename_list("loras") + return { + "required": { + "lora_name": (LORA_LIST,), + "lora_weight": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "force_fetch": ("BOOLEAN", {"default": False}), + }, + "optional": { + "lora_stack": ("LORA_STACK", ), + } + } + + RETURN_TYPES = ("LIST", "LIST", "LORA_STACK",) + RETURN_NAMES = ("civitai_tags_list", "meta_tags_list", "LORA_STACK",) + FUNCTION = "set_stack" + #OUTPUT_NODE = False + CATEGORY = "autotrigger" + + def set_stack(self, lora_name, lora_weight, force_fetch, lora_stack=None): + civitai_tags_list = load_and_save_tags(lora_name, force_fetch) + + meta_tags = get_metadata(lora_name, "loras") + meta_tags_list = sort_tags_by_frequency(meta_tags) + + if lora_stack is not None: + lora_stack.append((lora_name,lora_weight,lora_weight,)) + else: + lora_stack = [(lora_name,lora_weight,lora_weight,)] + + return (civitai_tags_list, meta_tags_list, lora_stack) + +class LoraLoaderAdvanced: + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + LORA_LIST = sorted(folder_paths.get_filename_list("loras"), key=str.lower) + populate_items(LORA_LIST, "loras") + return { + "required": { + "model": ("MODEL",), + "clip": ("CLIP", ), + "lora_name": (LORA_LIST, ), + "strength_model": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1}), + "force_fetch": ("BOOLEAN", {"default": False}), + "enable_preview": ("BOOLEAN", {"default": False}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP", "LIST", "LIST") + RETURN_NAMES = ("MODEL", "CLIP", "civitai_tags_list", "meta_tags_list") + FUNCTION = "load_lora" + CATEGORY = "autotrigger" + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip, force_fetch, enable_preview): + meta_tags_list = sort_tags_by_frequency(get_metadata(lora_name["content"], "loras")) + output_tags_list = load_and_save_tags(lora_name["content"], force_fetch) + lora_path = folder_paths.get_full_path("loras", lora_name["content"]) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + temp = self.loaded_lora + self.loaded_lora = None + del temp + + if lora is None: + lora = load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = load_lora_for_models(model, clip, lora, strength_model, strength_clip) + if enable_preview: + _, preview = copy_preview_to_temp(lora_name["image"]) + if preview is not None: + preview_output = { + "filename": preview, + "subfolder": "lora_preview", + "type": "temp" + } + return {"ui": {"images": [preview_output]}, "result": (model_lora, clip_lora, output_tags_list, meta_tags_list)} + + + return (model_lora, clip_lora, output_tags_list, meta_tags_list) + +class LoraLoaderStackedAdvanced: + @classmethod + def INPUT_TYPES(s): + LORA_LIST = folder_paths.get_filename_list("loras") + populate_items(LORA_LIST, "loras") + return { + "required": { + "lora_name": (LORA_LIST,), + "lora_weight": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "force_fetch": ("BOOLEAN", {"default": False}), + "enable_preview": ("BOOLEAN", {"default": False}), + }, + "optional": { + "lora_stack": ("LORA_STACK", ), + } + } + + RETURN_TYPES = ("LIST", "LIST", "LORA_STACK",) + RETURN_NAMES = ("civitai_tags_list", "meta_tags_list", "LORA_STACK",) + FUNCTION = "set_stack" + #OUTPUT_NODE = False + CATEGORY = "autotrigger" + + def set_stack(self, lora_name, lora_weight, force_fetch, enable_preview, lora_stack=None): + civitai_tags_list = load_and_save_tags(lora_name["content"], force_fetch) + + meta_tags = get_metadata(lora_name["content"], "loras") + meta_tags_list = sort_tags_by_frequency(meta_tags) + + if lora_stack is not None: + lora_stack.append((lora_name["content"],lora_weight,lora_weight,)) + else: + lora_stack = [(lora_name["content"],lora_weight,lora_weight,)] + + if enable_preview: + _, preview = copy_preview_to_temp(lora_name["image"]) + if preview is not None: + preview_output = { + "filename": preview, + "subfolder": "lora_preview", + "type": "temp" + } + return {"ui": {"images": [preview_output]}, "result": (civitai_tags_list, meta_tags_list, lora_stack)} + + return {"result": (civitai_tags_list, meta_tags_list, lora_stack)} + + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "LoraLoaderVanilla": LoraLoaderVanilla, + "LoraLoaderStackedVanilla": LoraLoaderStackedVanilla, + "LoraLoaderAdvanced": LoraLoaderAdvanced, + "LoraLoaderStackedAdvanced": LoraLoaderStackedAdvanced, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "LoraLoaderVanilla": "LoraLoaderVanilla", + "LoraLoaderStackedVanilla": "LoraLoaderStackedVanilla", + "LoraLoaderAdvanced": "LoraLoaderAdvanced", + "LoraLoaderStackedAdvanced": "LoraLoaderStackedAdvanced", +} diff --git a/nodes_utils.py b/nodes_utils.py new file mode 100644 index 0000000..03a01f6 --- /dev/null +++ b/nodes_utils.py @@ -0,0 +1,141 @@ +import random + +from .utils import * + +class FusionText: + @classmethod + def INPUT_TYPES(s): + return {"required": {"text_1": ("STRING", {"default": "", "forceInput": True}), "text_2": ("STRING", {"default": "", "forceInput": True})}} + RETURN_TYPES = ("STRING",) + FUNCTION = "combine" + CATEGORY = "autotrigger" + + def combine(self, text_1, text_2): + return (text_1 + text_2, ) + + +class Randomizer: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "text_1":("STRING", {"forceInput": True}), + "lora_1":("LORA_STACK", ), + "text_2":("STRING", {"forceInput": True} ), + "lora_2":("LORA_STACK", ), + "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}), + }, + } + + RETURN_TYPES = ("STRING", "LORA_STACK") + RETURN_NAMES = ("text", "lora stack") + FUNCTION = "randomize" + + #OUTPUT_NODE = False + + CATEGORY = "autotrigger" + + def randomize(self, text_1, lora_1, text_2, lora_2, seed): + random.seed(seed) + if random.random() < .5: + return (text_1, lora_1) + return (text_2, lora_2) + +class TextInputBasic: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "text":("STRING", {"default":"", "multiline":True}), + }, + "optional": { + "prefix":("STRING", {"default":"", "forceInput": True}), + "suffix":("STRING", {"default":"", "forceInput": True}), + } + } + + RETURN_TYPES = ("STRING",) + RETURN_NAMES = ("text", ) + FUNCTION = "get_text" + + #OUTPUT_NODE = False + + CATEGORY = "autotrigger" + + def get_text(self, text, prefix="", suffix=""): + return (prefix + text + suffix, ) + + +class TagsSelector: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "tags_list": ("LIST", {"default": []}), + "selector": ("STRING", {"default": ":"}), + "weight": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "ensure_coma": ("BOOLEAN", {"default": True}) + }, + "optional": { + "prefix":("STRING", {"default":"", "forceInput": True}), + "suffix":("STRING", {"default":"", "forceInput": True}), + } + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "select_tags" + CATEGORY = "autotrigger" + + def select_tags(self, tags_list, selector, weight, ensure_coma, prefix="", suffix=""): + if weight != 1.0: + tags_list = [f"({tag}:{weight})" for tag in tags_list] + output = parse_selector(selector, tags_list) + if ensure_coma: + striped_prefix = prefix.strip() + striped_suffix = suffix.strip() + if striped_prefix != "" and not striped_prefix.endswith(",") and output != "" and not output.startswith(","): + prefix = striped_prefix + ", " + if output != "" and not output.endswith(",") and striped_suffix != "" and not striped_suffix.startswith(","): + suffix = ", " + striped_suffix + return (prefix + output + suffix, ) + +class TagsFormater: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "tags_list": ("LIST", {"default": []}), + }, + } + + RETURN_TYPES = ("STRING",) + FUNCTION = "format_tags" + CATEGORY = "autotrigger" + + def format_tags(self, tags_list): + output = "" + i = 0 + for tag in tags_list: + output += f'{i} : "{tag}"\n' + i+=1 + + return (output,) + +# A dictionary that contains all nodes you want to export with their names +# NOTE: names should be globally unique +NODE_CLASS_MAPPINGS = { + "Randomizer": Randomizer, + "FusionText": FusionText, + "TextInputBasic": TextInputBasic, + "TagsSelector": TagsSelector, + "TagsFormater": TagsFormater, +} + +# A dictionary that contains the friendly/humanly readable titles for the nodes +NODE_DISPLAY_NAME_MAPPINGS = { + "Randomizer": "Randomizer", + "FusionText": "FusionText", + "TextInputBasic": "TextInputBasic", + "TagsSelector": "TagsSelector", + "TagsFormater": "TagsFormater", +} diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..dfad7a3 --- /dev/null +++ b/utils.py @@ -0,0 +1,202 @@ +import folder_paths +import hashlib +import json +import os +import requests +import shutil + +def get_preview_path(name, type): + file_name = os.path.splitext(name)[0] + file_path = folder_paths.get_full_path(type, name) + + if file_path is None: + print(f"Unable to get path for {type} {name}") + return None + + file_path_no_ext = os.path.splitext(file_path)[0] + item_image=None + for ext in ["png", "jpg", "jpeg", "preview.png"]: + has_image = os.path.isfile(file_path_no_ext + "." + ext) + if has_image: + item_image = f"{file_name}.{ext}" + break + + return has_image, item_image + + +def copy_preview_to_temp(file_name): + if file_name is None: + return None, None + base_name = os.path.basename(file_name) + lora_less = "/".join(file_name.split("/")[1:]) + + file_path = folder_paths.get_full_path("loras", lora_less) + + temp_path = folder_paths.get_temp_directory() + preview_path = os.path.join(temp_path, "lora_preview") + if not os.path.isdir(preview_path) : + os.makedirs(preview_path) + preview_path = os.path.join(preview_path, base_name) + + + shutil.copyfile(file_path, preview_path) + return preview_path, base_name + +# add previews in selectors +def populate_items(names, type): + for idx, item_name in enumerate(names): + + has_image, item_image = get_preview_path(item_name, type) + + names[idx] = { + "content": item_name, + "image": f"{type}/{item_image}" if has_image else None, + "type": "loras", + } + names.sort(key=lambda i: i["content"].lower()) + + +def load_json_from_file(file_path): + try: + with open(file_path, 'r') as json_file: + data = json.load(json_file) + return data + except FileNotFoundError: + print(f"File not found: {file_path}") + return None + except json.JSONDecodeError: + print(f"Error decoding JSON in file: {file_path}") + return None + +def save_dict_to_json(data_dict, file_path): + try: + with open(file_path, 'w') as json_file: + json.dump(data_dict, json_file, indent=4) + print(f"Data saved to {file_path}") + except Exception as e: + print(f"Error saving JSON to file: {e}") + +def get_model_version_info(hash_value): + api_url = f"https://civitai.com/api/v1/model-versions/by-hash/{hash_value}" + response = requests.get(api_url) + + if response.status_code == 200: + return response.json() + else: + return None + +def calculate_sha256(file_path): + sha256_hash = hashlib.sha256() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + sha256_hash.update(chunk) + return sha256_hash.hexdigest() + + +def load_and_save_tags(lora_name, force_fetch): + json_tags_path = "./loras_tags.json" + lora_tags = load_json_from_file(json_tags_path) + output_tags = lora_tags.get(lora_name, None) if lora_tags is not None else None + if output_tags is not None: + output_tags_list = output_tags + else: + output_tags_list = [] + + lora_path = folder_paths.get_full_path("loras", lora_name) + if lora_tags is None or force_fetch: # search on civitai only if no local cache or forced + print("calculating lora hash") + LORAsha256 = calculate_sha256(lora_path) + print("requesting infos") + model_info = get_model_version_info(LORAsha256) + if model_info is not None: + if "trainedWords" in model_info: + print("tags found!") + if lora_tags is None: + lora_tags = {} + lora_tags[lora_name] = model_info["trainedWords"] + save_dict_to_json(lora_tags,json_tags_path) + output_tags_list = model_info["trainedWords"] + else: + print("No informations found.") + if lora_tags is None: + lora_tags = {} + lora_tags[lora_name] = [] + save_dict_to_json(lora_tags,json_tags_path) + + return output_tags_list + +def show_list(list_input): + i = 0 + output = "" + for debug in list_input: + output += f"{i} : {debug}\n" + i+=1 + return output + +def get_metadata(filepath, type): + filepath = folder_paths.get_full_path(type, filepath) + with open(filepath, "rb") as file: + # https://github.com/huggingface/safetensors#format + # 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header + header_size = int.from_bytes(file.read(8), "little", signed=False) + + if header_size <= 0: + raise BufferError("Invalid header size") + + header = file.read(header_size) + if header_size <= 0: + raise BufferError("Invalid header") + header_json = json.loads(header) + return header_json["__metadata__"] if "__metadata__" in header_json else None + +# parse the __metadata__ json looking for trained tags +def sort_tags_by_frequency(meta_tags): + if meta_tags is None: + return [] + if "ss_tag_frequency" in meta_tags: + meta_tags = meta_tags["ss_tag_frequency"] + meta_tags = json.loads(meta_tags) + sorted_tags = {} + for _, dataset in meta_tags.items(): + for tag, count in dataset.items(): + tag = str(tag).strip() + if tag in sorted_tags: + sorted_tags[tag] = sorted_tags[tag] + count + else: + sorted_tags[tag] = count + # sort tags by training frequency. Most seen tags firsts + sorted_tags = dict(sorted(sorted_tags.items(), key=lambda item: item[1], reverse=True)) + return list(sorted_tags.keys()) + else: + return [] + +def parse_selector(selector, tags_list): + range_index_list = selector.split(",") + output = {} + for range_index in range_index_list: + # single value + if range_index.count(":") == 0: + index = int(range_index) + output[index] = tags_list[index] + + # actual range + if range_index.count(":") == 1: + indexes = range_index.split(":") + # check empty + if indexes[0] == "": + start = 0 + else: + start = int(indexes[0]) + if indexes[1] == "": + end = len(tags_list) + else: + end = int(indexes[1]) + # check negative + if start < 0: + start = len(tags_list) + start + if end < 0: + end = len(tags_list) + end + # merge all + for i in range(start, end): + output[i] = tags_list[i] + return ", ".join(list(output.values()))