feat(civitai): implement URL rewriting for Civitai previews and enhance download handling, fixes #499

This commit is contained in:
Will Miao
2025-10-09 17:54:37 +08:00
parent d2c2bfbe6a
commit f542ade628
7 changed files with 541 additions and 73 deletions

View File

@@ -34,6 +34,7 @@ from ...utils.constants import (
SUPPORTED_MEDIA_EXTENSIONS, SUPPORTED_MEDIA_EXTENSIONS,
VALID_LORA_TYPES, VALID_LORA_TYPES,
) )
from ...utils.civitai_utils import rewrite_preview_url
from ...utils.example_images_paths import is_valid_example_images_root from ...utils.example_images_paths import is_valid_example_images_root
from ...utils.lora_metadata import extract_trained_words from ...utils.lora_metadata import extract_trained_words
from ...utils.usage_stats import UsageStats from ...utils.usage_stats import UsageStats
@@ -692,7 +693,10 @@ class ModelLibraryHandler:
if images and isinstance(images, list): if images and isinstance(images, list):
first_image = images[0] first_image = images[0]
if isinstance(first_image, dict): if isinstance(first_image, dict):
thumbnail_url = first_image.get("url") raw_url = first_image.get("url")
media_type = first_image.get("type")
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
thumbnail_url = rewritten_url
in_library = await scanner.check_model_version_exists(version_id_int) in_library = await scanner.check_model_version_exists(version_id_int)

View File

@@ -4,8 +4,10 @@ import asyncio
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict, List from typing import Dict, List
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
from ..utils.civitai_utils import rewrite_preview_url
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
@@ -450,47 +452,78 @@ class DownloadManager:
# Download preview image if available # Download preview image if available
images = version_info.get('images', []) images = version_info.get('images', [])
if images: if images:
# Report preview download progress
if progress_callback: if progress_callback:
await progress_callback(1) # 1% progress for starting preview download await progress_callback(1) # 1% progress for starting preview download
# Check if it's a video or an image first_image = images[0] if isinstance(images[0], dict) else None
is_video = images[0].get('type') == 'video' preview_url = first_image.get('url') if first_image else None
media_type = (first_image.get('type') or '').lower() if first_image else ''
nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
if (is_video): def _extension_from_url(url: str, fallback: str) -> str:
# For videos, use .mp4 extension try:
preview_ext = '.mp4' parsed = urlparse(url)
except ValueError:
return fallback
ext = os.path.splitext(parsed.path)[1]
return ext or fallback
preview_downloaded = False
preview_path = None
if preview_url:
downloader = await get_downloader()
if media_type == 'video':
preview_ext = _extension_from_url(preview_url, '.mp4')
preview_path = os.path.splitext(save_path)[0] + preview_ext preview_path = os.path.splitext(save_path)[0] + preview_ext
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
attempt_urls: List[str] = []
if rewritten:
attempt_urls.append(rewritten_url)
attempt_urls.append(preview_url)
# Download video directly using downloader seen_attempts = set()
downloader = await get_downloader() for attempt in attempt_urls:
success, result = await downloader.download_file( if not attempt or attempt in seen_attempts:
images[0]['url'], continue
seen_attempts.add(attempt)
success, _ = await downloader.download_file(
attempt,
preview_path, preview_path,
use_auth=False # Preview images typically don't need auth
)
if success:
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
else:
# For images, use WebP format for better performance
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
# Download the original image to temp path using downloader
downloader = await get_downloader()
success, content, headers = await downloader.download_to_memory(
images[0]['url'],
use_auth=False use_auth=False
) )
if success: if success:
# Save to temp file preview_downloaded = True
with open(temp_path, 'wb') as f: break
f.write(content) else:
# Optimize and convert to WebP rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
if rewritten:
preview_ext = _extension_from_url(preview_url, '.png')
preview_path = os.path.splitext(save_path)[0] + preview_ext
success, _ = await downloader.download_file(
rewritten_url,
preview_path,
use_auth=False
)
if success:
preview_downloaded = True
if not preview_downloaded:
temp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
success, content, _ = await downloader.download_to_memory(
preview_url,
use_auth=False
)
if success:
with open(temp_path, 'wb') as temp_file_handle:
temp_file_handle.write(content)
preview_path = os.path.splitext(save_path)[0] + '.webp' preview_path = os.path.splitext(save_path)[0] + '.webp'
# Use ExifUtils to optimize and convert the image
optimized_data, _ = ExifUtils.optimize_image( optimized_data, _ = ExifUtils.optimize_image(
image_data=temp_path, image_data=temp_path,
target_width=CARD_PREVIEW_WIDTH, target_width=CARD_PREVIEW_WIDTH,
@@ -499,21 +532,23 @@ class DownloadManager:
preserve_metadata=False preserve_metadata=False
) )
# Save the optimized image with open(preview_path, 'wb') as preview_file:
with open(preview_path, 'wb') as f: preview_file.write(optimized_data)
f.write(optimized_data)
# Update metadata preview_downloaded = True
metadata.preview_url = preview_path.replace(os.sep, '/') finally:
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) if temp_path and os.path.exists(temp_path):
# Remove temporary file
try: try:
os.unlink(temp_path) os.unlink(temp_path)
except Exception as e: except Exception as e:
logger.warning(f"Failed to delete temp file: {e}") logger.warning(f"Failed to delete temp file: {e}")
# Report preview download completion if preview_downloaded and preview_path:
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = nsfw_level
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['preview_path'] = preview_path
if progress_callback: if progress_callback:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
@@ -677,7 +712,15 @@ class DownloadManager:
except Exception as e: except Exception as e:
logger.error(f"Error deleting metadata file: {e}") logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4) preview_path_value = download_info.get('preview_path')
if preview_path_value and os.path.exists(preview_path_value):
try:
os.unlink(preview_path_value)
logger.debug(f"Deleted preview file: {preview_path_value}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
# Delete preview file if exists (.webp or .mp4) for legacy paths
for preview_ext in ['.webp', '.mp4']: for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path): if os.path.exists(preview_path):

View File

@@ -5,8 +5,10 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Awaitable, Callable, Dict, Optional, Sequence from typing import Awaitable, Callable, Dict, Optional, Sequence
from urllib.parse import urlparse
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
from ..utils.civitai_utils import rewrite_preview_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,23 +47,59 @@ class PreviewAssetService:
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_dir = os.path.dirname(metadata_path) preview_dir = os.path.dirname(metadata_path)
is_video = first_preview.get("type") == "video" is_video = first_preview.get("type") == "video"
preview_url = first_preview.get("url")
if not preview_url:
return
def extension_from_url(url: str, fallback: str) -> str:
try:
parsed = urlparse(url)
except ValueError:
return fallback
ext = os.path.splitext(parsed.path)[1]
return ext or fallback
downloader = await self._downloader_factory()
if is_video: if is_video:
extension = ".mp4" extension = extension_from_url(preview_url, ".mp4")
preview_path = os.path.join(preview_dir, base_name + extension) preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory() rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
success, result = await downloader.download_file(
first_preview["url"], preview_path, use_auth=False attempt_urls = []
if rewritten:
attempt_urls.append(rewritten_url)
attempt_urls.append(preview_url)
seen: set[str] = set()
for candidate in attempt_urls:
if not candidate or candidate in seen:
continue
seen.add(candidate)
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
return
else:
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
if rewritten:
extension = extension_from_url(preview_url, ".png")
preview_path = os.path.join(preview_dir, base_name + extension)
success, _ = await downloader.download_file(
rewritten_url, preview_path, use_auth=False
) )
if success: if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/") local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
else: return
extension = ".webp" extension = ".webp"
preview_path = os.path.join(preview_dir, base_name + extension) preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, content, _headers = await downloader.download_to_memory( success, content, _headers = await downloader.download_to_memory(
first_preview["url"], use_auth=False preview_url, use_auth=False
) )
if not success: if not success:
return return

48
py/utils/civitai_utils.py Normal file
View File

@@ -0,0 +1,48 @@
"""Utilities for working with Civitai assets."""
from __future__ import annotations
from urllib.parse import urlparse, urlunparse
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
"""Rewrite Civitai preview URLs to use optimized renditions.
Args:
source_url: Original preview URL from the Civitai API.
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
Returns:
A tuple of the potentially rewritten URL and a flag indicating whether the
replacement occurred. When the URL is not rewritten, the original value is
returned with ``False``.
"""
if not source_url:
return source_url, False
try:
parsed = urlparse(source_url)
except ValueError:
return source_url, False
if parsed.netloc.lower() != "image.civitai.com":
return source_url, False
replacement = "/width=450,optimized=true"
if (media_type or "").lower() == "video":
replacement = "/transcode=true,width=450,optimized=true"
if "/original=true" not in parsed.path:
return source_url, False
updated_path = parsed.path.replace("/original=true", replacement, 1)
if updated_path == parsed.path:
return source_url, False
rewritten = urlunparse(parsed._replace(path=updated_path))
print(rewritten)
return rewritten, True
__all__ = ["rewrite_preview_url"]

View File

@@ -526,6 +526,64 @@ async def test_get_civitai_user_models_marks_library_versions():
assert provider.received_usernames == ["pixel"] assert provider.received_usernames == ["pixel"]
@pytest.mark.asyncio
async def test_get_civitai_user_models_rewrites_civitai_previews():
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "preview-image",
"baseModel": "Flux.1",
"images": [
{"url": image_url, "type": "image"},
],
},
{
"id": 101,
"name": "preview-video",
"baseModel": "Flux.1",
"images": [
{"url": video_url, "type": "video"},
],
},
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
assert (
previews_by_version[101]
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_civitai_user_models_requires_username(): async def test_get_civitai_user_models_requires_username():
provider = FakeUserModelsProvider([]) provider = FakeUserModelsProvider([])

View File

@@ -394,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert result == {"success": True} assert result == {"success": True}
assert [url for url, *_ in dummy_downloader.calls] == download_urls assert [url for url, *_ in dummy_downloader.calls] == download_urls
assert dummy_scanner.calls # ensure cache updated assert dummy_scanner.calls # ensure cache updated
@pytest.mark.asyncio
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
manager._active_downloads["dl"] = {}
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(target_path)
version_info = {
"images": [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 2,
}
]
}
download_urls = ["https://example.invalid/file.safetensors"]
class DummyDownloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, progress_callback=None, use_auth=None):
self.file_calls.append((url, path))
if url.endswith(".jpeg"):
Path(path).write_bytes(b"preview")
return True, None
if url.endswith(".safetensors"):
Path(path).write_bytes(b"model")
return True, None
return False, "unexpected url"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
dummy_downloader = DummyDownloader()
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
optimize_called = {"value": False}
def fake_optimize_image(**_kwargs):
optimize_called["value"] = True
return b"", {}
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id="dl",
)
assert result == {"success": True}
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
assert any("width=450,optimized=true" in url for url in preview_urls)
assert dummy_downloader.memory_calls == 0
assert optimize_called["value"] is False
assert metadata.preview_url.endswith(".jpeg")
assert metadata.preview_nsfw_level == 2
stored_preview = manager._active_downloads["dl"]["preview_path"]
assert stored_preview.endswith(".jpeg")
assert Path(stored_preview).exists()

View File

@@ -0,0 +1,182 @@
from pathlib import Path
from typing import Any
import pytest
from py.services.preview_asset_service import PreviewAssetService
class StubMetadataManager:
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
return True
class RecordingExifUtils:
def __init__(self) -> None:
self.called = False
def optimize_image(self, **kwargs):
self.called = True
return kwargs["image_data"], {}
@pytest.mark.asyncio
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "width=450,optimized=true" in url:
Path(path).write_bytes(b"image-data")
return True, None
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
downloader = Downloader()
async def downloader_factory():
return downloader
exif_utils = RecordingExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 3,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 0
assert exif_utils.called is False
assert len(downloader.file_calls) == 1
assert "width=450,optimized=true" in downloader.file_calls[0][0]
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".jpeg"
assert local_metadata["preview_nsfw_level"] == 3
@pytest.mark.asyncio
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return True, b"raw-image", {}
downloader = Downloader()
async def downloader_factory():
return downloader
class ExifUtils:
def __init__(self):
self.calls = 0
def optimize_image(self, **kwargs):
self.calls += 1
return b"webp-data", {}
exif_utils = ExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.png",
"type": "image",
"nsfwLevel": 1,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 1
assert exif_utils.calls == 1
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".webp"
@pytest.mark.asyncio
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "transcode=true,width=450,optimized=true" in url:
Path(path).write_bytes(b"video-data")
return True, None
if url.endswith(".mp4"):
return False, "fail"
return False, "unexpected"
async def download_to_memory(self, *_args, **_kwargs):
pytest.fail("download_to_memory should not be used for video previews")
downloader = Downloader()
async def downloader_factory():
return downloader
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=RecordingExifUtils(),
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
"type": "video",
"nsfwLevel": 2,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert len(downloader.file_calls) >= 1
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".mp4"
assert local_metadata["preview_nsfw_level"] == 2