feat: add license information handling for Civitai models

Add license resolution utilities and integrate license information into model metadata processing. The changes include:

- Add `resolve_license_payload` function to extract license data from Civitai model responses
- Integrate license information into model metadata in CivitaiClient and MetadataSyncService
- Add license flags support in model scanning and caching
- Implement CommercialUseLevel enum for standardized license classification
- Update model scanner to handle unknown fields when extracting metadata values

This ensures proper license attribution and compliance when working with Civitai models.
This commit is contained in:
Will Miao
2025-11-06 17:05:31 +08:00
parent 4301b3455f
commit ddf9e33961
10 changed files with 364 additions and 10 deletions

View File

@@ -11,6 +11,7 @@ from ..utils.models import BaseModelMetadata
from ..config import config
from ..utils.file_utils import find_preview_file, get_preview_extension
from ..utils.metadata_manager import MetadataManager
from ..utils.civitai_utils import resolve_license_info
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
@@ -175,7 +176,17 @@ class ModelScanner:
def get_value(key: str, default: Any = None) -> Any:
if is_mapping:
return source.get(key, default)
return getattr(source, key, default)
sentinel = object()
value = getattr(source, key, sentinel)
if value is not sentinel:
return value
unknown = getattr(source, "_unknown_fields", None)
if isinstance(unknown, dict) and key in unknown:
return unknown[key]
return default
file_path = file_path_override or get_value('file_path', '') or ''
normalized_path = file_path.replace('\\', '/')
@@ -197,7 +208,8 @@ class ModelScanner:
else:
preview_url = ''
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
civitai_full = get_value('civitai')
civitai_slim = self._slim_civitai_payload(civitai_full)
usage_tips = get_value('usage_tips', '') or ''
if not isinstance(usage_tips, str):
usage_tips = str(usage_tips)
@@ -229,12 +241,76 @@ class ModelScanner:
'civitai_deleted': bool(get_value('civitai_deleted', False)),
}
license_source: Dict[str, Any] = {}
if isinstance(civitai_full, Mapping):
civitai_model = civitai_full.get('model')
if isinstance(civitai_model, Mapping):
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key in civitai_model:
license_source[key] = civitai_model.get(key)
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key not in license_source:
value = get_value(key)
if value is not None:
license_source[key] = value
_, license_flags = resolve_license_info(license_source or {})
entry['license_flags'] = license_flags
model_type = get_value('model_type', None)
if model_type:
entry['model_type'] = model_type
return entry
def _ensure_license_flags(self, entry: Dict[str, Any]) -> None:
"""Ensure cached entries include an integer license flag bitset."""
if not isinstance(entry, dict):
return
license_value = entry.get('license_flags')
if license_value is not None:
try:
entry['license_flags'] = int(license_value)
except (TypeError, ValueError):
_, fallback_flags = resolve_license_info({})
entry['license_flags'] = fallback_flags
return
license_source = {
'allowNoCredit': entry.get('allowNoCredit'),
'allowCommercialUse': entry.get('allowCommercialUse'),
'allowDerivatives': entry.get('allowDerivatives'),
'allowDifferentLicense': entry.get('allowDifferentLicense'),
}
civitai_full = entry.get('civitai')
if isinstance(civitai_full, Mapping):
civitai_model = civitai_full.get('model')
if isinstance(civitai_model, Mapping):
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key in civitai_model:
license_source[key] = civitai_model.get(key)
_, license_flags = resolve_license_info(license_source)
entry['license_flags'] = license_flags
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -567,6 +643,7 @@ class ModelScanner:
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
print("init start", flush=True)
self._is_initializing = True # Set flag
try:
start_time = time.time()
@@ -575,6 +652,7 @@ class ModelScanner:
scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
print("init end", flush=True)
logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -681,6 +759,7 @@ class ModelScanner:
model_data = self.adjust_cached_entry(dict(model_data))
if not model_data:
continue
self._ensure_license_flags(model_data)
# Add to cache
self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data)
@@ -975,6 +1054,7 @@ class ModelScanner:
processed_files += 1
if result:
self._ensure_license_flags(result)
raw_data.append(result)
sha_value = result.get('sha256')