diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py index 652f3aed..cda1d0b3 100644 --- a/py/services/persistent_model_cache.py +++ b/py/services/persistent_model_cache.py @@ -21,7 +21,7 @@ class PersistedCacheData: excluded_models: List[str] -DEFAULT_LICENSE_FLAGS = 57 # 57 (0b111001) encodes CivitAI's documented default license permissions. +DEFAULT_LICENSE_FLAGS = 127 # 127 (0b1111111) encodes default CivitAI permissions with all commercial modes enabled. class PersistentModelCache: diff --git a/py/utils/civitai_utils.py b/py/utils/civitai_utils.py index aeab2f49..e1eb1a10 100644 --- a/py/utils/civitai_utils.py +++ b/py/utils/civitai_utils.py @@ -2,21 +2,10 @@ from __future__ import annotations -from enum import IntEnum from typing import Any, Dict, Iterable, Mapping, Sequence from urllib.parse import urlparse, urlunparse -class CommercialUseLevel(IntEnum): - """Enumerate supported commercial use permission levels.""" - - NONE = 0 - IMAGE = 1 - RENT_CIVIT = 2 - RENT = 3 - SELL = 4 - - _DEFAULT_ALLOW_COMMERCIAL_USE: Sequence[str] = ("Sell",) _LICENSE_DEFAULTS: Dict[str, Any] = { "allowNoCredit": True, @@ -24,13 +13,8 @@ _LICENSE_DEFAULTS: Dict[str, Any] = { "allowDerivatives": True, "allowDifferentLicense": True, } -_COMMERCIAL_VALUE_TO_LEVEL = { - "none": CommercialUseLevel.NONE, - "image": CommercialUseLevel.IMAGE, - "rentcivit": CommercialUseLevel.RENT_CIVIT, - "rent": CommercialUseLevel.RENT, - "sell": CommercialUseLevel.SELL, -} +_COMMERCIAL_ALLOWED_VALUES = {"sell", "rent", "rentcivit", "image"} +_COMMERCIAL_SHIFT = 1 def _normalize_commercial_values(value: Any) -> Sequence[str]: @@ -53,6 +37,11 @@ def _normalize_commercial_values(value: Any) -> Sequence[str]: result.append(str(item)) if result: return result + try: + if len(value) == 0: # type: ignore[arg-type] + return [] + except TypeError: + pass return list(_DEFAULT_ALLOW_COMMERCIAL_USE) @@ -96,17 +85,25 @@ def resolve_license_payload(model_data: Mapping[str, Any] | None) -> Dict[str, A return payload -def _resolve_commercial_level(values: Sequence[str]) -> CommercialUseLevel: - level = CommercialUseLevel.NONE +def _resolve_commercial_bits(values: Sequence[str]) -> int: + normalized_values = set() for value in values: - normalized = str(value).strip().lower().replace("_", "") - normalized = normalized.replace("-", "") - candidate = _COMMERCIAL_VALUE_TO_LEVEL.get(normalized) - if candidate is None: - continue - if candidate > level: - level = candidate - return level + normalized = str(value).strip().lower().replace("_", "").replace("-", "") + if normalized in _COMMERCIAL_ALLOWED_VALUES: + normalized_values.add(normalized) + + has_sell = "sell" in normalized_values + has_rent = has_sell or "rent" in normalized_values + has_rentcivit = has_rent or "rentcivit" in normalized_values + has_image = has_sell or "image" in normalized_values + + commercial_bits = ( + (1 if has_sell else 0) << 3 + | (1 if has_rent else 0) << 2 + | (1 if has_rentcivit else 0) << 1 + | (1 if has_image else 0) + ) + return commercial_bits << _COMMERCIAL_SHIFT def build_license_flags(payload: Mapping[str, Any] | None) -> int: @@ -118,14 +115,14 @@ def build_license_flags(payload: Mapping[str, Any] | None) -> int: if resolved.get("allowNoCredit", True): flags |= 1 << 0 - commercial_level = _resolve_commercial_level(resolved.get("allowCommercialUse", ())) - flags |= (int(commercial_level) & 0b111) << 1 + commercial_bits = _resolve_commercial_bits(resolved.get("allowCommercialUse", ())) + flags |= commercial_bits if resolved.get("allowDerivatives", True): - flags |= 1 << 4 + flags |= 1 << 5 if resolved.get("allowDifferentLicense", True): - flags |= 1 << 5 + flags |= 1 << 6 return flags @@ -176,7 +173,6 @@ def rewrite_preview_url(source_url: str | None, media_type: str | None = None) - __all__ = [ - "CommercialUseLevel", "build_license_flags", "resolve_license_payload", "resolve_license_info", diff --git a/tests/utils/test_civitai_utils.py b/tests/utils/test_civitai_utils.py index 1204e56b..b1f3ae2a 100644 --- a/tests/utils/test_civitai_utils.py +++ b/tests/utils/test_civitai_utils.py @@ -1,9 +1,4 @@ -from py.utils.civitai_utils import ( - CommercialUseLevel, - build_license_flags, - resolve_license_info, - resolve_license_payload, -) +from py.utils.civitai_utils import build_license_flags, resolve_license_info, resolve_license_payload def test_resolve_license_payload_defaults(): @@ -13,7 +8,7 @@ def test_resolve_license_payload_defaults(): assert payload["allowDerivatives"] is True assert payload["allowDifferentLicense"] is True assert payload["allowCommercialUse"] == ["Sell"] - assert flags == 57 + assert flags == 127 def test_build_license_flags_custom_values(): @@ -31,5 +26,23 @@ def test_build_license_flags_custom_values(): assert payload["allowDifferentLicense"] is False flags = build_license_flags(source) - # Highest commercial level is SELL -> level 4 -> shifted by 1 == 8. - assert flags == (CommercialUseLevel.SELL << 1) + # Sell automatically enables all commercial bits including image. + assert flags == 30 + + +def test_build_license_flags_respects_commercial_hierarchy(): + base = { + "allowNoCredit": False, + "allowDerivatives": False, + "allowDifferentLicense": False, + } + + assert build_license_flags({**base, "allowCommercialUse": []}) == 0 + # Rent adds rent and rentcivit permissions. + assert build_license_flags({**base, "allowCommercialUse": ["Rent"]}) == 12 + # RentCivit alone should only set its own bit. + assert build_license_flags({**base, "allowCommercialUse": ["RentCivit"]}) == 4 + # Image only toggles the image bit. + assert build_license_flags({**base, "allowCommercialUse": ["Image"]}) == 2 + # Sell forces all commercial bits regardless of image listing. + assert build_license_flags({**base, "allowCommercialUse": ["Sell"]}) == 30