mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-10 04:42:14 -03:00
Merge pull request #887 from NubeBuster/feat/usage-extractors
feat(usage-stats): add extractors for rgthree Power LoRA Loader and TensorRT loaders
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES, IS_SAMPLER
|
||||
|
||||
@@ -427,6 +429,75 @@ class ImageSizeExtractor(NodeMetadataExtractor):
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class RgthreePowerLoraLoaderExtractor(NodeMetadataExtractor):
|
||||
"""Extract LoRA metadata from rgthree Power Lora Loader.
|
||||
|
||||
The node passes LoRAs as dynamic kwargs: LORA_1, LORA_2, ... each containing
|
||||
{'on': bool, 'lora': filename, 'strength': float, 'strengthTwo': float}.
|
||||
"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
active_loras = []
|
||||
for key, value in inputs.items():
|
||||
if not key.upper().startswith('LORA_'):
|
||||
continue
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
if not value.get('on') or not value.get('lora'):
|
||||
continue
|
||||
lora_name = os.path.splitext(os.path.basename(value['lora']))[0]
|
||||
active_loras.append({
|
||||
"name": lora_name,
|
||||
"strength": round(float(value.get('strength', 1.0)), 2)
|
||||
})
|
||||
|
||||
if active_loras:
|
||||
metadata[LORAS][node_id] = {
|
||||
"lora_list": active_loras,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
|
||||
class TensorRTLoaderExtractor(NodeMetadataExtractor):
|
||||
"""Extract checkpoint metadata from TensorRT Loader.
|
||||
|
||||
extract() parses the engine filename from 'unet_name' as a best-effort
|
||||
fallback (strips profile suffix after '_$' and counter suffix).
|
||||
|
||||
update() checks if the output MODEL has attachments["source_model"]
|
||||
set by the node (NubeBuster fork) and overrides with the real name.
|
||||
Vanilla TRT doesn't set this — the filename parse stands.
|
||||
"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs or "unet_name" not in inputs:
|
||||
return
|
||||
unet_name = inputs.get("unet_name")
|
||||
# Strip path and extension, then drop the $_profile suffix
|
||||
model_name = os.path.splitext(os.path.basename(unet_name))[0]
|
||||
if "_$" in model_name:
|
||||
model_name = model_name[:model_name.index("_$")]
|
||||
# Strip counter suffix (e.g. _00001_) left by ComfyUI's save path
|
||||
model_name = re.sub(r'_\d+_?$', '', model_name)
|
||||
_store_checkpoint_metadata(metadata, node_id, model_name)
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
if not outputs or not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return
|
||||
first_output = outputs[0]
|
||||
if not isinstance(first_output, tuple) or len(first_output) < 1:
|
||||
return
|
||||
model = first_output[0]
|
||||
# NubeBuster fork sets attachments["source_model"] on the ModelPatcher
|
||||
source_model = getattr(model, 'attachments', {}).get("source_model")
|
||||
if source_model:
|
||||
_store_checkpoint_metadata(metadata, node_id, source_model)
|
||||
|
||||
|
||||
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -577,8 +648,6 @@ class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
import json
|
||||
|
||||
class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
@@ -715,6 +784,8 @@ NODE_EXTRACTORS = {
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"LoraLoader": LoraLoaderExtractor,
|
||||
"LoraLoaderLM": LoraLoaderManagerExtractor,
|
||||
"RgthreePowerLoraLoader": RgthreePowerLoraLoaderExtractor,
|
||||
"TensorRTLoader": TensorRTLoaderExtractor,
|
||||
# Conditioning
|
||||
"CLIPTextEncode": CLIPTextEncodeExtractor,
|
||||
"PromptLM": CLIPTextEncodeExtractor,
|
||||
|
||||
@@ -317,21 +317,23 @@ class UsageStats:
|
||||
|
||||
# Get hash for this checkpoint
|
||||
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
||||
if model_hash:
|
||||
# Update stats for this checkpoint with date tracking
|
||||
if model_hash not in self.stats["checkpoints"]:
|
||||
self.stats["checkpoints"][model_hash] = {
|
||||
"total": 0,
|
||||
"history": {}
|
||||
}
|
||||
|
||||
# Increment total count
|
||||
self.stats["checkpoints"][model_hash]["total"] += 1
|
||||
|
||||
# Increment today's count
|
||||
if today_date not in self.stats["checkpoints"][model_hash]["history"]:
|
||||
self.stats["checkpoints"][model_hash]["history"][today_date] = 0
|
||||
self.stats["checkpoints"][model_hash]["history"][today_date] += 1
|
||||
if not model_hash:
|
||||
logger.warning(f"No hash found for checkpoint '{model_filename}', tracking by name")
|
||||
stat_key = model_hash or f"name:{model_filename}"
|
||||
# Update stats for this checkpoint with date tracking
|
||||
if stat_key not in self.stats["checkpoints"]:
|
||||
self.stats["checkpoints"][stat_key] = {
|
||||
"total": 0,
|
||||
"history": {}
|
||||
}
|
||||
|
||||
# Increment total count
|
||||
self.stats["checkpoints"][stat_key]["total"] += 1
|
||||
|
||||
# Increment today's count
|
||||
if today_date not in self.stats["checkpoints"][stat_key]["history"]:
|
||||
self.stats["checkpoints"][stat_key]["history"][today_date] = 0
|
||||
self.stats["checkpoints"][stat_key]["history"][today_date] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
|
||||
|
||||
@@ -360,21 +362,23 @@ class UsageStats:
|
||||
|
||||
# Get hash for this LoRA
|
||||
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
|
||||
if lora_hash:
|
||||
# Update stats for this LoRA with date tracking
|
||||
if lora_hash not in self.stats["loras"]:
|
||||
self.stats["loras"][lora_hash] = {
|
||||
"total": 0,
|
||||
"history": {}
|
||||
}
|
||||
|
||||
# Increment total count
|
||||
self.stats["loras"][lora_hash]["total"] += 1
|
||||
|
||||
# Increment today's count
|
||||
if today_date not in self.stats["loras"][lora_hash]["history"]:
|
||||
self.stats["loras"][lora_hash]["history"][today_date] = 0
|
||||
self.stats["loras"][lora_hash]["history"][today_date] += 1
|
||||
if not lora_hash:
|
||||
logger.warning(f"No hash found for LoRA '{lora_name}', tracking by name")
|
||||
stat_key = lora_hash or f"name:{lora_name}"
|
||||
# Update stats for this LoRA with date tracking
|
||||
if stat_key not in self.stats["loras"]:
|
||||
self.stats["loras"][stat_key] = {
|
||||
"total": 0,
|
||||
"history": {}
|
||||
}
|
||||
|
||||
# Increment total count
|
||||
self.stats["loras"][stat_key]["total"] += 1
|
||||
|
||||
# Increment today's count
|
||||
if today_date not in self.stats["loras"][stat_key]["history"]:
|
||||
self.stats["loras"][stat_key]["history"][today_date] = 0
|
||||
self.stats["loras"][stat_key]["history"][today_date] += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user