mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-07 00:46:44 -03:00
fix(recipes): sanitize remote import gen params
This commit is contained in:
@@ -1,95 +1,86 @@
|
||||
import pytest
|
||||
from py.recipes.merger import GenParamsMerger
|
||||
|
||||
def test_merge_priority():
|
||||
request_params = {"prompt": "from request", "steps": 20}
|
||||
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
||||
|
||||
def test_merge_priority_and_normalization():
|
||||
request_params = {"prompt": "from request", "Steps": 20, "cfg": 7.5}
|
||||
civitai_meta = {"prompt": "from civitai", "cfgScale": 6.5, "negativePrompt": "bad"}
|
||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from request"
|
||||
assert merged["steps"] == 20
|
||||
assert merged["cfg"] == 7.0
|
||||
assert merged["seed"] == 123
|
||||
|
||||
def test_merge_no_request_params():
|
||||
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||
|
||||
merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from civitai"
|
||||
assert merged["cfg"] == 7.0
|
||||
assert merged["seed"] == 123
|
||||
assert merged == {
|
||||
"prompt": "from request",
|
||||
"steps": 20,
|
||||
"cfg_scale": 7.5,
|
||||
"negative_prompt": "bad",
|
||||
"seed": 123,
|
||||
}
|
||||
|
||||
|
||||
def test_merge_accepts_raw_embedded_metadata():
|
||||
embedded_metadata = {"prompt": "from raw embedded", "seed": 456, "scheduler": "karras"}
|
||||
|
||||
def test_merge_only_embedded():
|
||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||
|
||||
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from embedded"
|
||||
assert merged["seed"] == 123
|
||||
|
||||
def test_merge_raw_embedded():
|
||||
# Test when embedded metadata is just the gen_params themselves
|
||||
embedded_metadata = {"prompt": "from raw embedded", "seed": 456}
|
||||
|
||||
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
||||
|
||||
assert merged["prompt"] == "from raw embedded"
|
||||
assert merged["seed"] == 456
|
||||
assert merged == {
|
||||
"prompt": "from raw embedded",
|
||||
"seed": 456,
|
||||
"sampler": "karras",
|
||||
}
|
||||
|
||||
def test_merge_none_values():
|
||||
merged = GenParamsMerger.merge(None, None, None)
|
||||
assert merged == {}
|
||||
|
||||
def test_merge_filters_blacklisted_keys():
|
||||
request_params = {"prompt": "test", "id": "should-be-removed", "checkpoint": "should-not-be-here"}
|
||||
civitai_meta = {"cfg": 7, "url": "remove-me"}
|
||||
embedded_metadata = {"seed": 123, "hash": "remove-also"}
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||
|
||||
assert "prompt" in merged
|
||||
assert "cfg" in merged
|
||||
assert "seed" in merged
|
||||
assert "id" not in merged
|
||||
assert "url" not in merged
|
||||
assert "hash" not in merged
|
||||
assert "checkpoint" not in merged
|
||||
|
||||
def test_merge_filters_meta_and_normalizes_keys():
|
||||
def test_merge_filters_unknown_and_blacklisted_keys():
|
||||
request_params = {
|
||||
"prompt": "test",
|
||||
"id": "should-be-removed",
|
||||
"checkpoint": "should-not-be-here",
|
||||
"raw_metadata": {"prompt": "remove"},
|
||||
}
|
||||
civitai_meta = {
|
||||
"Version": "ComfyUI",
|
||||
"RNG": "cpu",
|
||||
"cfgScale": 7,
|
||||
"url": "remove-me",
|
||||
}
|
||||
embedded_metadata = {
|
||||
"seed": 123,
|
||||
"hash": "remove-also",
|
||||
"Discard penultimate sigma": True,
|
||||
"eps_scaling_factor": 0.1,
|
||||
}
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||
|
||||
assert merged == {
|
||||
"prompt": "test",
|
||||
"cfg_scale": 7,
|
||||
"seed": 123,
|
||||
}
|
||||
|
||||
|
||||
def test_merge_does_not_keep_original_key_variants():
|
||||
civitai_meta = {
|
||||
"prompt": "masterpiece",
|
||||
"cfgScale": 5,
|
||||
"clipSkip": 2,
|
||||
"negativePrompt": "low quality",
|
||||
"meta": {"irrelevant": "data"},
|
||||
"Size": "1024x1024",
|
||||
"draft": False,
|
||||
"workflow": "txt2img",
|
||||
"civitaiResources": [{"type": "checkpoint"}]
|
||||
"Denoising strength": 0.35,
|
||||
}
|
||||
request_params = {
|
||||
"cfg_scale": 5.0,
|
||||
"clip_skip": "2",
|
||||
"Steps": 30
|
||||
}
|
||||
|
||||
|
||||
merged = GenParamsMerger.merge(request_params, civitai_meta)
|
||||
|
||||
assert "meta" not in merged
|
||||
assert "cfgScale" not in merged
|
||||
assert "clipSkip" not in merged
|
||||
assert "negativePrompt" not in merged
|
||||
assert "Size" not in merged
|
||||
assert "draft" not in merged
|
||||
assert "workflow" not in merged
|
||||
assert "civitaiResources" not in merged
|
||||
|
||||
assert merged["cfg_scale"] == 5.0 # From request_params
|
||||
assert merged["clip_skip"] == "2" # From request_params
|
||||
assert merged["negative_prompt"] == "low quality" # Normalized from civitai_meta
|
||||
assert merged["size"] == "1024x1024" # Normalized from civitai_meta
|
||||
assert merged["steps"] == 30 # Normalized from request_params
|
||||
|
||||
assert merged == {
|
||||
"cfg_scale": 5.0,
|
||||
"clip_skip": "2",
|
||||
"negative_prompt": "low quality",
|
||||
"size": "1024x1024",
|
||||
"denoising_strength": 0.35,
|
||||
}
|
||||
|
||||
|
||||
def test_merge_none_values():
|
||||
assert GenParamsMerger.merge(None, None, None) == {}
|
||||
|
||||
@@ -306,6 +306,120 @@ async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
|
||||
assert "checkpoint" not in stored["gen_params"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_strips_non_persistable_gen_params(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
return []
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
return None
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"base_model": "Flux",
|
||||
"loras": [],
|
||||
"gen_params": {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"cfg_scale": 7,
|
||||
"raw_metadata": {"prompt": "should not persist"},
|
||||
"Version": "ComfyUI",
|
||||
"RNG": "cpu",
|
||||
"Schedule type": "karras",
|
||||
"Discard penultimate sigma": True,
|
||||
"eps_scaling_factor": 0.1,
|
||||
},
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"img",
|
||||
image_base64=None,
|
||||
name="Sanitized",
|
||||
tags=[],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["gen_params"] == {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"cfg_scale": 7,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_derives_allowed_fields_from_raw_metadata(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
return []
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
return None
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"base_model": "Flux",
|
||||
"loras": [],
|
||||
"raw_metadata": {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"steps": 30,
|
||||
"sampler": "Euler",
|
||||
"cfg_scale": 7,
|
||||
"seed": 123,
|
||||
"size": "1024x1024",
|
||||
"clip_skip": 2,
|
||||
"Version": "ComfyUI",
|
||||
"raw_metadata": {"nested": True},
|
||||
},
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"img",
|
||||
image_base64=None,
|
||||
name="Derived",
|
||||
tags=[],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["gen_params"] == {
|
||||
"prompt": "hello world",
|
||||
"negative_prompt": "bad hands",
|
||||
"steps": 30,
|
||||
"sampler": "Euler",
|
||||
"cfg_scale": 7,
|
||||
"seed": 123,
|
||||
"size": "1024x1024",
|
||||
"clip_skip": 2,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
Reference in New Issue
Block a user