mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
- Update get_lora_info() to check both loras_roots and extra_loras_roots - Add fallback logic to return trigger words even if path not in recognized roots - Ensure Trigger Word Toggle node displays trigger words for LoRAs from extra folder paths Fixes issue where LoRAs added from extra folder paths would not show their trigger words in connected Trigger Word Toggle nodes.
341 lines
11 KiB
Python
341 lines
11 KiB
Python
from difflib import SequenceMatcher
|
|
import os
|
|
import re
|
|
from typing import Dict
|
|
from ..services.service_registry import ServiceRegistry
|
|
from ..config import config
|
|
from ..services.settings_manager import get_settings_manager
|
|
import asyncio
|
|
|
|
|
|
def get_lora_info(lora_name):
|
|
"""Get the lora path and trigger words from cache"""
|
|
|
|
async def _get_lora_info_async():
|
|
scanner = await ServiceRegistry.get_lora_scanner()
|
|
cache = await scanner.get_cached_data()
|
|
|
|
for item in cache.raw_data:
|
|
if item.get("file_name") == lora_name:
|
|
file_path = item.get("file_path")
|
|
if file_path:
|
|
# Check all lora roots including extra paths
|
|
all_roots = list(config.loras_roots or []) + list(
|
|
config.extra_loras_roots or []
|
|
)
|
|
for root in all_roots:
|
|
root = root.replace(os.sep, "/")
|
|
if file_path.startswith(root):
|
|
relative_path = os.path.relpath(file_path, root).replace(
|
|
os.sep, "/"
|
|
)
|
|
# Get trigger words from civitai metadata
|
|
civitai = item.get("civitai", {})
|
|
trigger_words = (
|
|
civitai.get("trainedWords", []) if civitai else []
|
|
)
|
|
return relative_path, trigger_words
|
|
# If not found in any root, return path with trigger words from cache
|
|
civitai = item.get("civitai", {})
|
|
trigger_words = civitai.get("trainedWords", []) if civitai else []
|
|
return file_path, trigger_words
|
|
return lora_name, []
|
|
|
|
try:
|
|
# Check if we're already in an event loop
|
|
loop = asyncio.get_running_loop()
|
|
# If we're in a running loop, we need to use a different approach
|
|
# Create a new thread to run the async code
|
|
import concurrent.futures
|
|
|
|
def run_in_thread():
|
|
new_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(new_loop)
|
|
try:
|
|
return new_loop.run_until_complete(_get_lora_info_async())
|
|
finally:
|
|
new_loop.close()
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
future = executor.submit(run_in_thread)
|
|
return future.result()
|
|
|
|
except RuntimeError:
|
|
# No event loop is running, we can use asyncio.run()
|
|
return asyncio.run(_get_lora_info_async())
|
|
|
|
|
|
def get_lora_info_absolute(lora_name):
|
|
"""Get the absolute lora path and trigger words from cache
|
|
|
|
Returns:
|
|
tuple: (absolute_path, trigger_words) where absolute_path is the full
|
|
file system path to the LoRA file, or original lora_name if not found
|
|
"""
|
|
|
|
async def _get_lora_info_absolute_async():
|
|
scanner = await ServiceRegistry.get_lora_scanner()
|
|
cache = await scanner.get_cached_data()
|
|
|
|
for item in cache.raw_data:
|
|
if item.get("file_name") == lora_name:
|
|
file_path = item.get("file_path")
|
|
if file_path:
|
|
# Return absolute path directly
|
|
# Get trigger words from civitai metadata
|
|
civitai = item.get("civitai", {})
|
|
trigger_words = civitai.get("trainedWords", []) if civitai else []
|
|
return file_path, trigger_words
|
|
return lora_name, []
|
|
|
|
try:
|
|
# Check if we're already in an event loop
|
|
loop = asyncio.get_running_loop()
|
|
# If we're in a running loop, we need to use a different approach
|
|
# Create a new thread to run the async code
|
|
import concurrent.futures
|
|
|
|
def run_in_thread():
|
|
new_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(new_loop)
|
|
try:
|
|
return new_loop.run_until_complete(_get_lora_info_absolute_async())
|
|
finally:
|
|
new_loop.close()
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
future = executor.submit(run_in_thread)
|
|
return future.result()
|
|
|
|
except RuntimeError:
|
|
# No event loop is running, we can use asyncio.run()
|
|
return asyncio.run(_get_lora_info_absolute_async())
|
|
|
|
|
|
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
|
|
"""
|
|
Check if text matches pattern using fuzzy matching.
|
|
Returns True if similarity ratio is above threshold.
|
|
"""
|
|
if not pattern or not text:
|
|
return False
|
|
|
|
# Convert both to lowercase for case-insensitive matching
|
|
text = text.lower()
|
|
pattern = pattern.lower()
|
|
|
|
# Split pattern into words
|
|
search_words = pattern.split()
|
|
|
|
# Check each word
|
|
for word in search_words:
|
|
# First check if word is a substring (faster)
|
|
if word in text:
|
|
continue
|
|
|
|
# If not found as substring, try fuzzy matching
|
|
# Check if any part of the text matches this word
|
|
found_match = False
|
|
for text_part in text.split():
|
|
ratio = SequenceMatcher(None, text_part, word).ratio()
|
|
if ratio >= threshold:
|
|
found_match = True
|
|
break
|
|
|
|
if not found_match:
|
|
return False
|
|
|
|
# All words found either as substrings or fuzzy matches
|
|
return True
|
|
|
|
|
|
def sanitize_folder_name(name: str, replacement: str = "_") -> str:
|
|
"""Sanitize a folder name by removing or replacing invalid characters.
|
|
|
|
Args:
|
|
name: The original folder name.
|
|
replacement: The character to use when replacing invalid characters.
|
|
|
|
Returns:
|
|
A sanitized folder name safe to use across common filesystems.
|
|
"""
|
|
|
|
if not name:
|
|
return ""
|
|
|
|
# Replace invalid characters commonly restricted on Windows and POSIX
|
|
invalid_chars_pattern = r'[<>:"/\\|?*\x00-\x1f]'
|
|
sanitized = re.sub(invalid_chars_pattern, replacement, name)
|
|
|
|
# Trim whitespace introduced during sanitization
|
|
sanitized = sanitized.strip()
|
|
|
|
# Collapse repeated replacement characters to a single instance
|
|
if replacement:
|
|
sanitized = re.sub(f"{re.escape(replacement)}+", replacement, sanitized)
|
|
sanitized = sanitized.strip(replacement)
|
|
|
|
# Remove trailing spaces or periods which are invalid on Windows
|
|
sanitized = sanitized.rstrip(" .")
|
|
|
|
if not sanitized:
|
|
return "unnamed"
|
|
|
|
return sanitized
|
|
|
|
|
|
def calculate_recipe_fingerprint(loras):
|
|
"""
|
|
Calculate a unique fingerprint for a recipe based on its LoRAs.
|
|
|
|
The fingerprint is created by sorting LoRA hashes, filtering invalid entries,
|
|
normalizing strength values to 2 decimal places, and joining in format:
|
|
hash1:strength1|hash2:strength2|...
|
|
|
|
Args:
|
|
loras (list): List of LoRA dictionaries with hash and strength values
|
|
|
|
Returns:
|
|
str: The calculated fingerprint
|
|
"""
|
|
if not loras:
|
|
return ""
|
|
|
|
valid_loras = []
|
|
for lora in loras:
|
|
if lora.get("exclude", False):
|
|
continue
|
|
|
|
hash_value = lora.get("hash", "")
|
|
if isinstance(hash_value, str):
|
|
hash_value = hash_value.lower()
|
|
else:
|
|
hash_value = str(hash_value).lower() if hash_value else ""
|
|
if not hash_value and lora.get("modelVersionId"):
|
|
hash_value = str(lora.get("modelVersionId"))
|
|
|
|
if not hash_value:
|
|
continue
|
|
|
|
# Normalize strength to 2 decimal places (check both strength and weight fields)
|
|
strength_val = lora.get("strength", lora.get("weight", 1.0))
|
|
try:
|
|
strength = round(float(strength_val), 2)
|
|
except (ValueError, TypeError):
|
|
strength = 1.0
|
|
|
|
valid_loras.append((hash_value, strength))
|
|
|
|
# Sort by hash
|
|
valid_loras.sort()
|
|
|
|
# Join in format hash1:strength1|hash2:strength2|...
|
|
fingerprint = "|".join(
|
|
[f"{hash_value}:{strength}" for hash_value, strength in valid_loras]
|
|
)
|
|
|
|
return fingerprint
|
|
|
|
|
|
def calculate_relative_path_for_model(
|
|
model_data: Dict, model_type: str = "lora"
|
|
) -> str:
|
|
"""Calculate relative path for existing model using template from settings
|
|
|
|
Args:
|
|
model_data: Model data from scanner cache
|
|
model_type: Type of model ('lora', 'checkpoint', 'embedding')
|
|
|
|
Returns:
|
|
Relative path string (empty string for flat structure)
|
|
"""
|
|
# Get path template from settings for specific model type
|
|
settings_manager = get_settings_manager()
|
|
path_template = settings_manager.get_download_path_template(model_type)
|
|
|
|
# If template is empty, return empty path (flat structure)
|
|
if not path_template:
|
|
return ""
|
|
|
|
# Get base model name from model metadata
|
|
civitai_data = model_data.get("civitai", {})
|
|
|
|
# For CivitAI models, prefer civitai data only if 'id' exists; for non-CivitAI models, use model_data directly
|
|
if civitai_data and civitai_data.get("id") is not None:
|
|
base_model = model_data.get("base_model", "")
|
|
# Get author from civitai creator data
|
|
creator_info = civitai_data.get("creator") or {}
|
|
author = creator_info.get("username") or "Anonymous"
|
|
else:
|
|
# Fallback to model_data fields for non-CivitAI models
|
|
base_model = model_data.get("base_model", "")
|
|
author = "Anonymous" # Default for non-CivitAI models
|
|
|
|
model_tags = model_data.get("tags", [])
|
|
|
|
# Apply mapping if available
|
|
base_model_mappings = settings_manager.get("base_model_path_mappings", {})
|
|
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
|
|
|
# Convert all tags to lowercase to avoid case sensitivity issues on Windows
|
|
lowercase_tags = [tag.lower() for tag in model_tags if isinstance(tag, str)]
|
|
first_tag = settings_manager.resolve_priority_tag_for_model(
|
|
lowercase_tags, model_type
|
|
)
|
|
|
|
if not first_tag:
|
|
first_tag = "no tags" # Default if no tags available
|
|
|
|
# Format the template with available data
|
|
model_name = sanitize_folder_name(model_data.get("model_name", ""))
|
|
version_name = ""
|
|
|
|
if isinstance(civitai_data, dict):
|
|
version_name = sanitize_folder_name(civitai_data.get("name") or "")
|
|
|
|
formatted_path = path_template
|
|
formatted_path = formatted_path.replace("{base_model}", mapped_base_model)
|
|
formatted_path = formatted_path.replace("{first_tag}", first_tag)
|
|
formatted_path = formatted_path.replace("{author}", author)
|
|
formatted_path = formatted_path.replace("{model_name}", model_name)
|
|
formatted_path = formatted_path.replace("{version_name}", version_name)
|
|
|
|
if model_type == "embedding":
|
|
formatted_path = formatted_path.replace(" ", "_")
|
|
|
|
return formatted_path
|
|
|
|
|
|
def remove_empty_dirs(path):
|
|
"""Recursively remove empty directories starting from the given path.
|
|
|
|
Args:
|
|
path (str): Root directory to start cleaning from
|
|
|
|
Returns:
|
|
int: Number of empty directories removed
|
|
"""
|
|
removed_count = 0
|
|
|
|
if not os.path.isdir(path):
|
|
return removed_count
|
|
|
|
# List all files in directory
|
|
files = os.listdir(path)
|
|
|
|
# Process all subdirectories first
|
|
for file in files:
|
|
full_path = os.path.join(path, file)
|
|
if os.path.isdir(full_path):
|
|
removed_count += remove_empty_dirs(full_path)
|
|
|
|
# Check if directory is now empty (after processing subdirectories)
|
|
if not os.listdir(path):
|
|
try:
|
|
os.rmdir(path)
|
|
removed_count += 1
|
|
except OSError:
|
|
pass
|
|
|
|
return removed_count
|