diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py index b36ade5d..cee3ad0c 100644 --- a/py/routes/handlers/recipe_handlers.py +++ b/py/routes/handlers/recipe_handlers.py @@ -22,6 +22,7 @@ from ...services.recipes import ( RecipeSharingService, RecipeValidationError, ) +from ...services.metadata_service import get_default_metadata_provider Logger = logging.Logger EnsureDependenciesCallable = Callable[[], Awaitable[None]] @@ -451,7 +452,6 @@ class RecipeManagementHandler: raise RuntimeError("Recipe scanner unavailable") params = request.rel_url.query - print(params) image_url = params.get("image_url") name = params.get("name") resources_raw = params.get("resources") @@ -478,6 +478,9 @@ class RecipeManagementHandler: gen_params_ref = metadata.setdefault("gen_params", {}) if "checkpoint" not in gen_params_ref: gen_params_ref["checkpoint"] = checkpoint_entry + base_model_from_metadata = await self._resolve_base_model_from_checkpoint(checkpoint_entry) + if base_model_from_metadata: + metadata["base_model"] = base_model_from_metadata tags = self._parse_tags(params.get("tags")) image_bytes = await self._download_image_bytes(image_url) @@ -769,6 +772,29 @@ class RecipeManagementHandler: except (TypeError, ValueError): return 0 + async def _resolve_base_model_from_checkpoint(self, checkpoint_entry: Dict[str, Any]) -> str: + version_id = self._safe_int(checkpoint_entry.get("modelVersionId")) + + if not version_id: + return "" + + try: + provider = await get_default_metadata_provider() + if not provider: + return "" + + version_info = await provider.get_model_version_info(version_id) + if isinstance(version_info, tuple): + version_info = version_info[0] + + if isinstance(version_info, dict): + base_model = version_info.get("baseModel") or "" + return str(base_model) if base_model is not None else "" + except Exception as exc: # pragma: no cover - defensive logging + self._logger.warning("Failed to resolve base model from checkpoint metadata: %s", exc) + + return "" + class RecipeAnalysisHandler: """Analyze images to extract recipe metadata.""" diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py index 93e3578d..93d66300 100644 --- a/tests/routes/test_recipe_routes.py +++ b/tests/routes/test_recipe_routes.py @@ -13,6 +13,7 @@ from aiohttp.test_utils import TestClient, TestServer from py.config import config from py.routes import base_recipe_routes +from py.routes.handlers import recipe_handlers from py.routes.recipe_routes import RecipeRoutes from py.services.recipes import RecipeValidationError from py.services.service_registry import ServiceRegistry @@ -296,6 +297,18 @@ async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: + provider_calls: list[int] = [] + + class Provider: + async def get_model_version_info(self, model_version_id): + provider_calls.append(model_version_id) + return {"baseModel": "Flux Provider"}, None + + async def fake_get_default_metadata_provider(): + return Provider() + + monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider) + async with recipe_harness(monkeypatch, tmp_path) as harness: resources = [ { @@ -335,7 +348,8 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: assert call["name"] == "Remote Recipe" assert call["tags"] == ["foo", "bar"] metadata = call["metadata"] - assert metadata["base_model"] == "Flux" + assert metadata["base_model"] == "Flux Provider" + assert provider_calls == [33] assert metadata["checkpoint"]["modelVersionId"] == 33 assert metadata["loras"][0]["weight"] == 0.25 assert metadata["gen_params"]["prompt"] == "hello world" @@ -343,6 +357,49 @@ async def test_import_remote_recipe(monkeypatch, tmp_path: Path) -> None: assert harness.downloader.urls == ["https://example.com/images/1"] +async def test_import_remote_recipe_falls_back_to_request_base_model(monkeypatch, tmp_path: Path) -> None: + provider_calls: list[int] = [] + + class Provider: + async def get_model_version_info(self, model_version_id): + provider_calls.append(model_version_id) + return {}, None + + async def fake_get_default_metadata_provider(): + return Provider() + + monkeypatch.setattr(recipe_handlers, "get_default_metadata_provider", fake_get_default_metadata_provider) + + async with recipe_harness(monkeypatch, tmp_path) as harness: + resources = [ + { + "type": "checkpoint", + "modelId": 11, + "modelVersionId": 77, + "modelName": "Flux", + "modelVersionName": "Dev", + }, + ] + response = await harness.client.get( + "/api/lm/recipes/import-remote", + params={ + "image_url": "https://example.com/images/1", + "name": "Remote Recipe", + "resources": json.dumps(resources), + "tags": "foo,bar", + "base_model": "Flux", + }, + ) + + payload = await response.json() + assert response.status == 200 + assert payload["success"] is True + + metadata = harness.persistence.save_calls[-1]["metadata"] + assert metadata["base_model"] == "Flux" + assert provider_calls == [77] + + async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None: async with recipe_harness(monkeypatch, tmp_path) as harness: harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")