mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc4c11ddd2 | ||
|
|
d389e4d5d4 | ||
|
|
8cb78ad931 | ||
|
|
85f987d15c | ||
|
|
b12079e0f6 | ||
|
|
dcf5c6167a | ||
|
|
b395d3f487 | ||
|
|
37662cad10 | ||
|
|
aa1673063d | ||
|
|
f51f49eb60 | ||
|
|
54c9bac961 | ||
|
|
e70fd73bdd | ||
|
|
9bb9e7b64d | ||
|
|
f64c03543a | ||
|
|
51374de1a1 | ||
|
|
afcc12f263 | ||
|
|
88c5482366 | ||
|
|
bbf7295c32 | ||
|
|
ca5e23e68c | ||
|
|
eadb1487ae | ||
|
|
1faa70fc77 | ||
|
|
30d7c007de | ||
|
|
f54f6a4402 | ||
|
|
7b41cdec65 | ||
|
|
fb6a652a57 | ||
|
|
ea34d753c1 |
@@ -20,6 +20,12 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.8.8
|
||||
* **Real-time TriggerWord Updates** - Enhanced TriggerWord Toggle node to instantly update when connected Lora Loader or Lora Stacker nodes change, without requiring workflow execution
|
||||
* **Optimized Metadata Recovery** - Improved utilization of existing .civitai.info files for faster initialization and preservation of metadata from models deleted from CivitAI
|
||||
* **Migration Acceleration** - Further speed improvements for users transitioning from A1111/Forge environments
|
||||
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
|
||||
|
||||
### v0.8.7
|
||||
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
|
||||
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
|
||||
|
||||
30
py/config.py
30
py/config.py
@@ -103,21 +103,29 @@ class Config:
|
||||
|
||||
def _init_lora_paths(self) -> List[str]:
|
||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||
paths = sorted(set(path.replace(os.sep, "/")
|
||||
for path in folder_paths.get_folder_paths("loras")
|
||||
if os.path.exists(path)), key=lambda p: p.lower())
|
||||
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
|
||||
raw_paths = folder_paths.get_folder_paths("loras")
|
||||
|
||||
if not paths:
|
||||
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||
path_map = {}
|
||||
for path in raw_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
path_map[real_path] = path_map.get(real_path, path) # preserve first seen
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
print("Found LoRA roots:", "\n - " + "\n - ".join(unique_paths))
|
||||
|
||||
if not unique_paths:
|
||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||
|
||||
# 初始化路径映射
|
||||
for path in paths:
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
if real_path != path:
|
||||
self.add_path_mapping(path, real_path)
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return paths
|
||||
return unique_paths
|
||||
|
||||
|
||||
def _init_checkpoint_paths(self) -> List[str]:
|
||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||
|
||||
@@ -5,6 +5,8 @@ from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.usage_stats_routes import UsageStatsRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
import logging
|
||||
|
||||
@@ -92,6 +94,8 @@ class LoraManager:
|
||||
checkpoints_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
UsageStatsRoutes.setup_routes(app) # Register usage stats routes
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Constants used by the metadata collector"""
|
||||
|
||||
# Individual category constants
|
||||
# Metadata collection constants
|
||||
|
||||
# Metadata categories
|
||||
MODELS = "models"
|
||||
PROMPTS = "prompts"
|
||||
SAMPLING = "sampling"
|
||||
LORAS = "loras"
|
||||
SIZE = "size"
|
||||
IMAGES = "images" # Added new category for image results
|
||||
IMAGES = "images"
|
||||
|
||||
# Collection of categories for iteration
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories
|
||||
# Complete list of categories to track
|
||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||
|
||||
@@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,48 +32,6 @@ class LoraManagerLoader:
|
||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras"
|
||||
|
||||
async def get_lora_info(self, lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def _get_loras_list(self, kwargs):
|
||||
"""Helper to extract loras list from either old or new kwargs format"""
|
||||
if 'loras' not in kwargs:
|
||||
return []
|
||||
|
||||
loras_data = kwargs['loras']
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
return loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
|
||||
def load_loras(self, model, text, **kwargs):
|
||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||
@@ -89,14 +47,14 @@ class LoraManagerLoader:
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Then process loras from kwargs with support for both old and new formats
|
||||
loras_list = self._get_loras_list(kwargs)
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
@@ -105,7 +63,7 @@ class LoraManagerLoader:
|
||||
strength = float(lora['strength'])
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
|
||||
# Apply the LoRA using the resolved path
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
||||
|
||||
@@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -29,48 +29,6 @@ class LoraStacker:
|
||||
RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
||||
FUNCTION = "stack_loras"
|
||||
|
||||
async def get_lora_info(self, lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(self, lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def _get_loras_list(self, kwargs):
|
||||
"""Helper to extract loras list from either old or new kwargs format"""
|
||||
if 'loras' not in kwargs:
|
||||
return []
|
||||
|
||||
loras_data = kwargs['loras']
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
return loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
|
||||
def stack_loras(self, text, **kwargs):
|
||||
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
||||
@@ -84,12 +42,12 @@ class LoraStacker:
|
||||
stack.extend(lora_stack)
|
||||
# Get trigger words from existing stack entries
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = self.extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_list = self._get_loras_list(kwargs)
|
||||
loras_list = get_loras_list(kwargs)
|
||||
for lora in loras_list:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
@@ -99,7 +57,7 @@ class LoraStacker:
|
||||
clip_strength = model_strength # Using same strength for both as in the original loader
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
|
||||
# Add to stack without loading
|
||||
# replace '/' with os.sep to avoid different OS path format
|
||||
|
||||
@@ -5,6 +5,7 @@ import re
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
@@ -53,18 +54,55 @@ class SaveImage:
|
||||
async def get_lora_hash(self, lora_name):
|
||||
"""Get the lora hash from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Use the new direct filename lookup method
|
||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
return item.get('sha256')
|
||||
return None
|
||||
|
||||
async def get_checkpoint_hash(self, checkpoint_path):
|
||||
"""Get the checkpoint hash from cache"""
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Extract basename without extension
|
||||
checkpoint_name = os.path.basename(checkpoint_path)
|
||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||
|
||||
# Try direct filename lookup first
|
||||
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
normalized_path = checkpoint_path.replace('\\', '/')
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
|
||||
return item.get('sha256')
|
||||
|
||||
return None
|
||||
|
||||
async def format_metadata(self, metadata_dict):
|
||||
"""Format metadata in the requested format similar to userComment example"""
|
||||
if not metadata_dict:
|
||||
return ""
|
||||
|
||||
# Helper function to only add parameter if value is not None
|
||||
def add_param_if_not_none(param_list, label, value):
|
||||
if value is not None:
|
||||
param_list.append(f"{label}: {value}")
|
||||
|
||||
# Extract the prompt and negative prompt
|
||||
prompt = metadata_dict.get('prompt', '')
|
||||
negative_prompt = metadata_dict.get('negative_prompt', '')
|
||||
@@ -100,7 +138,11 @@ class SaveImage:
|
||||
|
||||
# Add standard parameters in the correct order
|
||||
if 'steps' in metadata_dict:
|
||||
params.append(f"Steps: {metadata_dict.get('steps')}")
|
||||
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
|
||||
|
||||
# Combine sampler and scheduler information
|
||||
sampler_name = None
|
||||
scheduler_name = None
|
||||
|
||||
if 'sampler' in metadata_dict:
|
||||
sampler = metadata_dict.get('sampler')
|
||||
@@ -123,7 +165,6 @@ class SaveImage:
|
||||
'ddim': 'DDIM'
|
||||
}
|
||||
sampler_name = sampler_mapping.get(sampler, sampler)
|
||||
params.append(f"Sampler: {sampler_name}")
|
||||
|
||||
if 'scheduler' in metadata_dict:
|
||||
scheduler = metadata_dict.get('scheduler')
|
||||
@@ -135,38 +176,48 @@ class SaveImage:
|
||||
'sgm_quadratic': 'SGM Quadratic'
|
||||
}
|
||||
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
||||
params.append(f"Schedule type: {scheduler_name}")
|
||||
|
||||
# CFG scale (cfg_scale in metadata_dict)
|
||||
if 'cfg_scale' in metadata_dict:
|
||||
params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
|
||||
# Add combined sampler and scheduler information
|
||||
if sampler_name:
|
||||
if scheduler_name:
|
||||
params.append(f"Sampler: {sampler_name} {scheduler_name}")
|
||||
else:
|
||||
params.append(f"Sampler: {sampler_name}")
|
||||
|
||||
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
|
||||
if 'guidance' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
|
||||
elif 'cfg_scale' in metadata_dict:
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
|
||||
elif 'cfg' in metadata_dict:
|
||||
params.append(f"CFG scale: {metadata_dict.get('cfg')}")
|
||||
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
|
||||
|
||||
# Seed
|
||||
if 'seed' in metadata_dict:
|
||||
params.append(f"Seed: {metadata_dict.get('seed')}")
|
||||
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
|
||||
|
||||
# Size
|
||||
if 'size' in metadata_dict:
|
||||
params.append(f"Size: {metadata_dict.get('size')}")
|
||||
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
|
||||
|
||||
# Model info
|
||||
if 'checkpoint' in metadata_dict:
|
||||
# Ensure checkpoint is a string before processing
|
||||
checkpoint = metadata_dict.get('checkpoint')
|
||||
if checkpoint is not None:
|
||||
# Handle both string and other types safely
|
||||
if isinstance(checkpoint, str):
|
||||
# Extract basename without path
|
||||
checkpoint = os.path.basename(checkpoint)
|
||||
# Remove extension if present
|
||||
checkpoint = os.path.splitext(checkpoint)[0]
|
||||
else:
|
||||
# Convert non-string to string
|
||||
checkpoint = str(checkpoint)
|
||||
# Get model hash
|
||||
model_hash = await self.get_checkpoint_hash(checkpoint)
|
||||
|
||||
params.append(f"Model: {checkpoint}")
|
||||
# Extract basename without path
|
||||
checkpoint_name = os.path.basename(checkpoint)
|
||||
# Remove extension if present
|
||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||
|
||||
# Add model hash if available
|
||||
if model_hash:
|
||||
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
||||
else:
|
||||
params.append(f"Model: {checkpoint_name}")
|
||||
|
||||
# Add LoRA hashes if available
|
||||
if lora_hashes:
|
||||
@@ -284,7 +335,7 @@ class SaveImage:
|
||||
if add_counter_to_filename:
|
||||
# Use counter + i to ensure unique filenames for all images in batch
|
||||
current_counter = counter + i
|
||||
base_filename += f"_{current_counter:05}"
|
||||
base_filename += f"_{current_counter:05}_"
|
||||
|
||||
# Set file extension and prepare saving parameters
|
||||
if file_format == "png":
|
||||
|
||||
@@ -47,10 +47,10 @@ class TriggerWordToggle:
|
||||
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
||||
|
||||
# Send trigger words to frontend
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": id,
|
||||
"message": trigger_words
|
||||
})
|
||||
# PromptServer.instance.send_sync("trigger_word_update", {
|
||||
# "id": id,
|
||||
# "message": trigger_words
|
||||
# })
|
||||
|
||||
filtered_triggers = trigger_words
|
||||
|
||||
|
||||
@@ -30,4 +30,55 @@ class FlexibleOptionalInputType(dict):
|
||||
return True
|
||||
|
||||
|
||||
any_type = AnyType("*")
|
||||
any_type = AnyType("*")
|
||||
|
||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_lora_info(lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
basename = os.path.basename(lora_path)
|
||||
return os.path.splitext(basename)[0]
|
||||
|
||||
def get_loras_list(kwargs):
|
||||
"""Helper to extract loras list from either old or new kwargs format"""
|
||||
if 'loras' not in kwargs:
|
||||
return []
|
||||
|
||||
loras_data = kwargs['loras']
|
||||
# Handle new format: {'loras': {'__value__': [...]}}
|
||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||
return loras_data['__value__']
|
||||
# Handle old format: {'loras': [...]}
|
||||
elif isinstance(loras_data, list):
|
||||
return loras_data
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
@@ -3,8 +3,10 @@ import json
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..nodes.utils import get_lora_info
|
||||
|
||||
from ..config import config
|
||||
from ..services.websocket_manager import ws_manager
|
||||
@@ -64,6 +66,9 @@ class ApiRoutes:
|
||||
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
|
||||
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
|
||||
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
|
||||
|
||||
# Add the new trigger words route
|
||||
app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words)
|
||||
|
||||
# Add update check routes
|
||||
UpdateRoutes.setup_routes(app)
|
||||
@@ -1021,4 +1026,35 @@ class ApiRoutes:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = await get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
69
py/routes/usage_stats_routes.py
Normal file
69
py/routes/usage_stats_routes.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from ..utils.usage_stats import UsageStats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UsageStatsRoutes:
|
||||
"""Routes for handling usage statistics updates"""
|
||||
|
||||
@staticmethod
|
||||
def setup_routes(app):
|
||||
"""Register usage stats routes"""
|
||||
app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats)
|
||||
app.router.add_get('/loras/api/get-usage-stats', UsageStatsRoutes.get_usage_stats)
|
||||
|
||||
@staticmethod
|
||||
async def update_usage_stats(request):
|
||||
"""
|
||||
Update usage statistics based on a prompt_id
|
||||
|
||||
Expects a JSON body with:
|
||||
{
|
||||
"prompt_id": "string"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
prompt_id = data.get('prompt_id')
|
||||
|
||||
if not prompt_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing prompt_id'
|
||||
}, status=400)
|
||||
|
||||
# Call the UsageStats to process this prompt_id synchronously
|
||||
usage_stats = UsageStats()
|
||||
await usage_stats.process_execution(prompt_id)
|
||||
|
||||
return web.json_response({
|
||||
'success': True
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def get_usage_stats(request):
|
||||
"""Get current usage statistics"""
|
||||
try:
|
||||
usage_stats = UsageStats()
|
||||
stats = await usage_stats.get_stats()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': stats
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
26
py/server_routes.py
Normal file
26
py/server_routes.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from aiohttp import web
|
||||
from server import PromptServer
|
||||
from .nodes.utils import get_lora_info
|
||||
|
||||
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
|
||||
async def get_trigger_words(request):
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = await get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
@@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Set
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .lora_hash_index import LoraHashIndex
|
||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||
from .settings_manager import settings
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match
|
||||
@@ -35,12 +35,12 @@ class LoraScanner(ModelScanner):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=LoraHashIndex()
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from typing import Dict, Optional, Set
|
||||
import os
|
||||
|
||||
class ModelHashIndex:
|
||||
"""Index for looking up models by hash or path"""
|
||||
|
||||
def __init__(self):
|
||||
self._hash_to_path: Dict[str, str] = {}
|
||||
self._path_to_hash: Dict[str, str] = {}
|
||||
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
|
||||
|
||||
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||
"""Add or update hash index entry"""
|
||||
@@ -15,37 +16,47 @@ class ModelHashIndex:
|
||||
# Ensure hash is lowercase for consistency
|
||||
sha256 = sha256.lower()
|
||||
|
||||
# Extract filename without extension
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
|
||||
# Remove old path mapping if hash exists
|
||||
if sha256 in self._hash_to_path:
|
||||
old_path = self._hash_to_path[sha256]
|
||||
if old_path in self._path_to_hash:
|
||||
del self._path_to_hash[old_path]
|
||||
old_filename = self._get_filename_from_path(old_path)
|
||||
if old_filename in self._filename_to_hash:
|
||||
del self._filename_to_hash[old_filename]
|
||||
|
||||
# Remove old hash mapping if path exists
|
||||
if file_path in self._path_to_hash:
|
||||
old_hash = self._path_to_hash[file_path]
|
||||
# Remove old hash mapping if filename exists
|
||||
if filename in self._filename_to_hash:
|
||||
old_hash = self._filename_to_hash[filename]
|
||||
if old_hash in self._hash_to_path:
|
||||
del self._hash_to_path[old_hash]
|
||||
|
||||
# Add new mappings
|
||||
self._hash_to_path[sha256] = file_path
|
||||
self._path_to_hash[file_path] = sha256
|
||||
self._filename_to_hash[filename] = sha256
|
||||
|
||||
def _get_filename_from_path(self, file_path: str) -> str:
|
||||
"""Extract filename without extension from path"""
|
||||
return os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
def remove_by_path(self, file_path: str) -> None:
|
||||
"""Remove entry by file path"""
|
||||
if file_path in self._path_to_hash:
|
||||
hash_val = self._path_to_hash[file_path]
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
if filename in self._filename_to_hash:
|
||||
hash_val = self._filename_to_hash[filename]
|
||||
if hash_val in self._hash_to_path:
|
||||
del self._hash_to_path[hash_val]
|
||||
del self._path_to_hash[file_path]
|
||||
del self._filename_to_hash[filename]
|
||||
|
||||
def remove_by_hash(self, sha256: str) -> None:
|
||||
"""Remove entry by hash"""
|
||||
sha256 = sha256.lower()
|
||||
if sha256 in self._hash_to_path:
|
||||
path = self._hash_to_path[sha256]
|
||||
if path in self._path_to_hash:
|
||||
del self._path_to_hash[path]
|
||||
filename = self._get_filename_from_path(path)
|
||||
if filename in self._filename_to_hash:
|
||||
del self._filename_to_hash[filename]
|
||||
del self._hash_to_path[sha256]
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
@@ -58,20 +69,27 @@ class ModelHashIndex:
|
||||
|
||||
def get_hash(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a file path"""
|
||||
return self._path_to_hash.get(file_path)
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
return self._filename_to_hash.get(filename)
|
||||
|
||||
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||
"""Get hash for a filename without extension"""
|
||||
# Strip extension if present to make the function more flexible
|
||||
filename = os.path.splitext(filename)[0]
|
||||
return self._filename_to_hash.get(filename)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all entries"""
|
||||
self._hash_to_path.clear()
|
||||
self._path_to_hash.clear()
|
||||
self._filename_to_hash.clear()
|
||||
|
||||
def get_all_hashes(self) -> Set[str]:
|
||||
"""Get all hashes in the index"""
|
||||
return set(self._hash_to_path.keys())
|
||||
|
||||
def get_all_paths(self) -> Set[str]:
|
||||
"""Get all file paths in the index"""
|
||||
return set(self._path_to_hash.keys())
|
||||
def get_all_filenames(self) -> Set[str]:
|
||||
"""Get all filenames in the index"""
|
||||
return set(self._filename_to_hash.keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of entries"""
|
||||
|
||||
@@ -292,7 +292,7 @@ class ModelScanner:
|
||||
)
|
||||
|
||||
# If force refresh is requested, initialize the cache directly
|
||||
if force_refresh:
|
||||
if (force_refresh):
|
||||
if self._cache is None:
|
||||
# For initial creation, do a full initialization
|
||||
await self._initialize_cache()
|
||||
@@ -553,9 +553,36 @@ class ModelScanner:
|
||||
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
||||
else:
|
||||
# Check if metadata exists but civitai field is empty - try to restore from civitai.info
|
||||
if metadata.civitai is None or metadata.civitai == {}:
|
||||
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
||||
if os.path.exists(civitai_info_path):
|
||||
try:
|
||||
with open(civitai_info_path, 'r', encoding='utf-8') as f:
|
||||
version_info = json.load(f)
|
||||
|
||||
logger.debug(f"Restoring missing civitai data from .civitai.info for {file_path}")
|
||||
metadata.civitai = version_info
|
||||
|
||||
# Ensure tags are also updated if they're missing
|
||||
if (not metadata.tags or len(metadata.tags) == 0) and 'model' in version_info:
|
||||
if 'tags' in version_info['model']:
|
||||
metadata.tags = version_info['model']['tags']
|
||||
|
||||
# Also restore description if missing
|
||||
if (not metadata.modelDescription or metadata.modelDescription == "") and 'model' in version_info:
|
||||
if 'description' in version_info['model']:
|
||||
metadata.modelDescription = version_info['model']['description']
|
||||
|
||||
# Save the updated metadata
|
||||
await save_metadata(file_path, metadata)
|
||||
logger.debug(f"Updated metadata with civitai info for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
|
||||
|
||||
if metadata is None:
|
||||
metadata = await self._get_file_info(file_path)
|
||||
if metadata is None:
|
||||
metadata = await self._get_file_info(file_path)
|
||||
|
||||
model_data = metadata.to_dict()
|
||||
|
||||
@@ -805,6 +832,10 @@ class ModelScanner:
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self._hash_index.get_hash(file_path)
|
||||
|
||||
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||
"""Get hash for a model by its filename without path"""
|
||||
return self._hash_index.get_hash_by_filename(filename)
|
||||
|
||||
# TODO: Adjust this method to use metadata instead of finding the file
|
||||
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
||||
|
||||
@@ -21,6 +21,7 @@ class BaseModelMetadata:
|
||||
civitai: Optional[Dict] = None # Civitai API data if available
|
||||
tags: List[str] = None # Model tags
|
||||
modelDescription: str = "" # Full model description
|
||||
civitai_deleted: bool = False # Whether deleted from Civitai
|
||||
|
||||
def __post_init__(self):
|
||||
# Initialize empty lists to avoid mutable default parameter issue
|
||||
@@ -64,6 +65,15 @@ class LoraMetadata(BaseModelMetadata):
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
|
||||
# Extract tags and description if available
|
||||
tags = []
|
||||
description = ""
|
||||
if 'model' in version_info:
|
||||
if 'tags' in version_info['model']:
|
||||
tags = version_info['model']['tags']
|
||||
if 'description' in version_info['model']:
|
||||
description = version_info['model']['description']
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
@@ -75,7 +85,9 @@ class LoraMetadata(BaseModelMetadata):
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0, # Will be updated after preview download
|
||||
from_civitai=True,
|
||||
civitai=version_info
|
||||
civitai=version_info,
|
||||
tags=tags,
|
||||
modelDescription=description
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@@ -90,6 +102,15 @@ class CheckpointMetadata(BaseModelMetadata):
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
model_type = version_info.get('type', 'checkpoint')
|
||||
|
||||
# Extract tags and description if available
|
||||
tags = []
|
||||
description = ""
|
||||
if 'model' in version_info:
|
||||
if 'tags' in version_info['model']:
|
||||
tags = version_info['model']['tags']
|
||||
if 'description' in version_info['model']:
|
||||
description = version_info['model']['description']
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
@@ -102,6 +123,8 @@ class CheckpointMetadata(BaseModelMetadata):
|
||||
preview_nsfw_level=0,
|
||||
from_civitai=True,
|
||||
civitai=version_info,
|
||||
model_type=model_type
|
||||
model_type=model_type,
|
||||
tags=tags,
|
||||
modelDescription=description
|
||||
)
|
||||
|
||||
|
||||
@@ -45,14 +45,14 @@ class RecipeMetadataParser(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info: Dict[str, Any],
|
||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
|
||||
"""
|
||||
Populate a lora entry with information from Civitai API response
|
||||
|
||||
Args:
|
||||
lora_entry: The lora entry to populate
|
||||
civitai_info: The response from Civitai API
|
||||
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
|
||||
recipe_scanner: Optional recipe scanner for local file lookup
|
||||
base_model_counts: Optional dict to track base model counts
|
||||
hash_value: Optional hash value to use if not available in civitai_info
|
||||
@@ -61,6 +61,9 @@ class RecipeMetadataParser(ABC):
|
||||
The populated lora_entry dict
|
||||
"""
|
||||
try:
|
||||
# Unpack the tuple to get the actual data
|
||||
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
|
||||
if civitai_info and civitai_info.get("error") != "Model not found":
|
||||
# Check if this is an early access lora
|
||||
if civitai_info.get('earlyAccessEndsAt'):
|
||||
@@ -241,11 +244,11 @@ class RecipeFormatParser(RecipeMetadataParser):
|
||||
# Try to get additional info from Civitai if we have a model version ID
|
||||
if lora.get('modelVersionId') and civitai_client:
|
||||
try:
|
||||
civitai_info = await civitai_client.get_model_version_info(lora['modelVersionId'])
|
||||
civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
|
||||
# Populate lora entry with Civitai info
|
||||
lora_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
civitai_info_tuple,
|
||||
recipe_scanner,
|
||||
None, # No need to track base model counts
|
||||
lora['hash']
|
||||
@@ -336,12 +339,13 @@ class StandardMetadataParser(RecipeMetadataParser):
|
||||
# Get additional info from Civitai if client is available
|
||||
if civitai_client:
|
||||
try:
|
||||
civitai_info = await civitai_client.get_model_version_info(model_version_id)
|
||||
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
|
||||
# Populate lora entry with Civitai info
|
||||
lora_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner
|
||||
civitai_info_tuple,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
||||
@@ -621,11 +625,11 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
||||
# Get additional info from Civitai if client is available
|
||||
if civitai_client:
|
||||
try:
|
||||
civitai_info = await civitai_client.get_model_version_info(model_version_id)
|
||||
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
|
||||
# Populate lora entry with Civitai info
|
||||
lora_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
civitai_info_tuple,
|
||||
recipe_scanner
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -660,7 +664,8 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
||||
# Get additional checkpoint info from Civitai
|
||||
if civitai_client:
|
||||
try:
|
||||
civitai_info = await civitai_client.get_model_version_info(checkpoint_version_id)
|
||||
civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
|
||||
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||
# Populate checkpoint with Civitai info
|
||||
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
||||
except Exception as e:
|
||||
|
||||
267
py/utils/usage_stats.py
Normal file
267
py/utils/usage_stats.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
|
||||
from ..config import config
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||
from ..metadata_collector.constants import MODELS, LORAS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UsageStats:
|
||||
"""Track usage statistics for models and save to JSON"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock() # For thread safety
|
||||
|
||||
# Default stats file name
|
||||
STATS_FILENAME = "lora_manager_stats.json"
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Initialize stats storage
|
||||
self.stats = {
|
||||
"checkpoints": {}, # sha256 -> count
|
||||
"loras": {}, # sha256 -> count
|
||||
"total_executions": 0,
|
||||
"last_save_time": 0
|
||||
}
|
||||
|
||||
# Queue for prompt_ids to process
|
||||
self.pending_prompt_ids = set()
|
||||
|
||||
# Load existing stats if available
|
||||
self._stats_file_path = self._get_stats_file_path()
|
||||
self._load_stats()
|
||||
|
||||
# Save interval in seconds
|
||||
self.save_interval = 90 # 1.5 minutes
|
||||
|
||||
# Start background task to process queued prompt_ids
|
||||
self._bg_task = asyncio.create_task(self._background_processor())
|
||||
|
||||
self._initialized = True
|
||||
logger.info("Usage statistics tracker initialized")
|
||||
|
||||
def _get_stats_file_path(self) -> str:
|
||||
"""Get the path to the stats JSON file"""
|
||||
if not config.loras_roots or len(config.loras_roots) == 0:
|
||||
# Fallback to temporary directory if no lora roots
|
||||
return os.path.join(config.temp_directory, self.STATS_FILENAME)
|
||||
|
||||
# Use the first lora root
|
||||
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
|
||||
|
||||
def _load_stats(self):
|
||||
"""Load existing statistics from file"""
|
||||
try:
|
||||
if os.path.exists(self._stats_file_path):
|
||||
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
|
||||
loaded_stats = json.load(f)
|
||||
|
||||
# Update our stats with loaded data
|
||||
if isinstance(loaded_stats, dict):
|
||||
# Update individual sections to maintain structure
|
||||
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
|
||||
self.stats["checkpoints"] = loaded_stats["checkpoints"]
|
||||
|
||||
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
|
||||
self.stats["loras"] = loaded_stats["loras"]
|
||||
|
||||
if "total_executions" in loaded_stats:
|
||||
self.stats["total_executions"] = loaded_stats["total_executions"]
|
||||
|
||||
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading usage statistics: {e}")
|
||||
|
||||
async def save_stats(self, force=False):
|
||||
"""Save statistics to file"""
|
||||
try:
|
||||
# Only save if it's been at least save_interval since last save or force is True
|
||||
current_time = time.time()
|
||||
if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval:
|
||||
return False
|
||||
|
||||
# Use a lock to prevent concurrent writes
|
||||
async with self._lock:
|
||||
# Update last save time
|
||||
self.stats["last_save_time"] = current_time
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
|
||||
|
||||
# Write to a temporary file first, then move it to avoid corruption
|
||||
temp_path = f"{self._stats_file_path}.tmp"
|
||||
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Replace the old file with the new one
|
||||
os.replace(temp_path, self._stats_file_path)
|
||||
|
||||
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def register_execution(self, prompt_id):
|
||||
"""Register a completed execution by prompt_id for later processing"""
|
||||
if prompt_id:
|
||||
self.pending_prompt_ids.add(prompt_id)
|
||||
|
||||
async def _background_processor(self):
|
||||
"""Background task to process queued prompt_ids"""
|
||||
try:
|
||||
while True:
|
||||
# Wait a short interval before checking for new prompt_ids
|
||||
await asyncio.sleep(5) # Check every 5 seconds
|
||||
|
||||
# Process any pending prompt_ids
|
||||
if self.pending_prompt_ids:
|
||||
async with self._lock:
|
||||
# Get a copy of the set and clear original
|
||||
prompt_ids = self.pending_prompt_ids.copy()
|
||||
self.pending_prompt_ids.clear()
|
||||
|
||||
# Process each prompt_id
|
||||
registry = MetadataRegistry()
|
||||
for prompt_id in prompt_ids:
|
||||
try:
|
||||
metadata = registry.get_metadata(prompt_id)
|
||||
await self._process_metadata(metadata)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
|
||||
|
||||
# Periodically save stats
|
||||
await self.save_stats()
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled, clean up
|
||||
await self.save_stats(force=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in background processing task: {e}", exc_info=True)
|
||||
# Restart the task after a delay if it fails
|
||||
asyncio.create_task(self._restart_background_task())
|
||||
|
||||
async def _restart_background_task(self):
|
||||
"""Restart the background task after a delay"""
|
||||
await asyncio.sleep(30) # Wait 30 seconds before restarting
|
||||
self._bg_task = asyncio.create_task(self._background_processor())
|
||||
|
||||
async def _process_metadata(self, metadata):
|
||||
"""Process metadata from an execution"""
|
||||
if not metadata or not isinstance(metadata, dict):
|
||||
return
|
||||
|
||||
# Increment total executions count
|
||||
self.stats["total_executions"] += 1
|
||||
|
||||
# Process checkpoints
|
||||
if MODELS in metadata and isinstance(metadata[MODELS], dict):
|
||||
await self._process_checkpoints(metadata[MODELS])
|
||||
|
||||
# Process loras
|
||||
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
||||
await self._process_loras(metadata[LORAS])
|
||||
|
||||
async def _process_checkpoints(self, models_data):
|
||||
"""Process checkpoint models from metadata"""
|
||||
try:
|
||||
# Get checkpoint scanner service
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
if not checkpoint_scanner:
|
||||
logger.warning("Checkpoint scanner not available for usage tracking")
|
||||
return
|
||||
|
||||
for node_id, model_info in models_data.items():
|
||||
if not isinstance(model_info, dict):
|
||||
continue
|
||||
|
||||
# Check if this is a checkpoint model
|
||||
model_type = model_info.get("type")
|
||||
if model_type == "checkpoint":
|
||||
model_name = model_info.get("name")
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
# Clean up filename (remove extension if present)
|
||||
model_filename = os.path.splitext(os.path.basename(model_name))[0]
|
||||
|
||||
# Get hash for this checkpoint
|
||||
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
||||
if model_hash:
|
||||
# Update stats for this checkpoint
|
||||
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
|
||||
|
||||
async def _process_loras(self, loras_data):
|
||||
"""Process LoRA models from metadata"""
|
||||
try:
|
||||
# Get LoRA scanner service
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
if not lora_scanner:
|
||||
logger.warning("LoRA scanner not available for usage tracking")
|
||||
return
|
||||
|
||||
for node_id, lora_info in loras_data.items():
|
||||
if not isinstance(lora_info, dict):
|
||||
continue
|
||||
|
||||
# Get the list of LoRAs from standardized format
|
||||
lora_list = lora_info.get("lora_list", [])
|
||||
for lora in lora_list:
|
||||
if not isinstance(lora, dict):
|
||||
continue
|
||||
|
||||
lora_name = lora.get("name")
|
||||
if not lora_name:
|
||||
continue
|
||||
|
||||
# Get hash for this LoRA
|
||||
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
|
||||
if lora_hash:
|
||||
# Update stats for this LoRA
|
||||
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||
|
||||
async def get_stats(self):
|
||||
"""Get current usage statistics"""
|
||||
return self.stats
|
||||
|
||||
async def get_model_usage_count(self, model_type, sha256):
|
||||
"""Get usage count for a specific model by hash"""
|
||||
if model_type == "checkpoint":
|
||||
return self.stats["checkpoints"].get(sha256, 0)
|
||||
elif model_type == "lora":
|
||||
return self.stats["loras"].get(sha256, 0)
|
||||
return 0
|
||||
|
||||
async def process_execution(self, prompt_id):
|
||||
"""Process a prompt execution immediately (synchronous approach)"""
|
||||
if not prompt_id:
|
||||
return
|
||||
|
||||
try:
|
||||
# Process metadata for this prompt_id
|
||||
registry = MetadataRegistry()
|
||||
metadata = registry.get_metadata(prompt_id)
|
||||
if metadata:
|
||||
await self._process_metadata(metadata)
|
||||
# Save stats if needed
|
||||
await self.save_stats()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "comfyui-lora-manager"
|
||||
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
||||
version = "0.8.7"
|
||||
version = "0.8.8"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { showToast } from '../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||
import { state } from '../state/index.js';
|
||||
import { showCheckpointModal } from './checkpointModal/index.js';
|
||||
import { NSFW_LEVELS } from '../utils/constants.js';
|
||||
@@ -204,21 +204,7 @@ export function createCheckpointCard(checkpoint) {
|
||||
const checkpointName = card.dataset.file_name;
|
||||
|
||||
try {
|
||||
// Modern clipboard API
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(checkpointName);
|
||||
} else {
|
||||
// Fallback for older browsers
|
||||
const textarea = document.createElement('textarea');
|
||||
textarea.value = checkpointName;
|
||||
textarea.style.position = 'absolute';
|
||||
textarea.style.left = '-99999px';
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
document.execCommand('copy');
|
||||
document.body.removeChild(textarea);
|
||||
}
|
||||
showToast('Checkpoint name copied', 'success');
|
||||
await copyToClipboard(checkpointName, 'Checkpoint name copied');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { showToast, openCivitai } from '../utils/uiHelpers.js';
|
||||
import { showToast, openCivitai, copyToClipboard } from '../utils/uiHelpers.js';
|
||||
import { state } from '../state/index.js';
|
||||
import { showLoraModal } from './loraModal/index.js';
|
||||
import { bulkManager } from '../managers/BulkManager.js';
|
||||
@@ -205,26 +205,7 @@ export function createLoraCard(lora) {
|
||||
const strength = usageTips.strength || 1;
|
||||
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
||||
|
||||
try {
|
||||
// Modern clipboard API
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(loraSyntax);
|
||||
} else {
|
||||
// Fallback for older browsers
|
||||
const textarea = document.createElement('textarea');
|
||||
textarea.value = loraSyntax;
|
||||
textarea.style.position = 'absolute';
|
||||
textarea.style.left = '-99999px';
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
document.execCommand('copy');
|
||||
document.body.removeChild(textarea);
|
||||
}
|
||||
showToast('LoRA syntax copied', 'success');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
}
|
||||
await copyToClipboard(loraSyntax, 'LoRA syntax copied');
|
||||
});
|
||||
|
||||
// Civitai button click event
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// Recipe Card Component
|
||||
import { showToast } from '../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||
import { modalManager } from '../managers/ModalManager.js';
|
||||
|
||||
class RecipeCard {
|
||||
@@ -109,14 +109,11 @@ class RecipeCard {
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success && data.syntax) {
|
||||
return navigator.clipboard.writeText(data.syntax);
|
||||
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||
} else {
|
||||
throw new Error(data.error || 'No syntax returned');
|
||||
}
|
||||
})
|
||||
.then(() => {
|
||||
showToast('Recipe syntax copied to clipboard', 'success');
|
||||
})
|
||||
.catch(err => {
|
||||
console.error('Failed to copy: ', err);
|
||||
showToast('Failed to copy recipe syntax', 'error');
|
||||
@@ -279,4 +276,4 @@ class RecipeCard {
|
||||
}
|
||||
}
|
||||
|
||||
export { RecipeCard };
|
||||
export { RecipeCard };
|
||||
@@ -1,5 +1,5 @@
|
||||
// Recipe Modal Component
|
||||
import { showToast } from '../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||
import { state } from '../state/index.js';
|
||||
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
|
||||
|
||||
@@ -747,9 +747,8 @@ class RecipeModal {
|
||||
const data = await response.json();
|
||||
|
||||
if (data.success && data.syntax) {
|
||||
// Copy to clipboard
|
||||
await navigator.clipboard.writeText(data.syntax);
|
||||
showToast('Recipe syntax copied to clipboard', 'success');
|
||||
// Use the centralized copyToClipboard utility function
|
||||
await copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||
} else {
|
||||
throw new Error(data.error || 'No syntax returned from server');
|
||||
}
|
||||
@@ -761,12 +760,7 @@ class RecipeModal {
|
||||
|
||||
// Helper method to copy text to clipboard
|
||||
copyToClipboard(text, successMessage) {
|
||||
navigator.clipboard.writeText(text).then(() => {
|
||||
showToast(successMessage, 'success');
|
||||
}).catch(err => {
|
||||
console.error('Failed to copy text: ', err);
|
||||
showToast('Failed to copy text', 'error');
|
||||
});
|
||||
copyToClipboard(text, successMessage);
|
||||
}
|
||||
|
||||
// Add new method to handle downloading missing LoRAs
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* ShowcaseView.js
|
||||
* Handles showcase content (images, videos) display for checkpoint modal
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||
import { state } from '../../state/index.js';
|
||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||
|
||||
@@ -307,8 +307,7 @@ function initMetadataPanelHandlers(container) {
|
||||
if (!promptElement) return;
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(promptElement.textContent);
|
||||
showToast('Prompt copied to clipboard', 'success');
|
||||
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/**
|
||||
* RecipeTab - Handles the recipes tab in the Lora Modal
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
||||
|
||||
/**
|
||||
@@ -172,14 +172,11 @@ function copyRecipeSyntax(recipeId) {
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
if (data.success && data.syntax) {
|
||||
return navigator.clipboard.writeText(data.syntax);
|
||||
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||
} else {
|
||||
throw new Error(data.error || 'No syntax returned');
|
||||
}
|
||||
})
|
||||
.then(() => {
|
||||
showToast('Recipe syntax copied to clipboard', 'success');
|
||||
})
|
||||
.catch(err => {
|
||||
console.error('Failed to copy: ', err);
|
||||
showToast('Failed to copy recipe syntax', 'error');
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* ShowcaseView.js
|
||||
* 处理LoRA模型展示内容(图片、视频)的功能模块
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||
import { state } from '../../state/index.js';
|
||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||
|
||||
@@ -311,8 +311,7 @@ function initMetadataPanelHandlers(container) {
|
||||
if (!promptElement) return;
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(promptElement.textContent);
|
||||
showToast('Prompt copied to clipboard', 'success');
|
||||
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* TriggerWords.js
|
||||
* 处理LoRA模型触发词相关的功能模块
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||
import { saveModelMetadata } from './ModelMetadata.js';
|
||||
|
||||
/**
|
||||
@@ -336,23 +336,7 @@ async function saveTriggerWords() {
|
||||
*/
|
||||
window.copyTriggerWord = async function(word) {
|
||||
try {
|
||||
// Modern clipboard API - with fallback for non-secure contexts
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(word);
|
||||
} else {
|
||||
// Fallback for older browsers or non-secure contexts
|
||||
const textarea = document.createElement('textarea');
|
||||
textarea.value = word;
|
||||
textarea.style.position = 'absolute';
|
||||
textarea.style.left = '-99999px';
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
const success = document.execCommand('copy');
|
||||
document.body.removeChild(textarea);
|
||||
|
||||
if (!success) throw new Error('Copy command failed');
|
||||
}
|
||||
showToast('Trigger word copied', 'success');
|
||||
await copyToClipboard(word, 'Trigger word copied');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
*
|
||||
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
|
||||
*/
|
||||
import { showToast } from '../../utils/uiHelpers.js';
|
||||
import { state } from '../../state/index.js';
|
||||
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||
import { modalManager } from '../../managers/ModalManager.js';
|
||||
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
|
||||
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
||||
@@ -174,8 +173,7 @@ export function showLoraModal(lora) {
|
||||
// Copy file name function
|
||||
window.copyFileName = async function(fileName) {
|
||||
try {
|
||||
await navigator.clipboard.writeText(fileName);
|
||||
showToast('File name copied', 'success');
|
||||
await copyToClipboard(fileName, 'File name copied');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { state } from '../state/index.js';
|
||||
import { showToast } from '../utils/uiHelpers.js';
|
||||
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||
import { updateCardsForBulkMode } from '../components/LoraCard.js';
|
||||
|
||||
export class BulkManager {
|
||||
@@ -205,13 +205,7 @@ export class BulkManager {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
await navigator.clipboard.writeText(loraSyntaxes.join(', '));
|
||||
showToast(`Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`, 'success');
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
}
|
||||
await copyToClipboard(loraSyntaxes.join(', '), `Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`);
|
||||
}
|
||||
|
||||
// Create and show the thumbnail strip of selected LoRAs
|
||||
|
||||
@@ -2,6 +2,40 @@ import { state } from '../state/index.js';
|
||||
import { resetAndReload } from '../api/loraApi.js';
|
||||
import { getStorageItem, setStorageItem } from './storageHelpers.js';
|
||||
|
||||
/**
|
||||
* Utility function to copy text to clipboard with fallback for older browsers
|
||||
* @param {string} text - The text to copy to clipboard
|
||||
* @param {string} successMessage - Optional success message to show in toast
|
||||
* @returns {Promise<boolean>} - Promise that resolves to true if copy was successful
|
||||
*/
|
||||
export async function copyToClipboard(text, successMessage = 'Copied to clipboard') {
|
||||
try {
|
||||
// Modern clipboard API
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
await navigator.clipboard.writeText(text);
|
||||
} else {
|
||||
// Fallback for older browsers
|
||||
const textarea = document.createElement('textarea');
|
||||
textarea.value = text;
|
||||
textarea.style.position = 'absolute';
|
||||
textarea.style.left = '-99999px';
|
||||
document.body.appendChild(textarea);
|
||||
textarea.select();
|
||||
document.execCommand('copy');
|
||||
document.body.removeChild(textarea);
|
||||
}
|
||||
|
||||
if (successMessage) {
|
||||
showToast(successMessage, 'success');
|
||||
}
|
||||
return true;
|
||||
} catch (err) {
|
||||
console.error('Copy failed:', err);
|
||||
showToast('Copy failed', 'error');
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export function showToast(message, type = 'info') {
|
||||
const toast = document.createElement('div');
|
||||
toast.className = `toast toast-${type}`;
|
||||
@@ -108,12 +142,6 @@ export function toggleFolder(tag) {
|
||||
resetAndReload();
|
||||
}
|
||||
|
||||
export function copyTriggerWord(word) {
|
||||
navigator.clipboard.writeText(word).then(() => {
|
||||
showToast('Trigger word copied', 'success');
|
||||
});
|
||||
}
|
||||
|
||||
function filterByFolder(folderPath) {
|
||||
document.querySelectorAll('.lora-card').forEach(card => {
|
||||
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
|
||||
|
||||
@@ -927,10 +927,6 @@ export function addLorasWidget(node, name, opts, callback) {
|
||||
// Function to directly save the recipe without dialog
|
||||
async function saveRecipeDirectly(widget) {
|
||||
try {
|
||||
// Get the workflow data from the ComfyUI app
|
||||
const prompt = await app.graphToPrompt();
|
||||
console.log('Prompt:', prompt);
|
||||
|
||||
// Show loading toast
|
||||
if (app && app.extensionManager && app.extensionManager.toast) {
|
||||
app.extensionManager.toast.add({
|
||||
@@ -941,14 +937,9 @@ async function saveRecipeDirectly(widget) {
|
||||
});
|
||||
}
|
||||
|
||||
// Prepare the data - only send workflow JSON
|
||||
const formData = new FormData();
|
||||
formData.append('workflow_json', JSON.stringify(prompt.output));
|
||||
|
||||
// Send the request
|
||||
const response = await fetch('/api/recipes/save-from-widget', {
|
||||
method: 'POST',
|
||||
body: formData
|
||||
method: 'POST'
|
||||
});
|
||||
|
||||
const result = await response.json();
|
||||
|
||||
@@ -9,6 +9,54 @@ async function getLorasWidgetModule() {
|
||||
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
||||
}
|
||||
|
||||
// Function to get connected trigger toggle nodes
|
||||
function getConnectedTriggerToggleNodes(node) {
|
||||
const connectedNodes = [];
|
||||
|
||||
// Check if node has outputs
|
||||
if (node.outputs && node.outputs.length > 0) {
|
||||
// For each output slot
|
||||
for (const output of node.outputs) {
|
||||
// Check if this output has any links
|
||||
if (output.links && output.links.length > 0) {
|
||||
// For each link, get the target node
|
||||
for (const linkId of output.links) {
|
||||
const link = app.graph.links[linkId];
|
||||
if (link) {
|
||||
const targetNode = app.graph.getNodeById(link.target_id);
|
||||
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||
connectedNodes.push(targetNode.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return connectedNodes;
|
||||
}
|
||||
|
||||
// Function to update trigger words for connected toggle nodes
|
||||
function updateConnectedTriggerWords(node, text) {
|
||||
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
||||
if (connectedNodeIds.length > 0) {
|
||||
const loraNames = new Set();
|
||||
let match;
|
||||
LORA_PATTERN.lastIndex = 0;
|
||||
while ((match = LORA_PATTERN.exec(text)) !== null) {
|
||||
loraNames.add(match[1]);
|
||||
}
|
||||
|
||||
fetch("/loramanager/get_trigger_words", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
lora_names: Array.from(loraNames),
|
||||
node_ids: connectedNodeIds
|
||||
})
|
||||
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||
}
|
||||
}
|
||||
|
||||
function mergeLoras(lorasText, lorasArr) {
|
||||
const result = [];
|
||||
let match;
|
||||
@@ -99,6 +147,9 @@ app.registerExtension({
|
||||
newText = newText.replace(/\s+/g, ' ').trim();
|
||||
|
||||
inputWidget.value = newText;
|
||||
|
||||
// Add this line to update trigger words when lorasWidget changes cause inputWidget value to change
|
||||
updateConnectedTriggerWords(node, newText);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
@@ -117,6 +168,9 @@ app.registerExtension({
|
||||
const mergedLoras = mergeLoras(value, currentLoras);
|
||||
|
||||
node.lorasWidget.value = mergedLoras;
|
||||
|
||||
// Replace the existing trigger word update code with the new function
|
||||
updateConnectedTriggerWords(node, value);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,58 @@
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { addLorasWidget } from "./loras_widget.js";
|
||||
import { dynamicImportByVersion } from "./utils.js";
|
||||
|
||||
// Extract pattern into a constant for consistent use
|
||||
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
|
||||
|
||||
// Function to get the appropriate loras widget based on ComfyUI version
|
||||
async function getLorasWidgetModule() {
|
||||
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
||||
}
|
||||
|
||||
// Function to get connected trigger toggle nodes
|
||||
function getConnectedTriggerToggleNodes(node) {
|
||||
const connectedNodes = [];
|
||||
|
||||
if (node.outputs && node.outputs.length > 0) {
|
||||
for (const output of node.outputs) {
|
||||
if (output.links && output.links.length > 0) {
|
||||
for (const linkId of output.links) {
|
||||
const link = app.graph.links[linkId];
|
||||
if (link) {
|
||||
const targetNode = app.graph.getNodeById(link.target_id);
|
||||
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||
connectedNodes.push(targetNode.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return connectedNodes;
|
||||
}
|
||||
|
||||
// Function to update trigger words for connected toggle nodes
|
||||
function updateConnectedTriggerWords(node, text) {
|
||||
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
||||
if (connectedNodeIds.length > 0) {
|
||||
const loraNames = new Set();
|
||||
let match;
|
||||
LORA_PATTERN.lastIndex = 0;
|
||||
while ((match = LORA_PATTERN.exec(text)) !== null) {
|
||||
loraNames.add(match[1]);
|
||||
}
|
||||
|
||||
fetch("/loramanager/get_trigger_words", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
lora_names: Array.from(loraNames),
|
||||
node_ids: connectedNodeIds
|
||||
})
|
||||
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||
}
|
||||
}
|
||||
|
||||
function mergeLoras(lorasText, lorasArr) {
|
||||
const result = [];
|
||||
let match;
|
||||
@@ -40,7 +89,7 @@ app.registerExtension({
|
||||
});
|
||||
|
||||
// Wait for node to be properly initialized
|
||||
requestAnimationFrame(() => {
|
||||
requestAnimationFrame(async () => {
|
||||
// Restore saved value if exists
|
||||
let existingLoras = [];
|
||||
if (node.widgets_values && node.widgets_values.length > 0) {
|
||||
@@ -64,7 +113,10 @@ app.registerExtension({
|
||||
// Add flag to prevent callback loops
|
||||
let isUpdating = false;
|
||||
|
||||
// Get the widget object directly from the returned object
|
||||
// Dynamically load the appropriate widget module
|
||||
const lorasModule = await getLorasWidgetModule();
|
||||
const { addLorasWidget } = lorasModule;
|
||||
|
||||
const result = addLorasWidget(node, "loras", {
|
||||
defaultVal: mergedLoras // Pass object directly
|
||||
}, (value) => {
|
||||
@@ -86,6 +138,9 @@ app.registerExtension({
|
||||
newText = newText.replace(/\s+/g, ' ').trim();
|
||||
|
||||
inputWidget.value = newText;
|
||||
|
||||
// Update trigger words when lorasWidget changes
|
||||
updateConnectedTriggerWords(node, newText);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
@@ -104,6 +159,9 @@ app.registerExtension({
|
||||
const mergedLoras = mergeLoras(value, currentLoras);
|
||||
|
||||
node.lorasWidget.value = mergedLoras;
|
||||
|
||||
// Update trigger words when input changes
|
||||
updateConnectedTriggerWords(node, value);
|
||||
} finally {
|
||||
isUpdating = false;
|
||||
}
|
||||
|
||||
36
web/comfyui/usage_stats.js
Normal file
36
web/comfyui/usage_stats.js
Normal file
@@ -0,0 +1,36 @@
|
||||
// ComfyUI extension to track model usage statistics
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { api } from "../../scripts/api.js";
|
||||
|
||||
// Register the extension
|
||||
app.registerExtension({
|
||||
name: "ComfyUI-Lora-Manager.UsageStats",
|
||||
|
||||
init() {
|
||||
// Listen for successful executions
|
||||
api.addEventListener("execution_success", ({ detail }) => {
|
||||
if (detail && detail.prompt_id) {
|
||||
this.updateUsageStats(detail.prompt_id);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
async updateUsageStats(promptId) {
|
||||
try {
|
||||
// Call backend endpoint with the prompt_id
|
||||
const response = await fetch(`/loras/api/update-usage-stats`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ prompt_id: promptId }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
console.warn("Failed to update usage statistics:", response.statusText);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating usage statistics:", error);
|
||||
}
|
||||
}
|
||||
});
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user