diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index 5e1397e3..efc7bade 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -1202,34 +1202,12 @@ class FileSystemHandler: class CustomWordsHandler: - """Handler for custom autocomplete words.""" + """Handler for autocomplete via TagFTSIndex.""" def __init__(self) -> None: from ...services.custom_words_service import get_custom_words_service self._service = get_custom_words_service() - async def get_custom_words(self, request: web.Request) -> web.Response: - """Get the content of the custom words file.""" - try: - content = self._service.get_content() - return web.Response(text=content, content_type="text/plain") - except Exception as exc: - logger.error("Error getting custom words: %s", exc, exc_info=True) - return web.json_response({"error": str(exc)}, status=500) - - async def update_custom_words(self, request: web.Request) -> web.Response: - """Update the custom words file content.""" - try: - content = await request.text() - success = self._service.save_words(content) - if success: - return web.Response(status=200) - else: - return web.json_response({"error": "Failed to save custom words"}, status=500) - except Exception as exc: - logger.error("Error updating custom words: %s", exc, exc_info=True) - return web.json_response({"error": str(exc)}, status=500) - async def search_custom_words(self, request: web.Request) -> web.Response: """Search custom words with autocomplete. @@ -1563,8 +1541,6 @@ class MiscHandlerSet: "get_model_versions_status": self.model_library.get_model_versions_status, "open_file_location": self.filesystem.open_file_location, "open_settings_location": self.filesystem.open_settings_location, - "get_custom_words": self.custom_words.get_custom_words, - "update_custom_words": self.custom_words.update_custom_words, "search_custom_words": self.custom_words.search_custom_words, } diff --git a/py/routes/misc_route_registrar.py b/py/routes/misc_route_registrar.py index 883ff111..1d6c4e7e 100644 --- a/py/routes/misc_route_registrar.py +++ b/py/routes/misc_route_registrar.py @@ -42,8 +42,6 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"), RouteDefinition("GET", "/api/lm/model-versions-status", "get_model_versions_status"), RouteDefinition("POST", "/api/lm/settings/open-location", "open_settings_location"), - RouteDefinition("GET", "/api/lm/custom-words", "get_custom_words"), - RouteDefinition("POST", "/api/lm/custom-words", "update_custom_words"), RouteDefinition("GET", "/api/lm/custom-words/search", "search_custom_words"), ) diff --git a/py/services/custom_words_service.py b/py/services/custom_words_service.py index 5825f3c8..7e0666db 100644 --- a/py/services/custom_words_service.py +++ b/py/services/custom_words_service.py @@ -1,44 +1,25 @@ -"""Service for managing custom autocomplete words. +"""Service for managing autocomplete via TagFTSIndex. -This service provides functionality to parse CSV-formatted custom words, -search them with priority-based ranking, and manage storage. - -It also integrates with TagFTSIndex to search the Danbooru/e621 tag database -for comprehensive autocomplete suggestions with category filtering. +This service provides full-text search capabilities for Danbooru/e621 tags +with category filtering and enriched results including post counts. """ from __future__ import annotations import logging -import os -from dataclasses import dataclass -from pathlib import Path -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Any, Optional logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class WordEntry: - """Represents a single custom word entry.""" - text: str - priority: Optional[int] = None - value: Optional[str] = None - - def get_insert_text(self) -> str: - """Get the text to insert when this word is selected.""" - return self.value if self.value is not None else self.text - - class CustomWordsService: - """Service for managing custom autocomplete words. + """Service for autocomplete via TagFTSIndex. This service: - - Loads custom words from CSV files (sharing with pysssss plugin) - - Parses CSV format: word[,priority] or word[,alias][,priority] - - Searches words with priority-based ranking - - Caches parsed words for performance - - Integrates with TagFTSIndex for Danbooru/e621 tag search + - Uses TagFTSIndex for fast full-text search of Danbooru/e621 tags + - Supports category-based filtering + - Returns enriched results with category and post_count + - Provides sub-100ms search times for 221k+ tags """ _instance: Optional[CustomWordsService] = None @@ -53,13 +34,9 @@ class CustomWordsService: if self._initialized: return - self._words_cache: Dict[str, WordEntry] = {} - self._file_path: Optional[Path] = None - self._tag_index: Optional[Any] = None # Lazy-loaded TagFTSIndex + self._tag_index: Optional[Any] = None self._initialized = True - self._determine_file_path() - @classmethod def get_instance(cls) -> CustomWordsService: """Get the singleton instance of CustomWordsService.""" @@ -67,42 +44,6 @@ class CustomWordsService: cls._instance = cls() return cls._instance - def _determine_file_path(self) -> None: - """Determine file path for custom words. - - Priority order: - 1. pysssss plugin's user/autocomplete.txt (if exists) - 2. Lora Manager's user directory - """ - try: - import folder_paths # type: ignore - comfy_dir = Path(folder_paths.base_path) - except (ImportError, AttributeError): - # Fallback: compute from __file__ - comfy_dir = Path(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - - pysssss_user_dir = comfy_dir / "custom_nodes" / "comfyui-custom-scripts" / "user" - - if pysssss_user_dir.exists(): - pysssss_file = pysssss_user_dir / "autocomplete.txt" - if pysssss_file.exists(): - self._file_path = pysssss_file - logger.info(f"Using pysssss custom words file: {pysssss_file}") - return - - # Fallback to Lora Manager's user directory - from .settings_manager import get_settings_manager - - settings_manager = get_settings_manager() - lm_user_dir = Path(settings_manager._get_user_config_directory()) - lm_user_dir.mkdir(parents=True, exist_ok=True) - self._file_path = lm_user_dir / "autocomplete.txt" - logger.info(f"Using Lora Manager custom words file: {self._file_path}") - - def get_file_path(self) -> Optional[Path]: - """Get the current file path for custom words.""" - return self._file_path - def _get_tag_index(self): """Get or create the TagFTSIndex instance (lazy initialization).""" if self._tag_index is None: @@ -114,198 +55,32 @@ class CustomWordsService: self._tag_index = None return self._tag_index - def load_words(self) -> Dict[str, WordEntry]: - """Load and parse words from the custom words file. - - Returns: - Dictionary mapping text to WordEntry objects. - """ - if self._file_path is None or not self._file_path.exists(): - self._words_cache = {} - return self._words_cache - - try: - content = self._file_path.read_text(encoding="utf-8") - self._words_cache = self._parse_csv_content(content) - logger.debug(f"Loaded {len(self._words_cache)} custom words") - except Exception as e: - logger.error(f"Error loading custom words: {e}", exc_info=True) - self._words_cache = {} - - return self._words_cache - - def _parse_csv_content(self, content: str) -> Dict[str, WordEntry]: - """Parse CSV content into word entries. - - Supported formats: - - word - - word,priority - - Args: - content: CSV-formatted string with one word per line. - - Returns: - Dictionary mapping text to WordEntry objects. - """ - words: Dict[str, WordEntry] = {} - - for line in content.splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - - parts = line.split(",") - parts = [p.strip() for p in parts if p.strip()] - - if not parts: - continue - - text = parts[0] - priority = None - value = None - - if len(parts) == 2: - try: - priority = int(parts[1]) - except ValueError: - # Not a priority, could be alias or unknown format - pass - - if text: - words[text] = WordEntry(text=text, priority=priority, value=value) - - return words - def search_words( self, search_term: str, limit: int = 20, categories: Optional[List[int]] = None, enriched: bool = False - ) -> Union[List[str], List[Dict[str, Any]]]: - """Search custom words with priority-based ranking. - - When categories are provided or enriched is True, uses TagFTSIndex to search - the Danbooru/e621 tag database and returns enriched results with category - and post_count. - - Matching priority (for custom words): - 1. Words with priority (sorted by priority descending) - 2. Prefix matches (word starts with search term) - 3. Include matches (word contains search term) + ) -> List[Dict[str, Any]]: + """Search tags using TagFTSIndex with category filtering. Args: search_term: The search term to match against. limit: Maximum number of results to return. categories: Optional list of category IDs to filter by. - When provided, searches TagFTSIndex instead of custom words. - enriched: If True, return enriched results even without category filtering. + enriched: If True, always return enriched results with category + and post_count (default behavior now). Returns: - List of matching word texts (when categories is None and enriched is False), or - List of dicts with tag_name, category, post_count (when categories is provided - or enriched is True). + List of dicts with tag_name, category, and post_count. """ - # Use TagFTSIndex when categories are specified or when explicitly requested tag_index = self._get_tag_index() if tag_index is not None: - # Search the tag database results = tag_index.search(search_term, categories=categories, limit=limit) - if results: - # If categories were specified or enriched requested, return enriched results - if categories is not None or enriched: - return results - # Otherwise, convert to simple string list for backward compatibility - return [r["tag_name"] for r in results] - # Fall through to custom words if no tag results + return results - # Fall back to custom words search - words = self._words_cache if self._words_cache else self.load_words() - - if not search_term: - term_lower = "" - else: - term_lower = search_term.lower() - - priority_matches = [] - prefix_matches = [] - include_matches = [] - - for text, entry in words.items(): - text_lower = text.lower() - pos = text_lower.find(term_lower) - - if pos == -1: - continue - - if entry.priority is not None: - priority_matches.append((entry, pos)) - elif pos == 0: - prefix_matches.append((entry, pos)) - else: - include_matches.append((entry, pos)) - - # Sort priority matches: by priority desc, then by length asc, then alphabetically - priority_matches.sort( - key=lambda x: (-x[0].priority if x[0].priority else 0, len(x[0].text), x[0].text) - ) - - # Sort prefix and include matches by position, then length, then alphabetically - prefix_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text)) - include_matches.sort(key=lambda x: (x[1], len(x[0].text), x[0].text)) - - # Combine results: 20% top priority + all prefix matches + rest of priority + all include - top_priority_count = max(1, limit // 5) - - text_results = ( - [entry.text for entry, _ in priority_matches[:top_priority_count]] - + [entry.text for entry, _ in prefix_matches] - + [entry.text for entry, _ in priority_matches[top_priority_count:]] - + [entry.text for entry, _ in include_matches] - ) - - # If categories were requested but tag index failed, return empty enriched format - if categories is not None: - return [{"tag_name": t, "category": 0, "post_count": 0} for t in text_results[:limit]] - - return text_results[:limit] - - def save_words(self, content: str) -> bool: - """Save custom words content to file. - - Args: - content: CSV-formatted content to save. - - Returns: - True if save was successful, False otherwise. - """ - if self._file_path is None: - logger.error("No file path configured for custom words") - return False - - try: - self._file_path.write_text(content, encoding="utf-8") - self._words_cache = self._parse_csv_content(content) - logger.info(f"Saved {len(self._words_cache)} custom words") - return True - except Exception as e: - logger.error(f"Error saving custom words: {e}", exc_info=True) - return False - - def get_content(self) -> str: - """Get the raw content of the custom words file. - - Returns: - The file content as a string, or empty string if file doesn't exist. - """ - if self._file_path is None or not self._file_path.exists(): - return "" - - try: - return self._file_path.read_text(encoding="utf-8") - except Exception as e: - logger.error(f"Error reading custom words file: {e}", exc_info=True) - return "" + logger.debug("TagFTSIndex not available, returning empty results") + return [] def get_custom_words_service() -> CustomWordsService: @@ -313,4 +88,4 @@ def get_custom_words_service() -> CustomWordsService: return CustomWordsService.get_instance() -__all__ = ["CustomWordsService", "WordEntry", "get_custom_words_service"] +__all__ = ["CustomWordsService", "get_custom_words_service"] diff --git a/tests/test_custom_words_service.py b/tests/test_custom_words_service.py index 19406bda..c43836de 100644 --- a/tests/test_custom_words_service.py +++ b/tests/test_custom_words_service.py @@ -1,62 +1,8 @@ -"""Tests for CustomWordsService.""" +"""Tests for CustomWordsService with TagFTSIndex integration.""" import pytest -from tempfile import NamedTemporaryFile -from pathlib import Path -from py.services.custom_words_service import CustomWordsService, WordEntry, get_custom_words_service - - -@pytest.fixture -def temp_autocomplete_file(): - """Create a temporary autocomplete.txt file.""" - import os - import tempfile - fd, path = tempfile.mkstemp(suffix='.txt') - try: - os.write(fd, b"""# Comment line -girl,4114588 -solo,3426446 -highres,3008413 -long_hair,2898315 -masterpiece,1588202 -best_quality,1588202 -blue_eyes,1000000 -red_eyes,500000 -simple_background -""") - finally: - os.close(fd) - yield Path(path) - os.unlink(path) - - -@pytest.fixture -def service(temp_autocomplete_file, monkeypatch): - """Create a CustomWordsService instance with temporary file.""" - # Monkey patch to use temp file - service = CustomWordsService.__new__(CustomWordsService) - - def mock_determine_path(): - service._file_path = temp_autocomplete_file - - monkeypatch.setattr(CustomWordsService, '_determine_file_path', mock_determine_path) - monkeypatch.setattr(service, '_file_path', temp_autocomplete_file) - - service.load_words() - return service - - -class TestWordEntry: - """Test WordEntry dataclass.""" - - def test_get_insert_text_with_value(self): - entry = WordEntry(text='alias_name', value='real_name') - assert entry.get_insert_text() == 'real_name' - - def test_get_insert_text_without_value(self): - entry = WordEntry(text='simple_word') - assert entry.get_insert_text() == 'simple_word' +from py.services.custom_words_service import CustomWordsService, get_custom_words_service class TestCustomWordsService: @@ -67,131 +13,99 @@ class TestCustomWordsService: service2 = get_custom_words_service() assert service1 is service2 - def test_parse_csv_content_basic(self): - service = CustomWordsService.__new__(CustomWordsService) - words = service._parse_csv_content("""word1 -word2 -word3 -""") - assert len(words) == 3 - assert 'word1' in words - assert 'word2' in words - assert 'word3' in words - - def test_parse_csv_content_with_priority(self): - service = CustomWordsService.__new__(CustomWordsService) - words = service._parse_csv_content("""word1,100 -word2,50 -word3,10 -""") - assert len(words) == 3 - assert words['word1'].priority == 100 - assert words['word2'].priority == 50 - assert words['word3'].priority == 10 - - def test_parse_csv_content_ignores_comments(self): - service = CustomWordsService.__new__(CustomWordsService) - words = service._parse_csv_content("""# This is a comment -word1 -# Another comment -word2 -""") - assert len(words) == 2 - assert 'word1' in words - assert 'word2' in words - - def test_parse_csv_content_ignores_empty_lines(self): - service = CustomWordsService.__new__(CustomWordsService) - words = service._parse_csv_content(""" -word1 - -word2 - -""") - assert len(words) == 2 - assert 'word1' in words - assert 'word2' in words - - def test_parse_csv_content_handles_whitespace(self): - service = CustomWordsService.__new__(CustomWordsService) - words = service._parse_csv_content(""" word1 - word2,50 -""") - assert len(words) == 2 - assert 'word1' in words - assert 'word2' in words - assert words['word2'].priority == 50 - - def test_load_words(self, temp_autocomplete_file): - service = CustomWordsService.__new__(CustomWordsService) - service._file_path = temp_autocomplete_file - words = service.load_words() - # Expect 9 words due to tempfile encoding quirks - assert 8 <= len(words) <= 9 - # Check for either '1girl' or 'girl' depending on encoding - assert '1girl' in words or 'girl' in words - assert 'solo' in words - if '1girl' in words: - assert words['1girl'].priority == 4114588 - if 'girl' in words: - assert words['girl'].priority == 4114588 - assert words['solo'].priority == 3426446 - - def test_search_words_empty_term(self, service): - results = service.search_words('') - # File may have encoding issues, so accept 8-20 words - assert 8 <= len(results) <= 20 # Limited to max of 20 - - def test_search_words_prefix_match(self, service): - results = service.search_words('lon') - assert len(results) > 0 - # Check for '1girl' or 'girl' depending on encoding - assert 'long_hair' in results - # long_hair should come first as prefix match - assert results.index('long_hair') == 0 - - def test_search_words_include_match(self, service): - results = service.search_words('hair') - assert len(results) > 0 - assert 'long_hair' in results - - def test_search_words_priority_sorting(self, service): - results = service.search_words('eye') - assert len(results) > 0 - assert 'blue_eyes' in results - assert 'red_eyes' in results - # Higher priority should come first - assert results.index('blue_eyes') < results.index('red_eyes') - - def test_search_words_respects_limit(self, service): - results = service.search_words('', limit=5) - assert len(results) <= 5 - - def test_save_words(self, tmp_path, monkeypatch): - temp_file = tmp_path / 'test_autocomplete.txt' + def test_search_words_without_tag_index(self): service = CustomWordsService.__new__(CustomWordsService) - monkeypatch.setattr(service, '_file_path', temp_file) + def mock_get_index(): + return None - content = 'test_word,100' - success = service.save_words(content) - assert success is True - assert temp_file.exists() + service._get_tag_index = mock_get_index - saved_content = temp_file.read_text(encoding='utf-8') - assert saved_content == content + results = service.search_words("test", limit=10) + assert results == [] - def test_get_content_no_file(self, tmp_path, monkeypatch): - non_existent_file = tmp_path / 'nonexistent.txt' + def test_search_words_with_tag_index(self): service = CustomWordsService.__new__(CustomWordsService) - monkeypatch.setattr(service, '_file_path', non_existent_file) - content = service.get_content() - assert content == '' + mock_tag_index = MockTagFTSIndex() - def test_get_content_with_file(self, temp_autocomplete_file, monkeypatch): + def mock_get_index(): + return mock_tag_index + + service._get_tag_index = mock_get_index + + results = service.search_words("miku", limit=20) + assert len(results) == 2 + assert results[0]["tag_name"] == "hatsune_miku" + assert results[0]["category"] == 4 + assert results[0]["post_count"] == 500000 + + def test_search_words_with_category_filter(self): service = CustomWordsService.__new__(CustomWordsService) - monkeypatch.setattr(service, '_file_path', temp_autocomplete_file) - content = service.get_content() - # Content may have escaped newlines in string representation - assert 'girl' in content or '1girl' in content - assert 'solo' in content + mock_tag_index = MockTagFTSIndex() + + def mock_get_index(): + return mock_tag_index + + service._get_tag_index = mock_get_index + + results = service.search_words("miku", categories=[4, 11], limit=20) + assert len(results) == 2 + assert results[0]["tag_name"] == "hatsune_miku" + assert results[0]["category"] == 4 + assert results[1]["tag_name"] == "hatsune_miku_(vocaloid)" + assert results[1]["category"] == 4 + + def test_search_words_respects_limit(self): + service = CustomWordsService.__new__(CustomWordsService) + mock_tag_index = MockTagFTSIndex() + + def mock_get_index(): + return mock_tag_index + + service._get_tag_index = mock_get_index + + results = service.search_words("miku", limit=1) + assert len(results) <= 1 + + def test_search_words_empty_term(self): + service = CustomWordsService.__new__(CustomWordsService) + mock_tag_index = MockTagFTSIndex() + + def mock_get_index(): + return mock_tag_index + + service._get_tag_index = mock_get_index + + results = service.search_words("", limit=20) + assert results == [] + + def test_search_words_uses_tag_index(self): + service = CustomWordsService.__new__(CustomWordsService) + mock_tag_index = MockTagFTSIndex() + + def mock_get_index(): + return mock_tag_index + + service._get_tag_index = mock_get_index + + results = service.search_words("test") + assert mock_tag_index.called + + +class MockTagFTSIndex: + """Mock TagFTSIndex for testing.""" + + def __init__(self): + self.called = False + self._results = [ + {"tag_name": "hatsune_miku", "category": 4, "post_count": 500000}, + {"tag_name": "hatsune_miku_(vocaloid)", "category": 4, "post_count": 250000}, + ] + + def search(self, query, categories=None, limit=20): + self.called = True + if not query: + return [] + if categories: + return [r for r in self._results if r["category"] in categories][:limit] + return self._results[:limit]