From 7b139b9b1da2a3470b1af2f6e887d94e7382f088 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Wed, 24 Dec 2025 20:35:06 +0800 Subject: [PATCH] refactor: Move `base_model` resolution to occur before checkpoint formatting and remove a `gen_params` checkpoint assertion. --- py/recipes/enrichment.py | 22 ++++++++++------------ tests/routes/test_recipe_routes.py | 1 - 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/py/recipes/enrichment.py b/py/recipes/enrichment.py index 83274164..34acdfcd 100644 --- a/py/recipes/enrichment.py +++ b/py/recipes/enrichment.py @@ -128,17 +128,7 @@ class RecipeEnricher: else: # Checkpoint exists, no need to sync to gen_params anymore. pass - # If base_model is empty or very generic, try to use what we found in checkpoint - current_base_model = recipe.get("base_model") - checkpoint_after = recipe.get("checkpoint") - if checkpoint_after and checkpoint_after.get("baseModel"): - resolved_base_model = checkpoint_after["baseModel"] - # Update if empty OR if it matches our generic prefix but is less specific - is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"] - if is_generic and resolved_base_model != current_base_model: - recipe["base_model"] = resolved_base_model - updated = True - + # base_model resolution moved to _resolve_and_populate_checkpoint to support strict formatting return updated @staticmethod @@ -190,8 +180,16 @@ class RecipeEnricher: if existing_cp is None: existing_cp = {} checkpoint_data = await RecipeMetadataParser.populate_checkpoint_from_civitai(existing_cp, civitai_info) + # 1. First, resolve base_model using full data before we format it away + current_base_model = recipe.get("base_model") + resolved_base_model = checkpoint_data.get("baseModel") + if resolved_base_model: + # Update if empty OR if it matches our generic prefix but is less specific + is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"] + if is_generic and resolved_base_model != current_base_model: + recipe["base_model"] = resolved_base_model - # Format according to requirements: type, modelId, modelVersionId, modelName, modelVersionName + # 2. Format according to requirements: type, modelId, modelVersionId, modelName, modelVersionName formatted_checkpoint = { "type": "checkpoint", "modelId": checkpoint_data.get("modelId"), diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py index 9b8d6050..f1292d55 100644 --- a/tests/routes/test_recipe_routes.py +++ b/tests/routes/test_recipe_routes.py @@ -394,7 +394,6 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: assert metadata["checkpoint"]["modelVersionId"] == 33 assert metadata["loras"][0]["weight"] == 0.25 assert metadata["gen_params"]["prompt"] == "hello world" - assert metadata["gen_params"]["checkpoint"]["modelVersionId"] == 33 assert harness.downloader.urls == ["https://example.com/images/1"]