fix(autocomplete): strip file extensions from model names in search suggestions

Remove .safetensors/.ckpt/.pt/.bin extensions from model names in autocomplete
suggestions to improve UX and search relevance:

Frontend (web/comfyui/autocomplete.js):
- Add _getDisplayText() helper to strip extensions from model paths
- Update _matchItem() to match against filename without extension
- Update render() and createItemElement() to display clean names

Backend (py/services/base_model_service.py):
- Add _remove_model_extension() helper method
- Update _relative_path_matches_tokens() to ignore extensions in matching
- Update _relative_path_sort_key() to sort based on names without extensions

Tests (tests/services/test_relative_path_search.py):
- Add tests to verify 's' and 'safe' queries don't match all .safetensors files

Fixes issue where typing 's' would match all .safetensors files and cluttered
suggestions with redundant extension names.
This commit is contained in:
Will Miao
2026-03-07 23:07:10 +08:00
parent a802a89ff9
commit 43f6bfab36
3 changed files with 95 additions and 12 deletions

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
import asyncio
import re
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import logging
import os
@@ -821,35 +822,58 @@ class BaseModelService(ABC):
return include_terms, exclude_terms
@staticmethod
def _remove_model_extension(path: str) -> str:
"""Remove model file extension (.safetensors, .ckpt, .pt, .bin) for cleaner matching."""
return re.sub(r"\.(safetensors|ckpt|pt|bin)$", "", path, flags=re.IGNORECASE)
@staticmethod
def _relative_path_matches_tokens(
path_lower: str, include_terms: List[str], exclude_terms: List[str]
) -> bool:
"""Determine whether a relative path string satisfies include/exclude tokens."""
if any(term and term in path_lower for term in exclude_terms):
"""Determine whether a relative path string satisfies include/exclude tokens.
Matches against the path without extension to avoid matching .safetensors
when searching for 's'.
"""
# Use path without extension for matching
path_for_matching = BaseModelService._remove_model_extension(path_lower)
if any(term and term in path_for_matching for term in exclude_terms):
return False
for term in include_terms:
if term and term not in path_lower:
if term and term not in path_for_matching:
return False
return True
@staticmethod
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
"""Sort paths by how well they satisfy the include tokens."""
path_lower = relative_path.lower()
"""Sort paths by how well they satisfy the include tokens.
Sorts based on path without extension for consistent ordering.
"""
# Use path without extension for sorting
path_for_sorting = BaseModelService._remove_model_extension(
relative_path.lower()
)
prefix_hits = sum(
1 for term in include_terms if term and path_lower.startswith(term)
1 for term in include_terms if term and path_for_sorting.startswith(term)
)
match_positions = [
path_lower.find(term)
path_for_sorting.find(term)
for term in include_terms
if term and term in path_lower
if term and term in path_for_sorting
]
first_match_index = min(match_positions) if match_positions else 0
return (-prefix_hits, first_match_index, len(relative_path), path_lower)
return (
-prefix_hits,
first_match_index,
len(path_for_sorting),
path_for_sorting,
)
async def search_relative_paths(
self, search_term: str, limit: int = 15, offset: int = 0

View File

@@ -62,3 +62,42 @@ async def test_search_relative_paths_excludes_tokens():
matching = await service.search_relative_paths("flux -detail")
assert matching == [f"flux{os.sep}keep-me.safetensors"]
@pytest.mark.asyncio
async def test_search_does_not_match_extension():
"""Searching for 's' or 'safe' should not match .safetensors extension."""
scanner = FakeScanner(
[
{"file_path": "/models/lora1.safetensors"},
{"file_path": "/models/lora2.safetensors"},
{"file_path": "/models/special-model.safetensors"}, # 's' in filename
],
["/models"],
)
service = DummyService("stub", scanner, BaseModelMetadata)
# Searching for 's' should only match 'special-model', not all .safetensors
matching = await service.search_relative_paths("s")
# Should only match 'special-model' because 's' is in the filename
assert len(matching) == 1
assert "special-model" in matching[0]
@pytest.mark.asyncio
async def test_search_safe_does_not_match_all_files():
"""Searching for 'safe' should not match .safetensors extension."""
scanner = FakeScanner(
[
{"file_path": "/models/flux.safetensors"},
{"file_path": "/models/detail.safetensors"},
],
["/models"],
)
service = DummyService("stub", scanner, BaseModelMetadata)
# Searching for 'safe' should return nothing (no file has 'safe' in its name)
matching = await service.search_relative_paths("safe")
assert len(matching) == 0

View File

@@ -689,6 +689,22 @@ class AutoComplete {
return Array.from(variations).filter(v => v.length >= this.options.minChars);
}
/**
* Get display text for an item (without extension for models)
* @param {string|Object} item - Item to get display text from
* @returns {string} - Display text without extension
*/
_getDisplayText(item) {
const itemText = typeof item === 'object' && item.tag_name ? item.tag_name : String(item);
// Remove extension for models to avoid matching/displaying .safetensors etc.
if (this.modelType === 'loras' || this.searchType === 'embeddings') {
return removeLoraExtension(itemText);
} else if (this.modelType === 'embeddings') {
return removeGeneralExtension(itemText);
}
return itemText;
}
/**
* Check if an item matches a search term
* Supports both string items and enriched items with tag_name property
@@ -697,7 +713,7 @@ class AutoComplete {
* @returns {Object} - { matched: boolean, isExactMatch: boolean }
*/
_matchItem(item, searchTerm) {
const itemText = typeof item === 'object' && item.tag_name ? item.tag_name : String(item);
const itemText = this._getDisplayText(item);
const itemTextLower = itemText.toLowerCase();
const searchTermLower = searchTerm.toLowerCase();
@@ -1070,7 +1086,9 @@ class AutoComplete {
// to prevent flex layout from breaking up the text
const nameSpan = document.createElement('span');
nameSpan.className = 'lm-autocomplete-name';
nameSpan.innerHTML = this.highlightMatch(displayText, this.currentSearchTerm);
// Use display text without extension for cleaner UI
const displayTextWithoutExt = this._getDisplayText(displayText);
nameSpan.innerHTML = this.highlightMatch(displayTextWithoutExt, this.currentSearchTerm);
nameSpan.style.cssText = `
flex: 1;
min-width: 0;
@@ -1522,7 +1540,9 @@ class AutoComplete {
} else {
const nameSpan = document.createElement('span');
nameSpan.className = 'lm-autocomplete-name';
nameSpan.innerHTML = this.highlightMatch(displayText, this.currentSearchTerm);
// Use display text without extension for cleaner UI
const displayTextWithoutExt = this._getDisplayText(displayText);
nameSpan.innerHTML = this.highlightMatch(displayTextWithoutExt, this.currentSearchTerm);
nameSpan.style.cssText = `
flex: 1;
min-width: 0;