diff --git a/locales/de.json b/locales/de.json index 7b9c426d..69035fe4 100644 --- a/locales/de.json +++ b/locales/de.json @@ -195,6 +195,7 @@ "title": "Modelle filtern", "baseModel": "Basis-Modell", "modelTags": "Tags (Top 20)", + "modelTypes": "Model Types", "license": "Lizenz", "noCreditRequired": "Kein Credit erforderlich", "allowSellingGeneratedContent": "Verkauf erlaubt", diff --git a/locales/en.json b/locales/en.json index d19a153b..c8a98ed2 100644 --- a/locales/en.json +++ b/locales/en.json @@ -195,6 +195,7 @@ "title": "Filter Models", "baseModel": "Base Model", "modelTags": "Tags (Top 20)", + "modelTypes": "Model Types", "license": "License", "noCreditRequired": "No Credit Required", "allowSellingGeneratedContent": "Allow Selling", diff --git a/locales/es.json b/locales/es.json index 47e7de39..93848935 100644 --- a/locales/es.json +++ b/locales/es.json @@ -195,6 +195,7 @@ "title": "Filtrar modelos", "baseModel": "Modelo base", "modelTags": "Etiquetas (Top 20)", + "modelTypes": "Model Types", "license": "Licencia", "noCreditRequired": "Sin crédito requerido", "allowSellingGeneratedContent": "Venta permitida", diff --git a/locales/fr.json b/locales/fr.json index f36908a1..9c6d3a34 100644 --- a/locales/fr.json +++ b/locales/fr.json @@ -195,6 +195,7 @@ "title": "Filtrer les modèles", "baseModel": "Modèle de base", "modelTags": "Tags (Top 20)", + "modelTypes": "Model Types", "license": "Licence", "noCreditRequired": "Crédit non requis", "allowSellingGeneratedContent": "Vente autorisée", diff --git a/locales/he.json b/locales/he.json index 5961086d..0ed1cd10 100644 --- a/locales/he.json +++ b/locales/he.json @@ -195,6 +195,7 @@ "title": "סנן מודלים", "baseModel": "מודל בסיס", "modelTags": "תגיות (20 המובילות)", + "modelTypes": "Model Types", "license": "רישיון", "noCreditRequired": "ללא קרדיט נדרש", "allowSellingGeneratedContent": "אפשר מכירה", diff --git a/locales/ja.json b/locales/ja.json index 626f0031..227923fe 100644 --- a/locales/ja.json +++ b/locales/ja.json @@ -195,6 +195,7 @@ "title": "モデルをフィルタ", "baseModel": "ベースモデル", "modelTags": "タグ(上位20)", + "modelTypes": "Model Types", "license": "ライセンス", "noCreditRequired": "クレジット不要", "allowSellingGeneratedContent": "販売許可", diff --git a/locales/ko.json b/locales/ko.json index f97de8e4..5942b619 100644 --- a/locales/ko.json +++ b/locales/ko.json @@ -195,6 +195,7 @@ "title": "모델 필터", "baseModel": "베이스 모델", "modelTags": "태그 (상위 20개)", + "modelTypes": "Model Types", "license": "라이선스", "noCreditRequired": "크레딧 표기 없음", "allowSellingGeneratedContent": "판매 허용", diff --git a/locales/ru.json b/locales/ru.json index e52a981b..10e26e58 100644 --- a/locales/ru.json +++ b/locales/ru.json @@ -195,6 +195,7 @@ "title": "Фильтр моделей", "baseModel": "Базовая модель", "modelTags": "Теги (Топ 20)", + "modelTypes": "Model Types", "license": "Лицензия", "noCreditRequired": "Без указания авторства", "allowSellingGeneratedContent": "Продажа разрешена", diff --git a/locales/zh-CN.json b/locales/zh-CN.json index 66615ff4..c8275ee0 100644 --- a/locales/zh-CN.json +++ b/locales/zh-CN.json @@ -195,6 +195,7 @@ "title": "筛选模型", "baseModel": "基础模型", "modelTags": "标签(前20)", + "modelTypes": "Model Types", "license": "许可证", "noCreditRequired": "无需署名", "allowSellingGeneratedContent": "允许销售", diff --git a/locales/zh-TW.json b/locales/zh-TW.json index 31be7e19..6eddf4ed 100644 --- a/locales/zh-TW.json +++ b/locales/zh-TW.json @@ -195,6 +195,7 @@ "title": "篩選模型", "baseModel": "基礎模型", "modelTags": "標籤(前 20)", + "modelTypes": "Model Types", "license": "授權", "noCreditRequired": "無需署名", "allowSellingGeneratedContent": "允許銷售", diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index fee783dd..c74dccca 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -152,6 +152,8 @@ class ModelListingHandler: fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true" base_models = request.query.getall("base_model", []) + model_types = list(request.query.getall("model_type", [])) + model_types.extend(request.query.getall("civitai_model_type", [])) # Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters legacy_tags = request.query.getall("tag", []) if not legacy_tags: @@ -225,6 +227,7 @@ class ModelListingHandler: "update_available_only": update_available_only, "credit_required": credit_required, "allow_selling_generated_content": allow_selling_generated_content, + "model_types": model_types, **self._parse_specific_params(request), } @@ -557,6 +560,17 @@ class ModelQueryHandler: self._logger.error("Error retrieving base models: %s", exc) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_model_types(self, request: web.Request) -> web.Response: + try: + limit = int(request.query.get("limit", "20")) + if limit < 1 or limit > 100: + limit = 20 + model_types = await self._service.get_model_types(limit) + return web.json_response({"success": True, "model_types": model_types}) + except Exception as exc: + self._logger.error("Error retrieving model types: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def scan_models(self, request: web.Request) -> web.Response: try: full_rebuild = request.query.get("full_rebuild", "false").lower() == "true" @@ -1579,6 +1593,7 @@ class ModelHandlerSet: "verify_duplicates": self.management.verify_duplicates, "get_top_tags": self.query.get_top_tags, "get_base_models": self.query.get_base_models, + "get_model_types": self.query.get_model_types, "scan_models": self.query.scan_models, "get_model_roots": self.query.get_model_roots, "get_folders": self.query.get_folders, diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index ce7a75ba..21589c7b 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -39,6 +39,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"), RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"), RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"), + RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"), RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"), RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"), RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"), diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 2fd393ec..c7f9de64 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -1,12 +1,21 @@ from abc import ABC, abstractmethod import asyncio -from typing import Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING import logging import os +from ..utils.constants import VALID_LORA_TYPES from ..utils.models import BaseModelMetadata from ..utils.metadata_manager import MetadataManager -from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider +from .model_query import ( + FilterCriteria, + ModelCacheRepository, + ModelFilterSet, + SearchStrategy, + SettingsProvider, + normalize_civitai_model_type, + resolve_civitai_model_type, +) from .settings_manager import get_settings_manager logger = logging.getLogger(__name__) @@ -59,6 +68,7 @@ class BaseModelService(ABC): search: str = None, fuzzy_search: bool = False, base_models: list = None, + model_types: list = None, tags: Optional[Dict[str, str]] = None, search_options: dict = None, hash_filters: dict = None, @@ -80,6 +90,7 @@ class BaseModelService(ABC): sorted_data, folder=folder, base_models=base_models, + model_types=model_types, tags=tags, favorites_only=favorites_only, search_options=search_options, @@ -149,6 +160,7 @@ class BaseModelService(ABC): data: List[Dict], folder: str = None, base_models: list = None, + model_types: list = None, tags: Optional[Dict[str, str]] = None, favorites_only: bool = False, search_options: dict = None, @@ -158,6 +170,7 @@ class BaseModelService(ABC): criteria = FilterCriteria( folder=folder, base_models=base_models, + model_types=model_types, tags=tags, favorites_only=favorites_only, search_options=normalized_options, @@ -456,6 +469,25 @@ class BaseModelService(ABC): async def get_base_models(self, limit: int = 20) -> List[Dict]: """Get base models sorted by frequency""" return await self.scanner.get_base_models(limit) + + async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]: + """Get counts of normalized CivitAI model types present in the cache.""" + cache = await self.scanner.get_cached_data() + + type_counts: Dict[str, int] = {} + for entry in cache.raw_data: + normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry)) + if not normalized_type or normalized_type not in VALID_LORA_TYPES: + continue + type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1 + + sorted_types = sorted( + [{"type": model_type, "count": count} for model_type, count in type_counts.items()], + key=lambda value: value["count"], + reverse=True, + ) + + return sorted_types[:limit] def has_hash(self, sha256: str) -> bool: """Check if a model with given hash exists""" diff --git a/py/services/model_query.py b/py/services/model_query.py index d88e9631..5b370138 100644 --- a/py/services/model_query.py +++ b/py/services/model_query.py @@ -1,12 +1,49 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable from ..utils.constants import NSFW_LEVELS from ..utils.utils import fuzzy_match as default_fuzzy_match +DEFAULT_CIVITAI_MODEL_TYPE = "LORA" + + +def _coerce_to_str(value: Any) -> Optional[str]: + if value is None: + return None + + candidate = str(value).strip() + return candidate if candidate else None + + +def normalize_civitai_model_type(value: Any) -> Optional[str]: + """Return a lowercase string suitable for comparisons.""" + candidate = _coerce_to_str(value) + return candidate.lower() if candidate else None + + +def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str: + """Extract the model type from CivitAI metadata, defaulting to LORA.""" + if not isinstance(entry, Mapping): + return DEFAULT_CIVITAI_MODEL_TYPE + + civitai = entry.get("civitai") + if isinstance(civitai, Mapping): + civitai_model = civitai.get("model") + if isinstance(civitai_model, Mapping): + model_type = _coerce_to_str(civitai_model.get("type")) + if model_type: + return model_type + + model_type = _coerce_to_str(entry.get("model_type")) + if model_type: + return model_type + + return DEFAULT_CIVITAI_MODEL_TYPE + + class SettingsProvider(Protocol): """Protocol describing the SettingsManager contract used by query helpers.""" @@ -31,6 +68,7 @@ class FilterCriteria: tags: Optional[Dict[str, str]] = None favorites_only: bool = False search_options: Optional[Dict[str, Any]] = None + model_types: Optional[Sequence[str]] = None class ModelCacheRepository: @@ -134,6 +172,19 @@ class ModelFilterSet: if not any(tag in exclude_tags for tag in (item.get("tags", []) or [])) ] + model_types = criteria.model_types or [] + normalized_model_types = { + model_type for model_type in ( + normalize_civitai_model_type(value) for value in model_types + ) + if model_type + } + if normalized_model_types: + items = [ + item for item in items + if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types + ] + return items diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index ebe05c42..8acaff4b 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -161,6 +161,12 @@ class ModelScanner: if trained_words: slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words + civitai_model = civitai.get('model') + if isinstance(civitai_model, Mapping): + model_type_value = civitai_model.get('type') + if model_type_value not in (None, '', []): + slim['model'] = {'type': model_type_value} + return slim or None def _build_cache_entry( diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py index cda1d0b3..c3ebcc27 100644 --- a/py/services/persistent_model_cache.py +++ b/py/services/persistent_model_cache.py @@ -5,7 +5,7 @@ import re import sqlite3 import threading from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Mapping, Optional, Sequence, Tuple from ..utils.settings_paths import get_project_root, get_settings_dir @@ -47,6 +47,7 @@ class PersistentModelCache: "metadata_source", "civitai_id", "civitai_model_id", + "civitai_model_type", "civitai_name", "civitai_creator_username", "trained_words", @@ -138,7 +139,8 @@ class PersistentModelCache: creator_username = row["civitai_creator_username"] civitai: Optional[Dict] = None civitai_has_data = any( - row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name") + row[col] is not None + for col in ("civitai_id", "civitai_model_id", "civitai_model_type", "civitai_name") ) or trained_words or creator_username if civitai_has_data: civitai = {} @@ -152,6 +154,9 @@ class PersistentModelCache: civitai["trainedWords"] = trained_words if creator_username: civitai.setdefault("creator", {})["username"] = creator_username + model_type_value = row["civitai_model_type"] + if model_type_value: + civitai.setdefault("model", {})["type"] = model_type_value license_value = row["license_flags"] if license_value is None: @@ -443,6 +448,7 @@ class PersistentModelCache: metadata_source TEXT, civitai_id INTEGER, civitai_model_id INTEGER, + civitai_model_type TEXT, civitai_name TEXT, civitai_creator_username TEXT, trained_words TEXT, @@ -492,6 +498,7 @@ class PersistentModelCache: required_columns = { "metadata_source": "TEXT", "civitai_creator_username": "TEXT", + "civitai_model_type": "TEXT", "civitai_deleted": "INTEGER DEFAULT 0", # Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57). "license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}", @@ -528,6 +535,13 @@ class PersistentModelCache: creator_data = civitai.get("creator") if isinstance(civitai, dict) else None if isinstance(creator_data, dict): creator_username = creator_data.get("username") or None + model_type_value = None + if isinstance(civitai, Mapping): + civitai_model_info = civitai.get("model") + if isinstance(civitai_model_info, Mapping): + candidate_type = civitai_model_info.get("type") + if candidate_type not in (None, "", []): + model_type_value = candidate_type license_flags = item.get("license_flags") if license_flags is None: @@ -552,6 +566,7 @@ class PersistentModelCache: metadata_source, civitai.get("id"), civitai.get("modelId"), + model_type_value, civitai.get("name"), creator_username, trained_words_json, diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 9b6c2300..e659dda8 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -848,6 +848,12 @@ export class BaseModelApiClient { } } } + + if (pageState.filters.modelTypes && pageState.filters.modelTypes.length > 0) { + pageState.filters.modelTypes.forEach((type) => { + params.append('model_type', type); + }); + } } this._addModelSpecificParams(params, pageState); diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index 1f0500c2..23e4ac89 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -2,6 +2,7 @@ import { getCurrentPageState } from '../state/index.js'; import { showToast, updatePanelPositions } from '../utils/uiHelpers.js'; import { getModelApiClient } from '../api/modelApiFactory.js'; import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js'; +import { MODEL_TYPE_DISPLAY_NAMES } from '../utils/constants.js'; export class FilterManager { constructor(options = {}) { @@ -34,6 +35,10 @@ export class FilterManager { this.createBaseModelTags(); } + if (document.getElementById('modelTypeTags')) { + this.createModelTypeTags(); + } + // Add click handlers for license filter tags this.initializeLicenseFilters(); @@ -248,12 +253,86 @@ export class FilterManager { // Update selections based on stored filters this.updateTagSelections(); + } + }) + .catch(error => { + console.error(`Error fetching base models for ${this.currentPage}:`, error); + baseModelTagsContainer.innerHTML = '
Failed to load base models
'; + }); + } + + async createModelTypeTags() { + const modelTypeContainer = document.getElementById('modelTypeTags'); + if (!modelTypeContainer) return; + + modelTypeContainer.innerHTML = '
Loading model types...
'; + + try { + const response = await fetch(`/api/lm/${this.currentPage}/model-types?limit=20`); + if (!response.ok) { + throw new Error('Failed to fetch model types'); + } + + const data = await response.json(); + if (!data.success || !Array.isArray(data.model_types)) { + throw new Error('Invalid response format'); + } + + const normalizedTypes = data.model_types + .map(entry => { + if (!entry || !entry.type) { + return null; + } + const typeKey = entry.type.toString().trim().toLowerCase(); + if (!typeKey || !MODEL_TYPE_DISPLAY_NAMES[typeKey]) { + return null; + } + return { + type: typeKey, + count: Number(entry.count) || 0, + }; + }) + .filter(Boolean); + + if (!normalizedTypes.length) { + modelTypeContainer.innerHTML = '
No model types available
'; + return; + } + + modelTypeContainer.innerHTML = ''; + + normalizedTypes.forEach(({ type, count }) => { + const tag = document.createElement('div'); + tag.className = 'filter-tag model-type-tag'; + tag.dataset.modelType = type; + tag.innerHTML = `${MODEL_TYPE_DISPLAY_NAMES[type]} ${count}`; + + if (this.filters.modelTypes.includes(type)) { + tag.classList.add('active'); } - }) - .catch(error => { - console.error(`Error fetching base models for ${this.currentPage}:`, error); - baseModelTagsContainer.innerHTML = '
Failed to load base models
'; + + tag.addEventListener('click', async () => { + const isSelected = this.filters.modelTypes.includes(type); + if (isSelected) { + this.filters.modelTypes = this.filters.modelTypes.filter(value => value !== type); + tag.classList.remove('active'); + } else { + this.filters.modelTypes.push(type); + tag.classList.add('active'); + } + + this.updateActiveFiltersCount(); + await this.applyFilters(false); + }); + + modelTypeContainer.appendChild(tag); }); + + this.updateModelTypeSelections(); + } catch (error) { + console.error('Error loading model types:', error); + modelTypeContainer.innerHTML = '
Failed to load model types
'; + } } toggleFilterPanel() { @@ -309,12 +388,26 @@ export class FilterManager { // Update license tags this.updateLicenseSelections(); + this.updateModelTypeSelections(); + } + + updateModelTypeSelections() { + const typeTags = document.querySelectorAll('.model-type-tag'); + typeTags.forEach(tag => { + const modelType = tag.dataset.modelType; + if (this.filters.modelTypes.includes(modelType)) { + tag.classList.add('active'); + } else { + tag.classList.remove('active'); + } + }); } updateActiveFiltersCount() { const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0; - const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount; + const modelTypeFilterCount = this.filters.modelTypes.length; + const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount + modelTypeFilterCount; if (this.activeFiltersCount) { if (totalActiveFilters > 0) { @@ -377,7 +470,8 @@ export class FilterManager { ...this.filters, baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }); // Update state @@ -437,7 +531,13 @@ export class FilterManager { hasActiveFilters() { const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0; - return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0; + const modelTypeCount = this.filters.modelTypes.length; + return ( + this.filters.baseModel.length > 0 || + tagCount > 0 || + licenseCount > 0 || + modelTypeCount > 0 + ); } initializeFilters(existingFilters = {}) { @@ -446,7 +546,8 @@ export class FilterManager { ...source, baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [], tags: this.normalizeTagFilters(source.tags), - license: this.normalizeLicenseFilters(source.license) + license: this.normalizeLicenseFilters(source.license), + modelTypes: this.normalizeModelTypeFilters(source.modelTypes) }; } @@ -496,12 +597,35 @@ export class FilterManager { return normalized; } + normalizeModelTypeFilters(modelTypes) { + if (!Array.isArray(modelTypes)) { + return []; + } + + const seen = new Set(); + return modelTypes.reduce((acc, type) => { + if (typeof type !== 'string') { + return acc; + } + + const normalized = type.trim().toLowerCase(); + if (!normalized || seen.has(normalized)) { + return acc; + } + + seen.add(normalized); + acc.push(normalized); + return acc; + }, []); + } + cloneFilters() { return { ...this.filters, baseModel: [...(this.filters.baseModel || [])], tags: { ...(this.filters.tags || {}) }, - license: { ...(this.filters.license || {}) } + license: { ...(this.filters.license || {}) }, + modelTypes: [...(this.filters.modelTypes || [])] }; } diff --git a/static/js/state/index.js b/static/js/state/index.js index b8157f66..d565e674 100644 --- a/static/js/state/index.js +++ b/static/js/state/index.js @@ -79,7 +79,8 @@ export const state = { filters: { baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }, bulkMode: false, selectedLoras: new Set(), @@ -105,6 +106,7 @@ export const state = { baseModel: [], tags: {}, license: {}, + modelTypes: [], search: '' }, pageSize: 20, @@ -131,7 +133,8 @@ export const state = { filters: { baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }, modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model' bulkMode: false, @@ -161,7 +164,8 @@ export const state = { filters: { baseModel: [], tags: {}, - license: {} + license: {}, + modelTypes: [] }, bulkMode: false, selectedModels: new Set(), diff --git a/static/js/utils/constants.js b/static/js/utils/constants.js index 8d56f454..ea6c7f3f 100644 --- a/static/js/utils/constants.js +++ b/static/js/utils/constants.js @@ -55,6 +55,12 @@ export const BASE_MODELS = { UNKNOWN: "Other" }; +export const MODEL_TYPE_DISPLAY_NAMES = { + lora: "LoRA", + locon: "LyCORIS", + dora: "DoRA", +}; + export const BASE_MODEL_ABBREVIATIONS = { // Stable Diffusion 1.x models [BASE_MODELS.SD_1_4]: 'SD1', diff --git a/templates/components/header.html b/templates/components/header.html index bbf4b8d1..2c358914 100644 --- a/templates/components/header.html +++ b/templates/components/header.html @@ -139,6 +139,14 @@
{{ t('common.status.loading') }}
+ {% if current_page == 'loras' %} +
+

{{ t('header.filter.modelTypes') }}

+
+
{{ t('common.status.loading') }}
+
+
+ {% endif %}

{{ t('header.filter.license') }}

@@ -156,4 +164,3 @@
- diff --git a/tests/conftest.py b/tests/conftest.py index 7f5becb5..5ed3bc72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -212,6 +212,7 @@ class MockModelService: self.model_type = "test-model" self.paginated_items: List[Dict[str, Any]] = [] self.formatted: List[Dict[str, Any]] = [] + self.model_types: List[Dict[str, Any]] = [] async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: items = [dict(item) for item in self.paginated_items] @@ -257,6 +258,9 @@ class MockModelService: async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover return [] + async def get_model_types(self, limit: int = 20): + return list(self.model_types)[:limit] + def has_hash(self, *_args, **_kwargs): # pragma: no cover return False @@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS def mock_service(mock_scanner: MockScanner) -> MockModelService: return MockModelService(scanner=mock_scanner) - diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index bd4b8550..55776dcd 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner): asyncio.run(scenario()) +def test_model_types_endpoint_returns_counts(mock_service, mock_scanner): + mock_service.model_types = [ + {"type": "LoRa", "count": 3}, + {"type": "Checkpoint", "count": 1}, + ] + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.get("/api/lm/test-models/model-types?limit=1") + payload = await response.json() + + assert response.status == 200 + assert payload["model_types"] == mock_service.model_types[:1] + finally: + await client.close() + + asyncio.run(scenario()) + + def test_routes_return_service_not_ready_when_unattached(): async def scenario(): client = await create_test_client(None) diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 8c412aa1..57ff3d67 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays(): assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"] +def test_model_filter_set_filters_by_model_types(): + settings = StubSettings({}) + filter_set = ModelFilterSet(settings) + data = [ + {"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}}, + {"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}}, + ] + + criteria = FilterCriteria(model_types=["locon"]) + result = filter_set.apply(data, criteria) + + assert [item["model_name"] for item in result] == ["LoConModel"] + + +def test_model_filter_set_defaults_missing_model_type_to_lora(): + settings = StubSettings({}) + filter_set = ModelFilterSet(settings) + data = [ + {"model_name": "DefaultModel"}, + {"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}}, + ] + + criteria = FilterCriteria(model_types=["lora"]) + result = filter_set.apply(data, criteria) + + assert [item["model_name"] for item in result] == ["DefaultModel"] + + +@pytest.mark.asyncio +async def test_get_model_types_counts_and_limits(): + raw_data = [ + {"civitai": {"model": {"type": "LoRa"}}}, + {"model_type": "LoRa"}, + {"civitai": {"model": {"type": "LoCon"}}}, + {}, + ] + + class CacheStub: + def __init__(self, raw_data): + self.raw_data = raw_data + + class ScannerStub: + def __init__(self, cache): + self._cache = cache + + async def get_cached_data(self, *_, **__): + return self._cache + + cache = CacheStub(raw_data) + scanner = ScannerStub(cache) + service = DummyService( + model_type="stub", + scanner=scanner, + metadata_class=BaseModelMetadata, + ) + + types = await service.get_model_types(limit=1) + + assert types == [{"type": "lora", "count": 3}] + + @pytest.mark.asyncio @pytest.mark.parametrize( "service_cls, extra_fields",