feat(filters): add model type filter

This commit is contained in:
Will Miao
2025-11-18 16:43:44 +08:00
parent 57f369a6de
commit e9d55fe146
17 changed files with 179 additions and 17 deletions

View File

@@ -195,6 +195,7 @@
"title": "Modelle filtern", "title": "Modelle filtern",
"baseModel": "Basis-Modell", "baseModel": "Basis-Modell",
"modelTags": "Tags (Top 20)", "modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "Lizenz", "license": "Lizenz",
"noCreditRequired": "Kein Credit erforderlich", "noCreditRequired": "Kein Credit erforderlich",
"allowSellingGeneratedContent": "Verkauf erlaubt", "allowSellingGeneratedContent": "Verkauf erlaubt",

View File

@@ -195,6 +195,7 @@
"title": "Filter Models", "title": "Filter Models",
"baseModel": "Base Model", "baseModel": "Base Model",
"modelTags": "Tags (Top 20)", "modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "License", "license": "License",
"noCreditRequired": "No Credit Required", "noCreditRequired": "No Credit Required",
"allowSellingGeneratedContent": "Allow Selling", "allowSellingGeneratedContent": "Allow Selling",

View File

@@ -195,6 +195,7 @@
"title": "Filtrar modelos", "title": "Filtrar modelos",
"baseModel": "Modelo base", "baseModel": "Modelo base",
"modelTags": "Etiquetas (Top 20)", "modelTags": "Etiquetas (Top 20)",
"modelTypes": "Model Types",
"license": "Licencia", "license": "Licencia",
"noCreditRequired": "Sin crédito requerido", "noCreditRequired": "Sin crédito requerido",
"allowSellingGeneratedContent": "Venta permitida", "allowSellingGeneratedContent": "Venta permitida",

View File

@@ -195,6 +195,7 @@
"title": "Filtrer les modèles", "title": "Filtrer les modèles",
"baseModel": "Modèle de base", "baseModel": "Modèle de base",
"modelTags": "Tags (Top 20)", "modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "Licence", "license": "Licence",
"noCreditRequired": "Crédit non requis", "noCreditRequired": "Crédit non requis",
"allowSellingGeneratedContent": "Vente autorisée", "allowSellingGeneratedContent": "Vente autorisée",

View File

@@ -195,6 +195,7 @@
"title": "סנן מודלים", "title": "סנן מודלים",
"baseModel": "מודל בסיס", "baseModel": "מודל בסיס",
"modelTags": "תגיות (20 המובילות)", "modelTags": "תגיות (20 המובילות)",
"modelTypes": "Model Types",
"license": "רישיון", "license": "רישיון",
"noCreditRequired": "ללא קרדיט נדרש", "noCreditRequired": "ללא קרדיט נדרש",
"allowSellingGeneratedContent": "אפשר מכירה", "allowSellingGeneratedContent": "אפשר מכירה",

View File

@@ -195,6 +195,7 @@
"title": "モデルをフィルタ", "title": "モデルをフィルタ",
"baseModel": "ベースモデル", "baseModel": "ベースモデル",
"modelTags": "タグ上位20", "modelTags": "タグ上位20",
"modelTypes": "Model Types",
"license": "ライセンス", "license": "ライセンス",
"noCreditRequired": "クレジット不要", "noCreditRequired": "クレジット不要",
"allowSellingGeneratedContent": "販売許可", "allowSellingGeneratedContent": "販売許可",

View File

@@ -195,6 +195,7 @@
"title": "모델 필터", "title": "모델 필터",
"baseModel": "베이스 모델", "baseModel": "베이스 모델",
"modelTags": "태그 (상위 20개)", "modelTags": "태그 (상위 20개)",
"modelTypes": "Model Types",
"license": "라이선스", "license": "라이선스",
"noCreditRequired": "크레딧 표기 없음", "noCreditRequired": "크레딧 표기 없음",
"allowSellingGeneratedContent": "판매 허용", "allowSellingGeneratedContent": "판매 허용",

View File

@@ -195,6 +195,7 @@
"title": "Фильтр моделей", "title": "Фильтр моделей",
"baseModel": "Базовая модель", "baseModel": "Базовая модель",
"modelTags": "Теги (Топ 20)", "modelTags": "Теги (Топ 20)",
"modelTypes": "Model Types",
"license": "Лицензия", "license": "Лицензия",
"noCreditRequired": "Без указания авторства", "noCreditRequired": "Без указания авторства",
"allowSellingGeneratedContent": "Продажа разрешена", "allowSellingGeneratedContent": "Продажа разрешена",

View File

@@ -195,6 +195,7 @@
"title": "筛选模型", "title": "筛选模型",
"baseModel": "基础模型", "baseModel": "基础模型",
"modelTags": "标签前20", "modelTags": "标签前20",
"modelTypes": "Model Types",
"license": "许可证", "license": "许可证",
"noCreditRequired": "无需署名", "noCreditRequired": "无需署名",
"allowSellingGeneratedContent": "允许销售", "allowSellingGeneratedContent": "允许销售",

View File

@@ -195,6 +195,7 @@
"title": "篩選模型", "title": "篩選模型",
"baseModel": "基礎模型", "baseModel": "基礎模型",
"modelTags": "標籤(前 20", "modelTags": "標籤(前 20",
"modelTypes": "Model Types",
"license": "授權", "license": "授權",
"noCreditRequired": "無需署名", "noCreditRequired": "無需署名",
"allowSellingGeneratedContent": "允許銷售", "allowSellingGeneratedContent": "允許銷售",

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import logging import logging
import os import os
from ..utils.constants import VALID_LORA_TYPES
from ..utils.models import BaseModelMetadata from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .model_query import ( from .model_query import (
@@ -12,6 +13,7 @@ from .model_query import (
ModelFilterSet, ModelFilterSet,
SearchStrategy, SearchStrategy,
SettingsProvider, SettingsProvider,
normalize_civitai_model_type,
resolve_civitai_model_type, resolve_civitai_model_type,
) )
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
@@ -469,12 +471,15 @@ class BaseModelService(ABC):
return await self.scanner.get_base_models(limit) return await self.scanner.get_base_models(limit)
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]: async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
"""Get counts of CivitAI model types present in the cache.""" """Get counts of normalized CivitAI model types present in the cache."""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
type_counts: Dict[str, int] = {} type_counts: Dict[str, int] = {}
for entry in cache.raw_data: for entry in cache.raw_data:
model_type = resolve_civitai_model_type(entry) normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry))
type_counts[model_type] = type_counts.get(model_type, 0) + 1 if not normalized_type or normalized_type not in VALID_LORA_TYPES:
continue
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
sorted_types = sorted( sorted_types = sorted(
[{"type": model_type, "count": count} for model_type, count in type_counts.items()], [{"type": model_type, "count": count} for model_type, count in type_counts.items()],

View File

@@ -848,6 +848,12 @@ export class BaseModelApiClient {
} }
} }
} }
if (pageState.filters.modelTypes && pageState.filters.modelTypes.length > 0) {
pageState.filters.modelTypes.forEach((type) => {
params.append('model_type', type);
});
}
} }
this._addModelSpecificParams(params, pageState); this._addModelSpecificParams(params, pageState);

View File

@@ -2,6 +2,7 @@ import { getCurrentPageState } from '../state/index.js';
import { showToast, updatePanelPositions } from '../utils/uiHelpers.js'; import { showToast, updatePanelPositions } from '../utils/uiHelpers.js';
import { getModelApiClient } from '../api/modelApiFactory.js'; import { getModelApiClient } from '../api/modelApiFactory.js';
import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js'; import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js';
import { MODEL_TYPE_DISPLAY_NAMES } from '../utils/constants.js';
export class FilterManager { export class FilterManager {
constructor(options = {}) { constructor(options = {}) {
@@ -34,6 +35,10 @@ export class FilterManager {
this.createBaseModelTags(); this.createBaseModelTags();
} }
if (document.getElementById('modelTypeTags')) {
this.createModelTypeTags();
}
// Add click handlers for license filter tags // Add click handlers for license filter tags
this.initializeLicenseFilters(); this.initializeLicenseFilters();
@@ -248,12 +253,86 @@ export class FilterManager {
// Update selections based on stored filters // Update selections based on stored filters
this.updateTagSelections(); this.updateTagSelections();
}
})
.catch(error => {
console.error(`Error fetching base models for ${this.currentPage}:`, error);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
});
}
async createModelTypeTags() {
const modelTypeContainer = document.getElementById('modelTypeTags');
if (!modelTypeContainer) return;
modelTypeContainer.innerHTML = '<div class="tags-loading">Loading model types...</div>';
try {
const response = await fetch(`/api/lm/${this.currentPage}/model-types?limit=20`);
if (!response.ok) {
throw new Error('Failed to fetch model types');
}
const data = await response.json();
if (!data.success || !Array.isArray(data.model_types)) {
throw new Error('Invalid response format');
}
const normalizedTypes = data.model_types
.map(entry => {
if (!entry || !entry.type) {
return null;
}
const typeKey = entry.type.toString().trim().toLowerCase();
if (!typeKey || !MODEL_TYPE_DISPLAY_NAMES[typeKey]) {
return null;
}
return {
type: typeKey,
count: Number(entry.count) || 0,
};
})
.filter(Boolean);
if (!normalizedTypes.length) {
modelTypeContainer.innerHTML = '<div class="no-tags">No model types available</div>';
return;
}
modelTypeContainer.innerHTML = '';
normalizedTypes.forEach(({ type, count }) => {
const tag = document.createElement('div');
tag.className = 'filter-tag model-type-tag';
tag.dataset.modelType = type;
tag.innerHTML = `${MODEL_TYPE_DISPLAY_NAMES[type]} <span class="tag-count">${count}</span>`;
if (this.filters.modelTypes.includes(type)) {
tag.classList.add('active');
} }
})
.catch(error => { tag.addEventListener('click', async () => {
console.error(`Error fetching base models for ${this.currentPage}:`, error); const isSelected = this.filters.modelTypes.includes(type);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>'; if (isSelected) {
this.filters.modelTypes = this.filters.modelTypes.filter(value => value !== type);
tag.classList.remove('active');
} else {
this.filters.modelTypes.push(type);
tag.classList.add('active');
}
this.updateActiveFiltersCount();
await this.applyFilters(false);
});
modelTypeContainer.appendChild(tag);
}); });
this.updateModelTypeSelections();
} catch (error) {
console.error('Error loading model types:', error);
modelTypeContainer.innerHTML = '<div class="tags-error">Failed to load model types</div>';
}
} }
toggleFilterPanel() { toggleFilterPanel() {
@@ -309,12 +388,26 @@ export class FilterManager {
// Update license tags // Update license tags
this.updateLicenseSelections(); this.updateLicenseSelections();
this.updateModelTypeSelections();
}
updateModelTypeSelections() {
const typeTags = document.querySelectorAll('.model-type-tag');
typeTags.forEach(tag => {
const modelType = tag.dataset.modelType;
if (this.filters.modelTypes.includes(modelType)) {
tag.classList.add('active');
} else {
tag.classList.remove('active');
}
});
} }
updateActiveFiltersCount() { updateActiveFiltersCount() {
const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0; const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount; const modelTypeFilterCount = this.filters.modelTypes.length;
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount + modelTypeFilterCount;
if (this.activeFiltersCount) { if (this.activeFiltersCount) {
if (totalActiveFilters > 0) { if (totalActiveFilters > 0) {
@@ -377,7 +470,8 @@ export class FilterManager {
...this.filters, ...this.filters,
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}); });
// Update state // Update state
@@ -437,7 +531,13 @@ export class FilterManager {
hasActiveFilters() { hasActiveFilters() {
const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0; const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0; const modelTypeCount = this.filters.modelTypes.length;
return (
this.filters.baseModel.length > 0 ||
tagCount > 0 ||
licenseCount > 0 ||
modelTypeCount > 0
);
} }
initializeFilters(existingFilters = {}) { initializeFilters(existingFilters = {}) {
@@ -446,7 +546,8 @@ export class FilterManager {
...source, ...source,
baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [], baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [],
tags: this.normalizeTagFilters(source.tags), tags: this.normalizeTagFilters(source.tags),
license: this.normalizeLicenseFilters(source.license) license: this.normalizeLicenseFilters(source.license),
modelTypes: this.normalizeModelTypeFilters(source.modelTypes)
}; };
} }
@@ -496,12 +597,35 @@ export class FilterManager {
return normalized; return normalized;
} }
normalizeModelTypeFilters(modelTypes) {
if (!Array.isArray(modelTypes)) {
return [];
}
const seen = new Set();
return modelTypes.reduce((acc, type) => {
if (typeof type !== 'string') {
return acc;
}
const normalized = type.trim().toLowerCase();
if (!normalized || seen.has(normalized)) {
return acc;
}
seen.add(normalized);
acc.push(normalized);
return acc;
}, []);
}
cloneFilters() { cloneFilters() {
return { return {
...this.filters, ...this.filters,
baseModel: [...(this.filters.baseModel || [])], baseModel: [...(this.filters.baseModel || [])],
tags: { ...(this.filters.tags || {}) }, tags: { ...(this.filters.tags || {}) },
license: { ...(this.filters.license || {}) } license: { ...(this.filters.license || {}) },
modelTypes: [...(this.filters.modelTypes || [])]
}; };
} }

View File

@@ -79,7 +79,8 @@ export const state = {
filters: { filters: {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}, },
bulkMode: false, bulkMode: false,
selectedLoras: new Set(), selectedLoras: new Set(),
@@ -105,6 +106,7 @@ export const state = {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {}, license: {},
modelTypes: [],
search: '' search: ''
}, },
pageSize: 20, pageSize: 20,
@@ -131,7 +133,8 @@ export const state = {
filters: { filters: {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}, },
modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model' modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model'
bulkMode: false, bulkMode: false,
@@ -161,7 +164,8 @@ export const state = {
filters: { filters: {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}, },
bulkMode: false, bulkMode: false,
selectedModels: new Set(), selectedModels: new Set(),

View File

@@ -55,6 +55,12 @@ export const BASE_MODELS = {
UNKNOWN: "Other" UNKNOWN: "Other"
}; };
export const MODEL_TYPE_DISPLAY_NAMES = {
lora: "LoRA",
locon: "LyCORIS",
dora: "DoRA",
};
export const BASE_MODEL_ABBREVIATIONS = { export const BASE_MODEL_ABBREVIATIONS = {
// Stable Diffusion 1.x models // Stable Diffusion 1.x models
[BASE_MODELS.SD_1_4]: 'SD1', [BASE_MODELS.SD_1_4]: 'SD1',

View File

@@ -139,6 +139,14 @@
<div class="tags-loading">{{ t('common.status.loading') }}</div> <div class="tags-loading">{{ t('common.status.loading') }}</div>
</div> </div>
</div> </div>
{% if current_page == 'loras' %}
<div class="filter-section">
<h4>{{ t('header.filter.modelTypes') }}</h4>
<div class="filter-tags" id="modelTypeTags">
<div class="tags-loading">{{ t('common.status.loading') }}</div>
</div>
</div>
{% endif %}
<div class="filter-section"> <div class="filter-section">
<h4>{{ t('header.filter.license') }}</h4> <h4>{{ t('header.filter.license') }}</h4>
<div class="filter-tags"> <div class="filter-tags">
@@ -156,4 +164,3 @@
</button> </button>
</div> </div>
</div> </div>

View File

@@ -834,7 +834,7 @@ async def test_get_model_types_counts_and_limits():
types = await service.get_model_types(limit=1) types = await service.get_model_types(limit=1)
assert types == [{"type": "LoRa", "count": 2}] assert types == [{"type": "lora", "count": 3}]
@pytest.mark.asyncio @pytest.mark.asyncio