feat: add license information handling for Civitai models

Add license resolution utilities and integrate license information into model metadata processing. The changes include:

- Add `resolve_license_payload` function to extract license data from Civitai model responses
- Integrate license information into model metadata in CivitaiClient and MetadataSyncService
- Add license flags support in model scanning and caching
- Implement CommercialUseLevel enum for standardized license classification
- Update model scanner to handle unknown fields when extracting metadata values

This ensures proper license attribution and compliance when working with Civitai models.
This commit is contained in:
Will Miao
2025-11-06 17:05:31 +08:00
parent 4301b3455f
commit ddf9e33961
10 changed files with 364 additions and 10 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict, Tuple, List, Sequence
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader
from .errors import RateLimitError, ResourceNotFoundError
from ..utils.civitai_utils import resolve_license_payload
logger = logging.getLogger(__name__)
@@ -420,6 +421,10 @@ class CivitaiClient:
model_info['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
license_payload = resolve_license_payload(model_data)
for field, value in license_payload.items():
model_info[field] = value
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai

View File

@@ -9,6 +9,7 @@ from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
from ..services.settings_manager import SettingsManager
from ..utils.civitai_utils import resolve_license_payload
from ..utils.model_utils import determine_base_model
from .errors import RateLimitError
@@ -135,6 +136,17 @@ class MetadataSyncService:
):
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
merged_civitai = local_metadata.get("civitai") or {}
civitai_model = merged_civitai.get("model")
if not isinstance(civitai_model, dict):
civitai_model = {}
license_payload = resolve_license_payload(model_data)
civitai_model.update(license_payload)
merged_civitai["model"] = civitai_model
local_metadata["civitai"] = merged_civitai
local_metadata["base_model"] = determine_base_model(
civitai_metadata.get("baseModel")
)
@@ -295,6 +307,7 @@ class MetadataSyncService:
"preview_url": local_metadata.get("preview_url"),
"civitai": local_metadata.get("civitai"),
}
model_data.update(update_payload)
await update_cache_func(file_path, file_path, local_metadata)
@@ -436,4 +449,3 @@ class MetadataSyncService:
results["verified_as_duplicates"] = False
return results

View File

@@ -11,6 +11,7 @@ from ..utils.models import BaseModelMetadata
from ..config import config
from ..utils.file_utils import find_preview_file, get_preview_extension
from ..utils.metadata_manager import MetadataManager
from ..utils.civitai_utils import resolve_license_info
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
@@ -175,7 +176,17 @@ class ModelScanner:
def get_value(key: str, default: Any = None) -> Any:
if is_mapping:
return source.get(key, default)
return getattr(source, key, default)
sentinel = object()
value = getattr(source, key, sentinel)
if value is not sentinel:
return value
unknown = getattr(source, "_unknown_fields", None)
if isinstance(unknown, dict) and key in unknown:
return unknown[key]
return default
file_path = file_path_override or get_value('file_path', '') or ''
normalized_path = file_path.replace('\\', '/')
@@ -197,7 +208,8 @@ class ModelScanner:
else:
preview_url = ''
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
civitai_full = get_value('civitai')
civitai_slim = self._slim_civitai_payload(civitai_full)
usage_tips = get_value('usage_tips', '') or ''
if not isinstance(usage_tips, str):
usage_tips = str(usage_tips)
@@ -229,12 +241,76 @@ class ModelScanner:
'civitai_deleted': bool(get_value('civitai_deleted', False)),
}
license_source: Dict[str, Any] = {}
if isinstance(civitai_full, Mapping):
civitai_model = civitai_full.get('model')
if isinstance(civitai_model, Mapping):
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key in civitai_model:
license_source[key] = civitai_model.get(key)
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key not in license_source:
value = get_value(key)
if value is not None:
license_source[key] = value
_, license_flags = resolve_license_info(license_source or {})
entry['license_flags'] = license_flags
model_type = get_value('model_type', None)
if model_type:
entry['model_type'] = model_type
return entry
def _ensure_license_flags(self, entry: Dict[str, Any]) -> None:
"""Ensure cached entries include an integer license flag bitset."""
if not isinstance(entry, dict):
return
license_value = entry.get('license_flags')
if license_value is not None:
try:
entry['license_flags'] = int(license_value)
except (TypeError, ValueError):
_, fallback_flags = resolve_license_info({})
entry['license_flags'] = fallback_flags
return
license_source = {
'allowNoCredit': entry.get('allowNoCredit'),
'allowCommercialUse': entry.get('allowCommercialUse'),
'allowDerivatives': entry.get('allowDerivatives'),
'allowDifferentLicense': entry.get('allowDifferentLicense'),
}
civitai_full = entry.get('civitai')
if isinstance(civitai_full, Mapping):
civitai_model = civitai_full.get('model')
if isinstance(civitai_model, Mapping):
for key in (
'allowNoCredit',
'allowCommercialUse',
'allowDerivatives',
'allowDifferentLicense',
):
if key in civitai_model:
license_source[key] = civitai_model.get(key)
_, license_flags = resolve_license_info(license_source)
entry['license_flags'] = license_flags
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -567,6 +643,7 @@ class ModelScanner:
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
print("init start", flush=True)
self._is_initializing = True # Set flag
try:
start_time = time.time()
@@ -575,6 +652,7 @@ class ModelScanner:
scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result)
print("init end", flush=True)
logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -681,6 +759,7 @@ class ModelScanner:
model_data = self.adjust_cached_entry(dict(model_data))
if not model_data:
continue
self._ensure_license_flags(model_data)
# Add to cache
self._cache.raw_data.append(model_data)
self._cache.add_to_version_index(model_data)
@@ -975,6 +1054,7 @@ class ModelScanner:
processed_files += 1
if result:
self._ensure_license_flags(result)
raw_data.append(result)
sha_value = result.get('sha256')

View File

@@ -21,6 +21,9 @@ class PersistedCacheData:
excluded_models: List[str]
DEFAULT_LICENSE_FLAGS = 57 # 57 (0b111001) encodes CivitAI's documented default license permissions.
class PersistentModelCache:
"""Persist core model metadata and hash index data in SQLite."""
@@ -47,6 +50,7 @@ class PersistentModelCache:
"civitai_name",
"civitai_creator_username",
"trained_words",
"license_flags",
"civitai_deleted",
"exclude",
"db_checked",
@@ -149,6 +153,10 @@ class PersistentModelCache:
if creator_username:
civitai.setdefault("creator", {})["username"] = creator_username
license_value = row["license_flags"]
if license_value is None:
license_value = DEFAULT_LICENSE_FLAGS
item = {
"file_path": file_path,
"file_name": row["file_name"],
@@ -171,6 +179,7 @@ class PersistentModelCache:
"tags": tags.get(file_path, []),
"civitai": civitai,
"civitai_deleted": bool(row["civitai_deleted"]),
"license_flags": int(license_value),
}
raw_data.append(item)
@@ -484,6 +493,8 @@ class PersistentModelCache:
"metadata_source": "TEXT",
"civitai_creator_username": "TEXT",
"civitai_deleted": "INTEGER DEFAULT 0",
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
}
for column, definition in required_columns.items():
@@ -518,6 +529,10 @@ class PersistentModelCache:
if isinstance(creator_data, dict):
creator_username = creator_data.get("username") or None
license_flags = item.get("license_flags")
if license_flags is None:
license_flags = DEFAULT_LICENSE_FLAGS
return (
model_type,
item.get("file_path"),
@@ -540,6 +555,7 @@ class PersistentModelCache:
civitai.get("name"),
creator_username,
trained_words_json,
int(license_flags),
1 if item.get("civitai_deleted") else 0,
1 if item.get("exclude") else 0,
1 if item.get("db_checked") else 0,