mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(autocomplete): reduce tag search overhead (#895)
This commit is contained in:
@@ -13,6 +13,7 @@ import contextlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
@@ -2433,6 +2434,7 @@ class CustomWordsHandler:
|
||||
even without category filtering.
|
||||
"""
|
||||
try:
|
||||
started_at = time.perf_counter()
|
||||
search_term = request.query.get("search", "")
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
offset = max(0, int(request.query.get("offset", "0")))
|
||||
@@ -2444,6 +2446,16 @@ class CustomWordsHandler:
|
||||
if category_param:
|
||||
categories = self._parse_category_param(category_param)
|
||||
|
||||
logger.info(
|
||||
"LM custom words request search=%r category_param=%r categories=%s limit=%s offset=%s enriched=%s",
|
||||
search_term,
|
||||
category_param,
|
||||
categories,
|
||||
limit,
|
||||
offset,
|
||||
enriched_param,
|
||||
)
|
||||
|
||||
results = self._service.search_words(
|
||||
search_term,
|
||||
limit,
|
||||
@@ -2452,6 +2464,14 @@ class CustomWordsHandler:
|
||||
enriched=enriched_param,
|
||||
)
|
||||
|
||||
elapsed_ms = (time.perf_counter() - started_at) * 1000
|
||||
logger.info(
|
||||
"LM custom words response search=%r result_count=%s elapsed_ms=%.2f",
|
||||
search_term,
|
||||
len(results),
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "words": results})
|
||||
except Exception as exc:
|
||||
logger.error("Error searching custom words: %s", exc, exc_info=True)
|
||||
|
||||
@@ -7,11 +7,13 @@ with category filtering and enriched results including post counts.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_EMBEDDED_COMMAND_PATTERN = re.compile(r"\s/\w")
|
||||
class CustomWordsService:
|
||||
"""Service for autocomplete via TagFTSIndex.
|
||||
|
||||
@@ -77,10 +79,47 @@ class CustomWordsService:
|
||||
Returns:
|
||||
List of dicts with tag_name, category, and post_count.
|
||||
"""
|
||||
normalized_search = search_term.strip()
|
||||
if not normalized_search:
|
||||
return []
|
||||
|
||||
# Prompt widgets should only send the active token, but guard against
|
||||
# accidental full-prompt queries reaching the FTS path.
|
||||
if (
|
||||
"__" in normalized_search
|
||||
or "," in normalized_search
|
||||
or ">" in normalized_search
|
||||
or "\n" in normalized_search
|
||||
or "\r" in normalized_search
|
||||
or _EMBEDDED_COMMAND_PATTERN.search(normalized_search)
|
||||
):
|
||||
logger.debug("Skipping prompt-like custom words query: %s", normalized_search)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"LM custom words service start search=%r categories=%s limit=%s offset=%s enriched=%s",
|
||||
normalized_search,
|
||||
categories,
|
||||
limit,
|
||||
offset,
|
||||
enriched,
|
||||
)
|
||||
|
||||
tag_index = self._get_tag_index()
|
||||
if tag_index is not None:
|
||||
logger.info(
|
||||
"LM custom words service tag_index ready=%s indexing=%s",
|
||||
getattr(tag_index, "is_ready", lambda: "unknown")(),
|
||||
getattr(tag_index, "is_indexing", lambda: "unknown")(),
|
||||
)
|
||||
|
||||
results = tag_index.search(
|
||||
search_term, categories=categories, limit=limit, offset=offset
|
||||
normalized_search, categories=categories, limit=limit, offset=offset
|
||||
)
|
||||
logger.info(
|
||||
"LM custom words service done search=%r result_count=%s",
|
||||
normalized_search,
|
||||
len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@@ -450,9 +450,9 @@ class TagFTSIndex:
|
||||
the tag_name, the result will include a "matched_alias" field.
|
||||
|
||||
Ranking is based on a combination of:
|
||||
1. FTS5 bm25 relevance score (how well the text matches)
|
||||
2. Post count (popularity)
|
||||
3. Exact prefix match boost (tag_name starts with query)
|
||||
1. Exact prefix match boost (tag_name starts with query)
|
||||
2. Post count to preserve expected autocomplete ordering
|
||||
3. FTS5 bm25 relevance score as a deterministic tie-breaker
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
@@ -464,6 +464,17 @@ class TagFTSIndex:
|
||||
List of dictionaries with tag_name, category, post_count,
|
||||
rank_score, and optionally matched_alias.
|
||||
"""
|
||||
search_started_at = time.perf_counter()
|
||||
logger.info(
|
||||
"LM tag FTS search start query=%r categories=%s limit=%s offset=%s ready=%s indexing=%s",
|
||||
query,
|
||||
categories,
|
||||
limit,
|
||||
offset,
|
||||
self.is_ready(),
|
||||
self.is_indexing(),
|
||||
)
|
||||
|
||||
# Ensure index is ready (lazy initialization)
|
||||
if not self.ensure_ready():
|
||||
if not self._warned_not_ready:
|
||||
@@ -478,71 +489,44 @@ class TagFTSIndex:
|
||||
if not fts_query:
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"LM tag FTS search built query=%r fts_query=%r",
|
||||
query,
|
||||
fts_query,
|
||||
)
|
||||
|
||||
query_lower = query.lower().strip()
|
||||
|
||||
try:
|
||||
logger.info("LM tag FTS search waiting_for_lock query=%r", query)
|
||||
with self._lock:
|
||||
logger.info("LM tag FTS search acquired_lock query=%r", query)
|
||||
conn = self._connect(readonly=True)
|
||||
try:
|
||||
# Build the SQL query with bm25 ranking
|
||||
# FTS5 bm25() returns negative scores, lower is better
|
||||
# We use -bm25() to get higher=better scores
|
||||
# Weights: -100.0 for exact matches, 1.0 for others
|
||||
# Add LOG10(post_count) weighting to boost popular tags
|
||||
# Use CASE to boost tag_name prefix matches above alias matches
|
||||
if categories:
|
||||
placeholders = ",".join("?" * len(categories))
|
||||
sql = f"""
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
|
||||
FROM tag_fts
|
||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
AND t.category IN ({placeholders})
|
||||
ORDER BY is_tag_name_match DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
# Escape special LIKE characters and add wildcard
|
||||
query_escaped = (
|
||||
query_lower.lstrip("/")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
sql, params = self._build_search_statement(
|
||||
query_lower=query_lower,
|
||||
fts_query=fts_query,
|
||||
categories=categories,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
params = (
|
||||
[query_escaped + "%", fts_query]
|
||||
+ categories
|
||||
+ [limit, offset]
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
|
||||
FROM tag_fts
|
||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
ORDER BY is_tag_name_match DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
query_escaped = (
|
||||
query_lower.lstrip("/")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
params = [query_escaped + "%", fts_query, limit, offset]
|
||||
|
||||
logger.info(
|
||||
"LM tag FTS search executing_sql query=%r query_len=%s category_count=%s",
|
||||
query,
|
||||
len(query_lower),
|
||||
len(categories) if categories else 0,
|
||||
)
|
||||
cursor = conn.execute(sql, params)
|
||||
logger.info("LM tag FTS search execute_returned query=%r", query)
|
||||
rows = cursor.fetchall()
|
||||
logger.info(
|
||||
"LM tag FTS search fetchall_returned query=%r row_count=%s",
|
||||
query,
|
||||
len(rows),
|
||||
)
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
for row in rows:
|
||||
result = {
|
||||
"tag_name": row[0],
|
||||
"category": row[1],
|
||||
@@ -564,6 +548,13 @@ class TagFTSIndex:
|
||||
result["matched_alias"] = matched_alias
|
||||
|
||||
results.append(result)
|
||||
elapsed_ms = (time.perf_counter() - search_started_at) * 1000
|
||||
logger.info(
|
||||
"LM tag FTS search done query=%r result_count=%s elapsed_ms=%.2f",
|
||||
query,
|
||||
len(results),
|
||||
elapsed_ms,
|
||||
)
|
||||
return results
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -571,6 +562,62 @@ class TagFTSIndex:
|
||||
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
|
||||
return []
|
||||
|
||||
def _build_search_statement(
|
||||
self,
|
||||
query_lower: str,
|
||||
fts_query: str,
|
||||
categories: Optional[List[int]],
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> tuple[str, list[object]]:
|
||||
"""Build the SQL statement and params for a tag search."""
|
||||
# Escape special LIKE characters and add wildcard
|
||||
query_escaped = (
|
||||
query_lower.lstrip("/")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
|
||||
# FTS5 bm25() returns negative scores, lower is better.
|
||||
# We use -bm25() to get higher=better scores, but keep post_count as the
|
||||
# primary sort within tag-name prefix matches so autocomplete ordering
|
||||
# remains aligned with the existing popularity-first behavior.
|
||||
if categories:
|
||||
placeholders = ",".join("?" * len(categories))
|
||||
sql = f"""
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
|
||||
FROM tag_fts
|
||||
CROSS JOIN tags t ON t.rowid = tag_fts.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
AND t.category IN ({placeholders})
|
||||
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = [query_escaped + "%", fts_query] + categories + [limit, offset]
|
||||
else:
|
||||
sql = """
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
|
||||
FROM tag_fts
|
||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = [query_escaped + "%", fts_query, limit, offset]
|
||||
|
||||
return sql, params
|
||||
|
||||
def _find_matched_alias(
|
||||
self, query: str, tag_name: str, aliases_str: str
|
||||
) -> Optional[str]:
|
||||
|
||||
@@ -126,6 +126,31 @@ describe('AutoComplete widget interactions', () => {
|
||||
expect(caretHelperInstance.getCursorOffset).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('deduplicates duplicate-equivalent query variations before issuing requests', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({ success: true, words: [] }),
|
||||
});
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('Example');
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
document.body.append(input);
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
new AutoComplete(input, 'prompt', { debounceDelay: 0, showPreview: false, minChars: 1 });
|
||||
|
||||
input.value = 'Example';
|
||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
await vi.runAllTimersAsync();
|
||||
await Promise.resolve();
|
||||
|
||||
expect(fetchApiMock).toHaveBeenCalledTimes(1);
|
||||
expect(fetchApiMock).toHaveBeenCalledWith('/lm/custom-words/search?enriched=true&search=Example&limit=100');
|
||||
});
|
||||
|
||||
it('inserts the selected LoRA with usage tip strengths and restores focus', async () => {
|
||||
fetchApiMock.mockImplementation((url) => {
|
||||
if (url.includes('usage-tips-by-path')) {
|
||||
@@ -244,6 +269,55 @@ describe('AutoComplete widget interactions', () => {
|
||||
expect(inputListener).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('shows the full command list when typing a single slash', async () => {
|
||||
const input = document.createElement('textarea');
|
||||
input.value = '/';
|
||||
input.selectionStart = input.value.length;
|
||||
document.body.append(input);
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('/');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input,'prompt', { showPreview: false, minChars: 1 });
|
||||
|
||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
const commandNames = autoComplete.items.map((item) => item.command);
|
||||
|
||||
expect(commandNames).toContain('/character');
|
||||
expect(commandNames).toContain('/char');
|
||||
expect(commandNames).toContain('/artist');
|
||||
expect(commandNames).toContain('/general');
|
||||
expect(commandNames).toContain('/copyright');
|
||||
expect(commandNames).toContain('/meta');
|
||||
expect(commandNames).toContain('/species');
|
||||
expect(commandNames).toContain('/lore');
|
||||
expect(commandNames).toContain('/emb');
|
||||
expect(commandNames).toContain('/embedding');
|
||||
expect(commandNames).toContain('/wild');
|
||||
expect(commandNames).toContain('/wildcard');
|
||||
});
|
||||
|
||||
it('renders every command item when slash opens the command list', async () => {
|
||||
const input = document.createElement('textarea');
|
||||
input.value = '/';
|
||||
input.selectionStart = input.value.length;
|
||||
document.body.append(input);
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('/');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input, 'prompt', { showPreview: false, minChars: 1 });
|
||||
|
||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
|
||||
const renderedCommands = autoComplete.contentContainer.querySelectorAll('.lm-autocomplete-command-name');
|
||||
|
||||
expect(renderedCommands).toHaveLength(autoComplete.items.length);
|
||||
});
|
||||
|
||||
it('accepts the selected suggestion with Enter', async () => {
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('example');
|
||||
|
||||
@@ -1095,6 +1169,7 @@ describe('AutoComplete widget interactions', () => {
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
fetchApiMock.mockClear();
|
||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
await vi.runAllTimersAsync();
|
||||
await Promise.resolve();
|
||||
@@ -1133,6 +1208,61 @@ describe('AutoComplete widget interactions', () => {
|
||||
expect(input.setSelectionRange).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('does not reopen autocomplete on blur after inserting a wildcard literal', async () => {
|
||||
const input = document.createElement('textarea');
|
||||
input.value = '__flower__,';
|
||||
input.selectionStart = input.value.length;
|
||||
document.body.append(input);
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('__flower__,');
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input,'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
const hideSpy = vi.spyOn(autoComplete, 'hide');
|
||||
input.dispatchEvent(new Event('blur', { bubbles: true }));
|
||||
|
||||
expect(fetchApiMock).not.toHaveBeenCalled();
|
||||
expect(hideSpy).toHaveBeenCalled();
|
||||
expect(autoComplete.isVisible).toBe(false);
|
||||
});
|
||||
|
||||
it('treats a command after a wildcard literal as the active token', async () => {
|
||||
vi.useFakeTimers();
|
||||
|
||||
fetchApiMock.mockResolvedValue({
|
||||
json: () => Promise.resolve({
|
||||
success: true,
|
||||
words: [{ tag_name: 'flower_field', category: 4, post_count: 1234 }],
|
||||
}),
|
||||
});
|
||||
|
||||
const input = document.createElement('textarea');
|
||||
input.value = '__flower__ /character f';
|
||||
input.selectionStart = input.value.length;
|
||||
document.body.append(input);
|
||||
|
||||
caretHelperInstance.getBeforeCursor.mockReturnValue('__flower__ /character f');
|
||||
caretHelperInstance.getCursorOffset.mockReturnValue({ left: 15, top: 25 });
|
||||
|
||||
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
|
||||
const autoComplete = new AutoComplete(input,'prompt', {
|
||||
debounceDelay: 0,
|
||||
showPreview: false,
|
||||
minChars: 1,
|
||||
});
|
||||
|
||||
input.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
await vi.runAllTimersAsync();
|
||||
await Promise.resolve();
|
||||
|
||||
expect(autoComplete.getSearchTerm(input.value)).toBe('/character f');
|
||||
});
|
||||
|
||||
it('invalidates stale autocomplete metadata and falls back to delimiter-based matching', async () => {
|
||||
settingGetMock.mockImplementation((key) => {
|
||||
if (key === 'loramanager.autocomplete_append_comma') {
|
||||
|
||||
@@ -94,6 +94,19 @@ class TestCustomWordsService:
|
||||
results = service.search_words("test")
|
||||
assert mock_tag_index.called
|
||||
|
||||
def test_search_words_skips_prompt_like_queries(self):
|
||||
service = CustomWordsService.__new__(CustomWordsService)
|
||||
mock_tag_index = MockTagFTSIndex()
|
||||
|
||||
def mock_get_index():
|
||||
return mock_tag_index
|
||||
|
||||
service._get_tag_index = mock_get_index
|
||||
|
||||
results = service.search_words("__flower__ /character f")
|
||||
|
||||
assert results == []
|
||||
assert mock_tag_index.called is False
|
||||
|
||||
class MockTagFTSIndex:
|
||||
"""Mock TagFTSIndex for testing."""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Tests for TagFTSIndex functionality."""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from typing import List
|
||||
|
||||
@@ -173,6 +174,40 @@ class TestTagFTSIndexSearch:
|
||||
assert len(results) >= 1
|
||||
assert all(r["category"] in [4, 11] for r in results)
|
||||
|
||||
def test_search_with_category_filter_uses_fts_first_plan(self, populated_fts):
|
||||
"""Category-filtered searches should start from FTS hits, not category scans."""
|
||||
sql, params = populated_fts._build_search_statement(
|
||||
query_lower="f",
|
||||
fts_query="f*",
|
||||
categories=[4, 11],
|
||||
limit=20,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
conn = sqlite3.connect(f"file:{populated_fts.get_database_path()}?mode=ro", uri=True)
|
||||
try:
|
||||
plan_rows = conn.execute(f"EXPLAIN QUERY PLAN {sql}", params).fetchall()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
plan_details = [row[3] for row in plan_rows]
|
||||
assert any(detail.startswith("SCAN tag_fts VIRTUAL TABLE INDEX") for detail in plan_details)
|
||||
assert any("SEARCH t USING INTEGER PRIMARY KEY" in detail for detail in plan_details)
|
||||
assert not any("SEARCH t USING INDEX idx_tags_category" in detail for detail in plan_details)
|
||||
|
||||
def test_search_statement_uses_post_count_as_tie_breaker(self, populated_fts):
|
||||
"""Search ranking should use popularity as a secondary sort key."""
|
||||
sql, _ = populated_fts._build_search_statement(
|
||||
query_lower="f",
|
||||
fts_query="f*",
|
||||
categories=[4, 11],
|
||||
limit=20,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
assert "ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC" in sql
|
||||
assert "LOG10" not in sql
|
||||
|
||||
def test_search_with_category_filter_excludes_others(self, populated_fts):
|
||||
"""Test that category filter excludes other categories."""
|
||||
# Search for "hi" but only in general category
|
||||
|
||||
@@ -366,6 +366,7 @@ class AutoComplete {
|
||||
this.previewTooltip = null;
|
||||
this.previewTooltipPromise = null;
|
||||
this.searchType = null;
|
||||
this.suppressAutocompleteOnce = false;
|
||||
|
||||
// Virtual scrolling state
|
||||
this.virtualScrollOffset = 0;
|
||||
@@ -505,6 +506,11 @@ class AutoComplete {
|
||||
bindEvents() {
|
||||
// Handle input changes
|
||||
this.onInput = (e) => {
|
||||
if (this.suppressAutocompleteOnce) {
|
||||
this.suppressAutocompleteOnce = false;
|
||||
this.hide();
|
||||
return;
|
||||
}
|
||||
this.handleInput(e.target.value);
|
||||
};
|
||||
this.inputElement.addEventListener('input', this.onInput);
|
||||
@@ -521,6 +527,7 @@ class AutoComplete {
|
||||
const formattedValue = formatAutocompleteTextOnBlur(this.inputElement.value);
|
||||
if (formattedValue !== this.inputElement.value) {
|
||||
this.inputElement.value = formattedValue;
|
||||
this.suppressAutocompleteOnce = true;
|
||||
this.inputElement.dispatchEvent(new Event('input', { bubbles: true }));
|
||||
}
|
||||
}
|
||||
@@ -725,9 +732,24 @@ class AutoComplete {
|
||||
}
|
||||
|
||||
const rawText = beforeCursor.substring(start);
|
||||
const text = rawText.trim();
|
||||
const leadingWhitespaceLength = rawText.length - rawText.trimStart().length;
|
||||
const trimmedStart = start + leadingWhitespaceLength;
|
||||
const text = rawText.trim();
|
||||
|
||||
if (this.modelType === 'prompt') {
|
||||
const tokenRange = this._getPromptTokenRange(rawText, trimmedStart, caretPos);
|
||||
if (tokenRange) {
|
||||
return {
|
||||
start: tokenRange.start,
|
||||
trimmedStart: tokenRange.trimmedStart,
|
||||
end: caretPos,
|
||||
beforeCursor,
|
||||
rawText: tokenRange.rawText,
|
||||
text: tokenRange.text,
|
||||
tokenType: tokenRange.tokenType,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
start,
|
||||
@@ -739,6 +761,73 @@ class AutoComplete {
|
||||
};
|
||||
}
|
||||
|
||||
_getPromptTokenRange(rawText = '', trimmedStart = 0, caretPos = 0) {
|
||||
const trimmedText = rawText.trim();
|
||||
if (!trimmedText) {
|
||||
return {
|
||||
start: trimmedStart,
|
||||
trimmedStart,
|
||||
rawText: '',
|
||||
text: '',
|
||||
tokenType: 'empty',
|
||||
};
|
||||
}
|
||||
|
||||
const commandOffset = trimmedText.startsWith('/')
|
||||
? 0
|
||||
: trimmedText.lastIndexOf(' /');
|
||||
if (commandOffset !== -1) {
|
||||
const normalizedCommandOffset = commandOffset === 0 ? 0 : commandOffset + 1;
|
||||
const commandText = trimmedText.slice(normalizedCommandOffset);
|
||||
const commandStart = trimmedStart + normalizedCommandOffset;
|
||||
return {
|
||||
start: commandStart,
|
||||
trimmedStart: commandStart,
|
||||
rawText: commandText,
|
||||
text: commandText,
|
||||
tokenType: commandText === '/' ? 'empty_command_trigger' : 'command',
|
||||
};
|
||||
}
|
||||
|
||||
const wildcardMatch = trimmedText.match(/(?:^|\s)(__[\w\s.\-+/*\\]+?__)$/);
|
||||
if (wildcardMatch) {
|
||||
const wildcardText = wildcardMatch[1];
|
||||
const wildcardOffset = trimmedText.lastIndexOf(wildcardText);
|
||||
const wildcardStart = trimmedStart + wildcardOffset;
|
||||
return {
|
||||
start: wildcardStart,
|
||||
trimmedStart: wildcardStart,
|
||||
rawText: wildcardText,
|
||||
text: '',
|
||||
tokenType: 'wildcard_literal',
|
||||
};
|
||||
}
|
||||
|
||||
const embeddingOffset = trimmedText.search(/(?:^|\s)emb:[^\s]*$/i);
|
||||
if (embeddingOffset !== -1) {
|
||||
const normalizedEmbeddingOffset = trimmedText.slice(embeddingOffset).startsWith(' ')
|
||||
? embeddingOffset + 1
|
||||
: embeddingOffset;
|
||||
const embeddingText = trimmedText.slice(normalizedEmbeddingOffset);
|
||||
const embeddingStart = trimmedStart + normalizedEmbeddingOffset;
|
||||
return {
|
||||
start: embeddingStart,
|
||||
trimmedStart: embeddingStart,
|
||||
rawText: embeddingText,
|
||||
text: embeddingText,
|
||||
tokenType: 'embedding_literal',
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
start: trimmedStart,
|
||||
trimmedStart,
|
||||
rawText,
|
||||
text: trimmedText,
|
||||
tokenType: 'tag_text',
|
||||
};
|
||||
}
|
||||
|
||||
_getHardBoundaryStart(beforeCursor = '') {
|
||||
const lastComma = beforeCursor.lastIndexOf(',');
|
||||
const lastAngle = beforeCursor.lastIndexOf('>');
|
||||
@@ -890,6 +979,32 @@ class AutoComplete {
|
||||
return Array.from(variations).filter(v => v.length >= this.options.minChars);
|
||||
}
|
||||
|
||||
_normalizeQueryForRequest(term = '') {
|
||||
return term.trim().toLowerCase();
|
||||
}
|
||||
|
||||
_getQueriesToExecute(term = '') {
|
||||
const queryVariations = this._generateQueryVariations(term);
|
||||
const uniqueQueries = [];
|
||||
const seen = new Set();
|
||||
|
||||
for (const query of queryVariations) {
|
||||
const normalized = this._normalizeQueryForRequest(query);
|
||||
if (!normalized || seen.has(normalized)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
seen.add(normalized);
|
||||
uniqueQueries.push(query);
|
||||
|
||||
if (uniqueQueries.length >= 4) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return uniqueQueries;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get display text for an item (without extension for models)
|
||||
* @param {string|Object} item - Item to get display text from
|
||||
@@ -1041,18 +1156,17 @@ class AutoComplete {
|
||||
endpoint = `/lm/${this.modelType}/relative-paths`;
|
||||
}
|
||||
|
||||
// Generate multiple query variations for better matching
|
||||
const queryVariations = this._generateQueryVariations(term);
|
||||
// Generate multiple query variations for better matching, but avoid
|
||||
// sending duplicate-equivalent requests that normalize to the same
|
||||
// backend search term.
|
||||
const queriesToExecute = this._getQueriesToExecute(term);
|
||||
|
||||
if (queryVariations.length === 0) {
|
||||
if (queriesToExecute.length === 0) {
|
||||
this.items = [];
|
||||
this.hide();
|
||||
return;
|
||||
}
|
||||
|
||||
// Limit the number of parallel queries to avoid overwhelming the server
|
||||
const queriesToExecute = queryVariations.slice(0, 4);
|
||||
|
||||
// Execute all queries in parallel
|
||||
const searchPromises = queriesToExecute.map(async (query) => {
|
||||
const url = endpoint.includes('?')
|
||||
@@ -1190,20 +1304,15 @@ class AutoComplete {
|
||||
|
||||
const filterLower = filter.toLowerCase();
|
||||
|
||||
// Get unique commands (avoid duplicates like /char and /character)
|
||||
const seenLabels = new Set();
|
||||
const commands = [];
|
||||
|
||||
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
|
||||
if (seenLabels.has(info.label)) continue;
|
||||
|
||||
// Filter out toggle commands that don't meet their condition
|
||||
if (info.type === 'toggle_setting' && info.condition) {
|
||||
if (!info.condition()) continue;
|
||||
}
|
||||
|
||||
if (!filter || cmd.slice(1).startsWith(filterLower)) {
|
||||
seenLabels.add(info.label);
|
||||
commands.push({ command: cmd, ...info });
|
||||
}
|
||||
}
|
||||
@@ -1288,9 +1397,17 @@ class AutoComplete {
|
||||
}
|
||||
|
||||
// Auto-select immediately so accept keys remain stable.
|
||||
// In virtual-scroll mode, calling selectItem() before the dropdown is
|
||||
// visible can see a zero-height container and incorrectly replace the
|
||||
// full command list with a partially virtualized slice.
|
||||
if (this.items.length > 0) {
|
||||
this.selectedIndex = 0;
|
||||
if (this.contentContainer) {
|
||||
this._applyItemSelection(0);
|
||||
} else {
|
||||
this.selectItem(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Update virtual scroll height for virtual scrolling mode
|
||||
if (this.contentContainer) {
|
||||
@@ -1635,8 +1752,7 @@ class AutoComplete {
|
||||
}
|
||||
}
|
||||
|
||||
const queryVariations = this._generateQueryVariations(this.currentSearchTerm);
|
||||
const queriesToExecute = queryVariations.slice(0, 4);
|
||||
const queriesToExecute = this._getQueriesToExecute(this.currentSearchTerm);
|
||||
const offset = this.items.length;
|
||||
|
||||
// Execute all queries in parallel with offset
|
||||
|
||||
Reference in New Issue
Block a user