mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
feat: add license information handling for Civitai models
Add license resolution utilities and integrate license information into model metadata processing. The changes include: - Add `resolve_license_payload` function to extract license data from Civitai model responses - Integrate license information into model metadata in CivitaiClient and MetadataSyncService - Add license flags support in model scanning and caching - Implement CommercialUseLevel enum for standardized license classification - Update model scanner to handle unknown fields when extracting metadata values This ensures proper license attribution and compliance when working with Civitai models.
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any, Optional, Dict, Tuple, List, Sequence
|
||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||
from .downloader import get_downloader
|
||||
from .errors import RateLimitError, ResourceNotFoundError
|
||||
from ..utils.civitai_utils import resolve_license_payload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -420,6 +421,10 @@ class CivitaiClient:
|
||||
model_info['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
|
||||
license_payload = resolve_license_payload(model_data)
|
||||
for field, value in license_payload.items():
|
||||
model_info[field] = value
|
||||
|
||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||
"""Fetch model version metadata from Civitai
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
|
||||
from ..services.settings_manager import SettingsManager
|
||||
from ..utils.civitai_utils import resolve_license_payload
|
||||
from ..utils.model_utils import determine_base_model
|
||||
from .errors import RateLimitError
|
||||
|
||||
@@ -135,6 +136,17 @@ class MetadataSyncService:
|
||||
):
|
||||
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||
|
||||
merged_civitai = local_metadata.get("civitai") or {}
|
||||
civitai_model = merged_civitai.get("model")
|
||||
if not isinstance(civitai_model, dict):
|
||||
civitai_model = {}
|
||||
|
||||
license_payload = resolve_license_payload(model_data)
|
||||
civitai_model.update(license_payload)
|
||||
|
||||
merged_civitai["model"] = civitai_model
|
||||
local_metadata["civitai"] = merged_civitai
|
||||
|
||||
local_metadata["base_model"] = determine_base_model(
|
||||
civitai_metadata.get("baseModel")
|
||||
)
|
||||
@@ -295,6 +307,7 @@ class MetadataSyncService:
|
||||
"preview_url": local_metadata.get("preview_url"),
|
||||
"civitai": local_metadata.get("civitai"),
|
||||
}
|
||||
|
||||
model_data.update(update_payload)
|
||||
|
||||
await update_cache_func(file_path, file_path, local_metadata)
|
||||
@@ -436,4 +449,3 @@ class MetadataSyncService:
|
||||
results["verified_as_duplicates"] = False
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
from ..utils.file_utils import find_preview_file, get_preview_extension
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..utils.civitai_utils import resolve_license_info
|
||||
from .model_cache import ModelCache
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
@@ -175,7 +176,17 @@ class ModelScanner:
|
||||
def get_value(key: str, default: Any = None) -> Any:
|
||||
if is_mapping:
|
||||
return source.get(key, default)
|
||||
return getattr(source, key, default)
|
||||
|
||||
sentinel = object()
|
||||
value = getattr(source, key, sentinel)
|
||||
if value is not sentinel:
|
||||
return value
|
||||
|
||||
unknown = getattr(source, "_unknown_fields", None)
|
||||
if isinstance(unknown, dict) and key in unknown:
|
||||
return unknown[key]
|
||||
|
||||
return default
|
||||
|
||||
file_path = file_path_override or get_value('file_path', '') or ''
|
||||
normalized_path = file_path.replace('\\', '/')
|
||||
@@ -197,7 +208,8 @@ class ModelScanner:
|
||||
else:
|
||||
preview_url = ''
|
||||
|
||||
civitai_slim = self._slim_civitai_payload(get_value('civitai'))
|
||||
civitai_full = get_value('civitai')
|
||||
civitai_slim = self._slim_civitai_payload(civitai_full)
|
||||
usage_tips = get_value('usage_tips', '') or ''
|
||||
if not isinstance(usage_tips, str):
|
||||
usage_tips = str(usage_tips)
|
||||
@@ -229,12 +241,76 @@ class ModelScanner:
|
||||
'civitai_deleted': bool(get_value('civitai_deleted', False)),
|
||||
}
|
||||
|
||||
license_source: Dict[str, Any] = {}
|
||||
if isinstance(civitai_full, Mapping):
|
||||
civitai_model = civitai_full.get('model')
|
||||
if isinstance(civitai_model, Mapping):
|
||||
for key in (
|
||||
'allowNoCredit',
|
||||
'allowCommercialUse',
|
||||
'allowDerivatives',
|
||||
'allowDifferentLicense',
|
||||
):
|
||||
if key in civitai_model:
|
||||
license_source[key] = civitai_model.get(key)
|
||||
|
||||
for key in (
|
||||
'allowNoCredit',
|
||||
'allowCommercialUse',
|
||||
'allowDerivatives',
|
||||
'allowDifferentLicense',
|
||||
):
|
||||
if key not in license_source:
|
||||
value = get_value(key)
|
||||
if value is not None:
|
||||
license_source[key] = value
|
||||
|
||||
_, license_flags = resolve_license_info(license_source or {})
|
||||
entry['license_flags'] = license_flags
|
||||
|
||||
model_type = get_value('model_type', None)
|
||||
if model_type:
|
||||
entry['model_type'] = model_type
|
||||
|
||||
return entry
|
||||
|
||||
def _ensure_license_flags(self, entry: Dict[str, Any]) -> None:
|
||||
"""Ensure cached entries include an integer license flag bitset."""
|
||||
|
||||
if not isinstance(entry, dict):
|
||||
return
|
||||
|
||||
license_value = entry.get('license_flags')
|
||||
if license_value is not None:
|
||||
try:
|
||||
entry['license_flags'] = int(license_value)
|
||||
except (TypeError, ValueError):
|
||||
_, fallback_flags = resolve_license_info({})
|
||||
entry['license_flags'] = fallback_flags
|
||||
return
|
||||
|
||||
license_source = {
|
||||
'allowNoCredit': entry.get('allowNoCredit'),
|
||||
'allowCommercialUse': entry.get('allowCommercialUse'),
|
||||
'allowDerivatives': entry.get('allowDerivatives'),
|
||||
'allowDifferentLicense': entry.get('allowDifferentLicense'),
|
||||
}
|
||||
civitai_full = entry.get('civitai')
|
||||
if isinstance(civitai_full, Mapping):
|
||||
civitai_model = civitai_full.get('model')
|
||||
if isinstance(civitai_model, Mapping):
|
||||
for key in (
|
||||
'allowNoCredit',
|
||||
'allowCommercialUse',
|
||||
'allowDerivatives',
|
||||
'allowDifferentLicense',
|
||||
):
|
||||
if key in civitai_model:
|
||||
license_source[key] = civitai_model.get(key)
|
||||
|
||||
_, license_flags = resolve_license_info(license_source)
|
||||
entry['license_flags'] = license_flags
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
try:
|
||||
@@ -567,6 +643,7 @@ class ModelScanner:
|
||||
|
||||
async def _initialize_cache(self) -> None:
|
||||
"""Initialize or refresh the cache"""
|
||||
print("init start", flush=True)
|
||||
self._is_initializing = True # Set flag
|
||||
try:
|
||||
start_time = time.time()
|
||||
@@ -575,6 +652,7 @@ class ModelScanner:
|
||||
scan_result = await self._gather_model_data()
|
||||
await self._apply_scan_result(scan_result)
|
||||
await self._save_persistent_cache(scan_result)
|
||||
print("init end", flush=True)
|
||||
|
||||
logger.info(
|
||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||
@@ -681,6 +759,7 @@ class ModelScanner:
|
||||
model_data = self.adjust_cached_entry(dict(model_data))
|
||||
if not model_data:
|
||||
continue
|
||||
self._ensure_license_flags(model_data)
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(model_data)
|
||||
self._cache.add_to_version_index(model_data)
|
||||
@@ -975,6 +1054,7 @@ class ModelScanner:
|
||||
processed_files += 1
|
||||
|
||||
if result:
|
||||
self._ensure_license_flags(result)
|
||||
raw_data.append(result)
|
||||
|
||||
sha_value = result.get('sha256')
|
||||
|
||||
@@ -21,6 +21,9 @@ class PersistedCacheData:
|
||||
excluded_models: List[str]
|
||||
|
||||
|
||||
DEFAULT_LICENSE_FLAGS = 57 # 57 (0b111001) encodes CivitAI's documented default license permissions.
|
||||
|
||||
|
||||
class PersistentModelCache:
|
||||
"""Persist core model metadata and hash index data in SQLite."""
|
||||
|
||||
@@ -47,6 +50,7 @@ class PersistentModelCache:
|
||||
"civitai_name",
|
||||
"civitai_creator_username",
|
||||
"trained_words",
|
||||
"license_flags",
|
||||
"civitai_deleted",
|
||||
"exclude",
|
||||
"db_checked",
|
||||
@@ -149,6 +153,10 @@ class PersistentModelCache:
|
||||
if creator_username:
|
||||
civitai.setdefault("creator", {})["username"] = creator_username
|
||||
|
||||
license_value = row["license_flags"]
|
||||
if license_value is None:
|
||||
license_value = DEFAULT_LICENSE_FLAGS
|
||||
|
||||
item = {
|
||||
"file_path": file_path,
|
||||
"file_name": row["file_name"],
|
||||
@@ -171,6 +179,7 @@ class PersistentModelCache:
|
||||
"tags": tags.get(file_path, []),
|
||||
"civitai": civitai,
|
||||
"civitai_deleted": bool(row["civitai_deleted"]),
|
||||
"license_flags": int(license_value),
|
||||
}
|
||||
raw_data.append(item)
|
||||
|
||||
@@ -484,6 +493,8 @@ class PersistentModelCache:
|
||||
"metadata_source": "TEXT",
|
||||
"civitai_creator_username": "TEXT",
|
||||
"civitai_deleted": "INTEGER DEFAULT 0",
|
||||
# Persisting without explicit flags should assume CivitAI's documented defaults (0b111001 == 57).
|
||||
"license_flags": f"INTEGER DEFAULT {DEFAULT_LICENSE_FLAGS}",
|
||||
}
|
||||
|
||||
for column, definition in required_columns.items():
|
||||
@@ -518,6 +529,10 @@ class PersistentModelCache:
|
||||
if isinstance(creator_data, dict):
|
||||
creator_username = creator_data.get("username") or None
|
||||
|
||||
license_flags = item.get("license_flags")
|
||||
if license_flags is None:
|
||||
license_flags = DEFAULT_LICENSE_FLAGS
|
||||
|
||||
return (
|
||||
model_type,
|
||||
item.get("file_path"),
|
||||
@@ -540,6 +555,7 @@ class PersistentModelCache:
|
||||
civitai.get("name"),
|
||||
creator_username,
|
||||
trained_words_json,
|
||||
int(license_flags),
|
||||
1 if item.get("civitai_deleted") else 0,
|
||||
1 if item.get("exclude") else 0,
|
||||
1 if item.get("db_checked") else 0,
|
||||
|
||||
@@ -2,9 +2,141 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, Iterable, Mapping, Sequence
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
||||
class CommercialUseLevel(IntEnum):
|
||||
"""Enumerate supported commercial use permission levels."""
|
||||
|
||||
NONE = 0
|
||||
IMAGE = 1
|
||||
RENT_CIVIT = 2
|
||||
RENT = 3
|
||||
SELL = 4
|
||||
|
||||
|
||||
_DEFAULT_ALLOW_COMMERCIAL_USE: Sequence[str] = ("Sell",)
|
||||
_LICENSE_DEFAULTS: Dict[str, Any] = {
|
||||
"allowNoCredit": True,
|
||||
"allowCommercialUse": _DEFAULT_ALLOW_COMMERCIAL_USE,
|
||||
"allowDerivatives": True,
|
||||
"allowDifferentLicense": True,
|
||||
}
|
||||
_COMMERCIAL_VALUE_TO_LEVEL = {
|
||||
"none": CommercialUseLevel.NONE,
|
||||
"image": CommercialUseLevel.IMAGE,
|
||||
"rentcivit": CommercialUseLevel.RENT_CIVIT,
|
||||
"rent": CommercialUseLevel.RENT,
|
||||
"sell": CommercialUseLevel.SELL,
|
||||
}
|
||||
|
||||
|
||||
def _normalize_commercial_values(value: Any) -> Sequence[str]:
|
||||
"""Return a normalized list of commercial permissions preserving source values."""
|
||||
|
||||
if value is None:
|
||||
return list(_DEFAULT_ALLOW_COMMERCIAL_USE)
|
||||
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
|
||||
if isinstance(value, Iterable):
|
||||
result = []
|
||||
for item in value:
|
||||
if item is None:
|
||||
continue
|
||||
if isinstance(item, str):
|
||||
result.append(item)
|
||||
continue
|
||||
result.append(str(item))
|
||||
if result:
|
||||
return result
|
||||
|
||||
return list(_DEFAULT_ALLOW_COMMERCIAL_USE)
|
||||
|
||||
|
||||
def _to_bool(value: Any, fallback: bool) -> bool:
|
||||
if value is None:
|
||||
return fallback
|
||||
return bool(value)
|
||||
|
||||
|
||||
def resolve_license_payload(model_data: Mapping[str, Any] | None) -> Dict[str, Any]:
|
||||
"""Extract license fields from model metadata applying documented defaults."""
|
||||
|
||||
payload: Dict[str, Any] = {}
|
||||
|
||||
allow_no_credit = payload["allowNoCredit"] = _to_bool(
|
||||
(model_data or {}).get("allowNoCredit"),
|
||||
_LICENSE_DEFAULTS["allowNoCredit"],
|
||||
)
|
||||
|
||||
commercial = _normalize_commercial_values(
|
||||
(model_data or {}).get("allowCommercialUse"),
|
||||
)
|
||||
payload["allowCommercialUse"] = list(commercial)
|
||||
|
||||
allow_derivatives = payload["allowDerivatives"] = _to_bool(
|
||||
(model_data or {}).get("allowDerivatives"),
|
||||
_LICENSE_DEFAULTS["allowDerivatives"],
|
||||
)
|
||||
|
||||
allow_different_license = payload["allowDifferentLicense"] = _to_bool(
|
||||
(model_data or {}).get("allowDifferentLicense"),
|
||||
_LICENSE_DEFAULTS["allowDifferentLicense"],
|
||||
)
|
||||
|
||||
# Ensure booleans are plain bool instances
|
||||
payload["allowNoCredit"] = bool(allow_no_credit)
|
||||
payload["allowDerivatives"] = bool(allow_derivatives)
|
||||
payload["allowDifferentLicense"] = bool(allow_different_license)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _resolve_commercial_level(values: Sequence[str]) -> CommercialUseLevel:
|
||||
level = CommercialUseLevel.NONE
|
||||
for value in values:
|
||||
normalized = str(value).strip().lower().replace("_", "")
|
||||
normalized = normalized.replace("-", "")
|
||||
candidate = _COMMERCIAL_VALUE_TO_LEVEL.get(normalized)
|
||||
if candidate is None:
|
||||
continue
|
||||
if candidate > level:
|
||||
level = candidate
|
||||
return level
|
||||
|
||||
|
||||
def build_license_flags(payload: Mapping[str, Any] | None) -> int:
|
||||
"""Encode license payload into a compact bitset for cache storage."""
|
||||
|
||||
resolved = resolve_license_payload(payload or {})
|
||||
|
||||
flags = 0
|
||||
if resolved.get("allowNoCredit", True):
|
||||
flags |= 1 << 0
|
||||
|
||||
commercial_level = _resolve_commercial_level(resolved.get("allowCommercialUse", ()))
|
||||
flags |= (int(commercial_level) & 0b111) << 1
|
||||
|
||||
if resolved.get("allowDerivatives", True):
|
||||
flags |= 1 << 4
|
||||
|
||||
if resolved.get("allowDifferentLicense", True):
|
||||
flags |= 1 << 5
|
||||
|
||||
return flags
|
||||
|
||||
|
||||
def resolve_license_info(model_data: Mapping[str, Any] | None) -> tuple[Dict[str, Any], int]:
|
||||
"""Return normalized license payload and its encoded bitset."""
|
||||
|
||||
payload = resolve_license_payload(model_data)
|
||||
return payload, build_license_flags(payload)
|
||||
|
||||
|
||||
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
|
||||
"""Rewrite Civitai preview URLs to use optimized renditions.
|
||||
|
||||
@@ -43,5 +175,10 @@ def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -
|
||||
return rewritten, True
|
||||
|
||||
|
||||
__all__ = ["rewrite_preview_url"]
|
||||
|
||||
__all__ = [
|
||||
"CommercialUseLevel",
|
||||
"build_license_flags",
|
||||
"resolve_license_payload",
|
||||
"resolve_license_info",
|
||||
"rewrite_preview_url",
|
||||
]
|
||||
|
||||
@@ -11,7 +11,11 @@
|
||||
"type": "LORA",
|
||||
"nsfw": false,
|
||||
"description": "description",
|
||||
"tags": ["style"]
|
||||
"tags": ["style"],
|
||||
"allowNoCredit": true,
|
||||
"allowCommercialUse": ["Sell"],
|
||||
"allowDerivatives": true,
|
||||
"allowDifferentLicense": true
|
||||
},
|
||||
"files": [
|
||||
{
|
||||
|
||||
@@ -75,6 +75,10 @@ async def test_update_model_metadata_merges_and_persists():
|
||||
"description": "desc",
|
||||
"tags": ["style"],
|
||||
"creator": {"id": 2},
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image"],
|
||||
"allowDerivatives": False,
|
||||
"allowDifferentLicense": True,
|
||||
},
|
||||
"baseModel": "sdxl",
|
||||
"images": ["img"],
|
||||
@@ -92,6 +96,13 @@ async def test_update_model_metadata_merges_and_persists():
|
||||
assert result["modelDescription"] == "desc"
|
||||
assert result["tags"] == ["style"]
|
||||
assert result["base_model"] == "SDXL 1.0"
|
||||
civitai_model = result["civitai"]["model"]
|
||||
assert civitai_model["allowNoCredit"] is False
|
||||
assert civitai_model["allowCommercialUse"] == ["Image"]
|
||||
assert civitai_model["allowDerivatives"] is False
|
||||
assert civitai_model["allowDifferentLicense"] is True
|
||||
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
|
||||
assert key not in result
|
||||
|
||||
helpers.preview_service.ensure_preview_for_metadata.assert_awaited_once()
|
||||
helpers.metadata_manager.save_metadata.assert_awaited_once_with(
|
||||
@@ -142,6 +153,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path):
|
||||
assert model_data["civitai_deleted"] is False
|
||||
assert "civitai" in model_data
|
||||
assert model_data["metadata_source"] == "civitai_api"
|
||||
civitai_model = model_data["civitai"]["model"]
|
||||
assert civitai_model["allowNoCredit"] is True
|
||||
assert civitai_model["allowDerivatives"] is True
|
||||
assert civitai_model["allowDifferentLicense"] is True
|
||||
assert civitai_model["allowCommercialUse"] == ["Sell"]
|
||||
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
|
||||
assert key not in model_data
|
||||
|
||||
helpers.metadata_manager.hydrate_model_data.assert_not_awaited()
|
||||
assert model_data["hydrated"] is True
|
||||
@@ -151,7 +169,13 @@ async def test_fetch_and_update_model_success_updates_cache(tmp_path):
|
||||
assert await_args, "expected metadata to be persisted"
|
||||
last_call = await_args[-1]
|
||||
assert last_call.args[0] == metadata_path
|
||||
assert last_call.args[1]["hydrated"] is True
|
||||
persisted_payload = last_call.args[1]
|
||||
assert persisted_payload["hydrated"] is True
|
||||
civitai_model = persisted_payload["civitai"]["model"]
|
||||
assert civitai_model["allowNoCredit"] is True
|
||||
assert civitai_model["allowCommercialUse"] == ["Sell"]
|
||||
for key in ("allowNoCredit", "allowCommercialUse", "allowDerivatives", "allowDifferentLicense"):
|
||||
assert key not in persisted_payload
|
||||
update_cache.assert_awaited_once()
|
||||
|
||||
|
||||
@@ -422,4 +446,3 @@ async def test_relink_metadata_raises_when_version_missing():
|
||||
model_id=9,
|
||||
model_version_id=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ from py.services import model_scanner
|
||||
from py.services.model_cache import ModelCache
|
||||
from py.services.model_hash_index import ModelHashIndex
|
||||
from py.services.model_scanner import CacheBuildResult, ModelScanner
|
||||
from py.services.persistent_model_cache import PersistentModelCache
|
||||
from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS
|
||||
from py.utils.civitai_utils import build_license_flags
|
||||
from py.utils.models import BaseModelMetadata
|
||||
|
||||
|
||||
@@ -122,6 +123,7 @@ async def test_initialize_cache_populates_cache(tmp_path: Path):
|
||||
_normalize_path(tmp_path / "one.txt"),
|
||||
_normalize_path(tmp_path / "nested" / "two.txt"),
|
||||
}
|
||||
assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS}
|
||||
|
||||
assert scanner._hash_index.get_path("hash-one") == _normalize_path(tmp_path / "one.txt")
|
||||
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
||||
@@ -179,12 +181,49 @@ async def test_initialize_in_background_applies_scan_result(tmp_path: Path, monk
|
||||
_normalize_path(tmp_path / "one.txt"),
|
||||
_normalize_path(tmp_path / "nested" / "two.txt"),
|
||||
}
|
||||
assert {item["license_flags"] for item in cache.raw_data} == {DEFAULT_LICENSE_FLAGS}
|
||||
assert scanner._hash_index.get_path("hash-two") == _normalize_path(tmp_path / "nested" / "two.txt")
|
||||
assert scanner._tags_count == {"alpha": 1, "beta": 1}
|
||||
assert scanner._excluded_models == [_normalize_path(tmp_path / "skip-file.txt")]
|
||||
assert ws_stub.payloads[-1]["progress"] == 100
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_cache_entry_encodes_license_flags(tmp_path: Path):
|
||||
scanner = DummyScanner(tmp_path)
|
||||
|
||||
metadata = {
|
||||
"file_path": _normalize_path(tmp_path / "sample.txt"),
|
||||
"file_name": "sample",
|
||||
"model_name": "Sample",
|
||||
"folder": "",
|
||||
"size": 1,
|
||||
"modified": 1.0,
|
||||
"sha256": "hash",
|
||||
"tags": [],
|
||||
"civitai": {
|
||||
"model": {
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image", "Rent"],
|
||||
"allowDerivatives": True,
|
||||
"allowDifferentLicense": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
expected_flags = build_license_flags(
|
||||
{
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": ["Image", "Rent"],
|
||||
"allowDerivatives": True,
|
||||
"allowDifferentLicense": False,
|
||||
}
|
||||
)
|
||||
|
||||
entry = scanner._build_cache_entry(metadata)
|
||||
assert entry["license_flags"] == expected_flags
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_in_background_uses_persisted_cache_without_full_scan(tmp_path: Path, monkeypatch):
|
||||
monkeypatch.setenv('LORA_MANAGER_DISABLE_PERSISTENT_CACHE', '0')
|
||||
|
||||
@@ -2,7 +2,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.persistent_model_cache import PersistentModelCache
|
||||
from py.services.persistent_model_cache import PersistentModelCache, DEFAULT_LICENSE_FLAGS
|
||||
|
||||
|
||||
def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
||||
@@ -43,6 +43,7 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
||||
'trainedWords': ['word1'],
|
||||
'creator': {'username': 'artist42'},
|
||||
},
|
||||
'license_flags': 13,
|
||||
},
|
||||
{
|
||||
'file_path': file_b,
|
||||
@@ -91,12 +92,14 @@ def test_persistent_cache_roundtrip(tmp_path: Path, monkeypatch) -> None:
|
||||
assert first['metadata_source'] == 'civitai_api'
|
||||
assert first['civitai']['creator']['username'] == 'artist42'
|
||||
assert first['civitai_deleted'] is False
|
||||
assert first['license_flags'] == 13
|
||||
|
||||
second = items[file_b]
|
||||
assert second['exclude'] is True
|
||||
assert second['civitai'] is None
|
||||
assert second['metadata_source'] is None
|
||||
assert second['civitai_deleted'] is True
|
||||
assert second['license_flags'] == DEFAULT_LICENSE_FLAGS
|
||||
|
||||
expected_hash_pairs = {
|
||||
('hash-a', file_a),
|
||||
|
||||
35
tests/utils/test_civitai_utils.py
Normal file
35
tests/utils/test_civitai_utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from py.utils.civitai_utils import (
|
||||
CommercialUseLevel,
|
||||
build_license_flags,
|
||||
resolve_license_info,
|
||||
resolve_license_payload,
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_license_payload_defaults():
|
||||
payload, flags = resolve_license_info({})
|
||||
|
||||
assert payload["allowNoCredit"] is True
|
||||
assert payload["allowDerivatives"] is True
|
||||
assert payload["allowDifferentLicense"] is True
|
||||
assert payload["allowCommercialUse"] == ["Sell"]
|
||||
assert flags == 57
|
||||
|
||||
|
||||
def test_build_license_flags_custom_values():
|
||||
source = {
|
||||
"allowNoCredit": False,
|
||||
"allowCommercialUse": {"Image", "Sell"},
|
||||
"allowDerivatives": False,
|
||||
"allowDifferentLicense": False,
|
||||
}
|
||||
|
||||
payload = resolve_license_payload(source)
|
||||
assert payload["allowNoCredit"] is False
|
||||
assert set(payload["allowCommercialUse"]) == {"Image", "Sell"}
|
||||
assert payload["allowDerivatives"] is False
|
||||
assert payload["allowDifferentLicense"] is False
|
||||
|
||||
flags = build_license_flags(source)
|
||||
# Highest commercial level is SELL -> level 4 -> shifted by 1 == 8.
|
||||
assert flags == (CommercialUseLevel.SELL << 1)
|
||||
Reference in New Issue
Block a user