mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 23:25:43 -03:00
feat: Implement usage statistics tracking with backend integration and route setup
This commit is contained in:
@@ -5,6 +5,8 @@ from .routes.lora_routes import LoraRoutes
|
|||||||
from .routes.api_routes import ApiRoutes
|
from .routes.api_routes import ApiRoutes
|
||||||
from .routes.recipe_routes import RecipeRoutes
|
from .routes.recipe_routes import RecipeRoutes
|
||||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||||
|
from .routes.update_routes import UpdateRoutes
|
||||||
|
from .routes.usage_stats_routes import UsageStatsRoutes
|
||||||
from .services.service_registry import ServiceRegistry
|
from .services.service_registry import ServiceRegistry
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -92,6 +94,8 @@ class LoraManager:
|
|||||||
checkpoints_routes.setup_routes(app)
|
checkpoints_routes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
ApiRoutes.setup_routes(app)
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
|
UpdateRoutes.setup_routes(app)
|
||||||
|
UsageStatsRoutes.setup_routes(app) # Register usage stats routes
|
||||||
|
|
||||||
# Schedule service initialization
|
# Schedule service initialization
|
||||||
app.on_startup.append(lambda app: cls._initialize_services())
|
app.on_startup.append(lambda app: cls._initialize_services())
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
"""Constants used by the metadata collector"""
|
"""Constants used by the metadata collector"""
|
||||||
|
|
||||||
# Individual category constants
|
# Metadata collection constants
|
||||||
|
|
||||||
|
# Metadata categories
|
||||||
MODELS = "models"
|
MODELS = "models"
|
||||||
PROMPTS = "prompts"
|
PROMPTS = "prompts"
|
||||||
SAMPLING = "sampling"
|
SAMPLING = "sampling"
|
||||||
LORAS = "loras"
|
LORAS = "loras"
|
||||||
SIZE = "size"
|
SIZE = "size"
|
||||||
IMAGES = "images" # Added new category for image results
|
IMAGES = "images"
|
||||||
|
|
||||||
# Collection of categories for iteration
|
# Complete list of categories to track
|
||||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories
|
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||||
|
|||||||
49
py/routes/usage_stats_routes.py
Normal file
49
py/routes/usage_stats_routes.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
from ..utils.usage_stats import UsageStats
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UsageStatsRoutes:
|
||||||
|
"""Routes for handling usage statistics updates"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup_routes(app):
|
||||||
|
"""Register usage stats routes"""
|
||||||
|
app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_usage_stats(request):
|
||||||
|
"""
|
||||||
|
Update usage statistics based on a prompt_id
|
||||||
|
|
||||||
|
Expects a JSON body with:
|
||||||
|
{
|
||||||
|
"prompt_id": "string"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse the request body
|
||||||
|
data = await request.json()
|
||||||
|
prompt_id = data.get('prompt_id')
|
||||||
|
|
||||||
|
if not prompt_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Missing prompt_id'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Call the UsageStats to process this prompt_id synchronously
|
||||||
|
usage_stats = UsageStats()
|
||||||
|
await usage_stats.process_execution(prompt_id)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
267
py/utils/usage_stats.py
Normal file
267
py/utils/usage_stats.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Set
|
||||||
|
|
||||||
|
from ..config import config
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||||
|
from ..metadata_collector.constants import MODELS, LORAS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UsageStats:
|
||||||
|
"""Track usage statistics for models and save to JSON"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = asyncio.Lock() # For thread safety
|
||||||
|
|
||||||
|
# Default stats file name
|
||||||
|
STATS_FILENAME = "lora_manager_stats.json"
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize stats storage
|
||||||
|
self.stats = {
|
||||||
|
"checkpoints": {}, # sha256 -> count
|
||||||
|
"loras": {}, # sha256 -> count
|
||||||
|
"total_executions": 0,
|
||||||
|
"last_save_time": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Queue for prompt_ids to process
|
||||||
|
self.pending_prompt_ids = set()
|
||||||
|
|
||||||
|
# Load existing stats if available
|
||||||
|
self._stats_file_path = self._get_stats_file_path()
|
||||||
|
self._load_stats()
|
||||||
|
|
||||||
|
# Save interval in seconds
|
||||||
|
self.save_interval = 90 # 1.5 minutes
|
||||||
|
|
||||||
|
# Start background task to process queued prompt_ids
|
||||||
|
self._bg_task = asyncio.create_task(self._background_processor())
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Usage statistics tracker initialized")
|
||||||
|
|
||||||
|
def _get_stats_file_path(self) -> str:
|
||||||
|
"""Get the path to the stats JSON file"""
|
||||||
|
if not config.loras_roots or len(config.loras_roots) == 0:
|
||||||
|
# Fallback to temporary directory if no lora roots
|
||||||
|
return os.path.join(config.temp_directory, self.STATS_FILENAME)
|
||||||
|
|
||||||
|
# Use the first lora root
|
||||||
|
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
|
||||||
|
|
||||||
|
def _load_stats(self):
|
||||||
|
"""Load existing statistics from file"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(self._stats_file_path):
|
||||||
|
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
loaded_stats = json.load(f)
|
||||||
|
|
||||||
|
# Update our stats with loaded data
|
||||||
|
if isinstance(loaded_stats, dict):
|
||||||
|
# Update individual sections to maintain structure
|
||||||
|
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
|
||||||
|
self.stats["checkpoints"] = loaded_stats["checkpoints"]
|
||||||
|
|
||||||
|
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
|
||||||
|
self.stats["loras"] = loaded_stats["loras"]
|
||||||
|
|
||||||
|
if "total_executions" in loaded_stats:
|
||||||
|
self.stats["total_executions"] = loaded_stats["total_executions"]
|
||||||
|
|
||||||
|
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading usage statistics: {e}")
|
||||||
|
|
||||||
|
async def save_stats(self, force=False):
|
||||||
|
"""Save statistics to file"""
|
||||||
|
try:
|
||||||
|
# Only save if it's been at least save_interval since last save or force is True
|
||||||
|
current_time = time.time()
|
||||||
|
if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use a lock to prevent concurrent writes
|
||||||
|
async with self._lock:
|
||||||
|
# Update last save time
|
||||||
|
self.stats["last_save_time"] = current_time
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Write to a temporary file first, then move it to avoid corruption
|
||||||
|
temp_path = f"{self._stats_file_path}.tmp"
|
||||||
|
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Replace the old file with the new one
|
||||||
|
os.replace(temp_path, self._stats_file_path)
|
||||||
|
|
||||||
|
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def register_execution(self, prompt_id):
|
||||||
|
"""Register a completed execution by prompt_id for later processing"""
|
||||||
|
if prompt_id:
|
||||||
|
self.pending_prompt_ids.add(prompt_id)
|
||||||
|
|
||||||
|
async def _background_processor(self):
|
||||||
|
"""Background task to process queued prompt_ids"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Wait a short interval before checking for new prompt_ids
|
||||||
|
await asyncio.sleep(5) # Check every 5 seconds
|
||||||
|
|
||||||
|
# Process any pending prompt_ids
|
||||||
|
if self.pending_prompt_ids:
|
||||||
|
async with self._lock:
|
||||||
|
# Get a copy of the set and clear original
|
||||||
|
prompt_ids = self.pending_prompt_ids.copy()
|
||||||
|
self.pending_prompt_ids.clear()
|
||||||
|
|
||||||
|
# Process each prompt_id
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
for prompt_id in prompt_ids:
|
||||||
|
try:
|
||||||
|
metadata = registry.get_metadata(prompt_id)
|
||||||
|
await self._process_metadata(metadata)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
|
||||||
|
|
||||||
|
# Periodically save stats
|
||||||
|
await self.save_stats()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Task was cancelled, clean up
|
||||||
|
await self.save_stats(force=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in background processing task: {e}", exc_info=True)
|
||||||
|
# Restart the task after a delay if it fails
|
||||||
|
asyncio.create_task(self._restart_background_task())
|
||||||
|
|
||||||
|
async def _restart_background_task(self):
|
||||||
|
"""Restart the background task after a delay"""
|
||||||
|
await asyncio.sleep(30) # Wait 30 seconds before restarting
|
||||||
|
self._bg_task = asyncio.create_task(self._background_processor())
|
||||||
|
|
||||||
|
async def _process_metadata(self, metadata):
|
||||||
|
"""Process metadata from an execution"""
|
||||||
|
if not metadata or not isinstance(metadata, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Increment total executions count
|
||||||
|
self.stats["total_executions"] += 1
|
||||||
|
|
||||||
|
# Process checkpoints
|
||||||
|
if MODELS in metadata and isinstance(metadata[MODELS], dict):
|
||||||
|
await self._process_checkpoints(metadata[MODELS])
|
||||||
|
|
||||||
|
# Process loras
|
||||||
|
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
||||||
|
await self._process_loras(metadata[LORAS])
|
||||||
|
|
||||||
|
async def _process_checkpoints(self, models_data):
|
||||||
|
"""Process checkpoint models from metadata"""
|
||||||
|
try:
|
||||||
|
# Get checkpoint scanner service
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
if not checkpoint_scanner:
|
||||||
|
logger.warning("Checkpoint scanner not available for usage tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
for node_id, model_info in models_data.items():
|
||||||
|
if not isinstance(model_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this is a checkpoint model
|
||||||
|
model_type = model_info.get("type")
|
||||||
|
if model_type == "checkpoint":
|
||||||
|
model_name = model_info.get("name")
|
||||||
|
if not model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Clean up filename (remove extension if present)
|
||||||
|
model_filename = os.path.splitext(os.path.basename(model_name))[0]
|
||||||
|
|
||||||
|
# Get hash for this checkpoint
|
||||||
|
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
||||||
|
if model_hash:
|
||||||
|
# Update stats for this checkpoint
|
||||||
|
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _process_loras(self, loras_data):
|
||||||
|
"""Process LoRA models from metadata"""
|
||||||
|
try:
|
||||||
|
# Get LoRA scanner service
|
||||||
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
if not lora_scanner:
|
||||||
|
logger.warning("LoRA scanner not available for usage tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
for node_id, lora_info in loras_data.items():
|
||||||
|
if not isinstance(lora_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the list of LoRAs from standardized format
|
||||||
|
lora_list = lora_info.get("lora_list", [])
|
||||||
|
for lora in lora_list:
|
||||||
|
if not isinstance(lora, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = lora.get("name")
|
||||||
|
if not lora_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get hash for this LoRA
|
||||||
|
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
|
||||||
|
if lora_hash:
|
||||||
|
# Update stats for this LoRA
|
||||||
|
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def get_stats(self):
|
||||||
|
"""Get current usage statistics"""
|
||||||
|
return self.stats
|
||||||
|
|
||||||
|
async def get_model_usage_count(self, model_type, sha256):
|
||||||
|
"""Get usage count for a specific model by hash"""
|
||||||
|
if model_type == "checkpoint":
|
||||||
|
return self.stats["checkpoints"].get(sha256, 0)
|
||||||
|
elif model_type == "lora":
|
||||||
|
return self.stats["loras"].get(sha256, 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def process_execution(self, prompt_id):
|
||||||
|
"""Process a prompt execution immediately (synchronous approach)"""
|
||||||
|
if not prompt_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process metadata for this prompt_id
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
metadata = registry.get_metadata(prompt_id)
|
||||||
|
if metadata:
|
||||||
|
await self._process_metadata(metadata)
|
||||||
|
# Save stats if needed
|
||||||
|
await self.save_stats()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)
|
||||||
37
web/comfyui/usage_stats.js
Normal file
37
web/comfyui/usage_stats.js
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
// ComfyUI extension to track model usage statistics
|
||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
|
||||||
|
// Register the extension
|
||||||
|
app.registerExtension({
|
||||||
|
name: "ComfyUI-Lora-Manager.UsageStats",
|
||||||
|
|
||||||
|
init() {
|
||||||
|
// Listen for successful executions
|
||||||
|
api.addEventListener("execution_success", ({ detail }) => {
|
||||||
|
if (detail && detail.prompt_id) {
|
||||||
|
this.updateUsageStats(detail.prompt_id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
async updateUsageStats(promptId) {
|
||||||
|
try {
|
||||||
|
console.log("Updating usage statistics for prompt ID:", promptId);
|
||||||
|
// Call backend endpoint with the prompt_id
|
||||||
|
const response = await fetch(`/loras/api/update-usage-stats`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ prompt_id: promptId }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
console.warn("Failed to update usage statistics:", response.statusText);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error updating usage statistics:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
Reference in New Issue
Block a user