From bde11b153f2e6c21a141cf3972f9eeff9d10af7b Mon Sep 17 00:00:00 2001 From: Will Miao Date: Mon, 2 Mar 2026 13:21:18 +0800 Subject: [PATCH] fix(preview): resolve CORS error when setting CivitAI remote media as preview - Add new endpoint POST /api/lm/{prefix}/set-preview-from-url to handle remote image downloads server-side, avoiding CORS issues - Use rewrite_preview_url() to download optimized smaller images (450px width) - Use Downloader service for reliable downloads with retry logic and proxy support - Update frontend to call new endpoint instead of fetching images in browser fixes #837 --- py/routes/handlers/model_handlers.py | 197 ++++++++++++++++-- py/routes/model_route_registrar.py | 74 +++++-- static/js/api/apiConfig.js | 1 + static/js/api/baseModelApi.js | 50 +++++ .../components/shared/showcase/MediaUtils.js | 19 +- tests/routes/test_base_model_routes_smoke.py | 176 +++++++++++++--- 6 files changed, 445 insertions(+), 72 deletions(-) diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 88ed56e0..af150b63 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -74,18 +74,14 @@ class ModelPageView: os.path.dirname(os.path.dirname(os.path.dirname(current_file))) ) supporters_path = os.path.join(root_dir, "data", "supporters.json") - + if os.path.exists(supporters_path): with open(supporters_path, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: self._logger.debug(f"Failed to load supporters data: {e}") - - return { - "specialThanks": [], - "allSupporters": [], - "totalCount": 0 - } + + return {"specialThanks": [], "allSupporters": [], "totalCount": 0} def _get_app_version(self) -> str: version = "1.0.0" @@ -404,20 +400,26 @@ class ModelManagementHandler: return web.json_response( {"success": False, "error": "Model not found in cache"}, status=404 ) - + # Check if hash needs to be calculated (lazy hash for checkpoints) sha256 = model_data.get("sha256") hash_status = model_data.get("hash_status", "completed") - + if not sha256 or hash_status != "completed": # For checkpoints, calculate hash on-demand scanner = self._service.scanner - if hasattr(scanner, 'calculate_hash_for_model'): - self._logger.info(f"Lazy hash calculation triggered for {file_path}") + if hasattr(scanner, "calculate_hash_for_model"): + self._logger.info( + f"Lazy hash calculation triggered for {file_path}" + ) sha256 = await scanner.calculate_hash_for_model(file_path) if not sha256: return web.json_response( - {"success": False, "error": "Failed to calculate SHA256 hash"}, status=500 + { + "success": False, + "error": "Failed to calculate SHA256 hash", + }, + status=500, ) # Update model_data with new hash model_data["sha256"] = sha256 @@ -545,6 +547,153 @@ class ModelManagementHandler: self._logger.error("Error replacing preview: %s", exc, exc_info=True) return web.Response(text=str(exc), status=500) + async def set_preview_from_url(self, request: web.Request) -> web.Response: + """Set a preview image from a remote URL (e.g., CivitAI).""" + try: + from ...utils.civitai_utils import rewrite_preview_url + from ...services.downloader import get_downloader + + data = await request.json() + model_path = data.get("model_path") + image_url = data.get("image_url") + nsfw_level = data.get("nsfw_level", 0) + + if not model_path: + return web.json_response( + {"success": False, "error": "Model path is required"}, status=400 + ) + + if not image_url: + return web.json_response( + {"success": False, "error": "Image URL is required"}, status=400 + ) + + # Rewrite URL to use optimized rendition if it's a Civitai URL + optimized_url, was_rewritten = rewrite_preview_url( + image_url, media_type="image" + ) + if was_rewritten and optimized_url: + self._logger.info( + f"Rewritten preview URL to optimized version: {optimized_url}" + ) + else: + optimized_url = image_url + + # Download the image using the Downloader service + self._logger.info( + f"Downloading preview from {optimized_url} for {model_path}" + ) + downloader = await get_downloader() + success, preview_data, headers = await downloader.download_to_memory( + optimized_url, use_auth=False, return_headers=True + ) + + if not success: + return web.json_response( + { + "success": False, + "error": f"Failed to download image: {preview_data}", + }, + status=502, + ) + + # preview_data is bytes when success is True + preview_bytes = ( + preview_data + if isinstance(preview_data, bytes) + else preview_data.encode("utf-8") + ) + + # Determine content type from response headers + content_type = ( + headers.get("Content-Type", "image/jpeg") if headers else "image/jpeg" + ) + + # Extract original filename from URL + original_filename = None + if "?" in image_url: + url_path = image_url.split("?")[0] + else: + url_path = image_url + original_filename = url_path.split("/")[-1] if "/" in url_path else None + + result = await self._preview_service.replace_preview( + model_path=model_path, + preview_data=preview_data, + content_type=content_type, + original_filename=original_filename, + nsfw_level=nsfw_level, + update_preview_in_cache=self._service.scanner.update_preview_in_cache, + metadata_loader=self._metadata_sync.load_local_metadata, + ) + + return web.json_response( + { + "success": True, + "preview_url": config.get_preview_static_url( + result["preview_path"] + ), + "preview_nsfw_level": result["preview_nsfw_level"], + } + ) + except Exception as exc: + self._logger.error("Error setting preview from URL: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + if not image_url: + return web.json_response( + {"success": False, "error": "Image URL is required"}, status=400 + ) + + # Download the image from the remote URL + self._logger.info(f"Downloading preview from {image_url} for {model_path}") + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as response: + if response.status != 200: + return web.json_response( + { + "success": False, + "error": f"Failed to download image: HTTP {response.status}", + }, + status=502, + ) + + content_type = response.headers.get("Content-Type", "image/jpeg") + preview_data = await response.read() + + # Extract original filename from URL + original_filename = None + if "?" in image_url: + url_path = image_url.split("?")[0] + else: + url_path = image_url + original_filename = ( + url_path.split("/")[-1] if "/" in url_path else None + ) + + result = await self._preview_service.replace_preview( + model_path=model_path, + preview_data=preview_bytes, + content_type=content_type, + original_filename=original_filename, + nsfw_level=nsfw_level, + update_preview_in_cache=self._service.scanner.update_preview_in_cache, + metadata_loader=self._metadata_sync.load_local_metadata, + ) + + return web.json_response( + { + "success": True, + "preview_url": config.get_preview_static_url( + result["preview_path"] + ), + "preview_nsfw_level": result["preview_nsfw_level"], + } + ) + except Exception as exc: + self._logger.error("Error setting preview from URL: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def save_metadata(self, request: web.Request) -> web.Response: try: data = await request.json() @@ -835,9 +984,7 @@ class ModelQueryHandler: # Format response group = {"hash": sha256, "models": []} for model in sorted_models: - group["models"].append( - await self._service.format_response(model) - ) + group["models"].append(await self._service.format_response(model)) # Only include groups with 2+ models after filtering if len(group["models"]) > 1: @@ -866,7 +1013,9 @@ class ModelQueryHandler: "favorites_only": request.query.get("favorites_only", "").lower() == "true", } - def _apply_duplicate_filters(self, models: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]: + def _apply_duplicate_filters( + self, models: List[Dict[str, Any]], filters: Dict[str, Any] + ) -> List[Dict[str, Any]]: """Apply filters to a list of models within a duplicate group.""" result = models @@ -907,7 +1056,9 @@ class ModelQueryHandler: return result - def _sort_duplicate_group(self, models: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _sort_duplicate_group( + self, models: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: """Sort models: originals first (left), copies (with -????. pattern) last (right).""" if len(models) <= 1: return models @@ -1192,10 +1343,13 @@ class ModelDownloadHandler: data["source"] = source if file_params_json: import json + try: data["file_params"] = json.loads(file_params_json) except json.JSONDecodeError: - self._logger.warning("Invalid file_params JSON: %s", file_params_json) + self._logger.warning( + "Invalid file_params JSON: %s", file_params_json + ) loop = asyncio.get_event_loop() future = loop.create_future() @@ -1926,7 +2080,8 @@ class ModelUpdateHandler: from dataclasses import replace new_record = replace( - record, versions=list(version_map.values()), + record, + versions=list(version_map.values()), ) # Optionally persist to database for caching @@ -2141,6 +2296,7 @@ class ModelUpdateHandler: if version.early_access_ends_at: try: from datetime import datetime, timezone + ea_date = datetime.fromisoformat( version.early_access_ends_at.replace("Z", "+00:00") ) @@ -2148,7 +2304,7 @@ class ModelUpdateHandler: except (ValueError, AttributeError): # If date parsing fails, treat as active EA (conservative) is_early_access = True - elif getattr(version, 'is_early_access', False): + elif getattr(version, "is_early_access", False): # Fallback to basic EA flag from bulk API is_early_access = True @@ -2228,6 +2384,7 @@ class ModelHandlerSet: "fetch_all_civitai": self.civitai.fetch_all_civitai, "relink_civitai": self.management.relink_civitai, "replace_preview": self.management.replace_preview, + "set_preview_from_url": self.management.set_preview_from_url, "save_metadata": self.management.save_metadata, "add_tags": self.management.add_tags, "rename_model": self.management.rename_model, diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index 9369db36..d3212cdc 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -1,4 +1,5 @@ """Route registrar for model endpoints.""" + from __future__ import annotations from dataclasses import dataclass @@ -27,6 +28,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"), RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"), RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"), + RouteDefinition( + "POST", "/api/lm/{prefix}/set-preview-from-url", "set_preview_from_url" + ), RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"), RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"), RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"), @@ -36,7 +40,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"), RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"), RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"), - RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"), + RouteDefinition( + "GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress" + ), RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"), RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"), RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"), @@ -44,30 +50,60 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"), RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"), RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"), - RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"), + RouteDefinition( + "GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree" + ), RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"), - RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"), + RouteDefinition( + "GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts" + ), RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"), RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"), RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"), RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"), - RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"), + RouteDefinition( + "GET", "/api/lm/{prefix}/model-description", "get_model_description" + ), RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"), - RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"), - RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"), - RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"), - RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"), - RouteDefinition("POST", "/api/lm/{prefix}/updates/fetch-missing-license", "fetch_missing_civitai_license_data"), - RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"), - RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"), - RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"), - RouteDefinition("GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions"), + RouteDefinition( + "GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions" + ), + RouteDefinition( + "GET", + "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", + "get_civitai_model_by_version", + ), + RouteDefinition( + "GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash" + ), + RouteDefinition( + "POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates" + ), + RouteDefinition( + "POST", + "/api/lm/{prefix}/updates/fetch-missing-license", + "fetch_missing_civitai_license_data", + ), + RouteDefinition( + "POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore" + ), + RouteDefinition( + "POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore" + ), + RouteDefinition( + "GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status" + ), + RouteDefinition( + "GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions" + ), RouteDefinition("POST", "/api/lm/download-model", "download_model"), RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"), RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"), RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"), RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"), - RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"), + RouteDefinition( + "GET", "/api/lm/download-progress/{download_id}", "get_download_progress" + ), RouteDefinition("POST", "/api/lm/{prefix}/cancel-task", "cancel_task"), RouteDefinition("GET", "/{prefix}", "handle_models_page"), ) @@ -94,12 +130,18 @@ class ModelRouteRegistrar: definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS, ) -> None: for definition in definitions: - self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name]) + self._bind_route( + definition.method, + definition.build_path(prefix), + handler_lookup[definition.handler_name], + ) def add_route(self, method: str, path: str, handler: Callable) -> None: self._bind_route(method, path, handler) - def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None: + def add_prefixed_route( + self, method: str, path_template: str, prefix: str, handler: Callable + ) -> None: self._bind_route(method, path_template.replace("{prefix}", prefix), handler) def _bind_route(self, method: str, path: str, handler: Callable) -> None: diff --git a/static/js/api/apiConfig.js b/static/js/api/apiConfig.js index 06c5b121..4969a8ea 100644 --- a/static/js/api/apiConfig.js +++ b/static/js/api/apiConfig.js @@ -86,6 +86,7 @@ export function getApiEndpoints(modelType) { // Preview management replacePreview: `/api/lm/${modelType}/replace-preview`, + setPreviewFromUrl: `/api/lm/${modelType}/set-preview-from-url`, // Query operations scan: `/api/lm/${modelType}/scan`, diff --git a/static/js/api/baseModelApi.js b/static/js/api/baseModelApi.js index 4d144f6a..115c6d1e 100644 --- a/static/js/api/baseModelApi.js +++ b/static/js/api/baseModelApi.js @@ -307,6 +307,56 @@ export class BaseModelApiClient { } } + /** + * Set a preview from a remote URL (e.g., CivitAI) + * @param {string} filePath - Path to the model file + * @param {string} imageUrl - Remote image URL + * @param {number} nsfwLevel - NSFW level for the preview + */ + async setPreviewFromUrl(filePath, imageUrl, nsfwLevel = 0) { + try { + state.loadingManager.showSimpleLoading('Setting preview from URL...'); + + const response = await fetch(this.apiConfig.endpoints.setPreviewFromUrl, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + model_path: filePath, + image_url: imageUrl, + nsfw_level: nsfwLevel + }) + }); + + if (!response.ok) { + throw new Error('Failed to set preview from URL'); + } + + const data = await response.json(); + const pageState = this.getPageState(); + + const timestamp = Date.now(); + if (pageState.previewVersions) { + pageState.previewVersions.set(filePath, timestamp); + + const storageKey = `${this.modelType}_preview_versions`; + saveMapToStorage(storageKey, pageState.previewVersions); + } + + const updateData = { + preview_url: data.preview_url, + preview_nsfw_level: data.preview_nsfw_level + }; + + state.virtualScroller.updateSingleItem(filePath, updateData); + showToast('toast.api.previewUpdated', {}, 'success'); + } catch (error) { + console.error('Error setting preview from URL:', error); + showToast('toast.api.previewUploadFailed', {}, 'error'); + } finally { + state.loadingManager.hide(); + } + } + async saveModelMetadata(filePath, data) { try { state.loadingManager.showSimpleLoading('Saving metadata...'); diff --git a/static/js/components/shared/showcase/MediaUtils.js b/static/js/components/shared/showcase/MediaUtils.js index deae96e7..b1e36ffd 100644 --- a/static/js/components/shared/showcase/MediaUtils.js +++ b/static/js/components/shared/showcase/MediaUtils.js @@ -527,17 +527,18 @@ function initSetPreviewHandlers(container) { const response = await fetch(mediaElement.dataset.localSrc); const blob = await response.blob(); const file = new File([blob], 'preview.jpg', { type: blob.type }); - + // Use the existing baseModelApi uploadPreview method with nsfw level - await apiClient.uploadPreview(modelFilePath, file, modelType, nsfwLevel); + await apiClient.uploadPreview(modelFilePath, file, nsfwLevel); } else { - // We need to download the remote file first - const response = await fetch(mediaElement.src); - const blob = await response.blob(); - const file = new File([blob], 'preview.jpg', { type: blob.type }); - - // Use the existing baseModelApi uploadPreview method with nsfw level - await apiClient.uploadPreview(modelFilePath, file, modelType, nsfwLevel); + // Remote file - send URL to backend to download (avoids CORS issues) + const imageUrl = mediaElement.src || mediaElement.dataset.remoteSrc; + if (!imageUrl) { + throw new Error('No image URL available'); + } + + // Use the new setPreviewFromUrl method + await apiClient.setPreviewFromUrl(modelFilePath, imageUrl, nsfwLevel); } } catch (error) { console.error('Error setting preview:', error); diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 835b4c46..542c8332 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -31,7 +31,9 @@ from py.utils.metadata_manager import MetadataManager class DummyRoutes(BaseModelRoutes): template_name = "dummy.html" - def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests + def setup_specific_routes( + self, registrar, prefix: str + ) -> None: # pragma: no cover - no extra routes in smoke tests return None def __init__(self, service=None): @@ -59,7 +61,9 @@ class NullUpdateRecord: @property def in_library_version_ids(self) -> list[int]: - return [version.version_id for version in self.versions if version.is_in_library] + return [ + version.version_id for version in self.versions if version.is_in_library + ] def has_update(self) -> bool: return False @@ -86,7 +90,9 @@ class NullModelUpdateService: ) for version_id in version_ids ] - return NullUpdateRecord(model_type=model_type, model_id=model_id, versions=versions) + return NullUpdateRecord( + model_type=model_type, model_id=model_id, versions=versions + ) async def set_should_ignore(self, model_type, model_id, should_ignore): return NullUpdateRecord( @@ -95,7 +101,9 @@ class NullModelUpdateService: should_ignore_model=should_ignore, ) - async def set_version_should_ignore(self, model_type, model_id, version_id, should_ignore): + async def set_version_should_ignore( + self, model_type, model_id, version_id, should_ignore + ): return await self.set_should_ignore(model_type, model_id, should_ignore) async def get_record(self, *args, **kwargs): @@ -167,7 +175,9 @@ def download_manager_stub(): def test_list_models_returns_formatted_items(mock_service, mock_scanner): - mock_service.paginated_items = [{"file_path": "/tmp/demo.safetensors", "name": "Demo"}] + mock_service.paginated_items = [ + {"file_path": "/tmp/demo.safetensors", "name": "Demo"} + ] async def scenario(): client = await create_test_client(mock_service) @@ -176,7 +186,13 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner): payload = await response.json() assert response.status == 200 - assert payload["items"] == [{"file_path": "/tmp/demo.safetensors", "name": "Demo", "formatted": True}] + assert payload["items"] == [ + { + "file_path": "/tmp/demo.safetensors", + "name": "Demo", + "formatted": True, + } + ] assert payload["total"] == 1 assert mock_service.formatted == payload["items"] finally: @@ -220,7 +236,9 @@ def test_routes_return_service_not_ready_when_unattached(): asyncio.run(scenario()) -def test_delete_model_updates_cache_and_hash_index(mock_service, mock_scanner, tmp_path: Path): +def test_delete_model_updates_cache_and_hash_index( + mock_service, mock_scanner, tmp_path: Path +): model_path = tmp_path / "sample.safetensors" model_path.write_bytes(b"model") mock_scanner._cache.raw_data = [{"file_path": str(model_path)}] @@ -271,17 +289,23 @@ def test_replace_preview_writes_file_and_updates_cache( ) form = FormData() - form.add_field("preview_file", b"binary-data", filename="preview.png", content_type="image/png") + form.add_field( + "preview_file", b"binary-data", filename="preview.png", content_type="image/png" + ) form.add_field("model_path", str(model_path)) form.add_field("nsfw_level", "2") async def scenario(): client = await create_test_client(mock_service) try: - response = await client.post("/api/lm/test-models/replace-preview", data=form) + response = await client.post( + "/api/lm/test-models/replace-preview", data=form + ) payload = await response.json() - expected_preview = str((tmp_path / "preview-model.webp")).replace(os.sep, "/") + expected_preview = str((tmp_path / "preview-model.webp")).replace( + os.sep, "/" + ) assert response.status == 200 assert payload["success"] is True assert payload["preview_url"] == "/static/preview-model.webp" @@ -299,6 +323,66 @@ def test_replace_preview_writes_file_and_updates_cache( asyncio.run(scenario()) +def test_set_preview_from_url_downloads_and_updates_cache( + mock_service, + mock_scanner, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + """Test that set_preview_from_url endpoint downloads remote images and sets them as preview.""" + model_path = tmp_path / "url-preview-model.safetensors" + model_path.write_bytes(b"model") + metadata_path = tmp_path / "url-preview-model.metadata.json" + metadata_path.write_text(json.dumps({"file_path": str(model_path)})) + + mock_scanner._cache.raw_data = [{"file_path": str(model_path)}] + + monkeypatch.setattr( + config, + "get_preview_static_url", + lambda preview_path: f"/static/{Path(preview_path).name}", + ) + + async def scenario(): + client = await create_test_client(mock_service) + try: + # Mock the Downloader to return a test image + from py.services import downloader + + class FakeDownloader: + async def download_to_memory( + self, url, use_auth=False, return_headers=True + ): + return True, b"fake-image-data", {"Content-Type": "image/jpeg"} + + async def fake_get_downloader(): + return FakeDownloader() + + monkeypatch.setattr(downloader, "get_downloader", fake_get_downloader) + + response = await client.post( + "/api/lm/test-models/set-preview-from-url", + json={ + "model_path": str(model_path), + "image_url": "https://example.com/image.jpg", + "nsfw_level": 3, + }, + ) + payload = await response.json() + + expected_preview = str((tmp_path / "url-preview-model.webp")).replace( + os.sep, "/" + ) + assert response.status == 200 + assert payload["success"] is True + assert payload["preview_url"] == "/static/url-preview-model.webp" + assert Path(expected_preview).exists() + finally: + await client.close() + + asyncio.run(scenario()) + + def test_fetch_civitai_hydrates_metadata_before_sync( mock_service, mock_scanner, @@ -370,9 +454,15 @@ def test_fetch_civitai_hydrates_metadata_before_sync( save_calls: list[tuple[str, dict]] = [] captured: dict[str, dict] = {} - monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata)) - monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata)) - monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model) + monkeypatch.setattr( + MetadataManager, "load_metadata", staticmethod(fake_load_metadata) + ) + monkeypatch.setattr( + MetadataManager, "save_metadata", staticmethod(fake_save_metadata) + ) + monkeypatch.setattr( + MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model + ) async def scenario(): client = await create_test_client(mock_service) @@ -386,7 +476,10 @@ def test_fetch_civitai_hydrates_metadata_before_sync( assert response.status == 200 assert payload["success"] is True assert captured["model_data"]["custom_field"] == "preserve" - assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png" + assert ( + captured["model_data"]["civitai"]["images"][0]["url"] + == "https://example.com/existing.png" + ) assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"] assert captured["model_data"]["civitai"]["id"] == 99 finally: @@ -398,7 +491,10 @@ def test_fetch_civitai_hydrates_metadata_before_sync( saved_path, saved_payload = save_calls[0] assert saved_path == str(metadata_path) assert saved_payload["custom_field"] == "preserve" - assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png" + assert ( + saved_payload["civitai"]["images"][0]["url"] + == "https://example.com/existing.png" + ) assert saved_payload["civitai"]["trainedWords"] == ["keep"] assert saved_payload["civitai"]["id"] == 99 assert saved_payload["legacy_field"] == "legacy" @@ -432,11 +528,22 @@ def test_download_model_invokes_download_manager( assert call_args["download_id"] == payload["download_id"] progress = ws_manager.get_download_progress(payload["download_id"]) assert progress is not None - expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete) + expected_progress = round( + download_manager_stub.last_progress_snapshot.percent_complete + ) assert progress["progress"] == expected_progress - assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded - assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes - assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second + assert ( + progress["bytes_downloaded"] + == download_manager_stub.last_progress_snapshot.bytes_downloaded + ) + assert ( + progress["total_bytes"] + == download_manager_stub.last_progress_snapshot.total_bytes + ) + assert ( + progress["bytes_per_second"] + == download_manager_stub.last_progress_snapshot.bytes_per_second + ) assert "timestamp" in progress progress_response = await client.get( @@ -526,21 +633,30 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service): async def scenario(): client = await create_test_client(mock_service) try: - await ws_manager.broadcast_auto_organize_progress({"status": "processing", "percent": 50}) + await ws_manager.broadcast_auto_organize_progress( + {"status": "processing", "percent": 50} + ) response = await client.get("/api/lm/test-models/auto-organize-progress") payload = await response.json() assert response.status == 200 - assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}} + assert payload == { + "success": True, + "progress": {"status": "processing", "percent": 50}, + } finally: await client.close() - + asyncio.run(scenario()) -def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch): - async def fake_auto_organize(self, file_paths=None, progress_callback=None, exclusion_patterns=None): +def test_auto_organize_route_emits_progress( + mock_service, monkeypatch: pytest.MonkeyPatch +): + async def fake_auto_organize( + self, file_paths=None, progress_callback=None, exclusion_patterns=None + ): result = AutoOrganizeResult() result.total = 1 result.processed = 1 @@ -549,8 +665,12 @@ def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.Mo result.failure_count = 0 result.operation_type = "bulk" if progress_callback is not None: - await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"}) - await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"}) + await progress_callback.on_progress( + {"type": "auto_organize_progress", "status": "started"} + ) + await progress_callback.on_progress( + {"type": "auto_organize_progress", "status": "completed"} + ) return result monkeypatch.setattr( @@ -562,7 +682,9 @@ def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.Mo async def scenario(): client = await create_test_client(mock_service) try: - response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []}) + response = await client.post( + "/api/lm/test-models/auto-organize", json={"file_paths": []} + ) payload = await response.json() assert response.status == 200