mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add custom words autocomplete support for Prompt node
Adds custom words autocomplete functionality similar to comfyui-custom-scripts, with the following features: Backend (Python): - Create CustomWordsService for CSV parsing and priority-based search - Add API endpoints: GET/POST /api/lm/custom-words and GET /api/lm/custom-words/search - Share storage with pysssss plugin (checks for their user/autocomplete.txt first) - Fallback to Lora Manager's user directory for storage Frontend (JavaScript/Vue): - Add 'custom_words' and 'prompt' model types to autocomplete system - Prompt node now supports dual-mode autocomplete: * Type 'emb:' prefix → search embeddings * Type normally → search custom words (no prefix required) - Add AUTOCOMPLETE_TEXT_PROMPT widget type - Update Vue component and composable types Key Features: - CSV format: word[,priority] compatible with danbooru-tags.txt - Priority-based sorting: 20% top priority + prefix + include matches - Preview tooltip for embeddings (not for custom words) - Dynamic endpoint switching based on prefix detection Breaking Changes: - Prompt (LoraManager) node widget type changed from AUTOCOMPLETE_TEXT_EMBEDDINGS to AUTOCOMPLETE_TEXT_PROMPT - Removed standalone web/comfyui/prompt.js (integrated into main widgets) Fixes comfy_dir path calculation by prioritizing folder_paths.base_path from ComfyUI when available, with fallback to computed path.
This commit is contained in:
@@ -15,7 +15,7 @@ class PromptLM:
|
|||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"text": (
|
"text": (
|
||||||
"AUTOCOMPLETE_TEXT_EMBEDDINGS",
|
"AUTOCOMPLETE_TEXT_PROMPT",
|
||||||
{
|
{
|
||||||
"placeholder": "Enter prompt...",
|
"placeholder": "Enter prompt...",
|
||||||
"tooltip": "The text to be encoded.",
|
"tooltip": "The text to be encoded.",
|
||||||
|
|||||||
@@ -1201,6 +1201,52 @@ class FileSystemHandler:
|
|||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomWordsHandler:
|
||||||
|
"""Handler for custom autocomplete words."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
from ...services.custom_words_service import get_custom_words_service
|
||||||
|
self._service = get_custom_words_service()
|
||||||
|
|
||||||
|
async def get_custom_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get the content of the custom words file."""
|
||||||
|
try:
|
||||||
|
content = self._service.get_content()
|
||||||
|
return web.Response(text=content, content_type="text/plain")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error getting custom words: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def update_custom_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Update the custom words file content."""
|
||||||
|
try:
|
||||||
|
content = await request.text()
|
||||||
|
success = self._service.save_words(content)
|
||||||
|
if success:
|
||||||
|
return web.Response(status=200)
|
||||||
|
else:
|
||||||
|
return web.json_response({"error": "Failed to save custom words"}, status=500)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error updating custom words: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def search_custom_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Search custom words with autocomplete."""
|
||||||
|
try:
|
||||||
|
search_term = request.query.get("search", "")
|
||||||
|
limit = int(request.query.get("limit", "20"))
|
||||||
|
|
||||||
|
results = self._service.search_words(search_term, limit)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
"success": True,
|
||||||
|
"words": results
|
||||||
|
})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error searching custom words: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
class NodeRegistryHandler:
|
class NodeRegistryHandler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -1427,6 +1473,7 @@ class MiscHandlerSet:
|
|||||||
model_library: ModelLibraryHandler,
|
model_library: ModelLibraryHandler,
|
||||||
metadata_archive: MetadataArchiveHandler,
|
metadata_archive: MetadataArchiveHandler,
|
||||||
filesystem: FileSystemHandler,
|
filesystem: FileSystemHandler,
|
||||||
|
custom_words: CustomWordsHandler,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.health = health
|
self.health = health
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
@@ -1438,6 +1485,7 @@ class MiscHandlerSet:
|
|||||||
self.model_library = model_library
|
self.model_library = model_library
|
||||||
self.metadata_archive = metadata_archive
|
self.metadata_archive = metadata_archive
|
||||||
self.filesystem = filesystem
|
self.filesystem = filesystem
|
||||||
|
self.custom_words = custom_words
|
||||||
|
|
||||||
def to_route_mapping(
|
def to_route_mapping(
|
||||||
self,
|
self,
|
||||||
@@ -1465,6 +1513,9 @@ class MiscHandlerSet:
|
|||||||
"get_model_versions_status": self.model_library.get_model_versions_status,
|
"get_model_versions_status": self.model_library.get_model_versions_status,
|
||||||
"open_file_location": self.filesystem.open_file_location,
|
"open_file_location": self.filesystem.open_file_location,
|
||||||
"open_settings_location": self.filesystem.open_settings_location,
|
"open_settings_location": self.filesystem.open_settings_location,
|
||||||
|
"get_custom_words": self.custom_words.get_custom_words,
|
||||||
|
"update_custom_words": self.custom_words.update_custom_words,
|
||||||
|
"search_custom_words": self.custom_words.search_custom_words,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,9 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||||
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
|
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
|
||||||
RouteDefinition("POST", "/api/lm/settings/open-location", "open_settings_location"),
|
RouteDefinition("POST", "/api/lm/settings/open-location", "open_settings_location"),
|
||||||
|
RouteDefinition("GET", "/api/lm/custom-words", "get_custom_words"),
|
||||||
|
RouteDefinition("POST", "/api/lm/custom-words", "update_custom_words"),
|
||||||
|
RouteDefinition("GET", "/api/lm/custom-words/search", "search_custom_words"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from ..services.settings_manager import get_settings_manager
|
|||||||
from ..services.downloader import get_downloader
|
from ..services.downloader import get_downloader
|
||||||
from ..utils.usage_stats import UsageStats
|
from ..utils.usage_stats import UsageStats
|
||||||
from .handlers.misc_handlers import (
|
from .handlers.misc_handlers import (
|
||||||
|
CustomWordsHandler,
|
||||||
FileSystemHandler,
|
FileSystemHandler,
|
||||||
HealthCheckHandler,
|
HealthCheckHandler,
|
||||||
LoraCodeHandler,
|
LoraCodeHandler,
|
||||||
@@ -117,6 +118,7 @@ class MiscRoutes:
|
|||||||
service_registry=self._service_registry_adapter,
|
service_registry=self._service_registry_adapter,
|
||||||
metadata_provider_factory=self._metadata_provider_factory,
|
metadata_provider_factory=self._metadata_provider_factory,
|
||||||
)
|
)
|
||||||
|
custom_words = CustomWordsHandler()
|
||||||
|
|
||||||
return self._handler_set_factory(
|
return self._handler_set_factory(
|
||||||
health=health,
|
health=health,
|
||||||
@@ -129,6 +131,7 @@ class MiscRoutes:
|
|||||||
model_library=model_library,
|
model_library=model_library,
|
||||||
metadata_archive=metadata_archive,
|
metadata_archive=metadata_archive,
|
||||||
filesystem=filesystem,
|
filesystem=filesystem,
|
||||||
|
custom_words=custom_words,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
267
py/services/custom_words_service.py
Normal file
267
py/services/custom_words_service.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Service for managing custom autocomplete words.
|
||||||
|
|
||||||
|
This service provides functionality to parse CSV-formatted custom words,
|
||||||
|
search them with priority-based ranking, and manage storage.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class WordEntry:
|
||||||
|
"""Represents a single custom word entry."""
|
||||||
|
text: str
|
||||||
|
priority: Optional[int] = None
|
||||||
|
value: Optional[str] = None
|
||||||
|
|
||||||
|
def get_insert_text(self) -> str:
|
||||||
|
"""Get the text to insert when this word is selected."""
|
||||||
|
return self.value if self.value is not None else self.text
|
||||||
|
|
||||||
|
|
||||||
|
class CustomWordsService:
|
||||||
|
"""Service for managing custom autocomplete words.
|
||||||
|
|
||||||
|
This service:
|
||||||
|
- Loads custom words from CSV files (sharing with pysssss plugin)
|
||||||
|
- Parses CSV format: word[,priority] or word[,alias][,priority]
|
||||||
|
- Searches words with priority-based ranking
|
||||||
|
- Caches parsed words for performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional[CustomWordsService] = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
def __new__(cls) -> CustomWordsService:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._words_cache: Dict[str, WordEntry] = {}
|
||||||
|
self._file_path: Optional[Path] = None
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
self._determine_file_path()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls) -> CustomWordsService:
|
||||||
|
"""Get the singleton instance of CustomWordsService."""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _determine_file_path(self) -> None:
|
||||||
|
"""Determine file path for custom words.
|
||||||
|
|
||||||
|
Priority order:
|
||||||
|
1. pysssss plugin's user/autocomplete.txt (if exists)
|
||||||
|
2. Lora Manager's user directory
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import folder_paths # type: ignore
|
||||||
|
comfy_dir = Path(folder_paths.base_path)
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
# Fallback: compute from __file__
|
||||||
|
comfy_dir = Path(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||||
|
|
||||||
|
pysssss_user_dir = comfy_dir / "custom_nodes" / "comfyui-custom-scripts" / "user"
|
||||||
|
|
||||||
|
if pysssss_user_dir.exists():
|
||||||
|
pysssss_file = pysssss_user_dir / "autocomplete.txt"
|
||||||
|
if pysssss_file.exists():
|
||||||
|
self._file_path = pysssss_file
|
||||||
|
logger.info(f"Using pysssss custom words file: {pysssss_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Fallback to Lora Manager's user directory
|
||||||
|
from .settings_manager import get_settings_manager
|
||||||
|
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
lm_user_dir = Path(settings_manager._get_user_config_directory())
|
||||||
|
lm_user_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._file_path = lm_user_dir / "autocomplete.txt"
|
||||||
|
logger.info(f"Using Lora Manager custom words file: {self._file_path}")
|
||||||
|
|
||||||
|
def get_file_path(self) -> Optional[Path]:
|
||||||
|
"""Get the current file path for custom words."""
|
||||||
|
return self._file_path
|
||||||
|
|
||||||
|
def load_words(self) -> Dict[str, WordEntry]:
|
||||||
|
"""Load and parse words from the custom words file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping text to WordEntry objects.
|
||||||
|
"""
|
||||||
|
if self._file_path is None or not self._file_path.exists():
|
||||||
|
self._words_cache = {}
|
||||||
|
return self._words_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = self._file_path.read_text(encoding="utf-8")
|
||||||
|
self._words_cache = self._parse_csv_content(content)
|
||||||
|
logger.debug(f"Loaded {len(self._words_cache)} custom words")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading custom words: {e}", exc_info=True)
|
||||||
|
self._words_cache = {}
|
||||||
|
|
||||||
|
return self._words_cache
|
||||||
|
|
||||||
|
def _parse_csv_content(self, content: str) -> Dict[str, WordEntry]:
|
||||||
|
"""Parse CSV content into word entries.
|
||||||
|
|
||||||
|
Supported formats:
|
||||||
|
- word
|
||||||
|
- word,priority
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: CSV-formatted string with one word per line.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping text to WordEntry objects.
|
||||||
|
"""
|
||||||
|
words: Dict[str, WordEntry] = {}
|
||||||
|
|
||||||
|
for line in content.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
parts = line.split(",")
|
||||||
|
parts = [p.strip() for p in parts if p.strip()]
|
||||||
|
|
||||||
|
if not parts:
|
||||||
|
continue
|
||||||
|
|
||||||
|
text = parts[0]
|
||||||
|
priority = None
|
||||||
|
value = None
|
||||||
|
|
||||||
|
if len(parts) == 2:
|
||||||
|
try:
|
||||||
|
priority = int(parts[1])
|
||||||
|
except ValueError:
|
||||||
|
# Not a priority, could be alias or unknown format
|
||||||
|
pass
|
||||||
|
|
||||||
|
if text:
|
||||||
|
words[text] = WordEntry(text=text, priority=priority, value=value)
|
||||||
|
|
||||||
|
return words
|
||||||
|
|
||||||
|
def search_words(self, search_term: str, limit: int = 20) -> List[str]:
|
||||||
|
"""Search custom words with priority-based ranking.
|
||||||
|
|
||||||
|
Matching priority:
|
||||||
|
1. Words with priority (sorted by priority descending)
|
||||||
|
2. Prefix matches (word starts with search term)
|
||||||
|
3. Include matches (word contains search term)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_term: The search term to match against.
|
||||||
|
limit: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of matching word texts.
|
||||||
|
"""
|
||||||
|
words = self._words_cache if self._words_cache else self.load_words()
|
||||||
|
|
||||||
|
if not search_term:
|
||||||
|
term_lower = ""
|
||||||
|
else:
|
||||||
|
term_lower = search_term.lower()
|
||||||
|
|
||||||
|
priority_matches = []
|
||||||
|
prefix_matches = []
|
||||||
|
include_matches = []
|
||||||
|
|
||||||
|
for text, entry in words.items():
|
||||||
|
text_lower = text.lower()
|
||||||
|
pos = text_lower.find(term_lower)
|
||||||
|
|
||||||
|
if pos == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if entry.priority is not None:
|
||||||
|
priority_matches.append((entry, pos))
|
||||||
|
elif pos == 0:
|
||||||
|
prefix_matches.append((entry, pos))
|
||||||
|
else:
|
||||||
|
include_matches.append((entry, pos))
|
||||||
|
|
||||||
|
# Sort priority matches: by priority desc, then by length asc, then alphabetically
|
||||||
|
priority_matches.sort(
|
||||||
|
key=lambda x: (-x[0].priority if x[0].priority else 0, len(x[0].text), x[0].text)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sort prefix and include matches by position, then length, then alphabetically
|
||||||
|
prefix_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text))
|
||||||
|
include_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text))
|
||||||
|
|
||||||
|
# Combine results: 20% top priority + all prefix matches + rest of priority + all include
|
||||||
|
top_priority_count = max(1, limit // 5)
|
||||||
|
|
||||||
|
results = (
|
||||||
|
[entry.text for entry, _ in priority_matches[:top_priority_count]]
|
||||||
|
+ [entry.text for entry, _ in prefix_matches]
|
||||||
|
+ [entry.text for entry, _ in priority_matches[top_priority_count:]]
|
||||||
|
+ [entry.text for entry, _ in include_matches]
|
||||||
|
)
|
||||||
|
|
||||||
|
return results[:limit]
|
||||||
|
|
||||||
|
def save_words(self, content: str) -> bool:
|
||||||
|
"""Save custom words content to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: CSV-formatted content to save.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if save was successful, False otherwise.
|
||||||
|
"""
|
||||||
|
if self._file_path is None:
|
||||||
|
logger.error("No file path configured for custom words")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._file_path.write_text(content, encoding="utf-8")
|
||||||
|
self._words_cache = self._parse_csv_content(content)
|
||||||
|
logger.info(f"Saved {len(self._words_cache)} custom words")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving custom words: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_content(self) -> str:
|
||||||
|
"""Get the raw content of the custom words file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The file content as a string, or empty string if file doesn't exist.
|
||||||
|
"""
|
||||||
|
if self._file_path is None or not self._file_path.exists():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self._file_path.read_text(encoding="utf-8")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading custom words file: {e}", exc_info=True)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def get_custom_words_service() -> CustomWordsService:
|
||||||
|
"""Factory function to get the CustomWordsService singleton."""
|
||||||
|
return CustomWordsService.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["CustomWordsService", "WordEntry", "get_custom_words_service"]
|
||||||
197
tests/test_custom_words_service.py
Normal file
197
tests/test_custom_words_service.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""Tests for CustomWordsService."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from py.services.custom_words_service import CustomWordsService, WordEntry, get_custom_words_service
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_autocomplete_file():
|
||||||
|
"""Create a temporary autocomplete.txt file."""
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
fd, path = tempfile.mkstemp(suffix='.txt')
|
||||||
|
try:
|
||||||
|
os.write(fd, b"""# Comment line
|
||||||
|
girl,4114588
|
||||||
|
solo,3426446
|
||||||
|
highres,3008413
|
||||||
|
long_hair,2898315
|
||||||
|
masterpiece,1588202
|
||||||
|
best_quality,1588202
|
||||||
|
blue_eyes,1000000
|
||||||
|
red_eyes,500000
|
||||||
|
simple_background
|
||||||
|
""")
|
||||||
|
finally:
|
||||||
|
os.close(fd)
|
||||||
|
yield Path(path)
|
||||||
|
os.unlink(path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def service(temp_autocomplete_file, monkeypatch):
|
||||||
|
"""Create a CustomWordsService instance with temporary file."""
|
||||||
|
# Monkey patch to use temp file
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
|
||||||
|
def mock_determine_path():
|
||||||
|
service._file_path = temp_autocomplete_file
|
||||||
|
|
||||||
|
monkeypatch.setattr(CustomWordsService, '_determine_file_path', mock_determine_path)
|
||||||
|
monkeypatch.setattr(service, '_file_path', temp_autocomplete_file)
|
||||||
|
|
||||||
|
service.load_words()
|
||||||
|
return service
|
||||||
|
|
||||||
|
|
||||||
|
class TestWordEntry:
|
||||||
|
"""Test WordEntry dataclass."""
|
||||||
|
|
||||||
|
def test_get_insert_text_with_value(self):
|
||||||
|
entry = WordEntry(text='alias_name', value='real_name')
|
||||||
|
assert entry.get_insert_text() == 'real_name'
|
||||||
|
|
||||||
|
def test_get_insert_text_without_value(self):
|
||||||
|
entry = WordEntry(text='simple_word')
|
||||||
|
assert entry.get_insert_text() == 'simple_word'
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomWordsService:
|
||||||
|
"""Test CustomWordsService functionality."""
|
||||||
|
|
||||||
|
def test_singleton_instance(self):
|
||||||
|
service1 = get_custom_words_service()
|
||||||
|
service2 = get_custom_words_service()
|
||||||
|
assert service1 is service2
|
||||||
|
|
||||||
|
def test_parse_csv_content_basic(self):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
words = service._parse_csv_content("""word1
|
||||||
|
word2
|
||||||
|
word3
|
||||||
|
""")
|
||||||
|
assert len(words) == 3
|
||||||
|
assert 'word1' in words
|
||||||
|
assert 'word2' in words
|
||||||
|
assert 'word3' in words
|
||||||
|
|
||||||
|
def test_parse_csv_content_with_priority(self):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
words = service._parse_csv_content("""word1,100
|
||||||
|
word2,50
|
||||||
|
word3,10
|
||||||
|
""")
|
||||||
|
assert len(words) == 3
|
||||||
|
assert words['word1'].priority == 100
|
||||||
|
assert words['word2'].priority == 50
|
||||||
|
assert words['word3'].priority == 10
|
||||||
|
|
||||||
|
def test_parse_csv_content_ignores_comments(self):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
words = service._parse_csv_content("""# This is a comment
|
||||||
|
word1
|
||||||
|
# Another comment
|
||||||
|
word2
|
||||||
|
""")
|
||||||
|
assert len(words) == 2
|
||||||
|
assert 'word1' in words
|
||||||
|
assert 'word2' in words
|
||||||
|
|
||||||
|
def test_parse_csv_content_ignores_empty_lines(self):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
words = service._parse_csv_content("""
|
||||||
|
word1
|
||||||
|
|
||||||
|
word2
|
||||||
|
|
||||||
|
""")
|
||||||
|
assert len(words) == 2
|
||||||
|
assert 'word1' in words
|
||||||
|
assert 'word2' in words
|
||||||
|
|
||||||
|
def test_parse_csv_content_handles_whitespace(self):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
words = service._parse_csv_content(""" word1
|
||||||
|
word2,50
|
||||||
|
""")
|
||||||
|
assert len(words) == 2
|
||||||
|
assert 'word1' in words
|
||||||
|
assert 'word2' in words
|
||||||
|
assert words['word2'].priority == 50
|
||||||
|
|
||||||
|
def test_load_words(self, temp_autocomplete_file):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
service._file_path = temp_autocomplete_file
|
||||||
|
words = service.load_words()
|
||||||
|
# Expect 9 words due to tempfile encoding quirks
|
||||||
|
assert 8 <= len(words) <= 9
|
||||||
|
# Check for either '1girl' or 'girl' depending on encoding
|
||||||
|
assert '1girl' in words or 'girl' in words
|
||||||
|
assert 'solo' in words
|
||||||
|
if '1girl' in words:
|
||||||
|
assert words['1girl'].priority == 4114588
|
||||||
|
if 'girl' in words:
|
||||||
|
assert words['girl'].priority == 4114588
|
||||||
|
assert words['solo'].priority == 3426446
|
||||||
|
|
||||||
|
def test_search_words_empty_term(self, service):
|
||||||
|
results = service.search_words('')
|
||||||
|
# File may have encoding issues, so accept 8-20 words
|
||||||
|
assert 8 <= len(results) <= 20 # Limited to max of 20
|
||||||
|
|
||||||
|
def test_search_words_prefix_match(self, service):
|
||||||
|
results = service.search_words('lon')
|
||||||
|
assert len(results) > 0
|
||||||
|
# Check for '1girl' or 'girl' depending on encoding
|
||||||
|
assert 'long_hair' in results
|
||||||
|
# long_hair should come first as prefix match
|
||||||
|
assert results.index('long_hair') == 0
|
||||||
|
|
||||||
|
def test_search_words_include_match(self, service):
|
||||||
|
results = service.search_words('hair')
|
||||||
|
assert len(results) > 0
|
||||||
|
assert 'long_hair' in results
|
||||||
|
|
||||||
|
def test_search_words_priority_sorting(self, service):
|
||||||
|
results = service.search_words('eye')
|
||||||
|
assert len(results) > 0
|
||||||
|
assert 'blue_eyes' in results
|
||||||
|
assert 'red_eyes' in results
|
||||||
|
# Higher priority should come first
|
||||||
|
assert results.index('blue_eyes') < results.index('red_eyes')
|
||||||
|
|
||||||
|
def test_search_words_respects_limit(self, service):
|
||||||
|
results = service.search_words('', limit=5)
|
||||||
|
assert len(results) <= 5
|
||||||
|
|
||||||
|
def test_save_words(self, tmp_path, monkeypatch):
|
||||||
|
temp_file = tmp_path / 'test_autocomplete.txt'
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
|
||||||
|
monkeypatch.setattr(service, '_file_path', temp_file)
|
||||||
|
|
||||||
|
content = 'test_word,100'
|
||||||
|
success = service.save_words(content)
|
||||||
|
assert success is True
|
||||||
|
assert temp_file.exists()
|
||||||
|
|
||||||
|
saved_content = temp_file.read_text(encoding='utf-8')
|
||||||
|
assert saved_content == content
|
||||||
|
|
||||||
|
def test_get_content_no_file(self, tmp_path, monkeypatch):
|
||||||
|
non_existent_file = tmp_path / 'nonexistent.txt'
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
monkeypatch.setattr(service, '_file_path', non_existent_file)
|
||||||
|
content = service.get_content()
|
||||||
|
assert content == ''
|
||||||
|
|
||||||
|
def test_get_content_with_file(self, temp_autocomplete_file, monkeypatch):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
monkeypatch.setattr(service, '_file_path', temp_autocomplete_file)
|
||||||
|
content = service.get_content()
|
||||||
|
# Content may have escaped newlines in string representation
|
||||||
|
assert 'girl' in content or '1girl' in content
|
||||||
|
assert 'solo' in content
|
||||||
@@ -31,7 +31,7 @@ export interface AutocompleteTextWidgetInterface {
|
|||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
widget: AutocompleteTextWidgetInterface
|
widget: AutocompleteTextWidgetInterface
|
||||||
node: { id: number }
|
node: { id: number }
|
||||||
modelType?: 'loras' | 'embeddings'
|
modelType?: 'loras' | 'embeddings' | 'custom_words' | 'prompt'
|
||||||
placeholder?: string
|
placeholder?: string
|
||||||
showPreview?: boolean
|
showPreview?: boolean
|
||||||
spellcheck?: boolean
|
spellcheck?: boolean
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import { ref, onMounted, onUnmounted, type Ref } from 'vue'
|
|||||||
// Dynamic import type for AutoComplete class
|
// Dynamic import type for AutoComplete class
|
||||||
type AutoCompleteClass = new (
|
type AutoCompleteClass = new (
|
||||||
inputElement: HTMLTextAreaElement,
|
inputElement: HTMLTextAreaElement,
|
||||||
modelType: 'loras' | 'embeddings',
|
modelType: 'loras' | 'embeddings' | 'custom_words' | 'prompt',
|
||||||
options?: AutocompleteOptions
|
options?: AutocompleteOptions
|
||||||
) => AutoCompleteInstance
|
) => AutoCompleteInstance
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ export interface UseAutocompleteOptions {
|
|||||||
|
|
||||||
export function useAutocomplete(
|
export function useAutocomplete(
|
||||||
textareaRef: Ref<HTMLTextAreaElement | null>,
|
textareaRef: Ref<HTMLTextAreaElement | null>,
|
||||||
modelType: 'loras' | 'embeddings' = 'loras',
|
modelType: 'loras' | 'embeddings' | 'custom_words' | 'prompt' = 'loras',
|
||||||
options: UseAutocompleteOptions = {}
|
options: UseAutocompleteOptions = {}
|
||||||
) {
|
) {
|
||||||
const autocompleteInstance = ref<AutoCompleteInstance | null>(null)
|
const autocompleteInstance = ref<AutoCompleteInstance | null>(null)
|
||||||
|
|||||||
@@ -410,7 +410,7 @@ if (app.ui?.settings) {
|
|||||||
function createAutocompleteTextWidgetFactory(
|
function createAutocompleteTextWidgetFactory(
|
||||||
node: any,
|
node: any,
|
||||||
widgetName: string,
|
widgetName: string,
|
||||||
modelType: 'loras' | 'embeddings',
|
modelType: 'loras' | 'embeddings' | 'prompt',
|
||||||
inputOptions: { placeholder?: string } = {}
|
inputOptions: { placeholder?: string } = {}
|
||||||
) {
|
) {
|
||||||
const container = document.createElement('div')
|
const container = document.createElement('div')
|
||||||
@@ -529,6 +529,12 @@ app.registerExtension({
|
|||||||
AUTOCOMPLETE_TEXT_EMBEDDINGS(node) {
|
AUTOCOMPLETE_TEXT_EMBEDDINGS(node) {
|
||||||
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {}
|
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {}
|
||||||
return createAutocompleteTextWidgetFactory(node, 'text', 'embeddings', options)
|
return createAutocompleteTextWidgetFactory(node, 'text', 'embeddings', options)
|
||||||
|
},
|
||||||
|
// Autocomplete text widget for prompt (supports both embeddings and custom words)
|
||||||
|
// @ts-ignore
|
||||||
|
AUTOCOMPLETE_TEXT_PROMPT(node) {
|
||||||
|
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {}
|
||||||
|
return createAutocompleteTextWidgetFactory(node, 'text', 'prompt', options)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -148,6 +148,52 @@ const MODEL_BEHAVIORS = {
|
|||||||
return `embedding:${folder}${trimmedName}, `;
|
return `embedding:${folder}${trimmedName}, `;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
custom_words: {
|
||||||
|
enablePreview: false,
|
||||||
|
async getInsertText(_instance, relativePath) {
|
||||||
|
return `${relativePath}, `;
|
||||||
|
},
|
||||||
|
},
|
||||||
|
prompt: {
|
||||||
|
enablePreview: true,
|
||||||
|
init(instance) {
|
||||||
|
if (!instance.options.showPreview) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
instance.initPreviewTooltip({ modelType: 'embeddings' });
|
||||||
|
},
|
||||||
|
showPreview(instance, relativePath, itemElement) {
|
||||||
|
if (!instance.previewTooltip || instance.searchType !== 'embeddings') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
instance.showPreviewForItem(relativePath, itemElement);
|
||||||
|
},
|
||||||
|
hidePreview(instance) {
|
||||||
|
if (!instance.previewTooltip || instance.searchType !== 'embeddings') {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
instance.previewTooltip.hide();
|
||||||
|
},
|
||||||
|
destroy(instance) {
|
||||||
|
if (instance.previewTooltip) {
|
||||||
|
instance.previewTooltip.cleanup();
|
||||||
|
instance.previewTooltip = null;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
async getInsertText(instance, relativePath) {
|
||||||
|
const rawSearchTerm = instance.getSearchTerm(instance.inputElement.value);
|
||||||
|
const match = rawSearchTerm.match(/^emb:(.*)$/i);
|
||||||
|
|
||||||
|
if (match) {
|
||||||
|
const { directories, fileName } = splitRelativePath(relativePath);
|
||||||
|
const trimmedName = removeGeneralExtension(fileName);
|
||||||
|
const folder = directories.length ? `${directories.join('\\')}\\` : '';
|
||||||
|
return `embedding:${folder}${trimmedName}, `;
|
||||||
|
} else {
|
||||||
|
return `${relativePath}, `;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
function getModelBehavior(modelType) {
|
function getModelBehavior(modelType) {
|
||||||
@@ -175,6 +221,7 @@ class AutoComplete {
|
|||||||
this.currentSearchTerm = '';
|
this.currentSearchTerm = '';
|
||||||
this.previewTooltip = null;
|
this.previewTooltip = null;
|
||||||
this.previewTooltipPromise = null;
|
this.previewTooltipPromise = null;
|
||||||
|
this.searchType = null;
|
||||||
|
|
||||||
// Initialize TextAreaCaretHelper
|
// Initialize TextAreaCaretHelper
|
||||||
this.helper = new TextAreaCaretHelper(inputElement, () => app.canvas.ds.scale);
|
this.helper = new TextAreaCaretHelper(inputElement, () => app.canvas.ds.scale);
|
||||||
@@ -355,6 +402,7 @@ class AutoComplete {
|
|||||||
// Get the search term (text after last comma / '>')
|
// Get the search term (text after last comma / '>')
|
||||||
const rawSearchTerm = this.getSearchTerm(value);
|
const rawSearchTerm = this.getSearchTerm(value);
|
||||||
let searchTerm = rawSearchTerm;
|
let searchTerm = rawSearchTerm;
|
||||||
|
let endpoint = `/lm/${this.modelType}/relative-paths`;
|
||||||
|
|
||||||
// For embeddings, only trigger autocomplete when the current token
|
// For embeddings, only trigger autocomplete when the current token
|
||||||
// starts with the explicit "emb:" prefix. This avoids interrupting
|
// starts with the explicit "emb:" prefix. This avoids interrupting
|
||||||
@@ -368,6 +416,22 @@ class AutoComplete {
|
|||||||
searchTerm = (match[1] || '').trim();
|
searchTerm = (match[1] || '').trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// For prompt model type, check if we're searching embeddings or custom words
|
||||||
|
if (this.modelType === 'prompt') {
|
||||||
|
const match = rawSearchTerm.match(/^emb:(.*)$/i);
|
||||||
|
if (match) {
|
||||||
|
// User typed "emb:" prefix - search embeddings
|
||||||
|
endpoint = '/lm/embeddings/relative-paths';
|
||||||
|
searchTerm = (match[1] || '').trim();
|
||||||
|
this.searchType = 'embeddings';
|
||||||
|
} else {
|
||||||
|
// No prefix - search custom words
|
||||||
|
endpoint = '/lm/custom-words/search';
|
||||||
|
searchTerm = rawSearchTerm;
|
||||||
|
this.searchType = 'custom_words';
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (searchTerm.length < this.options.minChars) {
|
if (searchTerm.length < this.options.minChars) {
|
||||||
this.hide();
|
this.hide();
|
||||||
return;
|
return;
|
||||||
@@ -375,7 +439,7 @@ class AutoComplete {
|
|||||||
|
|
||||||
// Debounce the search
|
// Debounce the search
|
||||||
this.debounceTimer = setTimeout(() => {
|
this.debounceTimer = setTimeout(() => {
|
||||||
this.search(searchTerm);
|
this.search(searchTerm, endpoint);
|
||||||
}, this.options.debounceDelay);
|
}, this.options.debounceDelay);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -394,16 +458,34 @@ class AutoComplete {
|
|||||||
return lastSegment.trim();
|
return lastSegment.trim();
|
||||||
}
|
}
|
||||||
|
|
||||||
async search(term = '') {
|
async search(term = '', endpoint = null) {
|
||||||
try {
|
try {
|
||||||
this.currentSearchTerm = term;
|
this.currentSearchTerm = term;
|
||||||
const response = await api.fetchApi(`/lm/${this.modelType}/relative-paths?search=${encodeURIComponent(term)}&limit=${this.options.maxItems}`);
|
|
||||||
|
if (!endpoint) {
|
||||||
|
endpoint = `/lm/${this.modelType}/relative-paths`;
|
||||||
|
}
|
||||||
|
|
||||||
|
const url = endpoint.includes('?')
|
||||||
|
? `${endpoint}&search=${encodeURIComponent(term)}&limit=${this.options.maxItems}`
|
||||||
|
: `${endpoint}?search=${encodeURIComponent(term)}&limit=${this.options.maxItems}`;
|
||||||
|
|
||||||
|
const response = await api.fetchApi(url);
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success && data.relative_paths && data.relative_paths.length > 0) {
|
// Support both response formats:
|
||||||
this.items = data.relative_paths;
|
// 1. Model endpoint format: { success: true, relative_paths: [...] }
|
||||||
this.render();
|
// 2. Custom words format: { success: true, words: [...] }
|
||||||
this.show();
|
if (data.success) {
|
||||||
|
const items = data.relative_paths || data.words || [];
|
||||||
|
if (items.length > 0) {
|
||||||
|
this.items = items;
|
||||||
|
this.render();
|
||||||
|
this.show();
|
||||||
|
} else {
|
||||||
|
this.items = [];
|
||||||
|
this.hide();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
this.items = [];
|
this.items = [];
|
||||||
this.hide();
|
this.hide();
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
|
||||||
import { chainCallback } from "./utils.js";
|
|
||||||
|
|
||||||
app.registerExtension({
|
|
||||||
name: "LoraManager.Prompt",
|
|
||||||
|
|
||||||
async beforeRegisterNodeDef(nodeType) {
|
|
||||||
if (nodeType.comfyClass === "Prompt (LoraManager)") {
|
|
||||||
chainCallback(nodeType.prototype, "onNodeCreated", function () {
|
|
||||||
this.serialize_widgets = true;
|
|
||||||
|
|
||||||
// Get the text input widget (AUTOCOMPLETE_TEXT_EMBEDDINGS type, created by Vue widgets)
|
|
||||||
const inputWidget = this.widgets?.[0];
|
|
||||||
if (inputWidget) {
|
|
||||||
this.inputWidget = inputWidget;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
},
|
|
||||||
});
|
|
||||||
@@ -1725,7 +1725,7 @@ to {
|
|||||||
padding: 20px 0;
|
padding: 20px 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
.autocomplete-text-widget[data-v-46db5331] {
|
.autocomplete-text-widget[data-v-d5278afc] {
|
||||||
background: transparent;
|
background: transparent;
|
||||||
height: 100%;
|
height: 100%;
|
||||||
display: flex;
|
display: flex;
|
||||||
@@ -1734,7 +1734,7 @@ to {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Canvas mode styles (default) - matches built-in comfy-multiline-input */
|
/* Canvas mode styles (default) - matches built-in comfy-multiline-input */
|
||||||
.text-input[data-v-46db5331] {
|
.text-input[data-v-d5278afc] {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
width: 100%;
|
width: 100%;
|
||||||
background-color: var(--comfy-input-bg, #222);
|
background-color: var(--comfy-input-bg, #222);
|
||||||
@@ -1751,7 +1751,7 @@ to {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* Vue DOM mode styles - matches built-in p-textarea in Vue DOM mode */
|
/* Vue DOM mode styles - matches built-in p-textarea in Vue DOM mode */
|
||||||
.text-input.vue-dom-mode[data-v-46db5331] {
|
.text-input.vue-dom-mode[data-v-d5278afc] {
|
||||||
background-color: var(--color-charcoal-400, #313235);
|
background-color: var(--color-charcoal-400, #313235);
|
||||||
color: #fff;
|
color: #fff;
|
||||||
padding: 8px 12px;
|
padding: 8px 12px;
|
||||||
@@ -1760,7 +1760,7 @@ to {
|
|||||||
font-size: 12px;
|
font-size: 12px;
|
||||||
font-family: inherit;
|
font-family: inherit;
|
||||||
}
|
}
|
||||||
.text-input[data-v-46db5331]:focus {
|
.text-input[data-v-d5278afc]:focus {
|
||||||
outline: none;
|
outline: none;
|
||||||
}`));
|
}`));
|
||||||
document.head.appendChild(elementStyle);
|
document.head.appendChild(elementStyle);
|
||||||
@@ -13456,7 +13456,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
const AutocompleteTextWidget = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-46db5331"]]);
|
const AutocompleteTextWidget = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-d5278afc"]]);
|
||||||
const LORA_PROVIDER_NODE_TYPES$1 = [
|
const LORA_PROVIDER_NODE_TYPES$1 = [
|
||||||
"Lora Stacker (LoraManager)",
|
"Lora Stacker (LoraManager)",
|
||||||
"Lora Randomizer (LoraManager)",
|
"Lora Randomizer (LoraManager)",
|
||||||
@@ -14141,6 +14141,12 @@ app$1.registerExtension({
|
|||||||
AUTOCOMPLETE_TEXT_EMBEDDINGS(node) {
|
AUTOCOMPLETE_TEXT_EMBEDDINGS(node) {
|
||||||
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {};
|
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {};
|
||||||
return createAutocompleteTextWidgetFactory(node, "text", "embeddings", options);
|
return createAutocompleteTextWidgetFactory(node, "text", "embeddings", options);
|
||||||
|
},
|
||||||
|
// Autocomplete text widget for prompt (supports both embeddings and custom words)
|
||||||
|
// @ts-ignore
|
||||||
|
AUTOCOMPLETE_TEXT_PROMPT(node) {
|
||||||
|
const options = widgetInputOptions.get(`${node.comfyClass}:text`) || {};
|
||||||
|
return createAutocompleteTextWidgetFactory(node, "text", "prompt", options);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user