From 658a04736d5a5b9800125be41748f9835b88088b Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 23 Apr 2026 11:20:20 +0800 Subject: [PATCH] fix(recipes): save widget checkpoint metadata as dict --- py/metadata_collector/node_extractors.py | 2 + py/services/recipe_scanner.py | 9 ++ py/services/recipes/persistence_service.py | 94 ++++++++++++++++++- .../test_metadata_collector.py | 30 ++++++ tests/services/test_recipe_services.py | 84 +++++++++++++++++ 5 files changed, 217 insertions(+), 2 deletions(-) diff --git a/py/metadata_collector/node_extractors.py b/py/metadata_collector/node_extractors.py index 1dda702f..7d6096a7 100644 --- a/py/metadata_collector/node_extractors.py +++ b/py/metadata_collector/node_extractors.py @@ -780,8 +780,10 @@ NODE_EXTRACTORS = { "GGUFLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes "DiffusionModelLoaderKJ": KJNodesModelLoaderExtractor, # KJNodes "CheckpointLoaderKJ": CheckpointLoaderExtractor, # KJNodes + "CheckpointLoaderLM": CheckpointLoaderExtractor, # LoRA Manager "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor + "UNETLoaderLM": UNETLoaderExtractor, # LoRA Manager "LoraLoader": LoraLoaderExtractor, "LoraLoaderLM": LoraLoaderManagerExtractor, "RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor, diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index 240fa752..3797e5d9 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -1815,6 +1815,15 @@ class RecipeScanner: return await self._lora_scanner.get_model_info_by_name(name) + async def get_local_checkpoint(self, name: str) -> Optional[Dict[str, Any]]: + """Lookup a local checkpoint model by name.""" + + checkpoint_scanner = getattr(self, "_checkpoint_scanner", None) + if not checkpoint_scanner or not name: + return None + + return await checkpoint_scanner.get_model_info_by_name(name) + async def get_paginated_data( self, page: int, diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py index 3c5a7c00..fdd06fd4 100644 --- a/py/services/recipes/persistence_service.py +++ b/py/services/recipes/persistence_service.py @@ -508,6 +508,10 @@ class RecipePersistenceService: most_common_base_model = ( max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else "" ) + checkpoint_entry = await self._build_widget_checkpoint_entry( + recipe_scanner, + metadata.get("checkpoint"), + ) recipe_data = { "id": recipe_id, @@ -515,9 +519,8 @@ class RecipePersistenceService: "title": recipe_name, "modified": time.time(), "created_date": time.time(), - "base_model": most_common_base_model, + "base_model": most_common_base_model or (checkpoint_entry or {}).get("baseModel", ""), "loras": loras_data, - "checkpoint": self._sanitize_checkpoint_entry(metadata.get("checkpoint", "")), "gen_params": { key: value for key, value in metadata.items() @@ -525,6 +528,8 @@ class RecipePersistenceService: }, "loras_stack": lora_stack, } + if checkpoint_entry: + recipe_data["checkpoint"] = checkpoint_entry json_filename = f"{recipe_id}.recipe.json" json_path = os.path.join(recipes_dir, json_filename) @@ -546,6 +551,91 @@ class RecipePersistenceService: # Helper methods --------------------------------------------------- + async def _build_widget_checkpoint_entry( + self, + recipe_scanner, + checkpoint_raw: Any, + ) -> Optional[dict[str, Any]]: + """Build recipe checkpoint metadata from widget generation metadata.""" + + if isinstance(checkpoint_raw, dict): + return self._sanitize_checkpoint_entry(checkpoint_raw) + + if not isinstance(checkpoint_raw, str): + return None + + checkpoint_name = checkpoint_raw.strip() + if not checkpoint_name: + return None + + file_name = os.path.splitext(os.path.basename(checkpoint_name))[0] + checkpoint_info = await self._lookup_widget_checkpoint( + recipe_scanner, + checkpoint_name, + ) + if not checkpoint_info: + return { + "type": "checkpoint", + "name": checkpoint_name, + "file_name": file_name, + "hash": "", + } + + civitai = checkpoint_info.get("civitai") or {} + civitai_model = civitai.get("model") or {} + file_path = checkpoint_info.get("file_path") or checkpoint_info.get("path") or "" + cached_file_name = ( + checkpoint_info.get("file_name") + or (os.path.splitext(os.path.basename(file_path))[0] if file_path else "") + or file_name + ) + + return { + "type": "checkpoint", + "modelId": civitai_model.get("id", 0), + "modelVersionId": civitai.get("id", 0), + "name": civitai_model.get("name") or checkpoint_info.get("model_name") or checkpoint_name, + "version": civitai.get("name", ""), + "hash": (checkpoint_info.get("sha256") or checkpoint_info.get("hash") or "").lower(), + "file_name": cached_file_name, + "modelName": civitai_model.get("name", ""), + "modelVersionName": civitai.get("name", ""), + "baseModel": checkpoint_info.get("base_model") or civitai.get("baseModel", ""), + } + + async def _lookup_widget_checkpoint( + self, + recipe_scanner, + checkpoint_name: str, + ) -> Optional[dict[str, Any]]: + lookup = getattr(recipe_scanner, "get_local_checkpoint", None) + if not callable(lookup): + return None + + candidates = [] + for candidate in ( + checkpoint_name, + os.path.basename(checkpoint_name), + os.path.splitext(os.path.basename(checkpoint_name))[0], + ): + if candidate and candidate not in candidates: + candidates.append(candidate) + + for candidate in candidates: + try: + checkpoint_info = await lookup(candidate) + except Exception as exc: + self._logger.debug( + "Failed to lookup checkpoint %s while saving widget recipe: %s", + candidate, + exc, + ) + continue + if checkpoint_info: + return checkpoint_info + + return None + def _extract_checkpoint_entry(self, metadata: dict[str, Any]) -> Optional[dict[str, Any]]: """Pull a checkpoint entry from various metadata locations.""" diff --git a/tests/metadata_collector/test_metadata_collector.py b/tests/metadata_collector/test_metadata_collector.py index e84d6fd7..dcb6b770 100644 --- a/tests/metadata_collector/test_metadata_collector.py +++ b/tests/metadata_collector/test_metadata_collector.py @@ -354,3 +354,33 @@ def test_lora_manager_cache_updates_when_loras_removed(metadata_registry): metadata = metadata_registry.get_metadata("prompt3") assert "lora_node" not in metadata[LORAS] + + +def test_lora_manager_checkpoint_and_unet_loaders_extract_models(metadata_registry): + metadata_registry.start_collection("prompt1") + + metadata_registry.record_node_execution( + "checkpoint_node", + "CheckpointLoaderLM", + {"ckpt_name": ["models/checkpoint.safetensors"]}, + None, + ) + metadata_registry.record_node_execution( + "unet_node", + "UNETLoaderLM", + {"unet_name": ["models/diffusion_model.safetensors"], "weight_dtype": ["default"]}, + None, + ) + + metadata = metadata_registry.get_metadata("prompt1") + + assert metadata[MODELS]["checkpoint_node"] == { + "name": "models/checkpoint.safetensors", + "type": "checkpoint", + "node_id": "checkpoint_node", + } + assert metadata[MODELS]["unet_node"] == { + "name": "models/diffusion_model.safetensors", + "type": "checkpoint", + "node_id": "unet_node", + } diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py index 255441f7..417183dc 100644 --- a/tests/services/test_recipe_services.py +++ b/tests/services/test_recipe_services.py @@ -491,6 +491,9 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path): async def get_local_lora(self, name): # pragma: no cover - no lookups expected return None + async def get_local_checkpoint(self, name): + return None + async def add_recipe(self, recipe_data): self.added.append(recipe_data) @@ -518,9 +521,90 @@ async def test_save_recipe_from_widget_allows_empty_lora(tmp_path): assert stored["loras"] == [] assert stored["title"] == "recipe" + assert stored["checkpoint"] == { + "type": "checkpoint", + "name": "base-model.safetensors", + "file_name": "base-model", + "hash": "", + } assert scanner.added and scanner.added[0]["loras"] == [] +@pytest.mark.asyncio +async def test_save_recipe_from_widget_enriches_checkpoint_from_local_cache(tmp_path): + exif_utils = DummyExifUtils() + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + self.added = [] + self.checkpoint_queries = [] + + async def get_local_lora(self, name): # pragma: no cover - no loras + return None + + async def get_local_checkpoint(self, name): + self.checkpoint_queries.append(name) + if name != "matched-model": + return None + return { + "file_name": "matched-model", + "file_path": "/models/checkpoints/folder/matched-model.safetensors", + "sha256": "ABC123", + "base_model": "Illustrious", + "civitai": { + "id": 456, + "name": "v1.0", + "baseModel": "Illustrious", + "model": { + "id": 123, + "name": "Matched Model", + }, + }, + } + + async def add_recipe(self, recipe_data): + self.added.append(recipe_data) + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + result = await service.save_recipe_from_widget( + recipe_scanner=scanner, + metadata={ + "loras": "", + "checkpoint": "folder/matched-model.safetensors", + "prompt": "a calm scene", + }, + image_bytes=b"image-bytes", + ) + + stored = json.loads(Path(result.payload["json_path"]).read_text()) + + assert scanner.checkpoint_queries == [ + "folder/matched-model.safetensors", + "matched-model.safetensors", + "matched-model", + ] + assert stored["base_model"] == "Illustrious" + assert stored["checkpoint"] == { + "type": "checkpoint", + "modelId": 123, + "modelVersionId": 456, + "name": "Matched Model", + "version": "v1.0", + "hash": "abc123", + "file_name": "matched-model", + "modelName": "Matched Model", + "modelVersionName": "v1.0", + "baseModel": "Illustrious", + } + + @pytest.mark.asyncio async def test_move_recipe_updates_paths(tmp_path): exif_utils = DummyExifUtils()