mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
fix(autocomplete): strip file extensions from model names in search suggestions
Remove .safetensors/.ckpt/.pt/.bin extensions from model names in autocomplete suggestions to improve UX and search relevance: Frontend (web/comfyui/autocomplete.js): - Add _getDisplayText() helper to strip extensions from model paths - Update _matchItem() to match against filename without extension - Update render() and createItemElement() to display clean names Backend (py/services/base_model_service.py): - Add _remove_model_extension() helper method - Update _relative_path_matches_tokens() to ignore extensions in matching - Update _relative_path_sort_key() to sort based on names without extensions Tests (tests/services/test_relative_path_search.py): - Add tests to verify 's' and 'safe' queries don't match all .safetensors files Fixes issue where typing 's' would match all .safetensors files and cluttered suggestions with redundant extension names.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
import logging
|
||||
import os
|
||||
@@ -821,35 +822,58 @@ class BaseModelService(ABC):
|
||||
|
||||
return include_terms, exclude_terms
|
||||
|
||||
@staticmethod
|
||||
def _remove_model_extension(path: str) -> str:
|
||||
"""Remove model file extension (.safetensors, .ckpt, .pt, .bin) for cleaner matching."""
|
||||
return re.sub(r"\.(safetensors|ckpt|pt|bin)$", "", path, flags=re.IGNORECASE)
|
||||
|
||||
@staticmethod
|
||||
def _relative_path_matches_tokens(
|
||||
path_lower: str, include_terms: List[str], exclude_terms: List[str]
|
||||
) -> bool:
|
||||
"""Determine whether a relative path string satisfies include/exclude tokens."""
|
||||
if any(term and term in path_lower for term in exclude_terms):
|
||||
"""Determine whether a relative path string satisfies include/exclude tokens.
|
||||
|
||||
Matches against the path without extension to avoid matching .safetensors
|
||||
when searching for 's'.
|
||||
"""
|
||||
# Use path without extension for matching
|
||||
path_for_matching = BaseModelService._remove_model_extension(path_lower)
|
||||
|
||||
if any(term and term in path_for_matching for term in exclude_terms):
|
||||
return False
|
||||
|
||||
for term in include_terms:
|
||||
if term and term not in path_lower:
|
||||
if term and term not in path_for_matching:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
|
||||
"""Sort paths by how well they satisfy the include tokens."""
|
||||
path_lower = relative_path.lower()
|
||||
"""Sort paths by how well they satisfy the include tokens.
|
||||
|
||||
Sorts based on path without extension for consistent ordering.
|
||||
"""
|
||||
# Use path without extension for sorting
|
||||
path_for_sorting = BaseModelService._remove_model_extension(
|
||||
relative_path.lower()
|
||||
)
|
||||
prefix_hits = sum(
|
||||
1 for term in include_terms if term and path_lower.startswith(term)
|
||||
1 for term in include_terms if term and path_for_sorting.startswith(term)
|
||||
)
|
||||
match_positions = [
|
||||
path_lower.find(term)
|
||||
path_for_sorting.find(term)
|
||||
for term in include_terms
|
||||
if term and term in path_lower
|
||||
if term and term in path_for_sorting
|
||||
]
|
||||
first_match_index = min(match_positions) if match_positions else 0
|
||||
|
||||
return (-prefix_hits, first_match_index, len(relative_path), path_lower)
|
||||
return (
|
||||
-prefix_hits,
|
||||
first_match_index,
|
||||
len(path_for_sorting),
|
||||
path_for_sorting,
|
||||
)
|
||||
|
||||
async def search_relative_paths(
|
||||
self, search_term: str, limit: int = 15, offset: int = 0
|
||||
|
||||
@@ -62,3 +62,42 @@ async def test_search_relative_paths_excludes_tokens():
|
||||
matching = await service.search_relative_paths("flux -detail")
|
||||
|
||||
assert matching == [f"flux{os.sep}keep-me.safetensors"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_does_not_match_extension():
|
||||
"""Searching for 's' or 'safe' should not match .safetensors extension."""
|
||||
scanner = FakeScanner(
|
||||
[
|
||||
{"file_path": "/models/lora1.safetensors"},
|
||||
{"file_path": "/models/lora2.safetensors"},
|
||||
{"file_path": "/models/special-model.safetensors"}, # 's' in filename
|
||||
],
|
||||
["/models"],
|
||||
)
|
||||
service = DummyService("stub", scanner, BaseModelMetadata)
|
||||
|
||||
# Searching for 's' should only match 'special-model', not all .safetensors
|
||||
matching = await service.search_relative_paths("s")
|
||||
|
||||
# Should only match 'special-model' because 's' is in the filename
|
||||
assert len(matching) == 1
|
||||
assert "special-model" in matching[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_safe_does_not_match_all_files():
|
||||
"""Searching for 'safe' should not match .safetensors extension."""
|
||||
scanner = FakeScanner(
|
||||
[
|
||||
{"file_path": "/models/flux.safetensors"},
|
||||
{"file_path": "/models/detail.safetensors"},
|
||||
],
|
||||
["/models"],
|
||||
)
|
||||
service = DummyService("stub", scanner, BaseModelMetadata)
|
||||
|
||||
# Searching for 'safe' should return nothing (no file has 'safe' in its name)
|
||||
matching = await service.search_relative_paths("safe")
|
||||
|
||||
assert len(matching) == 0
|
||||
|
||||
@@ -689,6 +689,22 @@ class AutoComplete {
|
||||
return Array.from(variations).filter(v => v.length >= this.options.minChars);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get display text for an item (without extension for models)
|
||||
* @param {string|Object} item - Item to get display text from
|
||||
* @returns {string} - Display text without extension
|
||||
*/
|
||||
_getDisplayText(item) {
|
||||
const itemText = typeof item === 'object' && item.tag_name ? item.tag_name : String(item);
|
||||
// Remove extension for models to avoid matching/displaying .safetensors etc.
|
||||
if (this.modelType === 'loras' || this.searchType === 'embeddings') {
|
||||
return removeLoraExtension(itemText);
|
||||
} else if (this.modelType === 'embeddings') {
|
||||
return removeGeneralExtension(itemText);
|
||||
}
|
||||
return itemText;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an item matches a search term
|
||||
* Supports both string items and enriched items with tag_name property
|
||||
@@ -697,7 +713,7 @@ class AutoComplete {
|
||||
* @returns {Object} - { matched: boolean, isExactMatch: boolean }
|
||||
*/
|
||||
_matchItem(item, searchTerm) {
|
||||
const itemText = typeof item === 'object' && item.tag_name ? item.tag_name : String(item);
|
||||
const itemText = this._getDisplayText(item);
|
||||
const itemTextLower = itemText.toLowerCase();
|
||||
const searchTermLower = searchTerm.toLowerCase();
|
||||
|
||||
@@ -1070,7 +1086,9 @@ class AutoComplete {
|
||||
// to prevent flex layout from breaking up the text
|
||||
const nameSpan = document.createElement('span');
|
||||
nameSpan.className = 'lm-autocomplete-name';
|
||||
nameSpan.innerHTML = this.highlightMatch(displayText, this.currentSearchTerm);
|
||||
// Use display text without extension for cleaner UI
|
||||
const displayTextWithoutExt = this._getDisplayText(displayText);
|
||||
nameSpan.innerHTML = this.highlightMatch(displayTextWithoutExt, this.currentSearchTerm);
|
||||
nameSpan.style.cssText = `
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
@@ -1522,7 +1540,9 @@ class AutoComplete {
|
||||
} else {
|
||||
const nameSpan = document.createElement('span');
|
||||
nameSpan.className = 'lm-autocomplete-name';
|
||||
nameSpan.innerHTML = this.highlightMatch(displayText, this.currentSearchTerm);
|
||||
// Use display text without extension for cleaner UI
|
||||
const displayTextWithoutExt = this._getDisplayText(displayText);
|
||||
nameSpan.innerHTML = this.highlightMatch(displayTextWithoutExt, this.currentSearchTerm);
|
||||
nameSpan.style.cssText = `
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
|
||||
Reference in New Issue
Block a user