fix(recipes): sanitize remote import gen params

This commit is contained in:
Will Miao
2026-04-12 20:29:01 +08:00
parent 0253d001e6
commit 55e9e4bb6f
6 changed files with 244 additions and 137 deletions

View File

@@ -306,6 +306,120 @@ async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
assert "checkpoint" not in stored["gen_params"]
@pytest.mark.asyncio
async def test_save_recipe_strips_non_persistable_gen_params(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
metadata = {
"base_model": "Flux",
"loras": [],
"gen_params": {
"prompt": "hello world",
"negative_prompt": "bad hands",
"cfg_scale": 7,
"raw_metadata": {"prompt": "should not persist"},
"Version": "ComfyUI",
"RNG": "cpu",
"Schedule type": "karras",
"Discard penultimate sigma": True,
"eps_scaling_factor": 0.1,
},
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Sanitized",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["gen_params"] == {
"prompt": "hello world",
"negative_prompt": "bad hands",
"cfg_scale": 7,
}
@pytest.mark.asyncio
async def test_save_recipe_derives_allowed_fields_from_raw_metadata(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
metadata = {
"base_model": "Flux",
"loras": [],
"raw_metadata": {
"prompt": "hello world",
"negative_prompt": "bad hands",
"steps": 30,
"sampler": "Euler",
"cfg_scale": 7,
"seed": 123,
"size": "1024x1024",
"clip_skip": 2,
"Version": "ComfyUI",
"raw_metadata": {"nested": True},
},
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Derived",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["gen_params"] == {
"prompt": "hello world",
"negative_prompt": "bad hands",
"steps": 30,
"sampler": "Euler",
"cfg_scale": 7,
"seed": 123,
"size": "1024x1024",
"clip_skip": 2,
}
@pytest.mark.asyncio
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
exif_utils = DummyExifUtils()