mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: consolidate checkpoint metadata handling
- Extract checkpoint entry from multiple metadata locations using helper method - Sanitize checkpoint metadata by removing transient/local-only fields - Remove checkpoint duplication from generation parameters to store only at top level - Update frontend to properly populate checkpoint metadata during import - Add tests for new checkpoint handling functionality This ensures consistent checkpoint metadata structure and prevents data duplication across different storage locations.
This commit is contained in:
@@ -79,7 +79,7 @@ class RecipePersistenceService:
|
||||
|
||||
current_time = time.time()
|
||||
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
||||
checkpoint_entry = metadata.get("checkpoint")
|
||||
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
|
||||
|
||||
gen_params = metadata.get("gen_params") or {}
|
||||
if not gen_params and "raw_metadata" in metadata:
|
||||
@@ -87,7 +87,6 @@ class RecipePersistenceService:
|
||||
gen_params = {
|
||||
"prompt": raw_metadata.get("prompt", ""),
|
||||
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
||||
"checkpoint": raw_metadata.get("checkpoint", {}),
|
||||
"steps": raw_metadata.get("steps", ""),
|
||||
"sampler": raw_metadata.get("sampler", ""),
|
||||
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
||||
@@ -95,8 +94,9 @@ class RecipePersistenceService:
|
||||
"size": raw_metadata.get("size", ""),
|
||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||
}
|
||||
if checkpoint_entry and "checkpoint" not in gen_params:
|
||||
gen_params["checkpoint"] = checkpoint_entry
|
||||
|
||||
# Drop checkpoint duplication from generation parameters to store it only at top level
|
||||
gen_params.pop("checkpoint", None)
|
||||
|
||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||
recipe_data: Dict[str, Any] = {
|
||||
@@ -335,7 +335,7 @@ class RecipePersistenceService:
|
||||
"created_date": time.time(),
|
||||
"base_model": most_common_base_model,
|
||||
"loras": loras_data,
|
||||
"checkpoint": metadata.get("checkpoint", ""),
|
||||
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
|
||||
"gen_params": {
|
||||
key: value
|
||||
for key, value in metadata.items()
|
||||
@@ -364,6 +364,30 @@ class RecipePersistenceService:
|
||||
|
||||
# Helper methods ---------------------------------------------------
|
||||
|
||||
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||
"""Pull a checkpoint entry from various metadata locations."""
|
||||
|
||||
checkpoint_entry = metadata.get("checkpoint") or metadata.get("model")
|
||||
if not checkpoint_entry:
|
||||
gen_params = metadata.get("gen_params") or {}
|
||||
checkpoint_entry = gen_params.get("checkpoint")
|
||||
|
||||
return checkpoint_entry if isinstance(checkpoint_entry, dict) else None
|
||||
|
||||
def _sanitize_checkpoint_entry(self, checkpoint_entry: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
|
||||
"""Remove transient/local-only fields from checkpoint metadata."""
|
||||
|
||||
if not checkpoint_entry:
|
||||
return None
|
||||
|
||||
if not isinstance(checkpoint_entry, dict):
|
||||
return checkpoint_entry
|
||||
|
||||
pruned = dict(checkpoint_entry)
|
||||
for key in ("existsLocally", "localPath", "thumbnailUrl", "size", "downloadUrl"):
|
||||
pruned.pop(key, None)
|
||||
return pruned
|
||||
|
||||
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
|
||||
if image_bytes is not None:
|
||||
return image_bytes
|
||||
|
||||
@@ -56,6 +56,15 @@ export class DownloadManager {
|
||||
gen_params: this.importManager.recipeData.gen_params || {},
|
||||
raw_metadata: this.importManager.recipeData.raw_metadata || {}
|
||||
};
|
||||
|
||||
const checkpointMetadata =
|
||||
this.importManager.recipeData.checkpoint ||
|
||||
this.importManager.recipeData.model ||
|
||||
(this.importManager.recipeData.gen_params || {}).checkpoint;
|
||||
|
||||
if (checkpointMetadata && typeof checkpointMetadata === 'object') {
|
||||
completeMetadata.checkpoint = checkpointMetadata;
|
||||
}
|
||||
|
||||
// Add source_path to metadata to track where the recipe was imported from
|
||||
if (this.importManager.importMode === 'url') {
|
||||
|
||||
@@ -203,7 +203,117 @@ async def test_save_recipe_persists_checkpoint_metadata(tmp_path):
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["checkpoint"] == checkpoint_meta
|
||||
assert stored["gen_params"]["checkpoint"] == checkpoint_meta
|
||||
assert "checkpoint" not in stored["gen_params"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
return []
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
return None
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
checkpoint_meta = {
|
||||
"type": "checkpoint",
|
||||
"modelId": 10,
|
||||
"modelVersionId": 20,
|
||||
"modelName": "Flux",
|
||||
"modelVersionName": "Dev",
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"base_model": "Flux",
|
||||
"loras": [],
|
||||
"gen_params": {
|
||||
"checkpoint": checkpoint_meta,
|
||||
},
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"img",
|
||||
image_base64=None,
|
||||
name="Checkpointed",
|
||||
tags=[],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["checkpoint"] == checkpoint_meta
|
||||
assert "checkpoint" not in stored["gen_params"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
return []
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
return None
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
checkpoint_meta = {
|
||||
"type": "checkpoint",
|
||||
"modelId": 10,
|
||||
"modelVersionId": 20,
|
||||
"modelName": "Flux",
|
||||
"modelVersionName": "Dev",
|
||||
"existsLocally": False,
|
||||
"localPath": "/tmp/foo",
|
||||
"thumbnailUrl": "http://example.com",
|
||||
"size": 123,
|
||||
"downloadUrl": "http://example.com/dl",
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"base_model": "Flux",
|
||||
"loras": [],
|
||||
"checkpoint": checkpoint_meta,
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"img",
|
||||
image_base64=None,
|
||||
name="Checkpointed",
|
||||
tags=[],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||
assert stored["checkpoint"] == {
|
||||
"type": "checkpoint",
|
||||
"modelId": 10,
|
||||
"modelVersionId": 20,
|
||||
"modelName": "Flux",
|
||||
"modelVersionName": "Dev",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user