mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-07 00:46:44 -03:00
404 lines
14 KiB
Python
404 lines
14 KiB
Python
"""Managed wildcard loading, search, and text expansion."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Any, Optional
|
|
|
|
import yaml
|
|
|
|
from ..utils.settings_paths import get_settings_dir
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_WILDCARD_PATTERN = re.compile(r"__([\w\s.\-+/*\\]+?)__")
|
|
_OPTION_PATTERN = re.compile(r"{([^{}]*?)}")
|
|
_TRIGGER_WORD_PATTERN = re.compile(r"^trigger_words\d+$")
|
|
_WEIGHTED_OPTION_PATTERN = re.compile(r"^\s*([0-9.]+)::")
|
|
_NUMERIC_PATTERN = re.compile(r"^-?\d+(\.\d+)?$")
|
|
|
|
|
|
def _normalize_wildcard_key(value: str) -> str:
|
|
return value.replace("\\", "/").strip("/").lower()
|
|
|
|
|
|
def _is_numeric_string(value: str) -> bool:
|
|
return bool(_NUMERIC_PATTERN.match(value))
|
|
|
|
|
|
def get_wildcards_dir(create: bool = False) -> str:
|
|
"""Return the managed wildcard directory inside the settings folder."""
|
|
|
|
settings_dir = get_settings_dir(create=create)
|
|
wildcards_dir = os.path.join(settings_dir, "wildcards")
|
|
if create:
|
|
os.makedirs(wildcards_dir, exist_ok=True)
|
|
return wildcards_dir
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class WildcardEntry:
|
|
key: str
|
|
values_count: int
|
|
|
|
|
|
class WildcardService:
|
|
"""Discover wildcard keys and expand wildcard syntax."""
|
|
|
|
_instance: Optional["WildcardService"] = None
|
|
|
|
def __new__(cls) -> "WildcardService":
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self) -> None:
|
|
if getattr(self, "_initialized", False):
|
|
return
|
|
self._initialized = True
|
|
self._cached_signature: tuple[tuple[str, int, int], ...] | None = None
|
|
self._wildcard_dict: dict[str, list[str]] = {}
|
|
|
|
@classmethod
|
|
def get_instance(cls) -> "WildcardService":
|
|
return cls()
|
|
|
|
def search_keys(
|
|
self, search_term: str, limit: int = 20, offset: int = 0
|
|
) -> list[str]:
|
|
"""Search wildcard keys for autocomplete."""
|
|
|
|
normalized_term = _normalize_wildcard_key(search_term).strip()
|
|
if not normalized_term:
|
|
return []
|
|
|
|
ranked: list[tuple[int, str]] = []
|
|
compact_term = normalized_term.replace("/", "")
|
|
for key in self.get_wildcard_dict().keys():
|
|
score = self._score_entry(key, normalized_term, compact_term)
|
|
if score is not None:
|
|
ranked.append((score, key))
|
|
|
|
ranked.sort(key=lambda item: (-item[0], item[1]))
|
|
keys = [key for _, key in ranked]
|
|
return keys[offset : offset + limit]
|
|
|
|
def expand_text(self, text: str, seed: int | None = None) -> str:
|
|
"""Expand wildcard and dynamic prompt syntax for a text value."""
|
|
|
|
if not isinstance(text, str) or not text:
|
|
return text
|
|
|
|
rng = random.Random(seed) if seed is not None else random.Random()
|
|
wildcard_dict = self.get_wildcard_dict()
|
|
if not wildcard_dict:
|
|
return self._expand_options_only(text, rng)
|
|
|
|
current = text
|
|
remaining_depth = 100
|
|
|
|
while remaining_depth > 0:
|
|
remaining_depth -= 1
|
|
after_options, options_replaced = self._replace_options(current, rng)
|
|
current, wildcards_replaced = self._replace_wildcards(
|
|
after_options, rng, wildcard_dict
|
|
)
|
|
if not options_replaced and not wildcards_replaced:
|
|
break
|
|
|
|
return current
|
|
|
|
def get_wildcard_dict(self) -> dict[str, list[str]]:
|
|
signature = self._build_signature()
|
|
if signature != self._cached_signature:
|
|
self._wildcard_dict = self._scan_wildcard_dict()
|
|
self._cached_signature = signature
|
|
return self._wildcard_dict
|
|
|
|
def get_entries(self) -> list[WildcardEntry]:
|
|
return [
|
|
WildcardEntry(key=key, values_count=len(values))
|
|
for key, values in sorted(self.get_wildcard_dict().items())
|
|
]
|
|
|
|
def _build_signature(self) -> tuple[tuple[str, int, int], ...]:
|
|
root = get_wildcards_dir(create=False)
|
|
if not os.path.isdir(root):
|
|
return ()
|
|
|
|
signature: list[tuple[str, int, int]] = []
|
|
for current_root, _dirs, files in os.walk(root, followlinks=True):
|
|
for file_name in sorted(files):
|
|
if not file_name.lower().endswith((".txt", ".yaml", ".yml", ".json")):
|
|
continue
|
|
file_path = os.path.join(current_root, file_name)
|
|
try:
|
|
stat = os.stat(file_path)
|
|
except OSError:
|
|
continue
|
|
rel_path = os.path.relpath(file_path, root).replace("\\", "/")
|
|
signature.append((rel_path, int(stat.st_mtime_ns), int(stat.st_size)))
|
|
signature.sort()
|
|
return tuple(signature)
|
|
|
|
def _scan_wildcard_dict(self) -> dict[str, list[str]]:
|
|
root = get_wildcards_dir(create=False)
|
|
if not os.path.isdir(root):
|
|
return {}
|
|
|
|
collected: dict[str, list[str]] = {}
|
|
for current_root, _dirs, files in os.walk(root, followlinks=True):
|
|
for file_name in sorted(files):
|
|
file_path = os.path.join(current_root, file_name)
|
|
lower_name = file_name.lower()
|
|
try:
|
|
if lower_name.endswith(".txt"):
|
|
rel_path = os.path.relpath(file_path, root)
|
|
key = _normalize_wildcard_key(os.path.splitext(rel_path)[0])
|
|
values = self._read_txt(file_path)
|
|
if values:
|
|
collected[key] = values
|
|
elif lower_name.endswith((".yaml", ".yml")):
|
|
payload = self._read_yaml(file_path)
|
|
self._merge_nested_entries(collected, payload)
|
|
elif lower_name.endswith(".json"):
|
|
payload = self._read_json(file_path)
|
|
self._merge_nested_entries(collected, payload)
|
|
except Exception as exc: # pragma: no cover - defensive logging
|
|
logger.warning("Failed to load wildcard file %s: %s", file_path, exc)
|
|
|
|
return collected
|
|
|
|
def _read_txt(self, file_path: str) -> list[str]:
|
|
try:
|
|
with open(file_path, "r", encoding="utf-8", errors="ignore") as handle:
|
|
return [line.strip() for line in handle.read().splitlines() if line.strip()]
|
|
except OSError as exc:
|
|
logger.warning("Failed to read wildcard txt file %s: %s", file_path, exc)
|
|
return []
|
|
|
|
def _read_yaml(self, file_path: str) -> Any:
|
|
with open(file_path, "r", encoding="utf-8") as handle:
|
|
return yaml.safe_load(handle) or {}
|
|
|
|
def _read_json(self, file_path: str) -> Any:
|
|
with open(file_path, "r", encoding="utf-8") as handle:
|
|
return json.load(handle)
|
|
|
|
def _merge_nested_entries(
|
|
self, collected: dict[str, list[str]], payload: Any
|
|
) -> None:
|
|
for key, values in self._flatten_payload(payload):
|
|
collected[key] = values
|
|
|
|
def _flatten_payload(
|
|
self, payload: Any, prefix: str = ""
|
|
) -> list[tuple[str, list[str]]]:
|
|
entries: list[tuple[str, list[str]]] = []
|
|
|
|
if isinstance(payload, dict):
|
|
for key, value in payload.items():
|
|
next_prefix = f"{prefix}/{key}" if prefix else str(key)
|
|
entries.extend(self._flatten_payload(value, next_prefix))
|
|
return entries
|
|
|
|
if isinstance(payload, list):
|
|
normalized_prefix = _normalize_wildcard_key(prefix)
|
|
values = [value.strip() for value in payload if isinstance(value, str) and value.strip()]
|
|
if normalized_prefix and values:
|
|
entries.append((normalized_prefix, values))
|
|
return entries
|
|
|
|
return entries
|
|
|
|
def _score_entry(
|
|
self, key: str, normalized_term: str, compact_term: str
|
|
) -> int | None:
|
|
key_compact = key.replace("/", "")
|
|
if key == normalized_term:
|
|
return 5000
|
|
if key.startswith(normalized_term):
|
|
return 4000
|
|
if f"/{normalized_term}" in key:
|
|
return 3500
|
|
if normalized_term in key:
|
|
return 3000
|
|
if compact_term and key_compact.startswith(compact_term):
|
|
return 2500
|
|
if compact_term and compact_term in key_compact:
|
|
return 2000
|
|
return None
|
|
|
|
def _expand_options_only(self, text: str, rng: random.Random) -> str:
|
|
current = text
|
|
remaining_depth = 100
|
|
while remaining_depth > 0:
|
|
remaining_depth -= 1
|
|
current, replaced = self._replace_options(current, rng)
|
|
if not replaced:
|
|
break
|
|
return current
|
|
|
|
def _replace_options(
|
|
self, text: str, rng: random.Random
|
|
) -> tuple[str, bool]:
|
|
replaced_any = False
|
|
|
|
def replace_option(match: re.Match[str]) -> str:
|
|
nonlocal replaced_any
|
|
replacement = self._resolve_option_group(match.group(1), rng)
|
|
replaced_any = True
|
|
return replacement
|
|
|
|
return _OPTION_PATTERN.sub(replace_option, text), replaced_any
|
|
|
|
def _resolve_option_group(self, group_text: str, rng: random.Random) -> str:
|
|
options = group_text.split("|")
|
|
multi_select_pattern = options[0].split("$$")
|
|
select_range: tuple[int, int] | None = None
|
|
select_separator = " "
|
|
|
|
if len(multi_select_pattern) > 1:
|
|
count_spec = multi_select_pattern[0]
|
|
range_match = re.match(r"(\d+)(-(\d+))?$", count_spec)
|
|
shorthand_match = re.match(r"-(\d+)$", count_spec)
|
|
if range_match:
|
|
start_text = range_match.group(1)
|
|
end_text = range_match.group(3)
|
|
if end_text is not None and _is_numeric_string(start_text) and _is_numeric_string(end_text):
|
|
select_range = (int(start_text), int(end_text))
|
|
elif _is_numeric_string(start_text):
|
|
value = int(start_text)
|
|
select_range = (value, value)
|
|
elif shorthand_match:
|
|
end_text = shorthand_match.group(1)
|
|
if _is_numeric_string(end_text):
|
|
select_range = (1, int(end_text))
|
|
|
|
if select_range is not None and len(multi_select_pattern) == 2:
|
|
options[0] = multi_select_pattern[1]
|
|
elif select_range is not None and len(multi_select_pattern) >= 3:
|
|
select_separator = multi_select_pattern[1]
|
|
options[0] = multi_select_pattern[2]
|
|
|
|
weighted_options: list[tuple[float, str]] = []
|
|
for option in options:
|
|
weight = 1.0
|
|
parts = option.split("::", 1)
|
|
if len(parts) == 2 and _is_numeric_string(parts[0].strip()):
|
|
weight = float(parts[0].strip())
|
|
weighted_options.append((weight, option))
|
|
|
|
if select_range is None:
|
|
selection_count = 1
|
|
else:
|
|
selection_count = rng.randint(select_range[0], select_range[1])
|
|
|
|
if selection_count <= 1:
|
|
return self._strip_weight_prefix(self._weighted_choice(weighted_options, rng))
|
|
|
|
selection_count = min(selection_count, len(weighted_options))
|
|
selected: list[str] = []
|
|
used_indexes: set[int] = set()
|
|
while len(selected) < selection_count:
|
|
picked_index = self._weighted_choice_index(weighted_options, rng)
|
|
if picked_index in used_indexes:
|
|
if len(used_indexes) == len(weighted_options):
|
|
break
|
|
continue
|
|
used_indexes.add(picked_index)
|
|
selected.append(
|
|
self._strip_weight_prefix(weighted_options[picked_index][1])
|
|
)
|
|
|
|
return select_separator.join(selected)
|
|
|
|
def _weighted_choice(
|
|
self, weighted_options: list[tuple[float, str]], rng: random.Random
|
|
) -> str:
|
|
return weighted_options[self._weighted_choice_index(weighted_options, rng)][1]
|
|
|
|
def _weighted_choice_index(
|
|
self, weighted_options: list[tuple[float, str]], rng: random.Random
|
|
) -> int:
|
|
total_weight = sum(max(weight, 0.0) for weight, _value in weighted_options)
|
|
if total_weight <= 0:
|
|
return rng.randrange(len(weighted_options))
|
|
|
|
threshold = rng.uniform(0, total_weight)
|
|
cumulative = 0.0
|
|
for index, (weight, _value) in enumerate(weighted_options):
|
|
cumulative += max(weight, 0.0)
|
|
if threshold <= cumulative:
|
|
return index
|
|
return len(weighted_options) - 1
|
|
|
|
def _strip_weight_prefix(self, value: str) -> str:
|
|
return _WEIGHTED_OPTION_PATTERN.sub("", value, count=1)
|
|
|
|
def _replace_wildcards(
|
|
self,
|
|
text: str,
|
|
rng: random.Random,
|
|
wildcard_dict: dict[str, list[str]],
|
|
) -> tuple[str, bool]:
|
|
replaced_any = False
|
|
|
|
def replace_match(match: re.Match[str]) -> str:
|
|
nonlocal replaced_any
|
|
replacement = self._resolve_wildcard_match(match.group(1), rng, wildcard_dict)
|
|
if replacement is None:
|
|
return match.group(0)
|
|
replaced_any = True
|
|
return replacement
|
|
|
|
return _WILDCARD_PATTERN.sub(replace_match, text), replaced_any
|
|
|
|
def _resolve_wildcard_match(
|
|
self,
|
|
raw_key: str,
|
|
rng: random.Random,
|
|
wildcard_dict: dict[str, list[str]],
|
|
) -> str | None:
|
|
keyword = _normalize_wildcard_key(raw_key)
|
|
if keyword in wildcard_dict:
|
|
return rng.choice(wildcard_dict[keyword])
|
|
|
|
if "*" in keyword:
|
|
regex_pattern = keyword.replace("*", ".*").replace("+", r"\+")
|
|
compiled = re.compile(f"^{regex_pattern}$")
|
|
aggregated: list[str] = []
|
|
for key, values in wildcard_dict.items():
|
|
if compiled.match(key):
|
|
aggregated.extend(values)
|
|
if aggregated:
|
|
return rng.choice(aggregated)
|
|
|
|
if "/" not in keyword:
|
|
fallback_keyword = _normalize_wildcard_key(f"*/{keyword}")
|
|
if fallback_keyword != keyword:
|
|
return self._resolve_wildcard_match(fallback_keyword, rng, wildcard_dict)
|
|
|
|
return None
|
|
|
|
|
|
def is_trigger_words_input(name: str) -> bool:
|
|
return bool(_TRIGGER_WORD_PATTERN.match(name))
|
|
|
|
|
|
def get_wildcard_service() -> WildcardService:
|
|
return WildcardService.get_instance()
|
|
|
|
|
|
__all__ = [
|
|
"WildcardService",
|
|
"get_wildcard_service",
|
|
"get_wildcards_dir",
|
|
"is_trigger_words_input",
|
|
]
|