fix(checkpoints): singleflight pending hash calculation

This commit is contained in:
Will Miao
2026-04-23 11:36:32 +08:00
parent 658a04736d
commit 2eef629821
2 changed files with 213 additions and 3 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import json import json
import logging import logging
import os import os
@@ -36,6 +37,9 @@ class CheckpointScanner(ModelScanner):
file_extensions=file_extensions, file_extensions=file_extensions,
hash_index=ModelHashIndex(), hash_index=ModelHashIndex(),
) )
if not hasattr(self, "_hash_calculation_lock"):
self._hash_calculation_lock = asyncio.Lock()
self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {}
async def _create_default_metadata( async def _create_default_metadata(
self, file_path: str self, file_path: str
@@ -88,7 +92,7 @@ class CheckpointScanner(ModelScanner):
return None return None
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]: async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
"""Calculate hash for a checkpoint on-demand. """Calculate hash for a checkpoint on-demand with per-file singleflight.
Args: Args:
file_path: Path to the model file file_path: Path to the model file
@@ -96,14 +100,65 @@ class CheckpointScanner(ModelScanner):
Returns: Returns:
SHA256 hash string, or None if calculation failed SHA256 hash string, or None if calculation failed
""" """
from ..utils.file_utils import calculate_sha256
try: try:
real_path = os.path.realpath(file_path) real_path = os.path.realpath(file_path)
if not os.path.exists(real_path): if not os.path.exists(real_path):
logger.error(f"File not found for hash calculation: {file_path}") logger.error(f"File not found for hash calculation: {file_path}")
return None return None
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if (
metadata is not None
and metadata.hash_status == "completed"
and metadata.sha256
):
return metadata.sha256
async with self._hash_calculation_lock:
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if (
metadata is not None
and metadata.hash_status == "completed"
and metadata.sha256
):
return metadata.sha256
task = self._hash_calculation_tasks.get(real_path)
if task is None:
task = asyncio.create_task(
self._run_hash_calculation_task(file_path, real_path)
)
self._hash_calculation_tasks[real_path] = task
return await asyncio.shield(task)
except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}")
return None
async def _run_hash_calculation_task(
self, file_path: str, real_path: str
) -> Optional[str]:
"""Run a hash calculation task and remove it from the in-flight map."""
try:
return await self._calculate_hash_for_model_uncached(file_path, real_path)
finally:
task = asyncio.current_task()
async with self._hash_calculation_lock:
if self._hash_calculation_tasks.get(real_path) is task:
del self._hash_calculation_tasks[real_path]
async def _calculate_hash_for_model_uncached(
self, file_path: str, real_path: str
) -> Optional[str]:
"""Calculate hash for a checkpoint without checking in-flight tasks."""
from ..utils.file_utils import calculate_sha256
try:
# Load current metadata # Load current metadata
metadata, should_skip = await MetadataManager.load_metadata( metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.model_class file_path, self.model_class

View File

@@ -1,5 +1,6 @@
"""Tests for checkpoint lazy hash calculation feature.""" """Tests for checkpoint lazy hash calculation feature."""
import asyncio
import json import json
import os import os
from pathlib import Path from pathlib import Path
@@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
mock_calc.assert_not_called(), "Should not recalculate if already completed" mock_calc.assert_not_called(), "Should not recalculate if already completed"
@pytest.mark.asyncio
async def test_calculate_hash_for_model_singleflight_same_file(
tmp_path: Path, monkeypatch
):
"""Concurrent calls for the same checkpoint should share one SHA256 task."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "test_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
real_file = os.path.realpath(normalized_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
calls = []
async def fake_calculate_sha256(file_path: str) -> str:
calls.append(file_path)
await asyncio.sleep(0.01)
return "a" * 64
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
results = await asyncio.gather(
*[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)]
)
assert calls == [real_file]
assert results == ["a" * 64] * 8
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries(
tmp_path: Path, monkeypatch
):
"""A failed in-flight task should be removed so later calls can retry."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "retry_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
attempts = 0
async def fake_calculate_sha256(_file_path: str) -> str:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("hash failed")
return "b" * 64
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
assert await scanner.calculate_hash_for_model(normalized_file) is None
assert scanner._hash_calculation_tasks == {}
hash_result = await scanner.calculate_hash_for_model(normalized_file)
assert hash_result == "b" * 64
assert attempts == 2
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files(
tmp_path: Path, monkeypatch
):
"""Different checkpoint files should not share the same in-flight task."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_files = [
checkpoints_root / "model_a.safetensors",
checkpoints_root / "model_b.safetensors",
]
for checkpoint_file in checkpoint_files:
checkpoint_file.write_text(
f"fake content for {checkpoint_file.name}", encoding="utf-8"
)
normalized_root = _normalize(checkpoints_root)
normalized_files = [
_normalize(checkpoint_file) for checkpoint_file in checkpoint_files
]
real_files = {os.path.realpath(file_path) for file_path in normalized_files}
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
for normalized_file in normalized_files:
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
calls = []
hashes_by_path = {
os.path.realpath(normalized_files[0]): "c" * 64,
os.path.realpath(normalized_files[1]): "d" * 64,
}
async def fake_calculate_sha256(file_path: str) -> str:
calls.append(file_path)
await asyncio.sleep(0.01)
return hashes_by_path[file_path]
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
results = await asyncio.gather(
*[
scanner.calculate_hash_for_model(file_path)
for file_path in normalized_files
]
)
assert set(calls) == real_files
assert len(calls) == 2
assert set(results) == {"c" * 64, "d" * 64}
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch): async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
"""Test that calculate_hash_for_model creates pending metadata when it is missing.""" """Test that calculate_hash_for_model creates pending metadata when it is missing."""