mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
fix(autocomplete): reduce tag search overhead (#895)
This commit is contained in:
@@ -7,11 +7,13 @@ with category filtering and enriched results including post counts.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_EMBEDDED_COMMAND_PATTERN = re.compile(r"\s/\w")
|
||||
class CustomWordsService:
|
||||
"""Service for autocomplete via TagFTSIndex.
|
||||
|
||||
@@ -77,10 +79,47 @@ class CustomWordsService:
|
||||
Returns:
|
||||
List of dicts with tag_name, category, and post_count.
|
||||
"""
|
||||
normalized_search = search_term.strip()
|
||||
if not normalized_search:
|
||||
return []
|
||||
|
||||
# Prompt widgets should only send the active token, but guard against
|
||||
# accidental full-prompt queries reaching the FTS path.
|
||||
if (
|
||||
"__" in normalized_search
|
||||
or "," in normalized_search
|
||||
or ">" in normalized_search
|
||||
or "\n" in normalized_search
|
||||
or "\r" in normalized_search
|
||||
or _EMBEDDED_COMMAND_PATTERN.search(normalized_search)
|
||||
):
|
||||
logger.debug("Skipping prompt-like custom words query: %s", normalized_search)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"LM custom words service start search=%r categories=%s limit=%s offset=%s enriched=%s",
|
||||
normalized_search,
|
||||
categories,
|
||||
limit,
|
||||
offset,
|
||||
enriched,
|
||||
)
|
||||
|
||||
tag_index = self._get_tag_index()
|
||||
if tag_index is not None:
|
||||
logger.info(
|
||||
"LM custom words service tag_index ready=%s indexing=%s",
|
||||
getattr(tag_index, "is_ready", lambda: "unknown")(),
|
||||
getattr(tag_index, "is_indexing", lambda: "unknown")(),
|
||||
)
|
||||
|
||||
results = tag_index.search(
|
||||
search_term, categories=categories, limit=limit, offset=offset
|
||||
normalized_search, categories=categories, limit=limit, offset=offset
|
||||
)
|
||||
logger.info(
|
||||
"LM custom words service done search=%r result_count=%s",
|
||||
normalized_search,
|
||||
len(results),
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@@ -450,9 +450,9 @@ class TagFTSIndex:
|
||||
the tag_name, the result will include a "matched_alias" field.
|
||||
|
||||
Ranking is based on a combination of:
|
||||
1. FTS5 bm25 relevance score (how well the text matches)
|
||||
2. Post count (popularity)
|
||||
3. Exact prefix match boost (tag_name starts with query)
|
||||
1. Exact prefix match boost (tag_name starts with query)
|
||||
2. Post count to preserve expected autocomplete ordering
|
||||
3. FTS5 bm25 relevance score as a deterministic tie-breaker
|
||||
|
||||
Args:
|
||||
query: The search query string.
|
||||
@@ -464,6 +464,17 @@ class TagFTSIndex:
|
||||
List of dictionaries with tag_name, category, post_count,
|
||||
rank_score, and optionally matched_alias.
|
||||
"""
|
||||
search_started_at = time.perf_counter()
|
||||
logger.info(
|
||||
"LM tag FTS search start query=%r categories=%s limit=%s offset=%s ready=%s indexing=%s",
|
||||
query,
|
||||
categories,
|
||||
limit,
|
||||
offset,
|
||||
self.is_ready(),
|
||||
self.is_indexing(),
|
||||
)
|
||||
|
||||
# Ensure index is ready (lazy initialization)
|
||||
if not self.ensure_ready():
|
||||
if not self._warned_not_ready:
|
||||
@@ -478,71 +489,44 @@ class TagFTSIndex:
|
||||
if not fts_query:
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"LM tag FTS search built query=%r fts_query=%r",
|
||||
query,
|
||||
fts_query,
|
||||
)
|
||||
|
||||
query_lower = query.lower().strip()
|
||||
|
||||
try:
|
||||
logger.info("LM tag FTS search waiting_for_lock query=%r", query)
|
||||
with self._lock:
|
||||
logger.info("LM tag FTS search acquired_lock query=%r", query)
|
||||
conn = self._connect(readonly=True)
|
||||
try:
|
||||
# Build the SQL query with bm25 ranking
|
||||
# FTS5 bm25() returns negative scores, lower is better
|
||||
# We use -bm25() to get higher=better scores
|
||||
# Weights: -100.0 for exact matches, 1.0 for others
|
||||
# Add LOG10(post_count) weighting to boost popular tags
|
||||
# Use CASE to boost tag_name prefix matches above alias matches
|
||||
if categories:
|
||||
placeholders = ",".join("?" * len(categories))
|
||||
sql = f"""
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
|
||||
FROM tag_fts
|
||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
AND t.category IN ({placeholders})
|
||||
ORDER BY is_tag_name_match DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
# Escape special LIKE characters and add wildcard
|
||||
query_escaped = (
|
||||
query_lower.lstrip("/")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
params = (
|
||||
[query_escaped + "%", fts_query]
|
||||
+ categories
|
||||
+ [limit, offset]
|
||||
)
|
||||
else:
|
||||
sql = """
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) + LOG10(t.post_count + 1) * 10.0 AS rank_score
|
||||
FROM tag_fts
|
||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
ORDER BY is_tag_name_match DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
query_escaped = (
|
||||
query_lower.lstrip("/")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
params = [query_escaped + "%", fts_query, limit, offset]
|
||||
sql, params = self._build_search_statement(
|
||||
query_lower=query_lower,
|
||||
fts_query=fts_query,
|
||||
categories=categories,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"LM tag FTS search executing_sql query=%r query_len=%s category_count=%s",
|
||||
query,
|
||||
len(query_lower),
|
||||
len(categories) if categories else 0,
|
||||
)
|
||||
cursor = conn.execute(sql, params)
|
||||
logger.info("LM tag FTS search execute_returned query=%r", query)
|
||||
rows = cursor.fetchall()
|
||||
logger.info(
|
||||
"LM tag FTS search fetchall_returned query=%r row_count=%s",
|
||||
query,
|
||||
len(rows),
|
||||
)
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
for row in rows:
|
||||
result = {
|
||||
"tag_name": row[0],
|
||||
"category": row[1],
|
||||
@@ -564,6 +548,13 @@ class TagFTSIndex:
|
||||
result["matched_alias"] = matched_alias
|
||||
|
||||
results.append(result)
|
||||
elapsed_ms = (time.perf_counter() - search_started_at) * 1000
|
||||
logger.info(
|
||||
"LM tag FTS search done query=%r result_count=%s elapsed_ms=%.2f",
|
||||
query,
|
||||
len(results),
|
||||
elapsed_ms,
|
||||
)
|
||||
return results
|
||||
finally:
|
||||
conn.close()
|
||||
@@ -571,6 +562,62 @@ class TagFTSIndex:
|
||||
logger.debug("Tag FTS search error for query '%s': %s", query, exc)
|
||||
return []
|
||||
|
||||
def _build_search_statement(
|
||||
self,
|
||||
query_lower: str,
|
||||
fts_query: str,
|
||||
categories: Optional[List[int]],
|
||||
limit: int,
|
||||
offset: int,
|
||||
) -> tuple[str, list[object]]:
|
||||
"""Build the SQL statement and params for a tag search."""
|
||||
# Escape special LIKE characters and add wildcard
|
||||
query_escaped = (
|
||||
query_lower.lstrip("/")
|
||||
.replace("\\", "\\\\")
|
||||
.replace("%", "\\%")
|
||||
.replace("_", "\\_")
|
||||
)
|
||||
|
||||
# FTS5 bm25() returns negative scores, lower is better.
|
||||
# We use -bm25() to get higher=better scores, but keep post_count as the
|
||||
# primary sort within tag-name prefix matches so autocomplete ordering
|
||||
# remains aligned with the existing popularity-first behavior.
|
||||
if categories:
|
||||
placeholders = ",".join("?" * len(categories))
|
||||
sql = f"""
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
|
||||
FROM tag_fts
|
||||
CROSS JOIN tags t ON t.rowid = tag_fts.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
AND t.category IN ({placeholders})
|
||||
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = [query_escaped + "%", fts_query] + categories + [limit, offset]
|
||||
else:
|
||||
sql = """
|
||||
SELECT t.tag_name, t.category, t.post_count, t.aliases,
|
||||
CASE
|
||||
WHEN t.tag_name LIKE ? ESCAPE '\\' THEN 1
|
||||
ELSE 0
|
||||
END AS is_tag_name_match,
|
||||
bm25(tag_fts, -100.0, 1.0, 1.0) AS rank_score
|
||||
FROM tag_fts
|
||||
JOIN tags t ON tag_fts.rowid = t.rowid
|
||||
WHERE tag_fts.searchable_text MATCH ?
|
||||
ORDER BY is_tag_name_match DESC, t.post_count DESC, rank_score DESC
|
||||
LIMIT ? OFFSET ?
|
||||
"""
|
||||
params = [query_escaped + "%", fts_query, limit, offset]
|
||||
|
||||
return sql, params
|
||||
|
||||
def _find_matched_alias(
|
||||
self, query: str, tag_name: str, aliases_str: str
|
||||
) -> Optional[str]:
|
||||
|
||||
Reference in New Issue
Block a user