mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add license-based filtering for model listings
Add support for filtering models by license requirements: - credit_required: filter models that require credits or allow free use - allow_selling_generated_content: filter models based on commercial usage rights These filters use license_flags bitmask to determine model permissions and enable users to find models that match their specific usage requirements and budget constraints.
This commit is contained in:
@@ -167,6 +167,19 @@ class ModelListingHandler:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
update_available_only = request.query.get("update_available_only", "false").lower() == "true"
|
update_available_only = request.query.get("update_available_only", "false").lower() == "true"
|
||||||
|
|
||||||
|
# New license-based query filters
|
||||||
|
credit_required = request.query.get("credit_required")
|
||||||
|
if credit_required is not None:
|
||||||
|
credit_required = credit_required.lower() not in ("false", "0", "")
|
||||||
|
else:
|
||||||
|
credit_required = None # None means no filter applied
|
||||||
|
|
||||||
|
allow_selling_generated_content = request.query.get("allow_selling_generated_content")
|
||||||
|
if allow_selling_generated_content is not None:
|
||||||
|
allow_selling_generated_content = allow_selling_generated_content.lower() not in ("false", "0", "")
|
||||||
|
else:
|
||||||
|
allow_selling_generated_content = None # None means no filter applied
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"page": page,
|
"page": page,
|
||||||
@@ -181,6 +194,8 @@ class ModelListingHandler:
|
|||||||
"hash_filters": hash_filters,
|
"hash_filters": hash_filters,
|
||||||
"favorites_only": favorites_only,
|
"favorites_only": favorites_only,
|
||||||
"update_available_only": update_available_only,
|
"update_available_only": update_available_only,
|
||||||
|
"credit_required": credit_required,
|
||||||
|
"allow_selling_generated_content": allow_selling_generated_content,
|
||||||
**self._parse_specific_params(request),
|
**self._parse_specific_params(request),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class BaseModelService(ABC):
|
|||||||
hash_filters: dict = None,
|
hash_filters: dict = None,
|
||||||
favorites_only: bool = False,
|
favorites_only: bool = False,
|
||||||
update_available_only: bool = False,
|
update_available_only: bool = False,
|
||||||
|
credit_required: Optional[bool] = None,
|
||||||
|
allow_selling_generated_content: Optional[bool] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""Get paginated and filtered model data"""
|
"""Get paginated and filtered model data"""
|
||||||
@@ -93,6 +95,13 @@ class BaseModelService(ABC):
|
|||||||
|
|
||||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||||
|
|
||||||
|
# Apply license-based filters
|
||||||
|
if credit_required is not None:
|
||||||
|
filtered_data = await self._apply_credit_required_filter(filtered_data, credit_required)
|
||||||
|
|
||||||
|
if allow_selling_generated_content is not None:
|
||||||
|
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content)
|
||||||
|
|
||||||
annotated_for_filter: Optional[List[Dict]] = None
|
annotated_for_filter: Optional[List[Dict]] = None
|
||||||
if update_available_only:
|
if update_available_only:
|
||||||
annotated_for_filter = await self._annotate_update_flags(filtered_data)
|
annotated_for_filter = await self._annotate_update_flags(filtered_data)
|
||||||
@@ -170,6 +179,61 @@ class BaseModelService(ABC):
|
|||||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
async def _apply_credit_required_filter(self, data: List[Dict], credit_required: bool) -> List[Dict]:
|
||||||
|
"""Apply credit required filtering based on license_flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: List of model data items
|
||||||
|
credit_required:
|
||||||
|
- True: Return items where credit is required (allowNoCredit=False)
|
||||||
|
- False: Return items where credit is not required (allowNoCredit=True)
|
||||||
|
"""
|
||||||
|
filtered_data = []
|
||||||
|
for item in data:
|
||||||
|
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||||
|
|
||||||
|
# Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required)
|
||||||
|
allow_no_credit = bool(license_flags & (1 << 0))
|
||||||
|
|
||||||
|
# If credit_required is True, we want items where allowNoCredit is False (credit required)
|
||||||
|
# If credit_required is False, we want items where allowNoCredit is True (no credit required)
|
||||||
|
if credit_required:
|
||||||
|
if not allow_no_credit: # Credit is required
|
||||||
|
filtered_data.append(item)
|
||||||
|
else:
|
||||||
|
if allow_no_credit: # Credit is not required
|
||||||
|
filtered_data.append(item)
|
||||||
|
|
||||||
|
return filtered_data
|
||||||
|
|
||||||
|
async def _apply_allow_selling_filter(self, data: List[Dict], allow_selling: bool) -> List[Dict]:
|
||||||
|
"""Apply allow selling generated content filtering based on license_flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: List of model data items
|
||||||
|
allow_selling:
|
||||||
|
- True: Return items where selling generated content is allowed (allowCommercialUse contains Image)
|
||||||
|
- False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image)
|
||||||
|
"""
|
||||||
|
filtered_data = []
|
||||||
|
for item in data:
|
||||||
|
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||||
|
|
||||||
|
# Bits 1-4 represent commercial use permissions
|
||||||
|
# Bit 1 specifically represents Image permission (allowCommercialUse contains Image)
|
||||||
|
has_image_permission = bool(license_flags & (1 << 1))
|
||||||
|
|
||||||
|
# If allow_selling is True, we want items where Image permission is granted
|
||||||
|
# If allow_selling is False, we want items where Image permission is not granted
|
||||||
|
if allow_selling:
|
||||||
|
if has_image_permission: # Selling generated content is allowed
|
||||||
|
filtered_data.append(item)
|
||||||
|
else:
|
||||||
|
if not has_image_permission: # Selling generated content is not allowed
|
||||||
|
filtered_data.append(item)
|
||||||
|
|
||||||
|
return filtered_data
|
||||||
|
|
||||||
async def _annotate_update_flags(
|
async def _annotate_update_flags(
|
||||||
self,
|
self,
|
||||||
items: List[Dict],
|
items: List[Dict],
|
||||||
|
|||||||
146
tests/services/test_license_filters.py
Normal file
146
tests/services/test_license_filters.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Tests for license-based filtering functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock
|
||||||
|
|
||||||
|
from py.services.base_model_service import BaseModelService
|
||||||
|
from py.utils.civitai_utils import build_license_flags
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModelService(BaseModelService):
|
||||||
|
"""Dummy implementation of BaseModelService for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Mock the required attributes
|
||||||
|
self.model_type = "test"
|
||||||
|
self.scanner = Mock()
|
||||||
|
self.metadata_class = Mock()
|
||||||
|
self.settings = Mock()
|
||||||
|
self.cache_repository = Mock()
|
||||||
|
self.filter_set = Mock()
|
||||||
|
self.search_strategy = Mock()
|
||||||
|
|
||||||
|
# Mock the scanner's get_cached_data to return a mock cache
|
||||||
|
async def mock_get_cached_data():
|
||||||
|
cache_mock = Mock()
|
||||||
|
cache_mock.get_sorted_data = AsyncMock(return_value=[])
|
||||||
|
return cache_mock
|
||||||
|
|
||||||
|
self.scanner.get_cached_data = mock_get_cached_data
|
||||||
|
|
||||||
|
async def format_response(self, model_data: dict) -> dict:
|
||||||
|
"""Required abstract method implementation."""
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credit_required_filter():
|
||||||
|
"""Test the credit required filtering logic."""
|
||||||
|
service = DummyModelService()
|
||||||
|
|
||||||
|
# Create test data with different license flags
|
||||||
|
test_data = [
|
||||||
|
# Model requiring credit (allowNoCredit = False)
|
||||||
|
{"file_path": "model1.safetensors", "license_flags": build_license_flags({"allowNoCredit": False})},
|
||||||
|
# Model not requiring credit (allowNoCredit = True)
|
||||||
|
{"file_path": "model2.safetensors", "license_flags": build_license_flags({"allowNoCredit": True})},
|
||||||
|
# Model with default license flags (allowNoCredit = True by default)
|
||||||
|
{"file_path": "model3.safetensors", "license_flags": build_license_flags(None)},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test credit_required=True (should return models that require credit - allowNoCredit=False)
|
||||||
|
filtered = await service._apply_credit_required_filter(test_data, credit_required=True)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["file_path"] == "model1.safetensors"
|
||||||
|
|
||||||
|
# Test credit_required=False (should return models that don't require credit - allowNoCredit=True)
|
||||||
|
filtered = await service._apply_credit_required_filter(test_data, credit_required=False)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
file_paths = {item["file_path"] for item in filtered}
|
||||||
|
assert file_paths == {"model2.safetensors", "model3.safetensors"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allow_selling_filter():
|
||||||
|
"""Test the allow selling generated content filtering logic."""
|
||||||
|
service = DummyModelService()
|
||||||
|
|
||||||
|
# Create test data with different license flags
|
||||||
|
test_data = [
|
||||||
|
# Model allowing selling (contains Image in allowCommercialUse)
|
||||||
|
{"file_path": "model1.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["Image"]})},
|
||||||
|
# Model not allowing selling (doesn't contain Image in allowCommercialUse)
|
||||||
|
{"file_path": "model2.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["RentCivit"]})},
|
||||||
|
# Model with default license flags (includes Sell by default, which implies Image)
|
||||||
|
{"file_path": "model3.safetensors", "license_flags": build_license_flags(None)},
|
||||||
|
# Model allowing selling (contains Sell in allowCommercialUse, which implies Image)
|
||||||
|
{"file_path": "model4.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["Sell"]})},
|
||||||
|
# Model with empty allowCommercialUse (doesn't allow selling)
|
||||||
|
{"file_path": "model5.safetensors", "license_flags": build_license_flags({"allowCommercialUse": []})},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test allow_selling=True (should return models that allow selling - have Image permission)
|
||||||
|
# Default and Sell permissions both include Image, so model3 and model4 will be included
|
||||||
|
filtered = await service._apply_allow_selling_filter(test_data, allow_selling=True)
|
||||||
|
assert len(filtered) == 3 # model1, model3 (default includes Sell which implies Image), model4
|
||||||
|
file_paths = {item["file_path"] for item in filtered}
|
||||||
|
assert file_paths == {"model1.safetensors", "model3.safetensors", "model4.safetensors"}
|
||||||
|
|
||||||
|
# Test allow_selling=False (should return models that don't allow selling - don't have Image permission)
|
||||||
|
filtered = await service._apply_allow_selling_filter(test_data, allow_selling=False)
|
||||||
|
assert len(filtered) == 2 # model2 and model5
|
||||||
|
file_paths = {item["file_path"] for item in filtered}
|
||||||
|
assert file_paths == {"model2.safetensors", "model5.safetensors"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_combined_filters():
|
||||||
|
"""Test combining both credit required and allow selling filters."""
|
||||||
|
service = DummyModelService()
|
||||||
|
|
||||||
|
# Create test data
|
||||||
|
test_data = [
|
||||||
|
# Requires credit AND allows selling
|
||||||
|
{"file_path": "model1.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": ["Image"]
|
||||||
|
})},
|
||||||
|
# Requires credit AND doesn't allow selling
|
||||||
|
{"file_path": "model2.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": ["Rent"]
|
||||||
|
})},
|
||||||
|
# Doesn't require credit AND allows selling
|
||||||
|
{"file_path": "model3.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": True,
|
||||||
|
"allowCommercialUse": ["Image"]
|
||||||
|
})},
|
||||||
|
# Doesn't require credit AND doesn't allow selling
|
||||||
|
{"file_path": "model4.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": True,
|
||||||
|
"allowCommercialUse": ["Rent"]
|
||||||
|
})},
|
||||||
|
]
|
||||||
|
|
||||||
|
# First apply credit_required=True filter (requires credit)
|
||||||
|
filtered = await service._apply_credit_required_filter(test_data, credit_required=True)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
file_paths = {item["file_path"] for item in filtered}
|
||||||
|
assert file_paths == {"model1.safetensors", "model2.safetensors"}
|
||||||
|
|
||||||
|
# Then apply allow_selling=True filter (allows selling) to the result
|
||||||
|
filtered = await service._apply_allow_selling_filter(filtered, allow_selling=True)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["file_path"] == "model1.safetensors"
|
||||||
|
|
||||||
|
# Test the other combination
|
||||||
|
# First apply credit_required=False filter (doesn't require credit)
|
||||||
|
filtered = await service._apply_credit_required_filter(test_data, credit_required=False)
|
||||||
|
assert len(filtered) == 2
|
||||||
|
file_paths = {item["file_path"] for item in filtered}
|
||||||
|
assert file_paths == {"model3.safetensors", "model4.safetensors"}
|
||||||
|
|
||||||
|
# Then apply allow_selling=False filter (doesn't allow selling) to the result
|
||||||
|
filtered = await service._apply_allow_selling_filter(filtered, allow_selling=False)
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["file_path"] == "model4.safetensors"
|
||||||
121
tests/services/test_license_filters_integration.py
Normal file
121
tests/services/test_license_filters_integration.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Integration tests for license-based filtering in BaseModelService."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock
|
||||||
|
|
||||||
|
from py.services.base_model_service import BaseModelService
|
||||||
|
from py.utils.civitai_utils import build_license_flags
|
||||||
|
from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModelService(BaseModelService):
|
||||||
|
"""Dummy implementation of BaseModelService for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Mock the required attributes
|
||||||
|
self.model_type = "test"
|
||||||
|
self.scanner = Mock()
|
||||||
|
self.metadata_class = Mock()
|
||||||
|
self.settings = Mock()
|
||||||
|
self.update_service = None # Add the missing attribute
|
||||||
|
|
||||||
|
# Mock the cache repository
|
||||||
|
self.cache_repository = ModelCacheRepository(self.scanner)
|
||||||
|
self.filter_set = ModelFilterSet(self.settings)
|
||||||
|
self.search_strategy = SearchStrategy()
|
||||||
|
|
||||||
|
# Mock the scanner's get_cached_data to return a mock cache
|
||||||
|
self.cache_mock = Mock()
|
||||||
|
self.cache_mock.get_sorted_data = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
async def mock_get_cached_data():
|
||||||
|
return self.cache_mock
|
||||||
|
|
||||||
|
self.scanner.get_cached_data = mock_get_cached_data
|
||||||
|
|
||||||
|
async def format_response(self, model_data: dict) -> dict:
|
||||||
|
"""Required abstract method implementation."""
|
||||||
|
return model_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_paginated_data_with_license_filters():
|
||||||
|
"""Test that license filters are applied in get_paginated_data."""
|
||||||
|
service = DummyModelService()
|
||||||
|
|
||||||
|
# Create test data with different license flags
|
||||||
|
test_data = [
|
||||||
|
# Model requiring credit AND allowing selling
|
||||||
|
{"file_path": "model1.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": ["Image"]
|
||||||
|
})},
|
||||||
|
# Model requiring credit AND not allowing selling
|
||||||
|
{"file_path": "model2.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": ["Rent"]
|
||||||
|
})},
|
||||||
|
# Model not requiring credit AND allowing selling
|
||||||
|
{"file_path": "model3.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": True,
|
||||||
|
"allowCommercialUse": ["Image"]
|
||||||
|
})},
|
||||||
|
# Model not requiring credit AND not allowing selling
|
||||||
|
{"file_path": "model4.safetensors", "license_flags": build_license_flags({
|
||||||
|
"allowNoCredit": True,
|
||||||
|
"allowCommercialUse": ["Rent"]
|
||||||
|
})},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock the sorted data
|
||||||
|
service.cache_mock.get_sorted_data = AsyncMock(return_value=test_data)
|
||||||
|
|
||||||
|
# Test with credit_required=True
|
||||||
|
result = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
credit_required=True
|
||||||
|
)
|
||||||
|
assert len(result["items"]) == 2
|
||||||
|
file_paths = {item["file_path"] for item in result["items"]}
|
||||||
|
assert file_paths == {"model1.safetensors", "model2.safetensors"}
|
||||||
|
|
||||||
|
# Test with credit_required=False
|
||||||
|
result = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
credit_required=False
|
||||||
|
)
|
||||||
|
assert len(result["items"]) == 2
|
||||||
|
file_paths = {item["file_path"] for item in result["items"]}
|
||||||
|
assert file_paths == {"model3.safetensors", "model4.safetensors"}
|
||||||
|
|
||||||
|
# Test with allow_selling_generated_content=True
|
||||||
|
result = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
allow_selling_generated_content=True
|
||||||
|
)
|
||||||
|
assert len(result["items"]) == 2
|
||||||
|
file_paths = {item["file_path"] for item in result["items"]}
|
||||||
|
assert file_paths == {"model1.safetensors", "model3.safetensors"}
|
||||||
|
|
||||||
|
# Test with allow_selling_generated_content=False
|
||||||
|
result = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
allow_selling_generated_content=False
|
||||||
|
)
|
||||||
|
assert len(result["items"]) == 2
|
||||||
|
file_paths = {item["file_path"] for item in result["items"]}
|
||||||
|
assert file_paths == {"model2.safetensors", "model4.safetensors"}
|
||||||
|
|
||||||
|
# Test with both filters
|
||||||
|
result = await service.get_paginated_data(
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
credit_required=True,
|
||||||
|
allow_selling_generated_content=True
|
||||||
|
)
|
||||||
|
assert len(result["items"]) == 1
|
||||||
|
assert result["items"][0]["file_path"] == "model1.safetensors"
|
||||||
Reference in New Issue
Block a user