diff --git a/py/lora_manager.py b/py/lora_manager.py index 2b04da36..08b254a4 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -5,6 +5,8 @@ from .routes.lora_routes import LoraRoutes from .routes.api_routes import ApiRoutes from .routes.recipe_routes import RecipeRoutes from .routes.checkpoints_routes import CheckpointsRoutes +from .routes.update_routes import UpdateRoutes +from .routes.usage_stats_routes import UsageStatsRoutes from .services.service_registry import ServiceRegistry import logging @@ -92,6 +94,8 @@ class LoraManager: checkpoints_routes.setup_routes(app) ApiRoutes.setup_routes(app) RecipeRoutes.setup_routes(app) + UpdateRoutes.setup_routes(app) + UsageStatsRoutes.setup_routes(app) # Register usage stats routes # Schedule service initialization app.on_startup.append(lambda app: cls._initialize_services()) diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py index c1109580..9a3ba95f 100644 --- a/py/metadata_collector/constants.py +++ b/py/metadata_collector/constants.py @@ -1,12 +1,14 @@ """Constants used by the metadata collector""" -# Individual category constants +# Metadata collection constants + +# Metadata categories MODELS = "models" PROMPTS = "prompts" SAMPLING = "sampling" LORAS = "loras" SIZE = "size" -IMAGES = "images" # Added new category for image results +IMAGES = "images" -# Collection of categories for iteration -METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories +# Complete list of categories to track +METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] diff --git a/py/routes/usage_stats_routes.py b/py/routes/usage_stats_routes.py new file mode 100644 index 00000000..0d162681 --- /dev/null +++ b/py/routes/usage_stats_routes.py @@ -0,0 +1,49 @@ +import logging +from aiohttp import web +from ..utils.usage_stats import UsageStats + +logger = logging.getLogger(__name__) + +class UsageStatsRoutes: + """Routes for handling usage statistics updates""" + + @staticmethod + def setup_routes(app): + """Register usage stats routes""" + app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats) + + @staticmethod + async def update_usage_stats(request): + """ + Update usage statistics based on a prompt_id + + Expects a JSON body with: + { + "prompt_id": "string" + } + """ + try: + # Parse the request body + data = await request.json() + prompt_id = data.get('prompt_id') + + if not prompt_id: + return web.json_response({ + 'success': False, + 'error': 'Missing prompt_id' + }, status=400) + + # Call the UsageStats to process this prompt_id synchronously + usage_stats = UsageStats() + await usage_stats.process_execution(prompt_id) + + return web.json_response({ + 'success': True + }) + + except Exception as e: + logger.error(f"Failed to update usage stats: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) diff --git a/py/utils/usage_stats.py b/py/utils/usage_stats.py new file mode 100644 index 00000000..1d365120 --- /dev/null +++ b/py/utils/usage_stats.py @@ -0,0 +1,267 @@ +import os +import json +import time +import asyncio +import logging +from typing import Dict, Set + +from ..config import config +from ..services.service_registry import ServiceRegistry +from ..metadata_collector.metadata_registry import MetadataRegistry +from ..metadata_collector.constants import MODELS, LORAS + +logger = logging.getLogger(__name__) + +class UsageStats: + """Track usage statistics for models and save to JSON""" + + _instance = None + _lock = asyncio.Lock() # For thread safety + + # Default stats file name + STATS_FILENAME = "lora_manager_stats.json" + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + + # Initialize stats storage + self.stats = { + "checkpoints": {}, # sha256 -> count + "loras": {}, # sha256 -> count + "total_executions": 0, + "last_save_time": 0 + } + + # Queue for prompt_ids to process + self.pending_prompt_ids = set() + + # Load existing stats if available + self._stats_file_path = self._get_stats_file_path() + self._load_stats() + + # Save interval in seconds + self.save_interval = 90 # 1.5 minutes + + # Start background task to process queued prompt_ids + self._bg_task = asyncio.create_task(self._background_processor()) + + self._initialized = True + logger.info("Usage statistics tracker initialized") + + def _get_stats_file_path(self) -> str: + """Get the path to the stats JSON file""" + if not config.loras_roots or len(config.loras_roots) == 0: + # Fallback to temporary directory if no lora roots + return os.path.join(config.temp_directory, self.STATS_FILENAME) + + # Use the first lora root + return os.path.join(config.loras_roots[0], self.STATS_FILENAME) + + def _load_stats(self): + """Load existing statistics from file""" + try: + if os.path.exists(self._stats_file_path): + with open(self._stats_file_path, 'r', encoding='utf-8') as f: + loaded_stats = json.load(f) + + # Update our stats with loaded data + if isinstance(loaded_stats, dict): + # Update individual sections to maintain structure + if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict): + self.stats["checkpoints"] = loaded_stats["checkpoints"] + + if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict): + self.stats["loras"] = loaded_stats["loras"] + + if "total_executions" in loaded_stats: + self.stats["total_executions"] = loaded_stats["total_executions"] + + logger.info(f"Loaded usage statistics from {self._stats_file_path}") + except Exception as e: + logger.error(f"Error loading usage statistics: {e}") + + async def save_stats(self, force=False): + """Save statistics to file""" + try: + # Only save if it's been at least save_interval since last save or force is True + current_time = time.time() + if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval: + return False + + # Use a lock to prevent concurrent writes + async with self._lock: + # Update last save time + self.stats["last_save_time"] = current_time + + # Create directory if it doesn't exist + os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True) + + # Write to a temporary file first, then move it to avoid corruption + temp_path = f"{self._stats_file_path}.tmp" + with open(temp_path, 'w', encoding='utf-8') as f: + json.dump(self.stats, f, indent=2, ensure_ascii=False) + + # Replace the old file with the new one + os.replace(temp_path, self._stats_file_path) + + logger.debug(f"Saved usage statistics to {self._stats_file_path}") + return True + except Exception as e: + logger.error(f"Error saving usage statistics: {e}", exc_info=True) + return False + + def register_execution(self, prompt_id): + """Register a completed execution by prompt_id for later processing""" + if prompt_id: + self.pending_prompt_ids.add(prompt_id) + + async def _background_processor(self): + """Background task to process queued prompt_ids""" + try: + while True: + # Wait a short interval before checking for new prompt_ids + await asyncio.sleep(5) # Check every 5 seconds + + # Process any pending prompt_ids + if self.pending_prompt_ids: + async with self._lock: + # Get a copy of the set and clear original + prompt_ids = self.pending_prompt_ids.copy() + self.pending_prompt_ids.clear() + + # Process each prompt_id + registry = MetadataRegistry() + for prompt_id in prompt_ids: + try: + metadata = registry.get_metadata(prompt_id) + await self._process_metadata(metadata) + except Exception as e: + logger.error(f"Error processing prompt_id {prompt_id}: {e}") + + # Periodically save stats + await self.save_stats() + except asyncio.CancelledError: + # Task was cancelled, clean up + await self.save_stats(force=True) + except Exception as e: + logger.error(f"Error in background processing task: {e}", exc_info=True) + # Restart the task after a delay if it fails + asyncio.create_task(self._restart_background_task()) + + async def _restart_background_task(self): + """Restart the background task after a delay""" + await asyncio.sleep(30) # Wait 30 seconds before restarting + self._bg_task = asyncio.create_task(self._background_processor()) + + async def _process_metadata(self, metadata): + """Process metadata from an execution""" + if not metadata or not isinstance(metadata, dict): + return + + # Increment total executions count + self.stats["total_executions"] += 1 + + # Process checkpoints + if MODELS in metadata and isinstance(metadata[MODELS], dict): + await self._process_checkpoints(metadata[MODELS]) + + # Process loras + if LORAS in metadata and isinstance(metadata[LORAS], dict): + await self._process_loras(metadata[LORAS]) + + async def _process_checkpoints(self, models_data): + """Process checkpoint models from metadata""" + try: + # Get checkpoint scanner service + checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() + if not checkpoint_scanner: + logger.warning("Checkpoint scanner not available for usage tracking") + return + + for node_id, model_info in models_data.items(): + if not isinstance(model_info, dict): + continue + + # Check if this is a checkpoint model + model_type = model_info.get("type") + if model_type == "checkpoint": + model_name = model_info.get("name") + if not model_name: + continue + + # Clean up filename (remove extension if present) + model_filename = os.path.splitext(os.path.basename(model_name))[0] + + # Get hash for this checkpoint + model_hash = checkpoint_scanner.get_hash_by_filename(model_filename) + if model_hash: + # Update stats for this checkpoint + self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1 + except Exception as e: + logger.error(f"Error processing checkpoint usage: {e}", exc_info=True) + + async def _process_loras(self, loras_data): + """Process LoRA models from metadata""" + try: + # Get LoRA scanner service + lora_scanner = await ServiceRegistry.get_lora_scanner() + if not lora_scanner: + logger.warning("LoRA scanner not available for usage tracking") + return + + for node_id, lora_info in loras_data.items(): + if not isinstance(lora_info, dict): + continue + + # Get the list of LoRAs from standardized format + lora_list = lora_info.get("lora_list", []) + for lora in lora_list: + if not isinstance(lora, dict): + continue + + lora_name = lora.get("name") + if not lora_name: + continue + + # Get hash for this LoRA + lora_hash = lora_scanner.get_hash_by_filename(lora_name) + if lora_hash: + # Update stats for this LoRA + self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1 + except Exception as e: + logger.error(f"Error processing LoRA usage: {e}", exc_info=True) + + async def get_stats(self): + """Get current usage statistics""" + return self.stats + + async def get_model_usage_count(self, model_type, sha256): + """Get usage count for a specific model by hash""" + if model_type == "checkpoint": + return self.stats["checkpoints"].get(sha256, 0) + elif model_type == "lora": + return self.stats["loras"].get(sha256, 0) + return 0 + + async def process_execution(self, prompt_id): + """Process a prompt execution immediately (synchronous approach)""" + if not prompt_id: + return + + try: + # Process metadata for this prompt_id + registry = MetadataRegistry() + metadata = registry.get_metadata(prompt_id) + if metadata: + await self._process_metadata(metadata) + # Save stats if needed + await self.save_stats() + except Exception as e: + logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True) diff --git a/web/comfyui/usage_stats.js b/web/comfyui/usage_stats.js new file mode 100644 index 00000000..97fcda74 --- /dev/null +++ b/web/comfyui/usage_stats.js @@ -0,0 +1,37 @@ +// ComfyUI extension to track model usage statistics +import { app } from "../../scripts/app.js"; +import { api } from "../../scripts/api.js"; + +// Register the extension +app.registerExtension({ + name: "ComfyUI-Lora-Manager.UsageStats", + + init() { + // Listen for successful executions + api.addEventListener("execution_success", ({ detail }) => { + if (detail && detail.prompt_id) { + this.updateUsageStats(detail.prompt_id); + } + }); + }, + + async updateUsageStats(promptId) { + try { + console.log("Updating usage statistics for prompt ID:", promptId); + // Call backend endpoint with the prompt_id + const response = await fetch(`/loras/api/update-usage-stats`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ prompt_id: promptId }), + }); + + if (!response.ok) { + console.warn("Failed to update usage statistics:", response.statusText); + } + } catch (error) { + console.error("Error updating usage statistics:", error); + } + } +});