feat(autocomplete): add Danbooru/e621 tag search with category filtering

- Add TagFTSIndex service for fast SQLite FTS5-based tag search (221k+ tags)
- Implement command-mode autocomplete: /char, /artist, /general, /meta, etc.
- Support category filtering via category IDs or names
- Return enriched results with post counts and category badges
- Add UI styling for category badges and command list dropdown
This commit is contained in:
Will Miao
2026-01-26 13:51:45 +08:00
parent 6142b3dc0c
commit 42f35be9d3
8 changed files with 223183 additions and 33 deletions

View File

@@ -0,0 +1,69 @@
# Danbooru/E621 Tag Categories Reference
Reference for category values used in `danbooru_e621_merged.csv` tag files.
## Category Value Mapping
### Danbooru Categories
| Value | Description |
|-------|-------------|
| 0 | General |
| 1 | Artist |
| 2 | *(unused)* |
| 3 | Copyright |
| 4 | Character |
| 5 | Meta |
### e621 Categories
| Value | Description |
|-------|-------------|
| 6 | *(unused)* |
| 7 | General |
| 8 | Artist |
| 9 | Contributor |
| 10 | Copyright |
| 11 | Character |
| 12 | Species |
| 13 | *(unused)* |
| 14 | Meta |
| 15 | Lore |
## Danbooru Category Colors
| Description | Normal Color | Hover Color |
|-------------|--------------|-------------|
| General | #009be6 | #4bb4ff |
| Artist | #ff8a8b | #ffc3c3 |
| Copyright | #c797ff | #ddc9fb |
| Character | #35c64a | #93e49a |
| Meta | #ead084 | #f7e7c3 |
## CSV Column Structure
Each row in the merged CSV file contains 4 columns:
| Column | Description | Example |
|--------|-------------|---------|
| 1 | Tag name | `1girl`, `highres`, `solo` |
| 2 | Category value (0-15) | `0`, `5`, `7` |
| 3 | Post count | `6008644`, `5256195` |
| 4 | Aliases (comma-separated, quoted) | `"1girls,sole_female"`, empty string |
### Sample Data
```
1girl,0,6008644,"1girls,sole_female"
highres,5,5256195,"high_res,high_resolution,hires"
solo,0,5000954,"alone,female_solo,single,solo_female"
long_hair,0,4350743,"/lh,longhair"
mammal,12,3437444,"cetancodont,cetancodontamorph,feralmammal"
anthro,7,3381927,"adult_anthro,anhtro,antho,anthro_horse"
skirt,0,1557883,
```
## Source
- [PR #312: Add danbooru_e621_merged.csv](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete/pull/312)
- [DraconicDragon/dbr-e621-lists-archive](https://github.com/DraconicDragon/dbr-e621-lists-archive)

View File

@@ -1231,12 +1231,31 @@ class CustomWordsHandler:
return web.json_response({"error": str(exc)}, status=500)
async def search_custom_words(self, request: web.Request) -> web.Response:
"""Search custom words with autocomplete."""
"""Search custom words with autocomplete.
Query parameters:
search: The search term to match against.
limit: Maximum number of results to return (default: 20).
category: Optional category filter. Can be:
- A category name (e.g., "character", "artist", "general")
- Comma-separated category IDs (e.g., "4,11" for character)
enriched: If "true", return enriched results with category and post_count
even without category filtering.
"""
try:
search_term = request.query.get("search", "")
limit = int(request.query.get("limit", "20"))
category_param = request.query.get("category", "")
enriched_param = request.query.get("enriched", "").lower() == "true"
results = self._service.search_words(search_term, limit)
# Parse category parameter
categories = None
if category_param:
categories = self._parse_category_param(category_param)
results = self._service.search_words(
search_term, limit, categories=categories, enriched=enriched_param
)
return web.json_response({
"success": True,
@@ -1246,6 +1265,37 @@ class CustomWordsHandler:
logger.error("Error searching custom words: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
def _parse_category_param(self, param: str) -> list[int] | None:
"""Parse category parameter into list of category IDs.
Args:
param: Category parameter value (name or comma-separated IDs).
Returns:
List of category IDs, or None if parsing fails.
"""
from ...services.tag_fts_index import CATEGORY_NAME_TO_IDS
param = param.strip().lower()
if not param:
return None
# Try to parse as category name first
if param in CATEGORY_NAME_TO_IDS:
return CATEGORY_NAME_TO_IDS[param]
# Try to parse as comma-separated integers
try:
category_ids = []
for part in param.split(","):
part = part.strip()
if part:
category_ids.append(int(part))
return category_ids if category_ids else None
except ValueError:
logger.debug("Invalid category parameter: %s", param)
return None
class NodeRegistryHandler:
def __init__(

View File

@@ -2,6 +2,9 @@
This service provides functionality to parse CSV-formatted custom words,
search them with priority-based ranking, and manage storage.
It also integrates with TagFTSIndex to search the Danbooru/e621 tag database
for comprehensive autocomplete suggestions with category filtering.
"""
from __future__ import annotations
@@ -10,7 +13,7 @@ import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Any, Optional
from typing import List, Dict, Any, Optional, Union
logger = logging.getLogger(__name__)
@@ -35,6 +38,7 @@ class CustomWordsService:
- Parses CSV format: word[,priority] or word[,alias][,priority]
- Searches words with priority-based ranking
- Caches parsed words for performance
- Integrates with TagFTSIndex for Danbooru/e621 tag search
"""
_instance: Optional[CustomWordsService] = None
@@ -51,6 +55,7 @@ class CustomWordsService:
self._words_cache: Dict[str, WordEntry] = {}
self._file_path: Optional[Path] = None
self._tag_index: Optional[Any] = None # Lazy-loaded TagFTSIndex
self._initialized = True
self._determine_file_path()
@@ -98,6 +103,17 @@ class CustomWordsService:
"""Get the current file path for custom words."""
return self._file_path
def _get_tag_index(self):
"""Get or create the TagFTSIndex instance (lazy initialization)."""
if self._tag_index is None:
try:
from .tag_fts_index import get_tag_fts_index
self._tag_index = get_tag_fts_index()
except Exception as e:
logger.warning(f"Failed to initialize TagFTSIndex: {e}")
self._tag_index = None
return self._tag_index
def load_words(self) -> Dict[str, WordEntry]:
"""Load and parse words from the custom words file.
@@ -160,10 +176,20 @@ class CustomWordsService:
return words
def search_words(self, search_term: str, limit: int = 20) -> List[str]:
def search_words(
self,
search_term: str,
limit: int = 20,
categories: Optional[List[int]] = None,
enriched: bool = False
) -> Union[List[str], List[Dict[str, Any]]]:
"""Search custom words with priority-based ranking.
Matching priority:
When categories are provided or enriched is True, uses TagFTSIndex to search
the Danbooru/e621 tag database and returns enriched results with category
and post_count.
Matching priority (for custom words):
1. Words with priority (sorted by priority descending)
2. Prefix matches (word starts with search term)
3. Include matches (word contains search term)
@@ -171,10 +197,29 @@ class CustomWordsService:
Args:
search_term: The search term to match against.
limit: Maximum number of results to return.
categories: Optional list of category IDs to filter by.
When provided, searches TagFTSIndex instead of custom words.
enriched: If True, return enriched results even without category filtering.
Returns:
List of matching word texts.
List of matching word texts (when categories is None and enriched is False), or
List of dicts with tag_name, category, post_count (when categories is provided
or enriched is True).
"""
# Use TagFTSIndex when categories are specified or when explicitly requested
tag_index = self._get_tag_index()
if tag_index is not None:
# Search the tag database
results = tag_index.search(search_term, categories=categories, limit=limit)
if results:
# If categories were specified or enriched requested, return enriched results
if categories is not None or enriched:
return results
# Otherwise, convert to simple string list for backward compatibility
return [r["tag_name"] for r in results]
# Fall through to custom words if no tag results
# Fall back to custom words search
words = self._words_cache if self._words_cache else self.load_words()
if not search_term:
@@ -212,14 +257,18 @@ class CustomWordsService:
# Combine results: 20% top priority + all prefix matches + rest of priority + all include
top_priority_count = max(1, limit // 5)
results = (
text_results = (
[entry.text for entry, _ in priority_matches[:top_priority_count]]
+ [entry.text for entry, _ in prefix_matches]
+ [entry.text for entry, _ in priority_matches[top_priority_count:]]
+ [entry.text for entry, _ in include_matches]
)
return results[:limit]
# If categories were requested but tag index failed, return empty enriched format
if categories is not None:
return [{"tag_name": t, "category": 0, "post_count": 0} for t in text_results[:limit]]
return text_results[:limit]
def save_words(self, content: str) -> bool:
"""Save custom words content to file.

View File

@@ -0,0 +1,504 @@
"""SQLite FTS5-based full-text search index for tags.
This module provides fast tag search using SQLite's FTS5 extension,
enabling sub-100ms search times for 221k+ Danbooru/e621 tags.
"""
from __future__ import annotations
import csv
import logging
import os
import re
import sqlite3
import threading
import time
from pathlib import Path
from typing import Dict, List, Optional, Set
from ..utils.settings_paths import get_settings_dir
logger = logging.getLogger(__name__)
# Category definitions for Danbooru and e621
CATEGORY_NAMES = {
# Danbooru categories
0: "general",
1: "artist",
3: "copyright",
4: "character",
5: "meta",
# e621 categories
7: "general",
8: "artist",
10: "copyright",
11: "character",
12: "species",
14: "meta",
15: "lore",
}
# Map category names to their IDs (for filtering)
CATEGORY_NAME_TO_IDS = {
"general": [0, 7],
"artist": [1, 8],
"copyright": [3, 10],
"character": [4, 11],
"meta": [5, 14],
"species": [12],
"lore": [15],
}
class TagFTSIndex:
"""SQLite FTS5-based full-text search index for tags.
Provides fast prefix-based search across the Danbooru/e621 tag database.
Supports category-based filtering and returns enriched results with
post counts and category information.
"""
_DEFAULT_FILENAME = "tag_fts.sqlite"
_CSV_FILENAME = "danbooru_e621_merged.csv"
def __init__(self, db_path: Optional[str] = None, csv_path: Optional[str] = None) -> None:
"""Initialize the FTS index.
Args:
db_path: Optional path to the SQLite database file.
If not provided, uses the default location in settings directory.
csv_path: Optional path to the CSV file containing tag data.
If not provided, looks in the refs/ directory.
"""
self._db_path = db_path or self._resolve_default_db_path()
self._csv_path = csv_path or self._resolve_default_csv_path()
self._lock = threading.Lock()
self._ready = threading.Event()
self._indexing_in_progress = False
self._schema_initialized = False
self._warned_not_ready = False
# Ensure directory exists
try:
directory = os.path.dirname(self._db_path)
if directory:
os.makedirs(directory, exist_ok=True)
except Exception as exc:
logger.warning("Could not create FTS index directory %s: %s", directory, exc)
def _resolve_default_db_path(self) -> str:
"""Resolve the default database path."""
override = os.environ.get("LORA_MANAGER_TAG_FTS_DB")
if override:
return override
try:
settings_dir = get_settings_dir(create=True)
except Exception as exc:
logger.warning("Falling back to current directory for FTS index: %s", exc)
settings_dir = "."
return os.path.join(settings_dir, self._DEFAULT_FILENAME)
def _resolve_default_csv_path(self) -> str:
"""Resolve the default CSV file path."""
# Look for the CSV in the refs/ directory relative to the package
package_dir = Path(__file__).parent.parent.parent
csv_path = package_dir / "refs" / self._CSV_FILENAME
return str(csv_path)
def get_database_path(self) -> str:
"""Return the resolved database path."""
return self._db_path
def get_csv_path(self) -> str:
"""Return the resolved CSV path."""
return self._csv_path
def is_ready(self) -> bool:
"""Check if the FTS index is ready for queries."""
return self._ready.is_set()
def is_indexing(self) -> bool:
"""Check if indexing is currently in progress."""
return self._indexing_in_progress
def initialize(self) -> None:
"""Initialize the database schema."""
if self._schema_initialized:
return
with self._lock:
if self._schema_initialized:
return
try:
conn = self._connect()
try:
conn.execute("PRAGMA journal_mode=WAL")
conn.executescript("""
-- FTS5 virtual table for full-text search
CREATE VIRTUAL TABLE IF NOT EXISTS tag_fts USING fts5(
tag_name,
tokenize='unicode61 remove_diacritics 2'
);
-- Tags table with metadata
CREATE TABLE IF NOT EXISTS tags (
rowid INTEGER PRIMARY KEY,
tag_name TEXT UNIQUE NOT NULL,
category INTEGER NOT NULL DEFAULT 0,
post_count INTEGER NOT NULL DEFAULT 0
);
-- Indexes for efficient filtering
CREATE INDEX IF NOT EXISTS idx_tags_category ON tags(category);
CREATE INDEX IF NOT EXISTS idx_tags_post_count ON tags(post_count DESC);
-- Index version tracking
CREATE TABLE IF NOT EXISTS fts_metadata (
key TEXT PRIMARY KEY,
value TEXT
);
""")
conn.commit()
self._schema_initialized = True
logger.debug("Tag FTS index schema initialized at %s", self._db_path)
finally:
conn.close()
except Exception as exc:
logger.error("Failed to initialize tag FTS schema: %s", exc)
def build_index(self) -> None:
"""Build the FTS index from the CSV file.
This method parses the danbooru_e621_merged.csv file and creates
the FTS index for fast searching.
"""
if self._indexing_in_progress:
logger.warning("Tag FTS indexing already in progress, skipping")
return
if not os.path.exists(self._csv_path):
logger.warning("CSV file not found at %s, cannot build tag index", self._csv_path)
return
self._indexing_in_progress = True
self._ready.clear()
start_time = time.time()
try:
self.initialize()
if not self._schema_initialized:
logger.error("Cannot build tag FTS index: schema not initialized")
return
with self._lock:
conn = self._connect()
try:
conn.execute("BEGIN")
# Clear existing data
conn.execute("DELETE FROM tag_fts")
conn.execute("DELETE FROM tags")
# Parse CSV and insert in batches
batch_size = 500
rows = []
total_inserted = 0
with open(self._csv_path, "r", encoding="utf-8") as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 3:
continue
tag_name = row[0].strip()
if not tag_name:
continue
try:
category = int(row[1])
except (ValueError, IndexError):
category = 0
try:
post_count = int(row[2])
except (ValueError, IndexError):
post_count = 0
rows.append((tag_name, category, post_count))
if len(rows) >= batch_size:
self._insert_batch(conn, rows)
total_inserted += len(rows)
rows = []
# Insert remaining rows
if rows:
self._insert_batch(conn, rows)
total_inserted += len(rows)
# Update metadata
conn.execute(
"INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)",
("last_build_time", str(time.time()))
)
conn.execute(
"INSERT OR REPLACE INTO fts_metadata (key, value) VALUES (?, ?)",
("tag_count", str(total_inserted))
)
conn.commit()
elapsed = time.time() - start_time
logger.info("Tag FTS index built: %d tags indexed in %.2fs", total_inserted, elapsed)
finally:
conn.close()
self._ready.set()
except Exception as exc:
logger.error("Failed to build tag FTS index: %s", exc, exc_info=True)
finally:
self._indexing_in_progress = False
def _insert_batch(self, conn: sqlite3.Connection, rows: List[tuple]) -> None:
"""Insert a batch of rows into the database."""
# Insert into tags table
conn.executemany(
"INSERT OR IGNORE INTO tags (tag_name, category, post_count) VALUES (?, ?, ?)",
rows
)
# Get rowids and insert into FTS table
tag_names = [row[0] for row in rows]
placeholders = ",".join("?" * len(tag_names))
cursor = conn.execute(
f"SELECT rowid, tag_name FROM tags WHERE tag_name IN ({placeholders})",
tag_names
)
fts_rows = [(tag_name,) for rowid, tag_name in cursor.fetchall()]
if fts_rows:
conn.executemany("INSERT INTO tag_fts (tag_name) VALUES (?)", fts_rows)
def ensure_ready(self) -> bool:
"""Ensure the index is ready, building if necessary.
Returns:
True if the index is ready, False otherwise.
"""
if self.is_ready():
return True
# Check if index already exists and has data
self.initialize()
if self._schema_initialized:
count = self.get_indexed_count()
if count > 0:
self._ready.set()
logger.debug("Tag FTS index already populated with %d tags", count)
return True
# Build the index
self.build_index()
return self.is_ready()
def search(
self,
query: str,
categories: Optional[List[int]] = None,
limit: int = 20
) -> List[Dict]:
"""Search tags using FTS5 with prefix matching.
Args:
query: The search query string.
categories: Optional list of category IDs to filter by.
limit: Maximum number of results to return.
Returns:
List of dictionaries with tag_name, category, and post_count.
"""
# Ensure index is ready (lazy initialization)
if not self.ensure_ready():
if not self._warned_not_ready:
logger.debug("Tag FTS index not ready, returning empty results")
self._warned_not_ready = True
return []
if not query or not query.strip():
return []
fts_query = self._build_fts_query(query)
if not fts_query:
return []
try:
with self._lock:
conn = self._connect(readonly=True)
try:
# Build the SQL query
if categories:
placeholders = ",".join("?" * len(categories))
sql = f"""
SELECT t.tag_name, t.category, t.post_count
FROM tags t
WHERE t.tag_name IN (
SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ?
)
AND t.category IN ({placeholders})
ORDER BY t.post_count DESC
LIMIT ?
"""
params = [fts_query] + categories + [limit]
else:
sql = """
SELECT t.tag_name, t.category, t.post_count
FROM tags t
WHERE t.tag_name IN (
SELECT tag_name FROM tag_fts WHERE tag_fts MATCH ?
)
ORDER BY t.post_count DESC
LIMIT ?
"""
params = [fts_query, limit]
cursor = conn.execute(sql, params)
results = []
for row in cursor.fetchall():
results.append({
"tag_name": row[0],
"category": row[1],
"post_count": row[2],
})
return results
finally:
conn.close()
except Exception as exc:
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
return []
def get_indexed_count(self) -> int:
"""Return the number of tags currently indexed."""
if not self._schema_initialized:
return 0
try:
with self._lock:
conn = self._connect(readonly=True)
try:
cursor = conn.execute("SELECT COUNT(*) FROM tags")
result = cursor.fetchone()
return result[0] if result else 0
finally:
conn.close()
except Exception:
return 0
def clear(self) -> bool:
"""Clear all data from the FTS index.
Returns:
True if successful, False otherwise.
"""
try:
with self._lock:
conn = self._connect()
try:
conn.execute("DELETE FROM tag_fts")
conn.execute("DELETE FROM tags")
conn.commit()
self._ready.clear()
return True
finally:
conn.close()
except Exception as exc:
logger.error("Failed to clear tag FTS index: %s", exc)
return False
# Internal helpers
def _connect(self, readonly: bool = False) -> sqlite3.Connection:
"""Create a database connection."""
uri = False
path = self._db_path
if readonly:
if not os.path.exists(path):
raise FileNotFoundError(path)
path = f"file:{path}?mode=ro"
uri = True
conn = sqlite3.connect(path, check_same_thread=False, uri=uri)
conn.row_factory = sqlite3.Row
return conn
def _build_fts_query(self, query: str) -> str:
"""Build an FTS5 query string with prefix matching.
Args:
query: The user's search query.
Returns:
FTS5 query string.
"""
# Split query into words and clean them
words = query.lower().split()
if not words:
return ""
# Escape and add prefix wildcard to each word
prefix_terms = []
for word in words:
escaped = self._escape_fts_query(word)
if escaped:
# Add prefix wildcard for substring-like matching
prefix_terms.append(f"{escaped}*")
if not prefix_terms:
return ""
# Combine terms with implicit AND (all words must match)
return " ".join(prefix_terms)
def _escape_fts_query(self, text: str) -> str:
"""Escape special FTS5 characters.
FTS5 special characters: " ( ) * : ^ -
We keep * for prefix matching but escape others.
"""
if not text:
return ""
# Replace FTS5 special characters with space
special = ['"', "(", ")", "*", ":", "^", "-", "{", "}", "[", "]"]
result = text
for char in special:
result = result.replace(char, " ")
# Collapse multiple spaces and strip
result = re.sub(r"\s+", " ", result).strip()
return result
# Singleton instance
_tag_fts_index: Optional[TagFTSIndex] = None
_tag_fts_lock = threading.Lock()
def get_tag_fts_index() -> TagFTSIndex:
"""Get the singleton TagFTSIndex instance."""
global _tag_fts_index
if _tag_fts_index is None:
with _tag_fts_lock:
if _tag_fts_index is None:
_tag_fts_index = TagFTSIndex()
return _tag_fts_index
__all__ = [
"TagFTSIndex",
"get_tag_fts_index",
"CATEGORY_NAMES",
"CATEGORY_NAME_TO_IDS",
]

221787
refs/danbooru_e621_merged.csv Normal file

File diff suppressed because one or more lines are too long

262
tests/test_tag_fts_index.py Normal file
View File

@@ -0,0 +1,262 @@
"""Tests for TagFTSIndex functionality."""
import os
import tempfile
from typing import List
import pytest
from py.services.tag_fts_index import (
TagFTSIndex,
CATEGORY_NAMES,
CATEGORY_NAME_TO_IDS,
)
@pytest.fixture
def temp_db_path():
"""Create a temporary database path."""
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as f:
path = f.name
yield path
# Cleanup
if os.path.exists(path):
os.unlink(path)
for suffix in ["-wal", "-shm"]:
wal_path = path + suffix
if os.path.exists(wal_path):
os.unlink(wal_path)
@pytest.fixture
def temp_csv_path():
"""Create a temporary CSV file with test data."""
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False, encoding="utf-8") as f:
# Write test data in the same format as danbooru_e621_merged.csv
# Format: tag_name,category,post_count,aliases
f.write('1girl,0,6008644,"1girls,sole_female"\n')
f.write('highres,5,5256195,"high_res,high_resolution,hires"\n')
f.write('solo,0,5000954,"alone,female_solo,single"\n')
f.write('hatsune_miku,4,500000,"miku"\n')
f.write('konpaku_youmu,4,150000,"youmu"\n')
f.write('artist_request,1,100000,""\n')
f.write('touhou,3,300000,"touhou_project"\n')
f.write('mammal,12,3437444,"cetancodont"\n')
f.write('anthro,7,3381927,"anthropomorphic"\n')
f.write('hi_res,14,3116617,"high_res"\n')
path = f.name
yield path
# Cleanup
if os.path.exists(path):
os.unlink(path)
class TestTagFTSIndexBasic:
"""Basic tests for TagFTSIndex initialization and schema."""
def test_initialize_creates_tables(self, temp_db_path, temp_csv_path):
"""Test that initialization creates required tables."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
fts.initialize()
assert fts._schema_initialized is True
def test_get_database_path(self, temp_db_path, temp_csv_path):
"""Test get_database_path returns correct path."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
assert fts.get_database_path() == temp_db_path
def test_get_csv_path(self, temp_db_path, temp_csv_path):
"""Test get_csv_path returns correct path."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
assert fts.get_csv_path() == temp_csv_path
def test_is_ready_initially_false(self, temp_db_path, temp_csv_path):
"""Test that is_ready returns False before building index."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
assert fts.is_ready() is False
class TestTagFTSIndexBuild:
"""Tests for building the FTS index."""
def test_build_index_from_csv(self, temp_db_path, temp_csv_path):
"""Test building index from CSV file."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
fts.build_index()
assert fts.is_ready() is True
assert fts.get_indexed_count() == 10
def test_build_index_nonexistent_csv(self, temp_db_path):
"""Test that build_index handles missing CSV gracefully."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path="/nonexistent/path.csv")
fts.build_index()
assert fts.is_ready() is False
assert fts.get_indexed_count() == 0
def test_ensure_ready_builds_index(self, temp_db_path, temp_csv_path):
"""Test that ensure_ready builds index if not ready."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
# Initially not ready
assert fts.is_ready() is False
# ensure_ready should build the index
result = fts.ensure_ready()
assert result is True
assert fts.is_ready() is True
class TestTagFTSIndexSearch:
"""Tests for searching the FTS index."""
@pytest.fixture
def populated_fts(self, temp_db_path, temp_csv_path):
"""Create a populated FTS index."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
fts.build_index()
return fts
def test_search_basic(self, populated_fts):
"""Test basic search functionality."""
results = populated_fts.search("1girl")
assert len(results) >= 1
assert any(r["tag_name"] == "1girl" for r in results)
def test_search_prefix(self, populated_fts):
"""Test prefix matching."""
results = populated_fts.search("hatsu")
assert len(results) >= 1
assert any(r["tag_name"] == "hatsune_miku" for r in results)
def test_search_returns_enriched_results(self, populated_fts):
"""Test that search returns enriched results with category and post_count."""
results = populated_fts.search("miku")
assert len(results) >= 1
result = results[0]
assert "tag_name" in result
assert "category" in result
assert "post_count" in result
assert result["tag_name"] == "hatsune_miku"
assert result["category"] == 4 # Character category
assert result["post_count"] == 500000
def test_search_with_category_filter(self, populated_fts):
"""Test searching with category filter."""
# Search for character tags only (categories 4 and 11)
results = populated_fts.search("konpaku", categories=[4, 11])
assert len(results) >= 1
assert all(r["category"] in [4, 11] for r in results)
def test_search_with_category_filter_excludes_others(self, populated_fts):
"""Test that category filter excludes other categories."""
# Search for "hi" but only in general category
results = populated_fts.search("hi", categories=[0, 7])
# Should not include "highres" (meta category 5) or "hi_res" (meta category 14)
assert all(r["category"] in [0, 7] for r in results)
def test_search_empty_query_returns_empty(self, populated_fts):
"""Test that empty query returns empty results."""
results = populated_fts.search("")
assert results == []
def test_search_no_matches_returns_empty(self, populated_fts):
"""Test that query with no matches returns empty results."""
results = populated_fts.search("zzzznonexistent")
assert results == []
def test_search_results_sorted_by_post_count(self, populated_fts):
"""Test that results are sorted by post_count descending."""
results = populated_fts.search("1girl", limit=10)
# Verify results are sorted by post_count descending
post_counts = [r["post_count"] for r in results]
assert post_counts == sorted(post_counts, reverse=True)
def test_search_limit(self, populated_fts):
"""Test search result limiting."""
results = populated_fts.search("girl", limit=1)
assert len(results) <= 1
class TestTagFTSIndexClear:
"""Tests for clearing the FTS index."""
def test_clear_removes_all_data(self, temp_db_path, temp_csv_path):
"""Test that clear removes all indexed data."""
fts = TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
fts.build_index()
assert fts.get_indexed_count() > 0
fts.clear()
assert fts.get_indexed_count() == 0
assert fts.is_ready() is False
class TestCategoryMappings:
"""Tests for category name mappings."""
def test_category_names_complete(self):
"""Test that CATEGORY_NAMES includes all expected categories."""
expected_categories = [0, 1, 3, 4, 5, 7, 8, 10, 11, 12, 14, 15]
for cat in expected_categories:
assert cat in CATEGORY_NAMES
def test_category_name_to_ids_complete(self):
"""Test that CATEGORY_NAME_TO_IDS includes all expected names."""
expected_names = ["general", "artist", "copyright", "character", "meta", "species", "lore"]
for name in expected_names:
assert name in CATEGORY_NAME_TO_IDS
assert isinstance(CATEGORY_NAME_TO_IDS[name], list)
assert len(CATEGORY_NAME_TO_IDS[name]) > 0
def test_category_name_to_ids_includes_both_platforms(self):
"""Test that category mappings include both Danbooru and e621 IDs where applicable."""
# General should have both Danbooru (0) and e621 (7)
assert 0 in CATEGORY_NAME_TO_IDS["general"]
assert 7 in CATEGORY_NAME_TO_IDS["general"]
# Character should have both Danbooru (4) and e621 (11)
assert 4 in CATEGORY_NAME_TO_IDS["character"]
assert 11 in CATEGORY_NAME_TO_IDS["character"]
class TestFTSQueryBuilding:
"""Tests for FTS query building."""
@pytest.fixture
def fts_instance(self, temp_db_path, temp_csv_path):
"""Create an FTS instance for testing."""
return TagFTSIndex(db_path=temp_db_path, csv_path=temp_csv_path)
def test_build_fts_query_simple(self, fts_instance):
"""Test FTS query building with simple query."""
query = fts_instance._build_fts_query("test")
assert query == "test*"
def test_build_fts_query_multiple_words(self, fts_instance):
"""Test FTS query building with multiple words."""
query = fts_instance._build_fts_query("test query")
assert query == "test* query*"
def test_build_fts_query_escapes_special_chars(self, fts_instance):
"""Test that special characters are escaped."""
query = fts_instance._build_fts_query("test:query")
# Colon should be replaced with space
assert ":" not in query
def test_build_fts_query_empty_returns_empty(self, fts_instance):
"""Test that empty query returns empty string."""
query = fts_instance._build_fts_query("")
assert query == ""

View File

@@ -3,6 +3,46 @@ import { app } from "../../scripts/app.js";
import { TextAreaCaretHelper } from "./textarea_caret_helper.js";
import { getPromptCustomWordsAutocompletePreference } from "./settings.js";
// Command definitions for category filtering
const TAG_COMMANDS = {
'/character': { categories: [4, 11], label: 'Character' },
'/char': { categories: [4, 11], label: 'Character' },
'/artist': { categories: [1, 8], label: 'Artist' },
'/general': { categories: [0, 7], label: 'General' },
'/copyright': { categories: [3, 10], label: 'Copyright' },
'/meta': { categories: [5, 14], label: 'Meta' },
'/species': { categories: [12], label: 'Species' },
'/lore': { categories: [15], label: 'Lore' },
'/emb': { type: 'embedding', label: 'Embeddings' },
'/embedding': { type: 'embedding', label: 'Embeddings' },
};
// Category display information
const CATEGORY_INFO = {
0: { bg: 'rgba(0, 155, 230, 0.2)', text: '#4bb4ff', label: 'General' },
1: { bg: 'rgba(255, 138, 139, 0.2)', text: '#ffc3c3', label: 'Artist' },
3: { bg: 'rgba(199, 151, 255, 0.2)', text: '#ddc9fb', label: 'Copyright' },
4: { bg: 'rgba(53, 198, 74, 0.2)', text: '#93e49a', label: 'Character' },
5: { bg: 'rgba(234, 208, 132, 0.2)', text: '#f7e7c3', label: 'Meta' },
7: { bg: 'rgba(0, 155, 230, 0.2)', text: '#4bb4ff', label: 'General' },
8: { bg: 'rgba(255, 138, 139, 0.2)', text: '#ffc3c3', label: 'Artist' },
10: { bg: 'rgba(199, 151, 255, 0.2)', text: '#ddc9fb', label: 'Copyright' },
11: { bg: 'rgba(53, 198, 74, 0.2)', text: '#93e49a', label: 'Character' },
12: { bg: 'rgba(237, 137, 54, 0.2)', text: '#f6ad55', label: 'Species' },
14: { bg: 'rgba(234, 208, 132, 0.2)', text: '#f7e7c3', label: 'Meta' },
15: { bg: 'rgba(72, 187, 120, 0.2)', text: '#68d391', label: 'Lore' },
};
// Format post count with K/M suffix
function formatPostCount(count) {
if (count >= 1000000) {
return (count / 1000000).toFixed(1).replace(/\.0$/, '') + 'M';
} else if (count >= 1000) {
return (count / 1000).toFixed(1).replace(/\.0$/, '') + 'K';
}
return count.toString();
}
function parseUsageTipNumber(value) {
if (typeof value === 'number' && Number.isFinite(value)) {
return value;
@@ -224,6 +264,10 @@ class AutoComplete {
this.previewTooltipPromise = null;
this.searchType = null;
// Command mode state
this.activeCommand = null; // Current active command (e.g., { categories: [4, 11], label: 'Character' })
this.showingCommands = false; // Whether showing command list dropdown
// Initialize TextAreaCaretHelper
this.helper = new TextAreaCaretHelper(inputElement, () => app.canvas.ds.scale);
@@ -425,11 +469,43 @@ class AutoComplete {
endpoint = '/lm/embeddings/relative-paths';
searchTerm = (match[1] || '').trim();
this.searchType = 'embeddings';
this.activeCommand = null;
this.showingCommands = false;
} else if (getPromptCustomWordsAutocompletePreference()) {
// Setting enabled - allow custom words search
endpoint = '/lm/custom-words/search';
searchTerm = rawSearchTerm;
this.searchType = 'custom_words';
// Setting enabled - check for command mode
const commandResult = this._parseCommandInput(rawSearchTerm);
if (commandResult.showCommands) {
// Show command list dropdown
this.showingCommands = true;
this.activeCommand = null;
this.searchType = 'commands';
this._showCommandList(commandResult.commandFilter);
return;
} else if (commandResult.command) {
// Command is active, use filtered search
this.showingCommands = false;
this.activeCommand = commandResult.command;
searchTerm = commandResult.searchTerm;
if (commandResult.command.type === 'embedding') {
// /emb or /embedding command
endpoint = '/lm/embeddings/relative-paths';
this.searchType = 'embeddings';
} else {
// Category filter command
const categories = commandResult.command.categories.join(',');
endpoint = `/lm/custom-words/search?category=${categories}`;
this.searchType = 'custom_words';
}
} else {
// No command - regular custom words search with enriched results
this.showingCommands = false;
this.activeCommand = null;
endpoint = '/lm/custom-words/search?enriched=true';
searchTerm = rawSearchTerm;
this.searchType = 'custom_words';
}
} else {
// Setting disabled - no autocomplete for non-emb: terms
this.hide();
@@ -501,24 +577,220 @@ class AutoComplete {
this.hide();
}
}
/**
* Parse command input to detect command mode
* @param {string} rawInput - Raw input text
* @returns {Object} - { showCommands, commandFilter, command, searchTerm }
*/
_parseCommandInput(rawInput) {
const trimmed = rawInput.trim();
// Check if input starts with "/"
if (!trimmed.startsWith('/')) {
return { showCommands: false, command: null, searchTerm: trimmed };
}
// Split into potential command and search term
const spaceIndex = trimmed.indexOf(' ');
if (spaceIndex === -1) {
// Still typing command (e.g., "/cha")
const partialCommand = trimmed.toLowerCase();
// Check for exact command match
if (TAG_COMMANDS[partialCommand]) {
return {
showCommands: false,
command: TAG_COMMANDS[partialCommand],
searchTerm: '',
};
}
// Show command suggestions
return {
showCommands: true,
commandFilter: partialCommand.slice(1), // Remove leading "/"
command: null,
searchTerm: '',
};
}
// Command with search term (e.g., "/char miku")
const commandPart = trimmed.slice(0, spaceIndex).toLowerCase();
const searchPart = trimmed.slice(spaceIndex + 1).trim();
if (TAG_COMMANDS[commandPart]) {
return {
showCommands: false,
command: TAG_COMMANDS[commandPart],
searchTerm: searchPart,
};
}
// Unknown command, treat as regular search
return { showCommands: false, command: null, searchTerm: trimmed };
}
/**
* Show the command list dropdown
* @param {string} filter - Optional filter for commands
*/
_showCommandList(filter = '') {
const filterLower = filter.toLowerCase();
// Get unique commands (avoid duplicates like /char and /character)
const seenLabels = new Set();
const commands = [];
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
if (seenLabels.has(info.label)) continue;
if (!filter || cmd.slice(1).startsWith(filterLower)) {
seenLabels.add(info.label);
commands.push({ command: cmd, ...info });
}
}
if (commands.length === 0) {
this.hide();
return;
}
this.items = commands;
this._renderCommandList();
this.show();
}
/**
* Render the command list dropdown
*/
_renderCommandList() {
this.dropdown.innerHTML = '';
this.selectedIndex = -1;
this.items.forEach((item, index) => {
const itemEl = document.createElement('div');
itemEl.className = 'comfy-autocomplete-item comfy-autocomplete-command';
const cmdSpan = document.createElement('span');
cmdSpan.className = 'lm-autocomplete-command-name';
cmdSpan.textContent = item.command;
const labelSpan = document.createElement('span');
labelSpan.className = 'lm-autocomplete-command-label';
labelSpan.textContent = item.label;
itemEl.appendChild(cmdSpan);
itemEl.appendChild(labelSpan);
itemEl.style.cssText = `
padding: 8px 12px;
cursor: pointer;
color: rgba(226, 232, 240, 0.8);
border-bottom: 1px solid rgba(226, 232, 240, 0.1);
transition: all 0.2s ease;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
display: flex;
justify-content: space-between;
align-items: center;
gap: 12px;
`;
itemEl.addEventListener('mouseenter', () => {
this.selectItem(index);
});
itemEl.addEventListener('click', () => {
this._insertCommand(item.command);
});
this.dropdown.appendChild(itemEl);
});
// Remove border from last item
if (this.dropdown.lastChild) {
this.dropdown.lastChild.style.borderBottom = 'none';
}
// Auto-select first item
if (this.items.length > 0) {
setTimeout(() => this.selectItem(0), 100);
}
}
/**
* Insert a command into the input
* @param {string} command - The command to insert (e.g., "/char")
*/
_insertCommand(command) {
const currentValue = this.inputElement.value;
const caretPos = this.getCaretPosition();
// Find the start of the current command being typed
const beforeCursor = currentValue.substring(0, caretPos);
const segments = beforeCursor.split(/[,\>]+/);
const lastSegment = segments[segments.length - 1];
const commandStartPos = caretPos - lastSegment.length;
// Insert command with trailing space
const insertText = command + ' ';
const newValue = currentValue.substring(0, commandStartPos) + insertText + currentValue.substring(caretPos);
const newCaretPos = commandStartPos + insertText.length;
this.inputElement.value = newValue;
// Trigger input event
const event = new Event('input', { bubbles: true });
this.inputElement.dispatchEvent(event);
this.hide();
// Focus and position cursor
this.inputElement.focus();
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
}
render() {
this.dropdown.innerHTML = '';
this.selectedIndex = -1;
// Early return if no items to prevent empty dropdown
if (!this.items || this.items.length === 0) {
return;
}
this.items.forEach((relativePath, index) => {
// Check if items are enriched (have tag_name, category, post_count)
const isEnriched = this.items[0] && typeof this.items[0] === 'object' && 'tag_name' in this.items[0];
this.items.forEach((itemData, index) => {
const item = document.createElement('div');
item.className = 'comfy-autocomplete-item';
// Create highlighted content
const highlightedContent = this.highlightMatch(relativePath, this.currentSearchTerm);
item.innerHTML = highlightedContent;
// Get the display text and path for insertion
const displayText = isEnriched ? itemData.tag_name : itemData;
const insertPath = isEnriched ? itemData.tag_name : itemData;
if (isEnriched) {
// Render enriched item with category badge and post count
this._renderEnrichedItem(item, itemData, this.currentSearchTerm);
} else {
// Create highlighted content for simple items, wrapped in a span
// to prevent flex layout from breaking up the text
const nameSpan = document.createElement('span');
nameSpan.className = 'lm-autocomplete-name';
nameSpan.innerHTML = this.highlightMatch(displayText, this.currentSearchTerm);
nameSpan.style.cssText = `
flex: 1;
min-width: 0;
overflow: hidden;
text-overflow: ellipsis;
`;
item.appendChild(nameSpan);
}
// Apply item styles with new color scheme
item.style.cssText = `
padding: 8px 12px;
@@ -530,37 +802,102 @@ class AutoComplete {
overflow: hidden;
text-overflow: ellipsis;
position: relative;
display: flex;
justify-content: space-between;
align-items: center;
gap: 8px;
`;
// Hover and selection handlers
item.addEventListener('mouseenter', () => {
this.selectItem(index);
});
item.addEventListener('mouseleave', () => {
this.hidePreview();
});
// Click handler
item.addEventListener('click', () => {
this.insertSelection(relativePath);
this.insertSelection(insertPath);
});
this.dropdown.appendChild(item);
});
// Remove border from last item
if (this.dropdown.lastChild) {
this.dropdown.lastChild.style.borderBottom = 'none';
}
// Auto-select the first item with a small delay
if (this.items.length > 0) {
setTimeout(() => {
this.selectItem(0);
}, 100); // 50ms delay
this.selectItem(0);
}, 100);
}
}
/**
* Render an enriched autocomplete item with category badge and post count
* @param {HTMLElement} itemEl - The item element to populate
* @param {Object} itemData - The enriched item data { tag_name, category, post_count }
* @param {string} searchTerm - The current search term for highlighting
*/
_renderEnrichedItem(itemEl, itemData, searchTerm) {
// Create name span with highlighted match
const nameSpan = document.createElement('span');
nameSpan.className = 'lm-autocomplete-name';
nameSpan.innerHTML = this.highlightMatch(itemData.tag_name, searchTerm);
nameSpan.style.cssText = `
flex: 1;
min-width: 0;
overflow: hidden;
text-overflow: ellipsis;
`;
// Create meta container for count and badge
const metaSpan = document.createElement('span');
metaSpan.className = 'lm-autocomplete-meta';
metaSpan.style.cssText = `
display: flex;
align-items: center;
gap: 8px;
flex-shrink: 0;
`;
// Add post count
if (itemData.post_count > 0) {
const countSpan = document.createElement('span');
countSpan.className = 'lm-autocomplete-count';
countSpan.textContent = formatPostCount(itemData.post_count);
countSpan.style.cssText = `
font-size: 11px;
color: rgba(226, 232, 240, 0.5);
`;
metaSpan.appendChild(countSpan);
}
// Add category badge
const categoryInfo = CATEGORY_INFO[itemData.category];
if (categoryInfo) {
const badgeSpan = document.createElement('span');
badgeSpan.className = 'lm-autocomplete-category';
badgeSpan.textContent = categoryInfo.label;
badgeSpan.style.cssText = `
font-size: 10px;
padding: 2px 6px;
border-radius: 10px;
background: ${categoryInfo.bg};
color: ${categoryInfo.text};
white-space: nowrap;
`;
metaSpan.appendChild(badgeSpan);
}
itemEl.appendChild(nameSpan);
itemEl.appendChild(metaSpan);
}
highlightMatch(text, searchTerm) {
const { include } = parseSearchTokens(searchTerm);
@@ -655,10 +992,11 @@ class AutoComplete {
this.dropdown.style.display = 'none';
this.isVisible = false;
this.selectedIndex = -1;
this.showingCommands = false;
// Hide preview tooltip
this.hidePreview();
// Clear selection styles from all items
const items = this.dropdown.querySelectorAll('.comfy-autocomplete-item');
items.forEach(item => {
@@ -715,7 +1053,17 @@ class AutoComplete {
case 'Enter':
e.preventDefault();
if (this.selectedIndex >= 0 && this.selectedIndex < this.items.length) {
this.insertSelection(this.items[this.selectedIndex]);
if (this.showingCommands) {
// Insert command
this._insertCommand(this.items[this.selectedIndex].command);
} else {
// Insert selection (handle enriched items)
const selectedItem = this.items[this.selectedIndex];
const insertPath = typeof selectedItem === 'object' && 'tag_name' in selectedItem
? selectedItem.tag_name
: selectedItem;
this.insertSelection(insertPath);
}
}
break;

View File

@@ -575,3 +575,84 @@ body.lm-lora-reordering * {
border-top: 1px solid rgba(255, 255, 255, 0.05);
margin: 6px 0;
}
/* Autocomplete styling */
.lm-autocomplete-name {
flex: 1;
min-width: 0;
overflow: hidden;
text-overflow: ellipsis;
}
.lm-autocomplete-meta {
display: flex;
align-items: center;
gap: 8px;
flex-shrink: 0;
}
.lm-autocomplete-count {
font-size: 11px;
color: rgba(226, 232, 240, 0.5);
}
.lm-autocomplete-category {
font-size: 10px;
padding: 2px 6px;
border-radius: 10px;
white-space: nowrap;
}
/* Category-specific badge colors */
.lm-autocomplete-category--general {
background: rgba(0, 155, 230, 0.2);
color: #4bb4ff;
}
.lm-autocomplete-category--artist {
background: rgba(255, 138, 139, 0.2);
color: #ffc3c3;
}
.lm-autocomplete-category--copyright {
background: rgba(199, 151, 255, 0.2);
color: #ddc9fb;
}
.lm-autocomplete-category--character {
background: rgba(53, 198, 74, 0.2);
color: #93e49a;
}
.lm-autocomplete-category--meta {
background: rgba(234, 208, 132, 0.2);
color: #f7e7c3;
}
.lm-autocomplete-category--species {
background: rgba(237, 137, 54, 0.2);
color: #f6ad55;
}
.lm-autocomplete-category--lore {
background: rgba(72, 187, 120, 0.2);
color: #68d391;
}
/* Command list styling */
.comfy-autocomplete-command {
display: flex;
justify-content: space-between;
align-items: center;
gap: 12px;
}
.lm-autocomplete-command-name {
font-family: 'Consolas', 'Monaco', monospace;
color: rgba(66, 153, 225, 0.9);
}
.lm-autocomplete-command-label {
font-size: 12px;
color: rgba(226, 232, 240, 0.6);
}