Merge pull request #651 from willmiao/tag-filtering-with-include-exclude-states, see #622

feat: implement tag filtering with include/exclude states
This commit is contained in:
pixelpaws
2025-11-08 12:01:13 +08:00
committed by GitHub
12 changed files with 338 additions and 116 deletions

View File

@@ -21,7 +21,7 @@ This matrix captures the scenarios that Phase 3 frontend tests should cover for
| ID | Feature | Scenario | LoRAs Expectations | Checkpoints Expectations | Notes |
| --- | --- | --- | --- | --- | --- |
| F-01 | Search filter | Typing a query updates `pageState.filters.search`, persists to session, and triggers `resetAndReload` on submit | Validate `SearchManager` writes query and reloads via API stub; confirm LoRA cards pass query downstream | Same as LoRAs | Cover `enter` press and clicking search icon |
| F-02 | Tag filter | Selecting a tag chip adds it to filters, applies active styling, and reloads results | Tag stored under `filters.tags`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path |
| F-02 | Tag filter | Selecting a tag chip cycles include ➜ exclude ➜ clear, updates storage, and reloads results | Tag state stored under `filters.tags[tagName] = 'include'|'exclude'`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path |
| F-03 | Base model filter | Toggling base model checkboxes updates `filters.baseModel`, persists, and reloads | Ensure only LoRA-supported models show; toggle multi-select | Ensure SDXL/Flux base models appear as expected | Capture UI state restored from storage on next init |
| F-04 | Favorites-only | Clicking favorites toggle updates session flag and calls `resetAndReload(true)` | Button gains `.active` class and API called | Same | Verify duplicates badge refresh when active |
| F-05 | Sort selection | Changing sort select saves preference (legacy + new format) and reloads | Confirm `PageControls.saveSortPreference` invoked with option and API called | Same with checkpoints-specific defaults | Cover `convertLegacySortFormat` branch |

View File

@@ -144,7 +144,28 @@ class ModelListingHandler:
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
base_models = request.query.getall("base_model", [])
tags = request.query.getall("tag", [])
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
legacy_tags = request.query.getall("tag", [])
if not legacy_tags:
legacy_csv = request.query.get("tags")
if legacy_csv:
legacy_tags = [tag.strip() for tag in legacy_csv.split(",") if tag.strip()]
include_tags = request.query.getall("tag_include", [])
exclude_tags = request.query.getall("tag_exclude", [])
tag_filters: Dict[str, str] = {}
for tag in legacy_tags:
if tag:
tag_filters[tag] = "include"
for tag in include_tags:
if tag:
tag_filters[tag] = "include"
for tag in exclude_tags:
if tag:
tag_filters[tag] = "exclude"
favorites_only = request.query.get("favorites_only", "false").lower() == "true"
search_options = {
@@ -189,7 +210,7 @@ class ModelListingHandler:
"search": search,
"fuzzy_search": fuzzy_search,
"base_models": base_models,
"tags": tags,
"tags": tag_filters,
"search_options": search_options,
"hash_filters": hash_filters,
"favorites_only": favorites_only,

View File

@@ -152,14 +152,31 @@ class RecipeListingHandler:
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
}
filters: Dict[str, list[str]] = {}
filters: Dict[str, Any] = {}
base_models = request.query.get("base_models")
if base_models:
filters["base_model"] = base_models.split(",")
tags = request.query.get("tags")
if tags:
filters["tags"] = tags.split(",")
tag_filters: Dict[str, str] = {}
legacy_tags = request.query.get("tags")
if legacy_tags:
for tag in legacy_tags.split(","):
tag = tag.strip()
if tag:
tag_filters[tag] = "include"
include_tags = request.query.getall("tag_include", [])
for tag in include_tags:
if tag:
tag_filters[tag] = "include"
exclude_tags = request.query.getall("tag_exclude", [])
for tag in exclude_tags:
if tag:
tag_filters[tag] = "exclude"
if tag_filters:
filters["tags"] = tag_filters
lora_hash = request.query.get("lora_hash")

View File

@@ -59,7 +59,7 @@ class BaseModelService(ABC):
search: str = None,
fuzzy_search: bool = False,
base_models: list = None,
tags: list = None,
tags: Optional[Dict[str, str]] = None,
search_options: dict = None,
hash_filters: dict = None,
favorites_only: bool = False,
@@ -149,7 +149,7 @@ class BaseModelService(ABC):
data: List[Dict],
folder: str = None,
base_models: list = None,
tags: list = None,
tags: Optional[Dict[str, str]] = None,
favorites_only: bool = False,
search_options: dict = None,
) -> List[Dict]:

View File

@@ -28,7 +28,7 @@ class FilterCriteria:
folder: Optional[str] = None
base_models: Optional[Sequence[str]] = None
tags: Optional[Sequence[str]] = None
tags: Optional[Dict[str, str]] = None
favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None
@@ -108,12 +108,30 @@ class ModelFilterSet:
base_model_set = set(base_models)
items = [item for item in items if item.get("base_model") in base_model_set]
tags = criteria.tags or []
if tags:
tag_set = set(tags)
tag_filters = criteria.tags or {}
include_tags = set()
exclude_tags = set()
if isinstance(tag_filters, dict):
for tag, state in tag_filters.items():
if not tag:
continue
if state == "exclude":
exclude_tags.add(tag)
else:
include_tags.add(tag)
else:
include_tags = {tag for tag in tag_filters if tag}
if include_tags:
items = [
item for item in items
if any(tag in tag_set for tag in item.get("tags", []))
if any(tag in include_tags for tag in (item.get("tags", []) or []))
]
if exclude_tags:
items = [
item for item in items
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
]
return items

View File

@@ -729,10 +729,32 @@ class RecipeScanner:
# Filter by tags
if 'tags' in filters and filters['tags']:
filtered_data = [
item for item in filtered_data
if any(tag in item.get('tags', []) for tag in filters['tags'])
]
tag_spec = filters['tags']
include_tags = set()
exclude_tags = set()
if isinstance(tag_spec, dict):
for tag, state in tag_spec.items():
if not tag:
continue
if state == 'exclude':
exclude_tags.add(tag)
else:
include_tags.add(tag)
else:
include_tags = {tag for tag in tag_spec if tag}
if include_tags:
filtered_data = [
item for item in filtered_data
if any(tag in include_tags for tag in (item.get('tags', []) or []))
]
if exclude_tags:
filtered_data = [
item for item in filtered_data
if not any(tag in exclude_tags for tag in (item.get('tags', []) or []))
]
# Calculate pagination
total_items = len(filtered_data)

View File

@@ -806,9 +806,13 @@ export class BaseModelApiClient {
params.append('recursive', pageState.searchOptions.recursive ? 'true' : 'false');
if (pageState.filters) {
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
pageState.filters.tags.forEach(tag => {
params.append('tag', tag);
if (pageState.filters.tags && Object.keys(pageState.filters.tags).length > 0) {
Object.entries(pageState.filters.tags).forEach(([tag, state]) => {
if (state === 'include') {
params.append('tag_include', tag);
} else if (state === 'exclude') {
params.append('tag_exclude', tag);
}
});
}

View File

@@ -66,8 +66,14 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) {
}
// Add tag filters
if (pageState.filters?.tags && pageState.filters.tags.length) {
params.append('tags', pageState.filters.tags.join(','));
if (pageState.filters?.tags && Object.keys(pageState.filters.tags).length) {
Object.entries(pageState.filters.tags).forEach(([tag, state]) => {
if (state === 'include') {
params.append('tag_include', tag);
} else if (state === 'exclude') {
params.append('tag_exclude', tag);
}
});
}
}

View File

@@ -12,11 +12,7 @@ export class FilterManager {
this.currentPage = options.page || document.body.dataset.page || 'loras';
const pageState = getCurrentPageState();
this.filters = pageState.filters || {
baseModel: [],
tags: [],
license: {}
};
this.filters = this.initializeFilters(pageState ? pageState.filters : undefined);
this.filterPanel = document.getElementById('filterPanel');
this.filterButton = document.getElementById('filterButton');
@@ -28,6 +24,7 @@ export class FilterManager {
// Store this instance in the state
if (pageState) {
pageState.filterManager = this;
pageState.filters = this.cloneFilters();
}
}
@@ -111,17 +108,12 @@ export class FilterManager {
tagEl.dataset.tag = tagName;
tagEl.innerHTML = `${tagName} <span class="tag-count">${tag.count}</span>`;
// Add click handler to toggle selection and automatically apply
// Add click handler to cycle through tri-state filter and automatically apply
tagEl.addEventListener('click', async () => {
tagEl.classList.toggle('active');
if (tagEl.classList.contains('active')) {
if (!this.filters.tags.includes(tagName)) {
this.filters.tags.push(tagName);
}
} else {
this.filters.tags = this.filters.tags.filter(t => t !== tagName);
}
const currentState = (this.filters.tags && this.filters.tags[tagName]) || 'none';
const newState = this.getNextTriStateState(currentState);
this.setTagFilterState(tagName, newState);
this.applyTagElementState(tagEl, newState);
this.updateActiveFiltersCount();
@@ -129,6 +121,7 @@ export class FilterManager {
await this.applyFilters(false);
});
this.applyTagElementState(tagEl, (this.filters.tags && this.filters.tags[tagName]) || 'none');
tagsContainer.appendChild(tagEl);
});
}
@@ -310,11 +303,8 @@ export class FilterManager {
const modelTags = document.querySelectorAll('.tag-filter');
modelTags.forEach(tag => {
const tagName = tag.dataset.tag;
if (this.filters.tags.includes(tagName)) {
tag.classList.add('active');
} else {
tag.classList.remove('active');
}
const state = (this.filters.tags && this.filters.tags[tagName]) || 'none';
this.applyTagElementState(tag, state);
});
// Update license tags
@@ -322,9 +312,9 @@ export class FilterManager {
}
updateActiveFiltersCount() {
const totalActiveFilters = this.filters.baseModel.length +
this.filters.tags.length +
(this.filters.license ? Object.keys(this.filters.license).length : 0);
const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount;
if (this.activeFiltersCount) {
if (totalActiveFilters > 0) {
@@ -341,10 +331,11 @@ export class FilterManager {
const storageKey = `${this.currentPage}_filters`;
// Save filters to localStorage
setStorageItem(storageKey, this.filters);
const filtersSnapshot = this.cloneFilters();
setStorageItem(storageKey, filtersSnapshot);
// Update state with current filters
pageState.filters = { ...this.filters };
pageState.filters = filtersSnapshot;
// Call the appropriate manager's load method based on page type
if (this.currentPage === 'recipes' && window.recipeManager) {
@@ -359,7 +350,7 @@ export class FilterManager {
this.filterButton.classList.add('active');
if (showToastNotification) {
const baseModelCount = this.filters.baseModel.length;
const tagsCount = this.filters.tags.length;
const tagsCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
let message = '';
if (baseModelCount > 0 && tagsCount > 0) {
@@ -382,15 +373,16 @@ export class FilterManager {
async clearFilters() {
// Clear all filters
this.filters = {
this.filters = this.initializeFilters({
...this.filters,
baseModel: [],
tags: [],
license: {} // Initialize with empty object instead of deleting
};
tags: {},
license: {}
});
// Update state
const pageState = getCurrentPageState();
pageState.filters = { ...this.filters };
pageState.filters = this.cloneFilters();
// Update UI
this.updateTagSelections();
@@ -424,15 +416,11 @@ export class FilterManager {
if (savedFilters) {
try {
// Ensure backward compatibility with older filter format
this.filters = {
baseModel: savedFilters.baseModel || [],
tags: savedFilters.tags || [],
license: savedFilters.license || {}
};
this.filters = this.initializeFilters(savedFilters);
// Update state with loaded filters
const pageState = getCurrentPageState();
pageState.filters = { ...this.filters };
pageState.filters = this.cloneFilters();
this.updateTagSelections();
this.updateActiveFiltersCount();
@@ -447,8 +435,109 @@ export class FilterManager {
}
hasActiveFilters() {
return this.filters.baseModel.length > 0 ||
this.filters.tags.length > 0 ||
(this.filters.license && Object.keys(this.filters.license).length > 0);
const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0;
}
initializeFilters(existingFilters = {}) {
const source = existingFilters || {};
return {
...source,
baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [],
tags: this.normalizeTagFilters(source.tags),
license: this.normalizeLicenseFilters(source.license)
};
}
normalizeTagFilters(tagFilters) {
if (!tagFilters) {
return {};
}
if (Array.isArray(tagFilters)) {
return tagFilters.reduce((acc, tag) => {
if (typeof tag === 'string' && tag.trim().length > 0) {
acc[tag] = 'include';
}
return acc;
}, {});
}
if (typeof tagFilters === 'object') {
const normalized = {};
Object.entries(tagFilters).forEach(([tag, state]) => {
if (!tag) {
return;
}
const normalizedState = typeof state === 'string' ? state.toLowerCase() : '';
if (normalizedState === 'include' || normalizedState === 'exclude') {
normalized[tag] = normalizedState;
}
});
return normalized;
}
return {};
}
normalizeLicenseFilters(licenseFilters) {
if (!licenseFilters || typeof licenseFilters !== 'object') {
return {};
}
const normalized = {};
Object.entries(licenseFilters).forEach(([key, state]) => {
const normalizedState = typeof state === 'string' ? state.toLowerCase() : '';
if (normalizedState === 'include' || normalizedState === 'exclude') {
normalized[key] = normalizedState;
}
});
return normalized;
}
cloneFilters() {
return {
...this.filters,
baseModel: [...(this.filters.baseModel || [])],
tags: { ...(this.filters.tags || {}) },
license: { ...(this.filters.license || {}) }
};
}
getNextTriStateState(currentState) {
switch (currentState) {
case 'none':
return 'include';
case 'include':
return 'exclude';
default:
return 'none';
}
}
setTagFilterState(tagName, state) {
if (!this.filters.tags) {
this.filters.tags = {};
}
if (state === 'none') {
delete this.filters.tags[tagName];
} else {
this.filters.tags[tagName] = state;
}
}
applyTagElementState(element, state) {
if (!element) {
return;
}
element.classList.remove('active', 'exclude');
if (state === 'include') {
element.classList.add('active');
} else if (state === 'exclude') {
element.classList.add('exclude');
}
}
}

View File

@@ -66,18 +66,19 @@ export const state = {
activeFolder: getStorageItem(`${MODEL_TYPES.LORA}_activeFolder`),
activeLetterFilter: null,
previewVersions: loraPreviewVersions,
searchManager: null,
searchOptions: {
filename: true,
modelname: true,
tags: false,
creator: false,
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
},
filters: {
baseModel: [],
tags: []
},
searchManager: null,
searchOptions: {
filename: true,
modelname: true,
tags: false,
creator: false,
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
},
filters: {
baseModel: [],
tags: {},
license: {}
},
bulkMode: false,
selectedLoras: new Set(),
loraMetadataCache: new Map(),
@@ -91,18 +92,19 @@ export const state = {
isLoading: false,
hasMore: true,
sortBy: 'date',
searchManager: null,
searchOptions: {
title: true,
tags: true,
loraName: true,
loraModel: true
},
filters: {
baseModel: [],
tags: [],
search: ''
},
searchManager: null,
searchOptions: {
title: true,
tags: true,
loraName: true,
loraModel: true
},
filters: {
baseModel: [],
tags: {},
license: {},
search: ''
},
pageSize: 20,
showFavoritesOnly: false,
duplicatesMode: false,
@@ -117,17 +119,18 @@ export const state = {
sortBy: 'name',
activeFolder: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_activeFolder`),
previewVersions: checkpointPreviewVersions,
searchManager: null,
searchOptions: {
filename: true,
modelname: true,
creator: false,
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
},
filters: {
baseModel: [],
tags: []
},
searchManager: null,
searchOptions: {
filename: true,
modelname: true,
creator: false,
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
},
filters: {
baseModel: [],
tags: {},
license: {}
},
modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model'
bulkMode: false,
selectedModels: new Set(),
@@ -145,18 +148,19 @@ export const state = {
activeFolder: getStorageItem(`${MODEL_TYPES.EMBEDDING}_activeFolder`),
activeLetterFilter: null,
previewVersions: embeddingPreviewVersions,
searchManager: null,
searchOptions: {
filename: true,
modelname: true,
tags: false,
creator: false,
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
},
filters: {
baseModel: [],
tags: []
},
searchManager: null,
searchOptions: {
filename: true,
modelname: true,
tags: false,
creator: false,
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
},
filters: {
baseModel: [],
tags: {},
license: {}
},
bulkMode: false,
selectedModels: new Set(),
metadataCache: new Map(),

View File

@@ -225,21 +225,31 @@ describe('FilterManager tag and base model filters', () => {
tagChip.dispatchEvent(new Event('click', { bubbles: true }));
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
expect(getCurrentPageState().filters.tags).toEqual(['style']);
expect(getCurrentPageState().filters.tags).toEqual({ style: 'include' });
expect(tagChip.classList.contains('active')).toBe(true);
expect(document.getElementById('activeFiltersCount').textContent).toBe('1');
expect(document.getElementById('activeFiltersCount').style.display).toBe('inline-flex');
const storageKey = `lora_manager_${pageKey}_filters`;
const storedFilters = JSON.parse(localStorage.getItem(storageKey));
expect(storedFilters.tags).toEqual(['style']);
expect(storedFilters.tags).toEqual({ style: 'include' });
loadMoreWithVirtualScrollMock.mockClear();
tagChip.dispatchEvent(new Event('click', { bubbles: true }));
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
expect(getCurrentPageState().filters.tags).toEqual([]);
expect(getCurrentPageState().filters.tags).toEqual({ style: 'exclude' });
expect(tagChip.classList.contains('exclude')).toBe(true);
expect(tagChip.classList.contains('active')).toBe(false);
expect(document.getElementById('activeFiltersCount').textContent).toBe('1');
loadMoreWithVirtualScrollMock.mockClear();
tagChip.dispatchEvent(new Event('click', { bubbles: true }));
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
expect(getCurrentPageState().filters.tags).toEqual({});
expect(document.getElementById('activeFiltersCount').style.display).toBe('none');
});

View File

@@ -5,6 +5,7 @@ from py.services.lora_service import LoraService
from py.services.checkpoint_service import CheckpointService
from py.services.embedding_service import EmbeddingService
from py.services.model_query import (
FilterCriteria,
ModelCacheRepository,
ModelFilterSet,
SearchStrategy,
@@ -126,7 +127,7 @@ async def test_get_paginated_data_uses_injected_collaborators():
search="query",
fuzzy_search=True,
base_models=["base"],
tags=["tag"],
tags={"tag": "include"},
search_options={"recursive": False},
favorites_only=True,
)
@@ -141,7 +142,7 @@ async def test_get_paginated_data_uses_injected_collaborators():
assert call_data == data
assert criteria.folder == "root"
assert criteria.base_models == ["base"]
assert criteria.tags == ["tag"]
assert criteria.tags == {"tag": "include"}
assert criteria.favorites_only is True
assert criteria.search_options.get("recursive") is False
@@ -234,7 +235,7 @@ async def test_get_paginated_data_filters_and_searches_combination():
folder="root",
search="artist",
base_models=["v1"],
tags=["tag1"],
tags={"tag1": "include"},
search_options={"creator": True, "tags": True},
favorites_only=True,
)
@@ -533,6 +534,36 @@ async def test_get_paginated_data_update_available_only_without_update_service()
assert response["total_pages"] == 0
def test_model_filter_set_handles_include_and_exclude_tag_filters():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "StyleOnly", "tags": ["style"]},
{"model_name": "StyleAnime", "tags": ["style", "anime"]},
{"model_name": "AnimeOnly", "tags": ["anime"]},
]
criteria = FilterCriteria(tags={"style": "include", "anime": "exclude"})
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["StyleOnly"]
def test_model_filter_set_supports_legacy_tag_arrays():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "StyleOnly", "tags": ["style"]},
{"model_name": "StyleAnime", "tags": ["style", "anime"]},
{"model_name": "AnimeOnly", "tags": ["anime"]},
]
criteria = FilterCriteria(tags=["style"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_cls, extra_fields",