feat: consolidate checkpoint metadata handling

- Extract checkpoint entry from multiple metadata locations using helper method
- Sanitize checkpoint metadata by removing transient/local-only fields
- Remove checkpoint duplication from generation parameters to store only at top level
- Update frontend to properly populate checkpoint metadata during import
- Add tests for new checkpoint handling functionality

This ensures consistent checkpoint metadata structure and prevents data duplication across different storage locations.
This commit is contained in:
Will Miao
2025-11-21 14:55:45 +08:00
parent 36f28b3c65
commit 4eb46a8d3e
3 changed files with 149 additions and 6 deletions

View File

@@ -79,7 +79,7 @@ class RecipePersistenceService:
current_time = time.time()
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
checkpoint_entry = metadata.get("checkpoint")
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
gen_params = metadata.get("gen_params") or {}
if not gen_params and "raw_metadata" in metadata:
@@ -87,7 +87,6 @@ class RecipePersistenceService:
gen_params = {
"prompt": raw_metadata.get("prompt", ""),
"negative_prompt": raw_metadata.get("negative_prompt", ""),
"checkpoint": raw_metadata.get("checkpoint", {}),
"steps": raw_metadata.get("steps", ""),
"sampler": raw_metadata.get("sampler", ""),
"cfg_scale": raw_metadata.get("cfg_scale", ""),
@@ -95,8 +94,9 @@ class RecipePersistenceService:
"size": raw_metadata.get("size", ""),
"clip_skip": raw_metadata.get("clip_skip", ""),
}
if checkpoint_entry and "checkpoint" not in gen_params:
gen_params["checkpoint"] = checkpoint_entry
# Drop checkpoint duplication from generation parameters to store it only at top level
gen_params.pop("checkpoint", None)
fingerprint = calculate_recipe_fingerprint(loras_data)
recipe_data: Dict[str, Any] = {
@@ -335,7 +335,7 @@ class RecipePersistenceService:
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"checkpoint": metadata.get("checkpoint", ""),
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
"gen_params": {
key: value
for key, value in metadata.items()
@@ -364,6 +364,30 @@ class RecipePersistenceService:
# Helper methods ---------------------------------------------------
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
"""Pull a checkpoint entry from various metadata locations."""
checkpoint_entry = metadata.get("checkpoint") or metadata.get("model")
if not checkpoint_entry:
gen_params = metadata.get("gen_params") or {}
checkpoint_entry = gen_params.get("checkpoint")
return checkpoint_entry if isinstance(checkpoint_entry, dict) else None
def _sanitize_checkpoint_entry(self, checkpoint_entry: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
"""Remove transient/local-only fields from checkpoint metadata."""
if not checkpoint_entry:
return None
if not isinstance(checkpoint_entry, dict):
return checkpoint_entry
pruned = dict(checkpoint_entry)
for key in ("existsLocally", "localPath", "thumbnailUrl", "size", "downloadUrl"):
pruned.pop(key, None)
return pruned
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
if image_bytes is not None:
return image_bytes

View File

@@ -56,6 +56,15 @@ export class DownloadManager {
gen_params: this.importManager.recipeData.gen_params || {},
raw_metadata: this.importManager.recipeData.raw_metadata || {}
};
const checkpointMetadata =
this.importManager.recipeData.checkpoint ||
this.importManager.recipeData.model ||
(this.importManager.recipeData.gen_params || {}).checkpoint;
if (checkpointMetadata && typeof checkpointMetadata === 'object') {
completeMetadata.checkpoint = checkpointMetadata;
}
// Add source_path to metadata to track where the recipe was imported from
if (this.importManager.importMode === 'url') {

View File

@@ -203,7 +203,117 @@ async def test_save_recipe_persists_checkpoint_metadata(tmp_path):
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["checkpoint"] == checkpoint_meta
assert stored["gen_params"]["checkpoint"] == checkpoint_meta
assert "checkpoint" not in stored["gen_params"]
@pytest.mark.asyncio
async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
checkpoint_meta = {
"type": "checkpoint",
"modelId": 10,
"modelVersionId": 20,
"modelName": "Flux",
"modelVersionName": "Dev",
}
metadata = {
"base_model": "Flux",
"loras": [],
"gen_params": {
"checkpoint": checkpoint_meta,
},
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Checkpointed",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["checkpoint"] == checkpoint_meta
assert "checkpoint" not in stored["gen_params"]
@pytest.mark.asyncio
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
async def find_recipes_by_fingerprint(self, fingerprint):
return []
async def add_recipe(self, recipe_data):
return None
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
checkpoint_meta = {
"type": "checkpoint",
"modelId": 10,
"modelVersionId": 20,
"modelName": "Flux",
"modelVersionName": "Dev",
"existsLocally": False,
"localPath": "/tmp/foo",
"thumbnailUrl": "http://example.com",
"size": 123,
"downloadUrl": "http://example.com/dl",
}
metadata = {
"base_model": "Flux",
"loras": [],
"checkpoint": checkpoint_meta,
}
result = await service.save_recipe(
recipe_scanner=scanner,
image_bytes=b"img",
image_base64=None,
name="Checkpointed",
tags=[],
metadata=metadata,
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert stored["checkpoint"] == {
"type": "checkpoint",
"modelId": 10,
"modelVersionId": 20,
"modelName": "Flux",
"modelVersionName": "Dev",
}
@pytest.mark.asyncio