Compare commits

..

30 Commits

Author SHA1 Message Date
Will Miao
2bc46e708e feat: Update release notes and version to 0.8.7 with enhancements and bug fixes 2025-04-18 19:03:00 +08:00
Will Miao
96e3b5b7b3 feat: Refactor Civitai model API routes and enhance RecipeContextMenu for missing LoRAs handling 2025-04-18 16:44:26 +08:00
Will Miao
fafbafa5e1 feat: Enhance copyTriggerWord function with modern clipboard API and fallback for non-secure contexts. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/110 2025-04-18 14:56:27 +08:00
Will Miao
be8605d8c6 feat: Enhance CivitaiClient and ApiRoutes to handle model version errors and improve metadata fetching. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/112 2025-04-18 14:44:53 +08:00
Will Miao
061660d47a feat: Increase maximum allowed trigger words from 10 to 30. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/109 2025-04-18 11:25:41 +08:00
pixelpaws
2ed6dbb344 Merge pull request #111 from willmiao/dev
Dev
2025-04-18 10:55:07 +08:00
Will Miao
4766b45746 feat: Update SaveImage node to modify default lossless_webp setting and adjust save_kwargs for image formats 2025-04-18 10:52:39 +08:00
Will Miao
0734252e98 feat: Enhance VAEDecodeExtractor to improve image caching and metadata handling 2025-04-18 10:03:26 +08:00
Will Miao
91b4827c1d feat: Enhance image retrieval in MetadataRegistry and update recipe routes to process images from metadata 2025-04-18 09:24:48 +08:00
Will Miao
df6d56ce66 feat: Add IMAGES category to constants and enhance metadata handling in node extractors 2025-04-18 07:12:43 +08:00
Will Miao
f0203c96ab feat: Simplify format_metadata method by removing custom_prompt parameter and update related function calls 2025-04-18 05:34:42 +08:00
Will Miao
bccabe40c0 feat: Enhance KSamplerAdvancedExtractor to include additional sampling parameters and update metadata processing 2025-04-18 05:29:36 +08:00
Will Miao
c2f599b4ff feat: Update node extractors to include UNETLoaderExtractor and enhance metadata handling for guidance parameters 2025-04-17 22:05:40 +08:00
Will Miao
5fd069d70d feat: Enhance checkpoint processing in format_metadata to handle non-string types safely 2025-04-17 09:38:20 +08:00
Will Miao
32d34d1748 feat: Enhance trace_node_input method with depth tracking and target class filtering; add FluxGuidanceExtractor for guidance parameter extraction 2025-04-17 08:06:21 +08:00
Will Miao
18eb605605 feat: Refactor metadata processing to use constants for category keys and improve structure 2025-04-17 06:23:31 +08:00
Will Miao
4fdc88e9e1 feat: Enhance LoraLoaderExtractor to extract base filename from lora_name input 2025-04-16 22:19:38 +08:00
Will Miao
4c69d8d3a8 feat: Integrate metadata collection in RecipeRoutes and simplify saveRecipeDirectly function 2025-04-16 22:15:46 +08:00
Will Miao
d4b2dd0ec1 refactor: Rename to_comfyui_format method to to_dict and update references in save_image.py 2025-04-16 21:42:54 +08:00
Will Miao
181f78421b feat: Standardize LoRA extraction format and enhance input handling in node extractors 2025-04-16 21:20:56 +08:00
Will Miao
8ed38527d0 feat: Implement metadata collection and processing framework with debug node for verification 2025-04-16 20:04:26 +08:00
Will Miao
c4c926070d fix: Update optimize_image method to handle image validation and error logging, and adjust metadata preservation logic. 2025-04-15 12:31:17 +08:00
Will Miao
ed87411e0d refactor: Change logging level from info to debug for service initialization and file monitoring 2025-04-15 11:48:37 +08:00
Will Miao
4ec2a448ab feat: Improve date formatting in filename generation with zero-padding and two-digit year support. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/102 2025-04-15 10:46:57 +08:00
Will Miao
73d01da94e feat: Enhance model preview version management with localStorage support 2025-04-15 10:35:50 +08:00
pixelpaws
df8e02157a Merge pull request #103 from willmiao/dev
feat: Add drag functionality for strength adjustment in LoRA entries.…
2025-04-15 08:57:52 +08:00
Will Miao
6e513ed32a feat: Add drag functionality for strength adjustment in LoRA entries. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/101 2025-04-15 08:56:19 +08:00
pixelpaws
325ef6327d Merge pull request #99 from willmiao/dev
Dev
2025-04-14 20:27:18 +08:00
Will Miao
46700e5ad0 feat: Refactor infinite scroll initialization for improved observer handling and sentinel management 2025-04-14 20:25:44 +08:00
Will Miao
d1e21fa345 feat: Implement context menus for checkpoints and recipes, including metadata refresh and NSFW level management 2025-04-14 15:37:36 +08:00
43 changed files with 2720 additions and 341 deletions

View File

@@ -20,6 +20,14 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
## Release Notes
### v0.8.7
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
* **Metadata Collector Overhaul** - Rebuilt metadata collection system with optimized architecture for better performance
* **Improved Save Image Node** - Enhanced metadata capture and image saving performance with the new metadata collector
* **Streamlined Recipe Saving** - Optimized Save Recipe functionality to work independently without requiring Preview Image nodes
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.6 Major Update
* **Checkpoint Management** - Added comprehensive management for model checkpoints including scanning, searching, filtering, and deletion
* **Enhanced Metadata Support** - New capabilities for retrieving and managing checkpoint metadata with improved operations

View File

@@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker
from .py.nodes.save_image import SaveImage
from .py.nodes.debug_metadata import DebugMetadata
# Import metadata collector to install hooks on startup
from .py.metadata_collector import init as init_metadata_collector
NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader,
TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker,
SaveImage.NAME: SaveImage
SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata
}
WEB_DIRECTORY = "./web/comfyui"
# Initialize metadata collector
init_metadata_collector()
# Register routes on import
LoraManager.add_routes()
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']

View File

@@ -104,8 +104,6 @@ class LoraManager:
async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry"""
try:
logger.info("LoRA Manager: Initializing services via ServiceRegistry")
# Initialize CivitaiClient first to ensure it's ready for other services
civitai_client = await ServiceRegistry.get_civitai_client()
@@ -115,12 +113,12 @@ class LoraManager:
# Start monitors
lora_monitor.start()
logger.info("Lora monitor started")
logger.debug("Lora monitor started")
# Make sure checkpoint monitor has paths before starting
await checkpoint_monitor.initialize_paths()
checkpoint_monitor.start()
logger.info("Checkpoint monitor started")
logger.debug("Checkpoint monitor started")
# Register DownloadManager with ServiceRegistry
download_manager = await ServiceRegistry.get_download_manager()

View File

@@ -0,0 +1,18 @@
import os
import importlib
from .metadata_hook import MetadataHook
from .metadata_registry import MetadataRegistry
def init():
# Install hooks to collect metadata during execution
MetadataHook.install()
# Initialize registry
registry = MetadataRegistry()
print("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None):
"""Helper function to get metadata from the registry"""
registry = MetadataRegistry()
return registry.get_metadata(prompt_id)

View File

@@ -0,0 +1,12 @@
"""Constants used by the metadata collector"""
# Individual category constants
MODELS = "models"
PROMPTS = "prompts"
SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
IMAGES = "images" # Added new category for image results
# Collection of categories for iteration
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES] # Added IMAGES to categories

View File

@@ -0,0 +1,123 @@
import sys
import inspect
from .metadata_registry import MetadataRegistry
class MetadataHook:
"""Install hooks for metadata collection"""
@staticmethod
def install():
"""Install hooks to collect metadata during execution"""
try:
# Import ComfyUI's execution module
execution = None
try:
# Try direct import first
import execution # type: ignore
except ImportError:
# Try to locate from system modules
for module_name in sys.modules:
if module_name.endswith('.execution'):
execution = sys.modules[module_name]
break
# If we can't find the execution module, we can't install hooks
if execution is None:
print("Could not locate ComfyUI execution module, metadata collection disabled")
return
# Store the original _map_node_over_list function
original_map_node_over_list = execution._map_node_over_list
# Define the wrapped _map_node_over_list function
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
def execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return original_execute(*args, **kwargs)
# Replace the functions
execution._map_node_over_list = map_node_over_list_with_metadata
execution.execute = execute_with_prompt_tracking
# Make map_node_over_list public to avoid it being hidden by hooks
execution.map_node_over_list = original_map_node_over_list
print("Metadata collection hooks installed for runtime values")
except Exception as e:
print(f"Error installing metadata hooks: {str(e)}")

View File

@@ -0,0 +1,245 @@
import json
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
class MetadataProcessor:
"""Process and format collected metadata"""
@staticmethod
def find_primary_sampler(metadata):
"""Find the primary KSampler node (with denoise=1)"""
primary_sampler = None
primary_sampler_id = None
# First, check for KSamplerAdvanced with add_noise="enable"
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
add_noise = parameters.get("add_noise")
# If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced
if add_noise == "enable":
primary_sampler = sampler_info
primary_sampler_id = node_id
break
# If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1
if primary_sampler is None:
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise")
# If denoise is 1.0, this is likely the primary sampler
if denoise == 1.0 or denoise == 1:
primary_sampler = sampler_info
primary_sampler_id = node_id
break
return primary_sampler_id, primary_sampler
@staticmethod
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
"""
Trace an input connection from a node to find the source node
Parameters:
- prompt: The prompt object containing node connections
- node_id: ID of the starting node
- input_name: Name of the input to trace
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
Returns:
- node_id of the found node, or None if not found
"""
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
return None
# For depth tracking
current_depth = 0
current_node_id = node_id
current_input = input_name
while current_depth < max_depth:
if current_node_id not in prompt.original_prompt:
return None
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if current_input not in node_inputs:
return None
input_value = node_inputs[current_input]
# Input connections are formatted as [node_id, output_index]
if isinstance(input_value, list) and len(input_value) >= 2:
found_node_id = input_value[0] # Connected node_id
# If we're looking for a specific node class
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class or haven't found it yet
if not target_class:
return found_node_id
# Continue tracing through intermediate nodes
current_node_id = found_node_id
# For most conditioning nodes, the input we want to follow is named "conditioning"
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
current_input = "conditioning"
else:
# If there's no "conditioning" input, we can't trace further
return found_node_id if not target_class else None
else:
# We've reached a node with no further connections
return None
current_depth += 1
# If we've reached max depth without finding target_class
return None
@staticmethod
def find_primary_checkpoint(metadata):
"""Find the primary checkpoint model in the workflow"""
if not metadata.get(MODELS):
return None
# In most workflows, there's only one checkpoint, so we can just take the first one
for node_id, model_info in metadata.get(MODELS, {}).items():
if model_info.get("type") == "checkpoint":
return model_info.get("name")
return None
@staticmethod
def extract_generation_params(metadata):
"""Extract generation parameters from metadata using node relationships"""
params = {
"prompt": "",
"negative_prompt": "",
"seed": None,
"steps": None,
"cfg_scale": None,
"guidance": None, # Add guidance parameter
"sampler": None,
"scheduler": None,
"checkpoint": None,
"loras": "",
"size": None,
"clip_skip": None
}
# Get the prompt object for node relationship tracing
prompt = metadata.get("current_prompt")
# Find the primary KSampler node
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
if checkpoint:
params["checkpoint"] = checkpoint
if primary_sampler:
# Extract sampling parameters
sampling_params = primary_sampler.get("parameters", {})
# Handle both seed and noise_seed
params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed")
params["steps"] = sampling_params.get("steps")
params["cfg_scale"] = sampling_params.get("cfg")
params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler")
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find any FluxGuidance nodes in the positive conditioning path
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
else:
# Fallback to the previous trace method if needed
latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image")
if latent_node_id:
# Follow chain to find EmptyLatentImage node
size_found = False
current_node_id = latent_node_id
# Limit depth to avoid infinite loops in complex workflows
max_depth = 10
for _ in range(max_depth):
if current_node_id in metadata.get(SIZE, {}):
width = metadata[SIZE][current_node_id].get("width")
height = metadata[SIZE][current_node_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
size_found = True
break
# Try to follow the chain
if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt:
node_info = prompt.original_prompt[current_node_id]
if "inputs" in node_info:
# Look for a connection that might lead to size information
for input_name, input_value in node_info["inputs"].items():
if isinstance(input_value, list) and len(input_value) >= 2:
current_node_id = input_value[0]
break
else:
break # No connections to follow
else:
break # No inputs to follow
else:
break # Can't follow further
# Extract LoRAs using the standardized format
lora_parts = []
for node_id, lora_info in metadata.get(LORAS, {}).items():
# Access the lora_list from the standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
name = lora.get("name", "unknown")
strength = lora.get("strength", 1.0)
lora_parts.append(f"<lora:{name}:{strength}>")
params["loras"] = " ".join(lora_parts)
# Set default clip_skip value
params["clip_skip"] = "1" # Common default
return params
@staticmethod
def to_dict(metadata):
"""Convert extracted metadata to the ComfyUI output.json format"""
params = MetadataProcessor.extract_generation_params(metadata)
# Convert all values to strings to match output.json format
for key in params:
if params[key] is not None:
params[key] = str(params[key])
return params
@staticmethod
def to_json(metadata):
"""Convert metadata to JSON string"""
params = MetadataProcessor.to_dict(metadata)
return json.dumps(params, indent=4)

View File

@@ -0,0 +1,275 @@
import time
from nodes import NODE_CLASS_MAPPINGS
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._reset()
return cls._instance
def _reset(self):
self.current_prompt_id = None
self.current_prompt = None
self.metadata = {}
self.prompt_metadata = {}
self.executed_nodes = set()
# Node-level cache for metadata
self.node_cache = {}
# Limit the number of stored prompts
self.max_prompt_history = 3
# Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history:
return
# Sort all prompt_ids by timestamp
sorted_prompts = sorted(
self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
)
# Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
for pid in prompts_to_remove:
del self.prompt_metadata[pid]
def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id
self.executed_nodes = set()
self.prompt_metadata[prompt_id] = {
category: {} for category in METADATA_CATEGORIES
}
# Add additional metadata fields
self.prompt_metadata[prompt_id].update({
"execution_order": [],
"current_prompt": None, # Will store the prompt object
"timestamp": time.time()
})
# Clean up old prompt data
self._clean_old_prompts()
def set_current_prompt(self, prompt):
"""Set the current prompt object reference"""
self.current_prompt = prompt
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
# Store the prompt in the metadata for later relationship tracing
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
def get_metadata(self, prompt_id=None):
"""Get collected metadata for a prompt"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return {}
metadata = self.prompt_metadata[key]
# If we have a current prompt object, check for non-executed nodes
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
# Fill in missing metadata from cache for nodes that weren't executed
self._fill_missing_metadata(key, original_prompt)
return self.prompt_metadata.get(key, {})
def _fill_missing_metadata(self, prompt_id, original_prompt):
"""Fill missing metadata from cache for non-executed nodes"""
if not original_prompt:
return
executed_nodes = self.executed_nodes
metadata = self.prompt_metadata[prompt_id]
# Iterate through nodes in the original prompt
for node_id, node_data in original_prompt.items():
# Skip if already executed in this run
if node_id in executed_nodes:
continue
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
prompt_class_type = node_data.get("class_type")
if not prompt_class_type:
continue
# Convert to actual class name (which is what we use in our cache)
class_type = prompt_class_type
if prompt_class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
class_type = class_obj.__name__
# Create cache key using the actual class name
cache_key = f"{node_id}:{class_type}"
# Check if this node type is relevant for metadata collection
if class_type in NODE_EXTRACTORS:
# Check if we have cached metadata for this node
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
# Apply cached metadata to the current metadata
for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id]
def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution"""
if not self.current_prompt_id:
return
# Add to execution order and mark as executed
if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
# Process inputs to simplify working with them
processed_inputs = {}
for input_name, input_values in inputs.items():
if isinstance(input_values, list) and len(input_values) > 0:
# For single values, just use the first one (most common case)
processed_inputs[input_name] = input_values[0]
else:
processed_inputs[input_name] = input_values
# Extract node-specific metadata
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
extractor.extract(
node_id,
processed_inputs,
outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Cache this node's metadata
self._cache_node_metadata(node_id, class_type)
def update_node_execution(self, node_id, class_type, outputs):
"""Update node metadata with output information"""
if not self.current_prompt_id:
return
# Process outputs to make them more usable
processed_outputs = outputs
# Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'):
extractor.update(
node_id,
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Update the cached metadata for this node
self._cache_node_metadata(node_id, class_type)
def _cache_node_metadata(self, node_id, class_type):
"""Cache the metadata for a specific node"""
if not self.current_prompt_id or not node_id or not class_type:
return
# Create a cache key combining node_id and class_type
cache_key = f"{node_id}:{class_type}"
# Create a shallow copy of the node's metadata
node_metadata = {}
current_metadata = self.prompt_metadata[self.current_prompt_id]
for category in self.metadata_categories:
if category in current_metadata and node_id in current_metadata[category]:
if category not in node_metadata:
node_metadata[category] = {}
node_metadata[category][node_id] = current_metadata[category][node_id]
# Save to cache if we have any metadata for this node
if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata
def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata
active_node_ids = set()
for prompt_data in self.prompt_metadata.values():
for category in self.metadata_categories:
if category in prompt_data:
active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed
keys_to_remove = []
for cache_key in self.node_cache:
node_id = cache_key.split(':')[0]
if node_id not in active_node_ids:
keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed
for key in keys_to_remove:
del self.node_cache[key]
def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None:
if prompt_id in self.prompt_metadata:
del self.prompt_metadata[prompt_id]
# Clean up cache after removing prompt
self.clear_unused_cache()
else:
# Reset all data
self._reset()
def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return None
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
# If no image is found in the current metadata, try to find it in the cache
# This handles the case where VAEDecode was cached by ComfyUI and not executed
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
for node_id, node_data in original_prompt.items():
class_type = node_data.get("class_type")
if class_type and class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[class_type]
class_name = class_obj.__name__
# Check if this is a VAEDecode node
if class_name == "VAEDecode":
# Try to find this node in the cache
cache_key = f"{node_id}:{class_name}"
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
return image_data[0]
return image_data
return None

View File

@@ -0,0 +1,280 @@
import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES
class NodeMetadataExtractor:
"""Base class for node-specific metadata extraction"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
"""Extract metadata from node inputs/outputs"""
pass
@staticmethod
def update(node_id, outputs, metadata):
"""Update metadata with node outputs after execution"""
pass
class GenericNodeExtractor(NodeMetadataExtractor):
"""Default extractor for nodes without specific handling"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
class CheckpointLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "ckpt_name" not in inputs:
return
model_name = inputs.get("ckpt_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "text" not in inputs:
return
text = inputs.get("text", "")
metadata[PROMPTS][node_id] = {
"text": text,
"node_id": node_id
}
class SamplerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "lora_name" not in inputs:
return
lora_name = inputs.get("lora_name")
# Extract base filename without extension from path
lora_name = os.path.splitext(os.path.basename(lora_name))[0]
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
# Use the standardized format with lora_list
metadata[LORAS][node_id] = {
"lora_list": [
{
"name": lora_name,
"strength": strength_model
}
],
"node_id": node_id
}
class ImageSizeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
width = inputs.get("width", 512)
height = inputs.get("height", 512)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
active_loras = []
# Process lora_stack if available
if "lora_stack" in inputs:
lora_stack = inputs.get("lora_stack", [])
for lora_path, model_strength, clip_strength in lora_stack:
# Extract lora name from path (following the format in lora_loader.py)
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
active_loras.append({
"name": lora_name,
"strength": model_strength
})
# Process loras from inputs
if "loras" in inputs:
loras_data = inputs.get("loras", [])
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
loras_list = loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Filter for active loras
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
active_loras.append({
"name": lora.get("name", ""),
"strength": float(lora.get("strength", 1.0))
})
if active_loras:
metadata[LORAS][node_id] = {
"lora_list": active_loras,
"node_id": node_id
}
class FluxGuidanceExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "guidance" not in inputs:
return
guidance_value = inputs.get("guidance")
# Store the guidance value in SAMPLING category
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
class UNETLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "unet_name" not in inputs:
return
model_name = inputs.get("unet_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class VAEDecodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
@staticmethod
def update(node_id, outputs, metadata):
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
# Save image data under node ID index to be captured by caching mechanism
metadata[IMAGES][node_id] = {
"node_id": node_id,
"image": outputs
}
# Only set first_decode if it hasn't been recorded yet
if "first_decode" not in metadata[IMAGES]:
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
# Registry of node-specific extractors
NODE_EXTRACTORS = {
# Sampling
"KSampler": SamplerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
# Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"LoraLoader": LoraLoaderExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
# Conditioning
"CLIPTextEncode": CLIPTextEncodeExtractor,
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Image
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
# Add other nodes as needed
}

View File

@@ -0,0 +1,35 @@
import logging
from ..metadata_collector.metadata_processor import MetadataProcessor
logger = logging.getLogger(__name__)
class DebugMetadata:
NAME = "Debug Metadata (LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Debug node to verify metadata_processor functionality"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("metadata_json",)
FUNCTION = "process_metadata"
def process_metadata(self, images):
try:
# Get the current execution context's metadata
from ..metadata_collector import get_metadata
metadata = get_metadata()
# Use the MetadataProcessor to convert it to JSON string
metadata_json = MetadataProcessor.to_json(metadata)
return (metadata_json,)
except Exception as e:
logger.error(f"Error processing metadata: {e}")
return ("{}",) # Return empty JSON object in case of error

View File

@@ -5,10 +5,10 @@ import re
import numpy as np
import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner
from ..workflow.parser import WorkflowParser
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
import piexif
from io import BytesIO
class SaveImage:
NAME = "Save Image (LoraManager)"
@@ -34,8 +34,7 @@ class SaveImage:
"file_format": (["png", "jpeg", "webp"],),
},
"optional": {
"custom_prompt": ("STRING", {"default": "", "forceInput": True}),
"lossless_webp": ("BOOLEAN", {"default": True}),
"lossless_webp": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
@@ -61,21 +60,17 @@ class SaveImage:
return item.get('sha256')
return None
async def format_metadata(self, parsed_workflow, custom_prompt=None):
async def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not parsed_workflow:
if not metadata_dict:
return ""
# Extract the prompt and negative prompt
prompt = parsed_workflow.get('prompt', '')
negative_prompt = parsed_workflow.get('negative_prompt', '')
# Override prompt with custom_prompt if provided
if custom_prompt:
prompt = custom_prompt
prompt = metadata_dict.get('prompt', '')
negative_prompt = metadata_dict.get('negative_prompt', '')
# Extract loras from the prompt if present
loras_text = parsed_workflow.get('loras', '')
loras_text = metadata_dict.get('loras', '')
lora_hashes = {}
# If loras are found, add them on a new line after the prompt
@@ -104,11 +99,11 @@ class SaveImage:
params = []
# Add standard parameters in the correct order
if 'steps' in parsed_workflow:
params.append(f"Steps: {parsed_workflow.get('steps')}")
if 'steps' in metadata_dict:
params.append(f"Steps: {metadata_dict.get('steps')}")
if 'sampler' in parsed_workflow:
sampler = parsed_workflow.get('sampler')
if 'sampler' in metadata_dict:
sampler = metadata_dict.get('sampler')
# Convert ComfyUI sampler names to user-friendly names
sampler_mapping = {
'euler': 'Euler',
@@ -130,8 +125,8 @@ class SaveImage:
sampler_name = sampler_mapping.get(sampler, sampler)
params.append(f"Sampler: {sampler_name}")
if 'scheduler' in parsed_workflow:
scheduler = parsed_workflow.get('scheduler')
if 'scheduler' in metadata_dict:
scheduler = metadata_dict.get('scheduler')
scheduler_mapping = {
'normal': 'Simple',
'karras': 'Karras',
@@ -142,27 +137,36 @@ class SaveImage:
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
params.append(f"Schedule type: {scheduler_name}")
# CFG scale (cfg in parsed_workflow)
if 'cfg_scale' in parsed_workflow:
params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}")
elif 'cfg' in parsed_workflow:
params.append(f"CFG scale: {parsed_workflow.get('cfg')}")
# CFG scale (cfg_scale in metadata_dict)
if 'cfg_scale' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg_scale')}")
elif 'cfg' in metadata_dict:
params.append(f"CFG scale: {metadata_dict.get('cfg')}")
# Seed
if 'seed' in parsed_workflow:
params.append(f"Seed: {parsed_workflow.get('seed')}")
if 'seed' in metadata_dict:
params.append(f"Seed: {metadata_dict.get('seed')}")
# Size
if 'size' in parsed_workflow:
params.append(f"Size: {parsed_workflow.get('size')}")
if 'size' in metadata_dict:
params.append(f"Size: {metadata_dict.get('size')}")
# Model info
if 'checkpoint' in parsed_workflow:
# Extract basename without path
checkpoint = os.path.basename(parsed_workflow.get('checkpoint', ''))
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
params.append(f"Model: {checkpoint}")
if 'checkpoint' in metadata_dict:
# Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None:
# Handle both string and other types safely
if isinstance(checkpoint, str):
# Extract basename without path
checkpoint = os.path.basename(checkpoint)
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
else:
# Convert non-string to string
checkpoint = str(checkpoint)
params.append(f"Model: {checkpoint}")
# Add LoRA hashes if available
if lora_hashes:
@@ -181,9 +185,9 @@ class SaveImage:
# credit to nkchocoai
# Add format_filename method to handle pattern substitution
def format_filename(self, filename, parsed_workflow):
def format_filename(self, filename, metadata_dict):
"""Format filename with metadata values"""
if not parsed_workflow:
if not metadata_dict:
return filename
result = re.findall(self.pattern_format, filename)
@@ -191,30 +195,30 @@ class SaveImage:
parts = segment.replace("%", "").split(":")
key = parts[0]
if key == "seed" and 'seed' in parsed_workflow:
filename = filename.replace(segment, str(parsed_workflow.get('seed', '')))
elif key == "width" and 'size' in parsed_workflow:
size = parsed_workflow.get('size', 'x')
if key == "seed" and 'seed' in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
elif key == "width" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
w = size.split('x')[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in parsed_workflow:
size = parsed_workflow.get('size', 'x')
elif key == "height" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
h = size.split('x')[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in parsed_workflow:
prompt = parsed_workflow.get('prompt', '').replace("\n", " ")
elif key == "pprompt" and 'prompt' in metadata_dict:
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in parsed_workflow:
prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ")
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "model" and 'checkpoint' in parsed_workflow:
model = parsed_workflow.get('checkpoint', '')
elif key == "model" and 'checkpoint' in metadata_dict:
model = metadata_dict.get('checkpoint', '')
model = os.path.splitext(os.path.basename(model))[0]
if len(parts) >= 2:
length = int(parts[1])
@@ -224,12 +228,13 @@ class SaveImage:
from datetime import datetime
now = datetime.now()
date_table = {
"yyyy": str(now.year),
"MM": str(now.month).zfill(2),
"dd": str(now.day).zfill(2),
"hh": str(now.hour).zfill(2),
"mm": str(now.minute).zfill(2),
"ss": str(now.second).zfill(2),
"yyyy": f"{now.year:04d}",
"yy": f"{now.year % 100:02d}",
"MM": f"{now.month:02d}",
"dd": f"{now.day:02d}",
"hh": f"{now.hour:02d}",
"mm": f"{now.minute:02d}",
"ss": f"{now.second:02d}",
}
if len(parts) >= 2:
date_format = parts[1]
@@ -245,23 +250,19 @@ class SaveImage:
return filename
def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
custom_prompt=None):
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Save images with metadata"""
results = []
# Parse the workflow using the WorkflowParser
parser = WorkflowParser()
if prompt:
parsed_workflow = parser.parse_workflow(prompt)
else:
parsed_workflow = {}
# Get metadata using the metadata collector
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Get or create metadata asynchronously
metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt))
metadata = asyncio.run(self.format_metadata(metadata_dict))
# Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, parsed_workflow)
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
@@ -289,7 +290,8 @@ class SaveImage:
if file_format == "png":
file = base_filename + ".png"
file_extension = ".png"
save_kwargs = {"optimize": True, "compress_level": self.compress_level}
# Remove "optimize": True to match built-in node behavior
save_kwargs = {"compress_level": self.compress_level}
pnginfo = PngImagePlugin.PngInfo()
elif file_format == "jpeg":
file = base_filename + ".jpg"
@@ -298,7 +300,8 @@ class SaveImage:
elif file_format == "webp":
file = base_filename + ".webp"
file_extension = ".webp"
save_kwargs = {"quality": quality, "lossless": lossless_webp}
# Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
# Full save path
file_path = os.path.join(full_output_folder, file)
@@ -346,8 +349,7 @@ class SaveImage:
return results
def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
custom_prompt=""):
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Process and save image with metadata"""
# Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True)
@@ -368,8 +370,7 @@ class SaveImage:
lossless_webp,
quality,
embed_workflow,
add_counter_to_filename,
custom_prompt if custom_prompt.strip() else None
add_counter_to_filename
)
return (images,)

View File

@@ -50,8 +50,8 @@ class ApiRoutes:
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/folders', routes.get_folders)
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
app.router.add_get('/api/civitai/model/{modelVersionId}', routes.get_civitai_model)
app.router.add_get('/api/civitai/model/{hash}', routes.get_civitai_model)
app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version)
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
app.router.add_post('/api/download-lora', routes.download_lora)
app.router.add_post('/api/settings', routes.update_settings)
app.router.add_post('/api/move_model', routes.move_model)
@@ -226,7 +226,7 @@ class ApiRoutes:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
extension = '.webp' # Use .webp without .preview part
@@ -396,25 +396,52 @@ class ApiRoutes:
logger.error(f"Error fetching model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID or hash"""
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_version_id = request.match_info.get('modelVersionId')
if not model_version_id:
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
# Get model details from Civitai API
model = await self.civitai_client.get_model_version_info(model_version_id)
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.Response(status=500, text=str(e))
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
@@ -773,7 +800,7 @@ class ApiRoutes:
logger.info(f"Fetching model metadata for model ID: {model_id}")
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
if model_metadata:
if (model_metadata):
description = model_metadata.get('description')
tags = model_metadata.get('tags', [])

View File

@@ -1,5 +1,9 @@
import os
import time
import numpy as np
from PIL import Image
import torch
import io
import logging
from aiohttp import web
from typing import Dict
@@ -11,9 +15,11 @@ from ..utils.recipe_parsers import RecipeParserFactory
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..config import config
from ..workflow.parser import WorkflowParser
from ..metadata_collector import get_metadata # Add MetadataCollector import
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from ..metadata_collector.metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__)
@@ -24,7 +30,7 @@ class RecipeRoutes:
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
self.parser = WorkflowParser()
# Remove WorkflowParser instance
# Pre-warm the cache
self._init_cache_task = None
@@ -656,8 +662,8 @@ class RecipeRoutes:
logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
'error': str(e)}
, status=500)
async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF"""
@@ -786,50 +792,72 @@ class RecipeRoutes:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Get metadata using the metadata collector instead of workflow parsing
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Process form data
workflow_json = None
# Check if we have valid metadata
if not metadata_dict:
return web.json_response({"error": "No generation metadata found"}, status=400)
while True:
field = await reader.next()
if field is None:
break
# Get the most recent image from metadata registry instead of temp directory
metadata_registry = MetadataRegistry()
latest_image = metadata_registry.get_first_decoded_image()
if not latest_image:
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
# Convert the image data to bytes - handle tuple and tensor cases
logger.debug(f"Image type: {type(latest_image)}")
try:
# Handle the tuple case first
if isinstance(latest_image, tuple):
# Extract the tensor from the tuple
if len(latest_image) > 0:
tensor_image = latest_image[0]
else:
return web.json_response({"error": "Empty image tuple received"}, status=400)
else:
tensor_image = latest_image
if field.name == 'workflow_json':
workflow_text = await field.text()
try:
workflow_json = json.loads(workflow_text)
except:
return web.json_response({"error": "Invalid workflow JSON"}, status=400)
# Get the shape info for debugging
if hasattr(tensor_image, 'shape'):
shape_info = tensor_image.shape
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
# Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
# Handle different tensor shapes
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
if len(image_np.shape) > 3:
# Remove batch dimensions until we get to (H, W, 3)
while len(image_np.shape) > 3:
image_np = image_np[0]
# If values are in [0, 1] range, convert to [0, 255]
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
# Ensure image is in the right format (HWC with RGB channels)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
image = img_byte_arr.getvalue()
else:
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
if not workflow_json:
return web.json_response({"error": "Missing workflow JSON"}, status=400)
# Find the latest image in the temp directory
temp_dir = config.temp_directory
image_files = []
for file in os.listdir(temp_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
file_path = os.path.join(temp_dir, file)
image_files.append((file_path, os.path.getmtime(file_path)))
if not image_files:
return web.json_response({"error": "No recent images found to use for recipe"}, status=400)
# Sort by modification time (newest first)
image_files.sort(key=lambda x: x[1], reverse=True)
latest_image_path = image_files[0][0]
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow:
return web.json_response({"error": "Could not extract parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "")
# Get the lora stack from the metadata
lora_stack = metadata_dict.get("loras", "")
# Parse the lora stack format: "<lora:name:strength> <lora:name2:strength2> ..."
import re
@@ -837,7 +865,7 @@ class RecipeRoutes:
# Check if any loras were found
if not lora_matches:
return web.json_response({"error": "No LoRAs found in the workflow"}, status=400)
return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400)
# Generate recipe name from the first 3 loras (or less if fewer are available)
loras_for_name = lora_matches[:3] # Take at most 3 loras for the name
@@ -851,10 +879,6 @@ class RecipeRoutes:
recipe_name = " ".join(recipe_name_parts)
# Read the image
with open(latest_image_path, 'rb') as f:
image = f.read()
# Create recipes directory if it doesn't exist
recipes_dir = self.recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)
@@ -922,8 +946,8 @@ class RecipeRoutes:
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"checkpoint": parsed_workflow.get("checkpoint", ""),
"gen_params": {key: value for key, value in parsed_workflow.items()
"checkpoint": metadata_dict.get("checkpoint", ""),
"gen_params": {key: value for key, value in metadata_dict.items()
if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string
}

View File

@@ -210,8 +210,17 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
"""Fetch model version metadata from Civitai"""
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai
Args:
version_id: The Civitai model version ID
Returns:
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
- The model version data or None if not found
- An error message if there was an error, or None on success
"""
try:
session = await self.session
url = f"{self.base_url}/model-versions/{version_id}"
@@ -219,11 +228,25 @@ class CivitaiClient:
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await response.json()
return None
return await response.json(), None
# Handle specific error cases
if response.status == 404:
# Try to parse the error message
try:
error_data = await response.json()
error_msg = error_data.get('error', f"Model not found (status 404)")
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
except:
return None, "Model not found (status 404)"
# Other error cases
return None, f"Failed to fetch model info (status {response.status})"
except Exception as e:
logger.error(f"Error fetching model version info: {e}")
return None
error_msg = f"Error fetching model version info: {e}"
logger.error(error_msg)
return None, error_msg
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description and tags) from Civitai API

View File

@@ -86,21 +86,24 @@ class DownloadManager:
# Get version info based on the provided identifier
version_info = None
error_msg = None
if download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info = await civitai_client.get_model_version_info(version_id)
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
elif model_version_id:
# Use model version ID directly
version_info = await civitai_client.get_model_version_info(model_version_id)
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
elif model_hash:
# Get model by hash
version_info = await civitai_client.get_model_by_hash(model_hash)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
if error_msg and "model not found" in error_msg.lower():
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
@@ -202,7 +205,7 @@ class DownloadManager:
# Check if it's a video or an image
is_video = images[0].get('type') == 'video'
if is_video:
if (is_video):
# For videos, use .mp4 extension
preview_ext = '.mp4'
preview_path = os.path.splitext(save_path)[0] + preview_ext
@@ -229,7 +232,7 @@ class DownloadManager:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
# Save the optimized image

View File

@@ -408,7 +408,7 @@ class BaseFileMonitor:
def start(self):
"""Start file monitoring"""
if not ENABLE_FILE_MONITORING:
logger.info("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
logger.debug("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
for path in self.monitor_paths:
@@ -525,18 +525,18 @@ class CheckpointFileMonitor(BaseFileMonitor):
def start(self):
"""Override start to check global enable flag"""
if not ENABLE_FILE_MONITORING:
logger.info("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
logger.debug("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
logger.info("Checkpoint file monitoring is temporarily disabled")
logger.debug("Checkpoint file monitoring is temporarily disabled")
# Skip the actual monitoring setup
pass
async def initialize_paths(self):
"""Initialize monitor paths from scanner - currently disabled"""
if not ENABLE_FILE_MONITORING:
logger.info("Checkpoint path initialization skipped (monitoring disabled)")
logger.debug("Checkpoint path initialization skipped (monitoring disabled)")
return
logger.info("Checkpoint file path initialization skipped (monitoring disabled)")
logger.debug("Checkpoint file path initialization skipped (monitoring disabled)")
pass

View File

@@ -341,6 +341,10 @@ class RecipeScanner:
metadata_updated = False
for lora in recipe_data['loras']:
# Skip deleted loras that were already marked
if lora.get('isDeleted', False):
continue
# Skip if already has complete information
if 'hash' in lora and 'file_name' in lora and lora['file_name']:
continue
@@ -356,10 +360,17 @@ class RecipeScanner:
metadata_updated = True
else:
# If not in cache, fetch from Civitai
hash_from_civitai = await self._get_hash_from_civitai(model_version_id)
if hash_from_civitai:
lora['hash'] = hash_from_civitai
metadata_updated = True
result = await self._get_hash_from_civitai(model_version_id)
if isinstance(result, tuple):
hash_from_civitai, is_deleted = result
if hash_from_civitai:
lora['hash'] = hash_from_civitai
metadata_updated = True
elif is_deleted:
# Mark the lora as deleted if it was not found on Civitai
lora['isDeleted'] = True
logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted")
metadata_updated = True
else:
logger.debug(f"Could not get hash for modelVersionId {model_version_id}")
@@ -411,41 +422,26 @@ class RecipeScanner:
logger.error("Failed to get CivitaiClient from ServiceRegistry")
return None
version_info = await civitai_client.get_model_version_info(model_version_id)
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
if not version_info or not version_info.get('files'):
logger.debug(f"No files found in version info for ID: {model_version_id}")
return None
if not version_info:
if error_msg and "model not found" in error_msg.lower():
logger.warning(f"Model with version ID {model_version_id} was not found on Civitai - marking as deleted")
return None, True # Return None hash and True for isDeleted flag
else:
logger.debug(f"Could not get hash for modelVersionId {model_version_id}: {error_msg}")
return None, False # Return None hash but not marked as deleted
# Get hash from the first file
for file_info in version_info.get('files', []):
if file_info.get('hashes', {}).get('SHA256'):
return file_info['hashes']['SHA256']
return file_info['hashes']['SHA256'], False # Return hash with False for isDeleted flag
logger.debug(f"No SHA256 hash found in version info for ID: {model_version_id}")
return None
return None, False
except Exception as e:
logger.error(f"Error getting hash from Civitai: {e}")
return None
async def _get_model_version_name(self, model_version_id: str) -> Optional[str]:
"""Get model version name from Civitai API"""
try:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
return None
version_info = await civitai_client.get_model_version_info(model_version_id)
if version_info and 'name' in version_info:
return version_info['name']
logger.debug(f"No version name found for modelVersionId {model_version_id}")
return None
except Exception as e:
logger.error(f"Error getting model version name from Civitai: {e}")
return None
return None, False
async def _determine_base_model(self, loras: List[Dict]) -> Optional[str]:
"""Determine the most common base model among LoRAs"""

View File

@@ -203,7 +203,7 @@ class ExifUtils:
return user_comment[:recipe_marker_index] + user_comment[next_line_index:]
@staticmethod
def optimize_image(image_data, target_width=250, format='webp', quality=85, preserve_metadata=True):
def optimize_image(image_data, target_width=250, format='webp', quality=85, preserve_metadata=False):
"""
Optimize an image by resizing and converting to WebP format
@@ -218,98 +218,144 @@ class ExifUtils:
Tuple of (optimized_image_data, extension)
"""
try:
# Extract metadata if needed
# First validate the image data is usable
img = None
if isinstance(image_data, str) and os.path.exists(image_data):
# It's a file path - validate file
try:
with Image.open(image_data) as test_img:
# Verify the image can be fully loaded by accessing its size
width, height = test_img.size
# If we got here, the image is valid
img = Image.open(image_data)
except (IOError, OSError) as e:
logger.error(f"Invalid or corrupt image file: {image_data}: {e}")
raise ValueError(f"Cannot process corrupt image: {e}")
else:
# It's binary data - validate data
try:
with BytesIO(image_data) as temp_buf:
test_img = Image.open(temp_buf)
# Verify the image can be fully loaded
width, height = test_img.size
# If successful, reopen for processing
img = Image.open(BytesIO(image_data))
except Exception as e:
logger.error(f"Invalid binary image data: {e}")
raise ValueError(f"Cannot process corrupt image data: {e}")
# Extract metadata if needed and valid
metadata = None
if preserve_metadata:
if isinstance(image_data, str) and os.path.exists(image_data):
# It's a file path
metadata = ExifUtils.extract_image_metadata(image_data)
img = Image.open(image_data)
else:
# It's binary data
temp_img = BytesIO(image_data)
img = Image.open(temp_img)
# Save to a temporary file to extract metadata
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(image_data)
metadata = ExifUtils.extract_image_metadata(temp_path)
os.unlink(temp_path)
else:
# Just open the image without extracting metadata
if isinstance(image_data, str) and os.path.exists(image_data):
img = Image.open(image_data)
else:
img = Image.open(BytesIO(image_data))
try:
if isinstance(image_data, str) and os.path.exists(image_data):
# For file path, extract directly
metadata = ExifUtils.extract_image_metadata(image_data)
else:
# For binary data, save to temp file first
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(image_data)
try:
metadata = ExifUtils.extract_image_metadata(temp_path)
except Exception as e:
logger.warning(f"Failed to extract metadata from temp file: {e}")
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to extract metadata, continuing without it: {e}")
# Continue without metadata
# Calculate new height to maintain aspect ratio
width, height = img.size
new_height = int(height * (target_width / width))
# Resize the image
resized_img = img.resize((target_width, new_height), Image.LANCZOS)
# Resize the image with error handling
try:
resized_img = img.resize((target_width, new_height), Image.LANCZOS)
except Exception as e:
logger.error(f"Failed to resize image: {e}")
# Return original image if resize fails
return image_data, '.jpg' if not isinstance(image_data, str) else os.path.splitext(image_data)[1]
# Save to BytesIO in the specified format
output = BytesIO()
# WebP format
# Set format and extension
if format.lower() == 'webp':
resized_img.save(output, format='WEBP', quality=quality)
extension = '.webp'
# JPEG format
save_format, extension = 'WEBP', '.webp'
elif format.lower() in ('jpg', 'jpeg'):
resized_img.save(output, format='JPEG', quality=quality)
extension = '.jpg'
# PNG format
save_format, extension = 'JPEG', '.jpg'
elif format.lower() == 'png':
resized_img.save(output, format='PNG', optimize=True)
extension = '.png'
save_format, extension = 'PNG', '.png'
else:
# Default to WebP
resized_img.save(output, format='WEBP', quality=quality)
extension = '.webp'
save_format, extension = 'WEBP', '.webp'
# Save with error handling
try:
if save_format == 'PNG':
resized_img.save(output, format=save_format, optimize=True)
else:
resized_img.save(output, format=save_format, quality=quality)
except Exception as e:
logger.error(f"Failed to save optimized image: {e}")
# Return original image if save fails
return image_data, '.jpg' if not isinstance(image_data, str) else os.path.splitext(image_data)[1]
# Get the optimized image data
optimized_data = output.getvalue()
# If we need to preserve metadata, write it to a temporary file
# Handle metadata preservation if requested and available
if preserve_metadata and metadata:
# For WebP format, we'll directly save with metadata
if format.lower() == 'webp':
# Create a new BytesIO with metadata
output_with_metadata = BytesIO()
# Create EXIF data with user comment
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict)
# Save with metadata
resized_img.save(output_with_metadata, format='WEBP', exif=exif_bytes, quality=quality)
optimized_data = output_with_metadata.getvalue()
else:
# For other formats, use the temporary file approach
import tempfile
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(optimized_data)
# Add the metadata back
ExifUtils.update_image_metadata(temp_path, metadata)
# Read the file with metadata
with open(temp_path, 'rb') as f:
optimized_data = f.read()
# Clean up
os.unlink(temp_path)
try:
if save_format == 'WEBP':
# For WebP format, directly save with metadata
try:
output_with_metadata = BytesIO()
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict)
resized_img.save(output_with_metadata, format='WEBP', exif=exif_bytes, quality=quality)
optimized_data = output_with_metadata.getvalue()
except Exception as e:
logger.warning(f"Failed to add metadata to WebP, continuing without it: {e}")
else:
# For other formats, use temporary file
import tempfile
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(optimized_data)
try:
# Add metadata
ExifUtils.update_image_metadata(temp_path, metadata)
# Read back the file
with open(temp_path, 'rb') as f:
optimized_data = f.read()
except Exception as e:
logger.warning(f"Failed to add metadata to image, continuing without it: {e}")
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to preserve metadata: {e}, continuing with unmodified output")
return optimized_data, extension
except Exception as e:
logger.error(f"Error optimizing image: {e}", exc_info=True)
# Return original data if optimization fails
# Return original data if optimization completely fails
if isinstance(image_data, str) and os.path.exists(image_data):
with open(image_data, 'rb') as f:
return f.read(), os.path.splitext(image_data)[1]
try:
with open(image_data, 'rb') as f:
return f.read(), os.path.splitext(image_data)[1]
except Exception:
return image_data, '.jpg' # Last resort fallback
return image_data, '.jpg'

View File

@@ -42,7 +42,7 @@ def find_preview_file(base_name: str, dir_path: str) -> str:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False # Changed from True to False
)
# Save the optimized webp file

View File

@@ -95,7 +95,7 @@ class ModelRouteUtils:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
# Save the optimized WebP image
@@ -387,7 +387,7 @@ class ModelRouteUtils:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
extension = '.webp' # Use .webp without .preview part

View File

@@ -1,7 +1,7 @@
[project]
name = "comfyui-lora-manager"
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
version = "0.8.6"
version = "0.8.7"
license = {file = "LICENSE"}
dependencies = [
"aiohttp",

View File

@@ -2,7 +2,7 @@
import { state, getCurrentPageState } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { showDeleteModal, confirmDelete } from '../utils/modalUtils.js';
import { getSessionItem } from '../utils/storageHelpers.js';
import { getSessionItem, saveMapToStorage } from '../utils/storageHelpers.js';
/**
* Shared functionality for handling models (loras and checkpoints)
@@ -424,12 +424,20 @@ async function uploadPreview(filePath, file, modelType = 'lora') {
const previewContainer = card.querySelector('.card-preview');
const oldPreview = previewContainer.querySelector('img, video');
// For LoRA models, use timestamp to prevent caching
if (modelType === 'lora') {
state.previewVersions?.set(filePath, Date.now());
// Get the current page's previewVersions Map based on model type
const pageType = modelType === 'checkpoint' ? 'checkpoints' : 'loras';
const previewVersions = state.pages[pageType].previewVersions;
// Update the version timestamp
const timestamp = Date.now();
if (previewVersions) {
previewVersions.set(filePath, timestamp);
// Save the updated Map to localStorage
const storageKey = modelType === 'checkpoint' ? 'checkpoint_preview_versions' : 'lora_preview_versions';
saveMapToStorage(storageKey, previewVersions);
}
const timestamp = Date.now();
const previewUrl = data.preview_url ?
`${data.preview_url}?t=${timestamp}` :
`/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`;

View File

@@ -5,7 +5,8 @@ import {
refreshModels as baseRefreshModels,
deleteModel as baseDeleteModel,
replaceModelPreview,
fetchCivitaiMetadata
fetchCivitaiMetadata,
refreshSingleModelMetadata
} from './baseModelApi.js';
// Load more checkpoints with pagination
@@ -54,4 +55,29 @@ export async function fetchCivitai() {
fetchEndpoint: '/api/checkpoints/fetch-all-civitai',
resetAndReloadFunction: resetAndReload
});
}
// Refresh single checkpoint metadata
export async function refreshSingleCheckpointMetadata(filePath) {
return refreshSingleModelMetadata(filePath, 'checkpoint');
}
// Save checkpoint metadata (similar to the Lora version)
export async function saveCheckpointMetadata(filePath, data) {
const response = await fetch('/api/checkpoints/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return await response.json();
}

View File

@@ -4,6 +4,7 @@ import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';
import { createPageControls } from './components/controls/index.js';
import { loadMoreCheckpoints } from './api/checkpointApi.js';
import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js';
import { CheckpointContextMenu } from './components/ContextMenu/index.js';
// Initialize the Checkpoints page
class CheckpointsPageManager {
@@ -34,6 +35,9 @@ class CheckpointsPageManager {
this.pageControls.restoreFolderFilter();
this.pageControls.initFolderTagsVisibility();
// Initialize context menu
new CheckpointContextMenu();
// Initialize infinite scroll
initializeInfiniteScroll('checkpoints');

View File

@@ -44,7 +44,10 @@ export function createCheckpointCard(checkpoint) {
// Determine preview URL
const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png';
const version = state.previewVersions ? state.previewVersions.get(checkpoint.file_path) : null;
// Get the page-specific previewVersions map
const previewVersions = state.pages.checkpoints.previewVersions || new Map();
const version = previewVersions.get(checkpoint.file_path);
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
// Determine NSFW warning text based on level

View File

@@ -366,4 +366,7 @@ export class LoraContextMenu {
this.menu.style.display = 'none';
this.currentCard = null;
}
}
}
// For backward compatibility, re-export the LoraContextMenu class
// export { LoraContextMenu } from './ContextMenu/LoraContextMenu.js';

View File

@@ -0,0 +1,84 @@
export class BaseContextMenu {
constructor(menuId, cardSelector) {
this.menu = document.getElementById(menuId);
this.cardSelector = cardSelector;
this.currentCard = null;
if (!this.menu) {
console.error(`Context menu element with ID ${menuId} not found`);
return;
}
this.init();
}
init() {
// Hide menu on regular clicks
document.addEventListener('click', () => this.hideMenu());
// Show menu on right-click on cards
document.addEventListener('contextmenu', (e) => {
const card = e.target.closest(this.cardSelector);
if (!card) {
this.hideMenu();
return;
}
e.preventDefault();
this.showMenu(e.clientX, e.clientY, card);
});
// Handle menu item clicks
this.menu.addEventListener('click', (e) => {
const menuItem = e.target.closest('.context-menu-item');
if (!menuItem || !this.currentCard) return;
const action = menuItem.dataset.action;
if (!action) return;
this.handleMenuAction(action, menuItem);
this.hideMenu();
});
}
handleMenuAction(action, menuItem) {
// Override in subclass
console.warn('handleMenuAction not implemented');
}
showMenu(x, y, card) {
this.currentCard = card;
this.menu.style.display = 'block';
// Get menu dimensions
const menuRect = this.menu.getBoundingClientRect();
// Get viewport dimensions
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
// Calculate position
let finalX = x;
let finalY = y;
// Ensure menu doesn't go offscreen right
if (x + menuRect.width > viewportWidth) {
finalX = x - menuRect.width;
}
// Ensure menu doesn't go offscreen bottom
if (y + menuRect.height > viewportHeight) {
finalY = y - menuRect.height;
}
// Position menu
this.menu.style.left = `${finalX}px`;
this.menu.style.top = `${finalY}px`;
}
hideMenu() {
if (this.menu) {
this.menu.style.display = 'none';
}
this.currentCard = null;
}
}

View File

@@ -0,0 +1,315 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { refreshSingleCheckpointMetadata, saveCheckpointMetadata } from '../../api/checkpointApi.js';
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
import { getStorageItem } from '../../utils/storageHelpers.js';
export class CheckpointContextMenu extends BaseContextMenu {
constructor() {
super('checkpointContextMenu', '.lora-card');
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
// Initialize NSFW Level Selector events
if (this.nsfwSelector) {
this.initNSFWSelector();
}
}
handleMenuAction(action) {
switch(action) {
case 'details':
// Show checkpoint details
this.currentCard.click();
break;
case 'preview':
// Replace checkpoint preview
if (this.currentCard.querySelector('.fa-image')) {
this.currentCard.querySelector('.fa-image').click();
}
break;
case 'civitai':
// Open civitai page
if (this.currentCard.dataset.from_civitai === 'true') {
if (this.currentCard.querySelector('.fa-globe')) {
this.currentCard.querySelector('.fa-globe').click();
}
} else {
showToast('No CivitAI information available', 'info');
}
break;
case 'delete':
// Delete checkpoint
if (this.currentCard.querySelector('.fa-trash')) {
this.currentCard.querySelector('.fa-trash').click();
}
break;
case 'copyname':
// Copy checkpoint name
if (this.currentCard.querySelector('.fa-copy')) {
this.currentCard.querySelector('.fa-copy').click();
}
break;
case 'refresh-metadata':
// Refresh metadata from CivitAI
refreshSingleCheckpointMetadata(this.currentCard.dataset.filepath);
break;
case 'set-nsfw':
// Set NSFW level
this.showNSFWLevelSelector(null, null, this.currentCard);
break;
case 'move':
// Move to folder (placeholder)
showToast('Move to folder feature coming soon', 'info');
break;
}
}
// NSFW Selector methods
initNSFWSelector() {
// Close button
const closeBtn = this.nsfwSelector.querySelector('.close-nsfw-selector');
closeBtn.addEventListener('click', () => {
this.nsfwSelector.style.display = 'none';
});
// Level buttons
const levelButtons = this.nsfwSelector.querySelectorAll('.nsfw-level-btn');
levelButtons.forEach(btn => {
btn.addEventListener('click', async () => {
const level = parseInt(btn.dataset.level);
const filePath = this.nsfwSelector.dataset.cardPath;
if (!filePath) return;
try {
await saveCheckpointMetadata(filePath, { preview_nsfw_level: level });
// Update card data
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
let metaData = {};
try {
metaData = JSON.parse(card.dataset.meta || '{}');
} catch (err) {
console.error('Error parsing metadata:', err);
}
metaData.preview_nsfw_level = level;
card.dataset.meta = JSON.stringify(metaData);
card.dataset.nsfwLevel = level.toString();
// Apply blur effect immediately
this.updateCardBlurEffect(card, level);
}
showToast(`Content rating set to ${getNSFWLevelName(level)}`, 'success');
this.nsfwSelector.style.display = 'none';
} catch (error) {
showToast(`Failed to set content rating: ${error.message}`, 'error');
}
});
});
// Close when clicking outside
document.addEventListener('click', (e) => {
if (this.nsfwSelector.style.display === 'block' &&
!this.nsfwSelector.contains(e.target) &&
!e.target.closest('.context-menu-item[data-action="set-nsfw"]')) {
this.nsfwSelector.style.display = 'none';
}
});
}
updateCardBlurEffect(card, level) {
// Get user settings for blur threshold
const blurThreshold = parseInt(getStorageItem('nsfwBlurLevel') || '4');
// Get card preview container
const previewContainer = card.querySelector('.card-preview');
if (!previewContainer) return;
// Get preview media element
const previewMedia = previewContainer.querySelector('img') || previewContainer.querySelector('video');
if (!previewMedia) return;
// Check if blur should be applied
if (level >= blurThreshold) {
// Add blur class to the preview container
previewContainer.classList.add('blurred');
// Get or create the NSFW overlay
let nsfwOverlay = previewContainer.querySelector('.nsfw-overlay');
if (!nsfwOverlay) {
// Create new overlay
nsfwOverlay = document.createElement('div');
nsfwOverlay.className = 'nsfw-overlay';
// Create and configure the warning content
const warningContent = document.createElement('div');
warningContent.className = 'nsfw-warning';
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Add warning text and show button
warningContent.innerHTML = `
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
`;
// Add click event to the show button
const showBtn = warningContent.querySelector('.show-content-btn');
showBtn.addEventListener('click', (e) => {
e.stopPropagation();
previewContainer.classList.remove('blurred');
nsfwOverlay.style.display = 'none';
// Update toggle button icon if it exists
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
});
nsfwOverlay.appendChild(warningContent);
previewContainer.appendChild(nsfwOverlay);
} else {
// Update existing overlay
const warningText = nsfwOverlay.querySelector('p');
if (warningText) {
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
warningText.textContent = nsfwText;
}
nsfwOverlay.style.display = 'flex';
}
// Get or create the toggle button in the header
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
let toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (!toggleBtn) {
toggleBtn = document.createElement('button');
toggleBtn.className = 'toggle-blur-btn';
toggleBtn.title = 'Toggle blur';
toggleBtn.innerHTML = '<i class="fas fa-eye"></i>';
// Add click event to toggle button
toggleBtn.addEventListener('click', (e) => {
e.stopPropagation();
const isBlurred = previewContainer.classList.toggle('blurred');
const icon = toggleBtn.querySelector('i');
// Update icon and overlay visibility
if (isBlurred) {
icon.className = 'fas fa-eye';
nsfwOverlay.style.display = 'flex';
} else {
icon.className = 'fas fa-eye-slash';
nsfwOverlay.style.display = 'none';
}
});
// Add to the beginning of header
cardHeader.insertBefore(toggleBtn, cardHeader.firstChild);
// Update base model label class
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && !baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.add('with-toggle');
}
} else {
// Update existing toggle button
toggleBtn.querySelector('i').className = 'fas fa-eye';
}
}
} else {
// Remove blur
previewContainer.classList.remove('blurred');
// Hide overlay if it exists
const overlay = previewContainer.querySelector('.nsfw-overlay');
if (overlay) overlay.style.display = 'none';
// Remove toggle button when content is set to PG or PG13
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
const toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (toggleBtn) {
// Remove the toggle button completely
toggleBtn.remove();
// Update base model label class if it exists
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.remove('with-toggle');
}
}
}
}
}
showNSFWLevelSelector(x, y, card) {
const selector = document.getElementById('nsfwLevelSelector');
const currentLevelEl = document.getElementById('currentNSFWLevel');
// Get current NSFW level
let currentLevel = 0;
try {
const metaData = JSON.parse(card.dataset.meta || '{}');
currentLevel = metaData.preview_nsfw_level || 0;
// Update if we have no recorded level but have a dataset attribute
if (!currentLevel && card.dataset.nsfwLevel) {
currentLevel = parseInt(card.dataset.nsfwLevel) || 0;
}
} catch (err) {
console.error('Error parsing metadata:', err);
}
currentLevelEl.textContent = getNSFWLevelName(currentLevel);
// Position the selector
if (x && y) {
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
const selectorRect = selector.getBoundingClientRect();
// Center the selector if no coordinates provided
let finalX = (viewportWidth - selectorRect.width) / 2;
let finalY = (viewportHeight - selectorRect.height) / 2;
selector.style.left = `${finalX}px`;
selector.style.top = `${finalY}px`;
}
// Highlight current level button
document.querySelectorAll('.nsfw-level-btn').forEach(btn => {
if (parseInt(btn.dataset.level) === currentLevel) {
btn.classList.add('active');
} else {
btn.classList.remove('active');
}
});
// Store reference to current card
selector.dataset.cardPath = card.dataset.filepath;
// Show selector
selector.style.display = 'block';
}
}

View File

@@ -0,0 +1,324 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { refreshSingleLoraMetadata } from '../../api/loraApi.js';
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
import { getStorageItem } from '../../utils/storageHelpers.js';
export class LoraContextMenu extends BaseContextMenu {
constructor() {
super('loraContextMenu', '.lora-card');
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
// Initialize NSFW Level Selector events
if (this.nsfwSelector) {
this.initNSFWSelector();
}
}
handleMenuAction(action, menuItem) {
switch(action) {
case 'detail':
// Trigger the main card click which shows the modal
this.currentCard.click();
break;
case 'civitai':
// Only trigger if the card is from civitai
if (this.currentCard.dataset.from_civitai === 'true') {
if (this.currentCard.dataset.meta === '{}') {
showToast('Please fetch metadata from CivitAI first', 'info');
} else {
this.currentCard.querySelector('.fa-globe')?.click();
}
} else {
showToast('No CivitAI information available', 'info');
}
break;
case 'copyname':
this.currentCard.querySelector('.fa-copy')?.click();
break;
case 'preview':
this.currentCard.querySelector('.fa-image')?.click();
break;
case 'delete':
this.currentCard.querySelector('.fa-trash')?.click();
break;
case 'move':
moveManager.showMoveModal(this.currentCard.dataset.filepath);
break;
case 'refresh-metadata':
refreshSingleLoraMetadata(this.currentCard.dataset.filepath);
break;
case 'set-nsfw':
this.showNSFWLevelSelector(null, null, this.currentCard);
break;
}
}
// NSFW Selector methods from the original context menu
initNSFWSelector() {
// Close button
const closeBtn = this.nsfwSelector.querySelector('.close-nsfw-selector');
closeBtn.addEventListener('click', () => {
this.nsfwSelector.style.display = 'none';
});
// Level buttons
const levelButtons = this.nsfwSelector.querySelectorAll('.nsfw-level-btn');
levelButtons.forEach(btn => {
btn.addEventListener('click', async () => {
const level = parseInt(btn.dataset.level);
const filePath = this.nsfwSelector.dataset.cardPath;
if (!filePath) return;
try {
await this.saveModelMetadata(filePath, { preview_nsfw_level: level });
// Update card data
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
let metaData = {};
try {
metaData = JSON.parse(card.dataset.meta || '{}');
} catch (err) {
console.error('Error parsing metadata:', err);
}
metaData.preview_nsfw_level = level;
card.dataset.meta = JSON.stringify(metaData);
card.dataset.nsfwLevel = level.toString();
// Apply blur effect immediately
this.updateCardBlurEffect(card, level);
}
showToast(`Content rating set to ${getNSFWLevelName(level)}`, 'success');
this.nsfwSelector.style.display = 'none';
} catch (error) {
showToast(`Failed to set content rating: ${error.message}`, 'error');
}
});
});
// Close when clicking outside
document.addEventListener('click', (e) => {
if (this.nsfwSelector.style.display === 'block' &&
!this.nsfwSelector.contains(e.target) &&
!e.target.closest('.context-menu-item[data-action="set-nsfw"]')) {
this.nsfwSelector.style.display = 'none';
}
});
}
async saveModelMetadata(filePath, data) {
const response = await fetch('/api/loras/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return await response.json();
}
updateCardBlurEffect(card, level) {
// Get user settings for blur threshold
const blurThreshold = parseInt(getStorageItem('nsfwBlurLevel') || '4');
// Get card preview container
const previewContainer = card.querySelector('.card-preview');
if (!previewContainer) return;
// Get preview media element
const previewMedia = previewContainer.querySelector('img') || previewContainer.querySelector('video');
if (!previewMedia) return;
// Check if blur should be applied
if (level >= blurThreshold) {
// Add blur class to the preview container
previewContainer.classList.add('blurred');
// Get or create the NSFW overlay
let nsfwOverlay = previewContainer.querySelector('.nsfw-overlay');
if (!nsfwOverlay) {
// Create new overlay
nsfwOverlay = document.createElement('div');
nsfwOverlay.className = 'nsfw-overlay';
// Create and configure the warning content
const warningContent = document.createElement('div');
warningContent.className = 'nsfw-warning';
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Add warning text and show button
warningContent.innerHTML = `
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
`;
// Add click event to the show button
const showBtn = warningContent.querySelector('.show-content-btn');
showBtn.addEventListener('click', (e) => {
e.stopPropagation();
previewContainer.classList.remove('blurred');
nsfwOverlay.style.display = 'none';
// Update toggle button icon if it exists
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
});
nsfwOverlay.appendChild(warningContent);
previewContainer.appendChild(nsfwOverlay);
} else {
// Update existing overlay
const warningText = nsfwOverlay.querySelector('p');
if (warningText) {
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
warningText.textContent = nsfwText;
}
nsfwOverlay.style.display = 'flex';
}
// Get or create the toggle button in the header
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
let toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (!toggleBtn) {
toggleBtn = document.createElement('button');
toggleBtn.className = 'toggle-blur-btn';
toggleBtn.title = 'Toggle blur';
toggleBtn.innerHTML = '<i class="fas fa-eye"></i>';
// Add click event to toggle button
toggleBtn.addEventListener('click', (e) => {
e.stopPropagation();
const isBlurred = previewContainer.classList.toggle('blurred');
const icon = toggleBtn.querySelector('i');
// Update icon and overlay visibility
if (isBlurred) {
icon.className = 'fas fa-eye';
nsfwOverlay.style.display = 'flex';
} else {
icon.className = 'fas fa-eye-slash';
nsfwOverlay.style.display = 'none';
}
});
// Add to the beginning of header
cardHeader.insertBefore(toggleBtn, cardHeader.firstChild);
// Update base model label class
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && !baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.add('with-toggle');
}
} else {
// Update existing toggle button
toggleBtn.querySelector('i').className = 'fas fa-eye';
}
}
} else {
// Remove blur
previewContainer.classList.remove('blurred');
// Hide overlay if it exists
const overlay = previewContainer.querySelector('.nsfw-overlay');
if (overlay) overlay.style.display = 'none';
// Remove toggle button when content is set to PG or PG13
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
const toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (toggleBtn) {
// Remove the toggle button completely
toggleBtn.remove();
// Update base model label class if it exists
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.remove('with-toggle');
}
}
}
}
}
showNSFWLevelSelector(x, y, card) {
const selector = document.getElementById('nsfwLevelSelector');
const currentLevelEl = document.getElementById('currentNSFWLevel');
// Get current NSFW level
let currentLevel = 0;
try {
const metaData = JSON.parse(card.dataset.meta || '{}');
currentLevel = metaData.preview_nsfw_level || 0;
// Update if we have no recorded level but have a dataset attribute
if (!currentLevel && card.dataset.nsfwLevel) {
currentLevel = parseInt(card.dataset.nsfwLevel) || 0;
}
} catch (err) {
console.error('Error parsing metadata:', err);
}
currentLevelEl.textContent = getNSFWLevelName(currentLevel);
// Position the selector
if (x && y) {
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
const selectorRect = selector.getBoundingClientRect();
// Center the selector if no coordinates provided
let finalX = (viewportWidth - selectorRect.width) / 2;
let finalY = (viewportHeight - selectorRect.height) / 2;
selector.style.left = `${finalX}px`;
selector.style.top = `${finalY}px`;
}
// Highlight current level button
document.querySelectorAll('.nsfw-level-btn').forEach(btn => {
if (parseInt(btn.dataset.level) === currentLevel) {
btn.classList.add('active');
} else {
btn.classList.remove('active');
}
});
// Store reference to current card
selector.dataset.cardPath = card.dataset.filepath;
// Show selector
selector.style.display = 'block';
}
}

View File

@@ -0,0 +1,205 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { showToast } from '../../utils/uiHelpers.js';
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
import { state } from '../../state/index.js';
export class RecipeContextMenu extends BaseContextMenu {
constructor() {
super('recipeContextMenu', '.lora-card');
}
showMenu(x, y, card) {
// Call the parent method first to handle basic positioning
super.showMenu(x, y, card);
// Get recipe data to check for missing LoRAs
const recipeId = card.dataset.id;
const missingLorasItem = this.menu.querySelector('.download-missing-item');
if (recipeId && missingLorasItem) {
// Check if this card has missing LoRAs
const loraCountElement = card.querySelector('.lora-count');
const hasMissingLoras = loraCountElement && loraCountElement.classList.contains('missing');
// Show/hide the download missing LoRAs option based on missing status
if (hasMissingLoras) {
missingLorasItem.style.display = 'flex';
} else {
missingLorasItem.style.display = 'none';
}
}
}
handleMenuAction(action) {
const recipeId = this.currentCard.dataset.id;
switch(action) {
case 'details':
// Show recipe details
this.currentCard.click();
break;
case 'copy':
// Copy recipe to clipboard
this.currentCard.querySelector('.fa-copy')?.click();
break;
case 'share':
// Share recipe
this.currentCard.querySelector('.fa-share-alt')?.click();
break;
case 'delete':
// Delete recipe
this.currentCard.querySelector('.fa-trash')?.click();
break;
case 'viewloras':
// View all LoRAs in the recipe
this.viewRecipeLoRAs(recipeId);
break;
case 'download-missing':
// Download missing LoRAs
this.downloadMissingLoRAs(recipeId);
break;
}
}
// View all LoRAs in the recipe
viewRecipeLoRAs(recipeId) {
if (!recipeId) {
showToast('Cannot view LoRAs: Missing recipe ID', 'error');
return;
}
// First get the recipe details to access its LoRAs
fetch(`/api/recipe/${recipeId}`)
.then(response => response.json())
.then(recipe => {
// Clear any previous filters first
removeSessionItem('recipe_to_lora_filterLoraHash');
removeSessionItem('recipe_to_lora_filterLoraHashes');
removeSessionItem('filterRecipeName');
removeSessionItem('viewLoraDetail');
// Collect all hashes from the recipe's LoRAs
const loraHashes = recipe.loras
.filter(lora => lora.hash)
.map(lora => lora.hash.toLowerCase());
if (loraHashes.length > 0) {
// Store the LoRA hashes and recipe name in session storage
setSessionItem('recipe_to_lora_filterLoraHashes', JSON.stringify(loraHashes));
setSessionItem('filterRecipeName', recipe.title);
// Navigate to the LoRAs page
window.location.href = '/loras';
} else {
showToast('No LoRAs found in this recipe', 'info');
}
})
.catch(error => {
console.error('Error loading recipe LoRAs:', error);
showToast('Error loading recipe LoRAs: ' + error.message, 'error');
});
}
// Download missing LoRAs
async downloadMissingLoRAs(recipeId) {
if (!recipeId) {
showToast('Cannot download LoRAs: Missing recipe ID', 'error');
return;
}
try {
// First get the recipe details
const response = await fetch(`/api/recipe/${recipeId}`);
const recipe = await response.json();
// Get missing LoRAs
const missingLoras = recipe.loras.filter(lora => !lora.inLibrary && !lora.isDeleted);
if (missingLoras.length === 0) {
showToast('No missing LoRAs to download', 'info');
return;
}
// Show loading toast
state.loadingManager.showSimpleLoading('Getting version info for missing LoRAs...');
// Get version info for each missing LoRA
const missingLorasWithVersionInfoPromises = missingLoras.map(async lora => {
let endpoint;
// Determine which endpoint to use based on available data
if (lora.modelVersionId) {
endpoint = `/api/civitai/model/version/${lora.modelVersionId}`;
} else if (lora.hash) {
endpoint = `/api/civitai/model/hash/${lora.hash}`;
} else {
console.error("Missing both hash and modelVersionId for lora:", lora);
return null;
}
const versionResponse = await fetch(endpoint);
const versionInfo = await versionResponse.json();
// Return original lora data combined with version info
return {
...lora,
civitaiInfo: versionInfo
};
});
// Wait for all API calls to complete
const lorasWithVersionInfo = await Promise.all(missingLorasWithVersionInfoPromises);
// Filter out null values (failed requests)
const validLoras = lorasWithVersionInfo.filter(lora => lora !== null);
if (validLoras.length === 0) {
showToast('Failed to get information for missing LoRAs', 'error');
return;
}
// Prepare data for import manager using the retrieved information
const recipeData = {
loras: validLoras.map(lora => {
const civitaiInfo = lora.civitaiInfo;
const modelFile = civitaiInfo.files ?
civitaiInfo.files.find(file => file.type === 'Model') : null;
return {
// Basic lora info
name: civitaiInfo.model?.name || lora.name,
version: civitaiInfo.name || '',
strength: lora.strength || 1.0,
// Model identifiers
hash: modelFile?.hashes?.SHA256?.toLowerCase() || lora.hash,
modelVersionId: civitaiInfo.id || lora.modelVersionId,
// Metadata
thumbnailUrl: civitaiInfo.images?.[0]?.url || '',
baseModel: civitaiInfo.baseModel || '',
downloadUrl: civitaiInfo.downloadUrl || '',
size: modelFile ? (modelFile.sizeKB * 1024) : 0,
file_name: modelFile ? modelFile.name.split('.')[0] : '',
// Status flags
existsLocally: false,
isDeleted: civitaiInfo.error === "Model not found",
isEarlyAccess: !!civitaiInfo.earlyAccessEndsAt,
earlyAccessEndsAt: civitaiInfo.earlyAccessEndsAt || ''
};
})
};
// Call ImportManager's download missing LoRAs method
window.importManager.downloadMissingLoras(recipeData, recipeId);
} catch (error) {
console.error('Error downloading missing LoRAs:', error);
showToast('Error preparing LoRAs for download: ' + error.message, 'error');
} finally {
if (state.loadingManager) {
state.loadingManager.hide();
}
}
}
}

View File

@@ -0,0 +1,3 @@
export { LoraContextMenu } from './LoraContextMenu.js';
export { RecipeContextMenu } from './RecipeContextMenu.js';
export { CheckpointContextMenu } from './CheckpointContextMenu.js';

View File

@@ -44,7 +44,9 @@ export function createLoraCard(lora) {
card.classList.add('selected');
}
const version = state.previewVersions.get(lora.file_path);
// Get the page-specific previewVersions map
const previewVersions = state.pages.loras.previewVersions || new Map();
const version = previewVersions.get(lora.file_path);
const previewUrl = lora.preview_url || '/loras_static/images/no-preview.png';
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;

View File

@@ -790,9 +790,9 @@ class RecipeModal {
// Determine which endpoint to use based on available data
if (lora.modelVersionId) {
endpoint = `/api/civitai/model/${lora.modelVersionId}`;
endpoint = `/api/civitai/model/version/${lora.modelVersionId}`;
} else if (lora.hash) {
endpoint = `/api/civitai/model/${lora.hash}`;
endpoint = `/api/civitai/model/hash/${lora.hash}`;
} else {
console.error("Missing both hash and modelVersionId for lora:", lora);
return null;

View File

@@ -235,8 +235,8 @@ function addNewTriggerWord(word) {
// Validation: Check total number
const currentTags = tagsContainer.querySelectorAll('.trigger-word-tag');
if (currentTags.length >= 10) {
showToast('Maximum 10 trigger words allowed', 'error');
if (currentTags.length >= 30) {
showToast('Maximum 30 trigger words allowed', 'error');
return;
}
@@ -336,7 +336,22 @@ async function saveTriggerWords() {
*/
window.copyTriggerWord = async function(word) {
try {
await navigator.clipboard.writeText(word);
// Modern clipboard API - with fallback for non-secure contexts
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(word);
} else {
// Fallback for older browsers or non-secure contexts
const textarea = document.createElement('textarea');
textarea.value = word;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
const success = document.execCommand('copy');
document.body.removeChild(textarea);
if (!success) throw new Error('Copy command failed');
}
showToast('Trigger word copied', 'success');
} catch (err) {
console.error('Copy failed:', err);

View File

@@ -6,7 +6,7 @@ import { updateCardsForBulkMode } from './components/LoraCard.js';
import { bulkManager } from './managers/BulkManager.js';
import { DownloadManager } from './managers/DownloadManager.js';
import { moveManager } from './managers/MoveManager.js';
import { LoraContextMenu } from './components/ContextMenu.js';
import { LoraContextMenu } from './components/ContextMenu/index.js';
import { createPageControls } from './components/controls/index.js';
import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';

View File

@@ -5,6 +5,7 @@ import { RecipeCard } from './components/RecipeCard.js';
import { RecipeModal } from './components/RecipeModal.js';
import { getCurrentPageState } from './state/index.js';
import { getSessionItem, removeSessionItem } from './utils/storageHelpers.js';
import { RecipeContextMenu } from './components/ContextMenu/index.js';
class RecipeManager {
constructor() {
@@ -37,6 +38,9 @@ class RecipeManager {
// Set default search options if not already defined
this._initSearchOptions();
// Initialize context menu
new RecipeContextMenu();
// Check for custom filter parameters in session storage
this._checkCustomFilter();

View File

@@ -1,5 +1,5 @@
// Create the new hierarchical state structure
import { getStorageItem } from '../utils/storageHelpers.js';
import { getStorageItem, getMapFromStorage } from '../utils/storageHelpers.js';
// Load settings from localStorage or use defaults
const savedSettings = getStorageItem('settings', {
@@ -7,6 +7,10 @@ const savedSettings = getStorageItem('settings', {
show_only_sfw: false
});
// Load preview versions from localStorage
const loraPreviewVersions = getMapFromStorage('lora_preview_versions');
const checkpointPreviewVersions = getMapFromStorage('checkpoint_preview_versions');
export const state = {
// Global state
global: {
@@ -23,7 +27,7 @@ export const state = {
hasMore: true,
sortBy: 'name',
activeFolder: null,
previewVersions: new Map(),
previewVersions: loraPreviewVersions,
searchManager: null,
searchOptions: {
filename: true,
@@ -66,6 +70,7 @@ export const state = {
hasMore: true,
sortBy: 'name',
activeFolder: null,
previewVersions: checkpointPreviewVersions,
searchManager: null,
searchOptions: {
filename: true,

View File

@@ -4,6 +4,7 @@ import { loadMoreCheckpoints } from '../api/checkpointApi.js';
import { debounce } from './debounce.js';
export function initializeInfiniteScroll(pageType = 'loras') {
// Clean up any existing observer
if (state.observer) {
state.observer.disconnect();
}
@@ -47,53 +48,53 @@ export function initializeInfiniteScroll(pageType = 'loras') {
}
const debouncedLoadMore = debounce(loadMoreFunction, 100);
// Create a more robust observer with lower threshold and root margin
state.observer = new IntersectionObserver(
(entries) => {
const target = entries[0];
if (target.isIntersecting && !pageState.isLoading && pageState.hasMore) {
debouncedLoadMore();
}
},
{
threshold: 0.01, // Lower threshold to detect even minimal visibility
rootMargin: '0px 0px 300px 0px' // Increase bottom margin to trigger earlier
}
);
const grid = document.getElementById(gridId);
if (!grid) {
console.warn(`Grid with ID "${gridId}" not found for infinite scroll`);
return;
}
// Remove any existing sentinel
const existingSentinel = document.getElementById('scroll-sentinel');
if (existingSentinel) {
state.observer.observe(existingSentinel);
} else {
// Create a wrapper div that will be placed after the grid
const sentinelWrapper = document.createElement('div');
sentinelWrapper.style.width = '100%';
sentinelWrapper.style.height = '30px'; // Increased height for better visibility
sentinelWrapper.style.margin = '0';
sentinelWrapper.style.padding = '0';
// Create the actual sentinel element
const sentinel = document.createElement('div');
sentinel.id = 'scroll-sentinel';
sentinel.style.height = '30px'; // Match wrapper height
// Add the sentinel to the wrapper
sentinelWrapper.appendChild(sentinel);
// Insert the wrapper after the grid instead of inside it
grid.parentNode.insertBefore(sentinelWrapper, grid.nextSibling);
state.observer.observe(sentinel);
existingSentinel.remove();
}
// Add a scroll event backup to handle edge cases
// Create a sentinel element after the grid (not inside it)
const sentinel = document.createElement('div');
sentinel.id = 'scroll-sentinel';
sentinel.style.width = '100%';
sentinel.style.height = '20px';
sentinel.style.visibility = 'hidden'; // Make it invisible but still affect layout
// Insert after grid instead of inside
grid.parentNode.insertBefore(sentinel, grid.nextSibling);
// Create observer with appropriate settings, slightly different for checkpoints page
const observerOptions = {
threshold: 0.1,
rootMargin: pageType === 'checkpoints' ? '0px 0px 200px 0px' : '0px 0px 100px 0px'
};
// Initialize the observer
state.observer = new IntersectionObserver((entries) => {
const target = entries[0];
if (target.isIntersecting && !pageState.isLoading && pageState.hasMore) {
debouncedLoadMore();
}
}, observerOptions);
// Start observing
state.observer.observe(sentinel);
// Clean up any existing scroll event listener
if (state.scrollHandler) {
window.removeEventListener('scroll', state.scrollHandler);
state.scrollHandler = null;
}
// Add a simple backup scroll handler
const handleScroll = debounce(() => {
if (pageState.isLoading || !pageState.hasMore) return;
@@ -103,26 +104,17 @@ export function initializeInfiniteScroll(pageType = 'loras') {
const rect = sentinel.getBoundingClientRect();
const windowHeight = window.innerHeight;
// If sentinel is within 500px of viewport bottom, load more
if (rect.top < windowHeight + 500) {
if (rect.top < windowHeight + 200) {
debouncedLoadMore();
}
}, 200);
// Clean up existing scroll listener if any
if (state.scrollHandler) {
window.removeEventListener('scroll', state.scrollHandler);
}
// Save reference to the handler for cleanup
state.scrollHandler = handleScroll;
window.addEventListener('scroll', state.scrollHandler);
// Check position immediately in case content is already visible
setTimeout(() => {
const sentinel = document.getElementById('scroll-sentinel');
if (sentinel && sentinel.getBoundingClientRect().top < window.innerHeight) {
debouncedLoadMore();
}
}, 100);
// Clear any existing interval
if (state.scrollCheckInterval) {
clearInterval(state.scrollCheckInterval);
state.scrollCheckInterval = null;
}
}

View File

@@ -171,4 +171,45 @@ export function migrateStorageItems() {
localStorage.setItem(STORAGE_PREFIX + 'migration_completed', 'true');
console.log('Lora Manager: Storage migration completed');
}
/**
* Save a Map to localStorage
* @param {string} key - The localStorage key
* @param {Map} map - The Map to save
*/
export function saveMapToStorage(key, map) {
if (!(map instanceof Map)) {
console.error('Cannot save non-Map object:', map);
return;
}
try {
const prefixedKey = STORAGE_PREFIX + key;
// Convert Map to array of entries and save as JSON
const entries = Array.from(map.entries());
localStorage.setItem(prefixedKey, JSON.stringify(entries));
} catch (error) {
console.error(`Error saving Map to localStorage (${key}):`, error);
}
}
/**
* Load a Map from localStorage
* @param {string} key - The localStorage key
* @returns {Map} - The loaded Map or a new empty Map
*/
export function getMapFromStorage(key) {
try {
const prefixedKey = STORAGE_PREFIX + key;
const data = localStorage.getItem(prefixedKey);
if (!data) return new Map();
// Parse JSON and convert back to Map
const entries = JSON.parse(data);
return new Map(entries);
} catch (error) {
console.error(`Error loading Map from localStorage (${key}):`, error);
return new Map();
}
}

View File

@@ -13,6 +13,18 @@
{% block additional_components %}
{% include 'components/checkpoint_modals.html' %}
<div id="checkpointContextMenu" class="context-menu" style="display: none;">
<div class="context-menu-item" data-action="details"><i class="fas fa-info-circle"></i> View Details</div>
<div class="context-menu-item" data-action="civitai"><i class="fas fa-external-link-alt"></i> View on CivitAI</div>
<div class="context-menu-item" data-action="refresh-metadata"><i class="fas fa-sync"></i> Refresh Civitai Data</div>
<div class="context-menu-item" data-action="copyname"><i class="fas fa-copy"></i> Copy Model Filename</div>
<div class="context-menu-item" data-action="preview"><i class="fas fa-image"></i> Replace Preview</div>
<div class="context-menu-item" data-action="set-nsfw"><i class="fas fa-exclamation-triangle"></i> Set Content Rating</div>
<div class="context-menu-separator"></div>
<div class="context-menu-item" data-action="move"><i class="fas fa-folder-open"></i> Move to Folder</div>
<div class="context-menu-item delete-item" data-action="delete"><i class="fas fa-trash"></i> Delete Model</div>
</div>
{% endblock %}
{% block content %}

View File

@@ -16,6 +16,16 @@
{% block additional_components %}
{% include 'components/import_modal.html' %}
{% include 'components/recipe_modal.html' %}
<div id="recipeContextMenu" class="context-menu" style="display: none;">
<div class="context-menu-item" data-action="details"><i class="fas fa-info-circle"></i> View Details</div>
<div class="context-menu-item" data-action="share"><i class="fas fa-share-alt"></i> Share Recipe</div>
<div class="context-menu-item" data-action="copy"><i class="fas fa-copy"></i> Copy Recipe Syntax</div>
<div class="context-menu-item" data-action="viewloras"><i class="fas fa-layer-group"></i> View All LoRAs</div>
<div class="context-menu-item download-missing-item" data-action="download-missing"><i class="fas fa-download"></i> Download Missing LoRAs</div>
<div class="context-menu-separator"></div>
<div class="context-menu-item delete-item" data-action="delete"><i class="fas fa-trash"></i> Delete Recipe</div>
</div>
{% endblock %}
{% block init_title %}Initializing Recipe Manager{% endblock %}

View File

@@ -287,6 +287,108 @@ export function addLorasWidget(node, name, opts, callback) {
// 创建预览tooltip实例
const previewTooltip = new PreviewTooltip();
// Function to handle strength adjustment via dragging
const handleStrengthDrag = (name, initialStrength, initialX, event, widget) => {
// Calculate drag sensitivity (how much the strength changes per pixel)
// Using 0.01 per 10 pixels of movement
const sensitivity = 0.001;
// Get the current mouse position
const currentX = event.clientX;
// Calculate the distance moved
const deltaX = currentX - initialX;
// Calculate the new strength value based on movement
// Moving right increases, moving left decreases
let newStrength = Number(initialStrength) + (deltaX * sensitivity);
// Limit the strength to reasonable bounds (now between -10 and 10)
newStrength = Math.max(-10, Math.min(10, newStrength));
newStrength = Number(newStrength.toFixed(2));
// Update the lora data
const lorasData = parseLoraValue(widget.value);
const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) {
lorasData[loraIndex].strength = newStrength;
// Update the widget value
widget.value = formatLoraValue(lorasData);
// Force re-render to show updated strength value
renderLoras(widget.value, widget);
}
};
// Function to initialize drag operation
const initDrag = (loraEl, nameEl, name, widget) => {
let isDragging = false;
let initialX = 0;
let initialStrength = 0;
// Create a style element for drag cursor override if it doesn't exist
if (!document.getElementById('comfy-lora-drag-style')) {
const styleEl = document.createElement('style');
styleEl.id = 'comfy-lora-drag-style';
styleEl.textContent = `
body.comfy-lora-dragging,
body.comfy-lora-dragging * {
cursor: ew-resize !important;
}
`;
document.head.appendChild(styleEl);
}
// Create a drag handler that's applied to the entire lora entry
// except toggle and strength controls
loraEl.addEventListener('mousedown', (e) => {
// Skip if clicking on toggle or strength control areas
if (e.target.closest('.comfy-lora-toggle') ||
e.target.closest('input') ||
e.target.closest('.comfy-lora-arrow')) {
return;
}
// Store initial values
const lorasData = parseLoraValue(widget.value);
const loraData = lorasData.find(l => l.name === name);
if (!loraData) return;
initialX = e.clientX;
initialStrength = loraData.strength;
isDragging = true;
// Add class to body to enforce cursor style globally
document.body.classList.add('comfy-lora-dragging');
// Prevent text selection during drag
e.preventDefault();
});
// Use the document for move and up events to ensure drag continues
// even if mouse leaves the element
document.addEventListener('mousemove', (e) => {
if (!isDragging) return;
// Call the strength adjustment function
handleStrengthDrag(name, initialStrength, initialX, e, widget);
// Prevent showing the preview tooltip during drag
previewTooltip.hide();
});
document.addEventListener('mouseup', () => {
if (isDragging) {
isDragging = false;
// Remove the class to restore normal cursor behavior
document.body.classList.remove('comfy-lora-dragging');
}
});
};
// Function to create menu item
const createMenuItem = (text, icon, onClick) => {
const menuItem = document.createElement('div');
@@ -756,6 +858,9 @@ export function addLorasWidget(node, name, opts, callback) {
loraEl.appendChild(strengthControl);
container.appendChild(loraEl);
// Initialize drag functionality
initDrag(loraEl, nameEl, name, widget);
});
};

View File

@@ -366,6 +366,108 @@ export function addLorasWidget(node, name, opts, callback) {
return menuItem;
};
// Function to handle strength adjustment via dragging
const handleStrengthDrag = (name, initialStrength, initialX, event, widget) => {
// Calculate drag sensitivity (how much the strength changes per pixel)
// Using 0.01 per 10 pixels of movement
const sensitivity = 0.001;
// Get the current mouse position
const currentX = event.clientX;
// Calculate the distance moved
const deltaX = currentX - initialX;
// Calculate the new strength value based on movement
// Moving right increases, moving left decreases
let newStrength = Number(initialStrength) + (deltaX * sensitivity);
// Limit the strength to reasonable bounds (now between -10 and 10)
newStrength = Math.max(-10, Math.min(10, newStrength));
newStrength = Number(newStrength.toFixed(2));
// Update the lora data
const lorasData = parseLoraValue(widget.value);
const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) {
lorasData[loraIndex].strength = newStrength;
// Update the widget value
widget.value = formatLoraValue(lorasData);
// Force re-render to show updated strength value
renderLoras(widget.value, widget);
}
};
// Function to initialize drag operation
const initDrag = (loraEl, nameEl, name, widget) => {
let isDragging = false;
let initialX = 0;
let initialStrength = 0;
// Create a style element for drag cursor override if it doesn't exist
if (!document.getElementById('comfy-lora-drag-style')) {
const styleEl = document.createElement('style');
styleEl.id = 'comfy-lora-drag-style';
styleEl.textContent = `
body.comfy-lora-dragging,
body.comfy-lora-dragging * {
cursor: ew-resize !important;
}
`;
document.head.appendChild(styleEl);
}
// Create a drag handler that's applied to the entire lora entry
// except toggle and strength controls
loraEl.addEventListener('mousedown', (e) => {
// Skip if clicking on toggle or strength control areas
if (e.target.closest('.comfy-lora-toggle') ||
e.target.closest('input') ||
e.target.closest('.comfy-lora-arrow')) {
return;
}
// Store initial values
const lorasData = parseLoraValue(widget.value);
const loraData = lorasData.find(l => l.name === name);
if (!loraData) return;
initialX = e.clientX;
initialStrength = loraData.strength;
isDragging = true;
// Add class to body to enforce cursor style globally
document.body.classList.add('comfy-lora-dragging');
// Prevent text selection during drag
e.preventDefault();
});
// Use the document for move and up events to ensure drag continues
// even if mouse leaves the element
document.addEventListener('mousemove', (e) => {
if (!isDragging) return;
// Call the strength adjustment function
handleStrengthDrag(name, initialStrength, initialX, e, widget);
// Prevent showing the preview tooltip during drag
previewTooltip.hide();
});
document.addEventListener('mouseup', () => {
if (isDragging) {
isDragging = false;
// Remove the class to restore normal cursor behavior
document.body.classList.remove('comfy-lora-dragging');
}
});
};
// Function to create context menu
const createContextMenu = (x, y, loraName, widget) => {
// Hide preview tooltip first
@@ -649,6 +751,9 @@ export function addLorasWidget(node, name, opts, callback) {
e.stopPropagation();
previewTooltip.hide();
});
// Initialize drag functionality for strength adjustment
initDrag(loraEl, nameEl, name, widget);
// Remove the preview tooltip events from loraEl
loraEl.onmouseenter = () => {
@@ -861,9 +966,6 @@ export function addLorasWidget(node, name, opts, callback) {
// Function to directly save the recipe without dialog
async function saveRecipeDirectly(widget) {
try {
// Get the workflow data from the ComfyUI app
const prompt = await app.graphToPrompt();
// Show loading toast
if (app && app.extensionManager && app.extensionManager.toast) {
app.extensionManager.toast.add({
@@ -874,14 +976,9 @@ async function saveRecipeDirectly(widget) {
});
}
// Prepare the data - only send workflow JSON
const formData = new FormData();
formData.append('workflow_json', JSON.stringify(prompt.output));
// Send the request
// Send the request to the backend API without workflow data
const response = await fetch('/api/recipes/save-from-widget', {
method: 'POST',
body: formData
method: 'POST'
});
const result = await response.json();
@@ -917,4 +1014,4 @@ async function saveRecipeDirectly(widget) {
});
}
}
}
}