From 22ee37b8175972f029e96b1fefe7dd39831d0f63 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sun, 30 Nov 2025 17:10:21 +0800 Subject: [PATCH] feat: parse aggregate commercial use values, see #708 Add support for parsing comma-separated and JSON-style commercial use permission values in both Python backend and JavaScript frontend. Implement helper functions to split aggregated values into individual permissions while preserving original values when no aggregation is detected. Added comprehensive test coverage for the new parsing functionality to ensure correct handling of various input formats including strings, arrays, and iterable objects with aggregated commercial use values. --- py/utils/civitai_utils.py | 18 ++- static/js/components/shared/ModelModal.js | 48 ++++++- .../modelModal.licenseIcons.test.js | 129 ++++++++++++++++++ tests/utils/test_civitai_utils.py | 32 +++++ 4 files changed, 223 insertions(+), 4 deletions(-) create mode 100644 tests/frontend/components/modelModal.licenseIcons.test.js diff --git a/py/utils/civitai_utils.py b/py/utils/civitai_utils.py index e1eb1a10..198af376 100644 --- a/py/utils/civitai_utils.py +++ b/py/utils/civitai_utils.py @@ -20,11 +20,25 @@ _COMMERCIAL_SHIFT = 1 def _normalize_commercial_values(value: Any) -> Sequence[str]: """Return a normalized list of commercial permissions preserving source values.""" + def _split_aggregate(value_str: str) -> list[str]: + stripped = value_str.strip() + looks_aggregate = "," in stripped or (stripped.startswith("{") and stripped.endswith("}")) + if not looks_aggregate: + return [value_str] + + trimmed = stripped + if trimmed.startswith("{") and trimmed.endswith("}"): + trimmed = trimmed[1:-1] + + parts = [part.strip() for part in trimmed.split(",")] + result = [part for part in parts if part] + return result or [value_str] + if value is None: return list(_DEFAULT_ALLOW_COMMERCIAL_USE) if isinstance(value, str): - return [value] + return _split_aggregate(value) if isinstance(value, Iterable): result = [] @@ -32,7 +46,7 @@ def _normalize_commercial_values(value: Any) -> Sequence[str]: if item is None: continue if isinstance(item, str): - result.append(item) + result.extend(_split_aggregate(item)) continue result.append(str(item)) if result: diff --git a/static/js/components/shared/ModelModal.js b/static/js/components/shared/ModelModal.js index 13e22fd6..562562a3 100644 --- a/static/js/components/shared/ModelModal.js +++ b/static/js/components/shared/ModelModal.js @@ -77,28 +77,72 @@ function indentMarkup(markup, spaces) { .join('\n'); } +function splitAggregateCommercialValue(value) { + const trimmed = String(value ?? '').trim(); + const looksAggregate = trimmed.includes(',') || (trimmed.startsWith('{') && trimmed.endsWith('}')); + if (!looksAggregate) { + return [value]; + } + + let inner = trimmed; + if (inner.startsWith('{') && inner.endsWith('}')) { + inner = inner.slice(1, -1); + } + + const parts = inner + .split(',') + .map(part => part.trim()) + .filter(Boolean); + + return parts.length ? parts : [value]; +} + function normalizeCommercialValues(value) { if (!value && value !== '') { return ['Sell']; } + if (Array.isArray(value)) { - return value.filter(item => item !== null && item !== undefined); + const flattened = []; + value.forEach(item => { + if (item === null || item === undefined) { + return; + } + if (typeof item === 'string') { + flattened.push(...splitAggregateCommercialValue(item)); + return; + } + flattened.push(String(item)); + }); + if (flattened.length > 0) { + return flattened; + } + if (value.length === 0) { + return []; + } } + if (typeof value === 'string') { - return [value]; + return splitAggregateCommercialValue(value); } + if (value && typeof value[Symbol.iterator] === 'function') { const result = []; for (const item of value) { if (item === null || item === undefined) { continue; } + if (typeof item === 'string') { + result.push(...splitAggregateCommercialValue(item)); + continue; + } result.push(String(item)); } if (result.length > 0) { return result; } } + return ['Sell']; } diff --git a/tests/frontend/components/modelModal.licenseIcons.test.js b/tests/frontend/components/modelModal.licenseIcons.test.js new file mode 100644 index 00000000..355fdd63 --- /dev/null +++ b/tests/frontend/components/modelModal.licenseIcons.test.js @@ -0,0 +1,129 @@ +import { describe, it, beforeEach, expect, vi } from 'vitest'; + +const { + MODAL_MODULE, + API_FACTORY, + UI_HELPERS_MODULE, + MODAL_MANAGER_MODULE, + SHOWCASE_MODULE, + MODEL_TAGS_MODULE, + UTILS_MODULE, + TRIGGER_WORDS_MODULE, + PRESET_TAGS_MODULE, + MODEL_VERSIONS_MODULE, + RECIPE_TAB_MODULE, + I18N_HELPERS_MODULE, +} = vi.hoisted(() => ({ + MODAL_MODULE: new URL('../../../static/js/components/shared/ModelModal.js', import.meta.url).pathname, + API_FACTORY: new URL('../../../static/js/api/modelApiFactory.js', import.meta.url).pathname, + UI_HELPERS_MODULE: new URL('../../../static/js/utils/uiHelpers.js', import.meta.url).pathname, + MODAL_MANAGER_MODULE: new URL('../../../static/js/managers/ModalManager.js', import.meta.url).pathname, + SHOWCASE_MODULE: new URL('../../../static/js/components/shared/showcase/ShowcaseView.js', import.meta.url).pathname, + MODEL_TAGS_MODULE: new URL('../../../static/js/components/shared/ModelTags.js', import.meta.url).pathname, + UTILS_MODULE: new URL('../../../static/js/components/shared/utils.js', import.meta.url).pathname, + TRIGGER_WORDS_MODULE: new URL('../../../static/js/components/shared/TriggerWords.js', import.meta.url).pathname, + PRESET_TAGS_MODULE: new URL('../../../static/js/components/shared/PresetTags.js', import.meta.url).pathname, + MODEL_VERSIONS_MODULE: new URL('../../../static/js/components/shared/ModelVersionsTab.js', import.meta.url).pathname, + RECIPE_TAB_MODULE: new URL('../../../static/js/components/shared/RecipeTab.js', import.meta.url).pathname, + I18N_HELPERS_MODULE: new URL('../../../static/js/utils/i18nHelpers.js', import.meta.url).pathname, +})); + +vi.mock(UI_HELPERS_MODULE, () => ({ + showToast: vi.fn(), + openCivitai: vi.fn(), +})); + +vi.mock(MODAL_MANAGER_MODULE, () => ({ + modalManager: { + showModal: vi.fn((id, html) => { + document.body.innerHTML = `
${html}
`; + }), + closeModal: vi.fn(), + }, +})); + +vi.mock(SHOWCASE_MODULE, () => ({ + toggleShowcase: vi.fn(), + setupShowcaseScroll: vi.fn(), + scrollToTop: vi.fn(), + loadExampleImages: vi.fn(), +})); + +vi.mock(MODEL_TAGS_MODULE, () => ({ + setupTagEditMode: vi.fn(), +})); + +vi.mock(UTILS_MODULE, () => ({ + renderCompactTags: vi.fn(() => ''), + setupTagTooltip: vi.fn(), + formatFileSize: vi.fn(() => '1 MB'), +})); + +vi.mock(TRIGGER_WORDS_MODULE, () => ({ + renderTriggerWords: vi.fn(() => ''), + setupTriggerWordsEditMode: vi.fn(), +})); + +vi.mock(PRESET_TAGS_MODULE, () => ({ + parsePresets: vi.fn(() => ({})), + renderPresetTags: vi.fn(() => ''), +})); + +vi.mock(MODEL_VERSIONS_MODULE, () => ({ + initVersionsTab: vi.fn(() => ({ + load: vi.fn().mockResolvedValue(undefined), + })), +})); + +vi.mock(RECIPE_TAB_MODULE, () => ({ + loadRecipesForLora: vi.fn(), +})); + +vi.mock(I18N_HELPERS_MODULE, () => ({ + translate: vi.fn((_, __, fallback) => fallback || ''), +})); + +vi.mock(API_FACTORY, () => ({ + getModelApiClient: vi.fn(), +})); + +describe('Model modal license rendering', () => { + let getModelApiClient; + + beforeEach(async () => { + document.body.innerHTML = ''; + ({ getModelApiClient } = await import(API_FACTORY)); + getModelApiClient.mockReset(); + }); + + it('handles aggregated commercial strings without extra restrictions', async () => { + const fetchModelMetadata = vi.fn().mockResolvedValue(null); + getModelApiClient.mockReturnValue({ + fetchModelMetadata, + saveModelMetadata: vi.fn(), + }); + + const { showModelModal } = await import(MODAL_MODULE); + + await showModelModal( + { + model_name: 'Aggregated', + file_path: 'models/agg.safetensors', + file_name: 'agg.safetensors', + civitai: { + model: { + allowNoCredit: true, + allowCommercialUse: '{Image,RentCivit,Rent}', + allowDerivatives: true, + allowDifferentLicense: false, + }, + }, + }, + 'loras', + ); + + const iconTitles = Array.from(document.querySelectorAll('.license-restrictions .license-icon')).map(icon => icon.getAttribute('title')); + + expect(iconTitles).toEqual(['No selling models', 'Same permissions required']); + }); +}); diff --git a/tests/utils/test_civitai_utils.py b/tests/utils/test_civitai_utils.py index b1f3ae2a..840b17d3 100644 --- a/tests/utils/test_civitai_utils.py +++ b/tests/utils/test_civitai_utils.py @@ -46,3 +46,35 @@ def test_build_license_flags_respects_commercial_hierarchy(): assert build_license_flags({**base, "allowCommercialUse": ["Image"]}) == 2 # Sell forces all commercial bits regardless of image listing. assert build_license_flags({**base, "allowCommercialUse": ["Sell"]}) == 30 + + +def test_build_license_flags_parses_aggregate_string(): + source = { + "allowNoCredit": True, + "allowCommercialUse": "{Image,RentCivit,Rent}", + "allowDerivatives": True, + "allowDifferentLicense": False, + } + + payload = resolve_license_payload(source) + assert set(payload["allowCommercialUse"]) == {"Image", "RentCivit", "Rent"} + + flags = build_license_flags(source) + expected_flags = (1 << 0) | (7 << 1) | (1 << 5) + assert flags == expected_flags + + +def test_build_license_flags_parses_aggregate_inside_list(): + source = { + "allowNoCredit": True, + "allowCommercialUse": ["{Image,RentCivit,Rent}"], + "allowDerivatives": True, + "allowDifferentLicense": False, + } + + payload = resolve_license_payload(source) + assert set(payload["allowCommercialUse"]) == {"Image", "RentCivit", "Rent"} + + flags = build_license_flags(source) + expected_flags = (1 << 0) | (7 << 1) | (1 << 5) + assert flags == expected_flags