diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index fbc628c5..2640035e 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -79,7 +79,7 @@ class RecipePersistenceService: current_time = time.time() loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])] - checkpoint_entry = metadata.get("checkpoint") + checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata)) gen_params = metadata.get("gen_params") or {} if not gen_params and "raw_metadata" in metadata: @@ -87,7 +87,6 @@ class RecipePersistenceService: gen_params = { "prompt": raw_metadata.get("prompt", ""), "negative_prompt": raw_metadata.get("negative_prompt", ""), - "checkpoint": raw_metadata.get("checkpoint", {}), "steps": raw_metadata.get("steps", ""), "sampler": raw_metadata.get("sampler", ""), "cfg_scale": raw_metadata.get("cfg_scale", ""), @@ -95,8 +94,9 @@ class RecipePersistenceService: "size": raw_metadata.get("size", ""), "clip_skip": raw_metadata.get("clip_skip", ""), } - if checkpoint_entry and "checkpoint" not in gen_params: - gen_params["checkpoint"] = checkpoint_entry + + # Drop checkpoint duplication from generation parameters to store it only at top level + gen_params.pop("checkpoint", None) fingerprint = calculate_recipe_fingerprint(loras_data) recipe_data: Dict[str, Any] = { @@ -335,7 +335,7 @@ class RecipePersistenceService: "created_date": time.time(), "base_model": most_common_base_model, "loras": loras_data, - "checkpoint": metadata.get("checkpoint", ""), + "checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")), "gen_params": { key: value for key, value in metadata.items() @@ -364,6 +364,30 @@ class RecipePersistenceService: # Helper methods --------------------------------------------------- + def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]: + """Pull a checkpoint entry from various metadata locations.""" + + checkpoint_entry = metadata.get("checkpoint") or metadata.get("model") + if not checkpoint_entry: + gen_params = metadata.get("gen_params") or {} + checkpoint_entry = gen_params.get("checkpoint") + + return checkpoint_entry if isinstance(checkpoint_entry, dict) else None + + def _sanitize_checkpoint_entry(self, checkpoint_entry: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]: + """Remove transient/local-only fields from checkpoint metadata.""" + + if not checkpoint_entry: + return None + + if not isinstance(checkpoint_entry, dict): + return checkpoint_entry + + pruned = dict(checkpoint_entry) + for key in ("existsLocally", "localPath", "thumbnailUrl", "size", "downloadUrl"): + pruned.pop(key, None) + return pruned + def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes: if image_bytes is not None: return image_bytes diff --git a/static/js/managers/import/DownloadManager.js b/static/js/managers/import/DownloadManager.js index b8b3a4e4..dc4b24b7 100644 --- a/static/js/managers/import/DownloadManager.js +++ b/static/js/managers/import/DownloadManager.js @@ -56,6 +56,15 @@ export class DownloadManager { gen_params: this.importManager.recipeData.gen_params || {}, raw_metadata: this.importManager.recipeData.raw_metadata || {} }; + + const checkpointMetadata = + this.importManager.recipeData.checkpoint || + this.importManager.recipeData.model || + (this.importManager.recipeData.gen_params || {}).checkpoint; + + if (checkpointMetadata && typeof checkpointMetadata === 'object') { + completeMetadata.checkpoint = checkpointMetadata; + } // Add source_path to metadata to track where the recipe was imported from if (this.importManager.importMode === 'url') { diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index fb6dd4e5..d0fc1462 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -203,7 +203,117 @@ async def test_save_recipe_persists_checkpoint_metadata(tmp_path): stored = json.loads(Path(result.payload["json_path"]).read_text()) assert stored["checkpoint"] == checkpoint_meta - assert stored["gen_params"]["checkpoint"] == checkpoint_meta + assert "checkpoint" not in stored["gen_params"] + + +@pytest.mark.asyncio +async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path): + exif_utils = DummyExifUtils() + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + + async def find_recipes_by_fingerprint(self, fingerprint): + return [] + + async def add_recipe(self, recipe_data): + return None + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + checkpoint_meta = { + "type": "checkpoint", + "modelId": 10, + "modelVersionId": 20, + "modelName": "Flux", + "modelVersionName": "Dev", + } + + metadata = { + "base_model": "Flux", + "loras": [], + "gen_params": { + "checkpoint": checkpoint_meta, + }, + } + + result = await service.save_recipe( + recipe_scanner=scanner, + image_bytes=b"img", + image_base64=None, + name="Checkpointed", + tags=[], + metadata=metadata, + ) + + stored = json.loads(Path(result.payload["json_path"]).read_text()) + assert stored["checkpoint"] == checkpoint_meta + assert "checkpoint" not in stored["gen_params"] + + +@pytest.mark.asyncio +async def test_save_recipe_strips_checkpoint_local_fields(tmp_path): + exif_utils = DummyExifUtils() + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + + async def find_recipes_by_fingerprint(self, fingerprint): + return [] + + async def add_recipe(self, recipe_data): + return None + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + checkpoint_meta = { + "type": "checkpoint", + "modelId": 10, + "modelVersionId": 20, + "modelName": "Flux", + "modelVersionName": "Dev", + "existsLocally": False, + "localPath": "/tmp/foo", + "thumbnailUrl": "http://example.com", + "size": 123, + "downloadUrl": "http://example.com/dl", + } + + metadata = { + "base_model": "Flux", + "loras": [], + "checkpoint": checkpoint_meta, + } + + result = await service.save_recipe( + recipe_scanner=scanner, + image_bytes=b"img", + image_base64=None, + name="Checkpointed", + tags=[], + metadata=metadata, + ) + + stored = json.loads(Path(result.payload["json_path"]).read_text()) + assert stored["checkpoint"] == { + "type": "checkpoint", + "modelId": 10, + "modelVersionId": 20, + "modelName": "Flux", + "modelVersionName": "Dev", + } @pytest.mark.asyncio