diff --git a/py/utils/civitai_utils.py b/py/utils/civitai_utils.py index e1eb1a10..198af376 100644 --- a/py/utils/civitai_utils.py +++ b/py/utils/civitai_utils.py @@ -20,11 +20,25 @@ _COMMERCIAL_SHIFT = 1 def _normalize_commercial_values(value: Any) -> Sequence[str]: """Return a normalized list of commercial permissions preserving source values.""" + def _split_aggregate(value_str: str) -> list[str]: + stripped = value_str.strip() + looks_aggregate = "," in stripped or (stripped.startswith("{") and stripped.endswith("}")) + if not looks_aggregate: + return [value_str] + + trimmed = stripped + if trimmed.startswith("{") and trimmed.endswith("}"): + trimmed = trimmed[1:-1] + + parts = [part.strip() for part in trimmed.split(",")] + result = [part for part in parts if part] + return result or [value_str] + if value is None: return list(_DEFAULT_ALLOW_COMMERCIAL_USE) if isinstance(value, str): - return [value] + return _split_aggregate(value) if isinstance(value, Iterable): result = [] @@ -32,7 +46,7 @@ def _normalize_commercial_values(value: Any) -> Sequence[str]: if item is None: continue if isinstance(item, str): - result.append(item) + result.extend(_split_aggregate(item)) continue result.append(str(item)) if result: diff --git a/static/js/components/shared/ModelModal.js b/static/js/components/shared/ModelModal.js index 13e22fd6..562562a3 100644 --- a/static/js/components/shared/ModelModal.js +++ b/static/js/components/shared/ModelModal.js @@ -77,28 +77,72 @@ function indentMarkup(markup, spaces) { .join('\n'); } +function splitAggregateCommercialValue(value) { + const trimmed = String(value ?? '').trim(); + const looksAggregate = trimmed.includes(',') || (trimmed.startsWith('{') && trimmed.endsWith('}')); + if (!looksAggregate) { + return [value]; + } + + let inner = trimmed; + if (inner.startsWith('{') && inner.endsWith('}')) { + inner = inner.slice(1, -1); + } + + const parts = inner + .split(',') + .map(part => part.trim()) + .filter(Boolean); + + return parts.length ? parts : [value]; +} + function normalizeCommercialValues(value) { if (!value && value !== '') { return ['Sell']; } + if (Array.isArray(value)) { - return value.filter(item => item !== null && item !== undefined); + const flattened = []; + value.forEach(item => { + if (item === null || item === undefined) { + return; + } + if (typeof item === 'string') { + flattened.push(...splitAggregateCommercialValue(item)); + return; + } + flattened.push(String(item)); + }); + if (flattened.length > 0) { + return flattened; + } + if (value.length === 0) { + return []; + } } + if (typeof value === 'string') { - return [value]; + return splitAggregateCommercialValue(value); } + if (value && typeof value[Symbol.iterator] === 'function') { const result = []; for (const item of value) { if (item === null || item === undefined) { continue; } + if (typeof item === 'string') { + result.push(...splitAggregateCommercialValue(item)); + continue; + } result.push(String(item)); } if (result.length > 0) { return result; } } + return ['Sell']; } diff --git a/tests/frontend/components/modelModal.licenseIcons.test.js b/tests/frontend/components/modelModal.licenseIcons.test.js new file mode 100644 index 00000000..355fdd63 --- /dev/null +++ b/tests/frontend/components/modelModal.licenseIcons.test.js @@ -0,0 +1,129 @@ +import { describe, it, beforeEach, expect, vi } from 'vitest'; + +const { + MODAL_MODULE, + API_FACTORY, + UI_HELPERS_MODULE, + MODAL_MANAGER_MODULE, + SHOWCASE_MODULE, + MODEL_TAGS_MODULE, + UTILS_MODULE, + TRIGGER_WORDS_MODULE, + PRESET_TAGS_MODULE, + MODEL_VERSIONS_MODULE, + RECIPE_TAB_MODULE, + I18N_HELPERS_MODULE, +} = vi.hoisted(() => ({ + MODAL_MODULE: new URL('../../../static/js/components/shared/ModelModal.js', import.meta.url).pathname, + API_FACTORY: new URL('../../../static/js/api/modelApiFactory.js', import.meta.url).pathname, + UI_HELPERS_MODULE: new URL('../../../static/js/utils/uiHelpers.js', import.meta.url).pathname, + MODAL_MANAGER_MODULE: new URL('../../../static/js/managers/ModalManager.js', import.meta.url).pathname, + SHOWCASE_MODULE: new URL('../../../static/js/components/shared/showcase/ShowcaseView.js', import.meta.url).pathname, + MODEL_TAGS_MODULE: new URL('../../../static/js/components/shared/ModelTags.js', import.meta.url).pathname, + UTILS_MODULE: new URL('../../../static/js/components/shared/utils.js', import.meta.url).pathname, + TRIGGER_WORDS_MODULE: new URL('../../../static/js/components/shared/TriggerWords.js', import.meta.url).pathname, + PRESET_TAGS_MODULE: new URL('../../../static/js/components/shared/PresetTags.js', import.meta.url).pathname, + MODEL_VERSIONS_MODULE: new URL('../../../static/js/components/shared/ModelVersionsTab.js', import.meta.url).pathname, + RECIPE_TAB_MODULE: new URL('../../../static/js/components/shared/RecipeTab.js', import.meta.url).pathname, + I18N_HELPERS_MODULE: new URL('../../../static/js/utils/i18nHelpers.js', import.meta.url).pathname, +})); + +vi.mock(UI_HELPERS_MODULE, () => ({ + showToast: vi.fn(), + openCivitai: vi.fn(), +})); + +vi.mock(MODAL_MANAGER_MODULE, () => ({ + modalManager: { + showModal: vi.fn((id, html) => { + document.body.innerHTML = `
${html}
`; + }), + closeModal: vi.fn(), + }, +})); + +vi.mock(SHOWCASE_MODULE, () => ({ + toggleShowcase: vi.fn(), + setupShowcaseScroll: vi.fn(), + scrollToTop: vi.fn(), + loadExampleImages: vi.fn(), +})); + +vi.mock(MODEL_TAGS_MODULE, () => ({ + setupTagEditMode: vi.fn(), +})); + +vi.mock(UTILS_MODULE, () => ({ + renderCompactTags: vi.fn(() => ''), + setupTagTooltip: vi.fn(), + formatFileSize: vi.fn(() => '1 MB'), +})); + +vi.mock(TRIGGER_WORDS_MODULE, () => ({ + renderTriggerWords: vi.fn(() => ''), + setupTriggerWordsEditMode: vi.fn(), +})); + +vi.mock(PRESET_TAGS_MODULE, () => ({ + parsePresets: vi.fn(() => ({})), + renderPresetTags: vi.fn(() => ''), +})); + +vi.mock(MODEL_VERSIONS_MODULE, () => ({ + initVersionsTab: vi.fn(() => ({ + load: vi.fn().mockResolvedValue(undefined), + })), +})); + +vi.mock(RECIPE_TAB_MODULE, () => ({ + loadRecipesForLora: vi.fn(), +})); + +vi.mock(I18N_HELPERS_MODULE, () => ({ + translate: vi.fn((_, __, fallback) => fallback || ''), +})); + +vi.mock(API_FACTORY, () => ({ + getModelApiClient: vi.fn(), +})); + +describe('Model modal license rendering', () => { + let getModelApiClient; + + beforeEach(async () => { + document.body.innerHTML = ''; + ({ getModelApiClient } = await import(API_FACTORY)); + getModelApiClient.mockReset(); + }); + + it('handles aggregated commercial strings without extra restrictions', async () => { + const fetchModelMetadata = vi.fn().mockResolvedValue(null); + getModelApiClient.mockReturnValue({ + fetchModelMetadata, + saveModelMetadata: vi.fn(), + }); + + const { showModelModal } = await import(MODAL_MODULE); + + await showModelModal( + { + model_name: 'Aggregated', + file_path: 'models/agg.safetensors', + file_name: 'agg.safetensors', + civitai: { + model: { + allowNoCredit: true, + allowCommercialUse: '{Image,RentCivit,Rent}', + allowDerivatives: true, + allowDifferentLicense: false, + }, + }, + }, + 'loras', + ); + + const iconTitles = Array.from(document.querySelectorAll('.license-restrictions .license-icon')).map(icon => icon.getAttribute('title')); + + expect(iconTitles).toEqual(['No selling models', 'Same permissions required']); + }); +}); diff --git a/tests/utils/test_civitai_utils.py b/tests/utils/test_civitai_utils.py index b1f3ae2a..840b17d3 100644 --- a/tests/utils/test_civitai_utils.py +++ b/tests/utils/test_civitai_utils.py @@ -46,3 +46,35 @@ def test_build_license_flags_respects_commercial_hierarchy(): assert build_license_flags({**base, "allowCommercialUse": ["Image"]}) == 2 # Sell forces all commercial bits regardless of image listing. assert build_license_flags({**base, "allowCommercialUse": ["Sell"]}) == 30 + + +def test_build_license_flags_parses_aggregate_string(): + source = { + "allowNoCredit": True, + "allowCommercialUse": "{Image,RentCivit,Rent}", + "allowDerivatives": True, + "allowDifferentLicense": False, + } + + payload = resolve_license_payload(source) + assert set(payload["allowCommercialUse"]) == {"Image", "RentCivit", "Rent"} + + flags = build_license_flags(source) + expected_flags = (1 << 0) | (7 << 1) | (1 << 5) + assert flags == expected_flags + + +def test_build_license_flags_parses_aggregate_inside_list(): + source = { + "allowNoCredit": True, + "allowCommercialUse": ["{Image,RentCivit,Rent}"], + "allowDerivatives": True, + "allowDifferentLicense": False, + } + + payload = resolve_license_payload(source) + assert set(payload["allowCommercialUse"]) == {"Image", "RentCivit", "Rent"} + + flags = build_license_flags(source) + expected_flags = (1 << 0) | (7 << 1) | (1 << 5) + assert flags == expected_flags