diff --git a/py/services/model_update_service.py b/py/services/model_update_service.py index 9df0dbeb..dabe4ad4 100644 --- a/py/services/model_update_service.py +++ b/py/services/model_update_service.py @@ -930,18 +930,39 @@ class ModelUpdateService: def _extract_size_bytes(self, files) -> Optional[int]: if not isinstance(files, Iterable): return None + + def parse_size(entry: Mapping) -> Optional[int]: + size_kb = entry.get("sizeKB") + if size_kb is None: + return None + try: + return int(float(size_kb) * 1024) + except (TypeError, ValueError): + return None + + preferred_size: Optional[int] = None + fallback_size: Optional[int] = None for entry in files: if not isinstance(entry, Mapping): continue - size_kb = entry.get("sizeKB") - if size_kb is None: + size_bytes = parse_size(entry) + if size_bytes is None: continue - try: - size_float = float(size_kb) - except (TypeError, ValueError): - continue - return int(size_float * 1024) - return None + + entry_type = entry.get("type") + is_model_type = isinstance(entry_type, str) and entry_type.lower() == "model" + primary_flag = entry.get("primary") + is_primary = primary_flag is True or ( + isinstance(primary_flag, str) and primary_flag.strip().lower() == "true" + ) + + if is_model_type and is_primary: + preferred_size = size_bytes + break + if fallback_size is None: + fallback_size = size_bytes + + return preferred_size if preferred_size is not None else fallback_size def _extract_preview_url(self, images) -> Optional[str]: if not isinstance(images, Iterable): diff --git a/tests/services/test_model_update_service.py b/tests/services/test_model_update_service.py index 5538f3b8..61799715 100644 --- a/tests/services/test_model_update_service.py +++ b/tests/services/test_model_update_service.py @@ -74,6 +74,58 @@ def make_record(*versions, should_ignore_model=False): ) +def test_extract_size_bytes_prefers_primary_model_file(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path)) + + response = { + "modelVersions": [ + { + "id": 42, + "files": [ + {"sizeKB": 2018.0400390625, "type": "Training Data", "primary": False}, + { + "sizeKB": 1152322.3515625, + "type": "Model", + "primary": "True", + }, + ], + "images": [], + } + ] + } + + versions = service._extract_versions(response) + assert versions is not None + assert versions[0].size_bytes == int(1152322.3515625 * 1024) + + +def test_extract_size_bytes_falls_back_without_primary(tmp_path): + db_path = tmp_path / "updates.sqlite" + service = ModelUpdateService(str(db_path)) + + response = { + "modelVersions": [ + { + "id": 43, + "files": [ + { + "sizeKB": 2048, + "type": "Training Data", + "primary": True, + }, + {"sizeKB": 1024, "type": "Archive", "primary": False}, + ], + "images": [], + } + ] + } + + versions = service._extract_versions(response) + assert versions is not None + assert versions[0].size_bytes == int(2048 * 1024) + + def test_has_update_requires_newer_version_than_library(): record = make_record( make_version(5, in_library=True),