feat(model): add model type filtering support

- Add model_types parameter to ModelListingHandler to support filtering by model type
- Implement get_model_types endpoint in ModelQueryHandler to retrieve available model types
- Register new /api/lm/{prefix}/model-types route for model type queries
- Extend BaseModelService to handle model type filtering in queries
- Support both model_type and civitai_model_type query parameters for backward compatibility

This enables users to filter models by specific types, improving model discovery and organization capabilities.
This commit is contained in:
Will Miao
2025-11-18 15:36:01 +08:00
parent 059ebeead7
commit 57f369a6de
9 changed files with 205 additions and 6 deletions

View File

@@ -152,6 +152,8 @@ class ModelListingHandler:
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
base_models = request.query.getall("base_model", [])
model_types = list(request.query.getall("model_type", []))
model_types.extend(request.query.getall("civitai_model_type", []))
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
legacy_tags = request.query.getall("tag", [])
if not legacy_tags:
@@ -225,6 +227,7 @@ class ModelListingHandler:
"update_available_only": update_available_only,
"credit_required": credit_required,
"allow_selling_generated_content": allow_selling_generated_content,
"model_types": model_types,
**self._parse_specific_params(request),
}
@@ -557,6 +560,17 @@ class ModelQueryHandler:
self._logger.error("Error retrieving base models: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_types(self, request: web.Request) -> web.Response:
try:
limit = int(request.query.get("limit", "20"))
if limit < 1 or limit > 100:
limit = 20
model_types = await self._service.get_model_types(limit)
return web.json_response({"success": True, "model_types": model_types})
except Exception as exc:
self._logger.error("Error retrieving model types: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def scan_models(self, request: web.Request) -> web.Response:
try:
full_rebuild = request.query.get("full_rebuild", "false").lower() == "true"
@@ -1579,6 +1593,7 @@ class ModelHandlerSet:
"verify_duplicates": self.management.verify_duplicates,
"get_top_tags": self.query.get_top_tags,
"get_base_models": self.query.get_base_models,
"get_model_types": self.query.get_model_types,
"scan_models": self.query.scan_models,
"get_model_roots": self.query.get_model_roots,
"get_folders": self.query.get_folders,

View File

@@ -39,6 +39,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),

View File

@@ -1,12 +1,19 @@
from abc import ABC, abstractmethod
import asyncio
from typing import Dict, List, Optional, Type, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import logging
import os
from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
from .model_query import (
FilterCriteria,
ModelCacheRepository,
ModelFilterSet,
SearchStrategy,
SettingsProvider,
resolve_civitai_model_type,
)
from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__)
@@ -59,6 +66,7 @@ class BaseModelService(ABC):
search: str = None,
fuzzy_search: bool = False,
base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None,
search_options: dict = None,
hash_filters: dict = None,
@@ -80,6 +88,7 @@ class BaseModelService(ABC):
sorted_data,
folder=folder,
base_models=base_models,
model_types=model_types,
tags=tags,
favorites_only=favorites_only,
search_options=search_options,
@@ -149,6 +158,7 @@ class BaseModelService(ABC):
data: List[Dict],
folder: str = None,
base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None,
favorites_only: bool = False,
search_options: dict = None,
@@ -158,6 +168,7 @@ class BaseModelService(ABC):
criteria = FilterCriteria(
folder=folder,
base_models=base_models,
model_types=model_types,
tags=tags,
favorites_only=favorites_only,
search_options=normalized_options,
@@ -456,6 +467,22 @@ class BaseModelService(ABC):
async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit)
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
"""Get counts of CivitAI model types present in the cache."""
cache = await self.scanner.get_cached_data()
type_counts: Dict[str, int] = {}
for entry in cache.raw_data:
model_type = resolve_civitai_model_type(entry)
type_counts[model_type] = type_counts.get(model_type, 0) + 1
sorted_types = sorted(
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
key=lambda value: value["count"],
reverse=True,
)
return sorted_types[:limit]
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""

View File

@@ -1,12 +1,49 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
def _coerce_to_str(value: Any) -> Optional[str]:
if value is None:
return None
candidate = str(value).strip()
return candidate if candidate else None
def normalize_civitai_model_type(value: Any) -> Optional[str]:
"""Return a lowercase string suitable for comparisons."""
candidate = _coerce_to_str(value)
return candidate.lower() if candidate else None
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
if not isinstance(entry, Mapping):
return DEFAULT_CIVITAI_MODEL_TYPE
civitai = entry.get("civitai")
if isinstance(civitai, Mapping):
civitai_model = civitai.get("model")
if isinstance(civitai_model, Mapping):
model_type = _coerce_to_str(civitai_model.get("type"))
if model_type:
return model_type
model_type = _coerce_to_str(entry.get("model_type"))
if model_type:
return model_type
return DEFAULT_CIVITAI_MODEL_TYPE
class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers."""
@@ -31,6 +68,7 @@ class FilterCriteria:
tags: Optional[Dict[str, str]] = None
favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None
model_types: Optional[Sequence[str]] = None
class ModelCacheRepository:
@@ -134,6 +172,19 @@ class ModelFilterSet:
if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
]
model_types = criteria.model_types or []
normalized_model_types = {
model_type for model_type in (
normalize_civitai_model_type(value) for value in model_types
)
if model_type
}
if normalized_model_types:
items = [
item for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
]
return items

View File

@@ -161,6 +161,12 @@ class ModelScanner:
if trained_words:
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
civitai_model = civitai.get('model')
if isinstance(civitai_model, Mapping):
model_type_value = civitai_model.get('type')
if model_type_value not in (None, '', []):
slim['model'] = {'type': model_type_value}
return slim or None
def _build_cache_entry(

View File

@@ -5,7 +5,7 @@ import re
import sqlite3
import threading
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Dict, List, Mapping, Optional, Sequence, Tuple
from ..utils.settings_paths import get_project_root, get_settings_dir
@@ -47,6 +47,7 @@ class PersistentModelCache:
"metadata_source",
"civitai_id",
"civitai_model_id",
"civitai_model_type",
"civitai_name",
"civitai_creator_username",
"trained_words",
@@ -138,7 +139,8 @@ class PersistentModelCache:
creator_username = row["civitai_creator_username"]
civitai: Optional[Dict] = None
civitai_has_data = any(
row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name")
row[col] is not None
for col in ("civitai_id", "civitai_model_id", "civitai_model_type", "civitai_name")
) or trained_words or creator_username
if civitai_has_data:
civitai = {}
@@ -152,6 +154,9 @@ class PersistentModelCache:
civitai["trainedWords"] = trained_words
if creator_username:
civitai.setdefault("creator", {})["username"] = creator_username
model_type_value = row["civitai_model_type"]
if model_type_value:
civitai.setdefault("model", {})["type"] = model_type_value
license_value = row["license_flags"]
if license_value is None:
@@ -443,6 +448,7 @@ class PersistentModelCache:
metadata_source TEXT,
civitai_id INTEGER,
civitai_model_id INTEGER,
civitai_model_type TEXT,
civitai_name TEXT,
civitai_creator_username TEXT,
trained_words TEXT,
@@ -492,6 +498,7 @@ class PersistentModelCache:
required_columns = {
"metadata_source": "TEXT",
"civitai_creator_username": "TEXT",
"civitai_model_type": "TEXT",
"civitai_deleted": "INTEGER DEFAULT 0",
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
@@ -528,6 +535,13 @@ class PersistentModelCache:
creator_data = civitai.get("creator") if isinstance(civitai, dict) else None
if isinstance(creator_data, dict):
creator_username = creator_data.get("username") or None
model_type_value = None
if isinstance(civitai, Mapping):
civitai_model_info = civitai.get("model")
if isinstance(civitai_model_info, Mapping):
candidate_type = civitai_model_info.get("type")
if candidate_type not in (None, "", []):
model_type_value = candidate_type
license_flags = item.get("license_flags")
if license_flags is None:
@@ -552,6 +566,7 @@ class PersistentModelCache:
metadata_source,
civitai.get("id"),
civitai.get("modelId"),
model_type_value,
civitai.get("name"),
creator_username,
trained_words_json,

View File

@@ -212,6 +212,7 @@ class MockModelService:
self.model_type = "test-model"
self.paginated_items: List[Dict[str, Any]] = []
self.formatted: List[Dict[str, Any]] = []
self.model_types: List[Dict[str, Any]] = []
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
items = [dict(item) for item in self.paginated_items]
@@ -257,6 +258,9 @@ class MockModelService:
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
return []
async def get_model_types(self, limit: int = 20):
return list(self.model_types)[:limit]
def has_hash(self, *_args, **_kwargs): # pragma: no cover
return False
@@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS
def mock_service(mock_scanner: MockScanner) -> MockModelService:
return MockModelService(scanner=mock_scanner)

View File

@@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
asyncio.run(scenario())
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
mock_service.model_types = [
{"type": "LoRa", "count": 3},
{"type": "Checkpoint", "count": 1},
]
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.get("/api/lm/test-models/model-types?limit=1")
payload = await response.json()
assert response.status == 200
assert payload["model_types"] == mock_service.model_types[:1]
finally:
await client.close()
asyncio.run(scenario())
def test_routes_return_service_not_ready_when_unattached():
async def scenario():
client = await create_test_client(None)

View File

@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
def test_model_filter_set_filters_by_model_types():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
]
criteria = FilterCriteria(model_types=["locon"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["LoConModel"]
def test_model_filter_set_defaults_missing_model_type_to_lora():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "DefaultModel"},
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
]
criteria = FilterCriteria(model_types=["lora"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["DefaultModel"]
@pytest.mark.asyncio
async def test_get_model_types_counts_and_limits():
raw_data = [
{"civitai": {"model": {"type": "LoRa"}}},
{"model_type": "LoRa"},
{"civitai": {"model": {"type": "LoCon"}}},
{},
]
class CacheStub:
def __init__(self, raw_data):
self.raw_data = raw_data
class ScannerStub:
def __init__(self, cache):
self._cache = cache
async def get_cached_data(self, *_, **__):
return self._cache
cache = CacheStub(raw_data)
scanner = ScannerStub(cache)
service = DummyService(
model_type="stub",
scanner=scanner,
metadata_class=BaseModelMetadata,
)
types = await service.get_model_types(limit=1)
assert types == [{"type": "LoRa", "count": 2}]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_cls, extra_fields",