diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 1d649791..806f12b3 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -441,7 +441,12 @@ class RecipeScanner: await self._update_lora_information(recipe_data) if recipe_data.get('checkpoint'): - recipe_data['checkpoint'] = self._enrich_checkpoint_entry(dict(recipe_data['checkpoint'])) + checkpoint_entry = self._normalize_checkpoint_entry(recipe_data['checkpoint']) + if checkpoint_entry: + recipe_data['checkpoint'] = self._enrich_checkpoint_entry(checkpoint_entry) + else: + logger.warning("Dropping invalid checkpoint entry in %s", recipe_path) + recipe_data.pop('checkpoint', None) # Calculate and update fingerprint if missing if 'loras' in recipe_data and 'fingerprint' not in recipe_data: @@ -665,6 +670,29 @@ class RecipeScanner: logger.error(f"Error getting base model for lora: {e}") return None + def _normalize_checkpoint_entry(self, checkpoint_raw: Any) -> Optional[Dict[str, Any]]: + """Coerce legacy or malformed checkpoint entries into a dict.""" + + if isinstance(checkpoint_raw, dict): + return dict(checkpoint_raw) + + if isinstance(checkpoint_raw, (list, tuple)) and len(checkpoint_raw) == 1: + return self._normalize_checkpoint_entry(checkpoint_raw[0]) + + if isinstance(checkpoint_raw, str): + name = checkpoint_raw.strip() + if not name: + return None + + file_name = os.path.splitext(os.path.basename(name))[0] + return { + "name": name, + "file_name": file_name, + } + + logger.warning("Unexpected checkpoint payload type %s", type(checkpoint_raw).__name__) + return None + def _enrich_checkpoint_entry(self, checkpoint: Dict[str, Any]) -> Dict[str, Any]: """Populate convenience fields for a checkpoint entry.""" diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py index 45d0ce33..9d2b2519 100644 --- a/tests/services/test_recipe_scanner.py +++ b/tests/services/test_recipe_scanner.py @@ -226,6 +226,36 @@ async def test_load_recipe_rewrites_missing_image_path(tmp_path: Path, recipe_sc assert persisted["file_path"] == expected_path +@pytest.mark.asyncio +async def test_load_recipe_upgrades_string_checkpoint(tmp_path: Path, recipe_scanner): + scanner, _ = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + recipe_id = "legacy-checkpoint" + image_path = recipes_dir / f"{recipe_id}.webp" + recipe_path = recipes_dir / f"{recipe_id}.recipe.json" + recipe_path.write_text( + json.dumps( + { + "id": recipe_id, + "file_path": str(image_path), + "title": "Legacy", + "modified": 0.0, + "created_date": 0.0, + "loras": [], + "checkpoint": "sd15.safetensors", + } + ) + ) + + loaded = await scanner._load_recipe_file(str(recipe_path)) + + assert isinstance(loaded["checkpoint"], dict) + assert loaded["checkpoint"]["name"] == "sd15.safetensors" + assert loaded["checkpoint"]["file_name"] == "sd15" + + def test_enrich_uses_version_index_when_hash_missing(recipe_scanner): scanner, stub = recipe_scanner version_id = 77