diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py index 55cb13cd..ef1d763f 100644 --- a/py/services/checkpoint_service.py +++ b/py/services/checkpoint_service.py @@ -38,6 +38,7 @@ class CheckpointService(BaseModelService): "notes": checkpoint_data.get("notes", ""), "model_type": checkpoint_data.get("model_type", "checkpoint"), "favorite": checkpoint_data.get("favorite", False), + "update_available": bool(checkpoint_data.get("update_available", False)), "civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) } diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py index 30a742df..3275552b 100644 --- a/py/services/embedding_service.py +++ b/py/services/embedding_service.py @@ -38,6 +38,7 @@ class EmbeddingService(BaseModelService): "notes": embedding_data.get("notes", ""), "model_type": embedding_data.get("model_type", "embedding"), "favorite": embedding_data.get("favorite", False), + "update_available": bool(embedding_data.get("update_available", False)), "civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) } diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 5a0e27f6..2de2cf96 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -38,6 +38,7 @@ class LoraService(BaseModelService): "usage_tips": lora_data.get("usage_tips", ""), "notes": lora_data.get("notes", ""), "favorite": lora_data.get("favorite", False), + "update_available": bool(lora_data.get("update_available", False)), "civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) } diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 0595963a..ec5bd333 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -1,6 +1,9 @@ import pytest from py.services.base_model_service import BaseModelService +from py.services.lora_service import LoraService +from py.services.checkpoint_service import CheckpointService +from py.services.embedding_service import EmbeddingService from py.services.model_query import ( ModelCacheRepository, ModelFilterSet, @@ -455,3 +458,54 @@ async def test_get_paginated_data_annotates_update_flags_with_bulk_dedup(): assert [item["update_available"] for item in response["items"]] == [True, True, False] assert response["total"] == 3 assert response["total_pages"] == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_cls, extra_fields", + [ + (LoraService, {"usage_tips": "tips"}), + (CheckpointService, {"model_type": "checkpoint"}), + (EmbeddingService, {"model_type": "embedding"}), + ], +) +async def test_format_response_includes_update_flag(service_cls, extra_fields): + service = service_cls(scanner=object()) + payload = { + "model_name": "Demo", + "file_name": "demo.safetensors", + "folder": "root", + "file_path": "root/demo.safetensors", + **extra_fields, + } + payload["update_available"] = True + + formatted = await service.format_response(payload) + + assert "update_available" in formatted + assert formatted["update_available"] is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_cls, extra_fields", + [ + (LoraService, {"usage_tips": "tips"}), + (CheckpointService, {"model_type": "checkpoint"}), + (EmbeddingService, {"model_type": "embedding"}), + ], +) +async def test_format_response_defaults_update_flag_false(service_cls, extra_fields): + service = service_cls(scanner=object()) + payload = { + "model_name": "Demo", + "file_name": "demo.safetensors", + "folder": "root", + "file_path": "root/demo.safetensors", + **extra_fields, + } + + formatted = await service.format_response(payload) + + assert "update_available" in formatted + assert formatted["update_available"] is False