feat: add license information handling for Civitai models

Add license resolution utilities and integrate license information into model metadata processing. The changes include:

- Add `resolve_license_payload` function to extract license data from Civitai model responses
- Integrate license information into model metadata in CivitaiClient and MetadataSyncService
- Add license flags support in model scanning and caching
- Implement CommercialUseLevel enum for standardized license classification
- Update model scanner to handle unknown fields when extracting metadata values

This ensures proper license attribution and compliance when working with Civitai models.
This commit is contained in:
Will Miao
2025-11-06 17:05:31 +08:00
parent 4301b3455f
commit ddf9e33961
10 changed files with 364 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict, Tuple, List, Sequence
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader
from .errors import RateLimitError, ResourceNotFoundError
from ..utils.civitai_utils import resolve_license_payload
logger = logging.getLogger(__name__)
@@ -420,6 +421,10 @@ class CivitaiClient:
model_info['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
license_payload = resolve_license_payload(model_data)
for field, value in license_payload.items():
model_info[field] = value
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai

View File

@@ -9,6 +9,7 @@ from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
from ..services.settings_manager import SettingsManager
from ..utils.civitai_utils import resolve_license_payload
from ..utils.model_utils import determine_base_model
from .errors import RateLimitError
@@ -135,6 +136,17 @@ class MetadataSyncService:
):
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
merged_civitai = local_metadata.get("civitai") or {}
civitai_model = merged_civitai.get("model")
if not isinstance(civitai_model, dict):
civitai_model = {}
license_payload = resolve_license_payload(model_data)
civitai_model.update(license_payload)
merged_civitai["model"] = civitai_model
local_metadata["civitai"] = merged_civitai
local_metadata["base_model"] = determine_base_model(
civitai_metadata.get("baseModel")
)
@@ -295,6 +307,7 @@ class MetadataSyncService:
"preview_url": local_metadata.get("preview_url"),
"civitai": local_metadata.get("civitai"),
}
model_data.update(update_payload)
await update_cache_func(file_path, file_path, local_metadata)
@@ -436,4 +449,3 @@ class MetadataSyncService:
results["verified_as_duplicates"] = False
return results

View File

@@ -11,6 +11,7 @@ from ..utils.models import BaseModelMetadata
from ..config import config
from ..utils.file_utils import find_preview_file, get_preview_extension
from ..utils.metadata_manager import MetadataManager
from ..utils.civitai_utils import resolve_license_info
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
@@ -175,7 +176,17 @@ class ModelScanner:
def get_value(key: str, default: Any = None) -> Any:
if is_mapping:
return source.get(key, default)
return getattr(source, key, default)
sentinel = object()
value = getattr(source, key, sentinel)
if value is not sentinel:
return value
unknown = getattr(source, "_unknown_fields", None)
if isinstance(unknown, dict) and key in unknown:
return unknown[key]
return default
file_path = file_path_override or get_value('file_path', '') or ''
normalized_path = file_path.replace('\\', '/')
@@ -197,7 +208,8 @@ class ModelScanner:
else:
preview_url = ''
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
civitai_full = get_value('civitai')
civitai_slim = self._slim_civitai_payload(civitai_full)
usage_tips = get_value('usage_tips', '') or ''
if not isinstance(usage_tips, str):
usage_tips = str(usage_tips)
@@ -229,12 +241,76 @@ class ModelScanner:
'civitai_deleted': bool(get_value('civitai_deleted', False)),
}
license_source: Dict[str, Any] = {}
if isinstance(civitai_full, Mapping):
civitai_model = civitai_full.get('model')
if isinstance(civitai_model, Mapping):
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key in civitai_model:
license_source[key] = civitai_model.get(key)
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key not in license_source:
value = get_value(key)
if value is not None:
license_source[key] = value
_, license_flags = resolve_license_info(license_source or {})
entry['license_flags'] = license_flags
model_type = get_value('model_type', None)
if model_type:
entry['model_type'] = model_type
return entry
def _ensure_license_flags(self, entry: Dict[str, Any]) -> None:
"""Ensure cached entries include an integer license flag bitset."""
if not isinstance(entry, dict):
return
license_value = entry.get('license_flags')
if license_value is not None:
try:
entry['license_flags'] = int(license_value)
except (TypeError, ValueError):
_, fallback_flags = resolve_license_info({})
entry['license_flags'] = fallback_flags
return
license_source = {
'allowNoCredit': entry.get('allowNoCredit'),
'allowCommercialUse': entry.get('allowCommercialUse'),
'allowDerivatives': entry.get('allowDerivatives'),
'allowDifferentLicense': entry.get('allowDifferentLicense'),
}
civitai_full = entry.get('civitai')
if isinstance(civitai_full, Mapping):
civitai_model = civitai_full.get('model')
if isinstance(civitai_model, Mapping):
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key in civitai_model:
license_source[key] = civitai_model.get(key)
_, license_flags = resolve_license_info(license_source)
entry['license_flags'] = license_flags
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -567,6 +643,7 @@ class ModelScanner:
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
print("init start", flush=True)
self._is_initializing = True # Set flag
try:
start_time = time.time()
@@ -575,6 +652,7 @@ class ModelScanner:
scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
print("init end", flush=True)
logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -681,6 +759,7 @@ class ModelScanner:
model_data = self.adjust_cached_entry(dict(model_data))
if not model_data:
continue
self._ensure_license_flags(model_data)
# Add to cache
self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data)
@@ -975,6 +1054,7 @@ class ModelScanner:
processed_files += 1
if result:
self._ensure_license_flags(result)
raw_data.append(result)
sha_value = result.get('sha256')

View File

@@ -21,6 +21,9 @@ class PersistedCacheData:
excluded_models: List[str]
DEFAULT_LICENSE_FLAGS = 57 # 57 (0b111001) encodes CivitAI's documented default license permissions.
class PersistentModelCache:
"""Persist core model metadata and hash index data in SQLite."""
@@ -47,6 +50,7 @@ class PersistentModelCache:
"civitai_name",
"civitai_creator_username",
"trained_words",
"license_flags",
"civitai_deleted",
"exclude",
"db_checked",
@@ -149,6 +153,10 @@ class PersistentModelCache:
if creator_username:
civitai.setdefault("creator", {})["username"] = creator_username
license_value = row["license_flags"]
if license_value is None:
license_value = DEFAULT_LICENSE_FLAGS
item = {
"file_path": file_path,
"file_name": row["file_name"],
@@ -171,6 +179,7 @@ class PersistentModelCache:
"tags": tags.get(file_path, []),
"civitai": civitai,
"civitai_deleted": bool(row["civitai_deleted"]),
"license_flags": int(license_value),
}
raw_data.append(item)
@@ -484,6 +493,8 @@ class PersistentModelCache:
"metadata_source": "TEXT",
"civitai_creator_username": "TEXT",
"civitai_deleted": "INTEGER DEFAULT 0",
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
}
for column, definition in required_columns.items():
@@ -518,6 +529,10 @@ class PersistentModelCache:
if isinstance(creator_data, dict):
creator_username = creator_data.get("username") or None
license_flags = item.get("license_flags")
if license_flags is None:
license_flags = DEFAULT_LICENSE_FLAGS
return (
model_type,
item.get("file_path"),
@@ -540,6 +555,7 @@ class PersistentModelCache:
civitai.get("name"),
creator_username,
trained_words_json,
int(license_flags),
1 if item.get("civitai_deleted") else 0,
1 if item.get("exclude") else 0,
1 if item.get("db_checked") else 0,

View File

@@ -2,9 +2,141 @@
from __future__ import annotations
from enum import IntEnum
from typing import Any, Dict, Iterable, Mapping, Sequence
from urllib.parse import urlparse, urlunparse
class CommercialUseLevel(IntEnum):
"""Enumerate supported commercial use permission levels."""
NONE = 0
IMAGE = 1
RENT_CIVIT = 2
RENT = 3
SELL = 4
_DEFAULT_ALLOW_COMMERCIAL_USE: Sequence[str] = ("Sell",)
_LICENSE_DEFAULTS: Dict[str, Any] = {
"allowNoCredit": True,
"allowCommercialUse": _DEFAULT_ALLOW_COMMERCIAL_USE,
"allowDerivatives": True,
"allowDifferentLicense": True,
}
_COMMERCIAL_VALUE_TO_LEVEL = {
"none": CommercialUseLevel.NONE,
"image": CommercialUseLevel.IMAGE,
"rentcivit": CommercialUseLevel.RENT_CIVIT,
"rent": CommercialUseLevel.RENT,
"sell": CommercialUseLevel.SELL,
}
def _normalize_commercial_values(value: Any) -> Sequence[str]:
"""Return a normalized list of commercial permissions preserving source values."""
if value is None:
return list(_DEFAULT_ALLOW_COMMERCIAL_USE)
if isinstance(value, str):
return [value]
if isinstance(value, Iterable):
result = []
for item in value:
if item is None:
continue
if isinstance(item, str):
result.append(item)
continue
result.append(str(item))
if result:
return result
return list(_DEFAULT_ALLOW_COMMERCIAL_USE)
def _to_bool(value: Any, fallback: bool) -> bool:
if value is None:
return fallback
return bool(value)
def resolve_license_payload(model_data: Mapping[str, Any] | None) -> Dict[str, Any]:
"""Extract license fields from model metadata applying documented defaults."""
payload: Dict[str, Any] = {}
allow_no_credit = payload["allowNoCredit"] = _to_bool(
(model_data or {}).get("allowNoCredit"),
_LICENSE_DEFAULTS["allowNoCredit"],
)
commercial = _normalize_commercial_values(
(model_data or {}).get("allowCommercialUse"),
)
payload["allowCommercialUse"] = list(commercial)
allow_derivatives = payload["allowDerivatives"] = _to_bool(
(model_data or {}).get("allowDerivatives"),
_LICENSE_DEFAULTS["allowDerivatives"],
)
allow_different_license = payload["allowDifferentLicense"] = _to_bool(
(model_data or {}).get("allowDifferentLicense"),
_LICENSE_DEFAULTS["allowDifferentLicense"],
)
# Ensure booleans are plain bool instances
payload["allowNoCredit"] = bool(allow_no_credit)
payload["allowDerivatives"] = bool(allow_derivatives)
payload["allowDifferentLicense"] = bool(allow_different_license)
return payload
def _resolve_commercial_level(values: Sequence[str]) -> CommercialUseLevel:
level = CommercialUseLevel.NONE
for value in values:
normalized = str(value).strip().lower().replace("_", "")
normalized = normalized.replace("-", "")
candidate = _COMMERCIAL_VALUE_TO_LEVEL.get(normalized)
if candidate is None:
continue
if candidate > level:
level = candidate
return level
def build_license_flags(payload: Mapping[str, Any] | None) -> int:
"""Encode license payload into a compact bitset for cache storage."""
resolved = resolve_license_payload(payload or {})
flags = 0
if resolved.get("allowNoCredit", True):
flags |= 1 << 0
commercial_level = _resolve_commercial_level(resolved.get("allowCommercialUse", ()))
flags |= (int(commercial_level) & 0b111) << 1
if resolved.get("allowDerivatives", True):
flags |= 1 << 4
if resolved.get("allowDifferentLicense", True):
flags |= 1 << 5
return flags
def resolve_license_info(model_data: Mapping[str, Any] | None) -> tuple[Dict[str, Any], int]:
"""Return normalized license payload and its encoded bitset."""
payload = resolve_license_payload(model_data)
return payload, build_license_flags(payload)
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
"""Rewrite Civitai preview URLs to use optimized renditions.
@@ -43,5 +175,10 @@ def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -
return rewritten, True
__all__ = ["rewrite_preview_url"]
__all__ = [
"CommercialUseLevel",
"build_license_flags",
"resolve_license_payload",
"resolve_license_info",
"rewrite_preview_url",
]

View File

@@ -11,7 +11,11 @@
"type": "LORA",
"nsfw": false,
"description": "description",
"tags": ["style"]
"tags": ["style"],
"allowNoCredit": true,
"allowCommercialUse": ["Sell"],
"allowDerivatives": true,
"allowDifferentLicense": true
},
"files": [
{

View File

@@ -75,6 +75,10 @@ async def test_update_model_metadata_merges_and_persists():
"description": "desc",
"tags": ["style"],
"creator": {"id": 2},
"allowNoCredit": False,
"allowCommercialUse": ["Image"],
"allowDerivatives": False,
"allowDifferentLicense": True,
},
"baseModel": "sdxl",
"images": ["img"],
@@ -92,6 +96,13 @@ async def test_update_model_metadata_merges_and_persists():
assert result["modelDescription"] == "desc"
assert result["tags"] == ["style"]
assert result["base_model"] == "SDXL 1.0"
civitai_model = result["civitai"]["model"]
assert civitai_model["allowNoCredit"] is False
assert civitai_model["allowCommercialUse"] == ["Image"]
assert civitai_model["allowDerivatives"] is False
assert civitai_model["allowDifferentLicense"] is True
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
assert key not in result
helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once()
helpers.metadata_manager.save_metadata.assert_awaited_once_with(
@@ -142,6 +153,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path):
assert model_data["civitai_deleted"] is False
assert "civitai" in model_data
assert model_data["metadata_source"] == "civitai_api"
civitai_model = model_data["civitai"]["model"]
assert civitai_model["allowNoCredit"] is True
assert civitai_model["allowDerivatives"] is True
assert civitai_model["allowDifferentLicense"] is True
assert civitai_model["allowCommercialUse"] == ["Sell"]
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
assert key not in model_data
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
assert model_data["hydrated"] is True
@@ -151,7 +169,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path):
assert await_args, "expected metadata to be persisted"
last_call = await_args[-1]
assert last_call.args[0] == metadata_path
assert last_call.args[1]["hydrated"] is True
persisted_payload = last_call.args[1]
assert persisted_payload["hydrated"] is True
civitai_model = persisted_payload["civitai"]["model"]
assert civitai_model["allowNoCredit"] is True
assert civitai_model["allowCommercialUse"] == ["Sell"]
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
assert key not in persisted_payload
update_cache.assert_awaited_once()
@@ -422,4 +446,3 @@ async def test_relink_metadata_raises_when_version_missing():
model_id=9,
model_version_id=None,
)

View File

@@ -11,7 +11,8 @@ from py.services import model_scanner
from py.services.model_cache import ModelCache
from py.services.model_hash_index import ModelHashIndex
from py.services.model_scanner import CacheBuildResult, ModelScanner
from py.services.persistent_model_cache import PersistentModelCache
from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS
from py.utils.civitai_utils import build_license_flags
from py.utils.models import BaseModelMetadata
@@ -122,6 +123,7 @@ async def test_initialize_cache_populates_cache(tmp_path: Path):
_normalize_path(tmp_path / "one.txt"),
_normalize_path(tmp_path / "nested" / "two.txt"),
}
assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS}
assert scanner._hash_index.get_path("hash-one") == _normalize_path(tmp_path / "one.txt")
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
@@ -179,12 +181,49 @@ async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monk
_normalize_path(tmp_path / "one.txt"),
_normalize_path(tmp_path / "nested" / "two.txt"),
}
assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS}
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
assert scanner._tags_count == {"alpha": 1, "beta": 1}
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
assert ws_stub.payloads[-1]["progress"] == 100
@pytest.mark.asyncio
async def test_build_cache_entry_encodes_license_flags(tmp_path: Path):
scanner = DummyScanner(tmp_path)
metadata = {
"file_path": _normalize_path(tmp_path / "sample.txt"),
"file_name": "sample",
"model_name": "Sample",
"folder": "",
"size": 1,
"modified": 1.0,
"sha256": "hash",
"tags": [],
"civitai": {
"model": {
"allowNoCredit": False,
"allowCommercialUse": ["Image", "Rent"],
"allowDerivatives": True,
"allowDifferentLicense": False,
}
},
}
expected_flags = build_license_flags(
{
"allowNoCredit": False,
"allowCommercialUse": ["Image", "Rent"],
"allowDerivatives": True,
"allowDifferentLicense": False,
}
)
entry = scanner._build_cache_entry(metadata)
assert entry["license_flags"] == expected_flags
@pytest.mark.asyncio
async def test_initialize_in_background_uses_persisted_cache_without_full_scan(tmp_path: Path, monkeypatch):
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')

View File

@@ -2,7 +2,7 @@ from pathlib import Path
import pytest
from py.services.persistent_model_cache import PersistentModelCache
from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS
def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
@@ -43,6 +43,7 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
'trainedWords': ['word1'],
'creator': {'username': 'artist42'},
},
'license_flags': 13,
},
{
'file_path': file_b,
@@ -91,12 +92,14 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
assert first['metadata_source'] == 'civitai_api'
assert first['civitai']['creator']['username'] == 'artist42'
assert first['civitai_deleted'] is False
assert first['license_flags'] == 13
second = items[file_b]
assert second['exclude'] is True
assert second['civitai'] is None
assert second['metadata_source'] is None
assert second['civitai_deleted'] is True
assert second['license_flags'] == DEFAULT_LICENSE_FLAGS
expected_hash_pairs = {
('hash-a', file_a),

View File

@@ -0,0 +1,35 @@
from py.utils.civitai_utils import (
CommercialUseLevel,
build_license_flags,
resolve_license_info,
resolve_license_payload,
)
def test_resolve_license_payload_defaults():
payload, flags = resolve_license_info({})
assert payload["allowNoCredit"] is True
assert payload["allowDerivatives"] is True
assert payload["allowDifferentLicense"] is True
assert payload["allowCommercialUse"] == ["Sell"]
assert flags == 57
def test_build_license_flags_custom_values():
source = {
"allowNoCredit": False,
"allowCommercialUse": {"Image", "Sell"},
"allowDerivatives": False,
"allowDifferentLicense": False,
}
payload = resolve_license_payload(source)
assert payload["allowNoCredit"] is False
assert set(payload["allowCommercialUse"]) == {"Image", "Sell"}
assert payload["allowDerivatives"] is False
assert payload["allowDifferentLicense"] is False
flags = build_license_flags(source)
# Highest commercial level is SELL -> level 4 -> shifted by 1 == 8.
assert flags == (CommercialUseLevel.SELL << 1)