mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-23 22:22:11 -03:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dc4c11ddd2 | ||
|
|
d389e4d5d4 | ||
|
|
8cb78ad931 | ||
|
|
85f987d15c | ||
|
|
b12079e0f6 | ||
|
|
dcf5c6167a | ||
|
|
b395d3f487 | ||
|
|
37662cad10 | ||
|
|
aa1673063d | ||
|
|
f51f49eb60 | ||
|
|
54c9bac961 | ||
|
|
e70fd73bdd | ||
|
|
9bb9e7b64d | ||
|
|
f64c03543a | ||
|
|
51374de1a1 | ||
|
|
afcc12f263 | ||
|
|
88c5482366 | ||
|
|
bbf7295c32 | ||
|
|
ca5e23e68c | ||
|
|
eadb1487ae | ||
|
|
1faa70fc77 | ||
|
|
30d7c007de | ||
|
|
f54f6a4402 | ||
|
|
7b41cdec65 | ||
|
|
fb6a652a57 | ||
|
|
ea34d753c1 |
@@ -20,6 +20,12 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
|||||||
|
|
||||||
## Release Notes
|
## Release Notes
|
||||||
|
|
||||||
|
### v0.8.8
|
||||||
|
* **Real-time TriggerWord Updates** - Enhanced TriggerWord Toggle node to instantly update when connected Lora Loader or Lora Stacker nodes change, without requiring workflow execution
|
||||||
|
* **Optimized Metadata Recovery** - Improved utilization of existing .civitai.info files for faster initialization and preservation of metadata from models deleted from CivitAI
|
||||||
|
* **Migration Acceleration** - Further speed improvements for users transitioning from A1111/Forge environments
|
||||||
|
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
|
||||||
|
|
||||||
### v0.8.7
|
### v0.8.7
|
||||||
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
|
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
|
||||||
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
|
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
|
||||||
|
|||||||
30
py/config.py
30
py/config.py
@@ -103,21 +103,29 @@ class Config:
|
|||||||
|
|
||||||
def _init_lora_paths(self) -> List[str]:
|
def _init_lora_paths(self) -> List[str]:
|
||||||
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
"""Initialize and validate LoRA paths from ComfyUI settings"""
|
||||||
paths = sorted(set(path.replace(os.sep, "/")
|
raw_paths = folder_paths.get_folder_paths("loras")
|
||||||
for path in folder_paths.get_folder_paths("loras")
|
|
||||||
if os.path.exists(path)), key=lambda p: p.lower())
|
|
||||||
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
|
|
||||||
|
|
||||||
if not paths:
|
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||||
|
path_map = {}
|
||||||
|
for path in raw_paths:
|
||||||
|
if os.path.exists(path):
|
||||||
|
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||||
|
path_map[real_path] = path_map.get(real_path, path) # preserve first seen
|
||||||
|
|
||||||
|
# Now sort and use only the deduplicated real paths
|
||||||
|
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||||
|
print("Found LoRA roots:", "\n - " + "\n - ".join(unique_paths))
|
||||||
|
|
||||||
|
if not unique_paths:
|
||||||
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
raise ValueError("No valid loras folders found in ComfyUI configuration")
|
||||||
|
|
||||||
# 初始化路径映射
|
for original_path in unique_paths:
|
||||||
for path in paths:
|
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
if real_path != original_path:
|
||||||
if real_path != path:
|
self.add_path_mapping(original_path, real_path)
|
||||||
self.add_path_mapping(path, real_path)
|
|
||||||
|
return unique_paths
|
||||||
|
|
||||||
return paths
|
|
||||||
|
|
||||||
def _init_checkpoint_paths(self) -> List[str]:
|
def _init_checkpoint_paths(self) -> List[str]:
|
||||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ from .routes.lora_routes import LoraRoutes
|
|||||||
from .routes.api_routes import ApiRoutes
|
from .routes.api_routes import ApiRoutes
|
||||||
from .routes.recipe_routes import RecipeRoutes
|
from .routes.recipe_routes import RecipeRoutes
|
||||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||||
|
from .routes.update_routes import UpdateRoutes
|
||||||
|
from .routes.usage_stats_routes import UsageStatsRoutes
|
||||||
from .services.service_registry import ServiceRegistry
|
from .services.service_registry import ServiceRegistry
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -92,6 +94,8 @@ class LoraManager:
|
|||||||
checkpoints_routes.setup_routes(app)
|
checkpoints_routes.setup_routes(app)
|
||||||
ApiRoutes.setup_routes(app)
|
ApiRoutes.setup_routes(app)
|
||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
|
UpdateRoutes.setup_routes(app)
|
||||||
|
UsageStatsRoutes.setup_routes(app) # Register usage stats routes
|
||||||
|
|
||||||
# Schedule service initialization
|
# Schedule service initialization
|
||||||
app.on_startup.append(lambda app: cls._initialize_services())
|
app.on_startup.append(lambda app: cls._initialize_services())
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
"""Constants used by the metadata collector"""
|
"""Constants used by the metadata collector"""
|
||||||
|
|
||||||
# Individual category constants
|
# Metadata collection constants
|
||||||
|
|
||||||
|
# Metadata categories
|
||||||
MODELS = "models"
|
MODELS = "models"
|
||||||
PROMPTS = "prompts"
|
PROMPTS = "prompts"
|
||||||
SAMPLING = "sampling"
|
SAMPLING = "sampling"
|
||||||
LORAS = "loras"
|
LORAS = "loras"
|
||||||
SIZE = "size"
|
SIZE = "size"
|
||||||
IMAGES = "images" # Added new category for image results
|
IMAGES = "images"
|
||||||
|
|
||||||
# Collection of categories for iteration
|
# Complete list of categories to track
|
||||||
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories
|
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type
|
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,48 +33,6 @@ class LoraManagerLoader:
|
|||||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||||
FUNCTION = "load_loras"
|
FUNCTION = "load_loras"
|
||||||
|
|
||||||
async def get_lora_info(self, lora_name):
|
|
||||||
"""Get the lora path and trigger words from cache"""
|
|
||||||
scanner = await LoraScanner.get_instance()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == lora_name:
|
|
||||||
file_path = item.get('file_path')
|
|
||||||
if file_path:
|
|
||||||
for root in config.loras_roots:
|
|
||||||
root = root.replace(os.sep, '/')
|
|
||||||
if file_path.startswith(root):
|
|
||||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
|
||||||
# Get trigger words from civitai metadata
|
|
||||||
civitai = item.get('civitai', {})
|
|
||||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
|
||||||
return relative_path, trigger_words
|
|
||||||
return lora_name, [] # Fallback if not found
|
|
||||||
|
|
||||||
def extract_lora_name(self, lora_path):
|
|
||||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
|
||||||
# Get the basename without extension
|
|
||||||
basename = os.path.basename(lora_path)
|
|
||||||
return os.path.splitext(basename)[0]
|
|
||||||
|
|
||||||
def _get_loras_list(self, kwargs):
|
|
||||||
"""Helper to extract loras list from either old or new kwargs format"""
|
|
||||||
if 'loras' not in kwargs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
loras_data = kwargs['loras']
|
|
||||||
# Handle new format: {'loras': {'__value__': [...]}}
|
|
||||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
|
||||||
return loras_data['__value__']
|
|
||||||
# Handle old format: {'loras': [...]}
|
|
||||||
elif isinstance(loras_data, list):
|
|
||||||
return loras_data
|
|
||||||
# Unexpected format
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def load_loras(self, model, text, **kwargs):
|
def load_loras(self, model, text, **kwargs):
|
||||||
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
|
||||||
loaded_loras = []
|
loaded_loras = []
|
||||||
@@ -89,14 +47,14 @@ class LoraManagerLoader:
|
|||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||||
|
|
||||||
# Extract lora name for trigger words lookup
|
# Extract lora name for trigger words lookup
|
||||||
lora_name = self.extract_lora_name(lora_path)
|
lora_name = extract_lora_name(lora_path)
|
||||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||||
|
|
||||||
# Then process loras from kwargs with support for both old and new formats
|
# Then process loras from kwargs with support for both old and new formats
|
||||||
loras_list = self._get_loras_list(kwargs)
|
loras_list = get_loras_list(kwargs)
|
||||||
for lora in loras_list:
|
for lora in loras_list:
|
||||||
if not lora.get('active', False):
|
if not lora.get('active', False):
|
||||||
continue
|
continue
|
||||||
@@ -105,7 +63,7 @@ class LoraManagerLoader:
|
|||||||
strength = float(lora['strength'])
|
strength = float(lora['strength'])
|
||||||
|
|
||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
# Apply the LoRA using the resolved path
|
# Apply the LoRA using the resolved path
|
||||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner
|
|||||||
from ..config import config
|
from ..config import config
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from .utils import FlexibleOptionalInputType, any_type
|
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -30,48 +30,6 @@ class LoraStacker:
|
|||||||
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
|
||||||
FUNCTION = "stack_loras"
|
FUNCTION = "stack_loras"
|
||||||
|
|
||||||
async def get_lora_info(self, lora_name):
|
|
||||||
"""Get the lora path and trigger words from cache"""
|
|
||||||
scanner = await LoraScanner.get_instance()
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == lora_name:
|
|
||||||
file_path = item.get('file_path')
|
|
||||||
if file_path:
|
|
||||||
for root in config.loras_roots:
|
|
||||||
root = root.replace(os.sep, '/')
|
|
||||||
if file_path.startswith(root):
|
|
||||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
|
||||||
# Get trigger words from civitai metadata
|
|
||||||
civitai = item.get('civitai', {})
|
|
||||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
|
||||||
return relative_path, trigger_words
|
|
||||||
return lora_name, [] # Fallback if not found
|
|
||||||
|
|
||||||
def extract_lora_name(self, lora_path):
|
|
||||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
|
||||||
# Get the basename without extension
|
|
||||||
basename = os.path.basename(lora_path)
|
|
||||||
return os.path.splitext(basename)[0]
|
|
||||||
|
|
||||||
def _get_loras_list(self, kwargs):
|
|
||||||
"""Helper to extract loras list from either old or new kwargs format"""
|
|
||||||
if 'loras' not in kwargs:
|
|
||||||
return []
|
|
||||||
|
|
||||||
loras_data = kwargs['loras']
|
|
||||||
# Handle new format: {'loras': {'__value__': [...]}}
|
|
||||||
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
|
||||||
return loras_data['__value__']
|
|
||||||
# Handle old format: {'loras': [...]}
|
|
||||||
elif isinstance(loras_data, list):
|
|
||||||
return loras_data
|
|
||||||
# Unexpected format
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def stack_loras(self, text, **kwargs):
|
def stack_loras(self, text, **kwargs):
|
||||||
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
|
||||||
stack = []
|
stack = []
|
||||||
@@ -84,12 +42,12 @@ class LoraStacker:
|
|||||||
stack.extend(lora_stack)
|
stack.extend(lora_stack)
|
||||||
# Get trigger words from existing stack entries
|
# Get trigger words from existing stack entries
|
||||||
for lora_path, _, _ in lora_stack:
|
for lora_path, _, _ in lora_stack:
|
||||||
lora_name = self.extract_lora_name(lora_path)
|
lora_name = extract_lora_name(lora_path)
|
||||||
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
all_trigger_words.extend(trigger_words)
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
# Process loras from kwargs with support for both old and new formats
|
# Process loras from kwargs with support for both old and new formats
|
||||||
loras_list = self._get_loras_list(kwargs)
|
loras_list = get_loras_list(kwargs)
|
||||||
for lora in loras_list:
|
for lora in loras_list:
|
||||||
if not lora.get('active', False):
|
if not lora.get('active', False):
|
||||||
continue
|
continue
|
||||||
@@ -99,7 +57,7 @@ class LoraStacker:
|
|||||||
clip_strength = model_strength # Using same strength for both as in the original loader
|
clip_strength = model_strength # Using same strength for both as in the original loader
|
||||||
|
|
||||||
# Get lora path and trigger words
|
# Get lora path and trigger words
|
||||||
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
|
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||||
|
|
||||||
# Add to stack without loading
|
# Add to stack without loading
|
||||||
# replace '/' with os.sep to avoid different OS path format
|
# replace '/' with os.sep to avoid different OS path format
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import re
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from ..services.lora_scanner import LoraScanner
|
from ..services.lora_scanner import LoraScanner
|
||||||
|
from ..services.checkpoint_scanner import CheckpointScanner
|
||||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||||
from ..metadata_collector import get_metadata
|
from ..metadata_collector import get_metadata
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
@@ -53,18 +54,55 @@ class SaveImage:
|
|||||||
async def get_lora_hash(self, lora_name):
|
async def get_lora_hash(self, lora_name):
|
||||||
"""Get the lora hash from cache"""
|
"""Get the lora hash from cache"""
|
||||||
scanner = await LoraScanner.get_instance()
|
scanner = await LoraScanner.get_instance()
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
|
# Use the new direct filename lookup method
|
||||||
|
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||||
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
|
# Fallback to old method for compatibility
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
for item in cache.raw_data:
|
for item in cache.raw_data:
|
||||||
if item.get('file_name') == lora_name:
|
if item.get('file_name') == lora_name:
|
||||||
return item.get('sha256')
|
return item.get('sha256')
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def get_checkpoint_hash(self, checkpoint_path):
|
||||||
|
"""Get the checkpoint hash from cache"""
|
||||||
|
scanner = await CheckpointScanner.get_instance()
|
||||||
|
|
||||||
|
if not checkpoint_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract basename without extension
|
||||||
|
checkpoint_name = os.path.basename(checkpoint_path)
|
||||||
|
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||||
|
|
||||||
|
# Try direct filename lookup first
|
||||||
|
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||||
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
|
# Fallback to old method for compatibility
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
normalized_path = checkpoint_path.replace('\\', '/')
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
|
||||||
|
return item.get('sha256')
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
async def format_metadata(self, metadata_dict):
|
async def format_metadata(self, metadata_dict):
|
||||||
"""Format metadata in the requested format similar to userComment example"""
|
"""Format metadata in the requested format similar to userComment example"""
|
||||||
if not metadata_dict:
|
if not metadata_dict:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# Helper function to only add parameter if value is not None
|
||||||
|
def add_param_if_not_none(param_list, label, value):
|
||||||
|
if value is not None:
|
||||||
|
param_list.append(f"{label}: {value}")
|
||||||
|
|
||||||
# Extract the prompt and negative prompt
|
# Extract the prompt and negative prompt
|
||||||
prompt = metadata_dict.get('prompt', '')
|
prompt = metadata_dict.get('prompt', '')
|
||||||
negative_prompt = metadata_dict.get('negative_prompt', '')
|
negative_prompt = metadata_dict.get('negative_prompt', '')
|
||||||
@@ -100,7 +138,11 @@ class SaveImage:
|
|||||||
|
|
||||||
# Add standard parameters in the correct order
|
# Add standard parameters in the correct order
|
||||||
if 'steps' in metadata_dict:
|
if 'steps' in metadata_dict:
|
||||||
params.append(f"Steps: {metadata_dict.get('steps')}")
|
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
|
||||||
|
|
||||||
|
# Combine sampler and scheduler information
|
||||||
|
sampler_name = None
|
||||||
|
scheduler_name = None
|
||||||
|
|
||||||
if 'sampler' in metadata_dict:
|
if 'sampler' in metadata_dict:
|
||||||
sampler = metadata_dict.get('sampler')
|
sampler = metadata_dict.get('sampler')
|
||||||
@@ -123,7 +165,6 @@ class SaveImage:
|
|||||||
'ddim': 'DDIM'
|
'ddim': 'DDIM'
|
||||||
}
|
}
|
||||||
sampler_name = sampler_mapping.get(sampler, sampler)
|
sampler_name = sampler_mapping.get(sampler, sampler)
|
||||||
params.append(f"Sampler: {sampler_name}")
|
|
||||||
|
|
||||||
if 'scheduler' in metadata_dict:
|
if 'scheduler' in metadata_dict:
|
||||||
scheduler = metadata_dict.get('scheduler')
|
scheduler = metadata_dict.get('scheduler')
|
||||||
@@ -135,38 +176,48 @@ class SaveImage:
|
|||||||
'sgm_quadratic': 'SGM Quadratic'
|
'sgm_quadratic': 'SGM Quadratic'
|
||||||
}
|
}
|
||||||
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
|
||||||
params.append(f"Schedule type: {scheduler_name}")
|
|
||||||
|
|
||||||
# CFG scale (cfg_scale in metadata_dict)
|
# Add combined sampler and scheduler information
|
||||||
if 'cfg_scale' in metadata_dict:
|
if sampler_name:
|
||||||
params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
|
if scheduler_name:
|
||||||
|
params.append(f"Sampler: {sampler_name} {scheduler_name}")
|
||||||
|
else:
|
||||||
|
params.append(f"Sampler: {sampler_name}")
|
||||||
|
|
||||||
|
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
|
||||||
|
if 'guidance' in metadata_dict:
|
||||||
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
|
||||||
|
elif 'cfg_scale' in metadata_dict:
|
||||||
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
|
||||||
elif 'cfg' in metadata_dict:
|
elif 'cfg' in metadata_dict:
|
||||||
params.append(f"CFG scale: {metadata_dict.get('cfg')}")
|
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
|
||||||
|
|
||||||
# Seed
|
# Seed
|
||||||
if 'seed' in metadata_dict:
|
if 'seed' in metadata_dict:
|
||||||
params.append(f"Seed: {metadata_dict.get('seed')}")
|
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
|
||||||
|
|
||||||
# Size
|
# Size
|
||||||
if 'size' in metadata_dict:
|
if 'size' in metadata_dict:
|
||||||
params.append(f"Size: {metadata_dict.get('size')}")
|
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
|
||||||
|
|
||||||
# Model info
|
# Model info
|
||||||
if 'checkpoint' in metadata_dict:
|
if 'checkpoint' in metadata_dict:
|
||||||
# Ensure checkpoint is a string before processing
|
# Ensure checkpoint is a string before processing
|
||||||
checkpoint = metadata_dict.get('checkpoint')
|
checkpoint = metadata_dict.get('checkpoint')
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
# Handle both string and other types safely
|
# Get model hash
|
||||||
if isinstance(checkpoint, str):
|
model_hash = await self.get_checkpoint_hash(checkpoint)
|
||||||
# Extract basename without path
|
|
||||||
checkpoint = os.path.basename(checkpoint)
|
|
||||||
# Remove extension if present
|
|
||||||
checkpoint = os.path.splitext(checkpoint)[0]
|
|
||||||
else:
|
|
||||||
# Convert non-string to string
|
|
||||||
checkpoint = str(checkpoint)
|
|
||||||
|
|
||||||
params.append(f"Model: {checkpoint}")
|
# Extract basename without path
|
||||||
|
checkpoint_name = os.path.basename(checkpoint)
|
||||||
|
# Remove extension if present
|
||||||
|
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||||
|
|
||||||
|
# Add model hash if available
|
||||||
|
if model_hash:
|
||||||
|
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
|
||||||
|
else:
|
||||||
|
params.append(f"Model: {checkpoint_name}")
|
||||||
|
|
||||||
# Add LoRA hashes if available
|
# Add LoRA hashes if available
|
||||||
if lora_hashes:
|
if lora_hashes:
|
||||||
@@ -284,7 +335,7 @@ class SaveImage:
|
|||||||
if add_counter_to_filename:
|
if add_counter_to_filename:
|
||||||
# Use counter + i to ensure unique filenames for all images in batch
|
# Use counter + i to ensure unique filenames for all images in batch
|
||||||
current_counter = counter + i
|
current_counter = counter + i
|
||||||
base_filename += f"_{current_counter:05}"
|
base_filename += f"_{current_counter:05}_"
|
||||||
|
|
||||||
# Set file extension and prepare saving parameters
|
# Set file extension and prepare saving parameters
|
||||||
if file_format == "png":
|
if file_format == "png":
|
||||||
|
|||||||
@@ -47,10 +47,10 @@ class TriggerWordToggle:
|
|||||||
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
|
||||||
|
|
||||||
# Send trigger words to frontend
|
# Send trigger words to frontend
|
||||||
PromptServer.instance.send_sync("trigger_word_update", {
|
# PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
"id": id,
|
# "id": id,
|
||||||
"message": trigger_words
|
# "message": trigger_words
|
||||||
})
|
# })
|
||||||
|
|
||||||
filtered_triggers = trigger_words
|
filtered_triggers = trigger_words
|
||||||
|
|
||||||
|
|||||||
@@ -31,3 +31,54 @@ class FlexibleOptionalInputType(dict):
|
|||||||
|
|
||||||
|
|
||||||
any_type = AnyType("*")
|
any_type = AnyType("*")
|
||||||
|
|
||||||
|
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
from ..services.lora_scanner import LoraScanner
|
||||||
|
from ..config import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def get_lora_info(lora_name):
|
||||||
|
"""Get the lora path and trigger words from cache"""
|
||||||
|
scanner = await LoraScanner.get_instance()
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
for item in cache.raw_data:
|
||||||
|
if item.get('file_name') == lora_name:
|
||||||
|
file_path = item.get('file_path')
|
||||||
|
if file_path:
|
||||||
|
for root in config.loras_roots:
|
||||||
|
root = root.replace(os.sep, '/')
|
||||||
|
if file_path.startswith(root):
|
||||||
|
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||||
|
# Get trigger words from civitai metadata
|
||||||
|
civitai = item.get('civitai', {})
|
||||||
|
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||||
|
return relative_path, trigger_words
|
||||||
|
return lora_name, [] # Fallback if not found
|
||||||
|
|
||||||
|
def extract_lora_name(lora_path):
|
||||||
|
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||||
|
# Get the basename without extension
|
||||||
|
basename = os.path.basename(lora_path)
|
||||||
|
return os.path.splitext(basename)[0]
|
||||||
|
|
||||||
|
def get_loras_list(kwargs):
|
||||||
|
"""Helper to extract loras list from either old or new kwargs format"""
|
||||||
|
if 'loras' not in kwargs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
loras_data = kwargs['loras']
|
||||||
|
# Handle new format: {'loras': {'__value__': [...]}}
|
||||||
|
if isinstance(loras_data, dict) and '__value__' in loras_data:
|
||||||
|
return loras_data['__value__']
|
||||||
|
# Handle old format: {'loras': [...]}
|
||||||
|
elif isinstance(loras_data, list):
|
||||||
|
return loras_data
|
||||||
|
# Unexpected format
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||||
|
return []
|
||||||
@@ -3,8 +3,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..utils.routes_common import ModelRouteUtils
|
||||||
|
from ..nodes.utils import get_lora_info
|
||||||
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..services.websocket_manager import ws_manager
|
from ..services.websocket_manager import ws_manager
|
||||||
@@ -65,6 +67,9 @@ class ApiRoutes:
|
|||||||
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
|
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
|
||||||
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
|
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
|
||||||
|
|
||||||
|
# Add the new trigger words route
|
||||||
|
app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words)
|
||||||
|
|
||||||
# Add update check routes
|
# Add update check routes
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
|
|
||||||
@@ -1022,3 +1027,34 @@ class ApiRoutes:
|
|||||||
'success': False,
|
'success': False,
|
||||||
'error': str(e)
|
'error': str(e)
|
||||||
}, status=500)
|
}, status=500)
|
||||||
|
|
||||||
|
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||||
|
"""Get trigger words for specified LoRA models"""
|
||||||
|
try:
|
||||||
|
json_data = await request.json()
|
||||||
|
lora_names = json_data.get("lora_names", [])
|
||||||
|
node_ids = json_data.get("node_ids", [])
|
||||||
|
|
||||||
|
all_trigger_words = []
|
||||||
|
for lora_name in lora_names:
|
||||||
|
_, trigger_words = await get_lora_info(lora_name)
|
||||||
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
|
# Format the trigger words
|
||||||
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
|
# Send update to all connected trigger word toggle nodes
|
||||||
|
for node_id in node_ids:
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
|
"id": node_id,
|
||||||
|
"message": trigger_words_text
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting trigger words: {e}")
|
||||||
|
return web.json_response({
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}, status=500)
|
||||||
69
py/routes/usage_stats_routes.py
Normal file
69
py/routes/usage_stats_routes.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import logging
|
||||||
|
from aiohttp import web
|
||||||
|
from ..utils.usage_stats import UsageStats
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UsageStatsRoutes:
|
||||||
|
"""Routes for handling usage statistics updates"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def setup_routes(app):
|
||||||
|
"""Register usage stats routes"""
|
||||||
|
app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats)
|
||||||
|
app.router.add_get('/loras/api/get-usage-stats', UsageStatsRoutes.get_usage_stats)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def update_usage_stats(request):
|
||||||
|
"""
|
||||||
|
Update usage statistics based on a prompt_id
|
||||||
|
|
||||||
|
Expects a JSON body with:
|
||||||
|
{
|
||||||
|
"prompt_id": "string"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Parse the request body
|
||||||
|
data = await request.json()
|
||||||
|
prompt_id = data.get('prompt_id')
|
||||||
|
|
||||||
|
if not prompt_id:
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': 'Missing prompt_id'
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
|
# Call the UsageStats to process this prompt_id synchronously
|
||||||
|
usage_stats = UsageStats()
|
||||||
|
await usage_stats.process_execution(prompt_id)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def get_usage_stats(request):
|
||||||
|
"""Get current usage statistics"""
|
||||||
|
try:
|
||||||
|
usage_stats = UsageStats()
|
||||||
|
stats = await usage_stats.get_stats()
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'data': stats
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': str(e)
|
||||||
|
}, status=500)
|
||||||
26
py/server_routes.py
Normal file
26
py/server_routes.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from aiohttp import web
|
||||||
|
from server import PromptServer
|
||||||
|
from .nodes.utils import get_lora_info
|
||||||
|
|
||||||
|
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
|
||||||
|
async def get_trigger_words(request):
|
||||||
|
json_data = await request.json()
|
||||||
|
lora_names = json_data.get("lora_names", [])
|
||||||
|
node_ids = json_data.get("node_ids", [])
|
||||||
|
|
||||||
|
all_trigger_words = []
|
||||||
|
for lora_name in lora_names:
|
||||||
|
_, trigger_words = await get_lora_info(lora_name)
|
||||||
|
all_trigger_words.extend(trigger_words)
|
||||||
|
|
||||||
|
# Format the trigger words
|
||||||
|
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||||
|
|
||||||
|
# Send update to all connected trigger word toggle nodes
|
||||||
|
for node_id in node_ids:
|
||||||
|
PromptServer.instance.send_sync("trigger_word_update", {
|
||||||
|
"id": node_id,
|
||||||
|
"message": trigger_words_text
|
||||||
|
})
|
||||||
|
|
||||||
|
return web.json_response({"success": True})
|
||||||
@@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Set
|
|||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .model_scanner import ModelScanner
|
from .model_scanner import ModelScanner
|
||||||
from .lora_hash_index import LoraHashIndex
|
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||||
from .settings_manager import settings
|
from .settings_manager import settings
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from ..utils.constants import NSFW_LEVELS
|
||||||
from ..utils.utils import fuzzy_match
|
from ..utils.utils import fuzzy_match
|
||||||
@@ -35,12 +35,12 @@ class LoraScanner(ModelScanner):
|
|||||||
# Define supported file extensions
|
# Define supported file extensions
|
||||||
file_extensions = {'.safetensors'}
|
file_extensions = {'.safetensors'}
|
||||||
|
|
||||||
# Initialize parent class
|
# Initialize parent class with ModelHashIndex
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_type="lora",
|
model_type="lora",
|
||||||
model_class=LoraMetadata,
|
model_class=LoraMetadata,
|
||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=LoraHashIndex()
|
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||||
)
|
)
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from typing import Dict, Optional, Set
|
from typing import Dict, Optional, Set
|
||||||
|
import os
|
||||||
|
|
||||||
class ModelHashIndex:
|
class ModelHashIndex:
|
||||||
"""Index for looking up models by hash or path"""
|
"""Index for looking up models by hash or path"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._hash_to_path: Dict[str, str] = {}
|
self._hash_to_path: Dict[str, str] = {}
|
||||||
self._path_to_hash: Dict[str, str] = {}
|
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
|
||||||
|
|
||||||
def add_entry(self, sha256: str, file_path: str) -> None:
|
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||||
"""Add or update hash index entry"""
|
"""Add or update hash index entry"""
|
||||||
@@ -15,37 +16,47 @@ class ModelHashIndex:
|
|||||||
# Ensure hash is lowercase for consistency
|
# Ensure hash is lowercase for consistency
|
||||||
sha256 = sha256.lower()
|
sha256 = sha256.lower()
|
||||||
|
|
||||||
|
# Extract filename without extension
|
||||||
|
filename = self._get_filename_from_path(file_path)
|
||||||
|
|
||||||
# Remove old path mapping if hash exists
|
# Remove old path mapping if hash exists
|
||||||
if sha256 in self._hash_to_path:
|
if sha256 in self._hash_to_path:
|
||||||
old_path = self._hash_to_path[sha256]
|
old_path = self._hash_to_path[sha256]
|
||||||
if old_path in self._path_to_hash:
|
old_filename = self._get_filename_from_path(old_path)
|
||||||
del self._path_to_hash[old_path]
|
if old_filename in self._filename_to_hash:
|
||||||
|
del self._filename_to_hash[old_filename]
|
||||||
|
|
||||||
# Remove old hash mapping if path exists
|
# Remove old hash mapping if filename exists
|
||||||
if file_path in self._path_to_hash:
|
if filename in self._filename_to_hash:
|
||||||
old_hash = self._path_to_hash[file_path]
|
old_hash = self._filename_to_hash[filename]
|
||||||
if old_hash in self._hash_to_path:
|
if old_hash in self._hash_to_path:
|
||||||
del self._hash_to_path[old_hash]
|
del self._hash_to_path[old_hash]
|
||||||
|
|
||||||
# Add new mappings
|
# Add new mappings
|
||||||
self._hash_to_path[sha256] = file_path
|
self._hash_to_path[sha256] = file_path
|
||||||
self._path_to_hash[file_path] = sha256
|
self._filename_to_hash[filename] = sha256
|
||||||
|
|
||||||
|
def _get_filename_from_path(self, file_path: str) -> str:
|
||||||
|
"""Extract filename without extension from path"""
|
||||||
|
return os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
def remove_by_path(self, file_path: str) -> None:
|
def remove_by_path(self, file_path: str) -> None:
|
||||||
"""Remove entry by file path"""
|
"""Remove entry by file path"""
|
||||||
if file_path in self._path_to_hash:
|
filename = self._get_filename_from_path(file_path)
|
||||||
hash_val = self._path_to_hash[file_path]
|
if filename in self._filename_to_hash:
|
||||||
|
hash_val = self._filename_to_hash[filename]
|
||||||
if hash_val in self._hash_to_path:
|
if hash_val in self._hash_to_path:
|
||||||
del self._hash_to_path[hash_val]
|
del self._hash_to_path[hash_val]
|
||||||
del self._path_to_hash[file_path]
|
del self._filename_to_hash[filename]
|
||||||
|
|
||||||
def remove_by_hash(self, sha256: str) -> None:
|
def remove_by_hash(self, sha256: str) -> None:
|
||||||
"""Remove entry by hash"""
|
"""Remove entry by hash"""
|
||||||
sha256 = sha256.lower()
|
sha256 = sha256.lower()
|
||||||
if sha256 in self._hash_to_path:
|
if sha256 in self._hash_to_path:
|
||||||
path = self._hash_to_path[sha256]
|
path = self._hash_to_path[sha256]
|
||||||
if path in self._path_to_hash:
|
filename = self._get_filename_from_path(path)
|
||||||
del self._path_to_hash[path]
|
if filename in self._filename_to_hash:
|
||||||
|
del self._filename_to_hash[filename]
|
||||||
del self._hash_to_path[sha256]
|
del self._hash_to_path[sha256]
|
||||||
|
|
||||||
def has_hash(self, sha256: str) -> bool:
|
def has_hash(self, sha256: str) -> bool:
|
||||||
@@ -58,20 +69,27 @@ class ModelHashIndex:
|
|||||||
|
|
||||||
def get_hash(self, file_path: str) -> Optional[str]:
|
def get_hash(self, file_path: str) -> Optional[str]:
|
||||||
"""Get hash for a file path"""
|
"""Get hash for a file path"""
|
||||||
return self._path_to_hash.get(file_path)
|
filename = self._get_filename_from_path(file_path)
|
||||||
|
return self._filename_to_hash.get(filename)
|
||||||
|
|
||||||
|
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||||
|
"""Get hash for a filename without extension"""
|
||||||
|
# Strip extension if present to make the function more flexible
|
||||||
|
filename = os.path.splitext(filename)[0]
|
||||||
|
return self._filename_to_hash.get(filename)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all entries"""
|
"""Clear all entries"""
|
||||||
self._hash_to_path.clear()
|
self._hash_to_path.clear()
|
||||||
self._path_to_hash.clear()
|
self._filename_to_hash.clear()
|
||||||
|
|
||||||
def get_all_hashes(self) -> Set[str]:
|
def get_all_hashes(self) -> Set[str]:
|
||||||
"""Get all hashes in the index"""
|
"""Get all hashes in the index"""
|
||||||
return set(self._hash_to_path.keys())
|
return set(self._hash_to_path.keys())
|
||||||
|
|
||||||
def get_all_paths(self) -> Set[str]:
|
def get_all_filenames(self) -> Set[str]:
|
||||||
"""Get all file paths in the index"""
|
"""Get all filenames in the index"""
|
||||||
return set(self._path_to_hash.keys())
|
return set(self._filename_to_hash.keys())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Get number of entries"""
|
"""Get number of entries"""
|
||||||
|
|||||||
@@ -292,7 +292,7 @@ class ModelScanner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If force refresh is requested, initialize the cache directly
|
# If force refresh is requested, initialize the cache directly
|
||||||
if force_refresh:
|
if (force_refresh):
|
||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
# For initial creation, do a full initialization
|
# For initial creation, do a full initialization
|
||||||
await self._initialize_cache()
|
await self._initialize_cache()
|
||||||
@@ -553,9 +553,36 @@ class ModelScanner:
|
|||||||
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
||||||
|
else:
|
||||||
|
# Check if metadata exists but civitai field is empty - try to restore from civitai.info
|
||||||
|
if metadata.civitai is None or metadata.civitai == {}:
|
||||||
|
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
||||||
|
if os.path.exists(civitai_info_path):
|
||||||
|
try:
|
||||||
|
with open(civitai_info_path, 'r', encoding='utf-8') as f:
|
||||||
|
version_info = json.load(f)
|
||||||
|
|
||||||
if metadata is None:
|
logger.debug(f"Restoring missing civitai data from .civitai.info for {file_path}")
|
||||||
metadata = await self._get_file_info(file_path)
|
metadata.civitai = version_info
|
||||||
|
|
||||||
|
# Ensure tags are also updated if they're missing
|
||||||
|
if (not metadata.tags or len(metadata.tags) == 0) and 'model' in version_info:
|
||||||
|
if 'tags' in version_info['model']:
|
||||||
|
metadata.tags = version_info['model']['tags']
|
||||||
|
|
||||||
|
# Also restore description if missing
|
||||||
|
if (not metadata.modelDescription or metadata.modelDescription == "") and 'model' in version_info:
|
||||||
|
if 'description' in version_info['model']:
|
||||||
|
metadata.modelDescription = version_info['model']['description']
|
||||||
|
|
||||||
|
# Save the updated metadata
|
||||||
|
await save_metadata(file_path, metadata)
|
||||||
|
logger.debug(f"Updated metadata with civitai info for {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = await self._get_file_info(file_path)
|
||||||
|
|
||||||
model_data = metadata.to_dict()
|
model_data = metadata.to_dict()
|
||||||
|
|
||||||
@@ -806,6 +833,10 @@ class ModelScanner:
|
|||||||
"""Get hash for a model by its file path"""
|
"""Get hash for a model by its file path"""
|
||||||
return self._hash_index.get_hash(file_path)
|
return self._hash_index.get_hash(file_path)
|
||||||
|
|
||||||
|
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||||
|
"""Get hash for a model by its filename without path"""
|
||||||
|
return self._hash_index.get_hash_by_filename(filename)
|
||||||
|
|
||||||
# TODO: Adjust this method to use metadata instead of finding the file
|
# TODO: Adjust this method to use metadata instead of finding the file
|
||||||
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
"""Get preview static URL for a model by its hash"""
|
"""Get preview static URL for a model by its hash"""
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ class BaseModelMetadata:
|
|||||||
civitai: Optional[Dict] = None # Civitai API data if available
|
civitai: Optional[Dict] = None # Civitai API data if available
|
||||||
tags: List[str] = None # Model tags
|
tags: List[str] = None # Model tags
|
||||||
modelDescription: str = "" # Full model description
|
modelDescription: str = "" # Full model description
|
||||||
|
civitai_deleted: bool = False # Whether deleted from Civitai
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize empty lists to avoid mutable default parameter issue
|
# Initialize empty lists to avoid mutable default parameter issue
|
||||||
@@ -64,6 +65,15 @@ class LoraMetadata(BaseModelMetadata):
|
|||||||
file_name = file_info['name']
|
file_name = file_info['name']
|
||||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
|
|
||||||
|
# Extract tags and description if available
|
||||||
|
tags = []
|
||||||
|
description = ""
|
||||||
|
if 'model' in version_info:
|
||||||
|
if 'tags' in version_info['model']:
|
||||||
|
tags = version_info['model']['tags']
|
||||||
|
if 'description' in version_info['model']:
|
||||||
|
description = version_info['model']['description']
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
file_name=os.path.splitext(file_name)[0],
|
file_name=os.path.splitext(file_name)[0],
|
||||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||||
@@ -75,7 +85,9 @@ class LoraMetadata(BaseModelMetadata):
|
|||||||
preview_url=None, # Will be updated after preview download
|
preview_url=None, # Will be updated after preview download
|
||||||
preview_nsfw_level=0, # Will be updated after preview download
|
preview_nsfw_level=0, # Will be updated after preview download
|
||||||
from_civitai=True,
|
from_civitai=True,
|
||||||
civitai=version_info
|
civitai=version_info,
|
||||||
|
tags=tags,
|
||||||
|
modelDescription=description
|
||||||
)
|
)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -90,6 +102,15 @@ class CheckpointMetadata(BaseModelMetadata):
|
|||||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||||
model_type = version_info.get('type', 'checkpoint')
|
model_type = version_info.get('type', 'checkpoint')
|
||||||
|
|
||||||
|
# Extract tags and description if available
|
||||||
|
tags = []
|
||||||
|
description = ""
|
||||||
|
if 'model' in version_info:
|
||||||
|
if 'tags' in version_info['model']:
|
||||||
|
tags = version_info['model']['tags']
|
||||||
|
if 'description' in version_info['model']:
|
||||||
|
description = version_info['model']['description']
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
file_name=os.path.splitext(file_name)[0],
|
file_name=os.path.splitext(file_name)[0],
|
||||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||||
@@ -102,6 +123,8 @@ class CheckpointMetadata(BaseModelMetadata):
|
|||||||
preview_nsfw_level=0,
|
preview_nsfw_level=0,
|
||||||
from_civitai=True,
|
from_civitai=True,
|
||||||
civitai=version_info,
|
civitai=version_info,
|
||||||
model_type=model_type
|
model_type=model_type,
|
||||||
|
tags=tags,
|
||||||
|
modelDescription=description
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -45,14 +45,14 @@ class RecipeMetadataParser(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info: Dict[str, Any],
|
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
|
||||||
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
|
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Populate a lora entry with information from Civitai API response
|
Populate a lora entry with information from Civitai API response
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lora_entry: The lora entry to populate
|
lora_entry: The lora entry to populate
|
||||||
civitai_info: The response from Civitai API
|
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
|
||||||
recipe_scanner: Optional recipe scanner for local file lookup
|
recipe_scanner: Optional recipe scanner for local file lookup
|
||||||
base_model_counts: Optional dict to track base model counts
|
base_model_counts: Optional dict to track base model counts
|
||||||
hash_value: Optional hash value to use if not available in civitai_info
|
hash_value: Optional hash value to use if not available in civitai_info
|
||||||
@@ -61,6 +61,9 @@ class RecipeMetadataParser(ABC):
|
|||||||
The populated lora_entry dict
|
The populated lora_entry dict
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Unpack the tuple to get the actual data
|
||||||
|
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||||
|
|
||||||
if civitai_info and civitai_info.get("error") != "Model not found":
|
if civitai_info and civitai_info.get("error") != "Model not found":
|
||||||
# Check if this is an early access lora
|
# Check if this is an early access lora
|
||||||
if civitai_info.get('earlyAccessEndsAt'):
|
if civitai_info.get('earlyAccessEndsAt'):
|
||||||
@@ -241,11 +244,11 @@ class RecipeFormatParser(RecipeMetadataParser):
|
|||||||
# Try to get additional info from Civitai if we have a model version ID
|
# Try to get additional info from Civitai if we have a model version ID
|
||||||
if lora.get('modelVersionId') and civitai_client:
|
if lora.get('modelVersionId') and civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(lora['modelVersionId'])
|
civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
|
||||||
# Populate lora entry with Civitai info
|
# Populate lora entry with Civitai info
|
||||||
lora_entry = await self.populate_lora_from_civitai(
|
lora_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info_tuple,
|
||||||
recipe_scanner,
|
recipe_scanner,
|
||||||
None, # No need to track base model counts
|
None, # No need to track base model counts
|
||||||
lora['hash']
|
lora['hash']
|
||||||
@@ -336,12 +339,13 @@ class StandardMetadataParser(RecipeMetadataParser):
|
|||||||
# Get additional info from Civitai if client is available
|
# Get additional info from Civitai if client is available
|
||||||
if civitai_client:
|
if civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(model_version_id)
|
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
|
||||||
# Populate lora entry with Civitai info
|
# Populate lora entry with Civitai info
|
||||||
lora_entry = await self.populate_lora_from_civitai(
|
lora_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info_tuple,
|
||||||
recipe_scanner
|
recipe_scanner,
|
||||||
|
base_model_counts
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
logger.error(f"Error fetching Civitai info for LoRA: {e}")
|
||||||
@@ -621,11 +625,11 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
|||||||
# Get additional info from Civitai if client is available
|
# Get additional info from Civitai if client is available
|
||||||
if civitai_client:
|
if civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(model_version_id)
|
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
|
||||||
# Populate lora entry with Civitai info
|
# Populate lora entry with Civitai info
|
||||||
lora_entry = await self.populate_lora_from_civitai(
|
lora_entry = await self.populate_lora_from_civitai(
|
||||||
lora_entry,
|
lora_entry,
|
||||||
civitai_info,
|
civitai_info_tuple,
|
||||||
recipe_scanner
|
recipe_scanner
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -660,7 +664,8 @@ class ComfyMetadataParser(RecipeMetadataParser):
|
|||||||
# Get additional checkpoint info from Civitai
|
# Get additional checkpoint info from Civitai
|
||||||
if civitai_client:
|
if civitai_client:
|
||||||
try:
|
try:
|
||||||
civitai_info = await civitai_client.get_model_version_info(checkpoint_version_id)
|
civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
|
||||||
|
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
|
||||||
# Populate checkpoint with Civitai info
|
# Populate checkpoint with Civitai info
|
||||||
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
267
py/utils/usage_stats.py
Normal file
267
py/utils/usage_stats.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Set
|
||||||
|
|
||||||
|
from ..config import config
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||||
|
from ..metadata_collector.constants import MODELS, LORAS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UsageStats:
|
||||||
|
"""Track usage statistics for models and save to JSON"""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = asyncio.Lock() # For thread safety
|
||||||
|
|
||||||
|
# Default stats file name
|
||||||
|
STATS_FILENAME = "lora_manager_stats.json"
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize stats storage
|
||||||
|
self.stats = {
|
||||||
|
"checkpoints": {}, # sha256 -> count
|
||||||
|
"loras": {}, # sha256 -> count
|
||||||
|
"total_executions": 0,
|
||||||
|
"last_save_time": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# Queue for prompt_ids to process
|
||||||
|
self.pending_prompt_ids = set()
|
||||||
|
|
||||||
|
# Load existing stats if available
|
||||||
|
self._stats_file_path = self._get_stats_file_path()
|
||||||
|
self._load_stats()
|
||||||
|
|
||||||
|
# Save interval in seconds
|
||||||
|
self.save_interval = 90 # 1.5 minutes
|
||||||
|
|
||||||
|
# Start background task to process queued prompt_ids
|
||||||
|
self._bg_task = asyncio.create_task(self._background_processor())
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
logger.info("Usage statistics tracker initialized")
|
||||||
|
|
||||||
|
def _get_stats_file_path(self) -> str:
|
||||||
|
"""Get the path to the stats JSON file"""
|
||||||
|
if not config.loras_roots or len(config.loras_roots) == 0:
|
||||||
|
# Fallback to temporary directory if no lora roots
|
||||||
|
return os.path.join(config.temp_directory, self.STATS_FILENAME)
|
||||||
|
|
||||||
|
# Use the first lora root
|
||||||
|
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
|
||||||
|
|
||||||
|
def _load_stats(self):
|
||||||
|
"""Load existing statistics from file"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(self._stats_file_path):
|
||||||
|
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
|
||||||
|
loaded_stats = json.load(f)
|
||||||
|
|
||||||
|
# Update our stats with loaded data
|
||||||
|
if isinstance(loaded_stats, dict):
|
||||||
|
# Update individual sections to maintain structure
|
||||||
|
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
|
||||||
|
self.stats["checkpoints"] = loaded_stats["checkpoints"]
|
||||||
|
|
||||||
|
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
|
||||||
|
self.stats["loras"] = loaded_stats["loras"]
|
||||||
|
|
||||||
|
if "total_executions" in loaded_stats:
|
||||||
|
self.stats["total_executions"] = loaded_stats["total_executions"]
|
||||||
|
|
||||||
|
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading usage statistics: {e}")
|
||||||
|
|
||||||
|
async def save_stats(self, force=False):
|
||||||
|
"""Save statistics to file"""
|
||||||
|
try:
|
||||||
|
# Only save if it's been at least save_interval since last save or force is True
|
||||||
|
current_time = time.time()
|
||||||
|
if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Use a lock to prevent concurrent writes
|
||||||
|
async with self._lock:
|
||||||
|
# Update last save time
|
||||||
|
self.stats["last_save_time"] = current_time
|
||||||
|
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Write to a temporary file first, then move it to avoid corruption
|
||||||
|
temp_path = f"{self._stats_file_path}.tmp"
|
||||||
|
with open(temp_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(self.stats, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Replace the old file with the new one
|
||||||
|
os.replace(temp_path, self._stats_file_path)
|
||||||
|
|
||||||
|
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def register_execution(self, prompt_id):
|
||||||
|
"""Register a completed execution by prompt_id for later processing"""
|
||||||
|
if prompt_id:
|
||||||
|
self.pending_prompt_ids.add(prompt_id)
|
||||||
|
|
||||||
|
async def _background_processor(self):
|
||||||
|
"""Background task to process queued prompt_ids"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Wait a short interval before checking for new prompt_ids
|
||||||
|
await asyncio.sleep(5) # Check every 5 seconds
|
||||||
|
|
||||||
|
# Process any pending prompt_ids
|
||||||
|
if self.pending_prompt_ids:
|
||||||
|
async with self._lock:
|
||||||
|
# Get a copy of the set and clear original
|
||||||
|
prompt_ids = self.pending_prompt_ids.copy()
|
||||||
|
self.pending_prompt_ids.clear()
|
||||||
|
|
||||||
|
# Process each prompt_id
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
for prompt_id in prompt_ids:
|
||||||
|
try:
|
||||||
|
metadata = registry.get_metadata(prompt_id)
|
||||||
|
await self._process_metadata(metadata)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
|
||||||
|
|
||||||
|
# Periodically save stats
|
||||||
|
await self.save_stats()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Task was cancelled, clean up
|
||||||
|
await self.save_stats(force=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in background processing task: {e}", exc_info=True)
|
||||||
|
# Restart the task after a delay if it fails
|
||||||
|
asyncio.create_task(self._restart_background_task())
|
||||||
|
|
||||||
|
async def _restart_background_task(self):
|
||||||
|
"""Restart the background task after a delay"""
|
||||||
|
await asyncio.sleep(30) # Wait 30 seconds before restarting
|
||||||
|
self._bg_task = asyncio.create_task(self._background_processor())
|
||||||
|
|
||||||
|
async def _process_metadata(self, metadata):
|
||||||
|
"""Process metadata from an execution"""
|
||||||
|
if not metadata or not isinstance(metadata, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Increment total executions count
|
||||||
|
self.stats["total_executions"] += 1
|
||||||
|
|
||||||
|
# Process checkpoints
|
||||||
|
if MODELS in metadata and isinstance(metadata[MODELS], dict):
|
||||||
|
await self._process_checkpoints(metadata[MODELS])
|
||||||
|
|
||||||
|
# Process loras
|
||||||
|
if LORAS in metadata and isinstance(metadata[LORAS], dict):
|
||||||
|
await self._process_loras(metadata[LORAS])
|
||||||
|
|
||||||
|
async def _process_checkpoints(self, models_data):
|
||||||
|
"""Process checkpoint models from metadata"""
|
||||||
|
try:
|
||||||
|
# Get checkpoint scanner service
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
if not checkpoint_scanner:
|
||||||
|
logger.warning("Checkpoint scanner not available for usage tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
for node_id, model_info in models_data.items():
|
||||||
|
if not isinstance(model_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if this is a checkpoint model
|
||||||
|
model_type = model_info.get("type")
|
||||||
|
if model_type == "checkpoint":
|
||||||
|
model_name = model_info.get("name")
|
||||||
|
if not model_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Clean up filename (remove extension if present)
|
||||||
|
model_filename = os.path.splitext(os.path.basename(model_name))[0]
|
||||||
|
|
||||||
|
# Get hash for this checkpoint
|
||||||
|
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
|
||||||
|
if model_hash:
|
||||||
|
# Update stats for this checkpoint
|
||||||
|
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _process_loras(self, loras_data):
|
||||||
|
"""Process LoRA models from metadata"""
|
||||||
|
try:
|
||||||
|
# Get LoRA scanner service
|
||||||
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
if not lora_scanner:
|
||||||
|
logger.warning("LoRA scanner not available for usage tracking")
|
||||||
|
return
|
||||||
|
|
||||||
|
for node_id, lora_info in loras_data.items():
|
||||||
|
if not isinstance(lora_info, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the list of LoRAs from standardized format
|
||||||
|
lora_list = lora_info.get("lora_list", [])
|
||||||
|
for lora in lora_list:
|
||||||
|
if not isinstance(lora, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
lora_name = lora.get("name")
|
||||||
|
if not lora_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get hash for this LoRA
|
||||||
|
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
|
||||||
|
if lora_hash:
|
||||||
|
# Update stats for this LoRA
|
||||||
|
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def get_stats(self):
|
||||||
|
"""Get current usage statistics"""
|
||||||
|
return self.stats
|
||||||
|
|
||||||
|
async def get_model_usage_count(self, model_type, sha256):
|
||||||
|
"""Get usage count for a specific model by hash"""
|
||||||
|
if model_type == "checkpoint":
|
||||||
|
return self.stats["checkpoints"].get(sha256, 0)
|
||||||
|
elif model_type == "lora":
|
||||||
|
return self.stats["loras"].get(sha256, 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def process_execution(self, prompt_id):
|
||||||
|
"""Process a prompt execution immediately (synchronous approach)"""
|
||||||
|
if not prompt_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process metadata for this prompt_id
|
||||||
|
registry = MetadataRegistry()
|
||||||
|
metadata = registry.get_metadata(prompt_id)
|
||||||
|
if metadata:
|
||||||
|
await self._process_metadata(metadata)
|
||||||
|
# Save stats if needed
|
||||||
|
await self.save_stats()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "comfyui-lora-manager"
|
name = "comfyui-lora-manager"
|
||||||
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
||||||
version = "0.8.7"
|
version = "0.8.8"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { showCheckpointModal } from './checkpointModal/index.js';
|
import { showCheckpointModal } from './checkpointModal/index.js';
|
||||||
import { NSFW_LEVELS } from '../utils/constants.js';
|
import { NSFW_LEVELS } from '../utils/constants.js';
|
||||||
@@ -204,21 +204,7 @@ export function createCheckpointCard(checkpoint) {
|
|||||||
const checkpointName = card.dataset.file_name;
|
const checkpointName = card.dataset.file_name;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Modern clipboard API
|
await copyToClipboard(checkpointName, 'Checkpoint name copied');
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
|
||||||
await navigator.clipboard.writeText(checkpointName);
|
|
||||||
} else {
|
|
||||||
// Fallback for older browsers
|
|
||||||
const textarea = document.createElement('textarea');
|
|
||||||
textarea.value = checkpointName;
|
|
||||||
textarea.style.position = 'absolute';
|
|
||||||
textarea.style.left = '-99999px';
|
|
||||||
document.body.appendChild(textarea);
|
|
||||||
textarea.select();
|
|
||||||
document.execCommand('copy');
|
|
||||||
document.body.removeChild(textarea);
|
|
||||||
}
|
|
||||||
showToast('Checkpoint name copied', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { showToast, openCivitai } from '../utils/uiHelpers.js';
|
import { showToast, openCivitai, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { showLoraModal } from './loraModal/index.js';
|
import { showLoraModal } from './loraModal/index.js';
|
||||||
import { bulkManager } from '../managers/BulkManager.js';
|
import { bulkManager } from '../managers/BulkManager.js';
|
||||||
@@ -205,26 +205,7 @@ export function createLoraCard(lora) {
|
|||||||
const strength = usageTips.strength || 1;
|
const strength = usageTips.strength || 1;
|
||||||
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
||||||
|
|
||||||
try {
|
await copyToClipboard(loraSyntax, 'LoRA syntax copied');
|
||||||
// Modern clipboard API
|
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
|
||||||
await navigator.clipboard.writeText(loraSyntax);
|
|
||||||
} else {
|
|
||||||
// Fallback for older browsers
|
|
||||||
const textarea = document.createElement('textarea');
|
|
||||||
textarea.value = loraSyntax;
|
|
||||||
textarea.style.position = 'absolute';
|
|
||||||
textarea.style.left = '-99999px';
|
|
||||||
document.body.appendChild(textarea);
|
|
||||||
textarea.select();
|
|
||||||
document.execCommand('copy');
|
|
||||||
document.body.removeChild(textarea);
|
|
||||||
}
|
|
||||||
showToast('LoRA syntax copied', 'success');
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Copy failed:', err);
|
|
||||||
showToast('Copy failed', 'error');
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Civitai button click event
|
// Civitai button click event
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// Recipe Card Component
|
// Recipe Card Component
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { modalManager } from '../managers/ModalManager.js';
|
import { modalManager } from '../managers/ModalManager.js';
|
||||||
|
|
||||||
class RecipeCard {
|
class RecipeCard {
|
||||||
@@ -109,14 +109,11 @@ class RecipeCard {
|
|||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.success && data.syntax) {
|
if (data.success && data.syntax) {
|
||||||
return navigator.clipboard.writeText(data.syntax);
|
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.error || 'No syntax returned');
|
throw new Error(data.error || 'No syntax returned');
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(() => {
|
|
||||||
showToast('Recipe syntax copied to clipboard', 'success');
|
|
||||||
})
|
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
console.error('Failed to copy: ', err);
|
console.error('Failed to copy: ', err);
|
||||||
showToast('Failed to copy recipe syntax', 'error');
|
showToast('Failed to copy recipe syntax', 'error');
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
// Recipe Modal Component
|
// Recipe Modal Component
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
|
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
|
||||||
|
|
||||||
@@ -747,9 +747,8 @@ class RecipeModal {
|
|||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
|
|
||||||
if (data.success && data.syntax) {
|
if (data.success && data.syntax) {
|
||||||
// Copy to clipboard
|
// Use the centralized copyToClipboard utility function
|
||||||
await navigator.clipboard.writeText(data.syntax);
|
await copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||||
showToast('Recipe syntax copied to clipboard', 'success');
|
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.error || 'No syntax returned from server');
|
throw new Error(data.error || 'No syntax returned from server');
|
||||||
}
|
}
|
||||||
@@ -761,12 +760,7 @@ class RecipeModal {
|
|||||||
|
|
||||||
// Helper method to copy text to clipboard
|
// Helper method to copy text to clipboard
|
||||||
copyToClipboard(text, successMessage) {
|
copyToClipboard(text, successMessage) {
|
||||||
navigator.clipboard.writeText(text).then(() => {
|
copyToClipboard(text, successMessage);
|
||||||
showToast(successMessage, 'success');
|
|
||||||
}).catch(err => {
|
|
||||||
console.error('Failed to copy text: ', err);
|
|
||||||
showToast('Failed to copy text', 'error');
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new method to handle downloading missing LoRAs
|
// Add new method to handle downloading missing LoRAs
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* ShowcaseView.js
|
* ShowcaseView.js
|
||||||
* Handles showcase content (images, videos) display for checkpoint modal
|
* Handles showcase content (images, videos) display for checkpoint modal
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { state } from '../../state/index.js';
|
import { state } from '../../state/index.js';
|
||||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||||
|
|
||||||
@@ -307,8 +307,7 @@ function initMetadataPanelHandlers(container) {
|
|||||||
if (!promptElement) return;
|
if (!promptElement) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(promptElement.textContent);
|
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
|
||||||
showToast('Prompt copied to clipboard', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* RecipeTab - Handles the recipes tab in the Lora Modal
|
* RecipeTab - Handles the recipes tab in the Lora Modal
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -172,14 +172,11 @@ function copyRecipeSyntax(recipeId) {
|
|||||||
.then(response => response.json())
|
.then(response => response.json())
|
||||||
.then(data => {
|
.then(data => {
|
||||||
if (data.success && data.syntax) {
|
if (data.success && data.syntax) {
|
||||||
return navigator.clipboard.writeText(data.syntax);
|
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
|
||||||
} else {
|
} else {
|
||||||
throw new Error(data.error || 'No syntax returned');
|
throw new Error(data.error || 'No syntax returned');
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.then(() => {
|
|
||||||
showToast('Recipe syntax copied to clipboard', 'success');
|
|
||||||
})
|
|
||||||
.catch(err => {
|
.catch(err => {
|
||||||
console.error('Failed to copy: ', err);
|
console.error('Failed to copy: ', err);
|
||||||
showToast('Failed to copy recipe syntax', 'error');
|
showToast('Failed to copy recipe syntax', 'error');
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* ShowcaseView.js
|
* ShowcaseView.js
|
||||||
* 处理LoRA模型展示内容(图片、视频)的功能模块
|
* 处理LoRA模型展示内容(图片、视频)的功能模块
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { state } from '../../state/index.js';
|
import { state } from '../../state/index.js';
|
||||||
import { NSFW_LEVELS } from '../../utils/constants.js';
|
import { NSFW_LEVELS } from '../../utils/constants.js';
|
||||||
|
|
||||||
@@ -311,8 +311,7 @@ function initMetadataPanelHandlers(container) {
|
|||||||
if (!promptElement) return;
|
if (!promptElement) return;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(promptElement.textContent);
|
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
|
||||||
showToast('Prompt copied to clipboard', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* TriggerWords.js
|
* TriggerWords.js
|
||||||
* 处理LoRA模型触发词相关的功能模块
|
* 处理LoRA模型触发词相关的功能模块
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { saveModelMetadata } from './ModelMetadata.js';
|
import { saveModelMetadata } from './ModelMetadata.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -336,23 +336,7 @@ async function saveTriggerWords() {
|
|||||||
*/
|
*/
|
||||||
window.copyTriggerWord = async function(word) {
|
window.copyTriggerWord = async function(word) {
|
||||||
try {
|
try {
|
||||||
// Modern clipboard API - with fallback for non-secure contexts
|
await copyToClipboard(word, 'Trigger word copied');
|
||||||
if (navigator.clipboard && window.isSecureContext) {
|
|
||||||
await navigator.clipboard.writeText(word);
|
|
||||||
} else {
|
|
||||||
// Fallback for older browsers or non-secure contexts
|
|
||||||
const textarea = document.createElement('textarea');
|
|
||||||
textarea.value = word;
|
|
||||||
textarea.style.position = 'absolute';
|
|
||||||
textarea.style.left = '-99999px';
|
|
||||||
document.body.appendChild(textarea);
|
|
||||||
textarea.select();
|
|
||||||
const success = document.execCommand('copy');
|
|
||||||
document.body.removeChild(textarea);
|
|
||||||
|
|
||||||
if (!success) throw new Error('Copy command failed');
|
|
||||||
}
|
|
||||||
showToast('Trigger word copied', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -3,8 +3,7 @@
|
|||||||
*
|
*
|
||||||
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
|
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
|
||||||
*/
|
*/
|
||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
|
||||||
import { state } from '../../state/index.js';
|
|
||||||
import { modalManager } from '../../managers/ModalManager.js';
|
import { modalManager } from '../../managers/ModalManager.js';
|
||||||
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
|
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
|
||||||
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
|
||||||
@@ -174,8 +173,7 @@ export function showLoraModal(lora) {
|
|||||||
// Copy file name function
|
// Copy file name function
|
||||||
window.copyFileName = async function(fileName) {
|
window.copyFileName = async function(fileName) {
|
||||||
try {
|
try {
|
||||||
await navigator.clipboard.writeText(fileName);
|
await copyToClipboard(fileName, 'File name copied');
|
||||||
showToast('File name copied', 'success');
|
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error('Copy failed:', err);
|
console.error('Copy failed:', err);
|
||||||
showToast('Copy failed', 'error');
|
showToast('Copy failed', 'error');
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { state } from '../state/index.js';
|
import { state } from '../state/index.js';
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
|
||||||
import { updateCardsForBulkMode } from '../components/LoraCard.js';
|
import { updateCardsForBulkMode } from '../components/LoraCard.js';
|
||||||
|
|
||||||
export class BulkManager {
|
export class BulkManager {
|
||||||
@@ -205,13 +205,7 @@ export class BulkManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
await copyToClipboard(loraSyntaxes.join(', '), `Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`);
|
||||||
await navigator.clipboard.writeText(loraSyntaxes.join(', '));
|
|
||||||
showToast(`Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`, 'success');
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Copy failed:', err);
|
|
||||||
showToast('Copy failed', 'error');
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create and show the thumbnail strip of selected LoRAs
|
// Create and show the thumbnail strip of selected LoRAs
|
||||||
|
|||||||
@@ -2,6 +2,40 @@ import { state } from '../state/index.js';
|
|||||||
import { resetAndReload } from '../api/loraApi.js';
|
import { resetAndReload } from '../api/loraApi.js';
|
||||||
import { getStorageItem, setStorageItem } from './storageHelpers.js';
|
import { getStorageItem, setStorageItem } from './storageHelpers.js';
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Utility function to copy text to clipboard with fallback for older browsers
|
||||||
|
* @param {string} text - The text to copy to clipboard
|
||||||
|
* @param {string} successMessage - Optional success message to show in toast
|
||||||
|
* @returns {Promise<boolean>} - Promise that resolves to true if copy was successful
|
||||||
|
*/
|
||||||
|
export async function copyToClipboard(text, successMessage = 'Copied to clipboard') {
|
||||||
|
try {
|
||||||
|
// Modern clipboard API
|
||||||
|
if (navigator.clipboard && window.isSecureContext) {
|
||||||
|
await navigator.clipboard.writeText(text);
|
||||||
|
} else {
|
||||||
|
// Fallback for older browsers
|
||||||
|
const textarea = document.createElement('textarea');
|
||||||
|
textarea.value = text;
|
||||||
|
textarea.style.position = 'absolute';
|
||||||
|
textarea.style.left = '-99999px';
|
||||||
|
document.body.appendChild(textarea);
|
||||||
|
textarea.select();
|
||||||
|
document.execCommand('copy');
|
||||||
|
document.body.removeChild(textarea);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (successMessage) {
|
||||||
|
showToast(successMessage, 'success');
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Copy failed:', err);
|
||||||
|
showToast('Copy failed', 'error');
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function showToast(message, type = 'info') {
|
export function showToast(message, type = 'info') {
|
||||||
const toast = document.createElement('div');
|
const toast = document.createElement('div');
|
||||||
toast.className = `toast toast-${type}`;
|
toast.className = `toast toast-${type}`;
|
||||||
@@ -108,12 +142,6 @@ export function toggleFolder(tag) {
|
|||||||
resetAndReload();
|
resetAndReload();
|
||||||
}
|
}
|
||||||
|
|
||||||
export function copyTriggerWord(word) {
|
|
||||||
navigator.clipboard.writeText(word).then(() => {
|
|
||||||
showToast('Trigger word copied', 'success');
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
function filterByFolder(folderPath) {
|
function filterByFolder(folderPath) {
|
||||||
document.querySelectorAll('.lora-card').forEach(card => {
|
document.querySelectorAll('.lora-card').forEach(card => {
|
||||||
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
|
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
|
||||||
|
|||||||
@@ -927,10 +927,6 @@ export function addLorasWidget(node, name, opts, callback) {
|
|||||||
// Function to directly save the recipe without dialog
|
// Function to directly save the recipe without dialog
|
||||||
async function saveRecipeDirectly(widget) {
|
async function saveRecipeDirectly(widget) {
|
||||||
try {
|
try {
|
||||||
// Get the workflow data from the ComfyUI app
|
|
||||||
const prompt = await app.graphToPrompt();
|
|
||||||
console.log('Prompt:', prompt);
|
|
||||||
|
|
||||||
// Show loading toast
|
// Show loading toast
|
||||||
if (app && app.extensionManager && app.extensionManager.toast) {
|
if (app && app.extensionManager && app.extensionManager.toast) {
|
||||||
app.extensionManager.toast.add({
|
app.extensionManager.toast.add({
|
||||||
@@ -941,14 +937,9 @@ async function saveRecipeDirectly(widget) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the data - only send workflow JSON
|
|
||||||
const formData = new FormData();
|
|
||||||
formData.append('workflow_json', JSON.stringify(prompt.output));
|
|
||||||
|
|
||||||
// Send the request
|
// Send the request
|
||||||
const response = await fetch('/api/recipes/save-from-widget', {
|
const response = await fetch('/api/recipes/save-from-widget', {
|
||||||
method: 'POST',
|
method: 'POST'
|
||||||
body: formData
|
|
||||||
});
|
});
|
||||||
|
|
||||||
const result = await response.json();
|
const result = await response.json();
|
||||||
|
|||||||
@@ -9,6 +9,54 @@ async function getLorasWidgetModule() {
|
|||||||
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Function to get connected trigger toggle nodes
|
||||||
|
function getConnectedTriggerToggleNodes(node) {
|
||||||
|
const connectedNodes = [];
|
||||||
|
|
||||||
|
// Check if node has outputs
|
||||||
|
if (node.outputs && node.outputs.length > 0) {
|
||||||
|
// For each output slot
|
||||||
|
for (const output of node.outputs) {
|
||||||
|
// Check if this output has any links
|
||||||
|
if (output.links && output.links.length > 0) {
|
||||||
|
// For each link, get the target node
|
||||||
|
for (const linkId of output.links) {
|
||||||
|
const link = app.graph.links[linkId];
|
||||||
|
if (link) {
|
||||||
|
const targetNode = app.graph.getNodeById(link.target_id);
|
||||||
|
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||||
|
connectedNodes.push(targetNode.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connectedNodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to update trigger words for connected toggle nodes
|
||||||
|
function updateConnectedTriggerWords(node, text) {
|
||||||
|
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
||||||
|
if (connectedNodeIds.length > 0) {
|
||||||
|
const loraNames = new Set();
|
||||||
|
let match;
|
||||||
|
LORA_PATTERN.lastIndex = 0;
|
||||||
|
while ((match = LORA_PATTERN.exec(text)) !== null) {
|
||||||
|
loraNames.add(match[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch("/loramanager/get_trigger_words", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
lora_names: Array.from(loraNames),
|
||||||
|
node_ids: connectedNodeIds
|
||||||
|
})
|
||||||
|
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function mergeLoras(lorasText, lorasArr) {
|
function mergeLoras(lorasText, lorasArr) {
|
||||||
const result = [];
|
const result = [];
|
||||||
let match;
|
let match;
|
||||||
@@ -99,6 +147,9 @@ app.registerExtension({
|
|||||||
newText = newText.replace(/\s+/g, ' ').trim();
|
newText = newText.replace(/\s+/g, ' ').trim();
|
||||||
|
|
||||||
inputWidget.value = newText;
|
inputWidget.value = newText;
|
||||||
|
|
||||||
|
// Add this line to update trigger words when lorasWidget changes cause inputWidget value to change
|
||||||
|
updateConnectedTriggerWords(node, newText);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
@@ -117,6 +168,9 @@ app.registerExtension({
|
|||||||
const mergedLoras = mergeLoras(value, currentLoras);
|
const mergedLoras = mergeLoras(value, currentLoras);
|
||||||
|
|
||||||
node.lorasWidget.value = mergedLoras;
|
node.lorasWidget.value = mergedLoras;
|
||||||
|
|
||||||
|
// Replace the existing trigger word update code with the new function
|
||||||
|
updateConnectedTriggerWords(node, value);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,58 @@
|
|||||||
import { app } from "../../scripts/app.js";
|
import { app } from "../../scripts/app.js";
|
||||||
import { addLorasWidget } from "./loras_widget.js";
|
import { dynamicImportByVersion } from "./utils.js";
|
||||||
|
|
||||||
// Extract pattern into a constant for consistent use
|
// Extract pattern into a constant for consistent use
|
||||||
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
|
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
|
||||||
|
|
||||||
|
// Function to get the appropriate loras widget based on ComfyUI version
|
||||||
|
async function getLorasWidgetModule() {
|
||||||
|
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to get connected trigger toggle nodes
|
||||||
|
function getConnectedTriggerToggleNodes(node) {
|
||||||
|
const connectedNodes = [];
|
||||||
|
|
||||||
|
if (node.outputs && node.outputs.length > 0) {
|
||||||
|
for (const output of node.outputs) {
|
||||||
|
if (output.links && output.links.length > 0) {
|
||||||
|
for (const linkId of output.links) {
|
||||||
|
const link = app.graph.links[linkId];
|
||||||
|
if (link) {
|
||||||
|
const targetNode = app.graph.getNodeById(link.target_id);
|
||||||
|
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
|
||||||
|
connectedNodes.push(targetNode.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return connectedNodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to update trigger words for connected toggle nodes
|
||||||
|
function updateConnectedTriggerWords(node, text) {
|
||||||
|
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
|
||||||
|
if (connectedNodeIds.length > 0) {
|
||||||
|
const loraNames = new Set();
|
||||||
|
let match;
|
||||||
|
LORA_PATTERN.lastIndex = 0;
|
||||||
|
while ((match = LORA_PATTERN.exec(text)) !== null) {
|
||||||
|
loraNames.add(match[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
fetch("/loramanager/get_trigger_words", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
lora_names: Array.from(loraNames),
|
||||||
|
node_ids: connectedNodeIds
|
||||||
|
})
|
||||||
|
}).catch(err => console.error("Error fetching trigger words:", err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function mergeLoras(lorasText, lorasArr) {
|
function mergeLoras(lorasText, lorasArr) {
|
||||||
const result = [];
|
const result = [];
|
||||||
let match;
|
let match;
|
||||||
@@ -40,7 +89,7 @@ app.registerExtension({
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Wait for node to be properly initialized
|
// Wait for node to be properly initialized
|
||||||
requestAnimationFrame(() => {
|
requestAnimationFrame(async () => {
|
||||||
// Restore saved value if exists
|
// Restore saved value if exists
|
||||||
let existingLoras = [];
|
let existingLoras = [];
|
||||||
if (node.widgets_values && node.widgets_values.length > 0) {
|
if (node.widgets_values && node.widgets_values.length > 0) {
|
||||||
@@ -64,7 +113,10 @@ app.registerExtension({
|
|||||||
// Add flag to prevent callback loops
|
// Add flag to prevent callback loops
|
||||||
let isUpdating = false;
|
let isUpdating = false;
|
||||||
|
|
||||||
// Get the widget object directly from the returned object
|
// Dynamically load the appropriate widget module
|
||||||
|
const lorasModule = await getLorasWidgetModule();
|
||||||
|
const { addLorasWidget } = lorasModule;
|
||||||
|
|
||||||
const result = addLorasWidget(node, "loras", {
|
const result = addLorasWidget(node, "loras", {
|
||||||
defaultVal: mergedLoras // Pass object directly
|
defaultVal: mergedLoras // Pass object directly
|
||||||
}, (value) => {
|
}, (value) => {
|
||||||
@@ -86,6 +138,9 @@ app.registerExtension({
|
|||||||
newText = newText.replace(/\s+/g, ' ').trim();
|
newText = newText.replace(/\s+/g, ' ').trim();
|
||||||
|
|
||||||
inputWidget.value = newText;
|
inputWidget.value = newText;
|
||||||
|
|
||||||
|
// Update trigger words when lorasWidget changes
|
||||||
|
updateConnectedTriggerWords(node, newText);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
@@ -104,6 +159,9 @@ app.registerExtension({
|
|||||||
const mergedLoras = mergeLoras(value, currentLoras);
|
const mergedLoras = mergeLoras(value, currentLoras);
|
||||||
|
|
||||||
node.lorasWidget.value = mergedLoras;
|
node.lorasWidget.value = mergedLoras;
|
||||||
|
|
||||||
|
// Update trigger words when input changes
|
||||||
|
updateConnectedTriggerWords(node, value);
|
||||||
} finally {
|
} finally {
|
||||||
isUpdating = false;
|
isUpdating = false;
|
||||||
}
|
}
|
||||||
|
|||||||
36
web/comfyui/usage_stats.js
Normal file
36
web/comfyui/usage_stats.js
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
// ComfyUI extension to track model usage statistics
|
||||||
|
import { app } from "../../scripts/app.js";
|
||||||
|
import { api } from "../../scripts/api.js";
|
||||||
|
|
||||||
|
// Register the extension
|
||||||
|
app.registerExtension({
|
||||||
|
name: "ComfyUI-Lora-Manager.UsageStats",
|
||||||
|
|
||||||
|
init() {
|
||||||
|
// Listen for successful executions
|
||||||
|
api.addEventListener("execution_success", ({ detail }) => {
|
||||||
|
if (detail && detail.prompt_id) {
|
||||||
|
this.updateUsageStats(detail.prompt_id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
},
|
||||||
|
|
||||||
|
async updateUsageStats(promptId) {
|
||||||
|
try {
|
||||||
|
// Call backend endpoint with the prompt_id
|
||||||
|
const response = await fetch(`/loras/api/update-usage-stats`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ prompt_id: promptId }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
console.warn("Failed to update usage statistics:", response.statusText);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Error updating usage statistics:", error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user