From 43f6bfab36ae06e22f51e0f9c9c2a9e47ace51e9 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sat, 7 Mar 2026 23:07:10 +0800 Subject: [PATCH] fix(autocomplete): strip file extensions from model names in search suggestions Remove .safetensors/.ckpt/.pt/.bin extensions from model names in autocomplete suggestions to improve UX and search relevance: Frontend (web/comfyui/autocomplete.js): - Add _getDisplayText() helper to strip extensions from model paths - Update _matchItem() to match against filename without extension - Update render() and createItemElement() to display clean names Backend (py/services/base_model_service.py): - Add _remove_model_extension() helper method - Update _relative_path_matches_tokens() to ignore extensions in matching - Update _relative_path_sort_key() to sort based on names without extensions Tests (tests/services/test_relative_path_search.py): - Add tests to verify 's' and 'safe' queries don't match all .safetensors files Fixes issue where typing 's' would match all .safetensors files and cluttered suggestions with redundant extension names. --- py/services/base_model_service.py | 42 ++++++++++++++++----- tests/services/test_relative_path_search.py | 39 +++++++++++++++++++ web/comfyui/autocomplete.js | 26 +++++++++++-- 3 files changed, 95 insertions(+), 12 deletions(-) diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 60ae81d8..c08e0dad 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod import asyncio +import re from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING import logging import os @@ -821,35 +822,58 @@ class BaseModelService(ABC): return include_terms, exclude_terms + @staticmethod + def _remove_model_extension(path: str) -> str: + """Remove model file extension (.safetensors, .ckpt, .pt, .bin) for cleaner matching.""" + return re.sub(r"\.(safetensors|ckpt|pt|bin)$", "", path, flags=re.IGNORECASE) + @staticmethod def _relative_path_matches_tokens( path_lower: str, include_terms: List[str], exclude_terms: List[str] ) -> bool: - """Determine whether a relative path string satisfies include/exclude tokens.""" - if any(term and term in path_lower for term in exclude_terms): + """Determine whether a relative path string satisfies include/exclude tokens. + + Matches against the path without extension to avoid matching .safetensors + when searching for 's'. + """ + # Use path without extension for matching + path_for_matching = BaseModelService._remove_model_extension(path_lower) + + if any(term and term in path_for_matching for term in exclude_terms): return False for term in include_terms: - if term and term not in path_lower: + if term and term not in path_for_matching: return False return True @staticmethod def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple: - """Sort paths by how well they satisfy the include tokens.""" - path_lower = relative_path.lower() + """Sort paths by how well they satisfy the include tokens. + + Sorts based on path without extension for consistent ordering. + """ + # Use path without extension for sorting + path_for_sorting = BaseModelService._remove_model_extension( + relative_path.lower() + ) prefix_hits = sum( - 1 for term in include_terms if term and path_lower.startswith(term) + 1 for term in include_terms if term and path_for_sorting.startswith(term) ) match_positions = [ - path_lower.find(term) + path_for_sorting.find(term) for term in include_terms - if term and term in path_lower + if term and term in path_for_sorting ] first_match_index = min(match_positions) if match_positions else 0 - return (-prefix_hits, first_match_index, len(relative_path), path_lower) + return ( + -prefix_hits, + first_match_index, + len(path_for_sorting), + path_for_sorting, + ) async def search_relative_paths( self, search_term: str, limit: int = 15, offset: int = 0 diff --git a/tests/services/test_relative_path_search.py b/tests/services/test_relative_path_search.py index 0e10039c..1c68ef8f 100644 --- a/tests/services/test_relative_path_search.py +++ b/tests/services/test_relative_path_search.py @@ -62,3 +62,42 @@ async def test_search_relative_paths_excludes_tokens(): matching = await service.search_relative_paths("flux -detail") assert matching == [f"flux{os.sep}keep-me.safetensors"] + + +@pytest.mark.asyncio +async def test_search_does_not_match_extension(): + """Searching for 's' or 'safe' should not match .safetensors extension.""" + scanner = FakeScanner( + [ + {"file_path": "/models/lora1.safetensors"}, + {"file_path": "/models/lora2.safetensors"}, + {"file_path": "/models/special-model.safetensors"}, # 's' in filename + ], + ["/models"], + ) + service = DummyService("stub", scanner, BaseModelMetadata) + + # Searching for 's' should only match 'special-model', not all .safetensors + matching = await service.search_relative_paths("s") + + # Should only match 'special-model' because 's' is in the filename + assert len(matching) == 1 + assert "special-model" in matching[0] + + +@pytest.mark.asyncio +async def test_search_safe_does_not_match_all_files(): + """Searching for 'safe' should not match .safetensors extension.""" + scanner = FakeScanner( + [ + {"file_path": "/models/flux.safetensors"}, + {"file_path": "/models/detail.safetensors"}, + ], + ["/models"], + ) + service = DummyService("stub", scanner, BaseModelMetadata) + + # Searching for 'safe' should return nothing (no file has 'safe' in its name) + matching = await service.search_relative_paths("safe") + + assert len(matching) == 0 diff --git a/web/comfyui/autocomplete.js b/web/comfyui/autocomplete.js index 4c6bedfc..a39d917d 100644 --- a/web/comfyui/autocomplete.js +++ b/web/comfyui/autocomplete.js @@ -689,6 +689,22 @@ class AutoComplete { return Array.from(variations).filter(v => v.length >= this.options.minChars); } + /** + * Get display text for an item (without extension for models) + * @param {string|Object} item - Item to get display text from + * @returns {string} - Display text without extension + */ + _getDisplayText(item) { + const itemText = typeof item === 'object' && item.tag_name ? item.tag_name : String(item); + // Remove extension for models to avoid matching/displaying .safetensors etc. + if (this.modelType === 'loras' || this.searchType === 'embeddings') { + return removeLoraExtension(itemText); + } else if (this.modelType === 'embeddings') { + return removeGeneralExtension(itemText); + } + return itemText; + } + /** * Check if an item matches a search term * Supports both string items and enriched items with tag_name property @@ -697,7 +713,7 @@ class AutoComplete { * @returns {Object} - { matched: boolean, isExactMatch: boolean } */ _matchItem(item, searchTerm) { - const itemText = typeof item === 'object' && item.tag_name ? item.tag_name : String(item); + const itemText = this._getDisplayText(item); const itemTextLower = itemText.toLowerCase(); const searchTermLower = searchTerm.toLowerCase(); @@ -1070,7 +1086,9 @@ class AutoComplete { // to prevent flex layout from breaking up the text const nameSpan = document.createElement('span'); nameSpan.className = 'lm-autocomplete-name'; - nameSpan.innerHTML = this.highlightMatch(displayText, this.currentSearchTerm); + // Use display text without extension for cleaner UI + const displayTextWithoutExt = this._getDisplayText(displayText); + nameSpan.innerHTML = this.highlightMatch(displayTextWithoutExt, this.currentSearchTerm); nameSpan.style.cssText = ` flex: 1; min-width: 0; @@ -1522,7 +1540,9 @@ class AutoComplete { } else { const nameSpan = document.createElement('span'); nameSpan.className = 'lm-autocomplete-name'; - nameSpan.innerHTML = this.highlightMatch(displayText, this.currentSearchTerm); + // Use display text without extension for cleaner UI + const displayTextWithoutExt = this._getDisplayText(displayText); + nameSpan.innerHTML = this.highlightMatch(displayTextWithoutExt, this.currentSearchTerm); nameSpan.style.cssText = ` flex: 1; min-width: 0;