mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-07 00:46:44 -03:00
fix(recipes): sanitize remote import gen params
This commit is contained in:
@@ -13,4 +13,5 @@ GEN_PARAM_KEYS = [
|
||||
'seed',
|
||||
'size',
|
||||
'clip_skip',
|
||||
'denoising_strength',
|
||||
]
|
||||
|
||||
@@ -1,27 +1,33 @@
|
||||
from typing import Any, Dict, Optional
|
||||
import logging
|
||||
|
||||
from .constants import GEN_PARAM_KEYS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenParamsMerger:
|
||||
"""Utility to merge generation parameters from multiple sources with priority."""
|
||||
|
||||
ALLOWED_KEYS = set(GEN_PARAM_KEYS)
|
||||
|
||||
BLACKLISTED_KEYS = {
|
||||
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
|
||||
"draft", "extra", "width", "height", "process", "quantity", "workflow",
|
||||
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
|
||||
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
|
||||
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
|
||||
"checkpoint", "checksum", "model_checksum"
|
||||
"checkpoint", "checksum", "model_checksum", "raw_metadata",
|
||||
}
|
||||
|
||||
NORMALIZATION_MAPPING = {
|
||||
# Civitai specific
|
||||
"cfg": "cfg_scale",
|
||||
"cfgScale": "cfg_scale",
|
||||
"clipSkip": "clip_skip",
|
||||
"negativePrompt": "negative_prompt",
|
||||
# Case variations
|
||||
"Sampler": "sampler",
|
||||
"sampler_name": "sampler",
|
||||
"scheduler": "sampler",
|
||||
"Steps": "steps",
|
||||
"Seed": "seed",
|
||||
"Size": "size",
|
||||
@@ -36,63 +42,40 @@ class GenParamsMerger:
|
||||
def merge(
|
||||
request_params: Optional[Dict[str, Any]] = None,
|
||||
civitai_meta: Optional[Dict[str, Any]] = None,
|
||||
embedded_metadata: Optional[Dict[str, Any]] = None
|
||||
embedded_metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge generation parameters from three sources.
|
||||
|
||||
Priority: request_params > civitai_meta > embedded_metadata
|
||||
|
||||
Args:
|
||||
request_params: Params provided directly in the import request
|
||||
civitai_meta: Params from Civitai Image API 'meta' field
|
||||
embedded_metadata: Params extracted from image EXIF/embedded metadata
|
||||
|
||||
Returns:
|
||||
Merged parameters dictionary
|
||||
"""
|
||||
result = {}
|
||||
result: Dict[str, Any] = {}
|
||||
|
||||
# 1. Start with embedded metadata (lowest priority)
|
||||
if embedded_metadata:
|
||||
# If it's a full recipe metadata, we use its gen_params
|
||||
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
|
||||
if "gen_params" in embedded_metadata and isinstance(
|
||||
embedded_metadata["gen_params"], dict
|
||||
):
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
|
||||
else:
|
||||
# Otherwise assume the dict itself contains gen_params
|
||||
GenParamsMerger._update_normalized(result, embedded_metadata)
|
||||
|
||||
# 2. Layer Civitai meta (medium priority)
|
||||
if civitai_meta:
|
||||
GenParamsMerger._update_normalized(result, civitai_meta)
|
||||
|
||||
# 3. Layer request params (highest priority)
|
||||
if request_params:
|
||||
GenParamsMerger._update_normalized(result, request_params)
|
||||
|
||||
# Filter out blacklisted keys and also the original camelCase keys if they were normalized
|
||||
final_result = {}
|
||||
for k, v in result.items():
|
||||
if k in GenParamsMerger.BLACKLISTED_KEYS:
|
||||
continue
|
||||
if k in GenParamsMerger.NORMALIZATION_MAPPING:
|
||||
continue
|
||||
final_result[k] = v
|
||||
|
||||
return final_result
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
"""Update target dict with normalized keys from source."""
|
||||
for k, v in source.items():
|
||||
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k)
|
||||
target[normalized_key] = v
|
||||
# Also keep the original key for now if it's not the same,
|
||||
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed?
|
||||
# Actually, if we rename it, we should probably NOT keep both in 'target'
|
||||
# because we want to filter them out at the end anyway.
|
||||
if normalized_key != k:
|
||||
# If we are overwriting an existing snake_case key with a camelCase one's value,
|
||||
# that's fine because of the priority order of calls to _update_normalized.
|
||||
pass
|
||||
target[k] = v
|
||||
"""Update target dict with normalized, persistence-safe keys from source."""
|
||||
for key, value in source.items():
|
||||
if key in GenParamsMerger.BLACKLISTED_KEYS:
|
||||
continue
|
||||
|
||||
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(key, key)
|
||||
if normalized_key not in GenParamsMerger.ALLOWED_KEYS:
|
||||
continue
|
||||
|
||||
target[normalized_key] = value
|
||||
|
||||
@@ -756,6 +756,14 @@ class RecipeManagementHandler:
|
||||
)
|
||||
gen_params_request = self._parse_gen_params(params.get("gen_params"))
|
||||
|
||||
self._logger.info(
|
||||
"Remote recipe import received: url=%s, request_gen_params_keys=%s, lora_count=%d, checkpoint_keys=%s",
|
||||
image_url,
|
||||
sorted(gen_params_request.keys()) if gen_params_request else [],
|
||||
len(lora_entries),
|
||||
sorted(checkpoint_entry.keys()) if isinstance(checkpoint_entry, dict) else [],
|
||||
)
|
||||
|
||||
# 2. Initial Metadata Construction
|
||||
metadata: Dict[str, Any] = {
|
||||
"base_model": params.get("base_model", "") or "",
|
||||
|
||||
@@ -12,6 +12,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from ...config import config
|
||||
from ...recipes.constants import GEN_PARAM_KEYS
|
||||
from ...utils.utils import calculate_recipe_fingerprint
|
||||
from .errors import RecipeNotFoundError, RecipeValidationError
|
||||
|
||||
@@ -90,23 +91,7 @@ class RecipePersistenceService:
|
||||
current_time = time.time()
|
||||
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
||||
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
|
||||
|
||||
gen_params = metadata.get("gen_params") or {}
|
||||
if not gen_params and "raw_metadata" in metadata:
|
||||
raw_metadata = metadata.get("raw_metadata", {})
|
||||
gen_params = {
|
||||
"prompt": raw_metadata.get("prompt", ""),
|
||||
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
||||
"steps": raw_metadata.get("steps", ""),
|
||||
"sampler": raw_metadata.get("sampler", ""),
|
||||
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
||||
"seed": raw_metadata.get("seed", ""),
|
||||
"size": raw_metadata.get("size", ""),
|
||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||
}
|
||||
|
||||
# Drop checkpoint duplication from generation parameters to store it only at top level
|
||||
gen_params.pop("checkpoint", None)
|
||||
gen_params = self._sanitize_gen_params_for_storage(metadata)
|
||||
|
||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||
recipe_data: Dict[str, Any] = {
|
||||
@@ -133,6 +118,7 @@ class RecipePersistenceService:
|
||||
json_filename = f"{recipe_id}.recipe.json"
|
||||
json_path = os.path.join(recipes_dir, json_filename)
|
||||
json_path = os.path.normpath(json_path)
|
||||
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
@@ -152,6 +138,30 @@ class RecipePersistenceService:
|
||||
}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_gen_params_for_storage(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
gen_params = metadata.get("gen_params")
|
||||
if isinstance(gen_params, dict) and gen_params:
|
||||
source = gen_params
|
||||
else:
|
||||
source = metadata.get("raw_metadata")
|
||||
|
||||
if not isinstance(source, dict):
|
||||
return {}
|
||||
|
||||
allowed_keys = set(GEN_PARAM_KEYS)
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key in allowed_keys:
|
||||
if key not in source:
|
||||
continue
|
||||
value = source.get(key)
|
||||
if value in (None, ""):
|
||||
continue
|
||||
sanitized[key] = value
|
||||
|
||||
sanitized.pop("checkpoint", None)
|
||||
return sanitized
|
||||
|
||||
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
|
||||
"""Delete an existing recipe."""
|
||||
|
||||
|
||||
@@ -1,95 +1,86 @@
|
||||
import pytest
|
||||
from py.recipes.merger import GenParamsMerger
|
||||
|
||||
def test_merge_priority():
|
||||
request_params = {"prompt": "from request", "steps": 20}
|
||||
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
||||
|
||||
def test_merge_priority_and_normalization():
|
||||
request_params = {"prompt": "from request", "Steps": 20, "cfg": 7.5}
|
||||
civitai_meta = {"prompt": "from civitai", "cfgScale": 6.5, "negativePrompt": "bad"}
|
||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from request"
|
||||
assert merged["steps"] == 20
|
||||
assert merged["cfg"] == 7.0
|
||||
assert merged["seed"] == 123
|
||||
assert merged == {
|
||||
"prompt": "from request",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.5,
|
||||
"negative_prompt": "bad",
|
||||
"seed": 123,
|
||||
}
|
||||
|
||||
def test_merge_no_request_params():
|
||||
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||
|
||||
merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from civitai"
|
||||
assert merged["cfg"] == 7.0
|
||||
assert merged["seed"] == 123
|
||||
|
||||
def test_merge_only_embedded():
|
||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||
def test_merge_accepts_raw_embedded_metadata():
|
||||
embedded_metadata = {"prompt": "from raw embedded", "seed": 456, "scheduler": "karras"}
|
||||
|
||||
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from embedded"
|
||||
assert merged["seed"] == 123
|
||||
assert merged == {
|
||||
"prompt": "from raw embedded",
|
||||
"seed": 456,
|
||||
"sampler": "karras",
|
||||
}
|
||||
|
||||
def test_merge_raw_embedded():
|
||||
# Test when embedded metadata is just the gen_params themselves
|
||||
embedded_metadata = {"prompt": "from raw embedded", "seed": 456}
|
||||
|
||||
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from raw embedded"
|
||||
assert merged["seed"] == 456
|
||||
|
||||
def test_merge_none_values():
|
||||
merged = GenParamsMerger.merge(None, None, None)
|
||||
assert merged == {}
|
||||
|
||||
def test_merge_filters_blacklisted_keys():
|
||||
request_params = {"prompt": "test", "id": "should-be-removed", "checkpoint": "should-not-be-here"}
|
||||
civitai_meta = {"cfg": 7, "url": "remove-me"}
|
||||
embedded_metadata = {"seed": 123, "hash": "remove-also"}
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||
|
||||
assert "prompt" in merged
|
||||
assert "cfg" in merged
|
||||
assert "seed" in merged
|
||||
assert "id" not in merged
|
||||
assert "url" not in merged
|
||||
assert "hash" not in merged
|
||||
assert "checkpoint" not in merged
|
||||
|
||||
def test_merge_filters_meta_and_normalizes_keys():
|
||||
def test_merge_filters_unknown_and_blacklisted_keys():
|
||||
request_params = {
|
||||
"prompt": "test",
|
||||
"id": "should-be-removed",
|
||||
"checkpoint": "should-not-be-here",
|
||||
"raw_metadata": {"prompt": "remove"},
|
||||
}
|
||||
civitai_meta = {
|
||||
"Version": "ComfyUI",
|
||||
"RNG": "cpu",
|
||||
"cfgScale": 7,
|
||||
"url": "remove-me",
|
||||
}
|
||||
embedded_metadata = {
|
||||
"seed": 123,
|
||||
"hash": "remove-also",
|
||||
"Discard penultimate sigma": True,
|
||||
"eps_scaling_factor": 0.1,
|
||||
}
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||
|
||||
assert merged == {
|
||||
"prompt": "test",
|
||||
"cfg_scale": 7,
|
||||
"seed": 123,
|
||||
}
|
||||
|
||||
|
||||
def test_merge_does_not_keep_original_key_variants():
|
||||
civitai_meta = {
|
||||
"prompt": "masterpiece",
|
||||
"cfgScale": 5,
|
||||
"clipSkip": 2,
|
||||
"negativePrompt": "low quality",
|
||||
"meta": {"irrelevant": "data"},
|
||||
"Size": "1024x1024",
|
||||
"draft": False,
|
||||
"workflow": "txt2img",
|
||||
"civitaiResources": [{"type": "checkpoint"}]
|
||||
"Denoising strength": 0.35,
|
||||
}
|
||||
request_params = {
|
||||
"cfg_scale": 5.0,
|
||||
"clip_skip": "2",
|
||||
"Steps": 30
|
||||
}
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta)
|
||||
|
||||
assert "meta" not in merged
|
||||
assert "cfgScale" not in merged
|
||||
assert "clipSkip" not in merged
|
||||
assert "negativePrompt" not in merged
|
||||
assert "Size" not in merged
|
||||
assert "draft" not in merged
|
||||
assert "workflow" not in merged
|
||||
assert "civitaiResources" not in merged
|
||||
assert merged == {
|
||||
"cfg_scale": 5.0,
|
||||
"clip_skip": "2",
|
||||
"negative_prompt": "low quality",
|
||||
"size": "1024x1024",
|
||||
"denoising_strength": 0.35,
|
||||
}
|
||||
|
||||
assert merged["cfg_scale"] == 5.0 # From request_params
|
||||
assert merged["clip_skip"] == "2" # From request_params
|
||||
assert merged["negative_prompt"] == "low quality" # Normalized from civitai_meta
|
||||
assert merged["size"] == "1024x1024" # Normalized from civitai_meta
|
||||
assert merged["steps"] == 30 # Normalized from request_params
|
||||
|
||||
def test_merge_none_values():
|
||||
assert GenParamsMerger.merge(None, None, None) == {}
|
||||
|
||||
@@ -306,6 +306,120 @@ async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
|
||||
assert "checkpoint" not in stored["gen_params"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_strips_non_persistable_gen_params(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
return []
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
return None
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"base_model": "Flux",
|
||||
"loras": [],
|
||||
"gen_params": {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"cfg_scale": 7,
|
||||
"raw_metadata": {"prompt": "should not persist"},
|
||||
"Version": "ComfyUI",
|
||||
"RNG": "cpu",
|
||||
"Schedule type": "karras",
|
||||
"Discard penultimate sigma": True,
|
||||
"eps_scaling_factor": 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"img",
|
||||
image_base64=None,
|
||||
name="Sanitized",
|
||||
tags=[],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["gen_params"] == {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"cfg_scale": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_derives_allowed_fields_from_raw_metadata(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
return []
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
return None
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"base_model": "Flux",
|
||||
"loras": [],
|
||||
"raw_metadata": {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"steps": 30,
|
||||
"sampler": "Euler",
|
||||
"cfg_scale": 7,
|
||||
"seed": 123,
|
||||
"size": "1024x1024",
|
||||
"clip_skip": 2,
|
||||
"Version": "ComfyUI",
|
||||
"raw_metadata": {"nested": True},
|
||||
},
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"img",
|
||||
image_base64=None,
|
||||
name="Derived",
|
||||
tags=[],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["gen_params"] == {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"steps": 30,
|
||||
"sampler": "Euler",
|
||||
"cfg_scale": 7,
|
||||
"seed": 123,
|
||||
"size": "1024x1024",
|
||||
"clip_skip": 2,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
Reference in New Issue
Block a user