Compare commits

..

26 Commits

Author SHA1 Message Date
Will Miao
dc4c11ddd2 feat: Update release notes and version to 0.8.8 with new features and bug fixes 2025-04-22 13:29:00 +08:00
pixelpaws
d389e4d5d4 Merge pull request #122 from willmiao/dev
Dev
2025-04-22 09:40:05 +08:00
Will Miao
8cb78ad931 feat: Add route for retrieving current usage statistics 2025-04-22 09:39:00 +08:00
Will Miao
85f987d15c feat: Centralize clipboard functionality with copyToClipboard utility across components 2025-04-22 09:33:05 +08:00
Will Miao
b12079e0f6 feat: Implement usage statistics tracking with backend integration and route setup 2025-04-22 08:56:34 +08:00
pixelpaws
dcf5c6167a Merge pull request #121 from willmiao/dev
Dev
2025-04-21 15:44:23 +08:00
Will Miao
b395d3f487 fix: Update filename formatting in save_images method to ensure unique filenames for batch images 2025-04-21 15:42:49 +08:00
Will Miao
37662cad10 Update workflow 2025-04-21 15:42:49 +08:00
pixelpaws
aa1673063d Merge pull request #120 from willmiao/dev
feat: Enhance LoraManager by updating trigger words handling and dyna…
2025-04-21 06:52:16 +08:00
Will Miao
f51f49eb60 feat: Enhance LoraManager by updating trigger words handling and dynamically loading widget modules. 2025-04-21 06:49:51 +08:00
pixelpaws
54c9bac961 Merge pull request #119 from willmiao/dev
Dev
2025-04-20 22:29:28 +08:00
Will Miao
e70fd73bdd feat: Implement trigger words API and update frontend integration for LoraManager. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/43 2025-04-20 22:27:53 +08:00
Will Miao
9bb9e7b64d refactor: Extract common methods for Lora handling into utils.py and update references in lora_loader.py and lora_stacker.py 2025-04-20 21:35:36 +08:00
pixelpaws
f64c03543a Merge pull request #116 from matrunchyk/main
Prevent duplicates of root folders when using symlinks
2025-04-20 17:05:08 +08:00
Will Miao
51374de1a1 fix: Update version to 0.8.7-bugfix2 in pyproject.toml for clarity on bug fixes 2025-04-20 15:04:24 +08:00
Will Miao
afcc12f263 fix: Update populate_lora_from_civitai method to accept a tuple for Civitai API response. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/117 2025-04-20 15:01:23 +08:00
Your Name
88c5482366 Merge branch 'main' of https://github.com/willmiao/ComfyUI-Lora-Manager 2025-04-19 21:47:41 +03:00
Your Name
bbf7295c32 Prevent duplicates of root folders when using symlinks 2025-04-19 21:42:01 +03:00
Will Miao
ca5e23e68c fix: Update version to 0.8.7-bugfix in pyproject.toml for clarity on bug fixes 2025-04-19 23:02:50 +08:00
Will Miao
eadb1487ae feat: Refactor metadata formatting to use helper function for conditional parameter addition 2025-04-19 23:00:09 +08:00
Will Miao
1faa70fc77 feat: Implement filename-based hash retrieval in LoraScanner and ModelScanner for improved compatibility 2025-04-19 21:12:26 +08:00
Will Miao
30d7c007de fix: Correct metadata restoration logic to ensure file info is fetched when metadata is missing 2025-04-19 20:51:23 +08:00
Will Miao
f54f6a4402 feat: Enhance metadata handling by restoring missing civitai data and extracting tags and descriptions from version info 2025-04-19 11:35:42 +08:00
Will Miao
7b41cdec65 feat: Add civitai_deleted attribute to BaseModelMetadata for tracking deletion status from Civitai 2025-04-19 09:30:43 +08:00
Will Miao
fb6a652a57 feat: Add checkpoint hash retrieval and enhance metadata formatting in SaveImage class 2025-04-18 23:55:45 +08:00
Will Miao
ea34d753c1 refactor: Remove unnecessary workflow data logging and streamline saveRecipeDirectly function for legacy loras widget 2025-04-18 21:52:26 +08:00
36 changed files with 895 additions and 286 deletions

View File

@@ -20,6 +20,12 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
## Release Notes
### v0.8.8
* **Real-time TriggerWord Updates** - Enhanced TriggerWord Toggle node to instantly update when connected Lora Loader or Lora Stacker nodes change, without requiring workflow execution
* **Optimized Metadata Recovery** - Improved utilization of existing .civitai.info files for faster initialization and preservation of metadata from models deleted from CivitAI
* **Migration Acceleration** - Further speed improvements for users transitioning from A1111/Forge environments
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.7
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment

View File

@@ -103,21 +103,29 @@ class Config:
def _init_lora_paths(self) -> List[str]:
"""Initialize and validate LoRA paths from ComfyUI settings"""
paths = sorted(set(path.replace(os.sep, "/")
for path in folder_paths.get_folder_paths("loras")
if os.path.exists(path)), key=lambda p: p.lower())
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
raw_paths = folder_paths.get_folder_paths("loras")
if not paths:
# Normalize and resolve symlinks, store mapping from resolved -> original
path_map = {}
for path in raw_paths:
if os.path.exists(path):
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
path_map[real_path] = path_map.get(real_path, path) # preserve first seen
# Now sort and use only the deduplicated real paths
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
print("Found LoRA roots:", "\n - " + "\n - ".join(unique_paths))
if not unique_paths:
raise ValueError("No valid loras folders found in ComfyUI configuration")
# 初始化路径映射
for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return paths
return unique_paths
def _init_checkpoint_paths(self) -> List[str]:
"""Initialize and validate checkpoint paths from ComfyUI settings"""

View File

@@ -5,6 +5,8 @@ from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes
from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes
from .routes.update_routes import UpdateRoutes
from .routes.usage_stats_routes import UsageStatsRoutes
from .services.service_registry import ServiceRegistry
import logging
@@ -92,6 +94,8 @@ class LoraManager:
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
UsageStatsRoutes.setup_routes(app) # Register usage stats routes
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())

View File

@@ -1,12 +1,14 @@
"""Constants used by the metadata collector"""
# Individual category constants
# Metadata collection constants
# Metadata categories
MODELS = "models"
PROMPTS = "prompts"
SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
IMAGES = "images" # Added new category for image results
IMAGES = "images"
# Collection of categories for iteration
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories
# Complete list of categories to track
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]

View File

@@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
logger = logging.getLogger(__name__)
@@ -32,48 +32,6 @@ class LoraManagerLoader:
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def load_loras(self, model, text, **kwargs):
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
@@ -89,14 +47,14 @@ class LoraManagerLoader:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
loaded_loras.append(f"{lora_name}: {model_strength}")
# Then process loras from kwargs with support for both old and new formats
loras_list = self._get_loras_list(kwargs)
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get('active', False):
continue
@@ -105,7 +63,7 @@ class LoraManagerLoader:
strength = float(lora['strength'])
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Apply the LoRA using the resolved path
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)

View File

@@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
import logging
logger = logging.getLogger(__name__)
@@ -29,48 +29,6 @@ class LoraStacker:
RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
FUNCTION = "stack_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def stack_loras(self, text, **kwargs):
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
@@ -84,12 +42,12 @@ class LoraStacker:
stack.extend(lora_stack)
# Get trigger words from existing stack entries
for lora_path, _, _ in lora_stack:
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
# Process loras from kwargs with support for both old and new formats
loras_list = self._get_loras_list(kwargs)
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get('active', False):
continue
@@ -99,7 +57,7 @@ class LoraStacker:
clip_strength = model_strength # Using same strength for both as in the original loader
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Add to stack without loading
# replace '/' with os.sep to avoid different OS path format

View File

@@ -5,6 +5,7 @@ import re
import numpy as np
import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner
from ..services.checkpoint_scanner import CheckpointScanner
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
@@ -53,18 +54,55 @@ class SaveImage:
async def get_lora_hash(self, lora_name):
"""Get the lora hash from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
# Use the new direct filename lookup method
hash_value = scanner.get_hash_by_filename(lora_name)
if hash_value:
return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
return item.get('sha256')
return None
async def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache"""
scanner = await CheckpointScanner.get_instance()
if not checkpoint_path:
return None
# Extract basename without extension
checkpoint_name = os.path.basename(checkpoint_path)
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Try direct filename lookup first
hash_value = scanner.get_hash_by_filename(checkpoint_name)
if hash_value:
return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
normalized_path = checkpoint_path.replace('\\', '/')
for item in cache.raw_data:
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
return item.get('sha256')
return None
async def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not metadata_dict:
return ""
# Helper function to only add parameter if value is not None
def add_param_if_not_none(param_list, label, value):
if value is not None:
param_list.append(f"{label}: {value}")
# Extract the prompt and negative prompt
prompt = metadata_dict.get('prompt', '')
negative_prompt = metadata_dict.get('negative_prompt', '')
@@ -100,7 +138,11 @@ class SaveImage:
# Add standard parameters in the correct order
if 'steps' in metadata_dict:
params.append(f"Steps: {metadata_dict.get('steps')}")
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
# Combine sampler and scheduler information
sampler_name = None
scheduler_name = None
if 'sampler' in metadata_dict:
sampler = metadata_dict.get('sampler')
@@ -123,7 +165,6 @@ class SaveImage:
'ddim': 'DDIM'
}
sampler_name = sampler_mapping.get(sampler, sampler)
params.append(f"Sampler: {sampler_name}")
if 'scheduler' in metadata_dict:
scheduler = metadata_dict.get('scheduler')
@@ -135,38 +176,48 @@ class SaveImage:
'sgm_quadratic': 'SGM Quadratic'
}
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
params.append(f"Schedule type: {scheduler_name}")
# CFG scale (cfg_scale in metadata_dict)
if 'cfg_scale' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
# Add combined sampler and scheduler information
if sampler_name:
if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}")
else:
params.append(f"Sampler: {sampler_name}")
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
if 'guidance' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
elif 'cfg_scale' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
elif 'cfg' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg')}")
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
# Seed
if 'seed' in metadata_dict:
params.append(f"Seed: {metadata_dict.get('seed')}")
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
# Size
if 'size' in metadata_dict:
params.append(f"Size: {metadata_dict.get('size')}")
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
# Model info
if 'checkpoint' in metadata_dict:
# Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None:
# Handle both string and other types safely
if isinstance(checkpoint, str):
# Extract basename without path
checkpoint = os.path.basename(checkpoint)
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
else:
# Convert non-string to string
checkpoint = str(checkpoint)
# Get model hash
model_hash = await self.get_checkpoint_hash(checkpoint)
params.append(f"Model: {checkpoint}")
# Extract basename without path
checkpoint_name = os.path.basename(checkpoint)
# Remove extension if present
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Add model hash if available
if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
else:
params.append(f"Model: {checkpoint_name}")
# Add LoRA hashes if available
if lora_hashes:
@@ -284,7 +335,7 @@ class SaveImage:
if add_counter_to_filename:
# Use counter + i to ensure unique filenames for all images in batch
current_counter = counter + i
base_filename += f"_{current_counter:05}"
base_filename += f"_{current_counter:05}_"
# Set file extension and prepare saving parameters
if file_format == "png":

View File

@@ -47,10 +47,10 @@ class TriggerWordToggle:
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
# Send trigger words to frontend
PromptServer.instance.send_sync("trigger_word_update", {
"id": id,
"message": trigger_words
})
# PromptServer.instance.send_sync("trigger_word_update", {
# "id": id,
# "message": trigger_words
# })
filtered_triggers = trigger_words

View File

@@ -30,4 +30,55 @@ class FlexibleOptionalInputType(dict):
return True
any_type = AnyType("*")
any_type = AnyType("*")
# Common methods extracted from lora_loader.py and lora_stacker.py
import os
import logging
import asyncio
from ..services.lora_scanner import LoraScanner
from ..config import config
logger = logging.getLogger(__name__)
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def get_loras_list(kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []

View File

@@ -3,8 +3,10 @@ import json
import logging
from aiohttp import web
from typing import Dict
from server import PromptServer # type: ignore
from ..utils.routes_common import ModelRouteUtils
from ..nodes.utils import get_lora_info
from ..config import config
from ..services.websocket_manager import ws_manager
@@ -64,6 +66,9 @@ class ApiRoutes:
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
# Add the new trigger words route
app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words)
# Add update check routes
UpdateRoutes.setup_routes(app)
@@ -1021,4 +1026,35 @@ class ApiRoutes:
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models"""
try:
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = await get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error getting trigger words: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -0,0 +1,69 @@
import logging
from aiohttp import web
from ..utils.usage_stats import UsageStats
logger = logging.getLogger(__name__)
class UsageStatsRoutes:
"""Routes for handling usage statistics updates"""
@staticmethod
def setup_routes(app):
"""Register usage stats routes"""
app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats)
app.router.add_get('/loras/api/get-usage-stats', UsageStatsRoutes.get_usage_stats)
@staticmethod
async def update_usage_stats(request):
"""
Update usage statistics based on a prompt_id
Expects a JSON body with:
{
"prompt_id": "string"
}
"""
try:
# Parse the request body
data = await request.json()
prompt_id = data.get('prompt_id')
if not prompt_id:
return web.json_response({
'success': False,
'error': 'Missing prompt_id'
}, status=400)
# Call the UsageStats to process this prompt_id synchronously
usage_stats = UsageStats()
await usage_stats.process_execution(prompt_id)
return web.json_response({
'success': True
})
except Exception as e:
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_usage_stats(request):
"""Get current usage statistics"""
try:
usage_stats = UsageStats()
stats = await usage_stats.get_stats()
return web.json_response({
'success': True,
'data': stats
})
except Exception as e:
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

26
py/server_routes.py Normal file
View File

@@ -0,0 +1,26 @@
from aiohttp import web
from server import PromptServer
from .nodes.utils import get_lora_info
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
async def get_trigger_words(request):
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = await get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})

View File

@@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Set
from ..utils.models import LoraMetadata
from ..config import config
from .model_scanner import ModelScanner
from .lora_hash_index import LoraHashIndex
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
@@ -35,12 +35,12 @@ class LoraScanner(ModelScanner):
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=LoraHashIndex()
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
self._initialized = True

View File

@@ -1,11 +1,12 @@
from typing import Dict, Optional, Set
import os
class ModelHashIndex:
"""Index for looking up models by hash or path"""
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
self._path_to_hash: Dict[str, str] = {}
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update hash index entry"""
@@ -15,37 +16,47 @@ class ModelHashIndex:
# Ensure hash is lowercase for consistency
sha256 = sha256.lower()
# Extract filename without extension
filename = self._get_filename_from_path(file_path)
# Remove old path mapping if hash exists
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
if old_path in self._path_to_hash:
del self._path_to_hash[old_path]
old_filename = self._get_filename_from_path(old_path)
if old_filename in self._filename_to_hash:
del self._filename_to_hash[old_filename]
# Remove old hash mapping if path exists
if file_path in self._path_to_hash:
old_hash = self._path_to_hash[file_path]
# Remove old hash mapping if filename exists
if filename in self._filename_to_hash:
old_hash = self._filename_to_hash[filename]
if old_hash in self._hash_to_path:
del self._hash_to_path[old_hash]
# Add new mappings
self._hash_to_path[sha256] = file_path
self._path_to_hash[file_path] = sha256
self._filename_to_hash[filename] = sha256
def _get_filename_from_path(self, file_path: str) -> str:
"""Extract filename without extension from path"""
return os.path.splitext(os.path.basename(file_path))[0]
def remove_by_path(self, file_path: str) -> None:
"""Remove entry by file path"""
if file_path in self._path_to_hash:
hash_val = self._path_to_hash[file_path]
filename = self._get_filename_from_path(file_path)
if filename in self._filename_to_hash:
hash_val = self._filename_to_hash[filename]
if hash_val in self._hash_to_path:
del self._hash_to_path[hash_val]
del self._path_to_hash[file_path]
del self._filename_to_hash[filename]
def remove_by_hash(self, sha256: str) -> None:
"""Remove entry by hash"""
sha256 = sha256.lower()
if sha256 in self._hash_to_path:
path = self._hash_to_path[sha256]
if path in self._path_to_hash:
del self._path_to_hash[path]
filename = self._get_filename_from_path(path)
if filename in self._filename_to_hash:
del self._filename_to_hash[filename]
del self._hash_to_path[sha256]
def has_hash(self, sha256: str) -> bool:
@@ -58,20 +69,27 @@ class ModelHashIndex:
def get_hash(self, file_path: str) -> Optional[str]:
"""Get hash for a file path"""
return self._path_to_hash.get(file_path)
filename = self._get_filename_from_path(file_path)
return self._filename_to_hash.get(filename)
def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a filename without extension"""
# Strip extension if present to make the function more flexible
filename = os.path.splitext(filename)[0]
return self._filename_to_hash.get(filename)
def clear(self) -> None:
"""Clear all entries"""
self._hash_to_path.clear()
self._path_to_hash.clear()
self._filename_to_hash.clear()
def get_all_hashes(self) -> Set[str]:
"""Get all hashes in the index"""
return set(self._hash_to_path.keys())
def get_all_paths(self) -> Set[str]:
"""Get all file paths in the index"""
return set(self._path_to_hash.keys())
def get_all_filenames(self) -> Set[str]:
"""Get all filenames in the index"""
return set(self._filename_to_hash.keys())
def __len__(self) -> int:
"""Get number of entries"""

View File

@@ -292,7 +292,7 @@ class ModelScanner:
)
# If force refresh is requested, initialize the cache directly
if force_refresh:
if (force_refresh):
if self._cache is None:
# For initial creation, do a full initialization
await self._initialize_cache()
@@ -553,9 +553,36 @@ class ModelScanner:
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
else:
# Check if metadata exists but civitai field is empty - try to restore from civitai.info
if metadata.civitai is None or metadata.civitai == {}:
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
logger.debug(f"Restoring missing civitai data from .civitai.info for {file_path}")
metadata.civitai = version_info
# Ensure tags are also updated if they're missing
if (not metadata.tags or len(metadata.tags) == 0) and 'model' in version_info:
if 'tags' in version_info['model']:
metadata.tags = version_info['model']['tags']
# Also restore description if missing
if (not metadata.modelDescription or metadata.modelDescription == "") and 'model' in version_info:
if 'description' in version_info['model']:
metadata.modelDescription = version_info['model']['description']
# Save the updated metadata
await save_metadata(file_path, metadata)
logger.debug(f"Updated metadata with civitai info for {file_path}")
except Exception as e:
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
if metadata is None:
metadata = await self._get_file_info(file_path)
if metadata is None:
metadata = await self._get_file_info(file_path)
model_data = metadata.to_dict()
@@ -805,6 +832,10 @@ class ModelScanner:
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self._hash_index.get_hash(file_path)
def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a model by its filename without path"""
return self._hash_index.get_hash_by_filename(filename)
# TODO: Adjust this method to use metadata instead of finding the file
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:

View File

@@ -21,6 +21,7 @@ class BaseModelMetadata:
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
civitai_deleted: bool = False # Whether deleted from Civitai
def __post_init__(self):
# Initialize empty lists to avoid mutable default parameter issue
@@ -64,6 +65,15 @@ class LoraMetadata(BaseModelMetadata):
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
@@ -75,7 +85,9 @@ class LoraMetadata(BaseModelMetadata):
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download
from_civitai=True,
civitai=version_info
civitai=version_info,
tags=tags,
modelDescription=description
)
@dataclass
@@ -90,6 +102,15 @@ class CheckpointMetadata(BaseModelMetadata):
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'checkpoint')
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
@@ -102,6 +123,8 @@ class CheckpointMetadata(BaseModelMetadata):
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type
model_type=model_type,
tags=tags,
modelDescription=description
)

View File

@@ -45,14 +45,14 @@ class RecipeMetadataParser(ABC):
"""
pass
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info: Dict[str, Any],
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
"""
Populate a lora entry with information from Civitai API response
Args:
lora_entry: The lora entry to populate
civitai_info: The response from Civitai API
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
recipe_scanner: Optional recipe scanner for local file lookup
base_model_counts: Optional dict to track base model counts
hash_value: Optional hash value to use if not available in civitai_info
@@ -61,6 +61,9 @@ class RecipeMetadataParser(ABC):
The populated lora_entry dict
"""
try:
# Unpack the tuple to get the actual data
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
if civitai_info and civitai_info.get("error") != "Model not found":
# Check if this is an early access lora
if civitai_info.get('earlyAccessEndsAt'):
@@ -241,11 +244,11 @@ class RecipeFormatParser(RecipeMetadataParser):
# Try to get additional info from Civitai if we have a model version ID
if lora.get('modelVersionId') and civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(lora['modelVersionId'])
civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
civitai_info_tuple,
recipe_scanner,
None, # No need to track base model counts
lora['hash']
@@ -336,12 +339,13 @@ class StandardMetadataParser(RecipeMetadataParser):
# Get additional info from Civitai if client is available
if civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(model_version_id)
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner
civitai_info_tuple,
recipe_scanner,
base_model_counts
)
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA: {e}")
@@ -621,11 +625,11 @@ class ComfyMetadataParser(RecipeMetadataParser):
# Get additional info from Civitai if client is available
if civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(model_version_id)
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
civitai_info_tuple,
recipe_scanner
)
except Exception as e:
@@ -660,7 +664,8 @@ class ComfyMetadataParser(RecipeMetadataParser):
# Get additional checkpoint info from Civitai
if civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(checkpoint_version_id)
civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
# Populate checkpoint with Civitai info
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
except Exception as e:

267
py/utils/usage_stats.py Normal file
View File

@@ -0,0 +1,267 @@
import os
import json
import time
import asyncio
import logging
from typing import Dict, Set
from ..config import config
from ..services.service_registry import ServiceRegistry
from ..metadata_collector.metadata_registry import MetadataRegistry
from ..metadata_collector.constants import MODELS, LORAS
logger = logging.getLogger(__name__)
class UsageStats:
"""Track usage statistics for models and save to JSON"""
_instance = None
_lock = asyncio.Lock() # For thread safety
# Default stats file name
STATS_FILENAME = "lora_manager_stats.json"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# Initialize stats storage
self.stats = {
"checkpoints": {}, # sha256 -> count
"loras": {}, # sha256 -> count
"total_executions": 0,
"last_save_time": 0
}
# Queue for prompt_ids to process
self.pending_prompt_ids = set()
# Load existing stats if available
self._stats_file_path = self._get_stats_file_path()
self._load_stats()
# Save interval in seconds
self.save_interval = 90 # 1.5 minutes
# Start background task to process queued prompt_ids
self._bg_task = asyncio.create_task(self._background_processor())
self._initialized = True
logger.info("Usage statistics tracker initialized")
def _get_stats_file_path(self) -> str:
"""Get the path to the stats JSON file"""
if not config.loras_roots or len(config.loras_roots) == 0:
# Fallback to temporary directory if no lora roots
return os.path.join(config.temp_directory, self.STATS_FILENAME)
# Use the first lora root
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
def _load_stats(self):
"""Load existing statistics from file"""
try:
if os.path.exists(self._stats_file_path):
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
loaded_stats = json.load(f)
# Update our stats with loaded data
if isinstance(loaded_stats, dict):
# Update individual sections to maintain structure
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
self.stats["checkpoints"] = loaded_stats["checkpoints"]
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
self.stats["loras"] = loaded_stats["loras"]
if "total_executions" in loaded_stats:
self.stats["total_executions"] = loaded_stats["total_executions"]
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
except Exception as e:
logger.error(f"Error loading usage statistics: {e}")
async def save_stats(self, force=False):
"""Save statistics to file"""
try:
# Only save if it's been at least save_interval since last save or force is True
current_time = time.time()
if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval:
return False
# Use a lock to prevent concurrent writes
async with self._lock:
# Update last save time
self.stats["last_save_time"] = current_time
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
# Write to a temporary file first, then move it to avoid corruption
temp_path = f"{self._stats_file_path}.tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
# Replace the old file with the new one
os.replace(temp_path, self._stats_file_path)
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
return True
except Exception as e:
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
return False
def register_execution(self, prompt_id):
"""Register a completed execution by prompt_id for later processing"""
if prompt_id:
self.pending_prompt_ids.add(prompt_id)
async def _background_processor(self):
"""Background task to process queued prompt_ids"""
try:
while True:
# Wait a short interval before checking for new prompt_ids
await asyncio.sleep(5) # Check every 5 seconds
# Process any pending prompt_ids
if self.pending_prompt_ids:
async with self._lock:
# Get a copy of the set and clear original
prompt_ids = self.pending_prompt_ids.copy()
self.pending_prompt_ids.clear()
# Process each prompt_id
registry = MetadataRegistry()
for prompt_id in prompt_ids:
try:
metadata = registry.get_metadata(prompt_id)
await self._process_metadata(metadata)
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
# Periodically save stats
await self.save_stats()
except asyncio.CancelledError:
# Task was cancelled, clean up
await self.save_stats(force=True)
except Exception as e:
logger.error(f"Error in background processing task: {e}", exc_info=True)
# Restart the task after a delay if it fails
asyncio.create_task(self._restart_background_task())
async def _restart_background_task(self):
"""Restart the background task after a delay"""
await asyncio.sleep(30) # Wait 30 seconds before restarting
self._bg_task = asyncio.create_task(self._background_processor())
async def _process_metadata(self, metadata):
"""Process metadata from an execution"""
if not metadata or not isinstance(metadata, dict):
return
# Increment total executions count
self.stats["total_executions"] += 1
# Process checkpoints
if MODELS in metadata and isinstance(metadata[MODELS], dict):
await self._process_checkpoints(metadata[MODELS])
# Process loras
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS])
async def _process_checkpoints(self, models_data):
"""Process checkpoint models from metadata"""
try:
# Get checkpoint scanner service
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
if not checkpoint_scanner:
logger.warning("Checkpoint scanner not available for usage tracking")
return
for node_id, model_info in models_data.items():
if not isinstance(model_info, dict):
continue
# Check if this is a checkpoint model
model_type = model_info.get("type")
if model_type == "checkpoint":
model_name = model_info.get("name")
if not model_name:
continue
# Clean up filename (remove extension if present)
model_filename = os.path.splitext(os.path.basename(model_name))[0]
# Get hash for this checkpoint
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
if model_hash:
# Update stats for this checkpoint
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
except Exception as e:
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
async def _process_loras(self, loras_data):
"""Process LoRA models from metadata"""
try:
# Get LoRA scanner service
lora_scanner = await ServiceRegistry.get_lora_scanner()
if not lora_scanner:
logger.warning("LoRA scanner not available for usage tracking")
return
for node_id, lora_info in loras_data.items():
if not isinstance(lora_info, dict):
continue
# Get the list of LoRAs from standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
if not isinstance(lora, dict):
continue
lora_name = lora.get("name")
if not lora_name:
continue
# Get hash for this LoRA
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
if lora_hash:
# Update stats for this LoRA
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
async def get_stats(self):
"""Get current usage statistics"""
return self.stats
async def get_model_usage_count(self, model_type, sha256):
"""Get usage count for a specific model by hash"""
if model_type == "checkpoint":
return self.stats["checkpoints"].get(sha256, 0)
elif model_type == "lora":
return self.stats["loras"].get(sha256, 0)
return 0
async def process_execution(self, prompt_id):
"""Process a prompt execution immediately (synchronous approach)"""
if not prompt_id:
return
try:
# Process metadata for this prompt_id
registry = MetadataRegistry()
metadata = registry.get_metadata(prompt_id)
if metadata:
await self._process_metadata(metadata)
# Save stats if needed
await self.save_stats()
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)

View File

@@ -1,7 +1,7 @@
[project]
name = "comfyui-lora-manager"
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
version = "0.8.7"
version = "0.8.8"
license = {file = "LICENSE"}
dependencies = [
"aiohttp",

View File

@@ -1,4 +1,4 @@
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showCheckpointModal } from './checkpointModal/index.js';
import { NSFW_LEVELS } from '../utils/constants.js';
@@ -204,21 +204,7 @@ export function createCheckpointCard(checkpoint) {
const checkpointName = card.dataset.file_name;
try {
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(checkpointName);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = checkpointName;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
showToast('Checkpoint name copied', 'success');
await copyToClipboard(checkpointName, 'Checkpoint name copied');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -1,4 +1,4 @@
import { showToast, openCivitai } from '../utils/uiHelpers.js';
import { showToast, openCivitai, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showLoraModal } from './loraModal/index.js';
import { bulkManager } from '../managers/BulkManager.js';
@@ -205,26 +205,7 @@ export function createLoraCard(lora) {
const strength = usageTips.strength || 1;
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
try {
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(loraSyntax);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = loraSyntax;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
showToast('LoRA syntax copied', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
await copyToClipboard(loraSyntax, 'LoRA syntax copied');
});
// Civitai button click event

View File

@@ -1,5 +1,5 @@
// Recipe Card Component
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { modalManager } from '../managers/ModalManager.js';
class RecipeCard {
@@ -109,14 +109,11 @@ class RecipeCard {
.then(response => response.json())
.then(data => {
if (data.success && data.syntax) {
return navigator.clipboard.writeText(data.syntax);
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
} else {
throw new Error(data.error || 'No syntax returned');
}
})
.then(() => {
showToast('Recipe syntax copied to clipboard', 'success');
})
.catch(err => {
console.error('Failed to copy: ', err);
showToast('Failed to copy recipe syntax', 'error');
@@ -279,4 +276,4 @@ class RecipeCard {
}
}
export { RecipeCard };
export { RecipeCard };

View File

@@ -1,5 +1,5 @@
// Recipe Modal Component
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
@@ -747,9 +747,8 @@ class RecipeModal {
const data = await response.json();
if (data.success && data.syntax) {
// Copy to clipboard
await navigator.clipboard.writeText(data.syntax);
showToast('Recipe syntax copied to clipboard', 'success');
// Use the centralized copyToClipboard utility function
await copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
} else {
throw new Error(data.error || 'No syntax returned from server');
}
@@ -761,12 +760,7 @@ class RecipeModal {
// Helper method to copy text to clipboard
copyToClipboard(text, successMessage) {
navigator.clipboard.writeText(text).then(() => {
showToast(successMessage, 'success');
}).catch(err => {
console.error('Failed to copy text: ', err);
showToast('Failed to copy text', 'error');
});
copyToClipboard(text, successMessage);
}
// Add new method to handle downloading missing LoRAs

View File

@@ -2,7 +2,7 @@
* ShowcaseView.js
* Handles showcase content (images, videos) display for checkpoint modal
*/
import { showToast } from '../../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
@@ -307,8 +307,7 @@ function initMetadataPanelHandlers(container) {
if (!promptElement) return;
try {
await navigator.clipboard.writeText(promptElement.textContent);
showToast('Prompt copied to clipboard', 'success');
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -1,7 +1,7 @@
/**
* RecipeTab - Handles the recipes tab in the Lora Modal
*/
import { showToast } from '../../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
/**
@@ -172,14 +172,11 @@ function copyRecipeSyntax(recipeId) {
.then(response => response.json())
.then(data => {
if (data.success && data.syntax) {
return navigator.clipboard.writeText(data.syntax);
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
} else {
throw new Error(data.error || 'No syntax returned');
}
})
.then(() => {
showToast('Recipe syntax copied to clipboard', 'success');
})
.catch(err => {
console.error('Failed to copy: ', err);
showToast('Failed to copy recipe syntax', 'error');

View File

@@ -2,7 +2,7 @@
* ShowcaseView.js
* 处理LoRA模型展示内容图片、视频的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
@@ -311,8 +311,7 @@ function initMetadataPanelHandlers(container) {
if (!promptElement) return;
try {
await navigator.clipboard.writeText(promptElement.textContent);
showToast('Prompt copied to clipboard', 'success');
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -2,7 +2,7 @@
* TriggerWords.js
* 处理LoRA模型触发词相关的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { saveModelMetadata } from './ModelMetadata.js';
/**
@@ -336,23 +336,7 @@ async function saveTriggerWords() {
*/
window.copyTriggerWord = async function(word) {
try {
// Modern clipboard API - with fallback for non-secure contexts
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(word);
} else {
// Fallback for older browsers or non-secure contexts
const textarea = document.createElement('textarea');
textarea.value = word;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
const success = document.execCommand('copy');
document.body.removeChild(textarea);
if (!success) throw new Error('Copy command failed');
}
showToast('Trigger word copied', 'success');
await copyToClipboard(word, 'Trigger word copied');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -3,8 +3,7 @@
*
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
*/
import { showToast } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { modalManager } from '../../managers/ModalManager.js';
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
@@ -174,8 +173,7 @@ export function showLoraModal(lora) {
// Copy file name function
window.copyFileName = async function(fileName) {
try {
await navigator.clipboard.writeText(fileName);
showToast('File name copied', 'success');
await copyToClipboard(fileName, 'File name copied');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -1,5 +1,5 @@
import { state } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { updateCardsForBulkMode } from '../components/LoraCard.js';
export class BulkManager {
@@ -205,13 +205,7 @@ export class BulkManager {
return;
}
try {
await navigator.clipboard.writeText(loraSyntaxes.join(', '));
showToast(`Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`, 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
await copyToClipboard(loraSyntaxes.join(', '), `Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`);
}
// Create and show the thumbnail strip of selected LoRAs

View File

@@ -2,6 +2,40 @@ import { state } from '../state/index.js';
import { resetAndReload } from '../api/loraApi.js';
import { getStorageItem, setStorageItem } from './storageHelpers.js';
/**
* Utility function to copy text to clipboard with fallback for older browsers
* @param {string} text - The text to copy to clipboard
* @param {string} successMessage - Optional success message to show in toast
* @returns {Promise<boolean>} - Promise that resolves to true if copy was successful
*/
export async function copyToClipboard(text, successMessage = 'Copied to clipboard') {
try {
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(text);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = text;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
if (successMessage) {
showToast(successMessage, 'success');
}
return true;
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
return false;
}
}
export function showToast(message, type = 'info') {
const toast = document.createElement('div');
toast.className = `toast toast-${type}`;
@@ -108,12 +142,6 @@ export function toggleFolder(tag) {
resetAndReload();
}
export function copyTriggerWord(word) {
navigator.clipboard.writeText(word).then(() => {
showToast('Trigger word copied', 'success');
});
}
function filterByFolder(folderPath) {
document.querySelectorAll('.lora-card').forEach(card => {
card.style.display = card.dataset.folder === folderPath ? '' : 'none';

View File

@@ -927,10 +927,6 @@ export function addLorasWidget(node, name, opts, callback) {
// Function to directly save the recipe without dialog
async function saveRecipeDirectly(widget) {
try {
// Get the workflow data from the ComfyUI app
const prompt = await app.graphToPrompt();
console.log('Prompt:', prompt);
// Show loading toast
if (app && app.extensionManager && app.extensionManager.toast) {
app.extensionManager.toast.add({
@@ -941,14 +937,9 @@ async function saveRecipeDirectly(widget) {
});
}
// Prepare the data - only send workflow JSON
const formData = new FormData();
formData.append('workflow_json', JSON.stringify(prompt.output));
// Send the request
const response = await fetch('/api/recipes/save-from-widget', {
method: 'POST',
body: formData
method: 'POST'
});
const result = await response.json();

View File

@@ -9,6 +9,54 @@ async function getLorasWidgetModule() {
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
}
// Function to get connected trigger toggle nodes
function getConnectedTriggerToggleNodes(node) {
const connectedNodes = [];
// Check if node has outputs
if (node.outputs && node.outputs.length > 0) {
// For each output slot
for (const output of node.outputs) {
// Check if this output has any links
if (output.links && output.links.length > 0) {
// For each link, get the target node
for (const linkId of output.links) {
const link = app.graph.links[linkId];
if (link) {
const targetNode = app.graph.getNodeById(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode.id);
}
}
}
}
}
}
return connectedNodes;
}
// Function to update trigger words for connected toggle nodes
function updateConnectedTriggerWords(node, text) {
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
if (connectedNodeIds.length > 0) {
const loraNames = new Set();
let match;
LORA_PATTERN.lastIndex = 0;
while ((match = LORA_PATTERN.exec(text)) !== null) {
loraNames.add(match[1]);
}
fetch("/loramanager/get_trigger_words", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
lora_names: Array.from(loraNames),
node_ids: connectedNodeIds
})
}).catch(err => console.error("Error fetching trigger words:", err));
}
}
function mergeLoras(lorasText, lorasArr) {
const result = [];
let match;
@@ -99,6 +147,9 @@ app.registerExtension({
newText = newText.replace(/\s+/g, ' ').trim();
inputWidget.value = newText;
// Add this line to update trigger words when lorasWidget changes cause inputWidget value to change
updateConnectedTriggerWords(node, newText);
} finally {
isUpdating = false;
}
@@ -117,6 +168,9 @@ app.registerExtension({
const mergedLoras = mergeLoras(value, currentLoras);
node.lorasWidget.value = mergedLoras;
// Replace the existing trigger word update code with the new function
updateConnectedTriggerWords(node, value);
} finally {
isUpdating = false;
}

View File

@@ -1,9 +1,58 @@
import { app } from "../../scripts/app.js";
import { addLorasWidget } from "./loras_widget.js";
import { dynamicImportByVersion } from "./utils.js";
// Extract pattern into a constant for consistent use
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
// Function to get the appropriate loras widget based on ComfyUI version
async function getLorasWidgetModule() {
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
}
// Function to get connected trigger toggle nodes
function getConnectedTriggerToggleNodes(node) {
const connectedNodes = [];
if (node.outputs && node.outputs.length > 0) {
for (const output of node.outputs) {
if (output.links && output.links.length > 0) {
for (const linkId of output.links) {
const link = app.graph.links[linkId];
if (link) {
const targetNode = app.graph.getNodeById(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode.id);
}
}
}
}
}
}
return connectedNodes;
}
// Function to update trigger words for connected toggle nodes
function updateConnectedTriggerWords(node, text) {
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
if (connectedNodeIds.length > 0) {
const loraNames = new Set();
let match;
LORA_PATTERN.lastIndex = 0;
while ((match = LORA_PATTERN.exec(text)) !== null) {
loraNames.add(match[1]);
}
fetch("/loramanager/get_trigger_words", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
lora_names: Array.from(loraNames),
node_ids: connectedNodeIds
})
}).catch(err => console.error("Error fetching trigger words:", err));
}
}
function mergeLoras(lorasText, lorasArr) {
const result = [];
let match;
@@ -40,7 +89,7 @@ app.registerExtension({
});
// Wait for node to be properly initialized
requestAnimationFrame(() => {
requestAnimationFrame(async () => {
// Restore saved value if exists
let existingLoras = [];
if (node.widgets_values && node.widgets_values.length > 0) {
@@ -64,7 +113,10 @@ app.registerExtension({
// Add flag to prevent callback loops
let isUpdating = false;
// Get the widget object directly from the returned object
// Dynamically load the appropriate widget module
const lorasModule = await getLorasWidgetModule();
const { addLorasWidget } = lorasModule;
const result = addLorasWidget(node, "loras", {
defaultVal: mergedLoras // Pass object directly
}, (value) => {
@@ -86,6 +138,9 @@ app.registerExtension({
newText = newText.replace(/\s+/g, ' ').trim();
inputWidget.value = newText;
// Update trigger words when lorasWidget changes
updateConnectedTriggerWords(node, newText);
} finally {
isUpdating = false;
}
@@ -104,6 +159,9 @@ app.registerExtension({
const mergedLoras = mergeLoras(value, currentLoras);
node.lorasWidget.value = mergedLoras;
// Update trigger words when input changes
updateConnectedTriggerWords(node, value);
} finally {
isUpdating = false;
}

View File

@@ -0,0 +1,36 @@
// ComfyUI extension to track model usage statistics
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
// Register the extension
app.registerExtension({
name: "ComfyUI-Lora-Manager.UsageStats",
init() {
// Listen for successful executions
api.addEventListener("execution_success", ({ detail }) => {
if (detail && detail.prompt_id) {
this.updateUsageStats(detail.prompt_id);
}
});
},
async updateUsageStats(promptId) {
try {
// Call backend endpoint with the prompt_id
const response = await fetch(`/loras/api/update-usage-stats`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ prompt_id: promptId }),
});
if (!response.ok) {
console.warn("Failed to update usage statistics:", response.statusText);
}
} catch (error) {
console.error("Error updating usage statistics:", error);
}
}
});

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long