fix(usage): resolve checkpoint hashes from disk

This commit is contained in:
Will Miao
2026-04-10 22:28:04 +08:00
parent 39643eb2bc
commit 25fa175aa2
4 changed files with 227 additions and 19 deletions

View File

@@ -105,12 +105,18 @@ class CheckpointScanner(ModelScanner):
return None
# Load current metadata
metadata, _ = await MetadataManager.load_metadata(
metadata, should_skip = await MetadataManager.load_metadata(
file_path, self.model_class
)
if metadata is None:
logger.error(f"No metadata found for {file_path}")
return None
if should_skip:
logger.error(f"Invalid metadata found for {file_path}")
return None
created_metadata = await self._create_default_metadata(file_path)
if created_metadata is None:
logger.error(f"No metadata found for {file_path}")
return None
metadata = created_metadata
# Check if hash is already calculated
if metadata.hash_status == "completed" and metadata.sha256:

View File

@@ -29,6 +29,18 @@ if not standalone_mode:
logger = logging.getLogger(__name__)
_DEFAULT_CHECKPOINT_EXTENSIONS = {
".ckpt",
".pt",
".pt2",
".bin",
".pth",
".safetensors",
".pkl",
".sft",
".gguf",
}
class UsageStats:
"""Track usage statistics for models and save to JSON"""
@@ -335,6 +347,55 @@ class UsageStats:
return None
async def _find_checkpoint_file_on_disk(self, checkpoint_scanner, model_name: str):
"""Search checkpoint roots directly for a matching file.
This is used when usage tracking sees a checkpoint name before the cache has
been refreshed. The lookup is intentionally exact: we only match the model
basename and supported checkpoint extensions.
"""
get_model_roots = getattr(checkpoint_scanner, "get_model_roots", None)
if not callable(get_model_roots):
return None
roots = [root for root in get_model_roots() if root]
if not roots:
return None
supported_extensions = getattr(
checkpoint_scanner, "file_extensions", _DEFAULT_CHECKPOINT_EXTENSIONS
)
if not isinstance(supported_extensions, (set, frozenset, list, tuple)):
supported_extensions = _DEFAULT_CHECKPOINT_EXTENSIONS
normalized_name = self._normalize_model_lookup_name(model_name)
matches: list[str] = []
for root_path in roots:
if not os.path.exists(root_path):
continue
for dirpath, _dirnames, filenames in os.walk(root_path):
for filename in filenames:
extension = os.path.splitext(filename)[1].lower()
if extension not in supported_extensions:
continue
if os.path.splitext(filename)[0] != normalized_name:
continue
matches.append(os.path.join(dirpath, filename).replace(os.sep, "/"))
if len(matches) > 1:
logger.warning(
"Multiple checkpoint files matched '%s'; skipping usage tracking: %s",
normalized_name,
", ".join(matches),
)
return None
return matches[0] if matches else None
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
model_filename = self._normalize_model_lookup_name(model_name)
@@ -343,27 +404,49 @@ class UsageStats:
return model_hash
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
if not cached_entry:
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
return None
if cached_entry:
cached_hash = cached_entry.get("sha256")
if cached_hash:
return cached_hash
cached_hash = cached_entry.get("sha256")
if cached_hash:
return cached_hash
if cached_entry.get("hash_status") == "pending":
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
file_path = cached_entry.get("file_path")
if callable(calculate_hash) and file_path:
calculated_hash = await calculate_hash(file_path)
if calculated_hash:
return calculated_hash
hash_status = cached_entry.get("hash_status")
if hash_status and hash_status != "pending":
logger.warning(
f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking"
"Checkpoint '%s' has hash_status=%s; skipping usage tracking",
model_filename,
hash_status,
)
return None
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
file_path = cached_entry.get("file_path") if cached_entry else None
if not file_path:
file_path = await self._find_checkpoint_file_on_disk(
checkpoint_scanner, model_name
)
if not file_path:
logger.warning(
f"No hash found for checkpoint '{model_filename}', skipping usage tracking"
)
return None
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
if not callable(calculate_hash):
logger.warning("Checkpoint scanner not available for usage tracking")
return None
logger.info(
"Calculating hash for checkpoint '%s' from %s",
model_filename,
file_path,
)
calculated_hash = await calculate_hash(file_path)
if calculated_hash:
return calculated_hash
logger.warning(
f"Failed to calculate hash for checkpoint '{model_filename}', skipping usage tracking"
)
return None
async def _process_checkpoints(self, models_data, today_date):