diff --git a/py/routes/base_recipe_routes.py b/py/routes/base_recipe_routes.py index f6b0b823..c598a6d2 100644 --- a/py/routes/base_recipe_routes.py +++ b/py/routes/base_recipe_routes.py @@ -191,6 +191,8 @@ class BaseRecipeRoutes: logger=logger, persistence_service=persistence_service, analysis_service=analysis_service, + downloader_factory=get_downloader, + civitai_client_getter=civitai_client_getter, ) analysis = RecipeAnalysisHandler( ensure_dependencies_ready=self.ensure_dependencies_ready, @@ -214,4 +216,3 @@ class BaseRecipeRoutes: analysis=analysis, sharing=sharing, ) - diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index 5d76f885..b36ade5d 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -4,8 +4,10 @@ from __future__ import annotations import json import logging import os +import re +import tempfile from dataclasses import dataclass -from typing import Any, Awaitable, Callable, Dict, Mapping, Optional +from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional from aiohttp import web @@ -45,6 +47,7 @@ class RecipeHandlerSet: "render_page": self.page_view.render_page, "list_recipes": self.listing.list_recipes, "get_recipe": self.listing.get_recipe, + "import_remote_recipe": self.management.import_remote_recipe, "analyze_uploaded_image": self.analysis.analyze_uploaded_image, "analyze_local_image": self.analysis.analyze_local_image, "save_recipe": self.management.save_recipe, @@ -404,12 +407,16 @@ class RecipeManagementHandler: logger: Logger, persistence_service: RecipePersistenceService, analysis_service: RecipeAnalysisService, + downloader_factory, + civitai_client_getter: CivitaiClientGetter, ) -> None: self._ensure_dependencies_ready = ensure_dependencies_ready self._recipe_scanner_getter = recipe_scanner_getter self._logger = logger self._persistence_service = persistence_service self._analysis_service = analysis_service + self._downloader_factory = downloader_factory + self._civitai_client_getter = civitai_client_getter async def save_recipe(self, request: web.Request) -> web.Response: try: @@ -436,6 +443,62 @@ class RecipeManagementHandler: self._logger.error("Error saving recipe: %s", exc, exc_info=True) return web.json_response({"error": str(exc)}, status=500) + async def import_remote_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + params = request.rel_url.query + print(params) + image_url = params.get("image_url") + name = params.get("name") + resources_raw = params.get("resources") + if not image_url: + raise RecipeValidationError("Missing required field: image_url") + if not name: + raise RecipeValidationError("Missing required field: name") + if not resources_raw: + raise RecipeValidationError("Missing required field: resources") + + checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw) + gen_params = self._parse_gen_params(params.get("gen_params")) + metadata: Dict[str, Any] = { + "base_model": params.get("base_model", "") or "", + "loras": lora_entries, + } + source_path = params.get("source_path") + if source_path: + metadata["source_path"] = source_path + if gen_params is not None: + metadata["gen_params"] = gen_params + if checkpoint_entry: + metadata["checkpoint"] = checkpoint_entry + gen_params_ref = metadata.setdefault("gen_params", {}) + if "checkpoint" not in gen_params_ref: + gen_params_ref["checkpoint"] = checkpoint_entry + + tags = self._parse_tags(params.get("tags")) + image_bytes = await self._download_image_bytes(image_url) + + result = await self._persistence_service.save_recipe( + recipe_scanner=recipe_scanner, + image_bytes=image_bytes, + image_base64=None, + name=name, + tags=tags, + metadata=metadata, + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) + except RecipeDownloadError as exc: + return web.json_response({"error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + async def delete_recipe(self, request: web.Request) -> web.Response: try: await self._ensure_dependencies_ready() @@ -595,6 +658,117 @@ class RecipeManagementHandler: "metadata": metadata, } + def _parse_tags(self, tag_text: Optional[str]) -> list[str]: + if not tag_text: + return [] + return [tag.strip() for tag in tag_text.split(",") if tag.strip()] + + def _parse_gen_params(self, payload: Optional[str]) -> Optional[Dict[str, Any]]: + if payload is None: + return None + if payload == "": + return {} + try: + parsed = json.loads(payload) + except json.JSONDecodeError as exc: + raise RecipeValidationError(f"Invalid gen_params payload: {exc}") from exc + if parsed is None: + return {} + if not isinstance(parsed, dict): + raise RecipeValidationError("gen_params payload must be an object") + return parsed + + def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]: + try: + payload = json.loads(payload_raw) + except json.JSONDecodeError as exc: + raise RecipeValidationError(f"Invalid resources payload: {exc}") from exc + + if not isinstance(payload, list): + raise RecipeValidationError("Resources payload must be a list") + + checkpoint_entry: Optional[Dict[str, Any]] = None + lora_entries: List[Dict[str, Any]] = [] + + for resource in payload: + if not isinstance(resource, dict): + continue + resource_type = str(resource.get("type") or "").lower() + if resource_type == "checkpoint": + checkpoint_entry = self._build_checkpoint_entry(resource) + elif resource_type in {"lora", "lycoris"}: + lora_entries.append(self._build_lora_entry(resource)) + + return checkpoint_entry, lora_entries + + def _build_checkpoint_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]: + return { + "type": resource.get("type", "checkpoint"), + "modelId": self._safe_int(resource.get("modelId")), + "modelVersionId": self._safe_int(resource.get("modelVersionId")), + "modelName": resource.get("modelName", ""), + "modelVersionName": resource.get("modelVersionName", ""), + } + + def _build_lora_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]: + weight_raw = resource.get("weight", 1.0) + try: + weight = float(weight_raw) + except (TypeError, ValueError): + weight = 1.0 + return { + "file_name": resource.get("modelName", ""), + "weight": weight, + "id": self._safe_int(resource.get("modelVersionId")), + "name": resource.get("modelName", ""), + "version": resource.get("modelVersionName", ""), + "isDeleted": False, + "exclude": False, + } + + async def _download_image_bytes(self, image_url: str) -> bytes: + civitai_client = self._civitai_client_getter() + downloader = await self._downloader_factory() + temp_path = None + try: + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_path = temp_file.name + download_url = image_url + civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url) + if civitai_match: + if civitai_client is None: + raise RecipeDownloadError("Civitai client unavailable for image download") + image_info = await civitai_client.get_image_info(civitai_match.group(1)) + if not image_info: + raise RecipeDownloadError("Failed to fetch image information from Civitai") + download_url = image_info.get("url") + if not download_url: + raise RecipeDownloadError("No image URL found in Civitai response") + + success, result = await downloader.download_file(download_url, temp_path, use_auth=False) + if not success: + raise RecipeDownloadError(f"Failed to download image: {result}") + with open(temp_path, "rb") as file_obj: + return file_obj.read() + except RecipeDownloadError: + raise + except RecipeValidationError: + raise + except Exception as exc: # pragma: no cover - defensive guard + raise RecipeValidationError(f"Unable to download image: {exc}") from exc + finally: + if temp_path: + try: + os.unlink(temp_path) + except FileNotFoundError: + pass + + def _safe_int(self, value: Any) -> int: + try: + return int(value) + except (TypeError, ValueError): + return 0 + class RecipeAnalysisHandler: """Analyze images to extract recipe metadata.""" diff --git a/py/routes/recipe_route_registrar.py b/py/routes/recipe_route_registrar.py index 471edf19..18bf4cba 100644 --- a/py/routes/recipe_route_registrar.py +++ b/py/routes/recipe_route_registrar.py @@ -20,6 +20,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("GET", "/loras/recipes", "render_page"), RouteDefinition("GET", "/api/lm/recipes", "list_recipes"), RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"), + RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"), RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"), RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"), RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"), @@ -61,4 +62,3 @@ class RecipeRouteRegistrar: add_method_name = self._METHOD_MAP[method.upper()] add_method = getattr(self._app.router, add_method_name) add_method(path, handler) - diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index 3e6db390..fbc628c5 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -78,9 +78,10 @@ class RecipePersistenceService: file_obj.write(optimized_image) current_time = time.time() - loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])] + loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])] + checkpoint_entry = metadata.get("checkpoint") - gen_params = metadata.get("gen_params", {}) + gen_params = metadata.get("gen_params") or {} if not gen_params and "raw_metadata" in metadata: raw_metadata = metadata.get("raw_metadata", {}) gen_params = { @@ -94,6 +95,8 @@ class RecipePersistenceService: "size": raw_metadata.get("size", ""), "clip_skip": raw_metadata.get("clip_skip", ""), } + if checkpoint_entry and "checkpoint" not in gen_params: + gen_params["checkpoint"] = checkpoint_entry fingerprint = calculate_recipe_fingerprint(loras_data) recipe_data: Dict[str, Any] = { @@ -107,6 +110,8 @@ class RecipePersistenceService: "gen_params": gen_params, "fingerprint": fingerprint, } + if checkpoint_entry: + recipe_data["checkpoint"] = checkpoint_entry tags_list = list(tags) if tags_list: diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py index 467cb5b5..93e3578d 100644 --- a/tests/routes/test_recipe_routes.py +++ b/tests/routes/test_recipe_routes.py @@ -27,6 +27,7 @@ class RecipeRouteHarness: analysis: "StubAnalysisService" persistence: "StubPersistenceService" sharing: "StubSharingService" + downloader: "StubDownloader" tmp_dir: Path @@ -175,6 +176,18 @@ class StubSharingService: return self.download_info +class StubDownloader: + """Downloader stub that writes deterministic bytes to requested locations.""" + + def __init__(self) -> None: + self.urls: List[str] = [] + + async def download_file(self, url: str, destination: str, use_auth: bool = False): # noqa: ARG002 - use_auth unused + self.urls.append(url) + Path(destination).write_bytes(b"imported-image") + return True, destination + + @asynccontextmanager async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]: """Context manager that yields a fully wired recipe route harness.""" @@ -191,11 +204,17 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou async def fake_get_civitai_client(): return object() + downloader = StubDownloader() + + async def fake_get_downloader(): + return downloader + monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner) monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client) monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService) monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService) monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService) + monkeypatch.setattr(base_recipe_routes, "get_downloader", fake_get_downloader) monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False) app = web.Application() @@ -211,6 +230,7 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou analysis=StubAnalysisService.instances[-1], persistence=StubPersistenceService.instances[-1], sharing=StubSharingService.instances[-1], + downloader=downloader, tmp_dir=tmp_path, ) @@ -275,6 +295,54 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> assert harness.persistence.delete_calls == ["saved-id"] +async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + resources = [ + { + "type": "checkpoint", + "modelId": 10, + "modelVersionId": 33, + "modelName": "Flux", + "modelVersionName": "Dev", + }, + { + "type": "lora", + "modelId": 20, + "modelVersionId": 44, + "modelName": "Painterly", + "modelVersionName": "v2", + "weight": 0.25, + }, + ] + response = await harness.client.get( + "/api/lm/recipes/import-remote", + params={ + "image_url": "https://example.com/images/1", + "name": "Remote Recipe", + "resources": json.dumps(resources), + "tags": "foo,bar", + "base_model": "Flux", + "source_path": "https://example.com/images/1", + "gen_params": json.dumps({"prompt": "hello world", "cfg_scale": 7}), + }, + ) + + payload = await response.json() + assert response.status == 200 + assert payload["success"] is True + + call = harness.persistence.save_calls[-1] + assert call["name"] == "Remote Recipe" + assert call["tags"] == ["foo", "bar"] + metadata = call["metadata"] + assert metadata["base_model"] == "Flux" + assert metadata["checkpoint"]["modelVersionId"] == 33 + assert metadata["loras"][0]["weight"] == 0.25 + assert metadata["gen_params"]["prompt"] == "hello world" + assert metadata["gen_params"]["checkpoint"]["modelVersionId"] == 33 + assert harness.downloader.urls == ["https://example.com/images/1"] + + async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None: async with recipe_harness(monkeypatch, tmp_path) as harness: harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided") @@ -327,4 +395,3 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None: assert body == b"stub" download_path.unlink(missing_ok=True) - diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index 034b20e2..fb6dd4e5 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -157,6 +157,55 @@ async def test_save_recipe_reports_duplicates(tmp_path): assert service._exif_utils.appended[0] == expected_image_path +@pytest.mark.asyncio +async def test_save_recipe_persists_checkpoint_metadata(tmp_path): + exif_utils = DummyExifUtils() + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + + async def find_recipes_by_fingerprint(self, fingerprint): + return [] + + async def add_recipe(self, recipe_data): + return None + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + checkpoint_meta = { + "type": "checkpoint", + "modelId": 10, + "modelVersionId": 20, + "modelName": "Flux", + "modelVersionName": "Dev", + } + + metadata = { + "base_model": "Flux", + "loras": [], + "checkpoint": checkpoint_meta, + } + + result = await service.save_recipe( + recipe_scanner=scanner, + image_bytes=b"img", + image_base64=None, + name="Checkpointed", + tags=[], + metadata=metadata, + ) + + stored = json.loads(Path(result.payload["json_path"]).read_text()) + assert stored["checkpoint"] == checkpoint_meta + assert stored["gen_params"]["checkpoint"] == checkpoint_meta + + @pytest.mark.asyncio async def test_save_recipe_from_widget_allows_empty_lora(tmp_path): exif_utils = DummyExifUtils()