mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat: add custom words autocomplete support for Prompt node
Adds custom words autocomplete functionality similar to comfyui-custom-scripts, with the following features: Backend (Python): - Create CustomWordsService for CSV parsing and priority-based search - Add API endpoints: GET/POST /api/lm/custom-words and GET /api/lm/custom-words/search - Share storage with pysssss plugin (checks for their user/autocomplete.txt first) - Fallback to Lora Manager's user directory for storage Frontend (JavaScript/Vue): - Add 'custom_words' and 'prompt' model types to autocomplete system - Prompt node now supports dual-mode autocomplete: * Type 'emb:' prefix → search embeddings * Type normally → search custom words (no prefix required) - Add AUTOCOMPLETE_TEXT_PROMPT widget type - Update Vue component and composable types Key Features: - CSV format: word[,priority] compatible with danbooru-tags.txt - Priority-based sorting: 20% top priority + prefix + include matches - Preview tooltip for embeddings (not for custom words) - Dynamic endpoint switching based on prefix detection Breaking Changes: - Prompt (LoraManager) node widget type changed from AUTOCOMPLETE_TEXT_EMBEDDINGS to AUTOCOMPLETE_TEXT_PROMPT - Removed standalone web/comfyui/prompt.js (integrated into main widgets) Fixes comfy_dir path calculation by prioritizing folder_paths.base_path from ComfyUI when available, with fallback to computed path.
This commit is contained in:
@@ -15,7 +15,7 @@ class PromptLM:
|
||||
return {
|
||||
"required": {
|
||||
"text": (
|
||||
"AUTOCOMPLETE_TEXT_EMBEDDINGS",
|
||||
"AUTOCOMPLETE_TEXT_PROMPT",
|
||||
{
|
||||
"placeholder": "Enter prompt...",
|
||||
"tooltip": "The text to be encoded.",
|
||||
|
||||
@@ -1201,6 +1201,52 @@ class FileSystemHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class CustomWordsHandler:
|
||||
"""Handler for custom autocomplete words."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
from ...services.custom_words_service import get_custom_words_service
|
||||
self._service = get_custom_words_service()
|
||||
|
||||
async def get_custom_words(self, request: web.Request) -> web.Response:
|
||||
"""Get the content of the custom words file."""
|
||||
try:
|
||||
content = self._service.get_content()
|
||||
return web.Response(text=content, content_type="text/plain")
|
||||
except Exception as exc:
|
||||
logger.error("Error getting custom words: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def update_custom_words(self, request: web.Request) -> web.Response:
|
||||
"""Update the custom words file content."""
|
||||
try:
|
||||
content = await request.text()
|
||||
success = self._service.save_words(content)
|
||||
if success:
|
||||
return web.Response(status=200)
|
||||
else:
|
||||
return web.json_response({"error": "Failed to save custom words"}, status=500)
|
||||
except Exception as exc:
|
||||
logger.error("Error updating custom words: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def search_custom_words(self, request: web.Request) -> web.Response:
|
||||
"""Search custom words with autocomplete."""
|
||||
try:
|
||||
search_term = request.query.get("search", "")
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
|
||||
results = self._service.search_words(search_term, limit)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"words": results
|
||||
})
|
||||
except Exception as exc:
|
||||
logger.error("Error searching custom words: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class NodeRegistryHandler:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1427,6 +1473,7 @@ class MiscHandlerSet:
|
||||
model_library: ModelLibraryHandler,
|
||||
metadata_archive: MetadataArchiveHandler,
|
||||
filesystem: FileSystemHandler,
|
||||
custom_words: CustomWordsHandler,
|
||||
) -> None:
|
||||
self.health = health
|
||||
self.settings = settings
|
||||
@@ -1438,6 +1485,7 @@ class MiscHandlerSet:
|
||||
self.model_library = model_library
|
||||
self.metadata_archive = metadata_archive
|
||||
self.filesystem = filesystem
|
||||
self.custom_words = custom_words
|
||||
|
||||
def to_route_mapping(
|
||||
self,
|
||||
@@ -1465,6 +1513,9 @@ class MiscHandlerSet:
|
||||
"get_model_versions_status": self.model_library.get_model_versions_status,
|
||||
"open_file_location": self.filesystem.open_file_location,
|
||||
"open_settings_location": self.filesystem.open_settings_location,
|
||||
"get_custom_words": self.custom_words.get_custom_words,
|
||||
"update_custom_words": self.custom_words.update_custom_words,
|
||||
"search_custom_words": self.custom_words.search_custom_words,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -42,6 +42,9 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"),
|
||||
RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"),
|
||||
RouteDefinition("POST", "/api/lm/settings/open-location", "open_settings_location"),
|
||||
RouteDefinition("GET", "/api/lm/custom-words", "get_custom_words"),
|
||||
RouteDefinition("POST", "/api/lm/custom-words", "update_custom_words"),
|
||||
RouteDefinition("GET", "/api/lm/custom-words/search", "search_custom_words"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from ..services.settings_manager import get_settings_manager
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.usage_stats import UsageStats
|
||||
from .handlers.misc_handlers import (
|
||||
CustomWordsHandler,
|
||||
FileSystemHandler,
|
||||
HealthCheckHandler,
|
||||
LoraCodeHandler,
|
||||
@@ -117,6 +118,7 @@ class MiscRoutes:
|
||||
service_registry=self._service_registry_adapter,
|
||||
metadata_provider_factory=self._metadata_provider_factory,
|
||||
)
|
||||
custom_words = CustomWordsHandler()
|
||||
|
||||
return self._handler_set_factory(
|
||||
health=health,
|
||||
@@ -129,6 +131,7 @@ class MiscRoutes:
|
||||
model_library=model_library,
|
||||
metadata_archive=metadata_archive,
|
||||
filesystem=filesystem,
|
||||
custom_words=custom_words,
|
||||
)
|
||||
|
||||
|
||||
|
||||
267
py/services/custom_words_service.py
Normal file
267
py/services/custom_words_service.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Service for managing custom autocomplete words.
|
||||
|
||||
This service provides functionality to parse CSV-formatted custom words,
|
||||
search them with priority-based ranking, and manage storage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WordEntry:
|
||||
"""Represents a single custom word entry."""
|
||||
text: str
|
||||
priority: Optional[int] = None
|
||||
value: Optional[str] = None
|
||||
|
||||
def get_insert_text(self) -> str:
|
||||
"""Get the text to insert when this word is selected."""
|
||||
return self.value if self.value is not None else self.text
|
||||
|
||||
|
||||
class CustomWordsService:
|
||||
"""Service for managing custom autocomplete words.
|
||||
|
||||
This service:
|
||||
- Loads custom words from CSV files (sharing with pysssss plugin)
|
||||
- Parses CSV format: word[,priority] or word[,alias][,priority]
|
||||
- Searches words with priority-based ranking
|
||||
- Caches parsed words for performance
|
||||
"""
|
||||
|
||||
_instance: Optional[CustomWordsService] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls) -> CustomWordsService:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._words_cache: Dict[str, WordEntry] = {}
|
||||
self._file_path: Optional[Path] = None
|
||||
self._initialized = True
|
||||
|
||||
self._determine_file_path()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> CustomWordsService:
|
||||
"""Get the singleton instance of CustomWordsService."""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def _determine_file_path(self) -> None:
|
||||
"""Determine file path for custom words.
|
||||
|
||||
Priority order:
|
||||
1. pysssss plugin's user/autocomplete.txt (if exists)
|
||||
2. Lora Manager's user directory
|
||||
"""
|
||||
try:
|
||||
import folder_paths # type: ignore
|
||||
comfy_dir = Path(folder_paths.base_path)
|
||||
except (ImportError, AttributeError):
|
||||
# Fallback: compute from __file__
|
||||
comfy_dir = Path(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
pysssss_user_dir = comfy_dir / "custom_nodes" / "comfyui-custom-scripts" / "user"
|
||||
|
||||
if pysssss_user_dir.exists():
|
||||
pysssss_file = pysssss_user_dir / "autocomplete.txt"
|
||||
if pysssss_file.exists():
|
||||
self._file_path = pysssss_file
|
||||
logger.info(f"Using pysssss custom words file: {pysssss_file}")
|
||||
return
|
||||
|
||||
# Fallback to Lora Manager's user directory
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
settings_manager = get_settings_manager()
|
||||
lm_user_dir = Path(settings_manager._get_user_config_directory())
|
||||
lm_user_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._file_path = lm_user_dir / "autocomplete.txt"
|
||||
logger.info(f"Using Lora Manager custom words file: {self._file_path}")
|
||||
|
||||
def get_file_path(self) -> Optional[Path]:
|
||||
"""Get the current file path for custom words."""
|
||||
return self._file_path
|
||||
|
||||
def load_words(self) -> Dict[str, WordEntry]:
|
||||
"""Load and parse words from the custom words file.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping text to WordEntry objects.
|
||||
"""
|
||||
if self._file_path is None or not self._file_path.exists():
|
||||
self._words_cache = {}
|
||||
return self._words_cache
|
||||
|
||||
try:
|
||||
content = self._file_path.read_text(encoding="utf-8")
|
||||
self._words_cache = self._parse_csv_content(content)
|
||||
logger.debug(f"Loaded {len(self._words_cache)} custom words")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading custom words: {e}", exc_info=True)
|
||||
self._words_cache = {}
|
||||
|
||||
return self._words_cache
|
||||
|
||||
def _parse_csv_content(self, content: str) -> Dict[str, WordEntry]:
|
||||
"""Parse CSV content into word entries.
|
||||
|
||||
Supported formats:
|
||||
- word
|
||||
- word,priority
|
||||
|
||||
Args:
|
||||
content: CSV-formatted string with one word per line.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping text to WordEntry objects.
|
||||
"""
|
||||
words: Dict[str, WordEntry] = {}
|
||||
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
parts = line.split(",")
|
||||
parts = [p.strip() for p in parts if p.strip()]
|
||||
|
||||
if not parts:
|
||||
continue
|
||||
|
||||
text = parts[0]
|
||||
priority = None
|
||||
value = None
|
||||
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
priority = int(parts[1])
|
||||
except ValueError:
|
||||
# Not a priority, could be alias or unknown format
|
||||
pass
|
||||
|
||||
if text:
|
||||
words[text] = WordEntry(text=text, priority=priority, value=value)
|
||||
|
||||
return words
|
||||
|
||||
def search_words(self, search_term: str, limit: int = 20) -> List[str]:
|
||||
"""Search custom words with priority-based ranking.
|
||||
|
||||
Matching priority:
|
||||
1. Words with priority (sorted by priority descending)
|
||||
2. Prefix matches (word starts with search term)
|
||||
3. Include matches (word contains search term)
|
||||
|
||||
Args:
|
||||
search_term: The search term to match against.
|
||||
limit: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching word texts.
|
||||
"""
|
||||
words = self._words_cache if self._words_cache else self.load_words()
|
||||
|
||||
if not search_term:
|
||||
term_lower = ""
|
||||
else:
|
||||
term_lower = search_term.lower()
|
||||
|
||||
priority_matches = []
|
||||
prefix_matches = []
|
||||
include_matches = []
|
||||
|
||||
for text, entry in words.items():
|
||||
text_lower = text.lower()
|
||||
pos = text_lower.find(term_lower)
|
||||
|
||||
if pos == -1:
|
||||
continue
|
||||
|
||||
if entry.priority is not None:
|
||||
priority_matches.append((entry, pos))
|
||||
elif pos == 0:
|
||||
prefix_matches.append((entry, pos))
|
||||
else:
|
||||
include_matches.append((entry, pos))
|
||||
|
||||
# Sort priority matches: by priority desc, then by length asc, then alphabetically
|
||||
priority_matches.sort(
|
||||
key=lambda x: (-x[0].priority if x[0].priority else 0, len(x[0].text), x[0].text)
|
||||
)
|
||||
|
||||
# Sort prefix and include matches by position, then length, then alphabetically
|
||||
prefix_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text))
|
||||
include_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text))
|
||||
|
||||
# Combine results: 20% top priority + all prefix matches + rest of priority + all include
|
||||
top_priority_count = max(1, limit // 5)
|
||||
|
||||
results = (
|
||||
[entry.text for entry, _ in priority_matches[:top_priority_count]]
|
||||
+ [entry.text for entry, _ in prefix_matches]
|
||||
+ [entry.text for entry, _ in priority_matches[top_priority_count:]]
|
||||
+ [entry.text for entry, _ in include_matches]
|
||||
)
|
||||
|
||||
return results[:limit]
|
||||
|
||||
def save_words(self, content: str) -> bool:
|
||||
"""Save custom words content to file.
|
||||
|
||||
Args:
|
||||
content: CSV-formatted content to save.
|
||||
|
||||
Returns:
|
||||
True if save was successful, False otherwise.
|
||||
"""
|
||||
if self._file_path is None:
|
||||
logger.error("No file path configured for custom words")
|
||||
return False
|
||||
|
||||
try:
|
||||
self._file_path.write_text(content, encoding="utf-8")
|
||||
self._words_cache = self._parse_csv_content(content)
|
||||
logger.info(f"Saved {len(self._words_cache)} custom words")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving custom words: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_content(self) -> str:
|
||||
"""Get the raw content of the custom words file.
|
||||
|
||||
Returns:
|
||||
The file content as a string, or empty string if file doesn't exist.
|
||||
"""
|
||||
if self._file_path is None or not self._file_path.exists():
|
||||
return ""
|
||||
|
||||
try:
|
||||
return self._file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading custom words file: {e}", exc_info=True)
|
||||
return ""
|
||||
|
||||
|
||||
def get_custom_words_service() -> CustomWordsService:
|
||||
"""Factory function to get the CustomWordsService singleton."""
|
||||
return CustomWordsService.get_instance()
|
||||
|
||||
|
||||
__all__ = ["CustomWordsService", "WordEntry", "get_custom_words_service"]
|
||||
Reference in New Issue
Block a user