mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat(civitai): implement URL rewriting for Civitai previews and enhance download handling, fixes #499
This commit is contained in:
@@ -34,6 +34,7 @@ from ...utils.constants import (
|
|||||||
SUPPORTED_MEDIA_EXTENSIONS,
|
SUPPORTED_MEDIA_EXTENSIONS,
|
||||||
VALID_LORA_TYPES,
|
VALID_LORA_TYPES,
|
||||||
)
|
)
|
||||||
|
from ...utils.civitai_utils import rewrite_preview_url
|
||||||
from ...utils.example_images_paths import is_valid_example_images_root
|
from ...utils.example_images_paths import is_valid_example_images_root
|
||||||
from ...utils.lora_metadata import extract_trained_words
|
from ...utils.lora_metadata import extract_trained_words
|
||||||
from ...utils.usage_stats import UsageStats
|
from ...utils.usage_stats import UsageStats
|
||||||
@@ -692,7 +693,10 @@ class ModelLibraryHandler:
|
|||||||
if images and isinstance(images, list):
|
if images and isinstance(images, list):
|
||||||
first_image = images[0]
|
first_image = images[0]
|
||||||
if isinstance(first_image, dict):
|
if isinstance(first_image, dict):
|
||||||
thumbnail_url = first_image.get("url")
|
raw_url = first_image.get("url")
|
||||||
|
media_type = first_image.get("type")
|
||||||
|
rewritten_url, _ = rewrite_preview_url(raw_url, media_type)
|
||||||
|
thumbnail_url = rewritten_url
|
||||||
|
|
||||||
in_library = await scanner.check_model_version_exists(version_id_int)
|
in_library = await scanner.check_model_version_exists(version_id_int)
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import asyncio
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from urllib.parse import urlparse
|
||||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||||
|
from ..utils.civitai_utils import rewrite_preview_url
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
@@ -450,70 +452,103 @@ class DownloadManager:
|
|||||||
# Download preview image if available
|
# Download preview image if available
|
||||||
images = version_info.get('images', [])
|
images = version_info.get('images', [])
|
||||||
if images:
|
if images:
|
||||||
# Report preview download progress
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(1) # 1% progress for starting preview download
|
await progress_callback(1) # 1% progress for starting preview download
|
||||||
|
|
||||||
# Check if it's a video or an image
|
first_image = images[0] if isinstance(images[0], dict) else None
|
||||||
is_video = images[0].get('type') == 'video'
|
preview_url = first_image.get('url') if first_image else None
|
||||||
|
media_type = (first_image.get('type') or '').lower() if first_image else ''
|
||||||
if (is_video):
|
nsfw_level = first_image.get('nsfwLevel', 0) if first_image else 0
|
||||||
# For videos, use .mp4 extension
|
|
||||||
preview_ext = '.mp4'
|
def _extension_from_url(url: str, fallback: str) -> str:
|
||||||
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
# Download video directly using downloader
|
except ValueError:
|
||||||
downloader = await get_downloader()
|
return fallback
|
||||||
success, result = await downloader.download_file(
|
ext = os.path.splitext(parsed.path)[1]
|
||||||
images[0]['url'],
|
return ext or fallback
|
||||||
preview_path,
|
|
||||||
use_auth=False # Preview images typically don't need auth
|
preview_downloaded = False
|
||||||
)
|
preview_path = None
|
||||||
if success:
|
|
||||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
if preview_url:
|
||||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
downloader = await get_downloader()
|
||||||
else:
|
|
||||||
# For images, use WebP format for better performance
|
if media_type == 'video':
|
||||||
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
preview_ext = _extension_from_url(preview_url, '.mp4')
|
||||||
temp_path = temp_file.name
|
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||||
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
|
||||||
# Download the original image to temp path using downloader
|
attempt_urls: List[str] = []
|
||||||
downloader = await get_downloader()
|
if rewritten:
|
||||||
success, content, headers = await downloader.download_to_memory(
|
attempt_urls.append(rewritten_url)
|
||||||
images[0]['url'],
|
attempt_urls.append(preview_url)
|
||||||
use_auth=False
|
|
||||||
)
|
seen_attempts = set()
|
||||||
if success:
|
for attempt in attempt_urls:
|
||||||
# Save to temp file
|
if not attempt or attempt in seen_attempts:
|
||||||
with open(temp_path, 'wb') as f:
|
continue
|
||||||
f.write(content)
|
seen_attempts.add(attempt)
|
||||||
# Optimize and convert to WebP
|
success, _ = await downloader.download_file(
|
||||||
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
attempt,
|
||||||
|
preview_path,
|
||||||
# Use ExifUtils to optimize and convert the image
|
use_auth=False
|
||||||
optimized_data, _ = ExifUtils.optimize_image(
|
)
|
||||||
image_data=temp_path,
|
if success:
|
||||||
target_width=CARD_PREVIEW_WIDTH,
|
preview_downloaded = True
|
||||||
format='webp',
|
break
|
||||||
quality=85,
|
else:
|
||||||
preserve_metadata=False
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
|
||||||
)
|
if rewritten:
|
||||||
|
preview_ext = _extension_from_url(preview_url, '.png')
|
||||||
# Save the optimized image
|
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
||||||
with open(preview_path, 'wb') as f:
|
success, _ = await downloader.download_file(
|
||||||
f.write(optimized_data)
|
rewritten_url,
|
||||||
|
preview_path,
|
||||||
# Update metadata
|
use_auth=False
|
||||||
metadata.preview_url = preview_path.replace(os.sep, '/')
|
)
|
||||||
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
if success:
|
||||||
|
preview_downloaded = True
|
||||||
# Remove temporary file
|
|
||||||
try:
|
if not preview_downloaded:
|
||||||
os.unlink(temp_path)
|
temp_path: str | None = None
|
||||||
except Exception as e:
|
try:
|
||||||
logger.warning(f"Failed to delete temp file: {e}")
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
success, content, _ = await downloader.download_to_memory(
|
||||||
|
preview_url,
|
||||||
|
use_auth=False
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
with open(temp_path, 'wb') as temp_file_handle:
|
||||||
|
temp_file_handle.write(content)
|
||||||
|
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
||||||
|
|
||||||
|
optimized_data, _ = ExifUtils.optimize_image(
|
||||||
|
image_data=temp_path,
|
||||||
|
target_width=CARD_PREVIEW_WIDTH,
|
||||||
|
format='webp',
|
||||||
|
quality=85,
|
||||||
|
preserve_metadata=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(preview_path, 'wb') as preview_file:
|
||||||
|
preview_file.write(optimized_data)
|
||||||
|
|
||||||
|
preview_downloaded = True
|
||||||
|
finally:
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete temp file: {e}")
|
||||||
|
|
||||||
|
if preview_downloaded and preview_path:
|
||||||
|
metadata.preview_url = preview_path.replace(os.sep, '/')
|
||||||
|
metadata.preview_nsfw_level = nsfw_level
|
||||||
|
if download_id and download_id in self._active_downloads:
|
||||||
|
self._active_downloads[download_id]['preview_path'] = preview_path
|
||||||
|
|
||||||
# Report preview download completion
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
await progress_callback(3) # 3% progress after preview download
|
await progress_callback(3) # 3% progress after preview download
|
||||||
|
|
||||||
@@ -677,7 +712,15 @@ class DownloadManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting metadata file: {e}")
|
logger.error(f"Error deleting metadata file: {e}")
|
||||||
|
|
||||||
# Delete preview file if exists (.webp or .mp4)
|
preview_path_value = download_info.get('preview_path')
|
||||||
|
if preview_path_value and os.path.exists(preview_path_value):
|
||||||
|
try:
|
||||||
|
os.unlink(preview_path_value)
|
||||||
|
logger.debug(f"Deleted preview file: {preview_path_value}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting preview file: {e}")
|
||||||
|
|
||||||
|
# Delete preview file if exists (.webp or .mp4) for legacy paths
|
||||||
for preview_ext in ['.webp', '.mp4']:
|
for preview_ext in ['.webp', '.mp4']:
|
||||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||||
if os.path.exists(preview_path):
|
if os.path.exists(preview_path):
|
||||||
@@ -710,4 +753,4 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
for task_id, info in self._active_downloads.items()
|
for task_id, info in self._active_downloads.items()
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,8 +5,10 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||||
|
from ..utils.civitai_utils import rewrite_preview_url
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -45,23 +47,59 @@ class PreviewAssetService:
|
|||||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||||
preview_dir = os.path.dirname(metadata_path)
|
preview_dir = os.path.dirname(metadata_path)
|
||||||
is_video = first_preview.get("type") == "video"
|
is_video = first_preview.get("type") == "video"
|
||||||
|
preview_url = first_preview.get("url")
|
||||||
|
|
||||||
|
if not preview_url:
|
||||||
|
return
|
||||||
|
|
||||||
|
def extension_from_url(url: str, fallback: str) -> str:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
except ValueError:
|
||||||
|
return fallback
|
||||||
|
ext = os.path.splitext(parsed.path)[1]
|
||||||
|
return ext or fallback
|
||||||
|
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
|
||||||
if is_video:
|
if is_video:
|
||||||
extension = ".mp4"
|
extension = extension_from_url(preview_url, ".mp4")
|
||||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
downloader = await self._downloader_factory()
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="video")
|
||||||
success, result = await downloader.download_file(
|
|
||||||
first_preview["url"], preview_path, use_auth=False
|
attempt_urls = []
|
||||||
)
|
if rewritten:
|
||||||
if success:
|
attempt_urls.append(rewritten_url)
|
||||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
attempt_urls.append(preview_url)
|
||||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
|
||||||
|
seen: set[str] = set()
|
||||||
|
for candidate in attempt_urls:
|
||||||
|
if not candidate or candidate in seen:
|
||||||
|
continue
|
||||||
|
seen.add(candidate)
|
||||||
|
|
||||||
|
success, _ = await downloader.download_file(candidate, preview_path, use_auth=False)
|
||||||
|
if success:
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type="image")
|
||||||
|
if rewritten:
|
||||||
|
extension = extension_from_url(preview_url, ".png")
|
||||||
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
|
success, _ = await downloader.download_file(
|
||||||
|
rewritten_url, preview_path, use_auth=False
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
return
|
||||||
|
|
||||||
extension = ".webp"
|
extension = ".webp"
|
||||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
downloader = await self._downloader_factory()
|
|
||||||
success, content, _headers = await downloader.download_to_memory(
|
success, content, _headers = await downloader.download_to_memory(
|
||||||
first_preview["url"], use_auth=False
|
preview_url, use_auth=False
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
return
|
return
|
||||||
|
|||||||
48
py/utils/civitai_utils.py
Normal file
48
py/utils/civitai_utils.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Utilities for working with Civitai assets."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from urllib.parse import urlparse, urlunparse
|
||||||
|
|
||||||
|
|
||||||
|
def rewrite_preview_url(source_url: str | None, media_type: str | None = None) -> tuple[str | None, bool]:
|
||||||
|
"""Rewrite Civitai preview URLs to use optimized renditions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_url: Original preview URL from the Civitai API.
|
||||||
|
media_type: Optional media type hint (e.g. ``"image"`` or ``"video"``).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of the potentially rewritten URL and a flag indicating whether the
|
||||||
|
replacement occurred. When the URL is not rewritten, the original value is
|
||||||
|
returned with ``False``.
|
||||||
|
"""
|
||||||
|
if not source_url:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
try:
|
||||||
|
parsed = urlparse(source_url)
|
||||||
|
except ValueError:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
if parsed.netloc.lower() != "image.civitai.com":
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
replacement = "/width=450,optimized=true"
|
||||||
|
if (media_type or "").lower() == "video":
|
||||||
|
replacement = "/transcode=true,width=450,optimized=true"
|
||||||
|
|
||||||
|
if "/original=true" not in parsed.path:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
updated_path = parsed.path.replace("/original=true", replacement, 1)
|
||||||
|
if updated_path == parsed.path:
|
||||||
|
return source_url, False
|
||||||
|
|
||||||
|
rewritten = urlunparse(parsed._replace(path=updated_path))
|
||||||
|
print(rewritten)
|
||||||
|
return rewritten, True
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["rewrite_preview_url"]
|
||||||
|
|
||||||
@@ -526,6 +526,64 @@ async def test_get_civitai_user_models_marks_library_versions():
|
|||||||
assert provider.received_usernames == ["pixel"]
|
assert provider.received_usernames == ["pixel"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_civitai_user_models_rewrites_civitai_previews():
|
||||||
|
image_url = "https://image.civitai.com/container/example/original=true/sample.jpeg"
|
||||||
|
video_url = "https://image.civitai.com/container/example/original=true/sample.mp4"
|
||||||
|
|
||||||
|
models = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Model A",
|
||||||
|
"type": "LORA",
|
||||||
|
"tags": ["style"],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 100,
|
||||||
|
"name": "preview-image",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [
|
||||||
|
{"url": image_url, "type": "image"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 101,
|
||||||
|
"name": "preview-video",
|
||||||
|
"baseModel": "Flux.1",
|
||||||
|
"images": [
|
||||||
|
{"url": video_url, "type": "video"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
provider = FakeUserModelsProvider(models)
|
||||||
|
|
||||||
|
async def provider_factory():
|
||||||
|
return provider
|
||||||
|
|
||||||
|
handler = ModelLibraryHandler(
|
||||||
|
ServiceRegistryAdapter(
|
||||||
|
get_lora_scanner=fake_scanner_factory,
|
||||||
|
get_checkpoint_scanner=fake_scanner_factory,
|
||||||
|
get_embedding_scanner=fake_scanner_factory,
|
||||||
|
),
|
||||||
|
metadata_provider_factory=provider_factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handler.get_civitai_user_models(FakeRequest(query={"username": "pixel"}))
|
||||||
|
payload = json.loads(response.text)
|
||||||
|
|
||||||
|
assert payload["success"] is True
|
||||||
|
previews_by_version = {item["versionId"]: item["thumbnailUrl"] for item in payload["versions"]}
|
||||||
|
assert previews_by_version[100] == "https://image.civitai.com/container/example/width=450,optimized=true/sample.jpeg"
|
||||||
|
assert (
|
||||||
|
previews_by_version[101]
|
||||||
|
== "https://image.civitai.com/container/example/transcode=true,width=450,optimized=true/sample.mp4"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_civitai_user_models_requires_username():
|
async def test_get_civitai_user_models_requires_username():
|
||||||
provider = FakeUserModelsProvider([])
|
provider = FakeUserModelsProvider([])
|
||||||
|
|||||||
@@ -394,3 +394,98 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
|||||||
assert result == {"success": True}
|
assert result == {"success": True}
|
||||||
assert [url for url, *_ in dummy_downloader.calls] == download_urls
|
assert [url for url, *_ in dummy_downloader.calls] == download_urls
|
||||||
assert dummy_scanner.calls # ensure cache updated
|
assert dummy_scanner.calls # ensure cache updated
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
|
||||||
|
manager = DownloadManager()
|
||||||
|
save_dir = tmp_path / "downloads"
|
||||||
|
save_dir.mkdir()
|
||||||
|
target_path = save_dir / "file.safetensors"
|
||||||
|
|
||||||
|
manager._active_downloads["dl"] = {}
|
||||||
|
|
||||||
|
class DummyMetadata:
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
self.file_path = str(path)
|
||||||
|
self.sha256 = "sha256"
|
||||||
|
self.file_name = path.stem
|
||||||
|
self.preview_url = None
|
||||||
|
self.preview_nsfw_level = None
|
||||||
|
|
||||||
|
def generate_unique_filename(self, *_args, **_kwargs):
|
||||||
|
return os.path.basename(self.file_path)
|
||||||
|
|
||||||
|
def update_file_info(self, _path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {"file_path": self.file_path}
|
||||||
|
|
||||||
|
metadata = DummyMetadata(target_path)
|
||||||
|
version_info = {
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||||
|
"type": "image",
|
||||||
|
"nsfwLevel": 2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
download_urls = ["https://example.invalid/file.safetensors"]
|
||||||
|
|
||||||
|
class DummyDownloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
self.memory_calls = 0
|
||||||
|
|
||||||
|
async def download_file(self, url, path, progress_callback=None, use_auth=None):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
if url.endswith(".jpeg"):
|
||||||
|
Path(path).write_bytes(b"preview")
|
||||||
|
return True, None
|
||||||
|
if url.endswith(".safetensors"):
|
||||||
|
Path(path).write_bytes(b"model")
|
||||||
|
return True, None
|
||||||
|
return False, "unexpected url"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
self.memory_calls += 1
|
||||||
|
return False, b"", {}
|
||||||
|
|
||||||
|
dummy_downloader = DummyDownloader()
|
||||||
|
monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader))
|
||||||
|
|
||||||
|
optimize_called = {"value": False}
|
||||||
|
|
||||||
|
def fake_optimize_image(**_kwargs):
|
||||||
|
optimize_called["value"] = True
|
||||||
|
return b"", {}
|
||||||
|
|
||||||
|
monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image))
|
||||||
|
monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True))
|
||||||
|
|
||||||
|
dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None))
|
||||||
|
monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner))
|
||||||
|
|
||||||
|
result = await manager._execute_download(
|
||||||
|
download_urls=download_urls,
|
||||||
|
save_dir=str(save_dir),
|
||||||
|
metadata=metadata,
|
||||||
|
version_info=version_info,
|
||||||
|
relative_path="",
|
||||||
|
progress_callback=None,
|
||||||
|
model_type="lora",
|
||||||
|
download_id="dl",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")]
|
||||||
|
assert any("width=450,optimized=true" in url for url in preview_urls)
|
||||||
|
assert dummy_downloader.memory_calls == 0
|
||||||
|
assert optimize_called["value"] is False
|
||||||
|
assert metadata.preview_url.endswith(".jpeg")
|
||||||
|
assert metadata.preview_nsfw_level == 2
|
||||||
|
stored_preview = manager._active_downloads["dl"]["preview_path"]
|
||||||
|
assert stored_preview.endswith(".jpeg")
|
||||||
|
assert Path(stored_preview).exists()
|
||||||
|
|||||||
182
tests/services/test_preview_asset_service.py
Normal file
182
tests/services/test_preview_asset_service.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.preview_asset_service import PreviewAssetService
|
||||||
|
|
||||||
|
|
||||||
|
class StubMetadataManager:
|
||||||
|
async def save_metadata(self, *_args: Any, **_kwargs: Any) -> bool: # pragma: no cover - helper
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingExifUtils:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.called = False
|
||||||
|
|
||||||
|
def optimize_image(self, **kwargs):
|
||||||
|
self.called = True
|
||||||
|
return kwargs["image_data"], {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_preview_prefers_rewritten_civitai_image(tmp_path):
|
||||||
|
metadata_path = tmp_path / "model.metadata.json"
|
||||||
|
metadata_path.write_text("{}")
|
||||||
|
local_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
self.memory_calls = 0
|
||||||
|
|
||||||
|
async def download_file(self, url, path, use_auth=False):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
if "width=450,optimized=true" in url:
|
||||||
|
Path(path).write_bytes(b"image-data")
|
||||||
|
return True, None
|
||||||
|
return False, "fail"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
self.memory_calls += 1
|
||||||
|
return False, b"", {}
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
|
||||||
|
async def downloader_factory():
|
||||||
|
return downloader
|
||||||
|
|
||||||
|
exif_utils = RecordingExifUtils()
|
||||||
|
service = PreviewAssetService(
|
||||||
|
metadata_manager=StubMetadataManager(),
|
||||||
|
downloader_factory=downloader_factory,
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.jpeg",
|
||||||
|
"type": "image",
|
||||||
|
"nsfwLevel": 3,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||||
|
|
||||||
|
assert downloader.memory_calls == 0
|
||||||
|
assert exif_utils.called is False
|
||||||
|
assert len(downloader.file_calls) == 1
|
||||||
|
assert "width=450,optimized=true" in downloader.file_calls[0][0]
|
||||||
|
preview_path = Path(local_metadata["preview_url"])
|
||||||
|
assert preview_path.exists()
|
||||||
|
assert preview_path.suffix == ".jpeg"
|
||||||
|
assert local_metadata["preview_nsfw_level"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_preview_falls_back_to_webp_when_rewrite_fails(tmp_path):
|
||||||
|
metadata_path = tmp_path / "model.metadata.json"
|
||||||
|
metadata_path.write_text("{}")
|
||||||
|
local_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
self.memory_calls = 0
|
||||||
|
|
||||||
|
async def download_file(self, url, path, use_auth=False):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
return False, "fail"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
self.memory_calls += 1
|
||||||
|
return True, b"raw-image", {}
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
|
||||||
|
async def downloader_factory():
|
||||||
|
return downloader
|
||||||
|
|
||||||
|
class ExifUtils:
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = 0
|
||||||
|
|
||||||
|
def optimize_image(self, **kwargs):
|
||||||
|
self.calls += 1
|
||||||
|
return b"webp-data", {}
|
||||||
|
|
||||||
|
exif_utils = ExifUtils()
|
||||||
|
|
||||||
|
service = PreviewAssetService(
|
||||||
|
metadata_manager=StubMetadataManager(),
|
||||||
|
downloader_factory=downloader_factory,
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.png",
|
||||||
|
"type": "image",
|
||||||
|
"nsfwLevel": 1,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||||
|
|
||||||
|
assert downloader.memory_calls == 1
|
||||||
|
assert exif_utils.calls == 1
|
||||||
|
preview_path = Path(local_metadata["preview_url"])
|
||||||
|
assert preview_path.exists()
|
||||||
|
assert preview_path.suffix == ".webp"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_preview_rewrites_civitai_video(tmp_path):
|
||||||
|
metadata_path = tmp_path / "model.metadata.json"
|
||||||
|
metadata_path.write_text("{}")
|
||||||
|
local_metadata: dict[str, Any] = {}
|
||||||
|
|
||||||
|
class Downloader:
|
||||||
|
def __init__(self):
|
||||||
|
self.file_calls: list[tuple[str, str]] = []
|
||||||
|
|
||||||
|
async def download_file(self, url, path, use_auth=False):
|
||||||
|
self.file_calls.append((url, path))
|
||||||
|
if "transcode=true,width=450,optimized=true" in url:
|
||||||
|
Path(path).write_bytes(b"video-data")
|
||||||
|
return True, None
|
||||||
|
if url.endswith(".mp4"):
|
||||||
|
return False, "fail"
|
||||||
|
return False, "unexpected"
|
||||||
|
|
||||||
|
async def download_to_memory(self, *_args, **_kwargs):
|
||||||
|
pytest.fail("download_to_memory should not be used for video previews")
|
||||||
|
|
||||||
|
downloader = Downloader()
|
||||||
|
|
||||||
|
async def downloader_factory():
|
||||||
|
return downloader
|
||||||
|
|
||||||
|
service = PreviewAssetService(
|
||||||
|
metadata_manager=StubMetadataManager(),
|
||||||
|
downloader_factory=downloader_factory,
|
||||||
|
exif_utils=RecordingExifUtils(),
|
||||||
|
)
|
||||||
|
|
||||||
|
images = [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/container/example/original=true/sample.mp4",
|
||||||
|
"type": "video",
|
||||||
|
"nsfwLevel": 2,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
await service.ensure_preview_for_metadata(str(metadata_path), local_metadata, images)
|
||||||
|
|
||||||
|
assert len(downloader.file_calls) >= 1
|
||||||
|
assert any("transcode=true,width=450,optimized=true" in url for url, _ in downloader.file_calls)
|
||||||
|
preview_path = Path(local_metadata["preview_url"])
|
||||||
|
assert preview_path.exists()
|
||||||
|
assert preview_path.suffix == ".mp4"
|
||||||
|
assert local_metadata["preview_nsfw_level"] == 2
|
||||||
Reference in New Issue
Block a user