feat: normalize and validate checkpoint entries before enrichment

Add _normalize_checkpoint_entry method to handle legacy checkpoint data formats (strings, tuples) by converting them to dictionaries. This prevents errors during enrichment when checkpoint data is not in the expected dictionary format. Invalid checkpoint entries are now removed instead of causing processing failures.

- Update get_paginated_data and get_recipe_by_id methods to use normalization
- Add test cases for legacy string and tuple checkpoint formats
- Ensure backward compatibility with existing checkpoint handling
This commit is contained in:
Will Miao
2025-11-21 23:36:32 +08:00
parent 02bac7edfb
commit 9198a23ba9
2 changed files with 56 additions and 2 deletions

View File

@@ -949,7 +949,11 @@ class RecipeScanner:
if 'loras' in item:
item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']]
if item.get('checkpoint'):
item['checkpoint'] = self._enrich_checkpoint_entry(dict(item['checkpoint']))
checkpoint_entry = self._normalize_checkpoint_entry(item['checkpoint'])
if checkpoint_entry:
item['checkpoint'] = self._enrich_checkpoint_entry(checkpoint_entry)
else:
item.pop('checkpoint', None)
result = {
'items': paginated_items,
@@ -998,7 +1002,11 @@ class RecipeScanner:
if 'loras' in formatted_recipe:
formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']]
if formatted_recipe.get('checkpoint'):
formatted_recipe['checkpoint'] = self._enrich_checkpoint_entry(dict(formatted_recipe['checkpoint']))
checkpoint_entry = self._normalize_checkpoint_entry(formatted_recipe['checkpoint'])
if checkpoint_entry:
formatted_recipe['checkpoint'] = self._enrich_checkpoint_entry(checkpoint_entry)
else:
formatted_recipe.pop('checkpoint', None)
return formatted_recipe

View File

@@ -256,6 +256,52 @@ async def test_load_recipe_upgrades_string_checkpoint(tmp_path: Path, recipe_sca
assert loaded["checkpoint"]["file_name"] == "sd15"
@pytest.mark.asyncio
async def test_get_paginated_data_normalizes_legacy_checkpoint(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "legacy.webp"
await scanner.add_recipe(
{
"id": "legacy-checkpoint",
"file_path": str(image_path),
"title": "Legacy",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": ["legacy.safetensors"],
}
)
await asyncio.sleep(0)
result = await scanner.get_paginated_data(page=1, page_size=5)
checkpoint = result["items"][0]["checkpoint"]
assert checkpoint["name"] == "legacy.safetensors"
assert checkpoint["file_name"] == "legacy"
@pytest.mark.asyncio
async def test_get_recipe_by_id_handles_non_dict_checkpoint(recipe_scanner):
scanner, _ = recipe_scanner
image_path = Path(config.loras_roots[0]) / "by-id.webp"
await scanner.add_recipe(
{
"id": "by-id-checkpoint",
"file_path": str(image_path),
"title": "ById",
"modified": 0.0,
"created_date": 0.0,
"loras": [],
"checkpoint": ("by-id.safetensors",),
}
)
recipe = await scanner.get_recipe_by_id("by-id-checkpoint")
assert recipe["checkpoint"]["name"] == "by-id.safetensors"
assert recipe["checkpoint"]["file_name"] == "by-id"
def test_enrich_uses_version_index_when_hash_missing(recipe_scanner):
scanner, stub = recipe_scanner
version_id = 77