mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 08:26:45 -03:00
fix(checkpoints): singleflight pending hash calculation
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -36,6 +37,9 @@ class CheckpointScanner(ModelScanner):
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex(),
|
||||
)
|
||||
if not hasattr(self, "_hash_calculation_lock"):
|
||||
self._hash_calculation_lock = asyncio.Lock()
|
||||
self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {}
|
||||
|
||||
async def _create_default_metadata(
|
||||
self, file_path: str
|
||||
@@ -88,7 +92,7 @@ class CheckpointScanner(ModelScanner):
|
||||
return None
|
||||
|
||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint on-demand.
|
||||
"""Calculate hash for a checkpoint on-demand with per-file singleflight.
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
@@ -96,14 +100,65 @@ class CheckpointScanner(ModelScanner):
|
||||
Returns:
|
||||
SHA256 hash string, or None if calculation failed
|
||||
"""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(file_path)
|
||||
if not os.path.exists(real_path):
|
||||
logger.error(f"File not found for hash calculation: {file_path}")
|
||||
return None
|
||||
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if (
|
||||
metadata is not None
|
||||
and metadata.hash_status == "completed"
|
||||
and metadata.sha256
|
||||
):
|
||||
return metadata.sha256
|
||||
|
||||
async with self._hash_calculation_lock:
|
||||
metadata, _ = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
)
|
||||
if (
|
||||
metadata is not None
|
||||
and metadata.hash_status == "completed"
|
||||
and metadata.sha256
|
||||
):
|
||||
return metadata.sha256
|
||||
|
||||
task = self._hash_calculation_tasks.get(real_path)
|
||||
if task is None:
|
||||
task = asyncio.create_task(
|
||||
self._run_hash_calculation_task(file_path, real_path)
|
||||
)
|
||||
self._hash_calculation_tasks[real_path] = task
|
||||
|
||||
return await asyncio.shield(task)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||
return None
|
||||
|
||||
async def _run_hash_calculation_task(
|
||||
self, file_path: str, real_path: str
|
||||
) -> Optional[str]:
|
||||
"""Run a hash calculation task and remove it from the in-flight map."""
|
||||
try:
|
||||
return await self._calculate_hash_for_model_uncached(file_path, real_path)
|
||||
finally:
|
||||
task = asyncio.current_task()
|
||||
async with self._hash_calculation_lock:
|
||||
if self._hash_calculation_tasks.get(real_path) is task:
|
||||
del self._hash_calculation_tasks[real_path]
|
||||
|
||||
async def _calculate_hash_for_model_uncached(
|
||||
self, file_path: str, real_path: str
|
||||
) -> Optional[str]:
|
||||
"""Calculate hash for a checkpoint without checking in-flight tasks."""
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
|
||||
try:
|
||||
# Load current metadata
|
||||
metadata, should_skip = await MetadataManager.load_metadata(
|
||||
file_path, self.model_class
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for checkpoint lazy hash calculation feature."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
|
||||
mock_calc.assert_not_called(), "Should not recalculate if already completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_singleflight_same_file(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""Concurrent calls for the same checkpoint should share one SHA256 task."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
real_file = os.path.realpath(normalized_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_calculate_sha256(file_path: str) -> str:
|
||||
calls.append(file_path)
|
||||
await asyncio.sleep(0.01)
|
||||
return "a" * 64
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
results = await asyncio.gather(
|
||||
*[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)]
|
||||
)
|
||||
|
||||
assert calls == [real_file]
|
||||
assert results == ["a" * 64] * 8
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""A failed in-flight task should be removed so later calls can retry."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "retry_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
attempts = 0
|
||||
|
||||
async def fake_calculate_sha256(_file_path: str) -> str:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise RuntimeError("hash failed")
|
||||
return "b" * 64
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
assert await scanner.calculate_hash_for_model(normalized_file) is None
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
hash_result = await scanner.calculate_hash_for_model(normalized_file)
|
||||
|
||||
assert hash_result == "b" * 64
|
||||
assert attempts == 2
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""Different checkpoint files should not share the same in-flight task."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_files = [
|
||||
checkpoints_root / "model_a.safetensors",
|
||||
checkpoints_root / "model_b.safetensors",
|
||||
]
|
||||
for checkpoint_file in checkpoint_files:
|
||||
checkpoint_file.write_text(
|
||||
f"fake content for {checkpoint_file.name}", encoding="utf-8"
|
||||
)
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_files = [
|
||||
_normalize(checkpoint_file) for checkpoint_file in checkpoint_files
|
||||
]
|
||||
real_files = {os.path.realpath(file_path) for file_path in normalized_files}
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
for normalized_file in normalized_files:
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
calls = []
|
||||
hashes_by_path = {
|
||||
os.path.realpath(normalized_files[0]): "c" * 64,
|
||||
os.path.realpath(normalized_files[1]): "d" * 64,
|
||||
}
|
||||
|
||||
async def fake_calculate_sha256(file_path: str) -> str:
|
||||
calls.append(file_path)
|
||||
await asyncio.sleep(0.01)
|
||||
return hashes_by_path[file_path]
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
scanner.calculate_hash_for_model(file_path)
|
||||
for file_path in normalized_files
|
||||
]
|
||||
)
|
||||
|
||||
assert set(calls) == real_files
|
||||
assert len(calls) == 2
|
||||
assert set(results) == {"c" * 64, "d" * 64}
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
|
||||
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
|
||||
|
||||
Reference in New Issue
Block a user