From d77b6d78b7526a2bde0e453d1c56f5fb85785e00 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Sat, 25 Oct 2025 21:31:36 +0800 Subject: [PATCH] feat(model-updates): filter records without updates in refresh response Add logic to only include model update records that have actual updates in the refresh response. This improves API efficiency by reducing payload size and only returning relevant data to clients. The change: - Adds filtering in ModelUpdateHandler.refresh_model_updates to check has_update method - Only serializes records that have updates available - Updates corresponding test to verify filtering behavior This prevents returning unnecessary data for models that don't have updates available. --- py/routes/handlers/model_handlers.py | 9 ++- tests/routes/test_model_update_handler.py | 97 +++++++++++++++++++++++ 2 files changed, 104 insertions(+), 2 deletions(-) diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 76c4d040..905d2627 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -1066,10 +1066,16 @@ class ModelUpdateHandler: self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) + serialized_records = [] + for record in records.values(): + has_update_fn = getattr(record, "has_update", None) + if callable(has_update_fn) and has_update_fn(): + serialized_records.append(self._serialize_record(record)) + return web.json_response( { "success": True, - "records": [self._serialize_record(record) for record in records.values()], + "records": serialized_records, } ) @@ -1331,4 +1337,3 @@ class ModelHandlerSet: "get_model_update_status": self.updates.get_model_update_status, "get_model_versions": self.updates.get_model_versions, } - diff --git a/tests/routes/test_model_update_handler.py b/tests/routes/test_model_update_handler.py index 864c1b78..b85fa83d 100644 --- a/tests/routes/test_model_update_handler.py +++ b/tests/routes/test_model_update_handler.py @@ -1,3 +1,4 @@ +import json import logging from types import SimpleNamespace @@ -22,6 +23,23 @@ class DummyService: self.scanner = DummyScanner(cache) +class DummyUpdateService: + def __init__(self, records): + self.records = records + self.calls = [] + + async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False): + self.calls.append( + { + "model_type": model_type, + "scanner": scanner, + "provider": provider, + "force_refresh": force_refresh, + } + ) + return self.records + + @pytest.mark.asyncio async def test_build_preview_overrides_uses_static_urls(): cache = SimpleNamespace(version_index={123: {"preview_url": "/tmp/previews/example.png"}}) @@ -55,3 +73,82 @@ async def test_build_preview_overrides_uses_static_urls(): overrides = await handler._build_preview_overrides(record) expected = config.get_preview_static_url("/tmp/previews/example.png") assert overrides == {123: expected} + + +@pytest.mark.asyncio +async def test_refresh_model_updates_filters_records_without_updates(): + cache = SimpleNamespace(version_index={}) + service = DummyService(cache) + + record_with_update = ModelUpdateRecord( + model_type="lora", + model_id=1, + versions=[ + ModelVersionRecord( + version_id=10, + name="v1", + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=False, + should_ignore=False, + ) + ], + last_checked_at=None, + should_ignore_model=False, + ) + record_without_update = ModelUpdateRecord( + model_type="lora", + model_id=2, + versions=[ + ModelVersionRecord( + version_id=20, + name="v2", + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + ) + ], + last_checked_at=None, + should_ignore_model=False, + ) + + update_service = DummyUpdateService({1: record_with_update, 2: record_without_update}) + + async def metadata_selector(name): + assert name == "civitai_api" + return object() + + handler = ModelUpdateHandler( + service=service, + update_service=update_service, + metadata_provider_selector=metadata_selector, + logger=logging.getLogger(__name__), + ) + + class DummyRequest: + can_read_body = True + query = {} + + async def json(self): + return {} + + response = await handler.refresh_model_updates(DummyRequest()) + assert response.status == 200 + + payload = json.loads(response.text) + assert payload["success"] is True + assert len(payload["records"]) == 1 + assert payload["records"][0]["modelId"] == 1 + assert payload["records"][0]["hasUpdate"] is True + + assert len(update_service.calls) == 1 + call = update_service.calls[0] + assert call["model_type"] == "lora" + assert call["scanner"] is service.scanner + assert call["force_refresh"] is False + assert call["provider"] is not None