feat(model): add model type filtering support

- Add model_types parameter to ModelListingHandler to support filtering by model type
- Implement get_model_types endpoint in ModelQueryHandler to retrieve available model types
- Register new /api/lm/{prefix}/model-types route for model type queries
- Extend BaseModelService to handle model type filtering in queries
- Support both model_type and civitai_model_type query parameters for backward compatibility

This enables users to filter models by specific types, improving model discovery and organization capabilities.
This commit is contained in:
Will Miao
2025-11-18 15:36:01 +08:00
parent 059ebeead7
commit 57f369a6de
9 changed files with 205 additions and 6 deletions

View File

@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
def test_model_filter_set_filters_by_model_types():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
]
criteria = FilterCriteria(model_types=["locon"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["LoConModel"]
def test_model_filter_set_defaults_missing_model_type_to_lora():
settings = StubSettings({})
filter_set = ModelFilterSet(settings)
data = [
{"model_name": "DefaultModel"},
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
]
criteria = FilterCriteria(model_types=["lora"])
result = filter_set.apply(data, criteria)
assert [item["model_name"] for item in result] == ["DefaultModel"]
@pytest.mark.asyncio
async def test_get_model_types_counts_and_limits():
raw_data = [
{"civitai": {"model": {"type": "LoRa"}}},
{"model_type": "LoRa"},
{"civitai": {"model": {"type": "LoCon"}}},
{},
]
class CacheStub:
def __init__(self, raw_data):
self.raw_data = raw_data
class ScannerStub:
def __init__(self, cache):
self._cache = cache
async def get_cached_data(self, *_, **__):
return self._cache
cache = CacheStub(raw_data)
scanner = ScannerStub(cache)
service = DummyService(
model_type="stub",
scanner=scanner,
metadata_class=BaseModelMetadata,
)
types = await service.get_model_types(limit=1)
assert types == [{"type": "LoRa", "count": 2}]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"service_cls, extra_fields",