mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
License filter
This commit is contained in:
@@ -188,6 +188,9 @@
|
||||
"title": "Modelle filtern",
|
||||
"baseModel": "Basis-Modell",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"license": "Lizenz",
|
||||
"noCreditRequired": "Kein Credit erforderlich",
|
||||
"allowSellingGeneratedContent": "Verkauf erlaubt",
|
||||
"clearAll": "Alle Filter löschen"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "Filter Models",
|
||||
"baseModel": "Base Model",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"license": "License",
|
||||
"noCreditRequired": "No Credit Required",
|
||||
"allowSellingGeneratedContent": "Allow Selling",
|
||||
"clearAll": "Clear All Filters"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "Filtrar modelos",
|
||||
"baseModel": "Modelo base",
|
||||
"modelTags": "Etiquetas (Top 20)",
|
||||
"license": "Licencia",
|
||||
"noCreditRequired": "Sin crédito requerido",
|
||||
"allowSellingGeneratedContent": "Venta permitida",
|
||||
"clearAll": "Limpiar todos los filtros"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "Filtrer les modèles",
|
||||
"baseModel": "Modèle de base",
|
||||
"modelTags": "Tags (Top 20)",
|
||||
"license": "Licence",
|
||||
"noCreditRequired": "Crédit non requis",
|
||||
"allowSellingGeneratedContent": "Vente autorisée",
|
||||
"clearAll": "Effacer tous les filtres"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "סנן מודלים",
|
||||
"baseModel": "מודל בסיס",
|
||||
"modelTags": "תגיות (20 המובילות)",
|
||||
"license": "רישיון",
|
||||
"noCreditRequired": "ללא קרדיט נדרש",
|
||||
"allowSellingGeneratedContent": "אפשר מכירה",
|
||||
"clearAll": "נקה את כל המסננים"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "モデルをフィルタ",
|
||||
"baseModel": "ベースモデル",
|
||||
"modelTags": "タグ(上位20)",
|
||||
"license": "ライセンス",
|
||||
"noCreditRequired": "クレジット不要",
|
||||
"allowSellingGeneratedContent": "販売許可",
|
||||
"clearAll": "すべてのフィルタをクリア"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "모델 필터",
|
||||
"baseModel": "베이스 모델",
|
||||
"modelTags": "태그 (상위 20개)",
|
||||
"license": "라이선스",
|
||||
"noCreditRequired": "크레딧 표기 없음",
|
||||
"allowSellingGeneratedContent": "판매 허용",
|
||||
"clearAll": "모든 필터 지우기"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "Фильтр моделей",
|
||||
"baseModel": "Базовая модель",
|
||||
"modelTags": "Теги (Топ 20)",
|
||||
"license": "Лицензия",
|
||||
"noCreditRequired": "Без указания авторства",
|
||||
"allowSellingGeneratedContent": "Продажа разрешена",
|
||||
"clearAll": "Очистить все фильтры"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "筛选模型",
|
||||
"baseModel": "基础模型",
|
||||
"modelTags": "标签(前20)",
|
||||
"license": "许可证",
|
||||
"noCreditRequired": "无需署名",
|
||||
"allowSellingGeneratedContent": "允许销售",
|
||||
"clearAll": "清除所有筛选"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -188,6 +188,9 @@
|
||||
"title": "篩選模型",
|
||||
"baseModel": "基礎模型",
|
||||
"modelTags": "標籤(前 20)",
|
||||
"license": "授權",
|
||||
"noCreditRequired": "無需署名",
|
||||
"allowSellingGeneratedContent": "允許銷售",
|
||||
"clearAll": "清除所有篩選"
|
||||
},
|
||||
"theme": {
|
||||
|
||||
@@ -167,6 +167,19 @@ class ModelListingHandler:
|
||||
pass
|
||||
|
||||
update_available_only = request.query.get("update_available_only", "false").lower() == "true"
|
||||
|
||||
# New license-based query filters
|
||||
credit_required = request.query.get("credit_required")
|
||||
if credit_required is not None:
|
||||
credit_required = credit_required.lower() not in ("false", "0", "")
|
||||
else:
|
||||
credit_required = None # None means no filter applied
|
||||
|
||||
allow_selling_generated_content = request.query.get("allow_selling_generated_content")
|
||||
if allow_selling_generated_content is not None:
|
||||
allow_selling_generated_content = allow_selling_generated_content.lower() not in ("false", "0", "")
|
||||
else:
|
||||
allow_selling_generated_content = None # None means no filter applied
|
||||
|
||||
return {
|
||||
"page": page,
|
||||
@@ -181,6 +194,8 @@ class ModelListingHandler:
|
||||
"hash_filters": hash_filters,
|
||||
"favorites_only": favorites_only,
|
||||
"update_available_only": update_available_only,
|
||||
"credit_required": credit_required,
|
||||
"allow_selling_generated_content": allow_selling_generated_content,
|
||||
**self._parse_specific_params(request),
|
||||
}
|
||||
|
||||
|
||||
@@ -64,6 +64,8 @@ class BaseModelService(ABC):
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
update_available_only: bool = False,
|
||||
credit_required: Optional[bool] = None,
|
||||
allow_selling_generated_content: Optional[bool] = None,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Get paginated and filtered model data"""
|
||||
@@ -93,6 +95,13 @@ class BaseModelService(ABC):
|
||||
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
# Apply license-based filters
|
||||
if credit_required is not None:
|
||||
filtered_data = await self._apply_credit_required_filter(filtered_data, credit_required)
|
||||
|
||||
if allow_selling_generated_content is not None:
|
||||
filtered_data = await self._apply_allow_selling_filter(filtered_data, allow_selling_generated_content)
|
||||
|
||||
annotated_for_filter: Optional[List[Dict]] = None
|
||||
if update_available_only:
|
||||
annotated_for_filter = await self._annotate_update_flags(filtered_data)
|
||||
@@ -170,6 +179,61 @@ class BaseModelService(ABC):
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
async def _apply_credit_required_filter(self, data: List[Dict], credit_required: bool) -> List[Dict]:
|
||||
"""Apply credit required filtering based on license_flags.
|
||||
|
||||
Args:
|
||||
data: List of model data items
|
||||
credit_required:
|
||||
- True: Return items where credit is required (allowNoCredit=False)
|
||||
- False: Return items where credit is not required (allowNoCredit=True)
|
||||
"""
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||
|
||||
# Bit 0 represents allowNoCredit (1 = no credit required, 0 = credit required)
|
||||
allow_no_credit = bool(license_flags & (1 << 0))
|
||||
|
||||
# If credit_required is True, we want items where allowNoCredit is False (credit required)
|
||||
# If credit_required is False, we want items where allowNoCredit is True (no credit required)
|
||||
if credit_required:
|
||||
if not allow_no_credit: # Credit is required
|
||||
filtered_data.append(item)
|
||||
else:
|
||||
if allow_no_credit: # Credit is not required
|
||||
filtered_data.append(item)
|
||||
|
||||
return filtered_data
|
||||
|
||||
async def _apply_allow_selling_filter(self, data: List[Dict], allow_selling: bool) -> List[Dict]:
|
||||
"""Apply allow selling generated content filtering based on license_flags.
|
||||
|
||||
Args:
|
||||
data: List of model data items
|
||||
allow_selling:
|
||||
- True: Return items where selling generated content is allowed (allowCommercialUse contains Image)
|
||||
- False: Return items where selling generated content is not allowed (allowCommercialUse does not contain Image)
|
||||
"""
|
||||
filtered_data = []
|
||||
for item in data:
|
||||
license_flags = item.get("license_flags", 127) # Default to all permissions enabled
|
||||
|
||||
# Bits 1-4 represent commercial use permissions
|
||||
# Bit 1 specifically represents Image permission (allowCommercialUse contains Image)
|
||||
has_image_permission = bool(license_flags & (1 << 1))
|
||||
|
||||
# If allow_selling is True, we want items where Image permission is granted
|
||||
# If allow_selling is False, we want items where Image permission is not granted
|
||||
if allow_selling:
|
||||
if has_image_permission: # Selling generated content is allowed
|
||||
filtered_data.append(item)
|
||||
else:
|
||||
if not has_image_permission: # Selling generated content is not allowed
|
||||
filtered_data.append(item)
|
||||
|
||||
return filtered_data
|
||||
|
||||
async def _annotate_update_flags(
|
||||
self,
|
||||
items: List[Dict],
|
||||
|
||||
@@ -51,6 +51,8 @@ html, body {
|
||||
--lora-border: oklch(72% 0.03 256 / 0.45);
|
||||
--lora-text: oklch(95% 0.02 256);
|
||||
--lora-error: oklch(75% 0.32 29);
|
||||
--lora-error-bg: color-mix(in oklch, var(--lora-error) 20%, transparent);
|
||||
--lora-error-border: color-mix(in oklch, var(--lora-error) 50%, transparent);
|
||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h));
|
||||
--badge-update-bg: oklch(72% 0.2 220);
|
||||
@@ -103,6 +105,8 @@ html[data-theme="light"] {
|
||||
--lora-border: oklch(90% 0.02 256 / 0.15);
|
||||
--lora-text: oklch(98% 0.02 256);
|
||||
--lora-warning: oklch(75% 0.25 80); /* Modified to be used with oklch() */
|
||||
--lora-error-bg: color-mix(in oklch, var(--lora-error) 15%, transparent);
|
||||
--lora-error-border: color-mix(in oklch, var(--lora-error) 40%, transparent);
|
||||
--badge-update-bg: oklch(62% 0.18 220);
|
||||
--badge-update-text: oklch(98% 0.02 240);
|
||||
--badge-update-glow: oklch(62% 0.18 220 / 0.4);
|
||||
|
||||
@@ -235,6 +235,13 @@
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Exclude state styling for filter tags */
|
||||
.filter-tag.exclude {
|
||||
background-color: var(--lora-error-bg);
|
||||
color: var(--lora-error);
|
||||
border-color: var(--lora-error-border);
|
||||
}
|
||||
|
||||
/* Tag filter styles */
|
||||
.tag-filter {
|
||||
display: flex;
|
||||
|
||||
@@ -817,6 +817,33 @@ export class BaseModelApiClient {
|
||||
params.append('base_model', model);
|
||||
});
|
||||
}
|
||||
|
||||
// Add license filters
|
||||
if (pageState.filters.license) {
|
||||
const licenseFilters = pageState.filters.license;
|
||||
|
||||
if (licenseFilters.noCredit) {
|
||||
// For noCredit filter:
|
||||
// - 'include' means credit_required=False (no credit required)
|
||||
// - 'exclude' means credit_required=True (credit required)
|
||||
if (licenseFilters.noCredit === 'include') {
|
||||
params.append('credit_required', 'false');
|
||||
} else if (licenseFilters.noCredit === 'exclude') {
|
||||
params.append('credit_required', 'true');
|
||||
}
|
||||
}
|
||||
|
||||
if (licenseFilters.allowSelling) {
|
||||
// For allowSelling filter:
|
||||
// - 'include' means allow_selling_generated_content=True
|
||||
// - 'exclude' means allow_selling_generated_content=False
|
||||
if (licenseFilters.allowSelling === 'include') {
|
||||
params.append('allow_selling_generated_content', 'true');
|
||||
} else if (licenseFilters.allowSelling === 'exclude') {
|
||||
params.append('allow_selling_generated_content', 'false');
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this._addModelSpecificParams(params, pageState);
|
||||
|
||||
@@ -14,7 +14,8 @@ export class FilterManager {
|
||||
|
||||
this.filters = pageState.filters || {
|
||||
baseModel: [],
|
||||
tags: []
|
||||
tags: [],
|
||||
license: {}
|
||||
};
|
||||
|
||||
this.filterPanel = document.getElementById('filterPanel');
|
||||
@@ -36,6 +37,9 @@ export class FilterManager {
|
||||
this.createBaseModelTags();
|
||||
}
|
||||
|
||||
// Add click handlers for license filter tags
|
||||
this.initializeLicenseFilters();
|
||||
|
||||
// Add click handler for filter button
|
||||
if (this.filterButton) {
|
||||
this.filterButton.addEventListener('click', () => {
|
||||
@@ -129,6 +133,85 @@ export class FilterManager {
|
||||
});
|
||||
}
|
||||
|
||||
initializeLicenseFilters() {
|
||||
const licenseTags = document.querySelectorAll('.license-tag');
|
||||
licenseTags.forEach(tag => {
|
||||
tag.addEventListener('click', async () => {
|
||||
const licenseType = tag.dataset.license;
|
||||
|
||||
// Ensure license object exists
|
||||
if (!this.filters.license) {
|
||||
this.filters.license = {};
|
||||
}
|
||||
|
||||
// Get current state
|
||||
let currentState = this.filters.license[licenseType] || 'none'; // none, include, exclude
|
||||
|
||||
// Cycle through states: none -> include -> exclude -> none
|
||||
let newState;
|
||||
switch (currentState) {
|
||||
case 'none':
|
||||
newState = 'include';
|
||||
tag.classList.remove('exclude');
|
||||
tag.classList.add('active');
|
||||
break;
|
||||
case 'include':
|
||||
newState = 'exclude';
|
||||
tag.classList.remove('active');
|
||||
tag.classList.add('exclude');
|
||||
break;
|
||||
case 'exclude':
|
||||
newState = 'none';
|
||||
tag.classList.remove('active', 'exclude');
|
||||
break;
|
||||
}
|
||||
|
||||
// Update filter state
|
||||
if (newState === 'none') {
|
||||
delete this.filters.license[licenseType];
|
||||
// Clean up empty license object
|
||||
if (Object.keys(this.filters.license).length === 0) {
|
||||
delete this.filters.license;
|
||||
}
|
||||
} else {
|
||||
this.filters.license[licenseType] = newState;
|
||||
}
|
||||
|
||||
this.updateActiveFiltersCount();
|
||||
|
||||
// Auto-apply filter when tag is clicked
|
||||
await this.applyFilters(false);
|
||||
});
|
||||
});
|
||||
|
||||
// Update selections based on stored filters
|
||||
this.updateLicenseSelections();
|
||||
}
|
||||
|
||||
updateLicenseSelections() {
|
||||
const licenseTags = document.querySelectorAll('.license-tag');
|
||||
licenseTags.forEach(tag => {
|
||||
const licenseType = tag.dataset.license;
|
||||
const state = (this.filters.license && this.filters.license[licenseType]) || 'none';
|
||||
|
||||
// Reset classes
|
||||
tag.classList.remove('active', 'exclude');
|
||||
|
||||
// Apply appropriate class based on state
|
||||
switch (state) {
|
||||
case 'include':
|
||||
tag.classList.add('active');
|
||||
break;
|
||||
case 'exclude':
|
||||
tag.classList.add('exclude');
|
||||
break;
|
||||
default:
|
||||
// none state - no classes needed
|
||||
break;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
createBaseModelTags() {
|
||||
const baseModelTagsContainer = document.getElementById('baseModelTags');
|
||||
if (!baseModelTagsContainer) return;
|
||||
@@ -233,10 +316,15 @@ export class FilterManager {
|
||||
tag.classList.remove('active');
|
||||
}
|
||||
});
|
||||
|
||||
// Update license tags
|
||||
this.updateLicenseSelections();
|
||||
}
|
||||
|
||||
updateActiveFiltersCount() {
|
||||
const totalActiveFilters = this.filters.baseModel.length + this.filters.tags.length;
|
||||
const totalActiveFilters = this.filters.baseModel.length +
|
||||
this.filters.tags.length +
|
||||
(this.filters.license ? Object.keys(this.filters.license).length : 0);
|
||||
|
||||
if (this.activeFiltersCount) {
|
||||
if (totalActiveFilters > 0) {
|
||||
@@ -296,7 +384,8 @@ export class FilterManager {
|
||||
// Clear all filters
|
||||
this.filters = {
|
||||
baseModel: [],
|
||||
tags: []
|
||||
tags: [],
|
||||
license: {} // Initialize with empty object instead of deleting
|
||||
};
|
||||
|
||||
// Update state
|
||||
@@ -337,7 +426,8 @@ export class FilterManager {
|
||||
// Ensure backward compatibility with older filter format
|
||||
this.filters = {
|
||||
baseModel: savedFilters.baseModel || [],
|
||||
tags: savedFilters.tags || []
|
||||
tags: savedFilters.tags || [],
|
||||
license: savedFilters.license || {}
|
||||
};
|
||||
|
||||
// Update state with loaded filters
|
||||
@@ -357,6 +447,8 @@ export class FilterManager {
|
||||
}
|
||||
|
||||
hasActiveFilters() {
|
||||
return this.filters.baseModel.length > 0 || this.filters.tags.length > 0;
|
||||
return this.filters.baseModel.length > 0 ||
|
||||
this.filters.tags.length > 0 ||
|
||||
(this.filters.license && Object.keys(this.filters.license).length > 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,6 +139,17 @@
|
||||
<div class="tags-loading">{{ t('common.status.loading') }}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="filter-section">
|
||||
<h4>{{ t('header.filter.license') }}</h4>
|
||||
<div class="filter-tags">
|
||||
<div class="filter-tag license-tag" data-license="noCredit">
|
||||
{{ t('header.filter.noCreditRequired') }}
|
||||
</div>
|
||||
<div class="filter-tag license-tag" data-license="allowSelling">
|
||||
{{ t('header.filter.allowSellingGeneratedContent') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div class="filter-actions">
|
||||
<button class="clear-filters-btn" onclick="filterManager.clearFilters()">
|
||||
{{ t('header.filter.clearAll') }}
|
||||
|
||||
146
tests/services/test_license_filters.py
Normal file
146
tests/services/test_license_filters.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for license-based filtering functionality."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
from py.services.base_model_service import BaseModelService
|
||||
from py.utils.civitai_utils import build_license_flags
|
||||
|
||||
|
||||
class DummyModelService(BaseModelService):
|
||||
"""Dummy implementation of BaseModelService for testing."""
|
||||
|
||||
def __init__(self):
|
||||
# Mock the required attributes
|
||||
self.model_type = "test"
|
||||
self.scanner = Mock()
|
||||
self.metadata_class = Mock()
|
||||
self.settings = Mock()
|
||||
self.cache_repository = Mock()
|
||||
self.filter_set = Mock()
|
||||
self.search_strategy = Mock()
|
||||
|
||||
# Mock the scanner's get_cached_data to return a mock cache
|
||||
async def mock_get_cached_data():
|
||||
cache_mock = Mock()
|
||||
cache_mock.get_sorted_data = AsyncMock(return_value=[])
|
||||
return cache_mock
|
||||
|
||||
self.scanner.get_cached_data = mock_get_cached_data
|
||||
|
||||
async def format_response(self, model_data: dict) -> dict:
|
||||
"""Required abstract method implementation."""
|
||||
return model_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_credit_required_filter():
|
||||
"""Test the credit required filtering logic."""
|
||||
service = DummyModelService()
|
||||
|
||||
# Create test data with different license flags
|
||||
test_data = [
|
||||
# Model requiring credit (allowNoCredit = False)
|
||||
{"file_path": "model1.safetensors", "license_flags": build_license_flags({"allowNoCredit": False})},
|
||||
# Model not requiring credit (allowNoCredit = True)
|
||||
{"file_path": "model2.safetensors", "license_flags": build_license_flags({"allowNoCredit": True})},
|
||||
# Model with default license flags (allowNoCredit = True by default)
|
||||
{"file_path": "model3.safetensors", "license_flags": build_license_flags(None)},
|
||||
]
|
||||
|
||||
# Test credit_required=True (should return models that require credit - allowNoCredit=False)
|
||||
filtered = await service._apply_credit_required_filter(test_data, credit_required=True)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_path"] == "model1.safetensors"
|
||||
|
||||
# Test credit_required=False (should return models that don't require credit - allowNoCredit=True)
|
||||
filtered = await service._apply_credit_required_filter(test_data, credit_required=False)
|
||||
assert len(filtered) == 2
|
||||
file_paths = {item["file_path"] for item in filtered}
|
||||
assert file_paths == {"model2.safetensors", "model3.safetensors"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_selling_filter():
|
||||
"""Test the allow selling generated content filtering logic."""
|
||||
service = DummyModelService()
|
||||
|
||||
# Create test data with different license flags
|
||||
test_data = [
|
||||
# Model allowing selling (contains Image in allowCommercialUse)
|
||||
{"file_path": "model1.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["Image"]})},
|
||||
# Model not allowing selling (doesn't contain Image in allowCommercialUse)
|
||||
{"file_path": "model2.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["RentCivit"]})},
|
||||
# Model with default license flags (includes Sell by default, which implies Image)
|
||||
{"file_path": "model3.safetensors", "license_flags": build_license_flags(None)},
|
||||
# Model allowing selling (contains Sell in allowCommercialUse, which implies Image)
|
||||
{"file_path": "model4.safetensors", "license_flags": build_license_flags({"allowCommercialUse": ["Sell"]})},
|
||||
# Model with empty allowCommercialUse (doesn't allow selling)
|
||||
{"file_path": "model5.safetensors", "license_flags": build_license_flags({"allowCommercialUse": []})},
|
||||
]
|
||||
|
||||
# Test allow_selling=True (should return models that allow selling - have Image permission)
|
||||
# Default and Sell permissions both include Image, so model3 and model4 will be included
|
||||
filtered = await service._apply_allow_selling_filter(test_data, allow_selling=True)
|
||||
assert len(filtered) == 3 # model1, model3 (default includes Sell which implies Image), model4
|
||||
file_paths = {item["file_path"] for item in filtered}
|
||||
assert file_paths == {"model1.safetensors", "model3.safetensors", "model4.safetensors"}
|
||||
|
||||
# Test allow_selling=False (should return models that don't allow selling - don't have Image permission)
|
||||
filtered = await service._apply_allow_selling_filter(test_data, allow_selling=False)
|
||||
assert len(filtered) == 2 # model2 and model5
|
||||
file_paths = {item["file_path"] for item in filtered}
|
||||
assert file_paths == {"model2.safetensors", "model5.safetensors"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_filters():
|
||||
"""Test combining both credit required and allow selling filters."""
|
||||
service = DummyModelService()
|
||||
|
||||
# Create test data
|
||||
test_data = [
|
||||
# Requires credit AND allows selling
|
||||
{"file_path": "model1.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image"]
|
||||
})},
|
||||
# Requires credit AND doesn't allow selling
|
||||
{"file_path": "model2.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Rent"]
|
||||
})},
|
||||
# Doesn't require credit AND allows selling
|
||||
{"file_path": "model3.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Image"]
|
||||
})},
|
||||
# Doesn't require credit AND doesn't allow selling
|
||||
{"file_path": "model4.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Rent"]
|
||||
})},
|
||||
]
|
||||
|
||||
# First apply credit_required=True filter (requires credit)
|
||||
filtered = await service._apply_credit_required_filter(test_data, credit_required=True)
|
||||
assert len(filtered) == 2
|
||||
file_paths = {item["file_path"] for item in filtered}
|
||||
assert file_paths == {"model1.safetensors", "model2.safetensors"}
|
||||
|
||||
# Then apply allow_selling=True filter (allows selling) to the result
|
||||
filtered = await service._apply_allow_selling_filter(filtered, allow_selling=True)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_path"] == "model1.safetensors"
|
||||
|
||||
# Test the other combination
|
||||
# First apply credit_required=False filter (doesn't require credit)
|
||||
filtered = await service._apply_credit_required_filter(test_data, credit_required=False)
|
||||
assert len(filtered) == 2
|
||||
file_paths = {item["file_path"] for item in filtered}
|
||||
assert file_paths == {"model3.safetensors", "model4.safetensors"}
|
||||
|
||||
# Then apply allow_selling=False filter (doesn't allow selling) to the result
|
||||
filtered = await service._apply_allow_selling_filter(filtered, allow_selling=False)
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["file_path"] == "model4.safetensors"
|
||||
121
tests/services/test_license_filters_integration.py
Normal file
121
tests/services/test_license_filters_integration.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Integration tests for license-based filtering in BaseModelService."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
|
||||
from py.services.base_model_service import BaseModelService
|
||||
from py.utils.civitai_utils import build_license_flags
|
||||
from py.services.model_query import ModelCacheRepository, ModelFilterSet, SearchStrategy, SortParams
|
||||
|
||||
|
||||
class DummyModelService(BaseModelService):
|
||||
"""Dummy implementation of BaseModelService for testing."""
|
||||
|
||||
def __init__(self):
|
||||
# Mock the required attributes
|
||||
self.model_type = "test"
|
||||
self.scanner = Mock()
|
||||
self.metadata_class = Mock()
|
||||
self.settings = Mock()
|
||||
self.update_service = None # Add the missing attribute
|
||||
|
||||
# Mock the cache repository
|
||||
self.cache_repository = ModelCacheRepository(self.scanner)
|
||||
self.filter_set = ModelFilterSet(self.settings)
|
||||
self.search_strategy = SearchStrategy()
|
||||
|
||||
# Mock the scanner's get_cached_data to return a mock cache
|
||||
self.cache_mock = Mock()
|
||||
self.cache_mock.get_sorted_data = AsyncMock(return_value=[])
|
||||
|
||||
async def mock_get_cached_data():
|
||||
return self.cache_mock
|
||||
|
||||
self.scanner.get_cached_data = mock_get_cached_data
|
||||
|
||||
async def format_response(self, model_data: dict) -> dict:
|
||||
"""Required abstract method implementation."""
|
||||
return model_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_with_license_filters():
|
||||
"""Test that license filters are applied in get_paginated_data."""
|
||||
service = DummyModelService()
|
||||
|
||||
# Create test data with different license flags
|
||||
test_data = [
|
||||
# Model requiring credit AND allowing selling
|
||||
{"file_path": "model1.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image"]
|
||||
})},
|
||||
# Model requiring credit AND not allowing selling
|
||||
{"file_path": "model2.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Rent"]
|
||||
})},
|
||||
# Model not requiring credit AND allowing selling
|
||||
{"file_path": "model3.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Image"]
|
||||
})},
|
||||
# Model not requiring credit AND not allowing selling
|
||||
{"file_path": "model4.safetensors", "license_flags": build_license_flags({
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": ["Rent"]
|
||||
})},
|
||||
]
|
||||
|
||||
# Mock the sorted data
|
||||
service.cache_mock.get_sorted_data = AsyncMock(return_value=test_data)
|
||||
|
||||
# Test with credit_required=True
|
||||
result = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
credit_required=True
|
||||
)
|
||||
assert len(result["items"]) == 2
|
||||
file_paths = {item["file_path"] for item in result["items"]}
|
||||
assert file_paths == {"model1.safetensors", "model2.safetensors"}
|
||||
|
||||
# Test with credit_required=False
|
||||
result = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
credit_required=False
|
||||
)
|
||||
assert len(result["items"]) == 2
|
||||
file_paths = {item["file_path"] for item in result["items"]}
|
||||
assert file_paths == {"model3.safetensors", "model4.safetensors"}
|
||||
|
||||
# Test with allow_selling_generated_content=True
|
||||
result = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
allow_selling_generated_content=True
|
||||
)
|
||||
assert len(result["items"]) == 2
|
||||
file_paths = {item["file_path"] for item in result["items"]}
|
||||
assert file_paths == {"model1.safetensors", "model3.safetensors"}
|
||||
|
||||
# Test with allow_selling_generated_content=False
|
||||
result = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
allow_selling_generated_content=False
|
||||
)
|
||||
assert len(result["items"]) == 2
|
||||
file_paths = {item["file_path"] for item in result["items"]}
|
||||
assert file_paths == {"model2.safetensors", "model4.safetensors"}
|
||||
|
||||
# Test with both filters
|
||||
result = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=10,
|
||||
credit_required=True,
|
||||
allow_selling_generated_content=True
|
||||
)
|
||||
assert len(result["items"]) == 1
|
||||
assert result["items"][0]["file_path"] == "model1.safetensors"
|
||||
Reference in New Issue
Block a user