mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
fix(checkpoints): singleflight pending hash calculation
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Tests for checkpoint lazy hash calculation feature."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -199,6 +200,160 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
|
||||
mock_calc.assert_not_called(), "Should not recalculate if already completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_singleflight_same_file(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""Concurrent calls for the same checkpoint should share one SHA256 task."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "test_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
real_file = os.path.realpath(normalized_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
calls = []
|
||||
|
||||
async def fake_calculate_sha256(file_path: str) -> str:
|
||||
calls.append(file_path)
|
||||
await asyncio.sleep(0.01)
|
||||
return "a" * 64
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
results = await asyncio.gather(
|
||||
*[scanner.calculate_hash_for_model(normalized_file) for _ in range(8)]
|
||||
)
|
||||
|
||||
assert calls == [real_file]
|
||||
assert results == ["a" * 64] * 8
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_cleans_task_after_failure_and_retries(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""A failed in-flight task should be removed so later calls can retry."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_file = checkpoints_root / "retry_model.safetensors"
|
||||
checkpoint_file.write_text("fake content", encoding="utf-8")
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_file = _normalize(checkpoint_file)
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
attempts = 0
|
||||
|
||||
async def fake_calculate_sha256(_file_path: str) -> str:
|
||||
nonlocal attempts
|
||||
attempts += 1
|
||||
if attempts == 1:
|
||||
raise RuntimeError("hash failed")
|
||||
return "b" * 64
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
assert await scanner.calculate_hash_for_model(normalized_file) is None
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
hash_result = await scanner.calculate_hash_for_model(normalized_file)
|
||||
|
||||
assert hash_result == "b" * 64
|
||||
assert attempts == 2
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_uses_separate_tasks_for_different_files(
|
||||
tmp_path: Path, monkeypatch
|
||||
):
|
||||
"""Different checkpoint files should not share the same in-flight task."""
|
||||
checkpoints_root = tmp_path / "checkpoints"
|
||||
checkpoints_root.mkdir()
|
||||
|
||||
checkpoint_files = [
|
||||
checkpoints_root / "model_a.safetensors",
|
||||
checkpoints_root / "model_b.safetensors",
|
||||
]
|
||||
for checkpoint_file in checkpoint_files:
|
||||
checkpoint_file.write_text(
|
||||
f"fake content for {checkpoint_file.name}", encoding="utf-8"
|
||||
)
|
||||
|
||||
normalized_root = _normalize(checkpoints_root)
|
||||
normalized_files = [
|
||||
_normalize(checkpoint_file) for checkpoint_file in checkpoint_files
|
||||
]
|
||||
real_files = {os.path.realpath(file_path) for file_path in normalized_files}
|
||||
|
||||
monkeypatch.setattr(
|
||||
model_scanner.config,
|
||||
"base_models_roots",
|
||||
[normalized_root],
|
||||
raising=False,
|
||||
)
|
||||
|
||||
scanner = CheckpointScanner()
|
||||
for normalized_file in normalized_files:
|
||||
metadata = await scanner._create_default_metadata(normalized_file)
|
||||
assert metadata is not None
|
||||
|
||||
calls = []
|
||||
hashes_by_path = {
|
||||
os.path.realpath(normalized_files[0]): "c" * 64,
|
||||
os.path.realpath(normalized_files[1]): "d" * 64,
|
||||
}
|
||||
|
||||
async def fake_calculate_sha256(file_path: str) -> str:
|
||||
calls.append(file_path)
|
||||
await asyncio.sleep(0.01)
|
||||
return hashes_by_path[file_path]
|
||||
|
||||
with patch(
|
||||
"py.utils.file_utils.calculate_sha256", side_effect=fake_calculate_sha256
|
||||
):
|
||||
results = await asyncio.gather(
|
||||
*[
|
||||
scanner.calculate_hash_for_model(file_path)
|
||||
for file_path in normalized_files
|
||||
]
|
||||
)
|
||||
|
||||
assert set(calls) == real_files
|
||||
assert len(calls) == 2
|
||||
assert set(results) == {"c" * 64, "d" * 64}
|
||||
assert scanner._hash_calculation_tasks == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
|
||||
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
|
||||
|
||||
Reference in New Issue
Block a user