diff --git a/py/services/tag_fts_index.py b/py/services/tag_fts_index.py index c3f6cc9b..867179db 100644 --- a/py/services/tag_fts_index.py +++ b/py/services/tag_fts_index.py @@ -2,6 +2,10 @@ This module provides fast tag search using SQLite's FTS5 extension, enabling sub-100ms search times for 221k+ Danbooru/e621 tags. + +Supports alias search: when a user searches for an alias (e.g., "miku"), +the system returns the canonical tag (e.g., "hatsune_miku") and indicates +which alias was matched. """ from __future__ import annotations @@ -20,6 +24,9 @@ from ..utils.cache_paths import CacheType, resolve_cache_path_with_migration logger = logging.getLogger(__name__) +# Schema version for tracking migrations +SCHEMA_VERSION = 2 # Version 2: Added aliases support + # Category definitions for Danbooru and e621 CATEGORY_NAMES = { @@ -131,19 +138,25 @@ class TagFTSIndex: conn = self._connect() try: conn.execute("PRAGMA journal_mode=WAL") + + # Check if we need to migrate from old schema + needs_rebuild = self._check_and_migrate_schema(conn) + conn.executescript(""" -- FTS5 virtual table for full-text search + -- searchable_text contains "tag_name alias1 alias2 ..." for alias matching CREATE VIRTUAL TABLE IF NOT EXISTS tag_fts USING fts5( - tag_name, + searchable_text, tokenize='unicode61 remove_diacritics 2' ); - -- Tags table with metadata + -- Tags table with metadata and aliases CREATE TABLE IF NOT EXISTS tags ( rowid INTEGER PRIMARY KEY, tag_name TEXT UNIQUE NOT NULL, category INTEGER NOT NULL DEFAULT 0, - post_count INTEGER NOT NULL DEFAULT 0 + post_count INTEGER NOT NULL DEFAULT 0, + aliases TEXT DEFAULT '' ); -- Indexes for efficient filtering @@ -156,19 +169,77 @@ class TagFTSIndex: value TEXT ); """) + + # Set schema version + conn.execute( + "INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)", + ("schema_version", str(SCHEMA_VERSION)) + ) conn.commit() + self._schema_initialized = True + self._needs_rebuild = needs_rebuild logger.debug("Tag FTS index schema initialized at %s", self._db_path) finally: conn.close() except Exception as exc: logger.error("Failed to initialize tag FTS schema: %s", exc) + def _check_and_migrate_schema(self, conn: sqlite3.Connection) -> bool: + """Check schema version and migrate if necessary. + + Returns: + True if the index needs to be rebuilt, False otherwise. + """ + try: + # Check if fts_metadata table exists + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='fts_metadata'" + ) + if not cursor.fetchone(): + return False # Fresh database, no migration needed + + # Check schema version + cursor = conn.execute( + "SELECT value FROM fts_metadata WHERE key='schema_version'" + ) + row = cursor.fetchone() + if not row: + # Old schema without version, needs rebuild + logger.info("Migrating tag FTS index to schema version %d (adding alias support)", SCHEMA_VERSION) + self._drop_old_tables(conn) + return True + + current_version = int(row[0]) + if current_version < SCHEMA_VERSION: + logger.info("Migrating tag FTS index from version %d to %d", current_version, SCHEMA_VERSION) + self._drop_old_tables(conn) + return True + + return False + except Exception as exc: + logger.warning("Error checking schema version: %s", exc) + return False + + def _drop_old_tables(self, conn: sqlite3.Connection) -> None: + """Drop old tables for schema migration.""" + try: + conn.executescript(""" + DROP TABLE IF EXISTS tag_fts; + DROP TABLE IF EXISTS tags; + """) + conn.commit() + except Exception as exc: + logger.warning("Error dropping old tables: %s", exc) + def build_index(self) -> None: """Build the FTS index from the CSV file. This method parses the danbooru_e621_merged.csv file and creates - the FTS index for fast searching. + the FTS index for fast searching. The CSV format is: + tag_name,category,post_count,aliases + + Where aliases is a comma-separated string (e.g., "miku,vocaloid_miku,39"). """ if self._indexing_in_progress: logger.warning("Tag FTS indexing already in progress, skipping") @@ -201,6 +272,7 @@ class TagFTSIndex: batch_size = 500 rows = [] total_inserted = 0 + tags_with_aliases = 0 with open(self._csv_path, "r", encoding="utf-8") as f: reader = csv.reader(f) @@ -222,7 +294,12 @@ class TagFTSIndex: except (ValueError, IndexError): post_count = 0 - rows.append((tag_name, category, post_count)) + # Parse aliases from column 4 (if present) + aliases = row[3].strip() if len(row) >= 4 else "" + if aliases: + tags_with_aliases += 1 + + rows.append((tag_name, category, post_count, aliases)) if len(rows) >= batch_size: self._insert_batch(conn, rows) @@ -243,10 +320,17 @@ class TagFTSIndex: "INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)", ("tag_count", str(total_inserted)) ) + conn.execute( + "INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)", + ("schema_version", str(SCHEMA_VERSION)) + ) conn.commit() elapsed = time.time() - start_time - logger.info("Tag FTS index built: %d tags indexed in %.2fs", total_inserted, elapsed) + logger.info( + "Tag FTS index built: %d tags indexed (%d with aliases) in %.2fs", + total_inserted, tags_with_aliases, elapsed + ) finally: conn.close() @@ -258,14 +342,22 @@ class TagFTSIndex: self._indexing_in_progress = False def _insert_batch(self, conn: sqlite3.Connection, rows: List[tuple]) -> None: - """Insert a batch of rows into the database.""" - # Insert into tags table + """Insert a batch of rows into the database. + + Each row is a tuple of (tag_name, category, post_count, aliases). + The FTS searchable_text is built as "tag_name alias1 alias2 ..." for alias matching. + """ + # Insert into tags table (with aliases) conn.executemany( - "INSERT OR IGNORE INTO tags (tag_name, category, post_count) VALUES (?, ?, ?)", + "INSERT OR IGNORE INTO tags (tag_name, category, post_count, aliases) VALUES (?, ?, ?, ?)", rows ) - # Get rowids and insert into FTS table + # Build a map of tag_name -> aliases for FTS insertion + aliases_map = {row[0]: row[3] for row in rows} + + # Get rowids and insert into FTS table with explicit rowid + # to ensure tags.rowid matches tag_fts.rowid for JOINs tag_names = [row[0] for row in rows] placeholders = ",".join("?" * len(tag_names)) cursor = conn.execute( @@ -273,9 +365,27 @@ class TagFTSIndex: tag_names ) - fts_rows = [(tag_name,) for rowid, tag_name in cursor.fetchall()] + # Build FTS rows with (rowid, searchable_text) = (tags.rowid, "tag_name alias1 alias2 ...") + fts_rows = [] + for rowid, tag_name in cursor.fetchall(): + aliases = aliases_map.get(tag_name, "") + if aliases: + # Replace commas with spaces to create searchable text + # Strip "/" prefix from aliases as it's an FTS5 special character + alias_parts = [] + for alias in aliases.split(","): + alias = alias.strip() + if alias.startswith("/"): + alias = alias[1:] # Remove leading slash + if alias: + alias_parts.append(alias) + searchable_text = f"{tag_name} {' '.join(alias_parts)}" if alias_parts else tag_name + else: + searchable_text = tag_name + fts_rows.append((rowid, searchable_text)) + if fts_rows: - conn.executemany("INSERT INTO tag_fts (tag_name) VALUES (?)", fts_rows) + conn.executemany("INSERT INTO tag_fts (rowid, searchable_text) VALUES (?, ?)", fts_rows) def ensure_ready(self) -> bool: """Ensure the index is ready, building if necessary. @@ -289,6 +399,13 @@ class TagFTSIndex: # Check if index already exists and has data self.initialize() if self._schema_initialized: + # Check if schema migration requires rebuild + if getattr(self, "_needs_rebuild", False): + logger.info("Schema migration requires index rebuild") + self._needs_rebuild = False + self.build_index() + return self.is_ready() + count = self.get_indexed_count() if count > 0: self._ready.set() @@ -307,13 +424,17 @@ class TagFTSIndex: ) -> List[Dict]: """Search tags using FTS5 with prefix matching. + Supports alias search: if the query matches an alias rather than + the tag_name, the result will include a "matched_alias" field. + Args: query: The search query string. categories: Optional list of category IDs to filter by. limit: Maximum number of results to return. Returns: - List of dictionaries with tag_name, category, and post_count. + List of dictionaries with tag_name, category, post_count, + and optionally matched_alias. """ # Ensure index is ready (lazy initialization) if not self.ensure_ready(): @@ -333,14 +454,15 @@ class TagFTSIndex: with self._lock: conn = self._connect(readonly=True) try: - # Build the SQL query + # Build the SQL query - now also fetch aliases for matched_alias detection + # Use subquery for category filter to ensure FTS is evaluated first if categories: placeholders = ",".join("?" * len(categories)) sql = f""" - SELECT t.tag_name, t.category, t.post_count + SELECT t.tag_name, t.category, t.post_count, t.aliases FROM tags t - WHERE t.tag_name IN ( - SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ? + WHERE t.rowid IN ( + SELECT rowid FROM tag_fts WHERE searchable_text MATCH ? ) AND t.category IN ({placeholders}) ORDER BY t.post_count DESC @@ -349,11 +471,10 @@ class TagFTSIndex: params = [fts_query] + categories + [limit] else: sql = """ - SELECT t.tag_name, t.category, t.post_count - FROM tags t - WHERE t.tag_name IN ( - SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ? - ) + SELECT t.tag_name, t.category, t.post_count, t.aliases + FROM tag_fts f + JOIN tags t ON f.rowid = t.rowid + WHERE f.searchable_text MATCH ? ORDER BY t.post_count DESC LIMIT ? """ @@ -362,11 +483,18 @@ class TagFTSIndex: cursor = conn.execute(sql, params) results = [] for row in cursor.fetchall(): - results.append({ + result = { "tag_name": row[0], "category": row[1], "post_count": row[2], - }) + } + + # Check if search matched an alias rather than the tag_name + matched_alias = self._find_matched_alias(query, row[0], row[3]) + if matched_alias: + result["matched_alias"] = matched_alias + + results.append(result) return results finally: conn.close() @@ -374,6 +502,59 @@ class TagFTSIndex: logger.debug("Tag FTS search error for query '%s': %s", query, exc) return [] + def _find_matched_alias(self, query: str, tag_name: str, aliases_str: str) -> Optional[str]: + """Find which alias matched the query, if any. + + Args: + query: The original search query. + tag_name: The canonical tag name. + aliases_str: Comma-separated string of aliases. + + Returns: + The matched alias string, or None if the query matched the tag_name directly. + """ + query_lower = query.lower().strip() + if not query_lower: + return None + + # Strip leading "/" from query if present (FTS index strips these) + query_normalized = query_lower.lstrip("/") + + # Check if query matches tag_name prefix (direct match, no alias needed) + if tag_name.lower().startswith(query_normalized): + return None + + # Check aliases first - if query matches an alias or a word within an alias, return it + if aliases_str: + for alias in aliases_str.split(","): + alias = alias.strip() + if not alias: + continue + # Normalize alias for comparison (strip leading slash) + alias_normalized = alias.lower().lstrip("/") + + # Check if alias starts with query + if alias_normalized.startswith(query_normalized): + return alias # Return original alias (with "/" if present) + + # Check if any word within the alias starts with query + # (mirrors FTS5 tokenization which splits on underscores) + alias_words = alias_normalized.replace("_", " ").split() + for word in alias_words: + if word.startswith(query_normalized): + return alias + + # If no alias matched, check if query matches a word in tag_name + # (handles cases like "long_hair" matching "long" - no alias indicator needed) + tag_words = tag_name.lower().replace("_", " ").split() + for word in tag_words: + if word.startswith(query_normalized): + return None + + # Query matched via FTS but not tag_name words or aliases + # This shouldn't normally happen, but return None for safety + return None + def get_indexed_count(self) -> int: """Return the number of tags currently indexed.""" if not self._schema_initialized: @@ -458,14 +639,15 @@ class TagFTSIndex: def _escape_fts_query(self, text: str) -> str: """Escape special FTS5 characters. - FTS5 special characters: " ( ) * : ^ - + FTS5 special characters: " ( ) * : ^ - / We keep * for prefix matching but escape others. """ if not text: return "" # Replace FTS5 special characters with space - special = ['"', "(", ")", "*", ":", "^", "-", "{", "}", "[", "]"] + # Note: "/" is special in FTS5 (column filter syntax), so we strip it + special = ['"', "(", ")", "*", ":", "^", "-", "{", "}", "[", "]", "/"] result = text for char in special: result = result.replace(char, " ") diff --git a/tests/test_tag_fts_index.py b/tests/test_tag_fts_index.py index e0c2e76f..73789352 100644 --- a/tests/test_tag_fts_index.py +++ b/tests/test_tag_fts_index.py @@ -188,6 +188,171 @@ class TestTagFTSIndexSearch: assert len(results) <= 1 +class TestAliasSearch: + """Tests for alias search functionality.""" + + @pytest.fixture + def populated_fts(self, temp_db_path, temp_csv_path): + """Create a populated FTS index.""" + fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path) + fts.build_index() + return fts + + def test_search_by_alias_returns_canonical_tag(self, populated_fts): + """Test that searching by alias returns the canonical tag with matched_alias.""" + # Search for "miku" which is an alias for "hatsune_miku" + results = populated_fts.search("miku") + + assert len(results) >= 1 + hatsune_result = next((r for r in results if r["tag_name"] == "hatsune_miku"), None) + assert hatsune_result is not None + assert hatsune_result["matched_alias"] == "miku" + + def test_search_by_canonical_name_no_matched_alias(self, populated_fts): + """Test that searching by canonical name does not set matched_alias.""" + # Search for "hatsune" which directly matches "hatsune_miku" + results = populated_fts.search("hatsune") + + assert len(results) >= 1 + hatsune_result = next((r for r in results if r["tag_name"] == "hatsune_miku"), None) + assert hatsune_result is not None + assert "matched_alias" not in hatsune_result + + def test_search_by_prefix_alias(self, populated_fts): + """Test prefix matching on aliases.""" + # "1girls" is an alias for "1girl" - search by prefix "1gir" + results = populated_fts.search("1gir") + + assert len(results) >= 1 + result = next((r for r in results if r["tag_name"] == "1girl"), None) + assert result is not None + # Should not have matched_alias since "1girl" starts with "1gir" + assert "matched_alias" not in result + + def test_alias_search_with_category_filter(self, populated_fts): + """Test that alias search works with category filtering.""" + # Search for "youmu" (alias for konpaku_youmu) with character category filter + results = populated_fts.search("youmu", categories=[4, 11]) + + assert len(results) >= 1 + result = results[0] + assert result["tag_name"] == "konpaku_youmu" + assert result["matched_alias"] == "youmu" + assert result["category"] in [4, 11] + + def test_tag_without_aliases_still_works(self, populated_fts): + """Test that tags without aliases still work correctly.""" + # "artist_request" has no aliases + results = populated_fts.search("artist_req") + + assert len(results) >= 1 + result = next((r for r in results if r["tag_name"] == "artist_request"), None) + assert result is not None + assert "matched_alias" not in result + + def test_multiple_aliases_first_match_returned(self, populated_fts): + """Test that when multiple aliases could match, the first one is returned.""" + # "highres" has aliases: "high_res,high_resolution,hires" + # Searching for "high_r" should match "high_res" first + results = populated_fts.search("high_r") + + assert len(results) >= 1 + highres_result = next((r for r in results if r["tag_name"] == "highres"), None) + assert highres_result is not None + assert highres_result["matched_alias"] == "high_res" + + def test_search_by_short_alias(self, populated_fts): + """Test searching by a short alias.""" + # "/lh" style short aliases - using "hires" which is short for highres + results = populated_fts.search("hires") + + assert len(results) >= 1 + result = next((r for r in results if r["tag_name"] == "highres"), None) + assert result is not None + assert result["matched_alias"] == "hires" + + def test_search_by_word_within_alias(self, populated_fts): + """Test searching by a word within a compound alias like 'sole_female'.""" + # "sole_female" is an alias for "1girl" + # Searching "female" should match "1girl" with matched_alias "sole_female" + results = populated_fts.search("female") + + assert len(results) >= 1 + result = next((r for r in results if r["tag_name"] == "1girl"), None) + assert result is not None + assert result["matched_alias"] == "sole_female" + + def test_search_by_second_word_in_alias(self, populated_fts): + """Test that searching for second word in underscore-separated alias works.""" + # "female_solo" is an alias for "solo" + # Searching "solo" would match the tag directly, but let's test another case + # "single" is an alias for "solo" - straightforward match + results = populated_fts.search("single") + + assert len(results) >= 1 + result = next((r for r in results if r["tag_name"] == "solo"), None) + assert result is not None + assert result["matched_alias"] == "single" + + +class TestSlashPrefixAliases: + """Tests for slash-prefixed alias search (e.g., /lh for long_hair).""" + + @pytest.fixture + def fts_with_slash_aliases(self, temp_db_path): + """Create an FTS index with slash-prefixed aliases.""" + with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, encoding="utf-8") as f: + # Format: tag_name,category,post_count,aliases + f.write('long_hair,0,4350743,"/lh,longhair,very_long_hair"\n') + f.write('breasts,0,3439214,"/b,boobs,oppai"\n') + f.write('short_hair,0,1500000,"/sh,shorthair"\n') + csv_path = f.name + + try: + fts = TagFTSIndex(db_path=temp_db_path, csv_path=csv_path) + fts.build_index() + yield fts + finally: + if os.path.exists(csv_path): + os.unlink(csv_path) + + def test_search_slash_alias_with_slash(self, fts_with_slash_aliases): + """Test searching with slash prefix returns correct result.""" + results = fts_with_slash_aliases.search("/lh") + + assert len(results) >= 1 + result = results[0] + assert result["tag_name"] == "long_hair" + assert result["matched_alias"] == "/lh" + + def test_search_slash_alias_without_slash(self, fts_with_slash_aliases): + """Test searching without slash prefix also works.""" + results = fts_with_slash_aliases.search("lh") + + assert len(results) >= 1 + result = results[0] + assert result["tag_name"] == "long_hair" + assert result["matched_alias"] == "/lh" + + def test_search_regular_alias_still_works(self, fts_with_slash_aliases): + """Test that non-slash aliases still work.""" + results = fts_with_slash_aliases.search("longhair") + + assert len(results) >= 1 + result = results[0] + assert result["tag_name"] == "long_hair" + assert result["matched_alias"] == "longhair" + + def test_direct_tag_name_search(self, fts_with_slash_aliases): + """Test that direct tag name search doesn't show alias.""" + results = fts_with_slash_aliases.search("long_hair") + + assert len(results) >= 1 + result = results[0] + assert result["tag_name"] == "long_hair" + assert "matched_alias" not in result + + class TestTagFTSIndexClear: """Tests for clearing the FTS index.""" diff --git a/web/comfyui/autocomplete.js b/web/comfyui/autocomplete.js index e0fe7ebb..11653538 100644 --- a/web/comfyui/autocomplete.js +++ b/web/comfyui/autocomplete.js @@ -841,14 +841,31 @@ class AutoComplete { /** * Render an enriched autocomplete item with category badge and post count * @param {HTMLElement} itemEl - The item element to populate - * @param {Object} itemData - The enriched item data { tag_name, category, post_count } + * @param {Object} itemData - The enriched item data { tag_name, category, post_count, matched_alias? } * @param {string} searchTerm - The current search term for highlighting */ _renderEnrichedItem(itemEl, itemData, searchTerm) { // Create name span with highlighted match const nameSpan = document.createElement('span'); nameSpan.className = 'lm-autocomplete-name'; - nameSpan.innerHTML = this.highlightMatch(itemData.tag_name, searchTerm); + + // If matched via alias, show: "tag_name ← alias" with alias highlighted + if (itemData.matched_alias) { + const tagText = document.createTextNode(itemData.tag_name + ' '); + nameSpan.appendChild(tagText); + + const aliasSpan = document.createElement('span'); + aliasSpan.className = 'lm-matched-alias'; + aliasSpan.innerHTML = '← ' + this.highlightMatch(itemData.matched_alias, searchTerm); + aliasSpan.style.cssText = ` + font-size: 11px; + color: rgba(226, 232, 240, 0.5); + `; + nameSpan.appendChild(aliasSpan); + } else { + nameSpan.innerHTML = this.highlightMatch(itemData.tag_name, searchTerm); + } + nameSpan.style.cssText = ` flex: 1; min-width: 0;