fix(extra-paths): support trigger words for LoRAs in extra folder paths, fixes #860

- Update get_lora_info() to check both loras_roots and extra_loras_roots
- Add fallback logic to return trigger words even if path not in recognized roots
- Ensure Trigger Word Toggle node displays trigger words for LoRAs from extra folder paths

Fixes issue where LoRAs added from extra folder paths would not show their trigger words in connected Trigger Word Toggle nodes.
This commit is contained in:
Will Miao
2026-03-16 09:36:01 +08:00
parent 55a18d401b
commit c89d4dae85

View File

@@ -7,33 +7,47 @@ from ..config import config
from ..services.settings_manager import get_settings_manager from ..services.settings_manager import get_settings_manager
import asyncio import asyncio
def get_lora_info(lora_name): def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache""" """Get the lora path and trigger words from cache"""
async def _get_lora_info_async(): async def _get_lora_info_async():
scanner = await ServiceRegistry.get_lora_scanner() scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
for item in cache.raw_data: for item in cache.raw_data:
if item.get('file_name') == lora_name: if item.get("file_name") == lora_name:
file_path = item.get('file_path') file_path = item.get("file_path")
if file_path: if file_path:
for root in config.loras_roots: # Check all lora roots including extra paths
root = root.replace(os.sep, '/') all_roots = list(config.loras_roots or []) + list(
config.extra_loras_roots or []
)
for root in all_roots:
root = root.replace(os.sep, "/")
if file_path.startswith(root): if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/') relative_path = os.path.relpath(file_path, root).replace(
os.sep, "/"
)
# Get trigger words from civitai metadata # Get trigger words from civitai metadata
civitai = item.get('civitai', {}) civitai = item.get("civitai", {})
trigger_words = civitai.get('trainedWords', []) if civitai else [] trigger_words = (
civitai.get("trainedWords", []) if civitai else []
)
return relative_path, trigger_words return relative_path, trigger_words
# If not found in any root, return path with trigger words from cache
civitai = item.get("civitai", {})
trigger_words = civitai.get("trainedWords", []) if civitai else []
return file_path, trigger_words
return lora_name, [] return lora_name, []
try: try:
# Check if we're already in an event loop # Check if we're already in an event loop
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# If we're in a running loop, we need to use a different approach # If we're in a running loop, we need to use a different approach
# Create a new thread to run the async code # Create a new thread to run the async code
import concurrent.futures import concurrent.futures
def run_in_thread(): def run_in_thread():
new_loop = asyncio.new_event_loop() new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop) asyncio.set_event_loop(new_loop)
@@ -41,11 +55,11 @@ def get_lora_info(lora_name):
return new_loop.run_until_complete(_get_lora_info_async()) return new_loop.run_until_complete(_get_lora_info_async())
finally: finally:
new_loop.close() new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread) future = executor.submit(run_in_thread)
return future.result() return future.result()
except RuntimeError: except RuntimeError:
# No event loop is running, we can use asyncio.run() # No event loop is running, we can use asyncio.run()
return asyncio.run(_get_lora_info_async()) return asyncio.run(_get_lora_info_async())
@@ -53,33 +67,34 @@ def get_lora_info(lora_name):
def get_lora_info_absolute(lora_name): def get_lora_info_absolute(lora_name):
"""Get the absolute lora path and trigger words from cache """Get the absolute lora path and trigger words from cache
Returns: Returns:
tuple: (absolute_path, trigger_words) where absolute_path is the full tuple: (absolute_path, trigger_words) where absolute_path is the full
file system path to the LoRA file, or original lora_name if not found file system path to the LoRA file, or original lora_name if not found
""" """
async def _get_lora_info_absolute_async(): async def _get_lora_info_absolute_async():
scanner = await ServiceRegistry.get_lora_scanner() scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
for item in cache.raw_data: for item in cache.raw_data:
if item.get('file_name') == lora_name: if item.get("file_name") == lora_name:
file_path = item.get('file_path') file_path = item.get("file_path")
if file_path: if file_path:
# Return absolute path directly # Return absolute path directly
# Get trigger words from civitai metadata # Get trigger words from civitai metadata
civitai = item.get('civitai', {}) civitai = item.get("civitai", {})
trigger_words = civitai.get('trainedWords', []) if civitai else [] trigger_words = civitai.get("trainedWords", []) if civitai else []
return file_path, trigger_words return file_path, trigger_words
return lora_name, [] return lora_name, []
try: try:
# Check if we're already in an event loop # Check if we're already in an event loop
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
# If we're in a running loop, we need to use a different approach # If we're in a running loop, we need to use a different approach
# Create a new thread to run the async code # Create a new thread to run the async code
import concurrent.futures import concurrent.futures
def run_in_thread(): def run_in_thread():
new_loop = asyncio.new_event_loop() new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop) asyncio.set_event_loop(new_loop)
@@ -87,50 +102,52 @@ def get_lora_info_absolute(lora_name):
return new_loop.run_until_complete(_get_lora_info_absolute_async()) return new_loop.run_until_complete(_get_lora_info_absolute_async())
finally: finally:
new_loop.close() new_loop.close()
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread) future = executor.submit(run_in_thread)
return future.result() return future.result()
except RuntimeError: except RuntimeError:
# No event loop is running, we can use asyncio.run() # No event loop is running, we can use asyncio.run()
return asyncio.run(_get_lora_info_absolute_async()) return asyncio.run(_get_lora_info_absolute_async())
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool: def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
""" """
Check if text matches pattern using fuzzy matching. Check if text matches pattern using fuzzy matching.
Returns True if similarity ratio is above threshold. Returns True if similarity ratio is above threshold.
""" """
if not pattern or not text: if not pattern or not text:
return False
# Convert both to lowercase for case-insensitive matching
text = text.lower()
pattern = pattern.lower()
# Split pattern into words
search_words = pattern.split()
# Check each word
for word in search_words:
# First check if word is a substring (faster)
if word in text:
continue
# If not found as substring, try fuzzy matching
# Check if any part of the text matches this word
found_match = False
for text_part in text.split():
ratio = SequenceMatcher(None, text_part, word).ratio()
if ratio >= threshold:
found_match = True
break
if not found_match:
return False return False
# Convert both to lowercase for case-insensitive matching # All words found either as substrings or fuzzy matches
text = text.lower() return True
pattern = pattern.lower()
# Split pattern into words
search_words = pattern.split()
# Check each word
for word in search_words:
# First check if word is a substring (faster)
if word in text:
continue
# If not found as substring, try fuzzy matching
# Check if any part of the text matches this word
found_match = False
for text_part in text.split():
ratio = SequenceMatcher(None, text_part, word).ratio()
if ratio >= threshold:
found_match = True
break
if not found_match:
return False
# All words found either as substrings or fuzzy matches
return True
def sanitize_folder_name(name: str, replacement: str = "_") -> str: def sanitize_folder_name(name: str, replacement: str = "_") -> str:
"""Sanitize a folder name by removing or replacing invalid characters. """Sanitize a folder name by removing or replacing invalid characters.
@@ -170,25 +187,25 @@ def sanitize_folder_name(name: str, replacement: str = "_") -> str:
def calculate_recipe_fingerprint(loras): def calculate_recipe_fingerprint(loras):
""" """
Calculate a unique fingerprint for a recipe based on its LoRAs. Calculate a unique fingerprint for a recipe based on its LoRAs.
The fingerprint is created by sorting LoRA hashes, filtering invalid entries, The fingerprint is created by sorting LoRA hashes, filtering invalid entries,
normalizing strength values to 2 decimal places, and joining in format: normalizing strength values to 2 decimal places, and joining in format:
hash1:strength1|hash2:strength2|... hash1:strength1|hash2:strength2|...
Args: Args:
loras (list): List of LoRA dictionaries with hash and strength values loras (list): List of LoRA dictionaries with hash and strength values
Returns: Returns:
str: The calculated fingerprint str: The calculated fingerprint
""" """
if not loras: if not loras:
return "" return ""
valid_loras = [] valid_loras = []
for lora in loras: for lora in loras:
if lora.get("exclude", False): if lora.get("exclude", False):
continue continue
hash_value = lora.get("hash", "") hash_value = lora.get("hash", "")
if isinstance(hash_value, str): if isinstance(hash_value, str):
hash_value = hash_value.lower() hash_value = hash_value.lower()
@@ -206,18 +223,23 @@ def calculate_recipe_fingerprint(loras):
strength = round(float(strength_val), 2) strength = round(float(strength_val), 2)
except (ValueError, TypeError): except (ValueError, TypeError):
strength = 1.0 strength = 1.0
valid_loras.append((hash_value, strength)) valid_loras.append((hash_value, strength))
# Sort by hash # Sort by hash
valid_loras.sort() valid_loras.sort()
# Join in format hash1:strength1|hash2:strength2|... # Join in format hash1:strength1|hash2:strength2|...
fingerprint = "|".join([f"{hash_value}:{strength}" for hash_value, strength in valid_loras]) fingerprint = "|".join(
[f"{hash_value}:{strength}" for hash_value, strength in valid_loras]
)
return fingerprint return fingerprint
def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora') -> str:
def calculate_relative_path_for_model(
model_data: Dict, model_type: str = "lora"
) -> str:
"""Calculate relative path for existing model using template from settings """Calculate relative path for existing model using template from settings
Args: Args:
@@ -233,77 +255,80 @@ def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora'
# If template is empty, return empty path (flat structure) # If template is empty, return empty path (flat structure)
if not path_template: if not path_template:
return '' return ""
# Get base model name from model metadata # Get base model name from model metadata
civitai_data = model_data.get('civitai', {}) civitai_data = model_data.get("civitai", {})
# For CivitAI models, prefer civitai data only if 'id' exists; for non-CivitAI models, use model_data directly # For CivitAI models, prefer civitai data only if 'id' exists; for non-CivitAI models, use model_data directly
if civitai_data and civitai_data.get('id') is not None: if civitai_data and civitai_data.get("id") is not None:
base_model = model_data.get('base_model', '') base_model = model_data.get("base_model", "")
# Get author from civitai creator data # Get author from civitai creator data
creator_info = civitai_data.get('creator') or {} creator_info = civitai_data.get("creator") or {}
author = creator_info.get('username') or 'Anonymous' author = creator_info.get("username") or "Anonymous"
else: else:
# Fallback to model_data fields for non-CivitAI models # Fallback to model_data fields for non-CivitAI models
base_model = model_data.get('base_model', '') base_model = model_data.get("base_model", "")
author = 'Anonymous' # Default for non-CivitAI models author = "Anonymous" # Default for non-CivitAI models
model_tags = model_data.get('tags', []) model_tags = model_data.get("tags", [])
# Apply mapping if available # Apply mapping if available
base_model_mappings = settings_manager.get('base_model_path_mappings', {}) base_model_mappings = settings_manager.get("base_model_path_mappings", {})
mapped_base_model = base_model_mappings.get(base_model, base_model) mapped_base_model = base_model_mappings.get(base_model, base_model)
# Convert all tags to lowercase to avoid case sensitivity issues on Windows # Convert all tags to lowercase to avoid case sensitivity issues on Windows
lowercase_tags = [tag.lower() for tag in model_tags if isinstance(tag, str)] lowercase_tags = [tag.lower() for tag in model_tags if isinstance(tag, str)]
first_tag = settings_manager.resolve_priority_tag_for_model(lowercase_tags, model_type) first_tag = settings_manager.resolve_priority_tag_for_model(
lowercase_tags, model_type
)
if not first_tag: if not first_tag:
first_tag = 'no tags' # Default if no tags available first_tag = "no tags" # Default if no tags available
# Format the template with available data # Format the template with available data
model_name = sanitize_folder_name(model_data.get('model_name', '')) model_name = sanitize_folder_name(model_data.get("model_name", ""))
version_name = '' version_name = ""
if isinstance(civitai_data, dict): if isinstance(civitai_data, dict):
version_name = sanitize_folder_name(civitai_data.get('name') or '') version_name = sanitize_folder_name(civitai_data.get("name") or "")
formatted_path = path_template formatted_path = path_template
formatted_path = formatted_path.replace('{base_model}', mapped_base_model) formatted_path = formatted_path.replace("{base_model}", mapped_base_model)
formatted_path = formatted_path.replace('{first_tag}', first_tag) formatted_path = formatted_path.replace("{first_tag}", first_tag)
formatted_path = formatted_path.replace('{author}', author) formatted_path = formatted_path.replace("{author}", author)
formatted_path = formatted_path.replace('{model_name}', model_name) formatted_path = formatted_path.replace("{model_name}", model_name)
formatted_path = formatted_path.replace('{version_name}', version_name) formatted_path = formatted_path.replace("{version_name}", version_name)
if model_type == 'embedding': if model_type == "embedding":
formatted_path = formatted_path.replace(' ', '_') formatted_path = formatted_path.replace(" ", "_")
return formatted_path return formatted_path
def remove_empty_dirs(path): def remove_empty_dirs(path):
"""Recursively remove empty directories starting from the given path. """Recursively remove empty directories starting from the given path.
Args: Args:
path (str): Root directory to start cleaning from path (str): Root directory to start cleaning from
Returns: Returns:
int: Number of empty directories removed int: Number of empty directories removed
""" """
removed_count = 0 removed_count = 0
if not os.path.isdir(path): if not os.path.isdir(path):
return removed_count return removed_count
# List all files in directory # List all files in directory
files = os.listdir(path) files = os.listdir(path)
# Process all subdirectories first # Process all subdirectories first
for file in files: for file in files:
full_path = os.path.join(path, file) full_path = os.path.join(path, file)
if os.path.isdir(full_path): if os.path.isdir(full_path):
removed_count += remove_empty_dirs(full_path) removed_count += remove_empty_dirs(full_path)
# Check if directory is now empty (after processing subdirectories) # Check if directory is now empty (after processing subdirectories)
if not os.listdir(path): if not os.listdir(path):
try: try:
@@ -311,5 +336,5 @@ def remove_empty_dirs(path):
removed_count += 1 removed_count += 1
except OSError: except OSError:
pass pass
return removed_count return removed_count