From f565cc35caf49485ce361234e0f5930568fe5f94 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Thu, 11 Jun 2026 17:12:34 +0800 Subject: [PATCH] =?UTF-8?q?feat(stats):=20track=20embedding=20usage=20from?= =?UTF-8?q?=20prompt=20text=20=E2=80=94=20Plan=20A=20+=20hybrid=20approach?= =?UTF-8?q?=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .omo/plans/embeddings-hybrid-approach.md | 181 +++++++++++++++++++++++ py/metadata_collector/constants.py | 3 +- py/utils/usage_stats.py | 89 +++++++++-- 3 files changed, 262 insertions(+), 11 deletions(-) create mode 100644 .omo/plans/embeddings-hybrid-approach.md diff --git a/.omo/plans/embeddings-hybrid-approach.md b/.omo/plans/embeddings-hybrid-approach.md new file mode 100644 index 00000000..23b45c17 --- /dev/null +++ b/.omo/plans/embeddings-hybrid-approach.md @@ -0,0 +1,181 @@ +# Embeddings Usage Tracking — Hybrid Approach (Plan C) + +> **Status**: Reference document for future implementation +> **Current implementation**: Plan A (prompt text parsing only, see `usage_stats.py:_process_embeddings`) +> **Next step**: Add Plan B as a supplement when edge-case coverage is needed + +## Problem + +Embeddings in ComfyUI are not loaded through dedicated ComfyUI nodes like LoRAs or +Checkpoints. They are resolved during CLIP tokenization when the prompt text contains +`embedding:` syntax (see `comfy/sd1_clip.py:SDTokenizer.tokenize_with_weights`). + +This means the existing metadata_collector hook (which intercepts node execution via +`_map_node_over_list`) cannot capture embeddings the same way it captures LoRAs and +checkpoints — there is no "EmbeddingLoader" node to intercept. + +## Solution Architecture + +The hybrid approach combines **two complementary mechanisms** to capture embedding +usage from all possible paths. + +``` +┌─────────────────────────────────────────────────────────┐ +│ Plan A (已实现) │ +│ │ +│ MetadataRegistry.prompt_metadata["prompts"] │ +│ │ │ +│ ▼ │ +│ _process_embeddings() │ +│ │ │ +│ ├─ Iterate all prompt node texts │ +│ ├─ regex extract "embedding:" │ +│ ├─ resolve name → sha256 via EmbeddingScanner │ +│ └─ UsageStats.stats["embeddings"][sha256]++ │ +│ │ +│ Coverage: ~95% — all CLIPTextEncode/Flux/etc nodes │ +│ │ +│ Gap: Custom nodes that load embeddings programmatically │ +│ without putting embedding:name in prompt text │ +└─────────────────────────────────────────────────────────┘ + + + + ↓ (future: enable Plan B when needed) + +┌─────────────────────────────────────────────────────────┐ +│ Plan B (未来 — monkey-patch) │ +│ │ +│ comfy/sd1_clip.py:load_embed() │ +│ │ │ +│ ▼ │ +│ Monkey-patch intercepts EVERY embedding file load │ +│ │ │ +│ ├─ Records embedding_name + success/failure │ +│ ├─ Associates with current prompt_id (via registry)│ +│ └─ Feeds into UsageStats same as Plan A │ +│ │ +│ Coverage: 100% — catches ALL embedding loads │ +│ │ +│ Cost: Requires patching into ComfyUI internals │ +│ (sd1_clip.py, sdxl_clip.py, some text_encoders) │ +└─────────────────────────────────────────────────────────┘ +``` + +## Plan B Detail — Monkey-patch `load_embed` + +### Target Function + +**`comfy.sd1_clip.load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None)`** +at line 415 of `sd1_clip.py`. + +This is the **single choke point** for all embedding file loads in ComfyUI. Every +CLIP variant (SD1, SDXL, SD3, Flux) calls this same function. + +### Implementation Sketch + +```python +# In metadata_collector/metadata_hook.py (or a new module) +import comfy.sd1_clip as sd1_clip + +_original_load_embed = sd1_clip.load_embed + +def _patched_load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None): + result = _original_load_embed( + embedding_name, embedding_directory, embedding_size, embed_key + ) + if result is not None: + _record_embedding_usage(embedding_name) + return result + +sd1_clip.load_embed = _patched_load_embed +``` + +### Prompt ID Association + +The challenge is associating the `load_embed` call with the current `prompt_id`. +Options: + +1. **Thread-local / contextvar**: Store current `prompt_id` in a `contextvars.ContextVar` + that the metadata_collector sets at the start of each prompt execution. + +2. **MetadataRegistry singleton**: The MetadataRegistry already has `current_prompt_id`. + The patch can read it directly since both run in the same thread. + +3. **Lazy aggregation**: Instead of associating with prompt_id at load time, collect + all loaded embedding names in a global set during execution, then flush to + UsageStats after the prompt completes. + +### Files to Patch + +| File | Function | Coverage | +|------|----------|----------| +| `comfy/sd1_clip.py:415` | `load_embed()` | Primary — SD1.x, SDXL, SD3, Flux | +| `comfy/sdxl_clip.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — | +| `comfy/text_encoders/sd3_clip.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — | +| `comfy/text_encoders/flux.py` | Not needed (calls `sd1_clip.SDTokenizer`) | — | + +The SD1 tokenizer is the base class for all CLIP variants' tokenizers, so patching +`load_embed` covers them all. + +### Edge Cases + +| Edge Case | Plan A | Plan B | +|-----------|--------|--------| +| `embedding:name` in CLIPTextEncode | ✅ | ✅ | +| `embedding:name` in CLIPTextEncodeFlux | ✅ | ✅ | +| `embedding:name` in PromptLM (LoRA Manager) | ✅ | ✅ | +| `embedding:name` in WAS_Text_to_Conditioning | ✅ | ✅ | +| Custom node that loads embedding programmatically | ❌ | ✅ | +| Embedding loaded multiple times in same prompt | ✅ (dedup via set) | ✅ (dedup via set) | +| Embedding file not found | N/A | ✅ (can log) | +| Embedding dimension mismatch | N/A | ✅ (can log) | +| Text encoder with non-standard tokenizer (LLaMA, T5...) | Partial | ✅ (if it calls load_embed) | + +## Migration Path: Standalone → Hybrid + +### Phase 1 — Plan A (当前状态) +- Prompt text parsing only +- No monkey-patching required +- Covers all standard workflows + +### Phase 2 — Enable Plan B (未来工作) +1. Add monkey-patch of `load_embed` in `metadata_collector/metadata_hook.py` (alongside + the existing `_map_node_over_list` hook) +2. Collect loaded embedding names in a `set()` on the registry +3. In `UsageStats._process_embeddings()`, merge the Plan A results (from prompt text) + with the Plan B results (from the patch) +4. Add `prompt_data` field on MetadataRegistry to store loaded embeddings per prompt + +### Deduplication + +```python +# Merge Plan A + Plan B results in _process_embeddings +plan_a_names = extract_from_prompt_texts(prompts_data) +plan_b_names = registry.get_loaded_embeddings(prompt_id) + +all_names = plan_a_names | plan_b_names +``` + +## Testing the Hybrid + +| Scenario | What to verify | +|----------|---------------| +| Standard `embedding:name` in prompt | Plan A captures it | +| Embedding loaded by custom node script | Plan B captures it | +| Both paths fire for same embedding | No double-counting (dedup) | +| Embedding name resolves to hash | EmbeddingScanner.get_hash_by_filename works | +| No embedding scanner available | Graceful skip, no crash | +| Missing embedding file | Plan B logs warning, Plan A skips gracefully | +| Empty prompt | No crash, no entries | +| Standalone mode | Both plans disabled gracefully | + +## Key Files Reference + +| File | Role | +|------|------| +| `py/utils/usage_stats.py` | Core — `_process_embeddings()` for Plan A | +| `py/metadata_collector/constants.py` | `EMBEDDINGS` category constant | +| `py/metadata_collector/metadata_hook.py` | Future — monkey-patch for Plan B | +| `py/services/embedding_scanner.py` | Hash resolution service | +| `py/routes/stats_routes.py` | Already handles `usage_data.get('embeddings', {})` | +| `comfy/sd1_clip.py` (ComfyUI) | `load_embed()` — Plan B target | diff --git a/py/metadata_collector/constants.py b/py/metadata_collector/constants.py index b38f010a..85072c05 100644 --- a/py/metadata_collector/constants.py +++ b/py/metadata_collector/constants.py @@ -5,9 +5,10 @@ MODELS = "models" PROMPTS = "prompts" SAMPLING = "sampling" LORAS = "loras" +EMBEDDINGS = "embeddings" SIZE = "size" IMAGES = "images" IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes # Complete list of categories to track -METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] +METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, EMBEDDINGS, SIZE, IMAGES] diff --git a/py/utils/usage_stats.py b/py/utils/usage_stats.py index b9fe4ac8..a603668c 100644 --- a/py/utils/usage_stats.py +++ b/py/utils/usage_stats.py @@ -1,4 +1,5 @@ import os +import re import json import time import asyncio @@ -16,14 +17,18 @@ standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.en # Define constants locally to avoid dependency on conditional imports MODELS = "models" LORAS = "loras" +EMBEDDINGS = "embeddings" +PROMPTS = "prompts" if not standalone_mode: from ..metadata_collector.metadata_registry import MetadataRegistry # Import constants from metadata_collector to ensure consistency, but we have fallbacks defined above try: - from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS + from ..metadata_collector.constants import MODELS as _MODELS, LORAS as _LORAS, EMBEDDINGS as _EMBEDDINGS, PROMPTS as _PROMPTS MODELS = _MODELS LORAS = _LORAS + EMBEDDINGS = _EMBEDDINGS + PROMPTS = _PROMPTS except ImportError: pass # Use the local definitions @@ -65,6 +70,7 @@ class UsageStats: self.stats = { "checkpoints": {}, # sha256 -> { total: count, history: { date: count } } "loras": {}, # sha256 -> { total: count, history: { date: count } } + "embeddings": {}, # sha256 -> { total: count, history: { date: count } } "total_executions": 0, "last_save_time": 0 } @@ -115,6 +121,7 @@ class UsageStats: new_stats = { "checkpoints": {}, "loras": {}, + "embeddings": {}, "total_executions": old_stats.get("total_executions", 0), "last_save_time": old_stats.get("last_save_time", time.time()) } @@ -142,21 +149,27 @@ class UsageStats: } } + # Convert embedding stats (if present in old format) + if "embeddings" in old_stats and isinstance(old_stats["embeddings"], dict): + for hash_id, count in old_stats["embeddings"].items(): + new_stats["embeddings"][hash_id] = { + "total": count, + "history": { + today: count + } + } + logger.info("Successfully converted stats from old format to new format with history") return new_stats def _is_old_format(self, stats): """Check if the stats are in the old format (direct count values)""" # Check if any lora or checkpoint entry is a direct number instead of an object - if "loras" in stats and isinstance(stats["loras"], dict): - for hash_id, data in stats["loras"].items(): - if isinstance(data, (int, float)): - return True - - if "checkpoints" in stats and isinstance(stats["checkpoints"], dict): - for hash_id, data in stats["checkpoints"].items(): - if isinstance(data, (int, float)): - return True + for category in ("loras", "checkpoints", "embeddings"): + if category in stats and isinstance(stats[category], dict): + for hash_id, data in stats[category].items(): + if isinstance(data, (int, float)): + return True return False @@ -304,6 +317,10 @@ class UsageStats: if LORAS in metadata and isinstance(metadata[LORAS], dict): await self._process_loras(metadata[LORAS], today) + # Process embeddings — parse prompt text for embedding:name references + if PROMPTS in metadata and isinstance(metadata[PROMPTS], dict): + await self._process_embeddings(metadata[PROMPTS], today) + def _increment_usage_counter(self, category: str, stat_key: str, today_date: str) -> None: """Increment usage counters for a resolved stats key.""" if stat_key not in self.stats[category]: @@ -510,6 +527,55 @@ class UsageStats: except Exception as e: logger.error(f"Error processing LoRA usage: {e}", exc_info=True) + @staticmethod + def _extract_embedding_names(prompt_text: str) -> set: + """Parse embedding:name references from prompt text. + + ComfyUI's SDTokenizer resolves ``embedding:`` during tokenization + (see ``sd1_clip.py _try_get_embedding``). This mirrors the same pattern + to extract embedding file names from the captured prompt strings. + """ + if not prompt_text: + return set() + # Matches ``embedding:name`` where name is alphanumeric plus _ . - / + names = re.findall(r"embedding:([a-zA-Z0-9_.\-/]+)", prompt_text) + return set(names) + + async def _process_embeddings(self, prompts_data, today_date): + """Extract embedding usage from prompt texts and record it. + + Iterates every prompt node's text field captured by the metadata + collector, extracts ``embedding:`` references, resolves each + name to its SHA256 hash via the embedding scanner, and increments + usage counters. + """ + try: + embedding_scanner = await ServiceRegistry.get_embedding_scanner() + if not embedding_scanner: + logger.warning("Embedding scanner not available for usage tracking") + return + + seen_names = set() + for _node_id, prompt_data in prompts_data.items(): + if not isinstance(prompt_data, dict): + continue + for text_field in ("text", "positive_text", "negative_text"): + text = prompt_data.get(text_field) + if isinstance(text, str): + seen_names.update(self._extract_embedding_names(text)) + + for emb_name in seen_names: + emb_hash = embedding_scanner.get_hash_by_filename(emb_name) + if emb_hash: + self._increment_usage_counter("embeddings", emb_hash, today_date) + else: + logger.debug( + "No hash found for embedding '%s', skipping usage tracking", + emb_name, + ) + except Exception as e: + logger.error("Error processing embedding usage: %s", e, exc_info=True) + async def get_stats(self): """Get current usage statistics""" return self.stats @@ -522,6 +588,9 @@ class UsageStats: elif model_type == "lora": if sha256 in self.stats["loras"]: return self.stats["loras"][sha256]["total"] + elif model_type == "embedding": + if sha256 in self.stats["embeddings"]: + return self.stats["embeddings"][sha256]["total"] return 0 async def process_execution(self, prompt_id):