diff --git a/py/recipes/merger.py b/py/recipes/merger.py new file mode 100644 index 00000000..1ddd3268 --- /dev/null +++ b/py/recipes/merger.py @@ -0,0 +1,50 @@ +from typing import Any, Dict, Optional +import logging + +logger = logging.getLogger(__name__) + +class GenParamsMerger: + """Utility to merge generation parameters from multiple sources with priority.""" + + BLACKLISTED_KEYS = {"id", "url", "userId", "username", "createdAt", "updatedAt", "hash"} + + @staticmethod + def merge( + request_params: Optional[Dict[str, Any]] = None, + civitai_meta: Optional[Dict[str, Any]] = None, + embedded_metadata: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: + """ + Merge generation parameters from three sources. + + Priority: request_params > civitai_meta > embedded_metadata + + Args: + request_params: Params provided directly in the import request + civitai_meta: Params from Civitai Image API 'meta' field + embedded_metadata: Params extracted from image EXIF/embedded metadata + + Returns: + Merged parameters dictionary + """ + result = {} + + # 1. Start with embedded metadata (lowest priority) + if embedded_metadata: + # If it's a full recipe metadata, we use its gen_params + if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict): + result.update(embedded_metadata["gen_params"]) + else: + # Otherwise assume the dict itself contains gen_params + result.update(embedded_metadata) + + # 2. Layer Civitai meta (medium priority) + if civitai_meta: + result.update(civitai_meta) + + # 3. Layer request params (highest priority) + if request_params: + result.update(request_params) + + # Filter out blacklisted keys + return {k: v for k, v in result.items() if k not in GenParamsMerger.BLACKLISTED_KEYS} diff --git a/py/recipes/parsers/comfy.py b/py/recipes/parsers/comfy.py index f81a15ad..e1d7251e 100644 --- a/py/recipes/parsers/comfy.py +++ b/py/recipes/parsers/comfy.py @@ -36,9 +36,6 @@ class ComfyMetadataParser(RecipeMetadataParser): # Find all LoraLoader nodes lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'} - if not lora_nodes: - return {"error": "No LoRA information found in this ComfyUI workflow", "loras": []} - # Process each LoraLoader node for node_id, node in lora_nodes.items(): if 'inputs' not in node or 'lora_name' not in node['inputs']: diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index a798f0b6..c18155dc 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -24,6 +24,8 @@ from ...services.recipes import ( ) from ...services.metadata_service import get_default_metadata_provider from ...utils.civitai_utils import rewrite_preview_url +from ...utils.exif_utils import ExifUtils +from ...recipes.merger import GenParamsMerger Logger = logging.Logger EnsureDependenciesCallable = Callable[[], Awaitable[None]] @@ -552,7 +554,41 @@ class RecipeManagementHandler: metadata["base_model"] = base_model_from_metadata tags = self._parse_tags(params.get("tags")) - image_bytes, extension = await self._download_remote_media(image_url) + image_bytes, extension, civitai_meta = await self._download_remote_media(image_url) + + # Extract embedded metadata from the downloaded image + embedded_metadata = None + try: + with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_img: + temp_img.write(image_bytes) + temp_img_path = temp_img.name + + try: + raw_embedded = ExifUtils.extract_image_metadata(temp_img_path) + if raw_embedded: + # Try to parse it using standard parsers if it looks like a recipe + parser = self._analysis_service._recipe_parser_factory.create_parser(raw_embedded) + if parser: + parsed_embedded = await parser.parse_metadata(raw_embedded, recipe_scanner=recipe_scanner) + embedded_metadata = parsed_embedded + else: + # Fallback to raw string if no parser matches (might be simple params) + embedded_metadata = {"gen_params": {"raw_metadata": raw_embedded}} + finally: + if os.path.exists(temp_img_path): + os.unlink(temp_img_path) + except Exception as exc: + self._logger.warning("Failed to extract embedded metadata during import: %s", exc) + + # Merge gen_params from all sources + merged_gen_params = GenParamsMerger.merge( + request_params=gen_params, + civitai_meta=civitai_meta, + embedded_metadata=embedded_metadata + ) + + if merged_gen_params: + metadata["gen_params"] = merged_gen_params result = await self._persistence_service.save_recipe( recipe_scanner=recipe_scanner, @@ -900,7 +936,7 @@ class RecipeManagementHandler: extension = ".webp" # Default to webp if unknown with open(temp_path, "rb") as file_obj: - return file_obj.read(), extension + return file_obj.read(), extension, image_info.get("meta") if civitai_match and image_info else None except RecipeDownloadError: raise except RecipeValidationError: diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py index 0a179732..51d3459e 100644 --- a/tests/routes/test_recipe_routes.py +++ b/tests/routes/test_recipe_routes.py @@ -91,6 +91,7 @@ class StubAnalysisService: self.remote_calls: List[Optional[str]] = [] self.local_calls: List[Optional[str]] = [] self.result = SimpleNamespace(payload={"loras": []}, status=200) + self._recipe_parser_factory = None StubAnalysisService.instances.append(self) async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature @@ -527,3 +528,69 @@ async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None: assert body == b"stub" download_path.unlink(missing_ok=True) +async def test_import_remote_recipe_merges_metadata(monkeypatch, tmp_path: Path) -> None: + # 1. Mock Metadata Provider + class Provider: + async def get_model_version_info(self, model_version_id): + return {"baseModel": "Flux Provider"}, None + + async def fake_get_default_metadata_provider(): + return Provider() + + monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider) + + # 2. Mock ExifUtils to return some embedded metadata + class MockExifUtils: + @staticmethod + def extract_image_metadata(path): + return "Recipe metadata: " + json.dumps({ + "gen_params": {"prompt": "from embedded", "seed": 123} + }) + + monkeypatch.setattr(recipe_handlers, "ExifUtils", MockExifUtils) + + # 3. Mock Parser Factory for StubAnalysisService + class MockParser: + async def parse_metadata(self, raw, recipe_scanner=None): + return json.loads(raw[len("Recipe metadata: "):]) + + class MockFactory: + def create_parser(self, raw): + if raw.startswith("Recipe metadata: "): + return MockParser() + return None + + # 4. Setup Harness and run test + async with recipe_harness(monkeypatch, tmp_path) as harness: + harness.analysis._recipe_parser_factory = MockFactory() + + # Civitai meta via image_info + harness.civitai.image_info["1"] = { + "id": 1, + "url": "https://example.com/images/1.jpg", + "meta": {"prompt": "from civitai", "cfg": 7.0} + } + + resources = [] + response = await harness.client.get( + "/api/lm/recipes/import-remote", + params={ + "image_url": "https://civitai.com/images/1", + "name": "Merged Recipe", + "resources": json.dumps(resources), + "gen_params": json.dumps({"prompt": "from request", "steps": 25}), + }, + ) + + payload = await response.json() + assert response.status == 200 + + call = harness.persistence.save_calls[-1] + metadata = call["metadata"] + gen_params = metadata["gen_params"] + + # Priority: request (prompt=request, steps=25) > civitai (prompt=civitai, cfg=7.0) > embedded (prompt=embedded, seed=123) + assert gen_params["prompt"] == "from request" + assert gen_params["steps"] == 25 + assert gen_params["cfg"] == 7.0 + assert gen_params["seed"] == 123 diff --git a/tests/services/test_comfy_metadata_parser.py b/tests/services/test_comfy_metadata_parser.py new file mode 100644 index 00000000..dac489ab --- /dev/null +++ b/tests/services/test_comfy_metadata_parser.py @@ -0,0 +1,113 @@ +import pytest +import json +from py.recipes.parsers.comfy import ComfyMetadataParser + +@pytest.mark.asyncio +async def test_parse_metadata_without_loras(monkeypatch): + checkpoint_info = { + "id": 2224012, + "modelId": 1908679, + "model": {"name": "SDXL Checkpoint", "type": "checkpoint"}, + "name": "v1.0", + "images": [{"url": "https://image.civitai.com/checkpoints/original=true"}], + "baseModel": "sdxl", + "downloadUrl": "https://civitai.com/api/download/checkpoint", + } + + async def fake_metadata_provider(): + class Provider: + async def get_model_version_info(self, version_id): + assert version_id == "2224012" + return checkpoint_info, None + return Provider() + + monkeypatch.setattr( + "py.recipes.parsers.comfy.get_default_metadata_provider", + fake_metadata_provider, + ) + + parser = ComfyMetadataParser() + + # User provided metadata + metadata_json = { + "resource-stack": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "urn:air:sdxl:checkpoint:civitai:1908679@2224012"} + }, + "6": { + "class_type": "smZ CLIPTextEncode", + "inputs": {"text": "Positive prompt content"}, + "_meta": {"title": "Positive"} + }, + "7": { + "class_type": "smZ CLIPTextEncode", + "inputs": {"text": "Negative prompt content"}, + "_meta": {"title": "Negative"} + }, + "11": { + "class_type": "KSampler", + "inputs": { + "sampler_name": "euler_ancestral", + "scheduler": "normal", + "seed": 904124997, + "steps": 35, + "cfg": 6, + "denoise": 0.1, + "model": ["resource-stack", 0], + "positive": ["6", 0], + "negative": ["7", 0], + "latent_image": ["21", 0] + }, + "_meta": {"title": "KSampler"} + }, + "extraMetadata": json.dumps({ + "prompt": "One woman, (solo:1.3), ...", + "negativePrompt": "lowres, worst quality, ...", + "steps": 35, + "cfgScale": 6, + "sampler": "euler_ancestral", + "seed": 904124997, + "width": 1024, + "height": 1024 + }) + } + + result = await parser.parse_metadata(json.dumps(metadata_json)) + + assert "error" not in result + assert result["loras"] == [] + assert result["checkpoint"] is not None + assert int(result["checkpoint"]["modelId"]) == 1908679 + assert int(result["checkpoint"]["id"]) == 2224012 + assert result["gen_params"]["prompt"] == "One woman, (solo:1.3), ..." + assert result["gen_params"]["steps"] == 35 + assert result["gen_params"]["size"] == "1024x1024" + assert result["from_comfy_metadata"] is True + +@pytest.mark.asyncio +async def test_parse_metadata_without_extra_metadata(monkeypatch): + async def fake_metadata_provider(): + class Provider: + async def get_model_version_info(self, version_id): + return {"model": {"name": "Test"}, "id": version_id}, None + return Provider() + + monkeypatch.setattr( + "py.recipes.parsers.comfy.get_default_metadata_provider", + fake_metadata_provider, + ) + + parser = ComfyMetadataParser() + + metadata_json = { + "node_1": { + "class_type": "CheckpointLoaderSimple", + "inputs": {"ckpt_name": "urn:air:sdxl:checkpoint:civitai:123@456"} + } + } + + result = await parser.parse_metadata(json.dumps(metadata_json)) + + assert "error" not in result + assert result["loras"] == [] + assert result["checkpoint"]["id"] == "456" diff --git a/tests/services/test_gen_params_merger.py b/tests/services/test_gen_params_merger.py new file mode 100644 index 00000000..291fff19 --- /dev/null +++ b/tests/services/test_gen_params_merger.py @@ -0,0 +1,59 @@ +import pytest +from py.recipes.merger import GenParamsMerger + +def test_merge_priority(): + request_params = {"prompt": "from request", "steps": 20} + civitai_meta = {"prompt": "from civitai", "cfg": 7.0} + embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}} + + merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata) + + assert merged["prompt"] == "from request" + assert merged["steps"] == 20 + assert merged["cfg"] == 7.0 + assert merged["seed"] == 123 + +def test_merge_no_request_params(): + civitai_meta = {"prompt": "from civitai", "cfg": 7.0} + embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}} + + merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata) + + assert merged["prompt"] == "from civitai" + assert merged["cfg"] == 7.0 + assert merged["seed"] == 123 + +def test_merge_only_embedded(): + embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}} + + merged = GenParamsMerger.merge(None, None, embedded_metadata) + + assert merged["prompt"] == "from embedded" + assert merged["seed"] == 123 + +def test_merge_raw_embedded(): + # Test when embedded metadata is just the gen_params themselves + embedded_metadata = {"prompt": "from raw embedded", "seed": 456} + + merged = GenParamsMerger.merge(None, None, embedded_metadata) + + assert merged["prompt"] == "from raw embedded" + assert merged["seed"] == 456 + +def test_merge_none_values(): + merged = GenParamsMerger.merge(None, None, None) + assert merged == {} + +def test_merge_filters_blacklisted_keys(): + request_params = {"prompt": "test", "id": "should-be-removed"} + civitai_meta = {"cfg": 7, "url": "remove-me"} + embedded_metadata = {"seed": 123, "hash": "remove-also"} + + merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata) + + assert "prompt" in merged + assert "cfg" in merged + assert "seed" in merged + assert "id" not in merged + assert "url" not in merged + assert "hash" not in merged