From 2eef629821d55237ada25310e0ea50a9e17bf0ea Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 23 Apr 2026 11:36:32 +0800 Subject: [PATCH] fix(checkpoints): singleflight pending hash calculation --- py/services/checkpoint_scanner.py | 61 +++++++- tests/services/test_checkpoint_lazy_hash.py | 155 ++++++++++++++++++++ 2 files changed, 213 insertions(+), 3 deletions(-) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index e3a46998..fb80f84b 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import os @@ -36,6 +37,9 @@ class CheckpointScanner(ModelScanner): file_extensions=file_extensions, hash_index=ModelHashIndex(), ) + if not hasattr(self, "_hash_calculation_lock"): + self._hash_calculation_lock = asyncio.Lock() + self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {} async def _create_default_metadata( self, file_path: str @@ -88,7 +92,7 @@ class CheckpointScanner(ModelScanner): return None async def calculate_hash_for_model(self, file_path: str) -> Optional[str]: - """Calculate hash for a checkpoint on-demand. + """Calculate hash for a checkpoint on-demand with per-file singleflight. Args: file_path: Path to the model file @@ -96,14 +100,65 @@ class CheckpointScanner(ModelScanner): Returns: SHA256 hash string, or None if calculation failed """ - from ..utils.file_utils import calculate_sha256 - try: real_path = os.path.realpath(file_path) if not os.path.exists(real_path): logger.error(f"File not found for hash calculation: {file_path}") return None + metadata, _ = await MetadataManager.load_metadata( + file_path, self.model_class + ) + if ( + metadata is not None + and metadata.hash_status == "completed" + and metadata.sha256 + ): + return metadata.sha256 + + async with self._hash_calculation_lock: + metadata, _ = await MetadataManager.load_metadata( + file_path, self.model_class + ) + if ( + metadata is not None + and metadata.hash_status == "completed" + and metadata.sha256 + ): + return metadata.sha256 + + task = self._hash_calculation_tasks.get(real_path) + if task is None: + task = asyncio.create_task( + self._run_hash_calculation_task(file_path, real_path) + ) + self._hash_calculation_tasks[real_path] = task + + return await asyncio.shield(task) + + except Exception as e: + logger.error(f"Error calculating hash for {file_path}: {e}") + return None + + async def _run_hash_calculation_task( + self, file_path: str, real_path: str + ) -> Optional[str]: + """Run a hash calculation task and remove it from the in-flight map.""" + try: + return await self._calculate_hash_for_model_uncached(file_path, real_path) + finally: + task = asyncio.current_task() + async with self._hash_calculation_lock: + if self._hash_calculation_tasks.get(real_path) is task: + del self._hash_calculation_tasks[real_path] + + async def _calculate_hash_for_model_uncached( + self, file_path: str, real_path: str + ) -> Optional[str]: + """Calculate hash for a checkpoint without checking in-flight tasks.""" + from ..utils.file_utils import calculate_sha256 + + try: # Load current metadata metadata, should_skip = await MetadataManager.load_metadata( file_path, self.model_class diff --git a/tests/services/test_checkpoint_lazy_hash.py b/tests/services/test_checkpoint_lazy_hash.py index 78ef2d1d..ad2f00c6 100644 --- a/tests/services/test_checkpoint_lazy_hash.py +++ b/tests/services/test_checkpoint_lazy_hash.py @@ -1,5 +1,6 @@ """Tests for checkpoint lazy hash calculation feature.""" +import asyncio import json import os from pathlib import Path @@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp mock_calc.assert_not_called(), "Should not recalculate if already completed" +@pytest.mark.asyncio +async def test_calculate_hash_for_model_singleflight_same_file( + tmp_path: Path, monkeypatch +): + """Concurrent calls for the same checkpoint should share one SHA256 task.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + checkpoint_file = checkpoints_root / "test_model.safetensors" + checkpoint_file.write_text("fake content", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + normalized_file = _normalize(checkpoint_file) + real_file = os.path.realpath(normalized_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + metadata = await scanner._create_default_metadata(normalized_file) + assert metadata is not None + + calls = [] + + async def fake_calculate_sha256(file_path: str) -> str: + calls.append(file_path) + await asyncio.sleep(0.01) + return "a" * 64 + + with patch( + "py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256 + ): + results = await asyncio.gather( + *[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)] + ) + + assert calls == [real_file] + assert results == ["a" * 64] * 8 + assert scanner._hash_calculation_tasks == {} + + +@pytest.mark.asyncio +async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries( + tmp_path: Path, monkeypatch +): + """A failed in-flight task should be removed so later calls can retry.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + checkpoint_file = checkpoints_root / "retry_model.safetensors" + checkpoint_file.write_text("fake content", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + normalized_file = _normalize(checkpoint_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + metadata = await scanner._create_default_metadata(normalized_file) + assert metadata is not None + + attempts = 0 + + async def fake_calculate_sha256(_file_path: str) -> str: + nonlocal attempts + attempts += 1 + if attempts == 1: + raise RuntimeError("hash failed") + return "b" * 64 + + with patch( + "py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256 + ): + assert await scanner.calculate_hash_for_model(normalized_file) is None + assert scanner._hash_calculation_tasks == {} + + hash_result = await scanner.calculate_hash_for_model(normalized_file) + + assert hash_result == "b" * 64 + assert attempts == 2 + assert scanner._hash_calculation_tasks == {} + + +@pytest.mark.asyncio +async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files( + tmp_path: Path, monkeypatch +): + """Different checkpoint files should not share the same in-flight task.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + checkpoint_files = [ + checkpoints_root / "model_a.safetensors", + checkpoints_root / "model_b.safetensors", + ] + for checkpoint_file in checkpoint_files: + checkpoint_file.write_text( + f"fake content for {checkpoint_file.name}", encoding="utf-8" + ) + + normalized_root = _normalize(checkpoints_root) + normalized_files = [ + _normalize(checkpoint_file) for checkpoint_file in checkpoint_files + ] + real_files = {os.path.realpath(file_path) for file_path in normalized_files} + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + for normalized_file in normalized_files: + metadata = await scanner._create_default_metadata(normalized_file) + assert metadata is not None + + calls = [] + hashes_by_path = { + os.path.realpath(normalized_files[0]): "c" * 64, + os.path.realpath(normalized_files[1]): "d" * 64, + } + + async def fake_calculate_sha256(file_path: str) -> str: + calls.append(file_path) + await asyncio.sleep(0.01) + return hashes_by_path[file_path] + + with patch( + "py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256 + ): + results = await asyncio.gather( + *[ + scanner.calculate_hash_for_model(file_path) + for file_path in normalized_files + ] + ) + + assert set(calls) == real_files + assert len(calls) == 2 + assert set(results) == {"c" * 64, "d" * 64} + assert scanner._hash_calculation_tasks == {} + + @pytest.mark.asyncio async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch): """Test that calculate_hash_for_model creates pending metadata when it is missing."""