fix(recipes): sanitize remote import gen params

This commit is contained in:
Will Miao
2026-04-12 20:29:01 +08:00
parent 0253d001e6
commit 55e9e4bb6f
6 changed files with 244 additions and 137 deletions

View File

@@ -13,4 +13,5 @@ GEN_PARAM_KEYS = [
'seed',
'size',
'clip_skip',
'denoising_strength',
]

View File

@@ -1,27 +1,33 @@
from typing import Any, Dict, Optional
import logging
from .constants import GEN_PARAM_KEYS
logger = logging.getLogger(__name__)
class GenParamsMerger:
"""Utility to merge generation parameters from multiple sources with priority."""
ALLOWED_KEYS = set(GEN_PARAM_KEYS)
BLACKLISTED_KEYS = {
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
"draft", "extra", "width", "height", "process", "quantity", "workflow",
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
"checkpoint", "checksum", "model_checksum"
"checkpoint", "checksum", "model_checksum", "raw_metadata",
}
NORMALIZATION_MAPPING = {
# Civitai specific
"cfg": "cfg_scale",
"cfgScale": "cfg_scale",
"clipSkip": "clip_skip",
"negativePrompt": "negative_prompt",
# Case variations
"Sampler": "sampler",
"sampler_name": "sampler",
"scheduler": "sampler",
"Steps": "steps",
"Seed": "seed",
"Size": "size",
@@ -36,63 +42,40 @@ class GenParamsMerger:
def merge(
request_params: Optional[Dict[str, Any]] = None,
civitai_meta: Optional[Dict[str, Any]] = None,
embedded_metadata: Optional[Dict[str, Any]] = None
embedded_metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Merge generation parameters from three sources.
Priority: request_params > civitai_meta > embedded_metadata
Args:
request_params: Params provided directly in the import request
civitai_meta: Params from Civitai Image API 'meta' field
embedded_metadata: Params extracted from image EXIF/embedded metadata
Returns:
Merged parameters dictionary
"""
result = {}
result: Dict[str, Any] = {}
# 1. Start with embedded metadata (lowest priority)
if embedded_metadata:
# If it's a full recipe metadata, we use its gen_params
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
if "gen_params" in embedded_metadata and isinstance(
embedded_metadata["gen_params"], dict
):
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
else:
# Otherwise assume the dict itself contains gen_params
GenParamsMerger._update_normalized(result, embedded_metadata)
# 2. Layer Civitai meta (medium priority)
if civitai_meta:
GenParamsMerger._update_normalized(result, civitai_meta)
# 3. Layer request params (highest priority)
if request_params:
GenParamsMerger._update_normalized(result, request_params)
# Filter out blacklisted keys and also the original camelCase keys if they were normalized
final_result = {}
for k, v in result.items():
if k in GenParamsMerger.BLACKLISTED_KEYS:
continue
if k in GenParamsMerger.NORMALIZATION_MAPPING:
continue
final_result[k] = v
return final_result
return result
@staticmethod
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
"""Update target dict with normalized keys from source."""
for k, v in source.items():
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k)
target[normalized_key] = v
# Also keep the original key for now if it's not the same,
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed?
# Actually, if we rename it, we should probably NOT keep both in 'target'
# because we want to filter them out at the end anyway.
if normalized_key != k:
# If we are overwriting an existing snake_case key with a camelCase one's value,
# that's fine because of the priority order of calls to _update_normalized.
pass
target[k] = v
"""Update target dict with normalized, persistence-safe keys from source."""
for key, value in source.items():
if key in GenParamsMerger.BLACKLISTED_KEYS:
continue
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(key, key)
if normalized_key not in GenParamsMerger.ALLOWED_KEYS:
continue
target[normalized_key] = value

View File

@@ -756,6 +756,14 @@ class RecipeManagementHandler:
)
gen_params_request = self._parse_gen_params(params.get("gen_params"))
self._logger.info(
"Remote recipe import received: url=%s, request_gen_params_keys=%s, lora_count=%d, checkpoint_keys=%s",
image_url,
sorted(gen_params_request.keys()) if gen_params_request else [],
len(lora_entries),
sorted(checkpoint_entry.keys()) if isinstance(checkpoint_entry, dict) else [],
)
# 2. Initial Metadata Construction
metadata: Dict[str, Any] = {
"base_model": params.get("base_model", "") or "",

View File

@@ -12,6 +12,7 @@ from dataclasses import dataclass
from typing import Any, Dict, Iterable, Optional
from ...config import config
from ...recipes.constants import GEN_PARAM_KEYS
from ...utils.utils import calculate_recipe_fingerprint
from .errors import RecipeNotFoundError, RecipeValidationError
@@ -90,23 +91,7 @@ class RecipePersistenceService:
current_time = time.time()
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
gen_params = metadata.get("gen_params") or {}
if not gen_params and "raw_metadata" in metadata:
raw_metadata = metadata.get("raw_metadata", {})
gen_params = {
"prompt": raw_metadata.get("prompt", ""),
"negative_prompt": raw_metadata.get("negative_prompt", ""),
"steps": raw_metadata.get("steps", ""),
"sampler": raw_metadata.get("sampler", ""),
"cfg_scale": raw_metadata.get("cfg_scale", ""),
"seed": raw_metadata.get("seed", ""),
"size": raw_metadata.get("size", ""),
"clip_skip": raw_metadata.get("clip_skip", ""),
}
# Drop checkpoint duplication from generation parameters to store it only at top level
gen_params.pop("checkpoint", None)
gen_params = self._sanitize_gen_params_for_storage(metadata)
fingerprint = calculate_recipe_fingerprint(loras_data)
recipe_data: Dict[str, Any] = {
@@ -133,6 +118,7 @@ class RecipePersistenceService:
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
json_path = os.path.normpath(json_path)
with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
@@ -152,6 +138,30 @@ class RecipePersistenceService:
}
)
@staticmethod
def _sanitize_gen_params_for_storage(metadata: dict[str, Any]) -> dict[str, Any]:
gen_params = metadata.get("gen_params")
if isinstance(gen_params, dict) and gen_params:
source = gen_params
else:
source = metadata.get("raw_metadata")
if not isinstance(source, dict):
return {}
allowed_keys = set(GEN_PARAM_KEYS)
sanitized: dict[str, Any] = {}
for key in allowed_keys:
if key not in source:
continue
value = source.get(key)
if value in (None, ""):
continue
sanitized[key] = value
sanitized.pop("checkpoint", None)
return sanitized
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
"""Delete an existing recipe."""

View File

@@ -1,95 +1,86 @@
import pytest
from py.recipes.merger import GenParamsMerger
def test_merge_priority():
request_params = {"prompt": "from request", "steps": 20}
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
def test_merge_priority_and_normalization():
request_params = {"prompt": "from request", "Steps": 20, "cfg": 7.5}
civitai_meta = {"prompt": "from civitai", "cfgScale": 6.5, "negativePrompt": "bad"}
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
assert merged["prompt"] == "from request"
assert merged["steps"] == 20
assert merged["cfg"] == 7.0
assert merged["seed"] == 123
assert merged == {
"prompt": "from request",
"steps": 20,
"cfg_scale": 7.5,
"negative_prompt": "bad",
"seed": 123,
}
def test_merge_no_request_params():
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata)
assert merged["prompt"] == "from civitai"
assert merged["cfg"] == 7.0
assert merged["seed"] == 123
def test_merge_only_embedded():
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
def test_merge_accepts_raw_embedded_metadata():
embedded_metadata = {"prompt": "from raw embedded", "seed": 456, "scheduler": "karras"}
merged = GenParamsMerger.merge(None, None, embedded_metadata)
assert merged["prompt"] == "from embedded"
assert merged["seed"] == 123
assert merged == {
"prompt": "from raw embedded",
"seed": 456,
"sampler": "karras",
}
def test_merge_raw_embedded():
# Test when embedded metadata is just the gen_params themselves
embedded_metadata = {"prompt": "from raw embedded", "seed": 456}
merged = GenParamsMerger.merge(None, None, embedded_metadata)
assert merged["prompt"] == "from raw embedded"
assert merged["seed"] == 456
def test_merge_none_values():
merged = GenParamsMerger.merge(None, None, None)
assert merged == {}
def test_merge_filters_blacklisted_keys():
request_params = {"prompt": "test", "id": "should-be-removed", "checkpoint": "should-not-be-here"}
civitai_meta = {"cfg": 7, "url": "remove-me"}
embedded_metadata = {"seed": 123, "hash": "remove-also"}
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
assert "prompt" in merged
assert "cfg" in merged
assert "seed" in merged
assert "id" not in merged
assert "url" not in merged
assert "hash" not in merged
assert "checkpoint" not in merged
def test_merge_filters_meta_and_normalizes_keys():
def test_merge_filters_unknown_and_blacklisted_keys():
request_params = {
"prompt": "test",
"id": "should-be-removed",
"checkpoint": "should-not-be-here",
"raw_metadata": {"prompt": "remove"},
}
civitai_meta = {
"Version": "ComfyUI",
"RNG": "cpu",
"cfgScale": 7,
"url": "remove-me",
}
embedded_metadata = {
"seed": 123,
"hash": "remove-also",
"Discard penultimate sigma": True,
"eps_scaling_factor": 0.1,
}
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
assert merged == {
"prompt": "test",
"cfg_scale": 7,
"seed": 123,
}
def test_merge_does_not_keep_original_key_variants():
civitai_meta = {
"prompt": "masterpiece",
"cfgScale": 5,
"clipSkip": 2,
"negativePrompt": "low quality",
"meta": {"irrelevant": "data"},
"Size": "1024x1024",
"draft": False,
"workflow": "txt2img",
"civitaiResources": [{"type": "checkpoint"}]
"Denoising strength": 0.35,
}
request_params = {
"cfg_scale": 5.0,
"clip_skip": "2",
"Steps": 30
}
merged = GenParamsMerger.merge(request_params, civitai_meta)
assert "meta" not in merged
assert "cfgScale" not in merged
assert "clipSkip" not in merged
assert "negativePrompt" not in merged
assert "Size" not in merged
assert "draft" not in merged
assert "workflow" not in merged
assert "civitaiResources" not in merged
assert merged == {
"cfg_scale": 5.0,
"clip_skip": "2",
"negative_prompt": "low quality",
"size": "1024x1024",
"denoising_strength": 0.35,
}
assert merged["cfg_scale"] == 5.0 # From request_params
assert merged["clip_skip"] == "2" # From request_params
assert merged["negative_prompt"] == "low quality" # Normalized from civitai_meta
assert merged["size"] == "1024x1024" # Normalized from civitai_meta
assert merged["steps"] == 30 # Normalized from request_params
def test_merge_none_values():
assert GenParamsMerger.merge(None, None, None) == {}

View File

@@ -306,6 +306,120 @@ async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
assert "checkpoint" not in stored["gen_params"]
@pytest.mark.asyncio
async def test_save_recipe_strips_non_persistable_gen_params(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
metadata = {
"base_model": "Flux",
"loras": [],
"gen_params": {
"prompt": "hello world",
"negative_prompt": "bad hands",
"cfg_scale": 7,
"raw_metadata": {"prompt": "should not persist"},
"Version": "ComfyUI",
"RNG": "cpu",
"Schedule type": "karras",
"Discard penultimate sigma": True,
"eps_scaling_factor": 0.1,
},
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Sanitized",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["gen_params"] == {
"prompt": "hello world",
"negative_prompt": "bad hands",
"cfg_scale": 7,
}
@pytest.mark.asyncio
async def test_save_recipe_derives_allowed_fields_from_raw_metadata(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
metadata = {
"base_model": "Flux",
"loras": [],
"raw_metadata": {
"prompt": "hello world",
"negative_prompt": "bad hands",
"steps": 30,
"sampler": "Euler",
"cfg_scale": 7,
"seed": 123,
"size": "1024x1024",
"clip_skip": 2,
"Version": "ComfyUI",
"raw_metadata": {"nested": True},
},
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Derived",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["gen_params"] == {
"prompt": "hello world",
"negative_prompt": "bad hands",
"steps": 30,
"sampler": "Euler",
"cfg_scale": 7,
"seed": 123,
"size": "1024x1024",
"clip_skip": 2,
}
@pytest.mark.asyncio
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
exif_utils = DummyExifUtils()