import os import json import time import asyncio import logging from typing import Dict, Set from ..config import config from ..services.service_registry import ServiceRegistry from ..metadata_collector.metadata_registry import MetadataRegistry from ..metadata_collector.constants import MODELS, LORAS logger = logging.getLogger(__name__) class UsageStats: """Track usage statistics for models and save to JSON""" _instance = None _lock = asyncio.Lock() # For thread safety # Default stats file name STATS_FILENAME = "lora_manager_stats.json" def __new__(cls): if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if self._initialized: return # Initialize stats storage self.stats = { "checkpoints": {}, # sha256 -> count "loras": {}, # sha256 -> count "total_executions": 0, "last_save_time": 0 } # Queue for prompt_ids to process self.pending_prompt_ids = set() # Load existing stats if available self._stats_file_path = self._get_stats_file_path() self._load_stats() # Save interval in seconds self.save_interval = 90 # 1.5 minutes # Start background task to process queued prompt_ids self._bg_task = asyncio.create_task(self._background_processor()) self._initialized = True logger.info("Usage statistics tracker initialized") def _get_stats_file_path(self) -> str: """Get the path to the stats JSON file""" if not config.loras_roots or len(config.loras_roots) == 0: # Fallback to temporary directory if no lora roots return os.path.join(config.temp_directory, self.STATS_FILENAME) # Use the first lora root return os.path.join(config.loras_roots[0], self.STATS_FILENAME) def _load_stats(self): """Load existing statistics from file""" try: if os.path.exists(self._stats_file_path): with open(self._stats_file_path, 'r', encoding='utf-8') as f: loaded_stats = json.load(f) # Update our stats with loaded data if isinstance(loaded_stats, dict): # Update individual sections to maintain structure if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict): self.stats["checkpoints"] = loaded_stats["checkpoints"] if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict): self.stats["loras"] = loaded_stats["loras"] if "total_executions" in loaded_stats: self.stats["total_executions"] = loaded_stats["total_executions"] logger.info(f"Loaded usage statistics from {self._stats_file_path}") except Exception as e: logger.error(f"Error loading usage statistics: {e}") async def save_stats(self, force=False): """Save statistics to file""" try: # Only save if it's been at least save_interval since last save or force is True current_time = time.time() if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval: return False # Use a lock to prevent concurrent writes async with self._lock: # Update last save time self.stats["last_save_time"] = current_time # Create directory if it doesn't exist os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True) # Write to a temporary file first, then move it to avoid corruption temp_path = f"{self._stats_file_path}.tmp" with open(temp_path, 'w', encoding='utf-8') as f: json.dump(self.stats, f, indent=2, ensure_ascii=False) # Replace the old file with the new one os.replace(temp_path, self._stats_file_path) logger.debug(f"Saved usage statistics to {self._stats_file_path}") return True except Exception as e: logger.error(f"Error saving usage statistics: {e}", exc_info=True) return False def register_execution(self, prompt_id): """Register a completed execution by prompt_id for later processing""" if prompt_id: self.pending_prompt_ids.add(prompt_id) async def _background_processor(self): """Background task to process queued prompt_ids""" try: while True: # Wait a short interval before checking for new prompt_ids await asyncio.sleep(5) # Check every 5 seconds # Process any pending prompt_ids if self.pending_prompt_ids: async with self._lock: # Get a copy of the set and clear original prompt_ids = self.pending_prompt_ids.copy() self.pending_prompt_ids.clear() # Process each prompt_id registry = MetadataRegistry() for prompt_id in prompt_ids: try: metadata = registry.get_metadata(prompt_id) await self._process_metadata(metadata) except Exception as e: logger.error(f"Error processing prompt_id {prompt_id}: {e}") # Periodically save stats await self.save_stats() except asyncio.CancelledError: # Task was cancelled, clean up await self.save_stats(force=True) except Exception as e: logger.error(f"Error in background processing task: {e}", exc_info=True) # Restart the task after a delay if it fails asyncio.create_task(self._restart_background_task()) async def _restart_background_task(self): """Restart the background task after a delay""" await asyncio.sleep(30) # Wait 30 seconds before restarting self._bg_task = asyncio.create_task(self._background_processor()) async def _process_metadata(self, metadata): """Process metadata from an execution""" if not metadata or not isinstance(metadata, dict): return # Increment total executions count self.stats["total_executions"] += 1 # Process checkpoints if MODELS in metadata and isinstance(metadata[MODELS], dict): await self._process_checkpoints(metadata[MODELS]) # Process loras if LORAS in metadata and isinstance(metadata[LORAS], dict): await self._process_loras(metadata[LORAS]) async def _process_checkpoints(self, models_data): """Process checkpoint models from metadata""" try: # Get checkpoint scanner service checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() if not checkpoint_scanner: logger.warning("Checkpoint scanner not available for usage tracking") return for node_id, model_info in models_data.items(): if not isinstance(model_info, dict): continue # Check if this is a checkpoint model model_type = model_info.get("type") if model_type == "checkpoint": model_name = model_info.get("name") if not model_name: continue # Clean up filename (remove extension if present) model_filename = os.path.splitext(os.path.basename(model_name))[0] # Get hash for this checkpoint model_hash = checkpoint_scanner.get_hash_by_filename(model_filename) if model_hash: # Update stats for this checkpoint self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1 except Exception as e: logger.error(f"Error processing checkpoint usage: {e}", exc_info=True) async def _process_loras(self, loras_data): """Process LoRA models from metadata""" try: # Get LoRA scanner service lora_scanner = await ServiceRegistry.get_lora_scanner() if not lora_scanner: logger.warning("LoRA scanner not available for usage tracking") return for node_id, lora_info in loras_data.items(): if not isinstance(lora_info, dict): continue # Get the list of LoRAs from standardized format lora_list = lora_info.get("lora_list", []) for lora in lora_list: if not isinstance(lora, dict): continue lora_name = lora.get("name") if not lora_name: continue # Get hash for this LoRA lora_hash = lora_scanner.get_hash_by_filename(lora_name) if lora_hash: # Update stats for this LoRA self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1 except Exception as e: logger.error(f"Error processing LoRA usage: {e}", exc_info=True) async def get_stats(self): """Get current usage statistics""" return self.stats async def get_model_usage_count(self, model_type, sha256): """Get usage count for a specific model by hash""" if model_type == "checkpoint": return self.stats["checkpoints"].get(sha256, 0) elif model_type == "lora": return self.stats["loras"].get(sha256, 0) return 0 async def process_execution(self, prompt_id): """Process a prompt execution immediately (synchronous approach)""" if not prompt_id: return try: # Process metadata for this prompt_id registry = MetadataRegistry() metadata = registry.get_metadata(prompt_id) if metadata: await self._process_metadata(metadata) # Save stats if needed await self.save_stats() except Exception as e: logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)