diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 2071264b..fa68f55f 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -167,6 +167,19 @@ class ModelListingHandler: pass update_available_only = request.query.get("update_available_only", "false").lower() == "true" + + # New license-based query filters + credit_required = request.query.get("credit_required") + if credit_required is not None: + credit_required = credit_required.lower() not in ("false", "0", "") + else: + credit_required = None # None means no filter applied + + allow_selling_generated_content = request.query.get("allow_selling_generated_content") + if allow_selling_generated_content is not None: + allow_selling_generated_content = allow_selling_generated_content.lower() not in ("false", "0", "") + else: + allow_selling_generated_content = None # None means no filter applied return { "page": page, @@ -181,6 +194,8 @@ class ModelListingHandler: "hash_filters": hash_filters, "favorites_only": favorites_only, "update_available_only": update_available_only, + "credit_required": credit_required, + "allow_selling_generated_content": allow_selling_generated_content, **self._parse_specific_params(request), } diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index dee143aa..d90c2ed8 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -64,6 +64,8 @@ class BaseModelService(ABC): hash_filters: dict = None, favorites_only: bool = False, update_available_only: bool = False, + credit_required: Optional[bool] = None, + allow_selling_generated_content: Optional[bool] = None, **kwargs, ) -> Dict: """Get paginated and filtered model data""" @@ -93,6 +95,13 @@ class BaseModelService(ABC): filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) + # Apply license-based filters + if credit_required is not None: + filtered_data = await self._apply_credit_required_filter(filtered_data, credit_required) + + if allow_selling_generated_content is not None: + filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content) + annotated_for_filter: Optional[List[Dict]] = None if update_available_only: annotated_for_filter = await self._annotate_update_flags(filtered_data) @@ -170,6 +179,61 @@ class BaseModelService(ABC): """Apply model-specific filters - to be overridden by subclasses if needed""" return data + async def _apply_credit_required_filter(self, data: List[Dict], credit_required: bool) -> List[Dict]: + """Apply credit required filtering based on license_flags. + + Args: + data: List of model data items + credit_required: + - True: Return items where credit is required (allowNoCredit=False) + - False: Return items where credit is not required (allowNoCredit=True) + """ + filtered_data = [] + for item in data: + license_flags = item.get("license_flags", 127) # Default to all permissions enabled + + # Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required) + allow_no_credit = bool(license_flags & (1 << 0)) + + # If credit_required is True, we want items where allowNoCredit is False (credit required) + # If credit_required is False, we want items where allowNoCredit is True (no credit required) + if credit_required: + if not allow_no_credit: # Credit is required + filtered_data.append(item) + else: + if allow_no_credit: # Credit is not required + filtered_data.append(item) + + return filtered_data + + async def _apply_allow_selling_filter(self, data: List[Dict], allow_selling: bool) -> List[Dict]: + """Apply allow selling generated content filtering based on license_flags. + + Args: + data: List of model data items + allow_selling: + - True: Return items where selling generated content is allowed (allowCommercialUse contains Image) + - False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image) + """ + filtered_data = [] + for item in data: + license_flags = item.get("license_flags", 127) # Default to all permissions enabled + + # Bits 1-4 represent commercial use permissions + # Bit 1 specifically represents Image permission (allowCommercialUse contains Image) + has_image_permission = bool(license_flags & (1 << 1)) + + # If allow_selling is True, we want items where Image permission is granted + # If allow_selling is False, we want items where Image permission is not granted + if allow_selling: + if has_image_permission: # Selling generated content is allowed + filtered_data.append(item) + else: + if not has_image_permission: # Selling generated content is not allowed + filtered_data.append(item) + + return filtered_data + async def _annotate_update_flags( self, items: List[Dict], diff --git a/tests/services/test_license_filters.py b/tests/services/test_license_filters.py new file mode 100644 index 00000000..6419d88d --- /dev/null +++ b/tests/services/test_license_filters.py @@ -0,0 +1,146 @@ +"""Tests for license-based filtering functionality.""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from py.services.base_model_service import BaseModelService +from py.utils.civitai_utils import build_license_flags + + +class DummyModelService(BaseModelService): + """Dummy implementation of BaseModelService for testing.""" + + def __init__(self): + # Mock the required attributes + self.model_type = "test" + self.scanner = Mock() + self.metadata_class = Mock() + self.settings = Mock() + self.cache_repository = Mock() + self.filter_set = Mock() + self.search_strategy = Mock() + + # Mock the scanner's get_cached_data to return a mock cache + async def mock_get_cached_data(): + cache_mock = Mock() + cache_mock.get_sorted_data = AsyncMock(return_value=[]) + return cache_mock + + self.scanner.get_cached_data = mock_get_cached_data + + async def format_response(self, model_data: dict) -> dict: + """Required abstract method implementation.""" + return model_data + + +@pytest.mark.asyncio +async def test_credit_required_filter(): + """Test the credit required filtering logic.""" + service = DummyModelService() + + # Create test data with different license flags + test_data = [ + # Model requiring credit (allowNoCredit = False) + {"file_path": "model1.safetensors", "license_flags": build_license_flags({"allowNoCredit": False})}, + # Model not requiring credit (allowNoCredit = True) + {"file_path": "model2.safetensors", "license_flags": build_license_flags({"allowNoCredit": True})}, + # Model with default license flags (allowNoCredit = True by default) + {"file_path": "model3.safetensors", "license_flags": build_license_flags(None)}, + ] + + # Test credit_required=True (should return models that require credit - allowNoCredit=False) + filtered = await service._apply_credit_required_filter(test_data, credit_required=True) + assert len(filtered) == 1 + assert filtered[0]["file_path"] == "model1.safetensors" + + # Test credit_required=False (should return models that don't require credit - allowNoCredit=True) + filtered = await service._apply_credit_required_filter(test_data, credit_required=False) + assert len(filtered) == 2 + file_paths = {item["file_path"] for item in filtered} + assert file_paths == {"model2.safetensors", "model3.safetensors"} + + +@pytest.mark.asyncio +async def test_allow_selling_filter(): + """Test the allow selling generated content filtering logic.""" + service = DummyModelService() + + # Create test data with different license flags + test_data = [ + # Model allowing selling (contains Image in allowCommercialUse) + {"file_path": "model1.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["Image"]})}, + # Model not allowing selling (doesn't contain Image in allowCommercialUse) + {"file_path": "model2.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["RentCivit"]})}, + # Model with default license flags (includes Sell by default, which implies Image) + {"file_path": "model3.safetensors", "license_flags": build_license_flags(None)}, + # Model allowing selling (contains Sell in allowCommercialUse, which implies Image) + {"file_path": "model4.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["Sell"]})}, + # Model with empty allowCommercialUse (doesn't allow selling) + {"file_path": "model5.safetensors", "license_flags": build_license_flags({"allowCommercialUse": []})}, + ] + + # Test allow_selling=True (should return models that allow selling - have Image permission) + # Default and Sell permissions both include Image, so model3 and model4 will be included + filtered = await service._apply_allow_selling_filter(test_data, allow_selling=True) + assert len(filtered) == 3 # model1, model3 (default includes Sell which implies Image), model4 + file_paths = {item["file_path"] for item in filtered} + assert file_paths == {"model1.safetensors", "model3.safetensors", "model4.safetensors"} + + # Test allow_selling=False (should return models that don't allow selling - don't have Image permission) + filtered = await service._apply_allow_selling_filter(test_data, allow_selling=False) + assert len(filtered) == 2 # model2 and model5 + file_paths = {item["file_path"] for item in filtered} + assert file_paths == {"model2.safetensors", "model5.safetensors"} + + +@pytest.mark.asyncio +async def test_combined_filters(): + """Test combining both credit required and allow selling filters.""" + service = DummyModelService() + + # Create test data + test_data = [ + # Requires credit AND allows selling + {"file_path": "model1.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": False, + "allowCommercialUse": ["Image"] + })}, + # Requires credit AND doesn't allow selling + {"file_path": "model2.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": False, + "allowCommercialUse": ["Rent"] + })}, + # Doesn't require credit AND allows selling + {"file_path": "model3.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": True, + "allowCommercialUse": ["Image"] + })}, + # Doesn't require credit AND doesn't allow selling + {"file_path": "model4.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": True, + "allowCommercialUse": ["Rent"] + })}, + ] + + # First apply credit_required=True filter (requires credit) + filtered = await service._apply_credit_required_filter(test_data, credit_required=True) + assert len(filtered) == 2 + file_paths = {item["file_path"] for item in filtered} + assert file_paths == {"model1.safetensors", "model2.safetensors"} + + # Then apply allow_selling=True filter (allows selling) to the result + filtered = await service._apply_allow_selling_filter(filtered, allow_selling=True) + assert len(filtered) == 1 + assert filtered[0]["file_path"] == "model1.safetensors" + + # Test the other combination + # First apply credit_required=False filter (doesn't require credit) + filtered = await service._apply_credit_required_filter(test_data, credit_required=False) + assert len(filtered) == 2 + file_paths = {item["file_path"] for item in filtered} + assert file_paths == {"model3.safetensors", "model4.safetensors"} + + # Then apply allow_selling=False filter (doesn't allow selling) to the result + filtered = await service._apply_allow_selling_filter(filtered, allow_selling=False) + assert len(filtered) == 1 + assert filtered[0]["file_path"] == "model4.safetensors" \ No newline at end of file diff --git a/tests/services/test_license_filters_integration.py b/tests/services/test_license_filters_integration.py new file mode 100644 index 00000000..4fd2ab7f --- /dev/null +++ b/tests/services/test_license_filters_integration.py @@ -0,0 +1,121 @@ +"""Integration tests for license-based filtering in BaseModelService.""" + +import pytest +from unittest.mock import Mock, AsyncMock + +from py.services.base_model_service import BaseModelService +from py.utils.civitai_utils import build_license_flags +from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams + + +class DummyModelService(BaseModelService): + """Dummy implementation of BaseModelService for testing.""" + + def __init__(self): + # Mock the required attributes + self.model_type = "test" + self.scanner = Mock() + self.metadata_class = Mock() + self.settings = Mock() + self.update_service = None # Add the missing attribute + + # Mock the cache repository + self.cache_repository = ModelCacheRepository(self.scanner) + self.filter_set = ModelFilterSet(self.settings) + self.search_strategy = SearchStrategy() + + # Mock the scanner's get_cached_data to return a mock cache + self.cache_mock = Mock() + self.cache_mock.get_sorted_data = AsyncMock(return_value=[]) + + async def mock_get_cached_data(): + return self.cache_mock + + self.scanner.get_cached_data = mock_get_cached_data + + async def format_response(self, model_data: dict) -> dict: + """Required abstract method implementation.""" + return model_data + + +@pytest.mark.asyncio +async def test_get_paginated_data_with_license_filters(): + """Test that license filters are applied in get_paginated_data.""" + service = DummyModelService() + + # Create test data with different license flags + test_data = [ + # Model requiring credit AND allowing selling + {"file_path": "model1.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": False, + "allowCommercialUse": ["Image"] + })}, + # Model requiring credit AND not allowing selling + {"file_path": "model2.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": False, + "allowCommercialUse": ["Rent"] + })}, + # Model not requiring credit AND allowing selling + {"file_path": "model3.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": True, + "allowCommercialUse": ["Image"] + })}, + # Model not requiring credit AND not allowing selling + {"file_path": "model4.safetensors", "license_flags": build_license_flags({ + "allowNoCredit": True, + "allowCommercialUse": ["Rent"] + })}, + ] + + # Mock the sorted data + service.cache_mock.get_sorted_data = AsyncMock(return_value=test_data) + + # Test with credit_required=True + result = await service.get_paginated_data( + page=1, + page_size=10, + credit_required=True + ) + assert len(result["items"]) == 2 + file_paths = {item["file_path"] for item in result["items"]} + assert file_paths == {"model1.safetensors", "model2.safetensors"} + + # Test with credit_required=False + result = await service.get_paginated_data( + page=1, + page_size=10, + credit_required=False + ) + assert len(result["items"]) == 2 + file_paths = {item["file_path"] for item in result["items"]} + assert file_paths == {"model3.safetensors", "model4.safetensors"} + + # Test with allow_selling_generated_content=True + result = await service.get_paginated_data( + page=1, + page_size=10, + allow_selling_generated_content=True + ) + assert len(result["items"]) == 2 + file_paths = {item["file_path"] for item in result["items"]} + assert file_paths == {"model1.safetensors", "model3.safetensors"} + + # Test with allow_selling_generated_content=False + result = await service.get_paginated_data( + page=1, + page_size=10, + allow_selling_generated_content=False + ) + assert len(result["items"]) == 2 + file_paths = {item["file_path"] for item in result["items"]} + assert file_paths == {"model2.safetensors", "model4.safetensors"} + + # Test with both filters + result = await service.get_paginated_data( + page=1, + page_size=10, + credit_required=True, + allow_selling_generated_content=True + ) + assert len(result["items"]) == 1 + assert result["items"][0]["file_path"] == "model1.safetensors" \ No newline at end of file