feat(model): add model type filtering support

- Add model_types parameter to ModelListingHandler to support filtering by model type
- Implement get_model_types endpoint in ModelQueryHandler to retrieve available model types
- Register new /api/lm/{prefix}/model-types route for model type queries
- Extend BaseModelService to handle model type filtering in queries
- Support both model_type and civitai_model_type query parameters for backward compatibility

This enables users to filter models by specific types, improving model discovery and organization capabilities.
This commit is contained in:
Will Miao
2025-11-18 15:36:01 +08:00
parent 059ebeead7
commit 57f369a6de
9 changed files with 205 additions and 6 deletions

View File

@@ -152,6 +152,8 @@ class ModelListingHandler:
fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true" fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true"
base_models = request.query.getall("base_model", []) base_models = request.query.getall("base_model", [])
model_types = list(request.query.getall("model_type", []))
model_types.extend(request.query.getall("civitai_model_type", []))
# Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters # Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters
legacy_tags = request.query.getall("tag", []) legacy_tags = request.query.getall("tag", [])
if not legacy_tags: if not legacy_tags:
@@ -225,6 +227,7 @@ class ModelListingHandler:
"update_available_only": update_available_only, "update_available_only": update_available_only,
"credit_required": credit_required, "credit_required": credit_required,
"allow_selling_generated_content": allow_selling_generated_content, "allow_selling_generated_content": allow_selling_generated_content,
"model_types": model_types,
**self._parse_specific_params(request), **self._parse_specific_params(request),
} }
@@ -557,6 +560,17 @@ class ModelQueryHandler:
self._logger.error("Error retrieving base models: %s", exc) self._logger.error("Error retrieving base models: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_model_types(self, request: web.Request) -> web.Response:
try:
limit = int(request.query.get("limit", "20"))
if limit < 1 or limit > 100:
limit = 20
model_types = await self._service.get_model_types(limit)
return web.json_response({"success": True, "model_types": model_types})
except Exception as exc:
self._logger.error("Error retrieving model types: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def scan_models(self, request: web.Request) -> web.Response: async def scan_models(self, request: web.Request) -> web.Response:
try: try:
full_rebuild = request.query.get("full_rebuild", "false").lower() == "true" full_rebuild = request.query.get("full_rebuild", "false").lower() == "true"
@@ -1579,6 +1593,7 @@ class ModelHandlerSet:
"verify_duplicates": self.management.verify_duplicates, "verify_duplicates": self.management.verify_duplicates,
"get_top_tags": self.query.get_top_tags, "get_top_tags": self.query.get_top_tags,
"get_base_models": self.query.get_base_models, "get_base_models": self.query.get_base_models,
"get_model_types": self.query.get_model_types,
"scan_models": self.query.scan_models, "scan_models": self.query.scan_models,
"get_model_roots": self.query.get_model_roots, "get_model_roots": self.query.get_model_roots,
"get_folders": self.query.get_folders, "get_folders": self.query.get_folders,

View File

@@ -39,6 +39,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"), RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"), RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"), RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"), RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"), RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"), RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),

View File

@@ -1,12 +1,19 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import asyncio import asyncio
from typing import Dict, List, Optional, Type, TYPE_CHECKING from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
import logging import logging
import os import os
from ..utils.models import BaseModelMetadata from ..utils.models import BaseModelMetadata
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider from .model_query import (
FilterCriteria,
ModelCacheRepository,
ModelFilterSet,
SearchStrategy,
SettingsProvider,
resolve_civitai_model_type,
)
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -59,6 +66,7 @@ class BaseModelService(ABC):
search: str = None, search: str = None,
fuzzy_search: bool = False, fuzzy_search: bool = False,
base_models: list = None, base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
search_options: dict = None, search_options: dict = None,
hash_filters: dict = None, hash_filters: dict = None,
@@ -80,6 +88,7 @@ class BaseModelService(ABC):
sorted_data, sorted_data,
folder=folder, folder=folder,
base_models=base_models, base_models=base_models,
model_types=model_types,
tags=tags, tags=tags,
favorites_only=favorites_only, favorites_only=favorites_only,
search_options=search_options, search_options=search_options,
@@ -149,6 +158,7 @@ class BaseModelService(ABC):
data: List[Dict], data: List[Dict],
folder: str = None, folder: str = None,
base_models: list = None, base_models: list = None,
model_types: list = None,
tags: Optional[Dict[str, str]] = None, tags: Optional[Dict[str, str]] = None,
favorites_only: bool = False, favorites_only: bool = False,
search_options: dict = None, search_options: dict = None,
@@ -158,6 +168,7 @@ class BaseModelService(ABC):
criteria = FilterCriteria( criteria = FilterCriteria(
folder=folder, folder=folder,
base_models=base_models, base_models=base_models,
model_types=model_types,
tags=tags, tags=tags,
favorites_only=favorites_only, favorites_only=favorites_only,
search_options=normalized_options, search_options=normalized_options,
@@ -456,6 +467,22 @@ class BaseModelService(ABC):
async def get_base_models(self, limit: int = 20) -> List[Dict]: async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency""" """Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit) return await self.scanner.get_base_models(limit)
async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]:
"""Get counts of CivitAI model types present in the cache."""
cache = await self.scanner.get_cached_data()
type_counts: Dict[str, int] = {}
for entry in cache.raw_data:
model_type = resolve_civitai_model_type(entry)
type_counts[model_type] = type_counts.get(model_type, 0) + 1
sorted_types = sorted(
[{"type": model_type, "count": count} for model_type, count in type_counts.items()],
key=lambda value: value["count"],
reverse=True,
)
return sorted_types[:limit]
def has_hash(self, sha256: str) -> bool: def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists""" """Check if a model with given hash exists"""

View File

@@ -1,12 +1,49 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable
from ..utils.constants import NSFW_LEVELS from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match from ..utils.utils import fuzzy_match as default_fuzzy_match
DEFAULT_CIVITAI_MODEL_TYPE = "LORA"
def _coerce_to_str(value: Any) -> Optional[str]:
if value is None:
return None
candidate = str(value).strip()
return candidate if candidate else None
def normalize_civitai_model_type(value: Any) -> Optional[str]:
"""Return a lowercase string suitable for comparisons."""
candidate = _coerce_to_str(value)
return candidate.lower() if candidate else None
def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str:
"""Extract the model type from CivitAI metadata, defaulting to LORA."""
if not isinstance(entry, Mapping):
return DEFAULT_CIVITAI_MODEL_TYPE
civitai = entry.get("civitai")
if isinstance(civitai, Mapping):
civitai_model = civitai.get("model")
if isinstance(civitai_model, Mapping):
model_type = _coerce_to_str(civitai_model.get("type"))
if model_type:
return model_type
model_type = _coerce_to_str(entry.get("model_type"))
if model_type:
return model_type
return DEFAULT_CIVITAI_MODEL_TYPE
class SettingsProvider(Protocol): class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers.""" """Protocol describing the SettingsManager contract used by query helpers."""
@@ -31,6 +68,7 @@ class FilterCriteria:
tags: Optional[Dict[str, str]] = None tags: Optional[Dict[str, str]] = None
favorites_only: bool = False favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None search_options: Optional[Dict[str, Any]] = None
model_types: Optional[Sequence[str]] = None
class ModelCacheRepository: class ModelCacheRepository:
@@ -134,6 +172,19 @@ class ModelFilterSet:
if not any(tag in exclude_tags for tag in (item.get("tags", []) or [])) if not any(tag in exclude_tags for tag in (item.get("tags", []) or []))
] ]
model_types = criteria.model_types or []
normalized_model_types = {
model_type for model_type in (
normalize_civitai_model_type(value) for value in model_types
)
if model_type
}
if normalized_model_types:
items = [
item for item in items
if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types
]
return items return items

View File

@@ -161,6 +161,12 @@ class ModelScanner:
if trained_words: if trained_words:
slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words
civitai_model = civitai.get('model')
if isinstance(civitai_model, Mapping):
model_type_value = civitai_model.get('type')
if model_type_value not in (None, '', []):
slim['model'] = {'type': model_type_value}
return slim or None return slim or None
def _build_cache_entry( def _build_cache_entry(

View File

@@ -5,7 +5,7 @@ import re
import sqlite3 import sqlite3
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple from typing import Dict, List, Mapping, Optional, Sequence, Tuple
from ..utils.settings_paths import get_project_root, get_settings_dir from ..utils.settings_paths import get_project_root, get_settings_dir
@@ -47,6 +47,7 @@ class PersistentModelCache:
"metadata_source", "metadata_source",
"civitai_id", "civitai_id",
"civitai_model_id", "civitai_model_id",
"civitai_model_type",
"civitai_name", "civitai_name",
"civitai_creator_username", "civitai_creator_username",
"trained_words", "trained_words",
@@ -138,7 +139,8 @@ class PersistentModelCache:
creator_username = row["civitai_creator_username"] creator_username = row["civitai_creator_username"]
civitai: Optional[Dict] = None civitai: Optional[Dict] = None
civitai_has_data = any( civitai_has_data = any(
row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name") row[col] is not None
for col in ("civitai_id", "civitai_model_id", "civitai_model_type", "civitai_name")
) or trained_words or creator_username ) or trained_words or creator_username
if civitai_has_data: if civitai_has_data:
civitai = {} civitai = {}
@@ -152,6 +154,9 @@ class PersistentModelCache:
civitai["trainedWords"] = trained_words civitai["trainedWords"] = trained_words
if creator_username: if creator_username:
civitai.setdefault("creator", {})["username"] = creator_username civitai.setdefault("creator", {})["username"] = creator_username
model_type_value = row["civitai_model_type"]
if model_type_value:
civitai.setdefault("model", {})["type"] = model_type_value
license_value = row["license_flags"] license_value = row["license_flags"]
if license_value is None: if license_value is None:
@@ -443,6 +448,7 @@ class PersistentModelCache:
metadata_source TEXT, metadata_source TEXT,
civitai_id INTEGER, civitai_id INTEGER,
civitai_model_id INTEGER, civitai_model_id INTEGER,
civitai_model_type TEXT,
civitai_name TEXT, civitai_name TEXT,
civitai_creator_username TEXT, civitai_creator_username TEXT,
trained_words TEXT, trained_words TEXT,
@@ -492,6 +498,7 @@ class PersistentModelCache:
required_columns = { required_columns = {
"metadata_source": "TEXT", "metadata_source": "TEXT",
"civitai_creator_username": "TEXT", "civitai_creator_username": "TEXT",
"civitai_model_type": "TEXT",
"civitai_deleted": "INTEGER DEFAULT 0", "civitai_deleted": "INTEGER DEFAULT 0",
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57). # Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}", "license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
@@ -528,6 +535,13 @@ class PersistentModelCache:
creator_data = civitai.get("creator") if isinstance(civitai, dict) else None creator_data = civitai.get("creator") if isinstance(civitai, dict) else None
if isinstance(creator_data, dict): if isinstance(creator_data, dict):
creator_username = creator_data.get("username") or None creator_username = creator_data.get("username") or None
model_type_value = None
if isinstance(civitai, Mapping):
civitai_model_info = civitai.get("model")
if isinstance(civitai_model_info, Mapping):
candidate_type = civitai_model_info.get("type")
if candidate_type not in (None, "", []):
model_type_value = candidate_type
license_flags = item.get("license_flags") license_flags = item.get("license_flags")
if license_flags is None: if license_flags is None:
@@ -552,6 +566,7 @@ class PersistentModelCache:
metadata_source, metadata_source,
civitai.get("id"), civitai.get("id"),
civitai.get("modelId"), civitai.get("modelId"),
model_type_value,
civitai.get("name"), civitai.get("name"),
creator_username, creator_username,
trained_words_json, trained_words_json,

View File

@@ -212,6 +212,7 @@ class MockModelService:
self.model_type = "test-model" self.model_type = "test-model"
self.paginated_items: List[Dict[str, Any]] = [] self.paginated_items: List[Dict[str, Any]] = []
self.formatted: List[Dict[str, Any]] = [] self.formatted: List[Dict[str, Any]] = []
self.model_types: List[Dict[str, Any]] = []
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
items = [dict(item) for item in self.paginated_items] items = [dict(item) for item in self.paginated_items]
@@ -257,6 +258,9 @@ class MockModelService:
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
return [] return []
async def get_model_types(self, limit: int = 20):
return list(self.model_types)[:limit]
def has_hash(self, *_args, **_kwargs): # pragma: no cover def has_hash(self, *_args, **_kwargs): # pragma: no cover
return False return False
@@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS
def mock_service(mock_scanner: MockScanner) -> MockModelService: def mock_service(mock_scanner: MockScanner) -> MockModelService:
return MockModelService(scanner=mock_scanner) return MockModelService(scanner=mock_scanner)

View File

@@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
asyncio.run(scenario()) asyncio.run(scenario())
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
mock_service.model_types = [
{"type": "LoRa", "count": 3},
{"type": "Checkpoint", "count": 1},
]
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.get("/api/lm/test-models/model-types?limit=1")
payload = await response.json()
assert response.status == 200
assert payload["model_types"] == mock_service.model_types[:1]
finally:
await client.close()
asyncio.run(scenario())
def test_routes_return_service_not_ready_when_unattached(): def test_routes_return_service_not_ready_when_unattached():
async def scenario(): async def scenario():
client = await create_test_client(None) client = await create_test_client(None)

View File

@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"] assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
def test_model_filter_set_filters_by_model_types():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
]
criteria = FilterCriteria(model_types=["locon"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["LoConModel"]
def test_model_filter_set_defaults_missing_model_type_to_lora():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "DefaultModel"},
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
]
criteria = FilterCriteria(model_types=["lora"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["DefaultModel"]
@pytest.mark.asyncio
async def test_get_model_types_counts_and_limits():
raw_data = [
{"civitai": {"model": {"type": "LoRa"}}},
{"model_type": "LoRa"},
{"civitai": {"model": {"type": "LoCon"}}},
{},
]
class CacheStub:
def __init__(self, raw_data):
self.raw_data = raw_data
class ScannerStub:
def __init__(self, cache):
self._cache = cache
async def get_cached_data(self, *_, **__):
return self._cache
cache = CacheStub(raw_data)
scanner = ScannerStub(cache)
service = DummyService(
model_type="stub",
scanner=scanner,
metadata_class=BaseModelMetadata,
)
types = await service.get_model_types(limit=1)
assert types == [{"type": "LoRa", "count": 2}]
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"service_cls, extra_fields", "service_cls, extra_fields",