feat: add remote recipe import functionality

Add support for importing recipes from remote sources by:
- Adding import_remote_recipe endpoint to RecipeHandlerSet
- Injecting downloader_factory and civitai_client_getter dependencies
- Implementing image download and resource parsing logic
- Supporting Civitai resource payloads with checkpoints and LoRAs
- Adding required imports for regex and temporary file handling

This enables users to import recipes directly from external sources like Civitai without manual file downloads.
This commit is contained in:
Will Miao
2025-11-21 11:12:58 +08:00
parent d540b21aac
commit 7173a2b9d6
6 changed files with 302 additions and 6 deletions

View File

@@ -191,6 +191,8 @@ class BaseRecipeRoutes:
logger=logger,
persistence_service=persistence_service,
analysis_service=analysis_service,
downloader_factory=get_downloader,
civitai_client_getter=civitai_client_getter,
)
analysis = RecipeAnalysisHandler(
ensure_dependencies_ready=self.ensure_dependencies_ready,
@@ -214,4 +216,3 @@ class BaseRecipeRoutes:
analysis=analysis,
sharing=sharing,
)

View File

@@ -4,8 +4,10 @@ from __future__ import annotations
import json
import logging
import os
import re
import tempfile
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional
from aiohttp import web
@@ -45,6 +47,7 @@ class RecipeHandlerSet:
"render_page": self.page_view.render_page,
"list_recipes": self.listing.list_recipes,
"get_recipe": self.listing.get_recipe,
"import_remote_recipe": self.management.import_remote_recipe,
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
"analyze_local_image": self.analysis.analyze_local_image,
"save_recipe": self.management.save_recipe,
@@ -404,12 +407,16 @@ class RecipeManagementHandler:
logger: Logger,
persistence_service: RecipePersistenceService,
analysis_service: RecipeAnalysisService,
downloader_factory,
civitai_client_getter: CivitaiClientGetter,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
self._logger = logger
self._persistence_service = persistence_service
self._analysis_service = analysis_service
self._downloader_factory = downloader_factory
self._civitai_client_getter = civitai_client_getter
async def save_recipe(self, request: web.Request) -> web.Response:
try:
@@ -436,6 +443,62 @@ class RecipeManagementHandler:
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def import_remote_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
params = request.rel_url.query
print(params)
image_url = params.get("image_url")
name = params.get("name")
resources_raw = params.get("resources")
if not image_url:
raise RecipeValidationError("Missing required field: image_url")
if not name:
raise RecipeValidationError("Missing required field: name")
if not resources_raw:
raise RecipeValidationError("Missing required field: resources")
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
gen_params = self._parse_gen_params(params.get("gen_params"))
metadata: Dict[str, Any] = {
"base_model": params.get("base_model", "") or "",
"loras": lora_entries,
}
source_path = params.get("source_path")
if source_path:
metadata["source_path"] = source_path
if gen_params is not None:
metadata["gen_params"] = gen_params
if checkpoint_entry:
metadata["checkpoint"] = checkpoint_entry
gen_params_ref = metadata.setdefault("gen_params", {})
if "checkpoint" not in gen_params_ref:
gen_params_ref["checkpoint"] = checkpoint_entry
tags = self._parse_tags(params.get("tags"))
image_bytes = await self._download_image_bytes(image_url)
result = await self._persistence_service.save_recipe(
recipe_scanner=recipe_scanner,
image_bytes=image_bytes,
image_base64=None,
name=name,
tags=tags,
metadata=metadata,
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"error": str(exc)}, status=400)
except RecipeDownloadError as exc:
return web.json_response({"error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def delete_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
@@ -595,6 +658,117 @@ class RecipeManagementHandler:
"metadata": metadata,
}
def _parse_tags(self, tag_text: Optional[str]) -> list[str]:
if not tag_text:
return []
return [tag.strip() for tag in tag_text.split(",") if tag.strip()]
def _parse_gen_params(self, payload: Optional[str]) -> Optional[Dict[str, Any]]:
if payload is None:
return None
if payload == "":
return {}
try:
parsed = json.loads(payload)
except json.JSONDecodeError as exc:
raise RecipeValidationError(f"Invalid gen_params payload: {exc}") from exc
if parsed is None:
return {}
if not isinstance(parsed, dict):
raise RecipeValidationError("gen_params payload must be an object")
return parsed
def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
try:
payload = json.loads(payload_raw)
except json.JSONDecodeError as exc:
raise RecipeValidationError(f"Invalid resources payload: {exc}") from exc
if not isinstance(payload, list):
raise RecipeValidationError("Resources payload must be a list")
checkpoint_entry: Optional[Dict[str, Any]] = None
lora_entries: List[Dict[str, Any]] = []
for resource in payload:
if not isinstance(resource, dict):
continue
resource_type = str(resource.get("type") or "").lower()
if resource_type == "checkpoint":
checkpoint_entry = self._build_checkpoint_entry(resource)
elif resource_type in {"lora", "lycoris"}:
lora_entries.append(self._build_lora_entry(resource))
return checkpoint_entry, lora_entries
def _build_checkpoint_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
return {
"type": resource.get("type", "checkpoint"),
"modelId": self._safe_int(resource.get("modelId")),
"modelVersionId": self._safe_int(resource.get("modelVersionId")),
"modelName": resource.get("modelName", ""),
"modelVersionName": resource.get("modelVersionName", ""),
}
def _build_lora_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
weight_raw = resource.get("weight", 1.0)
try:
weight = float(weight_raw)
except (TypeError, ValueError):
weight = 1.0
return {
"file_name": resource.get("modelName", ""),
"weight": weight,
"id": self._safe_int(resource.get("modelVersionId")),
"name": resource.get("modelName", ""),
"version": resource.get("modelVersionName", ""),
"isDeleted": False,
"exclude": False,
}
async def _download_image_bytes(self, image_url: str) -> bytes:
civitai_client = self._civitai_client_getter()
downloader = await self._downloader_factory()
temp_path = None
try:
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
temp_path = temp_file.name
download_url = image_url
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url)
if civitai_match:
if civitai_client is None:
raise RecipeDownloadError("Civitai client unavailable for image download")
image_info = await civitai_client.get_image_info(civitai_match.group(1))
if not image_info:
raise RecipeDownloadError("Failed to fetch image information from Civitai")
download_url = image_info.get("url")
if not download_url:
raise RecipeDownloadError("No image URL found in Civitai response")
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
if not success:
raise RecipeDownloadError(f"Failed to download image: {result}")
with open(temp_path, "rb") as file_obj:
return file_obj.read()
except RecipeDownloadError:
raise
except RecipeValidationError:
raise
except Exception as exc: # pragma: no cover - defensive guard
raise RecipeValidationError(f"Unable to download image: {exc}") from exc
finally:
if temp_path:
try:
os.unlink(temp_path)
except FileNotFoundError:
pass
def _safe_int(self, value: Any) -> int:
try:
return int(value)
except (TypeError, ValueError):
return 0
class RecipeAnalysisHandler:
"""Analyze images to extract recipe metadata."""

View File

@@ -20,6 +20,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/loras/recipes", "render_page"),
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"),
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
@@ -61,4 +62,3 @@ class RecipeRouteRegistrar:
add_method_name = self._METHOD_MAP[method.upper()]
add_method = getattr(self._app.router, add_method_name)
add_method(path, handler)

View File

@@ -78,9 +78,10 @@ class RecipePersistenceService:
file_obj.write(optimized_image)
current_time = time.time()
loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])]
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
checkpoint_entry = metadata.get("checkpoint")
gen_params = metadata.get("gen_params", {})
gen_params = metadata.get("gen_params") or {}
if not gen_params and "raw_metadata" in metadata:
raw_metadata = metadata.get("raw_metadata", {})
gen_params = {
@@ -94,6 +95,8 @@ class RecipePersistenceService:
"size": raw_metadata.get("size", ""),
"clip_skip": raw_metadata.get("clip_skip", ""),
}
if checkpoint_entry and "checkpoint" not in gen_params:
gen_params["checkpoint"] = checkpoint_entry
fingerprint = calculate_recipe_fingerprint(loras_data)
recipe_data: Dict[str, Any] = {
@@ -107,6 +110,8 @@ class RecipePersistenceService:
"gen_params": gen_params,
"fingerprint": fingerprint,
}
if checkpoint_entry:
recipe_data["checkpoint"] = checkpoint_entry
tags_list = list(tags)
if tags_list:

View File

@@ -27,6 +27,7 @@ class RecipeRouteHarness:
analysis: "StubAnalysisService"
persistence: "StubPersistenceService"
sharing: "StubSharingService"
downloader: "StubDownloader"
tmp_dir: Path
@@ -175,6 +176,18 @@ class StubSharingService:
return self.download_info
class StubDownloader:
"""Downloader stub that writes deterministic bytes to requested locations."""
def __init__(self) -> None:
self.urls: List[str] = []
async def download_file(self, url: str, destination: str, use_auth: bool = False): # noqa: ARG002 - use_auth unused
self.urls.append(url)
Path(destination).write_bytes(b"imported-image")
return True, destination
@asynccontextmanager
async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]:
"""Context manager that yields a fully wired recipe route harness."""
@@ -191,11 +204,17 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
async def fake_get_civitai_client():
return object()
downloader = StubDownloader()
async def fake_get_downloader():
return downloader
monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner)
monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client)
monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService)
monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService)
monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService)
monkeypatch.setattr(base_recipe_routes, "get_downloader", fake_get_downloader)
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False)
app = web.Application()
@@ -211,6 +230,7 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
analysis=StubAnalysisService.instances[-1],
persistence=StubPersistenceService.instances[-1],
sharing=StubSharingService.instances[-1],
downloader=downloader,
tmp_dir=tmp_path,
)
@@ -275,6 +295,54 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) ->
assert harness.persistence.delete_calls == ["saved-id"]
async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
resources = [
{
"type": "checkpoint",
"modelId": 10,
"modelVersionId": 33,
"modelName": "Flux",
"modelVersionName": "Dev",
},
{
"type": "lora",
"modelId": 20,
"modelVersionId": 44,
"modelName": "Painterly",
"modelVersionName": "v2",
"weight": 0.25,
},
]
response = await harness.client.get(
"/api/lm/recipes/import-remote",
params={
"image_url": "https://example.com/images/1",
"name": "Remote Recipe",
"resources": json.dumps(resources),
"tags": "foo,bar",
"base_model": "Flux",
"source_path": "https://example.com/images/1",
"gen_params": json.dumps({"prompt": "hello world", "cfg_scale": 7}),
},
)
payload = await response.json()
assert response.status == 200
assert payload["success"] is True
call = harness.persistence.save_calls[-1]
assert call["name"] == "Remote Recipe"
assert call["tags"] == ["foo", "bar"]
metadata = call["metadata"]
assert metadata["base_model"] == "Flux"
assert metadata["checkpoint"]["modelVersionId"] == 33
assert metadata["loras"][0]["weight"] == 0.25
assert metadata["gen_params"]["prompt"] == "hello world"
assert metadata["gen_params"]["checkpoint"]["modelVersionId"] == 33
assert harness.downloader.urls == ["https://example.com/images/1"]
async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None:
async with recipe_harness(monkeypatch, tmp_path) as harness:
harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")
@@ -327,4 +395,3 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
assert body == b"stub"
download_path.unlink(missing_ok=True)

View File

@@ -157,6 +157,55 @@ async def test_save_recipe_reports_duplicates(tmp_path):
assert service._exif_utils.appended[0] == expected_image_path
@pytest.mark.asyncio
async def test_save_recipe_persists_checkpoint_metadata(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
checkpoint_meta = {
"type": "checkpoint",
"modelId": 10,
"modelVersionId": 20,
"modelName": "Flux",
"modelVersionName": "Dev",
}
metadata = {
"base_model": "Flux",
"loras": [],
"checkpoint": checkpoint_meta,
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Checkpointed",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["checkpoint"] == checkpoint_meta
assert stored["gen_params"]["checkpoint"] == checkpoint_meta
@pytest.mark.asyncio
async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
exif_utils = DummyExifUtils()