From ddf9e339610af01c789dc26578325cf3d7ba5f56 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 6 Nov 2025 17:05:31 +0800 Subject: [PATCH] feat: add license information handling for Civitai models Add license resolution utilities and integrate license information into model metadata processing. The changes include: - Add `resolve_license_payload` function to extract license data from Civitai model responses - Integrate license information into model metadata in CivitaiClient and MetadataSyncService - Add license flags support in model scanning and caching - Implement CommercialUseLevel enum for standardized license classification - Update model scanner to handle unknown fields when extracting metadata values This ensures proper license attribution and compliance when working with Civitai models. --- py/services/civitai_client.py | 5 + py/services/metadata_sync_service.py | 14 +- py/services/model_scanner.py | 84 ++++++++++- py/services/persistent_model_cache.py | 16 ++ py/utils/civitai_utils.py | 141 +++++++++++++++++- refs/target_version.json | 6 +- tests/services/test_metadata_sync_service.py | 27 +++- tests/services/test_model_scanner.py | 41 ++++- tests/services/test_persistent_model_cache.py | 5 +- tests/utils/test_civitai_utils.py | 35 +++++ 10 files changed, 364 insertions(+), 10 deletions(-) create mode 100644 tests/utils/test_civitai_utils.py diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 30651e46..c008e20b 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Dict, Tuple, List, Sequence from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .downloader import get_downloader from .errors import RateLimitError, ResourceNotFoundError +from ..utils.civitai_utils import resolve_license_payload logger = logging.getLogger(__name__) @@ -420,6 +421,10 @@ class CivitaiClient: model_info['tags'] = model_data.get("tags", []) version['creator'] = model_data.get("creator") + license_payload = resolve_license_payload(model_data) + for field, value in license_payload.items(): + model_info[field] = value + async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]: """Fetch model version metadata from Civitai diff --git a/py/services/metadata_sync_service.py b/py/services/metadata_sync_service.py index cd160a57..6a302bd1 100644 --- a/py/services/metadata_sync_service.py +++ b/py/services/metadata_sync_service.py @@ -9,6 +9,7 @@ from datetime import datetime from typing import Any, Awaitable, Callable, Dict, Iterable, Optional from ..services.settings_manager import SettingsManager +from ..utils.civitai_utils import resolve_license_payload from ..utils.model_utils import determine_base_model from .errors import RateLimitError @@ -135,6 +136,17 @@ class MetadataSyncService: ): local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"] + merged_civitai = local_metadata.get("civitai") or {} + civitai_model = merged_civitai.get("model") + if not isinstance(civitai_model, dict): + civitai_model = {} + + license_payload = resolve_license_payload(model_data) + civitai_model.update(license_payload) + + merged_civitai["model"] = civitai_model + local_metadata["civitai"] = merged_civitai + local_metadata["base_model"] = determine_base_model( civitai_metadata.get("baseModel") ) @@ -295,6 +307,7 @@ class MetadataSyncService: "preview_url": local_metadata.get("preview_url"), "civitai": local_metadata.get("civitai"), } + model_data.update(update_payload) await update_cache_func(file_path, file_path, local_metadata) @@ -436,4 +449,3 @@ class MetadataSyncService: results["verified_as_duplicates"] = False return results - diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 6e1e8a0a..ebe05c42 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -11,6 +11,7 @@ from ..utils.models import BaseModelMetadata from ..config import config from ..utils.file_utils import find_preview_file, get_preview_extension from ..utils.metadata_manager import MetadataManager +from ..utils.civitai_utils import resolve_license_info from .model_cache import ModelCache from .model_hash_index import ModelHashIndex from ..utils.constants import PREVIEW_EXTENSIONS @@ -175,7 +176,17 @@ class ModelScanner: def get_value(key: str, default: Any = None) -> Any: if is_mapping: return source.get(key, default) - return getattr(source, key, default) + + sentinel = object() + value = getattr(source, key, sentinel) + if value is not sentinel: + return value + + unknown = getattr(source, "_unknown_fields", None) + if isinstance(unknown, dict) and key in unknown: + return unknown[key] + + return default file_path = file_path_override or get_value('file_path', '') or '' normalized_path = file_path.replace('\\', '/') @@ -197,7 +208,8 @@ class ModelScanner: else: preview_url = '' - civitai_slim = self._slim_civitai_payload(get_value('civitai')) + civitai_full = get_value('civitai') + civitai_slim = self._slim_civitai_payload(civitai_full) usage_tips = get_value('usage_tips', '') or '' if not isinstance(usage_tips, str): usage_tips = str(usage_tips) @@ -229,12 +241,76 @@ class ModelScanner: 'civitai_deleted': bool(get_value('civitai_deleted', False)), } + license_source: Dict[str, Any] = {} + if isinstance(civitai_full, Mapping): + civitai_model = civitai_full.get('model') + if isinstance(civitai_model, Mapping): + for key in ( + 'allowNoCredit', + 'allowCommercialUse', + 'allowDerivatives', + 'allowDifferentLicense', + ): + if key in civitai_model: + license_source[key] = civitai_model.get(key) + + for key in ( + 'allowNoCredit', + 'allowCommercialUse', + 'allowDerivatives', + 'allowDifferentLicense', + ): + if key not in license_source: + value = get_value(key) + if value is not None: + license_source[key] = value + + _, license_flags = resolve_license_info(license_source or {}) + entry['license_flags'] = license_flags + model_type = get_value('model_type', None) if model_type: entry['model_type'] = model_type return entry + def _ensure_license_flags(self, entry: Dict[str, Any]) -> None: + """Ensure cached entries include an integer license flag bitset.""" + + if not isinstance(entry, dict): + return + + license_value = entry.get('license_flags') + if license_value is not None: + try: + entry['license_flags'] = int(license_value) + except (TypeError, ValueError): + _, fallback_flags = resolve_license_info({}) + entry['license_flags'] = fallback_flags + return + + license_source = { + 'allowNoCredit': entry.get('allowNoCredit'), + 'allowCommercialUse': entry.get('allowCommercialUse'), + 'allowDerivatives': entry.get('allowDerivatives'), + 'allowDifferentLicense': entry.get('allowDifferentLicense'), + } + civitai_full = entry.get('civitai') + if isinstance(civitai_full, Mapping): + civitai_model = civitai_full.get('model') + if isinstance(civitai_model, Mapping): + for key in ( + 'allowNoCredit', + 'allowCommercialUse', + 'allowDerivatives', + 'allowDifferentLicense', + ): + if key in civitai_model: + license_source[key] = civitai_model.get(key) + + _, license_flags = resolve_license_info(license_source) + entry['license_flags'] = license_flags + async def initialize_in_background(self) -> None: """Initialize cache in background using thread pool""" try: @@ -567,6 +643,7 @@ class ModelScanner: async def _initialize_cache(self) -> None: """Initialize or refresh the cache""" + print("init start", flush=True) self._is_initializing = True # Set flag try: start_time = time.time() @@ -575,6 +652,7 @@ class ModelScanner: scan_result = await self._gather_model_data() await self._apply_scan_result(scan_result) await self._save_persistent_cache(scan_result) + print("init end", flush=True) logger.info( f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, " @@ -681,6 +759,7 @@ class ModelScanner: model_data = self.adjust_cached_entry(dict(model_data)) if not model_data: continue + self._ensure_license_flags(model_data) # Add to cache self._cache.raw_data.append(model_data) self._cache.add_to_version_index(model_data) @@ -975,6 +1054,7 @@ class ModelScanner: processed_files += 1 if result: + self._ensure_license_flags(result) raw_data.append(result) sha_value = result.get('sha256') diff --git a/py/services/persistent_model_cache.py b/py/services/persistent_model_cache.py index 91ae80d1..652f3aed 100644 --- a/py/services/persistent_model_cache.py +++ b/py/services/persistent_model_cache.py @@ -21,6 +21,9 @@ class PersistedCacheData: excluded_models: List[str] +DEFAULT_LICENSE_FLAGS = 57 # 57 (0b111001) encodes CivitAI's documented default license permissions. + + class PersistentModelCache: """Persist core model metadata and hash index data in SQLite.""" @@ -47,6 +50,7 @@ class PersistentModelCache: "civitai_name", "civitai_creator_username", "trained_words", + "license_flags", "civitai_deleted", "exclude", "db_checked", @@ -149,6 +153,10 @@ class PersistentModelCache: if creator_username: civitai.setdefault("creator", {})["username"] = creator_username + license_value = row["license_flags"] + if license_value is None: + license_value = DEFAULT_LICENSE_FLAGS + item = { "file_path": file_path, "file_name": row["file_name"], @@ -171,6 +179,7 @@ class PersistentModelCache: "tags": tags.get(file_path, []), "civitai": civitai, "civitai_deleted": bool(row["civitai_deleted"]), + "license_flags": int(license_value), } raw_data.append(item) @@ -484,6 +493,8 @@ class PersistentModelCache: "metadata_source": "TEXT", "civitai_creator_username": "TEXT", "civitai_deleted": "INTEGER DEFAULT 0", + # Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57). + "license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}", } for column, definition in required_columns.items(): @@ -518,6 +529,10 @@ class PersistentModelCache: if isinstance(creator_data, dict): creator_username = creator_data.get("username") or None + license_flags = item.get("license_flags") + if license_flags is None: + license_flags = DEFAULT_LICENSE_FLAGS + return ( model_type, item.get("file_path"), @@ -540,6 +555,7 @@ class PersistentModelCache: civitai.get("name"), creator_username, trained_words_json, + int(license_flags), 1 if item.get("civitai_deleted") else 0, 1 if item.get("exclude") else 0, 1 if item.get("db_checked") else 0, diff --git a/py/utils/civitai_utils.py b/py/utils/civitai_utils.py index 02155145..aeab2f49 100644 --- a/py/utils/civitai_utils.py +++ b/py/utils/civitai_utils.py @@ -2,9 +2,141 @@ from __future__ import annotations +from enum import IntEnum +from typing import Any, Dict, Iterable, Mapping, Sequence from urllib.parse import urlparse, urlunparse +class CommercialUseLevel(IntEnum): + """Enumerate supported commercial use permission levels.""" + + NONE = 0 + IMAGE = 1 + RENT_CIVIT = 2 + RENT = 3 + SELL = 4 + + +_DEFAULT_ALLOW_COMMERCIAL_USE: Sequence[str] = ("Sell",) +_LICENSE_DEFAULTS: Dict[str, Any] = { + "allowNoCredit": True, + "allowCommercialUse": _DEFAULT_ALLOW_COMMERCIAL_USE, + "allowDerivatives": True, + "allowDifferentLicense": True, +} +_COMMERCIAL_VALUE_TO_LEVEL = { + "none": CommercialUseLevel.NONE, + "image": CommercialUseLevel.IMAGE, + "rentcivit": CommercialUseLevel.RENT_CIVIT, + "rent": CommercialUseLevel.RENT, + "sell": CommercialUseLevel.SELL, +} + + +def _normalize_commercial_values(value: Any) -> Sequence[str]: + """Return a normalized list of commercial permissions preserving source values.""" + + if value is None: + return list(_DEFAULT_ALLOW_COMMERCIAL_USE) + + if isinstance(value, str): + return [value] + + if isinstance(value, Iterable): + result = [] + for item in value: + if item is None: + continue + if isinstance(item, str): + result.append(item) + continue + result.append(str(item)) + if result: + return result + + return list(_DEFAULT_ALLOW_COMMERCIAL_USE) + + +def _to_bool(value: Any, fallback: bool) -> bool: + if value is None: + return fallback + return bool(value) + + +def resolve_license_payload(model_data: Mapping[str, Any] | None) -> Dict[str, Any]: + """Extract license fields from model metadata applying documented defaults.""" + + payload: Dict[str, Any] = {} + + allow_no_credit = payload["allowNoCredit"] = _to_bool( + (model_data or {}).get("allowNoCredit"), + _LICENSE_DEFAULTS["allowNoCredit"], + ) + + commercial = _normalize_commercial_values( + (model_data or {}).get("allowCommercialUse"), + ) + payload["allowCommercialUse"] = list(commercial) + + allow_derivatives = payload["allowDerivatives"] = _to_bool( + (model_data or {}).get("allowDerivatives"), + _LICENSE_DEFAULTS["allowDerivatives"], + ) + + allow_different_license = payload["allowDifferentLicense"] = _to_bool( + (model_data or {}).get("allowDifferentLicense"), + _LICENSE_DEFAULTS["allowDifferentLicense"], + ) + + # Ensure booleans are plain bool instances + payload["allowNoCredit"] = bool(allow_no_credit) + payload["allowDerivatives"] = bool(allow_derivatives) + payload["allowDifferentLicense"] = bool(allow_different_license) + + return payload + + +def _resolve_commercial_level(values: Sequence[str]) -> CommercialUseLevel: + level = CommercialUseLevel.NONE + for value in values: + normalized = str(value).strip().lower().replace("_", "") + normalized = normalized.replace("-", "") + candidate = _COMMERCIAL_VALUE_TO_LEVEL.get(normalized) + if candidate is None: + continue + if candidate > level: + level = candidate + return level + + +def build_license_flags(payload: Mapping[str, Any] | None) -> int: + """Encode license payload into a compact bitset for cache storage.""" + + resolved = resolve_license_payload(payload or {}) + + flags = 0 + if resolved.get("allowNoCredit", True): + flags |= 1 << 0 + + commercial_level = _resolve_commercial_level(resolved.get("allowCommercialUse", ())) + flags |= (int(commercial_level) & 0b111) << 1 + + if resolved.get("allowDerivatives", True): + flags |= 1 << 4 + + if resolved.get("allowDifferentLicense", True): + flags |= 1 << 5 + + return flags + + +def resolve_license_info(model_data: Mapping[str, Any] | None) -> tuple[Dict[str, Any], int]: + """Return normalized license payload and its encoded bitset.""" + + payload = resolve_license_payload(model_data) + return payload, build_license_flags(payload) + + def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]: """Rewrite Civitai preview URLs to use optimized renditions. @@ -43,5 +175,10 @@ def rewrite_preview_url(source_url: str | None, media_type: str | None = None) - return rewritten, True -__all__ = ["rewrite_preview_url"] - +__all__ = [ + "CommercialUseLevel", + "build_license_flags", + "resolve_license_payload", + "resolve_license_info", + "rewrite_preview_url", +] diff --git a/refs/target_version.json b/refs/target_version.json index 6db1f1b1..17e1d0f8 100644 --- a/refs/target_version.json +++ b/refs/target_version.json @@ -11,7 +11,11 @@ "type": "LORA", "nsfw": false, "description": "description", - "tags": ["style"] + "tags": ["style"], + "allowNoCredit": true, + "allowCommercialUse": ["Sell"], + "allowDerivatives": true, + "allowDifferentLicense": true }, "files": [ { diff --git a/tests/services/test_metadata_sync_service.py b/tests/services/test_metadata_sync_service.py index 4af25168..6d44629d 100644 --- a/tests/services/test_metadata_sync_service.py +++ b/tests/services/test_metadata_sync_service.py @@ -75,6 +75,10 @@ async def test_update_model_metadata_merges_and_persists(): "description": "desc", "tags": ["style"], "creator": {"id": 2}, + "allowNoCredit": False, + "allowCommercialUse": ["Image"], + "allowDerivatives": False, + "allowDifferentLicense": True, }, "baseModel": "sdxl", "images": ["img"], @@ -92,6 +96,13 @@ async def test_update_model_metadata_merges_and_persists(): assert result["modelDescription"] == "desc" assert result["tags"] == ["style"] assert result["base_model"] == "SDXL 1.0" + civitai_model = result["civitai"]["model"] + assert civitai_model["allowNoCredit"] is False + assert civitai_model["allowCommercialUse"] == ["Image"] + assert civitai_model["allowDerivatives"] is False + assert civitai_model["allowDifferentLicense"] is True + for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"): + assert key not in result helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once() helpers.metadata_manager.save_metadata.assert_awaited_once_with( @@ -142,6 +153,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path): assert model_data["civitai_deleted"] is False assert "civitai" in model_data assert model_data["metadata_source"] == "civitai_api" + civitai_model = model_data["civitai"]["model"] + assert civitai_model["allowNoCredit"] is True + assert civitai_model["allowDerivatives"] is True + assert civitai_model["allowDifferentLicense"] is True + assert civitai_model["allowCommercialUse"] == ["Sell"] + for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"): + assert key not in model_data helpers.metadata_manager.hydrate_model_data.assert_not_awaited() assert model_data["hydrated"] is True @@ -151,7 +169,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path): assert await_args, "expected metadata to be persisted" last_call = await_args[-1] assert last_call.args[0] == metadata_path - assert last_call.args[1]["hydrated"] is True + persisted_payload = last_call.args[1] + assert persisted_payload["hydrated"] is True + civitai_model = persisted_payload["civitai"]["model"] + assert civitai_model["allowNoCredit"] is True + assert civitai_model["allowCommercialUse"] == ["Sell"] + for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"): + assert key not in persisted_payload update_cache.assert_awaited_once() @@ -422,4 +446,3 @@ async def test_relink_metadata_raises_when_version_missing(): model_id=9, model_version_id=None, ) - diff --git a/tests/services/test_model_scanner.py b/tests/services/test_model_scanner.py index 972d5f5a..02d85f97 100644 --- a/tests/services/test_model_scanner.py +++ b/tests/services/test_model_scanner.py @@ -11,7 +11,8 @@ from py.services import model_scanner from py.services.model_cache import ModelCache from py.services.model_hash_index import ModelHashIndex from py.services.model_scanner import CacheBuildResult, ModelScanner -from py.services.persistent_model_cache import PersistentModelCache +from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS +from py.utils.civitai_utils import build_license_flags from py.utils.models import BaseModelMetadata @@ -122,6 +123,7 @@ async def test_initialize_cache_populates_cache(tmp_path: Path): _normalize_path(tmp_path / "one.txt"), _normalize_path(tmp_path / "nested" / "two.txt"), } + assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS} assert scanner._hash_index.get_path("hash-one") == _normalize_path(tmp_path / "one.txt") assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt") @@ -179,12 +181,49 @@ async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monk _normalize_path(tmp_path / "one.txt"), _normalize_path(tmp_path / "nested" / "two.txt"), } + assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS} assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt") assert scanner._tags_count == {"alpha": 1, "beta": 1} assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")] assert ws_stub.payloads[-1]["progress"] == 100 +@pytest.mark.asyncio +async def test_build_cache_entry_encodes_license_flags(tmp_path: Path): + scanner = DummyScanner(tmp_path) + + metadata = { + "file_path": _normalize_path(tmp_path / "sample.txt"), + "file_name": "sample", + "model_name": "Sample", + "folder": "", + "size": 1, + "modified": 1.0, + "sha256": "hash", + "tags": [], + "civitai": { + "model": { + "allowNoCredit": False, + "allowCommercialUse": ["Image", "Rent"], + "allowDerivatives": True, + "allowDifferentLicense": False, + } + }, + } + + expected_flags = build_license_flags( + { + "allowNoCredit": False, + "allowCommercialUse": ["Image", "Rent"], + "allowDerivatives": True, + "allowDifferentLicense": False, + } + ) + + entry = scanner._build_cache_entry(metadata) + assert entry["license_flags"] == expected_flags + + @pytest.mark.asyncio async def test_initialize_in_background_uses_persisted_cache_without_full_scan(tmp_path: Path, monkeypatch): monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0') diff --git a/tests/services/test_persistent_model_cache.py b/tests/services/test_persistent_model_cache.py index 602ef85e..b861bef5 100644 --- a/tests/services/test_persistent_model_cache.py +++ b/tests/services/test_persistent_model_cache.py @@ -2,7 +2,7 @@ from pathlib import Path import pytest -from py.services.persistent_model_cache import PersistentModelCache +from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None: @@ -43,6 +43,7 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None: 'trainedWords': ['word1'], 'creator': {'username': 'artist42'}, }, + 'license_flags': 13, }, { 'file_path': file_b, @@ -91,12 +92,14 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None: assert first['metadata_source'] == 'civitai_api' assert first['civitai']['creator']['username'] == 'artist42' assert first['civitai_deleted'] is False + assert first['license_flags'] == 13 second = items[file_b] assert second['exclude'] is True assert second['civitai'] is None assert second['metadata_source'] is None assert second['civitai_deleted'] is True + assert second['license_flags'] == DEFAULT_LICENSE_FLAGS expected_hash_pairs = { ('hash-a', file_a), diff --git a/tests/utils/test_civitai_utils.py b/tests/utils/test_civitai_utils.py new file mode 100644 index 00000000..1204e56b --- /dev/null +++ b/tests/utils/test_civitai_utils.py @@ -0,0 +1,35 @@ +from py.utils.civitai_utils import ( + CommercialUseLevel, + build_license_flags, + resolve_license_info, + resolve_license_payload, +) + + +def test_resolve_license_payload_defaults(): + payload, flags = resolve_license_info({}) + + assert payload["allowNoCredit"] is True + assert payload["allowDerivatives"] is True + assert payload["allowDifferentLicense"] is True + assert payload["allowCommercialUse"] == ["Sell"] + assert flags == 57 + + +def test_build_license_flags_custom_values(): + source = { + "allowNoCredit": False, + "allowCommercialUse": {"Image", "Sell"}, + "allowDerivatives": False, + "allowDifferentLicense": False, + } + + payload = resolve_license_payload(source) + assert payload["allowNoCredit"] is False + assert set(payload["allowCommercialUse"]) == {"Image", "Sell"} + assert payload["allowDerivatives"] is False + assert payload["allowDifferentLicense"] is False + + flags = build_license_flags(source) + # Highest commercial level is SELL -> level 4 -> shifted by 1 == 8. + assert flags == (CommercialUseLevel.SELL << 1)