fix(checkpoints): singleflight pending hash calculation

This commit is contained in:
Will Miao
2026-04-23 11:36:32 +08:00
parent 658a04736d
commit 2eef629821
2 changed files with 213 additions and 3 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import json
import logging
import os
@@ -36,6 +37,9 @@ class CheckpointScanner(ModelScanner):
file_extensions=file_extensions,
hash_index=ModelHashIndex(),
)
if not hasattr(self, "_hash_calculation_lock"):
self._hash_calculation_lock = asyncio.Lock()
self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {}
async def _create_default_metadata(
self, file_path: str
@@ -88,7 +92,7 @@ class CheckpointScanner(ModelScanner):
return None
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
"""Calculate hash for a checkpoint on-demand.
"""Calculate hash for a checkpoint on-demand with per-file singleflight.
Args:
file_path: Path to the model file
@@ -96,14 +100,65 @@ class CheckpointScanner(ModelScanner):
Returns:
SHA256 hash string, or None if calculation failed
"""
from ..utils.file_utils import calculate_sha256
try:
real_path = os.path.realpath(file_path)
if not os.path.exists(real_path):
logger.error(f"File not found for hash calculation: {file_path}")
return None
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if (
metadata is not None
and metadata.hash_status == "completed"
and metadata.sha256
):
return metadata.sha256
async with self._hash_calculation_lock:
metadata, _ = await MetadataManager.load_metadata(
file_path, self.model_class
)
if (
metadata is not None
and metadata.hash_status == "completed"
and metadata.sha256
):
return metadata.sha256
task = self._hash_calculation_tasks.get(real_path)
if task is None:
task = asyncio.create_task(
self._run_hash_calculation_task(file_path, real_path)
)
self._hash_calculation_tasks[real_path] = task
return await asyncio.shield(task)
except Exception as e:
logger.error(f"Error calculating hash for {file_path}: {e}")
return None
async def _run_hash_calculation_task(
self, file_path: str, real_path: str
) -> Optional[str]:
"""Run a hash calculation task and remove it from the in-flight map."""
try:
return await self._calculate_hash_for_model_uncached(file_path, real_path)
finally:
task = asyncio.current_task()
async with self._hash_calculation_lock:
if self._hash_calculation_tasks.get(real_path) is task:
del self._hash_calculation_tasks[real_path]
async def _calculate_hash_for_model_uncached(
self, file_path: str, real_path: str
) -> Optional[str]:
"""Calculate hash for a checkpoint without checking in-flight tasks."""
from ..utils.file_utils import calculate_sha256
try:
# Load current metadata
metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.model_class