Compare commits

...

85 Commits

Author SHA1 Message Date
Will Miao
0817901bef feat: update README and pyproject.toml for v0.8.10 release; add standalone mode and portable edition features 2025-04-28 18:24:02 +08:00
Will Miao
ac22172e53 Update requirements for standalone mode 2025-04-28 15:14:11 +08:00
Will Miao
fd87fbf31e Update workflow 2025-04-28 07:08:35 +08:00
Will Miao
554be0908f feat: add dynamic filename format patterns for Save Image Node in README 2025-04-28 07:01:33 +08:00
Will Miao
eaec4e5f13 feat: update README and settings.json.example for standalone mode; enhance standalone.py to redirect status requests to loras page 2025-04-27 09:41:33 +08:00
Will Miao
0e7ba27a7d feat: enhance Civitai resource extraction in StandardMetadataParser for improved JSON handling. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/141 2025-04-26 22:12:40 +08:00
Will Miao
c551f5c23b feat: update README with standalone mode instructions and add settings.json.example file 2025-04-26 20:39:24 +08:00
pixelpaws
5159657ae5 Merge pull request #142 from willmiao/dev
Dev
2025-04-26 20:25:26 +08:00
Will Miao
d35db7df72 feat: add standalone mode for LoRA Manager with setup instructions 2025-04-26 20:23:27 +08:00
Will Miao
2b5399c559 feat: enhance folder path retrieval for diffusion models and improve warning messages 2025-04-26 20:08:00 +08:00
Will Miao
9e61bbbd8e feat: improve warning management by removing existing deleted LoRAs and early access warnings 2025-04-26 19:46:48 +08:00
Will Miao
7ce5857cd5 feat: implement standalone mode support with mock modules and path handling 2025-04-26 19:14:38 +08:00
Will Miao
38fbae99fd feat: limit maximum height of loras widget to accommodate up to 5 entries. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/109 2025-04-26 12:00:36 +08:00
Will Miao
b0a9d44b0c Add support for SamplerCustomAdvanced node in metadata extraction 2025-04-26 09:40:44 +08:00
Will Miao
b4e22cd375 feat: update release notes and version to 0.8.9 with new favorites system and UI enhancements 2025-04-25 22:13:16 +08:00
Will Miao
9bc92736a7 feat: enhance session management by ensuring freshness and optimizing connection parameters 2025-04-25 20:54:25 +08:00
pixelpaws
111b34d05c Merge pull request #138 from willmiao/dev
feat: implement theme management with auto-detection and user prefere…
2025-04-25 19:47:17 +08:00
Will Miao
07d9599a2f feat: implement theme management with auto-detection and user preference storage. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/137 2025-04-25 19:39:11 +08:00
pixelpaws
d8194f211d Merge pull request #136 from willmiao/dev
Dev
2025-04-25 17:56:26 +08:00
Will Miao
51a6374c33 feat: add favorites filtering functionality across models and UI components 2025-04-25 17:55:33 +08:00
Will Miao
aa6c6035b6 refactor: consolidate save model metadata functionality across APIs 2025-04-25 13:31:01 +08:00
Will Miao
44b4a7ffbb fix: update requirements to include 'toml' and correct pip install command in README. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/134 2025-04-25 10:26:01 +08:00
Will Miao
e5bb018d22 feat: integrate Font Awesome resources locally. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/131
- Replace CDN references with local resources
- Download and include Font Awesome CSS and webfonts in project
- Remove CDN preconnect as resources are now served locally
- Improve reliability for users with limited network access
2025-04-25 10:09:20 +08:00
Will Miao
79b8a6536e docs: Update README to clarify contribution guidelines and acknowledge project inspirations 2025-04-25 09:48:00 +08:00
Will Miao
3de31cd06a feat: Add functionality to move civitai.info file during model relocation 2025-04-25 09:41:23 +08:00
Will Miao
c579b54d40 fix: Preserve original path separators when mapping real paths in Config. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/132 2025-04-25 09:33:07 +08:00
Will Miao
0a52575e8b feat: Enhance model file retrieval by ensuring primary model is selected from files list. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/127 2025-04-25 05:45:29 +08:00
Will Miao
23c9a98f66 feat: Add endpoint for scanning and rebuilding recipe cache, and update UI to use new refresh method 2025-04-24 13:23:31 +08:00
Will Miao
796fc33b5b feat: Optimize TCP connection parameters and enhance logging for download operations 2025-04-22 19:43:37 +08:00
Will Miao
dc4c11ddd2 feat: Update release notes and version to 0.8.8 with new features and bug fixes 2025-04-22 13:29:00 +08:00
pixelpaws
d389e4d5d4 Merge pull request #122 from willmiao/dev
Dev
2025-04-22 09:40:05 +08:00
Will Miao
8cb78ad931 feat: Add route for retrieving current usage statistics 2025-04-22 09:39:00 +08:00
Will Miao
85f987d15c feat: Centralize clipboard functionality with copyToClipboard utility across components 2025-04-22 09:33:05 +08:00
Will Miao
b12079e0f6 feat: Implement usage statistics tracking with backend integration and route setup 2025-04-22 08:56:34 +08:00
pixelpaws
dcf5c6167a Merge pull request #121 from willmiao/dev
Dev
2025-04-21 15:44:23 +08:00
Will Miao
b395d3f487 fix: Update filename formatting in save_images method to ensure unique filenames for batch images 2025-04-21 15:42:49 +08:00
Will Miao
37662cad10 Update workflow 2025-04-21 15:42:49 +08:00
pixelpaws
aa1673063d Merge pull request #120 from willmiao/dev
feat: Enhance LoraManager by updating trigger words handling and dyna…
2025-04-21 06:52:16 +08:00
Will Miao
f51f49eb60 feat: Enhance LoraManager by updating trigger words handling and dynamically loading widget modules. 2025-04-21 06:49:51 +08:00
pixelpaws
54c9bac961 Merge pull request #119 from willmiao/dev
Dev
2025-04-20 22:29:28 +08:00
Will Miao
e70fd73bdd feat: Implement trigger words API and update frontend integration for LoraManager. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/43 2025-04-20 22:27:53 +08:00
Will Miao
9bb9e7b64d refactor: Extract common methods for Lora handling into utils.py and update references in lora_loader.py and lora_stacker.py 2025-04-20 21:35:36 +08:00
pixelpaws
f64c03543a Merge pull request #116 from matrunchyk/main
Prevent duplicates of root folders when using symlinks
2025-04-20 17:05:08 +08:00
Will Miao
51374de1a1 fix: Update version to 0.8.7-bugfix2 in pyproject.toml for clarity on bug fixes 2025-04-20 15:04:24 +08:00
Will Miao
afcc12f263 fix: Update populate_lora_from_civitai method to accept a tuple for Civitai API response. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/117 2025-04-20 15:01:23 +08:00
Your Name
88c5482366 Merge branch 'main' of https://github.com/willmiao/ComfyUI-Lora-Manager 2025-04-19 21:47:41 +03:00
Your Name
bbf7295c32 Prevent duplicates of root folders when using symlinks 2025-04-19 21:42:01 +03:00
Will Miao
ca5e23e68c fix: Update version to 0.8.7-bugfix in pyproject.toml for clarity on bug fixes 2025-04-19 23:02:50 +08:00
Will Miao
eadb1487ae feat: Refactor metadata formatting to use helper function for conditional parameter addition 2025-04-19 23:00:09 +08:00
Will Miao
1faa70fc77 feat: Implement filename-based hash retrieval in LoraScanner and ModelScanner for improved compatibility 2025-04-19 21:12:26 +08:00
Will Miao
30d7c007de fix: Correct metadata restoration logic to ensure file info is fetched when metadata is missing 2025-04-19 20:51:23 +08:00
Will Miao
f54f6a4402 feat: Enhance metadata handling by restoring missing civitai data and extracting tags and descriptions from version info 2025-04-19 11:35:42 +08:00
Will Miao
7b41cdec65 feat: Add civitai_deleted attribute to BaseModelMetadata for tracking deletion status from Civitai 2025-04-19 09:30:43 +08:00
Will Miao
fb6a652a57 feat: Add checkpoint hash retrieval and enhance metadata formatting in SaveImage class 2025-04-18 23:55:45 +08:00
Will Miao
ea34d753c1 refactor: Remove unnecessary workflow data logging and streamline saveRecipeDirectly function for legacy loras widget 2025-04-18 21:52:26 +08:00
Will Miao
2bc46e708e feat: Update release notes and version to 0.8.7 with enhancements and bug fixes 2025-04-18 19:03:00 +08:00
Will Miao
96e3b5b7b3 feat: Refactor Civitai model API routes and enhance RecipeContextMenu for missing LoRAs handling 2025-04-18 16:44:26 +08:00
Will Miao
fafbafa5e1 feat: Enhance copyTriggerWord function with modern clipboard API and fallback for non-secure contexts. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/110 2025-04-18 14:56:27 +08:00
Will Miao
be8605d8c6 feat: Enhance CivitaiClient and ApiRoutes to handle model version errors and improve metadata fetching. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/112 2025-04-18 14:44:53 +08:00
Will Miao
061660d47a feat: Increase maximum allowed trigger words from 10 to 30. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/109 2025-04-18 11:25:41 +08:00
pixelpaws
2ed6dbb344 Merge pull request #111 from willmiao/dev
Dev
2025-04-18 10:55:07 +08:00
Will Miao
4766b45746 feat: Update SaveImage node to modify default lossless_webp setting and adjust save_kwargs for image formats 2025-04-18 10:52:39 +08:00
Will Miao
0734252e98 feat: Enhance VAEDecodeExtractor to improve image caching and metadata handling 2025-04-18 10:03:26 +08:00
Will Miao
91b4827c1d feat: Enhance image retrieval in MetadataRegistry and update recipe routes to process images from metadata 2025-04-18 09:24:48 +08:00
Will Miao
df6d56ce66 feat: Add IMAGES category to constants and enhance metadata handling in node extractors 2025-04-18 07:12:43 +08:00
Will Miao
f0203c96ab feat: Simplify format_metadata method by removing custom_prompt parameter and update related function calls 2025-04-18 05:34:42 +08:00
Will Miao
bccabe40c0 feat: Enhance KSamplerAdvancedExtractor to include additional sampling parameters and update metadata processing 2025-04-18 05:29:36 +08:00
Will Miao
c2f599b4ff feat: Update node extractors to include UNETLoaderExtractor and enhance metadata handling for guidance parameters 2025-04-17 22:05:40 +08:00
Will Miao
5fd069d70d feat: Enhance checkpoint processing in format_metadata to handle non-string types safely 2025-04-17 09:38:20 +08:00
Will Miao
32d34d1748 feat: Enhance trace_node_input method with depth tracking and target class filtering; add FluxGuidanceExtractor for guidance parameter extraction 2025-04-17 08:06:21 +08:00
Will Miao
18eb605605 feat: Refactor metadata processing to use constants for category keys and improve structure 2025-04-17 06:23:31 +08:00
Will Miao
4fdc88e9e1 feat: Enhance LoraLoaderExtractor to extract base filename from lora_name input 2025-04-16 22:19:38 +08:00
Will Miao
4c69d8d3a8 feat: Integrate metadata collection in RecipeRoutes and simplify saveRecipeDirectly function 2025-04-16 22:15:46 +08:00
Will Miao
d4b2dd0ec1 refactor: Rename to_comfyui_format method to to_dict and update references in save_image.py 2025-04-16 21:42:54 +08:00
Will Miao
181f78421b feat: Standardize LoRA extraction format and enhance input handling in node extractors 2025-04-16 21:20:56 +08:00
Will Miao
8ed38527d0 feat: Implement metadata collection and processing framework with debug node for verification 2025-04-16 20:04:26 +08:00
Will Miao
c4c926070d fix: Update optimize_image method to handle image validation and error logging, and adjust metadata preservation logic. 2025-04-15 12:31:17 +08:00
Will Miao
ed87411e0d refactor: Change logging level from info to debug for service initialization and file monitoring 2025-04-15 11:48:37 +08:00
Will Miao
4ec2a448ab feat: Improve date formatting in filename generation with zero-padding and two-digit year support. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/102 2025-04-15 10:46:57 +08:00
Will Miao
73d01da94e feat: Enhance model preview version management with localStorage support 2025-04-15 10:35:50 +08:00
pixelpaws
df8e02157a Merge pull request #103 from willmiao/dev
feat: Add drag functionality for strength adjustment in LoRA entries.…
2025-04-15 08:57:52 +08:00
Will Miao
6e513ed32a feat: Add drag functionality for strength adjustment in LoRA entries. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/101 2025-04-15 08:56:19 +08:00
pixelpaws
325ef6327d Merge pull request #99 from willmiao/dev
Dev
2025-04-14 20:27:18 +08:00
Will Miao
46700e5ad0 feat: Refactor infinite scroll initialization for improved observer handling and sentinel management 2025-04-14 20:25:44 +08:00
Will Miao
d1e21fa345 feat: Implement context menus for checkpoints and recipes, including metadata refresh and NSFW level management 2025-04-14 15:37:36 +08:00
90 changed files with 4773 additions and 751 deletions

112
README.md
View File

@@ -20,6 +20,32 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
## Release Notes
### v0.8.10
* **Standalone Mode** - Run LoRA Manager independently from ComfyUI for a lightweight experience that works even with other stable diffusion interfaces
* **Portable Edition** - New one-click portable version for easy startup and updates in standalone mode
* **Enhanced Metadata Collection** - Added support for SamplerCustomAdvanced node in the metadata collector module
* **Improved UI Organization** - Optimized Lora Loader node height to display up to 5 LoRAs at once with scrolling capability for larger collections
### v0.8.9
* **Favorites System** - New functionality to bookmark your favorite LoRAs and checkpoints for quick access and better organization
* **Enhanced UI Controls** - Increased model card button sizes for improved usability and easier interaction
* **Smoother Page Transitions** - Optimized interface switching between pages, eliminating flash issues particularly noticeable in dark theme
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.8
* **Real-time TriggerWord Updates** - Enhanced TriggerWord Toggle node to instantly update when connected Lora Loader or Lora Stacker nodes change, without requiring workflow execution
* **Optimized Metadata Recovery** - Improved utilization of existing .civitai.info files for faster initialization and preservation of metadata from models deleted from CivitAI
* **Migration Acceleration** - Further speed improvements for users transitioning from A1111/Forge environments
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.7
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
* **Metadata Collector Overhaul** - Rebuilt metadata collection system with optimized architecture for better performance
* **Improved Save Image Node** - Enhanced metadata capture and image saving performance with the new metadata collector
* **Streamlined Recipe Saving** - Optimized Save Recipe functionality to work independently without requiring Preview Image nodes
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.6 Major Update
* **Checkpoint Management** - Added comprehensive management for model checkpoints including scanning, searching, filtering, and deletion
* **Enhanced Metadata Support** - New capabilities for retrieving and managing checkpoint metadata with improved operations
@@ -132,7 +158,7 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
```bash
git clone https://github.com/willmiao/ComfyUI-Lora-Manager.git
cd ComfyUI-Lora-Manager
pip install requirements.txt
pip install -r requirements.txt
```
## Usage
@@ -153,23 +179,92 @@ pip install requirements.txt
- Paste into the Lora Loader node's text input
- The node will automatically apply preset strength and trigger words
### Filename Format Patterns for Save Image Node
The Save Image Node supports dynamic filename generation using pattern codes. You can customize how your images are named using the following format patterns:
#### Available Pattern Codes
- `%seed%` - Inserts the generation seed number
- `%width%` - Inserts the image width
- `%height%` - Inserts the image height
- `%pprompt:N%` - Inserts the positive prompt (limited to N characters)
- `%nprompt:N%` - Inserts the negative prompt (limited to N characters)
- `%model:N%` - Inserts the model/checkpoint name (limited to N characters)
- `%date%` - Inserts current date/time as "yyyyMMddhhmmss"
- `%date:FORMAT%` - Inserts date using custom format with:
- `yyyy` - 4-digit year
- `yy` - 2-digit year
- `MM` - 2-digit month
- `dd` - 2-digit day
- `hh` - 2-digit hour
- `mm` - 2-digit minute
- `ss` - 2-digit second
#### Examples
- `image_%seed%``image_1234567890`
- `gen_%width%x%height%``gen_512x768`
- `%model:10%_%seed%``dreamshape_1234567890`
- `%date:yyyy-MM-dd%``2025-04-28`
- `%pprompt:20%_%seed%``beautiful landscape_1234567890`
- `%model%_%date:yyMMdd%_%seed%``dreamshaper_v8_250428_1234567890`
You can combine multiple patterns to create detailed, organized filenames for your generated images.
### Standalone Mode
You can now run LoRA Manager independently from ComfyUI:
1. **For ComfyUI users**:
- Launch ComfyUI with LoRA Manager at least once to initialize the necessary path information in the `settings.json` file.
- Make sure dependencies are installed: `pip install -r requirements.txt`
- From your ComfyUI root directory, run:
```bash
python custom_nodes\comfyui-lora-manager\standalone.py
```
- Access the interface at: `http://localhost:8188/loras`
- You can specify a different host or port with arguments:
```bash
python custom_nodes\comfyui-lora-manager\standalone.py --host 127.0.0.1 --port 9000
```
2. **For non-ComfyUI users**:
- Copy the provided `settings.json.example` file to create a new file named `settings.json`
- Edit `settings.json` to include your correct model folder paths and CivitAI API key
- Install required dependencies: `pip install -r requirements.txt`
- Run standalone mode:
```bash
python standalone.py
```
- Access the interface through your browser at: `http://localhost:8188/loras`
This standalone mode provides a lightweight option for managing your model and recipe collection without needing to run the full ComfyUI environment, making it useful even for users who primarily use other stable diffusion interfaces.
---
## Contributing
Thank you for your interest in contributing to ComfyUI LoRA Manager! As this project is currently in its early stages and undergoing rapid development and refactoring, we are temporarily not accepting pull requests.
However, your feedback and ideas are extremely valuable to us:
- Please feel free to open issues for any bugs you encounter
- Submit feature requests through GitHub issues
- Share your suggestions for improvements
We appreciate your understanding and look forward to potentially accepting code contributions once the project architecture stabilizes.
---
## Credits
This project has been inspired by and benefited from other excellent ComfyUI extensions:
- [ComfyUI-SaveImageWithMetaData](https://github.com/Comfy-Community/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
- [ComfyUI-SaveImageWithMetaData](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
---
## Contributing
If you have suggestions, bug reports, or improvements, feel free to open an issue or contribute directly to the codebase. Pull requests are always welcome!
---
## ☕ Support
If you find this project helpful, consider supporting its development:
@@ -182,3 +277,4 @@ Join our Discord community for support, discussions, and updates:
[Discord Server](https://discord.gg/vcqNrWVFvM)
---
````

View File

@@ -3,16 +3,23 @@ from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker
from .py.nodes.save_image import SaveImage
from .py.nodes.debug_metadata import DebugMetadata
# Import metadata collector to install hooks on startup
from .py.metadata_collector import init as init_metadata_collector
NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader,
TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker,
SaveImage.NAME: SaveImage
SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata
}
WEB_DIRECTORY = "./web/comfyui"
# Initialize metadata collector
init_metadata_collector()
# Register routes on import
LoraManager.add_routes()
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']

View File

@@ -3,6 +3,11 @@ import platform
import folder_paths # type: ignore
from typing import List
import logging
import sys
import json
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
logger = logging.getLogger(__name__)
@@ -18,9 +23,46 @@ class Config:
self._route_mappings = {}
self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = self._init_checkpoint_paths()
self.temp_directory = folder_paths.get_temp_directory()
# 在初始化时扫描符号链接
self._scan_symbolic_links()
if not standalone_mode:
# Save the paths to settings.json when running in ComfyUI mode
self.save_folder_paths_to_settings()
def save_folder_paths_to_settings(self):
"""Save folder paths to settings.json for standalone mode to use later"""
try:
# Check if we're running in ComfyUI mode (not standalone)
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type):
# Get all relevant paths
lora_paths = folder_paths.get_folder_paths("loras")
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffuser_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Load existing settings
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
settings = {}
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings = json.load(f)
# Update settings with paths
settings['folder_paths'] = {
'loras': lora_paths,
'checkpoints': checkpoint_paths,
'diffusers': diffuser_paths,
'unet': unet_paths
}
# Save settings
with open(settings_path, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=2)
logger.info("Saved folder paths to settings.json")
except Exception as e:
logger.warning(f"Failed to save folder paths: {e}")
def _is_link(self, path: str) -> bool:
try:
@@ -103,50 +145,66 @@ class Config:
def _init_lora_paths(self) -> List[str]:
"""Initialize and validate LoRA paths from ComfyUI settings"""
paths = sorted(set(path.replace(os.sep, "/")
for path in folder_paths.get_folder_paths("loras")
if os.path.exists(path)), key=lambda p: p.lower())
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
if not paths:
raise ValueError("No valid loras folders found in ComfyUI configuration")
# 初始化路径映射
for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
return paths
try:
raw_paths = folder_paths.get_folder_paths("loras")
# Normalize and resolve symlinks, store mapping from resolved -> original
path_map = {}
for path in raw_paths:
if os.path.exists(path):
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
# Now sort and use only the deduplicated real paths
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
logger.info("Found LoRA roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
if not unique_paths:
logger.warning("No valid loras folders found in ComfyUI configuration")
return []
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return unique_paths
except Exception as e:
logger.warning(f"Error initializing LoRA paths: {e}")
return []
def _init_checkpoint_paths(self) -> List[str]:
"""Initialize and validate checkpoint paths from ComfyUI settings"""
# Get checkpoint paths from folder_paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Combine all checkpoint-related paths
all_paths = checkpoint_paths + diffusion_paths + unet_paths
# Filter and normalize paths
paths = sorted(set(path.replace(os.sep, "/")
for path in all_paths
if os.path.exists(path)), key=lambda p: p.lower())
print("Found checkpoint roots:", paths)
if not paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
try:
# Get checkpoint paths from folder_paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Combine all checkpoint-related paths
all_paths = checkpoint_paths + diffusion_paths + unet_paths
# Filter and normalize paths
paths = sorted(set(path.replace(os.sep, "/")
for path in all_paths
if os.path.exists(path)), key=lambda p: p.lower())
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]"))
if not paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
return []
# 初始化路径映射,与 LoRA 路径处理方式相同
for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
return paths
except Exception as e:
logger.warning(f"Error initializing checkpoint paths: {e}")
return []
# 初始化路径映射,与 LoRA 路径处理方式相同
for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
return paths
def get_preview_static_url(self, preview_path: str) -> str:
"""Convert local preview path to static URL"""

View File

@@ -5,11 +5,17 @@ from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes
from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes
from .routes.update_routes import UpdateRoutes
from .routes.usage_stats_routes import UsageStatsRoutes
from .services.service_registry import ServiceRegistry
import logging
import sys
logger = logging.getLogger(__name__)
# Check if we're in standalone mode
STANDALONE_MODE = 'nodes' not in sys.modules
class LoraManager:
"""Main entry point for LoRA Manager plugin"""
@@ -18,6 +24,9 @@ class LoraManager:
"""Initialize and register all routes"""
app = PromptServer.instance.app
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
added_targets = set() # Track already added target paths
# Add static routes for each lora root
@@ -92,6 +101,8 @@ class LoraManager:
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
UsageStatsRoutes.setup_routes(app) # Register usage stats routes
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
@@ -104,7 +115,8 @@ class LoraManager:
async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry"""
try:
logger.info("LoRA Manager: Initializing services via ServiceRegistry")
# Ensure aiohttp access logger is configured with reduced verbosity
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Initialize CivitaiClient first to ensure it's ready for other services
civitai_client = await ServiceRegistry.get_civitai_client()
@@ -115,12 +127,12 @@ class LoraManager:
# Start monitors
lora_monitor.start()
logger.info("Lora monitor started")
logger.debug("Lora monitor started")
# Make sure checkpoint monitor has paths before starting
await checkpoint_monitor.initialize_paths()
checkpoint_monitor.start()
logger.info("Checkpoint monitor started")
logger.debug("Checkpoint monitor started")
# Register DownloadManager with ServiceRegistry
download_manager = await ServiceRegistry.get_download_manager()
@@ -135,6 +147,12 @@ class LoraManager:
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Initialize metadata collector if not in standalone mode
if not STANDALONE_MODE:
from .metadata_collector import init as init_metadata
init_metadata()
logger.debug("Metadata collector initialized")
# Create low-priority initialization tasks
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')

View File

@@ -0,0 +1,32 @@
import os
import importlib
import sys
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
if not standalone_mode:
from .metadata_hook import MetadataHook
from .metadata_registry import MetadataRegistry
def init():
# Install hooks to collect metadata during execution
MetadataHook.install()
# Initialize registry
registry = MetadataRegistry()
print("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None):
"""Helper function to get metadata from the registry"""
registry = MetadataRegistry()
return registry.get_metadata(prompt_id)
else:
# Standalone mode - provide dummy implementations
def init():
print("ComfyUI Metadata Collector disabled in standalone mode")
def get_metadata(prompt_id=None):
"""Dummy implementation for standalone mode"""
return {}

View File

@@ -0,0 +1,14 @@
"""Constants used by the metadata collector"""
# Metadata collection constants
# Metadata categories
MODELS = "models"
PROMPTS = "prompts"
SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
IMAGES = "images"
# Complete list of categories to track
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]

View File

@@ -0,0 +1,123 @@
import sys
import inspect
from .metadata_registry import MetadataRegistry
class MetadataHook:
"""Install hooks for metadata collection"""
@staticmethod
def install():
"""Install hooks to collect metadata during execution"""
try:
# Import ComfyUI's execution module
execution = None
try:
# Try direct import first
import execution # type: ignore
except ImportError:
# Try to locate from system modules
for module_name in sys.modules:
if module_name.endswith('.execution'):
execution = sys.modules[module_name]
break
# If we can't find the execution module, we can't install hooks
if execution is None:
print("Could not locate ComfyUI execution module, metadata collection disabled")
return
# Store the original _map_node_over_list function
original_map_node_over_list = execution._map_node_over_list
# Define the wrapped _map_node_over_list function
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
def execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return original_execute(*args, **kwargs)
# Replace the functions
execution._map_node_over_list = map_node_over_list_with_metadata
execution.execute = execute_with_prompt_tracking
# Make map_node_over_list public to avoid it being hidden by hooks
execution.map_node_over_list = original_map_node_over_list
print("Metadata collection hooks installed for runtime values")
except Exception as e:
print(f"Error installing metadata hooks: {str(e)}")

View File

@@ -0,0 +1,300 @@
import json
import sys
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
class MetadataProcessor:
"""Process and format collected metadata"""
@staticmethod
def find_primary_sampler(metadata):
"""Find the primary KSampler node (with denoise=1)"""
primary_sampler = None
primary_sampler_id = None
# First, check for SamplerCustomAdvanced
prompt = metadata.get("current_prompt")
if prompt and prompt.original_prompt:
for node_id, node_info in prompt.original_prompt.items():
if node_info.get("class_type") == "SamplerCustomAdvanced":
# Found a SamplerCustomAdvanced node
if node_id in metadata.get(SAMPLING, {}):
return node_id, metadata[SAMPLING][node_id]
# Next, check for KSamplerAdvanced with add_noise="enable"
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
add_noise = parameters.get("add_noise")
# If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced
if add_noise == "enable":
primary_sampler = sampler_info
primary_sampler_id = node_id
break
# If no specialized sampler found, fall back to traditional KSampler with denoise=1
if primary_sampler is None:
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise")
# If denoise is 1.0, this is likely the primary sampler
if denoise == 1.0 or denoise == 1:
primary_sampler = sampler_info
primary_sampler_id = node_id
break
return primary_sampler_id, primary_sampler
@staticmethod
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
"""
Trace an input connection from a node to find the source node
Parameters:
- prompt: The prompt object containing node connections
- node_id: ID of the starting node
- input_name: Name of the input to trace
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
Returns:
- node_id of the found node, or None if not found
"""
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
return None
# For depth tracking
current_depth = 0
current_node_id = node_id
current_input = input_name
while current_depth < max_depth:
if current_node_id not in prompt.original_prompt:
return None
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if current_input not in node_inputs:
return None
input_value = node_inputs[current_input]
# Input connections are formatted as [node_id, output_index]
if isinstance(input_value, list) and len(input_value) >= 2:
found_node_id = input_value[0] # Connected node_id
# If we're looking for a specific node class
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class or haven't found it yet
if not target_class:
return found_node_id
# Continue tracing through intermediate nodes
current_node_id = found_node_id
# For most conditioning nodes, the input we want to follow is named "conditioning"
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
current_input = "conditioning"
else:
# If there's no "conditioning" input, we can't trace further
return found_node_id if not target_class else None
else:
# We've reached a node with no further connections
return None
current_depth += 1
# If we've reached max depth without finding target_class
return None
@staticmethod
def find_primary_checkpoint(metadata):
"""Find the primary checkpoint model in the workflow"""
if not metadata.get(MODELS):
return None
# In most workflows, there's only one checkpoint, so we can just take the first one
for node_id, model_info in metadata.get(MODELS, {}).items():
if model_info.get("type") == "checkpoint":
return model_info.get("name")
return None
@staticmethod
def extract_generation_params(metadata):
"""Extract generation parameters from metadata using node relationships"""
params = {
"prompt": "",
"negative_prompt": "",
"seed": None,
"steps": None,
"cfg_scale": None,
"guidance": None, # Add guidance parameter
"sampler": None,
"scheduler": None,
"checkpoint": None,
"loras": "",
"size": None,
"clip_skip": None
}
# Get the prompt object for node relationship tracing
prompt = metadata.get("current_prompt")
# Find the primary KSampler node
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
if checkpoint:
params["checkpoint"] = checkpoint
if primary_sampler:
# Extract sampling parameters
sampling_params = primary_sampler.get("parameters", {})
# Handle both seed and noise_seed
params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed")
params["steps"] = sampling_params.get("steps")
params["cfg_scale"] = sampling_params.get("cfg")
params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler")
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# Check if this is a SamplerCustomAdvanced node
is_custom_advanced = False
if prompt.original_prompt and primary_sampler_id in prompt.original_prompt:
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
if is_custom_advanced:
# For SamplerCustomAdvanced, trace specific inputs
# 1. Trace sigmas input to find BasicScheduler
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
params["steps"] = scheduler_params.get("steps")
params["scheduler"] = scheduler_params.get("scheduler")
# 2. Trace sampler input to find KSamplerSelect
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
params["sampler"] = sampler_params.get("sampler_name")
# 3. Trace guider input for FluxGuidance and CLIPTextEncode
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
if guider_node_id:
# Look for FluxGuidance along the guider path
flux_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Find CLIPTextEncode for positive prompt (through conditioning)
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
else:
# Original tracing for standard samplers
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find any FluxGuidance nodes in the positive conditioning path
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Size extraction is same for all sampler types
# Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
else:
# Fallback to the previous trace method if needed
latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image")
if latent_node_id:
# Follow chain to find EmptyLatentImage node
size_found = False
current_node_id = latent_node_id
# Limit depth to avoid infinite loops in complex workflows
max_depth = 10
for _ in range(max_depth):
if current_node_id in metadata.get(SIZE, {}):
width = metadata[SIZE][current_node_id].get("width")
height = metadata[SIZE][current_node_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
size_found = True
break
# Try to follow the chain
if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt:
node_info = prompt.original_prompt[current_node_id]
if "inputs" in node_info:
# Look for a connection that might lead to size information
for input_name, input_value in node_info["inputs"].items():
if isinstance(input_value, list) and len(input_value) >= 2:
current_node_id = input_value[0]
break
else:
break # No connections to follow
else:
break # No inputs to follow
else:
break # Can't follow further
# Extract LoRAs using the standardized format
lora_parts = []
for node_id, lora_info in metadata.get(LORAS, {}).items():
# Access the lora_list from the standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
name = lora.get("name", "unknown")
strength = lora.get("strength", 1.0)
lora_parts.append(f"<lora:{name}:{strength}>")
params["loras"] = " ".join(lora_parts)
# Set default clip_skip value
params["clip_skip"] = "1" # Common default
return params
@staticmethod
def to_dict(metadata):
"""Convert extracted metadata to the ComfyUI output.json format"""
if standalone_mode:
# Return empty dictionary in standalone mode
return {}
params = MetadataProcessor.extract_generation_params(metadata)
# Convert all values to strings to match output.json format
for key in params:
if params[key] is not None:
params[key] = str(params[key])
return params
@staticmethod
def to_json(metadata):
"""Convert metadata to JSON string"""
params = MetadataProcessor.to_dict(metadata)
return json.dumps(params, indent=4)

View File

@@ -0,0 +1,275 @@
import time
from nodes import NODE_CLASS_MAPPINGS
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._reset()
return cls._instance
def _reset(self):
self.current_prompt_id = None
self.current_prompt = None
self.metadata = {}
self.prompt_metadata = {}
self.executed_nodes = set()
# Node-level cache for metadata
self.node_cache = {}
# Limit the number of stored prompts
self.max_prompt_history = 3
# Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history:
return
# Sort all prompt_ids by timestamp
sorted_prompts = sorted(
self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
)
# Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
for pid in prompts_to_remove:
del self.prompt_metadata[pid]
def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id
self.executed_nodes = set()
self.prompt_metadata[prompt_id] = {
category: {} for category in METADATA_CATEGORIES
}
# Add additional metadata fields
self.prompt_metadata[prompt_id].update({
"execution_order": [],
"current_prompt": None, # Will store the prompt object
"timestamp": time.time()
})
# Clean up old prompt data
self._clean_old_prompts()
def set_current_prompt(self, prompt):
"""Set the current prompt object reference"""
self.current_prompt = prompt
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
# Store the prompt in the metadata for later relationship tracing
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
def get_metadata(self, prompt_id=None):
"""Get collected metadata for a prompt"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return {}
metadata = self.prompt_metadata[key]
# If we have a current prompt object, check for non-executed nodes
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
# Fill in missing metadata from cache for nodes that weren't executed
self._fill_missing_metadata(key, original_prompt)
return self.prompt_metadata.get(key, {})
def _fill_missing_metadata(self, prompt_id, original_prompt):
"""Fill missing metadata from cache for non-executed nodes"""
if not original_prompt:
return
executed_nodes = self.executed_nodes
metadata = self.prompt_metadata[prompt_id]
# Iterate through nodes in the original prompt
for node_id, node_data in original_prompt.items():
# Skip if already executed in this run
if node_id in executed_nodes:
continue
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
prompt_class_type = node_data.get("class_type")
if not prompt_class_type:
continue
# Convert to actual class name (which is what we use in our cache)
class_type = prompt_class_type
if prompt_class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
class_type = class_obj.__name__
# Create cache key using the actual class name
cache_key = f"{node_id}:{class_type}"
# Check if this node type is relevant for metadata collection
if class_type in NODE_EXTRACTORS:
# Check if we have cached metadata for this node
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
# Apply cached metadata to the current metadata
for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id]
def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution"""
if not self.current_prompt_id:
return
# Add to execution order and mark as executed
if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
# Process inputs to simplify working with them
processed_inputs = {}
for input_name, input_values in inputs.items():
if isinstance(input_values, list) and len(input_values) > 0:
# For single values, just use the first one (most common case)
processed_inputs[input_name] = input_values[0]
else:
processed_inputs[input_name] = input_values
# Extract node-specific metadata
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
extractor.extract(
node_id,
processed_inputs,
outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Cache this node's metadata
self._cache_node_metadata(node_id, class_type)
def update_node_execution(self, node_id, class_type, outputs):
"""Update node metadata with output information"""
if not self.current_prompt_id:
return
# Process outputs to make them more usable
processed_outputs = outputs
# Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'):
extractor.update(
node_id,
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Update the cached metadata for this node
self._cache_node_metadata(node_id, class_type)
def _cache_node_metadata(self, node_id, class_type):
"""Cache the metadata for a specific node"""
if not self.current_prompt_id or not node_id or not class_type:
return
# Create a cache key combining node_id and class_type
cache_key = f"{node_id}:{class_type}"
# Create a shallow copy of the node's metadata
node_metadata = {}
current_metadata = self.prompt_metadata[self.current_prompt_id]
for category in self.metadata_categories:
if category in current_metadata and node_id in current_metadata[category]:
if category not in node_metadata:
node_metadata[category] = {}
node_metadata[category][node_id] = current_metadata[category][node_id]
# Save to cache if we have any metadata for this node
if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata
def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata
active_node_ids = set()
for prompt_data in self.prompt_metadata.values():
for category in self.metadata_categories:
if category in prompt_data:
active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed
keys_to_remove = []
for cache_key in self.node_cache:
node_id = cache_key.split(':')[0]
if node_id not in active_node_ids:
keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed
for key in keys_to_remove:
del self.node_cache[key]
def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None:
if prompt_id in self.prompt_metadata:
del self.prompt_metadata[prompt_id]
# Clean up cache after removing prompt
self.clear_unused_cache()
else:
# Reset all data
self._reset()
def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return None
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
# If no image is found in the current metadata, try to find it in the cache
# This handles the case where VAEDecode was cached by ComfyUI and not executed
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
for node_id, node_data in original_prompt.items():
class_type = node_data.get("class_type")
if class_type and class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[class_type]
class_name = class_obj.__name__
# Check if this is a VAEDecode node
if class_name == "VAEDecode":
# Try to find this node in the cache
cache_key = f"{node_id}:{class_name}"
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
return image_data[0]
return image_data
return None

View File

@@ -0,0 +1,353 @@
import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES
class NodeMetadataExtractor:
"""Base class for node-specific metadata extraction"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
"""Extract metadata from node inputs/outputs"""
pass
@staticmethod
def update(node_id, outputs, metadata):
"""Update metadata with node outputs after execution"""
pass
class GenericNodeExtractor(NodeMetadataExtractor):
"""Default extractor for nodes without specific handling"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
class CheckpointLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "ckpt_name" not in inputs:
return
model_name = inputs.get("ckpt_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "text" not in inputs:
return
text = inputs.get("text", "")
metadata[PROMPTS][node_id] = {
"text": text,
"node_id": node_id
}
class SamplerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "lora_name" not in inputs:
return
lora_name = inputs.get("lora_name")
# Extract base filename without extension from path
lora_name = os.path.splitext(os.path.basename(lora_name))[0]
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
# Use the standardized format with lora_list
metadata[LORAS][node_id] = {
"lora_list": [
{
"name": lora_name,
"strength": strength_model
}
],
"node_id": node_id
}
class ImageSizeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
width = inputs.get("width", 512)
height = inputs.get("height", 512)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
active_loras = []
# Process lora_stack if available
if "lora_stack" in inputs:
lora_stack = inputs.get("lora_stack", [])
for lora_path, model_strength, clip_strength in lora_stack:
# Extract lora name from path (following the format in lora_loader.py)
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
active_loras.append({
"name": lora_name,
"strength": model_strength
})
# Process loras from inputs
if "loras" in inputs:
loras_data = inputs.get("loras", [])
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
loras_list = loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Filter for active loras
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
active_loras.append({
"name": lora.get("name", ""),
"strength": float(lora.get("strength", 1.0))
})
if active_loras:
metadata[LORAS][node_id] = {
"lora_list": active_loras,
"node_id": node_id
}
class FluxGuidanceExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "guidance" not in inputs:
return
guidance_value = inputs.get("guidance")
# Store the guidance value in SAMPLING category
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
class UNETLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "unet_name" not in inputs:
return
model_name = inputs.get("unet_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class VAEDecodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
@staticmethod
def update(node_id, outputs, metadata):
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
# Save image data under node ID index to be captured by caching mechanism
metadata[IMAGES][node_id] = {
"node_id": node_id,
"image": outputs
}
# Only set first_decode if it hasn't been recorded yet
if "first_decode" not in metadata[IMAGES]:
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
class KSamplerSelectExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "sampler_name" not in inputs:
return
sampling_params = {}
if "sampler_name" in inputs:
sampling_params["sampler_name"] = inputs["sampler_name"]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
class BasicSchedulerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["scheduler", "steps", "denoise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
# Handle noise.seed as seed
if "noise" in inputs and inputs["noise"] is not None and hasattr(inputs["noise"], "seed"):
noise = inputs["noise"]
sampling_params["seed"] = noise.seed
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
# Registry of node-specific extractors
NODE_EXTRACTORS = {
# Sampling
"KSampler": SamplerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor,
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, # Updated to use dedicated extractor
# Sampling Selectors
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
# Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"LoraLoader": LoraLoaderExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
# Conditioning
"CLIPTextEncode": CLIPTextEncodeExtractor,
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Image
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
# Add other nodes as needed
}

View File

@@ -0,0 +1,35 @@
import logging
from ..metadata_collector.metadata_processor import MetadataProcessor
logger = logging.getLogger(__name__)
class DebugMetadata:
NAME = "Debug Metadata (LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Debug node to verify metadata_processor functionality"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("metadata_json",)
FUNCTION = "process_metadata"
def process_metadata(self, images):
try:
# Get the current execution context's metadata
from ..metadata_collector import get_metadata
metadata = get_metadata()
# Use the MetadataProcessor to convert it to JSON string
metadata_json = MetadataProcessor.to_json(metadata)
return (metadata_json,)
except Exception as e:
logger.error(f"Error processing metadata: {e}")
return ("{}",) # Return empty JSON object in case of error

View File

@@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
logger = logging.getLogger(__name__)
@@ -32,48 +32,6 @@ class LoraManagerLoader:
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def load_loras(self, model, text, **kwargs):
"""Loads multiple LoRAs based on the kwargs input and lora_stack."""
@@ -89,14 +47,14 @@ class LoraManagerLoader:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
loaded_loras.append(f"{lora_name}: {model_strength}")
# Then process loras from kwargs with support for both old and new formats
loras_list = self._get_loras_list(kwargs)
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get('active', False):
continue
@@ -105,7 +63,7 @@ class LoraManagerLoader:
strength = float(lora['strength'])
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Apply the LoRA using the resolved path
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)

View File

@@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
import logging
logger = logging.getLogger(__name__)
@@ -29,48 +29,6 @@ class LoraStacker:
RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
FUNCTION = "stack_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def stack_loras(self, text, **kwargs):
"""Stacks multiple LoRAs based on the kwargs input without loading them."""
@@ -84,12 +42,12 @@ class LoraStacker:
stack.extend(lora_stack)
# Get trigger words from existing stack entries
for lora_path, _, _ in lora_stack:
lora_name = self.extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words)
# Process loras from kwargs with support for both old and new formats
loras_list = self._get_loras_list(kwargs)
loras_list = get_loras_list(kwargs)
for lora in loras_list:
if not lora.get('active', False):
continue
@@ -99,7 +57,7 @@ class LoraStacker:
clip_strength = model_strength # Using same strength for both as in the original loader
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name))
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Add to stack without loading
# replace '/' with os.sep to avoid different OS path format

View File

@@ -5,10 +5,11 @@ import re
import numpy as np
import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner
from ..workflow.parser import WorkflowParser
from ..services.checkpoint_scanner import CheckpointScanner
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
import piexif
from io import BytesIO
class SaveImage:
NAME = "Save Image (LoraManager)"
@@ -34,8 +35,7 @@ class SaveImage:
"file_format": (["png", "jpeg", "webp"],),
},
"optional": {
"custom_prompt": ("STRING", {"default": "", "forceInput": True}),
"lossless_webp": ("BOOLEAN", {"default": True}),
"lossless_webp": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
@@ -54,28 +54,61 @@ class SaveImage:
async def get_lora_hash(self, lora_name):
"""Get the lora hash from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
# Use the new direct filename lookup method
hash_value = scanner.get_hash_by_filename(lora_name)
if hash_value:
return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
return item.get('sha256')
return None
async def format_metadata(self, parsed_workflow, custom_prompt=None):
async def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache"""
scanner = await CheckpointScanner.get_instance()
if not checkpoint_path:
return None
# Extract basename without extension
checkpoint_name = os.path.basename(checkpoint_path)
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Try direct filename lookup first
hash_value = scanner.get_hash_by_filename(checkpoint_name)
if hash_value:
return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
normalized_path = checkpoint_path.replace('\\', '/')
for item in cache.raw_data:
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
return item.get('sha256')
return None
async def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not parsed_workflow:
if not metadata_dict:
return ""
# Extract the prompt and negative prompt
prompt = parsed_workflow.get('prompt', '')
negative_prompt = parsed_workflow.get('negative_prompt', '')
# Helper function to only add parameter if value is not None
def add_param_if_not_none(param_list, label, value):
if value is not None:
param_list.append(f"{label}: {value}")
# Override prompt with custom_prompt if provided
if custom_prompt:
prompt = custom_prompt
# Extract the prompt and negative prompt
prompt = metadata_dict.get('prompt', '')
negative_prompt = metadata_dict.get('negative_prompt', '')
# Extract loras from the prompt if present
loras_text = parsed_workflow.get('loras', '')
loras_text = metadata_dict.get('loras', '')
lora_hashes = {}
# If loras are found, add them on a new line after the prompt
@@ -104,11 +137,15 @@ class SaveImage:
params = []
# Add standard parameters in the correct order
if 'steps' in parsed_workflow:
params.append(f"Steps: {parsed_workflow.get('steps')}")
if 'steps' in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
if 'sampler' in parsed_workflow:
sampler = parsed_workflow.get('sampler')
# Combine sampler and scheduler information
sampler_name = None
scheduler_name = None
if 'sampler' in metadata_dict:
sampler = metadata_dict.get('sampler')
# Convert ComfyUI sampler names to user-friendly names
sampler_mapping = {
'euler': 'Euler',
@@ -128,10 +165,9 @@ class SaveImage:
'ddim': 'DDIM'
}
sampler_name = sampler_mapping.get(sampler, sampler)
params.append(f"Sampler: {sampler_name}")
if 'scheduler' in parsed_workflow:
scheduler = parsed_workflow.get('scheduler')
if 'scheduler' in metadata_dict:
scheduler = metadata_dict.get('scheduler')
scheduler_mapping = {
'normal': 'Simple',
'karras': 'Karras',
@@ -140,29 +176,48 @@ class SaveImage:
'sgm_quadratic': 'SGM Quadratic'
}
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
params.append(f"Schedule type: {scheduler_name}")
# CFG scale (cfg in parsed_workflow)
if 'cfg_scale' in parsed_workflow:
params.append(f"CFG scale: {parsed_workflow.get('cfg_scale')}")
elif 'cfg' in parsed_workflow:
params.append(f"CFG scale: {parsed_workflow.get('cfg')}")
# Add combined sampler and scheduler information
if sampler_name:
if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}")
else:
params.append(f"Sampler: {sampler_name}")
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
if 'guidance' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
elif 'cfg_scale' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
elif 'cfg' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
# Seed
if 'seed' in parsed_workflow:
params.append(f"Seed: {parsed_workflow.get('seed')}")
if 'seed' in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
# Size
if 'size' in parsed_workflow:
params.append(f"Size: {parsed_workflow.get('size')}")
if 'size' in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
# Model info
if 'checkpoint' in parsed_workflow:
# Extract basename without path
checkpoint = os.path.basename(parsed_workflow.get('checkpoint', ''))
# Remove extension if present
checkpoint = os.path.splitext(checkpoint)[0]
params.append(f"Model: {checkpoint}")
if 'checkpoint' in metadata_dict:
# Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None:
# Get model hash
model_hash = await self.get_checkpoint_hash(checkpoint)
# Extract basename without path
checkpoint_name = os.path.basename(checkpoint)
# Remove extension if present
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Add model hash if available
if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
else:
params.append(f"Model: {checkpoint_name}")
# Add LoRA hashes if available
if lora_hashes:
@@ -181,9 +236,9 @@ class SaveImage:
# credit to nkchocoai
# Add format_filename method to handle pattern substitution
def format_filename(self, filename, parsed_workflow):
def format_filename(self, filename, metadata_dict):
"""Format filename with metadata values"""
if not parsed_workflow:
if not metadata_dict:
return filename
result = re.findall(self.pattern_format, filename)
@@ -191,30 +246,30 @@ class SaveImage:
parts = segment.replace("%", "").split(":")
key = parts[0]
if key == "seed" and 'seed' in parsed_workflow:
filename = filename.replace(segment, str(parsed_workflow.get('seed', '')))
elif key == "width" and 'size' in parsed_workflow:
size = parsed_workflow.get('size', 'x')
if key == "seed" and 'seed' in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
elif key == "width" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
w = size.split('x')[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in parsed_workflow:
size = parsed_workflow.get('size', 'x')
elif key == "height" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
h = size.split('x')[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in parsed_workflow:
prompt = parsed_workflow.get('prompt', '').replace("\n", " ")
elif key == "pprompt" and 'prompt' in metadata_dict:
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in parsed_workflow:
prompt = parsed_workflow.get('negative_prompt', '').replace("\n", " ")
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "model" and 'checkpoint' in parsed_workflow:
model = parsed_workflow.get('checkpoint', '')
elif key == "model" and 'checkpoint' in metadata_dict:
model = metadata_dict.get('checkpoint', '')
model = os.path.splitext(os.path.basename(model))[0]
if len(parts) >= 2:
length = int(parts[1])
@@ -224,12 +279,13 @@ class SaveImage:
from datetime import datetime
now = datetime.now()
date_table = {
"yyyy": str(now.year),
"MM": str(now.month).zfill(2),
"dd": str(now.day).zfill(2),
"hh": str(now.hour).zfill(2),
"mm": str(now.minute).zfill(2),
"ss": str(now.second).zfill(2),
"yyyy": f"{now.year:04d}",
"yy": f"{now.year % 100:02d}",
"MM": f"{now.month:02d}",
"dd": f"{now.day:02d}",
"hh": f"{now.hour:02d}",
"mm": f"{now.minute:02d}",
"ss": f"{now.second:02d}",
}
if len(parts) >= 2:
date_format = parts[1]
@@ -245,23 +301,19 @@ class SaveImage:
return filename
def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
custom_prompt=None):
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Save images with metadata"""
results = []
# Parse the workflow using the WorkflowParser
parser = WorkflowParser()
if prompt:
parsed_workflow = parser.parse_workflow(prompt)
else:
parsed_workflow = {}
# Get metadata using the metadata collector
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Get or create metadata asynchronously
metadata = asyncio.run(self.format_metadata(parsed_workflow, custom_prompt))
metadata = asyncio.run(self.format_metadata(metadata_dict))
# Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, parsed_workflow)
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
@@ -283,13 +335,14 @@ class SaveImage:
if add_counter_to_filename:
# Use counter + i to ensure unique filenames for all images in batch
current_counter = counter + i
base_filename += f"_{current_counter:05}"
base_filename += f"_{current_counter:05}_"
# Set file extension and prepare saving parameters
if file_format == "png":
file = base_filename + ".png"
file_extension = ".png"
save_kwargs = {"optimize": True, "compress_level": self.compress_level}
# Remove "optimize": True to match built-in node behavior
save_kwargs = {"compress_level": self.compress_level}
pnginfo = PngImagePlugin.PngInfo()
elif file_format == "jpeg":
file = base_filename + ".jpg"
@@ -298,7 +351,8 @@ class SaveImage:
elif file_format == "webp":
file = base_filename + ".webp"
file_extension = ".webp"
save_kwargs = {"quality": quality, "lossless": lossless_webp}
# Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
# Full save path
file_path = os.path.join(full_output_folder, file)
@@ -346,8 +400,7 @@ class SaveImage:
return results
def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True,
custom_prompt=""):
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Process and save image with metadata"""
# Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True)
@@ -368,8 +421,7 @@ class SaveImage:
lossless_webp,
quality,
embed_workflow,
add_counter_to_filename,
custom_prompt if custom_prompt.strip() else None
add_counter_to_filename
)
return (images,)

View File

@@ -47,10 +47,10 @@ class TriggerWordToggle:
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
# Send trigger words to frontend
PromptServer.instance.send_sync("trigger_word_update", {
"id": id,
"message": trigger_words
})
# PromptServer.instance.send_sync("trigger_word_update", {
# "id": id,
# "message": trigger_words
# })
filtered_triggers = trigger_words

View File

@@ -30,4 +30,55 @@ class FlexibleOptionalInputType(dict):
return True
any_type = AnyType("*")
any_type = AnyType("*")
# Common methods extracted from lora_loader.py and lora_stacker.py
import os
import logging
import asyncio
from ..services.lora_scanner import LoraScanner
from ..config import config
logger = logging.getLogger(__name__)
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def get_loras_list(kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []

View File

@@ -3,8 +3,10 @@ import json
import logging
from aiohttp import web
from typing import Dict
from server import PromptServer # type: ignore
from ..utils.routes_common import ModelRouteUtils
from ..nodes.utils import get_lora_info
from ..config import config
from ..services.websocket_manager import ws_manager
@@ -50,8 +52,8 @@ class ApiRoutes:
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/folders', routes.get_folders)
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
app.router.add_get('/api/civitai/model/{modelVersionId}', routes.get_civitai_model)
app.router.add_get('/api/civitai/model/{hash}', routes.get_civitai_model)
app.router.add_get('/api/civitai/model/version/{modelVersionId}', routes.get_civitai_model_by_version)
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
app.router.add_post('/api/download-lora', routes.download_lora)
app.router.add_post('/api/settings', routes.update_settings)
app.router.add_post('/api/move_model', routes.move_model)
@@ -64,6 +66,9 @@ class ApiRoutes:
app.router.add_get('/api/lora-civitai-url', routes.get_lora_civitai_url) # Add new route for Civitai URL
app.router.add_post('/api/rename_lora', routes.rename_lora) # Add new route for renaming LoRA files
app.router.add_get('/api/loras/scan', routes.scan_loras) # Add new route for scanning LoRA files
# Add the new trigger words route
app.router.add_post('/loramanager/get_trigger_words', routes.get_trigger_words)
# Add update check routes
UpdateRoutes.setup_routes(app)
@@ -120,6 +125,7 @@ class ApiRoutes:
# Get filter parameters
base_models = request.query.get('base_models', None)
tags = request.query.get('tags', None)
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # New parameter
# New parameters for recipe filtering
lora_hash = request.query.get('lora_hash', None)
@@ -150,7 +156,8 @@ class ApiRoutes:
base_models=filters.get('base_model', None),
tags=filters.get('tags', None),
search_options=search_options,
hash_filters=hash_filters
hash_filters=hash_filters,
favorites_only=favorites_only # Pass favorites_only parameter
)
# Get all available folders from cache
@@ -190,6 +197,7 @@ class ApiRoutes:
"from_civitai": lora.get("from_civitai", True),
"usage_tips": lora.get("usage_tips", ""),
"notes": lora.get("notes", ""),
"favorite": lora.get("favorite", False), # Include favorite status in response
"civitai": ModelRouteUtils.filter_civitai_data(lora.get("civitai", {}))
}
@@ -226,7 +234,7 @@ class ApiRoutes:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
extension = '.webp' # Use .webp without .preview part
@@ -396,25 +404,52 @@ class ApiRoutes:
logger.error(f"Error fetching model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID or hash"""
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
model_version_id = request.match_info.get('modelVersionId')
if not model_version_id:
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
# Get model details from Civitai API
model = await self.civitai_client.get_model_version_info(model_version_id)
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.Response(status=500, text=str(e))
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def download_lora(self, request: web.Request) -> web.Response:
async with self._download_lock:
@@ -773,7 +808,7 @@ class ApiRoutes:
logger.info(f"Fetching model metadata for model ID: {model_id}")
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
if model_metadata:
if (model_metadata):
description = model_metadata.get('description')
tags = model_metadata.get('tags', [])
@@ -994,4 +1029,35 @@ class ApiRoutes:
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models"""
try:
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = await get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error getting trigger words: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -69,6 +69,7 @@ class CheckpointsRoutes:
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
# Process search options
search_options = {
@@ -101,7 +102,8 @@ class CheckpointsRoutes:
base_models=base_models,
tags=tags,
search_options=search_options,
hash_filters=hash_filters
hash_filters=hash_filters,
favorites_only=favorites_only # Pass favorites_only parameter
)
# Format response items
@@ -123,7 +125,8 @@ class CheckpointsRoutes:
async def get_paginated_data(self, page, page_size, sort_by='name',
folder=None, search=None, fuzzy_search=False,
base_models=None, tags=None,
search_options=None, hash_filters=None):
search_options=None, hash_filters=None,
favorites_only=False): # Add favorites_only parameter with default False
"""Get paginated and filtered checkpoint data"""
cache = await self.scanner.get_cached_data()
@@ -181,6 +184,13 @@ class CheckpointsRoutes:
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
filtered_data = [
cp for cp in filtered_data
if cp.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
@@ -276,6 +286,7 @@ class CheckpointsRoutes:
"from_civitai": checkpoint.get("from_civitai", True),
"notes": checkpoint.get("notes", ""),
"model_type": checkpoint.get("model_type", "checkpoint"),
"favorite": checkpoint.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
}

View File

@@ -1,20 +1,35 @@
import os
import time
import numpy as np
from PIL import Image
import torch
import io
import logging
from aiohttp import web
from typing import Dict
import tempfile
import json
import asyncio
import sys
from ..utils.exif_utils import ExifUtils
from ..utils.recipe_parsers import RecipeParserFactory
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..config import config
from ..workflow.parser import WorkflowParser
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
# Only import MetadataRegistry in non-standalone mode
if not standalone_mode:
# Import metadata_collector functions and classes conditionally
from ..metadata_collector import get_metadata # Add MetadataCollector import
from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
from ..metadata_collector.metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__)
class RecipeRoutes:
@@ -24,7 +39,7 @@ class RecipeRoutes:
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
self.parser = WorkflowParser()
# Remove WorkflowParser instance
# Pre-warm the cache
self._init_cache_task = None
@@ -68,6 +83,9 @@ class RecipeRoutes:
# Add route to get recipes for a specific Lora
app.router.add_get('/api/recipes/for-lora', routes.get_recipes_for_lora)
# Add new endpoint for scanning and rebuilding the recipe cache
app.router.add_get('/api/recipes/scan', routes.scan_recipes)
async def _init_cache(self, app):
"""Initialize cache on startup"""
@@ -656,8 +674,8 @@ class RecipeRoutes:
logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
'error': str(e)}
, status=500)
async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF"""
@@ -786,50 +804,75 @@ class RecipeRoutes:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart()
# Get metadata using the metadata collector instead of workflow parsing
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Process form data
workflow_json = None
# Check if we have valid metadata
if not metadata_dict:
return web.json_response({"error": "No generation metadata found"}, status=400)
while True:
field = await reader.next()
if field is None:
break
# Get the most recent image from metadata registry instead of temp directory
if not standalone_mode:
metadata_registry = MetadataRegistry()
latest_image = metadata_registry.get_first_decoded_image()
else:
latest_image = None
if not latest_image:
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
# Convert the image data to bytes - handle tuple and tensor cases
logger.debug(f"Image type: {type(latest_image)}")
try:
# Handle the tuple case first
if isinstance(latest_image, tuple):
# Extract the tensor from the tuple
if len(latest_image) > 0:
tensor_image = latest_image[0]
else:
return web.json_response({"error": "Empty image tuple received"}, status=400)
else:
tensor_image = latest_image
if field.name == 'workflow_json':
workflow_text = await field.text()
try:
workflow_json = json.loads(workflow_text)
except:
return web.json_response({"error": "Invalid workflow JSON"}, status=400)
# Get the shape info for debugging
if hasattr(tensor_image, 'shape'):
shape_info = tensor_image.shape
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
# Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
# Handle different tensor shapes
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
if len(image_np.shape) > 3:
# Remove batch dimensions until we get to (H, W, 3)
while len(image_np.shape) > 3:
image_np = image_np[0]
# If values are in [0, 1] range, convert to [0, 255]
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
# Ensure image is in the right format (HWC with RGB channels)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
image = img_byte_arr.getvalue()
else:
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
if not workflow_json:
return web.json_response({"error": "Missing workflow JSON"}, status=400)
# Find the latest image in the temp directory
temp_dir = config.temp_directory
image_files = []
for file in os.listdir(temp_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
file_path = os.path.join(temp_dir, file)
image_files.append((file_path, os.path.getmtime(file_path)))
if not image_files:
return web.json_response({"error": "No recent images found to use for recipe"}, status=400)
# Sort by modification time (newest first)
image_files.sort(key=lambda x: x[1], reverse=True)
latest_image_path = image_files[0][0]
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow:
return web.json_response({"error": "Could not extract parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "")
# Get the lora stack from the metadata
lora_stack = metadata_dict.get("loras", "")
# Parse the lora stack format: "<lora:name:strength> <lora:name2:strength2> ..."
import re
@@ -837,7 +880,7 @@ class RecipeRoutes:
# Check if any loras were found
if not lora_matches:
return web.json_response({"error": "No LoRAs found in the workflow"}, status=400)
return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400)
# Generate recipe name from the first 3 loras (or less if fewer are available)
loras_for_name = lora_matches[:3] # Take at most 3 loras for the name
@@ -851,10 +894,6 @@ class RecipeRoutes:
recipe_name = " ".join(recipe_name_parts)
# Read the image
with open(latest_image_path, 'rb') as f:
image = f.read()
# Create recipes directory if it doesn't exist
recipes_dir = self.recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)
@@ -922,8 +961,8 @@ class RecipeRoutes:
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"checkpoint": parsed_workflow.get("checkpoint", ""),
"gen_params": {key: value for key, value in parsed_workflow.items()
"checkpoint": metadata_dict.get("checkpoint", ""),
"gen_params": {key: value for key, value in metadata_dict.items()
if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string
}
@@ -1231,3 +1270,24 @@ class RecipeRoutes:
except Exception as e:
logger.error(f"Error getting recipes for Lora: {str(e)}")
return web.json_response({'success': False, 'error': str(e)}, status=500)
async def scan_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for scanning and rebuilding the recipe cache"""
try:
# Ensure services are initialized
await self.init_services()
# Force refresh the recipe cache
logger.info("Manually triggering recipe cache rebuild")
await self.recipe_scanner.get_cached_data(force_refresh=True)
return web.json_response({
'success': True,
'message': 'Recipe cache refreshed successfully'
})
except Exception as e:
logger.error(f"Error refreshing recipe cache: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

View File

@@ -0,0 +1,69 @@
import logging
from aiohttp import web
from ..utils.usage_stats import UsageStats
logger = logging.getLogger(__name__)
class UsageStatsRoutes:
"""Routes for handling usage statistics updates"""
@staticmethod
def setup_routes(app):
"""Register usage stats routes"""
app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats)
app.router.add_get('/loras/api/get-usage-stats', UsageStatsRoutes.get_usage_stats)
@staticmethod
async def update_usage_stats(request):
"""
Update usage statistics based on a prompt_id
Expects a JSON body with:
{
"prompt_id": "string"
}
"""
try:
# Parse the request body
data = await request.json()
prompt_id = data.get('prompt_id')
if not prompt_id:
return web.json_response({
'success': False,
'error': 'Missing prompt_id'
}, status=400)
# Call the UsageStats to process this prompt_id synchronously
usage_stats = UsageStats()
await usage_stats.process_execution(prompt_id)
return web.json_response({
'success': True
})
except Exception as e:
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_usage_stats(request):
"""Get current usage statistics"""
try:
usage_stats = UsageStats()
stats = await usage_stats.get_stats()
return web.json_response({
'success': True,
'data': stats
})
except Exception as e:
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

26
py/server_routes.py Normal file
View File

@@ -0,0 +1,26 @@
from aiohttp import web
from server import PromptServer
from .nodes.utils import get_lora_info
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
async def get_trigger_words(request):
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = await get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})

View File

@@ -34,6 +34,7 @@ class CivitaiClient:
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
self._session = None
self._session_created_at = None
# Set default buffer size to 1MB for higher throughput
self.chunk_size = 1024 * 1024
@@ -44,8 +45,8 @@ class CivitaiClient:
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=10, # Increase parallel connections
ttl_dns_cache=300, # DNS cache time
limit=3, # Further reduced from 5 to 3
ttl_dns_cache=0, # Disabled DNS caching completely
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
@@ -57,7 +58,18 @@ class CivitaiClient:
trust_env=trust_env,
timeout=timeout
)
self._session_created_at = datetime.now()
return self._session
async def _ensure_fresh_session(self):
"""Refresh session if it's been open too long"""
if self._session is not None:
if not hasattr(self, '_session_created_at') or \
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
await self.close()
self._session = None
return await self.session
def _parse_content_disposition(self, header: str) -> str:
"""Parse filename from content-disposition header"""
@@ -103,13 +115,15 @@ class CivitaiClient:
Returns:
Tuple[bool, str]: (success, save_path or error message)
"""
session = await self.session
logger.debug(f"Resolving DNS for: {url}")
session = await self._ensure_fresh_session()
try:
headers = self._get_request_headers()
# Add Range header to allow resumable downloads
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
logger.debug(f"Starting download from: {url}")
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
# Handle 401 unauthorized responses
@@ -124,6 +138,7 @@ class CivitaiClient:
return False, "Access forbidden: You don't have permission to download this file."
# Generic error response for other status codes
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get filename from content-disposition header
@@ -170,7 +185,7 @@ class CivitaiClient:
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try:
session = await self.session
session = await self._ensure_fresh_session()
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
if response.status == 200:
return await response.json()
@@ -181,7 +196,7 @@ class CivitaiClient:
async def download_preview_image(self, image_url: str, save_path: str):
try:
session = await self.session
session = await self._ensure_fresh_session()
async with session.get(image_url) as response:
if response.status == 200:
content = await response.read()
@@ -196,7 +211,7 @@ class CivitaiClient:
async def get_model_versions(self, model_id: str) -> List[Dict]:
"""Get all versions of a model with local availability info"""
try:
session = await self.session # 等待获取 session
session = await self._ensure_fresh_session() # Use fresh session
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
@@ -210,20 +225,46 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Optional[Dict]:
"""Fetch model version metadata from Civitai"""
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai
Args:
version_id: The Civitai model version ID
Returns:
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
- The model version data or None if not found
- An error message if there was an error, or None on success
"""
try:
session = await self.session
session = await self._ensure_fresh_session()
url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
logger.debug(f"Resolving DNS for model version info: {url}")
async with session.get(url, headers=headers) as response:
if response.status == 200:
return await response.json()
return None
logger.debug(f"Successfully fetched model version info for: {version_id}")
return await response.json(), None
# Handle specific error cases
if response.status == 404:
# Try to parse the error message
try:
error_data = await response.json()
error_msg = error_data.get('error', f"Model not found (status 404)")
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
except:
return None, "Model not found (status 404)"
# Other error cases
logger.error(f"Failed to fetch model info for {version_id} (status {response.status})")
return None, f"Failed to fetch model info (status {response.status})"
except Exception as e:
logger.error(f"Error fetching model version info: {e}")
return None
error_msg = f"Error fetching model version info: {e}"
logger.error(error_msg)
return None, error_msg
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description and tags) from Civitai API
@@ -237,7 +278,7 @@ class CivitaiClient:
- The HTTP status code from the request
"""
try:
session = await self.session
session = await self._ensure_fresh_session()
headers = self._get_request_headers()
url = f"{self.base_url}/models/{model_id}"
@@ -281,10 +322,11 @@ class CivitaiClient:
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
"""Get hash from Civitai API"""
try:
if not self._session:
session = await self._ensure_fresh_session()
if not session:
return None
version_info = await self._session.get(f"{self.base_url}/model-versions/{model_version_id}")
version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}")
if not version_info or not version_info.json().get('files'):
return None

View File

@@ -86,21 +86,24 @@ class DownloadManager:
# Get version info based on the provided identifier
version_info = None
error_msg = None
if download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info = await civitai_client.get_model_version_info(version_id)
elif model_version_id:
# Use model version ID directly
version_info = await civitai_client.get_model_version_info(model_version_id)
elif model_hash:
if model_hash:
# Get model by hash
version_info = await civitai_client.get_model_by_hash(model_hash)
elif model_version_id:
# Use model version ID directly
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
elif download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
if error_msg and "model not found" in error_msg.lower():
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
@@ -202,7 +205,7 @@ class DownloadManager:
# Check if it's a video or an image
is_video = images[0].get('type') == 'video'
if is_video:
if (is_video):
# For videos, use .mp4 extension
preview_ext = '.mp4'
preview_path = os.path.splitext(save_path)[0] + preview_ext
@@ -229,7 +232,7 @@ class DownloadManager:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
# Save the optimized image

View File

@@ -408,7 +408,7 @@ class BaseFileMonitor:
def start(self):
"""Start file monitoring"""
if not ENABLE_FILE_MONITORING:
logger.info("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
logger.debug("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
for path in self.monitor_paths:
@@ -525,18 +525,18 @@ class CheckpointFileMonitor(BaseFileMonitor):
def start(self):
"""Override start to check global enable flag"""
if not ENABLE_FILE_MONITORING:
logger.info("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
logger.debug("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
logger.info("Checkpoint file monitoring is temporarily disabled")
logger.debug("Checkpoint file monitoring is temporarily disabled")
# Skip the actual monitoring setup
pass
async def initialize_paths(self):
"""Initialize monitor paths from scanner - currently disabled"""
if not ENABLE_FILE_MONITORING:
logger.info("Checkpoint path initialization skipped (monitoring disabled)")
logger.debug("Checkpoint path initialization skipped (monitoring disabled)")
return
logger.info("Checkpoint file path initialization skipped (monitoring disabled)")
logger.debug("Checkpoint file path initialization skipped (monitoring disabled)")
pass

View File

@@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Set
from ..utils.models import LoraMetadata
from ..config import config
from .model_scanner import ModelScanner
from .lora_hash_index import LoraHashIndex
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
@@ -35,12 +35,12 @@ class LoraScanner(ModelScanner):
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=LoraHashIndex()
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
self._initialized = True
@@ -122,7 +122,8 @@ class LoraScanner(ModelScanner):
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None, hash_filters: dict = None) -> Dict:
search_options: dict = None, hash_filters: dict = None,
favorites_only: bool = False) -> Dict:
"""Get paginated and filtered lora data
Args:
@@ -136,6 +137,7 @@ class LoraScanner(ModelScanner):
tags: List of tags to filter by
search_options: Dictionary with search options (filename, modelname, tags, recursive)
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
favorites_only: Filter for favorite models only
"""
cache = await self.get_cached_data()
@@ -194,6 +196,13 @@ class LoraScanner(ModelScanner):
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
filtered_data = [
lora for lora in filtered_data
if lora.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):

View File

@@ -1,11 +1,12 @@
from typing import Dict, Optional, Set
import os
class ModelHashIndex:
"""Index for looking up models by hash or path"""
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
self._path_to_hash: Dict[str, str] = {}
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update hash index entry"""
@@ -15,37 +16,47 @@ class ModelHashIndex:
# Ensure hash is lowercase for consistency
sha256 = sha256.lower()
# Extract filename without extension
filename = self._get_filename_from_path(file_path)
# Remove old path mapping if hash exists
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
if old_path in self._path_to_hash:
del self._path_to_hash[old_path]
old_filename = self._get_filename_from_path(old_path)
if old_filename in self._filename_to_hash:
del self._filename_to_hash[old_filename]
# Remove old hash mapping if path exists
if file_path in self._path_to_hash:
old_hash = self._path_to_hash[file_path]
# Remove old hash mapping if filename exists
if filename in self._filename_to_hash:
old_hash = self._filename_to_hash[filename]
if old_hash in self._hash_to_path:
del self._hash_to_path[old_hash]
# Add new mappings
self._hash_to_path[sha256] = file_path
self._path_to_hash[file_path] = sha256
self._filename_to_hash[filename] = sha256
def _get_filename_from_path(self, file_path: str) -> str:
"""Extract filename without extension from path"""
return os.path.splitext(os.path.basename(file_path))[0]
def remove_by_path(self, file_path: str) -> None:
"""Remove entry by file path"""
if file_path in self._path_to_hash:
hash_val = self._path_to_hash[file_path]
filename = self._get_filename_from_path(file_path)
if filename in self._filename_to_hash:
hash_val = self._filename_to_hash[filename]
if hash_val in self._hash_to_path:
del self._hash_to_path[hash_val]
del self._path_to_hash[file_path]
del self._filename_to_hash[filename]
def remove_by_hash(self, sha256: str) -> None:
"""Remove entry by hash"""
sha256 = sha256.lower()
if sha256 in self._hash_to_path:
path = self._hash_to_path[sha256]
if path in self._path_to_hash:
del self._path_to_hash[path]
filename = self._get_filename_from_path(path)
if filename in self._filename_to_hash:
del self._filename_to_hash[filename]
del self._hash_to_path[sha256]
def has_hash(self, sha256: str) -> bool:
@@ -58,20 +69,27 @@ class ModelHashIndex:
def get_hash(self, file_path: str) -> Optional[str]:
"""Get hash for a file path"""
return self._path_to_hash.get(file_path)
filename = self._get_filename_from_path(file_path)
return self._filename_to_hash.get(filename)
def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a filename without extension"""
# Strip extension if present to make the function more flexible
filename = os.path.splitext(filename)[0]
return self._filename_to_hash.get(filename)
def clear(self) -> None:
"""Clear all entries"""
self._hash_to_path.clear()
self._path_to_hash.clear()
self._filename_to_hash.clear()
def get_all_hashes(self) -> Set[str]:
"""Get all hashes in the index"""
return set(self._hash_to_path.keys())
def get_all_paths(self) -> Set[str]:
"""Get all file paths in the index"""
return set(self._path_to_hash.keys())
def get_all_filenames(self) -> Set[str]:
"""Get all filenames in the index"""
return set(self._filename_to_hash.keys())
def __len__(self) -> int:
"""Get number of entries"""

View File

@@ -292,7 +292,7 @@ class ModelScanner:
)
# If force refresh is requested, initialize the cache directly
if force_refresh:
if (force_refresh):
if self._cache is None:
# For initial creation, do a full initialization
await self._initialize_cache()
@@ -553,9 +553,36 @@ class ModelScanner:
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
else:
# Check if metadata exists but civitai field is empty - try to restore from civitai.info
if metadata.civitai is None or metadata.civitai == {}:
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
logger.debug(f"Restoring missing civitai data from .civitai.info for {file_path}")
metadata.civitai = version_info
# Ensure tags are also updated if they're missing
if (not metadata.tags or len(metadata.tags) == 0) and 'model' in version_info:
if 'tags' in version_info['model']:
metadata.tags = version_info['model']['tags']
# Also restore description if missing
if (not metadata.modelDescription or metadata.modelDescription == "") and 'model' in version_info:
if 'description' in version_info['model']:
metadata.modelDescription = version_info['model']['description']
# Save the updated metadata
await save_metadata(file_path, metadata)
logger.debug(f"Updated metadata with civitai info for {file_path}")
except Exception as e:
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
if metadata is None:
metadata = await self._get_file_info(file_path)
if metadata is None:
metadata = await self._get_file_info(file_path)
model_data = metadata.to_dict()
@@ -709,6 +736,12 @@ class ModelScanner:
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file)
# Move civitai.info file if exists
source_civitai = os.path.join(source_dir, f"{base_name}.civitai.info")
if os.path.exists(source_civitai):
target_civitai = os.path.join(target_path, f"{base_name}.civitai.info")
shutil.move(source_civitai, target_civitai)
for ext in PREVIEW_EXTENSIONS:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
@@ -805,6 +838,10 @@ class ModelScanner:
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self._hash_index.get_hash(file_path)
def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a model by its filename without path"""
return self._hash_index.get_hash_by_filename(filename)
# TODO: Adjust this method to use metadata instead of finding the file
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:

View File

@@ -341,6 +341,10 @@ class RecipeScanner:
metadata_updated = False
for lora in recipe_data['loras']:
# Skip deleted loras that were already marked
if lora.get('isDeleted', False):
continue
# Skip if already has complete information
if 'hash' in lora and 'file_name' in lora and lora['file_name']:
continue
@@ -356,10 +360,17 @@ class RecipeScanner:
metadata_updated = True
else:
# If not in cache, fetch from Civitai
hash_from_civitai = await self._get_hash_from_civitai(model_version_id)
if hash_from_civitai:
lora['hash'] = hash_from_civitai
metadata_updated = True
result = await self._get_hash_from_civitai(model_version_id)
if isinstance(result, tuple):
hash_from_civitai, is_deleted = result
if hash_from_civitai:
lora['hash'] = hash_from_civitai
metadata_updated = True
elif is_deleted:
# Mark the lora as deleted if it was not found on Civitai
lora['isDeleted'] = True
logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted")
metadata_updated = True
else:
logger.debug(f"Could not get hash for modelVersionId {model_version_id}")
@@ -411,41 +422,26 @@ class RecipeScanner:
logger.error("Failed to get CivitaiClient from ServiceRegistry")
return None
version_info = await civitai_client.get_model_version_info(model_version_id)
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
if not version_info or not version_info.get('files'):
logger.debug(f"No files found in version info for ID: {model_version_id}")
return None
if not version_info:
if error_msg and "model not found" in error_msg.lower():
logger.warning(f"Model with version ID {model_version_id} was not found on Civitai - marking as deleted")
return None, True # Return None hash and True for isDeleted flag
else:
logger.debug(f"Could not get hash for modelVersionId {model_version_id}: {error_msg}")
return None, False # Return None hash but not marked as deleted
# Get hash from the first file
for file_info in version_info.get('files', []):
if file_info.get('hashes', {}).get('SHA256'):
return file_info['hashes']['SHA256']
return file_info['hashes']['SHA256'], False # Return hash with False for isDeleted flag
logger.debug(f"No SHA256 hash found in version info for ID: {model_version_id}")
return None
return None, False
except Exception as e:
logger.error(f"Error getting hash from Civitai: {e}")
return None
async def _get_model_version_name(self, model_version_id: str) -> Optional[str]:
"""Get model version name from Civitai API"""
try:
# Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
return None
version_info = await civitai_client.get_model_version_info(model_version_id)
if version_info and 'name' in version_info:
return version_info['name']
logger.debug(f"No version name found for modelVersionId {model_version_id}")
return None
except Exception as e:
logger.error(f"Error getting model version name from Civitai: {e}")
return None
return None, False
async def _determine_base_model(self, loras: List[Dict]) -> Optional[str]:
"""Determine the most common base model among LoRAs"""

View File

@@ -203,7 +203,7 @@ class ExifUtils:
return user_comment[:recipe_marker_index] + user_comment[next_line_index:]
@staticmethod
def optimize_image(image_data, target_width=250, format='webp', quality=85, preserve_metadata=True):
def optimize_image(image_data, target_width=250, format='webp', quality=85, preserve_metadata=False):
"""
Optimize an image by resizing and converting to WebP format
@@ -218,98 +218,144 @@ class ExifUtils:
Tuple of (optimized_image_data, extension)
"""
try:
# Extract metadata if needed
# First validate the image data is usable
img = None
if isinstance(image_data, str) and os.path.exists(image_data):
# It's a file path - validate file
try:
with Image.open(image_data) as test_img:
# Verify the image can be fully loaded by accessing its size
width, height = test_img.size
# If we got here, the image is valid
img = Image.open(image_data)
except (IOError, OSError) as e:
logger.error(f"Invalid or corrupt image file: {image_data}: {e}")
raise ValueError(f"Cannot process corrupt image: {e}")
else:
# It's binary data - validate data
try:
with BytesIO(image_data) as temp_buf:
test_img = Image.open(temp_buf)
# Verify the image can be fully loaded
width, height = test_img.size
# If successful, reopen for processing
img = Image.open(BytesIO(image_data))
except Exception as e:
logger.error(f"Invalid binary image data: {e}")
raise ValueError(f"Cannot process corrupt image data: {e}")
# Extract metadata if needed and valid
metadata = None
if preserve_metadata:
if isinstance(image_data, str) and os.path.exists(image_data):
# It's a file path
metadata = ExifUtils.extract_image_metadata(image_data)
img = Image.open(image_data)
else:
# It's binary data
temp_img = BytesIO(image_data)
img = Image.open(temp_img)
# Save to a temporary file to extract metadata
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(image_data)
metadata = ExifUtils.extract_image_metadata(temp_path)
os.unlink(temp_path)
else:
# Just open the image without extracting metadata
if isinstance(image_data, str) and os.path.exists(image_data):
img = Image.open(image_data)
else:
img = Image.open(BytesIO(image_data))
try:
if isinstance(image_data, str) and os.path.exists(image_data):
# For file path, extract directly
metadata = ExifUtils.extract_image_metadata(image_data)
else:
# For binary data, save to temp file first
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(image_data)
try:
metadata = ExifUtils.extract_image_metadata(temp_path)
except Exception as e:
logger.warning(f"Failed to extract metadata from temp file: {e}")
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to extract metadata, continuing without it: {e}")
# Continue without metadata
# Calculate new height to maintain aspect ratio
width, height = img.size
new_height = int(height * (target_width / width))
# Resize the image
resized_img = img.resize((target_width, new_height), Image.LANCZOS)
# Resize the image with error handling
try:
resized_img = img.resize((target_width, new_height), Image.LANCZOS)
except Exception as e:
logger.error(f"Failed to resize image: {e}")
# Return original image if resize fails
return image_data, '.jpg' if not isinstance(image_data, str) else os.path.splitext(image_data)[1]
# Save to BytesIO in the specified format
output = BytesIO()
# WebP format
# Set format and extension
if format.lower() == 'webp':
resized_img.save(output, format='WEBP', quality=quality)
extension = '.webp'
# JPEG format
save_format, extension = 'WEBP', '.webp'
elif format.lower() in ('jpg', 'jpeg'):
resized_img.save(output, format='JPEG', quality=quality)
extension = '.jpg'
# PNG format
save_format, extension = 'JPEG', '.jpg'
elif format.lower() == 'png':
resized_img.save(output, format='PNG', optimize=True)
extension = '.png'
save_format, extension = 'PNG', '.png'
else:
# Default to WebP
resized_img.save(output, format='WEBP', quality=quality)
extension = '.webp'
save_format, extension = 'WEBP', '.webp'
# Save with error handling
try:
if save_format == 'PNG':
resized_img.save(output, format=save_format, optimize=True)
else:
resized_img.save(output, format=save_format, quality=quality)
except Exception as e:
logger.error(f"Failed to save optimized image: {e}")
# Return original image if save fails
return image_data, '.jpg' if not isinstance(image_data, str) else os.path.splitext(image_data)[1]
# Get the optimized image data
optimized_data = output.getvalue()
# If we need to preserve metadata, write it to a temporary file
# Handle metadata preservation if requested and available
if preserve_metadata and metadata:
# For WebP format, we'll directly save with metadata
if format.lower() == 'webp':
# Create a new BytesIO with metadata
output_with_metadata = BytesIO()
# Create EXIF data with user comment
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict)
# Save with metadata
resized_img.save(output_with_metadata, format='WEBP', exif=exif_bytes, quality=quality)
optimized_data = output_with_metadata.getvalue()
else:
# For other formats, use the temporary file approach
import tempfile
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(optimized_data)
# Add the metadata back
ExifUtils.update_image_metadata(temp_path, metadata)
# Read the file with metadata
with open(temp_path, 'rb') as f:
optimized_data = f.read()
# Clean up
os.unlink(temp_path)
try:
if save_format == 'WEBP':
# For WebP format, directly save with metadata
try:
output_with_metadata = BytesIO()
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict)
resized_img.save(output_with_metadata, format='WEBP', exif=exif_bytes, quality=quality)
optimized_data = output_with_metadata.getvalue()
except Exception as e:
logger.warning(f"Failed to add metadata to WebP, continuing without it: {e}")
else:
# For other formats, use temporary file
import tempfile
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(optimized_data)
try:
# Add metadata
ExifUtils.update_image_metadata(temp_path, metadata)
# Read back the file
with open(temp_path, 'rb') as f:
optimized_data = f.read()
except Exception as e:
logger.warning(f"Failed to add metadata to image, continuing without it: {e}")
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to preserve metadata: {e}, continuing with unmodified output")
return optimized_data, extension
except Exception as e:
logger.error(f"Error optimizing image: {e}", exc_info=True)
# Return original data if optimization fails
# Return original data if optimization completely fails
if isinstance(image_data, str) and os.path.exists(image_data):
with open(image_data, 'rb') as f:
return f.read(), os.path.splitext(image_data)[1]
try:
with open(image_data, 'rb') as f:
return f.read(), os.path.splitext(image_data)[1]
except Exception:
return image_data, '.jpg' # Last resort fallback
return image_data, '.jpg'

View File

@@ -42,7 +42,7 @@ def find_preview_file(base_name: str, dir_path: str) -> str:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False # Changed from True to False
)
# Save the optimized webp file

View File

@@ -21,6 +21,8 @@ class BaseModelMetadata:
civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description
civitai_deleted: bool = False # Whether deleted from Civitai
favorite: bool = False # Whether the model is a favorite
def __post_init__(self):
# Initialize empty lists to avoid mutable default parameter issue
@@ -64,6 +66,15 @@ class LoraMetadata(BaseModelMetadata):
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
@@ -75,7 +86,9 @@ class LoraMetadata(BaseModelMetadata):
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download
from_civitai=True,
civitai=version_info
civitai=version_info,
tags=tags,
modelDescription=description
)
@dataclass
@@ -90,6 +103,15 @@ class CheckpointMetadata(BaseModelMetadata):
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'checkpoint')
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
@@ -102,6 +124,8 @@ class CheckpointMetadata(BaseModelMetadata):
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type
model_type=model_type,
tags=tags,
modelDescription=description
)

View File

@@ -45,14 +45,14 @@ class RecipeMetadataParser(ABC):
"""
pass
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info: Dict[str, Any],
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
"""
Populate a lora entry with information from Civitai API response
Args:
lora_entry: The lora entry to populate
civitai_info: The response from Civitai API
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
recipe_scanner: Optional recipe scanner for local file lookup
base_model_counts: Optional dict to track base model counts
hash_value: Optional hash value to use if not available in civitai_info
@@ -61,6 +61,9 @@ class RecipeMetadataParser(ABC):
The populated lora_entry dict
"""
try:
# Unpack the tuple to get the actual data
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
if civitai_info and civitai_info.get("error") != "Model not found":
# Check if this is an early access lora
if civitai_info.get('earlyAccessEndsAt'):
@@ -94,8 +97,9 @@ class RecipeMetadataParser(ABC):
# Process file information if available
if 'files' in civitai_info:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in civitai_info.get('files', [])
if file.get('type') == 'Model'), None)
if file.get('type') == 'Model' and file.get('primary') == True), None)
if model_file:
# Get size
@@ -241,11 +245,11 @@ class RecipeFormatParser(RecipeMetadataParser):
# Try to get additional info from Civitai if we have a model version ID
if lora.get('modelVersionId') and civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(lora['modelVersionId'])
civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
civitai_info_tuple,
recipe_scanner,
None, # No need to track base model counts
lora['hash']
@@ -336,12 +340,13 @@ class StandardMetadataParser(RecipeMetadataParser):
# Get additional info from Civitai if client is available
if civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(model_version_id)
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner
civitai_info_tuple,
recipe_scanner,
base_model_counts
)
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA: {e}")
@@ -398,27 +403,43 @@ class StandardMetadataParser(RecipeMetadataParser):
# Extract Civitai resources
if 'Civitai resources:' in user_comment:
resources_part = user_comment.split('Civitai resources:', 1)[1]
if '],' in resources_part:
resources_json = resources_part.split('],', 1)[0] + ']'
try:
resources = json.loads(resources_json)
# Filter loras and checkpoints
for resource in resources:
if resource.get('type') == 'lora':
# 确保 weight 字段被正确保留
lora_entry = resource.copy()
# 如果找不到 weight默认为 1.0
if 'weight' not in lora_entry:
lora_entry['weight'] = 1.0
# Ensure modelVersionName is included
if 'modelVersionName' not in lora_entry:
lora_entry['modelVersionName'] = ''
metadata['loras'].append(lora_entry)
elif resource.get('type') == 'checkpoint':
metadata['checkpoint'] = resource
except json.JSONDecodeError:
pass
resources_part = user_comment.split('Civitai resources:', 1)[1].strip()
# Look for the opening and closing brackets to extract the JSON array
if resources_part.startswith('['):
# Find the position of the closing bracket
bracket_count = 0
end_pos = -1
for i, char in enumerate(resources_part):
if char == '[':
bracket_count += 1
elif char == ']':
bracket_count -= 1
if bracket_count == 0:
end_pos = i
break
if end_pos != -1:
resources_json = resources_part[:end_pos+1]
try:
resources = json.loads(resources_json)
# Filter loras and checkpoints
for resource in resources:
if resource.get('type') == 'lora':
# 确保 weight 字段被正确保留
lora_entry = resource.copy()
# 如果找不到 weight默认为 1.0
if 'weight' not in lora_entry:
lora_entry['weight'] = 1.0
# Ensure modelVersionName is included
if 'modelVersionName' not in lora_entry:
lora_entry['modelVersionName'] = ''
metadata['loras'].append(lora_entry)
elif resource.get('type') == 'checkpoint':
metadata['checkpoint'] = resource
except json.JSONDecodeError:
pass
return metadata
except Exception as e:
@@ -621,11 +642,11 @@ class ComfyMetadataParser(RecipeMetadataParser):
# Get additional info from Civitai if client is available
if civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(model_version_id)
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
civitai_info_tuple,
recipe_scanner
)
except Exception as e:
@@ -660,7 +681,8 @@ class ComfyMetadataParser(RecipeMetadataParser):
# Get additional checkpoint info from Civitai
if civitai_client:
try:
civitai_info = await civitai_client.get_model_version_info(checkpoint_version_id)
civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
# Populate checkpoint with Civitai info
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
except Exception as e:

View File

@@ -95,7 +95,7 @@ class ModelRouteUtils:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
# Save the optimized WebP image
@@ -387,7 +387,7 @@ class ModelRouteUtils:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=True
preserve_metadata=False
)
extension = '.webp' # Use .webp without .preview part

273
py/utils/usage_stats.py Normal file
View File

@@ -0,0 +1,273 @@
import os
import json
import sys
import time
import asyncio
import logging
from typing import Dict, Set
from ..config import config
from ..services.service_registry import ServiceRegistry
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
if not standalone_mode:
from ..metadata_collector.metadata_registry import MetadataRegistry
from ..metadata_collector.constants import MODELS, LORAS
logger = logging.getLogger(__name__)
class UsageStats:
"""Track usage statistics for models and save to JSON"""
_instance = None
_lock = asyncio.Lock() # For thread safety
# Default stats file name
STATS_FILENAME = "lora_manager_stats.json"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# Initialize stats storage
self.stats = {
"checkpoints": {}, # sha256 -> count
"loras": {}, # sha256 -> count
"total_executions": 0,
"last_save_time": 0
}
# Queue for prompt_ids to process
self.pending_prompt_ids = set()
# Load existing stats if available
self._stats_file_path = self._get_stats_file_path()
self._load_stats()
# Save interval in seconds
self.save_interval = 90 # 1.5 minutes
# Start background task to process queued prompt_ids
self._bg_task = asyncio.create_task(self._background_processor())
self._initialized = True
logger.info("Usage statistics tracker initialized")
def _get_stats_file_path(self) -> str:
"""Get the path to the stats JSON file"""
if not config.loras_roots or len(config.loras_roots) == 0:
# Fallback to temporary directory if no lora roots
return os.path.join(config.temp_directory, self.STATS_FILENAME)
# Use the first lora root
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
def _load_stats(self):
"""Load existing statistics from file"""
try:
if os.path.exists(self._stats_file_path):
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
loaded_stats = json.load(f)
# Update our stats with loaded data
if isinstance(loaded_stats, dict):
# Update individual sections to maintain structure
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
self.stats["checkpoints"] = loaded_stats["checkpoints"]
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
self.stats["loras"] = loaded_stats["loras"]
if "total_executions" in loaded_stats:
self.stats["total_executions"] = loaded_stats["total_executions"]
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
except Exception as e:
logger.error(f"Error loading usage statistics: {e}")
async def save_stats(self, force=False):
"""Save statistics to file"""
try:
# Only save if it's been at least save_interval since last save or force is True
current_time = time.time()
if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval:
return False
# Use a lock to prevent concurrent writes
async with self._lock:
# Update last save time
self.stats["last_save_time"] = current_time
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
# Write to a temporary file first, then move it to avoid corruption
temp_path = f"{self._stats_file_path}.tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
# Replace the old file with the new one
os.replace(temp_path, self._stats_file_path)
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
return True
except Exception as e:
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
return False
def register_execution(self, prompt_id):
"""Register a completed execution by prompt_id for later processing"""
if prompt_id:
self.pending_prompt_ids.add(prompt_id)
async def _background_processor(self):
"""Background task to process queued prompt_ids"""
try:
while True:
# Wait a short interval before checking for new prompt_ids
await asyncio.sleep(5) # Check every 5 seconds
# Process any pending prompt_ids
if self.pending_prompt_ids:
async with self._lock:
# Get a copy of the set and clear original
prompt_ids = self.pending_prompt_ids.copy()
self.pending_prompt_ids.clear()
# Process each prompt_id
registry = MetadataRegistry()
for prompt_id in prompt_ids:
try:
metadata = registry.get_metadata(prompt_id)
await self._process_metadata(metadata)
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
# Periodically save stats
await self.save_stats()
except asyncio.CancelledError:
# Task was cancelled, clean up
await self.save_stats(force=True)
except Exception as e:
logger.error(f"Error in background processing task: {e}", exc_info=True)
# Restart the task after a delay if it fails
asyncio.create_task(self._restart_background_task())
async def _restart_background_task(self):
"""Restart the background task after a delay"""
await asyncio.sleep(30) # Wait 30 seconds before restarting
self._bg_task = asyncio.create_task(self._background_processor())
async def _process_metadata(self, metadata):
"""Process metadata from an execution"""
if not metadata or not isinstance(metadata, dict):
return
# Increment total executions count
self.stats["total_executions"] += 1
# Process checkpoints
if MODELS in metadata and isinstance(metadata[MODELS], dict):
await self._process_checkpoints(metadata[MODELS])
# Process loras
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS])
async def _process_checkpoints(self, models_data):
"""Process checkpoint models from metadata"""
try:
# Get checkpoint scanner service
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
if not checkpoint_scanner:
logger.warning("Checkpoint scanner not available for usage tracking")
return
for node_id, model_info in models_data.items():
if not isinstance(model_info, dict):
continue
# Check if this is a checkpoint model
model_type = model_info.get("type")
if model_type == "checkpoint":
model_name = model_info.get("name")
if not model_name:
continue
# Clean up filename (remove extension if present)
model_filename = os.path.splitext(os.path.basename(model_name))[0]
# Get hash for this checkpoint
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
if model_hash:
# Update stats for this checkpoint
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
except Exception as e:
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
async def _process_loras(self, loras_data):
"""Process LoRA models from metadata"""
try:
# Get LoRA scanner service
lora_scanner = await ServiceRegistry.get_lora_scanner()
if not lora_scanner:
logger.warning("LoRA scanner not available for usage tracking")
return
for node_id, lora_info in loras_data.items():
if not isinstance(lora_info, dict):
continue
# Get the list of LoRAs from standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
if not isinstance(lora, dict):
continue
lora_name = lora.get("name")
if not lora_name:
continue
# Get hash for this LoRA
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
if lora_hash:
# Update stats for this LoRA
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
async def get_stats(self):
"""Get current usage statistics"""
return self.stats
async def get_model_usage_count(self, model_type, sha256):
"""Get usage count for a specific model by hash"""
if model_type == "checkpoint":
return self.stats["checkpoints"].get(sha256, 0)
elif model_type == "lora":
return self.stats["loras"].get(sha256, 0)
return 0
async def process_execution(self, prompt_id):
"""Process a prompt execution immediately (synchronous approach)"""
if not prompt_id:
return
try:
# Process metadata for this prompt_id
registry = MetadataRegistry()
metadata = registry.get_metadata(prompt_id)
if metadata:
await self._process_metadata(metadata)
# Save stats if needed
await self.save_stats()
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)

View File

@@ -1,7 +1,7 @@
[project]
name = "comfyui-lora-manager"
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
version = "0.8.6"
version = "0.8.10"
license = {file = "LICENSE"}
dependencies = [
"aiohttp",
@@ -12,7 +12,8 @@ dependencies = [
"piexif",
"Pillow",
"olefile", # for getting rid of warning message
"requests"
"requests",
"toml"
]
[project.urls]

View File

@@ -6,4 +6,7 @@ beautifulsoup4
piexif
Pillow
olefile
requests
requests
toml
numpy
torch

14
settings.json.example Normal file
View File

@@ -0,0 +1,14 @@
{
"civitai_api_key": "your_civitai_api_key_here",
"show_only_sfw": false,
"folder_paths": {
"loras": [
"C:/path/to/your/loras_folder",
"C:/path/to/another/loras_folder"
],
"checkpoints": [
"C:/path/to/your/checkpoints_folder",
"C:/path/to/another/checkpoints_folder"
]
}
}

347
standalone.py Normal file
View File

@@ -0,0 +1,347 @@
import os
import sys
import json
# Create mock folder_paths module BEFORE any other imports
class MockFolderPaths:
@staticmethod
def get_folder_paths(folder_name):
# Load paths from settings.json
settings_path = os.path.join(os.path.dirname(__file__), 'settings.json')
try:
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings = json.load(f)
# For diffusion_models, combine unet and diffusers paths
if folder_name == "diffusion_models":
paths = []
if 'folder_paths' in settings:
if 'unet' in settings['folder_paths']:
paths.extend(settings['folder_paths']['unet'])
if 'diffusers' in settings['folder_paths']:
paths.extend(settings['folder_paths']['diffusers'])
# Filter out paths that don't exist
valid_paths = [p for p in paths if os.path.exists(p)]
if valid_paths:
return valid_paths
else:
print(f"Warning: No valid paths found for {folder_name}")
# For other folder names, return their paths directly
elif 'folder_paths' in settings and folder_name in settings['folder_paths']:
paths = settings['folder_paths'][folder_name]
valid_paths = [p for p in paths if os.path.exists(p)]
if valid_paths:
return valid_paths
else:
print(f"Warning: No valid paths found for {folder_name}")
except Exception as e:
print(f"Error loading folder paths from settings: {e}")
# Fallback to empty list if no paths found
return []
@staticmethod
def get_temp_directory():
return os.path.join(os.path.dirname(__file__), 'temp')
@staticmethod
def set_temp_directory(path):
os.makedirs(path, exist_ok=True)
return path
# Create mock server module with PromptServer
class MockPromptServer:
def __init__(self):
self.app = None
def send_sync(self, *args, **kwargs):
pass
# Create mock metadata_collector module
class MockMetadataCollector:
def init(self):
pass
def get_metadata(self, prompt_id=None):
return {}
# Initialize basic mocks before any imports
sys.modules['folder_paths'] = MockFolderPaths()
sys.modules['server'] = type('server', (), {'PromptServer': MockPromptServer()})
sys.modules['py.metadata_collector'] = MockMetadataCollector()
# Now we can safely import modules that depend on folder_paths and server
import argparse
import asyncio
import logging
from aiohttp import web
# Setup logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger("lora-manager-standalone")
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Now we can import the global config from our local modules
from py.config import config
class StandaloneServer:
"""Server implementation for standalone mode"""
def __init__(self):
self.app = web.Application(logger=logger)
self.instance = self # Make it compatible with PromptServer.instance pattern
# Ensure the app's access logger is configured to reduce verbosity
self.app._subapps = [] # Ensure this exists to avoid AttributeError
# Configure access logging for the app
self.app.on_startup.append(self._configure_access_logger)
async def _configure_access_logger(self, app):
"""Configure access logger to reduce verbosity"""
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# If using aiohttp>=3.8.0, configure access logger through app directly
if hasattr(app, 'access_logger'):
app.access_logger.setLevel(logging.WARNING)
async def setup(self):
"""Set up the standalone server"""
# Create placeholders for compatibility with ComfyUI's implementation
self.last_prompt_id = None
self.last_node_id = None
self.client_id = None
# Set up routes
self.setup_routes()
# Add startup and shutdown handlers
self.app.on_startup.append(self.on_startup)
self.app.on_shutdown.append(self.on_shutdown)
def setup_routes(self):
"""Set up basic routes"""
# Add a simple status endpoint
self.app.router.add_get('/', self.handle_status)
async def handle_status(self, request):
"""Handle status request by redirecting to loras page"""
# Redirect to loras page instead of showing status
raise web.HTTPFound('/loras')
# Original JSON response (commented out)
# return web.json_response({
# "status": "running",
# "mode": "standalone",
# "loras_roots": config.loras_roots,
# "checkpoints_roots": config.checkpoints_roots
# })
async def on_startup(self, app):
"""Startup handler"""
logger.info("LoRA Manager standalone server starting...")
async def on_shutdown(self, app):
"""Shutdown handler"""
logger.info("LoRA Manager standalone server shutting down...")
def send_sync(self, event_type, data, sid=None):
"""Stub for compatibility with PromptServer"""
# In standalone mode, we don't have the same websocket system
pass
async def start(self, host='127.0.0.1', port=8188):
"""Start the server"""
runner = web.AppRunner(self.app)
await runner.setup()
site = web.TCPSite(runner, host, port)
await site.start()
# Log the server address with a clickable localhost URL regardless of the actual binding
logger.info(f"Server started at http://127.0.0.1:{port}")
# Keep the server running
while True:
await asyncio.sleep(3600) # Sleep for a long time
async def publish_loop(self):
"""Stub for compatibility with PromptServer"""
# This method exists in ComfyUI's server but we don't need it
pass
# After all mocks are in place, import LoraManager
from py.lora_manager import LoraManager
class StandaloneLoraManager(LoraManager):
"""Extended LoraManager for standalone mode"""
@classmethod
def add_routes(cls, server_instance):
"""Initialize and register all routes for standalone mode"""
app = server_instance.app
# Store app in a global-like location for compatibility
sys.modules['server'].PromptServer.instance = server_instance
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
added_targets = set() # Track already added target paths
# Add static routes for each lora root
for idx, root in enumerate(config.loras_roots, start=1):
if not os.path.exists(root):
logger.warning(f"Lora root path does not exist: {root}")
continue
preview_path = f'/loras_static/root{idx}/preview'
# Check if this root is a link path in the mappings
real_root = root
for target, link in config._path_mappings.items():
if os.path.normpath(link) == os.path.normpath(root):
# If so, route should point to the target (real path)
real_root = target
break
# Normalize and standardize path display for consistency
display_root = real_root.replace('\\', '/')
# Add static route for original path - use the normalized path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {display_root}")
# Record route mapping with normalized path
config.add_route_mapping(real_root, preview_path)
added_targets.add(os.path.normpath(real_root))
# Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1):
if not os.path.exists(root):
logger.warning(f"Checkpoint root path does not exist: {root}")
continue
preview_path = f'/checkpoints_static/root{idx}/preview'
# Check if this root is a link path in the mappings
real_root = root
for target, link in config._path_mappings.items():
if os.path.normpath(link) == os.path.normpath(root):
# If so, route should point to the target (real path)
real_root = target
break
# Normalize and standardize path display for consistency
display_root = real_root.replace('\\', '/')
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {display_root}")
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(os.path.normpath(real_root))
# Add static routes for symlink target paths that aren't already covered
link_idx = {
'lora': 1,
'checkpoint': 1
}
for target_path, link_path in config._path_mappings.items():
norm_target = os.path.normpath(target_path)
if norm_target not in added_targets:
# Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots)
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots)
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
link_idx["checkpoint"] += 1
else:
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
link_idx["lora"] += 1
# Display path with forward slashes for consistency
display_target = target_path.replace('\\', '/')
app.router.add_static(route_path, target_path)
logger.info(f"Added static route for link target {route_path} -> {display_target}")
config.add_route_mapping(target_path, route_path)
added_targets.add(norm_target)
# Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
from py.routes.lora_routes import LoraRoutes
from py.routes.api_routes import ApiRoutes
from py.routes.recipe_routes import RecipeRoutes
from py.routes.checkpoints_routes import CheckpointsRoutes
from py.routes.update_routes import UpdateRoutes
from py.routes.usage_stats_routes import UsageStatsRoutes
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
# Initialize routes
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
UsageStatsRoutes.setup_routes(app)
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
def parse_args():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description="LoRA Manager Standalone Server")
parser.add_argument("--host", type=str, default="0.0.0.0",
help="Host address to bind the server to (default: 0.0.0.0)")
parser.add_argument("--port", type=int, default=8188,
help="Port to bind the server to (default: 8188, access via http://localhost:8188/loras)")
# parser.add_argument("--loras", type=str, nargs="+",
# help="Additional paths to LoRA model directories (optional if settings.json has paths)")
# parser.add_argument("--checkpoints", type=str, nargs="+",
# help="Additional paths to checkpoint model directories (optional if settings.json has paths)")
parser.add_argument("--log-level", type=str, default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Logging level")
return parser.parse_args()
async def main():
"""Main entry point for standalone mode"""
args = parse_args()
# Set log level
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Explicitly configure aiohttp access logger regardless of selected log level
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Create the server instance
server = StandaloneServer()
# Initialize routes via the standalone lora manager
StandaloneLoraManager.add_routes(server)
# Set up and start the server
await server.setup()
await server.start(host=args.host, port=args.port)
if __name__ == "__main__":
try:
# Run the main function
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Server stopped by user")

View File

@@ -59,6 +59,16 @@ html, body {
--scrollbar-width: 8px; /* 添加滚动条宽度变量 */
}
html[data-theme="dark"] {
background-color: #1a1a1a !important;
color-scheme: dark;
}
html[data-theme="light"] {
background-color: #ffffff !important;
color-scheme: light;
}
[data-theme="dark"] {
--bg-color: #1a1a1a;
--text-color: #e0e0e0;

View File

@@ -192,12 +192,43 @@
margin-left: var(--space-1);
cursor: pointer;
color: white;
transition: opacity 0.2s;
font-size: 0.9em;
transition: opacity 0.2s, transform 0.15s ease;
font-size: 1.0em; /* Increased from 0.9em for better visibility */
width: 16px; /* Fixed width for consistent spacing */
height: 16px; /* Fixed height for larger touch target */
display: flex;
align-items: center;
justify-content: center;
border-radius: 50%;
padding: 4px; /* Add padding to increase clickable area */
box-sizing: content-box; /* Ensure padding adds to dimensions */
position: relative; /* For proper positioning */
margin: 0; /* Reset margin */
}
.card-actions i::before {
position: absolute; /* Position the icon glyph */
top: 50%;
left: 50%;
transform: translate(-50%, -50%); /* Center the icon */
}
.card-actions {
display: flex;
gap: var(--space-1); /* Use gap instead of margin for spacing between icons */
align-items: center;
}
.card-actions i:hover {
opacity: 0.8;
opacity: 0.9;
transform: scale(1.1);
background-color: rgba(255, 255, 255, 0.1);
}
/* Style for active favorites */
.favorite-active {
color: #ffc107 !important; /* Gold color for favorites */
text-shadow: 0 0 5px rgba(255, 193, 7, 0.5);
}
/* 响应式设计 */

View File

@@ -81,6 +81,22 @@
opacity: 1;
}
/* Controls */
.control-group button.favorite-filter {
position: relative;
overflow: hidden;
}
.control-group button.favorite-filter.active {
background: var(--lora-accent);
color: white;
}
.control-group button.favorite-filter i {
margin-right: 4px;
color: #ffc107;
}
/* Active state for buttons that can be toggled */
.control-group button.active {
background: var(--lora-accent);

View File

@@ -2,7 +2,7 @@
import { state, getCurrentPageState } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { showDeleteModal, confirmDelete } from '../utils/modalUtils.js';
import { getSessionItem } from '../utils/storageHelpers.js';
import { getSessionItem, saveMapToStorage } from '../utils/storageHelpers.js';
/**
* Shared functionality for handling models (loras and checkpoints)
@@ -45,6 +45,11 @@ export async function loadMoreModels(options = {}) {
params.append('folder', pageState.activeFolder);
}
// Add favorites filter parameter if enabled
if (pageState.showFavoritesOnly) {
params.append('favorites_only', 'true');
}
// Add search parameters if there's a search term
if (pageState.filters?.search) {
params.append('search', pageState.filters.search);
@@ -424,12 +429,20 @@ async function uploadPreview(filePath, file, modelType = 'lora') {
const previewContainer = card.querySelector('.card-preview');
const oldPreview = previewContainer.querySelector('img, video');
// For LoRA models, use timestamp to prevent caching
if (modelType === 'lora') {
state.previewVersions?.set(filePath, Date.now());
// Get the current page's previewVersions Map based on model type
const pageType = modelType === 'checkpoint' ? 'checkpoints' : 'loras';
const previewVersions = state.pages[pageType].previewVersions;
// Update the version timestamp
const timestamp = Date.now();
if (previewVersions) {
previewVersions.set(filePath, timestamp);
// Save the updated Map to localStorage
const storageKey = modelType === 'checkpoint' ? 'checkpoint_preview_versions' : 'lora_preview_versions';
saveMapToStorage(storageKey, previewVersions);
}
const timestamp = Date.now();
const previewUrl = data.preview_url ?
`${data.preview_url}?t=${timestamp}` :
`/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`;

View File

@@ -5,7 +5,8 @@ import {
refreshModels as baseRefreshModels,
deleteModel as baseDeleteModel,
replaceModelPreview,
fetchCivitaiMetadata
fetchCivitaiMetadata,
refreshSingleModelMetadata
} from './baseModelApi.js';
// Load more checkpoints with pagination
@@ -54,4 +55,34 @@ export async function fetchCivitai() {
fetchEndpoint: '/api/checkpoints/fetch-all-civitai',
resetAndReloadFunction: resetAndReload
});
}
// Refresh single checkpoint metadata
export async function refreshSingleCheckpointMetadata(filePath) {
return refreshSingleModelMetadata(filePath, 'checkpoint');
}
/**
* Save model metadata to the server
* @param {string} filePath - Path to the model file
* @param {Object} data - Metadata to save
* @returns {Promise} - Promise that resolves with the server response
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/checkpoints/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}

View File

@@ -9,6 +9,31 @@ import {
refreshSingleModelMetadata
} from './baseModelApi.js';
/**
* Save model metadata to the server
* @param {string} filePath - File path
* @param {Object} data - Data to save
* @returns {Promise} Promise of the save operation
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/loras/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}
export async function loadMoreLoras(resetPage = false, updateFolders = false) {
return loadMoreModels({
resetPage,

View File

@@ -4,6 +4,7 @@ import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';
import { createPageControls } from './components/controls/index.js';
import { loadMoreCheckpoints } from './api/checkpointApi.js';
import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js';
import { CheckpointContextMenu } from './components/ContextMenu/index.js';
// Initialize the Checkpoints page
class CheckpointsPageManager {
@@ -34,6 +35,9 @@ class CheckpointsPageManager {
this.pageControls.restoreFolderFilter();
this.pageControls.initFolderTagsVisibility();
// Initialize context menu
new CheckpointContextMenu();
// Initialize infinite scroll
initializeInfiniteScroll('checkpoints');

View File

@@ -1,8 +1,8 @@
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showCheckpointModal } from './checkpointModal/index.js';
import { NSFW_LEVELS } from '../utils/constants.js';
import { replaceCheckpointPreview as apiReplaceCheckpointPreview } from '../api/checkpointApi.js';
import { replaceCheckpointPreview as apiReplaceCheckpointPreview, saveModelMetadata } from '../api/checkpointApi.js';
export function createCheckpointCard(checkpoint) {
const card = document.createElement('div');
@@ -17,6 +17,7 @@ export function createCheckpointCard(checkpoint) {
card.dataset.from_civitai = checkpoint.from_civitai;
card.dataset.notes = checkpoint.notes || '';
card.dataset.base_model = checkpoint.base_model || 'Unknown';
card.dataset.favorite = checkpoint.favorite ? 'true' : 'false';
// Store metadata if available
if (checkpoint.civitai) {
@@ -44,7 +45,10 @@ export function createCheckpointCard(checkpoint) {
// Determine preview URL
const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png';
const version = state.previewVersions ? state.previewVersions.get(checkpoint.file_path) : null;
// Get the page-specific previewVersions map
const previewVersions = state.pages.checkpoints.previewVersions || new Map();
const version = previewVersions.get(checkpoint.file_path);
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
// Determine NSFW warning text based on level
@@ -62,6 +66,9 @@ export function createCheckpointCard(checkpoint) {
const isVideo = previewUrl.endsWith('.mp4');
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
// Get favorite status from checkpoint data
const isFavorite = checkpoint.favorite === true;
card.innerHTML = `
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
${isVideo ?
@@ -79,6 +86,9 @@ export function createCheckpointCard(checkpoint) {
${checkpoint.base_model}
</span>
<div class="card-actions">
<i class="${isFavorite ? 'fas fa-star favorite-active' : 'far fa-star'}"
title="${isFavorite ? 'Remove from favorites' : 'Add to favorites'}">
</i>
<i class="fas fa-globe"
title="${checkpoint.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
${!checkpoint.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
@@ -195,27 +205,46 @@ export function createCheckpointCard(checkpoint) {
});
}
// Favorite button click event
card.querySelector('.fa-star')?.addEventListener('click', async e => {
e.stopPropagation();
const starIcon = e.currentTarget;
const isFavorite = starIcon.classList.contains('fas');
const newFavoriteState = !isFavorite;
try {
// Save the new favorite state to the server
await saveModelMetadata(card.dataset.filepath, {
favorite: newFavoriteState
});
// Update the UI
if (newFavoriteState) {
starIcon.classList.remove('far');
starIcon.classList.add('fas', 'favorite-active');
starIcon.title = 'Remove from favorites';
card.dataset.favorite = 'true';
showToast('Added to favorites', 'success');
} else {
starIcon.classList.remove('fas', 'favorite-active');
starIcon.classList.add('far');
starIcon.title = 'Add to favorites';
card.dataset.favorite = 'false';
showToast('Removed from favorites', 'success');
}
} catch (error) {
console.error('Failed to update favorite status:', error);
showToast('Failed to update favorite status', 'error');
}
});
// Copy button click event
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
e.stopPropagation();
const checkpointName = card.dataset.file_name;
try {
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(checkpointName);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = checkpointName;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
showToast('Checkpoint name copied', 'success');
await copyToClipboard(checkpointName, 'Checkpoint name copied');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -366,4 +366,7 @@ export class LoraContextMenu {
this.menu.style.display = 'none';
this.currentCard = null;
}
}
}
// For backward compatibility, re-export the LoraContextMenu class
// export { LoraContextMenu } from './ContextMenu/LoraContextMenu.js';

View File

@@ -0,0 +1,84 @@
export class BaseContextMenu {
constructor(menuId, cardSelector) {
this.menu = document.getElementById(menuId);
this.cardSelector = cardSelector;
this.currentCard = null;
if (!this.menu) {
console.error(`Context menu element with ID ${menuId} not found`);
return;
}
this.init();
}
init() {
// Hide menu on regular clicks
document.addEventListener('click', () => this.hideMenu());
// Show menu on right-click on cards
document.addEventListener('contextmenu', (e) => {
const card = e.target.closest(this.cardSelector);
if (!card) {
this.hideMenu();
return;
}
e.preventDefault();
this.showMenu(e.clientX, e.clientY, card);
});
// Handle menu item clicks
this.menu.addEventListener('click', (e) => {
const menuItem = e.target.closest('.context-menu-item');
if (!menuItem || !this.currentCard) return;
const action = menuItem.dataset.action;
if (!action) return;
this.handleMenuAction(action, menuItem);
this.hideMenu();
});
}
handleMenuAction(action, menuItem) {
// Override in subclass
console.warn('handleMenuAction not implemented');
}
showMenu(x, y, card) {
this.currentCard = card;
this.menu.style.display = 'block';
// Get menu dimensions
const menuRect = this.menu.getBoundingClientRect();
// Get viewport dimensions
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
// Calculate position
let finalX = x;
let finalY = y;
// Ensure menu doesn't go offscreen right
if (x + menuRect.width > viewportWidth) {
finalX = x - menuRect.width;
}
// Ensure menu doesn't go offscreen bottom
if (y + menuRect.height > viewportHeight) {
finalY = y - menuRect.height;
}
// Position menu
this.menu.style.left = `${finalX}px`;
this.menu.style.top = `${finalY}px`;
}
hideMenu() {
if (this.menu) {
this.menu.style.display = 'none';
}
this.currentCard = null;
}
}

View File

@@ -0,0 +1,315 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { refreshSingleCheckpointMetadata, saveModelMetadata } from '../../api/checkpointApi.js';
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
import { getStorageItem } from '../../utils/storageHelpers.js';
export class CheckpointContextMenu extends BaseContextMenu {
constructor() {
super('checkpointContextMenu', '.lora-card');
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
// Initialize NSFW Level Selector events
if (this.nsfwSelector) {
this.initNSFWSelector();
}
}
handleMenuAction(action) {
switch(action) {
case 'details':
// Show checkpoint details
this.currentCard.click();
break;
case 'preview':
// Replace checkpoint preview
if (this.currentCard.querySelector('.fa-image')) {
this.currentCard.querySelector('.fa-image').click();
}
break;
case 'civitai':
// Open civitai page
if (this.currentCard.dataset.from_civitai === 'true') {
if (this.currentCard.querySelector('.fa-globe')) {
this.currentCard.querySelector('.fa-globe').click();
}
} else {
showToast('No CivitAI information available', 'info');
}
break;
case 'delete':
// Delete checkpoint
if (this.currentCard.querySelector('.fa-trash')) {
this.currentCard.querySelector('.fa-trash').click();
}
break;
case 'copyname':
// Copy checkpoint name
if (this.currentCard.querySelector('.fa-copy')) {
this.currentCard.querySelector('.fa-copy').click();
}
break;
case 'refresh-metadata':
// Refresh metadata from CivitAI
refreshSingleCheckpointMetadata(this.currentCard.dataset.filepath);
break;
case 'set-nsfw':
// Set NSFW level
this.showNSFWLevelSelector(null, null, this.currentCard);
break;
case 'move':
// Move to folder (placeholder)
showToast('Move to folder feature coming soon', 'info');
break;
}
}
// NSFW Selector methods
initNSFWSelector() {
// Close button
const closeBtn = this.nsfwSelector.querySelector('.close-nsfw-selector');
closeBtn.addEventListener('click', () => {
this.nsfwSelector.style.display = 'none';
});
// Level buttons
const levelButtons = this.nsfwSelector.querySelectorAll('.nsfw-level-btn');
levelButtons.forEach(btn => {
btn.addEventListener('click', async () => {
const level = parseInt(btn.dataset.level);
const filePath = this.nsfwSelector.dataset.cardPath;
if (!filePath) return;
try {
await saveModelMetadata(filePath, { preview_nsfw_level: level });
// Update card data
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
let metaData = {};
try {
metaData = JSON.parse(card.dataset.meta || '{}');
} catch (err) {
console.error('Error parsing metadata:', err);
}
metaData.preview_nsfw_level = level;
card.dataset.meta = JSON.stringify(metaData);
card.dataset.nsfwLevel = level.toString();
// Apply blur effect immediately
this.updateCardBlurEffect(card, level);
}
showToast(`Content rating set to ${getNSFWLevelName(level)}`, 'success');
this.nsfwSelector.style.display = 'none';
} catch (error) {
showToast(`Failed to set content rating: ${error.message}`, 'error');
}
});
});
// Close when clicking outside
document.addEventListener('click', (e) => {
if (this.nsfwSelector.style.display === 'block' &&
!this.nsfwSelector.contains(e.target) &&
!e.target.closest('.context-menu-item[data-action="set-nsfw"]')) {
this.nsfwSelector.style.display = 'none';
}
});
}
updateCardBlurEffect(card, level) {
// Get user settings for blur threshold
const blurThreshold = parseInt(getStorageItem('nsfwBlurLevel') || '4');
// Get card preview container
const previewContainer = card.querySelector('.card-preview');
if (!previewContainer) return;
// Get preview media element
const previewMedia = previewContainer.querySelector('img') || previewContainer.querySelector('video');
if (!previewMedia) return;
// Check if blur should be applied
if (level >= blurThreshold) {
// Add blur class to the preview container
previewContainer.classList.add('blurred');
// Get or create the NSFW overlay
let nsfwOverlay = previewContainer.querySelector('.nsfw-overlay');
if (!nsfwOverlay) {
// Create new overlay
nsfwOverlay = document.createElement('div');
nsfwOverlay.className = 'nsfw-overlay';
// Create and configure the warning content
const warningContent = document.createElement('div');
warningContent.className = 'nsfw-warning';
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Add warning text and show button
warningContent.innerHTML = `
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
`;
// Add click event to the show button
const showBtn = warningContent.querySelector('.show-content-btn');
showBtn.addEventListener('click', (e) => {
e.stopPropagation();
previewContainer.classList.remove('blurred');
nsfwOverlay.style.display = 'none';
// Update toggle button icon if it exists
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
});
nsfwOverlay.appendChild(warningContent);
previewContainer.appendChild(nsfwOverlay);
} else {
// Update existing overlay
const warningText = nsfwOverlay.querySelector('p');
if (warningText) {
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
warningText.textContent = nsfwText;
}
nsfwOverlay.style.display = 'flex';
}
// Get or create the toggle button in the header
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
let toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (!toggleBtn) {
toggleBtn = document.createElement('button');
toggleBtn.className = 'toggle-blur-btn';
toggleBtn.title = 'Toggle blur';
toggleBtn.innerHTML = '<i class="fas fa-eye"></i>';
// Add click event to toggle button
toggleBtn.addEventListener('click', (e) => {
e.stopPropagation();
const isBlurred = previewContainer.classList.toggle('blurred');
const icon = toggleBtn.querySelector('i');
// Update icon and overlay visibility
if (isBlurred) {
icon.className = 'fas fa-eye';
nsfwOverlay.style.display = 'flex';
} else {
icon.className = 'fas fa-eye-slash';
nsfwOverlay.style.display = 'none';
}
});
// Add to the beginning of header
cardHeader.insertBefore(toggleBtn, cardHeader.firstChild);
// Update base model label class
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && !baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.add('with-toggle');
}
} else {
// Update existing toggle button
toggleBtn.querySelector('i').className = 'fas fa-eye';
}
}
} else {
// Remove blur
previewContainer.classList.remove('blurred');
// Hide overlay if it exists
const overlay = previewContainer.querySelector('.nsfw-overlay');
if (overlay) overlay.style.display = 'none';
// Remove toggle button when content is set to PG or PG13
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
const toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (toggleBtn) {
// Remove the toggle button completely
toggleBtn.remove();
// Update base model label class if it exists
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.remove('with-toggle');
}
}
}
}
}
showNSFWLevelSelector(x, y, card) {
const selector = document.getElementById('nsfwLevelSelector');
const currentLevelEl = document.getElementById('currentNSFWLevel');
// Get current NSFW level
let currentLevel = 0;
try {
const metaData = JSON.parse(card.dataset.meta || '{}');
currentLevel = metaData.preview_nsfw_level || 0;
// Update if we have no recorded level but have a dataset attribute
if (!currentLevel && card.dataset.nsfwLevel) {
currentLevel = parseInt(card.dataset.nsfwLevel) || 0;
}
} catch (err) {
console.error('Error parsing metadata:', err);
}
currentLevelEl.textContent = getNSFWLevelName(currentLevel);
// Position the selector
if (x && y) {
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
const selectorRect = selector.getBoundingClientRect();
// Center the selector if no coordinates provided
let finalX = (viewportWidth - selectorRect.width) / 2;
let finalY = (viewportHeight - selectorRect.height) / 2;
selector.style.left = `${finalX}px`;
selector.style.top = `${finalY}px`;
}
// Highlight current level button
document.querySelectorAll('.nsfw-level-btn').forEach(btn => {
if (parseInt(btn.dataset.level) === currentLevel) {
btn.classList.add('active');
} else {
btn.classList.remove('active');
}
});
// Store reference to current card
selector.dataset.cardPath = card.dataset.filepath;
// Show selector
selector.style.display = 'block';
}
}

View File

@@ -0,0 +1,309 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { refreshSingleLoraMetadata, saveModelMetadata } from '../../api/loraApi.js';
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
import { getStorageItem } from '../../utils/storageHelpers.js';
export class LoraContextMenu extends BaseContextMenu {
constructor() {
super('loraContextMenu', '.lora-card');
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
// Initialize NSFW Level Selector events
if (this.nsfwSelector) {
this.initNSFWSelector();
}
}
handleMenuAction(action, menuItem) {
switch(action) {
case 'detail':
// Trigger the main card click which shows the modal
this.currentCard.click();
break;
case 'civitai':
// Only trigger if the card is from civitai
if (this.currentCard.dataset.from_civitai === 'true') {
if (this.currentCard.dataset.meta === '{}') {
showToast('Please fetch metadata from CivitAI first', 'info');
} else {
this.currentCard.querySelector('.fa-globe')?.click();
}
} else {
showToast('No CivitAI information available', 'info');
}
break;
case 'copyname':
this.currentCard.querySelector('.fa-copy')?.click();
break;
case 'preview':
this.currentCard.querySelector('.fa-image')?.click();
break;
case 'delete':
this.currentCard.querySelector('.fa-trash')?.click();
break;
case 'move':
moveManager.showMoveModal(this.currentCard.dataset.filepath);
break;
case 'refresh-metadata':
refreshSingleLoraMetadata(this.currentCard.dataset.filepath);
break;
case 'set-nsfw':
this.showNSFWLevelSelector(null, null, this.currentCard);
break;
}
}
// NSFW Selector methods from the original context menu
initNSFWSelector() {
// Close button
const closeBtn = this.nsfwSelector.querySelector('.close-nsfw-selector');
closeBtn.addEventListener('click', () => {
this.nsfwSelector.style.display = 'none';
});
// Level buttons
const levelButtons = this.nsfwSelector.querySelectorAll('.nsfw-level-btn');
levelButtons.forEach(btn => {
btn.addEventListener('click', async () => {
const level = parseInt(btn.dataset.level);
const filePath = this.nsfwSelector.dataset.cardPath;
if (!filePath) return;
try {
await this.saveModelMetadata(filePath, { preview_nsfw_level: level });
// Update card data
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
let metaData = {};
try {
metaData = JSON.parse(card.dataset.meta || '{}');
} catch (err) {
console.error('Error parsing metadata:', err);
}
metaData.preview_nsfw_level = level;
card.dataset.meta = JSON.stringify(metaData);
card.dataset.nsfwLevel = level.toString();
// Apply blur effect immediately
this.updateCardBlurEffect(card, level);
}
showToast(`Content rating set to ${getNSFWLevelName(level)}`, 'success');
this.nsfwSelector.style.display = 'none';
} catch (error) {
showToast(`Failed to set content rating: ${error.message}`, 'error');
}
});
});
// Close when clicking outside
document.addEventListener('click', (e) => {
if (this.nsfwSelector.style.display === 'block' &&
!this.nsfwSelector.contains(e.target) &&
!e.target.closest('.context-menu-item[data-action="set-nsfw"]')) {
this.nsfwSelector.style.display = 'none';
}
});
}
async saveModelMetadata(filePath, data) {
return saveModelMetadata(filePath, data);
}
updateCardBlurEffect(card, level) {
// Get user settings for blur threshold
const blurThreshold = parseInt(getStorageItem('nsfwBlurLevel') || '4');
// Get card preview container
const previewContainer = card.querySelector('.card-preview');
if (!previewContainer) return;
// Get preview media element
const previewMedia = previewContainer.querySelector('img') || previewContainer.querySelector('video');
if (!previewMedia) return;
// Check if blur should be applied
if (level >= blurThreshold) {
// Add blur class to the preview container
previewContainer.classList.add('blurred');
// Get or create the NSFW overlay
let nsfwOverlay = previewContainer.querySelector('.nsfw-overlay');
if (!nsfwOverlay) {
// Create new overlay
nsfwOverlay = document.createElement('div');
nsfwOverlay.className = 'nsfw-overlay';
// Create and configure the warning content
const warningContent = document.createElement('div');
warningContent.className = 'nsfw-warning';
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Add warning text and show button
warningContent.innerHTML = `
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
`;
// Add click event to the show button
const showBtn = warningContent.querySelector('.show-content-btn');
showBtn.addEventListener('click', (e) => {
e.stopPropagation();
previewContainer.classList.remove('blurred');
nsfwOverlay.style.display = 'none';
// Update toggle button icon if it exists
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
});
nsfwOverlay.appendChild(warningContent);
previewContainer.appendChild(nsfwOverlay);
} else {
// Update existing overlay
const warningText = nsfwOverlay.querySelector('p');
if (warningText) {
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
warningText.textContent = nsfwText;
}
nsfwOverlay.style.display = 'flex';
}
// Get or create the toggle button in the header
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
let toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (!toggleBtn) {
toggleBtn = document.createElement('button');
toggleBtn.className = 'toggle-blur-btn';
toggleBtn.title = 'Toggle blur';
toggleBtn.innerHTML = '<i class="fas fa-eye"></i>';
// Add click event to toggle button
toggleBtn.addEventListener('click', (e) => {
e.stopPropagation();
const isBlurred = previewContainer.classList.toggle('blurred');
const icon = toggleBtn.querySelector('i');
// Update icon and overlay visibility
if (isBlurred) {
icon.className = 'fas fa-eye';
nsfwOverlay.style.display = 'flex';
} else {
icon.className = 'fas fa-eye-slash';
nsfwOverlay.style.display = 'none';
}
});
// Add to the beginning of header
cardHeader.insertBefore(toggleBtn, cardHeader.firstChild);
// Update base model label class
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && !baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.add('with-toggle');
}
} else {
// Update existing toggle button
toggleBtn.querySelector('i').className = 'fas fa-eye';
}
}
} else {
// Remove blur
previewContainer.classList.remove('blurred');
// Hide overlay if it exists
const overlay = previewContainer.querySelector('.nsfw-overlay');
if (overlay) overlay.style.display = 'none';
// Remove toggle button when content is set to PG or PG13
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
const toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (toggleBtn) {
// Remove the toggle button completely
toggleBtn.remove();
// Update base model label class if it exists
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.remove('with-toggle');
}
}
}
}
}
showNSFWLevelSelector(x, y, card) {
const selector = document.getElementById('nsfwLevelSelector');
const currentLevelEl = document.getElementById('currentNSFWLevel');
// Get current NSFW level
let currentLevel = 0;
try {
const metaData = JSON.parse(card.dataset.meta || '{}');
currentLevel = metaData.preview_nsfw_level || 0;
// Update if we have no recorded level but have a dataset attribute
if (!currentLevel && card.dataset.nsfwLevel) {
currentLevel = parseInt(card.dataset.nsfwLevel) || 0;
}
} catch (err) {
console.error('Error parsing metadata:', err);
}
currentLevelEl.textContent = getNSFWLevelName(currentLevel);
// Position the selector
if (x && y) {
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
const selectorRect = selector.getBoundingClientRect();
// Center the selector if no coordinates provided
let finalX = (viewportWidth - selectorRect.width) / 2;
let finalY = (viewportHeight - selectorRect.height) / 2;
selector.style.left = `${finalX}px`;
selector.style.top = `${finalY}px`;
}
// Highlight current level button
document.querySelectorAll('.nsfw-level-btn').forEach(btn => {
if (parseInt(btn.dataset.level) === currentLevel) {
btn.classList.add('active');
} else {
btn.classList.remove('active');
}
});
// Store reference to current card
selector.dataset.cardPath = card.dataset.filepath;
// Show selector
selector.style.display = 'block';
}
}

View File

@@ -0,0 +1,205 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { showToast } from '../../utils/uiHelpers.js';
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
import { state } from '../../state/index.js';
export class RecipeContextMenu extends BaseContextMenu {
constructor() {
super('recipeContextMenu', '.lora-card');
}
showMenu(x, y, card) {
// Call the parent method first to handle basic positioning
super.showMenu(x, y, card);
// Get recipe data to check for missing LoRAs
const recipeId = card.dataset.id;
const missingLorasItem = this.menu.querySelector('.download-missing-item');
if (recipeId && missingLorasItem) {
// Check if this card has missing LoRAs
const loraCountElement = card.querySelector('.lora-count');
const hasMissingLoras = loraCountElement && loraCountElement.classList.contains('missing');
// Show/hide the download missing LoRAs option based on missing status
if (hasMissingLoras) {
missingLorasItem.style.display = 'flex';
} else {
missingLorasItem.style.display = 'none';
}
}
}
handleMenuAction(action) {
const recipeId = this.currentCard.dataset.id;
switch(action) {
case 'details':
// Show recipe details
this.currentCard.click();
break;
case 'copy':
// Copy recipe to clipboard
this.currentCard.querySelector('.fa-copy')?.click();
break;
case 'share':
// Share recipe
this.currentCard.querySelector('.fa-share-alt')?.click();
break;
case 'delete':
// Delete recipe
this.currentCard.querySelector('.fa-trash')?.click();
break;
case 'viewloras':
// View all LoRAs in the recipe
this.viewRecipeLoRAs(recipeId);
break;
case 'download-missing':
// Download missing LoRAs
this.downloadMissingLoRAs(recipeId);
break;
}
}
// View all LoRAs in the recipe
viewRecipeLoRAs(recipeId) {
if (!recipeId) {
showToast('Cannot view LoRAs: Missing recipe ID', 'error');
return;
}
// First get the recipe details to access its LoRAs
fetch(`/api/recipe/${recipeId}`)
.then(response => response.json())
.then(recipe => {
// Clear any previous filters first
removeSessionItem('recipe_to_lora_filterLoraHash');
removeSessionItem('recipe_to_lora_filterLoraHashes');
removeSessionItem('filterRecipeName');
removeSessionItem('viewLoraDetail');
// Collect all hashes from the recipe's LoRAs
const loraHashes = recipe.loras
.filter(lora => lora.hash)
.map(lora => lora.hash.toLowerCase());
if (loraHashes.length > 0) {
// Store the LoRA hashes and recipe name in session storage
setSessionItem('recipe_to_lora_filterLoraHashes', JSON.stringify(loraHashes));
setSessionItem('filterRecipeName', recipe.title);
// Navigate to the LoRAs page
window.location.href = '/loras';
} else {
showToast('No LoRAs found in this recipe', 'info');
}
})
.catch(error => {
console.error('Error loading recipe LoRAs:', error);
showToast('Error loading recipe LoRAs: ' + error.message, 'error');
});
}
// Download missing LoRAs
async downloadMissingLoRAs(recipeId) {
if (!recipeId) {
showToast('Cannot download LoRAs: Missing recipe ID', 'error');
return;
}
try {
// First get the recipe details
const response = await fetch(`/api/recipe/${recipeId}`);
const recipe = await response.json();
// Get missing LoRAs
const missingLoras = recipe.loras.filter(lora => !lora.inLibrary && !lora.isDeleted);
if (missingLoras.length === 0) {
showToast('No missing LoRAs to download', 'info');
return;
}
// Show loading toast
state.loadingManager.showSimpleLoading('Getting version info for missing LoRAs...');
// Get version info for each missing LoRA
const missingLorasWithVersionInfoPromises = missingLoras.map(async lora => {
let endpoint;
// Determine which endpoint to use based on available data
if (lora.modelVersionId) {
endpoint = `/api/civitai/model/version/${lora.modelVersionId}`;
} else if (lora.hash) {
endpoint = `/api/civitai/model/hash/${lora.hash}`;
} else {
console.error("Missing both hash and modelVersionId for lora:", lora);
return null;
}
const versionResponse = await fetch(endpoint);
const versionInfo = await versionResponse.json();
// Return original lora data combined with version info
return {
...lora,
civitaiInfo: versionInfo
};
});
// Wait for all API calls to complete
const lorasWithVersionInfo = await Promise.all(missingLorasWithVersionInfoPromises);
// Filter out null values (failed requests)
const validLoras = lorasWithVersionInfo.filter(lora => lora !== null);
if (validLoras.length === 0) {
showToast('Failed to get information for missing LoRAs', 'error');
return;
}
// Prepare data for import manager using the retrieved information
const recipeData = {
loras: validLoras.map(lora => {
const civitaiInfo = lora.civitaiInfo;
const modelFile = civitaiInfo.files ?
civitaiInfo.files.find(file => file.type === 'Model') : null;
return {
// Basic lora info
name: civitaiInfo.model?.name || lora.name,
version: civitaiInfo.name || '',
strength: lora.strength || 1.0,
// Model identifiers
hash: modelFile?.hashes?.SHA256?.toLowerCase() || lora.hash,
modelVersionId: civitaiInfo.id || lora.modelVersionId,
// Metadata
thumbnailUrl: civitaiInfo.images?.[0]?.url || '',
baseModel: civitaiInfo.baseModel || '',
downloadUrl: civitaiInfo.downloadUrl || '',
size: modelFile ? (modelFile.sizeKB * 1024) : 0,
file_name: modelFile ? modelFile.name.split('.')[0] : '',
// Status flags
existsLocally: false,
isDeleted: civitaiInfo.error === "Model not found",
isEarlyAccess: !!civitaiInfo.earlyAccessEndsAt,
earlyAccessEndsAt: civitaiInfo.earlyAccessEndsAt || ''
};
})
};
// Call ImportManager's download missing LoRAs method
window.importManager.downloadMissingLoras(recipeData, recipeId);
} catch (error) {
console.error('Error downloading missing LoRAs:', error);
showToast('Error preparing LoRAs for download: ' + error.message, 'error');
} finally {
if (state.loadingManager) {
state.loadingManager.hide();
}
}
}
}

View File

@@ -0,0 +1,3 @@
export { LoraContextMenu } from './LoraContextMenu.js';
export { RecipeContextMenu } from './RecipeContextMenu.js';
export { CheckpointContextMenu } from './CheckpointContextMenu.js';

View File

@@ -1,9 +1,9 @@
import { showToast, openCivitai } from '../utils/uiHelpers.js';
import { showToast, openCivitai, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showLoraModal } from './loraModal/index.js';
import { bulkManager } from '../managers/BulkManager.js';
import { NSFW_LEVELS } from '../utils/constants.js';
import { replacePreview, deleteModel } from '../api/loraApi.js'
import { replacePreview, deleteModel, saveModelMetadata } from '../api/loraApi.js'
export function createLoraCard(lora) {
const card = document.createElement('div');
@@ -20,6 +20,7 @@ export function createLoraCard(lora) {
card.dataset.usage_tips = lora.usage_tips;
card.dataset.notes = lora.notes;
card.dataset.meta = JSON.stringify(lora.civitai || {});
card.dataset.favorite = lora.favorite ? 'true' : 'false';
// Store tags and model description
if (lora.tags && Array.isArray(lora.tags)) {
@@ -44,7 +45,9 @@ export function createLoraCard(lora) {
card.classList.add('selected');
}
const version = state.previewVersions.get(lora.file_path);
// Get the page-specific previewVersions map
const previewVersions = state.pages.loras.previewVersions || new Map();
const version = previewVersions.get(lora.file_path);
const previewUrl = lora.preview_url || '/loras_static/images/no-preview.png';
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
@@ -63,6 +66,9 @@ export function createLoraCard(lora) {
const isVideo = previewUrl.endsWith('.mp4');
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
// Get favorite status from the lora data
const isFavorite = lora.favorite === true;
card.innerHTML = `
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
${isVideo ?
@@ -80,6 +86,9 @@ export function createLoraCard(lora) {
${lora.base_model}
</span>
<div class="card-actions">
<i class="${isFavorite ? 'fas fa-star favorite-active' : 'far fa-star'}"
title="${isFavorite ? 'Remove from favorites' : 'Add to favorites'}">
</i>
<i class="fas fa-globe"
title="${lora.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
${!lora.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
@@ -133,6 +142,7 @@ export function createLoraCard(lora) {
base_model: card.dataset.base_model,
usage_tips: card.dataset.usage_tips,
notes: card.dataset.notes,
favorite: card.dataset.favorite === 'true',
// Parse civitai metadata from the card's dataset
civitai: (() => {
try {
@@ -196,6 +206,39 @@ export function createLoraCard(lora) {
});
}
// Favorite button click event
card.querySelector('.fa-star')?.addEventListener('click', async e => {
e.stopPropagation();
const starIcon = e.currentTarget;
const isFavorite = starIcon.classList.contains('fas');
const newFavoriteState = !isFavorite;
try {
// Save the new favorite state to the server
await saveModelMetadata(card.dataset.filepath, {
favorite: newFavoriteState
});
// Update the UI
if (newFavoriteState) {
starIcon.classList.remove('far');
starIcon.classList.add('fas', 'favorite-active');
starIcon.title = 'Remove from favorites';
card.dataset.favorite = 'true';
showToast('Added to favorites', 'success');
} else {
starIcon.classList.remove('fas', 'favorite-active');
starIcon.classList.add('far');
starIcon.title = 'Add to favorites';
card.dataset.favorite = 'false';
showToast('Removed from favorites', 'success');
}
} catch (error) {
console.error('Failed to update favorite status:', error);
showToast('Failed to update favorite status', 'error');
}
});
// Copy button click event
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
e.stopPropagation();
@@ -203,26 +246,7 @@ export function createLoraCard(lora) {
const strength = usageTips.strength || 1;
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
try {
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(loraSyntax);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = loraSyntax;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
showToast('LoRA syntax copied', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
await copyToClipboard(loraSyntax, 'LoRA syntax copied');
});
// Civitai button click event

View File

@@ -1,5 +1,5 @@
// Recipe Card Component
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { modalManager } from '../managers/ModalManager.js';
class RecipeCard {
@@ -109,14 +109,11 @@ class RecipeCard {
.then(response => response.json())
.then(data => {
if (data.success && data.syntax) {
return navigator.clipboard.writeText(data.syntax);
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
} else {
throw new Error(data.error || 'No syntax returned');
}
})
.then(() => {
showToast('Recipe syntax copied to clipboard', 'success');
})
.catch(err => {
console.error('Failed to copy: ', err);
showToast('Failed to copy recipe syntax', 'error');
@@ -279,4 +276,4 @@ class RecipeCard {
}
}
export { RecipeCard };
export { RecipeCard };

View File

@@ -1,5 +1,5 @@
// Recipe Modal Component
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { setSessionItem, removeSessionItem } from '../utils/storageHelpers.js';
@@ -747,9 +747,8 @@ class RecipeModal {
const data = await response.json();
if (data.success && data.syntax) {
// Copy to clipboard
await navigator.clipboard.writeText(data.syntax);
showToast('Recipe syntax copied to clipboard', 'success');
// Use the centralized copyToClipboard utility function
await copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
} else {
throw new Error(data.error || 'No syntax returned from server');
}
@@ -761,12 +760,7 @@ class RecipeModal {
// Helper method to copy text to clipboard
copyToClipboard(text, successMessage) {
navigator.clipboard.writeText(text).then(() => {
showToast(successMessage, 'success');
}).catch(err => {
console.error('Failed to copy text: ', err);
showToast('Failed to copy text', 'error');
});
copyToClipboard(text, successMessage);
}
// Add new method to handle downloading missing LoRAs
@@ -790,9 +784,9 @@ class RecipeModal {
// Determine which endpoint to use based on available data
if (lora.modelVersionId) {
endpoint = `/api/civitai/model/${lora.modelVersionId}`;
endpoint = `/api/civitai/model/version/${lora.modelVersionId}`;
} else if (lora.hash) {
endpoint = `/api/civitai/model/${lora.hash}`;
endpoint = `/api/civitai/model/hash/${lora.hash}`;
} else {
console.error("Missing both hash and modelVersionId for lora:", lora);
return null;

View File

@@ -5,31 +5,7 @@
import { showToast } from '../../utils/uiHelpers.js';
import { BASE_MODELS } from '../../utils/constants.js';
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
/**
* Save model metadata to the server
* @param {string} filePath - Path to the model file
* @param {Object} data - Metadata to save
* @returns {Promise} - Promise that resolves with the server response
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/checkpoints/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}
import { saveModelMetadata } from '../../api/checkpointApi.js';
/**
* Set up model name editing functionality

View File

@@ -2,7 +2,7 @@
* ShowcaseView.js
* Handles showcase content (images, videos) display for checkpoint modal
*/
import { showToast } from '../../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
@@ -307,8 +307,7 @@ function initMetadataPanelHandlers(container) {
if (!promptElement) return;
try {
await navigator.clipboard.writeText(promptElement.textContent);
showToast('Prompt copied to clipboard', 'success');
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -11,9 +11,9 @@ import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
import {
setupModelNameEditing,
setupBaseModelEditing,
setupFileNameEditing,
saveModelMetadata
setupFileNameEditing
} from './ModelMetadata.js';
import { saveModelMetadata } from '../../api/checkpointApi.js';
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
import { updateCheckpointCard } from '../../utils/cardUpdater.js';

View File

@@ -2,7 +2,6 @@
import { PageControls } from './PageControls.js';
import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js';
import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
* LorasControls class - Extends PageControls for LoRA-specific functionality

View File

@@ -1,6 +1,6 @@
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
import { getStorageItem, setStorageItem } from '../../utils/storageHelpers.js';
import { getStorageItem, setStorageItem, getSessionItem, setSessionItem } from '../../utils/storageHelpers.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
@@ -26,6 +26,9 @@ export class PageControls {
// Initialize event listeners
this.initEventListeners();
// Initialize favorites filter button state
this.initFavoritesFilter();
console.log(`PageControls initialized for ${pageType} page`);
}
@@ -121,6 +124,12 @@ export class PageControls {
bulkButton.addEventListener('click', () => this.toggleBulkMode());
}
}
// Favorites filter button handler
const favoriteFilterBtn = document.getElementById('favoriteFilterBtn');
if (favoriteFilterBtn) {
favoriteFilterBtn.addEventListener('click', () => this.toggleFavoritesOnly());
}
}
/**
@@ -385,4 +394,50 @@ export class PageControls {
showToast('Failed to clear custom filter: ' + error.message, 'error');
}
}
/**
* Initialize the favorites filter button state
*/
initFavoritesFilter() {
const favoriteFilterBtn = document.getElementById('favoriteFilterBtn');
if (favoriteFilterBtn) {
// Get current state from session storage with page-specific key
const storageKey = `show_favorites_only_${this.pageType}`;
const showFavoritesOnly = getSessionItem(storageKey, false);
// Update button state
if (showFavoritesOnly) {
favoriteFilterBtn.classList.add('active');
}
// Update app state
this.pageState.showFavoritesOnly = showFavoritesOnly;
}
}
/**
* Toggle favorites-only filter and reload models
*/
async toggleFavoritesOnly() {
const favoriteFilterBtn = document.getElementById('favoriteFilterBtn');
// Toggle the filter state in storage
const storageKey = `show_favorites_only_${this.pageType}`;
const currentState = this.pageState.showFavoritesOnly;
const newState = !currentState;
// Update session storage
setSessionItem(storageKey, newState);
// Update state
this.pageState.showFavoritesOnly = newState;
// Update button appearance
if (favoriteFilterBtn) {
favoriteFilterBtn.classList.toggle('active', newState);
}
// Reload models with new filter
await this.resetAndReload(true);
}
}

View File

@@ -5,31 +5,7 @@
import { showToast } from '../../utils/uiHelpers.js';
import { BASE_MODELS } from '../../utils/constants.js';
import { updateLoraCard } from '../../utils/cardUpdater.js';
/**
* 保存模型元数据到服务器
* @param {string} filePath - 文件路径
* @param {Object} data - 要保存的数据
* @returns {Promise} 保存操作的Promise
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/loras/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}
import { saveModelMetadata } from '../../api/loraApi.js';
/**
* 设置模型名称编辑功能

View File

@@ -2,8 +2,7 @@
* PresetTags.js
* 处理LoRA模型预设参数标签相关的功能模块
*/
import { saveModelMetadata } from './ModelMetadata.js';
import { showToast } from '../../utils/uiHelpers.js';
import { saveModelMetadata } from '../../api/loraApi.js';
/**
* 解析预设参数

View File

@@ -1,7 +1,7 @@
/**
* RecipeTab - Handles the recipes tab in the Lora Modal
*/
import { showToast } from '../../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
/**
@@ -172,14 +172,11 @@ function copyRecipeSyntax(recipeId) {
.then(response => response.json())
.then(data => {
if (data.success && data.syntax) {
return navigator.clipboard.writeText(data.syntax);
return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
} else {
throw new Error(data.error || 'No syntax returned');
}
})
.then(() => {
showToast('Recipe syntax copied to clipboard', 'success');
})
.catch(err => {
console.error('Failed to copy: ', err);
showToast('Failed to copy recipe syntax', 'error');

View File

@@ -2,7 +2,7 @@
* ShowcaseView.js
* 处理LoRA模型展示内容图片、视频的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
@@ -311,8 +311,7 @@ function initMetadataPanelHandlers(container) {
if (!promptElement) return;
try {
await navigator.clipboard.writeText(promptElement.textContent);
showToast('Prompt copied to clipboard', 'success');
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -2,8 +2,8 @@
* TriggerWords.js
* 处理LoRA模型触发词相关的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { saveModelMetadata } from './ModelMetadata.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { saveModelMetadata } from '../../api/loraApi.js';
/**
* 渲染触发词
@@ -235,8 +235,8 @@ function addNewTriggerWord(word) {
// Validation: Check total number
const currentTags = tagsContainer.querySelectorAll('.trigger-word-tag');
if (currentTags.length >= 10) {
showToast('Maximum 10 trigger words allowed', 'error');
if (currentTags.length >= 30) {
showToast('Maximum 30 trigger words allowed', 'error');
return;
}
@@ -336,8 +336,7 @@ async function saveTriggerWords() {
*/
window.copyTriggerWord = async function(word) {
try {
await navigator.clipboard.writeText(word);
showToast('Trigger word copied', 'success');
await copyToClipboard(word, 'Trigger word copied');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -3,8 +3,7 @@
*
* 将原始的LoraModal.js拆分成多个功能模块后的主入口文件
*/
import { showToast } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { modalManager } from '../../managers/ModalManager.js';
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
@@ -14,9 +13,9 @@ import { loadRecipesForLora } from './RecipeTab.js'; // Add import for recipe ta
import {
setupModelNameEditing,
setupBaseModelEditing,
setupFileNameEditing,
saveModelMetadata
setupFileNameEditing
} from './ModelMetadata.js';
import { saveModelMetadata } from '../../api/loraApi.js';
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
import { updateLoraCard } from '../../utils/cardUpdater.js';
@@ -174,8 +173,7 @@ export function showLoraModal(lora) {
// Copy file name function
window.copyFileName = async function(fileName) {
try {
await navigator.clipboard.writeText(fileName);
showToast('File name copied', 'success');
await copyToClipboard(fileName, 'File name copied');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');

View File

@@ -6,7 +6,7 @@ import { updateCardsForBulkMode } from './components/LoraCard.js';
import { bulkManager } from './managers/BulkManager.js';
import { DownloadManager } from './managers/DownloadManager.js';
import { moveManager } from './managers/MoveManager.js';
import { LoraContextMenu } from './components/ContextMenu.js';
import { LoraContextMenu } from './components/ContextMenu/index.js';
import { createPageControls } from './components/controls/index.js';
import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';

View File

@@ -1,5 +1,5 @@
import { state } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { updateCardsForBulkMode } from '../components/LoraCard.js';
export class BulkManager {
@@ -205,13 +205,7 @@ export class BulkManager {
return;
}
try {
await navigator.clipboard.writeText(loraSyntaxes.join(', '));
showToast(`Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`, 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
await copyToClipboard(loraSyntaxes.join(', '), `Copied ${loraSyntaxes.length} LoRA syntaxes to clipboard`);
}
// Create and show the thumbnail strip of selected LoRAs

View File

@@ -146,6 +146,18 @@ export class ImportManager {
if (totalSizeDisplay) {
totalSizeDisplay.textContent = 'Calculating...';
}
// Remove any existing deleted LoRAs warning
const deletedLorasWarning = document.getElementById('deletedLorasWarning');
if (deletedLorasWarning) {
deletedLorasWarning.remove();
}
// Remove any existing early access warning
const earlyAccessWarning = document.getElementById('earlyAccessWarning');
if (earlyAccessWarning) {
earlyAccessWarning.remove();
}
}
toggleImportMode(mode) {
@@ -532,17 +544,17 @@ export class ImportManager {
const nextButton = document.querySelector('#detailsStep .primary-btn');
if (!nextButton) return;
// Always clean up previous warnings first
const existingWarning = document.getElementById('deletedLorasWarning');
if (existingWarning) {
existingWarning.remove();
}
// Count deleted LoRAs
const deletedLoras = this.recipeData.loras.filter(lora => lora.isDeleted).length;
// If we have deleted LoRAs, show a warning and update button text
if (deletedLoras > 0) {
// Remove any existing warning
const existingWarning = document.getElementById('deletedLorasWarning');
if (existingWarning) {
existingWarning.remove();
}
// Create a new warning container above the buttons
const buttonsContainer = document.querySelector('#detailsStep .modal-actions') || nextButton.parentNode;
const warningContainer = document.createElement('div');

View File

@@ -5,6 +5,7 @@ import { RecipeCard } from './components/RecipeCard.js';
import { RecipeModal } from './components/RecipeModal.js';
import { getCurrentPageState } from './state/index.js';
import { getSessionItem, removeSessionItem } from './utils/storageHelpers.js';
import { RecipeContextMenu } from './components/ContextMenu/index.js';
class RecipeManager {
constructor() {
@@ -37,6 +38,9 @@ class RecipeManager {
// Set default search options if not already defined
this._initSearchOptions();
// Initialize context menu
new RecipeContextMenu();
// Check for custom filter parameters in session storage
this._checkCustomFilter();
@@ -264,6 +268,32 @@ class RecipeManager {
}
}
/**
* Refreshes the recipe list by first rebuilding the cache and then loading recipes
*/
async refreshRecipes() {
try {
// Call the new endpoint to rebuild the recipe cache
const response = await fetch('/api/recipes/scan');
if (!response.ok) {
const data = await response.json();
throw new Error(data.error || 'Failed to refresh recipe cache');
}
// After successful cache rebuild, load the recipes
await this.loadRecipes(true);
appCore.showToast('Refresh complete', 'success');
} catch (error) {
console.error('Error refreshing recipes:', error);
appCore.showToast(error.message || 'Failed to refresh recipes', 'error');
// Still try to load recipes even if scan failed
await this.loadRecipes(true);
}
}
async _loadSpecificRecipe(recipeId) {
try {
// Fetch specific recipe by ID

View File

@@ -1,5 +1,5 @@
// Create the new hierarchical state structure
import { getStorageItem } from '../utils/storageHelpers.js';
import { getStorageItem, getMapFromStorage } from '../utils/storageHelpers.js';
// Load settings from localStorage or use defaults
const savedSettings = getStorageItem('settings', {
@@ -7,6 +7,10 @@ const savedSettings = getStorageItem('settings', {
show_only_sfw: false
});
// Load preview versions from localStorage
const loraPreviewVersions = getMapFromStorage('lora_preview_versions');
const checkpointPreviewVersions = getMapFromStorage('checkpoint_preview_versions');
export const state = {
// Global state
global: {
@@ -23,7 +27,7 @@ export const state = {
hasMore: true,
sortBy: 'name',
activeFolder: null,
previewVersions: new Map(),
previewVersions: loraPreviewVersions,
searchManager: null,
searchOptions: {
filename: true,
@@ -38,6 +42,7 @@ export const state = {
bulkMode: false,
selectedLoras: new Set(),
loraMetadataCache: new Map(),
showFavoritesOnly: false,
},
recipes: {
@@ -57,7 +62,8 @@ export const state = {
tags: [],
search: ''
},
pageSize: 20
pageSize: 20,
showFavoritesOnly: false,
},
checkpoints: {
@@ -66,6 +72,7 @@ export const state = {
hasMore: true,
sortBy: 'name',
activeFolder: null,
previewVersions: checkpointPreviewVersions,
searchManager: null,
searchOptions: {
filename: true,
@@ -75,7 +82,8 @@ export const state = {
filters: {
baseModel: [],
tags: []
}
},
showFavoritesOnly: false,
}
},

View File

@@ -4,6 +4,7 @@ import { loadMoreCheckpoints } from '../api/checkpointApi.js';
import { debounce } from './debounce.js';
export function initializeInfiniteScroll(pageType = 'loras') {
// Clean up any existing observer
if (state.observer) {
state.observer.disconnect();
}
@@ -47,53 +48,53 @@ export function initializeInfiniteScroll(pageType = 'loras') {
}
const debouncedLoadMore = debounce(loadMoreFunction, 100);
// Create a more robust observer with lower threshold and root margin
state.observer = new IntersectionObserver(
(entries) => {
const target = entries[0];
if (target.isIntersecting && !pageState.isLoading && pageState.hasMore) {
debouncedLoadMore();
}
},
{
threshold: 0.01, // Lower threshold to detect even minimal visibility
rootMargin: '0px 0px 300px 0px' // Increase bottom margin to trigger earlier
}
);
const grid = document.getElementById(gridId);
if (!grid) {
console.warn(`Grid with ID "${gridId}" not found for infinite scroll`);
return;
}
// Remove any existing sentinel
const existingSentinel = document.getElementById('scroll-sentinel');
if (existingSentinel) {
state.observer.observe(existingSentinel);
} else {
// Create a wrapper div that will be placed after the grid
const sentinelWrapper = document.createElement('div');
sentinelWrapper.style.width = '100%';
sentinelWrapper.style.height = '30px'; // Increased height for better visibility
sentinelWrapper.style.margin = '0';
sentinelWrapper.style.padding = '0';
// Create the actual sentinel element
const sentinel = document.createElement('div');
sentinel.id = 'scroll-sentinel';
sentinel.style.height = '30px'; // Match wrapper height
// Add the sentinel to the wrapper
sentinelWrapper.appendChild(sentinel);
// Insert the wrapper after the grid instead of inside it
grid.parentNode.insertBefore(sentinelWrapper, grid.nextSibling);
state.observer.observe(sentinel);
existingSentinel.remove();
}
// Add a scroll event backup to handle edge cases
// Create a sentinel element after the grid (not inside it)
const sentinel = document.createElement('div');
sentinel.id = 'scroll-sentinel';
sentinel.style.width = '100%';
sentinel.style.height = '20px';
sentinel.style.visibility = 'hidden'; // Make it invisible but still affect layout
// Insert after grid instead of inside
grid.parentNode.insertBefore(sentinel, grid.nextSibling);
// Create observer with appropriate settings, slightly different for checkpoints page
const observerOptions = {
threshold: 0.1,
rootMargin: pageType === 'checkpoints' ? '0px 0px 200px 0px' : '0px 0px 100px 0px'
};
// Initialize the observer
state.observer = new IntersectionObserver((entries) => {
const target = entries[0];
if (target.isIntersecting && !pageState.isLoading && pageState.hasMore) {
debouncedLoadMore();
}
}, observerOptions);
// Start observing
state.observer.observe(sentinel);
// Clean up any existing scroll event listener
if (state.scrollHandler) {
window.removeEventListener('scroll', state.scrollHandler);
state.scrollHandler = null;
}
// Add a simple backup scroll handler
const handleScroll = debounce(() => {
if (pageState.isLoading || !pageState.hasMore) return;
@@ -103,26 +104,17 @@ export function initializeInfiniteScroll(pageType = 'loras') {
const rect = sentinel.getBoundingClientRect();
const windowHeight = window.innerHeight;
// If sentinel is within 500px of viewport bottom, load more
if (rect.top < windowHeight + 500) {
if (rect.top < windowHeight + 200) {
debouncedLoadMore();
}
}, 200);
// Clean up existing scroll listener if any
if (state.scrollHandler) {
window.removeEventListener('scroll', state.scrollHandler);
}
// Save reference to the handler for cleanup
state.scrollHandler = handleScroll;
window.addEventListener('scroll', state.scrollHandler);
// Check position immediately in case content is already visible
setTimeout(() => {
const sentinel = document.getElementById('scroll-sentinel');
if (sentinel && sentinel.getBoundingClientRect().top < window.innerHeight) {
debouncedLoadMore();
}
}, 100);
// Clear any existing interval
if (state.scrollCheckInterval) {
clearInterval(state.scrollCheckInterval);
state.scrollCheckInterval = null;
}
}

View File

@@ -171,4 +171,45 @@ export function migrateStorageItems() {
localStorage.setItem(STORAGE_PREFIX + 'migration_completed', 'true');
console.log('Lora Manager: Storage migration completed');
}
/**
* Save a Map to localStorage
* @param {string} key - The localStorage key
* @param {Map} map - The Map to save
*/
export function saveMapToStorage(key, map) {
if (!(map instanceof Map)) {
console.error('Cannot save non-Map object:', map);
return;
}
try {
const prefixedKey = STORAGE_PREFIX + key;
// Convert Map to array of entries and save as JSON
const entries = Array.from(map.entries());
localStorage.setItem(prefixedKey, JSON.stringify(entries));
} catch (error) {
console.error(`Error saving Map to localStorage (${key}):`, error);
}
}
/**
* Load a Map from localStorage
* @param {string} key - The localStorage key
* @returns {Map} - The loaded Map or a new empty Map
*/
export function getMapFromStorage(key) {
try {
const prefixedKey = STORAGE_PREFIX + key;
const data = localStorage.getItem(prefixedKey);
if (!data) return new Map();
// Parse JSON and convert back to Map
const entries = JSON.parse(data);
return new Map(entries);
} catch (error) {
console.error(`Error loading Map from localStorage (${key}):`, error);
return new Map();
}
}

View File

@@ -2,6 +2,40 @@ import { state } from '../state/index.js';
import { resetAndReload } from '../api/loraApi.js';
import { getStorageItem, setStorageItem } from './storageHelpers.js';
/**
* Utility function to copy text to clipboard with fallback for older browsers
* @param {string} text - The text to copy to clipboard
* @param {string} successMessage - Optional success message to show in toast
* @returns {Promise<boolean>} - Promise that resolves to true if copy was successful
*/
export async function copyToClipboard(text, successMessage = 'Copied to clipboard') {
try {
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(text);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = text;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
if (successMessage) {
showToast(successMessage, 'success');
}
return true;
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
return false;
}
}
export function showToast(message, type = 'info') {
const toast = document.createElement('div');
toast.className = `toast toast-${type}`;
@@ -80,13 +114,55 @@ export function restoreFolderFilter() {
}
export function initTheme() {
document.body.dataset.theme = getStorageItem('theme') || 'dark';
const savedTheme = getStorageItem('theme') || 'auto';
applyTheme(savedTheme);
// Update theme when system preference changes (for 'auto' mode)
window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', () => {
const currentTheme = getStorageItem('theme') || 'auto';
if (currentTheme === 'auto') {
applyTheme('auto');
}
});
}
export function toggleTheme() {
const theme = document.body.dataset.theme === 'light' ? 'dark' : 'light';
document.body.dataset.theme = theme;
setStorageItem('theme', theme);
const currentTheme = getStorageItem('theme') || 'auto';
let newTheme;
if (currentTheme === 'dark') {
newTheme = 'light';
} else {
newTheme = 'dark';
}
setStorageItem('theme', newTheme);
applyTheme(newTheme);
// Force a repaint to ensure theme changes are applied immediately
document.body.style.display = 'none';
document.body.offsetHeight; // Trigger a reflow
document.body.style.display = '';
return newTheme;
}
// Add a new helper function to apply the theme
function applyTheme(theme) {
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
const htmlElement = document.documentElement;
// Remove any existing theme attributes
htmlElement.removeAttribute('data-theme');
// Apply the appropriate theme
if (theme === 'dark' || (theme === 'auto' && prefersDark)) {
htmlElement.setAttribute('data-theme', 'dark');
document.body.dataset.theme = 'dark';
} else {
htmlElement.setAttribute('data-theme', 'light');
document.body.dataset.theme = 'light';
}
}
export function toggleFolder(tag) {
@@ -108,12 +184,6 @@ export function toggleFolder(tag) {
resetAndReload();
}
export function copyTriggerWord(word) {
navigator.clipboard.writeText(word).then(() => {
showToast('Trigger word copied', 'success');
});
}
function filterByFolder(folderPath) {
document.querySelectorAll('.lora-card').forEach(card => {
card.style.display = card.dataset.folder === folderPath ? '' : 'none';

File diff suppressed because one or more lines are too long

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -6,7 +6,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="stylesheet" href="/loras_static/css/style.css">
{% block page_css %}{% endblock %}
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"
<link rel="stylesheet" href="/loras_static/vendor/font-awesome/css/all.min.css"
crossorigin="anonymous" referrerpolicy="no-referrer">
<link rel="icon" type="image/png" sizes="32x32" href="/loras_static/images/favicon-32x32.png">
<link rel="icon" type="image/png" sizes="16x16" href="/loras_static/images/favicon-16x16.png">
@@ -17,7 +17,7 @@
{% block preload %}{% endblock %}
<!-- 优化字体加载 -->
<link rel="preload" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/webfonts/fa-solid-900.woff2"
<link rel="preload" href="/loras_static/vendor/font-awesome/webfonts/fa-solid-900.woff2"
as="font" type="font/woff2" crossorigin>
<!-- 添加性能监控 -->
@@ -35,7 +35,7 @@
<!-- 添加资源加载策略 -->
<link rel="preconnect" href="https://civitai.com">
<link rel="preconnect" href="https://cdnjs.cloudflare.com">
<!-- <link rel="preconnect" href="https://cdnjs.cloudflare.com"> -->
<script>
// 计算滚动条宽度并设置CSS变量
@@ -48,6 +48,20 @@
document.documentElement.style.setProperty('--scrollbar-width', scrollbarWidth + 'px');
});
</script>
<script>
(function() {
// Apply theme immediately based on stored preference
const STORAGE_PREFIX = 'lora_manager_';
const savedTheme = localStorage.getItem(STORAGE_PREFIX + 'theme') || 'auto';
const prefersDark = window.matchMedia('(prefers-color-scheme: dark)').matches;
if (savedTheme === 'dark' || (savedTheme === 'auto' && prefersDark)) {
document.documentElement.setAttribute('data-theme', 'dark');
} else {
document.documentElement.setAttribute('data-theme', 'light');
}
})();
</script>
{% block head_scripts %}{% endblock %}
</head>

View File

@@ -13,6 +13,18 @@
{% block additional_components %}
{% include 'components/checkpoint_modals.html' %}
<div id="checkpointContextMenu" class="context-menu" style="display: none;">
<div class="context-menu-item" data-action="details"><i class="fas fa-info-circle"></i> View Details</div>
<div class="context-menu-item" data-action="civitai"><i class="fas fa-external-link-alt"></i> View on CivitAI</div>
<div class="context-menu-item" data-action="refresh-metadata"><i class="fas fa-sync"></i> Refresh Civitai Data</div>
<div class="context-menu-item" data-action="copyname"><i class="fas fa-copy"></i> Copy Model Filename</div>
<div class="context-menu-item" data-action="preview"><i class="fas fa-image"></i> Replace Preview</div>
<div class="context-menu-item" data-action="set-nsfw"><i class="fas fa-exclamation-triangle"></i> Set Content Rating</div>
<div class="context-menu-separator"></div>
<div class="context-menu-item" data-action="move"><i class="fas fa-folder-open"></i> Move to Folder</div>
<div class="context-menu-item delete-item" data-action="delete"><i class="fas fa-trash"></i> Delete Model</div>
</div>
{% endblock %}
{% block content %}

View File

@@ -35,6 +35,11 @@
</button>
</div>
{% endif %}
<div class="control-group">
<button id="favoriteFilterBtn" data-action="toggle-favorites" class="favorite-filter" title="Show favorites only">
<i class="fas fa-star"></i> Favorites
</button>
</div>
<div id="customFilterIndicator" class="control-group hidden">
<div class="filter-active">
<i class="fas fa-filter"></i> <span class="customFilterText" title=""></span>

View File

@@ -16,6 +16,16 @@
{% block additional_components %}
{% include 'components/import_modal.html' %}
{% include 'components/recipe_modal.html' %}
<div id="recipeContextMenu" class="context-menu" style="display: none;">
<div class="context-menu-item" data-action="details"><i class="fas fa-info-circle"></i> View Details</div>
<div class="context-menu-item" data-action="share"><i class="fas fa-share-alt"></i> Share Recipe</div>
<div class="context-menu-item" data-action="copy"><i class="fas fa-copy"></i> Copy Recipe Syntax</div>
<div class="context-menu-item" data-action="viewloras"><i class="fas fa-layer-group"></i> View All LoRAs</div>
<div class="context-menu-item download-missing-item" data-action="download-missing"><i class="fas fa-download"></i> Download Missing LoRAs</div>
<div class="context-menu-separator"></div>
<div class="context-menu-item delete-item" data-action="delete"><i class="fas fa-trash"></i> Delete Recipe</div>
</div>
{% endblock %}
{% block init_title %}Initializing Recipe Manager{% endblock %}
@@ -27,7 +37,7 @@
<div class="controls">
<div class="action-buttons">
<div title="Refresh recipe list" class="control-group">
<button onclick="recipeManager.loadRecipes(true)"><i class="fas fa-sync"></i> Refresh</button>
<button onclick="recipeManager.refreshRecipes()"><i class="fas fa-sync"></i> Refresh</button>
</div>
<div title="Import recipes" class="control-group">
<button onclick="importManager.showImportModal()"><i class="fas fa-file-import"></i> Import</button>

View File

@@ -287,6 +287,108 @@ export function addLorasWidget(node, name, opts, callback) {
// 创建预览tooltip实例
const previewTooltip = new PreviewTooltip();
// Function to handle strength adjustment via dragging
const handleStrengthDrag = (name, initialStrength, initialX, event, widget) => {
// Calculate drag sensitivity (how much the strength changes per pixel)
// Using 0.01 per 10 pixels of movement
const sensitivity = 0.001;
// Get the current mouse position
const currentX = event.clientX;
// Calculate the distance moved
const deltaX = currentX - initialX;
// Calculate the new strength value based on movement
// Moving right increases, moving left decreases
let newStrength = Number(initialStrength) + (deltaX * sensitivity);
// Limit the strength to reasonable bounds (now between -10 and 10)
newStrength = Math.max(-10, Math.min(10, newStrength));
newStrength = Number(newStrength.toFixed(2));
// Update the lora data
const lorasData = parseLoraValue(widget.value);
const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) {
lorasData[loraIndex].strength = newStrength;
// Update the widget value
widget.value = formatLoraValue(lorasData);
// Force re-render to show updated strength value
renderLoras(widget.value, widget);
}
};
// Function to initialize drag operation
const initDrag = (loraEl, nameEl, name, widget) => {
let isDragging = false;
let initialX = 0;
let initialStrength = 0;
// Create a style element for drag cursor override if it doesn't exist
if (!document.getElementById('comfy-lora-drag-style')) {
const styleEl = document.createElement('style');
styleEl.id = 'comfy-lora-drag-style';
styleEl.textContent = `
body.comfy-lora-dragging,
body.comfy-lora-dragging * {
cursor: ew-resize !important;
}
`;
document.head.appendChild(styleEl);
}
// Create a drag handler that's applied to the entire lora entry
// except toggle and strength controls
loraEl.addEventListener('mousedown', (e) => {
// Skip if clicking on toggle or strength control areas
if (e.target.closest('.comfy-lora-toggle') ||
e.target.closest('input') ||
e.target.closest('.comfy-lora-arrow')) {
return;
}
// Store initial values
const lorasData = parseLoraValue(widget.value);
const loraData = lorasData.find(l => l.name === name);
if (!loraData) return;
initialX = e.clientX;
initialStrength = loraData.strength;
isDragging = true;
// Add class to body to enforce cursor style globally
document.body.classList.add('comfy-lora-dragging');
// Prevent text selection during drag
e.preventDefault();
});
// Use the document for move and up events to ensure drag continues
// even if mouse leaves the element
document.addEventListener('mousemove', (e) => {
if (!isDragging) return;
// Call the strength adjustment function
handleStrengthDrag(name, initialStrength, initialX, e, widget);
// Prevent showing the preview tooltip during drag
previewTooltip.hide();
});
document.addEventListener('mouseup', () => {
if (isDragging) {
isDragging = false;
// Remove the class to restore normal cursor behavior
document.body.classList.remove('comfy-lora-dragging');
}
});
};
// Function to create menu item
const createMenuItem = (text, icon, onClick) => {
const menuItem = document.createElement('div');
@@ -756,6 +858,9 @@ export function addLorasWidget(node, name, opts, callback) {
loraEl.appendChild(strengthControl);
container.appendChild(loraEl);
// Initialize drag functionality
initDrag(loraEl, nameEl, name, widget);
});
};
@@ -822,10 +927,6 @@ export function addLorasWidget(node, name, opts, callback) {
// Function to directly save the recipe without dialog
async function saveRecipeDirectly(widget) {
try {
// Get the workflow data from the ComfyUI app
const prompt = await app.graphToPrompt();
console.log('Prompt:', prompt);
// Show loading toast
if (app && app.extensionManager && app.extensionManager.toast) {
app.extensionManager.toast.add({
@@ -836,14 +937,9 @@ async function saveRecipeDirectly(widget) {
});
}
// Prepare the data - only send workflow JSON
const formData = new FormData();
formData.append('workflow_json', JSON.stringify(prompt.output));
// Send the request
const response = await fetch('/api/recipes/save-from-widget', {
method: 'POST',
body: formData
method: 'POST'
});
const result = await response.json();

View File

@@ -9,6 +9,54 @@ async function getLorasWidgetModule() {
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
}
// Function to get connected trigger toggle nodes
function getConnectedTriggerToggleNodes(node) {
const connectedNodes = [];
// Check if node has outputs
if (node.outputs && node.outputs.length > 0) {
// For each output slot
for (const output of node.outputs) {
// Check if this output has any links
if (output.links && output.links.length > 0) {
// For each link, get the target node
for (const linkId of output.links) {
const link = app.graph.links[linkId];
if (link) {
const targetNode = app.graph.getNodeById(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode.id);
}
}
}
}
}
}
return connectedNodes;
}
// Function to update trigger words for connected toggle nodes
function updateConnectedTriggerWords(node, text) {
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
if (connectedNodeIds.length > 0) {
const loraNames = new Set();
let match;
LORA_PATTERN.lastIndex = 0;
while ((match = LORA_PATTERN.exec(text)) !== null) {
loraNames.add(match[1]);
}
fetch("/loramanager/get_trigger_words", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
lora_names: Array.from(loraNames),
node_ids: connectedNodeIds
})
}).catch(err => console.error("Error fetching trigger words:", err));
}
}
function mergeLoras(lorasText, lorasArr) {
const result = [];
let match;
@@ -99,6 +147,9 @@ app.registerExtension({
newText = newText.replace(/\s+/g, ' ').trim();
inputWidget.value = newText;
// Add this line to update trigger words when lorasWidget changes cause inputWidget value to change
updateConnectedTriggerWords(node, newText);
} finally {
isUpdating = false;
}
@@ -117,6 +168,9 @@ app.registerExtension({
const mergedLoras = mergeLoras(value, currentLoras);
node.lorasWidget.value = mergedLoras;
// Replace the existing trigger word update code with the new function
updateConnectedTriggerWords(node, value);
} finally {
isUpdating = false;
}

View File

@@ -1,9 +1,58 @@
import { app } from "../../scripts/app.js";
import { addLorasWidget } from "./loras_widget.js";
import { dynamicImportByVersion } from "./utils.js";
// Extract pattern into a constant for consistent use
const LORA_PATTERN = /<lora:([^:]+):([-\d\.]+)>/g;
// Function to get the appropriate loras widget based on ComfyUI version
async function getLorasWidgetModule() {
return await dynamicImportByVersion("./loras_widget.js", "./legacy_loras_widget.js");
}
// Function to get connected trigger toggle nodes
function getConnectedTriggerToggleNodes(node) {
const connectedNodes = [];
if (node.outputs && node.outputs.length > 0) {
for (const output of node.outputs) {
if (output.links && output.links.length > 0) {
for (const linkId of output.links) {
const link = app.graph.links[linkId];
if (link) {
const targetNode = app.graph.getNodeById(link.target_id);
if (targetNode && targetNode.comfyClass === "TriggerWord Toggle (LoraManager)") {
connectedNodes.push(targetNode.id);
}
}
}
}
}
}
return connectedNodes;
}
// Function to update trigger words for connected toggle nodes
function updateConnectedTriggerWords(node, text) {
const connectedNodeIds = getConnectedTriggerToggleNodes(node);
if (connectedNodeIds.length > 0) {
const loraNames = new Set();
let match;
LORA_PATTERN.lastIndex = 0;
while ((match = LORA_PATTERN.exec(text)) !== null) {
loraNames.add(match[1]);
}
fetch("/loramanager/get_trigger_words", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
lora_names: Array.from(loraNames),
node_ids: connectedNodeIds
})
}).catch(err => console.error("Error fetching trigger words:", err));
}
}
function mergeLoras(lorasText, lorasArr) {
const result = [];
let match;
@@ -40,7 +89,7 @@ app.registerExtension({
});
// Wait for node to be properly initialized
requestAnimationFrame(() => {
requestAnimationFrame(async () => {
// Restore saved value if exists
let existingLoras = [];
if (node.widgets_values && node.widgets_values.length > 0) {
@@ -64,7 +113,10 @@ app.registerExtension({
// Add flag to prevent callback loops
let isUpdating = false;
// Get the widget object directly from the returned object
// Dynamically load the appropriate widget module
const lorasModule = await getLorasWidgetModule();
const { addLorasWidget } = lorasModule;
const result = addLorasWidget(node, "loras", {
defaultVal: mergedLoras // Pass object directly
}, (value) => {
@@ -86,6 +138,9 @@ app.registerExtension({
newText = newText.replace(/\s+/g, ' ').trim();
inputWidget.value = newText;
// Update trigger words when lorasWidget changes
updateConnectedTriggerWords(node, newText);
} finally {
isUpdating = false;
}
@@ -104,6 +159,9 @@ app.registerExtension({
const mergedLoras = mergeLoras(value, currentLoras);
node.lorasWidget.value = mergedLoras;
// Update trigger words when input changes
updateConnectedTriggerWords(node, value);
} finally {
isUpdating = false;
}

View File

@@ -366,6 +366,108 @@ export function addLorasWidget(node, name, opts, callback) {
return menuItem;
};
// Function to handle strength adjustment via dragging
const handleStrengthDrag = (name, initialStrength, initialX, event, widget) => {
// Calculate drag sensitivity (how much the strength changes per pixel)
// Using 0.01 per 10 pixels of movement
const sensitivity = 0.001;
// Get the current mouse position
const currentX = event.clientX;
// Calculate the distance moved
const deltaX = currentX - initialX;
// Calculate the new strength value based on movement
// Moving right increases, moving left decreases
let newStrength = Number(initialStrength) + (deltaX * sensitivity);
// Limit the strength to reasonable bounds (now between -10 and 10)
newStrength = Math.max(-10, Math.min(10, newStrength));
newStrength = Number(newStrength.toFixed(2));
// Update the lora data
const lorasData = parseLoraValue(widget.value);
const loraIndex = lorasData.findIndex(l => l.name === name);
if (loraIndex >= 0) {
lorasData[loraIndex].strength = newStrength;
// Update the widget value
widget.value = formatLoraValue(lorasData);
// Force re-render to show updated strength value
renderLoras(widget.value, widget);
}
};
// Function to initialize drag operation
const initDrag = (loraEl, nameEl, name, widget) => {
let isDragging = false;
let initialX = 0;
let initialStrength = 0;
// Create a style element for drag cursor override if it doesn't exist
if (!document.getElementById('comfy-lora-drag-style')) {
const styleEl = document.createElement('style');
styleEl.id = 'comfy-lora-drag-style';
styleEl.textContent = `
body.comfy-lora-dragging,
body.comfy-lora-dragging * {
cursor: ew-resize !important;
}
`;
document.head.appendChild(styleEl);
}
// Create a drag handler that's applied to the entire lora entry
// except toggle and strength controls
loraEl.addEventListener('mousedown', (e) => {
// Skip if clicking on toggle or strength control areas
if (e.target.closest('.comfy-lora-toggle') ||
e.target.closest('input') ||
e.target.closest('.comfy-lora-arrow')) {
return;
}
// Store initial values
const lorasData = parseLoraValue(widget.value);
const loraData = lorasData.find(l => l.name === name);
if (!loraData) return;
initialX = e.clientX;
initialStrength = loraData.strength;
isDragging = true;
// Add class to body to enforce cursor style globally
document.body.classList.add('comfy-lora-dragging');
// Prevent text selection during drag
e.preventDefault();
});
// Use the document for move and up events to ensure drag continues
// even if mouse leaves the element
document.addEventListener('mousemove', (e) => {
if (!isDragging) return;
// Call the strength adjustment function
handleStrengthDrag(name, initialStrength, initialX, e, widget);
// Prevent showing the preview tooltip during drag
previewTooltip.hide();
});
document.addEventListener('mouseup', () => {
if (isDragging) {
isDragging = false;
// Remove the class to restore normal cursor behavior
document.body.classList.remove('comfy-lora-dragging');
}
});
};
// Function to create context menu
const createContextMenu = (x, y, loraName, widget) => {
// Hide preview tooltip first
@@ -649,6 +751,9 @@ export function addLorasWidget(node, name, opts, callback) {
e.stopPropagation();
previewTooltip.hide();
});
// Initialize drag functionality for strength adjustment
initDrag(loraEl, nameEl, name, widget);
// Remove the preview tooltip events from loraEl
loraEl.onmouseenter = () => {
@@ -795,7 +900,7 @@ export function addLorasWidget(node, name, opts, callback) {
});
// Calculate height based on number of loras and fixed sizes
const calculatedHeight = CONTAINER_PADDING + HEADER_HEIGHT + (lorasData.length * LORA_ENTRY_HEIGHT);
const calculatedHeight = CONTAINER_PADDING + HEADER_HEIGHT + (Math.min(lorasData.length, 5) * LORA_ENTRY_HEIGHT);
updateWidgetHeight(calculatedHeight);
};
@@ -861,9 +966,6 @@ export function addLorasWidget(node, name, opts, callback) {
// Function to directly save the recipe without dialog
async function saveRecipeDirectly(widget) {
try {
// Get the workflow data from the ComfyUI app
const prompt = await app.graphToPrompt();
// Show loading toast
if (app && app.extensionManager && app.extensionManager.toast) {
app.extensionManager.toast.add({
@@ -874,14 +976,9 @@ async function saveRecipeDirectly(widget) {
});
}
// Prepare the data - only send workflow JSON
const formData = new FormData();
formData.append('workflow_json', JSON.stringify(prompt.output));
// Send the request
// Send the request to the backend API without workflow data
const response = await fetch('/api/recipes/save-from-widget', {
method: 'POST',
body: formData
method: 'POST'
});
const result = await response.json();
@@ -917,4 +1014,4 @@ async function saveRecipeDirectly(widget) {
});
}
}
}
}

View File

@@ -0,0 +1,36 @@
// ComfyUI extension to track model usage statistics
import { app } from "../../scripts/app.js";
import { api } from "../../scripts/api.js";
// Register the extension
app.registerExtension({
name: "ComfyUI-Lora-Manager.UsageStats",
init() {
// Listen for successful executions
api.addEventListener("execution_success", ({ detail }) => {
if (detail && detail.prompt_id) {
this.updateUsageStats(detail.prompt_id);
}
});
},
async updateUsageStats(promptId) {
try {
// Call backend endpoint with the prompt_id
const response = await fetch(`/loras/api/update-usage-stats`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ prompt_id: promptId }),
});
if (!response.ok) {
console.warn("Failed to update usage statistics:", response.statusText);
}
} catch (error) {
console.error("Error updating usage statistics:", error);
}
}
});

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long