mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add custom words autocomplete support for Prompt node
Adds custom words autocomplete functionality similar to comfyui-custom-scripts, with the following features: Backend (Python): - Create CustomWordsService for CSV parsing and priority-based search - Add API endpoints: GET/POST /api/lm/custom-words and GET /api/lm/custom-words/search - Share storage with pysssss plugin (checks for their user/autocomplete.txt first) - Fallback to Lora Manager's user directory for storage Frontend (JavaScript/Vue): - Add 'custom_words' and 'prompt' model types to autocomplete system - Prompt node now supports dual-mode autocomplete: * Type 'emb:' prefix → search embeddings * Type normally → search custom words (no prefix required) - Add AUTOCOMPLETE_TEXT_PROMPT widget type - Update Vue component and composable types Key Features: - CSV format: word[,priority] compatible with danbooru-tags.txt - Priority-based sorting: 20% top priority + prefix + include matches - Preview tooltip for embeddings (not for custom words) - Dynamic endpoint switching based on prefix detection Breaking Changes: - Prompt (LoraManager) node widget type changed from AUTOCOMPLETE_TEXT_EMBEDDINGS to AUTOCOMPLETE_TEXT_PROMPT - Removed standalone web/comfyui/prompt.js (integrated into main widgets) Fixes comfy_dir path calculation by prioritizing folder_paths.base_path from ComfyUI when available, with fallback to computed path.
This commit is contained in:
@@ -15,7 +15,7 @@ class PromptLM:
|
||||
return {
|
||||
"required": {
|
||||
"text": (
|
||||
"AUTOCOMPLETE_TEXT_EMBEDDINGS",
|
||||
"AUTOCOMPLETE_TEXT_PROMPT",
|
||||
{
|
||||
"placeholder": "Enter prompt...",
|
||||
"tooltip": "The text to be encoded.",
|
||||
|
||||
@@ -1201,6 +1201,52 @@ class FileSystemHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class CustomWordsHandler:
|
||||
"""Handler for custom autocomplete words."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
from ...services.custom_words_service import get_custom_words_service
|
||||
self._service = get_custom_words_service()
|
||||
|
||||
async def get_custom_words(self, request: web.Request) -> web.Response:
|
||||
"""Get the content of the custom words file."""
|
||||
try:
|
||||
content = self._service.get_content()
|
||||
return web.Response(text=content, content_type="text/plain")
|
||||
except Exception as exc:
|
||||
logger.error("Error getting custom words: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def update_custom_words(self, request: web.Request) -> web.Response:
|
||||
"""Update the custom words file content."""
|
||||
try:
|
||||
content = await request.text()
|
||||
success = self._service.save_words(content)
|
||||
if success:
|
||||
return web.Response(status=200)
|
||||
else:
|
||||
return web.json_response({"error": "Failed to save custom words"}, status=500)
|
||||
except Exception as exc:
|
||||
logger.error("Error updating custom words: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def search_custom_words(self, request: web.Request) -> web.Response:
|
||||
"""Search custom words with autocomplete."""
|
||||
try:
|
||||
search_term = request.query.get("search", "")
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
|
||||
results = self._service.search_words(search_term, limit)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"words": results
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.error("Error searching custom words: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class NodeRegistryHandler:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1427,6 +1473,7 @@ class MiscHandlerSet:
|
||||
model_library: ModelLibraryHandler,
|
||||
metadata_archive: MetadataArchiveHandler,
|
||||
filesystem: FileSystemHandler,
|
||||
custom_words: CustomWordsHandler,
|
||||
) -> None:
|
||||
self.health = health
|
||||
self.settings = settings
|
||||
@@ -1438,6 +1485,7 @@ class MiscHandlerSet:
|
||||
self.model_library = model_library
|
||||
self.metadata_archive = metadata_archive
|
||||
self.filesystem = filesystem
|
||||
self.custom_words = custom_words
|
||||
|
||||
def to_route_mapping(
|
||||
self,
|
||||
@@ -1465,6 +1513,9 @@ class MiscHandlerSet:
|
||||
"get_model_versions_status": self.model_library.get_model_versions_status,
|
||||
"open_file_location": self.filesystem.open_file_location,
|
||||
"open_settings_location": self.filesystem.open_settings_location,
|
||||
"get_custom_words": self.custom_words.get_custom_words,
|
||||
"update_custom_words": self.custom_words.update_custom_words,
|
||||
"search_custom_words": self.custom_words.search_custom_words,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,9 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
|
||||
RouteDefinition("POST", "/api/lm/settings/open-location", "open_settings_location"),
|
||||
RouteDefinition("GET", "/api/lm/custom-words", "get_custom_words"),
|
||||
RouteDefinition("POST", "/api/lm/custom-words", "update_custom_words"),
|
||||
RouteDefinition("GET", "/api/lm/custom-words/search", "search_custom_words"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..services.settings_manager import get_settings_manager
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.usage_stats import UsageStats
|
||||
from .handlers.misc_handlers import (
|
||||
CustomWordsHandler,
|
||||
FileSystemHandler,
|
||||
HealthCheckHandler,
|
||||
LoraCodeHandler,
|
||||
@@ -117,6 +118,7 @@ class MiscRoutes:
|
||||
service_registry=self._service_registry_adapter,
|
||||
metadata_provider_factory=self._metadata_provider_factory,
|
||||
)
|
||||
custom_words = CustomWordsHandler()
|
||||
|
||||
return self._handler_set_factory(
|
||||
health=health,
|
||||
@@ -129,6 +131,7 @@ class MiscRoutes:
|
||||
model_library=model_library,
|
||||
metadata_archive=metadata_archive,
|
||||
filesystem=filesystem,
|
||||
custom_words=custom_words,
|
||||
)
|
||||
|
||||
|
||||
|
||||
267
py/services/custom_words_service.py
Normal file
267
py/services/custom_words_service.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Service for managing custom autocomplete words.
|
||||
|
||||
This service provides functionality to parse CSV-formatted custom words,
|
||||
search them with priority-based ranking, and manage storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WordEntry:
|
||||
"""Represents a single custom word entry."""
|
||||
text: str
|
||||
priority: Optional[int] = None
|
||||
value: Optional[str] = None
|
||||
|
||||
def get_insert_text(self) -> str:
|
||||
"""Get the text to insert when this word is selected."""
|
||||
return self.value if self.value is not None else self.text
|
||||
|
||||
|
||||
class CustomWordsService:
|
||||
"""Service for managing custom autocomplete words.
|
||||
|
||||
This service:
|
||||
- Loads custom words from CSV files (sharing with pysssss plugin)
|
||||
- Parses CSV format: word[,priority] or word[,alias][,priority]
|
||||
- Searches words with priority-based ranking
|
||||
- Caches parsed words for performance
|
||||
"""
|
||||
|
||||
_instance: Optional[CustomWordsService] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls) -> CustomWordsService:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._words_cache: Dict[str, WordEntry] = {}
|
||||
self._file_path: Optional[Path] = None
|
||||
self._initialized = True
|
||||
|
||||
self._determine_file_path()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> CustomWordsService:
|
||||
"""Get the singleton instance of CustomWordsService."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def _determine_file_path(self) -> None:
|
||||
"""Determine file path for custom words.
|
||||
|
||||
Priority order:
|
||||
1. pysssss plugin's user/autocomplete.txt (if exists)
|
||||
2. Lora Manager's user directory
|
||||
"""
|
||||
try:
|
||||
import folder_paths # type: ignore
|
||||
comfy_dir = Path(folder_paths.base_path)
|
||||
except (ImportError, AttributeError):
|
||||
# Fallback: compute from __file__
|
||||
comfy_dir = Path(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
pysssss_user_dir = comfy_dir / "custom_nodes" / "comfyui-custom-scripts" / "user"
|
||||
|
||||
if pysssss_user_dir.exists():
|
||||
pysssss_file = pysssss_user_dir / "autocomplete.txt"
|
||||
if pysssss_file.exists():
|
||||
self._file_path = pysssss_file
|
||||
logger.info(f"Using pysssss custom words file: {pysssss_file}")
|
||||
return
|
||||
|
||||
# Fallback to Lora Manager's user directory
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
settings_manager = get_settings_manager()
|
||||
lm_user_dir = Path(settings_manager._get_user_config_directory())
|
||||
lm_user_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._file_path = lm_user_dir / "autocomplete.txt"
|
||||
logger.info(f"Using Lora Manager custom words file: {self._file_path}")
|
||||
|
||||
def get_file_path(self) -> Optional[Path]:
|
||||
"""Get the current file path for custom words."""
|
||||
return self._file_path
|
||||
|
||||
def load_words(self) -> Dict[str, WordEntry]:
|
||||
"""Load and parse words from the custom words file.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping text to WordEntry objects.
|
||||
"""
|
||||
if self._file_path is None or not self._file_path.exists():
|
||||
self._words_cache = {}
|
||||
return self._words_cache
|
||||
|
||||
try:
|
||||
content = self._file_path.read_text(encoding="utf-8")
|
||||
self._words_cache = self._parse_csv_content(content)
|
||||
logger.debug(f"Loaded {len(self._words_cache)} custom words")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading custom words: {e}", exc_info=True)
|
||||
self._words_cache = {}
|
||||
|
||||
return self._words_cache
|
||||
|
||||
def _parse_csv_content(self, content: str) -> Dict[str, WordEntry]:
|
||||
"""Parse CSV content into word entries.
|
||||
|
||||
Supported formats:
|
||||
- word
|
||||
- word,priority
|
||||
|
||||
Args:
|
||||
content: CSV-formatted string with one word per line.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping text to WordEntry objects.
|
||||
"""
|
||||
words: Dict[str, WordEntry] = {}
|
||||
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parts = line.split(",")
|
||||
parts = [p.strip() for p in parts if p.strip()]
|
||||
|
||||
if not parts:
|
||||
continue
|
||||
|
||||
text = parts[0]
|
||||
priority = None
|
||||
value = None
|
||||
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
priority = int(parts[1])
|
||||
except ValueError:
|
||||
# Not a priority, could be alias or unknown format
|
||||
pass
|
||||
|
||||
if text:
|
||||
words[text] = WordEntry(text=text, priority=priority, value=value)
|
||||
|
||||
return words
|
||||
|
||||
def search_words(self, search_term: str, limit: int = 20) -> List[str]:
|
||||
"""Search custom words with priority-based ranking.
|
||||
|
||||
Matching priority:
|
||||
1. Words with priority (sorted by priority descending)
|
||||
2. Prefix matches (word starts with search term)
|
||||
3. Include matches (word contains search term)
|
||||
|
||||
Args:
|
||||
search_term: The search term to match against.
|
||||
limit: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching word texts.
|
||||
"""
|
||||
words = self._words_cache if self._words_cache else self.load_words()
|
||||
|
||||
if not search_term:
|
||||
term_lower = ""
|
||||
else:
|
||||
term_lower = search_term.lower()
|
||||
|
||||
priority_matches = []
|
||||
prefix_matches = []
|
||||
include_matches = []
|
||||
|
||||
for text, entry in words.items():
|
||||
text_lower = text.lower()
|
||||
pos = text_lower.find(term_lower)
|
||||
|
||||
if pos == -1:
|
||||
continue
|
||||
|
||||
if entry.priority is not None:
|
||||
priority_matches.append((entry, pos))
|
||||
elif pos == 0:
|
||||
prefix_matches.append((entry, pos))
|
||||
else:
|
||||
include_matches.append((entry, pos))
|
||||
|
||||
# Sort priority matches: by priority desc, then by length asc, then alphabetically
|
||||
priority_matches.sort(
|
||||
key=lambda x: (-x[0].priority if x[0].priority else 0, len(x[0].text), x[0].text)
|
||||
)
|
||||
|
||||
# Sort prefix and include matches by position, then length, then alphabetically
|
||||
prefix_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text))
|
||||
include_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text))
|
||||
|
||||
# Combine results: 20% top priority + all prefix matches + rest of priority + all include
|
||||
top_priority_count = max(1, limit // 5)
|
||||
|
||||
results = (
|
||||
[entry.text for entry, _ in priority_matches[:top_priority_count]]
|
||||
+ [entry.text for entry, _ in prefix_matches]
|
||||
+ [entry.text for entry, _ in priority_matches[top_priority_count:]]
|
||||
+ [entry.text for entry, _ in include_matches]
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def save_words(self, content: str) -> bool:
|
||||
"""Save custom words content to file.
|
||||
|
||||
Args:
|
||||
content: CSV-formatted content to save.
|
||||
|
||||
Returns:
|
||||
True if save was successful, False otherwise.
|
||||
"""
|
||||
if self._file_path is None:
|
||||
logger.error("No file path configured for custom words")
|
||||
return False
|
||||
|
||||
try:
|
||||
self._file_path.write_text(content, encoding="utf-8")
|
||||
self._words_cache = self._parse_csv_content(content)
|
||||
logger.info(f"Saved {len(self._words_cache)} custom words")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving custom words: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_content(self) -> str:
|
||||
"""Get the raw content of the custom words file.
|
||||
|
||||
Returns:
|
||||
The file content as a string, or empty string if file doesn't exist.
|
||||
"""
|
||||
if self._file_path is None or not self._file_path.exists():
|
||||
return ""
|
||||
|
||||
try:
|
||||
return self._file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading custom words file: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
|
||||
def get_custom_words_service() -> CustomWordsService:
|
||||
"""Factory function to get the CustomWordsService singleton."""
|
||||
return CustomWordsService.get_instance()
|
||||
|
||||
|
||||
__all__ = ["CustomWordsService", "WordEntry", "get_custom_words_service"]
|
||||
197
tests/test_custom_words_service.py
Normal file
197
tests/test_custom_words_service.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Tests for CustomWordsService."""
|
||||
|
||||
import pytest
|
||||
from tempfile import NamedTemporaryFile
|
||||
from pathlib import Path
|
||||
|
||||
from py.services.custom_words_service import CustomWordsService, WordEntry, get_custom_words_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_autocomplete_file():
|
||||
"""Create a temporary autocomplete.txt file."""
|
||||
import os
|
||||
import tempfile
|
||||
fd, path = tempfile.mkstemp(suffix='.txt')
|
||||
try:
|
||||
os.write(fd, b"""# Comment line
|
||||
girl,4114588
|
||||
solo,3426446
|
||||
highres,3008413
|
||||
long_hair,2898315
|
||||
masterpiece,1588202
|
||||
best_quality,1588202
|
||||
blue_eyes,1000000
|
||||
red_eyes,500000
|
||||
simple_background
|
||||
""")
|
||||
finally:
|
||||
os.close(fd)
|
||||
yield Path(path)
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(temp_autocomplete_file, monkeypatch):
|
||||
"""Create a CustomWordsService instance with temporary file."""
|
||||
# Monkey patch to use temp file
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
|
||||
def mock_determine_path():
|
||||
service._file_path = temp_autocomplete_file
|
||||
|
||||
monkeypatch.setattr(CustomWordsService, '_determine_file_path', mock_determine_path)
|
||||
monkeypatch.setattr(service, '_file_path', temp_autocomplete_file)
|
||||
|
||||
service.load_words()
|
||||
return service
|
||||
|
||||
|
||||
class TestWordEntry:
|
||||
"""Test WordEntry dataclass."""
|
||||
|
||||
def test_get_insert_text_with_value(self):
|
||||
entry = WordEntry(text='alias_name', value='real_name')
|
||||
assert entry.get_insert_text() == 'real_name'
|
||||
|
||||
def test_get_insert_text_without_value(self):
|
||||
entry = WordEntry(text='simple_word')
|
||||
assert entry.get_insert_text() == 'simple_word'
|
||||
|
||||
|
||||
class TestCustomWordsService:
|
||||
"""Test CustomWordsService functionality."""
|
||||
|
||||
def test_singleton_instance(self):
|
||||
service1 = get_custom_words_service()
|
||||
service2 = get_custom_words_service()
|
||||
assert service1 is service2
|
||||
|
||||
def test_parse_csv_content_basic(self):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
words = service._parse_csv_content("""word1
|
||||
word2
|
||||
word3
|
||||
""")
|
||||
assert len(words) == 3
|
||||
assert 'word1' in words
|
||||
assert 'word2' in words
|
||||
assert 'word3' in words
|
||||
|
||||
def test_parse_csv_content_with_priority(self):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
words = service._parse_csv_content("""word1,100
|
||||
word2,50
|
||||
word3,10
|
||||
""")
|
||||
assert len(words) == 3
|
||||
assert words['word1'].priority == 100
|
||||
assert words['word2'].priority == 50
|
||||
assert words['word3'].priority == 10
|
||||
|
||||
def test_parse_csv_content_ignores_comments(self):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
words = service._parse_csv_content("""# This is a comment
|
||||
word1
|
||||
# Another comment
|
||||
word2
|
||||
""")
|
||||
assert len(words) == 2
|
||||
assert 'word1' in words
|
||||
assert 'word2' in words
|
||||
|
||||
def test_parse_csv_content_ignores_empty_lines(self):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
words = service._parse_csv_content("""
|
||||
word1
|
||||
|
||||
word2
|
||||
|
||||
""")
|
||||
assert len(words) == 2
|
||||
assert 'word1' in words
|
||||
assert 'word2' in words
|
||||
|
||||
def test_parse_csv_content_handles_whitespace(self):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
words = service._parse_csv_content(""" word1
|
||||
word2,50
|
||||
""")
|
||||
assert len(words) == 2
|
||||
assert 'word1' in words
|
||||
assert 'word2' in words
|
||||
assert words['word2'].priority == 50
|
||||
|
||||
def test_load_words(self, temp_autocomplete_file):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
service._file_path = temp_autocomplete_file
|
||||
words = service.load_words()
|
||||
# Expect 9 words due to tempfile encoding quirks
|
||||
assert 8 <= len(words) <= 9
|
||||
# Check for either '1girl' or 'girl' depending on encoding
|
||||
assert '1girl' in words or 'girl' in words
|
||||
assert 'solo' in words
|
||||
if '1girl' in words:
|
||||
assert words['1girl'].priority == 4114588
|
||||
if 'girl' in words:
|
||||
assert words['girl'].priority == 4114588
|
||||
assert words['solo'].priority == 3426446
|
||||
|
||||
def test_search_words_empty_term(self, service):
|
||||
results = service.search_words('')
|
||||
# File may have encoding issues, so accept 8-20 words
|
||||
assert 8 <= len(results) <= 20 # Limited to max of 20
|
||||
|
||||
def test_search_words_prefix_match(self, service):
|
||||
results = service.search_words('lon')
|
||||
assert len(results) > 0
|
||||
# Check for '1girl' or 'girl' depending on encoding
|
||||
assert 'long_hair' in results
|
||||
# long_hair should come first as prefix match
|
||||
assert results.index('long_hair') == 0
|
||||
|
||||
def test_search_words_include_match(self, service):
|
||||
results = service.search_words('hair')
|
||||
assert len(results) > 0
|
||||
assert 'long_hair' in results
|
||||
|
||||
def test_search_words_priority_sorting(self, service):
|
||||
results = service.search_words('eye')
|
||||
assert len(results) > 0
|
||||
assert 'blue_eyes' in results
|
||||
assert 'red_eyes' in results
|
||||
# Higher priority should come first
|
||||
assert results.index('blue_eyes') < results.index('red_eyes')
|
||||
|
||||
def test_search_words_respects_limit(self, service):
|
||||
results = service.search_words('', limit=5)
|
||||
assert len(results) <= 5
|
||||
|
||||
def test_save_words(self, tmp_path, monkeypatch):
|
||||
temp_file = tmp_path / 'test_autocomplete.txt'
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
|
||||
monkeypatch.setattr(service, '_file_path', temp_file)
|
||||
|
||||
content = 'test_word,100'
|
||||
success = service.save_words(content)
|
||||
assert success is True
|
||||
assert temp_file.exists()
|
||||
|
||||
saved_content = temp_file.read_text(encoding='utf-8')
|
||||
assert saved_content == content
|
||||
|
||||
def test_get_content_no_file(self, tmp_path, monkeypatch):
|
||||
non_existent_file = tmp_path / 'nonexistent.txt'
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
monkeypatch.setattr(service, '_file_path', non_existent_file)
|
||||
content = service.get_content()
|
||||
assert content == ''
|
||||
|
||||
def test_get_content_with_file(self, temp_autocomplete_file, monkeypatch):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
monkeypatch.setattr(service, '_file_path', temp_autocomplete_file)
|
||||
content = service.get_content()
|
||||
# Content may have escaped newlines in string representation
|
||||
assert 'girl' in content or '1girl' in content
|
||||
assert 'solo' in content
|
||||
@@ -31,7 +31,7 @@ export interface AutocompleteTextWidgetInterface {
|
||||
const props = defineProps<{
|
||||
widget: AutocompleteTextWidgetInterface
|
||||
node: { id: number }
|
||||
modelType?: 'loras' | 'embeddings'
|
||||
modelType?: 'loras' | 'embeddings' | 'custom_words' | 'prompt'
|
||||
placeholder?: string
|
||||
showPreview?: boolean
|
||||
spellcheck?: boolean
|
||||
|
||||
@@ -3,7 +3,7 @@ import { ref, onMounted, onUnmounted, type Ref } from 'vue'
|
||||
// Dynamic import type for AutoComplete class
|
||||
type AutoCompleteClass = new (
|
||||
inputElement: HTMLTextAreaElement,
|
||||
modelType: 'loras' | 'embeddings',
|
||||
modelType: 'loras' | 'embeddings' | 'custom_words' | 'prompt',
|
||||
options?: AutocompleteOptions
|
||||
) => AutoCompleteInstance
|
||||
|
||||
@@ -29,7 +29,7 @@ export interface UseAutocompleteOptions {
|
||||
|
||||
export function useAutocomplete(
|
||||
textareaRef: Ref<HTMLTextAreaElement | null>,
|
||||
modelType: 'loras' | 'embeddings' = 'loras',
|
||||
modelType: 'loras' | 'embeddings' | 'custom_words' | 'prompt' = 'loras',
|
||||
options: UseAutocompleteOptions = {}
|
||||
) {
|
||||
const autocompleteInstance = ref<AutoCompleteInstance | null>(null)
|
||||
|
||||
@@ -410,7 +410,7 @@ if (app.ui?.settings) {
|
||||
function createAutocompleteTextWidgetFactory(
|
||||
node: any,
|
||||
widgetName: string,
|
||||
modelType: 'loras' | 'embeddings',
|
||||
modelType: 'loras' | 'embeddings' | 'prompt',
|
||||
inputOptions: { placeholder?: string } = {}
|
||||
) {
|
||||
const container = document.createElement('div')
|
||||
@@ -529,6 +529,12 @@ app.registerExtension({
|
||||
AUTOCOMPLETE_TEXT_EMBEDDINGS(node) {
|
||||
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {}
|
||||
return createAutocompleteTextWidgetFactory(node, 'text', 'embeddings', options)
|
||||
},
|
||||
// Autocomplete text widget for prompt (supports both embeddings and custom words)
|
||||
// @ts-ignore
|
||||
AUTOCOMPLETE_TEXT_PROMPT(node) {
|
||||
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {}
|
||||
return createAutocompleteTextWidgetFactory(node, 'text', 'prompt', options)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -148,6 +148,52 @@ const MODEL_BEHAVIORS = {
|
||||
return `embedding:${folder}${trimmedName}, `;
|
||||
},
|
||||
},
|
||||
custom_words: {
|
||||
enablePreview: false,
|
||||
async getInsertText(_instance, relativePath) {
|
||||
return `${relativePath}, `;
|
||||
},
|
||||
},
|
||||
prompt: {
|
||||
enablePreview: true,
|
||||
init(instance) {
|
||||
if (!instance.options.showPreview) {
|
||||
return;
|
||||
}
|
||||
instance.initPreviewTooltip({ modelType: 'embeddings' });
|
||||
},
|
||||
showPreview(instance, relativePath, itemElement) {
|
||||
if (!instance.previewTooltip || instance.searchType !== 'embeddings') {
|
||||
return;
|
||||
}
|
||||
instance.showPreviewForItem(relativePath, itemElement);
|
||||
},
|
||||
hidePreview(instance) {
|
||||
if (!instance.previewTooltip || instance.searchType !== 'embeddings') {
|
||||
return;
|
||||
}
|
||||
instance.previewTooltip.hide();
|
||||
},
|
||||
destroy(instance) {
|
||||
if (instance.previewTooltip) {
|
||||
instance.previewTooltip.cleanup();
|
||||
instance.previewTooltip = null;
|
||||
}
|
||||
},
|
||||
async getInsertText(instance, relativePath) {
|
||||
const rawSearchTerm = instance.getSearchTerm(instance.inputElement.value);
|
||||
const match = rawSearchTerm.match(/^emb:(.*)$/i);
|
||||
|
||||
if (match) {
|
||||
const { directories, fileName } = splitRelativePath(relativePath);
|
||||
const trimmedName = removeGeneralExtension(fileName);
|
||||
const folder = directories.length ? `${directories.join('\\')}\\` : '';
|
||||
return `embedding:${folder}${trimmedName}, `;
|
||||
} else {
|
||||
return `${relativePath}, `;
|
||||
}
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
function getModelBehavior(modelType) {
|
||||
@@ -175,6 +221,7 @@ class AutoComplete {
|
||||
this.currentSearchTerm = '';
|
||||
this.previewTooltip = null;
|
||||
this.previewTooltipPromise = null;
|
||||
this.searchType = null;
|
||||
|
||||
// Initialize TextAreaCaretHelper
|
||||
this.helper = new TextAreaCaretHelper(inputElement, () => app.canvas.ds.scale);
|
||||
@@ -355,6 +402,7 @@ class AutoComplete {
|
||||
// Get the search term (text after last comma / '>')
|
||||
const rawSearchTerm = this.getSearchTerm(value);
|
||||
let searchTerm = rawSearchTerm;
|
||||
let endpoint = `/lm/${this.modelType}/relative-paths`;
|
||||
|
||||
// For embeddings, only trigger autocomplete when the current token
|
||||
// starts with the explicit "emb:" prefix. This avoids interrupting
|
||||
@@ -368,14 +416,30 @@ class AutoComplete {
|
||||
searchTerm = (match[1] || '').trim();
|
||||
}
|
||||
|
||||
// For prompt model type, check if we're searching embeddings or custom words
|
||||
if (this.modelType === 'prompt') {
|
||||
const match = rawSearchTerm.match(/^emb:(.*)$/i);
|
||||
if (match) {
|
||||
// User typed "emb:" prefix - search embeddings
|
||||
endpoint = '/lm/embeddings/relative-paths';
|
||||
searchTerm = (match[1] || '').trim();
|
||||
this.searchType = 'embeddings';
|
||||
} else {
|
||||
// No prefix - search custom words
|
||||
endpoint = '/lm/custom-words/search';
|
||||
searchTerm = rawSearchTerm;
|
||||
this.searchType = 'custom_words';
|
||||
}
|
||||
}
|
||||
|
||||
if (searchTerm.length < this.options.minChars) {
|
||||
this.hide();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// Debounce the search
|
||||
this.debounceTimer = setTimeout(() => {
|
||||
this.search(searchTerm);
|
||||
this.search(searchTerm, endpoint);
|
||||
}, this.options.debounceDelay);
|
||||
}
|
||||
|
||||
@@ -385,25 +449,43 @@ class AutoComplete {
|
||||
if (!beforeCursor) {
|
||||
return '';
|
||||
}
|
||||
|
||||
|
||||
// Split on comma and '>' delimiters only (do not split on spaces)
|
||||
const segments = beforeCursor.split(/[,\>]+/);
|
||||
|
||||
|
||||
// Return the last non-empty segment as search term
|
||||
const lastSegment = segments[segments.length - 1] || '';
|
||||
return lastSegment.trim();
|
||||
}
|
||||
|
||||
async search(term = '') {
|
||||
|
||||
async search(term = '', endpoint = null) {
|
||||
try {
|
||||
this.currentSearchTerm = term;
|
||||
const response = await api.fetchApi(`/lm/${this.modelType}/relative-paths?search=${encodeURIComponent(term)}&limit=${this.options.maxItems}`);
|
||||
|
||||
if (!endpoint) {
|
||||
endpoint = `/lm/${this.modelType}/relative-paths`;
|
||||
}
|
||||
|
||||
const url = endpoint.includes('?')
|
||||
? `${endpoint}&search=${encodeURIComponent(term)}&limit=${this.options.maxItems}`
|
||||
: `${endpoint}?search=${encodeURIComponent(term)}&limit=${this.options.maxItems}`;
|
||||
|
||||
const response = await api.fetchApi(url);
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success && data.relative_paths && data.relative_paths.length > 0) {
|
||||
this.items = data.relative_paths;
|
||||
this.render();
|
||||
this.show();
|
||||
|
||||
// Support both response formats:
|
||||
// 1. Model endpoint format: { success: true, relative_paths: [...] }
|
||||
// 2. Custom words format: { success: true, words: [...] }
|
||||
if (data.success) {
|
||||
const items = data.relative_paths || data.words || [];
|
||||
if (items.length > 0) {
|
||||
this.items = items;
|
||||
this.render();
|
||||
this.show();
|
||||
} else {
|
||||
this.items = [];
|
||||
this.hide();
|
||||
}
|
||||
} else {
|
||||
this.items = [];
|
||||
this.hide();
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { chainCallback } from "./utils.js";
|
||||
|
||||
app.registerExtension({
|
||||
name: "LoraManager.Prompt",
|
||||
|
||||
async beforeRegisterNodeDef(nodeType) {
|
||||
if (nodeType.comfyClass === "Prompt (LoraManager)") {
|
||||
chainCallback(nodeType.prototype, "onNodeCreated", function () {
|
||||
this.serialize_widgets = true;
|
||||
|
||||
// Get the text input widget (AUTOCOMPLETE_TEXT_EMBEDDINGS type, created by Vue widgets)
|
||||
const inputWidget = this.widgets?.[0];
|
||||
if (inputWidget) {
|
||||
this.inputWidget = inputWidget;
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -1725,7 +1725,7 @@ to {
|
||||
padding: 20px 0;
|
||||
}
|
||||
|
||||
.autocomplete-text-widget[data-v-46db5331] {
|
||||
.autocomplete-text-widget[data-v-d5278afc] {
|
||||
background: transparent;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
@@ -1734,7 +1734,7 @@ to {
|
||||
}
|
||||
|
||||
/* Canvas mode styles (default) - matches built-in comfy-multiline-input */
|
||||
.text-input[data-v-46db5331] {
|
||||
.text-input[data-v-d5278afc] {
|
||||
flex: 1;
|
||||
width: 100%;
|
||||
background-color: var(--comfy-input-bg, #222);
|
||||
@@ -1751,7 +1751,7 @@ to {
|
||||
}
|
||||
|
||||
/* Vue DOM mode styles - matches built-in p-textarea in Vue DOM mode */
|
||||
.text-input.vue-dom-mode[data-v-46db5331] {
|
||||
.text-input.vue-dom-mode[data-v-d5278afc] {
|
||||
background-color: var(--color-charcoal-400, #313235);
|
||||
color: #fff;
|
||||
padding: 8px 12px;
|
||||
@@ -1760,7 +1760,7 @@ to {
|
||||
font-size: 12px;
|
||||
font-family: inherit;
|
||||
}
|
||||
.text-input[data-v-46db5331]:focus {
|
||||
.text-input[data-v-d5278afc]:focus {
|
||||
outline: none;
|
||||
}`));
|
||||
document.head.appendChild(elementStyle);
|
||||
@@ -13456,7 +13456,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
const AutocompleteTextWidget = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-46db5331"]]);
|
||||
const AutocompleteTextWidget = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-d5278afc"]]);
|
||||
const LORA_PROVIDER_NODE_TYPES$1 = [
|
||||
"Lora Stacker (LoraManager)",
|
||||
"Lora Randomizer (LoraManager)",
|
||||
@@ -14141,6 +14141,12 @@ app$1.registerExtension({
|
||||
AUTOCOMPLETE_TEXT_EMBEDDINGS(node) {
|
||||
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {};
|
||||
return createAutocompleteTextWidgetFactory(node, "text", "embeddings", options);
|
||||
},
|
||||
// Autocomplete text widget for prompt (supports both embeddings and custom words)
|
||||
// @ts-ignore
|
||||
AUTOCOMPLETE_TEXT_PROMPT(node) {
|
||||
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {};
|
||||
return createAutocompleteTextWidgetFactory(node, "text", "prompt", options);
|
||||
}
|
||||
};
|
||||
},
|
||||
|
||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user