diff --git a/py/recipes/constants.py b/py/recipes/constants.py index c1d4e383..b9ae8ba1 100644 --- a/py/recipes/constants.py +++ b/py/recipes/constants.py @@ -13,4 +13,5 @@ GEN_PARAM_KEYS = [ 'seed', 'size', 'clip_skip', + 'denoising_strength', ] diff --git a/py/recipes/merger.py b/py/recipes/merger.py index 93d19857..cb47b734 100644 --- a/py/recipes/merger.py +++ b/py/recipes/merger.py @@ -1,27 +1,33 @@ from typing import Any, Dict, Optional import logging +from .constants import GEN_PARAM_KEYS + logger = logging.getLogger(__name__) + class GenParamsMerger: """Utility to merge generation parameters from multiple sources with priority.""" + ALLOWED_KEYS = set(GEN_PARAM_KEYS) + BLACKLISTED_KEYS = { "id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta", "draft", "extra", "width", "height", "process", "quantity", "workflow", "baseModel", "resources", "disablePoi", "aspectRatio", "Created Date", "experimental", "civitaiResources", "civitai_resources", "Civitai resources", "modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash", - "checkpoint", "checksum", "model_checksum" + "checkpoint", "checksum", "model_checksum", "raw_metadata", } - + NORMALIZATION_MAPPING = { - # Civitai specific + "cfg": "cfg_scale", "cfgScale": "cfg_scale", "clipSkip": "clip_skip", "negativePrompt": "negative_prompt", - # Case variations "Sampler": "sampler", + "sampler_name": "sampler", + "scheduler": "sampler", "Steps": "steps", "Seed": "seed", "Size": "size", @@ -36,63 +42,40 @@ class GenParamsMerger: def merge( request_params: Optional[Dict[str, Any]] = None, civitai_meta: Optional[Dict[str, Any]] = None, - embedded_metadata: Optional[Dict[str, Any]] = None + embedded_metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """ Merge generation parameters from three sources. - - Priority: request_params > civitai_meta > embedded_metadata - - Args: - request_params: Params provided directly in the import request - civitai_meta: Params from Civitai Image API 'meta' field - embedded_metadata: Params extracted from image EXIF/embedded metadata - - Returns: - Merged parameters dictionary - """ - result = {} - # 1. Start with embedded metadata (lowest priority) + Priority: request_params > civitai_meta > embedded_metadata + """ + result: Dict[str, Any] = {} + if embedded_metadata: - # If it's a full recipe metadata, we use its gen_params - if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict): + if "gen_params" in embedded_metadata and isinstance( + embedded_metadata["gen_params"], dict + ): GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"]) else: - # Otherwise assume the dict itself contains gen_params GenParamsMerger._update_normalized(result, embedded_metadata) - # 2. Layer Civitai meta (medium priority) if civitai_meta: GenParamsMerger._update_normalized(result, civitai_meta) - # 3. Layer request params (highest priority) if request_params: GenParamsMerger._update_normalized(result, request_params) - # Filter out blacklisted keys and also the original camelCase keys if they were normalized - final_result = {} - for k, v in result.items(): - if k in GenParamsMerger.BLACKLISTED_KEYS: - continue - if k in GenParamsMerger.NORMALIZATION_MAPPING: - continue - final_result[k] = v - - return final_result + return result @staticmethod def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None: - """Update target dict with normalized keys from source.""" - for k, v in source.items(): - normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k) - target[normalized_key] = v - # Also keep the original key for now if it's not the same, - # so we can filter at the end or avoid losing it if it wasn't supposed to be renamed? - # Actually, if we rename it, we should probably NOT keep both in 'target' - # because we want to filter them out at the end anyway. - if normalized_key != k: - # If we are overwriting an existing snake_case key with a camelCase one's value, - # that's fine because of the priority order of calls to _update_normalized. - pass - target[k] = v + """Update target dict with normalized, persistence-safe keys from source.""" + for key, value in source.items(): + if key in GenParamsMerger.BLACKLISTED_KEYS: + continue + + normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(key, key) + if normalized_key not in GenParamsMerger.ALLOWED_KEYS: + continue + + target[normalized_key] = value diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index 59e354f7..6120cb17 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -756,6 +756,14 @@ class RecipeManagementHandler: ) gen_params_request = self._parse_gen_params(params.get("gen_params")) + self._logger.info( + "Remote recipe import received: url=%s, request_gen_params_keys=%s, lora_count=%d, checkpoint_keys=%s", + image_url, + sorted(gen_params_request.keys()) if gen_params_request else [], + len(lora_entries), + sorted(checkpoint_entry.keys()) if isinstance(checkpoint_entry, dict) else [], + ) + # 2. Initial Metadata Construction metadata: Dict[str, Any] = { "base_model": params.get("base_model", "") or "", diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index be307f3b..3c5a7c00 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -12,6 +12,7 @@ from dataclasses import dataclass from typing import Any, Dict, Iterable, Optional from ...config import config +from ...recipes.constants import GEN_PARAM_KEYS from ...utils.utils import calculate_recipe_fingerprint from .errors import RecipeNotFoundError, RecipeValidationError @@ -90,23 +91,7 @@ class RecipePersistenceService: current_time = time.time() loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])] checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata)) - - gen_params = metadata.get("gen_params") or {} - if not gen_params and "raw_metadata" in metadata: - raw_metadata = metadata.get("raw_metadata", {}) - gen_params = { - "prompt": raw_metadata.get("prompt", ""), - "negative_prompt": raw_metadata.get("negative_prompt", ""), - "steps": raw_metadata.get("steps", ""), - "sampler": raw_metadata.get("sampler", ""), - "cfg_scale": raw_metadata.get("cfg_scale", ""), - "seed": raw_metadata.get("seed", ""), - "size": raw_metadata.get("size", ""), - "clip_skip": raw_metadata.get("clip_skip", ""), - } - - # Drop checkpoint duplication from generation parameters to store it only at top level - gen_params.pop("checkpoint", None) + gen_params = self._sanitize_gen_params_for_storage(metadata) fingerprint = calculate_recipe_fingerprint(loras_data) recipe_data: Dict[str, Any] = { @@ -133,6 +118,7 @@ class RecipePersistenceService: json_filename = f"{recipe_id}.recipe.json" json_path = os.path.join(recipes_dir, json_filename) json_path = os.path.normpath(json_path) + with open(json_path, "w", encoding="utf-8") as file_obj: json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) @@ -152,6 +138,30 @@ class RecipePersistenceService: } ) + @staticmethod + def _sanitize_gen_params_for_storage(metadata: dict[str, Any]) -> dict[str, Any]: + gen_params = metadata.get("gen_params") + if isinstance(gen_params, dict) and gen_params: + source = gen_params + else: + source = metadata.get("raw_metadata") + + if not isinstance(source, dict): + return {} + + allowed_keys = set(GEN_PARAM_KEYS) + sanitized: dict[str, Any] = {} + for key in allowed_keys: + if key not in source: + continue + value = source.get(key) + if value in (None, ""): + continue + sanitized[key] = value + + sanitized.pop("checkpoint", None) + return sanitized + async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult: """Delete an existing recipe.""" diff --git a/tests/services/test_gen_params_merger.py b/tests/services/test_gen_params_merger.py index 16313f0a..036bab8a 100644 --- a/tests/services/test_gen_params_merger.py +++ b/tests/services/test_gen_params_merger.py @@ -1,95 +1,86 @@ -import pytest from py.recipes.merger import GenParamsMerger -def test_merge_priority(): - request_params = {"prompt": "from request", "steps": 20} - civitai_meta = {"prompt": "from civitai", "cfg": 7.0} + +def test_merge_priority_and_normalization(): + request_params = {"prompt": "from request", "Steps": 20, "cfg": 7.5} + civitai_meta = {"prompt": "from civitai", "cfgScale": 6.5, "negativePrompt": "bad"} embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}} - + merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata) - - assert merged["prompt"] == "from request" - assert merged["steps"] == 20 - assert merged["cfg"] == 7.0 - assert merged["seed"] == 123 -def test_merge_no_request_params(): - civitai_meta = {"prompt": "from civitai", "cfg": 7.0} - embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}} - - merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata) - - assert merged["prompt"] == "from civitai" - assert merged["cfg"] == 7.0 - assert merged["seed"] == 123 + assert merged == { + "prompt": "from request", + "steps": 20, + "cfg_scale": 7.5, + "negative_prompt": "bad", + "seed": 123, + } + + +def test_merge_accepts_raw_embedded_metadata(): + embedded_metadata = {"prompt": "from raw embedded", "seed": 456, "scheduler": "karras"} -def test_merge_only_embedded(): - embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}} - merged = GenParamsMerger.merge(None, None, embedded_metadata) - - assert merged["prompt"] == "from embedded" - assert merged["seed"] == 123 -def test_merge_raw_embedded(): - # Test when embedded metadata is just the gen_params themselves - embedded_metadata = {"prompt": "from raw embedded", "seed": 456} - - merged = GenParamsMerger.merge(None, None, embedded_metadata) - - assert merged["prompt"] == "from raw embedded" - assert merged["seed"] == 456 + assert merged == { + "prompt": "from raw embedded", + "seed": 456, + "sampler": "karras", + } -def test_merge_none_values(): - merged = GenParamsMerger.merge(None, None, None) - assert merged == {} -def test_merge_filters_blacklisted_keys(): - request_params = {"prompt": "test", "id": "should-be-removed", "checkpoint": "should-not-be-here"} - civitai_meta = {"cfg": 7, "url": "remove-me"} - embedded_metadata = {"seed": 123, "hash": "remove-also"} - - merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata) - - assert "prompt" in merged - assert "cfg" in merged - assert "seed" in merged - assert "id" not in merged - assert "url" not in merged - assert "hash" not in merged - assert "checkpoint" not in merged - -def test_merge_filters_meta_and_normalizes_keys(): +def test_merge_filters_unknown_and_blacklisted_keys(): + request_params = { + "prompt": "test", + "id": "should-be-removed", + "checkpoint": "should-not-be-here", + "raw_metadata": {"prompt": "remove"}, + } + civitai_meta = { + "Version": "ComfyUI", + "RNG": "cpu", + "cfgScale": 7, + "url": "remove-me", + } + embedded_metadata = { + "seed": 123, + "hash": "remove-also", + "Discard penultimate sigma": True, + "eps_scaling_factor": 0.1, + } + + merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata) + + assert merged == { + "prompt": "test", + "cfg_scale": 7, + "seed": 123, + } + + +def test_merge_does_not_keep_original_key_variants(): civitai_meta = { - "prompt": "masterpiece", "cfgScale": 5, "clipSkip": 2, "negativePrompt": "low quality", - "meta": {"irrelevant": "data"}, "Size": "1024x1024", - "draft": False, - "workflow": "txt2img", - "civitaiResources": [{"type": "checkpoint"}] + "Denoising strength": 0.35, } request_params = { "cfg_scale": 5.0, "clip_skip": "2", - "Steps": 30 } - + merged = GenParamsMerger.merge(request_params, civitai_meta) - - assert "meta" not in merged - assert "cfgScale" not in merged - assert "clipSkip" not in merged - assert "negativePrompt" not in merged - assert "Size" not in merged - assert "draft" not in merged - assert "workflow" not in merged - assert "civitaiResources" not in merged - - assert merged["cfg_scale"] == 5.0 # From request_params - assert merged["clip_skip"] == "2" # From request_params - assert merged["negative_prompt"] == "low quality" # Normalized from civitai_meta - assert merged["size"] == "1024x1024" # Normalized from civitai_meta - assert merged["steps"] == 30 # Normalized from request_params + + assert merged == { + "cfg_scale": 5.0, + "clip_skip": "2", + "negative_prompt": "low quality", + "size": "1024x1024", + "denoising_strength": 0.35, + } + + +def test_merge_none_values(): + assert GenParamsMerger.merge(None, None, None) == {} diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index 94b3ba59..32e93ee1 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -306,6 +306,120 @@ async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path): assert "checkpoint" not in stored["gen_params"] +@pytest.mark.asyncio +async def test_save_recipe_strips_non_persistable_gen_params(tmp_path): + exif_utils = DummyExifUtils() + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + + async def find_recipes_by_fingerprint(self, fingerprint): + return [] + + async def add_recipe(self, recipe_data): + return None + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + metadata = { + "base_model": "Flux", + "loras": [], + "gen_params": { + "prompt": "hello world", + "negative_prompt": "bad hands", + "cfg_scale": 7, + "raw_metadata": {"prompt": "should not persist"}, + "Version": "ComfyUI", + "RNG": "cpu", + "Schedule type": "karras", + "Discard penultimate sigma": True, + "eps_scaling_factor": 0.1, + }, + } + + result = await service.save_recipe( + recipe_scanner=scanner, + image_bytes=b"img", + image_base64=None, + name="Sanitized", + tags=[], + metadata=metadata, + ) + + stored = json.loads(Path(result.payload["json_path"]).read_text()) + assert stored["gen_params"] == { + "prompt": "hello world", + "negative_prompt": "bad hands", + "cfg_scale": 7, + } + + +@pytest.mark.asyncio +async def test_save_recipe_derives_allowed_fields_from_raw_metadata(tmp_path): + exif_utils = DummyExifUtils() + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + + async def find_recipes_by_fingerprint(self, fingerprint): + return [] + + async def add_recipe(self, recipe_data): + return None + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + metadata = { + "base_model": "Flux", + "loras": [], + "raw_metadata": { + "prompt": "hello world", + "negative_prompt": "bad hands", + "steps": 30, + "sampler": "Euler", + "cfg_scale": 7, + "seed": 123, + "size": "1024x1024", + "clip_skip": 2, + "Version": "ComfyUI", + "raw_metadata": {"nested": True}, + }, + } + + result = await service.save_recipe( + recipe_scanner=scanner, + image_bytes=b"img", + image_base64=None, + name="Derived", + tags=[], + metadata=metadata, + ) + + stored = json.loads(Path(result.payload["json_path"]).read_text()) + assert stored["gen_params"] == { + "prompt": "hello world", + "negative_prompt": "bad hands", + "steps": 30, + "sampler": "Euler", + "cfg_scale": 7, + "seed": 123, + "size": "1024x1024", + "clip_skip": 2, + } + + @pytest.mark.asyncio async def test_save_recipe_strips_checkpoint_local_fields(tmp_path): exif_utils = DummyExifUtils()