mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Merge pull request #682 from willmiao/feature/model-type-filter
Feature/model type filter
This commit is contained in:
@@ -195,6 +195,7 @@
|
||||
"title": "Modelle filtern",
|
||||
"baseModel": "Basis-Modell",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "Lizenz",
|
||||
"noCreditRequired": "Kein Credit erforderlich",
|
||||
"allowSellingGeneratedContent": "Verkauf erlaubt",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "Filter Models",
|
||||
"baseModel": "Base Model",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "License",
|
||||
"noCreditRequired": "No Credit Required",
|
||||
"allowSellingGeneratedContent": "Allow Selling",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "Filtrar modelos",
|
||||
"baseModel": "Modelo base",
|
||||
"modelTags": "Etiquetas (Top 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "Licencia",
|
||||
"noCreditRequired": "Sin crédito requerido",
|
||||
"allowSellingGeneratedContent": "Venta permitida",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "Filtrer les modèles",
|
||||
"baseModel": "Modèle de base",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "Licence",
|
||||
"noCreditRequired": "Crédit non requis",
|
||||
"allowSellingGeneratedContent": "Vente autorisée",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "סנן מודלים",
|
||||
"baseModel": "מודל בסיס",
|
||||
"modelTags": "תגיות (20 המובילות)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "רישיון",
|
||||
"noCreditRequired": "ללא קרדיט נדרש",
|
||||
"allowSellingGeneratedContent": "אפשר מכירה",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "モデルをフィルタ",
|
||||
"baseModel": "ベースモデル",
|
||||
"modelTags": "タグ(上位20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "ライセンス",
|
||||
"noCreditRequired": "クレジット不要",
|
||||
"allowSellingGeneratedContent": "販売許可",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "모델 필터",
|
||||
"baseModel": "베이스 모델",
|
||||
"modelTags": "태그 (상위 20개)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "라이선스",
|
||||
"noCreditRequired": "크레딧 표기 없음",
|
||||
"allowSellingGeneratedContent": "판매 허용",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "Фильтр моделей",
|
||||
"baseModel": "Базовая модель",
|
||||
"modelTags": "Теги (Топ 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "Лицензия",
|
||||
"noCreditRequired": "Без указания авторства",
|
||||
"allowSellingGeneratedContent": "Продажа разрешена",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "筛选模型",
|
||||
"baseModel": "基础模型",
|
||||
"modelTags": "标签(前20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "许可证",
|
||||
"noCreditRequired": "无需署名",
|
||||
"allowSellingGeneratedContent": "允许销售",
|
||||
|
||||
@@ -195,6 +195,7 @@
|
||||
"title": "篩選模型",
|
||||
"baseModel": "基礎模型",
|
||||
"modelTags": "標籤(前 20)",
|
||||
"modelTypes": "Model Types",
|
||||
"license": "授權",
|
||||
"noCreditRequired": "無需署名",
|
||||
"allowSellingGeneratedContent": "允許銷售",
|
||||
|
||||
@@ -152,6 +152,8 @@ class ModelListingHandler:
|
||||
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
|
||||
|
||||
base_models = request.query.getall("base_model", [])
|
||||
model_types = list(request.query.getall("model_type", []))
|
||||
model_types.extend(request.query.getall("civitai_model_type", []))
|
||||
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
|
||||
legacy_tags = request.query.getall("tag", [])
|
||||
if not legacy_tags:
|
||||
@@ -225,6 +227,7 @@ class ModelListingHandler:
|
||||
"update_available_only": update_available_only,
|
||||
"credit_required": credit_required,
|
||||
"allow_selling_generated_content": allow_selling_generated_content,
|
||||
"model_types": model_types,
|
||||
**self._parse_specific_params(request),
|
||||
}
|
||||
|
||||
@@ -557,6 +560,17 @@ class ModelQueryHandler:
|
||||
self._logger.error("Error retrieving base models: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_model_types(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20
|
||||
model_types = await self._service.get_model_types(limit)
|
||||
return web.json_response({"success": True, "model_types": model_types})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving model types: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def scan_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
full_rebuild = request.query.get("full_rebuild", "false").lower() == "true"
|
||||
@@ -1579,6 +1593,7 @@ class ModelHandlerSet:
|
||||
"verify_duplicates": self.management.verify_duplicates,
|
||||
"get_top_tags": self.query.get_top_tags,
|
||||
"get_base_models": self.query.get_base_models,
|
||||
"get_model_types": self.query.get_model_types,
|
||||
"scan_models": self.query.scan_models,
|
||||
"get_model_roots": self.query.get_model_roots,
|
||||
"get_folders": self.query.get_folders,
|
||||
|
||||
@@ -39,6 +39,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
|
||||
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Type, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||
from .model_query import (
|
||||
FilterCriteria,
|
||||
ModelCacheRepository,
|
||||
ModelFilterSet,
|
||||
SearchStrategy,
|
||||
SettingsProvider,
|
||||
normalize_civitai_model_type,
|
||||
resolve_civitai_model_type,
|
||||
)
|
||||
from .settings_manager import get_settings_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,6 +68,7 @@ class BaseModelService(ABC):
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
@@ -80,6 +90,7 @@ class BaseModelService(ABC):
|
||||
sorted_data,
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=search_options,
|
||||
@@ -149,6 +160,7 @@ class BaseModelService(ABC):
|
||||
data: List[Dict],
|
||||
folder: str = None,
|
||||
base_models: list = None,
|
||||
model_types: list = None,
|
||||
tags: Optional[Dict[str, str]] = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
@@ -158,6 +170,7 @@ class BaseModelService(ABC):
|
||||
criteria = FilterCriteria(
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
model_types=model_types,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=normalized_options,
|
||||
@@ -456,6 +469,25 @@ class BaseModelService(ABC):
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
|
||||
"""Get counts of normalized CivitAI model types present in the cache."""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
type_counts: Dict[str, int] = {}
|
||||
for entry in cache.raw_data:
|
||||
normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry))
|
||||
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
|
||||
continue
|
||||
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
|
||||
|
||||
sorted_types = sorted(
|
||||
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
|
||||
key=lambda value: value["count"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_types[:limit]
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
|
||||
@@ -1,12 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
|
||||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
|
||||
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||
|
||||
|
||||
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
|
||||
|
||||
|
||||
def _coerce_to_str(value: Any) -> Optional[str]:
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
candidate = str(value).strip()
|
||||
return candidate if candidate else None
|
||||
|
||||
|
||||
def normalize_civitai_model_type(value: Any) -> Optional[str]:
|
||||
"""Return a lowercase string suitable for comparisons."""
|
||||
candidate = _coerce_to_str(value)
|
||||
return candidate.lower() if candidate else None
|
||||
|
||||
|
||||
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
|
||||
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
|
||||
if not isinstance(entry, Mapping):
|
||||
return DEFAULT_CIVITAI_MODEL_TYPE
|
||||
|
||||
civitai = entry.get("civitai")
|
||||
if isinstance(civitai, Mapping):
|
||||
civitai_model = civitai.get("model")
|
||||
if isinstance(civitai_model, Mapping):
|
||||
model_type = _coerce_to_str(civitai_model.get("type"))
|
||||
if model_type:
|
||||
return model_type
|
||||
|
||||
model_type = _coerce_to_str(entry.get("model_type"))
|
||||
if model_type:
|
||||
return model_type
|
||||
|
||||
return DEFAULT_CIVITAI_MODEL_TYPE
|
||||
|
||||
|
||||
class SettingsProvider(Protocol):
|
||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||
|
||||
@@ -31,6 +68,7 @@ class FilterCriteria:
|
||||
tags: Optional[Dict[str, str]] = None
|
||||
favorites_only: bool = False
|
||||
search_options: Optional[Dict[str, Any]] = None
|
||||
model_types: Optional[Sequence[str]] = None
|
||||
|
||||
|
||||
class ModelCacheRepository:
|
||||
@@ -134,6 +172,19 @@ class ModelFilterSet:
|
||||
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
|
||||
]
|
||||
|
||||
model_types = criteria.model_types or []
|
||||
normalized_model_types = {
|
||||
model_type for model_type in (
|
||||
normalize_civitai_model_type(value) for value in model_types
|
||||
)
|
||||
if model_type
|
||||
}
|
||||
if normalized_model_types:
|
||||
items = [
|
||||
item for item in items
|
||||
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
|
||||
@@ -161,6 +161,12 @@ class ModelScanner:
|
||||
if trained_words:
|
||||
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
|
||||
|
||||
civitai_model = civitai.get('model')
|
||||
if isinstance(civitai_model, Mapping):
|
||||
model_type_value = civitai_model.get('type')
|
||||
if model_type_value not in (None, '', []):
|
||||
slim['model'] = {'type': model_type_value}
|
||||
|
||||
return slim or None
|
||||
|
||||
def _build_cache_entry(
|
||||
|
||||
@@ -5,7 +5,7 @@ import re
|
||||
import sqlite3
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from ..utils.settings_paths import get_project_root, get_settings_dir
|
||||
|
||||
@@ -47,6 +47,7 @@ class PersistentModelCache:
|
||||
"metadata_source",
|
||||
"civitai_id",
|
||||
"civitai_model_id",
|
||||
"civitai_model_type",
|
||||
"civitai_name",
|
||||
"civitai_creator_username",
|
||||
"trained_words",
|
||||
@@ -138,7 +139,8 @@ class PersistentModelCache:
|
||||
creator_username = row["civitai_creator_username"]
|
||||
civitai: Optional[Dict] = None
|
||||
civitai_has_data = any(
|
||||
row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name")
|
||||
row[col] is not None
|
||||
for col in ("civitai_id", "civitai_model_id", "civitai_model_type", "civitai_name")
|
||||
) or trained_words or creator_username
|
||||
if civitai_has_data:
|
||||
civitai = {}
|
||||
@@ -152,6 +154,9 @@ class PersistentModelCache:
|
||||
civitai["trainedWords"] = trained_words
|
||||
if creator_username:
|
||||
civitai.setdefault("creator", {})["username"] = creator_username
|
||||
model_type_value = row["civitai_model_type"]
|
||||
if model_type_value:
|
||||
civitai.setdefault("model", {})["type"] = model_type_value
|
||||
|
||||
license_value = row["license_flags"]
|
||||
if license_value is None:
|
||||
@@ -443,6 +448,7 @@ class PersistentModelCache:
|
||||
metadata_source TEXT,
|
||||
civitai_id INTEGER,
|
||||
civitai_model_id INTEGER,
|
||||
civitai_model_type TEXT,
|
||||
civitai_name TEXT,
|
||||
civitai_creator_username TEXT,
|
||||
trained_words TEXT,
|
||||
@@ -492,6 +498,7 @@ class PersistentModelCache:
|
||||
required_columns = {
|
||||
"metadata_source": "TEXT",
|
||||
"civitai_creator_username": "TEXT",
|
||||
"civitai_model_type": "TEXT",
|
||||
"civitai_deleted": "INTEGER DEFAULT 0",
|
||||
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
|
||||
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
|
||||
@@ -528,6 +535,13 @@ class PersistentModelCache:
|
||||
creator_data = civitai.get("creator") if isinstance(civitai, dict) else None
|
||||
if isinstance(creator_data, dict):
|
||||
creator_username = creator_data.get("username") or None
|
||||
model_type_value = None
|
||||
if isinstance(civitai, Mapping):
|
||||
civitai_model_info = civitai.get("model")
|
||||
if isinstance(civitai_model_info, Mapping):
|
||||
candidate_type = civitai_model_info.get("type")
|
||||
if candidate_type not in (None, "", []):
|
||||
model_type_value = candidate_type
|
||||
|
||||
license_flags = item.get("license_flags")
|
||||
if license_flags is None:
|
||||
@@ -552,6 +566,7 @@ class PersistentModelCache:
|
||||
metadata_source,
|
||||
civitai.get("id"),
|
||||
civitai.get("modelId"),
|
||||
model_type_value,
|
||||
civitai.get("name"),
|
||||
creator_username,
|
||||
trained_words_json,
|
||||
|
||||
@@ -848,6 +848,12 @@ export class BaseModelApiClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (pageState.filters.modelTypes && pageState.filters.modelTypes.length > 0) {
|
||||
pageState.filters.modelTypes.forEach((type) => {
|
||||
params.append('model_type', type);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
this._addModelSpecificParams(params, pageState);
|
||||
|
||||
@@ -2,6 +2,7 @@ import { getCurrentPageState } from '../state/index.js';
|
||||
import { showToast, updatePanelPositions } from '../utils/uiHelpers.js';
|
||||
import { getModelApiClient } from '../api/modelApiFactory.js';
|
||||
import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js';
|
||||
import { MODEL_TYPE_DISPLAY_NAMES } from '../utils/constants.js';
|
||||
|
||||
export class FilterManager {
|
||||
constructor(options = {}) {
|
||||
@@ -34,6 +35,10 @@ export class FilterManager {
|
||||
this.createBaseModelTags();
|
||||
}
|
||||
|
||||
if (document.getElementById('modelTypeTags')) {
|
||||
this.createModelTypeTags();
|
||||
}
|
||||
|
||||
// Add click handlers for license filter tags
|
||||
this.initializeLicenseFilters();
|
||||
|
||||
@@ -248,12 +253,86 @@ export class FilterManager {
|
||||
|
||||
// Update selections based on stored filters
|
||||
this.updateTagSelections();
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error(`Error fetching base models for ${this.currentPage}:`, error);
|
||||
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
|
||||
});
|
||||
}
|
||||
|
||||
async createModelTypeTags() {
|
||||
const modelTypeContainer = document.getElementById('modelTypeTags');
|
||||
if (!modelTypeContainer) return;
|
||||
|
||||
modelTypeContainer.innerHTML = '<div class="tags-loading">Loading model types...</div>';
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/lm/${this.currentPage}/model-types?limit=20`);
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to fetch model types');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
if (!data.success || !Array.isArray(data.model_types)) {
|
||||
throw new Error('Invalid response format');
|
||||
}
|
||||
|
||||
const normalizedTypes = data.model_types
|
||||
.map(entry => {
|
||||
if (!entry || !entry.type) {
|
||||
return null;
|
||||
}
|
||||
const typeKey = entry.type.toString().trim().toLowerCase();
|
||||
if (!typeKey || !MODEL_TYPE_DISPLAY_NAMES[typeKey]) {
|
||||
return null;
|
||||
}
|
||||
return {
|
||||
type: typeKey,
|
||||
count: Number(entry.count) || 0,
|
||||
};
|
||||
})
|
||||
.filter(Boolean);
|
||||
|
||||
if (!normalizedTypes.length) {
|
||||
modelTypeContainer.innerHTML = '<div class="no-tags">No model types available</div>';
|
||||
return;
|
||||
}
|
||||
|
||||
modelTypeContainer.innerHTML = '';
|
||||
|
||||
normalizedTypes.forEach(({ type, count }) => {
|
||||
const tag = document.createElement('div');
|
||||
tag.className = 'filter-tag model-type-tag';
|
||||
tag.dataset.modelType = type;
|
||||
tag.innerHTML = `${MODEL_TYPE_DISPLAY_NAMES[type]} <span class="tag-count">${count}</span>`;
|
||||
|
||||
if (this.filters.modelTypes.includes(type)) {
|
||||
tag.classList.add('active');
|
||||
}
|
||||
})
|
||||
.catch(error => {
|
||||
console.error(`Error fetching base models for ${this.currentPage}:`, error);
|
||||
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
|
||||
|
||||
tag.addEventListener('click', async () => {
|
||||
const isSelected = this.filters.modelTypes.includes(type);
|
||||
if (isSelected) {
|
||||
this.filters.modelTypes = this.filters.modelTypes.filter(value => value !== type);
|
||||
tag.classList.remove('active');
|
||||
} else {
|
||||
this.filters.modelTypes.push(type);
|
||||
tag.classList.add('active');
|
||||
}
|
||||
|
||||
this.updateActiveFiltersCount();
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
|
||||
modelTypeContainer.appendChild(tag);
|
||||
});
|
||||
|
||||
this.updateModelTypeSelections();
|
||||
} catch (error) {
|
||||
console.error('Error loading model types:', error);
|
||||
modelTypeContainer.innerHTML = '<div class="tags-error">Failed to load model types</div>';
|
||||
}
|
||||
}
|
||||
|
||||
toggleFilterPanel() {
|
||||
@@ -309,12 +388,26 @@ export class FilterManager {
|
||||
|
||||
// Update license tags
|
||||
this.updateLicenseSelections();
|
||||
this.updateModelTypeSelections();
|
||||
}
|
||||
|
||||
updateModelTypeSelections() {
|
||||
const typeTags = document.querySelectorAll('.model-type-tag');
|
||||
typeTags.forEach(tag => {
|
||||
const modelType = tag.dataset.modelType;
|
||||
if (this.filters.modelTypes.includes(modelType)) {
|
||||
tag.classList.add('active');
|
||||
} else {
|
||||
tag.classList.remove('active');
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
updateActiveFiltersCount() {
|
||||
const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
|
||||
const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
|
||||
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount;
|
||||
const modelTypeFilterCount = this.filters.modelTypes.length;
|
||||
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount + modelTypeFilterCount;
|
||||
|
||||
if (this.activeFiltersCount) {
|
||||
if (totalActiveFilters > 0) {
|
||||
@@ -377,7 +470,8 @@ export class FilterManager {
|
||||
...this.filters,
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {}
|
||||
license: {},
|
||||
modelTypes: []
|
||||
});
|
||||
|
||||
// Update state
|
||||
@@ -437,7 +531,13 @@ export class FilterManager {
|
||||
hasActiveFilters() {
|
||||
const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
|
||||
const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
|
||||
return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0;
|
||||
const modelTypeCount = this.filters.modelTypes.length;
|
||||
return (
|
||||
this.filters.baseModel.length > 0 ||
|
||||
tagCount > 0 ||
|
||||
licenseCount > 0 ||
|
||||
modelTypeCount > 0
|
||||
);
|
||||
}
|
||||
|
||||
initializeFilters(existingFilters = {}) {
|
||||
@@ -446,7 +546,8 @@ export class FilterManager {
|
||||
...source,
|
||||
baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [],
|
||||
tags: this.normalizeTagFilters(source.tags),
|
||||
license: this.normalizeLicenseFilters(source.license)
|
||||
license: this.normalizeLicenseFilters(source.license),
|
||||
modelTypes: this.normalizeModelTypeFilters(source.modelTypes)
|
||||
};
|
||||
}
|
||||
|
||||
@@ -496,12 +597,35 @@ export class FilterManager {
|
||||
return normalized;
|
||||
}
|
||||
|
||||
normalizeModelTypeFilters(modelTypes) {
|
||||
if (!Array.isArray(modelTypes)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const seen = new Set();
|
||||
return modelTypes.reduce((acc, type) => {
|
||||
if (typeof type !== 'string') {
|
||||
return acc;
|
||||
}
|
||||
|
||||
const normalized = type.trim().toLowerCase();
|
||||
if (!normalized || seen.has(normalized)) {
|
||||
return acc;
|
||||
}
|
||||
|
||||
seen.add(normalized);
|
||||
acc.push(normalized);
|
||||
return acc;
|
||||
}, []);
|
||||
}
|
||||
|
||||
cloneFilters() {
|
||||
return {
|
||||
...this.filters,
|
||||
baseModel: [...(this.filters.baseModel || [])],
|
||||
tags: { ...(this.filters.tags || {}) },
|
||||
license: { ...(this.filters.license || {}) }
|
||||
license: { ...(this.filters.license || {}) },
|
||||
modelTypes: [...(this.filters.modelTypes || [])]
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -79,7 +79,8 @@ export const state = {
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {}
|
||||
license: {},
|
||||
modelTypes: []
|
||||
},
|
||||
bulkMode: false,
|
||||
selectedLoras: new Set(),
|
||||
@@ -105,6 +106,7 @@ export const state = {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {},
|
||||
modelTypes: [],
|
||||
search: ''
|
||||
},
|
||||
pageSize: 20,
|
||||
@@ -131,7 +133,8 @@ export const state = {
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {}
|
||||
license: {},
|
||||
modelTypes: []
|
||||
},
|
||||
modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model'
|
||||
bulkMode: false,
|
||||
@@ -161,7 +164,8 @@ export const state = {
|
||||
filters: {
|
||||
baseModel: [],
|
||||
tags: {},
|
||||
license: {}
|
||||
license: {},
|
||||
modelTypes: []
|
||||
},
|
||||
bulkMode: false,
|
||||
selectedModels: new Set(),
|
||||
|
||||
@@ -55,6 +55,12 @@ export const BASE_MODELS = {
|
||||
UNKNOWN: "Other"
|
||||
};
|
||||
|
||||
export const MODEL_TYPE_DISPLAY_NAMES = {
|
||||
lora: "LoRA",
|
||||
locon: "LyCORIS",
|
||||
dora: "DoRA",
|
||||
};
|
||||
|
||||
export const BASE_MODEL_ABBREVIATIONS = {
|
||||
// Stable Diffusion 1.x models
|
||||
[BASE_MODELS.SD_1_4]: 'SD1',
|
||||
|
||||
@@ -139,6 +139,14 @@
|
||||
<div class="tags-loading">{{ t('common.status.loading') }}</div>
|
||||
</div>
|
||||
</div>
|
||||
{% if current_page == 'loras' %}
|
||||
<div class="filter-section">
|
||||
<h4>{{ t('header.filter.modelTypes') }}</h4>
|
||||
<div class="filter-tags" id="modelTypeTags">
|
||||
<div class="tags-loading">{{ t('common.status.loading') }}</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
<div class="filter-section">
|
||||
<h4>{{ t('header.filter.license') }}</h4>
|
||||
<div class="filter-tags">
|
||||
@@ -156,4 +164,3 @@
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -212,6 +212,7 @@ class MockModelService:
|
||||
self.model_type = "test-model"
|
||||
self.paginated_items: List[Dict[str, Any]] = []
|
||||
self.formatted: List[Dict[str, Any]] = []
|
||||
self.model_types: List[Dict[str, Any]] = []
|
||||
|
||||
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
||||
items = [dict(item) for item in self.paginated_items]
|
||||
@@ -257,6 +258,9 @@ class MockModelService:
|
||||
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
|
||||
return []
|
||||
|
||||
async def get_model_types(self, limit: int = 20):
|
||||
return list(self.model_types)[:limit]
|
||||
|
||||
def has_hash(self, *_args, **_kwargs): # pragma: no cover
|
||||
return False
|
||||
|
||||
@@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS
|
||||
def mock_service(mock_scanner: MockScanner) -> MockModelService:
|
||||
return MockModelService(scanner=mock_scanner)
|
||||
|
||||
|
||||
|
||||
@@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
|
||||
mock_service.model_types = [
|
||||
{"type": "LoRa", "count": 3},
|
||||
{"type": "Checkpoint", "count": 1},
|
||||
]
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.get("/api/lm/test-models/model-types?limit=1")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["model_types"] == mock_service.model_types[:1]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_routes_return_service_not_ready_when_unattached():
|
||||
async def scenario():
|
||||
client = await create_test_client(None)
|
||||
|
||||
@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
|
||||
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
|
||||
|
||||
|
||||
def test_model_filter_set_filters_by_model_types():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
|
||||
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["locon"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["LoConModel"]
|
||||
|
||||
|
||||
def test_model_filter_set_defaults_missing_model_type_to_lora():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "DefaultModel"},
|
||||
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["DefaultModel"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_types_counts_and_limits():
|
||||
raw_data = [
|
||||
{"civitai": {"model": {"type": "LoRa"}}},
|
||||
{"model_type": "LoRa"},
|
||||
{"civitai": {"model": {"type": "LoCon"}}},
|
||||
{},
|
||||
]
|
||||
|
||||
class CacheStub:
|
||||
def __init__(self, raw_data):
|
||||
self.raw_data = raw_data
|
||||
|
||||
class ScannerStub:
|
||||
def __init__(self, cache):
|
||||
self._cache = cache
|
||||
|
||||
async def get_cached_data(self, *_, **__):
|
||||
return self._cache
|
||||
|
||||
cache = CacheStub(raw_data)
|
||||
scanner = ScannerStub(cache)
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=scanner,
|
||||
metadata_class=BaseModelMetadata,
|
||||
)
|
||||
|
||||
types = await service.get_model_types(limit=1)
|
||||
|
||||
assert types == [{"type": "lora", "count": 3}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"service_cls, extra_fields",
|
||||
|
||||
Reference in New Issue
Block a user