fix(checkpoints): singleflight pending hash calculation

This commit is contained in:
Will Miao
2026-04-23 11:36:32 +08:00
parent 658a04736d
commit 2eef629821
2 changed files with 213 additions and 3 deletions

View File

@@ -1,5 +1,6 @@
"""Tests for checkpoint lazy hash calculation feature."""
import asyncio
import json
import os
from pathlib import Path
@@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
mock_calc.assert_not_called(), "Should not recalculate if already completed"
@pytest.mark.asyncio
async def test_calculate_hash_for_model_singleflight_same_file(
tmp_path: Path, monkeypatch
):
"""Concurrent calls for the same checkpoint should share one SHA256 task."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "test_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
real_file = os.path.realpath(normalized_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
calls = []
async def fake_calculate_sha256(file_path: str) -> str:
calls.append(file_path)
await asyncio.sleep(0.01)
return "a" * 64
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
results = await asyncio.gather(
*[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)]
)
assert calls == [real_file]
assert results == ["a" * 64] * 8
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries(
tmp_path: Path, monkeypatch
):
"""A failed in-flight task should be removed so later calls can retry."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "retry_model.safetensors"
checkpoint_file.write_text("fake content", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
attempts = 0
async def fake_calculate_sha256(_file_path: str) -> str:
nonlocal attempts
attempts += 1
if attempts == 1:
raise RuntimeError("hash failed")
return "b" * 64
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
assert await scanner.calculate_hash_for_model(normalized_file) is None
assert scanner._hash_calculation_tasks == {}
hash_result = await scanner.calculate_hash_for_model(normalized_file)
assert hash_result == "b" * 64
assert attempts == 2
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files(
tmp_path: Path, monkeypatch
):
"""Different checkpoint files should not share the same in-flight task."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_files = [
checkpoints_root / "model_a.safetensors",
checkpoints_root / "model_b.safetensors",
]
for checkpoint_file in checkpoint_files:
checkpoint_file.write_text(
f"fake content for {checkpoint_file.name}", encoding="utf-8"
)
normalized_root = _normalize(checkpoints_root)
normalized_files = [
_normalize(checkpoint_file) for checkpoint_file in checkpoint_files
]
real_files = {os.path.realpath(file_path) for file_path in normalized_files}
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
for normalized_file in normalized_files:
metadata = await scanner._create_default_metadata(normalized_file)
assert metadata is not None
calls = []
hashes_by_path = {
os.path.realpath(normalized_files[0]): "c" * 64,
os.path.realpath(normalized_files[1]): "d" * 64,
}
async def fake_calculate_sha256(file_path: str) -> str:
calls.append(file_path)
await asyncio.sleep(0.01)
return hashes_by_path[file_path]
with patch(
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
):
results = await asyncio.gather(
*[
scanner.calculate_hash_for_model(file_path)
for file_path in normalized_files
]
)
assert set(calls) == real_files
assert len(calls) == 2
assert set(results) == {"c" * 64, "d" * 64}
assert scanner._hash_calculation_tasks == {}
@pytest.mark.asyncio
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""