fix(recipes): sanitize remote import gen params

This commit is contained in:
Will Miao
2026-04-12 20:29:01 +08:00
parent 0253d001e6
commit 55e9e4bb6f
6 changed files with 244 additions and 137 deletions

View File

@@ -13,4 +13,5 @@ GEN_PARAM_KEYS = [
'seed', 'seed',
'size', 'size',
'clip_skip', 'clip_skip',
'denoising_strength',
] ]

View File

@@ -1,27 +1,33 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import logging import logging
from .constants import GEN_PARAM_KEYS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class GenParamsMerger: class GenParamsMerger:
"""Utility to merge generation parameters from multiple sources with priority.""" """Utility to merge generation parameters from multiple sources with priority."""
ALLOWED_KEYS = set(GEN_PARAM_KEYS)
BLACKLISTED_KEYS = { BLACKLISTED_KEYS = {
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta", "id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
"draft", "extra", "width", "height", "process", "quantity", "workflow", "draft", "extra", "width", "height", "process", "quantity", "workflow",
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date", "baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
"experimental", "civitaiResources", "civitai_resources", "Civitai resources", "experimental", "civitaiResources", "civitai_resources", "Civitai resources",
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash", "modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
"checkpoint", "checksum", "model_checksum" "checkpoint", "checksum", "model_checksum", "raw_metadata",
} }
NORMALIZATION_MAPPING = { NORMALIZATION_MAPPING = {
# Civitai specific "cfg": "cfg_scale",
"cfgScale": "cfg_scale", "cfgScale": "cfg_scale",
"clipSkip": "clip_skip", "clipSkip": "clip_skip",
"negativePrompt": "negative_prompt", "negativePrompt": "negative_prompt",
# Case variations
"Sampler": "sampler", "Sampler": "sampler",
"sampler_name": "sampler",
"scheduler": "sampler",
"Steps": "steps", "Steps": "steps",
"Seed": "seed", "Seed": "seed",
"Size": "size", "Size": "size",
@@ -36,63 +42,40 @@ class GenParamsMerger:
def merge( def merge(
request_params: Optional[Dict[str, Any]] = None, request_params: Optional[Dict[str, Any]] = None,
civitai_meta: Optional[Dict[str, Any]] = None, civitai_meta: Optional[Dict[str, Any]] = None,
embedded_metadata: Optional[Dict[str, Any]] = None embedded_metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Merge generation parameters from three sources. Merge generation parameters from three sources.
Priority: request_params > civitai_meta > embedded_metadata Priority: request_params > civitai_meta > embedded_metadata
Args:
request_params: Params provided directly in the import request
civitai_meta: Params from Civitai Image API 'meta' field
embedded_metadata: Params extracted from image EXIF/embedded metadata
Returns:
Merged parameters dictionary
""" """
result = {} result: Dict[str, Any] = {}
# 1. Start with embedded metadata (lowest priority)
if embedded_metadata: if embedded_metadata:
# If it's a full recipe metadata, we use its gen_params if "gen_params" in embedded_metadata and isinstance(
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict): embedded_metadata["gen_params"], dict
):
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"]) GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
else: else:
# Otherwise assume the dict itself contains gen_params
GenParamsMerger._update_normalized(result, embedded_metadata) GenParamsMerger._update_normalized(result, embedded_metadata)
# 2. Layer Civitai meta (medium priority)
if civitai_meta: if civitai_meta:
GenParamsMerger._update_normalized(result, civitai_meta) GenParamsMerger._update_normalized(result, civitai_meta)
# 3. Layer request params (highest priority)
if request_params: if request_params:
GenParamsMerger._update_normalized(result, request_params) GenParamsMerger._update_normalized(result, request_params)
# Filter out blacklisted keys and also the original camelCase keys if they were normalized return result
final_result = {}
for k, v in result.items():
if k in GenParamsMerger.BLACKLISTED_KEYS:
continue
if k in GenParamsMerger.NORMALIZATION_MAPPING:
continue
final_result[k] = v
return final_result
@staticmethod @staticmethod
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None: def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
"""Update target dict with normalized keys from source.""" """Update target dict with normalized, persistence-safe keys from source."""
for k, v in source.items(): for key, value in source.items():
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k) if key in GenParamsMerger.BLACKLISTED_KEYS:
target[normalized_key] = v continue
# Also keep the original key for now if it's not the same,
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed? normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(key, key)
# Actually, if we rename it, we should probably NOT keep both in 'target' if normalized_key not in GenParamsMerger.ALLOWED_KEYS:
# because we want to filter them out at the end anyway. continue
if normalized_key != k:
# If we are overwriting an existing snake_case key with a camelCase one's value, target[normalized_key] = value
# that's fine because of the priority order of calls to _update_normalized.
pass
target[k] = v

View File

@@ -756,6 +756,14 @@ class RecipeManagementHandler:
) )
gen_params_request = self._parse_gen_params(params.get("gen_params")) gen_params_request = self._parse_gen_params(params.get("gen_params"))
self._logger.info(
"Remote recipe import received: url=%s, request_gen_params_keys=%s, lora_count=%d, checkpoint_keys=%s",
image_url,
sorted(gen_params_request.keys()) if gen_params_request else [],
len(lora_entries),
sorted(checkpoint_entry.keys()) if isinstance(checkpoint_entry, dict) else [],
)
# 2. Initial Metadata Construction # 2. Initial Metadata Construction
metadata: Dict[str, Any] = { metadata: Dict[str, Any] = {
"base_model": params.get("base_model", "") or "", "base_model": params.get("base_model", "") or "",

View File

@@ -12,6 +12,7 @@ from dataclasses import dataclass
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, Optional
from ...config import config from ...config import config
from ...recipes.constants import GEN_PARAM_KEYS
from ...utils.utils import calculate_recipe_fingerprint from ...utils.utils import calculate_recipe_fingerprint
from .errors import RecipeNotFoundError, RecipeValidationError from .errors import RecipeNotFoundError, RecipeValidationError
@@ -90,23 +91,7 @@ class RecipePersistenceService:
current_time = time.time() current_time = time.time()
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])] loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata)) checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
gen_params = self._sanitize_gen_params_for_storage(metadata)
gen_params = metadata.get("gen_params") or {}
if not gen_params and "raw_metadata" in metadata:
raw_metadata = metadata.get("raw_metadata", {})
gen_params = {
"prompt": raw_metadata.get("prompt", ""),
"negative_prompt": raw_metadata.get("negative_prompt", ""),
"steps": raw_metadata.get("steps", ""),
"sampler": raw_metadata.get("sampler", ""),
"cfg_scale": raw_metadata.get("cfg_scale", ""),
"seed": raw_metadata.get("seed", ""),
"size": raw_metadata.get("size", ""),
"clip_skip": raw_metadata.get("clip_skip", ""),
}
# Drop checkpoint duplication from generation parameters to store it only at top level
gen_params.pop("checkpoint", None)
fingerprint = calculate_recipe_fingerprint(loras_data) fingerprint = calculate_recipe_fingerprint(loras_data)
recipe_data: Dict[str, Any] = { recipe_data: Dict[str, Any] = {
@@ -133,6 +118,7 @@ class RecipePersistenceService:
json_filename = f"{recipe_id}.recipe.json" json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename) json_path = os.path.join(recipes_dir, json_filename)
json_path = os.path.normpath(json_path) json_path = os.path.normpath(json_path)
with open(json_path, "w", encoding="utf-8") as file_obj: with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
@@ -152,6 +138,30 @@ class RecipePersistenceService:
} }
) )
@staticmethod
def _sanitize_gen_params_for_storage(metadata: dict[str, Any]) -> dict[str, Any]:
gen_params = metadata.get("gen_params")
if isinstance(gen_params, dict) and gen_params:
source = gen_params
else:
source = metadata.get("raw_metadata")
if not isinstance(source, dict):
return {}
allowed_keys = set(GEN_PARAM_KEYS)
sanitized: dict[str, Any] = {}
for key in allowed_keys:
if key not in source:
continue
value = source.get(key)
if value in (None, ""):
continue
sanitized[key] = value
sanitized.pop("checkpoint", None)
return sanitized
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult: async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
"""Delete an existing recipe.""" """Delete an existing recipe."""

View File

@@ -1,95 +1,86 @@
import pytest
from py.recipes.merger import GenParamsMerger from py.recipes.merger import GenParamsMerger
def test_merge_priority():
request_params = {"prompt": "from request", "steps": 20} def test_merge_priority_and_normalization():
civitai_meta = {"prompt": "from civitai", "cfg": 7.0} request_params = {"prompt": "from request", "Steps": 20, "cfg": 7.5}
civitai_meta = {"prompt": "from civitai", "cfgScale": 6.5, "negativePrompt": "bad"}
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}} embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata) merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
assert merged["prompt"] == "from request" assert merged == {
assert merged["steps"] == 20 "prompt": "from request",
assert merged["cfg"] == 7.0 "steps": 20,
assert merged["seed"] == 123 "cfg_scale": 7.5,
"negative_prompt": "bad",
"seed": 123,
}
def test_merge_no_request_params():
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata) def test_merge_accepts_raw_embedded_metadata():
embedded_metadata = {"prompt": "from raw embedded", "seed": 456, "scheduler": "karras"}
assert merged["prompt"] == "from civitai"
assert merged["cfg"] == 7.0
assert merged["seed"] == 123
def test_merge_only_embedded():
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
merged = GenParamsMerger.merge(None, None, embedded_metadata) merged = GenParamsMerger.merge(None, None, embedded_metadata)
assert merged["prompt"] == "from embedded" assert merged == {
assert merged["seed"] == 123 "prompt": "from raw embedded",
"seed": 456,
"sampler": "karras",
}
def test_merge_raw_embedded():
# Test when embedded metadata is just the gen_params themselves
embedded_metadata = {"prompt": "from raw embedded", "seed": 456}
merged = GenParamsMerger.merge(None, None, embedded_metadata) def test_merge_filters_unknown_and_blacklisted_keys():
request_params = {
assert merged["prompt"] == "from raw embedded" "prompt": "test",
assert merged["seed"] == 456 "id": "should-be-removed",
"checkpoint": "should-not-be-here",
def test_merge_none_values(): "raw_metadata": {"prompt": "remove"},
merged = GenParamsMerger.merge(None, None, None) }
assert merged == {} civitai_meta = {
"Version": "ComfyUI",
def test_merge_filters_blacklisted_keys(): "RNG": "cpu",
request_params = {"prompt": "test", "id": "should-be-removed", "checkpoint": "should-not-be-here"} "cfgScale": 7,
civitai_meta = {"cfg": 7, "url": "remove-me"} "url": "remove-me",
embedded_metadata = {"seed": 123, "hash": "remove-also"} }
embedded_metadata = {
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata) "seed": 123,
"hash": "remove-also",
assert "prompt" in merged "Discard penultimate sigma": True,
assert "cfg" in merged "eps_scaling_factor": 0.1,
assert "seed" in merged }
assert "id" not in merged
assert "url" not in merged merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
assert "hash" not in merged
assert "checkpoint" not in merged assert merged == {
"prompt": "test",
def test_merge_filters_meta_and_normalizes_keys(): "cfg_scale": 7,
"seed": 123,
}
def test_merge_does_not_keep_original_key_variants():
civitai_meta = { civitai_meta = {
"prompt": "masterpiece",
"cfgScale": 5, "cfgScale": 5,
"clipSkip": 2, "clipSkip": 2,
"negativePrompt": "low quality", "negativePrompt": "low quality",
"meta": {"irrelevant": "data"},
"Size": "1024x1024", "Size": "1024x1024",
"draft": False, "Denoising strength": 0.35,
"workflow": "txt2img",
"civitaiResources": [{"type": "checkpoint"}]
} }
request_params = { request_params = {
"cfg_scale": 5.0, "cfg_scale": 5.0,
"clip_skip": "2", "clip_skip": "2",
"Steps": 30
} }
merged = GenParamsMerger.merge(request_params, civitai_meta) merged = GenParamsMerger.merge(request_params, civitai_meta)
assert "meta" not in merged assert merged == {
assert "cfgScale" not in merged "cfg_scale": 5.0,
assert "clipSkip" not in merged "clip_skip": "2",
assert "negativePrompt" not in merged "negative_prompt": "low quality",
assert "Size" not in merged "size": "1024x1024",
assert "draft" not in merged "denoising_strength": 0.35,
assert "workflow" not in merged }
assert "civitaiResources" not in merged
assert merged["cfg_scale"] == 5.0 # From request_params
assert merged["clip_skip"] == "2" # From request_params def test_merge_none_values():
assert merged["negative_prompt"] == "low quality" # Normalized from civitai_meta assert GenParamsMerger.merge(None, None, None) == {}
assert merged["size"] == "1024x1024" # Normalized from civitai_meta
assert merged["steps"] == 30 # Normalized from request_params

View File

@@ -306,6 +306,120 @@ async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
assert "checkpoint" not in stored["gen_params"] assert "checkpoint" not in stored["gen_params"]
@pytest.mark.asyncio
async def test_save_recipe_strips_non_persistable_gen_params(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
metadata = {
"base_model": "Flux",
"loras": [],
"gen_params": {
"prompt": "hello world",
"negative_prompt": "bad hands",
"cfg_scale": 7,
"raw_metadata": {"prompt": "should not persist"},
"Version": "ComfyUI",
"RNG": "cpu",
"Schedule type": "karras",
"Discard penultimate sigma": True,
"eps_scaling_factor": 0.1,
},
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Sanitized",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["gen_params"] == {
"prompt": "hello world",
"negative_prompt": "bad hands",
"cfg_scale": 7,
}
@pytest.mark.asyncio
async def test_save_recipe_derives_allowed_fields_from_raw_metadata(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
metadata = {
"base_model": "Flux",
"loras": [],
"raw_metadata": {
"prompt": "hello world",
"negative_prompt": "bad hands",
"steps": 30,
"sampler": "Euler",
"cfg_scale": 7,
"seed": 123,
"size": "1024x1024",
"clip_skip": 2,
"Version": "ComfyUI",
"raw_metadata": {"nested": True},
},
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Derived",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["gen_params"] == {
"prompt": "hello world",
"negative_prompt": "bad hands",
"steps": 30,
"sampler": "Euler",
"cfg_scale": 7,
"seed": 123,
"size": "1024x1024",
"clip_skip": 2,
}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path): async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
exif_utils = DummyExifUtils() exif_utils = DummyExifUtils()