feat: enhance search with include/exclude tokens and improved sorting

- Add token parsing to support include/exclude search terms using "-" prefix
- Implement token-based matching logic for relative path searches
- Improve search result sorting by prioritizing prefix matches and match position
- Add frontend test for multi-token highlighting with exclusion support
This commit is contained in:
Will Miao
2025-11-21 19:48:43 +08:00
parent 9a789f8f08
commit ea1d1a49c9
4 changed files with 167 additions and 12 deletions

View File

@@ -648,13 +648,55 @@ class BaseModelService(ABC):
return None return None
return metadata.modelDescription or '' return metadata.modelDescription or ''
@staticmethod
def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]:
"""Split a search string into include and exclude tokens."""
include_terms: List[str] = []
exclude_terms: List[str] = []
for raw_term in search_term.split():
term = raw_term.strip()
if not term:
continue
if term.startswith("-") and len(term) > 1:
exclude_terms.append(term[1:].lower())
else:
include_terms.append(term.lower())
return include_terms, exclude_terms
@staticmethod
def _relative_path_matches_tokens(
path_lower: str, include_terms: List[str], exclude_terms: List[str]
) -> bool:
"""Determine whether a relative path string satisfies include/exclude tokens."""
if any(term and term in path_lower for term in exclude_terms):
return False
for term in include_terms:
if term and term not in path_lower:
return False
return True
@staticmethod
def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple:
"""Sort paths by how well they satisfy the include tokens."""
path_lower = relative_path.lower()
prefix_hits = sum(1 for term in include_terms if term and path_lower.startswith(term))
match_positions = [path_lower.find(term) for term in include_terms if term and term in path_lower]
first_match_index = min(match_positions) if match_positions else 0
return (-prefix_hits, first_match_index, len(relative_path), path_lower)
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]: async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
"""Search model relative file paths for autocomplete functionality""" """Search model relative file paths for autocomplete functionality"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
include_terms, exclude_terms = self._parse_search_tokens(search_term)
matching_paths = [] matching_paths = []
search_lower = search_term.lower()
# Get model roots for path calculation # Get model roots for path calculation
model_roots = self.scanner.get_model_roots() model_roots = self.scanner.get_model_roots()
@@ -676,17 +718,19 @@ class BaseModelService(ABC):
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep) relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
break break
if relative_path and search_lower in relative_path.lower(): if not relative_path:
continue
relative_lower = relative_path.lower()
if self._relative_path_matches_tokens(relative_lower, include_terms, exclude_terms):
matching_paths.append(relative_path) matching_paths.append(relative_path)
if len(matching_paths) >= limit * 2: # Get more for better sorting if len(matching_paths) >= limit * 2: # Get more for better sorting
break break
# Sort by relevance (exact matches first, then by length) # Sort by relevance (prefix and earliest hits first, then by length and alphabetically)
matching_paths.sort(key=lambda x: ( matching_paths.sort(
not x.lower().startswith(search_lower), # Exact prefix matches first key=lambda relative: self._relative_path_sort_key(relative, include_terms)
len(x), # Then by length (shorter first) )
x.lower() # Then alphabetically
))
return matching_paths[:limit] return matching_paths[:limit]

View File

@@ -136,4 +136,23 @@ describe('AutoComplete widget interactions', () => {
expect(input.focus).toHaveBeenCalled(); expect(input.focus).toHaveBeenCalled();
expect(input.setSelectionRange).toHaveBeenCalled(); expect(input.setSelectionRange).toHaveBeenCalled();
}); });
it('highlights multiple include tokens while ignoring excluded ones', async () => {
const input = document.createElement('textarea');
document.body.append(input);
const { AutoComplete } = await import(AUTOCOMPLETE_MODULE);
const autoComplete = new AutoComplete(input, 'loras', { showPreview: false });
const highlighted = autoComplete.highlightMatch(
'models/flux/beta-detail.safetensors',
'flux detail -beta',
);
const highlightCount = (highlighted.match(/<span/g) || []).length;
expect(highlightCount).toBe(2);
expect(highlighted).toContain('flux');
expect(highlighted).toContain('detail');
expect(highlighted).not.toMatch(/beta<\/span>/i);
});
}); });

View File

@@ -0,0 +1,63 @@
import pytest
from py.services.base_model_service import BaseModelService
from py.utils.models import BaseModelMetadata
class DummyService(BaseModelService):
async def format_response(self, model_data):
return model_data
class FakeCache:
def __init__(self, raw_data):
self.raw_data = list(raw_data)
class FakeScanner:
def __init__(self, raw_data, roots):
self._cache = FakeCache(raw_data)
self._roots = list(roots)
async def get_cached_data(self, *_args, **_kwargs):
return self._cache
def get_model_roots(self):
return list(self._roots)
@pytest.mark.asyncio
async def test_search_relative_paths_supports_multiple_tokens():
scanner = FakeScanner(
[
{"file_path": "/models/flux/detail-model.safetensors"},
{"file_path": "/models/flux/only-flux.safetensors"},
{"file_path": "/models/detail/flux-trained.safetensors"},
{"file_path": "/models/detail/standalone.safetensors"},
],
["/models"],
)
service = DummyService("stub", scanner, BaseModelMetadata)
matching = await service.search_relative_paths("flux detail")
assert matching == [
"flux/detail-model.safetensors",
"detail/flux-trained.safetensors",
]
@pytest.mark.asyncio
async def test_search_relative_paths_excludes_tokens():
scanner = FakeScanner(
[
{"file_path": "/models/flux/detail-model.safetensors"},
{"file_path": "/models/flux/keep-me.safetensors"},
],
["/models"],
)
service = DummyService("stub", scanner, BaseModelMetadata)
matching = await service.search_relative_paths("flux -detail")
assert matching == ["flux/keep-me.safetensors"]

View File

@@ -32,6 +32,25 @@ function removeLoraExtension(fileName = '') {
return fileName.replace(/\.(safetensors|ckpt|pt|bin)$/i, ''); return fileName.replace(/\.(safetensors|ckpt|pt|bin)$/i, '');
} }
function parseSearchTokens(term = '') {
const include = [];
const exclude = [];
term.split(/\s+/).forEach((rawTerm) => {
const token = rawTerm.trim();
if (!token) {
return;
}
if (token.startsWith('-') && token.length > 1) {
exclude.push(token.slice(1).toLowerCase());
} else {
include.push(token.toLowerCase());
}
});
return { include, exclude };
}
function createDefaultBehavior(modelType) { function createDefaultBehavior(modelType) {
return { return {
enablePreview: false, enablePreview: false,
@@ -393,10 +412,20 @@ class AutoComplete {
} }
highlightMatch(text, searchTerm) { highlightMatch(text, searchTerm) {
if (!searchTerm) return text; const { include } = parseSearchTokens(searchTerm);
const sanitizedTokens = include
const regex = new RegExp(`(${searchTerm.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')})`, 'gi'); .filter(Boolean)
return text.replace(regex, '<span style="background-color: rgba(66, 153, 225, 0.3); color: white; padding: 1px 2px; border-radius: 2px;">$1</span>'); .map((token) => token.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'));
if (!sanitizedTokens.length) {
return text;
}
const regex = new RegExp(`(${sanitizedTokens.join('|')})`, 'gi');
return text.replace(
regex,
'<span style="background-color: rgba(66, 153, 225, 0.3); color: white; padding: 1px 2px; border-radius: 2px;">$1</span>',
);
} }
showPreviewForItem(relativePath, itemElement) { showPreviewForItem(relativePath, itemElement) {