fix(usage): resolve checkpoint hashes from disk

This commit is contained in:
Will Miao
2026-04-10 22:28:04 +08:00
parent 39643eb2bc
commit 25fa175aa2
4 changed files with 227 additions and 19 deletions

View File

@@ -105,12 +105,18 @@ class CheckpointScanner(ModelScanner):
return None return None
# Load current metadata # Load current metadata
metadata, _ = await MetadataManager.load_metadata( metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.model_class file_path, self.model_class
) )
if metadata is None: if metadata is None:
logger.error(f"No metadata found for {file_path}") if should_skip:
return None logger.error(f"Invalid metadata found for {file_path}")
return None
created_metadata = await self._create_default_metadata(file_path)
if created_metadata is None:
logger.error(f"No metadata found for {file_path}")
return None
metadata = created_metadata
# Check if hash is already calculated # Check if hash is already calculated
if metadata.hash_status == "completed" and metadata.sha256: if metadata.hash_status == "completed" and metadata.sha256:

View File

@@ -29,6 +29,18 @@ if not standalone_mode:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_DEFAULT_CHECKPOINT_EXTENSIONS = {
".ckpt",
".pt",
".pt2",
".bin",
".pth",
".safetensors",
".pkl",
".sft",
".gguf",
}
class UsageStats: class UsageStats:
"""Track usage statistics for models and save to JSON""" """Track usage statistics for models and save to JSON"""
@@ -335,6 +347,55 @@ class UsageStats:
return None return None
async def _find_checkpoint_file_on_disk(self, checkpoint_scanner, model_name: str):
"""Search checkpoint roots directly for a matching file.
This is used when usage tracking sees a checkpoint name before the cache has
been refreshed. The lookup is intentionally exact: we only match the model
basename and supported checkpoint extensions.
"""
get_model_roots = getattr(checkpoint_scanner, "get_model_roots", None)
if not callable(get_model_roots):
return None
roots = [root for root in get_model_roots() if root]
if not roots:
return None
supported_extensions = getattr(
checkpoint_scanner, "file_extensions", _DEFAULT_CHECKPOINT_EXTENSIONS
)
if not isinstance(supported_extensions, (set, frozenset, list, tuple)):
supported_extensions = _DEFAULT_CHECKPOINT_EXTENSIONS
normalized_name = self._normalize_model_lookup_name(model_name)
matches: list[str] = []
for root_path in roots:
if not os.path.exists(root_path):
continue
for dirpath, _dirnames, filenames in os.walk(root_path):
for filename in filenames:
extension = os.path.splitext(filename)[1].lower()
if extension not in supported_extensions:
continue
if os.path.splitext(filename)[0] != normalized_name:
continue
matches.append(os.path.join(dirpath, filename).replace(os.sep, "/"))
if len(matches) > 1:
logger.warning(
"Multiple checkpoint files matched '%s'; skipping usage tracking: %s",
normalized_name,
", ".join(matches),
)
return None
return matches[0] if matches else None
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str): async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed.""" """Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
model_filename = self._normalize_model_lookup_name(model_name) model_filename = self._normalize_model_lookup_name(model_name)
@@ -343,27 +404,49 @@ class UsageStats:
return model_hash return model_hash
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name) cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
if not cached_entry: if cached_entry:
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") cached_hash = cached_entry.get("sha256")
return None if cached_hash:
return cached_hash
cached_hash = cached_entry.get("sha256") hash_status = cached_entry.get("hash_status")
if cached_hash: if hash_status and hash_status != "pending":
return cached_hash
if cached_entry.get("hash_status") == "pending":
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
file_path = cached_entry.get("file_path")
if callable(calculate_hash) and file_path:
calculated_hash = await calculate_hash(file_path)
if calculated_hash:
return calculated_hash
logger.warning( logger.warning(
f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking" "Checkpoint '%s' has hash_status=%s; skipping usage tracking",
model_filename,
hash_status,
) )
return None return None
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") file_path = cached_entry.get("file_path") if cached_entry else None
if not file_path:
file_path = await self._find_checkpoint_file_on_disk(
checkpoint_scanner, model_name
)
if not file_path:
logger.warning(
f"No hash found for checkpoint '{model_filename}', skipping usage tracking"
)
return None
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
if not callable(calculate_hash):
logger.warning("Checkpoint scanner not available for usage tracking")
return None
logger.info(
"Calculating hash for checkpoint '%s' from %s",
model_filename,
file_path,
)
calculated_hash = await calculate_hash(file_path)
if calculated_hash:
return calculated_hash
logger.warning(
f"Failed to calculate hash for checkpoint '{model_filename}', skipping usage tracking"
)
return None return None
async def _process_checkpoints(self, models_data, today_date): async def _process_checkpoints(self, models_data, today_date):

View File

@@ -199,6 +199,47 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp
mock_calc.assert_not_called(), "Should not recalculate if already completed" mock_calc.assert_not_called(), "Should not recalculate if already completed"
@pytest.mark.asyncio
async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch):
"""Test that calculate_hash_for_model creates pending metadata when it is missing."""
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "bootstrap_model.gguf"
checkpoint_file.write_text("fake content for hashing", encoding="utf-8")
normalized_root = _normalize(checkpoints_root)
normalized_file = _normalize(checkpoint_file)
monkeypatch.setattr(
model_scanner.config,
"base_models_roots",
[normalized_root],
raising=False,
)
monkeypatch.setattr(
model_scanner.config,
"checkpoints_roots",
[normalized_root],
raising=False,
)
scanner = CheckpointScanner()
hash_result = await scanner.calculate_hash_for_model(normalized_file)
assert hash_result is not None, "Hash calculation should succeed without existing metadata"
assert len(hash_result) == 64, "SHA256 should be 64 hex characters"
assert scanner.get_hash_by_filename("bootstrap_model") == hash_result
metadata_file = checkpoints_root / "bootstrap_model.metadata.json"
with open(metadata_file, "r", encoding="utf-8") as f:
saved_data = json.load(f)
assert saved_data.get("sha256") == hash_result, "sha256 should be updated"
assert saved_data.get("hash_status") == "completed", "hash_status should be 'completed'"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch): async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch):
"""Test bulk hash calculation for all pending checkpoints.""" """Test bulk hash calculation for all pending checkpoints."""

View File

@@ -194,6 +194,84 @@ async def test_usage_stats_calculates_pending_checkpoint_hash_on_demand(tmp_path
await _finalize_usage_stats(tasks) await _finalize_usage_stats(tasks)
async def test_usage_stats_skips_failed_checkpoint_hash_retry(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
metadata_payload = {
"models": {
"1": {"type": "checkpoint", "name": "failed_model.safetensors"},
},
"loras": {},
}
checkpoint_cache = SimpleNamespace(
raw_data=[
{
"file_name": "failed_model",
"model_name": "failed_model",
"file_path": "/models/failed_model.safetensors",
"sha256": "",
"hash_status": "failed",
}
]
)
checkpoint_scanner = SimpleNamespace(
get_hash_by_filename=lambda name: None,
get_cached_data=AsyncMock(return_value=checkpoint_cache),
calculate_hash_for_model=AsyncMock(return_value=None),
)
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
await stats._process_metadata(metadata_payload)
checkpoint_scanner.calculate_hash_for_model.assert_not_awaited()
assert stats.stats["checkpoints"] == {}
await _finalize_usage_stats(tasks)
async def test_usage_stats_resolves_manually_copied_checkpoint_from_disk(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
checkpoints_root = tmp_path / "checkpoints"
checkpoints_root.mkdir()
checkpoint_file = checkpoints_root / "Chroma1-HD-Q8_0.gguf"
checkpoint_file.write_text("manual checkpoint content", encoding="utf-8")
metadata_payload = {
"models": {
"1": {"type": "checkpoint", "name": "Chroma1-HD-Q8_0"},
},
"loras": {},
}
checkpoint_cache = SimpleNamespace(raw_data=[])
checkpoint_scanner = SimpleNamespace(
get_hash_by_filename=lambda name: None,
get_cached_data=AsyncMock(return_value=checkpoint_cache),
get_model_roots=lambda: [str(checkpoints_root)],
file_extensions={".ckpt", ".pt", ".pt2", ".bin", ".pth", ".safetensors", ".pkl", ".sft", ".gguf"},
calculate_hash_for_model=AsyncMock(return_value="resolved-hash"),
)
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=None))
await stats._process_metadata(metadata_payload)
checkpoint_scanner.calculate_hash_for_model.assert_awaited_once_with(
str(checkpoint_file).replace(os.sep, "/")
)
today = datetime.now().strftime("%Y-%m-%d")
assert stats.stats["checkpoints"]["resolved-hash"]["history"][today] == 1
await _finalize_usage_stats(tasks)
async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch): async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch) stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)