feat(civitai): implement URL rewriting for Civitai previews and enhance download handling, fixes #499

This commit is contained in:
Will Miao
2025-10-09 17:54:37 +08:00
parent d2c2bfbe6a
commit f542ade628
7 changed files with 541 additions and 73 deletions

View File

@@ -34,6 +34,7 @@ from ...utils.constants import (
SUPPORTED_MEDIA_EXTENSIONS, SUPPORTED_MEDIA_EXTENSIONS,
VALID_LORA_TYPES, VALID_LORA_TYPES,
) )
from ...utils.civitai_utils import rewrite_preview_url
from ...utils.example_images_paths import is_valid_example_images_root from ...utils.example_images_paths import is_valid_example_images_root
from ...utils.lora_metadata import extract_trained_words from ...utils.lora_metadata import extract_trained_words
from ...utils.usage_stats import UsageStats from ...utils.usage_stats import UsageStats
@@ -692,7 +693,10 @@ class ModelLibraryHandler:
if images and isinstance(images, list): if images and isinstance(images, list):
first_image = images[0] first_image = images[0]
if isinstance(first_image, dict): if isinstance(first_image, dict):
thumbnail_url = first_image.get("url") raw_url = first_image.get("url")
media_type = first_image.get("type")
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
thumbnail_url = rewritten_url
in_library = await scanner.check_model_version_exists(version_id_int) in_library = await scanner.check_model_version_exists(version_id_int)

View File

@@ -4,8 +4,10 @@ import asyncio
from collections import OrderedDict from collections import OrderedDict
import uuid import uuid
from typing import Dict, List from typing import Dict, List
from urllib.parse import urlparse
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
from ..utils.civitai_utils import rewrite_preview_url
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
@@ -450,70 +452,103 @@ class DownloadManager:
# Download preview image if available # Download preview image if available
images = version_info.get('images', []) images = version_info.get('images', [])
if images: if images:
# Report preview download progress
if progress_callback: if progress_callback:
await progress_callback(1) # 1% progress for starting preview download await progress_callback(1) # 1% progress for starting preview download
# Check if it's a video or an image first_image = images[0] if isinstance(images[0], dict) else None
is_video = images[0].get('type') == 'video' preview_url = first_image.get('url') if first_image else None
media_type = (first_image.get('type') or '').lower() if first_image else ''
if (is_video): nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
# For videos, use .mp4 extension
preview_ext = '.mp4' def _extension_from_url(url: str, fallback: str) -> str:
preview_path = os.path.splitext(save_path)[0] + preview_ext try:
parsed = urlparse(url)
# Download video directly using downloader except ValueError:
downloader = await get_downloader() return fallback
success, result = await downloader.download_file( ext = os.path.splitext(parsed.path)[1]
images[0]['url'], return ext or fallback
preview_path,
use_auth=False # Preview images typically don't need auth preview_downloaded = False
) preview_path = None
if success:
metadata.preview_url = preview_path.replace(os.sep, '/') if preview_url:
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) downloader = await get_downloader()
else:
# For images, use WebP format for better performance if media_type == 'video':
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: preview_ext = _extension_from_url(preview_url, '.mp4')
temp_path = temp_file.name preview_path = os.path.splitext(save_path)[0] + preview_ext
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
# Download the original image to temp path using downloader attempt_urls: List[str] = []
downloader = await get_downloader() if rewritten:
success, content, headers = await downloader.download_to_memory( attempt_urls.append(rewritten_url)
images[0]['url'], attempt_urls.append(preview_url)
use_auth=False
) seen_attempts = set()
if success: for attempt in attempt_urls:
# Save to temp file if not attempt or attempt in seen_attempts:
with open(temp_path, 'wb') as f: continue
f.write(content) seen_attempts.add(attempt)
# Optimize and convert to WebP success, _ = await downloader.download_file(
preview_path = os.path.splitext(save_path)[0] + '.webp' attempt,
preview_path,
# Use ExifUtils to optimize and convert the image use_auth=False
optimized_data, _ = ExifUtils.optimize_image( )
image_data=temp_path, if success:
target_width=CARD_PREVIEW_WIDTH, preview_downloaded = True
format='webp', break
quality=85, else:
preserve_metadata=False rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
) if rewritten:
preview_ext = _extension_from_url(preview_url, '.png')
# Save the optimized image preview_path = os.path.splitext(save_path)[0] + preview_ext
with open(preview_path, 'wb') as f: success, _ = await downloader.download_file(
f.write(optimized_data) rewritten_url,
preview_path,
# Update metadata use_auth=False
metadata.preview_url = preview_path.replace(os.sep, '/') )
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) if success:
preview_downloaded = True
# Remove temporary file
try: if not preview_downloaded:
os.unlink(temp_path) temp_path: str | None = None
except Exception as e: try:
logger.warning(f"Failed to delete temp file: {e}") with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
success, content, _ = await downloader.download_to_memory(
preview_url,
use_auth=False
)
if success:
with open(temp_path, 'wb') as temp_file_handle:
temp_file_handle.write(content)
preview_path = os.path.splitext(save_path)[0] + '.webp'
optimized_data, _ = ExifUtils.optimize_image(
image_data=temp_path,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
with open(preview_path, 'wb') as preview_file:
preview_file.write(optimized_data)
preview_downloaded = True
finally:
if temp_path and os.path.exists(temp_path):
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to delete temp file: {e}")
if preview_downloaded and preview_path:
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = nsfw_level
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['preview_path'] = preview_path
# Report preview download completion
if progress_callback: if progress_callback:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
@@ -677,7 +712,15 @@ class DownloadManager:
except Exception as e: except Exception as e:
logger.error(f"Error deleting metadata file: {e}") logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4) preview_path_value = download_info.get('preview_path')
if preview_path_value and os.path.exists(preview_path_value):
try:
os.unlink(preview_path_value)
logger.debug(f"Deleted preview file: {preview_path_value}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
# Delete preview file if exists (.webp or .mp4) for legacy paths
for preview_ext in ['.webp', '.mp4']: for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path): if os.path.exists(preview_path):
@@ -710,4 +753,4 @@ class DownloadManager:
} }
for task_id, info in self._active_downloads.items() for task_id, info in self._active_downloads.items()
] ]
} }

View File

@@ -5,8 +5,10 @@ from __future__ import annotations
import logging import logging
import os import os
from typing import Awaitable, Callable, Dict, Optional, Sequence from typing import Awaitable, Callable, Dict, Optional, Sequence
from urllib.parse import urlparse
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
from ..utils.civitai_utils import rewrite_preview_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,23 +47,59 @@ class PreviewAssetService:
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_dir = os.path.dirname(metadata_path) preview_dir = os.path.dirname(metadata_path)
is_video = first_preview.get("type") == "video" is_video = first_preview.get("type") == "video"
preview_url = first_preview.get("url")
if not preview_url:
return
def extension_from_url(url: str, fallback: str) -> str:
try:
parsed = urlparse(url)
except ValueError:
return fallback
ext = os.path.splitext(parsed.path)[1]
return ext or fallback
downloader = await self._downloader_factory()
if is_video: if is_video:
extension = ".mp4" extension = extension_from_url(preview_url, ".mp4")
preview_path = os.path.join(preview_dir, base_name + extension) preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory() rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
success, result = await downloader.download_file(
first_preview["url"], preview_path, use_auth=False attempt_urls = []
) if rewritten:
if success: attempt_urls.append(rewritten_url)
local_metadata["preview_url"] = preview_path.replace(os.sep, "/") attempt_urls.append(preview_url)
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
seen: set[str] = set()
for candidate in attempt_urls:
if not candidate or candidate in seen:
continue
seen.add(candidate)
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
return
else: else:
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
if rewritten:
extension = extension_from_url(preview_url, ".png")
preview_path = os.path.join(preview_dir, base_name + extension)
success, _ = await downloader.download_file(
rewritten_url, preview_path, use_auth=False
)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
return
extension = ".webp" extension = ".webp"
preview_path = os.path.join(preview_dir, base_name + extension) preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, content, _headers = await downloader.download_to_memory( success, content, _headers = await downloader.download_to_memory(
first_preview["url"], use_auth=False preview_url, use_auth=False
) )
if not success: if not success:
return return

48
py/utils/civitai_utils.py Normal file
View File

@@ -0,0 +1,48 @@
"""Utilities for working with Civitai assets."""
from __future__ import annotations
from urllib.parse import urlparse, urlunparse
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
"""Rewrite Civitai preview URLs to use optimized renditions.
Args:
source_url: Original preview URL from the Civitai API.
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
Returns:
A tuple of the potentially rewritten URL and a flag indicating whether the
replacement occurred. When the URL is not rewritten, the original value is
returned with ``False``.
"""
if not source_url:
return source_url, False
try:
parsed = urlparse(source_url)
except ValueError:
return source_url, False
if parsed.netloc.lower() != "image.civitai.com":
return source_url, False
replacement = "/width=450,optimized=true"
if (media_type or "").lower() == "video":
replacement = "/transcode=true,width=450,optimized=true"
if "/original=true" not in parsed.path:
return source_url, False
updated_path = parsed.path.replace("/original=true", replacement, 1)
if updated_path == parsed.path:
return source_url, False
rewritten = urlunparse(parsed._replace(path=updated_path))
print(rewritten)
return rewritten, True
__all__ = ["rewrite_preview_url"]

View File

@@ -526,6 +526,64 @@ async def test_get_civitai_user_models_marks_library_versions():
assert provider.received_usernames == ["pixel"] assert provider.received_usernames == ["pixel"]
@pytest.mark.asyncio
async def test_get_civitai_user_models_rewrites_civitai_previews():
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
models = [
{
"id": 1,
"name": "Model A",
"type": "LORA",
"tags": ["style"],
"modelVersions": [
{
"id": 100,
"name": "preview-image",
"baseModel": "Flux.1",
"images": [
{"url": image_url, "type": "image"},
],
},
{
"id": 101,
"name": "preview-video",
"baseModel": "Flux.1",
"images": [
{"url": video_url, "type": "video"},
],
},
],
},
]
provider = FakeUserModelsProvider(models)
async def provider_factory():
return provider
handler = ModelLibraryHandler(
ServiceRegistryAdapter(
get_lora_scanner=fake_scanner_factory,
get_checkpoint_scanner=fake_scanner_factory,
get_embedding_scanner=fake_scanner_factory,
),
metadata_provider_factory=provider_factory,
)
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
payload = json.loads(response.text)
assert payload["success"] is True
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
assert (
previews_by_version[101]
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_civitai_user_models_requires_username(): async def test_get_civitai_user_models_requires_username():
provider = FakeUserModelsProvider([]) provider = FakeUserModelsProvider([])

View File

@@ -394,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
assert result == {"success": True} assert result == {"success": True}
assert [url for url, *_ in dummy_downloader.calls] == download_urls assert [url for url, *_ in dummy_downloader.calls] == download_urls
assert dummy_scanner.calls # ensure cache updated assert dummy_scanner.calls # ensure cache updated
@pytest.mark.asyncio
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
manager = DownloadManager()
save_dir = tmp_path / "downloads"
save_dir.mkdir()
target_path = save_dir / "file.safetensors"
manager._active_downloads["dl"] = {}
class DummyMetadata:
def __init__(self, path: Path):
self.file_path = str(path)
self.sha256 = "sha256"
self.file_name = path.stem
self.preview_url = None
self.preview_nsfw_level = None
def generate_unique_filename(self, *_args, **_kwargs):
return os.path.basename(self.file_path)
def update_file_info(self, _path):
return None
def to_dict(self):
return {"file_path": self.file_path}
metadata = DummyMetadata(target_path)
version_info = {
"images": [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 2,
}
]
}
download_urls = ["https://example.invalid/file.safetensors"]
class DummyDownloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, progress_callback=None, use_auth=None):
self.file_calls.append((url, path))
if url.endswith(".jpeg"):
Path(path).write_bytes(b"preview")
return True, None
if url.endswith(".safetensors"):
Path(path).write_bytes(b"model")
return True, None
return False, "unexpected url"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
dummy_downloader = DummyDownloader()
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
optimize_called = {"value": False}
def fake_optimize_image(**_kwargs):
optimize_called["value"] = True
return b"", {}
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
result = await manager._execute_download(
download_urls=download_urls,
save_dir=str(save_dir),
metadata=metadata,
version_info=version_info,
relative_path="",
progress_callback=None,
model_type="lora",
download_id="dl",
)
assert result == {"success": True}
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
assert any("width=450,optimized=true" in url for url in preview_urls)
assert dummy_downloader.memory_calls == 0
assert optimize_called["value"] is False
assert metadata.preview_url.endswith(".jpeg")
assert metadata.preview_nsfw_level == 2
stored_preview = manager._active_downloads["dl"]["preview_path"]
assert stored_preview.endswith(".jpeg")
assert Path(stored_preview).exists()

View File

@@ -0,0 +1,182 @@
from pathlib import Path
from typing import Any
import pytest
from py.services.preview_asset_service import PreviewAssetService
class StubMetadataManager:
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
return True
class RecordingExifUtils:
def __init__(self) -> None:
self.called = False
def optimize_image(self, **kwargs):
self.called = True
return kwargs["image_data"], {}
@pytest.mark.asyncio
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "width=450,optimized=true" in url:
Path(path).write_bytes(b"image-data")
return True, None
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return False, b"", {}
downloader = Downloader()
async def downloader_factory():
return downloader
exif_utils = RecordingExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
"type": "image",
"nsfwLevel": 3,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 0
assert exif_utils.called is False
assert len(downloader.file_calls) == 1
assert "width=450,optimized=true" in downloader.file_calls[0][0]
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".jpeg"
assert local_metadata["preview_nsfw_level"] == 3
@pytest.mark.asyncio
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
self.memory_calls = 0
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
return False, "fail"
async def download_to_memory(self, *_args, **_kwargs):
self.memory_calls += 1
return True, b"raw-image", {}
downloader = Downloader()
async def downloader_factory():
return downloader
class ExifUtils:
def __init__(self):
self.calls = 0
def optimize_image(self, **kwargs):
self.calls += 1
return b"webp-data", {}
exif_utils = ExifUtils()
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=exif_utils,
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.png",
"type": "image",
"nsfwLevel": 1,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert downloader.memory_calls == 1
assert exif_utils.calls == 1
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".webp"
@pytest.mark.asyncio
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
metadata_path = tmp_path / "model.metadata.json"
metadata_path.write_text("{}")
local_metadata: dict[str, Any] = {}
class Downloader:
def __init__(self):
self.file_calls: list[tuple[str, str]] = []
async def download_file(self, url, path, use_auth=False):
self.file_calls.append((url, path))
if "transcode=true,width=450,optimized=true" in url:
Path(path).write_bytes(b"video-data")
return True, None
if url.endswith(".mp4"):
return False, "fail"
return False, "unexpected"
async def download_to_memory(self, *_args, **_kwargs):
pytest.fail("download_to_memory should not be used for video previews")
downloader = Downloader()
async def downloader_factory():
return downloader
service = PreviewAssetService(
metadata_manager=StubMetadataManager(),
downloader_factory=downloader_factory,
exif_utils=RecordingExifUtils(),
)
images = [
{
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
"type": "video",
"nsfwLevel": 2,
}
]
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
assert len(downloader.file_calls) >= 1
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
preview_path = Path(local_metadata["preview_url"])
assert preview_path.exists()
assert preview_path.suffix == ".mp4"
assert local_metadata["preview_nsfw_level"] == 2