Merge pull request #682 from willmiao/feature/model-type-filter

Feature/model type filter
This commit is contained in:
pixelpaws
2025-11-18 18:51:52 +08:00
committed by GitHub
24 changed files with 380 additions and 19 deletions

View File

@@ -195,6 +195,7 @@
"title": "Modelle filtern",
"baseModel": "Basis-Modell",
"modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "Lizenz",
"noCreditRequired": "Kein Credit erforderlich",
"allowSellingGeneratedContent": "Verkauf erlaubt",

View File

@@ -195,6 +195,7 @@
"title": "Filter Models",
"baseModel": "Base Model",
"modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "License",
"noCreditRequired": "No Credit Required",
"allowSellingGeneratedContent": "Allow Selling",

View File

@@ -195,6 +195,7 @@
"title": "Filtrar modelos",
"baseModel": "Modelo base",
"modelTags": "Etiquetas (Top 20)",
"modelTypes": "Model Types",
"license": "Licencia",
"noCreditRequired": "Sin crédito requerido",
"allowSellingGeneratedContent": "Venta permitida",

View File

@@ -195,6 +195,7 @@
"title": "Filtrer les modèles",
"baseModel": "Modèle de base",
"modelTags": "Tags (Top 20)",
"modelTypes": "Model Types",
"license": "Licence",
"noCreditRequired": "Crédit non requis",
"allowSellingGeneratedContent": "Vente autorisée",

View File

@@ -195,6 +195,7 @@
"title": "סנן מודלים",
"baseModel": "מודל בסיס",
"modelTags": "תגיות (20 המובילות)",
"modelTypes": "Model Types",
"license": "רישיון",
"noCreditRequired": "ללא קרדיט נדרש",
"allowSellingGeneratedContent": "אפשר מכירה",

View File

@@ -195,6 +195,7 @@
"title": "モデルをフィルタ",
"baseModel": "ベースモデル",
"modelTags": "タグ上位20",
"modelTypes": "Model Types",
"license": "ライセンス",
"noCreditRequired": "クレジット不要",
"allowSellingGeneratedContent": "販売許可",

View File

@@ -195,6 +195,7 @@
"title": "모델 필터",
"baseModel": "베이스 모델",
"modelTags": "태그 (상위 20개)",
"modelTypes": "Model Types",
"license": "라이선스",
"noCreditRequired": "크레딧 표기 없음",
"allowSellingGeneratedContent": "판매 허용",

View File

@@ -195,6 +195,7 @@
"title": "Фильтр моделей",
"baseModel": "Базовая модель",
"modelTags": "Теги (Топ 20)",
"modelTypes": "Model Types",
"license": "Лицензия",
"noCreditRequired": "Без указания авторства",
"allowSellingGeneratedContent": "Продажа разрешена",

View File

@@ -195,6 +195,7 @@
"title": "筛选模型",
"baseModel": "基础模型",
"modelTags": "标签前20",
"modelTypes": "Model Types",
"license": "许可证",
"noCreditRequired": "无需署名",
"allowSellingGeneratedContent": "允许销售",

View File

@@ -195,6 +195,7 @@
"title": "篩選模型",
"baseModel": "基礎模型",
"modelTags": "標籤(前 20",
"modelTypes": "Model Types",
"license": "授權",
"noCreditRequired": "無需署名",
"allowSellingGeneratedContent": "允許銷售",

View File

@@ -152,6 +152,8 @@ class ModelListingHandler:
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
base_models = request.query.getall("base_model", [])
model_types = list(request.query.getall("model_type", []))
model_types.extend(request.query.getall("civitai_model_type", []))
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
legacy_tags = request.query.getall("tag", [])
if not legacy_tags:
@@ -225,6 +227,7 @@ class ModelListingHandler:
"update_available_only": update_available_only,
"credit_required": credit_required,
"allow_selling_generated_content": allow_selling_generated_content,
"model_types": model_types,
**self._parse_specific_params(request),
}
@@ -557,6 +560,17 @@ class ModelQueryHandler:
self._logger.error("Error retrieving base models: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_types(self, request: web.Request) -> web.Response:
try:
limit = int(request.query.get("limit", "20"))
if limit < 1 or limit > 100:
limit = 20
model_types = await self._service.get_model_types(limit)
return web.json_response({"success": True, "model_types": model_types})
except Exception as exc:
self._logger.error("Error retrieving model types: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def scan_models(self, request: web.Request) -> web.Response:
try:
full_rebuild = request.query.get("full_rebuild", "false").lower() == "true"
@@ -1579,6 +1593,7 @@ class ModelHandlerSet:
"verify_duplicates": self.management.verify_duplicates,
"get_top_tags": self.query.get_top_tags,
"get_base_models": self.query.get_base_models,
"get_model_types": self.query.get_model_types,
"scan_models": self.query.scan_models,
"get_model_roots": self.query.get_model_roots,
"get_folders": self.query.get_folders,

View File

@@ -39,6 +39,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),

View File

@@ -1,12 +1,21 @@
from abc import ABC, abstractmethod
import asyncio
from typing import Dict, List, Optional, Type, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import logging
import os
from ..utils.constants import VALID_LORA_TYPES
from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
from .model_query import (
FilterCriteria,
ModelCacheRepository,
ModelFilterSet,
SearchStrategy,
SettingsProvider,
normalize_civitai_model_type,
resolve_civitai_model_type,
)
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@@ -59,6 +68,7 @@ class BaseModelService(ABC):
search: str = None,
fuzzy_search: bool = False,
base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None,
search_options: dict = None,
hash_filters: dict = None,
@@ -80,6 +90,7 @@ class BaseModelService(ABC):
sorted_data,
folder=folder,
base_models=base_models,
model_types=model_types,
tags=tags,
favorites_only=favorites_only,
search_options=search_options,
@@ -149,6 +160,7 @@ class BaseModelService(ABC):
data: List[Dict],
folder: str = None,
base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None,
favorites_only: bool = False,
search_options: dict = None,
@@ -158,6 +170,7 @@ class BaseModelService(ABC):
criteria = FilterCriteria(
folder=folder,
base_models=base_models,
model_types=model_types,
tags=tags,
favorites_only=favorites_only,
search_options=normalized_options,
@@ -456,6 +469,25 @@ class BaseModelService(ABC):
async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit)
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
"""Get counts of normalized CivitAI model types present in the cache."""
cache = await self.scanner.get_cached_data()
type_counts: Dict[str, int] = {}
for entry in cache.raw_data:
normalized_type = normalize_civitai_model_type(resolve_civitai_model_type(entry))
if not normalized_type or normalized_type not in VALID_LORA_TYPES:
continue
type_counts[normalized_type] = type_counts.get(normalized_type, 0) + 1
sorted_types = sorted(
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
key=lambda value: value["count"],
reverse=True,
)
return sorted_types[:limit]
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""

View File

@@ -1,12 +1,49 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
def _coerce_to_str(value: Any) -> Optional[str]:
if value is None:
return None
candidate = str(value).strip()
return candidate if candidate else None
def normalize_civitai_model_type(value: Any) -> Optional[str]:
"""Return a lowercase string suitable for comparisons."""
candidate = _coerce_to_str(value)
return candidate.lower() if candidate else None
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
if not isinstance(entry, Mapping):
return DEFAULT_CIVITAI_MODEL_TYPE
civitai = entry.get("civitai")
if isinstance(civitai, Mapping):
civitai_model = civitai.get("model")
if isinstance(civitai_model, Mapping):
model_type = _coerce_to_str(civitai_model.get("type"))
if model_type:
return model_type
model_type = _coerce_to_str(entry.get("model_type"))
if model_type:
return model_type
return DEFAULT_CIVITAI_MODEL_TYPE
class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers."""
@@ -31,6 +68,7 @@ class FilterCriteria:
tags: Optional[Dict[str, str]] = None
favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None
model_types: Optional[Sequence[str]] = None
class ModelCacheRepository:
@@ -134,6 +172,19 @@ class ModelFilterSet:
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
]
model_types = criteria.model_types or []
normalized_model_types = {
model_type for model_type in (
normalize_civitai_model_type(value) for value in model_types
)
if model_type
}
if normalized_model_types:
items = [
item for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
]
return items

View File

@@ -161,6 +161,12 @@ class ModelScanner:
if trained_words:
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
civitai_model = civitai.get('model')
if isinstance(civitai_model, Mapping):
model_type_value = civitai_model.get('type')
if model_type_value not in (None, '', []):
slim['model'] = {'type': model_type_value}
return slim or None
def _build_cache_entry(

View File

@@ -5,7 +5,7 @@ import re
import sqlite3
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
from ..utils.settings_paths import get_project_root, get_settings_dir
@@ -47,6 +47,7 @@ class PersistentModelCache:
"metadata_source",
"civitai_id",
"civitai_model_id",
"civitai_model_type",
"civitai_name",
"civitai_creator_username",
"trained_words",
@@ -138,7 +139,8 @@ class PersistentModelCache:
creator_username = row["civitai_creator_username"]
civitai: Optional[Dict] = None
civitai_has_data = any(
row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name")
row[col] is not None
for col in ("civitai_id", "civitai_model_id", "civitai_model_type", "civitai_name")
) or trained_words or creator_username
if civitai_has_data:
civitai = {}
@@ -152,6 +154,9 @@ class PersistentModelCache:
civitai["trainedWords"] = trained_words
if creator_username:
civitai.setdefault("creator", {})["username"] = creator_username
model_type_value = row["civitai_model_type"]
if model_type_value:
civitai.setdefault("model", {})["type"] = model_type_value
license_value = row["license_flags"]
if license_value is None:
@@ -443,6 +448,7 @@ class PersistentModelCache:
metadata_source TEXT,
civitai_id INTEGER,
civitai_model_id INTEGER,
civitai_model_type TEXT,
civitai_name TEXT,
civitai_creator_username TEXT,
trained_words TEXT,
@@ -492,6 +498,7 @@ class PersistentModelCache:
required_columns = {
"metadata_source": "TEXT",
"civitai_creator_username": "TEXT",
"civitai_model_type": "TEXT",
"civitai_deleted": "INTEGER DEFAULT 0",
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
@@ -528,6 +535,13 @@ class PersistentModelCache:
creator_data = civitai.get("creator") if isinstance(civitai, dict) else None
if isinstance(creator_data, dict):
creator_username = creator_data.get("username") or None
model_type_value = None
if isinstance(civitai, Mapping):
civitai_model_info = civitai.get("model")
if isinstance(civitai_model_info, Mapping):
candidate_type = civitai_model_info.get("type")
if candidate_type not in (None, "", []):
model_type_value = candidate_type
license_flags = item.get("license_flags")
if license_flags is None:
@@ -552,6 +566,7 @@ class PersistentModelCache:
metadata_source,
civitai.get("id"),
civitai.get("modelId"),
model_type_value,
civitai.get("name"),
creator_username,
trained_words_json,

View File

@@ -848,6 +848,12 @@ export class BaseModelApiClient {
}
}
}
if (pageState.filters.modelTypes && pageState.filters.modelTypes.length > 0) {
pageState.filters.modelTypes.forEach((type) => {
params.append('model_type', type);
});
}
}
this._addModelSpecificParams(params, pageState);

View File

@@ -2,6 +2,7 @@ import { getCurrentPageState } from '../state/index.js';
import { showToast, updatePanelPositions } from '../utils/uiHelpers.js';
import { getModelApiClient } from '../api/modelApiFactory.js';
import { removeStorageItem, setStorageItem, getStorageItem } from '../utils/storageHelpers.js';
import { MODEL_TYPE_DISPLAY_NAMES } from '../utils/constants.js';
export class FilterManager {
constructor(options = {}) {
@@ -34,6 +35,10 @@ export class FilterManager {
this.createBaseModelTags();
}
if (document.getElementById('modelTypeTags')) {
this.createModelTypeTags();
}
// Add click handlers for license filter tags
this.initializeLicenseFilters();
@@ -248,12 +253,86 @@ export class FilterManager {
// Update selections based on stored filters
this.updateTagSelections();
}
})
.catch(error => {
console.error(`Error fetching base models for ${this.currentPage}:`, error);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
});
}
async createModelTypeTags() {
const modelTypeContainer = document.getElementById('modelTypeTags');
if (!modelTypeContainer) return;
modelTypeContainer.innerHTML = '<div class="tags-loading">Loading model types...</div>';
try {
const response = await fetch(`/api/lm/${this.currentPage}/model-types?limit=20`);
if (!response.ok) {
throw new Error('Failed to fetch model types');
}
const data = await response.json();
if (!data.success || !Array.isArray(data.model_types)) {
throw new Error('Invalid response format');
}
const normalizedTypes = data.model_types
.map(entry => {
if (!entry || !entry.type) {
return null;
}
const typeKey = entry.type.toString().trim().toLowerCase();
if (!typeKey || !MODEL_TYPE_DISPLAY_NAMES[typeKey]) {
return null;
}
return {
type: typeKey,
count: Number(entry.count) || 0,
};
})
.filter(Boolean);
if (!normalizedTypes.length) {
modelTypeContainer.innerHTML = '<div class="no-tags">No model types available</div>';
return;
}
modelTypeContainer.innerHTML = '';
normalizedTypes.forEach(({ type, count }) => {
const tag = document.createElement('div');
tag.className = 'filter-tag model-type-tag';
tag.dataset.modelType = type;
tag.innerHTML = `${MODEL_TYPE_DISPLAY_NAMES[type]} <span class="tag-count">${count}</span>`;
if (this.filters.modelTypes.includes(type)) {
tag.classList.add('active');
}
})
.catch(error => {
console.error(`Error fetching base models for ${this.currentPage}:`, error);
baseModelTagsContainer.innerHTML = '<div class="tags-error">Failed to load base models</div>';
tag.addEventListener('click', async () => {
const isSelected = this.filters.modelTypes.includes(type);
if (isSelected) {
this.filters.modelTypes = this.filters.modelTypes.filter(value => value !== type);
tag.classList.remove('active');
} else {
this.filters.modelTypes.push(type);
tag.classList.add('active');
}
this.updateActiveFiltersCount();
await this.applyFilters(false);
});
modelTypeContainer.appendChild(tag);
});
this.updateModelTypeSelections();
} catch (error) {
console.error('Error loading model types:', error);
modelTypeContainer.innerHTML = '<div class="tags-error">Failed to load model types</div>';
}
}
toggleFilterPanel() {
@@ -309,12 +388,26 @@ export class FilterManager {
// Update license tags
this.updateLicenseSelections();
this.updateModelTypeSelections();
}
updateModelTypeSelections() {
const typeTags = document.querySelectorAll('.model-type-tag');
typeTags.forEach(tag => {
const modelType = tag.dataset.modelType;
if (this.filters.modelTypes.includes(modelType)) {
tag.classList.add('active');
} else {
tag.classList.remove('active');
}
});
}
updateActiveFiltersCount() {
const tagFilterCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseFilterCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount;
const modelTypeFilterCount = this.filters.modelTypes.length;
const totalActiveFilters = this.filters.baseModel.length + tagFilterCount + licenseFilterCount + modelTypeFilterCount;
if (this.activeFiltersCount) {
if (totalActiveFilters > 0) {
@@ -377,7 +470,8 @@ export class FilterManager {
...this.filters,
baseModel: [],
tags: {},
license: {}
license: {},
modelTypes: []
});
// Update state
@@ -437,7 +531,13 @@ export class FilterManager {
hasActiveFilters() {
const tagCount = this.filters.tags ? Object.keys(this.filters.tags).length : 0;
const licenseCount = this.filters.license ? Object.keys(this.filters.license).length : 0;
return this.filters.baseModel.length > 0 || tagCount > 0 || licenseCount > 0;
const modelTypeCount = this.filters.modelTypes.length;
return (
this.filters.baseModel.length > 0 ||
tagCount > 0 ||
licenseCount > 0 ||
modelTypeCount > 0
);
}
initializeFilters(existingFilters = {}) {
@@ -446,7 +546,8 @@ export class FilterManager {
...source,
baseModel: Array.isArray(source.baseModel) ? [...source.baseModel] : [],
tags: this.normalizeTagFilters(source.tags),
license: this.normalizeLicenseFilters(source.license)
license: this.normalizeLicenseFilters(source.license),
modelTypes: this.normalizeModelTypeFilters(source.modelTypes)
};
}
@@ -496,12 +597,35 @@ export class FilterManager {
return normalized;
}
normalizeModelTypeFilters(modelTypes) {
if (!Array.isArray(modelTypes)) {
return [];
}
const seen = new Set();
return modelTypes.reduce((acc, type) => {
if (typeof type !== 'string') {
return acc;
}
const normalized = type.trim().toLowerCase();
if (!normalized || seen.has(normalized)) {
return acc;
}
seen.add(normalized);
acc.push(normalized);
return acc;
}, []);
}
cloneFilters() {
return {
...this.filters,
baseModel: [...(this.filters.baseModel || [])],
tags: { ...(this.filters.tags || {}) },
license: { ...(this.filters.license || {}) }
license: { ...(this.filters.license || {}) },
modelTypes: [...(this.filters.modelTypes || [])]
};
}

View File

@@ -79,7 +79,8 @@ export const state = {
filters: {
baseModel: [],
tags: {},
license: {}
license: {},
modelTypes: []
},
bulkMode: false,
selectedLoras: new Set(),
@@ -105,6 +106,7 @@ export const state = {
baseModel: [],
tags: {},
license: {},
modelTypes: [],
search: ''
},
pageSize: 20,
@@ -131,7 +133,8 @@ export const state = {
filters: {
baseModel: [],
tags: {},
license: {}
license: {},
modelTypes: []
},
modelType: 'checkpoint', // 'checkpoint' or 'diffusion_model'
bulkMode: false,
@@ -161,7 +164,8 @@ export const state = {
filters: {
baseModel: [],
tags: {},
license: {}
license: {},
modelTypes: []
},
bulkMode: false,
selectedModels: new Set(),

View File

@@ -55,6 +55,12 @@ export const BASE_MODELS = {
UNKNOWN: "Other"
};
export const MODEL_TYPE_DISPLAY_NAMES = {
lora: "LoRA",
locon: "LyCORIS",
dora: "DoRA",
};
export const BASE_MODEL_ABBREVIATIONS = {
// Stable Diffusion 1.x models
[BASE_MODELS.SD_1_4]: 'SD1',

View File

@@ -139,6 +139,14 @@
<div class="tags-loading">{{ t('common.status.loading') }}</div>
</div>
</div>
{% if current_page == 'loras' %}
<div class="filter-section">
<h4>{{ t('header.filter.modelTypes') }}</h4>
<div class="filter-tags" id="modelTypeTags">
<div class="tags-loading">{{ t('common.status.loading') }}</div>
</div>
</div>
{% endif %}
<div class="filter-section">
<h4>{{ t('header.filter.license') }}</h4>
<div class="filter-tags">
@@ -156,4 +164,3 @@
</button>
</div>
</div>

View File

@@ -212,6 +212,7 @@ class MockModelService:
self.model_type = "test-model"
self.paginated_items: List[Dict[str, Any]] = []
self.formatted: List[Dict[str, Any]] = []
self.model_types: List[Dict[str, Any]] = []
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
items = [dict(item) for item in self.paginated_items]
@@ -257,6 +258,9 @@ class MockModelService:
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
return []
async def get_model_types(self, limit: int = 20):
return list(self.model_types)[:limit]
def has_hash(self, *_args, **_kwargs): # pragma: no cover
return False
@@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS
def mock_service(mock_scanner: MockScanner) -> MockModelService:
return MockModelService(scanner=mock_scanner)

View File

@@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
asyncio.run(scenario())
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
mock_service.model_types = [
{"type": "LoRa", "count": 3},
{"type": "Checkpoint", "count": 1},
]
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.get("/api/lm/test-models/model-types?limit=1")
payload = await response.json()
assert response.status == 200
assert payload["model_types"] == mock_service.model_types[:1]
finally:
await client.close()
asyncio.run(scenario())
def test_routes_return_service_not_ready_when_unattached():
async def scenario():
client = await create_test_client(None)

View File

@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
def test_model_filter_set_filters_by_model_types():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
]
criteria = FilterCriteria(model_types=["locon"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["LoConModel"]
def test_model_filter_set_defaults_missing_model_type_to_lora():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "DefaultModel"},
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
]
criteria = FilterCriteria(model_types=["lora"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["DefaultModel"]
@pytest.mark.asyncio
async def test_get_model_types_counts_and_limits():
raw_data = [
{"civitai": {"model": {"type": "LoRa"}}},
{"model_type": "LoRa"},
{"civitai": {"model": {"type": "LoCon"}}},
{},
]
class CacheStub:
def __init__(self, raw_data):
self.raw_data = raw_data
class ScannerStub:
def __init__(self, cache):
self._cache = cache
async def get_cached_data(self, *_, **__):
return self._cache
cache = CacheStub(raw_data)
scanner = ScannerStub(cache)
service = DummyService(
model_type="stub",
scanner=scanner,
metadata_class=BaseModelMetadata,
)
types = await service.get_model_types(limit=1)
assert types == [{"type": "lora", "count": 3}]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_cls, extra_fields",