mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
fix(recipes): sanitize remote import gen params
This commit is contained in:
@@ -13,4 +13,5 @@ GEN_PARAM_KEYS = [
|
|||||||
'seed',
|
'seed',
|
||||||
'size',
|
'size',
|
||||||
'clip_skip',
|
'clip_skip',
|
||||||
|
'denoising_strength',
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,27 +1,33 @@
|
|||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from .constants import GEN_PARAM_KEYS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GenParamsMerger:
|
class GenParamsMerger:
|
||||||
"""Utility to merge generation parameters from multiple sources with priority."""
|
"""Utility to merge generation parameters from multiple sources with priority."""
|
||||||
|
|
||||||
|
ALLOWED_KEYS = set(GEN_PARAM_KEYS)
|
||||||
|
|
||||||
BLACKLISTED_KEYS = {
|
BLACKLISTED_KEYS = {
|
||||||
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
|
"id", "url", "userId", "username", "createdAt", "updatedAt", "hash", "meta",
|
||||||
"draft", "extra", "width", "height", "process", "quantity", "workflow",
|
"draft", "extra", "width", "height", "process", "quantity", "workflow",
|
||||||
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
|
"baseModel", "resources", "disablePoi", "aspectRatio", "Created Date",
|
||||||
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
|
"experimental", "civitaiResources", "civitai_resources", "Civitai resources",
|
||||||
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
|
"modelVersionId", "modelId", "hashes", "Model", "Model hash", "checkpoint_hash",
|
||||||
"checkpoint", "checksum", "model_checksum"
|
"checkpoint", "checksum", "model_checksum", "raw_metadata",
|
||||||
}
|
}
|
||||||
|
|
||||||
NORMALIZATION_MAPPING = {
|
NORMALIZATION_MAPPING = {
|
||||||
# Civitai specific
|
"cfg": "cfg_scale",
|
||||||
"cfgScale": "cfg_scale",
|
"cfgScale": "cfg_scale",
|
||||||
"clipSkip": "clip_skip",
|
"clipSkip": "clip_skip",
|
||||||
"negativePrompt": "negative_prompt",
|
"negativePrompt": "negative_prompt",
|
||||||
# Case variations
|
|
||||||
"Sampler": "sampler",
|
"Sampler": "sampler",
|
||||||
|
"sampler_name": "sampler",
|
||||||
|
"scheduler": "sampler",
|
||||||
"Steps": "steps",
|
"Steps": "steps",
|
||||||
"Seed": "seed",
|
"Seed": "seed",
|
||||||
"Size": "size",
|
"Size": "size",
|
||||||
@@ -36,63 +42,40 @@ class GenParamsMerger:
|
|||||||
def merge(
|
def merge(
|
||||||
request_params: Optional[Dict[str, Any]] = None,
|
request_params: Optional[Dict[str, Any]] = None,
|
||||||
civitai_meta: Optional[Dict[str, Any]] = None,
|
civitai_meta: Optional[Dict[str, Any]] = None,
|
||||||
embedded_metadata: Optional[Dict[str, Any]] = None
|
embedded_metadata: Optional[Dict[str, Any]] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Merge generation parameters from three sources.
|
Merge generation parameters from three sources.
|
||||||
|
|
||||||
Priority: request_params > civitai_meta > embedded_metadata
|
Priority: request_params > civitai_meta > embedded_metadata
|
||||||
|
|
||||||
Args:
|
|
||||||
request_params: Params provided directly in the import request
|
|
||||||
civitai_meta: Params from Civitai Image API 'meta' field
|
|
||||||
embedded_metadata: Params extracted from image EXIF/embedded metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Merged parameters dictionary
|
|
||||||
"""
|
"""
|
||||||
result = {}
|
result: Dict[str, Any] = {}
|
||||||
|
|
||||||
# 1. Start with embedded metadata (lowest priority)
|
|
||||||
if embedded_metadata:
|
if embedded_metadata:
|
||||||
# If it's a full recipe metadata, we use its gen_params
|
if "gen_params" in embedded_metadata and isinstance(
|
||||||
if "gen_params" in embedded_metadata and isinstance(embedded_metadata["gen_params"], dict):
|
embedded_metadata["gen_params"], dict
|
||||||
|
):
|
||||||
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
|
GenParamsMerger._update_normalized(result, embedded_metadata["gen_params"])
|
||||||
else:
|
else:
|
||||||
# Otherwise assume the dict itself contains gen_params
|
|
||||||
GenParamsMerger._update_normalized(result, embedded_metadata)
|
GenParamsMerger._update_normalized(result, embedded_metadata)
|
||||||
|
|
||||||
# 2. Layer Civitai meta (medium priority)
|
|
||||||
if civitai_meta:
|
if civitai_meta:
|
||||||
GenParamsMerger._update_normalized(result, civitai_meta)
|
GenParamsMerger._update_normalized(result, civitai_meta)
|
||||||
|
|
||||||
# 3. Layer request params (highest priority)
|
|
||||||
if request_params:
|
if request_params:
|
||||||
GenParamsMerger._update_normalized(result, request_params)
|
GenParamsMerger._update_normalized(result, request_params)
|
||||||
|
|
||||||
# Filter out blacklisted keys and also the original camelCase keys if they were normalized
|
return result
|
||||||
final_result = {}
|
|
||||||
for k, v in result.items():
|
|
||||||
if k in GenParamsMerger.BLACKLISTED_KEYS:
|
|
||||||
continue
|
|
||||||
if k in GenParamsMerger.NORMALIZATION_MAPPING:
|
|
||||||
continue
|
|
||||||
final_result[k] = v
|
|
||||||
|
|
||||||
return final_result
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
def _update_normalized(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||||
"""Update target dict with normalized keys from source."""
|
"""Update target dict with normalized, persistence-safe keys from source."""
|
||||||
for k, v in source.items():
|
for key, value in source.items():
|
||||||
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(k, k)
|
if key in GenParamsMerger.BLACKLISTED_KEYS:
|
||||||
target[normalized_key] = v
|
continue
|
||||||
# Also keep the original key for now if it's not the same,
|
|
||||||
# so we can filter at the end or avoid losing it if it wasn't supposed to be renamed?
|
normalized_key = GenParamsMerger.NORMALIZATION_MAPPING.get(key, key)
|
||||||
# Actually, if we rename it, we should probably NOT keep both in 'target'
|
if normalized_key not in GenParamsMerger.ALLOWED_KEYS:
|
||||||
# because we want to filter them out at the end anyway.
|
continue
|
||||||
if normalized_key != k:
|
|
||||||
# If we are overwriting an existing snake_case key with a camelCase one's value,
|
target[normalized_key] = value
|
||||||
# that's fine because of the priority order of calls to _update_normalized.
|
|
||||||
pass
|
|
||||||
target[k] = v
|
|
||||||
|
|||||||
@@ -756,6 +756,14 @@ class RecipeManagementHandler:
|
|||||||
)
|
)
|
||||||
gen_params_request = self._parse_gen_params(params.get("gen_params"))
|
gen_params_request = self._parse_gen_params(params.get("gen_params"))
|
||||||
|
|
||||||
|
self._logger.info(
|
||||||
|
"Remote recipe import received: url=%s, request_gen_params_keys=%s, lora_count=%d, checkpoint_keys=%s",
|
||||||
|
image_url,
|
||||||
|
sorted(gen_params_request.keys()) if gen_params_request else [],
|
||||||
|
len(lora_entries),
|
||||||
|
sorted(checkpoint_entry.keys()) if isinstance(checkpoint_entry, dict) else [],
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Initial Metadata Construction
|
# 2. Initial Metadata Construction
|
||||||
metadata: Dict[str, Any] = {
|
metadata: Dict[str, Any] = {
|
||||||
"base_model": params.get("base_model", "") or "",
|
"base_model": params.get("base_model", "") or "",
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Dict, Iterable, Optional
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
from ...config import config
|
from ...config import config
|
||||||
|
from ...recipes.constants import GEN_PARAM_KEYS
|
||||||
from ...utils.utils import calculate_recipe_fingerprint
|
from ...utils.utils import calculate_recipe_fingerprint
|
||||||
from .errors import RecipeNotFoundError, RecipeValidationError
|
from .errors import RecipeNotFoundError, RecipeValidationError
|
||||||
|
|
||||||
@@ -90,23 +91,7 @@ class RecipePersistenceService:
|
|||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
||||||
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
|
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
|
||||||
|
gen_params = self._sanitize_gen_params_for_storage(metadata)
|
||||||
gen_params = metadata.get("gen_params") or {}
|
|
||||||
if not gen_params and "raw_metadata" in metadata:
|
|
||||||
raw_metadata = metadata.get("raw_metadata", {})
|
|
||||||
gen_params = {
|
|
||||||
"prompt": raw_metadata.get("prompt", ""),
|
|
||||||
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
|
||||||
"steps": raw_metadata.get("steps", ""),
|
|
||||||
"sampler": raw_metadata.get("sampler", ""),
|
|
||||||
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
|
||||||
"seed": raw_metadata.get("seed", ""),
|
|
||||||
"size": raw_metadata.get("size", ""),
|
|
||||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Drop checkpoint duplication from generation parameters to store it only at top level
|
|
||||||
gen_params.pop("checkpoint", None)
|
|
||||||
|
|
||||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||||
recipe_data: Dict[str, Any] = {
|
recipe_data: Dict[str, Any] = {
|
||||||
@@ -133,6 +118,7 @@ class RecipePersistenceService:
|
|||||||
json_filename = f"{recipe_id}.recipe.json"
|
json_filename = f"{recipe_id}.recipe.json"
|
||||||
json_path = os.path.join(recipes_dir, json_filename)
|
json_path = os.path.join(recipes_dir, json_filename)
|
||||||
json_path = os.path.normpath(json_path)
|
json_path = os.path.normpath(json_path)
|
||||||
|
|
||||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
@@ -152,6 +138,30 @@ class RecipePersistenceService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_gen_params_for_storage(metadata: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
gen_params = metadata.get("gen_params")
|
||||||
|
if isinstance(gen_params, dict) and gen_params:
|
||||||
|
source = gen_params
|
||||||
|
else:
|
||||||
|
source = metadata.get("raw_metadata")
|
||||||
|
|
||||||
|
if not isinstance(source, dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
allowed_keys = set(GEN_PARAM_KEYS)
|
||||||
|
sanitized: dict[str, Any] = {}
|
||||||
|
for key in allowed_keys:
|
||||||
|
if key not in source:
|
||||||
|
continue
|
||||||
|
value = source.get(key)
|
||||||
|
if value in (None, ""):
|
||||||
|
continue
|
||||||
|
sanitized[key] = value
|
||||||
|
|
||||||
|
sanitized.pop("checkpoint", None)
|
||||||
|
return sanitized
|
||||||
|
|
||||||
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
|
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
|
||||||
"""Delete an existing recipe."""
|
"""Delete an existing recipe."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,95 +1,86 @@
|
|||||||
import pytest
|
|
||||||
from py.recipes.merger import GenParamsMerger
|
from py.recipes.merger import GenParamsMerger
|
||||||
|
|
||||||
def test_merge_priority():
|
|
||||||
request_params = {"prompt": "from request", "steps": 20}
|
def test_merge_priority_and_normalization():
|
||||||
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
request_params = {"prompt": "from request", "Steps": 20, "cfg": 7.5}
|
||||||
|
civitai_meta = {"prompt": "from civitai", "cfgScale": 6.5, "negativePrompt": "bad"}
|
||||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
||||||
|
|
||||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||||
|
|
||||||
assert merged["prompt"] == "from request"
|
assert merged == {
|
||||||
assert merged["steps"] == 20
|
"prompt": "from request",
|
||||||
assert merged["cfg"] == 7.0
|
"steps": 20,
|
||||||
assert merged["seed"] == 123
|
"cfg_scale": 7.5,
|
||||||
|
"negative_prompt": "bad",
|
||||||
|
"seed": 123,
|
||||||
|
}
|
||||||
|
|
||||||
def test_merge_no_request_params():
|
|
||||||
civitai_meta = {"prompt": "from civitai", "cfg": 7.0}
|
|
||||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
|
||||||
|
|
||||||
merged = GenParamsMerger.merge(None, civitai_meta, embedded_metadata)
|
def test_merge_accepts_raw_embedded_metadata():
|
||||||
|
embedded_metadata = {"prompt": "from raw embedded", "seed": 456, "scheduler": "karras"}
|
||||||
assert merged["prompt"] == "from civitai"
|
|
||||||
assert merged["cfg"] == 7.0
|
|
||||||
assert merged["seed"] == 123
|
|
||||||
|
|
||||||
def test_merge_only_embedded():
|
|
||||||
embedded_metadata = {"gen_params": {"prompt": "from embedded", "seed": 123}}
|
|
||||||
|
|
||||||
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
||||||
|
|
||||||
assert merged["prompt"] == "from embedded"
|
assert merged == {
|
||||||
assert merged["seed"] == 123
|
"prompt": "from raw embedded",
|
||||||
|
"seed": 456,
|
||||||
|
"sampler": "karras",
|
||||||
|
}
|
||||||
|
|
||||||
def test_merge_raw_embedded():
|
|
||||||
# Test when embedded metadata is just the gen_params themselves
|
|
||||||
embedded_metadata = {"prompt": "from raw embedded", "seed": 456}
|
|
||||||
|
|
||||||
merged = GenParamsMerger.merge(None, None, embedded_metadata)
|
def test_merge_filters_unknown_and_blacklisted_keys():
|
||||||
|
request_params = {
|
||||||
assert merged["prompt"] == "from raw embedded"
|
"prompt": "test",
|
||||||
assert merged["seed"] == 456
|
"id": "should-be-removed",
|
||||||
|
"checkpoint": "should-not-be-here",
|
||||||
def test_merge_none_values():
|
"raw_metadata": {"prompt": "remove"},
|
||||||
merged = GenParamsMerger.merge(None, None, None)
|
}
|
||||||
assert merged == {}
|
civitai_meta = {
|
||||||
|
"Version": "ComfyUI",
|
||||||
def test_merge_filters_blacklisted_keys():
|
"RNG": "cpu",
|
||||||
request_params = {"prompt": "test", "id": "should-be-removed", "checkpoint": "should-not-be-here"}
|
"cfgScale": 7,
|
||||||
civitai_meta = {"cfg": 7, "url": "remove-me"}
|
"url": "remove-me",
|
||||||
embedded_metadata = {"seed": 123, "hash": "remove-also"}
|
}
|
||||||
|
embedded_metadata = {
|
||||||
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
"seed": 123,
|
||||||
|
"hash": "remove-also",
|
||||||
assert "prompt" in merged
|
"Discard penultimate sigma": True,
|
||||||
assert "cfg" in merged
|
"eps_scaling_factor": 0.1,
|
||||||
assert "seed" in merged
|
}
|
||||||
assert "id" not in merged
|
|
||||||
assert "url" not in merged
|
merged = GenParamsMerger.merge(request_params, civitai_meta, embedded_metadata)
|
||||||
assert "hash" not in merged
|
|
||||||
assert "checkpoint" not in merged
|
assert merged == {
|
||||||
|
"prompt": "test",
|
||||||
def test_merge_filters_meta_and_normalizes_keys():
|
"cfg_scale": 7,
|
||||||
|
"seed": 123,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_does_not_keep_original_key_variants():
|
||||||
civitai_meta = {
|
civitai_meta = {
|
||||||
"prompt": "masterpiece",
|
|
||||||
"cfgScale": 5,
|
"cfgScale": 5,
|
||||||
"clipSkip": 2,
|
"clipSkip": 2,
|
||||||
"negativePrompt": "low quality",
|
"negativePrompt": "low quality",
|
||||||
"meta": {"irrelevant": "data"},
|
|
||||||
"Size": "1024x1024",
|
"Size": "1024x1024",
|
||||||
"draft": False,
|
"Denoising strength": 0.35,
|
||||||
"workflow": "txt2img",
|
|
||||||
"civitaiResources": [{"type": "checkpoint"}]
|
|
||||||
}
|
}
|
||||||
request_params = {
|
request_params = {
|
||||||
"cfg_scale": 5.0,
|
"cfg_scale": 5.0,
|
||||||
"clip_skip": "2",
|
"clip_skip": "2",
|
||||||
"Steps": 30
|
|
||||||
}
|
}
|
||||||
|
|
||||||
merged = GenParamsMerger.merge(request_params, civitai_meta)
|
merged = GenParamsMerger.merge(request_params, civitai_meta)
|
||||||
|
|
||||||
assert "meta" not in merged
|
assert merged == {
|
||||||
assert "cfgScale" not in merged
|
"cfg_scale": 5.0,
|
||||||
assert "clipSkip" not in merged
|
"clip_skip": "2",
|
||||||
assert "negativePrompt" not in merged
|
"negative_prompt": "low quality",
|
||||||
assert "Size" not in merged
|
"size": "1024x1024",
|
||||||
assert "draft" not in merged
|
"denoising_strength": 0.35,
|
||||||
assert "workflow" not in merged
|
}
|
||||||
assert "civitaiResources" not in merged
|
|
||||||
|
|
||||||
assert merged["cfg_scale"] == 5.0 # From request_params
|
|
||||||
assert merged["clip_skip"] == "2" # From request_params
|
def test_merge_none_values():
|
||||||
assert merged["negative_prompt"] == "low quality" # Normalized from civitai_meta
|
assert GenParamsMerger.merge(None, None, None) == {}
|
||||||
assert merged["size"] == "1024x1024" # Normalized from civitai_meta
|
|
||||||
assert merged["steps"] == 30 # Normalized from request_params
|
|
||||||
|
|||||||
@@ -306,6 +306,120 @@ async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
|
|||||||
assert "checkpoint" not in stored["gen_params"]
|
assert "checkpoint" not in stored["gen_params"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_recipe_strips_non_persistable_gen_params(tmp_path):
|
||||||
|
exif_utils = DummyExifUtils()
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, root):
|
||||||
|
self.recipes_dir = str(root)
|
||||||
|
|
||||||
|
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
service = RecipePersistenceService(
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
card_preview_width=512,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"base_model": "Flux",
|
||||||
|
"loras": [],
|
||||||
|
"gen_params": {
|
||||||
|
"prompt": "hello world",
|
||||||
|
"negative_prompt": "bad hands",
|
||||||
|
"cfg_scale": 7,
|
||||||
|
"raw_metadata": {"prompt": "should not persist"},
|
||||||
|
"Version": "ComfyUI",
|
||||||
|
"RNG": "cpu",
|
||||||
|
"Schedule type": "karras",
|
||||||
|
"Discard penultimate sigma": True,
|
||||||
|
"eps_scaling_factor": 0.1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await service.save_recipe(
|
||||||
|
recipe_scanner=scanner,
|
||||||
|
image_bytes=b"img",
|
||||||
|
image_base64=None,
|
||||||
|
name="Sanitized",
|
||||||
|
tags=[],
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||||
|
assert stored["gen_params"] == {
|
||||||
|
"prompt": "hello world",
|
||||||
|
"negative_prompt": "bad hands",
|
||||||
|
"cfg_scale": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_recipe_derives_allowed_fields_from_raw_metadata(tmp_path):
|
||||||
|
exif_utils = DummyExifUtils()
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, root):
|
||||||
|
self.recipes_dir = str(root)
|
||||||
|
|
||||||
|
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
service = RecipePersistenceService(
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
card_preview_width=512,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"base_model": "Flux",
|
||||||
|
"loras": [],
|
||||||
|
"raw_metadata": {
|
||||||
|
"prompt": "hello world",
|
||||||
|
"negative_prompt": "bad hands",
|
||||||
|
"steps": 30,
|
||||||
|
"sampler": "Euler",
|
||||||
|
"cfg_scale": 7,
|
||||||
|
"seed": 123,
|
||||||
|
"size": "1024x1024",
|
||||||
|
"clip_skip": 2,
|
||||||
|
"Version": "ComfyUI",
|
||||||
|
"raw_metadata": {"nested": True},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await service.save_recipe(
|
||||||
|
recipe_scanner=scanner,
|
||||||
|
image_bytes=b"img",
|
||||||
|
image_base64=None,
|
||||||
|
name="Derived",
|
||||||
|
tags=[],
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||||
|
assert stored["gen_params"] == {
|
||||||
|
"prompt": "hello world",
|
||||||
|
"negative_prompt": "bad hands",
|
||||||
|
"steps": 30,
|
||||||
|
"sampler": "Euler",
|
||||||
|
"cfg_scale": 7,
|
||||||
|
"seed": 123,
|
||||||
|
"size": "1024x1024",
|
||||||
|
"clip_skip": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
|
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
|
||||||
exif_utils = DummyExifUtils()
|
exif_utils = DummyExifUtils()
|
||||||
|
|||||||
Reference in New Issue
Block a user