Files
ComfyUI-Lora-Manager/py/utils/usage_stats.py

463 lines
19 KiB
Python

import os
import json
import time
import asyncio
import logging
import datetime
import shutil
from typing import Dict, Set
from ..config import config
from ..services.service_registry import ServiceRegistry
# Check if running in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
# Define constants locally to avoid dependency on conditional imports
MODELS = "models"
LORAS = "loras"
if not standalone_mode:
from ..metadata_collector.metadata_registry import MetadataRegistry
# Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above
try:
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS
MODELS = _MODELS
LORAS = _LORAS
except ImportError:
pass # Use the local definitions
logger = logging.getLogger(__name__)
class UsageStats:
"""Track usage statistics for models and save to JSON"""
_instance = None
_lock = asyncio.Lock() # For thread safety
# Default stats file name
STATS_FILENAME = "lora_manager_stats.json"
BACKUP_SUFFIX = ".backup"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# Initialize stats storage
self.stats = {
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
"loras": {}, # sha256 -> { total: count, history: { date: count } }
"total_executions": 0,
"last_save_time": 0
}
# Track if stats have been modified since last save
self._is_dirty = False
# Queue for prompt_ids to process
self.pending_prompt_ids = set()
# Load existing stats if available
self._stats_file_path = self._get_stats_file_path()
self._load_stats()
# Save interval in seconds
self.save_interval = 90 # 1.5 minutes
# Start background task to process queued prompt_ids
self._bg_task = asyncio.create_task(self._background_processor())
self._initialized = True
logger.debug("Usage statistics tracker initialized")
def _get_stats_file_path(self) -> str:
"""Get the path to the stats JSON file"""
if not config.loras_roots or len(config.loras_roots) == 0:
# If no lora roots are available, we can't save stats
# This will be handled by the caller
raise RuntimeError("No LoRA root directories configured. Cannot initialize usage statistics.")
# Use the first lora root
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
def _backup_old_stats(self):
"""Backup the old stats file before conversion"""
if os.path.exists(self._stats_file_path):
backup_path = f"{self._stats_file_path}{self.BACKUP_SUFFIX}"
try:
shutil.copy2(self._stats_file_path, backup_path)
logger.info(f"Backed up old stats file to {backup_path}")
return True
except Exception as e:
logger.error(f"Failed to backup stats file: {e}")
return False
def _convert_old_format(self, old_stats):
"""Convert old stats format to new format with history"""
new_stats = {
"checkpoints": {},
"loras": {},
"total_executions": old_stats.get("total_executions", 0),
"last_save_time": old_stats.get("last_save_time", time.time())
}
# Get today's date in YYYY-MM-DD format
today = datetime.datetime.now().strftime("%Y-%m-%d")
# Convert checkpoint stats
if "checkpoints" in old_stats and isinstance(old_stats["checkpoints"], dict):
for hash_id, count in old_stats["checkpoints"].items():
new_stats["checkpoints"][hash_id] = {
"total": count,
"history": {
today: count
}
}
# Convert lora stats
if "loras" in old_stats and isinstance(old_stats["loras"], dict):
for hash_id, count in old_stats["loras"].items():
new_stats["loras"][hash_id] = {
"total": count,
"history": {
today: count
}
}
logger.info("Successfully converted stats from old format to new format with history")
return new_stats
def _is_old_format(self, stats):
"""Check if the stats are in the old format (direct count values)"""
# Check if any lora or checkpoint entry is a direct number instead of an object
if "loras" in stats and isinstance(stats["loras"], dict):
for hash_id, data in stats["loras"].items():
if isinstance(data, (int, float)):
return True
if "checkpoints" in stats and isinstance(stats["checkpoints"], dict):
for hash_id, data in stats["checkpoints"].items():
if isinstance(data, (int, float)):
return True
return False
def _load_stats(self):
"""Load existing statistics from file"""
try:
if os.path.exists(self._stats_file_path):
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
loaded_stats = json.load(f)
# Check if old format and needs conversion
if self._is_old_format(loaded_stats):
logger.info("Detected old stats format, performing conversion")
self._backup_old_stats()
self.stats = self._convert_old_format(loaded_stats)
else:
# Update our stats with loaded data (already in new format)
if isinstance(loaded_stats, dict):
# Update individual sections to maintain structure
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
self.stats["checkpoints"] = loaded_stats["checkpoints"]
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
self.stats["loras"] = loaded_stats["loras"]
if "total_executions" in loaded_stats:
self.stats["total_executions"] = loaded_stats["total_executions"]
if "last_save_time" in loaded_stats:
self.stats["last_save_time"] = loaded_stats["last_save_time"]
logger.debug(f"Loaded usage statistics from {self._stats_file_path}")
except Exception as e:
logger.error(f"Error loading usage statistics: {e}")
async def save_stats(self, force=False):
"""Save statistics to file"""
try:
# Only save if:
# 1. force is True, OR
# 2. stats have been modified (is_dirty) AND save_interval has passed
current_time = time.time()
time_since_last_save = current_time - self.stats.get("last_save_time", 0)
if not force:
if not self._is_dirty:
# No changes to save
return False
if time_since_last_save < self.save_interval:
# Too soon since last save
return False
# Use a lock to prevent concurrent writes
async with self._lock:
# Update last save time
self.stats["last_save_time"] = current_time
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
# Write to a temporary file first, then move it to avoid corruption
temp_path = f"{self._stats_file_path}.tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
# Replace the old file with the new one
os.replace(temp_path, self._stats_file_path)
# Clear dirty flag since we've saved
self._is_dirty = False
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
return True
except Exception as e:
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
return False
def register_execution(self, prompt_id):
"""Register a completed execution by prompt_id for later processing"""
if prompt_id:
self.pending_prompt_ids.add(prompt_id)
async def _background_processor(self):
"""Background task to process queued prompt_ids"""
try:
while True:
# Wait a short interval before checking for new prompt_ids
await asyncio.sleep(5) # Check every 5 seconds
# Process any pending prompt_ids
if self.pending_prompt_ids:
async with self._lock:
# Get a copy of the set and clear original
prompt_ids = self.pending_prompt_ids.copy()
self.pending_prompt_ids.clear()
# Process each prompt_id
try:
registry = MetadataRegistry()
except NameError:
# MetadataRegistry not available (standalone mode)
registry = None
if registry:
for prompt_id in prompt_ids:
try:
metadata = registry.get_metadata(prompt_id)
await self._process_metadata(metadata)
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
# Periodically save stats (only if there are changes)
if self._is_dirty:
await self.save_stats()
except asyncio.CancelledError:
# Task was cancelled, clean up
await self.save_stats(force=True)
except Exception as e:
logger.error(f"Error in background processing task: {e}", exc_info=True)
# Restart the task after a delay if it fails
asyncio.create_task(self._restart_background_task())
async def _restart_background_task(self):
"""Restart the background task after a delay"""
await asyncio.sleep(30) # Wait 30 seconds before restarting
self._bg_task = asyncio.create_task(self._background_processor())
async def _process_metadata(self, metadata):
"""Process metadata from an execution"""
if not metadata or not isinstance(metadata, dict):
return
# Increment total executions count
self.stats["total_executions"] += 1
self._is_dirty = True
# Get today's date in YYYY-MM-DD format
today = datetime.datetime.now().strftime("%Y-%m-%d")
# Process checkpoints
if MODELS in metadata and isinstance(metadata[MODELS], dict):
await self._process_checkpoints(metadata[MODELS], today)
# Process loras
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS], today)
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
"""Increment usage counters for a resolved stats key."""
if stat_key not in self.stats[category]:
self.stats[category][stat_key] = {
"total": 0,
"history": {}
}
self.stats[category][stat_key]["total"] += 1
if today_date not in self.stats[category][stat_key]["history"]:
self.stats[category][stat_key]["history"][today_date] = 0
self.stats[category][stat_key]["history"][today_date] += 1
def _normalize_model_lookup_name(self, model_name: str) -> str:
"""Normalize a model reference to its base filename without extension."""
return os.path.splitext(os.path.basename(model_name))[0]
async def _find_cached_checkpoint_entry(self, checkpoint_scanner, model_name: str):
"""Best-effort lookup for a checkpoint cache entry by filename/model name."""
get_cached_data = getattr(checkpoint_scanner, "get_cached_data", None)
if not callable(get_cached_data):
return None
cache = await get_cached_data()
raw_data = getattr(cache, "raw_data", None)
if not isinstance(raw_data, list):
return None
normalized_name = self._normalize_model_lookup_name(model_name)
for entry in raw_data:
if not isinstance(entry, dict):
continue
for candidate_key in ("file_name", "model_name", "file_path"):
candidate_value = entry.get(candidate_key)
if not candidate_value or not isinstance(candidate_value, str):
continue
if self._normalize_model_lookup_name(candidate_value) == normalized_name:
return entry
return None
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
model_filename = self._normalize_model_lookup_name(model_name)
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
if model_hash:
return model_hash
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
if not cached_entry:
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
return None
cached_hash = cached_entry.get("sha256")
if cached_hash:
return cached_hash
if cached_entry.get("hash_status") == "pending":
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
file_path = cached_entry.get("file_path")
if callable(calculate_hash) and file_path:
calculated_hash = await calculate_hash(file_path)
if calculated_hash:
return calculated_hash
logger.warning(
f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking"
)
return None
logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking")
return None
async def _process_checkpoints(self, models_data, today_date):
"""Process checkpoint models from metadata"""
try:
# Get checkpoint scanner service
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
if not checkpoint_scanner:
logger.warning("Checkpoint scanner not available for usage tracking")
return
for node_id, model_info in models_data.items():
if not isinstance(model_info, dict):
continue
# Check if this is a checkpoint model
model_type = model_info.get("type")
if model_type == "checkpoint":
model_name = model_info.get("name")
if not model_name:
continue
model_hash = await self._resolve_checkpoint_hash(checkpoint_scanner, model_name)
if not model_hash:
continue
self._increment_usage_counter("checkpoints", model_hash, today_date)
except Exception as e:
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
async def _process_loras(self, loras_data, today_date):
"""Process LoRA models from metadata"""
try:
# Get LoRA scanner service
lora_scanner = await ServiceRegistry.get_lora_scanner()
if not lora_scanner:
logger.warning("LoRA scanner not available for usage tracking")
return
for node_id, lora_info in loras_data.items():
if not isinstance(lora_info, dict):
continue
# Get the list of LoRAs from standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
if not isinstance(lora, dict):
continue
lora_name = lora.get("name")
if not lora_name:
continue
# Get hash for this LoRA
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
if not lora_hash:
logger.warning(f"No hash found for LoRA '{lora_name}', skipping usage tracking")
continue
self._increment_usage_counter("loras", lora_hash, today_date)
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
async def get_stats(self):
"""Get current usage statistics"""
return self.stats
async def get_model_usage_count(self, model_type, sha256):
"""Get usage count for a specific model by hash"""
if model_type == "checkpoint":
if sha256 in self.stats["checkpoints"]:
return self.stats["checkpoints"][sha256]["total"]
elif model_type == "lora":
if sha256 in self.stats["loras"]:
return self.stats["loras"][sha256]["total"]
return 0
async def process_execution(self, prompt_id):
"""Process a prompt execution immediately (synchronous approach)"""
if not prompt_id:
return
if standalone_mode:
# Usage statistics are not available in standalone mode
return
try:
# Process metadata for this prompt_id
registry = MetadataRegistry()
metadata = registry.get_metadata(prompt_id)
if metadata:
await self._process_metadata(metadata)
# Save stats if needed
await self.save_stats()
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)