mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 14:42:11 -03:00
feat(model): add model type filtering support
- Add model_types parameter to ModelListingHandler to support filtering by model type
- Implement get_model_types endpoint in ModelQueryHandler to retrieve available model types
- Register new /api/lm/{prefix}/model-types route for model type queries
- Extend BaseModelService to handle model type filtering in queries
- Support both model_type and civitai_model_type query parameters for backward compatibility
This enables users to filter models by specific types, improving model discovery and organization capabilities.
This commit is contained in:
@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
|
||||
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
|
||||
|
||||
|
||||
def test_model_filter_set_filters_by_model_types():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
|
||||
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["locon"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["LoConModel"]
|
||||
|
||||
|
||||
def test_model_filter_set_defaults_missing_model_type_to_lora():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "DefaultModel"},
|
||||
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["DefaultModel"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_types_counts_and_limits():
|
||||
raw_data = [
|
||||
{"civitai": {"model": {"type": "LoRa"}}},
|
||||
{"model_type": "LoRa"},
|
||||
{"civitai": {"model": {"type": "LoCon"}}},
|
||||
{},
|
||||
]
|
||||
|
||||
class CacheStub:
|
||||
def __init__(self, raw_data):
|
||||
self.raw_data = raw_data
|
||||
|
||||
class ScannerStub:
|
||||
def __init__(self, cache):
|
||||
self._cache = cache
|
||||
|
||||
async def get_cached_data(self, *_, **__):
|
||||
return self._cache
|
||||
|
||||
cache = CacheStub(raw_data)
|
||||
scanner = ScannerStub(cache)
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=scanner,
|
||||
metadata_class=BaseModelMetadata,
|
||||
)
|
||||
|
||||
types = await service.get_model_types(limit=1)
|
||||
|
||||
assert types == [{"type": "LoRa", "count": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"service_cls, extra_fields",
|
||||
|
||||
Reference in New Issue
Block a user