diff --git a/py/utils/usage_stats.py b/py/utils/usage_stats.py index c4696aed..9cc027a3 100644 --- a/py/utils/usage_stats.py +++ b/py/utils/usage_stats.py @@ -291,6 +291,80 @@ class UsageStats: # Process loras if LORAS in metadata and isinstance(metadata[LORAS], dict): await self._process_loras(metadata[LORAS], today) + + def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None: + """Increment usage counters for a resolved stats key.""" + if stat_key not in self.stats[category]: + self.stats[category][stat_key] = { + "total": 0, + "history": {} + } + + self.stats[category][stat_key]["total"] += 1 + + if today_date not in self.stats[category][stat_key]["history"]: + self.stats[category][stat_key]["history"][today_date] = 0 + self.stats[category][stat_key]["history"][today_date] += 1 + + def _normalize_model_lookup_name(self, model_name: str) -> str: + """Normalize a model reference to its base filename without extension.""" + return os.path.splitext(os.path.basename(model_name))[0] + + async def _find_cached_checkpoint_entry(self, checkpoint_scanner, model_name: str): + """Best-effort lookup for a checkpoint cache entry by filename/model name.""" + get_cached_data = getattr(checkpoint_scanner, "get_cached_data", None) + if not callable(get_cached_data): + return None + + cache = await get_cached_data() + raw_data = getattr(cache, "raw_data", None) + if not isinstance(raw_data, list): + return None + + normalized_name = self._normalize_model_lookup_name(model_name) + for entry in raw_data: + if not isinstance(entry, dict): + continue + + for candidate_key in ("file_name", "model_name", "file_path"): + candidate_value = entry.get(candidate_key) + if not candidate_value or not isinstance(candidate_value, str): + continue + if self._normalize_model_lookup_name(candidate_value) == normalized_name: + return entry + + return None + + async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str): + """Resolve a checkpoint hash, calculating pending hashes on demand when needed.""" + model_filename = self._normalize_model_lookup_name(model_name) + model_hash = checkpoint_scanner.get_hash_by_filename(model_filename) + if model_hash: + return model_hash + + cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name) + if not cached_entry: + logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") + return None + + cached_hash = cached_entry.get("sha256") + if cached_hash: + return cached_hash + + if cached_entry.get("hash_status") == "pending": + calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None) + file_path = cached_entry.get("file_path") + if callable(calculate_hash) and file_path: + calculated_hash = await calculate_hash(file_path) + if calculated_hash: + return calculated_hash + logger.warning( + f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking" + ) + return None + + logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") + return None async def _process_checkpoints(self, models_data, today_date): """Process checkpoint models from metadata""" @@ -311,29 +385,12 @@ class UsageStats: model_name = model_info.get("name") if not model_name: continue - - # Clean up filename (remove extension if present) - model_filename = os.path.splitext(os.path.basename(model_name))[0] - - # Get hash for this checkpoint - model_hash = checkpoint_scanner.get_hash_by_filename(model_filename) + + model_hash = await self._resolve_checkpoint_hash(checkpoint_scanner, model_name) if not model_hash: - logger.warning(f"No hash found for checkpoint '{model_filename}', tracking by name") - stat_key = model_hash or f"name:{model_filename}" - # Update stats for this checkpoint with date tracking - if stat_key not in self.stats["checkpoints"]: - self.stats["checkpoints"][stat_key] = { - "total": 0, - "history": {} - } + continue - # Increment total count - self.stats["checkpoints"][stat_key]["total"] += 1 - - # Increment today's count - if today_date not in self.stats["checkpoints"][stat_key]["history"]: - self.stats["checkpoints"][stat_key]["history"][today_date] = 0 - self.stats["checkpoints"][stat_key]["history"][today_date] += 1 + self._increment_usage_counter("checkpoints", model_hash, today_date) except Exception as e: logger.error(f"Error processing checkpoint usage: {e}", exc_info=True) @@ -363,22 +420,10 @@ class UsageStats: # Get hash for this LoRA lora_hash = lora_scanner.get_hash_by_filename(lora_name) if not lora_hash: - logger.warning(f"No hash found for LoRA '{lora_name}', tracking by name") - stat_key = lora_hash or f"name:{lora_name}" - # Update stats for this LoRA with date tracking - if stat_key not in self.stats["loras"]: - self.stats["loras"][stat_key] = { - "total": 0, - "history": {} - } + logger.warning(f"No hash found for LoRA '{lora_name}', skipping usage tracking") + continue - # Increment total count - self.stats["loras"][stat_key]["total"] += 1 - - # Increment today's count - if today_date not in self.stats["loras"][stat_key]["history"]: - self.stats["loras"][stat_key]["history"][today_date] = 0 - self.stats["loras"][stat_key]["history"][today_date] += 1 + self._increment_usage_counter("loras", lora_hash, today_date) except Exception as e: logger.error(f"Error processing LoRA usage: {e}", exc_info=True) diff --git a/tests/utils/test_usage_stats.py b/tests/utils/test_usage_stats.py index 6102f690..d76cc598 100644 --- a/tests/utils/test_usage_stats.py +++ b/tests/utils/test_usage_stats.py @@ -152,3 +152,67 @@ async def test_usage_stats_background_processor_handles_pending_prompts(tmp_path assert stats.stats["loras"]["lora-hash"]["history"][today] == 1 await _finalize_usage_stats(tasks) + + +async def test_usage_stats_calculates_pending_checkpoint_hash_on_demand(tmp_path, monkeypatch): + stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch) + + metadata_payload = { + "models": { + "1": {"type": "checkpoint", "name": "pending_model.safetensors"}, + }, + "loras": {}, + } + + checkpoint_cache = SimpleNamespace( + raw_data=[ + { + "file_name": "pending_model", + "model_name": "pending_model", + "file_path": "/models/pending_model.safetensors", + "sha256": "", + "hash_status": "pending", + } + ] + ) + checkpoint_scanner = SimpleNamespace( + get_hash_by_filename=lambda name: None, + get_cached_data=AsyncMock(return_value=checkpoint_cache), + calculate_hash_for_model=AsyncMock(return_value="resolved-hash"), + ) + lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None) + + monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) + + await stats._process_metadata(metadata_payload) + + today = datetime.now().strftime("%Y-%m-%d") + checkpoint_scanner.calculate_hash_for_model.assert_awaited_once_with("/models/pending_model.safetensors") + assert stats.stats["checkpoints"]["resolved-hash"]["history"][today] == 1 + + await _finalize_usage_stats(tasks) + + +async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch): + stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch) + + metadata_payload = { + "models": {}, + "loras": { + "2": {"lora_list": [{"name": "missing_lora"}]}, + }, + } + + checkpoint_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None) + lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None) + + monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) + monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) + + await stats._process_metadata(metadata_payload) + + assert stats.stats["loras"] == {} + assert not any(key.startswith("name:") for key in stats.stats["loras"]) + + await _finalize_usage_stats(tasks)