feat(prompt): expand wildcards at runtime (#895)

This commit is contained in:
Will Miao
2026-04-15 20:42:27 +08:00
parent 6d0d9600a7
commit 62247bdd87
15 changed files with 831 additions and 31 deletions

View File

@@ -1,15 +1,34 @@
from __future__ import annotations
from typing import Any from typing import Any
import inspect import inspect
from ..services.wildcard_service import get_wildcard_service, is_trigger_words_input
class _AllContainer:
"""Container that accepts any key for dynamic input validation."""
def __contains__(self, item): class _PromptOptionalInputs:
return True """Lookup that preserves explicit optional inputs and dynamic trigger slots."""
def __getitem__(self, key): def __init__(self, explicit_inputs: dict[str, tuple[str, dict[str, Any]]]) -> None:
return ("STRING", {"forceInput": True}) self._explicit_inputs = explicit_inputs
def __contains__(self, item: object) -> bool:
if not isinstance(item, str):
return False
return item in self._explicit_inputs or is_trigger_words_input(item)
def __getitem__(self, key: str) -> tuple[str, dict[str, Any]]:
if key in self._explicit_inputs:
return self._explicit_inputs[key]
if is_trigger_words_input(key):
return (
"STRING",
{
"forceInput": True,
"tooltip": "Trigger words to prepend. Connect to add more inputs.",
},
)
raise KeyError(key)
class PromptLM: class PromptLM:
@@ -20,12 +39,19 @@ class PromptLM:
DESCRIPTION = ( DESCRIPTION = (
"Encodes a text prompt using a CLIP model into an embedding that can be used " "Encodes a text prompt using a CLIP model into an embedding that can be used "
"to guide the diffusion model towards generating specific images. " "to guide the diffusion model towards generating specific images. "
"Supports dynamic trigger words inputs." "Supports dynamic trigger words inputs and runtime wildcard expansion."
) )
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
dyn_inputs = { optional_inputs: dict[str, tuple[str, dict[str, Any]]] = {
"seed": (
"INT",
{
"forceInput": True,
"tooltip": "Optional seed for wildcard generation. Leave unconnected for non-deterministic wildcard expansion.",
},
),
"trigger_words1": ( "trigger_words1": (
"STRING", "STRING",
{ {
@@ -35,10 +61,9 @@ class PromptLM:
), ),
} }
# Bypass validation for dynamic inputs during graph execution
stack = inspect.stack() stack = inspect.stack()
if len(stack) > 2 and stack[2].function == "get_input_info": if len(stack) > 2 and stack[2].function == "get_input_info":
dyn_inputs = _AllContainer() optional_inputs = _PromptOptionalInputs(optional_inputs) # type: ignore[assignment]
return { return {
"required": { "required": {
@@ -46,8 +71,8 @@ class PromptLM:
"AUTOCOMPLETE_TEXT_PROMPT,STRING", "AUTOCOMPLETE_TEXT_PROMPT,STRING",
{ {
"widgetType": "AUTOCOMPLETE_TEXT_PROMPT", "widgetType": "AUTOCOMPLETE_TEXT_PROMPT",
"placeholder": "Enter prompt... /char, /artist for quick tag search", "placeholder": "Enter prompt... /char, /artist, /wild for quick search",
"tooltip": "The text to be encoded.", "tooltip": "The text to be encoded. Wildcard references inserted with /wild are expanded at runtime.",
}, },
), ),
"clip": ( "clip": (
@@ -55,7 +80,7 @@ class PromptLM:
{"tooltip": "The CLIP model used for encoding the text."}, {"tooltip": "The CLIP model used for encoding the text."},
), ),
}, },
"optional": dyn_inputs, "optional": optional_inputs,
} }
RETURN_TYPES = ("CONDITIONING", "STRING") RETURN_TYPES = ("CONDITIONING", "STRING")
@@ -65,20 +90,26 @@ class PromptLM:
) )
FUNCTION = "encode" FUNCTION = "encode"
def encode(self, text: str, clip: Any, **kwargs): def encode(
# Collect all trigger words from dynamic inputs self,
text: str,
clip: Any,
seed: int | None = None,
**kwargs: Any,
):
expanded_text = get_wildcard_service().expand_text(text, seed=seed)
trigger_words = [] trigger_words = []
for key, value in kwargs.items(): for key, value in kwargs.items():
if key.startswith("trigger_words") and value: if is_trigger_words_input(key) and value:
trigger_words.append(value) trigger_words.append(value)
# Build final prompt
if trigger_words: if trigger_words:
prompt = ", ".join(trigger_words + [text]) prompt = ", ".join(trigger_words + [expanded_text])
else: else:
prompt = text prompt = expanded_text
from nodes import CLIPTextEncode # type: ignore from nodes import CLIPTextEncode # type: ignore
conditioning = CLIPTextEncode().encode(clip, prompt)[0] conditioning = CLIPTextEncode().encode(clip, prompt)[0]
return (conditioning, prompt) return (conditioning, prompt)

View File

@@ -1,10 +1,15 @@
from __future__ import annotations
from ..services.wildcard_service import get_wildcard_service
class TextLM: class TextLM:
"""A simple text node with autocomplete support.""" """A simple text node with autocomplete support."""
NAME = "Text (LoraManager)" NAME = "Text (LoraManager)"
CATEGORY = "Lora Manager/utils" CATEGORY = "Lora Manager/utils"
DESCRIPTION = ( DESCRIPTION = (
"A simple text input node with autocomplete support for tags and styles." "A simple text input node with autocomplete support for tags, styles, and wildcard expansion."
) )
@classmethod @classmethod
@@ -15,8 +20,17 @@ class TextLM:
"AUTOCOMPLETE_TEXT_PROMPT,STRING", "AUTOCOMPLETE_TEXT_PROMPT,STRING",
{ {
"widgetType": "AUTOCOMPLETE_TEXT_PROMPT", "widgetType": "AUTOCOMPLETE_TEXT_PROMPT",
"placeholder": "Enter text... /char, /artist for quick tag search", "placeholder": "Enter text... /char, /artist, /wild for quick search",
"tooltip": "The text output.", "tooltip": "The text output. Wildcard references inserted with /wild are expanded at runtime.",
},
),
},
"optional": {
"seed": (
"INT",
{
"forceInput": True,
"tooltip": "Optional seed for wildcard generation. Leave unconnected for non-deterministic wildcard expansion.",
}, },
), ),
}, },
@@ -24,10 +38,8 @@ class TextLM:
RETURN_TYPES = ("STRING",) RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("STRING",) RETURN_NAMES = ("STRING",)
OUTPUT_TOOLTIPS = ( OUTPUT_TOOLTIPS = ("The text output.",)
"The text output.",
)
FUNCTION = "process" FUNCTION = "process"
def process(self, text: str): def process(self, text: str, seed: int | None = None):
return (text,) return (get_wildcard_service().expand_text(text, seed=seed),)

View File

@@ -2489,6 +2489,30 @@ class CustomWordsHandler:
return None return None
class WildcardsHandler:
"""Handler for wildcard autocomplete search."""
def __init__(self, *, service=None) -> None:
if service is None:
from ...services.wildcard_service import get_wildcard_service
service = get_wildcard_service()
self._service = service
async def search_wildcards(self, request: web.Request) -> web.Response:
"""Search managed wildcard keys for autocomplete."""
try:
search_term = request.query.get("search", "")
limit = min(int(request.query.get("limit", "20")), 100)
offset = max(0, int(request.query.get("offset", "0")))
results = self._service.search_keys(search_term, limit=limit, offset=offset)
return web.json_response({"success": True, "words": results})
except Exception as exc:
logger.error("Error searching wildcards: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
class NodeRegistryHandler: class NodeRegistryHandler:
def __init__( def __init__(
self, self,
@@ -2717,6 +2741,7 @@ class MiscHandlerSet:
backup: BackupHandler, backup: BackupHandler,
filesystem: FileSystemHandler, filesystem: FileSystemHandler,
custom_words: CustomWordsHandler, custom_words: CustomWordsHandler,
wildcards: WildcardsHandler,
supporters: SupportersHandler, supporters: SupportersHandler,
doctor: DoctorHandler, doctor: DoctorHandler,
example_workflows: ExampleWorkflowsHandler, example_workflows: ExampleWorkflowsHandler,
@@ -2734,6 +2759,7 @@ class MiscHandlerSet:
self.backup = backup self.backup = backup
self.filesystem = filesystem self.filesystem = filesystem
self.custom_words = custom_words self.custom_words = custom_words
self.wildcards = wildcards
self.supporters = supporters self.supporters = supporters
self.doctor = doctor self.doctor = doctor
self.example_workflows = example_workflows self.example_workflows = example_workflows
@@ -2775,6 +2801,7 @@ class MiscHandlerSet:
"open_settings_location": self.filesystem.open_settings_location, "open_settings_location": self.filesystem.open_settings_location,
"open_backup_location": self.filesystem.open_backup_location, "open_backup_location": self.filesystem.open_backup_location,
"search_custom_words": self.custom_words.search_custom_words, "search_custom_words": self.custom_words.search_custom_words,
"search_wildcards": self.wildcards.search_wildcards,
"get_supporters": self.supporters.get_supporters, "get_supporters": self.supporters.get_supporters,
"get_example_workflows": self.example_workflows.get_example_workflows, "get_example_workflows": self.example_workflows.get_example_workflows,
"get_example_workflow": self.example_workflows.get_example_workflow, "get_example_workflow": self.example_workflows.get_example_workflow,

View File

@@ -30,6 +30,7 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/settings/libraries/activate", "activate_library"), RouteDefinition("POST", "/api/lm/settings/libraries/activate", "activate_library"),
RouteDefinition("GET", "/api/lm/health-check", "health_check"), RouteDefinition("GET", "/api/lm/health-check", "health_check"),
RouteDefinition("GET", "/api/lm/supporters", "get_supporters"), RouteDefinition("GET", "/api/lm/supporters", "get_supporters"),
RouteDefinition("GET", "/api/lm/wildcards/search", "search_wildcards"),
RouteDefinition("POST", "/api/lm/open-file-location", "open_file_location"), RouteDefinition("POST", "/api/lm/open-file-location", "open_file_location"),
RouteDefinition("POST", "/api/lm/update-usage-stats", "update_usage_stats"), RouteDefinition("POST", "/api/lm/update-usage-stats", "update_usage_stats"),
RouteDefinition("GET", "/api/lm/get-usage-stats", "get_usage_stats"), RouteDefinition("GET", "/api/lm/get-usage-stats", "get_usage_stats"),

View File

@@ -35,6 +35,7 @@ from .handlers.misc_handlers import (
SupportersHandler, SupportersHandler,
TrainedWordsHandler, TrainedWordsHandler,
UsageStatsHandler, UsageStatsHandler,
WildcardsHandler,
build_service_registry_adapter, build_service_registry_adapter,
) )
from .handlers.base_model_handlers import BaseModelHandlerSet from .handlers.base_model_handlers import BaseModelHandlerSet
@@ -130,6 +131,7 @@ class MiscRoutes:
metadata_provider_factory=self._metadata_provider_factory, metadata_provider_factory=self._metadata_provider_factory,
) )
custom_words = CustomWordsHandler() custom_words = CustomWordsHandler()
wildcards = WildcardsHandler()
supporters = SupportersHandler() supporters = SupportersHandler()
doctor = DoctorHandler(settings_service=self._settings) doctor = DoctorHandler(settings_service=self._settings)
example_workflows = ExampleWorkflowsHandler() example_workflows = ExampleWorkflowsHandler()
@@ -148,6 +150,7 @@ class MiscRoutes:
backup=backup, backup=backup,
filesystem=filesystem, filesystem=filesystem,
custom_words=custom_words, custom_words=custom_words,
wildcards=wildcards,
supporters=supporters, supporters=supporters,
doctor=doctor, doctor=doctor,
example_workflows=example_workflows, example_workflows=example_workflows,

View File

@@ -0,0 +1,403 @@
"""Managed wildcard loading, search, and text expansion."""
from __future__ import annotations
import json
import logging
import os
import random
import re
from dataclasses import dataclass
from typing import Any, Optional
import yaml
from ..utils.settings_paths import get_settings_dir
logger = logging.getLogger(__name__)
_WILDCARD_PATTERN = re.compile(r"__([\w\s.\-+/*\\]+?)__")
_OPTION_PATTERN = re.compile(r"{([^{}]*?)}")
_TRIGGER_WORD_PATTERN = re.compile(r"^trigger_words\d+$")
_WEIGHTED_OPTION_PATTERN = re.compile(r"^\s*([0-9.]+)::")
_NUMERIC_PATTERN = re.compile(r"^-?\d+(\.\d+)?$")
def _normalize_wildcard_key(value: str) -> str:
return value.replace("\\", "/").strip("/").lower()
def _is_numeric_string(value: str) -> bool:
return bool(_NUMERIC_PATTERN.match(value))
def get_wildcards_dir(create: bool = False) -> str:
"""Return the managed wildcard directory inside the settings folder."""
settings_dir = get_settings_dir(create=create)
wildcards_dir = os.path.join(settings_dir, "wildcards")
if create:
os.makedirs(wildcards_dir, exist_ok=True)
return wildcards_dir
@dataclass(frozen=True)
class WildcardEntry:
key: str
values_count: int
class WildcardService:
"""Discover wildcard keys and expand wildcard syntax."""
_instance: Optional["WildcardService"] = None
def __new__(cls) -> "WildcardService":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if getattr(self, "_initialized", False):
return
self._initialized = True
self._cached_signature: tuple[tuple[str, int, int], ...] | None = None
self._wildcard_dict: dict[str, list[str]] = {}
@classmethod
def get_instance(cls) -> "WildcardService":
return cls()
def search_keys(
self, search_term: str, limit: int = 20, offset: int = 0
) -> list[str]:
"""Search wildcard keys for autocomplete."""
normalized_term = _normalize_wildcard_key(search_term).strip()
if not normalized_term:
return []
ranked: list[tuple[int, str]] = []
compact_term = normalized_term.replace("/", "")
for key in self.get_wildcard_dict().keys():
score = self._score_entry(key, normalized_term, compact_term)
if score is not None:
ranked.append((score, key))
ranked.sort(key=lambda item: (-item[0], item[1]))
keys = [key for _, key in ranked]
return keys[offset : offset + limit]
def expand_text(self, text: str, seed: int | None = None) -> str:
"""Expand wildcard and dynamic prompt syntax for a text value."""
if not isinstance(text, str) or not text:
return text
rng = random.Random(seed) if seed is not None else random.Random()
wildcard_dict = self.get_wildcard_dict()
if not wildcard_dict:
return self._expand_options_only(text, rng)
current = text
remaining_depth = 100
while remaining_depth > 0:
remaining_depth -= 1
after_options, options_replaced = self._replace_options(current, rng)
current, wildcards_replaced = self._replace_wildcards(
after_options, rng, wildcard_dict
)
if not options_replaced and not wildcards_replaced:
break
return current
def get_wildcard_dict(self) -> dict[str, list[str]]:
signature = self._build_signature()
if signature != self._cached_signature:
self._wildcard_dict = self._scan_wildcard_dict()
self._cached_signature = signature
return self._wildcard_dict
def get_entries(self) -> list[WildcardEntry]:
return [
WildcardEntry(key=key, values_count=len(values))
for key, values in sorted(self.get_wildcard_dict().items())
]
def _build_signature(self) -> tuple[tuple[str, int, int], ...]:
root = get_wildcards_dir(create=False)
if not os.path.isdir(root):
return ()
signature: list[tuple[str, int, int]] = []
for current_root, _dirs, files in os.walk(root, followlinks=True):
for file_name in sorted(files):
if not file_name.lower().endswith((".txt", ".yaml", ".yml", ".json")):
continue
file_path = os.path.join(current_root, file_name)
try:
stat = os.stat(file_path)
except OSError:
continue
rel_path = os.path.relpath(file_path, root).replace("\\", "/")
signature.append((rel_path, int(stat.st_mtime_ns), int(stat.st_size)))
signature.sort()
return tuple(signature)
def _scan_wildcard_dict(self) -> dict[str, list[str]]:
root = get_wildcards_dir(create=False)
if not os.path.isdir(root):
return {}
collected: dict[str, list[str]] = {}
for current_root, _dirs, files in os.walk(root, followlinks=True):
for file_name in sorted(files):
file_path = os.path.join(current_root, file_name)
lower_name = file_name.lower()
try:
if lower_name.endswith(".txt"):
rel_path = os.path.relpath(file_path, root)
key = _normalize_wildcard_key(os.path.splitext(rel_path)[0])
values = self._read_txt(file_path)
if values:
collected[key] = values
elif lower_name.endswith((".yaml", ".yml")):
payload = self._read_yaml(file_path)
self._merge_nested_entries(collected, payload)
elif lower_name.endswith(".json"):
payload = self._read_json(file_path)
self._merge_nested_entries(collected, payload)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to load wildcard file %s: %s", file_path, exc)
return collected
def _read_txt(self, file_path: str) -> list[str]:
try:
with open(file_path, "r", encoding="utf-8", errors="ignore") as handle:
return [line.strip() for line in handle.read().splitlines() if line.strip()]
except OSError as exc:
logger.warning("Failed to read wildcard txt file %s: %s", file_path, exc)
return []
def _read_yaml(self, file_path: str) -> Any:
with open(file_path, "r", encoding="utf-8") as handle:
return yaml.safe_load(handle) or {}
def _read_json(self, file_path: str) -> Any:
with open(file_path, "r", encoding="utf-8") as handle:
return json.load(handle)
def _merge_nested_entries(
self, collected: dict[str, list[str]], payload: Any
) -> None:
for key, values in self._flatten_payload(payload):
collected[key] = values
def _flatten_payload(
self, payload: Any, prefix: str = ""
) -> list[tuple[str, list[str]]]:
entries: list[tuple[str, list[str]]] = []
if isinstance(payload, dict):
for key, value in payload.items():
next_prefix = f"{prefix}/{key}" if prefix else str(key)
entries.extend(self._flatten_payload(value, next_prefix))
return entries
if isinstance(payload, list):
normalized_prefix = _normalize_wildcard_key(prefix)
values = [value.strip() for value in payload if isinstance(value, str) and value.strip()]
if normalized_prefix and values:
entries.append((normalized_prefix, values))
return entries
return entries
def _score_entry(
self, key: str, normalized_term: str, compact_term: str
) -> int | None:
key_compact = key.replace("/", "")
if key == normalized_term:
return 5000
if key.startswith(normalized_term):
return 4000
if f"/{normalized_term}" in key:
return 3500
if normalized_term in key:
return 3000
if compact_term and key_compact.startswith(compact_term):
return 2500
if compact_term and compact_term in key_compact:
return 2000
return None
def _expand_options_only(self, text: str, rng: random.Random) -> str:
current = text
remaining_depth = 100
while remaining_depth > 0:
remaining_depth -= 1
current, replaced = self._replace_options(current, rng)
if not replaced:
break
return current
def _replace_options(
self, text: str, rng: random.Random
) -> tuple[str, bool]:
replaced_any = False
def replace_option(match: re.Match[str]) -> str:
nonlocal replaced_any
replacement = self._resolve_option_group(match.group(1), rng)
replaced_any = True
return replacement
return _OPTION_PATTERN.sub(replace_option, text), replaced_any
def _resolve_option_group(self, group_text: str, rng: random.Random) -> str:
options = group_text.split("|")
multi_select_pattern = options[0].split("$$")
select_range: tuple[int, int] | None = None
select_separator = " "
if len(multi_select_pattern) > 1:
count_spec = multi_select_pattern[0]
range_match = re.match(r"(\d+)(-(\d+))?$", count_spec)
shorthand_match = re.match(r"-(\d+)$", count_spec)
if range_match:
start_text = range_match.group(1)
end_text = range_match.group(3)
if end_text is not None and _is_numeric_string(start_text) and _is_numeric_string(end_text):
select_range = (int(start_text), int(end_text))
elif _is_numeric_string(start_text):
value = int(start_text)
select_range = (value, value)
elif shorthand_match:
end_text = shorthand_match.group(1)
if _is_numeric_string(end_text):
select_range = (1, int(end_text))
if select_range is not None and len(multi_select_pattern) == 2:
options[0] = multi_select_pattern[1]
elif select_range is not None and len(multi_select_pattern) >= 3:
select_separator = multi_select_pattern[1]
options[0] = multi_select_pattern[2]
weighted_options: list[tuple[float, str]] = []
for option in options:
weight = 1.0
parts = option.split("::", 1)
if len(parts) == 2 and _is_numeric_string(parts[0].strip()):
weight = float(parts[0].strip())
weighted_options.append((weight, option))
if select_range is None:
selection_count = 1
else:
selection_count = rng.randint(select_range[0], select_range[1])
if selection_count <= 1:
return self._strip_weight_prefix(self._weighted_choice(weighted_options, rng))
selection_count = min(selection_count, len(weighted_options))
selected: list[str] = []
used_indexes: set[int] = set()
while len(selected) < selection_count:
picked_index = self._weighted_choice_index(weighted_options, rng)
if picked_index in used_indexes:
if len(used_indexes) == len(weighted_options):
break
continue
used_indexes.add(picked_index)
selected.append(
self._strip_weight_prefix(weighted_options[picked_index][1])
)
return select_separator.join(selected)
def _weighted_choice(
self, weighted_options: list[tuple[float, str]], rng: random.Random
) -> str:
return weighted_options[self._weighted_choice_index(weighted_options, rng)][1]
def _weighted_choice_index(
self, weighted_options: list[tuple[float, str]], rng: random.Random
) -> int:
total_weight = sum(max(weight, 0.0) for weight, _value in weighted_options)
if total_weight <= 0:
return rng.randrange(len(weighted_options))
threshold = rng.uniform(0, total_weight)
cumulative = 0.0
for index, (weight, _value) in enumerate(weighted_options):
cumulative += max(weight, 0.0)
if threshold <= cumulative:
return index
return len(weighted_options) - 1
def _strip_weight_prefix(self, value: str) -> str:
return _WEIGHTED_OPTION_PATTERN.sub("", value, count=1)
def _replace_wildcards(
self,
text: str,
rng: random.Random,
wildcard_dict: dict[str, list[str]],
) -> tuple[str, bool]:
replaced_any = False
def replace_match(match: re.Match[str]) -> str:
nonlocal replaced_any
replacement = self._resolve_wildcard_match(match.group(1), rng, wildcard_dict)
if replacement is None:
return match.group(0)
replaced_any = True
return replacement
return _WILDCARD_PATTERN.sub(replace_match, text), replaced_any
def _resolve_wildcard_match(
self,
raw_key: str,
rng: random.Random,
wildcard_dict: dict[str, list[str]],
) -> str | None:
keyword = _normalize_wildcard_key(raw_key)
if keyword in wildcard_dict:
return rng.choice(wildcard_dict[keyword])
if "*" in keyword:
regex_pattern = keyword.replace("*", ".*").replace("+", r"\+")
compiled = re.compile(f"^{regex_pattern}$")
aggregated: list[str] = []
for key, values in wildcard_dict.items():
if compiled.match(key):
aggregated.extend(values)
if aggregated:
return rng.choice(aggregated)
if "/" not in keyword:
fallback_keyword = _normalize_wildcard_key(f"*/{keyword}")
if fallback_keyword != keyword:
return self._resolve_wildcard_match(fallback_keyword, rng, wildcard_dict)
return None
def is_trigger_words_input(name: str) -> bool:
return bool(_TRIGGER_WORD_PATTERN.match(name))
def get_wildcard_service() -> WildcardService:
return WildcardService.get_instance()
__all__ = [
"WildcardService",
"get_wildcard_service",
"get_wildcards_dir",
"is_trigger_words_input",
]

View File

@@ -1073,6 +1073,66 @@ describe('AutoComplete widget interactions', () => {
expect(fetchApiMock).toHaveBeenCalledWith('/lm/custom-words/search?enriched=true&search=cat&limit=100'); expect(fetchApiMock).toHaveBeenCalledWith('/lm/custom-words/search?enriched=true&search=cat&limit=100');
}); });
it('searches wildcard keys when using the /wild command', async () => {
vi.useFakeTimers();
fetchApiMock.mockResolvedValue({
json: () => Promise.resolve({ success: true, words: ['animals/cat'] }),
});
caretHelperInstance.getBeforeCursor.mockReturnValue('/wild cat');
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
const input = document.createElement('textarea');
input.value = '/wild cat';
input.selectionStart = input.value.length;
document.body.append(input);
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
const autoComplete = new AutoComplete(input,'prompt', {
debounceDelay: 0,
showPreview: false,
minChars: 1,
});
input.dispatchEvent(new Event('input', { bubbles: true }));
await vi.runAllTimersAsync();
await Promise.resolve();
expect(fetchApiMock).toHaveBeenCalledWith('/lm/wildcards/search?search=cat&limit=100');
expect(autoComplete.searchType).toBe('wildcards');
expect(autoComplete.items).toEqual(['animals/cat']);
});
it('inserts wildcard references when accepting a /wild result', async () => {
caretHelperInstance.getBeforeCursor.mockReturnValue('/wild animals/cat');
const input = document.createElement('textarea');
input.value = '/wild animals/cat';
input.selectionStart = input.value.length;
input.focus = vi.fn();
input.setSelectionRange = vi.fn();
document.body.append(input);
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
const autoComplete = new AutoComplete(input,'prompt', {
debounceDelay: 0,
showPreview: false,
minChars: 1,
});
autoComplete.searchType = 'wildcards';
autoComplete.activeCommand = { type: 'wildcard', label: 'Wildcards' };
autoComplete.items = ['animals/cat'];
autoComplete.selectedIndex = 0;
await autoComplete.insertSelection('animals/cat');
expect(input.value).toBe('__animals/cat__,');
expect(input.focus).toHaveBeenCalled();
expect(input.setSelectionRange).toHaveBeenCalled();
});
it('invalidates stale autocomplete metadata and falls back to delimiter-based matching', async () => { it('invalidates stale autocomplete metadata and falls back to delimiter-based matching', async () => {
settingGetMock.mockImplementation((key) => { settingGetMock.mockImplementation((key) => {
if (key === 'loramanager.autocomplete_append_comma') { if (key === 'loramanager.autocomplete_append_comma') {

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from py.nodes.prompt import PromptLM
from py.nodes.text import TextLM
def test_text_lm_expands_wildcards_before_output(monkeypatch):
node = TextLM()
expand_calls = []
class StubService:
def expand_text(self, text, seed=None):
expand_calls.append((text, seed))
return "expanded text"
monkeypatch.setattr("py.nodes.text.get_wildcard_service", lambda: StubService())
assert node.process("__flower__", seed=9) == ("expanded text",)
assert expand_calls == [("__flower__", 9)]
def test_prompt_lm_expands_before_appending_trigger_words(monkeypatch):
node = PromptLM()
class StubService:
def expand_text(self, text, seed=None):
assert text == "__flower__"
assert seed == 42
return "rose"
class StubEncoder:
def encode(self, clip, prompt):
assert clip == "clip"
assert prompt == "artist style, rose"
return ("conditioning",)
monkeypatch.setattr("py.nodes.prompt.get_wildcard_service", lambda: StubService())
monkeypatch.setattr("nodes.CLIPTextEncode", lambda: StubEncoder(), raising=False)
result = node.encode("__flower__", "clip", seed=42, trigger_words1="artist style")
assert result == ("conditioning", "artist style, rose")
def test_prompt_lm_input_types_expose_input_only_seed():
input_types = PromptLM.INPUT_TYPES()
seed_type, seed_options = input_types["optional"]["seed"]
assert seed_type == "INT"
assert seed_options["forceInput"] is True
assert "wildcard generation" in seed_options["tooltip"]
def test_text_lm_input_types_expose_input_only_seed():
input_types = TextLM.INPUT_TYPES()
seed_type, seed_options = input_types["optional"]["seed"]
assert seed_type == "INT"
assert seed_options["forceInput"] is True
assert "wildcard generation" in seed_options["tooltip"]

View File

@@ -0,0 +1,45 @@
from __future__ import annotations
import json
import pytest
from py.routes.handlers.misc_handlers import WildcardsHandler
class FakeRequest:
def __init__(self, query=None):
self.query = query or {}
@pytest.mark.asyncio
async def test_search_wildcards_returns_results():
class StubService:
def search_keys(self, search_term, limit, offset):
assert search_term == "cat"
assert limit == 25
assert offset == 2
return ["animals/cat"]
handler = WildcardsHandler(service=StubService())
response = await handler.search_wildcards(
FakeRequest(query={"search": "cat", "limit": "25", "offset": "2"})
)
payload = json.loads(response.text)
assert response.status == 200
assert payload == {"success": True, "words": ["animals/cat"]}
@pytest.mark.asyncio
async def test_search_wildcards_handles_errors():
class StubService:
def search_keys(self, search_term, limit, offset):
raise RuntimeError("boom")
handler = WildcardsHandler(service=StubService())
response = await handler.search_wildcards(FakeRequest(query={"search": "cat"}))
payload = json.loads(response.text)
assert response.status == 500
assert payload["error"] == "boom"

View File

@@ -0,0 +1,123 @@
from __future__ import annotations
import json
from py.services.wildcard_service import WildcardService
def _make_service(monkeypatch, tmp_path):
settings_dir = tmp_path / "settings"
settings_dir.mkdir()
monkeypatch.setattr(
"py.services.wildcard_service.get_settings_dir",
lambda create=True: str(settings_dir),
)
service = WildcardService()
service._cached_signature = None
service._wildcard_dict = {}
return service, settings_dir / "wildcards"
def test_search_keys_returns_empty_when_directory_missing(monkeypatch, tmp_path):
service, _wildcards_dir = _make_service(monkeypatch, tmp_path)
assert service.search_keys("cat") == []
def test_search_keys_loads_txt_yaml_and_json(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
(wildcards_dir / "animals").mkdir()
(wildcards_dir / "animals" / "cat.txt").write_text("tabby\nblack cat\n", encoding="utf-8")
(wildcards_dir / "colors.yaml").write_text(
"palette:\n warm:\n - red\n - orange\n",
encoding="utf-8",
)
(wildcards_dir / "artists.json").write_text(
json.dumps({"illustrators/digital": ["alice", "bob"]}),
encoding="utf-8",
)
assert service.search_keys("cat") == ["animals/cat"]
assert service.search_keys("warm") == ["palette/warm"]
assert service.search_keys("digital") == ["illustrators/digital"]
def test_search_keys_prefers_exact_and_prefix_matches(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
(wildcards_dir / "animals").mkdir()
(wildcards_dir / "animals" / "cat.txt").write_text("tabby\n", encoding="utf-8")
(wildcards_dir / "animals" / "catgirl.txt").write_text("heroine\n", encoding="utf-8")
(wildcards_dir / "fantasy_cat.txt").write_text("beast\n", encoding="utf-8")
results = service.search_keys("cat")
assert results == ["animals/cat", "animals/catgirl", "fantasy_cat"]
def test_search_keys_supports_offset_and_limit(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
for name in ("cat", "catgirl", "catmaid"):
(wildcards_dir / f"{name}.txt").write_text("x\n", encoding="utf-8")
assert service.search_keys("cat", limit=1, offset=1) == ["catgirl"]
def test_expand_text_resolves_nested_wildcards(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
(wildcards_dir / "flower.txt").write_text("rose\n__color__ lily\n", encoding="utf-8")
(wildcards_dir / "color.txt").write_text("red\nblue\n", encoding="utf-8")
expanded = service.expand_text("__flower__", seed=7)
assert expanded in {"rose", "red lily", "blue lily"}
assert "__" not in expanded
def test_expand_text_resolves_dynamic_prompt_and_multi_select(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
expanded = service.expand_text("{2$$, $$red|blue|green}", seed=3)
assert expanded.count(", ") == 1
assert set(expanded.split(", ")).issubset({"red", "blue", "green"})
def test_expand_text_resolves_wildcard_glob(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
(wildcards_dir / "animals").mkdir()
(wildcards_dir / "animals" / "cat.txt").write_text("tabby\n", encoding="utf-8")
(wildcards_dir / "animals" / "dog.txt").write_text("retriever\n", encoding="utf-8")
expanded = service.expand_text("__animals/*__", seed=1)
assert expanded in {"tabby", "retriever"}
def test_expand_text_is_deterministic_with_seed(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
(wildcards_dir / "color.txt").write_text("red\nblue\ngreen\n", encoding="utf-8")
first = service.expand_text("__color__", seed=123)
second = service.expand_text("__color__", seed=123)
assert first == second
def test_expand_text_leaves_unresolved_reference_visible(monkeypatch, tmp_path):
service, wildcards_dir = _make_service(monkeypatch, tmp_path)
wildcards_dir.mkdir()
assert service.expand_text("__missing__", seed=1) == "__missing__"

View File

@@ -425,7 +425,7 @@ function shouldBypassAutocompleteWidgetMigration(
} }
const originalWidgetsInputs = Object.values(inputDefs).filter((input: any) => const originalWidgetsInputs = Object.values(inputDefs).filter((input: any) =>
widgetNames.has(input.name) || input.forceInput widgetNames.has(input.name)
) )
const widgetIndexHasForceInput = originalWidgetsInputs.flatMap((input: any) => const widgetIndexHasForceInput = originalWidgetsInputs.flatMap((input: any) =>

View File

@@ -1,6 +1,12 @@
import { api } from "../../scripts/api.js"; import { api } from "../../scripts/api.js";
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { TextAreaCaretHelper } from "./textarea_caret_helper.js"; import { TextAreaCaretHelper } from "./textarea_caret_helper.js";
import {
WILDCARD_COMMANDS,
getWildcardInsertText,
getWildcardSearchEndpoint,
isWildcardCommand,
} from "./autocomplete_wildcards.js";
import { import {
getAutocompleteAppendCommaPreference, getAutocompleteAppendCommaPreference,
getAutocompleteAutoFormatPreference, getAutocompleteAutoFormatPreference,
@@ -22,6 +28,7 @@ const TAG_COMMANDS = {
'/lore': { categories: [15], label: 'Lore' }, '/lore': { categories: [15], label: 'Lore' },
'/emb': { type: 'embedding', label: 'Embeddings' }, '/emb': { type: 'embedding', label: 'Embeddings' },
'/embedding': { type: 'embedding', label: 'Embeddings' }, '/embedding': { type: 'embedding', label: 'Embeddings' },
...WILDCARD_COMMANDS,
// Autocomplete toggle commands - only show one based on current state // Autocomplete toggle commands - only show one based on current state
'/ac': { '/ac': {
type: 'toggle_setting', type: 'toggle_setting',
@@ -314,6 +321,8 @@ const MODEL_BEHAVIORS = {
const trimmedName = removeGeneralExtension(fileName); const trimmedName = removeGeneralExtension(fileName);
const folder = directories.length ? `${directories.join('/')}/` : ''; const folder = directories.length ? `${directories.join('/')}/` : '';
return formatAutocompleteInsertion(`embedding:${folder}${trimmedName}`); return formatAutocompleteInsertion(`embedding:${folder}${trimmedName}`);
} else if (instance.searchType === 'wildcards' || isWildcardCommand(instance.activeCommand)) {
return formatAutocompleteInsertion(getWildcardInsertText(relativePath));
} else { } else {
let tagText = relativePath; let tagText = relativePath;
@@ -658,6 +667,9 @@ class AutoComplete {
// /emb or /embedding command // /emb or /embedding command
endpoint = '/lm/embeddings/relative-paths'; endpoint = '/lm/embeddings/relative-paths';
this.searchType = 'embeddings'; this.searchType = 'embeddings';
} else if (isWildcardCommand(commandResult.command)) {
endpoint = getWildcardSearchEndpoint();
this.searchType = 'wildcards';
} else { } else {
// Category filter command // Category filter command
const categories = commandResult.command.categories.join(','); const categories = commandResult.command.categories.join(',');
@@ -1611,6 +1623,8 @@ class AutoComplete {
if (this.modelType === 'prompt') { if (this.modelType === 'prompt') {
if (this.searchType === 'embeddings') { if (this.searchType === 'embeddings') {
endpoint = '/lm/embeddings/relative-paths'; endpoint = '/lm/embeddings/relative-paths';
} else if (this.searchType === 'wildcards') {
endpoint = getWildcardSearchEndpoint();
} else if (this.searchType === 'custom_words') { } else if (this.searchType === 'custom_words') {
if (this.activeCommand?.categories) { if (this.activeCommand?.categories) {
const categories = this.activeCommand.categories.join(','); const categories = this.activeCommand.categories.join(',');

View File

@@ -0,0 +1,20 @@
export const WILDCARD_COMMANDS = {
'/wild': { type: 'wildcard', label: 'Wildcards' },
'/wildcard': { type: 'wildcard', label: 'Wildcards' },
};
export function isWildcardCommand(command) {
return command?.type === 'wildcard';
}
export function getWildcardSearchEndpoint() {
return '/lm/wildcards/search';
}
export function getWildcardInsertText(relativePath = '') {
const trimmed = typeof relativePath === 'string' ? relativePath.trim() : '';
if (!trimmed) {
return '';
}
return `__${trimmed}__`;
}

View File

@@ -15557,7 +15557,7 @@ function shouldBypassAutocompleteWidgetMigration(node, widgetValues) {
return false; return false;
} }
const originalWidgetsInputs = Object.values(inputDefs).filter( const originalWidgetsInputs = Object.values(inputDefs).filter(
(input) => widgetNames.has(input.name) || input.forceInput (input) => widgetNames.has(input.name)
); );
const widgetIndexHasForceInput = originalWidgetsInputs.flatMap( const widgetIndexHasForceInput = originalWidgetsInputs.flatMap(
(input) => input.control_after_generate ? [!!input.forceInput, false] : [!!input.forceInput] (input) => input.control_after_generate ? [!!input.forceInput, false] : [!!input.forceInput]

File diff suppressed because one or more lines are too long