diff --git a/py/services/recipes/sharing_service.py b/py/services/recipes/sharing_service.py index a1af4f37..2afe190c 100644 --- a/py/services/recipes/sharing_service.py +++ b/py/services/recipes/sharing_service.py @@ -2,9 +2,11 @@ from __future__ import annotations import os +import re import shutil import tempfile import time +import unicodedata from dataclasses import dataclass from typing import Any, Dict @@ -59,8 +61,9 @@ class RecipeSharingService: } self._cleanup_shared_recipes() - safe_title = recipe.get("title", "").replace(" ", "_").lower() - filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}" + filename = self._build_download_filename( + title=recipe.get("title", ""), recipe_id=recipe_id, ext=ext + ) url_path = f"/api/lm/recipe/{recipe_id}/share/download?t={timestamp}" return SharingResult({"success": True, "download_url": url_path, "filename": filename}) @@ -78,13 +81,38 @@ class RecipeSharingService: raise RecipeNotFoundError("Shared recipe file not found") recipe = await recipe_scanner.get_recipe_by_id(recipe_id) - filename_base = ( - f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id - ) ext = os.path.splitext(file_path)[1] - download_filename = f"{filename_base}{ext}" + download_filename = self._build_download_filename( + title=recipe.get("title", "") if recipe else "", + recipe_id=recipe_id, + ext=ext, + ) return DownloadInfo(file_path=file_path, download_filename=download_filename) + @staticmethod + def _build_download_filename(*, title: str, recipe_id: str, ext: str) -> str: + """Generate a sanitized filename safe for HTTP headers and filesystems.""" + + ext = ext or "" + safe_title = RecipeSharingService._slugify(title) + fallback = RecipeSharingService._slugify(recipe_id) + identifier = safe_title or fallback or "recipe" + return f"recipe_{identifier}{ext}" + + @staticmethod + def _slugify(value: str) -> str: + """Convert arbitrary input into a lowercase, header-safe slug.""" + + if not value: + return "" + + normalized = unicodedata.normalize("NFKD", value) + ascii_value = normalized.encode("ascii", "ignore").decode("ascii") + ascii_value = ascii_value.replace("\n", " ").replace("\r", " ") + sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", ascii_value) + sanitized = re.sub(r"_+", "_", sanitized).strip("._-") + return sanitized.lower() + def _cleanup_shared_recipes(self) -> None: for recipe_id in list(self._shared_recipes.keys()): shared = self._shared_recipes.get(recipe_id) diff --git a/tests/services/test_recipe_sharing_service.py b/tests/services/test_recipe_sharing_service.py new file mode 100644 index 00000000..d70ff84b --- /dev/null +++ b/tests/services/test_recipe_sharing_service.py @@ -0,0 +1,79 @@ +import logging + +import pytest + +from py.services.recipes.errors import RecipeNotFoundError +from py.services.recipes.sharing_service import RecipeSharingService + + +class DummyScanner: + def __init__(self, recipe_by_id): + self._recipes = recipe_by_id + + async def get_recipe_by_id(self, recipe_id): + return self._recipes.get(recipe_id) + + +@pytest.mark.asyncio +async def test_share_recipe_sanitizes_filename(tmp_path): + image_path = tmp_path / "original.png" + image_path.write_bytes(b"data") + + recipe_id = "unsafe:id" + recipe = { + "file_path": str(image_path), + "title": "Bad\rTitle\n../", + } + scanner = DummyScanner({recipe_id: recipe}) + + service = RecipeSharingService(ttl_seconds=30, logger=logging.getLogger("test")) + + result = await service.share_recipe(recipe_scanner=scanner, recipe_id=recipe_id) + assert result.payload["filename"] == "recipe_bad_title.png" + + download_info = await service.prepare_download(recipe_scanner=scanner, recipe_id=recipe_id) + assert download_info.download_filename == "recipe_bad_title.png" + + service._cleanup_entry(recipe_id) + + +@pytest.mark.asyncio +async def test_share_recipe_falls_back_to_recipe_id(tmp_path): + image_path = tmp_path / "original.png" + image_path.write_bytes(b"data") + + recipe_id = "ID 123" + recipe = { + "file_path": str(image_path), + "title": "\n\t", + } + scanner = DummyScanner({recipe_id: recipe}) + + service = RecipeSharingService(ttl_seconds=30, logger=logging.getLogger("test")) + + result = await service.share_recipe(recipe_scanner=scanner, recipe_id=recipe_id) + assert result.payload["filename"] == "recipe_id_123.png" + + service._cleanup_entry(recipe_id) + + +@pytest.mark.asyncio +async def test_prepare_download_rejects_expired(tmp_path): + service = RecipeSharingService(ttl_seconds=30, logger=logging.getLogger("test")) + + image_path = tmp_path / "original.png" + image_path.write_bytes(b"data") + recipe_id = "recipe" + + recipe = {"file_path": str(image_path), "title": "sample"} + scanner = DummyScanner({recipe_id: recipe}) + + await service.share_recipe(recipe_scanner=scanner, recipe_id=recipe_id) + + # Force the entry to expire + service._shared_recipes[recipe_id]["expires"] = 0 + + with pytest.raises(RecipeNotFoundError): + await service.prepare_download(recipe_scanner=scanner, recipe_id=recipe_id) + + service._cleanup_entry(recipe_id)