fix(preview): resolve CORS error when setting CivitAI remote media as preview

- Add new endpoint POST /api/lm/{prefix}/set-preview-from-url to handle
  remote image downloads server-side, avoiding CORS issues
- Use rewrite_preview_url() to download optimized smaller images (450px width)
- Use Downloader service for reliable downloads with retry logic and proxy support
- Update frontend to call new endpoint instead of fetching images in browser

fixes #837
This commit is contained in:
Will Miao
2026-03-02 13:21:18 +08:00
parent 8b924b1551
commit bde11b153f
6 changed files with 445 additions and 72 deletions

View File

@@ -74,18 +74,14 @@ class ModelPageView:
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
)
supporters_path = os.path.join(root_dir, "data", "supporters.json")
if os.path.exists(supporters_path):
with open(supporters_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
self._logger.debug(f"Failed to load supporters data: {e}")
return {
"specialThanks": [],
"allSupporters": [],
"totalCount": 0
}
return {"specialThanks": [], "allSupporters": [], "totalCount": 0}
def _get_app_version(self) -> str:
version = "1.0.0"
@@ -404,20 +400,26 @@ class ModelManagementHandler:
return web.json_response(
{"success": False, "error": "Model not found in cache"}, status=404
)
# Check if hash needs to be calculated (lazy hash for checkpoints)
sha256 = model_data.get("sha256")
hash_status = model_data.get("hash_status", "completed")
if not sha256 or hash_status != "completed":
# For checkpoints, calculate hash on-demand
scanner = self._service.scanner
if hasattr(scanner, 'calculate_hash_for_model'):
self._logger.info(f"Lazy hash calculation triggered for {file_path}")
if hasattr(scanner, "calculate_hash_for_model"):
self._logger.info(
f"Lazy hash calculation triggered for {file_path}"
)
sha256 = await scanner.calculate_hash_for_model(file_path)
if not sha256:
return web.json_response(
{"success": False, "error": "Failed to calculate SHA256 hash"}, status=500
{
"success": False,
"error": "Failed to calculate SHA256 hash",
},
status=500,
)
# Update model_data with new hash
model_data["sha256"] = sha256
@@ -545,6 +547,153 @@ class ModelManagementHandler:
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
return web.Response(text=str(exc), status=500)
async def set_preview_from_url(self, request: web.Request) -> web.Response:
"""Set a preview image from a remote URL (e.g., CivitAI)."""
try:
from ...utils.civitai_utils import rewrite_preview_url
from ...services.downloader import get_downloader
data = await request.json()
model_path = data.get("model_path")
image_url = data.get("image_url")
nsfw_level = data.get("nsfw_level", 0)
if not model_path:
return web.json_response(
{"success": False, "error": "Model path is required"}, status=400
)
if not image_url:
return web.json_response(
{"success": False, "error": "Image URL is required"}, status=400
)
# Rewrite URL to use optimized rendition if it's a Civitai URL
optimized_url, was_rewritten = rewrite_preview_url(
image_url, media_type="image"
)
if was_rewritten and optimized_url:
self._logger.info(
f"Rewritten preview URL to optimized version: {optimized_url}"
)
else:
optimized_url = image_url
# Download the image using the Downloader service
self._logger.info(
f"Downloading preview from {optimized_url} for {model_path}"
)
downloader = await get_downloader()
success, preview_data, headers = await downloader.download_to_memory(
optimized_url, use_auth=False, return_headers=True
)
if not success:
return web.json_response(
{
"success": False,
"error": f"Failed to download image: {preview_data}",
},
status=502,
)
# preview_data is bytes when success is True
preview_bytes = (
preview_data
if isinstance(preview_data, bytes)
else preview_data.encode("utf-8")
)
# Determine content type from response headers
content_type = (
headers.get("Content-Type", "image/jpeg") if headers else "image/jpeg"
)
# Extract original filename from URL
original_filename = None
if "?" in image_url:
url_path = image_url.split("?")[0]
else:
url_path = image_url
original_filename = url_path.split("/")[-1] if "/" in url_path else None
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_data,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(
result["preview_path"]
),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
if not image_url:
return web.json_response(
{"success": False, "error": "Image URL is required"}, status=400
)
# Download the image from the remote URL
self._logger.info(f"Downloading preview from {image_url} for {model_path}")
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status != 200:
return web.json_response(
{
"success": False,
"error": f"Failed to download image: HTTP {response.status}",
},
status=502,
)
content_type = response.headers.get("Content-Type", "image/jpeg")
preview_data = await response.read()
# Extract original filename from URL
original_filename = None
if "?" in image_url:
url_path = image_url.split("?")[0]
else:
url_path = image_url
original_filename = (
url_path.split("/")[-1] if "/" in url_path else None
)
result = await self._preview_service.replace_preview(
model_path=model_path,
preview_data=preview_bytes,
content_type=content_type,
original_filename=original_filename,
nsfw_level=nsfw_level,
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
metadata_loader=self._metadata_sync.load_local_metadata,
)
return web.json_response(
{
"success": True,
"preview_url": config.get_preview_static_url(
result["preview_path"]
),
"preview_nsfw_level": result["preview_nsfw_level"],
}
)
except Exception as exc:
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
try:
data = await request.json()
@@ -835,9 +984,7 @@ class ModelQueryHandler:
# Format response
group = {"hash": sha256, "models": []}
for model in sorted_models:
group["models"].append(
await self._service.format_response(model)
)
group["models"].append(await self._service.format_response(model))
# Only include groups with 2+ models after filtering
if len(group["models"]) > 1:
@@ -866,7 +1013,9 @@ class ModelQueryHandler:
"favorites_only": request.query.get("favorites_only", "").lower() == "true",
}
def _apply_duplicate_filters(self, models: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
def _apply_duplicate_filters(
self, models: List[Dict[str, Any]], filters: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Apply filters to a list of models within a duplicate group."""
result = models
@@ -907,7 +1056,9 @@ class ModelQueryHandler:
return result
def _sort_duplicate_group(self, models: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def _sort_duplicate_group(
self, models: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Sort models: originals first (left), copies (with -????. pattern) last (right)."""
if len(models) <= 1:
return models
@@ -1192,10 +1343,13 @@ class ModelDownloadHandler:
data["source"] = source
if file_params_json:
import json
try:
data["file_params"] = json.loads(file_params_json)
except json.JSONDecodeError:
self._logger.warning("Invalid file_params JSON: %s", file_params_json)
self._logger.warning(
"Invalid file_params JSON: %s", file_params_json
)
loop = asyncio.get_event_loop()
future = loop.create_future()
@@ -1926,7 +2080,8 @@ class ModelUpdateHandler:
from dataclasses import replace
new_record = replace(
record, versions=list(version_map.values()),
record,
versions=list(version_map.values()),
)
# Optionally persist to database for caching
@@ -2141,6 +2296,7 @@ class ModelUpdateHandler:
if version.early_access_ends_at:
try:
from datetime import datetime, timezone
ea_date = datetime.fromisoformat(
version.early_access_ends_at.replace("Z", "+00:00")
)
@@ -2148,7 +2304,7 @@ class ModelUpdateHandler:
except (ValueError, AttributeError):
# If date parsing fails, treat as active EA (conservative)
is_early_access = True
elif getattr(version, 'is_early_access', False):
elif getattr(version, "is_early_access", False):
# Fallback to basic EA flag from bulk API
is_early_access = True
@@ -2228,6 +2384,7 @@ class ModelHandlerSet:
"fetch_all_civitai": self.civitai.fetch_all_civitai,
"relink_civitai": self.management.relink_civitai,
"replace_preview": self.management.replace_preview,
"set_preview_from_url": self.management.set_preview_from_url,
"save_metadata": self.management.save_metadata,
"add_tags": self.management.add_tags,
"rename_model": self.management.rename_model,