mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add license information handling for Civitai models
Add license resolution utilities and integrate license information into model metadata processing. The changes include: - Add `resolve_license_payload` function to extract license data from Civitai model responses - Integrate license information into model metadata in CivitaiClient and MetadataSyncService - Add license flags support in model scanning and caching - Implement CommercialUseLevel enum for standardized license classification - Update model scanner to handle unknown fields when extracting metadata values This ensures proper license attribution and compliance when working with Civitai models.
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict, Tuple, List, Sequence
|
|||||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
from .errors import RateLimitError, ResourceNotFoundError
|
from .errors import RateLimitError, ResourceNotFoundError
|
||||||
|
from ..utils.civitai_utils import resolve_license_payload
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -420,6 +421,10 @@ class CivitaiClient:
|
|||||||
model_info['tags'] = model_data.get("tags", [])
|
model_info['tags'] = model_data.get("tags", [])
|
||||||
version['creator'] = model_data.get("creator")
|
version['creator'] = model_data.get("creator")
|
||||||
|
|
||||||
|
license_payload = resolve_license_payload(model_data)
|
||||||
|
for field, value in license_payload.items():
|
||||||
|
model_info[field] = value
|
||||||
|
|
||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Fetch model version metadata from Civitai
|
"""Fetch model version metadata from Civitai
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from datetime import datetime
|
|||||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||||
|
|
||||||
from ..services.settings_manager import SettingsManager
|
from ..services.settings_manager import SettingsManager
|
||||||
|
from ..utils.civitai_utils import resolve_license_payload
|
||||||
from ..utils.model_utils import determine_base_model
|
from ..utils.model_utils import determine_base_model
|
||||||
from .errors import RateLimitError
|
from .errors import RateLimitError
|
||||||
|
|
||||||
@@ -135,6 +136,17 @@ class MetadataSyncService:
|
|||||||
):
|
):
|
||||||
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||||
|
|
||||||
|
merged_civitai = local_metadata.get("civitai") or {}
|
||||||
|
civitai_model = merged_civitai.get("model")
|
||||||
|
if not isinstance(civitai_model, dict):
|
||||||
|
civitai_model = {}
|
||||||
|
|
||||||
|
license_payload = resolve_license_payload(model_data)
|
||||||
|
civitai_model.update(license_payload)
|
||||||
|
|
||||||
|
merged_civitai["model"] = civitai_model
|
||||||
|
local_metadata["civitai"] = merged_civitai
|
||||||
|
|
||||||
local_metadata["base_model"] = determine_base_model(
|
local_metadata["base_model"] = determine_base_model(
|
||||||
civitai_metadata.get("baseModel")
|
civitai_metadata.get("baseModel")
|
||||||
)
|
)
|
||||||
@@ -295,6 +307,7 @@ class MetadataSyncService:
|
|||||||
"preview_url": local_metadata.get("preview_url"),
|
"preview_url": local_metadata.get("preview_url"),
|
||||||
"civitai": local_metadata.get("civitai"),
|
"civitai": local_metadata.get("civitai"),
|
||||||
}
|
}
|
||||||
|
|
||||||
model_data.update(update_payload)
|
model_data.update(update_payload)
|
||||||
|
|
||||||
await update_cache_func(file_path, file_path, local_metadata)
|
await update_cache_func(file_path, file_path, local_metadata)
|
||||||
@@ -436,4 +449,3 @@ class MetadataSyncService:
|
|||||||
results["verified_as_duplicates"] = False
|
results["verified_as_duplicates"] = False
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from ..utils.models import BaseModelMetadata
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.file_utils import find_preview_file, get_preview_extension
|
from ..utils.file_utils import find_preview_file, get_preview_extension
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
|
from ..utils.civitai_utils import resolve_license_info
|
||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .model_hash_index import ModelHashIndex
|
from .model_hash_index import ModelHashIndex
|
||||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||||
@@ -175,7 +176,17 @@ class ModelScanner:
|
|||||||
def get_value(key: str, default: Any = None) -> Any:
|
def get_value(key: str, default: Any = None) -> Any:
|
||||||
if is_mapping:
|
if is_mapping:
|
||||||
return source.get(key, default)
|
return source.get(key, default)
|
||||||
return getattr(source, key, default)
|
|
||||||
|
sentinel = object()
|
||||||
|
value = getattr(source, key, sentinel)
|
||||||
|
if value is not sentinel:
|
||||||
|
return value
|
||||||
|
|
||||||
|
unknown = getattr(source, "_unknown_fields", None)
|
||||||
|
if isinstance(unknown, dict) and key in unknown:
|
||||||
|
return unknown[key]
|
||||||
|
|
||||||
|
return default
|
||||||
|
|
||||||
file_path = file_path_override or get_value('file_path', '') or ''
|
file_path = file_path_override or get_value('file_path', '') or ''
|
||||||
normalized_path = file_path.replace('\\', '/')
|
normalized_path = file_path.replace('\\', '/')
|
||||||
@@ -197,7 +208,8 @@ class ModelScanner:
|
|||||||
else:
|
else:
|
||||||
preview_url = ''
|
preview_url = ''
|
||||||
|
|
||||||
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
|
civitai_full = get_value('civitai')
|
||||||
|
civitai_slim = self._slim_civitai_payload(civitai_full)
|
||||||
usage_tips = get_value('usage_tips', '') or ''
|
usage_tips = get_value('usage_tips', '') or ''
|
||||||
if not isinstance(usage_tips, str):
|
if not isinstance(usage_tips, str):
|
||||||
usage_tips = str(usage_tips)
|
usage_tips = str(usage_tips)
|
||||||
@@ -229,12 +241,76 @@ class ModelScanner:
|
|||||||
'civitai_deleted': bool(get_value('civitai_deleted', False)),
|
'civitai_deleted': bool(get_value('civitai_deleted', False)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
license_source: Dict[str, Any] = {}
|
||||||
|
if isinstance(civitai_full, Mapping):
|
||||||
|
civitai_model = civitai_full.get('model')
|
||||||
|
if isinstance(civitai_model, Mapping):
|
||||||
|
for key in (
|
||||||
|
'allowNoCredit',
|
||||||
|
'allowCommercialUse',
|
||||||
|
'allowDerivatives',
|
||||||
|
'allowDifferentLicense',
|
||||||
|
):
|
||||||
|
if key in civitai_model:
|
||||||
|
license_source[key] = civitai_model.get(key)
|
||||||
|
|
||||||
|
for key in (
|
||||||
|
'allowNoCredit',
|
||||||
|
'allowCommercialUse',
|
||||||
|
'allowDerivatives',
|
||||||
|
'allowDifferentLicense',
|
||||||
|
):
|
||||||
|
if key not in license_source:
|
||||||
|
value = get_value(key)
|
||||||
|
if value is not None:
|
||||||
|
license_source[key] = value
|
||||||
|
|
||||||
|
_, license_flags = resolve_license_info(license_source or {})
|
||||||
|
entry['license_flags'] = license_flags
|
||||||
|
|
||||||
model_type = get_value('model_type', None)
|
model_type = get_value('model_type', None)
|
||||||
if model_type:
|
if model_type:
|
||||||
entry['model_type'] = model_type
|
entry['model_type'] = model_type
|
||||||
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
def _ensure_license_flags(self, entry: Dict[str, Any]) -> None:
|
||||||
|
"""Ensure cached entries include an integer license flag bitset."""
|
||||||
|
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
license_value = entry.get('license_flags')
|
||||||
|
if license_value is not None:
|
||||||
|
try:
|
||||||
|
entry['license_flags'] = int(license_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
_, fallback_flags = resolve_license_info({})
|
||||||
|
entry['license_flags'] = fallback_flags
|
||||||
|
return
|
||||||
|
|
||||||
|
license_source = {
|
||||||
|
'allowNoCredit': entry.get('allowNoCredit'),
|
||||||
|
'allowCommercialUse': entry.get('allowCommercialUse'),
|
||||||
|
'allowDerivatives': entry.get('allowDerivatives'),
|
||||||
|
'allowDifferentLicense': entry.get('allowDifferentLicense'),
|
||||||
|
}
|
||||||
|
civitai_full = entry.get('civitai')
|
||||||
|
if isinstance(civitai_full, Mapping):
|
||||||
|
civitai_model = civitai_full.get('model')
|
||||||
|
if isinstance(civitai_model, Mapping):
|
||||||
|
for key in (
|
||||||
|
'allowNoCredit',
|
||||||
|
'allowCommercialUse',
|
||||||
|
'allowDerivatives',
|
||||||
|
'allowDifferentLicense',
|
||||||
|
):
|
||||||
|
if key in civitai_model:
|
||||||
|
license_source[key] = civitai_model.get(key)
|
||||||
|
|
||||||
|
_, license_flags = resolve_license_info(license_source)
|
||||||
|
entry['license_flags'] = license_flags
|
||||||
|
|
||||||
async def initialize_in_background(self) -> None:
|
async def initialize_in_background(self) -> None:
|
||||||
"""Initialize cache in background using thread pool"""
|
"""Initialize cache in background using thread pool"""
|
||||||
try:
|
try:
|
||||||
@@ -567,6 +643,7 @@ class ModelScanner:
|
|||||||
|
|
||||||
async def _initialize_cache(self) -> None:
|
async def _initialize_cache(self) -> None:
|
||||||
"""Initialize or refresh the cache"""
|
"""Initialize or refresh the cache"""
|
||||||
|
print("init start", flush=True)
|
||||||
self._is_initializing = True # Set flag
|
self._is_initializing = True # Set flag
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -575,6 +652,7 @@ class ModelScanner:
|
|||||||
scan_result = await self._gather_model_data()
|
scan_result = await self._gather_model_data()
|
||||||
await self._apply_scan_result(scan_result)
|
await self._apply_scan_result(scan_result)
|
||||||
await self._save_persistent_cache(scan_result)
|
await self._save_persistent_cache(scan_result)
|
||||||
|
print("init end", flush=True)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||||
@@ -681,6 +759,7 @@ class ModelScanner:
|
|||||||
model_data = self.adjust_cached_entry(dict(model_data))
|
model_data = self.adjust_cached_entry(dict(model_data))
|
||||||
if not model_data:
|
if not model_data:
|
||||||
continue
|
continue
|
||||||
|
self._ensure_license_flags(model_data)
|
||||||
# Add to cache
|
# Add to cache
|
||||||
self._cache.raw_data.append(model_data)
|
self._cache.raw_data.append(model_data)
|
||||||
self._cache.add_to_version_index(model_data)
|
self._cache.add_to_version_index(model_data)
|
||||||
@@ -975,6 +1054,7 @@ class ModelScanner:
|
|||||||
processed_files += 1
|
processed_files += 1
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
self._ensure_license_flags(result)
|
||||||
raw_data.append(result)
|
raw_data.append(result)
|
||||||
|
|
||||||
sha_value = result.get('sha256')
|
sha_value = result.get('sha256')
|
||||||
|
|||||||
@@ -21,6 +21,9 @@ class PersistedCacheData:
|
|||||||
excluded_models: List[str]
|
excluded_models: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_LICENSE_FLAGS = 57 # 57 (0b111001) encodes CivitAI's documented default license permissions.
|
||||||
|
|
||||||
|
|
||||||
class PersistentModelCache:
|
class PersistentModelCache:
|
||||||
"""Persist core model metadata and hash index data in SQLite."""
|
"""Persist core model metadata and hash index data in SQLite."""
|
||||||
|
|
||||||
@@ -47,6 +50,7 @@ class PersistentModelCache:
|
|||||||
"civitai_name",
|
"civitai_name",
|
||||||
"civitai_creator_username",
|
"civitai_creator_username",
|
||||||
"trained_words",
|
"trained_words",
|
||||||
|
"license_flags",
|
||||||
"civitai_deleted",
|
"civitai_deleted",
|
||||||
"exclude",
|
"exclude",
|
||||||
"db_checked",
|
"db_checked",
|
||||||
@@ -149,6 +153,10 @@ class PersistentModelCache:
|
|||||||
if creator_username:
|
if creator_username:
|
||||||
civitai.setdefault("creator", {})["username"] = creator_username
|
civitai.setdefault("creator", {})["username"] = creator_username
|
||||||
|
|
||||||
|
license_value = row["license_flags"]
|
||||||
|
if license_value is None:
|
||||||
|
license_value = DEFAULT_LICENSE_FLAGS
|
||||||
|
|
||||||
item = {
|
item = {
|
||||||
"file_path": file_path,
|
"file_path": file_path,
|
||||||
"file_name": row["file_name"],
|
"file_name": row["file_name"],
|
||||||
@@ -171,6 +179,7 @@ class PersistentModelCache:
|
|||||||
"tags": tags.get(file_path, []),
|
"tags": tags.get(file_path, []),
|
||||||
"civitai": civitai,
|
"civitai": civitai,
|
||||||
"civitai_deleted": bool(row["civitai_deleted"]),
|
"civitai_deleted": bool(row["civitai_deleted"]),
|
||||||
|
"license_flags": int(license_value),
|
||||||
}
|
}
|
||||||
raw_data.append(item)
|
raw_data.append(item)
|
||||||
|
|
||||||
@@ -484,6 +493,8 @@ class PersistentModelCache:
|
|||||||
"metadata_source": "TEXT",
|
"metadata_source": "TEXT",
|
||||||
"civitai_creator_username": "TEXT",
|
"civitai_creator_username": "TEXT",
|
||||||
"civitai_deleted": "INTEGER DEFAULT 0",
|
"civitai_deleted": "INTEGER DEFAULT 0",
|
||||||
|
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
|
||||||
|
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
|
||||||
}
|
}
|
||||||
|
|
||||||
for column, definition in required_columns.items():
|
for column, definition in required_columns.items():
|
||||||
@@ -518,6 +529,10 @@ class PersistentModelCache:
|
|||||||
if isinstance(creator_data, dict):
|
if isinstance(creator_data, dict):
|
||||||
creator_username = creator_data.get("username") or None
|
creator_username = creator_data.get("username") or None
|
||||||
|
|
||||||
|
license_flags = item.get("license_flags")
|
||||||
|
if license_flags is None:
|
||||||
|
license_flags = DEFAULT_LICENSE_FLAGS
|
||||||
|
|
||||||
return (
|
return (
|
||||||
model_type,
|
model_type,
|
||||||
item.get("file_path"),
|
item.get("file_path"),
|
||||||
@@ -540,6 +555,7 @@ class PersistentModelCache:
|
|||||||
civitai.get("name"),
|
civitai.get("name"),
|
||||||
creator_username,
|
creator_username,
|
||||||
trained_words_json,
|
trained_words_json,
|
||||||
|
int(license_flags),
|
||||||
1 if item.get("civitai_deleted") else 0,
|
1 if item.get("civitai_deleted") else 0,
|
||||||
1 if item.get("exclude") else 0,
|
1 if item.get("exclude") else 0,
|
||||||
1 if item.get("db_checked") else 0,
|
1 if item.get("db_checked") else 0,
|
||||||
|
|||||||
@@ -2,9 +2,141 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import IntEnum
|
||||||
|
from typing import Any, Dict, Iterable, Mapping, Sequence
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
|
|
||||||
|
class CommercialUseLevel(IntEnum):
|
||||||
|
"""Enumerate supported commercial use permission levels."""
|
||||||
|
|
||||||
|
NONE = 0
|
||||||
|
IMAGE = 1
|
||||||
|
RENT_CIVIT = 2
|
||||||
|
RENT = 3
|
||||||
|
SELL = 4
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_ALLOW_COMMERCIAL_USE: Sequence[str] = ("Sell",)
|
||||||
|
_LICENSE_DEFAULTS: Dict[str, Any] = {
|
||||||
|
"allowNoCredit": True,
|
||||||
|
"allowCommercialUse": _DEFAULT_ALLOW_COMMERCIAL_USE,
|
||||||
|
"allowDerivatives": True,
|
||||||
|
"allowDifferentLicense": True,
|
||||||
|
}
|
||||||
|
_COMMERCIAL_VALUE_TO_LEVEL = {
|
||||||
|
"none": CommercialUseLevel.NONE,
|
||||||
|
"image": CommercialUseLevel.IMAGE,
|
||||||
|
"rentcivit": CommercialUseLevel.RENT_CIVIT,
|
||||||
|
"rent": CommercialUseLevel.RENT,
|
||||||
|
"sell": CommercialUseLevel.SELL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_commercial_values(value: Any) -> Sequence[str]:
|
||||||
|
"""Return a normalized list of commercial permissions preserving source values."""
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
return list(_DEFAULT_ALLOW_COMMERCIAL_USE)
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [value]
|
||||||
|
|
||||||
|
if isinstance(value, Iterable):
|
||||||
|
result = []
|
||||||
|
for item in value:
|
||||||
|
if item is None:
|
||||||
|
continue
|
||||||
|
if isinstance(item, str):
|
||||||
|
result.append(item)
|
||||||
|
continue
|
||||||
|
result.append(str(item))
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
|
||||||
|
return list(_DEFAULT_ALLOW_COMMERCIAL_USE)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_bool(value: Any, fallback: bool) -> bool:
|
||||||
|
if value is None:
|
||||||
|
return fallback
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_license_payload(model_data: Mapping[str, Any] | None) -> Dict[str, Any]:
|
||||||
|
"""Extract license fields from model metadata applying documented defaults."""
|
||||||
|
|
||||||
|
payload: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
allow_no_credit = payload["allowNoCredit"] = _to_bool(
|
||||||
|
(model_data or {}).get("allowNoCredit"),
|
||||||
|
_LICENSE_DEFAULTS["allowNoCredit"],
|
||||||
|
)
|
||||||
|
|
||||||
|
commercial = _normalize_commercial_values(
|
||||||
|
(model_data or {}).get("allowCommercialUse"),
|
||||||
|
)
|
||||||
|
payload["allowCommercialUse"] = list(commercial)
|
||||||
|
|
||||||
|
allow_derivatives = payload["allowDerivatives"] = _to_bool(
|
||||||
|
(model_data or {}).get("allowDerivatives"),
|
||||||
|
_LICENSE_DEFAULTS["allowDerivatives"],
|
||||||
|
)
|
||||||
|
|
||||||
|
allow_different_license = payload["allowDifferentLicense"] = _to_bool(
|
||||||
|
(model_data or {}).get("allowDifferentLicense"),
|
||||||
|
_LICENSE_DEFAULTS["allowDifferentLicense"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure booleans are plain bool instances
|
||||||
|
payload["allowNoCredit"] = bool(allow_no_credit)
|
||||||
|
payload["allowDerivatives"] = bool(allow_derivatives)
|
||||||
|
payload["allowDifferentLicense"] = bool(allow_different_license)
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_commercial_level(values: Sequence[str]) -> CommercialUseLevel:
|
||||||
|
level = CommercialUseLevel.NONE
|
||||||
|
for value in values:
|
||||||
|
normalized = str(value).strip().lower().replace("_", "")
|
||||||
|
normalized = normalized.replace("-", "")
|
||||||
|
candidate = _COMMERCIAL_VALUE_TO_LEVEL.get(normalized)
|
||||||
|
if candidate is None:
|
||||||
|
continue
|
||||||
|
if candidate > level:
|
||||||
|
level = candidate
|
||||||
|
return level
|
||||||
|
|
||||||
|
|
||||||
|
def build_license_flags(payload: Mapping[str, Any] | None) -> int:
|
||||||
|
"""Encode license payload into a compact bitset for cache storage."""
|
||||||
|
|
||||||
|
resolved = resolve_license_payload(payload or {})
|
||||||
|
|
||||||
|
flags = 0
|
||||||
|
if resolved.get("allowNoCredit", True):
|
||||||
|
flags |= 1 << 0
|
||||||
|
|
||||||
|
commercial_level = _resolve_commercial_level(resolved.get("allowCommercialUse", ()))
|
||||||
|
flags |= (int(commercial_level) & 0b111) << 1
|
||||||
|
|
||||||
|
if resolved.get("allowDerivatives", True):
|
||||||
|
flags |= 1 << 4
|
||||||
|
|
||||||
|
if resolved.get("allowDifferentLicense", True):
|
||||||
|
flags |= 1 << 5
|
||||||
|
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_license_info(model_data: Mapping[str, Any] | None) -> tuple[Dict[str, Any], int]:
|
||||||
|
"""Return normalized license payload and its encoded bitset."""
|
||||||
|
|
||||||
|
payload = resolve_license_payload(model_data)
|
||||||
|
return payload, build_license_flags(payload)
|
||||||
|
|
||||||
|
|
||||||
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
|
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
|
||||||
"""Rewrite Civitai preview URLs to use optimized renditions.
|
"""Rewrite Civitai preview URLs to use optimized renditions.
|
||||||
|
|
||||||
@@ -43,5 +175,10 @@ def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -
|
|||||||
return rewritten, True
|
return rewritten, True
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["rewrite_preview_url"]
|
__all__ = [
|
||||||
|
"CommercialUseLevel",
|
||||||
|
"build_license_flags",
|
||||||
|
"resolve_license_payload",
|
||||||
|
"resolve_license_info",
|
||||||
|
"rewrite_preview_url",
|
||||||
|
]
|
||||||
|
|||||||
@@ -11,7 +11,11 @@
|
|||||||
"type": "LORA",
|
"type": "LORA",
|
||||||
"nsfw": false,
|
"nsfw": false,
|
||||||
"description": "description",
|
"description": "description",
|
||||||
"tags": ["style"]
|
"tags": ["style"],
|
||||||
|
"allowNoCredit": true,
|
||||||
|
"allowCommercialUse": ["Sell"],
|
||||||
|
"allowDerivatives": true,
|
||||||
|
"allowDifferentLicense": true
|
||||||
},
|
},
|
||||||
"files": [
|
"files": [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -75,6 +75,10 @@ async def test_update_model_metadata_merges_and_persists():
|
|||||||
"description": "desc",
|
"description": "desc",
|
||||||
"tags": ["style"],
|
"tags": ["style"],
|
||||||
"creator": {"id": 2},
|
"creator": {"id": 2},
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": ["Image"],
|
||||||
|
"allowDerivatives": False,
|
||||||
|
"allowDifferentLicense": True,
|
||||||
},
|
},
|
||||||
"baseModel": "sdxl",
|
"baseModel": "sdxl",
|
||||||
"images": ["img"],
|
"images": ["img"],
|
||||||
@@ -92,6 +96,13 @@ async def test_update_model_metadata_merges_and_persists():
|
|||||||
assert result["modelDescription"] == "desc"
|
assert result["modelDescription"] == "desc"
|
||||||
assert result["tags"] == ["style"]
|
assert result["tags"] == ["style"]
|
||||||
assert result["base_model"] == "SDXL 1.0"
|
assert result["base_model"] == "SDXL 1.0"
|
||||||
|
civitai_model = result["civitai"]["model"]
|
||||||
|
assert civitai_model["allowNoCredit"] is False
|
||||||
|
assert civitai_model["allowCommercialUse"] == ["Image"]
|
||||||
|
assert civitai_model["allowDerivatives"] is False
|
||||||
|
assert civitai_model["allowDifferentLicense"] is True
|
||||||
|
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
|
||||||
|
assert key not in result
|
||||||
|
|
||||||
helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once()
|
helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once()
|
||||||
helpers.metadata_manager.save_metadata.assert_awaited_once_with(
|
helpers.metadata_manager.save_metadata.assert_awaited_once_with(
|
||||||
@@ -142,6 +153,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path):
|
|||||||
assert model_data["civitai_deleted"] is False
|
assert model_data["civitai_deleted"] is False
|
||||||
assert "civitai" in model_data
|
assert "civitai" in model_data
|
||||||
assert model_data["metadata_source"] == "civitai_api"
|
assert model_data["metadata_source"] == "civitai_api"
|
||||||
|
civitai_model = model_data["civitai"]["model"]
|
||||||
|
assert civitai_model["allowNoCredit"] is True
|
||||||
|
assert civitai_model["allowDerivatives"] is True
|
||||||
|
assert civitai_model["allowDifferentLicense"] is True
|
||||||
|
assert civitai_model["allowCommercialUse"] == ["Sell"]
|
||||||
|
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
|
||||||
|
assert key not in model_data
|
||||||
|
|
||||||
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
|
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
|
||||||
assert model_data["hydrated"] is True
|
assert model_data["hydrated"] is True
|
||||||
@@ -151,7 +169,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path):
|
|||||||
assert await_args, "expected metadata to be persisted"
|
assert await_args, "expected metadata to be persisted"
|
||||||
last_call = await_args[-1]
|
last_call = await_args[-1]
|
||||||
assert last_call.args[0] == metadata_path
|
assert last_call.args[0] == metadata_path
|
||||||
assert last_call.args[1]["hydrated"] is True
|
persisted_payload = last_call.args[1]
|
||||||
|
assert persisted_payload["hydrated"] is True
|
||||||
|
civitai_model = persisted_payload["civitai"]["model"]
|
||||||
|
assert civitai_model["allowNoCredit"] is True
|
||||||
|
assert civitai_model["allowCommercialUse"] == ["Sell"]
|
||||||
|
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
|
||||||
|
assert key not in persisted_payload
|
||||||
update_cache.assert_awaited_once()
|
update_cache.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
@@ -422,4 +446,3 @@ async def test_relink_metadata_raises_when_version_missing():
|
|||||||
model_id=9,
|
model_id=9,
|
||||||
model_version_id=None,
|
model_version_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ from py.services import model_scanner
|
|||||||
from py.services.model_cache import ModelCache
|
from py.services.model_cache import ModelCache
|
||||||
from py.services.model_hash_index import ModelHashIndex
|
from py.services.model_hash_index import ModelHashIndex
|
||||||
from py.services.model_scanner import CacheBuildResult, ModelScanner
|
from py.services.model_scanner import CacheBuildResult, ModelScanner
|
||||||
from py.services.persistent_model_cache import PersistentModelCache
|
from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS
|
||||||
|
from py.utils.civitai_utils import build_license_flags
|
||||||
from py.utils.models import BaseModelMetadata
|
from py.utils.models import BaseModelMetadata
|
||||||
|
|
||||||
|
|
||||||
@@ -122,6 +123,7 @@ async def test_initialize_cache_populates_cache(tmp_path: Path):
|
|||||||
_normalize_path(tmp_path / "one.txt"),
|
_normalize_path(tmp_path / "one.txt"),
|
||||||
_normalize_path(tmp_path / "nested" / "two.txt"),
|
_normalize_path(tmp_path / "nested" / "two.txt"),
|
||||||
}
|
}
|
||||||
|
assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS}
|
||||||
|
|
||||||
assert scanner._hash_index.get_path("hash-one") == _normalize_path(tmp_path / "one.txt")
|
assert scanner._hash_index.get_path("hash-one") == _normalize_path(tmp_path / "one.txt")
|
||||||
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
||||||
@@ -179,12 +181,49 @@ async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monk
|
|||||||
_normalize_path(tmp_path / "one.txt"),
|
_normalize_path(tmp_path / "one.txt"),
|
||||||
_normalize_path(tmp_path / "nested" / "two.txt"),
|
_normalize_path(tmp_path / "nested" / "two.txt"),
|
||||||
}
|
}
|
||||||
|
assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS}
|
||||||
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
||||||
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
||||||
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
|
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
|
||||||
assert ws_stub.payloads[-1]["progress"] == 100
|
assert ws_stub.payloads[-1]["progress"] == 100
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_build_cache_entry_encodes_license_flags(tmp_path: Path):
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"file_path": _normalize_path(tmp_path / "sample.txt"),
|
||||||
|
"file_name": "sample",
|
||||||
|
"model_name": "Sample",
|
||||||
|
"folder": "",
|
||||||
|
"size": 1,
|
||||||
|
"modified": 1.0,
|
||||||
|
"sha256": "hash",
|
||||||
|
"tags": [],
|
||||||
|
"civitai": {
|
||||||
|
"model": {
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": ["Image", "Rent"],
|
||||||
|
"allowDerivatives": True,
|
||||||
|
"allowDifferentLicense": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_flags = build_license_flags(
|
||||||
|
{
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": ["Image", "Rent"],
|
||||||
|
"allowDerivatives": True,
|
||||||
|
"allowDifferentLicense": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
entry = scanner._build_cache_entry(metadata)
|
||||||
|
assert entry["license_flags"] == expected_flags
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_initialize_in_background_uses_persisted_cache_without_full_scan(tmp_path: Path, monkeypatch):
|
async def test_initialize_in_background_uses_persisted_cache_without_full_scan(tmp_path: Path, monkeypatch):
|
||||||
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
|
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.persistent_model_cache import PersistentModelCache
|
from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS
|
||||||
|
|
||||||
|
|
||||||
def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
||||||
@@ -43,6 +43,7 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
|||||||
'trainedWords': ['word1'],
|
'trainedWords': ['word1'],
|
||||||
'creator': {'username': 'artist42'},
|
'creator': {'username': 'artist42'},
|
||||||
},
|
},
|
||||||
|
'license_flags': 13,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'file_path': file_b,
|
'file_path': file_b,
|
||||||
@@ -91,12 +92,14 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
|||||||
assert first['metadata_source'] == 'civitai_api'
|
assert first['metadata_source'] == 'civitai_api'
|
||||||
assert first['civitai']['creator']['username'] == 'artist42'
|
assert first['civitai']['creator']['username'] == 'artist42'
|
||||||
assert first['civitai_deleted'] is False
|
assert first['civitai_deleted'] is False
|
||||||
|
assert first['license_flags'] == 13
|
||||||
|
|
||||||
second = items[file_b]
|
second = items[file_b]
|
||||||
assert second['exclude'] is True
|
assert second['exclude'] is True
|
||||||
assert second['civitai'] is None
|
assert second['civitai'] is None
|
||||||
assert second['metadata_source'] is None
|
assert second['metadata_source'] is None
|
||||||
assert second['civitai_deleted'] is True
|
assert second['civitai_deleted'] is True
|
||||||
|
assert second['license_flags'] == DEFAULT_LICENSE_FLAGS
|
||||||
|
|
||||||
expected_hash_pairs = {
|
expected_hash_pairs = {
|
||||||
('hash-a', file_a),
|
('hash-a', file_a),
|
||||||
|
|||||||
35
tests/utils/test_civitai_utils.py
Normal file
35
tests/utils/test_civitai_utils.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
from py.utils.civitai_utils import (
|
||||||
|
CommercialUseLevel,
|
||||||
|
build_license_flags,
|
||||||
|
resolve_license_info,
|
||||||
|
resolve_license_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_license_payload_defaults():
|
||||||
|
payload, flags = resolve_license_info({})
|
||||||
|
|
||||||
|
assert payload["allowNoCredit"] is True
|
||||||
|
assert payload["allowDerivatives"] is True
|
||||||
|
assert payload["allowDifferentLicense"] is True
|
||||||
|
assert payload["allowCommercialUse"] == ["Sell"]
|
||||||
|
assert flags == 57
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_license_flags_custom_values():
|
||||||
|
source = {
|
||||||
|
"allowNoCredit": False,
|
||||||
|
"allowCommercialUse": {"Image", "Sell"},
|
||||||
|
"allowDerivatives": False,
|
||||||
|
"allowDifferentLicense": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = resolve_license_payload(source)
|
||||||
|
assert payload["allowNoCredit"] is False
|
||||||
|
assert set(payload["allowCommercialUse"]) == {"Image", "Sell"}
|
||||||
|
assert payload["allowDerivatives"] is False
|
||||||
|
assert payload["allowDifferentLicense"] is False
|
||||||
|
|
||||||
|
flags = build_license_flags(source)
|
||||||
|
# Highest commercial level is SELL -> level 4 -> shifted by 1 == 8.
|
||||||
|
assert flags == (CommercialUseLevel.SELL << 1)
|
||||||
Reference in New Issue
Block a user