mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-06-11 13:19:24 -03:00
feat(stats): track embedding usage from prompt text — Plan A + hybrid approach docs
This commit is contained in:
181
.omo/plans/embeddings-hybrid-approach.md
Normal file
181
.omo/plans/embeddings-hybrid-approach.md
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
# Embeddings Usage Tracking — Hybrid Approach (Plan C)
|
||||||
|
|
||||||
|
> **Status**: Reference document for future implementation
|
||||||
|
> **Current implementation**: Plan A (prompt text parsing only, see `usage_stats.py:_process_embeddings`)
|
||||||
|
> **Next step**: Add Plan B as a supplement when edge-case coverage is needed
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
|
||||||
|
Embeddings in ComfyUI are not loaded through dedicated ComfyUI nodes like LoRAs or
|
||||||
|
Checkpoints. They are resolved during CLIP tokenization when the prompt text contains
|
||||||
|
`embedding:<name>` syntax (see `comfy/sd1_clip.py:SDTokenizer.tokenize_with_weights`).
|
||||||
|
|
||||||
|
This means the existing metadata_collector hook (which intercepts node execution via
|
||||||
|
`_map_node_over_list`) cannot capture embeddings the same way it captures LoRAs and
|
||||||
|
checkpoints — there is no "EmbeddingLoader" node to intercept.
|
||||||
|
|
||||||
|
## Solution Architecture
|
||||||
|
|
||||||
|
The hybrid approach combines **two complementary mechanisms** to capture embedding
|
||||||
|
usage from all possible paths.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ Plan A (已实现) │
|
||||||
|
│ │
|
||||||
|
│ MetadataRegistry.prompt_metadata["prompts"] │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ _process_embeddings() │
|
||||||
|
│ │ │
|
||||||
|
│ ├─ Iterate all prompt node texts │
|
||||||
|
│ ├─ regex extract "embedding:<name>" │
|
||||||
|
│ ├─ resolve name → sha256 via EmbeddingScanner │
|
||||||
|
│ └─ UsageStats.stats["embeddings"][sha256]++ │
|
||||||
|
│ │
|
||||||
|
│ Coverage: ~95% — all CLIPTextEncode/Flux/etc nodes │
|
||||||
|
│ │
|
||||||
|
│ Gap: Custom nodes that load embeddings programmatically │
|
||||||
|
│ without putting embedding:name in prompt text │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
+
|
||||||
|
↓ (future: enable Plan B when needed)
|
||||||
|
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ Plan B (未来 — monkey-patch) │
|
||||||
|
│ │
|
||||||
|
│ comfy/sd1_clip.py:load_embed() │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ Monkey-patch intercepts EVERY embedding file load │
|
||||||
|
│ │ │
|
||||||
|
│ ├─ Records embedding_name + success/failure │
|
||||||
|
│ ├─ Associates with current prompt_id (via registry)│
|
||||||
|
│ └─ Feeds into UsageStats same as Plan A │
|
||||||
|
│ │
|
||||||
|
│ Coverage: 100% — catches ALL embedding loads │
|
||||||
|
│ │
|
||||||
|
│ Cost: Requires patching into ComfyUI internals │
|
||||||
|
│ (sd1_clip.py, sdxl_clip.py, some text_encoders) │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## Plan B Detail — Monkey-patch `load_embed`
|
||||||
|
|
||||||
|
### Target Function
|
||||||
|
|
||||||
|
**`comfy.sd1_clip.load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None)`**
|
||||||
|
at line 415 of `sd1_clip.py`.
|
||||||
|
|
||||||
|
This is the **single choke point** for all embedding file loads in ComfyUI. Every
|
||||||
|
CLIP variant (SD1, SDXL, SD3, Flux) calls this same function.
|
||||||
|
|
||||||
|
### Implementation Sketch
|
||||||
|
|
||||||
|
```python
|
||||||
|
# In metadata_collector/metadata_hook.py (or a new module)
|
||||||
|
import comfy.sd1_clip as sd1_clip
|
||||||
|
|
||||||
|
_original_load_embed = sd1_clip.load_embed
|
||||||
|
|
||||||
|
def _patched_load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||||
|
result = _original_load_embed(
|
||||||
|
embedding_name, embedding_directory, embedding_size, embed_key
|
||||||
|
)
|
||||||
|
if result is not None:
|
||||||
|
_record_embedding_usage(embedding_name)
|
||||||
|
return result
|
||||||
|
|
||||||
|
sd1_clip.load_embed = _patched_load_embed
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prompt ID Association
|
||||||
|
|
||||||
|
The challenge is associating the `load_embed` call with the current `prompt_id`.
|
||||||
|
Options:
|
||||||
|
|
||||||
|
1. **Thread-local / contextvar**: Store current `prompt_id` in a `contextvars.ContextVar`
|
||||||
|
that the metadata_collector sets at the start of each prompt execution.
|
||||||
|
|
||||||
|
2. **MetadataRegistry singleton**: The MetadataRegistry already has `current_prompt_id`.
|
||||||
|
The patch can read it directly since both run in the same thread.
|
||||||
|
|
||||||
|
3. **Lazy aggregation**: Instead of associating with prompt_id at load time, collect
|
||||||
|
all loaded embedding names in a global set during execution, then flush to
|
||||||
|
UsageStats after the prompt completes.
|
||||||
|
|
||||||
|
### Files to Patch
|
||||||
|
|
||||||
|
| File | Function | Coverage |
|
||||||
|
|------|----------|----------|
|
||||||
|
| `comfy/sd1_clip.py:415` | `load_embed()` | Primary — SD1.x, SDXL, SD3, Flux |
|
||||||
|
| `comfy/sdxl_clip.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — |
|
||||||
|
| `comfy/text_encoders/sd3_clip.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — |
|
||||||
|
| `comfy/text_encoders/flux.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — |
|
||||||
|
|
||||||
|
The SD1 tokenizer is the base class for all CLIP variants' tokenizers, so patching
|
||||||
|
`load_embed` covers them all.
|
||||||
|
|
||||||
|
### Edge Cases
|
||||||
|
|
||||||
|
| Edge Case | Plan A | Plan B |
|
||||||
|
|-----------|--------|--------|
|
||||||
|
| `embedding:name` in CLIPTextEncode | ✅ | ✅ |
|
||||||
|
| `embedding:name` in CLIPTextEncodeFlux | ✅ | ✅ |
|
||||||
|
| `embedding:name` in PromptLM (LoRA Manager) | ✅ | ✅ |
|
||||||
|
| `embedding:name` in WAS_Text_to_Conditioning | ✅ | ✅ |
|
||||||
|
| Custom node that loads embedding programmatically | ❌ | ✅ |
|
||||||
|
| Embedding loaded multiple times in same prompt | ✅ (dedup via set) | ✅ (dedup via set) |
|
||||||
|
| Embedding file not found | N/A | ✅ (can log) |
|
||||||
|
| Embedding dimension mismatch | N/A | ✅ (can log) |
|
||||||
|
| Text encoder with non-standard tokenizer (LLaMA, T5...) | Partial | ✅ (if it calls load_embed) |
|
||||||
|
|
||||||
|
## Migration Path: Standalone → Hybrid
|
||||||
|
|
||||||
|
### Phase 1 — Plan A (当前状态)
|
||||||
|
- Prompt text parsing only
|
||||||
|
- No monkey-patching required
|
||||||
|
- Covers all standard workflows
|
||||||
|
|
||||||
|
### Phase 2 — Enable Plan B (未来工作)
|
||||||
|
1. Add monkey-patch of `load_embed` in `metadata_collector/metadata_hook.py` (alongside
|
||||||
|
the existing `_map_node_over_list` hook)
|
||||||
|
2. Collect loaded embedding names in a `set()` on the registry
|
||||||
|
3. In `UsageStats._process_embeddings()`, merge the Plan A results (from prompt text)
|
||||||
|
with the Plan B results (from the patch)
|
||||||
|
4. Add `prompt_data` field on MetadataRegistry to store loaded embeddings per prompt
|
||||||
|
|
||||||
|
### Deduplication
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Merge Plan A + Plan B results in _process_embeddings
|
||||||
|
plan_a_names = extract_from_prompt_texts(prompts_data)
|
||||||
|
plan_b_names = registry.get_loaded_embeddings(prompt_id)
|
||||||
|
|
||||||
|
all_names = plan_a_names | plan_b_names
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing the Hybrid
|
||||||
|
|
||||||
|
| Scenario | What to verify |
|
||||||
|
|----------|---------------|
|
||||||
|
| Standard `embedding:name` in prompt | Plan A captures it |
|
||||||
|
| Embedding loaded by custom node script | Plan B captures it |
|
||||||
|
| Both paths fire for same embedding | No double-counting (dedup) |
|
||||||
|
| Embedding name resolves to hash | EmbeddingScanner.get_hash_by_filename works |
|
||||||
|
| No embedding scanner available | Graceful skip, no crash |
|
||||||
|
| Missing embedding file | Plan B logs warning, Plan A skips gracefully |
|
||||||
|
| Empty prompt | No crash, no entries |
|
||||||
|
| Standalone mode | Both plans disabled gracefully |
|
||||||
|
|
||||||
|
## Key Files Reference
|
||||||
|
|
||||||
|
| File | Role |
|
||||||
|
|------|------|
|
||||||
|
| `py/utils/usage_stats.py` | Core — `_process_embeddings()` for Plan A |
|
||||||
|
| `py/metadata_collector/constants.py` | `EMBEDDINGS` category constant |
|
||||||
|
| `py/metadata_collector/metadata_hook.py` | Future — monkey-patch for Plan B |
|
||||||
|
| `py/services/embedding_scanner.py` | Hash resolution service |
|
||||||
|
| `py/routes/stats_routes.py` | Already handles `usage_data.get('embeddings', {})` |
|
||||||
|
| `comfy/sd1_clip.py` (ComfyUI) | `load_embed()` — Plan B target |
|
||||||
@@ -5,9 +5,10 @@ MODELS = "models"
|
|||||||
PROMPTS = "prompts"
|
PROMPTS = "prompts"
|
||||||
SAMPLING = "sampling"
|
SAMPLING = "sampling"
|
||||||
LORAS = "loras"
|
LORAS = "loras"
|
||||||
|
EMBEDDINGS = "embeddings"
|
||||||
SIZE = "size"
|
SIZE = "size"
|
||||||
IMAGES = "images"
|
IMAGES = "images"
|
||||||
IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes
|
IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes
|
||||||
|
|
||||||
# Complete list of categories to track
|
# Complete list of categories to track
|
||||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, EMBEDDINGS, SIZE, IMAGES]
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import re
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -16,14 +17,18 @@ standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.en
|
|||||||
# Define constants locally to avoid dependency on conditional imports
|
# Define constants locally to avoid dependency on conditional imports
|
||||||
MODELS = "models"
|
MODELS = "models"
|
||||||
LORAS = "loras"
|
LORAS = "loras"
|
||||||
|
EMBEDDINGS = "embeddings"
|
||||||
|
PROMPTS = "prompts"
|
||||||
|
|
||||||
if not standalone_mode:
|
if not standalone_mode:
|
||||||
from ..metadata_collector.metadata_registry import MetadataRegistry
|
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||||
# Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above
|
# Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above
|
||||||
try:
|
try:
|
||||||
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS
|
from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS, EMBEDDINGS as _EMBEDDINGS, PROMPTS as _PROMPTS
|
||||||
MODELS = _MODELS
|
MODELS = _MODELS
|
||||||
LORAS = _LORAS
|
LORAS = _LORAS
|
||||||
|
EMBEDDINGS = _EMBEDDINGS
|
||||||
|
PROMPTS = _PROMPTS
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass # Use the local definitions
|
pass # Use the local definitions
|
||||||
|
|
||||||
@@ -65,6 +70,7 @@ class UsageStats:
|
|||||||
self.stats = {
|
self.stats = {
|
||||||
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
|
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
|
||||||
"loras": {}, # sha256 -> { total: count, history: { date: count } }
|
"loras": {}, # sha256 -> { total: count, history: { date: count } }
|
||||||
|
"embeddings": {}, # sha256 -> { total: count, history: { date: count } }
|
||||||
"total_executions": 0,
|
"total_executions": 0,
|
||||||
"last_save_time": 0
|
"last_save_time": 0
|
||||||
}
|
}
|
||||||
@@ -115,6 +121,7 @@ class UsageStats:
|
|||||||
new_stats = {
|
new_stats = {
|
||||||
"checkpoints": {},
|
"checkpoints": {},
|
||||||
"loras": {},
|
"loras": {},
|
||||||
|
"embeddings": {},
|
||||||
"total_executions": old_stats.get("total_executions", 0),
|
"total_executions": old_stats.get("total_executions", 0),
|
||||||
"last_save_time": old_stats.get("last_save_time", time.time())
|
"last_save_time": old_stats.get("last_save_time", time.time())
|
||||||
}
|
}
|
||||||
@@ -142,19 +149,25 @@ class UsageStats:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Convert embedding stats (if present in old format)
|
||||||
|
if "embeddings" in old_stats and isinstance(old_stats["embeddings"], dict):
|
||||||
|
for hash_id, count in old_stats["embeddings"].items():
|
||||||
|
new_stats["embeddings"][hash_id] = {
|
||||||
|
"total": count,
|
||||||
|
"history": {
|
||||||
|
today: count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.info("Successfully converted stats from old format to new format with history")
|
logger.info("Successfully converted stats from old format to new format with history")
|
||||||
return new_stats
|
return new_stats
|
||||||
|
|
||||||
def _is_old_format(self, stats):
|
def _is_old_format(self, stats):
|
||||||
"""Check if the stats are in the old format (direct count values)"""
|
"""Check if the stats are in the old format (direct count values)"""
|
||||||
# Check if any lora or checkpoint entry is a direct number instead of an object
|
# Check if any lora or checkpoint entry is a direct number instead of an object
|
||||||
if "loras" in stats and isinstance(stats["loras"], dict):
|
for category in ("loras", "checkpoints", "embeddings"):
|
||||||
for hash_id, data in stats["loras"].items():
|
if category in stats and isinstance(stats[category], dict):
|
||||||
if isinstance(data, (int, float)):
|
for hash_id, data in stats[category].items():
|
||||||
return True
|
|
||||||
|
|
||||||
if "checkpoints" in stats and isinstance(stats["checkpoints"], dict):
|
|
||||||
for hash_id, data in stats["checkpoints"].items():
|
|
||||||
if isinstance(data, (int, float)):
|
if isinstance(data, (int, float)):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -304,6 +317,10 @@ class UsageStats:
|
|||||||
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
||||||
await self._process_loras(metadata[LORAS], today)
|
await self._process_loras(metadata[LORAS], today)
|
||||||
|
|
||||||
|
# Process embeddings — parse prompt text for embedding:name references
|
||||||
|
if PROMPTS in metadata and isinstance(metadata[PROMPTS], dict):
|
||||||
|
await self._process_embeddings(metadata[PROMPTS], today)
|
||||||
|
|
||||||
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
|
def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None:
|
||||||
"""Increment usage counters for a resolved stats key."""
|
"""Increment usage counters for a resolved stats key."""
|
||||||
if stat_key not in self.stats[category]:
|
if stat_key not in self.stats[category]:
|
||||||
@@ -510,6 +527,55 @@ class UsageStats:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_embedding_names(prompt_text: str) -> set:
|
||||||
|
"""Parse embedding:name references from prompt text.
|
||||||
|
|
||||||
|
ComfyUI's SDTokenizer resolves ``embedding:<name>`` during tokenization
|
||||||
|
(see ``sd1_clip.py _try_get_embedding``). This mirrors the same pattern
|
||||||
|
to extract embedding file names from the captured prompt strings.
|
||||||
|
"""
|
||||||
|
if not prompt_text:
|
||||||
|
return set()
|
||||||
|
# Matches ``embedding:name`` where name is alphanumeric plus _ . - /
|
||||||
|
names = re.findall(r"embedding:([a-zA-Z0-9_.\-/]+)", prompt_text)
|
||||||
|
return set(names)
|
||||||
|
|
||||||
|
async def _process_embeddings(self, prompts_data, today_date):
|
||||||
|
"""Extract embedding usage from prompt texts and record it.
|
||||||
|
|
||||||
|
Iterates every prompt node's text field captured by the metadata
|
||||||
|
collector, extracts ``embedding:<name>`` references, resolves each
|
||||||
|
name to its SHA256 hash via the embedding scanner, and increments
|
||||||
|
usage counters.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
|
if not embedding_scanner:
|
||||||
|
logger.warning("Embedding scanner not available for usage tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
seen_names = set()
|
||||||
|
for _node_id, prompt_data in prompts_data.items():
|
||||||
|
if not isinstance(prompt_data, dict):
|
||||||
|
continue
|
||||||
|
for text_field in ("text", "positive_text", "negative_text"):
|
||||||
|
text = prompt_data.get(text_field)
|
||||||
|
if isinstance(text, str):
|
||||||
|
seen_names.update(self._extract_embedding_names(text))
|
||||||
|
|
||||||
|
for emb_name in seen_names:
|
||||||
|
emb_hash = embedding_scanner.get_hash_by_filename(emb_name)
|
||||||
|
if emb_hash:
|
||||||
|
self._increment_usage_counter("embeddings", emb_hash, today_date)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"No hash found for embedding '%s', skipping usage tracking",
|
||||||
|
emb_name,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error processing embedding usage: %s", e, exc_info=True)
|
||||||
|
|
||||||
async def get_stats(self):
|
async def get_stats(self):
|
||||||
"""Get current usage statistics"""
|
"""Get current usage statistics"""
|
||||||
return self.stats
|
return self.stats
|
||||||
@@ -522,6 +588,9 @@ class UsageStats:
|
|||||||
elif model_type == "lora":
|
elif model_type == "lora":
|
||||||
if sha256 in self.stats["loras"]:
|
if sha256 in self.stats["loras"]:
|
||||||
return self.stats["loras"][sha256]["total"]
|
return self.stats["loras"][sha256]["total"]
|
||||||
|
elif model_type == "embedding":
|
||||||
|
if sha256 in self.stats["embeddings"]:
|
||||||
|
return self.stats["embeddings"][sha256]["total"]
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
async def process_execution(self, prompt_id):
|
async def process_execution(self, prompt_id):
|
||||||
|
|||||||
Reference in New Issue
Block a user