mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
fix(checkpoints): singleflight pending hash calculation
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -36,6 +37,9 @@ class CheckpointScanner(ModelScanner):
|
|||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=ModelHashIndex(),
|
hash_index=ModelHashIndex(),
|
||||||
)
|
)
|
||||||
|
if not hasattr(self, "_hash_calculation_lock"):
|
||||||
|
self._hash_calculation_lock = asyncio.Lock()
|
||||||
|
self._hash_calculation_tasks: dict[str, asyncio.Task[Optional[str]]] = {}
|
||||||
|
|
||||||
async def _create_default_metadata(
|
async def _create_default_metadata(
|
||||||
self, file_path: str
|
self, file_path: str
|
||||||
@@ -88,7 +92,7 @@ class CheckpointScanner(ModelScanner):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
async def calculate_hash_for_model(self, file_path: str) -> Optional[str]:
|
||||||
"""Calculate hash for a checkpoint on-demand.
|
"""Calculate hash for a checkpoint on-demand with per-file singleflight.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Path to the model file
|
file_path: Path to the model file
|
||||||
@@ -96,14 +100,65 @@ class CheckpointScanner(ModelScanner):
|
|||||||
Returns:
|
Returns:
|
||||||
SHA256 hash string, or None if calculation failed
|
SHA256 hash string, or None if calculation failed
|
||||||
"""
|
"""
|
||||||
from ..utils.file_utils import calculate_sha256
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
real_path = os.path.realpath(file_path)
|
real_path = os.path.realpath(file_path)
|
||||||
if not os.path.exists(real_path):
|
if not os.path.exists(real_path):
|
||||||
logger.error(f"File not found for hash calculation: {file_path}")
|
logger.error(f"File not found for hash calculation: {file_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
metadata, _ = await MetadataManager.load_metadata(
|
||||||
|
file_path, self.model_class
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
metadata is not None
|
||||||
|
and metadata.hash_status == "completed"
|
||||||
|
and metadata.sha256
|
||||||
|
):
|
||||||
|
return metadata.sha256
|
||||||
|
|
||||||
|
async with self._hash_calculation_lock:
|
||||||
|
metadata, _ = await MetadataManager.load_metadata(
|
||||||
|
file_path, self.model_class
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
metadata is not None
|
||||||
|
and metadata.hash_status == "completed"
|
||||||
|
and metadata.sha256
|
||||||
|
):
|
||||||
|
return metadata.sha256
|
||||||
|
|
||||||
|
task = self._hash_calculation_tasks.get(real_path)
|
||||||
|
if task is None:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self._run_hash_calculation_task(file_path, real_path)
|
||||||
|
)
|
||||||
|
self._hash_calculation_tasks[real_path] = task
|
||||||
|
|
||||||
|
return await asyncio.shield(task)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating hash for {file_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _run_hash_calculation_task(
|
||||||
|
self, file_path: str, real_path: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Run a hash calculation task and remove it from the in-flight map."""
|
||||||
|
try:
|
||||||
|
return await self._calculate_hash_for_model_uncached(file_path, real_path)
|
||||||
|
finally:
|
||||||
|
task = asyncio.current_task()
|
||||||
|
async with self._hash_calculation_lock:
|
||||||
|
if self._hash_calculation_tasks.get(real_path) is task:
|
||||||
|
del self._hash_calculation_tasks[real_path]
|
||||||
|
|
||||||
|
async def _calculate_hash_for_model_uncached(
|
||||||
|
self, file_path: str, real_path: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Calculate hash for a checkpoint without checking in-flight tasks."""
|
||||||
|
from ..utils.file_utils import calculate_sha256
|
||||||
|
|
||||||
|
try:
|
||||||
# Load current metadata
|
# Load current metadata
|
||||||
metadata, should_skip = await MetadataManager.load_metadata(
|
metadata, should_skip = await MetadataManager.load_metadata(
|
||||||
file_path, self.model_class
|
file_path, self.model_class
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Tests for checkpoint lazy hash calculation feature."""
|
"""Tests for checkpoint lazy hash calculation feature."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
|
|||||||
mock_calc.assert_not_called(), "Should not recalculate if already completed"
|
mock_calc.assert_not_called(), "Should not recalculate if already completed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calculate_hash_for_model_singleflight_same_file(
|
||||||
|
tmp_path: Path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Concurrent calls for the same checkpoint should share one SHA256 task."""
|
||||||
|
checkpoints_root = tmp_path / "checkpoints"
|
||||||
|
checkpoints_root.mkdir()
|
||||||
|
|
||||||
|
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||||
|
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||||
|
|
||||||
|
normalized_root = _normalize(checkpoints_root)
|
||||||
|
normalized_file = _normalize(checkpoint_file)
|
||||||
|
real_file = os.path.realpath(normalized_file)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
model_scanner.config,
|
||||||
|
"base_models_roots",
|
||||||
|
[normalized_root],
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
scanner = CheckpointScanner()
|
||||||
|
metadata = await scanner._create_default_metadata(normalized_file)
|
||||||
|
assert metadata is not None
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
async def fake_calculate_sha256(file_path: str) -> str:
|
||||||
|
calls.append(file_path)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return "a" * 64
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||||
|
):
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert calls == [real_file]
|
||||||
|
assert results == ["a" * 64] * 8
|
||||||
|
assert scanner._hash_calculation_tasks == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries(
|
||||||
|
tmp_path: Path, monkeypatch
|
||||||
|
):
|
||||||
|
"""A failed in-flight task should be removed so later calls can retry."""
|
||||||
|
checkpoints_root = tmp_path / "checkpoints"
|
||||||
|
checkpoints_root.mkdir()
|
||||||
|
|
||||||
|
checkpoint_file = checkpoints_root / "retry_model.safetensors"
|
||||||
|
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||||
|
|
||||||
|
normalized_root = _normalize(checkpoints_root)
|
||||||
|
normalized_file = _normalize(checkpoint_file)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
model_scanner.config,
|
||||||
|
"base_models_roots",
|
||||||
|
[normalized_root],
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
scanner = CheckpointScanner()
|
||||||
|
metadata = await scanner._create_default_metadata(normalized_file)
|
||||||
|
assert metadata is not None
|
||||||
|
|
||||||
|
attempts = 0
|
||||||
|
|
||||||
|
async def fake_calculate_sha256(_file_path: str) -> str:
|
||||||
|
nonlocal attempts
|
||||||
|
attempts += 1
|
||||||
|
if attempts == 1:
|
||||||
|
raise RuntimeError("hash failed")
|
||||||
|
return "b" * 64
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||||
|
):
|
||||||
|
assert await scanner.calculate_hash_for_model(normalized_file) is None
|
||||||
|
assert scanner._hash_calculation_tasks == {}
|
||||||
|
|
||||||
|
hash_result = await scanner.calculate_hash_for_model(normalized_file)
|
||||||
|
|
||||||
|
assert hash_result == "b" * 64
|
||||||
|
assert attempts == 2
|
||||||
|
assert scanner._hash_calculation_tasks == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files(
|
||||||
|
tmp_path: Path, monkeypatch
|
||||||
|
):
|
||||||
|
"""Different checkpoint files should not share the same in-flight task."""
|
||||||
|
checkpoints_root = tmp_path / "checkpoints"
|
||||||
|
checkpoints_root.mkdir()
|
||||||
|
|
||||||
|
checkpoint_files = [
|
||||||
|
checkpoints_root / "model_a.safetensors",
|
||||||
|
checkpoints_root / "model_b.safetensors",
|
||||||
|
]
|
||||||
|
for checkpoint_file in checkpoint_files:
|
||||||
|
checkpoint_file.write_text(
|
||||||
|
f"fake content for {checkpoint_file.name}", encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
normalized_root = _normalize(checkpoints_root)
|
||||||
|
normalized_files = [
|
||||||
|
_normalize(checkpoint_file) for checkpoint_file in checkpoint_files
|
||||||
|
]
|
||||||
|
real_files = {os.path.realpath(file_path) for file_path in normalized_files}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
model_scanner.config,
|
||||||
|
"base_models_roots",
|
||||||
|
[normalized_root],
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
scanner = CheckpointScanner()
|
||||||
|
for normalized_file in normalized_files:
|
||||||
|
metadata = await scanner._create_default_metadata(normalized_file)
|
||||||
|
assert metadata is not None
|
||||||
|
|
||||||
|
calls = []
|
||||||
|
hashes_by_path = {
|
||||||
|
os.path.realpath(normalized_files[0]): "c" * 64,
|
||||||
|
os.path.realpath(normalized_files[1]): "d" * 64,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_calculate_sha256(file_path: str) -> str:
|
||||||
|
calls.append(file_path)
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return hashes_by_path[file_path]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||||
|
):
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[
|
||||||
|
scanner.calculate_hash_for_model(file_path)
|
||||||
|
for file_path in normalized_files
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert set(calls) == real_files
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert set(results) == {"c" * 64, "d" * 64}
|
||||||
|
assert scanner._hash_calculation_tasks == {}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
|
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
|
||||||
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
|
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
|
||||||
|
|||||||
Reference in New Issue
Block a user