fix(preview): resolve CORS error when setting CivitAI remote media as preview

- Add new endpoint POST /api/lm/{prefix}/set-preview-from-url to handle
  remote image downloads server-side, avoiding CORS issues
- Use rewrite_preview_url() to download optimized smaller images (450px width)
- Use Downloader service for reliable downloads with retry logic and proxy support
- Update frontend to call new endpoint instead of fetching images in browser

fixes #837
This commit is contained in:
Will Miao
2026-03-02 13:21:18 +08:00
parent 8b924b1551
commit bde11b153f
6 changed files with 445 additions and 72 deletions

View File

@@ -31,7 +31,9 @@ from py.utils.metadata_manager import MetadataManager
class DummyRoutes(BaseModelRoutes):
template_name = "dummy.html"
def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
def setup_specific_routes(
self, registrar, prefix: str
) -> None: # pragma: no cover - no extra routes in smoke tests
return None
def __init__(self, service=None):
@@ -59,7 +61,9 @@ class NullUpdateRecord:
@property
def in_library_version_ids(self) -> list[int]:
return [version.version_id for version in self.versions if version.is_in_library]
return [
version.version_id for version in self.versions if version.is_in_library
]
def has_update(self) -> bool:
return False
@@ -86,7 +90,9 @@ class NullModelUpdateService:
)
for version_id in version_ids
]
return NullUpdateRecord(model_type=model_type, model_id=model_id, versions=versions)
return NullUpdateRecord(
model_type=model_type, model_id=model_id, versions=versions
)
async def set_should_ignore(self, model_type, model_id, should_ignore):
return NullUpdateRecord(
@@ -95,7 +101,9 @@ class NullModelUpdateService:
should_ignore_model=should_ignore,
)
async def set_version_should_ignore(self, model_type, model_id, version_id, should_ignore):
async def set_version_should_ignore(
self, model_type, model_id, version_id, should_ignore
):
return await self.set_should_ignore(model_type, model_id, should_ignore)
async def get_record(self, *args, **kwargs):
@@ -167,7 +175,9 @@ def download_manager_stub():
def test_list_models_returns_formatted_items(mock_service, mock_scanner):
mock_service.paginated_items = [{"file_path": "/tmp/demo.safetensors", "name": "Demo"}]
mock_service.paginated_items = [
{"file_path": "/tmp/demo.safetensors", "name": "Demo"}
]
async def scenario():
client = await create_test_client(mock_service)
@@ -176,7 +186,13 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
payload = await response.json()
assert response.status == 200
assert payload["items"] == [{"file_path": "/tmp/demo.safetensors", "name": "Demo", "formatted": True}]
assert payload["items"] == [
{
"file_path": "/tmp/demo.safetensors",
"name": "Demo",
"formatted": True,
}
]
assert payload["total"] == 1
assert mock_service.formatted == payload["items"]
finally:
@@ -220,7 +236,9 @@ def test_routes_return_service_not_ready_when_unattached():
asyncio.run(scenario())
def test_delete_model_updates_cache_and_hash_index(mock_service, mock_scanner, tmp_path: Path):
def test_delete_model_updates_cache_and_hash_index(
mock_service, mock_scanner, tmp_path: Path
):
model_path = tmp_path / "sample.safetensors"
model_path.write_bytes(b"model")
mock_scanner._cache.raw_data = [{"file_path": str(model_path)}]
@@ -271,17 +289,23 @@ def test_replace_preview_writes_file_and_updates_cache(
)
form = FormData()
form.add_field("preview_file", b"binary-data", filename="preview.png", content_type="image/png")
form.add_field(
"preview_file", b"binary-data", filename="preview.png", content_type="image/png"
)
form.add_field("model_path", str(model_path))
form.add_field("nsfw_level", "2")
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.post("/api/lm/test-models/replace-preview", data=form)
response = await client.post(
"/api/lm/test-models/replace-preview", data=form
)
payload = await response.json()
expected_preview = str((tmp_path / "preview-model.webp")).replace(os.sep, "/")
expected_preview = str((tmp_path / "preview-model.webp")).replace(
os.sep, "/"
)
assert response.status == 200
assert payload["success"] is True
assert payload["preview_url"] == "/static/preview-model.webp"
@@ -299,6 +323,66 @@ def test_replace_preview_writes_file_and_updates_cache(
asyncio.run(scenario())
def test_set_preview_from_url_downloads_and_updates_cache(
mock_service,
mock_scanner,
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
):
"""Test that set_preview_from_url endpoint downloads remote images and sets them as preview."""
model_path = tmp_path / "url-preview-model.safetensors"
model_path.write_bytes(b"model")
metadata_path = tmp_path / "url-preview-model.metadata.json"
metadata_path.write_text(json.dumps({"file_path": str(model_path)}))
mock_scanner._cache.raw_data = [{"file_path": str(model_path)}]
monkeypatch.setattr(
config,
"get_preview_static_url",
lambda preview_path: f"/static/{Path(preview_path).name}",
)
async def scenario():
client = await create_test_client(mock_service)
try:
# Mock the Downloader to return a test image
from py.services import downloader
class FakeDownloader:
async def download_to_memory(
self, url, use_auth=False, return_headers=True
):
return True, b"fake-image-data", {"Content-Type": "image/jpeg"}
async def fake_get_downloader():
return FakeDownloader()
monkeypatch.setattr(downloader, "get_downloader", fake_get_downloader)
response = await client.post(
"/api/lm/test-models/set-preview-from-url",
json={
"model_path": str(model_path),
"image_url": "https://example.com/image.jpg",
"nsfw_level": 3,
},
)
payload = await response.json()
expected_preview = str((tmp_path / "url-preview-model.webp")).replace(
os.sep, "/"
)
assert response.status == 200
assert payload["success"] is True
assert payload["preview_url"] == "/static/url-preview-model.webp"
assert Path(expected_preview).exists()
finally:
await client.close()
asyncio.run(scenario())
def test_fetch_civitai_hydrates_metadata_before_sync(
mock_service,
mock_scanner,
@@ -370,9 +454,15 @@ def test_fetch_civitai_hydrates_metadata_before_sync(
save_calls: list[tuple[str, dict]] = []
captured: dict[str, dict] = {}
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata))
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata))
monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model)
monkeypatch.setattr(
MetadataManager, "load_metadata", staticmethod(fake_load_metadata)
)
monkeypatch.setattr(
MetadataManager, "save_metadata", staticmethod(fake_save_metadata)
)
monkeypatch.setattr(
MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model
)
async def scenario():
client = await create_test_client(mock_service)
@@ -386,7 +476,10 @@ def test_fetch_civitai_hydrates_metadata_before_sync(
assert response.status == 200
assert payload["success"] is True
assert captured["model_data"]["custom_field"] == "preserve"
assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
assert (
captured["model_data"]["civitai"]["images"][0]["url"]
== "https://example.com/existing.png"
)
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
assert captured["model_data"]["civitai"]["id"] == 99
finally:
@@ -398,7 +491,10 @@ def test_fetch_civitai_hydrates_metadata_before_sync(
saved_path, saved_payload = save_calls[0]
assert saved_path == str(metadata_path)
assert saved_payload["custom_field"] == "preserve"
assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
assert (
saved_payload["civitai"]["images"][0]["url"]
== "https://example.com/existing.png"
)
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
assert saved_payload["civitai"]["id"] == 99
assert saved_payload["legacy_field"] == "legacy"
@@ -432,11 +528,22 @@ def test_download_model_invokes_download_manager(
assert call_args["download_id"] == payload["download_id"]
progress = ws_manager.get_download_progress(payload["download_id"])
assert progress is not None
expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete)
expected_progress = round(
download_manager_stub.last_progress_snapshot.percent_complete
)
assert progress["progress"] == expected_progress
assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded
assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes
assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second
assert (
progress["bytes_downloaded"]
== download_manager_stub.last_progress_snapshot.bytes_downloaded
)
assert (
progress["total_bytes"]
== download_manager_stub.last_progress_snapshot.total_bytes
)
assert (
progress["bytes_per_second"]
== download_manager_stub.last_progress_snapshot.bytes_per_second
)
assert "timestamp" in progress
progress_response = await client.get(
@@ -526,21 +633,30 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service):
async def scenario():
client = await create_test_client(mock_service)
try:
await ws_manager.broadcast_auto_organize_progress({"status": "processing", "percent": 50})
await ws_manager.broadcast_auto_organize_progress(
{"status": "processing", "percent": 50}
)
response = await client.get("/api/lm/test-models/auto-organize-progress")
payload = await response.json()
assert response.status == 200
assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}}
assert payload == {
"success": True,
"progress": {"status": "processing", "percent": 50},
}
finally:
await client.close()
asyncio.run(scenario())
def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch):
async def fake_auto_organize(self, file_paths=None, progress_callback=None, exclusion_patterns=None):
def test_auto_organize_route_emits_progress(
mock_service, monkeypatch: pytest.MonkeyPatch
):
async def fake_auto_organize(
self, file_paths=None, progress_callback=None, exclusion_patterns=None
):
result = AutoOrganizeResult()
result.total = 1
result.processed = 1
@@ -549,8 +665,12 @@ def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.Mo
result.failure_count = 0
result.operation_type = "bulk"
if progress_callback is not None:
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"})
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"})
await progress_callback.on_progress(
{"type": "auto_organize_progress", "status": "started"}
)
await progress_callback.on_progress(
{"type": "auto_organize_progress", "status": "completed"}
)
return result
monkeypatch.setattr(
@@ -562,7 +682,9 @@ def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.Mo
async def scenario():
client = await create_test_client(mock_service)
try:
response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []})
response = await client.post(
"/api/lm/test-models/auto-organize", json={"file_paths": []}
)
payload = await response.json()
assert response.status == 200