mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
fix(recipes): sanitize shared recipe filenames
This commit is contained in:
@@ -2,9 +2,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
import unicodedata
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
||||||
@@ -59,8 +61,9 @@ class RecipeSharingService:
|
|||||||
}
|
}
|
||||||
self._cleanup_shared_recipes()
|
self._cleanup_shared_recipes()
|
||||||
|
|
||||||
safe_title = recipe.get("title", "").replace(" ", "_").lower()
|
filename = self._build_download_filename(
|
||||||
filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}"
|
title=recipe.get("title", ""), recipe_id=recipe_id, ext=ext
|
||||||
|
)
|
||||||
url_path = f"/api/lm/recipe/{recipe_id}/share/download?t={timestamp}"
|
url_path = f"/api/lm/recipe/{recipe_id}/share/download?t={timestamp}"
|
||||||
return SharingResult({"success": True, "download_url": url_path, "filename": filename})
|
return SharingResult({"success": True, "download_url": url_path, "filename": filename})
|
||||||
|
|
||||||
@@ -78,13 +81,38 @@ class RecipeSharingService:
|
|||||||
raise RecipeNotFoundError("Shared recipe file not found")
|
raise RecipeNotFoundError("Shared recipe file not found")
|
||||||
|
|
||||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||||
filename_base = (
|
|
||||||
f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id
|
|
||||||
)
|
|
||||||
ext = os.path.splitext(file_path)[1]
|
ext = os.path.splitext(file_path)[1]
|
||||||
download_filename = f"{filename_base}{ext}"
|
download_filename = self._build_download_filename(
|
||||||
|
title=recipe.get("title", "") if recipe else "",
|
||||||
|
recipe_id=recipe_id,
|
||||||
|
ext=ext,
|
||||||
|
)
|
||||||
return DownloadInfo(file_path=file_path, download_filename=download_filename)
|
return DownloadInfo(file_path=file_path, download_filename=download_filename)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_download_filename(*, title: str, recipe_id: str, ext: str) -> str:
|
||||||
|
"""Generate a sanitized filename safe for HTTP headers and filesystems."""
|
||||||
|
|
||||||
|
ext = ext or ""
|
||||||
|
safe_title = RecipeSharingService._slugify(title)
|
||||||
|
fallback = RecipeSharingService._slugify(recipe_id)
|
||||||
|
identifier = safe_title or fallback or "recipe"
|
||||||
|
return f"recipe_{identifier}{ext}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _slugify(value: str) -> str:
|
||||||
|
"""Convert arbitrary input into a lowercase, header-safe slug."""
|
||||||
|
|
||||||
|
if not value:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
normalized = unicodedata.normalize("NFKD", value)
|
||||||
|
ascii_value = normalized.encode("ascii", "ignore").decode("ascii")
|
||||||
|
ascii_value = ascii_value.replace("\n", " ").replace("\r", " ")
|
||||||
|
sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", ascii_value)
|
||||||
|
sanitized = re.sub(r"_+", "_", sanitized).strip("._-")
|
||||||
|
return sanitized.lower()
|
||||||
|
|
||||||
def _cleanup_shared_recipes(self) -> None:
|
def _cleanup_shared_recipes(self) -> None:
|
||||||
for recipe_id in list(self._shared_recipes.keys()):
|
for recipe_id in list(self._shared_recipes.keys()):
|
||||||
shared = self._shared_recipes.get(recipe_id)
|
shared = self._shared_recipes.get(recipe_id)
|
||||||
|
|||||||
79
tests/services/test_recipe_sharing_service.py
Normal file
79
tests/services/test_recipe_sharing_service.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.recipes.errors import RecipeNotFoundError
|
||||||
|
from py.services.recipes.sharing_service import RecipeSharingService
|
||||||
|
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, recipe_by_id):
|
||||||
|
self._recipes = recipe_by_id
|
||||||
|
|
||||||
|
async def get_recipe_by_id(self, recipe_id):
|
||||||
|
return self._recipes.get(recipe_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_share_recipe_sanitizes_filename(tmp_path):
|
||||||
|
image_path = tmp_path / "original.png"
|
||||||
|
image_path.write_bytes(b"data")
|
||||||
|
|
||||||
|
recipe_id = "unsafe:id"
|
||||||
|
recipe = {
|
||||||
|
"file_path": str(image_path),
|
||||||
|
"title": "Bad\rTitle\n../",
|
||||||
|
}
|
||||||
|
scanner = DummyScanner({recipe_id: recipe})
|
||||||
|
|
||||||
|
service = RecipeSharingService(ttl_seconds=30, logger=logging.getLogger("test"))
|
||||||
|
|
||||||
|
result = await service.share_recipe(recipe_scanner=scanner, recipe_id=recipe_id)
|
||||||
|
assert result.payload["filename"] == "recipe_bad_title.png"
|
||||||
|
|
||||||
|
download_info = await service.prepare_download(recipe_scanner=scanner, recipe_id=recipe_id)
|
||||||
|
assert download_info.download_filename == "recipe_bad_title.png"
|
||||||
|
|
||||||
|
service._cleanup_entry(recipe_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_share_recipe_falls_back_to_recipe_id(tmp_path):
|
||||||
|
image_path = tmp_path / "original.png"
|
||||||
|
image_path.write_bytes(b"data")
|
||||||
|
|
||||||
|
recipe_id = "ID 123"
|
||||||
|
recipe = {
|
||||||
|
"file_path": str(image_path),
|
||||||
|
"title": "\n\t",
|
||||||
|
}
|
||||||
|
scanner = DummyScanner({recipe_id: recipe})
|
||||||
|
|
||||||
|
service = RecipeSharingService(ttl_seconds=30, logger=logging.getLogger("test"))
|
||||||
|
|
||||||
|
result = await service.share_recipe(recipe_scanner=scanner, recipe_id=recipe_id)
|
||||||
|
assert result.payload["filename"] == "recipe_id_123.png"
|
||||||
|
|
||||||
|
service._cleanup_entry(recipe_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prepare_download_rejects_expired(tmp_path):
|
||||||
|
service = RecipeSharingService(ttl_seconds=30, logger=logging.getLogger("test"))
|
||||||
|
|
||||||
|
image_path = tmp_path / "original.png"
|
||||||
|
image_path.write_bytes(b"data")
|
||||||
|
recipe_id = "recipe"
|
||||||
|
|
||||||
|
recipe = {"file_path": str(image_path), "title": "sample"}
|
||||||
|
scanner = DummyScanner({recipe_id: recipe})
|
||||||
|
|
||||||
|
await service.share_recipe(recipe_scanner=scanner, recipe_id=recipe_id)
|
||||||
|
|
||||||
|
# Force the entry to expire
|
||||||
|
service._shared_recipes[recipe_id]["expires"] = 0
|
||||||
|
|
||||||
|
with pytest.raises(RecipeNotFoundError):
|
||||||
|
await service.prepare_download(recipe_scanner=scanner, recipe_id=recipe_id)
|
||||||
|
|
||||||
|
service._cleanup_entry(recipe_id)
|
||||||
Reference in New Issue
Block a user