diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index 7577f6eb..c061ae7a 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -63,13 +63,27 @@ class ModelUpdateRecord: return [version.version_id for version in self.versions if version.is_in_library] def has_update(self) -> bool: - """Return True when a non-ignored remote version is missing locally.""" + """Return True when a non-ignored remote version newer than the newest local copy is available.""" if self.should_ignore_model: return False - return any( - not version.is_in_library and not version.should_ignore for version in self.versions - ) + max_in_library = None + for version in self.versions: + if version.is_in_library: + if max_in_library is None or version.version_id > max_in_library: + max_in_library = version.version_id + + if max_in_library is None: + return any( + not version.is_in_library and not version.should_ignore for version in self.versions + ) + + for version in self.versions: + if version.is_in_library or version.should_ignore: + continue + if version.version_id > max_in_library: + return True + return False class ModelUpdateService: diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py index 37eb1dff..2f49d452 100644 --- a/tests/services/test_model_update_service.py +++ b/tests/services/test_model_update_service.py @@ -2,7 +2,11 @@ from types import SimpleNamespace import pytest -from py.services.model_update_service import ModelUpdateService +from py.services.model_update_service import ( + ModelUpdateRecord, + ModelUpdateService, + ModelVersionRecord, +) class DummyScanner: @@ -31,6 +35,49 @@ class DummyProvider: return {model_id: self.response for model_id in model_ids} +def make_version(version_id, *, in_library, should_ignore=False): + return ModelVersionRecord( + version_id=version_id, + name=None, + base_model=None, + released_at=None, + size_bytes=None, + preview_url=None, + is_in_library=in_library, + should_ignore=should_ignore, + ) + + +def make_record(*versions, should_ignore_model=False): + return ModelUpdateRecord( + model_type="lora", + model_id=999, + versions=list(versions), + last_checked_at=None, + should_ignore_model=should_ignore_model, + ) + + +def test_has_update_requires_newer_version_than_library(): + record = make_record( + make_version(5, in_library=True), + make_version(4, in_library=False), + make_version(8, in_library=False, should_ignore=True), + ) + + assert record.has_update() is False + + +def test_has_update_detects_newer_remote_version(): + record = make_record( + make_version(5, in_library=True), + make_version(7, in_library=False), + make_version(6, in_library=False, should_ignore=True), + ) + + assert record.has_update() is True + + @pytest.mark.asyncio async def test_refresh_persists_versions_and_uses_cache(tmp_path): db_path = tmp_path / "updates.sqlite"