fix(recipes): save widget checkpoint metadata as dict

This commit is contained in:
Will Miao
2026-04-23 11:20:20 +08:00
parent ef7f677933
commit 658a04736d
5 changed files with 217 additions and 2 deletions

View File

@@ -780,8 +780,10 @@ NODE_EXTRACTORS = {
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
"CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
"UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager
"LoraLoader": LoraLoaderExtractor,
"LoraLoaderLM": LoraLoaderManagerExtractor,
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,

View File

@@ -1815,6 +1815,15 @@ class RecipeScanner:
return await self._lora_scanner.get_model_info_by_name(name)
async def get_local_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
"""Lookup a local checkpoint model by name."""
checkpoint_scanner = getattr(self, "_checkpoint_scanner", None)
if not checkpoint_scanner or not name:
return None
return await checkpoint_scanner.get_model_info_by_name(name)
async def get_paginated_data(
self,
page: int,

View File

@@ -508,6 +508,10 @@ class RecipePersistenceService:
most_common_base_model = (
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
)
checkpoint_entry = await self._build_widget_checkpoint_entry(
recipe_scanner,
metadata.get("checkpoint"),
)
recipe_data = {
"id": recipe_id,
@@ -515,9 +519,8 @@ class RecipePersistenceService:
"title": recipe_name,
"modified": time.time(),
"created_date": time.time(),
"base_model": most_common_base_model,
"base_model": most_common_base_model or (checkpoint_entry or {}).get("baseModel", ""),
"loras": loras_data,
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
"gen_params": {
key: value
for key, value in metadata.items()
@@ -525,6 +528,8 @@ class RecipePersistenceService:
},
"loras_stack": lora_stack,
}
if checkpoint_entry:
recipe_data["checkpoint"] = checkpoint_entry
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
@@ -546,6 +551,91 @@ class RecipePersistenceService:
# Helper methods ---------------------------------------------------
async def _build_widget_checkpoint_entry(
self,
recipe_scanner,
checkpoint_raw: Any,
) -> Optional[dict[str, Any]]:
"""Build recipe checkpoint metadata from widget generation metadata."""
if isinstance(checkpoint_raw, dict):
return self._sanitize_checkpoint_entry(checkpoint_raw)
if not isinstance(checkpoint_raw, str):
return None
checkpoint_name = checkpoint_raw.strip()
if not checkpoint_name:
return None
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
checkpoint_info = await self._lookup_widget_checkpoint(
recipe_scanner,
checkpoint_name,
)
if not checkpoint_info:
return {
"type": "checkpoint",
"name": checkpoint_name,
"file_name": file_name,
"hash": "",
}
civitai = checkpoint_info.get("civitai") or {}
civitai_model = civitai.get("model") or {}
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
cached_file_name = (
checkpoint_info.get("file_name")
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
or file_name
)
return {
"type": "checkpoint",
"modelId": civitai_model.get("id", 0),
"modelVersionId": civitai.get("id", 0),
"name": civitai_model.get("name") or checkpoint_info.get("model_name") or checkpoint_name,
"version": civitai.get("name", ""),
"hash": (checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "").lower(),
"file_name": cached_file_name,
"modelName": civitai_model.get("name", ""),
"modelVersionName": civitai.get("name", ""),
"baseModel": checkpoint_info.get("base_model") or civitai.get("baseModel", ""),
}
async def _lookup_widget_checkpoint(
self,
recipe_scanner,
checkpoint_name: str,
) -> Optional[dict[str, Any]]:
lookup = getattr(recipe_scanner, "get_local_checkpoint", None)
if not callable(lookup):
return None
candidates = []
for candidate in (
checkpoint_name,
os.path.basename(checkpoint_name),
os.path.splitext(os.path.basename(checkpoint_name))[0],
):
if candidate and candidate not in candidates:
candidates.append(candidate)
for candidate in candidates:
try:
checkpoint_info = await lookup(candidate)
except Exception as exc:
self._logger.debug(
"Failed to lookup checkpoint %s while saving widget recipe: %s",
candidate,
exc,
)
continue
if checkpoint_info:
return checkpoint_info
return None
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
"""Pull a checkpoint entry from various metadata locations."""

View File

@@ -354,3 +354,33 @@ def test_lora_manager_cache_updates_when_loras_removed(metadata_registry):
metadata = metadata_registry.get_metadata("prompt3")
assert "lora_node" not in metadata[LORAS]
def test_lora_manager_checkpoint_and_unet_loaders_extract_models(metadata_registry):
metadata_registry.start_collection("prompt1")
metadata_registry.record_node_execution(
"checkpoint_node",
"CheckpointLoaderLM",
{"ckpt_name": ["models/checkpoint.safetensors"]},
None,
)
metadata_registry.record_node_execution(
"unet_node",
"UNETLoaderLM",
{"unet_name": ["models/diffusion_model.safetensors"], "weight_dtype": ["default"]},
None,
)
metadata = metadata_registry.get_metadata("prompt1")
assert metadata[MODELS]["checkpoint_node"] == {
"name": "models/checkpoint.safetensors",
"type": "checkpoint",
"node_id": "checkpoint_node",
}
assert metadata[MODELS]["unet_node"] == {
"name": "models/diffusion_model.safetensors",
"type": "checkpoint",
"node_id": "unet_node",
}

View File

@@ -491,6 +491,9 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
async def get_local_lora(self, name): # pragma: no cover - no lookups expected
return None
async def get_local_checkpoint(self, name):
return None
async def add_recipe(self, recipe_data):
self.added.append(recipe_data)
@@ -518,9 +521,90 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path):
assert stored["loras"] == []
assert stored["title"] == "recipe"
assert stored["checkpoint"] == {
"type": "checkpoint",
"name": "base-model.safetensors",
"file_name": "base-model",
"hash": "",
}
assert scanner.added and scanner.added[0]["loras"] == []
@pytest.mark.asyncio
async def test_save_recipe_from_widget_enriches_checkpoint_from_local_cache(tmp_path):
exif_utils = DummyExifUtils()
class DummyScanner:
def __init__(self, root):
self.recipes_dir = str(root)
self.added = []
self.checkpoint_queries = []
async def get_local_lora(self, name): # pragma: no cover - no loras
return None
async def get_local_checkpoint(self, name):
self.checkpoint_queries.append(name)
if name != "matched-model":
return None
return {
"file_name": "matched-model",
"file_path": "/models/checkpoints/folder/matched-model.safetensors",
"sha256": "ABC123",
"base_model": "Illustrious",
"civitai": {
"id": 456,
"name": "v1.0",
"baseModel": "Illustrious",
"model": {
"id": 123,
"name": "Matched Model",
},
},
}
async def add_recipe(self, recipe_data):
self.added.append(recipe_data)
scanner = DummyScanner(tmp_path)
service = RecipePersistenceService(
exif_utils=exif_utils,
card_preview_width=512,
logger=logging.getLogger("test"),
)
result = await service.save_recipe_from_widget(
recipe_scanner=scanner,
metadata={
"loras": "",
"checkpoint": "folder/matched-model.safetensors",
"prompt": "a calm scene",
},
image_bytes=b"image-bytes",
)
stored = json.loads(Path(result.payload["json_path"]).read_text())
assert scanner.checkpoint_queries == [
"folder/matched-model.safetensors",
"matched-model.safetensors",
"matched-model",
]
assert stored["base_model"] == "Illustrious"
assert stored["checkpoint"] == {
"type": "checkpoint",
"modelId": 123,
"modelVersionId": 456,
"name": "Matched Model",
"version": "v1.0",
"hash": "abc123",
"file_name": "matched-model",
"modelName": "Matched Model",
"modelVersionName": "v1.0",
"baseModel": "Illustrious",
}
@pytest.mark.asyncio
async def test_move_recipe_updates_paths(tmp_path):
exif_utils = DummyExifUtils()