mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: add remote recipe import functionality
Add support for importing recipes from remote sources by: - Adding import_remote_recipe endpoint to RecipeHandlerSet - Injecting downloader_factory and civitai_client_getter dependencies - Implementing image download and resource parsing logic - Supporting Civitai resource payloads with checkpoints and LoRAs - Adding required imports for regex and temporary file handling This enables users to import recipes directly from external sources like Civitai without manual file downloads.
This commit is contained in:
@@ -191,6 +191,8 @@ class BaseRecipeRoutes:
|
||||
logger=logger,
|
||||
persistence_service=persistence_service,
|
||||
analysis_service=analysis_service,
|
||||
downloader_factory=get_downloader,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
)
|
||||
analysis = RecipeAnalysisHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
@@ -214,4 +216,3 @@ class BaseRecipeRoutes:
|
||||
analysis=analysis,
|
||||
sharing=sharing,
|
||||
)
|
||||
|
||||
|
||||
@@ -4,8 +4,10 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
@@ -45,6 +47,7 @@ class RecipeHandlerSet:
|
||||
"render_page": self.page_view.render_page,
|
||||
"list_recipes": self.listing.list_recipes,
|
||||
"get_recipe": self.listing.get_recipe,
|
||||
"import_remote_recipe": self.management.import_remote_recipe,
|
||||
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
||||
"analyze_local_image": self.analysis.analyze_local_image,
|
||||
"save_recipe": self.management.save_recipe,
|
||||
@@ -404,12 +407,16 @@ class RecipeManagementHandler:
|
||||
logger: Logger,
|
||||
persistence_service: RecipePersistenceService,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
downloader_factory,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._persistence_service = persistence_service
|
||||
self._analysis_service = analysis_service
|
||||
self._downloader_factory = downloader_factory
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
|
||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -436,6 +443,62 @@ class RecipeManagementHandler:
|
||||
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def import_remote_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
params = request.rel_url.query
|
||||
print(params)
|
||||
image_url = params.get("image_url")
|
||||
name = params.get("name")
|
||||
resources_raw = params.get("resources")
|
||||
if not image_url:
|
||||
raise RecipeValidationError("Missing required field: image_url")
|
||||
if not name:
|
||||
raise RecipeValidationError("Missing required field: name")
|
||||
if not resources_raw:
|
||||
raise RecipeValidationError("Missing required field: resources")
|
||||
|
||||
checkpoint_entry, lora_entries = self._parse_resources_payload(resources_raw)
|
||||
gen_params = self._parse_gen_params(params.get("gen_params"))
|
||||
metadata: Dict[str, Any] = {
|
||||
"base_model": params.get("base_model", "") or "",
|
||||
"loras": lora_entries,
|
||||
}
|
||||
source_path = params.get("source_path")
|
||||
if source_path:
|
||||
metadata["source_path"] = source_path
|
||||
if gen_params is not None:
|
||||
metadata["gen_params"] = gen_params
|
||||
if checkpoint_entry:
|
||||
metadata["checkpoint"] = checkpoint_entry
|
||||
gen_params_ref = metadata.setdefault("gen_params", {})
|
||||
if "checkpoint" not in gen_params_ref:
|
||||
gen_params_ref["checkpoint"] = checkpoint_entry
|
||||
|
||||
tags = self._parse_tags(params.get("tags"))
|
||||
image_bytes = await self._download_image_bytes(image_url)
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=image_bytes,
|
||||
image_base64=None,
|
||||
name=name,
|
||||
tags=tags,
|
||||
metadata=metadata,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeDownloadError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error importing recipe from remote source: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
@@ -595,6 +658,117 @@ class RecipeManagementHandler:
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
def _parse_tags(self, tag_text: Optional[str]) -> list[str]:
|
||||
if not tag_text:
|
||||
return []
|
||||
return [tag.strip() for tag in tag_text.split(",") if tag.strip()]
|
||||
|
||||
def _parse_gen_params(self, payload: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
if payload is None:
|
||||
return None
|
||||
if payload == "":
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RecipeValidationError(f"Invalid gen_params payload: {exc}") from exc
|
||||
if parsed is None:
|
||||
return {}
|
||||
if not isinstance(parsed, dict):
|
||||
raise RecipeValidationError("gen_params payload must be an object")
|
||||
return parsed
|
||||
|
||||
def _parse_resources_payload(self, payload_raw: str) -> tuple[Optional[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
try:
|
||||
payload = json.loads(payload_raw)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RecipeValidationError(f"Invalid resources payload: {exc}") from exc
|
||||
|
||||
if not isinstance(payload, list):
|
||||
raise RecipeValidationError("Resources payload must be a list")
|
||||
|
||||
checkpoint_entry: Optional[Dict[str, Any]] = None
|
||||
lora_entries: List[Dict[str, Any]] = []
|
||||
|
||||
for resource in payload:
|
||||
if not isinstance(resource, dict):
|
||||
continue
|
||||
resource_type = str(resource.get("type") or "").lower()
|
||||
if resource_type == "checkpoint":
|
||||
checkpoint_entry = self._build_checkpoint_entry(resource)
|
||||
elif resource_type in {"lora", "lycoris"}:
|
||||
lora_entries.append(self._build_lora_entry(resource))
|
||||
|
||||
return checkpoint_entry, lora_entries
|
||||
|
||||
def _build_checkpoint_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": resource.get("type", "checkpoint"),
|
||||
"modelId": self._safe_int(resource.get("modelId")),
|
||||
"modelVersionId": self._safe_int(resource.get("modelVersionId")),
|
||||
"modelName": resource.get("modelName", ""),
|
||||
"modelVersionName": resource.get("modelVersionName", ""),
|
||||
}
|
||||
|
||||
def _build_lora_entry(self, resource: Dict[str, Any]) -> Dict[str, Any]:
|
||||
weight_raw = resource.get("weight", 1.0)
|
||||
try:
|
||||
weight = float(weight_raw)
|
||||
except (TypeError, ValueError):
|
||||
weight = 1.0
|
||||
return {
|
||||
"file_name": resource.get("modelName", ""),
|
||||
"weight": weight,
|
||||
"id": self._safe_int(resource.get("modelVersionId")),
|
||||
"name": resource.get("modelName", ""),
|
||||
"version": resource.get("modelVersionName", ""),
|
||||
"isDeleted": False,
|
||||
"exclude": False,
|
||||
}
|
||||
|
||||
async def _download_image_bytes(self, image_url: str) -> bytes:
|
||||
civitai_client = self._civitai_client_getter()
|
||||
downloader = await self._downloader_factory()
|
||||
temp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
download_url = image_url
|
||||
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", image_url)
|
||||
if civitai_match:
|
||||
if civitai_client is None:
|
||||
raise RecipeDownloadError("Civitai client unavailable for image download")
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
download_url = image_info.get("url")
|
||||
if not download_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
|
||||
success, result = await downloader.download_file(download_url, temp_path, use_auth=False)
|
||||
if not success:
|
||||
raise RecipeDownloadError(f"Failed to download image: {result}")
|
||||
with open(temp_path, "rb") as file_obj:
|
||||
return file_obj.read()
|
||||
except RecipeDownloadError:
|
||||
raise
|
||||
except RecipeValidationError:
|
||||
raise
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
raise RecipeValidationError(f"Unable to download image: {exc}") from exc
|
||||
finally:
|
||||
if temp_path:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def _safe_int(self, value: Any) -> int:
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
|
||||
class RecipeAnalysisHandler:
|
||||
"""Analyze images to extract recipe metadata."""
|
||||
|
||||
@@ -20,6 +20,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
||||
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/import-remote", "import_remote_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||
@@ -61,4 +62,3 @@ class RecipeRouteRegistrar:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
|
||||
|
||||
@@ -78,9 +78,10 @@ class RecipePersistenceService:
|
||||
file_obj.write(optimized_image)
|
||||
|
||||
current_time = time.time()
|
||||
loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])]
|
||||
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
||||
checkpoint_entry = metadata.get("checkpoint")
|
||||
|
||||
gen_params = metadata.get("gen_params", {})
|
||||
gen_params = metadata.get("gen_params") or {}
|
||||
if not gen_params and "raw_metadata" in metadata:
|
||||
raw_metadata = metadata.get("raw_metadata", {})
|
||||
gen_params = {
|
||||
@@ -94,6 +95,8 @@ class RecipePersistenceService:
|
||||
"size": raw_metadata.get("size", ""),
|
||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||
}
|
||||
if checkpoint_entry and "checkpoint" not in gen_params:
|
||||
gen_params["checkpoint"] = checkpoint_entry
|
||||
|
||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||
recipe_data: Dict[str, Any] = {
|
||||
@@ -107,6 +110,8 @@ class RecipePersistenceService:
|
||||
"gen_params": gen_params,
|
||||
"fingerprint": fingerprint,
|
||||
}
|
||||
if checkpoint_entry:
|
||||
recipe_data["checkpoint"] = checkpoint_entry
|
||||
|
||||
tags_list = list(tags)
|
||||
if tags_list:
|
||||
|
||||
@@ -27,6 +27,7 @@ class RecipeRouteHarness:
|
||||
analysis: "StubAnalysisService"
|
||||
persistence: "StubPersistenceService"
|
||||
sharing: "StubSharingService"
|
||||
downloader: "StubDownloader"
|
||||
tmp_dir: Path
|
||||
|
||||
|
||||
@@ -175,6 +176,18 @@ class StubSharingService:
|
||||
return self.download_info
|
||||
|
||||
|
||||
class StubDownloader:
|
||||
"""Downloader stub that writes deterministic bytes to requested locations."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.urls: List[str] = []
|
||||
|
||||
async def download_file(self, url: str, destination: str, use_auth: bool = False): # noqa: ARG002 - use_auth unused
|
||||
self.urls.append(url)
|
||||
Path(destination).write_bytes(b"imported-image")
|
||||
return True, destination
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]:
|
||||
"""Context manager that yields a fully wired recipe route harness."""
|
||||
@@ -191,11 +204,17 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
|
||||
async def fake_get_civitai_client():
|
||||
return object()
|
||||
|
||||
downloader = StubDownloader()
|
||||
|
||||
async def fake_get_downloader():
|
||||
return downloader
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner)
|
||||
monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService)
|
||||
monkeypatch.setattr(base_recipe_routes, "get_downloader", fake_get_downloader)
|
||||
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False)
|
||||
|
||||
app = web.Application()
|
||||
@@ -211,6 +230,7 @@ async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRou
|
||||
analysis=StubAnalysisService.instances[-1],
|
||||
persistence=StubPersistenceService.instances[-1],
|
||||
sharing=StubSharingService.instances[-1],
|
||||
downloader=downloader,
|
||||
tmp_dir=tmp_path,
|
||||
)
|
||||
|
||||
@@ -275,6 +295,54 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) ->
|
||||
assert harness.persistence.delete_calls == ["saved-id"]
|
||||
|
||||
|
||||
async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
resources = [
|
||||
{
|
||||
"type": "checkpoint",
|
||||
"modelId": 10,
|
||||
"modelVersionId": 33,
|
||||
"modelName": "Flux",
|
||||
"modelVersionName": "Dev",
|
||||
},
|
||||
{
|
||||
"type": "lora",
|
||||
"modelId": 20,
|
||||
"modelVersionId": 44,
|
||||
"modelName": "Painterly",
|
||||
"modelVersionName": "v2",
|
||||
"weight": 0.25,
|
||||
},
|
||||
]
|
||||
response = await harness.client.get(
|
||||
"/api/lm/recipes/import-remote",
|
||||
params={
|
||||
"image_url": "https://example.com/images/1",
|
||||
"name": "Remote Recipe",
|
||||
"resources": json.dumps(resources),
|
||||
"tags": "foo,bar",
|
||||
"base_model": "Flux",
|
||||
"source_path": "https://example.com/images/1",
|
||||
"gen_params": json.dumps({"prompt": "hello world", "cfg_scale": 7}),
|
||||
},
|
||||
)
|
||||
|
||||
payload = await response.json()
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
|
||||
call = harness.persistence.save_calls[-1]
|
||||
assert call["name"] == "Remote Recipe"
|
||||
assert call["tags"] == ["foo", "bar"]
|
||||
metadata = call["metadata"]
|
||||
assert metadata["base_model"] == "Flux"
|
||||
assert metadata["checkpoint"]["modelVersionId"] == 33
|
||||
assert metadata["loras"][0]["weight"] == 0.25
|
||||
assert metadata["gen_params"]["prompt"] == "hello world"
|
||||
assert metadata["gen_params"]["checkpoint"]["modelVersionId"] == 33
|
||||
assert harness.downloader.urls == ["https://example.com/images/1"]
|
||||
|
||||
|
||||
async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")
|
||||
@@ -327,4 +395,3 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
assert body == b"stub"
|
||||
|
||||
download_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@@ -157,6 +157,55 @@ async def test_save_recipe_reports_duplicates(tmp_path):
|
||||
assert service._exif_utils.appended[0] == expected_image_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_persists_checkpoint_metadata(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
return []
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
return None
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
checkpoint_meta = {
|
||||
"type": "checkpoint",
|
||||
"modelId": 10,
|
||||
"modelVersionId": 20,
|
||||
"modelName": "Flux",
|
||||
"modelVersionName": "Dev",
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"base_model": "Flux",
|
||||
"loras": [],
|
||||
"checkpoint": checkpoint_meta,
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"img",
|
||||
image_base64=None,
|
||||
name="Checkpointed",
|
||||
tags=[],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["checkpoint"] == checkpoint_meta
|
||||
assert stored["gen_params"]["checkpoint"] == checkpoint_meta
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
Reference in New Issue
Block a user