diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 76c4d040..905d2627 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -1066,10 +1066,16 @@ class ModelUpdateHandler: self._logger.error("Failed to refresh model updates: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) + serialized_records = [] + for record in records.values(): + has_update_fn = getattr(record, "has_update", None) + if callable(has_update_fn) and has_update_fn(): + serialized_records.append(self._serialize_record(record)) + return web.json_response( { "success": True, - "records": [self._serialize_record(record) for record in records.values()], + "records": serialized_records, } ) @@ -1331,4 +1337,3 @@ class ModelHandlerSet: "get_model_update_status": self.updates.get_model_update_status, "get_model_versions": self.updates.get_model_versions, } - diff --git a/tests/routes/test_model_update_handler.py b/tests/routes/test_model_update_handler.py index 864c1b78..b85fa83d 100644 --- a/tests/routes/test_model_update_handler.py +++ b/tests/routes/test_model_update_handler.py @@ -1,3 +1,4 @@ +import json import logging from types import SimpleNamespace @@ -22,6 +23,23 @@ class DummyService: self.scanner = DummyScanner(cache) +class DummyUpdateService: + def __init__(self, records): + self.records = records + self.calls = [] + + async def refresh_for_model_type(self, model_type, scanner, provider, *, force_refresh=False): + self.calls.append( + { + "model_type": model_type, + "scanner": scanner, + "provider": provider, + "force_refresh": force_refresh, + } + ) + return self.records + + @pytest.mark.asyncio async def test_build_preview_overrides_uses_static_urls(): cache = SimpleNamespace(version_index={123: {"preview_url": "/tmp/previews/example.png"}}) @@ -55,3 +73,82 @@ async def test_build_preview_overrides_uses_static_urls(): overrides = await handler._build_preview_overrides(record) expected = config.get_preview_static_url("/tmp/previews/example.png") assert overrides == {123: expected} + + +@pytest.mark.asyncio +async def test_refresh_model_updates_filters_records_without_updates(): + cache = SimpleNamespace(version_index={}) + service = DummyService(cache) + + record_with_update = ModelUpdateRecord( + model_type="lora", + model_id=1, + versions=[ + ModelVersionRecord( + version_id=10, + name="v1", + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=False, + should_ignore=False, + ) + ], + last_checked_at=None, + should_ignore_model=False, + ) + record_without_update = ModelUpdateRecord( + model_type="lora", + model_id=2, + versions=[ + ModelVersionRecord( + version_id=20, + name="v2", + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=True, + should_ignore=False, + ) + ], + last_checked_at=None, + should_ignore_model=False, + ) + + update_service = DummyUpdateService({1: record_with_update, 2: record_without_update}) + + async def metadata_selector(name): + assert name == "civitai_api" + return object() + + handler = ModelUpdateHandler( + service=service, + update_service=update_service, + metadata_provider_selector=metadata_selector, + logger=logging.getLogger(__name__), + ) + + class DummyRequest: + can_read_body = True + query = {} + + async def json(self): + return {} + + response = await handler.refresh_model_updates(DummyRequest()) + assert response.status == 200 + + payload = json.loads(response.text) + assert payload["success"] is True + assert len(payload["records"]) == 1 + assert payload["records"][0]["modelId"] == 1 + assert payload["records"][0]["hasUpdate"] is True + + assert len(update_service.calls) == 1 + call = update_service.calls[0] + assert call["model_type"] == "lora" + assert call["scanner"] is service.scanner + assert call["force_refresh"] is False + assert call["provider"] is not None