diff --git a/py/routes/handlers/misc_handlers.py b/py/routes/handlers/misc_handlers.py index 459d19d5..2a6359b6 100644 --- a/py/routes/handlers/misc_handlers.py +++ b/py/routes/handlers/misc_handlers.py @@ -34,6 +34,7 @@ from ...utils.constants import ( SUPPORTED_MEDIA_EXTENSIONS, VALID_LORA_TYPES, ) +from ...utils.civitai_utils import rewrite_preview_url from ...utils.example_images_paths import is_valid_example_images_root from ...utils.lora_metadata import extract_trained_words from ...utils.usage_stats import UsageStats @@ -692,7 +693,10 @@ class ModelLibraryHandler: if images and isinstance(images, list): first_image = images[0] if isinstance(first_image, dict): - thumbnail_url = first_image.get("url") + raw_url = first_image.get("url") + media_type = first_image.get("type") + rewritten_url, _ = rewrite_preview_url(raw_url, media_type) + thumbnail_url = rewritten_url in_library = await scanner.check_model_version_exists(version_id_int) diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 1b8866ed..239c17d3 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -4,8 +4,10 @@ import asyncio from collections import OrderedDict import uuid from typing import Dict, List +from urllib.parse import urlparse from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS +from ..utils.civitai_utils import rewrite_preview_url from ..utils.exif_utils import ExifUtils from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry @@ -450,70 +452,103 @@ class DownloadManager: # Download preview image if available images = version_info.get('images', []) if images: - # Report preview download progress if progress_callback: await progress_callback(1) # 1% progress for starting preview download - # Check if it's a video or an image - is_video = images[0].get('type') == 'video' - - if (is_video): - # For videos, use .mp4 extension - preview_ext = '.mp4' - preview_path = os.path.splitext(save_path)[0] + preview_ext - - # Download video directly using downloader - downloader = await get_downloader() - success, result = await downloader.download_file( - images[0]['url'], - preview_path, - use_auth=False # Preview images typically don't need auth - ) - if success: - metadata.preview_url = preview_path.replace(os.sep, '/') - metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) - else: - # For images, use WebP format for better performance - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: - temp_path = temp_file.name - - # Download the original image to temp path using downloader - downloader = await get_downloader() - success, content, headers = await downloader.download_to_memory( - images[0]['url'], - use_auth=False - ) - if success: - # Save to temp file - with open(temp_path, 'wb') as f: - f.write(content) - # Optimize and convert to WebP - preview_path = os.path.splitext(save_path)[0] + '.webp' - - # Use ExifUtils to optimize and convert the image - optimized_data, _ = ExifUtils.optimize_image( - image_data=temp_path, - target_width=CARD_PREVIEW_WIDTH, - format='webp', - quality=85, - preserve_metadata=False - ) - - # Save the optimized image - with open(preview_path, 'wb') as f: - f.write(optimized_data) - - # Update metadata - metadata.preview_url = preview_path.replace(os.sep, '/') - metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) - - # Remove temporary file - try: - os.unlink(temp_path) - except Exception as e: - logger.warning(f"Failed to delete temp file: {e}") + first_image = images[0] if isinstance(images[0], dict) else None + preview_url = first_image.get('url') if first_image else None + media_type = (first_image.get('type') or '').lower() if first_image else '' + nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0 + + def _extension_from_url(url: str, fallback: str) -> str: + try: + parsed = urlparse(url) + except ValueError: + return fallback + ext = os.path.splitext(parsed.path)[1] + return ext or fallback + + preview_downloaded = False + preview_path = None + + if preview_url: + downloader = await get_downloader() + + if media_type == 'video': + preview_ext = _extension_from_url(preview_url, '.mp4') + preview_path = os.path.splitext(save_path)[0] + preview_ext + rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video') + attempt_urls: List[str] = [] + if rewritten: + attempt_urls.append(rewritten_url) + attempt_urls.append(preview_url) + + seen_attempts = set() + for attempt in attempt_urls: + if not attempt or attempt in seen_attempts: + continue + seen_attempts.add(attempt) + success, _ = await downloader.download_file( + attempt, + preview_path, + use_auth=False + ) + if success: + preview_downloaded = True + break + else: + rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image') + if rewritten: + preview_ext = _extension_from_url(preview_url, '.png') + preview_path = os.path.splitext(save_path)[0] + preview_ext + success, _ = await downloader.download_file( + rewritten_url, + preview_path, + use_auth=False + ) + if success: + preview_downloaded = True + + if not preview_downloaded: + temp_path: str | None = None + try: + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: + temp_path = temp_file.name + + success, content, _ = await downloader.download_to_memory( + preview_url, + use_auth=False + ) + if success: + with open(temp_path, 'wb') as temp_file_handle: + temp_file_handle.write(content) + preview_path = os.path.splitext(save_path)[0] + '.webp' + + optimized_data, _ = ExifUtils.optimize_image( + image_data=temp_path, + target_width=CARD_PREVIEW_WIDTH, + format='webp', + quality=85, + preserve_metadata=False + ) + + with open(preview_path, 'wb') as preview_file: + preview_file.write(optimized_data) + + preview_downloaded = True + finally: + if temp_path and os.path.exists(temp_path): + try: + os.unlink(temp_path) + except Exception as e: + logger.warning(f"Failed to delete temp file: {e}") + + if preview_downloaded and preview_path: + metadata.preview_url = preview_path.replace(os.sep, '/') + metadata.preview_nsfw_level = nsfw_level + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]['preview_path'] = preview_path - # Report preview download completion if progress_callback: await progress_callback(3) # 3% progress after preview download @@ -677,7 +712,15 @@ class DownloadManager: except Exception as e: logger.error(f"Error deleting metadata file: {e}") - # Delete preview file if exists (.webp or .mp4) + preview_path_value = download_info.get('preview_path') + if preview_path_value and os.path.exists(preview_path_value): + try: + os.unlink(preview_path_value) + logger.debug(f"Deleted preview file: {preview_path_value}") + except Exception as e: + logger.error(f"Error deleting preview file: {e}") + + # Delete preview file if exists (.webp or .mp4) for legacy paths for preview_ext in ['.webp', '.mp4']: preview_path = os.path.splitext(file_path)[0] + preview_ext if os.path.exists(preview_path): @@ -710,4 +753,4 @@ class DownloadManager: } for task_id, info in self._active_downloads.items() ] - } \ No newline at end of file + } diff --git a/py/services/preview_asset_service.py b/py/services/preview_asset_service.py index 42baadac..62c3b0a1 100644 --- a/py/services/preview_asset_service.py +++ b/py/services/preview_asset_service.py @@ -5,8 +5,10 @@ from __future__ import annotations import logging import os from typing import Awaitable, Callable, Dict, Optional, Sequence +from urllib.parse import urlparse from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS +from ..utils.civitai_utils import rewrite_preview_url logger = logging.getLogger(__name__) @@ -45,23 +47,59 @@ class PreviewAssetService: base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] preview_dir = os.path.dirname(metadata_path) is_video = first_preview.get("type") == "video" + preview_url = first_preview.get("url") + + if not preview_url: + return + + def extension_from_url(url: str, fallback: str) -> str: + try: + parsed = urlparse(url) + except ValueError: + return fallback + ext = os.path.splitext(parsed.path)[1] + return ext or fallback + + downloader = await self._downloader_factory() if is_video: - extension = ".mp4" + extension = extension_from_url(preview_url, ".mp4") preview_path = os.path.join(preview_dir, base_name + extension) - downloader = await self._downloader_factory() - success, result = await downloader.download_file( - first_preview["url"], preview_path, use_auth=False - ) - if success: - local_metadata["preview_url"] = preview_path.replace(os.sep, "/") - local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video") + + attempt_urls = [] + if rewritten: + attempt_urls.append(rewritten_url) + attempt_urls.append(preview_url) + + seen: set[str] = set() + for candidate in attempt_urls: + if not candidate or candidate in seen: + continue + seen.add(candidate) + + success, _ = await downloader.download_file(candidate, preview_path, use_auth=False) + if success: + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + return else: + rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image") + if rewritten: + extension = extension_from_url(preview_url, ".png") + preview_path = os.path.join(preview_dir, base_name + extension) + success, _ = await downloader.download_file( + rewritten_url, preview_path, use_auth=False + ) + if success: + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + return + extension = ".webp" preview_path = os.path.join(preview_dir, base_name + extension) - downloader = await self._downloader_factory() success, content, _headers = await downloader.download_to_memory( - first_preview["url"], use_auth=False + preview_url, use_auth=False ) if not success: return diff --git a/py/utils/civitai_utils.py b/py/utils/civitai_utils.py new file mode 100644 index 00000000..01308d4d --- /dev/null +++ b/py/utils/civitai_utils.py @@ -0,0 +1,48 @@ +"""Utilities for working with Civitai assets.""" + +from __future__ import annotations + +from urllib.parse import urlparse, urlunparse + + +def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]: + """Rewrite Civitai preview URLs to use optimized renditions. + + Args: + source_url: Original preview URL from the Civitai API. + media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``). + + Returns: + A tuple of the potentially rewritten URL and a flag indicating whether the + replacement occurred. When the URL is not rewritten, the original value is + returned with ``False``. + """ + if not source_url: + return source_url, False + + try: + parsed = urlparse(source_url) + except ValueError: + return source_url, False + + if parsed.netloc.lower() != "image.civitai.com": + return source_url, False + + replacement = "/width=450,optimized=true" + if (media_type or "").lower() == "video": + replacement = "/transcode=true,width=450,optimized=true" + + if "/original=true" not in parsed.path: + return source_url, False + + updated_path = parsed.path.replace("/original=true", replacement, 1) + if updated_path == parsed.path: + return source_url, False + + rewritten = urlunparse(parsed._replace(path=updated_path)) + print(rewritten) + return rewritten, True + + +__all__ = ["rewrite_preview_url"] + diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 3dca0515..79dab8e4 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -526,6 +526,64 @@ async def test_get_civitai_user_models_marks_library_versions(): assert provider.received_usernames == ["pixel"] +@pytest.mark.asyncio +async def test_get_civitai_user_models_rewrites_civitai_previews(): + image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg" + video_url = "https://image.civitai.com/container/example/original=true/sample.mp4" + + models = [ + { + "id": 1, + "name": "Model A", + "type": "LORA", + "tags": ["style"], + "modelVersions": [ + { + "id": 100, + "name": "preview-image", + "baseModel": "Flux.1", + "images": [ + {"url": image_url, "type": "image"}, + ], + }, + { + "id": 101, + "name": "preview-video", + "baseModel": "Flux.1", + "images": [ + {"url": video_url, "type": "video"}, + ], + }, + ], + }, + ] + + provider = FakeUserModelsProvider(models) + + async def provider_factory(): + return provider + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=fake_scanner_factory, + get_checkpoint_scanner=fake_scanner_factory, + get_embedding_scanner=fake_scanner_factory, + ), + metadata_provider_factory=provider_factory, + ) + + response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"})) + payload = json.loads(response.text) + + assert payload["success"] is True + previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]} + assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg" + assert ( + previews_by_version[101] + == "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4" + ) + + @pytest.mark.asyncio async def test_get_civitai_user_models_requires_username(): provider = FakeUserModelsProvider([]) diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index f8bd3688..48b425af 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -394,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path): assert result == {"success": True} assert [url for url, *_ in dummy_downloader.calls] == download_urls assert dummy_scanner.calls # ensure cache updated + + +@pytest.mark.asyncio +async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path): + manager = DownloadManager() + save_dir = tmp_path / "downloads" + save_dir.mkdir() + target_path = save_dir / "file.safetensors" + + manager._active_downloads["dl"] = {} + + class DummyMetadata: + def __init__(self, path: Path): + self.file_path = str(path) + self.sha256 = "sha256" + self.file_name = path.stem + self.preview_url = None + self.preview_nsfw_level = None + + def generate_unique_filename(self, *_args, **_kwargs): + return os.path.basename(self.file_path) + + def update_file_info(self, _path): + return None + + def to_dict(self): + return {"file_path": self.file_path} + + metadata = DummyMetadata(target_path) + version_info = { + "images": [ + { + "url": "https://image.civitai.com/container/example/original=true/sample.jpeg", + "type": "image", + "nsfwLevel": 2, + } + ] + } + download_urls = ["https://example.invalid/file.safetensors"] + + class DummyDownloader: + def __init__(self): + self.file_calls: list[tuple[str, str]] = [] + self.memory_calls = 0 + + async def download_file(self, url, path, progress_callback=None, use_auth=None): + self.file_calls.append((url, path)) + if url.endswith(".jpeg"): + Path(path).write_bytes(b"preview") + return True, None + if url.endswith(".safetensors"): + Path(path).write_bytes(b"model") + return True, None + return False, "unexpected url" + + async def download_to_memory(self, *_args, **_kwargs): + self.memory_calls += 1 + return False, b"", {} + + dummy_downloader = DummyDownloader() + monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader)) + + optimize_called = {"value": False} + + def fake_optimize_image(**_kwargs): + optimize_called["value"] = True + return b"", {} + + monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image)) + monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) + + dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) + monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + + result = await manager._execute_download( + download_urls=download_urls, + save_dir=str(save_dir), + metadata=metadata, + version_info=version_info, + relative_path="", + progress_callback=None, + model_type="lora", + download_id="dl", + ) + + assert result == {"success": True} + preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")] + assert any("width=450,optimized=true" in url for url in preview_urls) + assert dummy_downloader.memory_calls == 0 + assert optimize_called["value"] is False + assert metadata.preview_url.endswith(".jpeg") + assert metadata.preview_nsfw_level == 2 + stored_preview = manager._active_downloads["dl"]["preview_path"] + assert stored_preview.endswith(".jpeg") + assert Path(stored_preview).exists() diff --git a/tests/services/test_preview_asset_service.py b/tests/services/test_preview_asset_service.py new file mode 100644 index 00000000..e4a2a68f --- /dev/null +++ b/tests/services/test_preview_asset_service.py @@ -0,0 +1,182 @@ +from pathlib import Path +from typing import Any + +import pytest + +from py.services.preview_asset_service import PreviewAssetService + + +class StubMetadataManager: + async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper + return True + + +class RecordingExifUtils: + def __init__(self) -> None: + self.called = False + + def optimize_image(self, **kwargs): + self.called = True + return kwargs["image_data"], {} + + +@pytest.mark.asyncio +async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path): + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text("{}") + local_metadata: dict[str, Any] = {} + + class Downloader: + def __init__(self): + self.file_calls: list[tuple[str, str]] = [] + self.memory_calls = 0 + + async def download_file(self, url, path, use_auth=False): + self.file_calls.append((url, path)) + if "width=450,optimized=true" in url: + Path(path).write_bytes(b"image-data") + return True, None + return False, "fail" + + async def download_to_memory(self, *_args, **_kwargs): + self.memory_calls += 1 + return False, b"", {} + + downloader = Downloader() + + async def downloader_factory(): + return downloader + + exif_utils = RecordingExifUtils() + service = PreviewAssetService( + metadata_manager=StubMetadataManager(), + downloader_factory=downloader_factory, + exif_utils=exif_utils, + ) + + images = [ + { + "url": "https://image.civitai.com/container/example/original=true/sample.jpeg", + "type": "image", + "nsfwLevel": 3, + } + ] + + await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images) + + assert downloader.memory_calls == 0 + assert exif_utils.called is False + assert len(downloader.file_calls) == 1 + assert "width=450,optimized=true" in downloader.file_calls[0][0] + preview_path = Path(local_metadata["preview_url"]) + assert preview_path.exists() + assert preview_path.suffix == ".jpeg" + assert local_metadata["preview_nsfw_level"] == 3 + + +@pytest.mark.asyncio +async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path): + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text("{}") + local_metadata: dict[str, Any] = {} + + class Downloader: + def __init__(self): + self.file_calls: list[tuple[str, str]] = [] + self.memory_calls = 0 + + async def download_file(self, url, path, use_auth=False): + self.file_calls.append((url, path)) + return False, "fail" + + async def download_to_memory(self, *_args, **_kwargs): + self.memory_calls += 1 + return True, b"raw-image", {} + + downloader = Downloader() + + async def downloader_factory(): + return downloader + + class ExifUtils: + def __init__(self): + self.calls = 0 + + def optimize_image(self, **kwargs): + self.calls += 1 + return b"webp-data", {} + + exif_utils = ExifUtils() + + service = PreviewAssetService( + metadata_manager=StubMetadataManager(), + downloader_factory=downloader_factory, + exif_utils=exif_utils, + ) + + images = [ + { + "url": "https://image.civitai.com/container/example/original=true/sample.png", + "type": "image", + "nsfwLevel": 1, + } + ] + + await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images) + + assert downloader.memory_calls == 1 + assert exif_utils.calls == 1 + preview_path = Path(local_metadata["preview_url"]) + assert preview_path.exists() + assert preview_path.suffix == ".webp" + + +@pytest.mark.asyncio +async def test_ensure_preview_rewrites_civitai_video(tmp_path): + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text("{}") + local_metadata: dict[str, Any] = {} + + class Downloader: + def __init__(self): + self.file_calls: list[tuple[str, str]] = [] + + async def download_file(self, url, path, use_auth=False): + self.file_calls.append((url, path)) + if "transcode=true,width=450,optimized=true" in url: + Path(path).write_bytes(b"video-data") + return True, None + if url.endswith(".mp4"): + return False, "fail" + return False, "unexpected" + + async def download_to_memory(self, *_args, **_kwargs): + pytest.fail("download_to_memory should not be used for video previews") + + downloader = Downloader() + + async def downloader_factory(): + return downloader + + service = PreviewAssetService( + metadata_manager=StubMetadataManager(), + downloader_factory=downloader_factory, + exif_utils=RecordingExifUtils(), + ) + + images = [ + { + "url": "https://image.civitai.com/container/example/original=true/sample.mp4", + "type": "video", + "nsfwLevel": 2, + } + ] + + await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images) + + assert len(downloader.file_calls) >= 1 + assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls) + preview_path = Path(local_metadata["preview_url"]) + assert preview_path.exists() + assert preview_path.suffix == ".mp4" + assert local_metadata["preview_nsfw_level"] == 2