diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 3dffbc01..505d997f 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -9,6 +9,7 @@ from urllib.parse import urlparse from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES from ..utils.civitai_utils import rewrite_preview_url +from ..utils.preview_selection import select_preview_media from ..utils.utils import sanitize_folder_name from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager @@ -495,10 +496,21 @@ class DownloadManager: if progress_callback: await progress_callback(1) # 1% progress for starting preview download - first_image = images[0] if isinstance(images[0], dict) else None - preview_url = first_image.get('url') if first_image else None - media_type = (first_image.get('type') or '').lower() if first_image else '' - nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0 + settings_manager = get_settings_manager() + blur_mature_content = bool( + settings_manager.get('blur_mature_content', True) + ) + selected_image, nsfw_level = select_preview_media( + images, + blur_mature_content=blur_mature_content, + ) + + preview_url = selected_image.get('url') if selected_image else None + media_type = ( + (selected_image.get('type') or '').lower() + if selected_image + else '' + ) def _extension_from_url(url: str, fallback: str) -> str: try: diff --git a/py/services/preview_asset_service.py b/py/services/preview_asset_service.py index 62c3b0a1..9d339a3e 100644 --- a/py/services/preview_asset_service.py +++ b/py/services/preview_asset_service.py @@ -9,6 +9,8 @@ from urllib.parse import urlparse from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS from ..utils.civitai_utils import rewrite_preview_url +from ..utils.preview_selection import select_preview_media +from .settings_manager import get_settings_manager logger = logging.getLogger(__name__) @@ -43,7 +45,18 @@ class PreviewAssetService: if not images: return - first_preview = images[0] + settings_manager = get_settings_manager() + blur_mature_content = bool( + settings_manager.get("blur_mature_content", True) + ) + first_preview, nsfw_level = select_preview_media( + images, + blur_mature_content=blur_mature_content, + ) + + if not first_preview: + return + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] preview_dir = os.path.dirname(metadata_path) is_video = first_preview.get("type") == "video" @@ -81,7 +94,7 @@ class PreviewAssetService: success, _ = await downloader.download_file(candidate, preview_path, use_auth=False) if success: local_metadata["preview_url"] = preview_path.replace(os.sep, "/") - local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + local_metadata["preview_nsfw_level"] = nsfw_level return else: rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image") @@ -93,7 +106,7 @@ class PreviewAssetService: ) if success: local_metadata["preview_url"] = preview_path.replace(os.sep, "/") - local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + local_metadata["preview_nsfw_level"] = nsfw_level return extension = ".webp" @@ -124,7 +137,7 @@ class PreviewAssetService: return local_metadata["preview_url"] = preview_path.replace(os.sep, "/") - local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + local_metadata["preview_nsfw_level"] = nsfw_level async def replace_preview( self, diff --git a/py/utils/preview_selection.py b/py/utils/preview_selection.py new file mode 100644 index 00000000..a815a81b --- /dev/null +++ b/py/utils/preview_selection.py @@ -0,0 +1,63 @@ +"""Utilities for selecting preview media from Civitai image metadata.""" + +from __future__ import annotations + +from typing import Mapping, Optional, Sequence, Tuple + +from .constants import NSFW_LEVELS + +PreviewMedia = Mapping[str, object] + + +def _extract_nsfw_level(entry: Mapping[str, object]) -> int: + """Return a normalized NSFW level value for the supplied media entry.""" + + value = entry.get("nsfwLevel", 0) + try: + return int(value) # type: ignore[return-value] + except (TypeError, ValueError): + return 0 + + +def select_preview_media( + images: Sequence[Mapping[str, object]] | None, + *, + blur_mature_content: bool, +) -> Tuple[Optional[PreviewMedia], int]: + """Select the most appropriate preview media entry. + + When ``blur_mature_content`` is enabled we first try to return the first media + item with an ``nsfwLevel`` lower than :pydata:`NSFW_LEVELS["R"]`. If none are + available we return the media entry with the lowest NSFW level. When the + setting is disabled we simply return the first entry. + """ + + if not images: + return None, 0 + + candidates = [item for item in images if isinstance(item, Mapping)] + if not candidates: + return None, 0 + + selected = candidates[0] + selected_level = _extract_nsfw_level(selected) + + if not blur_mature_content: + return selected, selected_level + + safe_threshold = NSFW_LEVELS.get("R", 4) + for candidate in candidates: + level = _extract_nsfw_level(candidate) + if level < safe_threshold: + return candidate, level + + for candidate in candidates[1:]: + level = _extract_nsfw_level(candidate) + if level < selected_level: + selected = candidate + selected_level = level + + return selected, selected_level + + +__all__ = ["select_preview_media"] diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index 7ecd163e..c8bbeed8 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -699,3 +699,116 @@ async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_ stored_preview = manager._active_downloads["dl"]["preview_path"] assert stored_preview.endswith(".jpeg") assert Path(stored_preview).exists() + + +@pytest.mark.asyncio +async def test_execute_download_respects_blur_setting(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + manager._active_downloads["dl"] = {} + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + self.preview_nsfw_level = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + metadata = DummyMetadata(target_path) + version_info = { + "images": [ + { + "url": "https://image.civitai.com/container/example/original=true/nsfw.jpeg", + "type": "image", + "nsfwLevel": 8, + }, + { + "url": "https://image.civitai.com/container/example/original=true/safe.jpeg", + "type": "image", + "nsfwLevel": 1, + }, + ], + "files": [ + { + "type": "Model", + "primary": True, + "downloadUrl": "https://example.invalid/file.safetensors", + "name": "file.safetensors", + } + ], + } + download_urls = ["https://example.invalid/file.safetensors"] + + class DummyDownloader: + def __init__(self): + self.file_calls: list[tuple[str, str]] = [] + + async def download_file(self, url, path, progress_callback=None, use_auth=None): + self.file_calls.append((url, path)) + if url.endswith(".safetensors"): + Path(path).write_bytes(b"model") + return True, None + if "safe.jpeg" in url: + Path(path).write_bytes(b"preview") + return True, None + return False, "unexpected url" + + async def download_to_memory(self, *_args, **_kwargs): + return False, b"", {} + + dummy_downloader = DummyDownloader() + + class StubSettingsManager: + def __init__(self, blur: bool) -> None: + self.blur = blur + + def get(self, key: str, default=None): + if key == "blur_mature_content": + return self.blur + return default + + monkeypatch.setattr( + download_manager, + "get_settings_manager", + lambda: StubSettingsManager(True), + ) + + monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader)) + monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(lambda **_kwargs: (b"", {}))) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + + result = await manager._execute_download( + download_urls=download_urls, + save_dir=str(save_dir), + metadata=metadata, + version_info=version_info, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="dl", + ) + + assert result == {"success": True} + preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")] + assert preview_urls + assert all("nsfw.jpeg" not in url for url in preview_urls) + assert any("safe.jpeg" in url for url in preview_urls) + assert metadata.preview_nsfw_level == 1 + stored_preview = manager._active_downloads["dl"].get("preview_path") + assert stored_preview and stored_preview.endswith(".jpeg") diff --git a/tests/services/test_preview_asset_service.py b/tests/services/test_preview_asset_service.py index e4a2a68f..c67d78df 100644 --- a/tests/services/test_preview_asset_service.py +++ b/tests/services/test_preview_asset_service.py @@ -3,6 +3,7 @@ from typing import Any import pytest +from py.services import preview_asset_service from py.services.preview_asset_service import PreviewAssetService @@ -180,3 +181,68 @@ async def test_ensure_preview_rewrites_civitai_video(tmp_path): assert preview_path.exists() assert preview_path.suffix == ".mp4" assert local_metadata["preview_nsfw_level"] == 2 + + +@pytest.mark.asyncio +async def test_ensure_preview_respects_blur_setting(monkeypatch, tmp_path): + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text("{}") + local_metadata: dict[str, Any] = {} + + class Downloader: + def __init__(self): + self.file_calls: list[tuple[str, str]] = [] + + async def download_file(self, url, path, use_auth=False): + self.file_calls.append((url, path)) + Path(path).write_bytes(b"image-data") + return True, None + + async def download_to_memory(self, *_args, **_kwargs): + pytest.fail("download_to_memory should not be used when download_file succeeds") + + downloader = Downloader() + + async def downloader_factory(): + return downloader + + class StubSettingsManager: + def __init__(self, blur: bool) -> None: + self.blur = blur + + def get(self, key: str, default=None): + if key == "blur_mature_content": + return self.blur + return default + + monkeypatch.setattr( + preview_asset_service, + "get_settings_manager", + lambda: StubSettingsManager(True), + ) + + service = PreviewAssetService( + metadata_manager=StubMetadataManager(), + downloader_factory=downloader_factory, + exif_utils=RecordingExifUtils(), + ) + + images = [ + { + "url": "https://image.civitai.com/container/example/original=true/nsfw.jpeg", + "type": "image", + "nsfwLevel": 8, + }, + { + "url": "https://image.civitai.com/container/example/original=true/safe.jpeg", + "type": "image", + "nsfwLevel": 1, + }, + ] + + await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images) + + assert len(downloader.file_calls) == 1 + requested_url = downloader.file_calls[0][0] + assert "safe.jpeg" in requested_url + assert local_metadata["preview_nsfw_level"] == 1 diff --git a/tests/utils/test_preview_selection.py b/tests/utils/test_preview_selection.py new file mode 100644 index 00000000..5109c9e8 --- /dev/null +++ b/tests/utils/test_preview_selection.py @@ -0,0 +1,39 @@ +from py.utils.preview_selection import select_preview_media + + +def test_select_preview_prefers_safe_media_when_blurred(): + images = [ + {"url": "nsfw", "type": "image", "nsfwLevel": 8}, + {"url": "mid", "type": "image", "nsfwLevel": 4}, + {"url": "safe", "type": "image", "nsfwLevel": 1}, + ] + + selected, level = select_preview_media(images, blur_mature_content=True) + + assert selected["url"] == "safe" + assert level == 1 + + +def test_select_preview_returns_lowest_when_no_safe_media(): + images = [ + {"url": "x", "type": "image", "nsfwLevel": 16}, + {"url": "r", "type": "image", "nsfwLevel": 4}, + {"url": "xx", "type": "image", "nsfwLevel": 8}, + ] + + selected, level = select_preview_media(images, blur_mature_content=True) + + assert selected["url"] == "r" + assert level == 4 + + +def test_select_preview_returns_first_when_blur_disabled(): + images = [ + {"url": "nsfw", "type": "image", "nsfwLevel": 32}, + {"url": "safe", "type": "image", "nsfwLevel": 1}, + ] + + selected, level = select_preview_media(images, blur_mature_content=False) + + assert selected["url"] == "nsfw" + assert level == 32