diff --git a/docs/frontend-filtering-test-matrix.md b/docs/frontend-filtering-test-matrix.md index cd8f48d2..d6efd465 100644 --- a/docs/frontend-filtering-test-matrix.md +++ b/docs/frontend-filtering-test-matrix.md @@ -21,7 +21,7 @@ This matrix captures the scenarios that Phase 3 frontend tests should cover for | ID | Feature | Scenario | LoRAs Expectations | Checkpoints Expectations | Notes | | --- | --- | --- | --- | --- | --- | | F-01 | Search filter | Typing a query updates `pageState.filters.search`, persists to session, and triggers `resetAndReload` on submit | Validate `SearchManager` writes query and reloads via API stub; confirm LoRA cards pass query downstream | Same as LoRAs | Cover `enter` press and clicking search icon | -| F-02 | Tag filter | Selecting a tag chip adds it to filters, applies active styling, and reloads results | Tag stored under `filters.tags`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path | +| F-02 | Tag filter | Selecting a tag chip cycles include ➜ exclude ➜ clear, updates storage, and reloads results | Tag state stored under `filters.tags[tagName] = 'include'|'exclude'`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path | | F-03 | Base model filter | Toggling base model checkboxes updates `filters.baseModel`, persists, and reloads | Ensure only LoRA-supported models show; toggle multi-select | Ensure SDXL/Flux base models appear as expected | Capture UI state restored from storage on next init | | F-04 | Favorites-only | Clicking favorites toggle updates session flag and calls `resetAndReload(true)` | Button gains `.active` class and API called | Same | Verify duplicates badge refresh when active | | F-05 | Sort selection | Changing sort select saves preference (legacy + new format) and reloads | Confirm `PageControls.saveSortPreference` invoked with option and API called | Same with checkpoints-specific defaults | Cover `convertLegacySortFormat` branch | diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index fa68f55f..7268e0ca 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -144,7 +144,28 @@ class ModelListingHandler: fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true" base_models = request.query.getall("base_model", []) - tags = request.query.getall("tag", []) + # Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters + legacy_tags = request.query.getall("tag", []) + if not legacy_tags: + legacy_csv = request.query.get("tags") + if legacy_csv: + legacy_tags = [tag.strip() for tag in legacy_csv.split(",") if tag.strip()] + + include_tags = request.query.getall("tag_include", []) + exclude_tags = request.query.getall("tag_exclude", []) + + tag_filters: Dict[str, str] = {} + for tag in legacy_tags: + if tag: + tag_filters[tag] = "include" + + for tag in include_tags: + if tag: + tag_filters[tag] = "include" + + for tag in exclude_tags: + if tag: + tag_filters[tag] = "exclude" favorites_only = request.query.get("favorites_only", "false").lower() == "true" search_options = { @@ -189,7 +210,7 @@ class ModelListingHandler: "search": search, "fuzzy_search": fuzzy_search, "base_models": base_models, - "tags": tags, + "tags": tag_filters, "search_options": search_options, "hash_filters": hash_filters, "favorites_only": favorites_only, diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index d8cebb2c..5d76f885 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -152,14 +152,31 @@ class RecipeListingHandler: "lora_model": request.query.get("search_lora_model", "true").lower() == "true", } - filters: Dict[str, list[str]] = {} + filters: Dict[str, Any] = {} base_models = request.query.get("base_models") if base_models: filters["base_model"] = base_models.split(",") - tags = request.query.get("tags") - if tags: - filters["tags"] = tags.split(",") + tag_filters: Dict[str, str] = {} + legacy_tags = request.query.get("tags") + if legacy_tags: + for tag in legacy_tags.split(","): + tag = tag.strip() + if tag: + tag_filters[tag] = "include" + + include_tags = request.query.getall("tag_include", []) + for tag in include_tags: + if tag: + tag_filters[tag] = "include" + + exclude_tags = request.query.getall("tag_exclude", []) + for tag in exclude_tags: + if tag: + tag_filters[tag] = "exclude" + + if tag_filters: + filters["tags"] = tag_filters lora_hash = request.query.get("lora_hash") diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index d90c2ed8..d28b8f72 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -59,7 +59,7 @@ class BaseModelService(ABC): search: str = None, fuzzy_search: bool = False, base_models: list = None, - tags: list = None, + tags: Optional[Dict[str, str]] = None, search_options: dict = None, hash_filters: dict = None, favorites_only: bool = False, @@ -149,7 +149,7 @@ class BaseModelService(ABC): data: List[Dict], folder: str = None, base_models: list = None, - tags: list = None, + tags: Optional[Dict[str, str]] = None, favorites_only: bool = False, search_options: dict = None, ) -> List[Dict]: diff --git a/py/services/model_query.py b/py/services/model_query.py index 100f8f8b..d88e9631 100644 --- a/py/services/model_query.py +++ b/py/services/model_query.py @@ -28,7 +28,7 @@ class FilterCriteria: folder: Optional[str] = None base_models: Optional[Sequence[str]] = None - tags: Optional[Sequence[str]] = None + tags: Optional[Dict[str, str]] = None favorites_only: bool = False search_options: Optional[Dict[str, Any]] = None @@ -108,12 +108,30 @@ class ModelFilterSet: base_model_set = set(base_models) items = [item for item in items if item.get("base_model") in base_model_set] - tags = criteria.tags or [] - if tags: - tag_set = set(tags) + tag_filters = criteria.tags or {} + include_tags = set() + exclude_tags = set() + if isinstance(tag_filters, dict): + for tag, state in tag_filters.items(): + if not tag: + continue + if state == "exclude": + exclude_tags.add(tag) + else: + include_tags.add(tag) + else: + include_tags = {tag for tag in tag_filters if tag} + + if include_tags: items = [ item for item in items - if any(tag in tag_set for tag in item.get("tags", [])) + if any(tag in include_tags for tag in (item.get("tags", []) or [])) + ] + + if exclude_tags: + items = [ + item for item in items + if not any(tag in exclude_tags for tag in (item.get("tags", []) or [])) ] return items diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ea27a924..b5f1ce2f 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -729,10 +729,32 @@ class RecipeScanner: # Filter by tags if 'tags' in filters and filters['tags']: - filtered_data = [ - item for item in filtered_data - if any(tag in item.get('tags', []) for tag in filters['tags']) - ] + tag_spec = filters['tags'] + include_tags = set() + exclude_tags = set() + + if isinstance(tag_spec, dict): + for tag, state in tag_spec.items(): + if not tag: + continue + if state == 'exclude': + exclude_tags.add(tag) + else: + include_tags.add(tag) + else: + include_tags = {tag for tag in tag_spec if tag} + + if include_tags: + filtered_data = [ + item for item in filtered_data + if any(tag in include_tags for tag in (item.get('tags', []) or [])) + ] + + if exclude_tags: + filtered_data = [ + item for item in filtered_data + if not any(tag in exclude_tags for tag in (item.get('tags', []) or [])) + ] # Calculate pagination total_items = len(filtered_data) diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 9034a465..9b6c2300 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -806,9 +806,13 @@ export class BaseModelApiClient { params.append('recursive', pageState.searchOptions.recursive ? 'true' : 'false'); if (pageState.filters) { - if (pageState.filters.tags && pageState.filters.tags.length > 0) { - pageState.filters.tags.forEach(tag => { - params.append('tag', tag); + if (pageState.filters.tags && Object.keys(pageState.filters.tags).length > 0) { + Object.entries(pageState.filters.tags).forEach(([tag, state]) => { + if (state === 'include') { + params.append('tag_include', tag); + } else if (state === 'exclude') { + params.append('tag_exclude', tag); + } }); } diff --git a/static/js/api/recipeApi.js b/static/js/api/recipeApi.js index 3b912905..ece9938f 100644 --- a/static/js/api/recipeApi.js +++ b/static/js/api/recipeApi.js @@ -66,8 +66,14 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) { } // Add tag filters - if (pageState.filters?.tags && pageState.filters.tags.length) { - params.append('tags', pageState.filters.tags.join(',')); + if (pageState.filters?.tags && Object.keys(pageState.filters.tags).length) { + Object.entries(pageState.filters.tags).forEach(([tag, state]) => { + if (state === 'include') { + params.append('tag_include', tag); + } else if (state === 'exclude') { + params.append('tag_exclude', tag); + } + }); } } diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 31d416ad..1f0500c2 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -12,11 +12,7 @@ export class FilterManager { this.currentPage = options.page || document.body.dataset.page || 'loras'; const pageState = getCurrentPageState(); - this.filters = pageState.filters || { - baseModel: [], - tags: [], - license: {} - }; + this.filters = this.initializeFilters(pageState ? pageState.filters : undefined); this.filterPanel = document.getElementById('filterPanel'); this.filterButton = document.getElementById('filterButton'); @@ -28,6 +24,7 @@ export class FilterManager { // Store this instance in the state if (pageState) { pageState.filterManager = this; + pageState.filters = this.cloneFilters(); } } @@ -111,17 +108,12 @@ export class FilterManager { tagEl.dataset.tag = tagName; tagEl.innerHTML = `${tagName} ${tag.count}`; - // Add click handler to toggle selection and automatically apply + // Add click handler to cycle through tri-state filter and automatically apply tagEl.addEventListener('click', async () => { - tagEl.classList.toggle('active'); - - if (tagEl.classList.contains('active')) { - if (!this.filters.tags.includes(tagName)) { - this.filters.tags.push(tagName); - } - } else { - this.filters.tags = this.filters.tags.filter(t => t !== tagName); - } + const currentState = (this.filters.tags && this.filters.tags[tagName]) || 'none'; + const newState = this.getNextTriStateState(currentState); + this.setTagFilterState(tagName, newState); + this.applyTagElementState(tagEl, newState); this.updateActiveFiltersCount(); @@ -129,6 +121,7 @@ export class FilterManager { await this.applyFilters(false); }); + this.applyTagElementState(tagEl, (this.filters.tags && this.filters.tags[tagName]) || 'none'); tagsContainer.appendChild(tagEl); }); } @@ -310,11 +303,8 @@ export class FilterManager { const modelTags = document.querySelectorAll('.tag-filter'); modelTags.forEach(tag => { const tagName = tag.dataset.tag; - if (this.filters.tags.includes(tagName)) { - tag.classList.add('active'); - } else { - tag.classList.remove('active'); - } + const state = (this.filters.tags && this.filters.tags[tagName]) || 'none'; + this.applyTagElementState(tag, state); }); // Update license tags @@ -322,9 +312,9 @@ export class FilterManager { } updateActiveFiltersCount() { - const totalActiveFilters = this.filters.baseModel.length + - this.filters.tags.length + - (this.filters.license ? Object.keys(this.filters.license).length : 0); + const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; + const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0; + const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount; if (this.activeFiltersCount) { if (totalActiveFilters > 0) { @@ -341,10 +331,11 @@ export class FilterManager { const storageKey = `${this.currentPage}_filters`; // Save filters to localStorage - setStorageItem(storageKey, this.filters); + const filtersSnapshot = this.cloneFilters(); + setStorageItem(storageKey, filtersSnapshot); // Update state with current filters - pageState.filters = { ...this.filters }; + pageState.filters = filtersSnapshot; // Call the appropriate manager's load method based on page type if (this.currentPage === 'recipes' && window.recipeManager) { @@ -359,7 +350,7 @@ export class FilterManager { this.filterButton.classList.add('active'); if (showToastNotification) { const baseModelCount = this.filters.baseModel.length; - const tagsCount = this.filters.tags.length; + const tagsCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; let message = ''; if (baseModelCount > 0 && tagsCount > 0) { @@ -382,15 +373,16 @@ export class FilterManager { async clearFilters() { // Clear all filters - this.filters = { + this.filters = this.initializeFilters({ + ...this.filters, baseModel: [], - tags: [], - license: {} // Initialize with empty object instead of deleting - }; + tags: {}, + license: {} + }); // Update state const pageState = getCurrentPageState(); - pageState.filters = { ...this.filters }; + pageState.filters = this.cloneFilters(); // Update UI this.updateTagSelections(); @@ -424,15 +416,11 @@ export class FilterManager { if (savedFilters) { try { // Ensure backward compatibility with older filter format - this.filters = { - baseModel: savedFilters.baseModel || [], - tags: savedFilters.tags || [], - license: savedFilters.license || {} - }; + this.filters = this.initializeFilters(savedFilters); // Update state with loaded filters const pageState = getCurrentPageState(); - pageState.filters = { ...this.filters }; + pageState.filters = this.cloneFilters(); this.updateTagSelections(); this.updateActiveFiltersCount(); @@ -447,8 +435,109 @@ export class FilterManager { } hasActiveFilters() { - return this.filters.baseModel.length > 0 || - this.filters.tags.length > 0 || - (this.filters.license && Object.keys(this.filters.license).length > 0); + const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; + const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0; + return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0; + } + + initializeFilters(existingFilters = {}) { + const source = existingFilters || {}; + return { + ...source, + baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [], + tags: this.normalizeTagFilters(source.tags), + license: this.normalizeLicenseFilters(source.license) + }; + } + + normalizeTagFilters(tagFilters) { + if (!tagFilters) { + return {}; + } + + if (Array.isArray(tagFilters)) { + return tagFilters.reduce((acc, tag) => { + if (typeof tag === 'string' && tag.trim().length > 0) { + acc[tag] = 'include'; + } + return acc; + }, {}); + } + + if (typeof tagFilters === 'object') { + const normalized = {}; + Object.entries(tagFilters).forEach(([tag, state]) => { + if (!tag) { + return; + } + const normalizedState = typeof state === 'string' ? state.toLowerCase() : ''; + if (normalizedState === 'include' || normalizedState === 'exclude') { + normalized[tag] = normalizedState; + } + }); + return normalized; + } + + return {}; + } + + normalizeLicenseFilters(licenseFilters) { + if (!licenseFilters || typeof licenseFilters !== 'object') { + return {}; + } + + const normalized = {}; + Object.entries(licenseFilters).forEach(([key, state]) => { + const normalizedState = typeof state === 'string' ? state.toLowerCase() : ''; + if (normalizedState === 'include' || normalizedState === 'exclude') { + normalized[key] = normalizedState; + } + }); + return normalized; + } + + cloneFilters() { + return { + ...this.filters, + baseModel: [...(this.filters.baseModel || [])], + tags: { ...(this.filters.tags || {}) }, + license: { ...(this.filters.license || {}) } + }; + } + + getNextTriStateState(currentState) { + switch (currentState) { + case 'none': + return 'include'; + case 'include': + return 'exclude'; + default: + return 'none'; + } + } + + setTagFilterState(tagName, state) { + if (!this.filters.tags) { + this.filters.tags = {}; + } + + if (state === 'none') { + delete this.filters.tags[tagName]; + } else { + this.filters.tags[tagName] = state; + } + } + + applyTagElementState(element, state) { + if (!element) { + return; + } + + element.classList.remove('active', 'exclude'); + if (state === 'include') { + element.classList.add('active'); + } else if (state === 'exclude') { + element.classList.add('exclude'); + } } } diff --git a/static/js/state/index.js b/static/js/state/index.js index 2d6b183c..7df333c9 100644 --- a/static/js/state/index.js +++ b/static/js/state/index.js @@ -66,18 +66,19 @@ export const state = { activeFolder: getStorageItem(`${MODEL_TYPES.LORA}_activeFolder`), activeLetterFilter: null, previewVersions: loraPreviewVersions, - searchManager: null, - searchOptions: { - filename: true, - modelname: true, - tags: false, - creator: false, - recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true), - }, - filters: { - baseModel: [], - tags: [] - }, + searchManager: null, + searchOptions: { + filename: true, + modelname: true, + tags: false, + creator: false, + recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true), + }, + filters: { + baseModel: [], + tags: {}, + license: {} + }, bulkMode: false, selectedLoras: new Set(), loraMetadataCache: new Map(), @@ -91,18 +92,19 @@ export const state = { isLoading: false, hasMore: true, sortBy: 'date', - searchManager: null, - searchOptions: { - title: true, - tags: true, - loraName: true, - loraModel: true - }, - filters: { - baseModel: [], - tags: [], - search: '' - }, + searchManager: null, + searchOptions: { + title: true, + tags: true, + loraName: true, + loraModel: true + }, + filters: { + baseModel: [], + tags: {}, + license: {}, + search: '' + }, pageSize: 20, showFavoritesOnly: false, duplicatesMode: false, @@ -117,17 +119,18 @@ export const state = { sortBy: 'name', activeFolder: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_activeFolder`), previewVersions: checkpointPreviewVersions, - searchManager: null, - searchOptions: { - filename: true, - modelname: true, - creator: false, - recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true), - }, - filters: { - baseModel: [], - tags: [] - }, + searchManager: null, + searchOptions: { + filename: true, + modelname: true, + creator: false, + recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true), + }, + filters: { + baseModel: [], + tags: {}, + license: {} + }, modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model' bulkMode: false, selectedModels: new Set(), @@ -145,18 +148,19 @@ export const state = { activeFolder: getStorageItem(`${MODEL_TYPES.EMBEDDING}_activeFolder`), activeLetterFilter: null, previewVersions: embeddingPreviewVersions, - searchManager: null, - searchOptions: { - filename: true, - modelname: true, - tags: false, - creator: false, - recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true), - }, - filters: { - baseModel: [], - tags: [] - }, + searchManager: null, + searchOptions: { + filename: true, + modelname: true, + tags: false, + creator: false, + recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true), + }, + filters: { + baseModel: [], + tags: {}, + license: {} + }, bulkMode: false, selectedModels: new Set(), metadataCache: new Map(), diff --git a/tests/frontend/components/pageControls.filtering.test.js b/tests/frontend/components/pageControls.filtering.test.js index 3aa12a51..3954a394 100644 --- a/tests/frontend/components/pageControls.filtering.test.js +++ b/tests/frontend/components/pageControls.filtering.test.js @@ -225,21 +225,31 @@ describe('FilterManager tag and base model filters', () => { tagChip.dispatchEvent(new Event('click', { bubbles: true })); await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1)); - expect(getCurrentPageState().filters.tags).toEqual(['style']); + expect(getCurrentPageState().filters.tags).toEqual({ style: 'include' }); expect(tagChip.classList.contains('active')).toBe(true); expect(document.getElementById('activeFiltersCount').textContent).toBe('1'); expect(document.getElementById('activeFiltersCount').style.display).toBe('inline-flex'); const storageKey = `lora_manager_${pageKey}_filters`; const storedFilters = JSON.parse(localStorage.getItem(storageKey)); - expect(storedFilters.tags).toEqual(['style']); + expect(storedFilters.tags).toEqual({ style: 'include' }); loadMoreWithVirtualScrollMock.mockClear(); tagChip.dispatchEvent(new Event('click', { bubbles: true })); await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1)); - expect(getCurrentPageState().filters.tags).toEqual([]); + expect(getCurrentPageState().filters.tags).toEqual({ style: 'exclude' }); + expect(tagChip.classList.contains('exclude')).toBe(true); + expect(tagChip.classList.contains('active')).toBe(false); + expect(document.getElementById('activeFiltersCount').textContent).toBe('1'); + + loadMoreWithVirtualScrollMock.mockClear(); + + tagChip.dispatchEvent(new Event('click', { bubbles: true })); + await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1)); + + expect(getCurrentPageState().filters.tags).toEqual({}); expect(document.getElementById('activeFiltersCount').style.display).toBe('none'); }); diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index f57db314..77a2a1b4 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -5,6 +5,7 @@ from py.services.lora_service import LoraService from py.services.checkpoint_service import CheckpointService from py.services.embedding_service import EmbeddingService from py.services.model_query import ( + FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, @@ -126,7 +127,7 @@ async def test_get_paginated_data_uses_injected_collaborators(): search="query", fuzzy_search=True, base_models=["base"], - tags=["tag"], + tags={"tag": "include"}, search_options={"recursive": False}, favorites_only=True, ) @@ -141,7 +142,7 @@ async def test_get_paginated_data_uses_injected_collaborators(): assert call_data == data assert criteria.folder == "root" assert criteria.base_models == ["base"] - assert criteria.tags == ["tag"] + assert criteria.tags == {"tag": "include"} assert criteria.favorites_only is True assert criteria.search_options.get("recursive") is False @@ -234,7 +235,7 @@ async def test_get_paginated_data_filters_and_searches_combination(): folder="root", search="artist", base_models=["v1"], - tags=["tag1"], + tags={"tag1": "include"}, search_options={"creator": True, "tags": True}, favorites_only=True, ) @@ -533,6 +534,36 @@ async def test_get_paginated_data_update_available_only_without_update_service() assert response["total_pages"] == 0 +def test_model_filter_set_handles_include_and_exclude_tag_filters(): + settings = StubSettings({}) + filter_set = ModelFilterSet(settings) + data = [ + {"model_name": "StyleOnly", "tags": ["style"]}, + {"model_name": "StyleAnime", "tags": ["style", "anime"]}, + {"model_name": "AnimeOnly", "tags": ["anime"]}, + ] + + criteria = FilterCriteria(tags={"style": "include", "anime": "exclude"}) + result = filter_set.apply(data, criteria) + + assert [item["model_name"] for item in result] == ["StyleOnly"] + + +def test_model_filter_set_supports_legacy_tag_arrays(): + settings = StubSettings({}) + filter_set = ModelFilterSet(settings) + data = [ + {"model_name": "StyleOnly", "tags": ["style"]}, + {"model_name": "StyleAnime", "tags": ["style", "anime"]}, + {"model_name": "AnimeOnly", "tags": ["anime"]}, + ] + + criteria = FilterCriteria(tags=["style"]) + result = filter_set.apply(data, criteria) + + assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"] + + @pytest.mark.asyncio @pytest.mark.parametrize( "service_cls, extra_fields",