fix(usage-stats): resolve pending checkpoint hashes

This commit is contained in:
Will Miao
2026-04-08 09:40:20 +08:00
parent 1c4919a3e8
commit e13d70248a
2 changed files with 145 additions and 36 deletions

View File

@@ -292,6 +292,80 @@ class UsageStats:
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS], today)
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
"""Increment usage counters for a resolved stats key."""
if stat_key not in self.stats[category]:
self.stats[category][stat_key] = {
"total": 0,
"history": {}
}
self.stats[category][stat_key]["total"] += 1
if today_date not in self.stats[category][stat_key]["history"]:
self.stats[category][stat_key]["history"][today_date] = 0
self.stats[category][stat_key]["history"][today_date] += 1
def _normalize_model_lookup_name(self, model_name: str) -> str:
"""Normalize a model reference to its base filename without extension."""
return os.path.splitext(os.path.basename(model_name))[0]
async def _find_cached_checkpoint_entry(self, checkpoint_scanner, model_name: str):
"""Best-effort lookup for a checkpoint cache entry by filename/model name."""
get_cached_data = getattr(checkpoint_scanner, "get_cached_data", None)
if not callable(get_cached_data):
return None
cache = await get_cached_data()
raw_data = getattr(cache, "raw_data", None)
if not isinstance(raw_data, list):
return None
normalized_name = self._normalize_model_lookup_name(model_name)
for entry in raw_data:
if not isinstance(entry, dict):
continue
for candidate_key in ("file_name", "model_name", "file_path"):
candidate_value = entry.get(candidate_key)
if not candidate_value or not isinstance(candidate_value, str):
continue
if self._normalize_model_lookup_name(candidate_value) == normalized_name:
return entry
return None
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
model_filename = self._normalize_model_lookup_name(model_name)
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
if model_hash:
return model_hash
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
if not cached_entry:
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
return None
cached_hash = cached_entry.get("sha256")
if cached_hash:
return cached_hash
if cached_entry.get("hash_status") == "pending":
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
file_path = cached_entry.get("file_path")
if callable(calculate_hash) and file_path:
calculated_hash = await calculate_hash(file_path)
if calculated_hash:
return calculated_hash
logger.warning(
f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking"
)
return None
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
return None
async def _process_checkpoints(self, models_data, today_date):
"""Process checkpoint models from metadata"""
try:
@@ -312,28 +386,11 @@ class UsageStats:
if not model_name:
continue
# Clean up filename (remove extension if present)
model_filename = os.path.splitext(os.path.basename(model_name))[0]
# Get hash for this checkpoint
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
model_hash = await self._resolve_checkpoint_hash(checkpoint_scanner, model_name)
if not model_hash:
logger.warning(f"No hash found for checkpoint '{model_filename}', tracking by name")
stat_key = model_hash or f"name:{model_filename}"
# Update stats for this checkpoint with date tracking
if stat_key not in self.stats["checkpoints"]:
self.stats["checkpoints"][stat_key] = {
"total": 0,
"history": {}
}
continue
# Increment total count
self.stats["checkpoints"][stat_key]["total"] += 1
# Increment today's count
if today_date not in self.stats["checkpoints"][stat_key]["history"]:
self.stats["checkpoints"][stat_key]["history"][today_date] = 0
self.stats["checkpoints"][stat_key]["history"][today_date] += 1
self._increment_usage_counter("checkpoints", model_hash, today_date)
except Exception as e:
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
@@ -363,22 +420,10 @@ class UsageStats:
# Get hash for this LoRA
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
if not lora_hash:
logger.warning(f"No hash found for LoRA '{lora_name}', tracking by name")
stat_key = lora_hash or f"name:{lora_name}"
# Update stats for this LoRA with date tracking
if stat_key not in self.stats["loras"]:
self.stats["loras"][stat_key] = {
"total": 0,
"history": {}
}
logger.warning(f"No hash found for LoRA '{lora_name}', skipping usage tracking")
continue
# Increment total count
self.stats["loras"][stat_key]["total"] += 1
# Increment today's count
if today_date not in self.stats["loras"][stat_key]["history"]:
self.stats["loras"][stat_key]["history"][today_date] = 0
self.stats["loras"][stat_key]["history"][today_date] += 1
self._increment_usage_counter("loras", lora_hash, today_date)
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)

View File

@@ -152,3 +152,67 @@ async def test_usage_stats_background_processor_handles_pending_prompts(tmp_path
assert stats.stats["loras"]["lora-hash"]["history"][today] == 1
await _finalize_usage_stats(tasks)
async def test_usage_stats_calculates_pending_checkpoint_hash_on_demand(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
metadata_payload = {
"models": {
"1": {"type": "checkpoint", "name": "pending_model.safetensors"},
},
"loras": {},
}
checkpoint_cache = SimpleNamespace(
raw_data=[
{
"file_name": "pending_model",
"model_name": "pending_model",
"file_path": "/models/pending_model.safetensors",
"sha256": "",
"hash_status": "pending",
}
]
)
checkpoint_scanner = SimpleNamespace(
get_hash_by_filename=lambda name: None,
get_cached_data=AsyncMock(return_value=checkpoint_cache),
calculate_hash_for_model=AsyncMock(return_value="resolved-hash"),
)
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
await stats._process_metadata(metadata_payload)
today = datetime.now().strftime("%Y-%m-%d")
checkpoint_scanner.calculate_hash_for_model.assert_awaited_once_with("/models/pending_model.safetensors")
assert stats.stats["checkpoints"]["resolved-hash"]["history"][today] == 1
await _finalize_usage_stats(tasks)
async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch):
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
metadata_payload = {
"models": {},
"loras": {
"2": {"lora_list": [{"name": "missing_lora"}]},
},
}
checkpoint_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
await stats._process_metadata(metadata_payload)
assert stats.stats["loras"] == {}
assert not any(key.startswith("name:") for key in stats.stats["loras"])
await _finalize_usage_stats(tasks)