import os import json import time import asyncio import logging import datetime import shutil from typing import Dict, Set from ..config import config from ..services.service_registry import ServiceRegistry # Check if running in standalone mode standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" # Define constants locally to avoid dependency on conditional imports MODELS = "models" LORAS = "loras" if not standalone_mode: from ..metadata_collector.metadata_registry import MetadataRegistry # Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above try: from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS MODELS = _MODELS LORAS = _LORAS except ImportError: pass # Use the local definitions logger = logging.getLogger(__name__) class UsageStats: """Track usage statistics for models and save to JSON""" _instance = None _lock = asyncio.Lock() # For thread safety # Default stats file name STATS_FILENAME = "lora_manager_stats.json" BACKUP_SUFFIX = ".backup" def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return # Initialize stats storage self.stats = { "checkpoints": {}, # sha256 -> { total: count, history: { date: count } } "loras": {}, # sha256 -> { total: count, history: { date: count } } "total_executions": 0, "last_save_time": 0 } # Track if stats have been modified since last save self._is_dirty = False # Queue for prompt_ids to process self.pending_prompt_ids = set() # Load existing stats if available self._stats_file_path = self._get_stats_file_path() self._load_stats() # Save interval in seconds self.save_interval = 90 # 1.5 minutes # Start background task to process queued prompt_ids self._bg_task = asyncio.create_task(self._background_processor()) self._initialized = True logger.debug("Usage statistics tracker initialized") def _get_stats_file_path(self) -> str: """Get the path to the stats JSON file""" if not config.loras_roots or len(config.loras_roots) == 0: # If no lora roots are available, we can't save stats # This will be handled by the caller raise RuntimeError("No LoRA root directories configured. Cannot initialize usage statistics.") # Use the first lora root return os.path.join(config.loras_roots[0], self.STATS_FILENAME) def _backup_old_stats(self): """Backup the old stats file before conversion""" if os.path.exists(self._stats_file_path): backup_path = f"{self._stats_file_path}{self.BACKUP_SUFFIX}" try: shutil.copy2(self._stats_file_path, backup_path) logger.info(f"Backed up old stats file to {backup_path}") return True except Exception as e: logger.error(f"Failed to backup stats file: {e}") return False def _convert_old_format(self, old_stats): """Convert old stats format to new format with history""" new_stats = { "checkpoints": {}, "loras": {}, "total_executions": old_stats.get("total_executions", 0), "last_save_time": old_stats.get("last_save_time", time.time()) } # Get today's date in YYYY-MM-DD format today = datetime.datetime.now().strftime("%Y-%m-%d") # Convert checkpoint stats if "checkpoints" in old_stats and isinstance(old_stats["checkpoints"], dict): for hash_id, count in old_stats["checkpoints"].items(): new_stats["checkpoints"][hash_id] = { "total": count, "history": { today: count } } # Convert lora stats if "loras" in old_stats and isinstance(old_stats["loras"], dict): for hash_id, count in old_stats["loras"].items(): new_stats["loras"][hash_id] = { "total": count, "history": { today: count } } logger.info("Successfully converted stats from old format to new format with history") return new_stats def _is_old_format(self, stats): """Check if the stats are in the old format (direct count values)""" # Check if any lora or checkpoint entry is a direct number instead of an object if "loras" in stats and isinstance(stats["loras"], dict): for hash_id, data in stats["loras"].items(): if isinstance(data, (int, float)): return True if "checkpoints" in stats and isinstance(stats["checkpoints"], dict): for hash_id, data in stats["checkpoints"].items(): if isinstance(data, (int, float)): return True return False def _load_stats(self): """Load existing statistics from file""" try: if os.path.exists(self._stats_file_path): with open(self._stats_file_path, 'r', encoding='utf-8') as f: loaded_stats = json.load(f) # Check if old format and needs conversion if self._is_old_format(loaded_stats): logger.info("Detected old stats format, performing conversion") self._backup_old_stats() self.stats = self._convert_old_format(loaded_stats) else: # Update our stats with loaded data (already in new format) if isinstance(loaded_stats, dict): # Update individual sections to maintain structure if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict): self.stats["checkpoints"] = loaded_stats["checkpoints"] if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict): self.stats["loras"] = loaded_stats["loras"] if "total_executions" in loaded_stats: self.stats["total_executions"] = loaded_stats["total_executions"] if "last_save_time" in loaded_stats: self.stats["last_save_time"] = loaded_stats["last_save_time"] logger.debug(f"Loaded usage statistics from {self._stats_file_path}") except Exception as e: logger.error(f"Error loading usage statistics: {e}") async def save_stats(self, force=False): """Save statistics to file""" try: # Only save if: # 1. force is True, OR # 2. stats have been modified (is_dirty) AND save_interval has passed current_time = time.time() time_since_last_save = current_time - self.stats.get("last_save_time", 0) if not force: if not self._is_dirty: # No changes to save return False if time_since_last_save < self.save_interval: # Too soon since last save return False # Use a lock to prevent concurrent writes async with self._lock: # Update last save time self.stats["last_save_time"] = current_time # Create directory if it doesn't exist os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True) # Write to a temporary file first, then move it to avoid corruption temp_path = f"{self._stats_file_path}.tmp" with open(temp_path, 'w', encoding='utf-8') as f: json.dump(self.stats, f, indent=2, ensure_ascii=False) # Replace the old file with the new one os.replace(temp_path, self._stats_file_path) # Clear dirty flag since we've saved self._is_dirty = False logger.debug(f"Saved usage statistics to {self._stats_file_path}") return True except Exception as e: logger.error(f"Error saving usage statistics: {e}", exc_info=True) return False def register_execution(self, prompt_id): """Register a completed execution by prompt_id for later processing""" if prompt_id: self.pending_prompt_ids.add(prompt_id) async def _background_processor(self): """Background task to process queued prompt_ids""" try: while True: # Wait a short interval before checking for new prompt_ids await asyncio.sleep(5) # Check every 5 seconds # Process any pending prompt_ids if self.pending_prompt_ids: async with self._lock: # Get a copy of the set and clear original prompt_ids = self.pending_prompt_ids.copy() self.pending_prompt_ids.clear() # Process each prompt_id try: registry = MetadataRegistry() except NameError: # MetadataRegistry not available (standalone mode) registry = None if registry: for prompt_id in prompt_ids: try: metadata = registry.get_metadata(prompt_id) await self._process_metadata(metadata) except Exception as e: logger.error(f"Error processing prompt_id {prompt_id}: {e}") # Periodically save stats (only if there are changes) if self._is_dirty: await self.save_stats() except asyncio.CancelledError: # Task was cancelled, clean up await self.save_stats(force=True) except Exception as e: logger.error(f"Error in background processing task: {e}", exc_info=True) # Restart the task after a delay if it fails asyncio.create_task(self._restart_background_task()) async def _restart_background_task(self): """Restart the background task after a delay""" await asyncio.sleep(30) # Wait 30 seconds before restarting self._bg_task = asyncio.create_task(self._background_processor()) async def _process_metadata(self, metadata): """Process metadata from an execution""" if not metadata or not isinstance(metadata, dict): return # Increment total executions count self.stats["total_executions"] += 1 self._is_dirty = True # Get today's date in YYYY-MM-DD format today = datetime.datetime.now().strftime("%Y-%m-%d") # Process checkpoints if MODELS in metadata and isinstance(metadata[MODELS], dict): await self._process_checkpoints(metadata[MODELS], today) # Process loras if LORAS in metadata and isinstance(metadata[LORAS], dict): await self._process_loras(metadata[LORAS], today) def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None: """Increment usage counters for a resolved stats key.""" if stat_key not in self.stats[category]: self.stats[category][stat_key] = { "total": 0, "history": {} } self.stats[category][stat_key]["total"] += 1 if today_date not in self.stats[category][stat_key]["history"]: self.stats[category][stat_key]["history"][today_date] = 0 self.stats[category][stat_key]["history"][today_date] += 1 def _normalize_model_lookup_name(self, model_name: str) -> str: """Normalize a model reference to its base filename without extension.""" return os.path.splitext(os.path.basename(model_name))[0] async def _find_cached_checkpoint_entry(self, checkpoint_scanner, model_name: str): """Best-effort lookup for a checkpoint cache entry by filename/model name.""" get_cached_data = getattr(checkpoint_scanner, "get_cached_data", None) if not callable(get_cached_data): return None cache = await get_cached_data() raw_data = getattr(cache, "raw_data", None) if not isinstance(raw_data, list): return None normalized_name = self._normalize_model_lookup_name(model_name) for entry in raw_data: if not isinstance(entry, dict): continue for candidate_key in ("file_name", "model_name", "file_path"): candidate_value = entry.get(candidate_key) if not candidate_value or not isinstance(candidate_value, str): continue if self._normalize_model_lookup_name(candidate_value) == normalized_name: return entry return None async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str): """Resolve a checkpoint hash, calculating pending hashes on demand when needed.""" model_filename = self._normalize_model_lookup_name(model_name) model_hash = checkpoint_scanner.get_hash_by_filename(model_filename) if model_hash: return model_hash cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name) if not cached_entry: logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") return None cached_hash = cached_entry.get("sha256") if cached_hash: return cached_hash if cached_entry.get("hash_status") == "pending": calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None) file_path = cached_entry.get("file_path") if callable(calculate_hash) and file_path: calculated_hash = await calculate_hash(file_path) if calculated_hash: return calculated_hash logger.warning( f"Failed to calculate pending hash for checkpoint '{model_filename}', skipping usage tracking" ) return None logger.warning(f"No hash found for checkpoint '{model_filename}', skipping usage tracking") return None async def _process_checkpoints(self, models_data, today_date): """Process checkpoint models from metadata""" try: # Get checkpoint scanner service checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() if not checkpoint_scanner: logger.warning("Checkpoint scanner not available for usage tracking") return for node_id, model_info in models_data.items(): if not isinstance(model_info, dict): continue # Check if this is a checkpoint model model_type = model_info.get("type") if model_type == "checkpoint": model_name = model_info.get("name") if not model_name: continue model_hash = await self._resolve_checkpoint_hash(checkpoint_scanner, model_name) if not model_hash: continue self._increment_usage_counter("checkpoints", model_hash, today_date) except Exception as e: logger.error(f"Error processing checkpoint usage: {e}", exc_info=True) async def _process_loras(self, loras_data, today_date): """Process LoRA models from metadata""" try: # Get LoRA scanner service lora_scanner = await ServiceRegistry.get_lora_scanner() if not lora_scanner: logger.warning("LoRA scanner not available for usage tracking") return for node_id, lora_info in loras_data.items(): if not isinstance(lora_info, dict): continue # Get the list of LoRAs from standardized format lora_list = lora_info.get("lora_list", []) for lora in lora_list: if not isinstance(lora, dict): continue lora_name = lora.get("name") if not lora_name: continue # Get hash for this LoRA lora_hash = lora_scanner.get_hash_by_filename(lora_name) if not lora_hash: logger.warning(f"No hash found for LoRA '{lora_name}', skipping usage tracking") continue self._increment_usage_counter("loras", lora_hash, today_date) except Exception as e: logger.error(f"Error processing LoRA usage: {e}", exc_info=True) async def get_stats(self): """Get current usage statistics""" return self.stats async def get_model_usage_count(self, model_type, sha256): """Get usage count for a specific model by hash""" if model_type == "checkpoint": if sha256 in self.stats["checkpoints"]: return self.stats["checkpoints"][sha256]["total"] elif model_type == "lora": if sha256 in self.stats["loras"]: return self.stats["loras"][sha256]["total"] return 0 async def process_execution(self, prompt_id): """Process a prompt execution immediately (synchronous approach)""" if not prompt_id: return if standalone_mode: # Usage statistics are not available in standalone mode return try: # Process metadata for this prompt_id registry = MetadataRegistry() metadata = registry.get_metadata(prompt_id) if metadata: await self._process_metadata(metadata) # Save stats if needed await self.save_stats() except Exception as e: logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)