feat(stats): track embedding usage from prompt text — Plan A + hybrid approach docs

This commit is contained in:
Will Miao
2026-06-11 17:12:34 +08:00
parent dd1cdce16d
commit f565cc35ca
3 changed files with 262 additions and 11 deletions

View File

@@ -1,4 +1,5 @@
import os
import re
import json
import time
import asyncio
@@ -16,14 +17,18 @@ standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.en
# Define constants locally to avoid dependency on conditional imports
MODELS = "models"
LORAS = "loras"
EMBEDDINGS = "embeddings"
PROMPTS = "prompts"
if not standalone_mode:
from ..metadata_collector.metadata_registry import MetadataRegistry
# Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above
try:
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS, EMBEDDINGS as _EMBEDDINGS, PROMPTS as _PROMPTS
MODELS = _MODELS
LORAS = _LORAS
EMBEDDINGS = _EMBEDDINGS
PROMPTS = _PROMPTS
except ImportError:
pass # Use the local definitions
@@ -65,6 +70,7 @@ class UsageStats:
self.stats = {
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
"loras": {}, # sha256 -> { total: count, history: { date: count } }
"embeddings": {}, # sha256 -> { total: count, history: { date: count } }
"total_executions": 0,
"last_save_time": 0
}
@@ -115,6 +121,7 @@ class UsageStats:
new_stats = {
"checkpoints": {},
"loras": {},
"embeddings": {},
"total_executions": old_stats.get("total_executions", 0),
"last_save_time": old_stats.get("last_save_time", time.time())
}
@@ -142,21 +149,27 @@ class UsageStats:
}
}
# Convert embedding stats (if present in old format)
if "embeddings" in old_stats and isinstance(old_stats["embeddings"], dict):
for hash_id, count in old_stats["embeddings"].items():
new_stats["embeddings"][hash_id] = {
"total": count,
"history": {
today: count
}
}
logger.info("Successfully converted stats from old format to new format with history")
return new_stats
def _is_old_format(self, stats):
"""Check if the stats are in the old format (direct count values)"""
# Check if any lora or checkpoint entry is a direct number instead of an object
if "loras" in stats and isinstance(stats["loras"], dict):
for hash_id, data in stats["loras"].items():
if isinstance(data, (int, float)):
return True
if "checkpoints" in stats and isinstance(stats["checkpoints"], dict):
for hash_id, data in stats["checkpoints"].items():
if isinstance(data, (int, float)):
return True
for category in ("loras", "checkpoints", "embeddings"):
if category in stats and isinstance(stats[category], dict):
for hash_id, data in stats[category].items():
if isinstance(data, (int, float)):
return True
return False
@@ -304,6 +317,10 @@ class UsageStats:
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS], today)
# Process embeddings — parse prompt text for embedding:name references
if PROMPTS in metadata and isinstance(metadata[PROMPTS], dict):
await self._process_embeddings(metadata[PROMPTS], today)
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
"""Increment usage counters for a resolved stats key."""
if stat_key not in self.stats[category]:
@@ -510,6 +527,55 @@ class UsageStats:
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
@staticmethod
def _extract_embedding_names(prompt_text: str) -> set:
"""Parse embedding:name references from prompt text.
ComfyUI's SDTokenizer resolves ``embedding:<name>`` during tokenization
(see ``sd1_clip.py _try_get_embedding``). This mirrors the same pattern
to extract embedding file names from the captured prompt strings.
"""
if not prompt_text:
return set()
# Matches ``embedding:name`` where name is alphanumeric plus _ . - /
names = re.findall(r"embedding:([a-zA-Z0-9_.\-/]+)", prompt_text)
return set(names)
async def _process_embeddings(self, prompts_data, today_date):
"""Extract embedding usage from prompt texts and record it.
Iterates every prompt node's text field captured by the metadata
collector, extracts ``embedding:<name>`` references, resolves each
name to its SHA256 hash via the embedding scanner, and increments
usage counters.
"""
try:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
if not embedding_scanner:
logger.warning("Embedding scanner not available for usage tracking")
return
seen_names = set()
for _node_id, prompt_data in prompts_data.items():
if not isinstance(prompt_data, dict):
continue
for text_field in ("text", "positive_text", "negative_text"):
text = prompt_data.get(text_field)
if isinstance(text, str):
seen_names.update(self._extract_embedding_names(text))
for emb_name in seen_names:
emb_hash = embedding_scanner.get_hash_by_filename(emb_name)
if emb_hash:
self._increment_usage_counter("embeddings", emb_hash, today_date)
else:
logger.debug(
"No hash found for embedding '%s', skipping usage tracking",
emb_name,
)
except Exception as e:
logger.error("Error processing embedding usage: %s", e, exc_info=True)
async def get_stats(self):
"""Get current usage statistics"""
return self.stats
@@ -522,6 +588,9 @@ class UsageStats:
elif model_type == "lora":
if sha256 in self.stats["loras"]:
return self.stats["loras"][sha256]["total"]
elif model_type == "embedding":
if sha256 in self.stats["embeddings"]:
return self.stats["embeddings"][sha256]["total"]
return 0
async def process_execution(self, prompt_id):