mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-10 04:42:14 -03:00
fix(usage-stats): resolve pending checkpoint hashes
This commit is contained in:
@@ -292,6 +292,80 @@ class UsageStats:
|
||||
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
||||
await self._process_loras(metadata[LORAS], today)
|
||||
|
||||
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
|
||||
"""Increment usage counters for a resolved stats key."""
|
||||
if stat_key not in self.stats[category]:
|
||||
self.stats[category][stat_key] = {
|
||||
"total": 0,
|
||||
"history": {}
|
||||
}
|
||||
|
||||
self.stats[category][stat_key]["total"] += 1
|
||||
|
||||
if today_date not in self.stats[category][stat_key]["history"]:
|
||||
self.stats[category][stat_key]["history"][today_date] = 0
|
||||
self.stats[category][stat_key]["history"][today_date] += 1
|
||||
|
||||
def _normalize_model_lookup_name(self, model_name: str) -> str:
|
||||
"""Normalize a model reference to its base filename without extension."""
|
||||
return os.path.splitext(os.path.basename(model_name))[0]
|
||||
|
||||
async def _find_cached_checkpoint_entry(self, checkpoint_scanner, model_name: str):
|
||||
"""Best-effort lookup for a checkpoint cache entry by filename/model name."""
|
||||
get_cached_data = getattr(checkpoint_scanner, "get_cached_data", None)
|
||||
if not callable(get_cached_data):
|
||||
return None
|
||||
|
||||
cache = await get_cached_data()
|
||||
raw_data = getattr(cache, "raw_data", None)
|
||||
if not isinstance(raw_data, list):
|
||||
return None
|
||||
|
||||
normalized_name = self._normalize_model_lookup_name(model_name)
|
||||
for entry in raw_data:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
for candidate_key in ("file_name", "model_name", "file_path"):
|
||||
candidate_value = entry.get(candidate_key)
|
||||
if not candidate_value or not isinstance(candidate_value, str):
|
||||
continue
|
||||
if self._normalize_model_lookup_name(candidate_value) == normalized_name:
|
||||
return entry
|
||||
|
||||
return None
|
||||
|
||||
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
|
||||
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
|
||||
model_filename = self._normalize_model_lookup_name(model_name)
|
||||
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
||||
if model_hash:
|
||||
return model_hash
|
||||
|
||||
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
|
||||
if not cached_entry:
|
||||
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
|
||||
return None
|
||||
|
||||
cached_hash = cached_entry.get("sha256")
|
||||
if cached_hash:
|
||||
return cached_hash
|
||||
|
||||
if cached_entry.get("hash_status") == "pending":
|
||||
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
|
||||
file_path = cached_entry.get("file_path")
|
||||
if callable(calculate_hash) and file_path:
|
||||
calculated_hash = await calculate_hash(file_path)
|
||||
if calculated_hash:
|
||||
return calculated_hash
|
||||
logger.warning(
|
||||
f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
|
||||
return None
|
||||
|
||||
async def _process_checkpoints(self, models_data, today_date):
|
||||
"""Process checkpoint models from metadata"""
|
||||
try:
|
||||
@@ -312,28 +386,11 @@ class UsageStats:
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
# Clean up filename (remove extension if present)
|
||||
model_filename = os.path.splitext(os.path.basename(model_name))[0]
|
||||
|
||||
# Get hash for this checkpoint
|
||||
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
||||
model_hash = await self._resolve_checkpoint_hash(checkpoint_scanner, model_name)
|
||||
if not model_hash:
|
||||
logger.warning(f"No hash found for checkpoint '{model_filename}', tracking by name")
|
||||
stat_key = model_hash or f"name:{model_filename}"
|
||||
# Update stats for this checkpoint with date tracking
|
||||
if stat_key not in self.stats["checkpoints"]:
|
||||
self.stats["checkpoints"][stat_key] = {
|
||||
"total": 0,
|
||||
"history": {}
|
||||
}
|
||||
continue
|
||||
|
||||
# Increment total count
|
||||
self.stats["checkpoints"][stat_key]["total"] += 1
|
||||
|
||||
# Increment today's count
|
||||
if today_date not in self.stats["checkpoints"][stat_key]["history"]:
|
||||
self.stats["checkpoints"][stat_key]["history"][today_date] = 0
|
||||
self.stats["checkpoints"][stat_key]["history"][today_date] += 1
|
||||
self._increment_usage_counter("checkpoints", model_hash, today_date)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
|
||||
|
||||
@@ -363,22 +420,10 @@ class UsageStats:
|
||||
# Get hash for this LoRA
|
||||
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
|
||||
if not lora_hash:
|
||||
logger.warning(f"No hash found for LoRA '{lora_name}', tracking by name")
|
||||
stat_key = lora_hash or f"name:{lora_name}"
|
||||
# Update stats for this LoRA with date tracking
|
||||
if stat_key not in self.stats["loras"]:
|
||||
self.stats["loras"][stat_key] = {
|
||||
"total": 0,
|
||||
"history": {}
|
||||
}
|
||||
logger.warning(f"No hash found for LoRA '{lora_name}', skipping usage tracking")
|
||||
continue
|
||||
|
||||
# Increment total count
|
||||
self.stats["loras"][stat_key]["total"] += 1
|
||||
|
||||
# Increment today's count
|
||||
if today_date not in self.stats["loras"][stat_key]["history"]:
|
||||
self.stats["loras"][stat_key]["history"][today_date] = 0
|
||||
self.stats["loras"][stat_key]["history"][today_date] += 1
|
||||
self._increment_usage_counter("loras", lora_hash, today_date)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -152,3 +152,67 @@ async def test_usage_stats_background_processor_handles_pending_prompts(tmp_path
|
||||
assert stats.stats["loras"]["lora-hash"]["history"][today] == 1
|
||||
|
||||
await _finalize_usage_stats(tasks)
|
||||
|
||||
|
||||
async def test_usage_stats_calculates_pending_checkpoint_hash_on_demand(tmp_path, monkeypatch):
|
||||
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
|
||||
|
||||
metadata_payload = {
|
||||
"models": {
|
||||
"1": {"type": "checkpoint", "name": "pending_model.safetensors"},
|
||||
},
|
||||
"loras": {},
|
||||
}
|
||||
|
||||
checkpoint_cache = SimpleNamespace(
|
||||
raw_data=[
|
||||
{
|
||||
"file_name": "pending_model",
|
||||
"model_name": "pending_model",
|
||||
"file_path": "/models/pending_model.safetensors",
|
||||
"sha256": "",
|
||||
"hash_status": "pending",
|
||||
}
|
||||
]
|
||||
)
|
||||
checkpoint_scanner = SimpleNamespace(
|
||||
get_hash_by_filename=lambda name: None,
|
||||
get_cached_data=AsyncMock(return_value=checkpoint_cache),
|
||||
calculate_hash_for_model=AsyncMock(return_value="resolved-hash"),
|
||||
)
|
||||
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
|
||||
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
|
||||
|
||||
await stats._process_metadata(metadata_payload)
|
||||
|
||||
today = datetime.now().strftime("%Y-%m-%d")
|
||||
checkpoint_scanner.calculate_hash_for_model.assert_awaited_once_with("/models/pending_model.safetensors")
|
||||
assert stats.stats["checkpoints"]["resolved-hash"]["history"][today] == 1
|
||||
|
||||
await _finalize_usage_stats(tasks)
|
||||
|
||||
|
||||
async def test_usage_stats_skips_name_fallback_for_missing_lora_hash(tmp_path, monkeypatch):
|
||||
stats, tasks, _ = _prepare_usage_stats(tmp_path, monkeypatch)
|
||||
|
||||
metadata_payload = {
|
||||
"models": {},
|
||||
"loras": {
|
||||
"2": {"lora_list": [{"name": "missing_lora"}]},
|
||||
},
|
||||
}
|
||||
|
||||
checkpoint_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
|
||||
lora_scanner = SimpleNamespace(get_hash_by_filename=lambda name: None)
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner))
|
||||
monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner))
|
||||
|
||||
await stats._process_metadata(metadata_payload)
|
||||
|
||||
assert stats.stats["loras"] == {}
|
||||
assert not any(key.startswith("name:") for key in stats.stats["loras"])
|
||||
|
||||
await _finalize_usage_stats(tasks)
|
||||
|
||||
Reference in New Issue
Block a user