refactor: Move base_model resolution to occur before checkpoint formatting and remove a gen_params checkpoint assertion.

This commit is contained in:
Will Miao
2025-12-24 20:35:06 +08:00
parent a552f07448
commit 7b139b9b1d
2 changed files with 10 additions and 13 deletions

View File

@@ -128,17 +128,7 @@ class RecipeEnricher:
else:
# Checkpoint exists, no need to sync to gen_params anymore.
pass
# If base_model is empty or very generic, try to use what we found in checkpoint
current_base_model = recipe.get("base_model")
checkpoint_after = recipe.get("checkpoint")
if checkpoint_after and checkpoint_after.get("baseModel"):
resolved_base_model = checkpoint_after["baseModel"]
# Update if empty OR if it matches our generic prefix but is less specific
is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"]
if is_generic and resolved_base_model != current_base_model:
recipe["base_model"] = resolved_base_model
updated = True
# base_model resolution moved to _resolve_and_populate_checkpoint to support strict formatting
return updated
@staticmethod
@@ -190,8 +180,16 @@ class RecipeEnricher:
if existing_cp is None:
existing_cp = {}
checkpoint_data = await RecipeMetadataParser.populate_checkpoint_from_civitai(existing_cp, civitai_info)
# 1. First, resolve base_model using full data before we format it away
current_base_model = recipe.get("base_model")
resolved_base_model = checkpoint_data.get("baseModel")
if resolved_base_model:
# Update if empty OR if it matches our generic prefix but is less specific
is_generic = not current_base_model or current_base_model.lower() in ["flux", "sdxl", "sd15"]
if is_generic and resolved_base_model != current_base_model:
recipe["base_model"] = resolved_base_model
# Format according to requirements: type, modelId, modelVersionId, modelName, modelVersionName
# 2. Format according to requirements: type, modelId, modelVersionId, modelName, modelVersionName
formatted_checkpoint = {
"type": "checkpoint",
"modelId": checkpoint_data.get("modelId"),

View File

@@ -394,7 +394,6 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None:
assert metadata["checkpoint"]["modelVersionId"] == 33
assert metadata["loras"][0]["weight"] == 0.25
assert metadata["gen_params"]["prompt"] == "hello world"
assert metadata["gen_params"]["checkpoint"]["modelVersionId"] == 33
assert harness.downloader.urls == ["https://example.com/images/1"]