diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index c7f9de64..84db592b 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -648,13 +648,55 @@ class BaseModelService(ABC): return None return metadata.modelDescription or '' + @staticmethod + def _parse_search_tokens(search_term: str) -> tuple[List[str], List[str]]: + """Split a search string into include and exclude tokens.""" + include_terms: List[str] = [] + exclude_terms: List[str] = [] + + for raw_term in search_term.split(): + term = raw_term.strip() + if not term: + continue + + if term.startswith("-") and len(term) > 1: + exclude_terms.append(term[1:].lower()) + else: + include_terms.append(term.lower()) + + return include_terms, exclude_terms + + @staticmethod + def _relative_path_matches_tokens( + path_lower: str, include_terms: List[str], exclude_terms: List[str] + ) -> bool: + """Determine whether a relative path string satisfies include/exclude tokens.""" + if any(term and term in path_lower for term in exclude_terms): + return False + + for term in include_terms: + if term and term not in path_lower: + return False + + return True + + @staticmethod + def _relative_path_sort_key(relative_path: str, include_terms: List[str]) -> tuple: + """Sort paths by how well they satisfy the include tokens.""" + path_lower = relative_path.lower() + prefix_hits = sum(1 for term in include_terms if term and path_lower.startswith(term)) + match_positions = [path_lower.find(term) for term in include_terms if term and term in path_lower] + first_match_index = min(match_positions) if match_positions else 0 + + return (-prefix_hits, first_match_index, len(relative_path), path_lower) + async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]: """Search model relative file paths for autocomplete functionality""" cache = await self.scanner.get_cached_data() + include_terms, exclude_terms = self._parse_search_tokens(search_term) matching_paths = [] - search_lower = search_term.lower() # Get model roots for path calculation model_roots = self.scanner.get_model_roots() @@ -676,17 +718,19 @@ class BaseModelService(ABC): relative_path = normalized_file[len(normalized_root):].lstrip(os.sep) break - if relative_path and search_lower in relative_path.lower(): + if not relative_path: + continue + + relative_lower = relative_path.lower() + if self._relative_path_matches_tokens(relative_lower, include_terms, exclude_terms): matching_paths.append(relative_path) if len(matching_paths) >= limit * 2: # Get more for better sorting break - # Sort by relevance (exact matches first, then by length) - matching_paths.sort(key=lambda x: ( - not x.lower().startswith(search_lower), # Exact prefix matches first - len(x), # Then by length (shorter first) - x.lower() # Then alphabetically - )) + # Sort by relevance (prefix and earliest hits first, then by length and alphabetically) + matching_paths.sort( + key=lambda relative: self._relative_path_sort_key(relative, include_terms) + ) return matching_paths[:limit] diff --git a/tests/frontend/components/autocomplete.behavior.test.js b/tests/frontend/components/autocomplete.behavior.test.js index d0497be7..16a79ff0 100644 --- a/tests/frontend/components/autocomplete.behavior.test.js +++ b/tests/frontend/components/autocomplete.behavior.test.js @@ -136,4 +136,23 @@ describe('AutoComplete widget interactions', () => { expect(input.focus).toHaveBeenCalled(); expect(input.setSelectionRange).toHaveBeenCalled(); }); + + it('highlights multiple include tokens while ignoring excluded ones', async () => { + const input = document.createElement('textarea'); + document.body.append(input); + + const { AutoComplete } = await import(AUTOCOMPLETE_MODULE); + const autoComplete = new AutoComplete(input, 'loras', { showPreview: false }); + + const highlighted = autoComplete.highlightMatch( + 'models/flux/beta-detail.safetensors', + 'flux detail -beta', + ); + + const highlightCount = (highlighted.match(//i); + }); }); diff --git a/tests/services/test_relative_path_search.py b/tests/services/test_relative_path_search.py new file mode 100644 index 00000000..e26c98ea --- /dev/null +++ b/tests/services/test_relative_path_search.py @@ -0,0 +1,63 @@ +import pytest + +from py.services.base_model_service import BaseModelService +from py.utils.models import BaseModelMetadata + + +class DummyService(BaseModelService): + async def format_response(self, model_data): + return model_data + + +class FakeCache: + def __init__(self, raw_data): + self.raw_data = list(raw_data) + + +class FakeScanner: + def __init__(self, raw_data, roots): + self._cache = FakeCache(raw_data) + self._roots = list(roots) + + async def get_cached_data(self, *_args, **_kwargs): + return self._cache + + def get_model_roots(self): + return list(self._roots) + + +@pytest.mark.asyncio +async def test_search_relative_paths_supports_multiple_tokens(): + scanner = FakeScanner( + [ + {"file_path": "/models/flux/detail-model.safetensors"}, + {"file_path": "/models/flux/only-flux.safetensors"}, + {"file_path": "/models/detail/flux-trained.safetensors"}, + {"file_path": "/models/detail/standalone.safetensors"}, + ], + ["/models"], + ) + service = DummyService("stub", scanner, BaseModelMetadata) + + matching = await service.search_relative_paths("flux detail") + + assert matching == [ + "flux/detail-model.safetensors", + "detail/flux-trained.safetensors", + ] + + +@pytest.mark.asyncio +async def test_search_relative_paths_excludes_tokens(): + scanner = FakeScanner( + [ + {"file_path": "/models/flux/detail-model.safetensors"}, + {"file_path": "/models/flux/keep-me.safetensors"}, + ], + ["/models"], + ) + service = DummyService("stub", scanner, BaseModelMetadata) + + matching = await service.search_relative_paths("flux -detail") + + assert matching == ["flux/keep-me.safetensors"] diff --git a/web/comfyui/autocomplete.js b/web/comfyui/autocomplete.js index aa0d510c..574cd037 100644 --- a/web/comfyui/autocomplete.js +++ b/web/comfyui/autocomplete.js @@ -32,6 +32,25 @@ function removeLoraExtension(fileName = '') { return fileName.replace(/\.(safetensors|ckpt|pt|bin)$/i, ''); } +function parseSearchTokens(term = '') { + const include = []; + const exclude = []; + + term.split(/\s+/).forEach((rawTerm) => { + const token = rawTerm.trim(); + if (!token) { + return; + } + if (token.startsWith('-') && token.length > 1) { + exclude.push(token.slice(1).toLowerCase()); + } else { + include.push(token.toLowerCase()); + } + }); + + return { include, exclude }; +} + function createDefaultBehavior(modelType) { return { enablePreview: false, @@ -393,10 +412,20 @@ class AutoComplete { } highlightMatch(text, searchTerm) { - if (!searchTerm) return text; - - const regex = new RegExp(`(${searchTerm.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')})`, 'gi'); - return text.replace(regex, '$1'); + const { include } = parseSearchTokens(searchTerm); + const sanitizedTokens = include + .filter(Boolean) + .map((token) => token.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')); + + if (!sanitizedTokens.length) { + return text; + } + + const regex = new RegExp(`(${sanitizedTokens.join('|')})`, 'gi'); + return text.replace( + regex, + '$1', + ); } showPreviewForItem(relativePath, itemElement) {