fix(usage): resolve checkpoint hashes from disk

This commit is contained in:
Will Miao
2026-04-10 22:28:04 +08:00
parent 39643eb2bc
commit 25fa175aa2
4 changed files with 227 additions and 19 deletions

View File

@@ -199,6 +199,47 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
mock_calc.assert_not_called(), "Should not recalculate if already completed"
@pytest.mark.asyncio
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "bootstrap_model.gguf"
checkpoint_file.write_text("fake content for hashing", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
monkeypatch.setattr(
model_scanner.config,
"checkpoints_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
hash_result = await scanner.calculate_hash_for_model(normalized_file)
assert hash_result is not None, "Hash calculation should succeed without existing metadata"
assert len(hash_result) == 64, "SHA256 should be 64 hex characters"
assert scanner.get_hash_by_filename("bootstrap_model") == hash_result
metadata_file = checkpoints_root / "bootstrap_model.metadata.json"
with open(metadata_file, "r", encoding="utf-8") as f:
saved_data = json.load(f)
assert saved_data.get("sha256") == hash_result, "sha256 should be updated"
assert saved_data.get("hash_status") == "completed", "hash_status should be 'completed'"
@pytest.mark.asyncio
async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch):
"""Test bulk hash calculation for all pending checkpoints."""

View File

@@ -194,6 +194,84 @@ async def test_usage_stats_calculates_pending_checkpoint_hash_on_demand(tmp_path
await _finalize_usage_stats(tasks)
async def test_usage_stats_skips_failed_checkpoint_hash_retry(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
metadata_payload = {
"models": {
"1": {"type": "checkpoint", "name": "failed_model.safetensors"},
},
"loras": {},
}
checkpoint_cache = SimpleNamespace(
raw_data=[
{
"file_name": "failed_model",
"model_name": "failed_model",
"file_path": "/models/failed_model.safetensors",
"sha256": "",
"hash_status": "failed",
}
]
)
checkpoint_scanner = SimpleNamespace(
get_hash_by_filename=lambda name: None,
get_cached_data=AsyncMock(return_value=checkpoint_cache),
calculate_hash_for_model=AsyncMock(return_value=None),
)
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
await stats._process_metadata(metadata_payload)
checkpoint_scanner.calculate_hash_for_model.assert_not_awaited()
assert stats.stats["checkpoints"] == {}
await _finalize_usage_stats(tasks)
async def test_usage_stats_resolves_manually_copied_checkpoint_from_disk(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "Chroma1-HD-Q8_0.gguf"
checkpoint_file.write_text("manual checkpoint content", encoding="utf-8")
metadata_payload = {
"models": {
"1": {"type": "checkpoint", "name": "Chroma1-HD-Q8_0"},
},
"loras": {},
}
checkpoint_cache = SimpleNamespace(raw_data=[])
checkpoint_scanner = SimpleNamespace(
get_hash_by_filename=lambda name: None,
get_cached_data=AsyncMock(return_value=checkpoint_cache),
get_model_roots=lambda: [str(checkpoints_root)],
file_extensions={".ckpt", ".pt", ".pt2", ".bin", ".pth", ".safetensors", ".pkl", ".sft", ".gguf"},
calculate_hash_for_model=AsyncMock(return_value="resolved-hash"),
)
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=None))
await stats._process_metadata(metadata_payload)
checkpoint_scanner.calculate_hash_for_model.assert_awaited_once_with(
str(checkpoint_file).replace(os.sep, "/")
)
today = datetime.now().strftime("%Y-%m-%d")
assert stats.stats["checkpoints"]["resolved-hash"]["history"][today] == 1
await _finalize_usage_stats(tasks)
async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)