From 9344d86332ed0c42ceca23dfe818e4646f33dd68 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Fri, 3 Apr 2026 22:16:09 +0800 Subject: [PATCH] test(misc): cover model existence download status --- .../__snapshots__/test_api_snapshots.ambr | 2 + tests/routes/test_misc_routes.py | 48 ++++++++++++++++++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/tests/routes/__snapshots__/test_api_snapshots.ambr b/tests/routes/__snapshots__/test_api_snapshots.ambr index 83a386bf..51be5c49 100644 --- a/tests/routes/__snapshots__/test_api_snapshots.ambr +++ b/tests/routes/__snapshots__/test_api_snapshots.ambr @@ -1,6 +1,8 @@ # serializer version: 1 # name: TestModelLibraryHandlerSnapshots.test_check_model_exists_empty_response dict({ + 'downloadedVersionIds': list([ + ]), 'modelType': None, 'success': True, 'versions': list([ diff --git a/tests/routes/test_misc_routes.py b/tests/routes/test_misc_routes.py index 67828a33..51753f66 100644 --- a/tests/routes/test_misc_routes.py +++ b/tests/routes/test_misc_routes.py @@ -23,9 +23,10 @@ from py.routes.misc_routes import MiscRoutes class FakeRequest: - def __init__(self, *, json_data=None, query=None): + def __init__(self, *, json_data=None, query=None, method="POST"): self._json_data = json_data or {} self.query = query or {} + self.method = method async def json(self): return self._json_data @@ -869,6 +870,32 @@ async def test_check_model_exists_returns_local_versions(): assert lora_scanner.version_calls == [5] +@pytest.mark.asyncio +async def test_check_model_exists_model_id_only_does_not_call_metadata_provider(): + async def metadata_provider_factory(): + raise AssertionError("metadata provider should not be called for modelId-only checks") + + handler = ModelLibraryHandler( + ServiceRegistryAdapter( + get_lora_scanner=fake_scanner_factory, + get_checkpoint_scanner=fake_scanner_factory, + get_embedding_scanner=fake_scanner_factory, + get_downloaded_version_history_service=fake_download_history_service_factory, + ), + metadata_provider_factory=metadata_provider_factory, + ) + + response = await handler.check_model_exists(FakeRequest(query={"modelId": "5"})) + payload = json.loads(response.text) + + assert payload == { + "success": True, + "modelType": None, + "versions": [], + "downloadedVersionIds": [], + } + + @pytest.mark.asyncio async def test_check_model_exists_returns_download_history_when_file_missing(): history_service = FakeDownloadHistoryService({"checkpoint": {999}}) @@ -949,6 +976,25 @@ async def test_model_version_download_status_endpoints(): ("checkpoint", 456, 78, "manual", "/tmp/model.safetensors") ] + set_get_response = await handler.set_model_version_download_status( + FakeRequest( + method="GET", + query={ + "modelType": "embedding", + "modelVersionId": "789", + "modelId": "12", + "downloaded": "false", + }, + ) + ) + set_get_payload = json.loads(set_get_response.text) + assert set_get_payload == { + "success": True, + "modelType": "embedding", + "modelVersionId": 789, + "hasBeenDownloaded": False, + } + def test_create_handler_set_uses_provided_dependencies(): recorded_handlers: list[dict] = []