mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-11 13:19:24 -03:00
feat(stats): track embedding usage from prompt text — Plan A + hybrid approach docs
This commit is contained in:
@@ -5,9 +5,10 @@ MODELS = "models"
|
||||
PROMPTS = "prompts"
|
||||
SAMPLING = "sampling"
|
||||
LORAS = "loras"
|
||||
EMBEDDINGS = "embeddings"
|
||||
SIZE = "size"
|
||||
IMAGES = "images"
|
||||
IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes
|
||||
|
||||
# Complete list of categories to track
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, EMBEDDINGS, SIZE, IMAGES]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
@@ -16,14 +17,18 @@ standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.en
|
||||
# Define constants locally to avoid dependency on conditional imports
|
||||
MODELS = "models"
|
||||
LORAS = "loras"
|
||||
EMBEDDINGS = "embeddings"
|
||||
PROMPTS = "prompts"
|
||||
|
||||
if not standalone_mode:
|
||||
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||
# Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above
|
||||
try:
|
||||
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS
|
||||
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS, EMBEDDINGS as _EMBEDDINGS, PROMPTS as _PROMPTS
|
||||
MODELS = _MODELS
|
||||
LORAS = _LORAS
|
||||
EMBEDDINGS = _EMBEDDINGS
|
||||
PROMPTS = _PROMPTS
|
||||
except ImportError:
|
||||
pass # Use the local definitions
|
||||
|
||||
@@ -65,6 +70,7 @@ class UsageStats:
|
||||
self.stats = {
|
||||
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
|
||||
"loras": {}, # sha256 -> { total: count, history: { date: count } }
|
||||
"embeddings": {}, # sha256 -> { total: count, history: { date: count } }
|
||||
"total_executions": 0,
|
||||
"last_save_time": 0
|
||||
}
|
||||
@@ -115,6 +121,7 @@ class UsageStats:
|
||||
new_stats = {
|
||||
"checkpoints": {},
|
||||
"loras": {},
|
||||
"embeddings": {},
|
||||
"total_executions": old_stats.get("total_executions", 0),
|
||||
"last_save_time": old_stats.get("last_save_time", time.time())
|
||||
}
|
||||
@@ -142,21 +149,27 @@ class UsageStats:
|
||||
}
|
||||
}
|
||||
|
||||
# Convert embedding stats (if present in old format)
|
||||
if "embeddings" in old_stats and isinstance(old_stats["embeddings"], dict):
|
||||
for hash_id, count in old_stats["embeddings"].items():
|
||||
new_stats["embeddings"][hash_id] = {
|
||||
"total": count,
|
||||
"history": {
|
||||
today: count
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("Successfully converted stats from old format to new format with history")
|
||||
return new_stats
|
||||
|
||||
def _is_old_format(self, stats):
|
||||
"""Check if the stats are in the old format (direct count values)"""
|
||||
# Check if any lora or checkpoint entry is a direct number instead of an object
|
||||
if "loras" in stats and isinstance(stats["loras"], dict):
|
||||
for hash_id, data in stats["loras"].items():
|
||||
if isinstance(data, (int, float)):
|
||||
return True
|
||||
|
||||
if "checkpoints" in stats and isinstance(stats["checkpoints"], dict):
|
||||
for hash_id, data in stats["checkpoints"].items():
|
||||
if isinstance(data, (int, float)):
|
||||
return True
|
||||
for category in ("loras", "checkpoints", "embeddings"):
|
||||
if category in stats and isinstance(stats[category], dict):
|
||||
for hash_id, data in stats[category].items():
|
||||
if isinstance(data, (int, float)):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -304,6 +317,10 @@ class UsageStats:
|
||||
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
||||
await self._process_loras(metadata[LORAS], today)
|
||||
|
||||
# Process embeddings — parse prompt text for embedding:name references
|
||||
if PROMPTS in metadata and isinstance(metadata[PROMPTS], dict):
|
||||
await self._process_embeddings(metadata[PROMPTS], today)
|
||||
|
||||
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
|
||||
"""Increment usage counters for a resolved stats key."""
|
||||
if stat_key not in self.stats[category]:
|
||||
@@ -510,6 +527,55 @@ class UsageStats:
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||
|
||||
@staticmethod
|
||||
def _extract_embedding_names(prompt_text: str) -> set:
|
||||
"""Parse embedding:name references from prompt text.
|
||||
|
||||
ComfyUI's SDTokenizer resolves ``embedding:<name>`` during tokenization
|
||||
(see ``sd1_clip.py _try_get_embedding``). This mirrors the same pattern
|
||||
to extract embedding file names from the captured prompt strings.
|
||||
"""
|
||||
if not prompt_text:
|
||||
return set()
|
||||
# Matches ``embedding:name`` where name is alphanumeric plus _ . - /
|
||||
names = re.findall(r"embedding:([a-zA-Z0-9_.\-/]+)", prompt_text)
|
||||
return set(names)
|
||||
|
||||
async def _process_embeddings(self, prompts_data, today_date):
|
||||
"""Extract embedding usage from prompt texts and record it.
|
||||
|
||||
Iterates every prompt node's text field captured by the metadata
|
||||
collector, extracts ``embedding:<name>`` references, resolves each
|
||||
name to its SHA256 hash via the embedding scanner, and increments
|
||||
usage counters.
|
||||
"""
|
||||
try:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
if not embedding_scanner:
|
||||
logger.warning("Embedding scanner not available for usage tracking")
|
||||
return
|
||||
|
||||
seen_names = set()
|
||||
for _node_id, prompt_data in prompts_data.items():
|
||||
if not isinstance(prompt_data, dict):
|
||||
continue
|
||||
for text_field in ("text", "positive_text", "negative_text"):
|
||||
text = prompt_data.get(text_field)
|
||||
if isinstance(text, str):
|
||||
seen_names.update(self._extract_embedding_names(text))
|
||||
|
||||
for emb_name in seen_names:
|
||||
emb_hash = embedding_scanner.get_hash_by_filename(emb_name)
|
||||
if emb_hash:
|
||||
self._increment_usage_counter("embeddings", emb_hash, today_date)
|
||||
else:
|
||||
logger.debug(
|
||||
"No hash found for embedding '%s', skipping usage tracking",
|
||||
emb_name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Error processing embedding usage: %s", e, exc_info=True)
|
||||
|
||||
async def get_stats(self):
|
||||
"""Get current usage statistics"""
|
||||
return self.stats
|
||||
@@ -522,6 +588,9 @@ class UsageStats:
|
||||
elif model_type == "lora":
|
||||
if sha256 in self.stats["loras"]:
|
||||
return self.stats["loras"][sha256]["total"]
|
||||
elif model_type == "embedding":
|
||||
if sha256 in self.stats["embeddings"]:
|
||||
return self.stats["embeddings"][sha256]["total"]
|
||||
return 0
|
||||
|
||||
async def process_execution(self, prompt_id):
|
||||
|
||||
Reference in New Issue
Block a user