mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: implement tag filtering with include/exclude states
- Update frontend tag filter to cycle through include/exclude/clear states - Add backend support for tag_include and tag_exclude query parameters - Maintain backward compatibility with legacy tag parameter - Store tag states as dictionary with 'include'/'exclude' values - Update test matrix documentation to reflect new tag behavior The changes enable more granular tag filtering where users can now explicitly include or exclude specific tags, rather than just adding tags to a simple inclusion list. This provides better control over search results and improves the filtering user experience.
This commit is contained in:
@@ -21,7 +21,7 @@ This matrix captures the scenarios that Phase 3 frontend tests should cover for
|
||||
| ID | Feature | Scenario | LoRAs Expectations | Checkpoints Expectations | Notes |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| F-01 | Search filter | Typing a query updates `pageState.filters.search`, persists to session, and triggers `resetAndReload` on submit | Validate `SearchManager` writes query and reloads via API stub; confirm LoRA cards pass query downstream | Same as LoRAs | Cover `enter` press and clicking search icon |
|
||||
| F-02 | Tag filter | Selecting a tag chip adds it to filters, applies active styling, and reloads results | Tag stored under `filters.tags`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path |
|
||||
| F-02 | Tag filter | Selecting a tag chip cycles include ➜ exclude ➜ clear, updates storage, and reloads results | Tag state stored under `filters.tags[tagName] = 'include'|'exclude'`; `FilterManager.applyFilters` persists and triggers `resetAndReload(true)` | Same; ensure base model tag set is scoped to checkpoints dataset | Include removal path |
|
||||
| F-03 | Base model filter | Toggling base model checkboxes updates `filters.baseModel`, persists, and reloads | Ensure only LoRA-supported models show; toggle multi-select | Ensure SDXL/Flux base models appear as expected | Capture UI state restored from storage on next init |
|
||||
| F-04 | Favorites-only | Clicking favorites toggle updates session flag and calls `resetAndReload(true)` | Button gains `.active` class and API called | Same | Verify duplicates badge refresh when active |
|
||||
| F-05 | Sort selection | Changing sort select saves preference (legacy + new format) and reloads | Confirm `PageControls.saveSortPreference` invoked with option and API called | Same with checkpoints-specific defaults | Cover `convertLegacySortFormat` branch |
|
||||
|
||||
@@ -144,7 +144,28 @@ class ModelListingHandler:
|
||||
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
|
||||
|
||||
base_models = request.query.getall("base_model", [])
|
||||
tags = request.query.getall("tag", [])
|
||||
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
|
||||
legacy_tags = request.query.getall("tag", [])
|
||||
if not legacy_tags:
|
||||
legacy_csv = request.query.get("tags")
|
||||
if legacy_csv:
|
||||
legacy_tags = [tag.strip() for tag in legacy_csv.split(",") if tag.strip()]
|
||||
|
||||
include_tags = request.query.getall("tag_include", [])
|
||||
exclude_tags = request.query.getall("tag_exclude", [])
|
||||
|
||||
tag_filters: Dict[str, str] = {}
|
||||
for tag in legacy_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "include"
|
||||
|
||||
for tag in include_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "include"
|
||||
|
||||
for tag in exclude_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "exclude"
|
||||
favorites_only = request.query.get("favorites_only", "false").lower() == "true"
|
||||
|
||||
search_options = {
|
||||
@@ -189,7 +210,7 @@ class ModelListingHandler:
|
||||
"search": search,
|
||||
"fuzzy_search": fuzzy_search,
|
||||
"base_models": base_models,
|
||||
"tags": tags,
|
||||
"tags": tag_filters,
|
||||
"search_options": search_options,
|
||||
"hash_filters": hash_filters,
|
||||
"favorites_only": favorites_only,
|
||||
|
||||
@@ -152,14 +152,31 @@ class RecipeListingHandler:
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||
}
|
||||
|
||||
filters: Dict[str, list[str]] = {}
|
||||
filters: Dict[str, Any] = {}
|
||||
base_models = request.query.get("base_models")
|
||||
if base_models:
|
||||
filters["base_model"] = base_models.split(",")
|
||||
|
||||
tags = request.query.get("tags")
|
||||
if tags:
|
||||
filters["tags"] = tags.split(",")
|
||||
tag_filters: Dict[str, str] = {}
|
||||
legacy_tags = request.query.get("tags")
|
||||
if legacy_tags:
|
||||
for tag in legacy_tags.split(","):
|
||||
tag = tag.strip()
|
||||
if tag:
|
||||
tag_filters[tag] = "include"
|
||||
|
||||
include_tags = request.query.getall("tag_include", [])
|
||||
for tag in include_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "include"
|
||||
|
||||
exclude_tags = request.query.getall("tag_exclude", [])
|
||||
for tag in exclude_tags:
|
||||
if tag:
|
||||
tag_filters[tag] = "exclude"
|
||||
|
||||
if tag_filters:
|
||||
filters["tags"] = tag_filters
|
||||
|
||||
lora_hash = request.query.get("lora_hash")
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class BaseModelService(ABC):
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
@@ -149,7 +149,7 @@ class BaseModelService(ABC):
|
||||
data: List[Dict],
|
||||
folder: str = None,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
) -> List[Dict]:
|
||||
|
||||
@@ -28,7 +28,7 @@ class FilterCriteria:
|
||||
|
||||
folder: Optional[str] = None
|
||||
base_models: Optional[Sequence[str]] = None
|
||||
tags: Optional[Sequence[str]] = None
|
||||
tags: Optional[Dict[str, str]] = None
|
||||
favorites_only: bool = False
|
||||
search_options: Optional[Dict[str, Any]] = None
|
||||
|
||||
@@ -108,12 +108,30 @@ class ModelFilterSet:
|
||||
base_model_set = set(base_models)
|
||||
items = [item for item in items if item.get("base_model") in base_model_set]
|
||||
|
||||
tags = criteria.tags or []
|
||||
if tags:
|
||||
tag_set = set(tags)
|
||||
tag_filters = criteria.tags or {}
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
if isinstance(tag_filters, dict):
|
||||
for tag, state in tag_filters.items():
|
||||
if not tag:
|
||||
continue
|
||||
if state == "exclude":
|
||||
exclude_tags.add(tag)
|
||||
else:
|
||||
include_tags.add(tag)
|
||||
else:
|
||||
include_tags = {tag for tag in tag_filters if tag}
|
||||
|
||||
if include_tags:
|
||||
items = [
|
||||
item for item in items
|
||||
if any(tag in tag_set for tag in item.get("tags", []))
|
||||
if any(tag in include_tags for tag in (item.get("tags", []) or []))
|
||||
]
|
||||
|
||||
if exclude_tags:
|
||||
items = [
|
||||
item for item in items
|
||||
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
@@ -729,10 +729,32 @@ class RecipeScanner:
|
||||
|
||||
# Filter by tags
|
||||
if 'tags' in filters and filters['tags']:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if any(tag in item.get('tags', []) for tag in filters['tags'])
|
||||
]
|
||||
tag_spec = filters['tags']
|
||||
include_tags = set()
|
||||
exclude_tags = set()
|
||||
|
||||
if isinstance(tag_spec, dict):
|
||||
for tag, state in tag_spec.items():
|
||||
if not tag:
|
||||
continue
|
||||
if state == 'exclude':
|
||||
exclude_tags.add(tag)
|
||||
else:
|
||||
include_tags.add(tag)
|
||||
else:
|
||||
include_tags = {tag for tag in tag_spec if tag}
|
||||
|
||||
if include_tags:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if any(tag in include_tags for tag in (item.get('tags', []) or []))
|
||||
]
|
||||
|
||||
if exclude_tags:
|
||||
filtered_data = [
|
||||
item for item in filtered_data
|
||||
if not any(tag in exclude_tags for tag in (item.get('tags', []) or []))
|
||||
]
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
|
||||
@@ -806,9 +806,13 @@ export class BaseModelApiClient {
|
||||
params.append('recursive', pageState.searchOptions.recursive ? 'true' : 'false');
|
||||
|
||||
if (pageState.filters) {
|
||||
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
|
||||
pageState.filters.tags.forEach(tag => {
|
||||
params.append('tag', tag);
|
||||
if (pageState.filters.tags && Object.keys(pageState.filters.tags).length > 0) {
|
||||
Object.entries(pageState.filters.tags).forEach(([tag, state]) => {
|
||||
if (state === 'include') {
|
||||
params.append('tag_include', tag);
|
||||
} else if (state === 'exclude') {
|
||||
params.append('tag_exclude', tag);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -66,8 +66,14 @@ export async function fetchRecipesPage(page = 1, pageSize = 100) {
|
||||
}
|
||||
|
||||
// Add tag filters
|
||||
if (pageState.filters?.tags && pageState.filters.tags.length) {
|
||||
params.append('tags', pageState.filters.tags.join(','));
|
||||
if (pageState.filters?.tags && Object.keys(pageState.filters.tags).length) {
|
||||
Object.entries(pageState.filters.tags).forEach(([tag, state]) => {
|
||||
if (state === 'include') {
|
||||
params.append('tag_include', tag);
|
||||
} else if (state === 'exclude') {
|
||||
params.append('tag_exclude', tag);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,11 +12,7 @@ export class FilterManager {
|
||||
this.currentPage = options.page || document.body.dataset.page || 'loras';
|
||||
const pageState = getCurrentPageState();
|
||||
|
||||
this.filters = pageState.filters || {
|
||||
baseModel: [],
|
||||
tags: [],
|
||||
license: {}
|
||||
};
|
||||
this.filters = this.initializeFilters(pageState ? pageState.filters : undefined);
|
||||
|
||||
this.filterPanel = document.getElementById('filterPanel');
|
||||
this.filterButton = document.getElementById('filterButton');
|
||||
@@ -28,6 +24,7 @@ export class FilterManager {
|
||||
// Store this instance in the state
|
||||
if (pageState) {
|
||||
pageState.filterManager = this;
|
||||
pageState.filters = this.cloneFilters();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,17 +108,12 @@ export class FilterManager {
|
||||
tagEl.dataset.tag = tagName;
|
||||
tagEl.innerHTML = `${tagName} <span class="tag-count">${tag.count}</span>`;
|
||||
|
||||
// Add click handler to toggle selection and automatically apply
|
||||
// Add click handler to cycle through tri-state filter and automatically apply
|
||||
tagEl.addEventListener('click', async () => {
|
||||
tagEl.classList.toggle('active');
|
||||
|
||||
if (tagEl.classList.contains('active')) {
|
||||
if (!this.filters.tags.includes(tagName)) {
|
||||
this.filters.tags.push(tagName);
|
||||
}
|
||||
} else {
|
||||
this.filters.tags = this.filters.tags.filter(t => t !== tagName);
|
||||
}
|
||||
const currentState = (this.filters.tags && this.filters.tags[tagName]) || 'none';
|
||||
const newState = this.getNextTriStateState(currentState);
|
||||
this.setTagFilterState(tagName, newState);
|
||||
this.applyTagElementState(tagEl, newState);
|
||||
|
||||
this.updateActiveFiltersCount();
|
||||
|
||||
@@ -129,6 +121,7 @@ export class FilterManager {
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
this.applyTagElementState(tagEl, (this.filters.tags && this.filters.tags[tagName]) || 'none');
|
||||
tagsContainer.appendChild(tagEl);
|
||||
});
|
||||
}
|
||||
@@ -310,11 +303,8 @@ export class FilterManager {
|
||||
const modelTags = document.querySelectorAll('.tag-filter');
|
||||
modelTags.forEach(tag => {
|
||||
const tagName = tag.dataset.tag;
|
||||
if (this.filters.tags.includes(tagName)) {
|
||||
tag.classList.add('active');
|
||||
} else {
|
||||
tag.classList.remove('active');
|
||||
}
|
||||
const state = (this.filters.tags && this.filters.tags[tagName]) || 'none';
|
||||
this.applyTagElementState(tag, state);
|
||||
});
|
||||
|
||||
// Update license tags
|
||||
@@ -322,9 +312,9 @@ export class FilterManager {
|
||||
}
|
||||
|
||||
updateActiveFiltersCount() {
|
||||
const totalActiveFilters = this.filters.baseModel.length +
|
||||
this.filters.tags.length +
|
||||
(this.filters.license ? Object.keys(this.filters.license).length : 0);
|
||||
const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
|
||||
const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
|
||||
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount;
|
||||
|
||||
if (this.activeFiltersCount) {
|
||||
if (totalActiveFilters > 0) {
|
||||
@@ -341,10 +331,11 @@ export class FilterManager {
|
||||
const storageKey = `${this.currentPage}_filters`;
|
||||
|
||||
// Save filters to localStorage
|
||||
setStorageItem(storageKey, this.filters);
|
||||
const filtersSnapshot = this.cloneFilters();
|
||||
setStorageItem(storageKey, filtersSnapshot);
|
||||
|
||||
// Update state with current filters
|
||||
pageState.filters = { ...this.filters };
|
||||
pageState.filters = filtersSnapshot;
|
||||
|
||||
// Call the appropriate manager's load method based on page type
|
||||
if (this.currentPage === 'recipes' && window.recipeManager) {
|
||||
@@ -359,7 +350,7 @@ export class FilterManager {
|
||||
this.filterButton.classList.add('active');
|
||||
if (showToastNotification) {
|
||||
const baseModelCount = this.filters.baseModel.length;
|
||||
const tagsCount = this.filters.tags.length;
|
||||
const tagsCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
|
||||
|
||||
let message = '';
|
||||
if (baseModelCount > 0 && tagsCount > 0) {
|
||||
@@ -382,15 +373,16 @@ export class FilterManager {
|
||||
|
||||
async clearFilters() {
|
||||
// Clear all filters
|
||||
this.filters = {
|
||||
this.filters = this.initializeFilters({
|
||||
...this.filters,
|
||||
baseModel: [],
|
||||
tags: [],
|
||||
license: {} // Initialize with empty object instead of deleting
|
||||
};
|
||||
tags: {},
|
||||
license: {}
|
||||
});
|
||||
|
||||
// Update state
|
||||
const pageState = getCurrentPageState();
|
||||
pageState.filters = { ...this.filters };
|
||||
pageState.filters = this.cloneFilters();
|
||||
|
||||
// Update UI
|
||||
this.updateTagSelections();
|
||||
@@ -424,15 +416,11 @@ export class FilterManager {
|
||||
if (savedFilters) {
|
||||
try {
|
||||
// Ensure backward compatibility with older filter format
|
||||
this.filters = {
|
||||
baseModel: savedFilters.baseModel || [],
|
||||
tags: savedFilters.tags || [],
|
||||
license: savedFilters.license || {}
|
||||
};
|
||||
this.filters = this.initializeFilters(savedFilters);
|
||||
|
||||
// Update state with loaded filters
|
||||
const pageState = getCurrentPageState();
|
||||
pageState.filters = { ...this.filters };
|
||||
pageState.filters = this.cloneFilters();
|
||||
|
||||
this.updateTagSelections();
|
||||
this.updateActiveFiltersCount();
|
||||
@@ -447,8 +435,109 @@ export class FilterManager {
|
||||
}
|
||||
|
||||
hasActiveFilters() {
|
||||
return this.filters.baseModel.length > 0 ||
|
||||
this.filters.tags.length > 0 ||
|
||||
(this.filters.license && Object.keys(this.filters.license).length > 0);
|
||||
const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
|
||||
const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
|
||||
return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0;
|
||||
}
|
||||
|
||||
initializeFilters(existingFilters = {}) {
|
||||
const source = existingFilters || {};
|
||||
return {
|
||||
...source,
|
||||
baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [],
|
||||
tags: this.normalizeTagFilters(source.tags),
|
||||
license: this.normalizeLicenseFilters(source.license)
|
||||
};
|
||||
}
|
||||
|
||||
normalizeTagFilters(tagFilters) {
|
||||
if (!tagFilters) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (Array.isArray(tagFilters)) {
|
||||
return tagFilters.reduce((acc, tag) => {
|
||||
if (typeof tag === 'string' && tag.trim().length > 0) {
|
||||
acc[tag] = 'include';
|
||||
}
|
||||
return acc;
|
||||
}, {});
|
||||
}
|
||||
|
||||
if (typeof tagFilters === 'object') {
|
||||
const normalized = {};
|
||||
Object.entries(tagFilters).forEach(([tag, state]) => {
|
||||
if (!tag) {
|
||||
return;
|
||||
}
|
||||
const normalizedState = typeof state === 'string' ? state.toLowerCase() : '';
|
||||
if (normalizedState === 'include' || normalizedState === 'exclude') {
|
||||
normalized[tag] = normalizedState;
|
||||
}
|
||||
});
|
||||
return normalized;
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
normalizeLicenseFilters(licenseFilters) {
|
||||
if (!licenseFilters || typeof licenseFilters !== 'object') {
|
||||
return {};
|
||||
}
|
||||
|
||||
const normalized = {};
|
||||
Object.entries(licenseFilters).forEach(([key, state]) => {
|
||||
const normalizedState = typeof state === 'string' ? state.toLowerCase() : '';
|
||||
if (normalizedState === 'include' || normalizedState === 'exclude') {
|
||||
normalized[key] = normalizedState;
|
||||
}
|
||||
});
|
||||
return normalized;
|
||||
}
|
||||
|
||||
cloneFilters() {
|
||||
return {
|
||||
...this.filters,
|
||||
baseModel: [...(this.filters.baseModel || [])],
|
||||
tags: { ...(this.filters.tags || {}) },
|
||||
license: { ...(this.filters.license || {}) }
|
||||
};
|
||||
}
|
||||
|
||||
getNextTriStateState(currentState) {
|
||||
switch (currentState) {
|
||||
case 'none':
|
||||
return 'include';
|
||||
case 'include':
|
||||
return 'exclude';
|
||||
default:
|
||||
return 'none';
|
||||
}
|
||||
}
|
||||
|
||||
setTagFilterState(tagName, state) {
|
||||
if (!this.filters.tags) {
|
||||
this.filters.tags = {};
|
||||
}
|
||||
|
||||
if (state === 'none') {
|
||||
delete this.filters.tags[tagName];
|
||||
} else {
|
||||
this.filters.tags[tagName] = state;
|
||||
}
|
||||
}
|
||||
|
||||
applyTagElementState(element, state) {
|
||||
if (!element) {
|
||||
return;
|
||||
}
|
||||
|
||||
element.classList.remove('active', 'exclude');
|
||||
if (state === 'include') {
|
||||
element.classList.add('active');
|
||||
} else if (state === 'exclude') {
|
||||
element.classList.add('exclude');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,18 +66,19 @@ export const state = {
|
||||
activeFolder: getStorageItem(`${MODEL_TYPES.LORA}_activeFolder`),
|
||||
activeLetterFilter: null,
|
||||
previewVersions: loraPreviewVersions,
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
filename: true,
|
||||
modelname: true,
|
||||
tags: false,
|
||||
creator: false,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: []
|
||||
},
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
filename: true,
|
||||
modelname: true,
|
||||
tags: false,
|
||||
creator: false,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.LORA}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {}
|
||||
},
|
||||
bulkMode: false,
|
||||
selectedLoras: new Set(),
|
||||
loraMetadataCache: new Map(),
|
||||
@@ -91,18 +92,19 @@ export const state = {
|
||||
isLoading: false,
|
||||
hasMore: true,
|
||||
sortBy: 'date',
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
title: true,
|
||||
tags: true,
|
||||
loraName: true,
|
||||
loraModel: true
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: [],
|
||||
search: ''
|
||||
},
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
title: true,
|
||||
tags: true,
|
||||
loraName: true,
|
||||
loraModel: true
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {},
|
||||
search: ''
|
||||
},
|
||||
pageSize: 20,
|
||||
showFavoritesOnly: false,
|
||||
duplicatesMode: false,
|
||||
@@ -117,17 +119,18 @@ export const state = {
|
||||
sortBy: 'name',
|
||||
activeFolder: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_activeFolder`),
|
||||
previewVersions: checkpointPreviewVersions,
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
filename: true,
|
||||
modelname: true,
|
||||
creator: false,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: []
|
||||
},
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
filename: true,
|
||||
modelname: true,
|
||||
creator: false,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.CHECKPOINT}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {}
|
||||
},
|
||||
modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model'
|
||||
bulkMode: false,
|
||||
selectedModels: new Set(),
|
||||
@@ -145,18 +148,19 @@ export const state = {
|
||||
activeFolder: getStorageItem(`${MODEL_TYPES.EMBEDDING}_activeFolder`),
|
||||
activeLetterFilter: null,
|
||||
previewVersions: embeddingPreviewVersions,
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
filename: true,
|
||||
modelname: true,
|
||||
tags: false,
|
||||
creator: false,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: []
|
||||
},
|
||||
searchManager: null,
|
||||
searchOptions: {
|
||||
filename: true,
|
||||
modelname: true,
|
||||
tags: false,
|
||||
creator: false,
|
||||
recursive: getStorageItem(`${MODEL_TYPES.EMBEDDING}_recursiveSearch`, true),
|
||||
},
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {}
|
||||
},
|
||||
bulkMode: false,
|
||||
selectedModels: new Set(),
|
||||
metadataCache: new Map(),
|
||||
|
||||
@@ -225,21 +225,31 @@ describe('FilterManager tag and base model filters', () => {
|
||||
tagChip.dispatchEvent(new Event('click', { bubbles: true }));
|
||||
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
|
||||
|
||||
expect(getCurrentPageState().filters.tags).toEqual(['style']);
|
||||
expect(getCurrentPageState().filters.tags).toEqual({ style: 'include' });
|
||||
expect(tagChip.classList.contains('active')).toBe(true);
|
||||
expect(document.getElementById('activeFiltersCount').textContent).toBe('1');
|
||||
expect(document.getElementById('activeFiltersCount').style.display).toBe('inline-flex');
|
||||
|
||||
const storageKey = `lora_manager_${pageKey}_filters`;
|
||||
const storedFilters = JSON.parse(localStorage.getItem(storageKey));
|
||||
expect(storedFilters.tags).toEqual(['style']);
|
||||
expect(storedFilters.tags).toEqual({ style: 'include' });
|
||||
|
||||
loadMoreWithVirtualScrollMock.mockClear();
|
||||
|
||||
tagChip.dispatchEvent(new Event('click', { bubbles: true }));
|
||||
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
|
||||
|
||||
expect(getCurrentPageState().filters.tags).toEqual([]);
|
||||
expect(getCurrentPageState().filters.tags).toEqual({ style: 'exclude' });
|
||||
expect(tagChip.classList.contains('exclude')).toBe(true);
|
||||
expect(tagChip.classList.contains('active')).toBe(false);
|
||||
expect(document.getElementById('activeFiltersCount').textContent).toBe('1');
|
||||
|
||||
loadMoreWithVirtualScrollMock.mockClear();
|
||||
|
||||
tagChip.dispatchEvent(new Event('click', { bubbles: true }));
|
||||
await vi.waitFor(() => expect(loadMoreWithVirtualScrollMock).toHaveBeenCalledTimes(1));
|
||||
|
||||
expect(getCurrentPageState().filters.tags).toEqual({});
|
||||
expect(document.getElementById('activeFiltersCount').style.display).toBe('none');
|
||||
});
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from py.services.lora_service import LoraService
|
||||
from py.services.checkpoint_service import CheckpointService
|
||||
from py.services.embedding_service import EmbeddingService
|
||||
from py.services.model_query import (
|
||||
FilterCriteria,
|
||||
ModelCacheRepository,
|
||||
ModelFilterSet,
|
||||
SearchStrategy,
|
||||
@@ -126,7 +127,7 @@ async def test_get_paginated_data_uses_injected_collaborators():
|
||||
search="query",
|
||||
fuzzy_search=True,
|
||||
base_models=["base"],
|
||||
tags=["tag"],
|
||||
tags={"tag": "include"},
|
||||
search_options={"recursive": False},
|
||||
favorites_only=True,
|
||||
)
|
||||
@@ -141,7 +142,7 @@ async def test_get_paginated_data_uses_injected_collaborators():
|
||||
assert call_data == data
|
||||
assert criteria.folder == "root"
|
||||
assert criteria.base_models == ["base"]
|
||||
assert criteria.tags == ["tag"]
|
||||
assert criteria.tags == {"tag": "include"}
|
||||
assert criteria.favorites_only is True
|
||||
assert criteria.search_options.get("recursive") is False
|
||||
|
||||
@@ -234,7 +235,7 @@ async def test_get_paginated_data_filters_and_searches_combination():
|
||||
folder="root",
|
||||
search="artist",
|
||||
base_models=["v1"],
|
||||
tags=["tag1"],
|
||||
tags={"tag1": "include"},
|
||||
search_options={"creator": True, "tags": True},
|
||||
favorites_only=True,
|
||||
)
|
||||
@@ -533,6 +534,36 @@ async def test_get_paginated_data_update_available_only_without_update_service()
|
||||
assert response["total_pages"] == 0
|
||||
|
||||
|
||||
def test_model_filter_set_handles_include_and_exclude_tag_filters():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "StyleOnly", "tags": ["style"]},
|
||||
{"model_name": "StyleAnime", "tags": ["style", "anime"]},
|
||||
{"model_name": "AnimeOnly", "tags": ["anime"]},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(tags={"style": "include", "anime": "exclude"})
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["StyleOnly"]
|
||||
|
||||
|
||||
def test_model_filter_set_supports_legacy_tag_arrays():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "StyleOnly", "tags": ["style"]},
|
||||
{"model_name": "StyleAnime", "tags": ["style", "anime"]},
|
||||
{"model_name": "AnimeOnly", "tags": ["anime"]},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(tags=["style"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"service_cls, extra_fields",
|
||||
|
||||
Reference in New Issue
Block a user