fix(autocomplete): reduce tag search overhead (#895)

This commit is contained in:
Will Miao
2026-04-15 20:42:33 +08:00
parent 62247bdd87
commit 4514ca94b7
7 changed files with 475 additions and 75 deletions

View File

@@ -13,6 +13,7 @@ import contextlib
import io
import json
import logging
import time
import os
import platform
import re
@@ -2433,6 +2434,7 @@ class CustomWordsHandler:
even without category filtering.
"""
try:
started_at = time.perf_counter()
search_term = request.query.get("search", "")
limit = int(request.query.get("limit", "20"))
offset = max(0, int(request.query.get("offset", "0")))
@@ -2444,6 +2446,16 @@ class CustomWordsHandler:
if category_param:
categories = self._parse_category_param(category_param)
logger.info(
"LM custom words request search=%r category_param=%r categories=%s limit=%s offset=%s enriched=%s",
search_term,
category_param,
categories,
limit,
offset,
enriched_param,
)
results = self._service.search_words(
search_term,
limit,
@@ -2452,6 +2464,14 @@ class CustomWordsHandler:
enriched=enriched_param,
)
elapsed_ms = (time.perf_counter() - started_at) * 1000
logger.info(
"LM custom words response search=%r result_count=%s elapsed_ms=%.2f",
search_term,
len(results),
elapsed_ms,
)
return web.json_response({"success": True, "words": results})
except Exception as exc:
logger.error("Error searching custom words: %s", exc, exc_info=True)

View File

@@ -7,11 +7,13 @@ with category filtering and enriched results including post counts.
from __future__ import annotations
import logging
import re
from typing import List, Dict, Any, Optional
logger = logging.getLogger(__name__)
_EMBEDDED_COMMAND_PATTERN = re.compile(r"\s/\w")
class CustomWordsService:
"""Service for autocomplete via TagFTSIndex.
@@ -77,10 +79,47 @@ class CustomWordsService:
Returns:
List of dicts with tag_name, category, and post_count.
"""
normalized_search = search_term.strip()
if not normalized_search:
return []
# Prompt widgets should only send the active token, but guard against
# accidental full-prompt queries reaching the FTS path.
if (
"__" in normalized_search
or "," in normalized_search
or ">" in normalized_search
or "\n" in normalized_search
or "\r" in normalized_search
or _EMBEDDED_COMMAND_PATTERN.search(normalized_search)
):
logger.debug("Skipping prompt-like custom words query: %s", normalized_search)
return []
logger.info(
"LM custom words service start search=%r categories=%s limit=%s offset=%s enriched=%s",
normalized_search,
categories,
limit,
offset,
enriched,
)
tag_index = self._get_tag_index()
if tag_index is not None:
logger.info(
"LM custom words service tag_index ready=%s indexing=%s",
getattr(tag_index, "is_ready", lambda: "unknown")(),
getattr(tag_index, "is_indexing", lambda: "unknown")(),
)
results = tag_index.search(
search_term, categories=categories, limit=limit, offset=offset
normalized_search, categories=categories, limit=limit, offset=offset
)
logger.info(
"LM custom words service done search=%r result_count=%s",
normalized_search,
len(results),
)
return results

View File

@@ -450,9 +450,9 @@ class TagFTSIndex:
the tag_name, the result will include a "matched_alias" field.
Ranking is based on a combination of:
1. FTS5 bm25 relevance score (how well the text matches)
2. Post count (popularity)
3. Exact prefix match boost (tag_name starts with query)
1. Exact prefix match boost (tag_name starts with query)
2. Post count to preserve expected autocomplete ordering
3. FTS5 bm25 relevance score as a deterministic tie-breaker
Args:
query: The search query string.
@@ -464,6 +464,17 @@ class TagFTSIndex:
List of dictionaries with tag_name, category, post_count,
rank_score, and optionally matched_alias.
"""
search_started_at = time.perf_counter()
logger.info(
"LM tag FTS search start query=%r categories=%s limit=%s offset=%s ready=%s indexing=%s",
query,
categories,
limit,
offset,
self.is_ready(),
self.is_indexing(),
)
# Ensure index is ready (lazy initialization)
if not self.ensure_ready():
if not self._warned_not_ready:
@@ -478,71 +489,44 @@ class TagFTSIndex:
if not fts_query:
return []
logger.info(
"LM tag FTS search built query=%r fts_query=%r",
query,
fts_query,
)
query_lower = query.lower().strip()
try:
logger.info("LM tag FTS search waiting_for_lock query=%r", query)
with self._lock:
logger.info("LM tag FTS search acquired_lock query=%r", query)
conn = self._connect(readonly=True)
try:
# Build the SQL query with bm25 ranking
# FTS5 bm25() returns negative scores, lower is better
# We use -bm25() to get higher=better scores
# Weights: -100.0 for exact matches, 1.0 for others
# Add LOG10(post_count) weighting to boost popular tags
# Use CASE to boost tag_name prefix matches above alias matches
if categories:
placeholders = ",".join("?" * len(categories))
sql = f"""
SELECT t.tag_name, t.category, t.post_count, t.aliases,
CASE
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
ELSE 0
END AS is_tag_name_match,
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
FROM tag_fts
JOIN tags t ON tag_fts.rowid = t.rowid
WHERE tag_fts.searchable_text MATCH ?
AND t.category IN ({placeholders})
ORDER BY is_tag_name_match DESC, rank_score DESC
LIMIT ? OFFSET ?
"""
# Escape special LIKE characters and add wildcard
query_escaped = (
query_lower.lstrip("/")
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
params = (
[query_escaped + "%", fts_query]
+ categories
+ [limit, offset]
)
else:
sql = """
SELECT t.tag_name, t.category, t.post_count, t.aliases,
CASE
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
ELSE 0
END AS is_tag_name_match,
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
FROM tag_fts
JOIN tags t ON tag_fts.rowid = t.rowid
WHERE tag_fts.searchable_text MATCH ?
ORDER BY is_tag_name_match DESC, rank_score DESC
LIMIT ? OFFSET ?
"""
query_escaped = (
query_lower.lstrip("/")
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
params = [query_escaped + "%", fts_query, limit, offset]
sql, params = self._build_search_statement(
query_lower=query_lower,
fts_query=fts_query,
categories=categories,
limit=limit,
offset=offset,
)
logger.info(
"LM tag FTS search executing_sql query=%r query_len=%s category_count=%s",
query,
len(query_lower),
len(categories) if categories else 0,
)
cursor = conn.execute(sql, params)
logger.info("LM tag FTS search execute_returned query=%r", query)
rows = cursor.fetchall()
logger.info(
"LM tag FTS search fetchall_returned query=%r row_count=%s",
query,
len(rows),
)
results = []
for row in cursor.fetchall():
for row in rows:
result = {
"tag_name": row[0],
"category": row[1],
@@ -564,6 +548,13 @@ class TagFTSIndex:
result["matched_alias"] = matched_alias
results.append(result)
elapsed_ms = (time.perf_counter() - search_started_at) * 1000
logger.info(
"LM tag FTS search done query=%r result_count=%s elapsed_ms=%.2f",
query,
len(results),
elapsed_ms,
)
return results
finally:
conn.close()
@@ -571,6 +562,62 @@ class TagFTSIndex:
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
return []
def _build_search_statement(
self,
query_lower: str,
fts_query: str,
categories: Optional[List[int]],
limit: int,
offset: int,
) -> tuple[str, list[object]]:
"""Build the SQL statement and params for a tag search."""
# Escape special LIKE characters and add wildcard
query_escaped = (
query_lower.lstrip("/")
.replace("\\", "\\\\")
.replace("%", "\\%")
.replace("_", "\\_")
)
# FTS5 bm25() returns negative scores, lower is better.
# We use -bm25() to get higher=better scores, but keep post_count as the
# primary sort within tag-name prefix matches so autocomplete ordering
# remains aligned with the existing popularity-first behavior.
if categories:
placeholders = ",".join("?" * len(categories))
sql = f"""
SELECT t.tag_name, t.category, t.post_count, t.aliases,
CASE
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
ELSE 0
END AS is_tag_name_match,
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
FROM tag_fts
CROSS JOIN tags t ON t.rowid = tag_fts.rowid
WHERE tag_fts.searchable_text MATCH ?
AND t.category IN ({placeholders})
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
LIMIT ? OFFSET ?
"""
params = [query_escaped + "%", fts_query] + categories + [limit, offset]
else:
sql = """
SELECT t.tag_name, t.category, t.post_count, t.aliases,
CASE
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
ELSE 0
END AS is_tag_name_match,
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
FROM tag_fts
JOIN tags t ON tag_fts.rowid = t.rowid
WHERE tag_fts.searchable_text MATCH ?
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
LIMIT ? OFFSET ?
"""
params = [query_escaped + "%", fts_query, limit, offset]
return sql, params
def _find_matched_alias(
self, query: str, tag_name: str, aliases_str: str
) -> Optional[str]:

View File

@@ -126,6 +126,31 @@ describe('AutoComplete widget interactions', () => {
expect(caretHelperInstance.getCursorOffset).toHaveBeenCalled();
});
it('deduplicates duplicate-equivalent query variations before issuing requests', async () => {
vi.useFakeTimers();
fetchApiMock.mockResolvedValue({
json: () => Promise.resolve({ success: true, words: [] }),
});
caretHelperInstance.getBeforeCursor.mockReturnValue('Example');
const input = document.createElement('textarea');
document.body.append(input);
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
new AutoComplete(input, 'prompt', { debounceDelay: 0, showPreview: false, minChars: 1 });
input.value = 'Example';
input.dispatchEvent(new Event('input', { bubbles: true }));
await vi.runAllTimersAsync();
await Promise.resolve();
expect(fetchApiMock).toHaveBeenCalledTimes(1);
expect(fetchApiMock).toHaveBeenCalledWith('/lm/custom-words/search?enriched=true&search=Example&limit=100');
});
it('inserts the selected LoRA with usage tip strengths and restores focus', async () => {
fetchApiMock.mockImplementation((url) => {
if (url.includes('usage-tips-by-path')) {
@@ -244,6 +269,55 @@ describe('AutoComplete widget interactions', () => {
expect(inputListener).not.toHaveBeenCalled();
});
it('shows the full command list when typing a single slash', async () => {
const input = document.createElement('textarea');
input.value = '/';
input.selectionStart = input.value.length;
document.body.append(input);
caretHelperInstance.getBeforeCursor.mockReturnValue('/');
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
const autoComplete = new AutoComplete(input,'prompt', { showPreview: false, minChars: 1 });
input.dispatchEvent(new Event('input', { bubbles: true }));
const commandNames = autoComplete.items.map((item) => item.command);
expect(commandNames).toContain('/character');
expect(commandNames).toContain('/char');
expect(commandNames).toContain('/artist');
expect(commandNames).toContain('/general');
expect(commandNames).toContain('/copyright');
expect(commandNames).toContain('/meta');
expect(commandNames).toContain('/species');
expect(commandNames).toContain('/lore');
expect(commandNames).toContain('/emb');
expect(commandNames).toContain('/embedding');
expect(commandNames).toContain('/wild');
expect(commandNames).toContain('/wildcard');
});
it('renders every command item when slash opens the command list', async () => {
const input = document.createElement('textarea');
input.value = '/';
input.selectionStart = input.value.length;
document.body.append(input);
caretHelperInstance.getBeforeCursor.mockReturnValue('/');
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
const autoComplete = new AutoComplete(input, 'prompt', { showPreview: false, minChars: 1 });
input.dispatchEvent(new Event('input', { bubbles: true }));
const renderedCommands = autoComplete.contentContainer.querySelectorAll('.lm-autocomplete-command-name');
expect(renderedCommands).toHaveLength(autoComplete.items.length);
});
it('accepts the selected suggestion with Enter', async () => {
caretHelperInstance.getBeforeCursor.mockReturnValue('example');
@@ -1095,6 +1169,7 @@ describe('AutoComplete widget interactions', () => {
minChars: 1,
});
fetchApiMock.mockClear();
input.dispatchEvent(new Event('input', { bubbles: true }));
await vi.runAllTimersAsync();
await Promise.resolve();
@@ -1133,6 +1208,61 @@ describe('AutoComplete widget interactions', () => {
expect(input.setSelectionRange).toHaveBeenCalled();
});
it('does not reopen autocomplete on blur after inserting a wildcard literal', async () => {
const input = document.createElement('textarea');
input.value = '__flower__,';
input.selectionStart = input.value.length;
document.body.append(input);
caretHelperInstance.getBeforeCursor.mockReturnValue('__flower__,');
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
const autoComplete = new AutoComplete(input,'prompt', {
debounceDelay: 0,
showPreview: false,
minChars: 1,
});
const hideSpy = vi.spyOn(autoComplete, 'hide');
input.dispatchEvent(new Event('blur', { bubbles: true }));
expect(fetchApiMock).not.toHaveBeenCalled();
expect(hideSpy).toHaveBeenCalled();
expect(autoComplete.isVisible).toBe(false);
});
it('treats a command after a wildcard literal as the active token', async () => {
vi.useFakeTimers();
fetchApiMock.mockResolvedValue({
json: () => Promise.resolve({
success: true,
words: [{ tag_name: 'flower_field', category: 4, post_count: 1234 }],
}),
});
const input = document.createElement('textarea');
input.value = '__flower__ /character f';
input.selectionStart = input.value.length;
document.body.append(input);
caretHelperInstance.getBeforeCursor.mockReturnValue('__flower__ /character f');
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
const autoComplete = new AutoComplete(input,'prompt', {
debounceDelay: 0,
showPreview: false,
minChars: 1,
});
input.dispatchEvent(new Event('input', { bubbles: true }));
await vi.runAllTimersAsync();
await Promise.resolve();
expect(autoComplete.getSearchTerm(input.value)).toBe('/character f');
});
it('invalidates stale autocomplete metadata and falls back to delimiter-based matching', async () => {
settingGetMock.mockImplementation((key) => {
if (key === 'loramanager.autocomplete_append_comma') {

View File

@@ -94,6 +94,19 @@ class TestCustomWordsService:
results = service.search_words("test")
assert mock_tag_index.called
def test_search_words_skips_prompt_like_queries(self):
service = CustomWordsService.__new__(CustomWordsService)
mock_tag_index = MockTagFTSIndex()
def mock_get_index():
return mock_tag_index
service._get_tag_index = mock_get_index
results = service.search_words("__flower__ /character f")
assert results == []
assert mock_tag_index.called is False
class MockTagFTSIndex:
"""Mock TagFTSIndex for testing."""

View File

@@ -1,6 +1,7 @@
"""Tests for TagFTSIndex functionality."""
import os
import sqlite3
import tempfile
from typing import List
@@ -173,6 +174,40 @@ class TestTagFTSIndexSearch:
assert len(results) >= 1
assert all(r["category"] in [4, 11] for r in results)
def test_search_with_category_filter_uses_fts_first_plan(self, populated_fts):
"""Category-filtered searches should start from FTS hits, not category scans."""
sql, params = populated_fts._build_search_statement(
query_lower="f",
fts_query="f*",
categories=[4, 11],
limit=20,
offset=0,
)
conn = sqlite3.connect(f"file:{populated_fts.get_database_path()}?mode=ro", uri=True)
try:
plan_rows = conn.execute(f"EXPLAIN QUERY PLAN {sql}", params).fetchall()
finally:
conn.close()
plan_details = [row[3] for row in plan_rows]
assert any(detail.startswith("SCAN tag_fts VIRTUAL TABLE INDEX") for detail in plan_details)
assert any("SEARCH t USING INTEGER PRIMARY KEY" in detail for detail in plan_details)
assert not any("SEARCH t USING INDEX idx_tags_category" in detail for detail in plan_details)
def test_search_statement_uses_post_count_as_tie_breaker(self, populated_fts):
"""Search ranking should use popularity as a secondary sort key."""
sql, _ = populated_fts._build_search_statement(
query_lower="f",
fts_query="f*",
categories=[4, 11],
limit=20,
offset=0,
)
assert "ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC" in sql
assert "LOG10" not in sql
def test_search_with_category_filter_excludes_others(self, populated_fts):
"""Test that category filter excludes other categories."""
# Search for "hi" but only in general category

View File

@@ -366,6 +366,7 @@ class AutoComplete {
this.previewTooltip = null;
this.previewTooltipPromise = null;
this.searchType = null;
this.suppressAutocompleteOnce = false;
// Virtual scrolling state
this.virtualScrollOffset = 0;
@@ -505,6 +506,11 @@ class AutoComplete {
bindEvents() {
// Handle input changes
this.onInput = (e) => {
if (this.suppressAutocompleteOnce) {
this.suppressAutocompleteOnce = false;
this.hide();
return;
}
this.handleInput(e.target.value);
};
this.inputElement.addEventListener('input', this.onInput);
@@ -521,6 +527,7 @@ class AutoComplete {
const formattedValue = formatAutocompleteTextOnBlur(this.inputElement.value);
if (formattedValue !== this.inputElement.value) {
this.inputElement.value = formattedValue;
this.suppressAutocompleteOnce = true;
this.inputElement.dispatchEvent(new Event('input', { bubbles: true }));
}
}
@@ -725,9 +732,24 @@ class AutoComplete {
}
const rawText = beforeCursor.substring(start);
const text = rawText.trim();
const leadingWhitespaceLength = rawText.length - rawText.trimStart().length;
const trimmedStart = start + leadingWhitespaceLength;
const text = rawText.trim();
if (this.modelType === 'prompt') {
const tokenRange = this._getPromptTokenRange(rawText, trimmedStart, caretPos);
if (tokenRange) {
return {
start: tokenRange.start,
trimmedStart: tokenRange.trimmedStart,
end: caretPos,
beforeCursor,
rawText: tokenRange.rawText,
text: tokenRange.text,
tokenType: tokenRange.tokenType,
};
}
}
return {
start,
@@ -739,6 +761,73 @@ class AutoComplete {
};
}
_getPromptTokenRange(rawText = '', trimmedStart = 0, caretPos = 0) {
const trimmedText = rawText.trim();
if (!trimmedText) {
return {
start: trimmedStart,
trimmedStart,
rawText: '',
text: '',
tokenType: 'empty',
};
}
const commandOffset = trimmedText.startsWith('/')
? 0
: trimmedText.lastIndexOf(' /');
if (commandOffset !== -1) {
const normalizedCommandOffset = commandOffset === 0 ? 0 : commandOffset + 1;
const commandText = trimmedText.slice(normalizedCommandOffset);
const commandStart = trimmedStart + normalizedCommandOffset;
return {
start: commandStart,
trimmedStart: commandStart,
rawText: commandText,
text: commandText,
tokenType: commandText === '/' ? 'empty_command_trigger' : 'command',
};
}
const wildcardMatch = trimmedText.match(/(?:^|\s)(__[\w\s.\-+/*\\]+?__)$/);
if (wildcardMatch) {
const wildcardText = wildcardMatch[1];
const wildcardOffset = trimmedText.lastIndexOf(wildcardText);
const wildcardStart = trimmedStart + wildcardOffset;
return {
start: wildcardStart,
trimmedStart: wildcardStart,
rawText: wildcardText,
text: '',
tokenType: 'wildcard_literal',
};
}
const embeddingOffset = trimmedText.search(/(?:^|\s)emb:[^\s]*$/i);
if (embeddingOffset !== -1) {
const normalizedEmbeddingOffset = trimmedText.slice(embeddingOffset).startsWith(' ')
? embeddingOffset + 1
: embeddingOffset;
const embeddingText = trimmedText.slice(normalizedEmbeddingOffset);
const embeddingStart = trimmedStart + normalizedEmbeddingOffset;
return {
start: embeddingStart,
trimmedStart: embeddingStart,
rawText: embeddingText,
text: embeddingText,
tokenType: 'embedding_literal',
};
}
return {
start: trimmedStart,
trimmedStart,
rawText,
text: trimmedText,
tokenType: 'tag_text',
};
}
_getHardBoundaryStart(beforeCursor = '') {
const lastComma = beforeCursor.lastIndexOf(',');
const lastAngle = beforeCursor.lastIndexOf('>');
@@ -890,6 +979,32 @@ class AutoComplete {
return Array.from(variations).filter(v => v.length >= this.options.minChars);
}
_normalizeQueryForRequest(term = '') {
return term.trim().toLowerCase();
}
_getQueriesToExecute(term = '') {
const queryVariations = this._generateQueryVariations(term);
const uniqueQueries = [];
const seen = new Set();
for (const query of queryVariations) {
const normalized = this._normalizeQueryForRequest(query);
if (!normalized || seen.has(normalized)) {
continue;
}
seen.add(normalized);
uniqueQueries.push(query);
if (uniqueQueries.length >= 4) {
break;
}
}
return uniqueQueries;
}
/**
* Get display text for an item (without extension for models)
* @param {string|Object} item - Item to get display text from
@@ -1041,18 +1156,17 @@ class AutoComplete {
endpoint = `/lm/${this.modelType}/relative-paths`;
}
// Generate multiple query variations for better matching
const queryVariations = this._generateQueryVariations(term);
// Generate multiple query variations for better matching, but avoid
// sending duplicate-equivalent requests that normalize to the same
// backend search term.
const queriesToExecute = this._getQueriesToExecute(term);
if (queryVariations.length === 0) {
if (queriesToExecute.length === 0) {
this.items = [];
this.hide();
return;
}
// Limit the number of parallel queries to avoid overwhelming the server
const queriesToExecute = queryVariations.slice(0, 4);
// Execute all queries in parallel
const searchPromises = queriesToExecute.map(async (query) => {
const url = endpoint.includes('?')
@@ -1190,20 +1304,15 @@ class AutoComplete {
const filterLower = filter.toLowerCase();
// Get unique commands (avoid duplicates like /char and /character)
const seenLabels = new Set();
const commands = [];
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
if (seenLabels.has(info.label)) continue;
// Filter out toggle commands that don't meet their condition
if (info.type === 'toggle_setting' && info.condition) {
if (!info.condition()) continue;
}
if (!filter || cmd.slice(1).startsWith(filterLower)) {
seenLabels.add(info.label);
commands.push({ command: cmd, ...info });
}
}
@@ -1288,8 +1397,16 @@ class AutoComplete {
}
// Auto-select immediately so accept keys remain stable.
// In virtual-scroll mode, calling selectItem() before the dropdown is
// visible can see a zero-height container and incorrectly replace the
// full command list with a partially virtualized slice.
if (this.items.length > 0) {
this.selectItem(0);
this.selectedIndex = 0;
if (this.contentContainer) {
this._applyItemSelection(0);
} else {
this.selectItem(0);
}
}
// Update virtual scroll height for virtual scrolling mode
@@ -1635,8 +1752,7 @@ class AutoComplete {
}
}
const queryVariations = this._generateQueryVariations(this.currentSearchTerm);
const queriesToExecute = queryVariations.slice(0, 4);
const queriesToExecute = this._getQueriesToExecute(this.currentSearchTerm);
const offset = this.items.length;
// Execute all queries in parallel with offset