fix(recipes): save widget checkpoint metadata as dict

This commit is contained in:
Will Miao
2026-04-23 11:20:20 +08:00
parent ef7f677933
commit 658a04736d
5 changed files with 217 additions and 2 deletions

View File

@@ -780,8 +780,10 @@ NODE_EXTRACTORS = {
"GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
"DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes
"CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes
"CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
"UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager
"LoraLoader": LoraLoaderExtractor,
"LoraLoaderLM": LoraLoaderManagerExtractor,
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,

View File

@@ -1815,6 +1815,15 @@ class RecipeScanner:
return await self._lora_scanner.get_model_info_by_name(name)
async def get_local_checkpoint(self, name: str) -> Optional[Dict[str, Any]]:
"""Lookup a local checkpoint model by name."""
checkpoint_scanner = getattr(self, "_checkpoint_scanner", None)
if not checkpoint_scanner or not name:
return None
return await checkpoint_scanner.get_model_info_by_name(name)
async def get_paginated_data(
self,
page: int,

View File

@@ -508,6 +508,10 @@ class RecipePersistenceService:
most_common_base_model = (
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
)
checkpoint_entry = await self._build_widget_checkpoint_entry(
recipe_scanner,
metadata.get("checkpoint"),
)
recipe_data = {
"id": recipe_id,
@@ -515,9 +519,8 @@ class RecipePersistenceService:
"title": recipe_name,
"modified": time.time(),
"created_date": time.time(),
"base_model": most_common_base_model,
"base_model": most_common_base_model or (checkpoint_entry or {}).get("baseModel", ""),
"loras": loras_data,
"checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")),
"gen_params": {
key: value
for key, value in metadata.items()
@@ -525,6 +528,8 @@ class RecipePersistenceService:
},
"loras_stack": lora_stack,
}
if checkpoint_entry:
recipe_data["checkpoint"] = checkpoint_entry
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
@@ -546,6 +551,91 @@ class RecipePersistenceService:
# Helper methods ---------------------------------------------------
async def _build_widget_checkpoint_entry(
self,
recipe_scanner,
checkpoint_raw: Any,
) -> Optional[dict[str, Any]]:
"""Build recipe checkpoint metadata from widget generation metadata."""
if isinstance(checkpoint_raw, dict):
return self._sanitize_checkpoint_entry(checkpoint_raw)
if not isinstance(checkpoint_raw, str):
return None
checkpoint_name = checkpoint_raw.strip()
if not checkpoint_name:
return None
file_name = os.path.splitext(os.path.basename(checkpoint_name))[0]
checkpoint_info = await self._lookup_widget_checkpoint(
recipe_scanner,
checkpoint_name,
)
if not checkpoint_info:
return {
"type": "checkpoint",
"name": checkpoint_name,
"file_name": file_name,
"hash": "",
}
civitai = checkpoint_info.get("civitai") or {}
civitai_model = civitai.get("model") or {}
file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or ""
cached_file_name = (
checkpoint_info.get("file_name")
or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "")
or file_name
)
return {
"type": "checkpoint",
"modelId": civitai_model.get("id", 0),
"modelVersionId": civitai.get("id", 0),
"name": civitai_model.get("name") or checkpoint_info.get("model_name") or checkpoint_name,
"version": civitai.get("name", ""),
"hash": (checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "").lower(),
"file_name": cached_file_name,
"modelName": civitai_model.get("name", ""),
"modelVersionName": civitai.get("name", ""),
"baseModel": checkpoint_info.get("base_model") or civitai.get("baseModel", ""),
}
async def _lookup_widget_checkpoint(
self,
recipe_scanner,
checkpoint_name: str,
) -> Optional[dict[str, Any]]:
lookup = getattr(recipe_scanner, "get_local_checkpoint", None)
if not callable(lookup):
return None
candidates = []
for candidate in (
checkpoint_name,
os.path.basename(checkpoint_name),
os.path.splitext(os.path.basename(checkpoint_name))[0],
):
if candidate and candidate not in candidates:
candidates.append(candidate)
for candidate in candidates:
try:
checkpoint_info = await lookup(candidate)
except Exception as exc:
self._logger.debug(
"Failed to lookup checkpoint %s while saving widget recipe: %s",
candidate,
exc,
)
continue
if checkpoint_info:
return checkpoint_info
return None
def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]:
"""Pull a checkpoint entry from various metadata locations."""