mirror of
https://github.com/idrirap/ComfyUI-Lora-Auto-Trigger-Words.git
synced 2026-03-21 21:22:12 -03:00
initial commit
This commit is contained in:
202
utils.py
Normal file
202
utils.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import folder_paths
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import requests
|
||||
import shutil
|
||||
|
||||
def get_preview_path(name, type):
|
||||
file_name = os.path.splitext(name)[0]
|
||||
file_path = folder_paths.get_full_path(type, name)
|
||||
|
||||
if file_path is None:
|
||||
print(f"Unable to get path for {type} {name}")
|
||||
return None
|
||||
|
||||
file_path_no_ext = os.path.splitext(file_path)[0]
|
||||
item_image=None
|
||||
for ext in ["png", "jpg", "jpeg", "preview.png"]:
|
||||
has_image = os.path.isfile(file_path_no_ext + "." + ext)
|
||||
if has_image:
|
||||
item_image = f"{file_name}.{ext}"
|
||||
break
|
||||
|
||||
return has_image, item_image
|
||||
|
||||
|
||||
def copy_preview_to_temp(file_name):
|
||||
if file_name is None:
|
||||
return None, None
|
||||
base_name = os.path.basename(file_name)
|
||||
lora_less = "/".join(file_name.split("/")[1:])
|
||||
|
||||
file_path = folder_paths.get_full_path("loras", lora_less)
|
||||
|
||||
temp_path = folder_paths.get_temp_directory()
|
||||
preview_path = os.path.join(temp_path, "lora_preview")
|
||||
if not os.path.isdir(preview_path) :
|
||||
os.makedirs(preview_path)
|
||||
preview_path = os.path.join(preview_path, base_name)
|
||||
|
||||
|
||||
shutil.copyfile(file_path, preview_path)
|
||||
return preview_path, base_name
|
||||
|
||||
# add previews in selectors
|
||||
def populate_items(names, type):
|
||||
for idx, item_name in enumerate(names):
|
||||
|
||||
has_image, item_image = get_preview_path(item_name, type)
|
||||
|
||||
names[idx] = {
|
||||
"content": item_name,
|
||||
"image": f"{type}/{item_image}" if has_image else None,
|
||||
"type": "loras",
|
||||
}
|
||||
names.sort(key=lambda i: i["content"].lower())
|
||||
|
||||
|
||||
def load_json_from_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'r') as json_file:
|
||||
data = json.load(json_file)
|
||||
return data
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {file_path}")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error decoding JSON in file: {file_path}")
|
||||
return None
|
||||
|
||||
def save_dict_to_json(data_dict, file_path):
|
||||
try:
|
||||
with open(file_path, 'w') as json_file:
|
||||
json.dump(data_dict, json_file, indent=4)
|
||||
print(f"Data saved to {file_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving JSON to file: {e}")
|
||||
|
||||
def get_model_version_info(hash_value):
|
||||
api_url = f"https://civitai.com/api/v1/model-versions/by-hash/{hash_value}"
|
||||
response = requests.get(api_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return None
|
||||
|
||||
def calculate_sha256(file_path):
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(file_path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(chunk)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def load_and_save_tags(lora_name, force_fetch):
|
||||
json_tags_path = "./loras_tags.json"
|
||||
lora_tags = load_json_from_file(json_tags_path)
|
||||
output_tags = lora_tags.get(lora_name, None) if lora_tags is not None else None
|
||||
if output_tags is not None:
|
||||
output_tags_list = output_tags
|
||||
else:
|
||||
output_tags_list = []
|
||||
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
if lora_tags is None or force_fetch: # search on civitai only if no local cache or forced
|
||||
print("calculating lora hash")
|
||||
LORAsha256 = calculate_sha256(lora_path)
|
||||
print("requesting infos")
|
||||
model_info = get_model_version_info(LORAsha256)
|
||||
if model_info is not None:
|
||||
if "trainedWords" in model_info:
|
||||
print("tags found!")
|
||||
if lora_tags is None:
|
||||
lora_tags = {}
|
||||
lora_tags[lora_name] = model_info["trainedWords"]
|
||||
save_dict_to_json(lora_tags,json_tags_path)
|
||||
output_tags_list = model_info["trainedWords"]
|
||||
else:
|
||||
print("No informations found.")
|
||||
if lora_tags is None:
|
||||
lora_tags = {}
|
||||
lora_tags[lora_name] = []
|
||||
save_dict_to_json(lora_tags,json_tags_path)
|
||||
|
||||
return output_tags_list
|
||||
|
||||
def show_list(list_input):
|
||||
i = 0
|
||||
output = ""
|
||||
for debug in list_input:
|
||||
output += f"{i} : {debug}\n"
|
||||
i+=1
|
||||
return output
|
||||
|
||||
def get_metadata(filepath, type):
|
||||
filepath = folder_paths.get_full_path(type, filepath)
|
||||
with open(filepath, "rb") as file:
|
||||
# https://github.com/huggingface/safetensors#format
|
||||
# 8 bytes: N, an unsigned little-endian 64-bit integer, containing the size of the header
|
||||
header_size = int.from_bytes(file.read(8), "little", signed=False)
|
||||
|
||||
if header_size <= 0:
|
||||
raise BufferError("Invalid header size")
|
||||
|
||||
header = file.read(header_size)
|
||||
if header_size <= 0:
|
||||
raise BufferError("Invalid header")
|
||||
header_json = json.loads(header)
|
||||
return header_json["__metadata__"] if "__metadata__" in header_json else None
|
||||
|
||||
# parse the __metadata__ json looking for trained tags
|
||||
def sort_tags_by_frequency(meta_tags):
|
||||
if meta_tags is None:
|
||||
return []
|
||||
if "ss_tag_frequency" in meta_tags:
|
||||
meta_tags = meta_tags["ss_tag_frequency"]
|
||||
meta_tags = json.loads(meta_tags)
|
||||
sorted_tags = {}
|
||||
for _, dataset in meta_tags.items():
|
||||
for tag, count in dataset.items():
|
||||
tag = str(tag).strip()
|
||||
if tag in sorted_tags:
|
||||
sorted_tags[tag] = sorted_tags[tag] + count
|
||||
else:
|
||||
sorted_tags[tag] = count
|
||||
# sort tags by training frequency. Most seen tags firsts
|
||||
sorted_tags = dict(sorted(sorted_tags.items(), key=lambda item: item[1], reverse=True))
|
||||
return list(sorted_tags.keys())
|
||||
else:
|
||||
return []
|
||||
|
||||
def parse_selector(selector, tags_list):
|
||||
range_index_list = selector.split(",")
|
||||
output = {}
|
||||
for range_index in range_index_list:
|
||||
# single value
|
||||
if range_index.count(":") == 0:
|
||||
index = int(range_index)
|
||||
output[index] = tags_list[index]
|
||||
|
||||
# actual range
|
||||
if range_index.count(":") == 1:
|
||||
indexes = range_index.split(":")
|
||||
# check empty
|
||||
if indexes[0] == "":
|
||||
start = 0
|
||||
else:
|
||||
start = int(indexes[0])
|
||||
if indexes[1] == "":
|
||||
end = len(tags_list)
|
||||
else:
|
||||
end = int(indexes[1])
|
||||
# check negative
|
||||
if start < 0:
|
||||
start = len(tags_list) + start
|
||||
if end < 0:
|
||||
end = len(tags_list) + end
|
||||
# merge all
|
||||
for i in range(start, end):
|
||||
output[i] = tags_list[i]
|
||||
return ", ".join(list(output.values()))
|
||||
Reference in New Issue
Block a user