From 25fa175aa283112b399f90c96c962b2b82c5ea8a Mon Sep 17 00:00:00 2001 From: Will Miao Date: Fri, 10 Apr 2026 22:28:04 +0800 Subject: [PATCH] fix(usage): resolve checkpoint hashes from disk --- py/services/checkpoint_scanner.py | 12 +- py/utils/usage_stats.py | 115 +++++++++++++++++--- tests/services/test_checkpoint_lazy_hash.py | 41 +++++++ tests/utils/test_usage_stats.py | 78 +++++++++++++ 4 files changed, 227 insertions(+), 19 deletions(-) diff --git a/py/services/checkpoint_scanner.py b/py/services/checkpoint_scanner.py index 42dc0580..e3a46998 100644 --- a/py/services/checkpoint_scanner.py +++ b/py/services/checkpoint_scanner.py @@ -105,12 +105,18 @@ class CheckpointScanner(ModelScanner): return None # Load current metadata - metadata, _ = await MetadataManager.load_metadata( + metadata, should_skip = await MetadataManager.load_metadata( file_path, self.model_class ) if metadata is None: - logger.error(f"No metadata found for {file_path}") - return None + if should_skip: + logger.error(f"Invalid metadata found for {file_path}") + return None + created_metadata = await self._create_default_metadata(file_path) + if created_metadata is None: + logger.error(f"No metadata found for {file_path}") + return None + metadata = created_metadata # Check if hash is already calculated if metadata.hash_status == "completed" and metadata.sha256: diff --git a/py/utils/usage_stats.py b/py/utils/usage_stats.py index 9cc027a3..b9fe4ac8 100644 --- a/py/utils/usage_stats.py +++ b/py/utils/usage_stats.py @@ -29,6 +29,18 @@ if not standalone_mode: logger = logging.getLogger(__name__) +_DEFAULT_CHECKPOINT_EXTENSIONS = { + ".ckpt", + ".pt", + ".pt2", + ".bin", + ".pth", + ".safetensors", + ".pkl", + ".sft", + ".gguf", +} + class UsageStats: """Track usage statistics for models and save to JSON""" @@ -335,6 +347,55 @@ class UsageStats: return None + async def _find_checkpoint_file_on_disk(self, checkpoint_scanner, model_name: str): + """Search checkpoint roots directly for a matching file. + + This is used when usage tracking sees a checkpoint name before the cache has + been refreshed. The lookup is intentionally exact: we only match the model + basename and supported checkpoint extensions. + """ + get_model_roots = getattr(checkpoint_scanner, "get_model_roots", None) + if not callable(get_model_roots): + return None + + roots = [root for root in get_model_roots() if root] + if not roots: + return None + + supported_extensions = getattr( + checkpoint_scanner, "file_extensions", _DEFAULT_CHECKPOINT_EXTENSIONS + ) + if not isinstance(supported_extensions, (set, frozenset, list, tuple)): + supported_extensions = _DEFAULT_CHECKPOINT_EXTENSIONS + + normalized_name = self._normalize_model_lookup_name(model_name) + matches: list[str] = [] + + for root_path in roots: + if not os.path.exists(root_path): + continue + + for dirpath, _dirnames, filenames in os.walk(root_path): + for filename in filenames: + extension = os.path.splitext(filename)[1].lower() + if extension not in supported_extensions: + continue + + if os.path.splitext(filename)[0] != normalized_name: + continue + + matches.append(os.path.join(dirpath, filename).replace(os.sep, "/")) + + if len(matches) > 1: + logger.warning( + "Multiple checkpoint files matched '%s'; skipping usage tracking: %s", + normalized_name, + ", ".join(matches), + ) + return None + + return matches[0] if matches else None + async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str): """Resolve a checkpoint hash, calculating pending hashes on demand when needed.""" model_filename = self._normalize_model_lookup_name(model_name) @@ -343,27 +404,49 @@ class UsageStats: return model_hash cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name) - if not cached_entry: - logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") - return None + if cached_entry: + cached_hash = cached_entry.get("sha256") + if cached_hash: + return cached_hash - cached_hash = cached_entry.get("sha256") - if cached_hash: - return cached_hash - - if cached_entry.get("hash_status") == "pending": - calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None) - file_path = cached_entry.get("file_path") - if callable(calculate_hash) and file_path: - calculated_hash = await calculate_hash(file_path) - if calculated_hash: - return calculated_hash + hash_status = cached_entry.get("hash_status") + if hash_status and hash_status != "pending": logger.warning( - f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking" + "Checkpoint '%s' has hash_status=%s; skipping usage tracking", + model_filename, + hash_status, ) return None - logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") + file_path = cached_entry.get("file_path") if cached_entry else None + if not file_path: + file_path = await self._find_checkpoint_file_on_disk( + checkpoint_scanner, model_name + ) + + if not file_path: + logger.warning( + f"No hash found for checkpoint '{model_filename}', skipping usage tracking" + ) + return None + + calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None) + if not callable(calculate_hash): + logger.warning("Checkpoint scanner not available for usage tracking") + return None + + logger.info( + "Calculating hash for checkpoint '%s' from %s", + model_filename, + file_path, + ) + calculated_hash = await calculate_hash(file_path) + if calculated_hash: + return calculated_hash + + logger.warning( + f"Failed to calculate hash for checkpoint '{model_filename}', skipping usage tracking" + ) return None async def _process_checkpoints(self, models_data, today_date): diff --git a/tests/services/test_checkpoint_lazy_hash.py b/tests/services/test_checkpoint_lazy_hash.py index 67a8462f..78ef2d1d 100644 --- a/tests/services/test_checkpoint_lazy_hash.py +++ b/tests/services/test_checkpoint_lazy_hash.py @@ -199,6 +199,47 @@ async def test_calculate_hash_skips_if_already_completed(tmp_path: Path, monkeyp mock_calc.assert_not_called(), "Should not recalculate if already completed" +@pytest.mark.asyncio +async def test_calculate_hash_for_model_bootstraps_missing_metadata(tmp_path: Path, monkeypatch): + """Test that calculate_hash_for_model creates pending metadata when it is missing.""" + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + + checkpoint_file = checkpoints_root / "bootstrap_model.gguf" + checkpoint_file.write_text("fake content for hashing", encoding="utf-8") + + normalized_root = _normalize(checkpoints_root) + normalized_file = _normalize(checkpoint_file) + + monkeypatch.setattr( + model_scanner.config, + "base_models_roots", + [normalized_root], + raising=False, + ) + monkeypatch.setattr( + model_scanner.config, + "checkpoints_roots", + [normalized_root], + raising=False, + ) + + scanner = CheckpointScanner() + + hash_result = await scanner.calculate_hash_for_model(normalized_file) + + assert hash_result is not None, "Hash calculation should succeed without existing metadata" + assert len(hash_result) == 64, "SHA256 should be 64 hex characters" + assert scanner.get_hash_by_filename("bootstrap_model") == hash_result + + metadata_file = checkpoints_root / "bootstrap_model.metadata.json" + with open(metadata_file, "r", encoding="utf-8") as f: + saved_data = json.load(f) + + assert saved_data.get("sha256") == hash_result, "sha256 should be updated" + assert saved_data.get("hash_status") == "completed", "hash_status should be 'completed'" + + @pytest.mark.asyncio async def test_calculate_all_pending_hashes(tmp_path: Path, monkeypatch): """Test bulk hash calculation for all pending checkpoints.""" diff --git a/tests/utils/test_usage_stats.py b/tests/utils/test_usage_stats.py index d76cc598..5794a265 100644 --- a/tests/utils/test_usage_stats.py +++ b/tests/utils/test_usage_stats.py @@ -194,6 +194,84 @@ async def test_usage_stats_calculates_pending_checkpoint_hash_on_demand(tmp_path await _finalize_usage_stats(tasks) +async def test_usage_stats_skips_failed_checkpoint_hash_retry(tmp_path, monkeypatch): + stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch) + + metadata_payload = { + "models": { + "1": {"type": "checkpoint", "name": "failed_model.safetensors"}, + }, + "loras": {}, + } + + checkpoint_cache = SimpleNamespace( + raw_data=[ + { + "file_name": "failed_model", + "model_name": "failed_model", + "file_path": "/models/failed_model.safetensors", + "sha256": "", + "hash_status": "failed", + } + ] + ) + checkpoint_scanner = SimpleNamespace( + get_hash_by_filename=lambda name: None, + get_cached_data=AsyncMock(return_value=checkpoint_cache), + calculate_hash_for_model=AsyncMock(return_value=None), + ) + lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None) + + monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) + + await stats._process_metadata(metadata_payload) + + checkpoint_scanner.calculate_hash_for_model.assert_not_awaited() + assert stats.stats["checkpoints"] == {} + + await _finalize_usage_stats(tasks) + + +async def test_usage_stats_resolves_manually_copied_checkpoint_from_disk(tmp_path, monkeypatch): + stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch) + + checkpoints_root = tmp_path / "checkpoints" + checkpoints_root.mkdir() + checkpoint_file = checkpoints_root / "Chroma1-HD-Q8_0.gguf" + checkpoint_file.write_text("manual checkpoint content", encoding="utf-8") + + metadata_payload = { + "models": { + "1": {"type": "checkpoint", "name": "Chroma1-HD-Q8_0"}, + }, + "loras": {}, + } + + checkpoint_cache = SimpleNamespace(raw_data=[]) + checkpoint_scanner = SimpleNamespace( + get_hash_by_filename=lambda name: None, + get_cached_data=AsyncMock(return_value=checkpoint_cache), + get_model_roots=lambda: [str(checkpoints_root)], + file_extensions={".ckpt", ".pt", ".pt2", ".bin", ".pth", ".safetensors", ".pkl", ".sft", ".gguf"}, + calculate_hash_for_model=AsyncMock(return_value="resolved-hash"), + ) + + monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=None)) + + await stats._process_metadata(metadata_payload) + + checkpoint_scanner.calculate_hash_for_model.assert_awaited_once_with( + str(checkpoint_file).replace(os.sep, "/") + ) + + today = datetime.now().strftime("%Y-%m-%d") + assert stats.stats["checkpoints"]["resolved-hash"]["history"][today] == 1 + + await _finalize_usage_stats(tasks) + + async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch): stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)