mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
feat: consolidate checkpoint metadata handling
- Extract checkpoint entry from multiple metadata locations using helper method - Sanitize checkpoint metadata by removing transient/local-only fields - Remove checkpoint duplication from generation parameters to store only at top level - Update frontend to properly populate checkpoint metadata during import - Add tests for new checkpoint handling functionality This ensures consistent checkpoint metadata structure and prevents data duplication across different storage locations.
This commit is contained in:
@@ -79,7 +79,7 @@ class RecipePersistenceService:
|
|||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
loras_data = [self._normalise_lora_entry(lora) for lora in (metadata.get("loras") or [])]
|
||||||
checkpoint_entry = metadata.get("checkpoint")
|
checkpoint_entry = self._sanitize_checkpoint_entry(self._extract_checkpoint_entry(metadata))
|
||||||
|
|
||||||
gen_params = metadata.get("gen_params") or {}
|
gen_params = metadata.get("gen_params") or {}
|
||||||
if not gen_params and "raw_metadata" in metadata:
|
if not gen_params and "raw_metadata" in metadata:
|
||||||
@@ -87,7 +87,6 @@ class RecipePersistenceService:
|
|||||||
gen_params = {
|
gen_params = {
|
||||||
"prompt": raw_metadata.get("prompt", ""),
|
"prompt": raw_metadata.get("prompt", ""),
|
||||||
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
||||||
"checkpoint": raw_metadata.get("checkpoint", {}),
|
|
||||||
"steps": raw_metadata.get("steps", ""),
|
"steps": raw_metadata.get("steps", ""),
|
||||||
"sampler": raw_metadata.get("sampler", ""),
|
"sampler": raw_metadata.get("sampler", ""),
|
||||||
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
||||||
@@ -95,8 +94,9 @@ class RecipePersistenceService:
|
|||||||
"size": raw_metadata.get("size", ""),
|
"size": raw_metadata.get("size", ""),
|
||||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||||
}
|
}
|
||||||
if checkpoint_entry and "checkpoint" not in gen_params:
|
|
||||||
gen_params["checkpoint"] = checkpoint_entry
|
# Drop checkpoint duplication from generation parameters to store it only at top level
|
||||||
|
gen_params.pop("checkpoint", None)
|
||||||
|
|
||||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||||
recipe_data: Dict[str, Any] = {
|
recipe_data: Dict[str, Any] = {
|
||||||
@@ -335,7 +335,7 @@ class RecipePersistenceService:
|
|||||||
"created_date": time.time(),
|
"created_date": time.time(),
|
||||||
"base_model": most_common_base_model,
|
"base_model": most_common_base_model,
|
||||||
"loras": loras_data,
|
"loras": loras_data,
|
||||||
"checkpoint": metadata.get("checkpoint", ""),
|
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
|
||||||
"gen_params": {
|
"gen_params": {
|
||||||
key: value
|
key: value
|
||||||
for key, value in metadata.items()
|
for key, value in metadata.items()
|
||||||
@@ -364,6 +364,30 @@ class RecipePersistenceService:
|
|||||||
|
|
||||||
# Helper methods ---------------------------------------------------
|
# Helper methods ---------------------------------------------------
|
||||||
|
|
||||||
|
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
|
||||||
|
"""Pull a checkpoint entry from various metadata locations."""
|
||||||
|
|
||||||
|
checkpoint_entry = metadata.get("checkpoint") or metadata.get("model")
|
||||||
|
if not checkpoint_entry:
|
||||||
|
gen_params = metadata.get("gen_params") or {}
|
||||||
|
checkpoint_entry = gen_params.get("checkpoint")
|
||||||
|
|
||||||
|
return checkpoint_entry if isinstance(checkpoint_entry, dict) else None
|
||||||
|
|
||||||
|
def _sanitize_checkpoint_entry(self, checkpoint_entry: Optional[dict[str, Any]]) -> Optional[dict[str, Any]]:
|
||||||
|
"""Remove transient/local-only fields from checkpoint metadata."""
|
||||||
|
|
||||||
|
if not checkpoint_entry:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(checkpoint_entry, dict):
|
||||||
|
return checkpoint_entry
|
||||||
|
|
||||||
|
pruned = dict(checkpoint_entry)
|
||||||
|
for key in ("existsLocally", "localPath", "thumbnailUrl", "size", "downloadUrl"):
|
||||||
|
pruned.pop(key, None)
|
||||||
|
return pruned
|
||||||
|
|
||||||
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
|
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
|
||||||
if image_bytes is not None:
|
if image_bytes is not None:
|
||||||
return image_bytes
|
return image_bytes
|
||||||
|
|||||||
@@ -56,6 +56,15 @@ export class DownloadManager {
|
|||||||
gen_params: this.importManager.recipeData.gen_params || {},
|
gen_params: this.importManager.recipeData.gen_params || {},
|
||||||
raw_metadata: this.importManager.recipeData.raw_metadata || {}
|
raw_metadata: this.importManager.recipeData.raw_metadata || {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const checkpointMetadata =
|
||||||
|
this.importManager.recipeData.checkpoint ||
|
||||||
|
this.importManager.recipeData.model ||
|
||||||
|
(this.importManager.recipeData.gen_params || {}).checkpoint;
|
||||||
|
|
||||||
|
if (checkpointMetadata && typeof checkpointMetadata === 'object') {
|
||||||
|
completeMetadata.checkpoint = checkpointMetadata;
|
||||||
|
}
|
||||||
|
|
||||||
// Add source_path to metadata to track where the recipe was imported from
|
// Add source_path to metadata to track where the recipe was imported from
|
||||||
if (this.importManager.importMode === 'url') {
|
if (this.importManager.importMode === 'url') {
|
||||||
|
|||||||
@@ -203,7 +203,117 @@ async def test_save_recipe_persists_checkpoint_metadata(tmp_path):
|
|||||||
|
|
||||||
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||||
assert stored["checkpoint"] == checkpoint_meta
|
assert stored["checkpoint"] == checkpoint_meta
|
||||||
assert stored["gen_params"]["checkpoint"] == checkpoint_meta
|
assert "checkpoint" not in stored["gen_params"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_recipe_promotes_checkpoint_from_gen_params(tmp_path):
|
||||||
|
exif_utils = DummyExifUtils()
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, root):
|
||||||
|
self.recipes_dir = str(root)
|
||||||
|
|
||||||
|
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
service = RecipePersistenceService(
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
card_preview_width=512,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_meta = {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"modelId": 10,
|
||||||
|
"modelVersionId": 20,
|
||||||
|
"modelName": "Flux",
|
||||||
|
"modelVersionName": "Dev",
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"base_model": "Flux",
|
||||||
|
"loras": [],
|
||||||
|
"gen_params": {
|
||||||
|
"checkpoint": checkpoint_meta,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await service.save_recipe(
|
||||||
|
recipe_scanner=scanner,
|
||||||
|
image_bytes=b"img",
|
||||||
|
image_base64=None,
|
||||||
|
name="Checkpointed",
|
||||||
|
tags=[],
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||||
|
assert stored["checkpoint"] == checkpoint_meta
|
||||||
|
assert "checkpoint" not in stored["gen_params"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_recipe_strips_checkpoint_local_fields(tmp_path):
|
||||||
|
exif_utils = DummyExifUtils()
|
||||||
|
|
||||||
|
class DummyScanner:
|
||||||
|
def __init__(self, root):
|
||||||
|
self.recipes_dir = str(root)
|
||||||
|
|
||||||
|
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data):
|
||||||
|
return None
|
||||||
|
|
||||||
|
scanner = DummyScanner(tmp_path)
|
||||||
|
service = RecipePersistenceService(
|
||||||
|
exif_utils=exif_utils,
|
||||||
|
card_preview_width=512,
|
||||||
|
logger=logging.getLogger("test"),
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_meta = {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"modelId": 10,
|
||||||
|
"modelVersionId": 20,
|
||||||
|
"modelName": "Flux",
|
||||||
|
"modelVersionName": "Dev",
|
||||||
|
"existsLocally": False,
|
||||||
|
"localPath": "/tmp/foo",
|
||||||
|
"thumbnailUrl": "http://example.com",
|
||||||
|
"size": 123,
|
||||||
|
"downloadUrl": "http://example.com/dl",
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"base_model": "Flux",
|
||||||
|
"loras": [],
|
||||||
|
"checkpoint": checkpoint_meta,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await service.save_recipe(
|
||||||
|
recipe_scanner=scanner,
|
||||||
|
image_bytes=b"img",
|
||||||
|
image_base64=None,
|
||||||
|
name="Checkpointed",
|
||||||
|
tags=[],
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
stored = json.loads(Path(result.payload["json_path"]).read_text())
|
||||||
|
assert stored["checkpoint"] == {
|
||||||
|
"type": "checkpoint",
|
||||||
|
"modelId": 10,
|
||||||
|
"modelVersionId": 20,
|
||||||
|
"modelName": "Flux",
|
||||||
|
"modelVersionName": "Dev",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Reference in New Issue
Block a user