Merge pull request #651 from willmiao/tag-filtering-with-include-exclude-states, see #622

feat: implement tag filtering with include/exclude states
This commit is contained in:
pixelpaws
2025-11-08 12:01:13 +08:00
committed by GitHub
12 changed files with 338 additions and 116 deletions

View File

@@ -21,7 +21,7 @@ This matrix captures the scenarios that Phase 3 frontend tests should cover for
| ID | Feature | Scenario | LoRAs Expectations | Checkpoints Expectations | Notes | | ID | Feature | Scenario | LoRAs Expectations | Checkpoints Expectations | Notes |
| --- | --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- | --- |
| F-01 | Search filter | Typing a query updates `pageState.filters.search`, persists to session, and triggers `resetAndReload` on submit | Validate `SearchManager` writes query and reloads via API stub; confirm LoRA cards pass query downstream | Same as LoRAs | Cover `enter` press and clicking search icon | | F-01 | Search filter | Typing a query updates `pageState.filters.search`, persists to session, and triggers `resetAndReload` on submit | Validate `SearchManager` writes query and reloads via API stub; confirm LoRA cards pass query downstream | Same as LoRAs | Cover `enter` press and clicking search icon |
| F-02 | Tag filter | Selecting a tag chip adds it to filters, applies active styling, and reloads results | Tag stored under `filters.tags`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path | | F-02 | Tag filter | Selecting a tag chip cycles include ➜ exclude ➜ clear, updates storage, and reloads results | Tag state stored under `filters.tags[tagName] = 'include'|'exclude'`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path |
| F-03 | Base model filter | Toggling base model checkboxes updates `filters.baseModel`, persists, and reloads | Ensure only LoRA-supported models show; toggle multi-select | Ensure SDXL/Flux base models appear as expected | Capture UI state restored from storage on next init | | F-03 | Base model filter | Toggling base model checkboxes updates `filters.baseModel`, persists, and reloads | Ensure only LoRA-supported models show; toggle multi-select | Ensure SDXL/Flux base models appear as expected | Capture UI state restored from storage on next init |
| F-04 | Favorites-only | Clicking favorites toggle updates session flag and calls `resetAndReload(true)` | Button gains `.active` class and API called | Same | Verify duplicates badge refresh when active | | F-04 | Favorites-only | Clicking favorites toggle updates session flag and calls `resetAndReload(true)` | Button gains `.active` class and API called | Same | Verify duplicates badge refresh when active |
| F-05 | Sort selection | Changing sort select saves preference (legacy + new format) and reloads | Confirm `PageControls.saveSortPreference` invoked with option and API called | Same with checkpoints-specific defaults | Cover `convertLegacySortFormat` branch | | F-05 | Sort selection | Changing sort select saves preference (legacy + new format) and reloads | Confirm `PageControls.saveSortPreference` invoked with option and API called | Same with checkpoints-specific defaults | Cover `convertLegacySortFormat` branch |

View File

@@ -144,7 +144,28 @@ class ModelListingHandler:
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true" fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
base_models = request.query.getall("base_model", []) base_models = request.query.getall("base_model", [])
tags = request.query.getall("tag", []) # Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
legacy_tags = request.query.getall("tag", [])
if not legacy_tags:
legacy_csv = request.query.get("tags")
if legacy_csv:
legacy_tags = [tag.strip() for tag in legacy_csv.split(",") if tag.strip()]
include_tags = request.query.getall("tag_include", [])
exclude_tags = request.query.getall("tag_exclude", [])
tag_filters: Dict[str, str] = {}
for tag in legacy_tags:
if tag:
tag_filters[tag] = "include"
for tag in include_tags:
if tag:
tag_filters[tag] = "include"
for tag in exclude_tags:
if tag:
tag_filters[tag] = "exclude"
favorites_only = request.query.get("favorites_only", "false").lower() == "true" favorites_only = request.query.get("favorites_only", "false").lower() == "true"
search_options = { search_options = {
@@ -189,7 +210,7 @@ class ModelListingHandler:
"search": search, "search": search,
"fuzzy_search": fuzzy_search, "fuzzy_search": fuzzy_search,
"base_models": base_models, "base_models": base_models,
"tags": tags, "tags": tag_filters,
"search_options": search_options, "search_options": search_options,
"hash_filters": hash_filters, "hash_filters": hash_filters,
"favorites_only": favorites_only, "favorites_only": favorites_only,

View File

@@ -152,14 +152,31 @@ class RecipeListingHandler:
"lora_model": request.query.get("search_lora_model", "true").lower() == "true", "lora_model": request.query.get("search_lora_model", "true").lower() == "true",
} }
filters: Dict[str, list[str]] = {} filters: Dict[str, Any] = {}
base_models = request.query.get("base_models") base_models = request.query.get("base_models")
if base_models: if base_models:
filters["base_model"] = base_models.split(",") filters["base_model"] = base_models.split(",")
tags = request.query.get("tags") tag_filters: Dict[str, str] = {}
if tags: legacy_tags = request.query.get("tags")
filters["tags"] = tags.split(",") if legacy_tags:
for tag in legacy_tags.split(","):
tag = tag.strip()
if tag:
tag_filters[tag] = "include"
include_tags = request.query.getall("tag_include", [])
for tag in include_tags:
if tag:
tag_filters[tag] = "include"
exclude_tags = request.query.getall("tag_exclude", [])
for tag in exclude_tags:
if tag:
tag_filters[tag] = "exclude"
if tag_filters:
filters["tags"] = tag_filters
lora_hash = request.query.get("lora_hash") lora_hash = request.query.get("lora_hash")

View File

@@ -59,7 +59,7 @@ class BaseModelService(ABC):
search: str = None, search: str = None,
fuzzy_search: bool = False, fuzzy_search: bool = False,
base_models: list = None, base_models: list = None,
tags: list = None, tags: Optional[Dict[str, str]] = None,
search_options: dict = None, search_options: dict = None,
hash_filters: dict = None, hash_filters: dict = None,
favorites_only: bool = False, favorites_only: bool = False,
@@ -149,7 +149,7 @@ class BaseModelService(ABC):
data: List[Dict], data: List[Dict],
folder: str = None, folder: str = None,
base_models: list = None, base_models: list = None,
tags: list = None, tags: Optional[Dict[str, str]] = None,
favorites_only: bool = False, favorites_only: bool = False,
search_options: dict = None, search_options: dict = None,
) -> List[Dict]: ) -> List[Dict]:

View File

@@ -28,7 +28,7 @@ class FilterCriteria:
folder: Optional[str] = None folder: Optional[str] = None
base_models: Optional[Sequence[str]] = None base_models: Optional[Sequence[str]] = None
tags: Optional[Sequence[str]] = None tags: Optional[Dict[str, str]] = None
favorites_only: bool = False favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None search_options: Optional[Dict[str, Any]] = None
@@ -108,12 +108,30 @@ class ModelFilterSet:
base_model_set = set(base_models) base_model_set = set(base_models)
items = [item for item in items if item.get("base_model") in base_model_set] items = [item for item in items if item.get("base_model") in base_model_set]
tags = criteria.tags or [] tag_filters = criteria.tags or {}
if tags: include_tags = set()
tag_set = set(tags) exclude_tags = set()
if isinstance(tag_filters, dict):
for tag, state in tag_filters.items():
if not tag:
continue
if state == "exclude":
exclude_tags.add(tag)
else:
include_tags.add(tag)
else:
include_tags = {tag for tag in tag_filters if tag}
if include_tags:
items = [ items = [
item for item in items item for item in items
if any(tag in tag_set for tag in item.get("tags", [])) if any(tag in include_tags for tag in (item.get("tags", []) or []))
]
if exclude_tags:
items = [
item for item in items
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
] ]
return items return items

View File

@@ -729,10 +729,32 @@ class RecipeScanner:
# Filter by tags # Filter by tags
if 'tags' in filters and filters['tags']: if 'tags' in filters and filters['tags']:
filtered_data = [ tag_spec = filters['tags']
item for item in filtered_data include_tags = set()
if any(tag in item.get('tags', []) for tag in filters['tags']) exclude_tags = set()
]
if isinstance(tag_spec, dict):
for tag, state in tag_spec.items():
if not tag:
continue
if state == 'exclude':
exclude_tags.add(tag)
else:
include_tags.add(tag)
else:
include_tags = {tag for tag in tag_spec if tag}
if include_tags:
filtered_data = [
item for item in filtered_data
if any(tag in include_tags for tag in (item.get('tags', []) or []))
]
if exclude_tags:
filtered_data = [
item for item in filtered_data
if not any(tag in exclude_tags for tag in (item.get('tags', []) or []))
]
# Calculate pagination # Calculate pagination
total_items = len(filtered_data) total_items = len(filtered_data)

View File

@@ -806,9 +806,13 @@ export class BaseModelApiClient {
params.append('recursive', pageState.searchOptions.recursive ? 'true' : 'false'); params.append('recursive', pageState.searchOptions.recursive ? 'true' : 'false');
if (pageState.filters) { if (pageState.filters) {
if (pageState.filters.tags && pageState.filters.tags.length > 0) { if (pageState.filters.tags && Object.keys(pageState.filters.tags).length > 0) {
pageState.filters.tags.forEach(tag => { Object.entries(pageState.filters.tags).forEach(([tag, state]) => {
params.append('tag', tag); if (state === 'include') {
params.append('tag_include', tag);
} else if (state === 'exclude') {
params.append('tag_exclude', tag);
}
}); });
} }

View File

@@ -66,8 +66,14 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) {
} }
// Add tag filters // Add tag filters
if (pageState.filters?.tags && pageState.filters.tags.length) { if (pageState.filters?.tags && Object.keys(pageState.filters.tags).length) {
params.append('tags', pageState.filters.tags.join(',')); Object.entries(pageState.filters.tags).forEach(([tag, state]) => {
if (state === 'include') {
params.append('tag_include', tag);
} else if (state === 'exclude') {
params.append('tag_exclude', tag);
}
});
} }
} }

View File

@@ -12,11 +12,7 @@ export class FilterManager {
this.currentPage = options.page || document.body.dataset.page || 'loras'; this.currentPage = options.page || document.body.dataset.page || 'loras';
const pageState = getCurrentPageState(); const pageState = getCurrentPageState();
this.filters = pageState.filters || { this.filters = this.initializeFilters(pageState ? pageState.filters : undefined);
baseModel: [],
tags: [],
license: {}
};
this.filterPanel = document.getElementById('filterPanel'); this.filterPanel = document.getElementById('filterPanel');
this.filterButton = document.getElementById('filterButton'); this.filterButton = document.getElementById('filterButton');
@@ -28,6 +24,7 @@ export class FilterManager {
// Store this instance in the state // Store this instance in the state
if (pageState) { if (pageState) {
pageState.filterManager = this; pageState.filterManager = this;
pageState.filters = this.cloneFilters();
} }
} }
@@ -111,17 +108,12 @@ export class FilterManager {
tagEl.dataset.tag = tagName; tagEl.dataset.tag = tagName;
tagEl.innerHTML = `${tagName} <span class="tag-count">${tag.count}</span>`; tagEl.innerHTML = `${tagName} <span class="tag-count">${tag.count}</span>`;
// Add click handler to toggle selection and automatically apply // Add click handler to cycle through tri-state filter and automatically apply
tagEl.addEventListener('click', async () => { tagEl.addEventListener('click', async () => {
tagEl.classList.toggle('active'); const currentState = (this.filters.tags && this.filters.tags[tagName]) || 'none';
const newState = this.getNextTriStateState(currentState);
if (tagEl.classList.contains('active')) { this.setTagFilterState(tagName, newState);
if (!this.filters.tags.includes(tagName)) { this.applyTagElementState(tagEl, newState);
this.filters.tags.push(tagName);
}
} else {
this.filters.tags = this.filters.tags.filter(t => t !== tagName);
}
this.updateActiveFiltersCount(); this.updateActiveFiltersCount();
@@ -129,6 +121,7 @@ export class FilterManager {
await this.applyFilters(false); await this.applyFilters(false);
}); });
this.applyTagElementState(tagEl, (this.filters.tags && this.filters.tags[tagName]) || 'none');
tagsContainer.appendChild(tagEl); tagsContainer.appendChild(tagEl);
}); });
} }
@@ -310,11 +303,8 @@ export class FilterManager {
const modelTags = document.querySelectorAll('.tag-filter'); const modelTags = document.querySelectorAll('.tag-filter');
modelTags.forEach(tag => { modelTags.forEach(tag => {
const tagName = tag.dataset.tag; const tagName = tag.dataset.tag;
if (this.filters.tags.includes(tagName)) { const state = (this.filters.tags && this.filters.tags[tagName]) || 'none';
tag.classList.add('active'); this.applyTagElementState(tag, state);
} else {
tag.classList.remove('active');
}
}); });
// Update license tags // Update license tags
@@ -322,9 +312,9 @@ export class FilterManager {
} }
updateActiveFiltersCount() { updateActiveFiltersCount() {
const totalActiveFilters = this.filters.baseModel.length + const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
this.filters.tags.length + const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
(this.filters.license ? Object.keys(this.filters.license).length : 0); const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount;
if (this.activeFiltersCount) { if (this.activeFiltersCount) {
if (totalActiveFilters > 0) { if (totalActiveFilters > 0) {
@@ -341,10 +331,11 @@ export class FilterManager {
const storageKey = `${this.currentPage}_filters`; const storageKey = `${this.currentPage}_filters`;
// Save filters to localStorage // Save filters to localStorage
setStorageItem(storageKey, this.filters); const filtersSnapshot = this.cloneFilters();
setStorageItem(storageKey, filtersSnapshot);
// Update state with current filters // Update state with current filters
pageState.filters = { ...this.filters }; pageState.filters = filtersSnapshot;
// Call the appropriate manager's load method based on page type // Call the appropriate manager's load method based on page type
if (this.currentPage === 'recipes' && window.recipeManager) { if (this.currentPage === 'recipes' && window.recipeManager) {
@@ -359,7 +350,7 @@ export class FilterManager {
this.filterButton.classList.add('active'); this.filterButton.classList.add('active');
if (showToastNotification) { if (showToastNotification) {
const baseModelCount = this.filters.baseModel.length; const baseModelCount = this.filters.baseModel.length;
const tagsCount = this.filters.tags.length; const tagsCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
let message = ''; let message = '';
if (baseModelCount > 0 && tagsCount > 0) { if (baseModelCount > 0 && tagsCount > 0) {
@@ -382,15 +373,16 @@ export class FilterManager {
async clearFilters() { async clearFilters() {
// Clear all filters // Clear all filters
this.filters = { this.filters = this.initializeFilters({
...this.filters,
baseModel: [], baseModel: [],
tags: [], tags: {},
license: {} // Initialize with empty object instead of deleting license: {}
}; });
// Update state // Update state
const pageState = getCurrentPageState(); const pageState = getCurrentPageState();
pageState.filters = { ...this.filters }; pageState.filters = this.cloneFilters();
// Update UI // Update UI
this.updateTagSelections(); this.updateTagSelections();
@@ -424,15 +416,11 @@ export class FilterManager {
if (savedFilters) { if (savedFilters) {
try { try {
// Ensure backward compatibility with older filter format // Ensure backward compatibility with older filter format
this.filters = { this.filters = this.initializeFilters(savedFilters);
baseModel: savedFilters.baseModel || [],
tags: savedFilters.tags || [],
license: savedFilters.license || {}
};
// Update state with loaded filters // Update state with loaded filters
const pageState = getCurrentPageState(); const pageState = getCurrentPageState();
pageState.filters = { ...this.filters }; pageState.filters = this.cloneFilters();
this.updateTagSelections(); this.updateTagSelections();
this.updateActiveFiltersCount(); this.updateActiveFiltersCount();
@@ -447,8 +435,109 @@ export class FilterManager {
} }
hasActiveFilters() { hasActiveFilters() {
return this.filters.baseModel.length > 0 || const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
this.filters.tags.length > 0 || const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
(this.filters.license && Object.keys(this.filters.license).length > 0); return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0;
}
initializeFilters(existingFilters = {}) {
const source = existingFilters || {};
return {
...source,
baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [],
tags: this.normalizeTagFilters(source.tags),
license: this.normalizeLicenseFilters(source.license)
};
}
normalizeTagFilters(tagFilters) {
if (!tagFilters) {
return {};
}
if (Array.isArray(tagFilters)) {
return tagFilters.reduce((acc, tag) => {
if (typeof tag === 'string' && tag.trim().length > 0) {
acc[tag] = 'include';
}
return acc;
}, {});
}
if (typeof tagFilters === 'object') {
const normalized = {};
Object.entries(tagFilters).forEach(([tag, state]) => {
if (!tag) {
return;
}
const normalizedState = typeof state === 'string' ? state.toLowerCase() : '';
if (normalizedState === 'include' || normalizedState === 'exclude') {
normalized[tag] = normalizedState;
}
});
return normalized;
}
return {};
}
normalizeLicenseFilters(licenseFilters) {
if (!licenseFilters || typeof licenseFilters !== 'object') {
return {};
}
const normalized = {};
Object.entries(licenseFilters).forEach(([key, state]) => {
const normalizedState = typeof state === 'string' ? state.toLowerCase() : '';
if (normalizedState === 'include' || normalizedState === 'exclude') {
normalized[key] = normalizedState;
}
});
return normalized;
}
cloneFilters() {
return {
...this.filters,
baseModel: [...(this.filters.baseModel || [])],
tags: { ...(this.filters.tags || {}) },
license: { ...(this.filters.license || {}) }
};
}
getNextTriStateState(currentState) {
switch (currentState) {
case 'none':
return 'include';
case 'include':
return 'exclude';
default:
return 'none';
}
}
setTagFilterState(tagName, state) {
if (!this.filters.tags) {
this.filters.tags = {};
}
if (state === 'none') {
delete this.filters.tags[tagName];
} else {
this.filters.tags[tagName] = state;
}
}
applyTagElementState(element, state) {
if (!element) {
return;
}
element.classList.remove('active', 'exclude');
if (state === 'include') {
element.classList.add('active');
} else if (state === 'exclude') {
element.classList.add('exclude');
}
} }
} }

View File

@@ -66,18 +66,19 @@ export const state = {
activeFolder: getStorageItem(`${MODEL_TYPES.LORA}_activeFolder`), activeFolder: getStorageItem(`${MODEL_TYPES.LORA}_activeFolder`),
activeLetterFilter: null, activeLetterFilter: null,
previewVersions: loraPreviewVersions, previewVersions: loraPreviewVersions,
searchManager: null, searchManager: null,
searchOptions: { searchOptions: {
filename: true, filename: true,
modelname: true, modelname: true,
tags: false, tags: false,
creator: false, creator: false,
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true), recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
}, },
filters: { filters: {
baseModel: [], baseModel: [],
tags: [] tags: {},
}, license: {}
},
bulkMode: false, bulkMode: false,
selectedLoras: new Set(), selectedLoras: new Set(),
loraMetadataCache: new Map(), loraMetadataCache: new Map(),
@@ -91,18 +92,19 @@ export const state = {
isLoading: false, isLoading: false,
hasMore: true, hasMore: true,
sortBy: 'date', sortBy: 'date',
searchManager: null, searchManager: null,
searchOptions: { searchOptions: {
title: true, title: true,
tags: true, tags: true,
loraName: true, loraName: true,
loraModel: true loraModel: true
}, },
filters: { filters: {
baseModel: [], baseModel: [],
tags: [], tags: {},
search: '' license: {},
}, search: ''
},
pageSize: 20, pageSize: 20,
showFavoritesOnly: false, showFavoritesOnly: false,
duplicatesMode: false, duplicatesMode: false,
@@ -117,17 +119,18 @@ export const state = {
sortBy: 'name', sortBy: 'name',
activeFolder: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_activeFolder`), activeFolder: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_activeFolder`),
previewVersions: checkpointPreviewVersions, previewVersions: checkpointPreviewVersions,
searchManager: null, searchManager: null,
searchOptions: { searchOptions: {
filename: true, filename: true,
modelname: true, modelname: true,
creator: false, creator: false,
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true), recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
}, },
filters: { filters: {
baseModel: [], baseModel: [],
tags: [] tags: {},
}, license: {}
},
modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model' modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model'
bulkMode: false, bulkMode: false,
selectedModels: new Set(), selectedModels: new Set(),
@@ -145,18 +148,19 @@ export const state = {
activeFolder: getStorageItem(`${MODEL_TYPES.EMBEDDING}_activeFolder`), activeFolder: getStorageItem(`${MODEL_TYPES.EMBEDDING}_activeFolder`),
activeLetterFilter: null, activeLetterFilter: null,
previewVersions: embeddingPreviewVersions, previewVersions: embeddingPreviewVersions,
searchManager: null, searchManager: null,
searchOptions: { searchOptions: {
filename: true, filename: true,
modelname: true, modelname: true,
tags: false, tags: false,
creator: false, creator: false,
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true), recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
}, },
filters: { filters: {
baseModel: [], baseModel: [],
tags: [] tags: {},
}, license: {}
},
bulkMode: false, bulkMode: false,
selectedModels: new Set(), selectedModels: new Set(),
metadataCache: new Map(), metadataCache: new Map(),

View File

@@ -225,21 +225,31 @@ describe('FilterManager tag and base model filters', () => {
tagChip.dispatchEvent(new Event('click', { bubbles: true })); tagChip.dispatchEvent(new Event('click', { bubbles: true }));
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1)); await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
expect(getCurrentPageState().filters.tags).toEqual(['style']); expect(getCurrentPageState().filters.tags).toEqual({ style: 'include' });
expect(tagChip.classList.contains('active')).toBe(true); expect(tagChip.classList.contains('active')).toBe(true);
expect(document.getElementById('activeFiltersCount').textContent).toBe('1'); expect(document.getElementById('activeFiltersCount').textContent).toBe('1');
expect(document.getElementById('activeFiltersCount').style.display).toBe('inline-flex'); expect(document.getElementById('activeFiltersCount').style.display).toBe('inline-flex');
const storageKey = `lora_manager_${pageKey}_filters`; const storageKey = `lora_manager_${pageKey}_filters`;
const storedFilters = JSON.parse(localStorage.getItem(storageKey)); const storedFilters = JSON.parse(localStorage.getItem(storageKey));
expect(storedFilters.tags).toEqual(['style']); expect(storedFilters.tags).toEqual({ style: 'include' });
loadMoreWithVirtualScrollMock.mockClear(); loadMoreWithVirtualScrollMock.mockClear();
tagChip.dispatchEvent(new Event('click', { bubbles: true })); tagChip.dispatchEvent(new Event('click', { bubbles: true }));
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1)); await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
expect(getCurrentPageState().filters.tags).toEqual([]); expect(getCurrentPageState().filters.tags).toEqual({ style: 'exclude' });
expect(tagChip.classList.contains('exclude')).toBe(true);
expect(tagChip.classList.contains('active')).toBe(false);
expect(document.getElementById('activeFiltersCount').textContent).toBe('1');
loadMoreWithVirtualScrollMock.mockClear();
tagChip.dispatchEvent(new Event('click', { bubbles: true }));
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
expect(getCurrentPageState().filters.tags).toEqual({});
expect(document.getElementById('activeFiltersCount').style.display).toBe('none'); expect(document.getElementById('activeFiltersCount').style.display).toBe('none');
}); });

View File

@@ -5,6 +5,7 @@ from py.services.lora_service import LoraService
from py.services.checkpoint_service import CheckpointService from py.services.checkpoint_service import CheckpointService
from py.services.embedding_service import EmbeddingService from py.services.embedding_service import EmbeddingService
from py.services.model_query import ( from py.services.model_query import (
FilterCriteria,
ModelCacheRepository, ModelCacheRepository,
ModelFilterSet, ModelFilterSet,
SearchStrategy, SearchStrategy,
@@ -126,7 +127,7 @@ async def test_get_paginated_data_uses_injected_collaborators():
search="query", search="query",
fuzzy_search=True, fuzzy_search=True,
base_models=["base"], base_models=["base"],
tags=["tag"], tags={"tag": "include"},
search_options={"recursive": False}, search_options={"recursive": False},
favorites_only=True, favorites_only=True,
) )
@@ -141,7 +142,7 @@ async def test_get_paginated_data_uses_injected_collaborators():
assert call_data == data assert call_data == data
assert criteria.folder == "root" assert criteria.folder == "root"
assert criteria.base_models == ["base"] assert criteria.base_models == ["base"]
assert criteria.tags == ["tag"] assert criteria.tags == {"tag": "include"}
assert criteria.favorites_only is True assert criteria.favorites_only is True
assert criteria.search_options.get("recursive") is False assert criteria.search_options.get("recursive") is False
@@ -234,7 +235,7 @@ async def test_get_paginated_data_filters_and_searches_combination():
folder="root", folder="root",
search="artist", search="artist",
base_models=["v1"], base_models=["v1"],
tags=["tag1"], tags={"tag1": "include"},
search_options={"creator": True, "tags": True}, search_options={"creator": True, "tags": True},
favorites_only=True, favorites_only=True,
) )
@@ -533,6 +534,36 @@ async def test_get_paginated_data_update_available_only_without_update_service()
assert response["total_pages"] == 0 assert response["total_pages"] == 0
def test_model_filter_set_handles_include_and_exclude_tag_filters():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "StyleOnly", "tags": ["style"]},
{"model_name": "StyleAnime", "tags": ["style", "anime"]},
{"model_name": "AnimeOnly", "tags": ["anime"]},
]
criteria = FilterCriteria(tags={"style": "include", "anime": "exclude"})
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["StyleOnly"]
def test_model_filter_set_supports_legacy_tag_arrays():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "StyleOnly", "tags": ["style"]},
{"model_name": "StyleAnime", "tags": ["style", "anime"]},
{"model_name": "AnimeOnly", "tags": ["anime"]},
]
criteria = FilterCriteria(tags=["style"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service_cls, extra_fields", "service_cls, extra_fields",