Files
ComfyUI-Lora-Manager/py/utils/usage_stats.py

615 lines
24 KiB
Python

import os
import re
import json
import time
import asyncio
import logging
import datetime
import shutil
from typing import Dict, Set
from ..config import config
from ..services.service_registry import ServiceRegistry
# Check if running in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
# Define constants locally to avoid dependency on conditional imports
MODELS = "models"
LORAS = "loras"
EMBEDDINGS = "embeddings"
PROMPTS = "prompts"
if not standalone_mode:
from ..metadata_collector.metadata_registry import MetadataRegistry
# Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above
try:
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS, EMBEDDINGS as _EMBEDDINGS, PROMPTS as _PROMPTS
MODELS = _MODELS
LORAS = _LORAS
EMBEDDINGS = _EMBEDDINGS
PROMPTS = _PROMPTS
except ImportError:
pass # Use the local definitions
logger = logging.getLogger(__name__)
_DEFAULT_CHECKPOINT_EXTENSIONS = {
".ckpt",
".pt",
".pt2",
".bin",
".pth",
".safetensors",
".pkl",
".sft",
".gguf",
}
class UsageStats:
"""Track usage statistics for models and save to JSON"""
_instance = None
_lock = asyncio.Lock() # For thread safety
# Default stats file name
STATS_FILENAME = "lora_manager_stats.json"
BACKUP_SUFFIX = ".backup"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# Initialize stats storage
self.stats = {
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
"loras": {}, # sha256 -> { total: count, history: { date: count } }
"embeddings": {}, # sha256 -> { total: count, history: { date: count } }
"total_executions": 0,
"last_save_time": 0
}
# Track if stats have been modified since last save
self._is_dirty = False
# Queue for prompt_ids to process
self.pending_prompt_ids = set()
# Load existing stats if available
self._stats_file_path = self._get_stats_file_path()
self._load_stats()
# Save interval in seconds
self.save_interval = 90 # 1.5 minutes
# Start background task to process queued prompt_ids
self._bg_task = asyncio.create_task(self._background_processor())
self._initialized = True
logger.debug("Usage statistics tracker initialized")
def _get_stats_file_path(self) -> str:
"""Get the path to the stats JSON file"""
if not config.loras_roots or len(config.loras_roots) == 0:
# If no lora roots are available, we can't save stats
# This will be handled by the caller
raise RuntimeError("No LoRA root directories configured. Cannot initialize usage statistics.")
# Use the first lora root
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
def _backup_old_stats(self):
"""Backup the old stats file before conversion"""
if os.path.exists(self._stats_file_path):
backup_path = f"{self._stats_file_path}{self.BACKUP_SUFFIX}"
try:
shutil.copy2(self._stats_file_path, backup_path)
logger.info(f"Backed up old stats file to {backup_path}")
return True
except Exception as e:
logger.error(f"Failed to backup stats file: {e}")
return False
def _convert_old_format(self, old_stats):
"""Convert old stats format to new format with history"""
new_stats = {
"checkpoints": {},
"loras": {},
"embeddings": {},
"total_executions": old_stats.get("total_executions", 0),
"last_save_time": old_stats.get("last_save_time", time.time())
}
# Get today's date in YYYY-MM-DD format
today = datetime.datetime.now().strftime("%Y-%m-%d")
# Convert checkpoint stats
if "checkpoints" in old_stats and isinstance(old_stats["checkpoints"], dict):
for hash_id, count in old_stats["checkpoints"].items():
new_stats["checkpoints"][hash_id] = {
"total": count,
"history": {
today: count
}
}
# Convert lora stats
if "loras" in old_stats and isinstance(old_stats["loras"], dict):
for hash_id, count in old_stats["loras"].items():
new_stats["loras"][hash_id] = {
"total": count,
"history": {
today: count
}
}
# Convert embedding stats (if present in old format)
if "embeddings" in old_stats and isinstance(old_stats["embeddings"], dict):
for hash_id, count in old_stats["embeddings"].items():
new_stats["embeddings"][hash_id] = {
"total": count,
"history": {
today: count
}
}
logger.info("Successfully converted stats from old format to new format with history")
return new_stats
def _is_old_format(self, stats):
"""Check if the stats are in the old format (direct count values)"""
# Check if any lora or checkpoint entry is a direct number instead of an object
for category in ("loras", "checkpoints", "embeddings"):
if category in stats and isinstance(stats[category], dict):
for hash_id, data in stats[category].items():
if isinstance(data, (int, float)):
return True
return False
def _load_stats(self):
"""Load existing statistics from file"""
try:
if os.path.exists(self._stats_file_path):
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
loaded_stats = json.load(f)
# Check if old format and needs conversion
if self._is_old_format(loaded_stats):
logger.info("Detected old stats format, performing conversion")
self._backup_old_stats()
self.stats = self._convert_old_format(loaded_stats)
else:
# Update our stats with loaded data (already in new format)
if isinstance(loaded_stats, dict):
# Update individual sections to maintain structure
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
self.stats["checkpoints"] = loaded_stats["checkpoints"]
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
self.stats["loras"] = loaded_stats["loras"]
if "total_executions" in loaded_stats:
self.stats["total_executions"] = loaded_stats["total_executions"]
if "last_save_time" in loaded_stats:
self.stats["last_save_time"] = loaded_stats["last_save_time"]
logger.debug(f"Loaded usage statistics from {self._stats_file_path}")
except Exception as e:
logger.error(f"Error loading usage statistics: {e}")
async def save_stats(self, force=False):
"""Save statistics to file"""
try:
# Only save if:
# 1. force is True, OR
# 2. stats have been modified (is_dirty) AND save_interval has passed
current_time = time.time()
time_since_last_save = current_time - self.stats.get("last_save_time", 0)
if not force:
if not self._is_dirty:
# No changes to save
return False
if time_since_last_save < self.save_interval:
# Too soon since last save
return False
# Use a lock to prevent concurrent writes
async with self._lock:
# Update last save time
self.stats["last_save_time"] = current_time
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
# Write to a temporary file first, then move it to avoid corruption
temp_path = f"{self._stats_file_path}.tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
# Replace the old file with the new one
os.replace(temp_path, self._stats_file_path)
# Clear dirty flag since we've saved
self._is_dirty = False
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
return True
except Exception as e:
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
return False
def register_execution(self, prompt_id):
"""Register a completed execution by prompt_id for later processing"""
if prompt_id:
self.pending_prompt_ids.add(prompt_id)
async def _background_processor(self):
"""Background task to process queued prompt_ids"""
try:
while True:
# Wait a short interval before checking for new prompt_ids
await asyncio.sleep(5) # Check every 5 seconds
# Process any pending prompt_ids
if self.pending_prompt_ids:
async with self._lock:
# Get a copy of the set and clear original
prompt_ids = self.pending_prompt_ids.copy()
self.pending_prompt_ids.clear()
# Process each prompt_id
try:
registry = MetadataRegistry()
except NameError:
# MetadataRegistry not available (standalone mode)
registry = None
if registry:
for prompt_id in prompt_ids:
try:
metadata = registry.get_metadata(prompt_id)
await self._process_metadata(metadata)
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
# Periodically save stats (only if there are changes)
if self._is_dirty:
await self.save_stats()
except asyncio.CancelledError:
# Task was cancelled, clean up
await self.save_stats(force=True)
except Exception as e:
logger.error(f"Error in background processing task: {e}", exc_info=True)
# Restart the task after a delay if it fails
asyncio.create_task(self._restart_background_task())
async def _restart_background_task(self):
"""Restart the background task after a delay"""
await asyncio.sleep(30) # Wait 30 seconds before restarting
self._bg_task = asyncio.create_task(self._background_processor())
async def _process_metadata(self, metadata):
"""Process metadata from an execution"""
if not metadata or not isinstance(metadata, dict):
return
# Increment total executions count
self.stats["total_executions"] += 1
self._is_dirty = True
# Get today's date in YYYY-MM-DD format
today = datetime.datetime.now().strftime("%Y-%m-%d")
# Process checkpoints
if MODELS in metadata and isinstance(metadata[MODELS], dict):
await self._process_checkpoints(metadata[MODELS], today)
# Process loras
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS], today)
# Process embeddings — parse prompt text for embedding:name references
if PROMPTS in metadata and isinstance(metadata[PROMPTS], dict):
await self._process_embeddings(metadata[PROMPTS], today)
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
"""Increment usage counters for a resolved stats key."""
if stat_key not in self.stats[category]:
self.stats[category][stat_key] = {
"total": 0,
"history": {}
}
self.stats[category][stat_key]["total"] += 1
if today_date not in self.stats[category][stat_key]["history"]:
self.stats[category][stat_key]["history"][today_date] = 0
self.stats[category][stat_key]["history"][today_date] += 1
def _normalize_model_lookup_name(self, model_name: str) -> str:
"""Normalize a model reference to its base filename without extension."""
return os.path.splitext(os.path.basename(model_name))[0]
async def _find_cached_checkpoint_entry(self, checkpoint_scanner, model_name: str):
"""Best-effort lookup for a checkpoint cache entry by filename/model name."""
get_cached_data = getattr(checkpoint_scanner, "get_cached_data", None)
if not callable(get_cached_data):
return None
cache = await get_cached_data()
raw_data = getattr(cache, "raw_data", None)
if not isinstance(raw_data, list):
return None
normalized_name = self._normalize_model_lookup_name(model_name)
for entry in raw_data:
if not isinstance(entry, dict):
continue
for candidate_key in ("file_name", "model_name", "file_path"):
candidate_value = entry.get(candidate_key)
if not candidate_value or not isinstance(candidate_value, str):
continue
if self._normalize_model_lookup_name(candidate_value) == normalized_name:
return entry
return None
async def _find_checkpoint_file_on_disk(self, checkpoint_scanner, model_name: str):
"""Search checkpoint roots directly for a matching file.
This is used when usage tracking sees a checkpoint name before the cache has
been refreshed. The lookup is intentionally exact: we only match the model
basename and supported checkpoint extensions.
"""
get_model_roots = getattr(checkpoint_scanner, "get_model_roots", None)
if not callable(get_model_roots):
return None
roots = [root for root in get_model_roots() if root]
if not roots:
return None
supported_extensions = getattr(
checkpoint_scanner, "file_extensions", _DEFAULT_CHECKPOINT_EXTENSIONS
)
if not isinstance(supported_extensions, (set, frozenset, list, tuple)):
supported_extensions = _DEFAULT_CHECKPOINT_EXTENSIONS
normalized_name = self._normalize_model_lookup_name(model_name)
matches: list[str] = []
for root_path in roots:
if not os.path.exists(root_path):
continue
for dirpath, _dirnames, filenames in os.walk(root_path):
for filename in filenames:
extension = os.path.splitext(filename)[1].lower()
if extension not in supported_extensions:
continue
if os.path.splitext(filename)[0] != normalized_name:
continue
matches.append(os.path.join(dirpath, filename).replace(os.sep, "/"))
if len(matches) > 1:
logger.warning(
"Multiple checkpoint files matched '%s'; skipping usage tracking: %s",
normalized_name,
", ".join(matches),
)
return None
return matches[0] if matches else None
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
model_filename = self._normalize_model_lookup_name(model_name)
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
if model_hash:
return model_hash
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
if cached_entry:
cached_hash = cached_entry.get("sha256")
if cached_hash:
return cached_hash
hash_status = cached_entry.get("hash_status")
if hash_status and hash_status != "pending":
logger.warning(
"Checkpoint '%s' has hash_status=%s; skipping usage tracking",
model_filename,
hash_status,
)
return None
file_path = cached_entry.get("file_path") if cached_entry else None
if not file_path:
file_path = await self._find_checkpoint_file_on_disk(
checkpoint_scanner, model_name
)
if not file_path:
logger.warning(
f"No hash found for checkpoint '{model_filename}', skipping usage tracking"
)
return None
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
if not callable(calculate_hash):
logger.warning("Checkpoint scanner not available for usage tracking")
return None
logger.info(
"Calculating hash for checkpoint '%s' from %s",
model_filename,
file_path,
)
calculated_hash = await calculate_hash(file_path)
if calculated_hash:
return calculated_hash
logger.warning(
f"Failed to calculate hash for checkpoint '{model_filename}', skipping usage tracking"
)
return None
async def _process_checkpoints(self, models_data, today_date):
"""Process checkpoint models from metadata"""
try:
# Get checkpoint scanner service
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
if not checkpoint_scanner:
logger.warning("Checkpoint scanner not available for usage tracking")
return
for node_id, model_info in models_data.items():
if not isinstance(model_info, dict):
continue
# Check if this is a checkpoint model
model_type = model_info.get("type")
if model_type == "checkpoint":
model_name = model_info.get("name")
if not model_name:
continue
model_hash = await self._resolve_checkpoint_hash(checkpoint_scanner, model_name)
if not model_hash:
continue
self._increment_usage_counter("checkpoints", model_hash, today_date)
except Exception as e:
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
async def _process_loras(self, loras_data, today_date):
"""Process LoRA models from metadata"""
try:
# Get LoRA scanner service
lora_scanner = await ServiceRegistry.get_lora_scanner()
if not lora_scanner:
logger.warning("LoRA scanner not available for usage tracking")
return
for node_id, lora_info in loras_data.items():
if not isinstance(lora_info, dict):
continue
# Get the list of LoRAs from standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
if not isinstance(lora, dict):
continue
lora_name = lora.get("name")
if not lora_name:
continue
# Get hash for this LoRA
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
if not lora_hash:
logger.warning(f"No hash found for LoRA '{lora_name}', skipping usage tracking")
continue
self._increment_usage_counter("loras", lora_hash, today_date)
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
@staticmethod
def _extract_embedding_names(prompt_text: str) -> set:
"""Parse embedding:name references from prompt text.
ComfyUI's SDTokenizer resolves ``embedding:<name>`` during tokenization
(see ``sd1_clip.py _try_get_embedding``). This mirrors the same pattern
to extract embedding file names from the captured prompt strings.
"""
if not prompt_text:
return set()
# Matches ``embedding:name`` where name is alphanumeric plus _ . - /
names = re.findall(r"embedding:([a-zA-Z0-9_.\-/]+)", prompt_text)
return set(names)
async def _process_embeddings(self, prompts_data, today_date):
"""Extract embedding usage from prompt texts and record it.
Iterates every prompt node's text field captured by the metadata
collector, extracts ``embedding:<name>`` references, resolves each
name to its SHA256 hash via the embedding scanner, and increments
usage counters.
"""
try:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
if not embedding_scanner:
logger.warning("Embedding scanner not available for usage tracking")
return
seen_names = set()
for _node_id, prompt_data in prompts_data.items():
if not isinstance(prompt_data, dict):
continue
for text_field in ("text", "positive_text", "negative_text"):
text = prompt_data.get(text_field)
if isinstance(text, str):
seen_names.update(self._extract_embedding_names(text))
for emb_name in seen_names:
emb_hash = embedding_scanner.get_hash_by_filename(emb_name)
if emb_hash:
self._increment_usage_counter("embeddings", emb_hash, today_date)
else:
logger.debug(
"No hash found for embedding '%s', skipping usage tracking",
emb_name,
)
except Exception as e:
logger.error("Error processing embedding usage: %s", e, exc_info=True)
async def get_stats(self):
"""Get current usage statistics"""
return self.stats
async def get_model_usage_count(self, model_type, sha256):
"""Get usage count for a specific model by hash"""
if model_type == "checkpoint":
if sha256 in self.stats["checkpoints"]:
return self.stats["checkpoints"][sha256]["total"]
elif model_type == "lora":
if sha256 in self.stats["loras"]:
return self.stats["loras"][sha256]["total"]
elif model_type == "embedding":
if sha256 in self.stats["embeddings"]:
return self.stats["embeddings"][sha256]["total"]
return 0
async def process_execution(self, prompt_id):
"""Process a prompt execution immediately (synchronous approach)"""
if not prompt_id:
return
if standalone_mode:
# Usage statistics are not available in standalone mode
return
try:
# Process metadata for this prompt_id
registry = MetadataRegistry()
metadata = registry.get_metadata(prompt_id)
if metadata:
await self._process_metadata(metadata)
# Save stats if needed
await self.save_stats()
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)