mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(civitai): implement URL rewriting for Civitai previews and enhance download handling, fixes #499
This commit is contained in:
@@ -34,6 +34,7 @@ from ...utils.constants import (
|
||||
SUPPORTED_MEDIA_EXTENSIONS,
|
||||
VALID_LORA_TYPES,
|
||||
)
|
||||
from ...utils.civitai_utils import rewrite_preview_url
|
||||
from ...utils.example_images_paths import is_valid_example_images_root
|
||||
from ...utils.lora_metadata import extract_trained_words
|
||||
from ...utils.usage_stats import UsageStats
|
||||
@@ -692,7 +693,10 @@ class ModelLibraryHandler:
|
||||
if images and isinstance(images, list):
|
||||
first_image = images[0]
|
||||
if isinstance(first_image, dict):
|
||||
thumbnail_url = first_image.get("url")
|
||||
raw_url = first_image.get("url")
|
||||
media_type = first_image.get("type")
|
||||
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
|
||||
thumbnail_url = rewritten_url
|
||||
|
||||
in_library = await scanner.check_model_version_exists(version_id_int)
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ import asyncio
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
from urllib.parse import urlparse
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
@@ -450,70 +452,103 @@ class DownloadManager:
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
if images:
|
||||
# Report preview download progress
|
||||
if progress_callback:
|
||||
await progress_callback(1) # 1% progress for starting preview download
|
||||
|
||||
# Check if it's a video or an image
|
||||
is_video = images[0].get('type') == 'video'
|
||||
|
||||
if (is_video):
|
||||
# For videos, use .mp4 extension
|
||||
preview_ext = '.mp4'
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
|
||||
# Download video directly using downloader
|
||||
downloader = await get_downloader()
|
||||
success, result = await downloader.download_file(
|
||||
images[0]['url'],
|
||||
preview_path,
|
||||
use_auth=False # Preview images typically don't need auth
|
||||
)
|
||||
if success:
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
else:
|
||||
# For images, use WebP format for better performance
|
||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
# Download the original image to temp path using downloader
|
||||
downloader = await get_downloader()
|
||||
success, content, headers = await downloader.download_to_memory(
|
||||
images[0]['url'],
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
# Save to temp file
|
||||
with open(temp_path, 'wb') as f:
|
||||
f.write(content)
|
||||
# Optimize and convert to WebP
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
# Use ExifUtils to optimize and convert the image
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=temp_path,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format='webp',
|
||||
quality=85,
|
||||
preserve_metadata=False
|
||||
)
|
||||
|
||||
# Save the optimized image
|
||||
with open(preview_path, 'wb') as f:
|
||||
f.write(optimized_data)
|
||||
|
||||
# Update metadata
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
||||
|
||||
# Remove temporary file
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file: {e}")
|
||||
first_image = images[0] if isinstance(images[0], dict) else None
|
||||
preview_url = first_image.get('url') if first_image else None
|
||||
media_type = (first_image.get('type') or '').lower() if first_image else ''
|
||||
nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
|
||||
|
||||
def _extension_from_url(url: str, fallback: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return fallback
|
||||
ext = os.path.splitext(parsed.path)[1]
|
||||
return ext or fallback
|
||||
|
||||
preview_downloaded = False
|
||||
preview_path = None
|
||||
|
||||
if preview_url:
|
||||
downloader = await get_downloader()
|
||||
|
||||
if media_type == 'video':
|
||||
preview_ext = _extension_from_url(preview_url, '.mp4')
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
|
||||
attempt_urls: List[str] = []
|
||||
if rewritten:
|
||||
attempt_urls.append(rewritten_url)
|
||||
attempt_urls.append(preview_url)
|
||||
|
||||
seen_attempts = set()
|
||||
for attempt in attempt_urls:
|
||||
if not attempt or attempt in seen_attempts:
|
||||
continue
|
||||
seen_attempts.add(attempt)
|
||||
success, _ = await downloader.download_file(
|
||||
attempt,
|
||||
preview_path,
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
preview_downloaded = True
|
||||
break
|
||||
else:
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
|
||||
if rewritten:
|
||||
preview_ext = _extension_from_url(preview_url, '.png')
|
||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||
success, _ = await downloader.download_file(
|
||||
rewritten_url,
|
||||
preview_path,
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
preview_downloaded = True
|
||||
|
||||
if not preview_downloaded:
|
||||
temp_path: str | None = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
success, content, _ = await downloader.download_to_memory(
|
||||
preview_url,
|
||||
use_auth=False
|
||||
)
|
||||
if success:
|
||||
with open(temp_path, 'wb') as temp_file_handle:
|
||||
temp_file_handle.write(content)
|
||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=temp_path,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format='webp',
|
||||
quality=85,
|
||||
preserve_metadata=False
|
||||
)
|
||||
|
||||
with open(preview_path, 'wb') as preview_file:
|
||||
preview_file.write(optimized_data)
|
||||
|
||||
preview_downloaded = True
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete temp file: {e}")
|
||||
|
||||
if preview_downloaded and preview_path:
|
||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||
metadata.preview_nsfw_level = nsfw_level
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['preview_path'] = preview_path
|
||||
|
||||
# Report preview download completion
|
||||
if progress_callback:
|
||||
await progress_callback(3) # 3% progress after preview download
|
||||
|
||||
@@ -677,7 +712,15 @@ class DownloadManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting metadata file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4)
|
||||
preview_path_value = download_info.get('preview_path')
|
||||
if preview_path_value and os.path.exists(preview_path_value):
|
||||
try:
|
||||
os.unlink(preview_path_value)
|
||||
logger.debug(f"Deleted preview file: {preview_path_value}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting preview file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4) for legacy paths
|
||||
for preview_ext in ['.webp', '.mp4']:
|
||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||
if os.path.exists(preview_path):
|
||||
@@ -710,4 +753,4 @@ class DownloadManager:
|
||||
}
|
||||
for task_id, info in self._active_downloads.items()
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,8 +5,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||
from ..utils.civitai_utils import rewrite_preview_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,23 +47,59 @@ class PreviewAssetService:
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_dir = os.path.dirname(metadata_path)
|
||||
is_video = first_preview.get("type") == "video"
|
||||
preview_url = first_preview.get("url")
|
||||
|
||||
if not preview_url:
|
||||
return
|
||||
|
||||
def extension_from_url(url: str, fallback: str) -> str:
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except ValueError:
|
||||
return fallback
|
||||
ext = os.path.splitext(parsed.path)[1]
|
||||
return ext or fallback
|
||||
|
||||
downloader = await self._downloader_factory()
|
||||
|
||||
if is_video:
|
||||
extension = ".mp4"
|
||||
extension = extension_from_url(preview_url, ".mp4")
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, result = await downloader.download_file(
|
||||
first_preview["url"], preview_path, use_auth=False
|
||||
)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
|
||||
|
||||
attempt_urls = []
|
||||
if rewritten:
|
||||
attempt_urls.append(rewritten_url)
|
||||
attempt_urls.append(preview_url)
|
||||
|
||||
seen: set[str] = set()
|
||||
for candidate in attempt_urls:
|
||||
if not candidate or candidate in seen:
|
||||
continue
|
||||
seen.add(candidate)
|
||||
|
||||
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
return
|
||||
else:
|
||||
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
|
||||
if rewritten:
|
||||
extension = extension_from_url(preview_url, ".png")
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
success, _ = await downloader.download_file(
|
||||
rewritten_url, preview_path, use_auth=False
|
||||
)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
return
|
||||
|
||||
extension = ".webp"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, content, _headers = await downloader.download_to_memory(
|
||||
first_preview["url"], use_auth=False
|
||||
preview_url, use_auth=False
|
||||
)
|
||||
if not success:
|
||||
return
|
||||
|
||||
48
py/utils/civitai_utils.py
Normal file
48
py/utils/civitai_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Utilities for working with Civitai assets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
|
||||
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
|
||||
"""Rewrite Civitai preview URLs to use optimized renditions.
|
||||
|
||||
Args:
|
||||
source_url: Original preview URL from the Civitai API.
|
||||
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
|
||||
|
||||
Returns:
|
||||
A tuple of the potentially rewritten URL and a flag indicating whether the
|
||||
replacement occurred. When the URL is not rewritten, the original value is
|
||||
returned with ``False``.
|
||||
"""
|
||||
if not source_url:
|
||||
return source_url, False
|
||||
|
||||
try:
|
||||
parsed = urlparse(source_url)
|
||||
except ValueError:
|
||||
return source_url, False
|
||||
|
||||
if parsed.netloc.lower() != "image.civitai.com":
|
||||
return source_url, False
|
||||
|
||||
replacement = "/width=450,optimized=true"
|
||||
if (media_type or "").lower() == "video":
|
||||
replacement = "/transcode=true,width=450,optimized=true"
|
||||
|
||||
if "/original=true" not in parsed.path:
|
||||
return source_url, False
|
||||
|
||||
updated_path = parsed.path.replace("/original=true", replacement, 1)
|
||||
if updated_path == parsed.path:
|
||||
return source_url, False
|
||||
|
||||
rewritten = urlunparse(parsed._replace(path=updated_path))
|
||||
print(rewritten)
|
||||
return rewritten, True
|
||||
|
||||
|
||||
__all__ = ["rewrite_preview_url"]
|
||||
|
||||
@@ -526,6 +526,64 @@ async def test_get_civitai_user_models_marks_library_versions():
|
||||
assert provider.received_usernames == ["pixel"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_rewrites_civitai_previews():
|
||||
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
|
||||
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
|
||||
|
||||
models = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "Model A",
|
||||
"type": "LORA",
|
||||
"tags": ["style"],
|
||||
"modelVersions": [
|
||||
{
|
||||
"id": 100,
|
||||
"name": "preview-image",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [
|
||||
{"url": image_url, "type": "image"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": 101,
|
||||
"name": "preview-video",
|
||||
"baseModel": "Flux.1",
|
||||
"images": [
|
||||
{"url": video_url, "type": "video"},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
provider = FakeUserModelsProvider(models)
|
||||
|
||||
async def provider_factory():
|
||||
return provider
|
||||
|
||||
handler = ModelLibraryHandler(
|
||||
ServiceRegistryAdapter(
|
||||
get_lora_scanner=fake_scanner_factory,
|
||||
get_checkpoint_scanner=fake_scanner_factory,
|
||||
get_embedding_scanner=fake_scanner_factory,
|
||||
),
|
||||
metadata_provider_factory=provider_factory,
|
||||
)
|
||||
|
||||
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||
payload = json.loads(response.text)
|
||||
|
||||
assert payload["success"] is True
|
||||
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
|
||||
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
|
||||
assert (
|
||||
previews_by_version[101]
|
||||
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_civitai_user_models_requires_username():
|
||||
provider = FakeUserModelsProvider([])
|
||||
|
||||
@@ -394,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
||||
assert result == {"success": True}
|
||||
assert [url for url, *_ in dummy_downloader.calls] == download_urls
|
||||
assert dummy_scanner.calls # ensure cache updated
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
|
||||
manager = DownloadManager()
|
||||
save_dir = tmp_path / "downloads"
|
||||
save_dir.mkdir()
|
||||
target_path = save_dir / "file.safetensors"
|
||||
|
||||
manager._active_downloads["dl"] = {}
|
||||
|
||||
class DummyMetadata:
|
||||
def __init__(self, path: Path):
|
||||
self.file_path = str(path)
|
||||
self.sha256 = "sha256"
|
||||
self.file_name = path.stem
|
||||
self.preview_url = None
|
||||
self.preview_nsfw_level = None
|
||||
|
||||
def generate_unique_filename(self, *_args, **_kwargs):
|
||||
return os.path.basename(self.file_path)
|
||||
|
||||
def update_file_info(self, _path):
|
||||
return None
|
||||
|
||||
def to_dict(self):
|
||||
return {"file_path": self.file_path}
|
||||
|
||||
metadata = DummyMetadata(target_path)
|
||||
version_info = {
|
||||
"images": [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||
"type": "image",
|
||||
"nsfwLevel": 2,
|
||||
}
|
||||
]
|
||||
}
|
||||
download_urls = ["https://example.invalid/file.safetensors"]
|
||||
|
||||
class DummyDownloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
self.memory_calls = 0
|
||||
|
||||
async def download_file(self, url, path, progress_callback=None, use_auth=None):
|
||||
self.file_calls.append((url, path))
|
||||
if url.endswith(".jpeg"):
|
||||
Path(path).write_bytes(b"preview")
|
||||
return True, None
|
||||
if url.endswith(".safetensors"):
|
||||
Path(path).write_bytes(b"model")
|
||||
return True, None
|
||||
return False, "unexpected url"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
self.memory_calls += 1
|
||||
return False, b"", {}
|
||||
|
||||
dummy_downloader = DummyDownloader()
|
||||
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
|
||||
|
||||
optimize_called = {"value": False}
|
||||
|
||||
def fake_optimize_image(**_kwargs):
|
||||
optimize_called["value"] = True
|
||||
return b"", {}
|
||||
|
||||
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
|
||||
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||
|
||||
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||
|
||||
result = await manager._execute_download(
|
||||
download_urls=download_urls,
|
||||
save_dir=str(save_dir),
|
||||
metadata=metadata,
|
||||
version_info=version_info,
|
||||
relative_path="",
|
||||
progress_callback=None,
|
||||
model_type="lora",
|
||||
download_id="dl",
|
||||
)
|
||||
|
||||
assert result == {"success": True}
|
||||
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
|
||||
assert any("width=450,optimized=true" in url for url in preview_urls)
|
||||
assert dummy_downloader.memory_calls == 0
|
||||
assert optimize_called["value"] is False
|
||||
assert metadata.preview_url.endswith(".jpeg")
|
||||
assert metadata.preview_nsfw_level == 2
|
||||
stored_preview = manager._active_downloads["dl"]["preview_path"]
|
||||
assert stored_preview.endswith(".jpeg")
|
||||
assert Path(stored_preview).exists()
|
||||
|
||||
182
tests/services/test_preview_asset_service.py
Normal file
182
tests/services/test_preview_asset_service.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.preview_asset_service import PreviewAssetService
|
||||
|
||||
|
||||
class StubMetadataManager:
|
||||
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
|
||||
return True
|
||||
|
||||
|
||||
class RecordingExifUtils:
|
||||
def __init__(self) -> None:
|
||||
self.called = False
|
||||
|
||||
def optimize_image(self, **kwargs):
|
||||
self.called = True
|
||||
return kwargs["image_data"], {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text("{}")
|
||||
local_metadata: dict[str, Any] = {}
|
||||
|
||||
class Downloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
self.memory_calls = 0
|
||||
|
||||
async def download_file(self, url, path, use_auth=False):
|
||||
self.file_calls.append((url, path))
|
||||
if "width=450,optimized=true" in url:
|
||||
Path(path).write_bytes(b"image-data")
|
||||
return True, None
|
||||
return False, "fail"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
self.memory_calls += 1
|
||||
return False, b"", {}
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
async def downloader_factory():
|
||||
return downloader
|
||||
|
||||
exif_utils = RecordingExifUtils()
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=StubMetadataManager(),
|
||||
downloader_factory=downloader_factory,
|
||||
exif_utils=exif_utils,
|
||||
)
|
||||
|
||||
images = [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||
"type": "image",
|
||||
"nsfwLevel": 3,
|
||||
}
|
||||
]
|
||||
|
||||
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||
|
||||
assert downloader.memory_calls == 0
|
||||
assert exif_utils.called is False
|
||||
assert len(downloader.file_calls) == 1
|
||||
assert "width=450,optimized=true" in downloader.file_calls[0][0]
|
||||
preview_path = Path(local_metadata["preview_url"])
|
||||
assert preview_path.exists()
|
||||
assert preview_path.suffix == ".jpeg"
|
||||
assert local_metadata["preview_nsfw_level"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text("{}")
|
||||
local_metadata: dict[str, Any] = {}
|
||||
|
||||
class Downloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
self.memory_calls = 0
|
||||
|
||||
async def download_file(self, url, path, use_auth=False):
|
||||
self.file_calls.append((url, path))
|
||||
return False, "fail"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
self.memory_calls += 1
|
||||
return True, b"raw-image", {}
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
async def downloader_factory():
|
||||
return downloader
|
||||
|
||||
class ExifUtils:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
def optimize_image(self, **kwargs):
|
||||
self.calls += 1
|
||||
return b"webp-data", {}
|
||||
|
||||
exif_utils = ExifUtils()
|
||||
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=StubMetadataManager(),
|
||||
downloader_factory=downloader_factory,
|
||||
exif_utils=exif_utils,
|
||||
)
|
||||
|
||||
images = [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.png",
|
||||
"type": "image",
|
||||
"nsfwLevel": 1,
|
||||
}
|
||||
]
|
||||
|
||||
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||
|
||||
assert downloader.memory_calls == 1
|
||||
assert exif_utils.calls == 1
|
||||
preview_path = Path(local_metadata["preview_url"])
|
||||
assert preview_path.exists()
|
||||
assert preview_path.suffix == ".webp"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text("{}")
|
||||
local_metadata: dict[str, Any] = {}
|
||||
|
||||
class Downloader:
|
||||
def __init__(self):
|
||||
self.file_calls: list[tuple[str, str]] = []
|
||||
|
||||
async def download_file(self, url, path, use_auth=False):
|
||||
self.file_calls.append((url, path))
|
||||
if "transcode=true,width=450,optimized=true" in url:
|
||||
Path(path).write_bytes(b"video-data")
|
||||
return True, None
|
||||
if url.endswith(".mp4"):
|
||||
return False, "fail"
|
||||
return False, "unexpected"
|
||||
|
||||
async def download_to_memory(self, *_args, **_kwargs):
|
||||
pytest.fail("download_to_memory should not be used for video previews")
|
||||
|
||||
downloader = Downloader()
|
||||
|
||||
async def downloader_factory():
|
||||
return downloader
|
||||
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=StubMetadataManager(),
|
||||
downloader_factory=downloader_factory,
|
||||
exif_utils=RecordingExifUtils(),
|
||||
)
|
||||
|
||||
images = [
|
||||
{
|
||||
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
|
||||
"type": "video",
|
||||
"nsfwLevel": 2,
|
||||
}
|
||||
]
|
||||
|
||||
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||
|
||||
assert len(downloader.file_calls) >= 1
|
||||
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
|
||||
preview_path = Path(local_metadata["preview_url"])
|
||||
assert preview_path.exists()
|
||||
assert preview_path.suffix == ".mp4"
|
||||
assert local_metadata["preview_nsfw_level"] == 2
|
||||
Reference in New Issue
Block a user