From 57f369a6deeb170af8e57e65fbb34a26c8047fb6 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Tue, 18 Nov 2025 15:36:01 +0800 Subject: [PATCH] feat(model): add model type filtering support - Add model_types parameter to ModelListingHandler to support filtering by model type - Implement get_model_types endpoint in ModelQueryHandler to retrieve available model types - Register new /api/lm/{prefix}/model-types route for model type queries - Extend BaseModelService to handle model type filtering in queries - Support both model_type and civitai_model_type query parameters for backward compatibility This enables users to filter models by specific types, improving model discovery and organization capabilities. --- py/routes/handlers/model_handlers.py | 15 +++++ py/routes/model_route_registrar.py | 1 + py/services/base_model_service.py | 31 +++++++++- py/services/model_query.py | 53 ++++++++++++++++- py/services/model_scanner.py | 6 ++ py/services/persistent_model_cache.py | 19 +++++- tests/conftest.py | 5 +- tests/routes/test_base_model_routes_smoke.py | 20 +++++++ tests/services/test_base_model_service.py | 61 ++++++++++++++++++++ 9 files changed, 205 insertions(+), 6 deletions(-) diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index fee783dd..c74dccca 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -152,6 +152,8 @@ class ModelListingHandler: fuzzy_search = request.query.get("fuzzy_search", "false").lower() == "true" base_models = request.query.getall("base_model", []) + model_types = list(request.query.getall("model_type", [])) + model_types.extend(request.query.getall("civitai_model_type", [])) # Support legacy ?tag=foo plus new ?tag_include/foo & ?tag_exclude parameters legacy_tags = request.query.getall("tag", []) if not legacy_tags: @@ -225,6 +227,7 @@ class ModelListingHandler: "update_available_only": update_available_only, "credit_required": credit_required, "allow_selling_generated_content": allow_selling_generated_content, + "model_types": model_types, **self._parse_specific_params(request), } @@ -557,6 +560,17 @@ class ModelQueryHandler: self._logger.error("Error retrieving base models: %s", exc) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_model_types(self, request: web.Request) -> web.Response: + try: + limit = int(request.query.get("limit", "20")) + if limit < 1 or limit > 100: + limit = 20 + model_types = await self._service.get_model_types(limit) + return web.json_response({"success": True, "model_types": model_types}) + except Exception as exc: + self._logger.error("Error retrieving model types: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def scan_models(self, request: web.Request) -> web.Response: try: full_rebuild = request.query.get("full_rebuild", "false").lower() == "true" @@ -1579,6 +1593,7 @@ class ModelHandlerSet: "verify_duplicates": self.management.verify_duplicates, "get_top_tags": self.query.get_top_tags, "get_base_models": self.query.get_base_models, + "get_model_types": self.query.get_model_types, "scan_models": self.query.scan_models, "get_model_roots": self.query.get_model_roots, "get_folders": self.query.get_folders, diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index ce7a75ba..21589c7b 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -39,6 +39,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"), RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"), RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"), + RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"), RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"), RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"), RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"), diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index 2fd393ec..1c983a39 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -1,12 +1,19 @@ from abc import ABC, abstractmethod import asyncio -from typing import Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING import logging import os from ..utils.models import BaseModelMetadata from ..utils.metadata_manager import MetadataManager -from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider +from .model_query import ( + FilterCriteria, + ModelCacheRepository, + ModelFilterSet, + SearchStrategy, + SettingsProvider, + resolve_civitai_model_type, +) from .settings_manager import get_settings_manager logger = logging.getLogger(__name__) @@ -59,6 +66,7 @@ class BaseModelService(ABC): search: str = None, fuzzy_search: bool = False, base_models: list = None, + model_types: list = None, tags: Optional[Dict[str, str]] = None, search_options: dict = None, hash_filters: dict = None, @@ -80,6 +88,7 @@ class BaseModelService(ABC): sorted_data, folder=folder, base_models=base_models, + model_types=model_types, tags=tags, favorites_only=favorites_only, search_options=search_options, @@ -149,6 +158,7 @@ class BaseModelService(ABC): data: List[Dict], folder: str = None, base_models: list = None, + model_types: list = None, tags: Optional[Dict[str, str]] = None, favorites_only: bool = False, search_options: dict = None, @@ -158,6 +168,7 @@ class BaseModelService(ABC): criteria = FilterCriteria( folder=folder, base_models=base_models, + model_types=model_types, tags=tags, favorites_only=favorites_only, search_options=normalized_options, @@ -456,6 +467,22 @@ class BaseModelService(ABC): async def get_base_models(self, limit: int = 20) -> List[Dict]: """Get base models sorted by frequency""" return await self.scanner.get_base_models(limit) + + async def get_model_types(self, limit: int = 20) -> List[Dict[str, Any]]: + """Get counts of CivitAI model types present in the cache.""" + cache = await self.scanner.get_cached_data() + type_counts: Dict[str, int] = {} + for entry in cache.raw_data: + model_type = resolve_civitai_model_type(entry) + type_counts[model_type] = type_counts.get(model_type, 0) + 1 + + sorted_types = sorted( + [{"type": model_type, "count": count} for model_type, count in type_counts.items()], + key=lambda value: value["count"], + reverse=True, + ) + + return sorted_types[:limit] def has_hash(self, sha256: str) -> bool: """Check if a model with given hash exists""" diff --git a/py/services/model_query.py b/py/services/model_query.py index d88e9631..5b370138 100644 --- a/py/services/model_query.py +++ b/py/services/model_query.py @@ -1,12 +1,49 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable +from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Protocol, Callable from ..utils.constants import NSFW_LEVELS from ..utils.utils import fuzzy_match as default_fuzzy_match +DEFAULT_CIVITAI_MODEL_TYPE = "LORA" + + +def _coerce_to_str(value: Any) -> Optional[str]: + if value is None: + return None + + candidate = str(value).strip() + return candidate if candidate else None + + +def normalize_civitai_model_type(value: Any) -> Optional[str]: + """Return a lowercase string suitable for comparisons.""" + candidate = _coerce_to_str(value) + return candidate.lower() if candidate else None + + +def resolve_civitai_model_type(entry: Mapping[str, Any]) -> str: + """Extract the model type from CivitAI metadata, defaulting to LORA.""" + if not isinstance(entry, Mapping): + return DEFAULT_CIVITAI_MODEL_TYPE + + civitai = entry.get("civitai") + if isinstance(civitai, Mapping): + civitai_model = civitai.get("model") + if isinstance(civitai_model, Mapping): + model_type = _coerce_to_str(civitai_model.get("type")) + if model_type: + return model_type + + model_type = _coerce_to_str(entry.get("model_type")) + if model_type: + return model_type + + return DEFAULT_CIVITAI_MODEL_TYPE + + class SettingsProvider(Protocol): """Protocol describing the SettingsManager contract used by query helpers.""" @@ -31,6 +68,7 @@ class FilterCriteria: tags: Optional[Dict[str, str]] = None favorites_only: bool = False search_options: Optional[Dict[str, Any]] = None + model_types: Optional[Sequence[str]] = None class ModelCacheRepository: @@ -134,6 +172,19 @@ class ModelFilterSet: if not any(tag in exclude_tags for tag in (item.get("tags", []) or [])) ] + model_types = criteria.model_types or [] + normalized_model_types = { + model_type for model_type in ( + normalize_civitai_model_type(value) for value in model_types + ) + if model_type + } + if normalized_model_types: + items = [ + item for item in items + if normalize_civitai_model_type(resolve_civitai_model_type(item)) in normalized_model_types + ] + return items diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index ebe05c42..8acaff4b 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -161,6 +161,12 @@ class ModelScanner: if trained_words: slim['trainedWords'] = list(trained_words) if isinstance(trained_words, list) else trained_words + civitai_model = civitai.get('model') + if isinstance(civitai_model, Mapping): + model_type_value = civitai_model.get('type') + if model_type_value not in (None, '', []): + slim['model'] = {'type': model_type_value} + return slim or None def _build_cache_entry( diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py index cda1d0b3..c3ebcc27 100644 --- a/py/services/persistent_model_cache.py +++ b/py/services/persistent_model_cache.py @@ -5,7 +5,7 @@ import re import sqlite3 import threading from dataclasses import dataclass -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Mapping, Optional, Sequence, Tuple from ..utils.settings_paths import get_project_root, get_settings_dir @@ -47,6 +47,7 @@ class PersistentModelCache: "metadata_source", "civitai_id", "civitai_model_id", + "civitai_model_type", "civitai_name", "civitai_creator_username", "trained_words", @@ -138,7 +139,8 @@ class PersistentModelCache: creator_username = row["civitai_creator_username"] civitai: Optional[Dict] = None civitai_has_data = any( - row[col] is not None for col in ("civitai_id", "civitai_model_id", "civitai_name") + row[col] is not None + for col in ("civitai_id", "civitai_model_id", "civitai_model_type", "civitai_name") ) or trained_words or creator_username if civitai_has_data: civitai = {} @@ -152,6 +154,9 @@ class PersistentModelCache: civitai["trainedWords"] = trained_words if creator_username: civitai.setdefault("creator", {})["username"] = creator_username + model_type_value = row["civitai_model_type"] + if model_type_value: + civitai.setdefault("model", {})["type"] = model_type_value license_value = row["license_flags"] if license_value is None: @@ -443,6 +448,7 @@ class PersistentModelCache: metadata_source TEXT, civitai_id INTEGER, civitai_model_id INTEGER, + civitai_model_type TEXT, civitai_name TEXT, civitai_creator_username TEXT, trained_words TEXT, @@ -492,6 +498,7 @@ class PersistentModelCache: required_columns = { "metadata_source": "TEXT", "civitai_creator_username": "TEXT", + "civitai_model_type": "TEXT", "civitai_deleted": "INTEGER DEFAULT 0", # Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57). "license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}", @@ -528,6 +535,13 @@ class PersistentModelCache: creator_data = civitai.get("creator") if isinstance(civitai, dict) else None if isinstance(creator_data, dict): creator_username = creator_data.get("username") or None + model_type_value = None + if isinstance(civitai, Mapping): + civitai_model_info = civitai.get("model") + if isinstance(civitai_model_info, Mapping): + candidate_type = civitai_model_info.get("type") + if candidate_type not in (None, "", []): + model_type_value = candidate_type license_flags = item.get("license_flags") if license_flags is None: @@ -552,6 +566,7 @@ class PersistentModelCache: metadata_source, civitai.get("id"), civitai.get("modelId"), + model_type_value, civitai.get("name"), creator_username, trained_words_json, diff --git a/tests/conftest.py b/tests/conftest.py index 7f5becb5..5ed3bc72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -212,6 +212,7 @@ class MockModelService: self.model_type = "test-model" self.paginated_items: List[Dict[str, Any]] = [] self.formatted: List[Dict[str, Any]] = [] + self.model_types: List[Dict[str, Any]] = [] async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: items = [dict(item) for item in self.paginated_items] @@ -257,6 +258,9 @@ class MockModelService: async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover return [] + async def get_model_types(self, limit: int = 20): + return list(self.model_types)[:limit] + def has_hash(self, *_args, **_kwargs): # pragma: no cover return False @@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS def mock_service(mock_scanner: MockScanner) -> MockModelService: return MockModelService(scanner=mock_scanner) - diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index bd4b8550..55776dcd 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner): asyncio.run(scenario()) +def test_model_types_endpoint_returns_counts(mock_service, mock_scanner): + mock_service.model_types = [ + {"type": "LoRa", "count": 3}, + {"type": "Checkpoint", "count": 1}, + ] + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.get("/api/lm/test-models/model-types?limit=1") + payload = await response.json() + + assert response.status == 200 + assert payload["model_types"] == mock_service.model_types[:1] + finally: + await client.close() + + asyncio.run(scenario()) + + def test_routes_return_service_not_ready_when_unattached(): async def scenario(): client = await create_test_client(None) diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 8c412aa1..9a93f216 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays(): assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"] +def test_model_filter_set_filters_by_model_types(): + settings = StubSettings({}) + filter_set = ModelFilterSet(settings) + data = [ + {"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}}, + {"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}}, + ] + + criteria = FilterCriteria(model_types=["locon"]) + result = filter_set.apply(data, criteria) + + assert [item["model_name"] for item in result] == ["LoConModel"] + + +def test_model_filter_set_defaults_missing_model_type_to_lora(): + settings = StubSettings({}) + filter_set = ModelFilterSet(settings) + data = [ + {"model_name": "DefaultModel"}, + {"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}}, + ] + + criteria = FilterCriteria(model_types=["lora"]) + result = filter_set.apply(data, criteria) + + assert [item["model_name"] for item in result] == ["DefaultModel"] + + +@pytest.mark.asyncio +async def test_get_model_types_counts_and_limits(): + raw_data = [ + {"civitai": {"model": {"type": "LoRa"}}}, + {"model_type": "LoRa"}, + {"civitai": {"model": {"type": "LoCon"}}}, + {}, + ] + + class CacheStub: + def __init__(self, raw_data): + self.raw_data = raw_data + + class ScannerStub: + def __init__(self, cache): + self._cache = cache + + async def get_cached_data(self, *_, **__): + return self._cache + + cache = CacheStub(raw_data) + scanner = ScannerStub(cache) + service = DummyService( + model_type="stub", + scanner=scanner, + metadata_class=BaseModelMetadata, + ) + + types = await service.get_model_types(limit=1) + + assert types == [{"type": "LoRa", "count": 2}] + + @pytest.mark.asyncio @pytest.mark.parametrize( "service_cls, extra_fields",