mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-13 05:59:25 -03:00
_load_stats() was missing the embeddings section, so on every restart the embeddings usage tracking hash would start from an empty dict. This caused all previously saved embedding usage data to appear reset. Added the missing load path for the 'embeddings' key, parallel to the existing checkpoints and loras loading logic.
644 lines
26 KiB
Python
644 lines
26 KiB
Python
import os
|
|
import re
|
|
import json
|
|
import time
|
|
import asyncio
|
|
import logging
|
|
import datetime
|
|
import shutil
|
|
from typing import Dict, Set
|
|
|
|
from ..config import config
|
|
from ..services.service_registry import ServiceRegistry
|
|
from ..utils.settings_paths import get_settings_dir
|
|
|
|
# Check if running in standalone mode
|
|
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
|
|
|
# Define constants locally to avoid dependency on conditional imports
|
|
MODELS = "models"
|
|
LORAS = "loras"
|
|
EMBEDDINGS = "embeddings"
|
|
PROMPTS = "prompts"
|
|
|
|
if not standalone_mode:
|
|
from ..metadata_collector.metadata_registry import MetadataRegistry
|
|
# Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above
|
|
try:
|
|
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS, EMBEDDINGS as _EMBEDDINGS, PROMPTS as _PROMPTS
|
|
MODELS = _MODELS
|
|
LORAS = _LORAS
|
|
EMBEDDINGS = _EMBEDDINGS
|
|
PROMPTS = _PROMPTS
|
|
except ImportError:
|
|
pass # Use the local definitions
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DEFAULT_CHECKPOINT_EXTENSIONS = {
|
|
".ckpt",
|
|
".pt",
|
|
".pt2",
|
|
".bin",
|
|
".pth",
|
|
".safetensors",
|
|
".pkl",
|
|
".sft",
|
|
".gguf",
|
|
}
|
|
|
|
class UsageStats:
|
|
"""Track usage statistics for models and save to JSON"""
|
|
|
|
_instance = None
|
|
_lock = asyncio.Lock() # For thread safety
|
|
|
|
# Default stats file name
|
|
STATS_FILENAME = "lora_manager_stats.json"
|
|
BACKUP_SUFFIX = ".backup"
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
cls._instance._initialized = False
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if self._initialized:
|
|
return
|
|
|
|
# Initialize stats storage
|
|
self.stats = {
|
|
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
|
|
"loras": {}, # sha256 -> { total: count, history: { date: count } }
|
|
"embeddings": {}, # sha256 -> { total: count, history: { date: count } }
|
|
"total_executions": 0,
|
|
"last_save_time": 0
|
|
}
|
|
|
|
# Track if stats have been modified since last save
|
|
self._is_dirty = False
|
|
|
|
# Queue for prompt_ids to process
|
|
self.pending_prompt_ids = set()
|
|
|
|
# Load existing stats if available
|
|
self._stats_file_path = self._get_stats_file_path()
|
|
self._migrate_from_old_location()
|
|
self._load_stats()
|
|
|
|
# Save interval in seconds
|
|
self.save_interval = 90 # 1.5 minutes
|
|
|
|
# Start background task to process queued prompt_ids
|
|
self._bg_task = asyncio.create_task(self._background_processor())
|
|
|
|
self._initialized = True
|
|
logger.debug("Usage statistics tracker initialized")
|
|
|
|
def _get_stats_file_path(self) -> str:
|
|
"""Get the path to the stats JSON file in the settings directory."""
|
|
settings_dir = get_settings_dir(create=True)
|
|
return os.path.join(settings_dir, "stats", self.STATS_FILENAME)
|
|
|
|
@staticmethod
|
|
def _get_old_stats_file_path() -> str:
|
|
"""Get the legacy stats file path in the first lora root directory."""
|
|
if not config.loras_roots or len(config.loras_roots) == 0:
|
|
return ""
|
|
return os.path.join(config.loras_roots[0], UsageStats.STATS_FILENAME)
|
|
|
|
def _migrate_from_old_location(self) -> None:
|
|
"""Migrate stats file from old location (first lora root) to new location (settings_dir/stats/)."""
|
|
new_path = self._stats_file_path
|
|
if os.path.exists(new_path):
|
|
return
|
|
|
|
old_path = self._get_old_stats_file_path()
|
|
if not old_path or not os.path.exists(old_path):
|
|
return
|
|
|
|
try:
|
|
os.makedirs(os.path.dirname(new_path), exist_ok=True)
|
|
shutil.copy2(old_path, new_path)
|
|
logger.info("Migrated usage stats from %s to %s", old_path, new_path)
|
|
try:
|
|
os.remove(old_path)
|
|
logger.info("Cleaned up old stats file: %s", old_path)
|
|
except Exception as e:
|
|
logger.warning("Failed to remove old stats file %s: %s", old_path, e)
|
|
except Exception as e:
|
|
logger.error("Failed to migrate usage stats from %s to %s: %s", old_path, new_path, e)
|
|
|
|
def _backup_old_stats(self):
|
|
"""Backup the old stats file before conversion"""
|
|
if os.path.exists(self._stats_file_path):
|
|
backup_path = f"{self._stats_file_path}{self.BACKUP_SUFFIX}"
|
|
try:
|
|
shutil.copy2(self._stats_file_path, backup_path)
|
|
logger.info(f"Backed up old stats file to {backup_path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to backup stats file: {e}")
|
|
return False
|
|
|
|
def _convert_old_format(self, old_stats):
|
|
"""Convert old stats format to new format with history"""
|
|
new_stats = {
|
|
"checkpoints": {},
|
|
"loras": {},
|
|
"embeddings": {},
|
|
"total_executions": old_stats.get("total_executions", 0),
|
|
"last_save_time": old_stats.get("last_save_time", time.time())
|
|
}
|
|
|
|
# Get today's date in YYYY-MM-DD format
|
|
today = datetime.datetime.now().strftime("%Y-%m-%d")
|
|
|
|
# Convert checkpoint stats
|
|
if "checkpoints" in old_stats and isinstance(old_stats["checkpoints"], dict):
|
|
for hash_id, count in old_stats["checkpoints"].items():
|
|
new_stats["checkpoints"][hash_id] = {
|
|
"total": count,
|
|
"history": {
|
|
today: count
|
|
}
|
|
}
|
|
|
|
# Convert lora stats
|
|
if "loras" in old_stats and isinstance(old_stats["loras"], dict):
|
|
for hash_id, count in old_stats["loras"].items():
|
|
new_stats["loras"][hash_id] = {
|
|
"total": count,
|
|
"history": {
|
|
today: count
|
|
}
|
|
}
|
|
|
|
# Convert embedding stats (if present in old format)
|
|
if "embeddings" in old_stats and isinstance(old_stats["embeddings"], dict):
|
|
for hash_id, count in old_stats["embeddings"].items():
|
|
new_stats["embeddings"][hash_id] = {
|
|
"total": count,
|
|
"history": {
|
|
today: count
|
|
}
|
|
}
|
|
|
|
logger.info("Successfully converted stats from old format to new format with history")
|
|
return new_stats
|
|
|
|
def _is_old_format(self, stats):
|
|
"""Check if the stats are in the old format (direct count values)"""
|
|
# Check if any lora or checkpoint entry is a direct number instead of an object
|
|
for category in ("loras", "checkpoints", "embeddings"):
|
|
if category in stats and isinstance(stats[category], dict):
|
|
for hash_id, data in stats[category].items():
|
|
if isinstance(data, (int, float)):
|
|
return True
|
|
|
|
return False
|
|
|
|
def _load_stats(self):
|
|
"""Load existing statistics from file"""
|
|
try:
|
|
if os.path.exists(self._stats_file_path):
|
|
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
|
|
loaded_stats = json.load(f)
|
|
|
|
# Check if old format and needs conversion
|
|
if self._is_old_format(loaded_stats):
|
|
logger.info("Detected old stats format, performing conversion")
|
|
self._backup_old_stats()
|
|
self.stats = self._convert_old_format(loaded_stats)
|
|
else:
|
|
# Update our stats with loaded data (already in new format)
|
|
if isinstance(loaded_stats, dict):
|
|
# Update individual sections to maintain structure
|
|
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
|
|
self.stats["checkpoints"] = loaded_stats["checkpoints"]
|
|
|
|
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
|
|
self.stats["loras"] = loaded_stats["loras"]
|
|
|
|
if "embeddings" in loaded_stats and isinstance(loaded_stats["embeddings"], dict):
|
|
self.stats["embeddings"] = loaded_stats["embeddings"]
|
|
|
|
if "total_executions" in loaded_stats:
|
|
self.stats["total_executions"] = loaded_stats["total_executions"]
|
|
|
|
if "last_save_time" in loaded_stats:
|
|
self.stats["last_save_time"] = loaded_stats["last_save_time"]
|
|
|
|
logger.debug(f"Loaded usage statistics from {self._stats_file_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error loading usage statistics: {e}")
|
|
|
|
async def save_stats(self, force=False):
|
|
"""Save statistics to file"""
|
|
try:
|
|
# Only save if:
|
|
# 1. force is True, OR
|
|
# 2. stats have been modified (is_dirty) AND save_interval has passed
|
|
current_time = time.time()
|
|
time_since_last_save = current_time - self.stats.get("last_save_time", 0)
|
|
|
|
if not force:
|
|
if not self._is_dirty:
|
|
# No changes to save
|
|
return False
|
|
if time_since_last_save < self.save_interval:
|
|
# Too soon since last save
|
|
return False
|
|
|
|
# Use a lock to prevent concurrent writes
|
|
async with self._lock:
|
|
# Update last save time
|
|
self.stats["last_save_time"] = current_time
|
|
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
|
|
|
|
# Write to a temporary file first, then move it to avoid corruption
|
|
temp_path = f"{self._stats_file_path}.tmp"
|
|
with open(temp_path, 'w', encoding='utf-8') as f:
|
|
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
|
|
|
# Replace the old file with the new one
|
|
os.replace(temp_path, self._stats_file_path)
|
|
|
|
# Clear dirty flag since we've saved
|
|
self._is_dirty = False
|
|
|
|
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
|
|
return False
|
|
|
|
def register_execution(self, prompt_id):
|
|
"""Register a completed execution by prompt_id for later processing"""
|
|
if prompt_id:
|
|
self.pending_prompt_ids.add(prompt_id)
|
|
|
|
async def _background_processor(self):
|
|
"""Background task to process queued prompt_ids"""
|
|
try:
|
|
while True:
|
|
# Wait a short interval before checking for new prompt_ids
|
|
await asyncio.sleep(5) # Check every 5 seconds
|
|
|
|
# Process any pending prompt_ids
|
|
if self.pending_prompt_ids:
|
|
async with self._lock:
|
|
# Get a copy of the set and clear original
|
|
prompt_ids = self.pending_prompt_ids.copy()
|
|
self.pending_prompt_ids.clear()
|
|
|
|
# Process each prompt_id
|
|
try:
|
|
registry = MetadataRegistry()
|
|
except NameError:
|
|
# MetadataRegistry not available (standalone mode)
|
|
registry = None
|
|
|
|
if registry:
|
|
for prompt_id in prompt_ids:
|
|
try:
|
|
metadata = registry.get_metadata(prompt_id)
|
|
await self._process_metadata(metadata)
|
|
except Exception as e:
|
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
|
|
|
|
# Periodically save stats (only if there are changes)
|
|
if self._is_dirty:
|
|
await self.save_stats()
|
|
except asyncio.CancelledError:
|
|
# Task was cancelled, clean up
|
|
await self.save_stats(force=True)
|
|
except Exception as e:
|
|
logger.error(f"Error in background processing task: {e}", exc_info=True)
|
|
# Restart the task after a delay if it fails
|
|
asyncio.create_task(self._restart_background_task())
|
|
|
|
async def _restart_background_task(self):
|
|
"""Restart the background task after a delay"""
|
|
await asyncio.sleep(30) # Wait 30 seconds before restarting
|
|
self._bg_task = asyncio.create_task(self._background_processor())
|
|
|
|
async def _process_metadata(self, metadata):
|
|
"""Process metadata from an execution"""
|
|
if not metadata or not isinstance(metadata, dict):
|
|
return
|
|
|
|
# Increment total executions count
|
|
self.stats["total_executions"] += 1
|
|
self._is_dirty = True
|
|
|
|
# Get today's date in YYYY-MM-DD format
|
|
today = datetime.datetime.now().strftime("%Y-%m-%d")
|
|
|
|
# Process checkpoints
|
|
if MODELS in metadata and isinstance(metadata[MODELS], dict):
|
|
await self._process_checkpoints(metadata[MODELS], today)
|
|
|
|
# Process loras
|
|
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
|
await self._process_loras(metadata[LORAS], today)
|
|
|
|
# Process embeddings — parse prompt text for embedding:name references
|
|
if PROMPTS in metadata and isinstance(metadata[PROMPTS], dict):
|
|
await self._process_embeddings(metadata[PROMPTS], today)
|
|
|
|
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
|
|
"""Increment usage counters for a resolved stats key."""
|
|
if stat_key not in self.stats[category]:
|
|
self.stats[category][stat_key] = {
|
|
"total": 0,
|
|
"history": {}
|
|
}
|
|
|
|
self.stats[category][stat_key]["total"] += 1
|
|
|
|
if today_date not in self.stats[category][stat_key]["history"]:
|
|
self.stats[category][stat_key]["history"][today_date] = 0
|
|
self.stats[category][stat_key]["history"][today_date] += 1
|
|
|
|
def _normalize_model_lookup_name(self, model_name: str) -> str:
|
|
"""Normalize a model reference to its base filename without extension."""
|
|
return os.path.splitext(os.path.basename(model_name))[0]
|
|
|
|
async def _find_cached_checkpoint_entry(self, checkpoint_scanner, model_name: str):
|
|
"""Best-effort lookup for a checkpoint cache entry by filename/model name."""
|
|
get_cached_data = getattr(checkpoint_scanner, "get_cached_data", None)
|
|
if not callable(get_cached_data):
|
|
return None
|
|
|
|
cache = await get_cached_data()
|
|
raw_data = getattr(cache, "raw_data", None)
|
|
if not isinstance(raw_data, list):
|
|
return None
|
|
|
|
normalized_name = self._normalize_model_lookup_name(model_name)
|
|
for entry in raw_data:
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
|
|
for candidate_key in ("file_name", "model_name", "file_path"):
|
|
candidate_value = entry.get(candidate_key)
|
|
if not candidate_value or not isinstance(candidate_value, str):
|
|
continue
|
|
if self._normalize_model_lookup_name(candidate_value) == normalized_name:
|
|
return entry
|
|
|
|
return None
|
|
|
|
async def _find_checkpoint_file_on_disk(self, checkpoint_scanner, model_name: str):
|
|
"""Search checkpoint roots directly for a matching file.
|
|
|
|
This is used when usage tracking sees a checkpoint name before the cache has
|
|
been refreshed. The lookup is intentionally exact: we only match the model
|
|
basename and supported checkpoint extensions.
|
|
"""
|
|
get_model_roots = getattr(checkpoint_scanner, "get_model_roots", None)
|
|
if not callable(get_model_roots):
|
|
return None
|
|
|
|
roots = [root for root in get_model_roots() if root]
|
|
if not roots:
|
|
return None
|
|
|
|
supported_extensions = getattr(
|
|
checkpoint_scanner, "file_extensions", _DEFAULT_CHECKPOINT_EXTENSIONS
|
|
)
|
|
if not isinstance(supported_extensions, (set, frozenset, list, tuple)):
|
|
supported_extensions = _DEFAULT_CHECKPOINT_EXTENSIONS
|
|
|
|
normalized_name = self._normalize_model_lookup_name(model_name)
|
|
matches: list[str] = []
|
|
|
|
for root_path in roots:
|
|
if not os.path.exists(root_path):
|
|
continue
|
|
|
|
for dirpath, _dirnames, filenames in os.walk(root_path):
|
|
for filename in filenames:
|
|
extension = os.path.splitext(filename)[1].lower()
|
|
if extension not in supported_extensions:
|
|
continue
|
|
|
|
if os.path.splitext(filename)[0] != normalized_name:
|
|
continue
|
|
|
|
matches.append(os.path.join(dirpath, filename).replace(os.sep, "/"))
|
|
|
|
if len(matches) > 1:
|
|
logger.warning(
|
|
"Multiple checkpoint files matched '%s'; skipping usage tracking: %s",
|
|
normalized_name,
|
|
", ".join(matches),
|
|
)
|
|
return None
|
|
|
|
return matches[0] if matches else None
|
|
|
|
async def _resolve_checkpoint_hash(self, checkpoint_scanner, model_name: str):
|
|
"""Resolve a checkpoint hash, calculating pending hashes on demand when needed."""
|
|
model_filename = self._normalize_model_lookup_name(model_name)
|
|
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
|
if model_hash:
|
|
return model_hash
|
|
|
|
cached_entry = await self._find_cached_checkpoint_entry(checkpoint_scanner, model_name)
|
|
if cached_entry:
|
|
cached_hash = cached_entry.get("sha256")
|
|
if cached_hash:
|
|
return cached_hash
|
|
|
|
hash_status = cached_entry.get("hash_status")
|
|
if hash_status and hash_status != "pending":
|
|
logger.warning(
|
|
"Checkpoint '%s' has hash_status=%s; skipping usage tracking",
|
|
model_filename,
|
|
hash_status,
|
|
)
|
|
return None
|
|
|
|
file_path = cached_entry.get("file_path") if cached_entry else None
|
|
if not file_path:
|
|
file_path = await self._find_checkpoint_file_on_disk(
|
|
checkpoint_scanner, model_name
|
|
)
|
|
|
|
if not file_path:
|
|
logger.warning(
|
|
f"No hash found for checkpoint '{model_filename}', skipping usage tracking"
|
|
)
|
|
return None
|
|
|
|
calculate_hash = getattr(checkpoint_scanner, "calculate_hash_for_model", None)
|
|
if not callable(calculate_hash):
|
|
logger.warning("Checkpoint scanner not available for usage tracking")
|
|
return None
|
|
|
|
logger.info(
|
|
"Calculating hash for checkpoint '%s' from %s",
|
|
model_filename,
|
|
file_path,
|
|
)
|
|
calculated_hash = await calculate_hash(file_path)
|
|
if calculated_hash:
|
|
return calculated_hash
|
|
|
|
logger.warning(
|
|
f"Failed to calculate hash for checkpoint '{model_filename}', skipping usage tracking"
|
|
)
|
|
return None
|
|
|
|
async def _process_checkpoints(self, models_data, today_date):
|
|
"""Process checkpoint models from metadata"""
|
|
try:
|
|
# Get checkpoint scanner service
|
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
if not checkpoint_scanner:
|
|
logger.warning("Checkpoint scanner not available for usage tracking")
|
|
return
|
|
|
|
for node_id, model_info in models_data.items():
|
|
if not isinstance(model_info, dict):
|
|
continue
|
|
|
|
# Check if this is a checkpoint model
|
|
model_type = model_info.get("type")
|
|
if model_type == "checkpoint":
|
|
model_name = model_info.get("name")
|
|
if not model_name:
|
|
continue
|
|
|
|
model_hash = await self._resolve_checkpoint_hash(checkpoint_scanner, model_name)
|
|
if not model_hash:
|
|
continue
|
|
|
|
self._increment_usage_counter("checkpoints", model_hash, today_date)
|
|
except Exception as e:
|
|
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
|
|
|
|
async def _process_loras(self, loras_data, today_date):
|
|
"""Process LoRA models from metadata"""
|
|
try:
|
|
# Get LoRA scanner service
|
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
|
if not lora_scanner:
|
|
logger.warning("LoRA scanner not available for usage tracking")
|
|
return
|
|
|
|
for node_id, lora_info in loras_data.items():
|
|
if not isinstance(lora_info, dict):
|
|
continue
|
|
|
|
# Get the list of LoRAs from standardized format
|
|
lora_list = lora_info.get("lora_list", [])
|
|
for lora in lora_list:
|
|
if not isinstance(lora, dict):
|
|
continue
|
|
|
|
lora_name = lora.get("name")
|
|
if not lora_name:
|
|
continue
|
|
|
|
# Get hash for this LoRA
|
|
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
|
|
if not lora_hash:
|
|
logger.warning(f"No hash found for LoRA '{lora_name}', skipping usage tracking")
|
|
continue
|
|
|
|
self._increment_usage_counter("loras", lora_hash, today_date)
|
|
except Exception as e:
|
|
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
|
|
|
@staticmethod
|
|
def _extract_embedding_names(prompt_text: str) -> set:
|
|
"""Parse embedding:name references from prompt text.
|
|
|
|
ComfyUI's SDTokenizer resolves ``embedding:<name>`` during tokenization
|
|
(see ``sd1_clip.py _try_get_embedding``). This mirrors the same pattern
|
|
to extract embedding file names from the captured prompt strings.
|
|
"""
|
|
if not prompt_text:
|
|
return set()
|
|
# Matches ``embedding:name`` where name is alphanumeric plus _ . - /
|
|
names = re.findall(r"embedding:([a-zA-Z0-9_.\-/]+)", prompt_text)
|
|
return set(names)
|
|
|
|
async def _process_embeddings(self, prompts_data, today_date):
|
|
"""Extract embedding usage from prompt texts and record it.
|
|
|
|
Iterates every prompt node's text field captured by the metadata
|
|
collector, extracts ``embedding:<name>`` references, resolves each
|
|
name to its SHA256 hash via the embedding scanner, and increments
|
|
usage counters.
|
|
"""
|
|
try:
|
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
|
if not embedding_scanner:
|
|
logger.warning("Embedding scanner not available for usage tracking")
|
|
return
|
|
|
|
seen_names = set()
|
|
for _node_id, prompt_data in prompts_data.items():
|
|
if not isinstance(prompt_data, dict):
|
|
continue
|
|
for text_field in ("text", "positive_text", "negative_text"):
|
|
text = prompt_data.get(text_field)
|
|
if isinstance(text, str):
|
|
seen_names.update(self._extract_embedding_names(text))
|
|
|
|
for emb_name in seen_names:
|
|
emb_hash = embedding_scanner.get_hash_by_filename(emb_name)
|
|
if emb_hash:
|
|
self._increment_usage_counter("embeddings", emb_hash, today_date)
|
|
else:
|
|
logger.debug(
|
|
"No hash found for embedding '%s', skipping usage tracking",
|
|
emb_name,
|
|
)
|
|
except Exception as e:
|
|
logger.error("Error processing embedding usage: %s", e, exc_info=True)
|
|
|
|
async def get_stats(self):
|
|
"""Get current usage statistics"""
|
|
return self.stats
|
|
|
|
async def get_model_usage_count(self, model_type, sha256):
|
|
"""Get usage count for a specific model by hash"""
|
|
if model_type == "checkpoint":
|
|
if sha256 in self.stats["checkpoints"]:
|
|
return self.stats["checkpoints"][sha256]["total"]
|
|
elif model_type == "lora":
|
|
if sha256 in self.stats["loras"]:
|
|
return self.stats["loras"][sha256]["total"]
|
|
elif model_type == "embedding":
|
|
if sha256 in self.stats["embeddings"]:
|
|
return self.stats["embeddings"][sha256]["total"]
|
|
return 0
|
|
|
|
async def process_execution(self, prompt_id):
|
|
"""Process a prompt execution immediately (synchronous approach)"""
|
|
if not prompt_id:
|
|
return
|
|
|
|
if standalone_mode:
|
|
# Usage statistics are not available in standalone mode
|
|
return
|
|
|
|
try:
|
|
# Process metadata for this prompt_id
|
|
registry = MetadataRegistry()
|
|
metadata = registry.get_metadata(prompt_id)
|
|
if metadata:
|
|
await self._process_metadata(metadata)
|
|
# Save stats if needed
|
|
await self.save_stats()
|
|
except Exception as e:
|
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)
|