mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: add remote recipe import functionality
Add support for importing recipes from remote sources by: - Adding import_remote_recipe endpoint to RecipeHandlerSet - Injecting downloader_factory and civitai_client_getter dependencies - Implementing image download and resource parsing logic - Supporting Civitai resource payloads with checkpoints and LoRAs - Adding required imports for regex and temporary file handling This enables users to import recipes directly from external sources like Civitai without manual file downloads.
This commit is contained in:
@@ -191,6 +191,8 @@ class BaseRecipeRoutes:
|
|||||||
logger=logger,
|
logger=logger,
|
||||||
persistence_service=persistence_service,
|
persistence_service=persistence_service,
|
||||||
analysis_service=analysis_service,
|
analysis_service=analysis_service,
|
||||||
|
downloader_factory=get_downloader,
|
||||||
|
civitai_client_getter=civitai_client_getter,
|
||||||
)
|
)
|
||||||
analysis = RecipeAnalysisHandler(
|
analysis = RecipeAnalysisHandler(
|
||||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||||
@@ -214,4 +216,3 @@ class BaseRecipeRoutes:
|
|||||||
analysis=analysis,
|
analysis=analysis,
|
||||||
sharing=sharing,
|
sharing=sharing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
|
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
@@ -45,6 +47,7 @@ class RecipeHandlerSet:
|
|||||||
"render_page": self.page_view.render_page,
|
"render_page": self.page_view.render_page,
|
||||||
"list_recipes": self.listing.list_recipes,
|
"list_recipes": self.listing.list_recipes,
|
||||||
"get_recipe": self.listing.get_recipe,
|
"get_recipe": self.listing.get_recipe,
|
||||||
|
"import_remote_recipe": self.management.import_remote_recipe,
|
||||||
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
||||||
"analyze_local_image": self.analysis.analyze_local_image,
|
"analyze_local_image": self.analysis.analyze_local_image,
|
||||||
"save_recipe": self.management.save_recipe,
|
"save_recipe": self.management.save_recipe,
|
||||||
@@ -404,12 +407,16 @@ class RecipeManagementHandler:
|
|||||||
logger: Logger,
|
logger: Logger,
|
||||||
persistence_service: RecipePersistenceService,
|
persistence_service: RecipePersistenceService,
|
||||||
analysis_service: RecipeAnalysisService,
|
analysis_service: RecipeAnalysisService,
|
||||||
|
downloader_factory,
|
||||||
|
civitai_client_getter: CivitaiClientGetter,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||||
self._recipe_scanner_getter = recipe_scanner_getter
|
self._recipe_scanner_getter = recipe_scanner_getter
|
||||||
self._logger = logger
|
self._logger = logger
|
||||||
self._persistence_service = persistence_service
|
self._persistence_service = persistence_service
|
||||||
self._analysis_service = analysis_service
|
self._analysis_service = analysis_service
|
||||||
|
self._downloader_factory = downloader_factory
|
||||||
|
self._civitai_client_getter = civitai_client_getter
|
||||||
|
|
||||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
@@ -436,6 +443,62 @@ class RecipeManagementHandler:
|
|||||||
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||||
return web.json_response({"error": str(exc)}, status=500)
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def import_remote_recipe(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
params = request.rel_url.query
|
||||||
|
print(params)
|
||||||
|
image_url = params.get("image_url")
|
||||||
|
name = params.get("name")
|
||||||
|
resources_raw = params.get("resources")
|
||||||
|
if not image_url:
|
||||||
|
raise RecipeValidationError("Missing required field: image_url")
|
||||||
|
if not name:
|
||||||
|
raise RecipeValidationError("Missing required field: name")
|
||||||
|
if not resources_raw:
|
||||||
|
raise RecipeValidationError("Missing required field: resources")
|
||||||
|
|
||||||
|
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
|
||||||
|
gen_params = self._parse_gen_params(params.get("gen_params"))
|
||||||
|
metadata: Dict[str, Any] = {
|
||||||
|
"base_model": params.get("base_model", "") or "",
|
||||||
|
"loras": lora_entries,
|
||||||
|
}
|
||||||
|
source_path = params.get("source_path")
|
||||||
|
if source_path:
|
||||||
|
metadata["source_path"] = source_path
|
||||||
|
if gen_params is not None:
|
||||||
|
metadata["gen_params"] = gen_params
|
||||||
|
if checkpoint_entry:
|
||||||
|
metadata["checkpoint"] = checkpoint_entry
|
||||||
|
gen_params_ref = metadata.setdefault("gen_params", {})
|
||||||
|
if "checkpoint" not in gen_params_ref:
|
||||||
|
gen_params_ref["checkpoint"] = checkpoint_entry
|
||||||
|
|
||||||
|
tags = self._parse_tags(params.get("tags"))
|
||||||
|
image_bytes = await self._download_image_bytes(image_url)
|
||||||
|
|
||||||
|
result = await self._persistence_service.save_recipe(
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
image_bytes=image_bytes,
|
||||||
|
image_base64=None,
|
||||||
|
name=name,
|
||||||
|
tags=tags,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=400)
|
||||||
|
except RecipeDownloadError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
await self._ensure_dependencies_ready()
|
await self._ensure_dependencies_ready()
|
||||||
@@ -595,6 +658,117 @@ class RecipeManagementHandler:
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _parse_tags(self, tag_text: Optional[str]) -> list[str]:
|
||||||
|
if not tag_text:
|
||||||
|
return []
|
||||||
|
return [tag.strip() for tag in tag_text.split(",") if tag.strip()]
|
||||||
|
|
||||||
|
def _parse_gen_params(self, payload: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||||
|
if payload is None:
|
||||||
|
return None
|
||||||
|
if payload == "":
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(payload)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise RecipeValidationError(f"Invalid gen_params payload: {exc}") from exc
|
||||||
|
if parsed is None:
|
||||||
|
return {}
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise RecipeValidationError("gen_params payload must be an object")
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
|
try:
|
||||||
|
payload = json.loads(payload_raw)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise RecipeValidationError(f"Invalid resources payload: {exc}") from exc
|
||||||
|
|
||||||
|
if not isinstance(payload, list):
|
||||||
|
raise RecipeValidationError("Resources payload must be a list")
|
||||||
|
|
||||||
|
checkpoint_entry: Optional[Dict[str, Any]] = None
|
||||||
|
lora_entries: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for resource in payload:
|
||||||
|
if not isinstance(resource, dict):
|
||||||
|
continue
|
||||||
|
resource_type = str(resource.get("type") or "").lower()
|
||||||
|
if resource_type == "checkpoint":
|
||||||
|
checkpoint_entry = self._build_checkpoint_entry(resource)
|
||||||
|
elif resource_type in {"lora", "lycoris"}:
|
||||||
|
lora_entries.append(self._build_lora_entry(resource))
|
||||||
|
|
||||||
|
return checkpoint_entry, lora_entries
|
||||||
|
|
||||||
|
def _build_checkpoint_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": resource.get("type", "checkpoint"),
|
||||||
|
"modelId": self._safe_int(resource.get("modelId")),
|
||||||
|
"modelVersionId": self._safe_int(resource.get("modelVersionId")),
|
||||||
|
"modelName": resource.get("modelName", ""),
|
||||||
|
"modelVersionName": resource.get("modelVersionName", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_lora_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
weight_raw = resource.get("weight", 1.0)
|
||||||
|
try:
|
||||||
|
weight = float(weight_raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
weight = 1.0
|
||||||
|
return {
|
||||||
|
"file_name": resource.get("modelName", ""),
|
||||||
|
"weight": weight,
|
||||||
|
"id": self._safe_int(resource.get("modelVersionId")),
|
||||||
|
"name": resource.get("modelName", ""),
|
||||||
|
"version": resource.get("modelVersionName", ""),
|
||||||
|
"isDeleted": False,
|
||||||
|
"exclude": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _download_image_bytes(self, image_url: str) -> bytes:
|
||||||
|
civitai_client = self._civitai_client_getter()
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
temp_path = None
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
download_url = image_url
|
||||||
|
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url)
|
||||||
|
if civitai_match:
|
||||||
|
if civitai_client is None:
|
||||||
|
raise RecipeDownloadError("Civitai client unavailable for image download")
|
||||||
|
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||||
|
if not image_info:
|
||||||
|
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||||
|
download_url = image_info.get("url")
|
||||||
|
if not download_url:
|
||||||
|
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||||
|
|
||||||
|
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
|
||||||
|
if not success:
|
||||||
|
raise RecipeDownloadError(f"Failed to download image: {result}")
|
||||||
|
with open(temp_path, "rb") as file_obj:
|
||||||
|
return file_obj.read()
|
||||||
|
except RecipeDownloadError:
|
||||||
|
raise
|
||||||
|
except RecipeValidationError:
|
||||||
|
raise
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
raise RecipeValidationError(f"Unable to download image: {exc}") from exc
|
||||||
|
finally:
|
||||||
|
if temp_path:
|
||||||
|
try:
|
||||||
|
os.unlink(temp_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _safe_int(self, value: Any) -> int:
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class RecipeAnalysisHandler:
|
class RecipeAnalysisHandler:
|
||||||
"""Analyze images to extract recipe metadata."""
|
"""Analyze images to extract recipe metadata."""
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
||||||
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
||||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"),
|
||||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||||
@@ -61,4 +62,3 @@ class RecipeRouteRegistrar:
|
|||||||
add_method_name = self._METHOD_MAP[method.upper()]
|
add_method_name = self._METHOD_MAP[method.upper()]
|
||||||
add_method = getattr(self._app.router, add_method_name)
|
add_method = getattr(self._app.router, add_method_name)
|
||||||
add_method(path, handler)
|
add_method(path, handler)
|
||||||
|
|
||||||
|
|||||||
@@ -78,9 +78,10 @@ class RecipePersistenceService:
|
|||||||
file_obj.write(optimized_image)
|
file_obj.write(optimized_image)
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])]
|
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
||||||
|
checkpoint_entry = metadata.get("checkpoint")
|
||||||
|
|
||||||
gen_params = metadata.get("gen_params", {})
|
gen_params = metadata.get("gen_params") or {}
|
||||||
if not gen_params and "raw_metadata" in metadata:
|
if not gen_params and "raw_metadata" in metadata:
|
||||||
raw_metadata = metadata.get("raw_metadata", {})
|
raw_metadata = metadata.get("raw_metadata", {})
|
||||||
gen_params = {
|
gen_params = {
|
||||||
@@ -94,6 +95,8 @@ class RecipePersistenceService:
|
|||||||
"size": raw_metadata.get("size", ""),
|
"size": raw_metadata.get("size", ""),
|
||||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||||
}
|
}
|
||||||
|
if checkpoint_entry and "checkpoint" not in gen_params:
|
||||||
|
gen_params["checkpoint"] = checkpoint_entry
|
||||||
|
|
||||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||||
recipe_data: Dict[str, Any] = {
|
recipe_data: Dict[str, Any] = {
|
||||||
@@ -107,6 +110,8 @@ class RecipePersistenceService:
|
|||||||
"gen_params": gen_params,
|
"gen_params": gen_params,
|
||||||
"fingerprint": fingerprint,
|
"fingerprint": fingerprint,
|
||||||
}
|
}
|
||||||
|
if checkpoint_entry:
|
||||||
|
recipe_data["checkpoint"] = checkpoint_entry
|
||||||
|
|
||||||
tags_list = list(tags)
|
tags_list = list(tags)
|
||||||
if tags_list:
|
if tags_list:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ class RecipeRouteHarness:
|
|||||||
analysis: "StubAnalysisService"
|
analysis: "StubAnalysisService"
|
||||||
persistence: "StubPersistenceService"
|
persistence: "StubPersistenceService"
|
||||||
sharing: "StubSharingService"
|
sharing: "StubSharingService"
|
||||||
|
downloader: "StubDownloader"
|
||||||
tmp_dir: Path
|
tmp_dir: Path
|
||||||
|
|
||||||
|
|
||||||
@@ -175,6 +176,18 @@ class StubSharingService:
|
|||||||
return self.download_info
|
return self.download_info
|
||||||
|
|
||||||
|
|
||||||
|
class StubDownloader:
|
||||||
|
"""Downloader stub that writes deterministic bytes to requested locations."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.urls: List[str] = []
|
||||||
|
|
||||||
|
async def download_file(self, url: str, destination: str, use_auth: bool = False): # noqa: ARG002 - use_auth unused
|
||||||
|
self.urls.append(url)
|
||||||
|
Path(destination).write_bytes(b"imported-image")
|
||||||
|
return True, destination
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]:
|
async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]:
|
||||||
"""Context manager that yields a fully wired recipe route harness."""
|
"""Context manager that yields a fully wired recipe route harness."""
|
||||||
@@ -191,11 +204,17 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
|
|||||||
async def fake_get_civitai_client():
|
async def fake_get_civitai_client():
|
||||||
return object()
|
return object()
|
||||||
|
|
||||||
|
downloader = StubDownloader()
|
||||||
|
|
||||||
|
async def fake_get_downloader():
|
||||||
|
return downloader
|
||||||
|
|
||||||
monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner)
|
monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner)
|
||||||
monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client)
|
monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client)
|
||||||
monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService)
|
monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService)
|
||||||
monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService)
|
monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService)
|
||||||
monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService)
|
monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService)
|
||||||
|
monkeypatch.setattr(base_recipe_routes, "get_downloader", fake_get_downloader)
|
||||||
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False)
|
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False)
|
||||||
|
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
@@ -211,6 +230,7 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
|
|||||||
analysis=StubAnalysisService.instances[-1],
|
analysis=StubAnalysisService.instances[-1],
|
||||||
persistence=StubPersistenceService.instances[-1],
|
persistence=StubPersistenceService.instances[-1],
|
||||||
sharing=StubSharingService.instances[-1],
|
sharing=StubSharingService.instances[-1],
|
||||||
|
downloader=downloader,
|
||||||
tmp_dir=tmp_path,
|
tmp_dir=tmp_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -275,6 +295,54 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) ->
|
|||||||
assert harness.persistence.delete_calls == ["saved-id"]
|
assert harness.persistence.delete_calls == ["saved-id"]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||||
|
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||||
|
resources = [
|
||||||
|
{
|
||||||
|
"type": "checkpoint",
|
||||||
|
"modelId": 10,
|
||||||
|
"modelVersionId": 33,
|
||||||
|
"modelName": "Flux",
|
||||||
|
"modelVersionName": "Dev",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "lora",
|
||||||
|
"modelId": 20,
|
||||||
|
"modelVersionId": 44,
|
||||||
|
"modelName": "Painterly",
|
||||||
|
"modelVersionName": "v2",
|
||||||
|
"weight": 0.25,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
response = await harness.client.get(
|
||||||
|
"/api/lm/recipes/import-remote",
|
||||||
|
params={
|
||||||
|
"image_url": "https://example.com/images/1",
|
||||||
|
"name": "Remote Recipe",
|
||||||
|
"resources": json.dumps(resources),
|
||||||
|
"tags": "foo,bar",
|
||||||
|
"base_model": "Flux",
|
||||||
|
"source_path": "https://example.com/images/1",
|
||||||
|
"gen_params": json.dumps({"prompt": "hello world", "cfg_scale": 7}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = await response.json()
|
||||||
|
assert response.status == 200
|
||||||
|
assert payload["success"] is True
|
||||||
|
|
||||||
|
call = harness.persistence.save_calls[-1]
|
||||||
|
assert call["name"] == "Remote Recipe"
|
||||||
|
assert call["tags"] == ["foo", "bar"]
|
||||||
|
metadata = call["metadata"]
|
||||||
|
assert metadata["base_model"] == "Flux"
|
||||||
|
assert metadata["checkpoint"]["modelVersionId"] == 33
|
||||||
|
assert metadata["loras"][0]["weight"] == 0.25
|
||||||
|
assert metadata["gen_params"]["prompt"] == "hello world"
|
||||||
|
assert metadata["gen_params"]["checkpoint"]["modelVersionId"] == 33
|
||||||
|
assert harness.downloader.urls == ["https://example.com/images/1"]
|
||||||
|
|
||||||
|
|
||||||
async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None:
|
async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None:
|
||||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||||
harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")
|
harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")
|
||||||
@@ -327,4 +395,3 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
|
|||||||
assert body == b"stub"
|
assert body == b"stub"
|
||||||
|
|
||||||
download_path.unlink(missing_ok=True)
|
download_path.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|||||||
@@ -157,6 +157,55 @@ async def test_save_recipe_reports_duplicates(tmp_path):
|
|||||||
assert service._exif_utils.appended[0] == expected_image_path
|
assert service._exif_utils.appended[0] == expected_image_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_recipe_persists_checkpoint_metadata(tmp_path):
|
||||||
|
exif_utils = DummyExifUtils()
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, root):
|
||||||
|
self.recipes_dir = str(root)
|
||||||
|
|
||||||
|
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
service = RecipePersistenceService(
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
card_preview_width=512,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_meta = {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"modelId": 10,
|
||||||
|
"modelVersionId": 20,
|
||||||
|
"modelName": "Flux",
|
||||||
|
"modelVersionName": "Dev",
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"base_model": "Flux",
|
||||||
|
"loras": [],
|
||||||
|
"checkpoint": checkpoint_meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await service.save_recipe(
|
||||||
|
recipe_scanner=scanner,
|
||||||
|
image_bytes=b"img",
|
||||||
|
image_base64=None,
|
||||||
|
name="Checkpointed",
|
||||||
|
tags=[],
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||||
|
assert stored["checkpoint"] == checkpoint_meta
|
||||||
|
assert stored["gen_params"]["checkpoint"] == checkpoint_meta
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
||||||
exif_utils = DummyExifUtils()
|
exif_utils = DummyExifUtils()
|
||||||
|
|||||||
Reference in New Issue
Block a user