Merge pull request #682 from willmiao/feature/model-type-filter

Feature/model type filter
This commit is contained in:
pixelpaws
2025-11-18 18:51:52 +08:00
committed by GitHub
24 changed files with 380 additions and 19 deletions

View File

@@ -195,6 +195,7 @@
"title": "Modelle filtern", "title": "Modelle filtern",
"baseModel": "Basis-Modell", "baseModel": "Basis-Modell",
"modelTags": "Tags (Top 20)", "modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "Lizenz", "license": "Lizenz",
"noCreditRequired": "Kein Credit erforderlich", "noCreditRequired": "Kein Credit erforderlich",
"allowSellingGeneratedContent": "Verkauf erlaubt", "allowSellingGeneratedContent": "Verkauf erlaubt",

View File

@@ -195,6 +195,7 @@
"title": "Filter Models", "title": "Filter Models",
"baseModel": "Base Model", "baseModel": "Base Model",
"modelTags": "Tags (Top 20)", "modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "License", "license": "License",
"noCreditRequired": "No Credit Required", "noCreditRequired": "No Credit Required",
"allowSellingGeneratedContent": "Allow Selling", "allowSellingGeneratedContent": "Allow Selling",

View File

@@ -195,6 +195,7 @@
"title": "Filtrar modelos", "title": "Filtrar modelos",
"baseModel": "Modelo base", "baseModel": "Modelo base",
"modelTags": "Etiquetas (Top 20)", "modelTags": "Etiquetas (Top 20)",
"modelTypes": "Model Types",
"license": "Licencia", "license": "Licencia",
"noCreditRequired": "Sin crédito requerido", "noCreditRequired": "Sin crédito requerido",
"allowSellingGeneratedContent": "Venta permitida", "allowSellingGeneratedContent": "Venta permitida",

View File

@@ -195,6 +195,7 @@
"title": "Filtrer les modèles", "title": "Filtrer les modèles",
"baseModel": "Modèle de base", "baseModel": "Modèle de base",
"modelTags": "Tags (Top 20)", "modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "Licence", "license": "Licence",
"noCreditRequired": "Crédit non requis", "noCreditRequired": "Crédit non requis",
"allowSellingGeneratedContent": "Vente autorisée", "allowSellingGeneratedContent": "Vente autorisée",

View File

@@ -195,6 +195,7 @@
"title": "סנן מודלים", "title": "סנן מודלים",
"baseModel": "מודל בסיס", "baseModel": "מודל בסיס",
"modelTags": "תגיות (20 המובילות)", "modelTags": "תגיות (20 המובילות)",
"modelTypes": "Model Types",
"license": "רישיון", "license": "רישיון",
"noCreditRequired": "ללא קרדיט נדרש", "noCreditRequired": "ללא קרדיט נדרש",
"allowSellingGeneratedContent": "אפשר מכירה", "allowSellingGeneratedContent": "אפשר מכירה",

View File

@@ -195,6 +195,7 @@
"title": "モデルをフィルタ", "title": "モデルをフィルタ",
"baseModel": "ベースモデル", "baseModel": "ベースモデル",
"modelTags": "タグ上位20", "modelTags": "タグ上位20",
"modelTypes": "Model Types",
"license": "ライセンス", "license": "ライセンス",
"noCreditRequired": "クレジット不要", "noCreditRequired": "クレジット不要",
"allowSellingGeneratedContent": "販売許可", "allowSellingGeneratedContent": "販売許可",

View File

@@ -195,6 +195,7 @@
"title": "모델 필터", "title": "모델 필터",
"baseModel": "베이스 모델", "baseModel": "베이스 모델",
"modelTags": "태그 (상위 20개)", "modelTags": "태그 (상위 20개)",
"modelTypes": "Model Types",
"license": "라이선스", "license": "라이선스",
"noCreditRequired": "크레딧 표기 없음", "noCreditRequired": "크레딧 표기 없음",
"allowSellingGeneratedContent": "판매 허용", "allowSellingGeneratedContent": "판매 허용",

View File

@@ -195,6 +195,7 @@
"title": "Фильтр моделей", "title": "Фильтр моделей",
"baseModel": "Базовая модель", "baseModel": "Базовая модель",
"modelTags": "Теги (Топ 20)", "modelTags": "Теги (Топ 20)",
"modelTypes": "Model Types",
"license": "Лицензия", "license": "Лицензия",
"noCreditRequired": "Без указания авторства", "noCreditRequired": "Без указания авторства",
"allowSellingGeneratedContent": "Продажа разрешена", "allowSellingGeneratedContent": "Продажа разрешена",

View File

@@ -195,6 +195,7 @@
"title": "筛选模型", "title": "筛选模型",
"baseModel": "基础模型", "baseModel": "基础模型",
"modelTags": "标签前20", "modelTags": "标签前20",
"modelTypes": "Model Types",
"license": "许可证", "license": "许可证",
"noCreditRequired": "无需署名", "noCreditRequired": "无需署名",
"allowSellingGeneratedContent": "允许销售", "allowSellingGeneratedContent": "允许销售",

View File

@@ -195,6 +195,7 @@
"title": "篩選模型", "title": "篩選模型",
"baseModel": "基礎模型", "baseModel": "基礎模型",
"modelTags": "標籤(前 20", "modelTags": "標籤(前 20",
"modelTypes": "Model Types",
"license": "授權", "license": "授權",
"noCreditRequired": "無需署名", "noCreditRequired": "無需署名",
"allowSellingGeneratedContent": "允許銷售", "allowSellingGeneratedContent": "允許銷售",

View File

@@ -152,6 +152,8 @@ class ModelListingHandler:
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true" fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
base_models = request.query.getall("base_model", []) base_models = request.query.getall("base_model", [])
model_types = list(request.query.getall("model_type", []))
model_types.extend(request.query.getall("civitai_model_type", []))
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters # Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
legacy_tags = request.query.getall("tag", []) legacy_tags = request.query.getall("tag", [])
if not legacy_tags: if not legacy_tags:
@@ -225,6 +227,7 @@ class ModelListingHandler:
"update_available_only": update_available_only, "update_available_only": update_available_only,
"credit_required": credit_required, "credit_required": credit_required,
"allow_selling_generated_content": allow_selling_generated_content, "allow_selling_generated_content": allow_selling_generated_content,
"model_types": model_types,
**self._parse_specific_params(request), **self._parse_specific_params(request),
} }
@@ -557,6 +560,17 @@ class ModelQueryHandler:
self._logger.error("Error retrieving base models: %s", exc) self._logger.error("Error retrieving base models: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_types(self, request: web.Request) -> web.Response:
try:
limit = int(request.query.get("limit", "20"))
if limit < 1 or limit > 100:
limit = 20
model_types = await self._service.get_model_types(limit)
return web.json_response({"success": True, "model_types": model_types})
except Exception as exc:
self._logger.error("Error retrieving model types: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def scan_models(self, request: web.Request) -> web.Response: async def scan_models(self, request: web.Request) -> web.Response:
try: try:
full_rebuild = request.query.get("full_rebuild", "false").lower() == "true" full_rebuild = request.query.get("full_rebuild", "false").lower() == "true"
@@ -1579,6 +1593,7 @@ class ModelHandlerSet:
"verify_duplicates": self.management.verify_duplicates, "verify_duplicates": self.management.verify_duplicates,
"get_top_tags": self.query.get_top_tags, "get_top_tags": self.query.get_top_tags,
"get_base_models": self.query.get_base_models, "get_base_models": self.query.get_base_models,
"get_model_types": self.query.get_model_types,
"scan_models": self.query.scan_models, "scan_models": self.query.scan_models,
"get_model_roots": self.query.get_model_roots, "get_model_roots": self.query.get_model_roots,
"get_folders": self.query.get_folders, "get_folders": self.query.get_folders,

View File

@@ -39,6 +39,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"), RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"), RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"), RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"), RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"), RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"), RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),

View File

@@ -1,12 +1,21 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from typing import Dict, List, Optional, Type, TYPE_CHECKING from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import logging import logging
import os import os
from ..utils.constants import VALID_LORA_TYPES
from ..utils.models import BaseModelMetadata from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider from .model_query import (
FilterCriteria,
ModelCacheRepository,
ModelFilterSet,
SearchStrategy,
SettingsProvider,
normalize_civitai_model_type,
resolve_civitai_model_type,
)
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,6 +68,7 @@ class BaseModelService(ABC):
search: str = None, search: str = None,
fuzzy_search: bool = False, fuzzy_search: bool = False,
base_models: list = None, base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
search_options: dict = None, search_options: dict = None,
hash_filters: dict = None, hash_filters: dict = None,
@@ -80,6 +90,7 @@ class BaseModelService(ABC):
sorted_data, sorted_data,
folder=folder, folder=folder,
base_models=base_models, base_models=base_models,
model_types=model_types,
tags=tags, tags=tags,
favorites_only=favorites_only, favorites_only=favorites_only,
search_options=search_options, search_options=search_options,
@@ -149,6 +160,7 @@ class BaseModelService(ABC):
data: List[Dict], data: List[Dict],
folder: str = None, folder: str = None,
base_models: list = None, base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
favorites_only: bool = False, favorites_only: bool = False,
search_options: dict = None, search_options: dict = None,
@@ -158,6 +170,7 @@ class BaseModelService(ABC):
criteria = FilterCriteria( criteria = FilterCriteria(
folder=folder, folder=folder,
base_models=base_models, base_models=base_models,
model_types=model_types,
tags=tags, tags=tags,
favorites_only=favorites_only, favorites_only=favorites_only,
search_options=normalized_options, search_options=normalized_options,
@@ -457,6 +470,25 @@ class BaseModelService(ABC):
"""Get base models sorted by frequency""" """Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit) return await self.scanner.get_base_models(limit)
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
"""Get counts of normalized CivitAI model types present in the cache."""
cache = await self.scanner.get_cached_data()
type_counts: Dict[str, int] = {}
for entry in cache.raw_data:
normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry))
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
continue
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
sorted_types = sorted(
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
key=lambda value: value["count"],
reverse=True,
)
return sorted_types[:limit]
def has_hash(self, sha256: str) -> bool: def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists""" """Check if a model with given hash exists"""
return self.scanner.has_hash(sha256) return self.scanner.has_hash(sha256)

View File

@@ -1,12 +1,49 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
from ..utils.constants import NSFW_LEVELS from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match from ..utils.utils import fuzzy_match as default_fuzzy_match
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
def _coerce_to_str(value: Any) -> Optional[str]:
if value is None:
return None
candidate = str(value).strip()
return candidate if candidate else None
def normalize_civitai_model_type(value: Any) -> Optional[str]:
"""Return a lowercase string suitable for comparisons."""
candidate = _coerce_to_str(value)
return candidate.lower() if candidate else None
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
if not isinstance(entry, Mapping):
return DEFAULT_CIVITAI_MODEL_TYPE
civitai = entry.get("civitai")
if isinstance(civitai, Mapping):
civitai_model = civitai.get("model")
if isinstance(civitai_model, Mapping):
model_type = _coerce_to_str(civitai_model.get("type"))
if model_type:
return model_type
model_type = _coerce_to_str(entry.get("model_type"))
if model_type:
return model_type
return DEFAULT_CIVITAI_MODEL_TYPE
class SettingsProvider(Protocol): class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers.""" """Protocol describing the SettingsManager contract used by query helpers."""
@@ -31,6 +68,7 @@ class FilterCriteria:
tags: Optional[Dict[str, str]] = None tags: Optional[Dict[str, str]] = None
favorites_only: bool = False favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None search_options: Optional[Dict[str, Any]] = None
model_types: Optional[Sequence[str]] = None
class ModelCacheRepository: class ModelCacheRepository:
@@ -134,6 +172,19 @@ class ModelFilterSet:
if not any(tag in exclude_tags for tag in (item.get("tags", []) or [])) if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
] ]
model_types = criteria.model_types or []
normalized_model_types = {
model_type for model_type in (
normalize_civitai_model_type(value) for value in model_types
)
if model_type
}
if normalized_model_types:
items = [
item for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
]
return items return items

View File

@@ -161,6 +161,12 @@ class ModelScanner:
if trained_words: if trained_words:
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
civitai_model = civitai.get('model')
if isinstance(civitai_model, Mapping):
model_type_value = civitai_model.get('type')
if model_type_value not in (None, '', []):
slim['model'] = {'type': model_type_value}
return slim or None return slim or None
def _build_cache_entry( def _build_cache_entry(

View File

@@ -5,7 +5,7 @@ import re
import sqlite3 import sqlite3
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple from typing import Dict, List, Mapping, Optional, Sequence, Tuple
from ..utils.settings_paths import get_project_root, get_settings_dir from ..utils.settings_paths import get_project_root, get_settings_dir
@@ -47,6 +47,7 @@ class PersistentModelCache:
"metadata_source", "metadata_source",
"civitai_id", "civitai_id",
"civitai_model_id", "civitai_model_id",
"civitai_model_type",
"civitai_name", "civitai_name",
"civitai_creator_username", "civitai_creator_username",
"trained_words", "trained_words",
@@ -138,7 +139,8 @@ class PersistentModelCache:
creator_username = row["civitai_creator_username"] creator_username = row["civitai_creator_username"]
civitai: Optional[Dict] = None civitai: Optional[Dict] = None
civitai_has_data = any( civitai_has_data = any(
row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name") row[col] is not None
for col in ("civitai_id", "civitai_model_id", "civitai_model_type", "civitai_name")
) or trained_words or creator_username ) or trained_words or creator_username
if civitai_has_data: if civitai_has_data:
civitai = {} civitai = {}
@@ -152,6 +154,9 @@ class PersistentModelCache:
civitai["trainedWords"] = trained_words civitai["trainedWords"] = trained_words
if creator_username: if creator_username:
civitai.setdefault("creator", {})["username"] = creator_username civitai.setdefault("creator", {})["username"] = creator_username
model_type_value = row["civitai_model_type"]
if model_type_value:
civitai.setdefault("model", {})["type"] = model_type_value
license_value = row["license_flags"] license_value = row["license_flags"]
if license_value is None: if license_value is None:
@@ -443,6 +448,7 @@ class PersistentModelCache:
metadata_source TEXT, metadata_source TEXT,
civitai_id INTEGER, civitai_id INTEGER,
civitai_model_id INTEGER, civitai_model_id INTEGER,
civitai_model_type TEXT,
civitai_name TEXT, civitai_name TEXT,
civitai_creator_username TEXT, civitai_creator_username TEXT,
trained_words TEXT, trained_words TEXT,
@@ -492,6 +498,7 @@ class PersistentModelCache:
required_columns = { required_columns = {
"metadata_source": "TEXT", "metadata_source": "TEXT",
"civitai_creator_username": "TEXT", "civitai_creator_username": "TEXT",
"civitai_model_type": "TEXT",
"civitai_deleted": "INTEGER DEFAULT 0", "civitai_deleted": "INTEGER DEFAULT 0",
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57). # Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}", "license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
@@ -528,6 +535,13 @@ class PersistentModelCache:
creator_data = civitai.get("creator") if isinstance(civitai, dict) else None creator_data = civitai.get("creator") if isinstance(civitai, dict) else None
if isinstance(creator_data, dict): if isinstance(creator_data, dict):
creator_username = creator_data.get("username") or None creator_username = creator_data.get("username") or None
model_type_value = None
if isinstance(civitai, Mapping):
civitai_model_info = civitai.get("model")
if isinstance(civitai_model_info, Mapping):
candidate_type = civitai_model_info.get("type")
if candidate_type not in (None, "", []):
model_type_value = candidate_type
license_flags = item.get("license_flags") license_flags = item.get("license_flags")
if license_flags is None: if license_flags is None:
@@ -552,6 +566,7 @@ class PersistentModelCache:
metadata_source, metadata_source,
civitai.get("id"), civitai.get("id"),
civitai.get("modelId"), civitai.get("modelId"),
model_type_value,
civitai.get("name"), civitai.get("name"),
creator_username, creator_username,
trained_words_json, trained_words_json,

View File

@@ -848,6 +848,12 @@ export class BaseModelApiClient {
} }
} }
} }
if (pageState.filters.modelTypes && pageState.filters.modelTypes.length > 0) {
pageState.filters.modelTypes.forEach((type) => {
params.append('model_type', type);
});
}
} }
this._addModelSpecificParams(params, pageState); this._addModelSpecificParams(params, pageState);

View File

@@ -2,6 +2,7 @@ import { getCurrentPageState } from '../state/index.js';
import { showToast, updatePanelPositions } from '../utils/uiHelpers.js'; import { showToast, updatePanelPositions } from '../utils/uiHelpers.js';
import { getModelApiClient } from '../api/modelApiFactory.js'; import { getModelApiClient } from '../api/modelApiFactory.js';
import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js'; import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js';
import { MODEL_TYPE_DISPLAY_NAMES } from '../utils/constants.js';
export class FilterManager { export class FilterManager {
constructor(options = {}) { constructor(options = {}) {
@@ -34,6 +35,10 @@ export class FilterManager {
this.createBaseModelTags(); this.createBaseModelTags();
} }
if (document.getElementById('modelTypeTags')) {
this.createModelTypeTags();
}
// Add click handlers for license filter tags // Add click handlers for license filter tags
this.initializeLicenseFilters(); this.initializeLicenseFilters();
@@ -248,12 +253,86 @@ export class FilterManager {
// Update selections based on stored filters // Update selections based on stored filters
this.updateTagSelections(); this.updateTagSelections();
}
})
.catch(error => {
console.error(`Error fetching base models for ${this.currentPage}:`, error);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
});
}
async createModelTypeTags() {
const modelTypeContainer = document.getElementById('modelTypeTags');
if (!modelTypeContainer) return;
modelTypeContainer.innerHTML = '<div class="tags-loading">Loading model types...</div>';
try {
const response = await fetch(`/api/lm/${this.currentPage}/model-types?limit=20`);
if (!response.ok) {
throw new Error('Failed to fetch model types');
}
const data = await response.json();
if (!data.success || !Array.isArray(data.model_types)) {
throw new Error('Invalid response format');
}
const normalizedTypes = data.model_types
.map(entry => {
if (!entry || !entry.type) {
return null;
}
const typeKey = entry.type.toString().trim().toLowerCase();
if (!typeKey || !MODEL_TYPE_DISPLAY_NAMES[typeKey]) {
return null;
}
return {
type: typeKey,
count: Number(entry.count) || 0,
};
})
.filter(Boolean);
if (!normalizedTypes.length) {
modelTypeContainer.innerHTML = '<div class="no-tags">No model types available</div>';
return;
}
modelTypeContainer.innerHTML = '';
normalizedTypes.forEach(({ type, count }) => {
const tag = document.createElement('div');
tag.className = 'filter-tag model-type-tag';
tag.dataset.modelType = type;
tag.innerHTML = `${MODEL_TYPE_DISPLAY_NAMES[type]} <span class="tag-count">${count}</span>`;
if (this.filters.modelTypes.includes(type)) {
tag.classList.add('active');
} }
})
.catch(error => { tag.addEventListener('click', async () => {
console.error(`Error fetching base models for ${this.currentPage}:`, error); const isSelected = this.filters.modelTypes.includes(type);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>'; if (isSelected) {
this.filters.modelTypes = this.filters.modelTypes.filter(value => value !== type);
tag.classList.remove('active');
} else {
this.filters.modelTypes.push(type);
tag.classList.add('active');
}
this.updateActiveFiltersCount();
await this.applyFilters(false);
});
modelTypeContainer.appendChild(tag);
}); });
this.updateModelTypeSelections();
} catch (error) {
console.error('Error loading model types:', error);
modelTypeContainer.innerHTML = '<div class="tags-error">Failed to load model types</div>';
}
} }
toggleFilterPanel() { toggleFilterPanel() {
@@ -309,12 +388,26 @@ export class FilterManager {
// Update license tags // Update license tags
this.updateLicenseSelections(); this.updateLicenseSelections();
this.updateModelTypeSelections();
}
updateModelTypeSelections() {
const typeTags = document.querySelectorAll('.model-type-tag');
typeTags.forEach(tag => {
const modelType = tag.dataset.modelType;
if (this.filters.modelTypes.includes(modelType)) {
tag.classList.add('active');
} else {
tag.classList.remove('active');
}
});
} }
updateActiveFiltersCount() { updateActiveFiltersCount() {
const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0; const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount; const modelTypeFilterCount = this.filters.modelTypes.length;
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount + modelTypeFilterCount;
if (this.activeFiltersCount) { if (this.activeFiltersCount) {
if (totalActiveFilters > 0) { if (totalActiveFilters > 0) {
@@ -377,7 +470,8 @@ export class FilterManager {
...this.filters, ...this.filters,
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}); });
// Update state // Update state
@@ -437,7 +531,13 @@ export class FilterManager {
hasActiveFilters() { hasActiveFilters() {
const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0; const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0; const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0; const modelTypeCount = this.filters.modelTypes.length;
return (
this.filters.baseModel.length > 0 ||
tagCount > 0 ||
licenseCount > 0 ||
modelTypeCount > 0
);
} }
initializeFilters(existingFilters = {}) { initializeFilters(existingFilters = {}) {
@@ -446,7 +546,8 @@ export class FilterManager {
...source, ...source,
baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [], baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [],
tags: this.normalizeTagFilters(source.tags), tags: this.normalizeTagFilters(source.tags),
license: this.normalizeLicenseFilters(source.license) license: this.normalizeLicenseFilters(source.license),
modelTypes: this.normalizeModelTypeFilters(source.modelTypes)
}; };
} }
@@ -496,12 +597,35 @@ export class FilterManager {
return normalized; return normalized;
} }
normalizeModelTypeFilters(modelTypes) {
if (!Array.isArray(modelTypes)) {
return [];
}
const seen = new Set();
return modelTypes.reduce((acc, type) => {
if (typeof type !== 'string') {
return acc;
}
const normalized = type.trim().toLowerCase();
if (!normalized || seen.has(normalized)) {
return acc;
}
seen.add(normalized);
acc.push(normalized);
return acc;
}, []);
}
cloneFilters() { cloneFilters() {
return { return {
...this.filters, ...this.filters,
baseModel: [...(this.filters.baseModel || [])], baseModel: [...(this.filters.baseModel || [])],
tags: { ...(this.filters.tags || {}) }, tags: { ...(this.filters.tags || {}) },
license: { ...(this.filters.license || {}) } license: { ...(this.filters.license || {}) },
modelTypes: [...(this.filters.modelTypes || [])]
}; };
} }

View File

@@ -79,7 +79,8 @@ export const state = {
filters: { filters: {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}, },
bulkMode: false, bulkMode: false,
selectedLoras: new Set(), selectedLoras: new Set(),
@@ -105,6 +106,7 @@ export const state = {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {}, license: {},
modelTypes: [],
search: '' search: ''
}, },
pageSize: 20, pageSize: 20,
@@ -131,7 +133,8 @@ export const state = {
filters: { filters: {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}, },
modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model' modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model'
bulkMode: false, bulkMode: false,
@@ -161,7 +164,8 @@ export const state = {
filters: { filters: {
baseModel: [], baseModel: [],
tags: {}, tags: {},
license: {} license: {},
modelTypes: []
}, },
bulkMode: false, bulkMode: false,
selectedModels: new Set(), selectedModels: new Set(),

View File

@@ -55,6 +55,12 @@ export const BASE_MODELS = {
UNKNOWN: "Other" UNKNOWN: "Other"
}; };
export const MODEL_TYPE_DISPLAY_NAMES = {
lora: "LoRA",
locon: "LyCORIS",
dora: "DoRA",
};
export const BASE_MODEL_ABBREVIATIONS = { export const BASE_MODEL_ABBREVIATIONS = {
// Stable Diffusion 1.x models // Stable Diffusion 1.x models
[BASE_MODELS.SD_1_4]: 'SD1', [BASE_MODELS.SD_1_4]: 'SD1',

View File

@@ -139,6 +139,14 @@
<div class="tags-loading">{{ t('common.status.loading') }}</div> <div class="tags-loading">{{ t('common.status.loading') }}</div>
</div> </div>
</div> </div>
{% if current_page == 'loras' %}
<div class="filter-section">
<h4>{{ t('header.filter.modelTypes') }}</h4>
<div class="filter-tags" id="modelTypeTags">
<div class="tags-loading">{{ t('common.status.loading') }}</div>
</div>
</div>
{% endif %}
<div class="filter-section"> <div class="filter-section">
<h4>{{ t('header.filter.license') }}</h4> <h4>{{ t('header.filter.license') }}</h4>
<div class="filter-tags"> <div class="filter-tags">
@@ -156,4 +164,3 @@
</button> </button>
</div> </div>
</div> </div>

View File

@@ -212,6 +212,7 @@ class MockModelService:
self.model_type = "test-model" self.model_type = "test-model"
self.paginated_items: List[Dict[str, Any]] = [] self.paginated_items: List[Dict[str, Any]] = []
self.formatted: List[Dict[str, Any]] = [] self.formatted: List[Dict[str, Any]] = []
self.model_types: List[Dict[str, Any]] = []
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
items = [dict(item) for item in self.paginated_items] items = [dict(item) for item in self.paginated_items]
@@ -257,6 +258,9 @@ class MockModelService:
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
return [] return []
async def get_model_types(self, limit: int = 20):
return list(self.model_types)[:limit]
def has_hash(self, *_args, **_kwargs): # pragma: no cover def has_hash(self, *_args, **_kwargs): # pragma: no cover
return False return False
@@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS
def mock_service(mock_scanner: MockScanner) -> MockModelService: def mock_service(mock_scanner: MockScanner) -> MockModelService:
return MockModelService(scanner=mock_scanner) return MockModelService(scanner=mock_scanner)

View File

@@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
asyncio.run(scenario()) asyncio.run(scenario())
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
mock_service.model_types = [
{"type": "LoRa", "count": 3},
{"type": "Checkpoint", "count": 1},
]
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.get("/api/lm/test-models/model-types?limit=1")
payload = await response.json()
assert response.status == 200
assert payload["model_types"] == mock_service.model_types[:1]
finally:
await client.close()
asyncio.run(scenario())
def test_routes_return_service_not_ready_when_unattached(): def test_routes_return_service_not_ready_when_unattached():
async def scenario(): async def scenario():
client = await create_test_client(None) client = await create_test_client(None)

View File

@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"] assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
def test_model_filter_set_filters_by_model_types():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
]
criteria = FilterCriteria(model_types=["locon"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["LoConModel"]
def test_model_filter_set_defaults_missing_model_type_to_lora():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "DefaultModel"},
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
]
criteria = FilterCriteria(model_types=["lora"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["DefaultModel"]
@pytest.mark.asyncio
async def test_get_model_types_counts_and_limits():
raw_data = [
{"civitai": {"model": {"type": "LoRa"}}},
{"model_type": "LoRa"},
{"civitai": {"model": {"type": "LoCon"}}},
{},
]
class CacheStub:
def __init__(self, raw_data):
self.raw_data = raw_data
class ScannerStub:
def __init__(self, cache):
self._cache = cache
async def get_cached_data(self, *_, **__):
return self._cache
cache = CacheStub(raw_data)
scanner = ScannerStub(cache)
service = DummyService(
model_type="stub",
scanner=scanner,
metadata_class=BaseModelMetadata,
)
types = await service.get_model_types(limit=1)
assert types == [{"type": "lora", "count": 3}]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service_cls, extra_fields", "service_cls, extra_fields",