From 2452cc4df1c7339657a0c38852d26d6bee808149 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Fri, 21 Nov 2025 12:12:27 +0800 Subject: [PATCH] feat(recipes): resolve base model from checkpoint metadata Add metadata service integration to automatically resolve base model information from checkpoint metadata during recipe import. This replaces the previous approach of relying solely on request parameters and provides more accurate base model information. - Add _resolve_base_model_from_checkpoint method to fetch base model from metadata provider - Update recipe import logic to use resolved base model when available - Add comprehensive tests for base model resolution with fallback behavior - Remove debug print statement from import parameters --- py/routes/handlers/recipe_handlers.py | 28 ++++++++++++- tests/routes/test_recipe_routes.py | 59 ++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 2 deletions(-) diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index b36ade5d..cee3ad0c 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -22,6 +22,7 @@ from ...services.recipes import ( RecipeSharingService, RecipeValidationError, ) +from ...services.metadata_service import get_default_metadata_provider Logger = logging.Logger EnsureDependenciesCallable = Callable[[], Awaitable[None]] @@ -451,7 +452,6 @@ class RecipeManagementHandler: raise RuntimeError("Recipe scanner unavailable") params = request.rel_url.query - print(params) image_url = params.get("image_url") name = params.get("name") resources_raw = params.get("resources") @@ -478,6 +478,9 @@ class RecipeManagementHandler: gen_params_ref = metadata.setdefault("gen_params", {}) if "checkpoint" not in gen_params_ref: gen_params_ref["checkpoint"] = checkpoint_entry + base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry) + if base_model_from_metadata: + metadata["base_model"] = base_model_from_metadata tags = self._parse_tags(params.get("tags")) image_bytes = await self._download_image_bytes(image_url) @@ -769,6 +772,29 @@ class RecipeManagementHandler: except (TypeError, ValueError): return 0 + async def _resolve_base_model_from_checkpoint(self, checkpoint_entry: Dict[str, Any]) -> str: + version_id = self._safe_int(checkpoint_entry.get("modelVersionId")) + + if not version_id: + return "" + + try: + provider = await get_default_metadata_provider() + if not provider: + return "" + + version_info = await provider.get_model_version_info(version_id) + if isinstance(version_info, tuple): + version_info = version_info[0] + + if isinstance(version_info, dict): + base_model = version_info.get("baseModel") or "" + return str(base_model) if base_model is not None else "" + except Exception as exc: # pragma: no cover - defensive logging + self._logger.warning("Failed to resolve base model from checkpoint metadata: %s", exc) + + return "" + class RecipeAnalysisHandler: """Analyze images to extract recipe metadata.""" diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py index 93e3578d..93d66300 100644 --- a/tests/routes/test_recipe_routes.py +++ b/tests/routes/test_recipe_routes.py @@ -13,6 +13,7 @@ from aiohttp.test_utils import TestClient, TestServer from py.config import config from py.routes import base_recipe_routes +from py.routes.handlers import recipe_handlers from py.routes.recipe_routes import RecipeRoutes from py.services.recipes import RecipeValidationError from py.services.service_registry import ServiceRegistry @@ -296,6 +297,18 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: + provider_calls: list[int] = [] + + class Provider: + async def get_model_version_info(self, model_version_id): + provider_calls.append(model_version_id) + return {"baseModel": "Flux Provider"}, None + + async def fake_get_default_metadata_provider(): + return Provider() + + monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider) + async with recipe_harness(monkeypatch, tmp_path) as harness: resources = [ { @@ -335,7 +348,8 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: assert call["name"] == "Remote Recipe" assert call["tags"] == ["foo", "bar"] metadata = call["metadata"] - assert metadata["base_model"] == "Flux" + assert metadata["base_model"] == "Flux Provider" + assert provider_calls == [33] assert metadata["checkpoint"]["modelVersionId"] == 33 assert metadata["loras"][0]["weight"] == 0.25 assert metadata["gen_params"]["prompt"] == "hello world" @@ -343,6 +357,49 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: assert harness.downloader.urls == ["https://example.com/images/1"] +async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch, tmp_path: Path) -> None: + provider_calls: list[int] = [] + + class Provider: + async def get_model_version_info(self, model_version_id): + provider_calls.append(model_version_id) + return {}, None + + async def fake_get_default_metadata_provider(): + return Provider() + + monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider) + + async with recipe_harness(monkeypatch, tmp_path) as harness: + resources = [ + { + "type": "checkpoint", + "modelId": 11, + "modelVersionId": 77, + "modelName": "Flux", + "modelVersionName": "Dev", + }, + ] + response = await harness.client.get( + "/api/lm/recipes/import-remote", + params={ + "image_url": "https://example.com/images/1", + "name": "Remote Recipe", + "resources": json.dumps(resources), + "tags": "foo,bar", + "base_model": "Flux", + }, + ) + + payload = await response.json() + assert response.status == 200 + assert payload["success"] is True + + metadata = harness.persistence.save_calls[-1]["metadata"] + assert metadata["base_model"] == "Flux" + assert provider_calls == [77] + + async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None: async with recipe_harness(monkeypatch, tmp_path) as harness: harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")