mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
fix(autocomplete): reduce tag search overhead (#895)
This commit is contained in:
@@ -13,6 +13,7 @@ import contextlib
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
import re
|
import re
|
||||||
@@ -2433,6 +2434,7 @@ class CustomWordsHandler:
|
|||||||
even without category filtering.
|
even without category filtering.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
started_at = time.perf_counter()
|
||||||
search_term = request.query.get("search", "")
|
search_term = request.query.get("search", "")
|
||||||
limit = int(request.query.get("limit", "20"))
|
limit = int(request.query.get("limit", "20"))
|
||||||
offset = max(0, int(request.query.get("offset", "0")))
|
offset = max(0, int(request.query.get("offset", "0")))
|
||||||
@@ -2444,6 +2446,16 @@ class CustomWordsHandler:
|
|||||||
if category_param:
|
if category_param:
|
||||||
categories = self._parse_category_param(category_param)
|
categories = self._parse_category_param(category_param)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"LM custom words request search=%r category_param=%r categories=%s limit=%s offset=%s enriched=%s",
|
||||||
|
search_term,
|
||||||
|
category_param,
|
||||||
|
categories,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
enriched_param,
|
||||||
|
)
|
||||||
|
|
||||||
results = self._service.search_words(
|
results = self._service.search_words(
|
||||||
search_term,
|
search_term,
|
||||||
limit,
|
limit,
|
||||||
@@ -2452,6 +2464,14 @@ class CustomWordsHandler:
|
|||||||
enriched=enriched_param,
|
enriched=enriched_param,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elapsed_ms = (time.perf_counter() - started_at) * 1000
|
||||||
|
logger.info(
|
||||||
|
"LM custom words response search=%r result_count=%s elapsed_ms=%.2f",
|
||||||
|
search_term,
|
||||||
|
len(results),
|
||||||
|
elapsed_ms,
|
||||||
|
)
|
||||||
|
|
||||||
return web.json_response({"success": True, "words": results})
|
return web.json_response({"success": True, "words": results})
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Error searching custom words: %s", exc, exc_info=True)
|
logger.error("Error searching custom words: %s", exc, exc_info=True)
|
||||||
|
|||||||
@@ -7,11 +7,13 @@ with category filtering and enriched results including post counts.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_EMBEDDED_COMMAND_PATTERN = re.compile(r"\s/\w")
|
||||||
class CustomWordsService:
|
class CustomWordsService:
|
||||||
"""Service for autocomplete via TagFTSIndex.
|
"""Service for autocomplete via TagFTSIndex.
|
||||||
|
|
||||||
@@ -77,10 +79,47 @@ class CustomWordsService:
|
|||||||
Returns:
|
Returns:
|
||||||
List of dicts with tag_name, category, and post_count.
|
List of dicts with tag_name, category, and post_count.
|
||||||
"""
|
"""
|
||||||
|
normalized_search = search_term.strip()
|
||||||
|
if not normalized_search:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Prompt widgets should only send the active token, but guard against
|
||||||
|
# accidental full-prompt queries reaching the FTS path.
|
||||||
|
if (
|
||||||
|
"__" in normalized_search
|
||||||
|
or "," in normalized_search
|
||||||
|
or ">" in normalized_search
|
||||||
|
or "\n" in normalized_search
|
||||||
|
or "\r" in normalized_search
|
||||||
|
or _EMBEDDED_COMMAND_PATTERN.search(normalized_search)
|
||||||
|
):
|
||||||
|
logger.debug("Skipping prompt-like custom words query: %s", normalized_search)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"LM custom words service start search=%r categories=%s limit=%s offset=%s enriched=%s",
|
||||||
|
normalized_search,
|
||||||
|
categories,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
enriched,
|
||||||
|
)
|
||||||
|
|
||||||
tag_index = self._get_tag_index()
|
tag_index = self._get_tag_index()
|
||||||
if tag_index is not None:
|
if tag_index is not None:
|
||||||
|
logger.info(
|
||||||
|
"LM custom words service tag_index ready=%s indexing=%s",
|
||||||
|
getattr(tag_index, "is_ready", lambda: "unknown")(),
|
||||||
|
getattr(tag_index, "is_indexing", lambda: "unknown")(),
|
||||||
|
)
|
||||||
|
|
||||||
results = tag_index.search(
|
results = tag_index.search(
|
||||||
search_term, categories=categories, limit=limit, offset=offset
|
normalized_search, categories=categories, limit=limit, offset=offset
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"LM custom words service done search=%r result_count=%s",
|
||||||
|
normalized_search,
|
||||||
|
len(results),
|
||||||
)
|
)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -450,9 +450,9 @@ class TagFTSIndex:
|
|||||||
the tag_name, the result will include a "matched_alias" field.
|
the tag_name, the result will include a "matched_alias" field.
|
||||||
|
|
||||||
Ranking is based on a combination of:
|
Ranking is based on a combination of:
|
||||||
1. FTS5 bm25 relevance score (how well the text matches)
|
1. Exact prefix match boost (tag_name starts with query)
|
||||||
2. Post count (popularity)
|
2. Post count to preserve expected autocomplete ordering
|
||||||
3. Exact prefix match boost (tag_name starts with query)
|
3. FTS5 bm25 relevance score as a deterministic tie-breaker
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: The search query string.
|
query: The search query string.
|
||||||
@@ -464,6 +464,17 @@ class TagFTSIndex:
|
|||||||
List of dictionaries with tag_name, category, post_count,
|
List of dictionaries with tag_name, category, post_count,
|
||||||
rank_score, and optionally matched_alias.
|
rank_score, and optionally matched_alias.
|
||||||
"""
|
"""
|
||||||
|
search_started_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
"LM tag FTS search start query=%r categories=%s limit=%s offset=%s ready=%s indexing=%s",
|
||||||
|
query,
|
||||||
|
categories,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
self.is_ready(),
|
||||||
|
self.is_indexing(),
|
||||||
|
)
|
||||||
|
|
||||||
# Ensure index is ready (lazy initialization)
|
# Ensure index is ready (lazy initialization)
|
||||||
if not self.ensure_ready():
|
if not self.ensure_ready():
|
||||||
if not self._warned_not_ready:
|
if not self._warned_not_ready:
|
||||||
@@ -478,71 +489,44 @@ class TagFTSIndex:
|
|||||||
if not fts_query:
|
if not fts_query:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"LM tag FTS search built query=%r fts_query=%r",
|
||||||
|
query,
|
||||||
|
fts_query,
|
||||||
|
)
|
||||||
|
|
||||||
query_lower = query.lower().strip()
|
query_lower = query.lower().strip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info("LM tag FTS search waiting_for_lock query=%r", query)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
logger.info("LM tag FTS search acquired_lock query=%r", query)
|
||||||
conn = self._connect(readonly=True)
|
conn = self._connect(readonly=True)
|
||||||
try:
|
try:
|
||||||
# Build the SQL query with bm25 ranking
|
sql, params = self._build_search_statement(
|
||||||
# FTS5 bm25() returns negative scores, lower is better
|
query_lower=query_lower,
|
||||||
# We use -bm25() to get higher=better scores
|
fts_query=fts_query,
|
||||||
# Weights: -100.0 for exact matches, 1.0 for others
|
categories=categories,
|
||||||
# Add LOG10(post_count) weighting to boost popular tags
|
limit=limit,
|
||||||
# Use CASE to boost tag_name prefix matches above alias matches
|
offset=offset,
|
||||||
if categories:
|
|
||||||
placeholders = ",".join("?" * len(categories))
|
|
||||||
sql = f"""
|
|
||||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
|
||||||
CASE
|
|
||||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
|
||||||
ELSE 0
|
|
||||||
END AS is_tag_name_match,
|
|
||||||
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
|
|
||||||
FROM tag_fts
|
|
||||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
|
||||||
WHERE tag_fts.searchable_text MATCH ?
|
|
||||||
AND t.category IN ({placeholders})
|
|
||||||
ORDER BY is_tag_name_match DESC, rank_score DESC
|
|
||||||
LIMIT ? OFFSET ?
|
|
||||||
"""
|
|
||||||
# Escape special LIKE characters and add wildcard
|
|
||||||
query_escaped = (
|
|
||||||
query_lower.lstrip("/")
|
|
||||||
.replace("\\", "\\\\")
|
|
||||||
.replace("%", "\\%")
|
|
||||||
.replace("_", "\\_")
|
|
||||||
)
|
)
|
||||||
params = (
|
|
||||||
[query_escaped + "%", fts_query]
|
|
||||||
+ categories
|
|
||||||
+ [limit, offset]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sql = """
|
|
||||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
|
||||||
CASE
|
|
||||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
|
||||||
ELSE 0
|
|
||||||
END AS is_tag_name_match,
|
|
||||||
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
|
|
||||||
FROM tag_fts
|
|
||||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
|
||||||
WHERE tag_fts.searchable_text MATCH ?
|
|
||||||
ORDER BY is_tag_name_match DESC, rank_score DESC
|
|
||||||
LIMIT ? OFFSET ?
|
|
||||||
"""
|
|
||||||
query_escaped = (
|
|
||||||
query_lower.lstrip("/")
|
|
||||||
.replace("\\", "\\\\")
|
|
||||||
.replace("%", "\\%")
|
|
||||||
.replace("_", "\\_")
|
|
||||||
)
|
|
||||||
params = [query_escaped + "%", fts_query, limit, offset]
|
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"LM tag FTS search executing_sql query=%r query_len=%s category_count=%s",
|
||||||
|
query,
|
||||||
|
len(query_lower),
|
||||||
|
len(categories) if categories else 0,
|
||||||
|
)
|
||||||
cursor = conn.execute(sql, params)
|
cursor = conn.execute(sql, params)
|
||||||
|
logger.info("LM tag FTS search execute_returned query=%r", query)
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
logger.info(
|
||||||
|
"LM tag FTS search fetchall_returned query=%r row_count=%s",
|
||||||
|
query,
|
||||||
|
len(rows),
|
||||||
|
)
|
||||||
results = []
|
results = []
|
||||||
for row in cursor.fetchall():
|
for row in rows:
|
||||||
result = {
|
result = {
|
||||||
"tag_name": row[0],
|
"tag_name": row[0],
|
||||||
"category": row[1],
|
"category": row[1],
|
||||||
@@ -564,6 +548,13 @@ class TagFTSIndex:
|
|||||||
result["matched_alias"] = matched_alias
|
result["matched_alias"] = matched_alias
|
||||||
|
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
elapsed_ms = (time.perf_counter() - search_started_at) * 1000
|
||||||
|
logger.info(
|
||||||
|
"LM tag FTS search done query=%r result_count=%s elapsed_ms=%.2f",
|
||||||
|
query,
|
||||||
|
len(results),
|
||||||
|
elapsed_ms,
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
@@ -571,6 +562,62 @@ class TagFTSIndex:
|
|||||||
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
|
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def _build_search_statement(
|
||||||
|
self,
|
||||||
|
query_lower: str,
|
||||||
|
fts_query: str,
|
||||||
|
categories: Optional[List[int]],
|
||||||
|
limit: int,
|
||||||
|
offset: int,
|
||||||
|
) -> tuple[str, list[object]]:
|
||||||
|
"""Build the SQL statement and params for a tag search."""
|
||||||
|
# Escape special LIKE characters and add wildcard
|
||||||
|
query_escaped = (
|
||||||
|
query_lower.lstrip("/")
|
||||||
|
.replace("\\", "\\\\")
|
||||||
|
.replace("%", "\\%")
|
||||||
|
.replace("_", "\\_")
|
||||||
|
)
|
||||||
|
|
||||||
|
# FTS5 bm25() returns negative scores, lower is better.
|
||||||
|
# We use -bm25() to get higher=better scores, but keep post_count as the
|
||||||
|
# primary sort within tag-name prefix matches so autocomplete ordering
|
||||||
|
# remains aligned with the existing popularity-first behavior.
|
||||||
|
if categories:
|
||||||
|
placeholders = ",".join("?" * len(categories))
|
||||||
|
sql = f"""
|
||||||
|
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||||
|
CASE
|
||||||
|
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||||
|
ELSE 0
|
||||||
|
END AS is_tag_name_match,
|
||||||
|
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
|
||||||
|
FROM tag_fts
|
||||||
|
CROSS JOIN tags t ON t.rowid = tag_fts.rowid
|
||||||
|
WHERE tag_fts.searchable_text MATCH ?
|
||||||
|
AND t.category IN ({placeholders})
|
||||||
|
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
"""
|
||||||
|
params = [query_escaped + "%", fts_query] + categories + [limit, offset]
|
||||||
|
else:
|
||||||
|
sql = """
|
||||||
|
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||||
|
CASE
|
||||||
|
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||||
|
ELSE 0
|
||||||
|
END AS is_tag_name_match,
|
||||||
|
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
|
||||||
|
FROM tag_fts
|
||||||
|
JOIN tags t ON tag_fts.rowid = t.rowid
|
||||||
|
WHERE tag_fts.searchable_text MATCH ?
|
||||||
|
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
|
||||||
|
LIMIT ? OFFSET ?
|
||||||
|
"""
|
||||||
|
params = [query_escaped + "%", fts_query, limit, offset]
|
||||||
|
|
||||||
|
return sql, params
|
||||||
|
|
||||||
def _find_matched_alias(
|
def _find_matched_alias(
|
||||||
self, query: str, tag_name: str, aliases_str: str
|
self, query: str, tag_name: str, aliases_str: str
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
|||||||
@@ -126,6 +126,31 @@ describe('AutoComplete widget interactions', () => {
|
|||||||
expect(caretHelperInstance.getCursorOffset).toHaveBeenCalled();
|
expect(caretHelperInstance.getCursorOffset).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('deduplicates duplicate-equivalent query variations before issuing requests', async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
|
||||||
|
fetchApiMock.mockResolvedValue({
|
||||||
|
json: () => Promise.resolve({ success: true, words: [] }),
|
||||||
|
});
|
||||||
|
|
||||||
|
caretHelperInstance.getBeforeCursor.mockReturnValue('Example');
|
||||||
|
|
||||||
|
const input = document.createElement('textarea');
|
||||||
|
document.body.append(input);
|
||||||
|
|
||||||
|
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||||
|
new AutoComplete(input, 'prompt', { debounceDelay: 0, showPreview: false, minChars: 1 });
|
||||||
|
|
||||||
|
input.value = 'Example';
|
||||||
|
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
|
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
await Promise.resolve();
|
||||||
|
|
||||||
|
expect(fetchApiMock).toHaveBeenCalledTimes(1);
|
||||||
|
expect(fetchApiMock).toHaveBeenCalledWith('/lm/custom-words/search?enriched=true&search=Example&limit=100');
|
||||||
|
});
|
||||||
|
|
||||||
it('inserts the selected LoRA with usage tip strengths and restores focus', async () => {
|
it('inserts the selected LoRA with usage tip strengths and restores focus', async () => {
|
||||||
fetchApiMock.mockImplementation((url) => {
|
fetchApiMock.mockImplementation((url) => {
|
||||||
if (url.includes('usage-tips-by-path')) {
|
if (url.includes('usage-tips-by-path')) {
|
||||||
@@ -244,6 +269,55 @@ describe('AutoComplete widget interactions', () => {
|
|||||||
expect(inputListener).not.toHaveBeenCalled();
|
expect(inputListener).not.toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('shows the full command list when typing a single slash', async () => {
|
||||||
|
const input = document.createElement('textarea');
|
||||||
|
input.value = '/';
|
||||||
|
input.selectionStart = input.value.length;
|
||||||
|
document.body.append(input);
|
||||||
|
|
||||||
|
caretHelperInstance.getBeforeCursor.mockReturnValue('/');
|
||||||
|
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||||
|
|
||||||
|
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||||
|
const autoComplete = new AutoComplete(input,'prompt', { showPreview: false, minChars: 1 });
|
||||||
|
|
||||||
|
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
|
|
||||||
|
const commandNames = autoComplete.items.map((item) => item.command);
|
||||||
|
|
||||||
|
expect(commandNames).toContain('/character');
|
||||||
|
expect(commandNames).toContain('/char');
|
||||||
|
expect(commandNames).toContain('/artist');
|
||||||
|
expect(commandNames).toContain('/general');
|
||||||
|
expect(commandNames).toContain('/copyright');
|
||||||
|
expect(commandNames).toContain('/meta');
|
||||||
|
expect(commandNames).toContain('/species');
|
||||||
|
expect(commandNames).toContain('/lore');
|
||||||
|
expect(commandNames).toContain('/emb');
|
||||||
|
expect(commandNames).toContain('/embedding');
|
||||||
|
expect(commandNames).toContain('/wild');
|
||||||
|
expect(commandNames).toContain('/wildcard');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('renders every command item when slash opens the command list', async () => {
|
||||||
|
const input = document.createElement('textarea');
|
||||||
|
input.value = '/';
|
||||||
|
input.selectionStart = input.value.length;
|
||||||
|
document.body.append(input);
|
||||||
|
|
||||||
|
caretHelperInstance.getBeforeCursor.mockReturnValue('/');
|
||||||
|
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||||
|
|
||||||
|
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||||
|
const autoComplete = new AutoComplete(input, 'prompt', { showPreview: false, minChars: 1 });
|
||||||
|
|
||||||
|
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
|
|
||||||
|
const renderedCommands = autoComplete.contentContainer.querySelectorAll('.lm-autocomplete-command-name');
|
||||||
|
|
||||||
|
expect(renderedCommands).toHaveLength(autoComplete.items.length);
|
||||||
|
});
|
||||||
|
|
||||||
it('accepts the selected suggestion with Enter', async () => {
|
it('accepts the selected suggestion with Enter', async () => {
|
||||||
caretHelperInstance.getBeforeCursor.mockReturnValue('example');
|
caretHelperInstance.getBeforeCursor.mockReturnValue('example');
|
||||||
|
|
||||||
@@ -1095,6 +1169,7 @@ describe('AutoComplete widget interactions', () => {
|
|||||||
minChars: 1,
|
minChars: 1,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
fetchApiMock.mockClear();
|
||||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
await vi.runAllTimersAsync();
|
await vi.runAllTimersAsync();
|
||||||
await Promise.resolve();
|
await Promise.resolve();
|
||||||
@@ -1133,6 +1208,61 @@ describe('AutoComplete widget interactions', () => {
|
|||||||
expect(input.setSelectionRange).toHaveBeenCalled();
|
expect(input.setSelectionRange).toHaveBeenCalled();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it('does not reopen autocomplete on blur after inserting a wildcard literal', async () => {
|
||||||
|
const input = document.createElement('textarea');
|
||||||
|
input.value = '__flower__,';
|
||||||
|
input.selectionStart = input.value.length;
|
||||||
|
document.body.append(input);
|
||||||
|
|
||||||
|
caretHelperInstance.getBeforeCursor.mockReturnValue('__flower__,');
|
||||||
|
|
||||||
|
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||||
|
const autoComplete = new AutoComplete(input,'prompt', {
|
||||||
|
debounceDelay: 0,
|
||||||
|
showPreview: false,
|
||||||
|
minChars: 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
const hideSpy = vi.spyOn(autoComplete, 'hide');
|
||||||
|
input.dispatchEvent(new Event('blur', { bubbles: true }));
|
||||||
|
|
||||||
|
expect(fetchApiMock).not.toHaveBeenCalled();
|
||||||
|
expect(hideSpy).toHaveBeenCalled();
|
||||||
|
expect(autoComplete.isVisible).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('treats a command after a wildcard literal as the active token', async () => {
|
||||||
|
vi.useFakeTimers();
|
||||||
|
|
||||||
|
fetchApiMock.mockResolvedValue({
|
||||||
|
json: () => Promise.resolve({
|
||||||
|
success: true,
|
||||||
|
words: [{ tag_name: 'flower_field', category: 4, post_count: 1234 }],
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
const input = document.createElement('textarea');
|
||||||
|
input.value = '__flower__ /character f';
|
||||||
|
input.selectionStart = input.value.length;
|
||||||
|
document.body.append(input);
|
||||||
|
|
||||||
|
caretHelperInstance.getBeforeCursor.mockReturnValue('__flower__ /character f');
|
||||||
|
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||||
|
|
||||||
|
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||||
|
const autoComplete = new AutoComplete(input,'prompt', {
|
||||||
|
debounceDelay: 0,
|
||||||
|
showPreview: false,
|
||||||
|
minChars: 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
|
await vi.runAllTimersAsync();
|
||||||
|
await Promise.resolve();
|
||||||
|
|
||||||
|
expect(autoComplete.getSearchTerm(input.value)).toBe('/character f');
|
||||||
|
});
|
||||||
|
|
||||||
it('invalidates stale autocomplete metadata and falls back to delimiter-based matching', async () => {
|
it('invalidates stale autocomplete metadata and falls back to delimiter-based matching', async () => {
|
||||||
settingGetMock.mockImplementation((key) => {
|
settingGetMock.mockImplementation((key) => {
|
||||||
if (key === 'loramanager.autocomplete_append_comma') {
|
if (key === 'loramanager.autocomplete_append_comma') {
|
||||||
|
|||||||
@@ -94,6 +94,19 @@ class TestCustomWordsService:
|
|||||||
results = service.search_words("test")
|
results = service.search_words("test")
|
||||||
assert mock_tag_index.called
|
assert mock_tag_index.called
|
||||||
|
|
||||||
|
def test_search_words_skips_prompt_like_queries(self):
|
||||||
|
service = CustomWordsService.__new__(CustomWordsService)
|
||||||
|
mock_tag_index = MockTagFTSIndex()
|
||||||
|
|
||||||
|
def mock_get_index():
|
||||||
|
return mock_tag_index
|
||||||
|
|
||||||
|
service._get_tag_index = mock_get_index
|
||||||
|
|
||||||
|
results = service.search_words("__flower__ /character f")
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
assert mock_tag_index.called is False
|
||||||
|
|
||||||
class MockTagFTSIndex:
|
class MockTagFTSIndex:
|
||||||
"""Mock TagFTSIndex for testing."""
|
"""Mock TagFTSIndex for testing."""
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Tests for TagFTSIndex functionality."""
|
"""Tests for TagFTSIndex functionality."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sqlite3
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -173,6 +174,40 @@ class TestTagFTSIndexSearch:
|
|||||||
assert len(results) >= 1
|
assert len(results) >= 1
|
||||||
assert all(r["category"] in [4, 11] for r in results)
|
assert all(r["category"] in [4, 11] for r in results)
|
||||||
|
|
||||||
|
def test_search_with_category_filter_uses_fts_first_plan(self, populated_fts):
|
||||||
|
"""Category-filtered searches should start from FTS hits, not category scans."""
|
||||||
|
sql, params = populated_fts._build_search_statement(
|
||||||
|
query_lower="f",
|
||||||
|
fts_query="f*",
|
||||||
|
categories=[4, 11],
|
||||||
|
limit=20,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(f"file:{populated_fts.get_database_path()}?mode=ro", uri=True)
|
||||||
|
try:
|
||||||
|
plan_rows = conn.execute(f"EXPLAIN QUERY PLAN {sql}", params).fetchall()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
plan_details = [row[3] for row in plan_rows]
|
||||||
|
assert any(detail.startswith("SCAN tag_fts VIRTUAL TABLE INDEX") for detail in plan_details)
|
||||||
|
assert any("SEARCH t USING INTEGER PRIMARY KEY" in detail for detail in plan_details)
|
||||||
|
assert not any("SEARCH t USING INDEX idx_tags_category" in detail for detail in plan_details)
|
||||||
|
|
||||||
|
def test_search_statement_uses_post_count_as_tie_breaker(self, populated_fts):
|
||||||
|
"""Search ranking should use popularity as a secondary sort key."""
|
||||||
|
sql, _ = populated_fts._build_search_statement(
|
||||||
|
query_lower="f",
|
||||||
|
fts_query="f*",
|
||||||
|
categories=[4, 11],
|
||||||
|
limit=20,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC" in sql
|
||||||
|
assert "LOG10" not in sql
|
||||||
|
|
||||||
def test_search_with_category_filter_excludes_others(self, populated_fts):
|
def test_search_with_category_filter_excludes_others(self, populated_fts):
|
||||||
"""Test that category filter excludes other categories."""
|
"""Test that category filter excludes other categories."""
|
||||||
# Search for "hi" but only in general category
|
# Search for "hi" but only in general category
|
||||||
|
|||||||
@@ -366,6 +366,7 @@ class AutoComplete {
|
|||||||
this.previewTooltip = null;
|
this.previewTooltip = null;
|
||||||
this.previewTooltipPromise = null;
|
this.previewTooltipPromise = null;
|
||||||
this.searchType = null;
|
this.searchType = null;
|
||||||
|
this.suppressAutocompleteOnce = false;
|
||||||
|
|
||||||
// Virtual scrolling state
|
// Virtual scrolling state
|
||||||
this.virtualScrollOffset = 0;
|
this.virtualScrollOffset = 0;
|
||||||
@@ -505,6 +506,11 @@ class AutoComplete {
|
|||||||
bindEvents() {
|
bindEvents() {
|
||||||
// Handle input changes
|
// Handle input changes
|
||||||
this.onInput = (e) => {
|
this.onInput = (e) => {
|
||||||
|
if (this.suppressAutocompleteOnce) {
|
||||||
|
this.suppressAutocompleteOnce = false;
|
||||||
|
this.hide();
|
||||||
|
return;
|
||||||
|
}
|
||||||
this.handleInput(e.target.value);
|
this.handleInput(e.target.value);
|
||||||
};
|
};
|
||||||
this.inputElement.addEventListener('input', this.onInput);
|
this.inputElement.addEventListener('input', this.onInput);
|
||||||
@@ -521,6 +527,7 @@ class AutoComplete {
|
|||||||
const formattedValue = formatAutocompleteTextOnBlur(this.inputElement.value);
|
const formattedValue = formatAutocompleteTextOnBlur(this.inputElement.value);
|
||||||
if (formattedValue !== this.inputElement.value) {
|
if (formattedValue !== this.inputElement.value) {
|
||||||
this.inputElement.value = formattedValue;
|
this.inputElement.value = formattedValue;
|
||||||
|
this.suppressAutocompleteOnce = true;
|
||||||
this.inputElement.dispatchEvent(new Event('input', { bubbles: true }));
|
this.inputElement.dispatchEvent(new Event('input', { bubbles: true }));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -725,9 +732,24 @@ class AutoComplete {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const rawText = beforeCursor.substring(start);
|
const rawText = beforeCursor.substring(start);
|
||||||
const text = rawText.trim();
|
|
||||||
const leadingWhitespaceLength = rawText.length - rawText.trimStart().length;
|
const leadingWhitespaceLength = rawText.length - rawText.trimStart().length;
|
||||||
const trimmedStart = start + leadingWhitespaceLength;
|
const trimmedStart = start + leadingWhitespaceLength;
|
||||||
|
const text = rawText.trim();
|
||||||
|
|
||||||
|
if (this.modelType === 'prompt') {
|
||||||
|
const tokenRange = this._getPromptTokenRange(rawText, trimmedStart, caretPos);
|
||||||
|
if (tokenRange) {
|
||||||
|
return {
|
||||||
|
start: tokenRange.start,
|
||||||
|
trimmedStart: tokenRange.trimmedStart,
|
||||||
|
end: caretPos,
|
||||||
|
beforeCursor,
|
||||||
|
rawText: tokenRange.rawText,
|
||||||
|
text: tokenRange.text,
|
||||||
|
tokenType: tokenRange.tokenType,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
start,
|
start,
|
||||||
@@ -739,6 +761,73 @@ class AutoComplete {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_getPromptTokenRange(rawText = '', trimmedStart = 0, caretPos = 0) {
|
||||||
|
const trimmedText = rawText.trim();
|
||||||
|
if (!trimmedText) {
|
||||||
|
return {
|
||||||
|
start: trimmedStart,
|
||||||
|
trimmedStart,
|
||||||
|
rawText: '',
|
||||||
|
text: '',
|
||||||
|
tokenType: 'empty',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const commandOffset = trimmedText.startsWith('/')
|
||||||
|
? 0
|
||||||
|
: trimmedText.lastIndexOf(' /');
|
||||||
|
if (commandOffset !== -1) {
|
||||||
|
const normalizedCommandOffset = commandOffset === 0 ? 0 : commandOffset + 1;
|
||||||
|
const commandText = trimmedText.slice(normalizedCommandOffset);
|
||||||
|
const commandStart = trimmedStart + normalizedCommandOffset;
|
||||||
|
return {
|
||||||
|
start: commandStart,
|
||||||
|
trimmedStart: commandStart,
|
||||||
|
rawText: commandText,
|
||||||
|
text: commandText,
|
||||||
|
tokenType: commandText === '/' ? 'empty_command_trigger' : 'command',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const wildcardMatch = trimmedText.match(/(?:^|\s)(__[\w\s.\-+/*\\]+?__)$/);
|
||||||
|
if (wildcardMatch) {
|
||||||
|
const wildcardText = wildcardMatch[1];
|
||||||
|
const wildcardOffset = trimmedText.lastIndexOf(wildcardText);
|
||||||
|
const wildcardStart = trimmedStart + wildcardOffset;
|
||||||
|
return {
|
||||||
|
start: wildcardStart,
|
||||||
|
trimmedStart: wildcardStart,
|
||||||
|
rawText: wildcardText,
|
||||||
|
text: '',
|
||||||
|
tokenType: 'wildcard_literal',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const embeddingOffset = trimmedText.search(/(?:^|\s)emb:[^\s]*$/i);
|
||||||
|
if (embeddingOffset !== -1) {
|
||||||
|
const normalizedEmbeddingOffset = trimmedText.slice(embeddingOffset).startsWith(' ')
|
||||||
|
? embeddingOffset + 1
|
||||||
|
: embeddingOffset;
|
||||||
|
const embeddingText = trimmedText.slice(normalizedEmbeddingOffset);
|
||||||
|
const embeddingStart = trimmedStart + normalizedEmbeddingOffset;
|
||||||
|
return {
|
||||||
|
start: embeddingStart,
|
||||||
|
trimmedStart: embeddingStart,
|
||||||
|
rawText: embeddingText,
|
||||||
|
text: embeddingText,
|
||||||
|
tokenType: 'embedding_literal',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
start: trimmedStart,
|
||||||
|
trimmedStart,
|
||||||
|
rawText,
|
||||||
|
text: trimmedText,
|
||||||
|
tokenType: 'tag_text',
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
_getHardBoundaryStart(beforeCursor = '') {
|
_getHardBoundaryStart(beforeCursor = '') {
|
||||||
const lastComma = beforeCursor.lastIndexOf(',');
|
const lastComma = beforeCursor.lastIndexOf(',');
|
||||||
const lastAngle = beforeCursor.lastIndexOf('>');
|
const lastAngle = beforeCursor.lastIndexOf('>');
|
||||||
@@ -890,6 +979,32 @@ class AutoComplete {
|
|||||||
return Array.from(variations).filter(v => v.length >= this.options.minChars);
|
return Array.from(variations).filter(v => v.length >= this.options.minChars);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_normalizeQueryForRequest(term = '') {
|
||||||
|
return term.trim().toLowerCase();
|
||||||
|
}
|
||||||
|
|
||||||
|
_getQueriesToExecute(term = '') {
|
||||||
|
const queryVariations = this._generateQueryVariations(term);
|
||||||
|
const uniqueQueries = [];
|
||||||
|
const seen = new Set();
|
||||||
|
|
||||||
|
for (const query of queryVariations) {
|
||||||
|
const normalized = this._normalizeQueryForRequest(query);
|
||||||
|
if (!normalized || seen.has(normalized)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
seen.add(normalized);
|
||||||
|
uniqueQueries.push(query);
|
||||||
|
|
||||||
|
if (uniqueQueries.length >= 4) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return uniqueQueries;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get display text for an item (without extension for models)
|
* Get display text for an item (without extension for models)
|
||||||
* @param {string|Object} item - Item to get display text from
|
* @param {string|Object} item - Item to get display text from
|
||||||
@@ -1041,18 +1156,17 @@ class AutoComplete {
|
|||||||
endpoint = `/lm/${this.modelType}/relative-paths`;
|
endpoint = `/lm/${this.modelType}/relative-paths`;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate multiple query variations for better matching
|
// Generate multiple query variations for better matching, but avoid
|
||||||
const queryVariations = this._generateQueryVariations(term);
|
// sending duplicate-equivalent requests that normalize to the same
|
||||||
|
// backend search term.
|
||||||
|
const queriesToExecute = this._getQueriesToExecute(term);
|
||||||
|
|
||||||
if (queryVariations.length === 0) {
|
if (queriesToExecute.length === 0) {
|
||||||
this.items = [];
|
this.items = [];
|
||||||
this.hide();
|
this.hide();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Limit the number of parallel queries to avoid overwhelming the server
|
|
||||||
const queriesToExecute = queryVariations.slice(0, 4);
|
|
||||||
|
|
||||||
// Execute all queries in parallel
|
// Execute all queries in parallel
|
||||||
const searchPromises = queriesToExecute.map(async (query) => {
|
const searchPromises = queriesToExecute.map(async (query) => {
|
||||||
const url = endpoint.includes('?')
|
const url = endpoint.includes('?')
|
||||||
@@ -1190,20 +1304,15 @@ class AutoComplete {
|
|||||||
|
|
||||||
const filterLower = filter.toLowerCase();
|
const filterLower = filter.toLowerCase();
|
||||||
|
|
||||||
// Get unique commands (avoid duplicates like /char and /character)
|
|
||||||
const seenLabels = new Set();
|
|
||||||
const commands = [];
|
const commands = [];
|
||||||
|
|
||||||
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
|
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
|
||||||
if (seenLabels.has(info.label)) continue;
|
|
||||||
|
|
||||||
// Filter out toggle commands that don't meet their condition
|
// Filter out toggle commands that don't meet their condition
|
||||||
if (info.type === 'toggle_setting' && info.condition) {
|
if (info.type === 'toggle_setting' && info.condition) {
|
||||||
if (!info.condition()) continue;
|
if (!info.condition()) continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!filter || cmd.slice(1).startsWith(filterLower)) {
|
if (!filter || cmd.slice(1).startsWith(filterLower)) {
|
||||||
seenLabels.add(info.label);
|
|
||||||
commands.push({ command: cmd, ...info });
|
commands.push({ command: cmd, ...info });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1288,9 +1397,17 @@ class AutoComplete {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Auto-select immediately so accept keys remain stable.
|
// Auto-select immediately so accept keys remain stable.
|
||||||
|
// In virtual-scroll mode, calling selectItem() before the dropdown is
|
||||||
|
// visible can see a zero-height container and incorrectly replace the
|
||||||
|
// full command list with a partially virtualized slice.
|
||||||
if (this.items.length > 0) {
|
if (this.items.length > 0) {
|
||||||
|
this.selectedIndex = 0;
|
||||||
|
if (this.contentContainer) {
|
||||||
|
this._applyItemSelection(0);
|
||||||
|
} else {
|
||||||
this.selectItem(0);
|
this.selectItem(0);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update virtual scroll height for virtual scrolling mode
|
// Update virtual scroll height for virtual scrolling mode
|
||||||
if (this.contentContainer) {
|
if (this.contentContainer) {
|
||||||
@@ -1635,8 +1752,7 @@ class AutoComplete {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const queryVariations = this._generateQueryVariations(this.currentSearchTerm);
|
const queriesToExecute = this._getQueriesToExecute(this.currentSearchTerm);
|
||||||
const queriesToExecute = queryVariations.slice(0, 4);
|
|
||||||
const offset = this.items.length;
|
const offset = this.items.length;
|
||||||
|
|
||||||
// Execute all queries in parallel with offset
|
// Execute all queries in parallel with offset
|
||||||
|
|||||||
Reference in New Issue
Block a user