diff --git a/locales/de.json b/locales/de.json index 7b9c426d..69035fe4 100644 --- a/locales/de.json +++ b/locales/de.json @@ -195,6 +195,7 @@ "title": "Modelle filtern", "baseModel": "Basis-Modell", "modelTags": "Tags (Top 20)", + "modelTypes": "Model Types", "license": "Lizenz", "noCreditRequired": "Kein Credit erforderlich", "allowSellingGeneratedContent": "Verkauf erlaubt", diff --git a/locales/en.json b/locales/en.json index d19a153b..c8a98ed2 100644 --- a/locales/en.json +++ b/locales/en.json @@ -195,6 +195,7 @@ "title": "Filter Models", "baseModel": "Base Model", "modelTags": "Tags (Top 20)", + "modelTypes": "Model Types", "license": "License", "noCreditRequired": "No Credit Required", "allowSellingGeneratedContent": "Allow Selling", diff --git a/locales/es.json b/locales/es.json index 47e7de39..93848935 100644 --- a/locales/es.json +++ b/locales/es.json @@ -195,6 +195,7 @@ "title": "Filtrar modelos", "baseModel": "Modelo base", "modelTags": "Etiquetas (Top 20)", + "modelTypes": "Model Types", "license": "Licencia", "noCreditRequired": "Sin crédito requerido", "allowSellingGeneratedContent": "Venta permitida", diff --git a/locales/fr.json b/locales/fr.json index f36908a1..9c6d3a34 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -195,6 +195,7 @@ "title": "Filtrer les modèles", "baseModel": "Modèle de base", "modelTags": "Tags (Top 20)", + "modelTypes": "Model Types", "license": "Licence", "noCreditRequired": "Crédit non requis", "allowSellingGeneratedContent": "Vente autorisée", diff --git a/locales/he.json b/locales/he.json index 5961086d..0ed1cd10 100644 --- a/locales/he.json +++ b/locales/he.json @@ -195,6 +195,7 @@ "title": "סנן מודלים", "baseModel": "מודל בסיס", "modelTags": "תגיות (20 המובילות)", + "modelTypes": "Model Types", "license": "רישיון", "noCreditRequired": "ללא קרדיט נדרש", "allowSellingGeneratedContent": "אפשר מכירה", diff --git a/locales/ja.json b/locales/ja.json index 626f0031..227923fe 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -195,6 +195,7 @@ "title": "モデルをフィルタ", "baseModel": "ベースモデル", "modelTags": "タグ(上位20)", + "modelTypes": "Model Types", "license": "ライセンス", "noCreditRequired": "クレジット不要", "allowSellingGeneratedContent": "販売許可", diff --git a/locales/ko.json b/locales/ko.json index f97de8e4..5942b619 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -195,6 +195,7 @@ "title": "모델 필터", "baseModel": "베이스 모델", "modelTags": "태그 (상위 20개)", + "modelTypes": "Model Types", "license": "라이선스", "noCreditRequired": "크레딧 표기 없음", "allowSellingGeneratedContent": "판매 허용", diff --git a/locales/ru.json b/locales/ru.json index e52a981b..10e26e58 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -195,6 +195,7 @@ "title": "Фильтр моделей", "baseModel": "Базовая модель", "modelTags": "Теги (Топ 20)", + "modelTypes": "Model Types", "license": "Лицензия", "noCreditRequired": "Без указания авторства", "allowSellingGeneratedContent": "Продажа разрешена", diff --git a/locales/zh-CN.json b/locales/zh-CN.json index 66615ff4..c8275ee0 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -195,6 +195,7 @@ "title": "筛选模型", "baseModel": "基础模型", "modelTags": "标签(前20)", + "modelTypes": "Model Types", "license": "许可证", "noCreditRequired": "无需署名", "allowSellingGeneratedContent": "允许销售", diff --git a/locales/zh-TW.json b/locales/zh-TW.json index 31be7e19..6eddf4ed 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -195,6 +195,7 @@ "title": "篩選模型", "baseModel": "基礎模型", "modelTags": "標籤(前 20)", + "modelTypes": "Model Types", "license": "授權", "noCreditRequired": "無需署名", "allowSellingGeneratedContent": "允許銷售", diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 1c983a39..c7f9de64 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING import logging import os +from ..utils.constants import VALID_LORA_TYPES from ..utils.models import BaseModelMetadata from ..utils.metadata_manager import MetadataManager from .model_query import ( @@ -12,6 +13,7 @@ from .model_query import ( ModelFilterSet, SearchStrategy, SettingsProvider, + normalize_civitai_model_type, resolve_civitai_model_type, ) from .settings_manager import get_settings_manager @@ -469,12 +471,15 @@ class BaseModelService(ABC): return await self.scanner.get_base_models(limit) async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]: - """Get counts of CivitAI model types present in the cache.""" + """Get counts of normalized CivitAI model types present in the cache.""" cache = await self.scanner.get_cached_data() + type_counts: Dict[str, int] = {} for entry in cache.raw_data: - model_type = resolve_civitai_model_type(entry) - type_counts[model_type] = type_counts.get(model_type, 0) + 1 + normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry)) + if not normalized_type or normalized_type not in VALID_LORA_TYPES: + continue + type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1 sorted_types = sorted( [{"type": model_type, "count": count} for model_type, count in type_counts.items()], diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 9b6c2300..e659dda8 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -848,6 +848,12 @@ export class BaseModelApiClient { } } } + + if (pageState.filters.modelTypes && pageState.filters.modelTypes.length > 0) { + pageState.filters.modelTypes.forEach((type) => { + params.append('model_type', type); + }); + } } this._addModelSpecificParams(params, pageState); diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 1f0500c2..23e4ac89 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -2,6 +2,7 @@ import { getCurrentPageState } from '../state/index.js'; import { showToast, updatePanelPositions } from '../utils/uiHelpers.js'; import { getModelApiClient } from '../api/modelApiFactory.js'; import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js'; +import { MODEL_TYPE_DISPLAY_NAMES } from '../utils/constants.js'; export class FilterManager { constructor(options = {}) { @@ -34,6 +35,10 @@ export class FilterManager { this.createBaseModelTags(); } + if (document.getElementById('modelTypeTags')) { + this.createModelTypeTags(); + } + // Add click handlers for license filter tags this.initializeLicenseFilters(); @@ -248,12 +253,86 @@ export class FilterManager { // Update selections based on stored filters this.updateTagSelections(); + } + }) + .catch(error => { + console.error(`Error fetching base models for ${this.currentPage}:`, error); + baseModelTagsContainer.innerHTML = '
'; + }); + } + + async createModelTypeTags() { + const modelTypeContainer = document.getElementById('modelTypeTags'); + if (!modelTypeContainer) return; + + modelTypeContainer.innerHTML = ''; + + try { + const response = await fetch(`/api/lm/${this.currentPage}/model-types?limit=20`); + if (!response.ok) { + throw new Error('Failed to fetch model types'); + } + + const data = await response.json(); + if (!data.success || !Array.isArray(data.model_types)) { + throw new Error('Invalid response format'); + } + + const normalizedTypes = data.model_types + .map(entry => { + if (!entry || !entry.type) { + return null; + } + const typeKey = entry.type.toString().trim().toLowerCase(); + if (!typeKey || !MODEL_TYPE_DISPLAY_NAMES[typeKey]) { + return null; + } + return { + type: typeKey, + count: Number(entry.count) || 0, + }; + }) + .filter(Boolean); + + if (!normalizedTypes.length) { + modelTypeContainer.innerHTML = ''; + return; + } + + modelTypeContainer.innerHTML = ''; + + normalizedTypes.forEach(({ type, count }) => { + const tag = document.createElement('div'); + tag.className = 'filter-tag model-type-tag'; + tag.dataset.modelType = type; + tag.innerHTML = `${MODEL_TYPE_DISPLAY_NAMES[type]} ${count}`; + + if (this.filters.modelTypes.includes(type)) { + tag.classList.add('active'); } - }) - .catch(error => { - console.error(`Error fetching base models for ${this.currentPage}:`, error); - baseModelTagsContainer.innerHTML = ''; + + tag.addEventListener('click', async () => { + const isSelected = this.filters.modelTypes.includes(type); + if (isSelected) { + this.filters.modelTypes = this.filters.modelTypes.filter(value => value !== type); + tag.classList.remove('active'); + } else { + this.filters.modelTypes.push(type); + tag.classList.add('active'); + } + + this.updateActiveFiltersCount(); + await this.applyFilters(false); + }); + + modelTypeContainer.appendChild(tag); }); + + this.updateModelTypeSelections(); + } catch (error) { + console.error('Error loading model types:', error); + modelTypeContainer.innerHTML = ''; + } } toggleFilterPanel() { @@ -309,12 +388,26 @@ export class FilterManager { // Update license tags this.updateLicenseSelections(); + this.updateModelTypeSelections(); + } + + updateModelTypeSelections() { + const typeTags = document.querySelectorAll('.model-type-tag'); + typeTags.forEach(tag => { + const modelType = tag.dataset.modelType; + if (this.filters.modelTypes.includes(modelType)) { + tag.classList.add('active'); + } else { + tag.classList.remove('active'); + } + }); } updateActiveFiltersCount() { const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0; - const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount; + const modelTypeFilterCount = this.filters.modelTypes.length; + const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount + modelTypeFilterCount; if (this.activeFiltersCount) { if (totalActiveFilters > 0) { @@ -377,7 +470,8 @@ export class FilterManager { ...this.filters, baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }); // Update state @@ -437,7 +531,13 @@ export class FilterManager { hasActiveFilters() { const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0; - return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0; + const modelTypeCount = this.filters.modelTypes.length; + return ( + this.filters.baseModel.length > 0 || + tagCount > 0 || + licenseCount > 0 || + modelTypeCount > 0 + ); } initializeFilters(existingFilters = {}) { @@ -446,7 +546,8 @@ export class FilterManager { ...source, baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [], tags: this.normalizeTagFilters(source.tags), - license: this.normalizeLicenseFilters(source.license) + license: this.normalizeLicenseFilters(source.license), + modelTypes: this.normalizeModelTypeFilters(source.modelTypes) }; } @@ -496,12 +597,35 @@ export class FilterManager { return normalized; } + normalizeModelTypeFilters(modelTypes) { + if (!Array.isArray(modelTypes)) { + return []; + } + + const seen = new Set(); + return modelTypes.reduce((acc, type) => { + if (typeof type !== 'string') { + return acc; + } + + const normalized = type.trim().toLowerCase(); + if (!normalized || seen.has(normalized)) { + return acc; + } + + seen.add(normalized); + acc.push(normalized); + return acc; + }, []); + } + cloneFilters() { return { ...this.filters, baseModel: [...(this.filters.baseModel || [])], tags: { ...(this.filters.tags || {}) }, - license: { ...(this.filters.license || {}) } + license: { ...(this.filters.license || {}) }, + modelTypes: [...(this.filters.modelTypes || [])] }; } diff --git a/static/js/state/index.js b/static/js/state/index.js index b8157f66..d565e674 100644 --- a/static/js/state/index.js +++ b/static/js/state/index.js @@ -79,7 +79,8 @@ export const state = { filters: { baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }, bulkMode: false, selectedLoras: new Set(), @@ -105,6 +106,7 @@ export const state = { baseModel: [], tags: {}, license: {}, + modelTypes: [], search: '' }, pageSize: 20, @@ -131,7 +133,8 @@ export const state = { filters: { baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }, modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model' bulkMode: false, @@ -161,7 +164,8 @@ export const state = { filters: { baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }, bulkMode: false, selectedModels: new Set(), diff --git a/static/js/utils/constants.js b/static/js/utils/constants.js index 8d56f454..ea6c7f3f 100644 --- a/static/js/utils/constants.js +++ b/static/js/utils/constants.js @@ -55,6 +55,12 @@ export const BASE_MODELS = { UNKNOWN: "Other" }; +export const MODEL_TYPE_DISPLAY_NAMES = { + lora: "LoRA", + locon: "LyCORIS", + dora: "DoRA", +}; + export const BASE_MODEL_ABBREVIATIONS = { // Stable Diffusion 1.x models [BASE_MODELS.SD_1_4]: 'SD1', diff --git a/templates/components/header.html b/templates/components/header.html index bbf4b8d1..2c358914 100644 --- a/templates/components/header.html +++ b/templates/components/header.html @@ -139,6 +139,14 @@ + {% if current_page == 'loras' %} +