From 2abb5bf12204b2e5e4a9058bd8dd1d2218e5a100 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sat, 25 Oct 2025 16:29:54 +0800 Subject: [PATCH] feat: add update_available flag to model services Add update_available field to checkpoint, embedding, and LoRA service response formatting. The flag indicates whether a model update is available and defaults to false when not specified. Include comprehensive tests to verify the update flag is properly included in formatted responses and defaults to false when not present in the payload. --- py/services/checkpoint_service.py | 1 + py/services/embedding_service.py | 1 + py/services/lora_service.py | 1 + tests/services/test_base_model_service.py | 54 +++++++++++++++++++++++ 4 files changed, 57 insertions(+) diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py index 55cb13cd..ef1d763f 100644 --- a/py/services/checkpoint_service.py +++ b/py/services/checkpoint_service.py @@ -38,6 +38,7 @@ class CheckpointService(BaseModelService): "notes": checkpoint_data.get("notes", ""), "model_type": checkpoint_data.get("model_type", "checkpoint"), "favorite": checkpoint_data.get("favorite", False), + "update_available": bool(checkpoint_data.get("update_available", False)), "civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) } diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py index 30a742df..3275552b 100644 --- a/py/services/embedding_service.py +++ b/py/services/embedding_service.py @@ -38,6 +38,7 @@ class EmbeddingService(BaseModelService): "notes": embedding_data.get("notes", ""), "model_type": embedding_data.get("model_type", "embedding"), "favorite": embedding_data.get("favorite", False), + "update_available": bool(embedding_data.get("update_available", False)), "civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) } diff --git a/py/services/lora_service.py b/py/services/lora_service.py index 5a0e27f6..2de2cf96 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -38,6 +38,7 @@ class LoraService(BaseModelService): "usage_tips": lora_data.get("usage_tips", ""), "notes": lora_data.get("notes", ""), "favorite": lora_data.get("favorite", False), + "update_available": bool(lora_data.get("update_available", False)), "civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) } diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py index 0595963a..ec5bd333 100644 --- a/tests/services/test_base_model_service.py +++ b/tests/services/test_base_model_service.py @@ -1,6 +1,9 @@ import pytest from py.services.base_model_service import BaseModelService +from py.services.lora_service import LoraService +from py.services.checkpoint_service import CheckpointService +from py.services.embedding_service import EmbeddingService from py.services.model_query import ( ModelCacheRepository, ModelFilterSet, @@ -455,3 +458,54 @@ async def test_get_paginated_data_annotates_update_flags_with_bulk_dedup(): assert [item["update_available"] for item in response["items"]] == [True, True, False] assert response["total"] == 3 assert response["total_pages"] == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_cls, extra_fields", + [ + (LoraService, {"usage_tips": "tips"}), + (CheckpointService, {"model_type": "checkpoint"}), + (EmbeddingService, {"model_type": "embedding"}), + ], +) +async def test_format_response_includes_update_flag(service_cls, extra_fields): + service = service_cls(scanner=object()) + payload = { + "model_name": "Demo", + "file_name": "demo.safetensors", + "folder": "root", + "file_path": "root/demo.safetensors", + **extra_fields, + } + payload["update_available"] = True + + formatted = await service.format_response(payload) + + assert "update_available" in formatted + assert formatted["update_available"] is True + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "service_cls, extra_fields", + [ + (LoraService, {"usage_tips": "tips"}), + (CheckpointService, {"model_type": "checkpoint"}), + (EmbeddingService, {"model_type": "embedding"}), + ], +) +async def test_format_response_defaults_update_flag_false(service_cls, extra_fields): + service = service_cls(scanner=object()) + payload = { + "model_name": "Demo", + "file_name": "demo.safetensors", + "folder": "root", + "file_path": "root/demo.safetensors", + **extra_fields, + } + + formatted = await service.format_response(payload) + + assert "update_available" in formatted + assert formatted["update_available"] is False