mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(model): add model type filtering support
- Add model_types parameter to ModelListingHandler to support filtering by model type
- Implement get_model_types endpoint in ModelQueryHandler to retrieve available model types
- Register new /api/lm/{prefix}/model-types route for model type queries
- Extend BaseModelService to handle model type filtering in queries
- Support both model_type and civitai_model_type query parameters for backward compatibility
This enables users to filter models by specific types, improving model discovery and organization capabilities.
This commit is contained in:
@@ -212,6 +212,7 @@ class MockModelService:
|
||||
self.model_type = "test-model"
|
||||
self.paginated_items: List[Dict[str, Any]] = []
|
||||
self.formatted: List[Dict[str, Any]] = []
|
||||
self.model_types: List[Dict[str, Any]] = []
|
||||
|
||||
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
||||
items = [dict(item) for item in self.paginated_items]
|
||||
@@ -257,6 +258,9 @@ class MockModelService:
|
||||
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
|
||||
return []
|
||||
|
||||
async def get_model_types(self, limit: int = 20):
|
||||
return list(self.model_types)[:limit]
|
||||
|
||||
def has_hash(self, *_args, **_kwargs): # pragma: no cover
|
||||
return False
|
||||
|
||||
@@ -283,4 +287,3 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS
|
||||
def mock_service(mock_scanner: MockScanner) -> MockModelService:
|
||||
return MockModelService(scanner=mock_scanner)
|
||||
|
||||
|
||||
|
||||
@@ -185,6 +185,26 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_model_types_endpoint_returns_counts(mock_service, mock_scanner):
|
||||
mock_service.model_types = [
|
||||
{"type": "LoRa", "count": 3},
|
||||
{"type": "Checkpoint", "count": 1},
|
||||
]
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.get("/api/lm/test-models/model-types?limit=1")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["model_types"] == mock_service.model_types[:1]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_routes_return_service_not_ready_when_unattached():
|
||||
async def scenario():
|
||||
client = await create_test_client(None)
|
||||
|
||||
@@ -776,6 +776,67 @@ def test_model_filter_set_supports_legacy_tag_arrays():
|
||||
assert [item["model_name"] for item in result] == ["StyleOnly", "StyleAnime"]
|
||||
|
||||
|
||||
def test_model_filter_set_filters_by_model_types():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "LoConModel", "civitai": {"model": {"type": "LoCon"}}},
|
||||
{"model_name": "LoRaModel", "civitai": {"model": {"type": "LoRa"}}},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["locon"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["LoConModel"]
|
||||
|
||||
|
||||
def test_model_filter_set_defaults_missing_model_type_to_lora():
|
||||
settings = StubSettings({})
|
||||
filter_set = ModelFilterSet(settings)
|
||||
data = [
|
||||
{"model_name": "DefaultModel"},
|
||||
{"model_name": "CheckpointModel", "civitai": {"model": {"type": "checkpoint"}}},
|
||||
]
|
||||
|
||||
criteria = FilterCriteria(model_types=["lora"])
|
||||
result = filter_set.apply(data, criteria)
|
||||
|
||||
assert [item["model_name"] for item in result] == ["DefaultModel"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_types_counts_and_limits():
|
||||
raw_data = [
|
||||
{"civitai": {"model": {"type": "LoRa"}}},
|
||||
{"model_type": "LoRa"},
|
||||
{"civitai": {"model": {"type": "LoCon"}}},
|
||||
{},
|
||||
]
|
||||
|
||||
class CacheStub:
|
||||
def __init__(self, raw_data):
|
||||
self.raw_data = raw_data
|
||||
|
||||
class ScannerStub:
|
||||
def __init__(self, cache):
|
||||
self._cache = cache
|
||||
|
||||
async def get_cached_data(self, *_, **__):
|
||||
return self._cache
|
||||
|
||||
cache = CacheStub(raw_data)
|
||||
scanner = ScannerStub(cache)
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=scanner,
|
||||
metadata_class=BaseModelMetadata,
|
||||
)
|
||||
|
||||
types = await service.get_model_types(limit=1)
|
||||
|
||||
assert types == [{"type": "LoRa", "count": 2}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"service_cls, extra_fields",
|
||||
|
||||
Reference in New Issue
Block a user