Compare commits

...

184 Commits

Author SHA1 Message Date
Will Miao
dc4c11ddd2 feat: Update release notes and version to 0.8.8 with new features and bug fixes 2025-04-22 13:29:00 +08:00
pixelpaws
d389e4d5d4 Merge pull request #122 from willmiao/dev
Dev
2025-04-22 09:40:05 +08:00
Will Miao
8cb78ad931 feat: Add route for retrieving current usage statistics 2025-04-22 09:39:00 +08:00
Will Miao
85f987d15c feat: Centralize clipboard functionality with copyToClipboard utility across components 2025-04-22 09:33:05 +08:00
Will Miao
b12079e0f6 feat: Implement usage statistics tracking with backend integration and route setup 2025-04-22 08:56:34 +08:00
pixelpaws
dcf5c6167a Merge pull request #121 from willmiao/dev
Dev
2025-04-21 15:44:23 +08:00
Will Miao
b395d3f487 fix: Update filename formatting in save_images method to ensure unique filenames for batch images 2025-04-21 15:42:49 +08:00
Will Miao
37662cad10 Update workflow 2025-04-21 15:42:49 +08:00
pixelpaws
aa1673063d Merge pull request #120 from willmiao/dev
feat: Enhance LoraManager by updating trigger words handling and dyna…
2025-04-21 06:52:16 +08:00
Will Miao
f51f49eb60 feat: Enhance LoraManager by updating trigger words handling and dynamically loading widget modules. 2025-04-21 06:49:51 +08:00
pixelpaws
54c9bac961 Merge pull request #119 from willmiao/dev
Dev
2025-04-20 22:29:28 +08:00
Will Miao
e70fd73bdd feat: Implement trigger words API and update frontend integration for LoraManager. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/43 2025-04-20 22:27:53 +08:00
Will Miao
9bb9e7b64d refactor: Extract common methods for Lora handling into utils.py and update references in lora_loader.py and lora_stacker.py 2025-04-20 21:35:36 +08:00
pixelpaws
f64c03543a Merge pull request #116 from matrunchyk/main
Prevent duplicates of root folders when using symlinks
2025-04-20 17:05:08 +08:00
Will Miao
51374de1a1 fix: Update version to 0.8.7-bugfix2 in pyproject.toml for clarity on bug fixes 2025-04-20 15:04:24 +08:00
Will Miao
afcc12f263 fix: Update populate_lora_from_civitai method to accept a tuple for Civitai API response. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/117 2025-04-20 15:01:23 +08:00
Your Name
88c5482366 Merge branch 'main' of https://github.com/willmiao/ComfyUI-Lora-Manager 2025-04-19 21:47:41 +03:00
Your Name
bbf7295c32 Prevent duplicates of root folders when using symlinks 2025-04-19 21:42:01 +03:00
Will Miao
ca5e23e68c fix: Update version to 0.8.7-bugfix in pyproject.toml for clarity on bug fixes 2025-04-19 23:02:50 +08:00
Will Miao
eadb1487ae feat: Refactor metadata formatting to use helper function for conditional parameter addition 2025-04-19 23:00:09 +08:00
Will Miao
1faa70fc77 feat: Implement filename-based hash retrieval in LoraScanner and ModelScanner for improved compatibility 2025-04-19 21:12:26 +08:00
Will Miao
30d7c007de fix: Correct metadata restoration logic to ensure file info is fetched when metadata is missing 2025-04-19 20:51:23 +08:00
Will Miao
f54f6a4402 feat: Enhance metadata handling by restoring missing civitai data and extracting tags and descriptions from version info 2025-04-19 11:35:42 +08:00
Will Miao
7b41cdec65 feat: Add civitai_deleted attribute to BaseModelMetadata for tracking deletion status from Civitai 2025-04-19 09:30:43 +08:00
Will Miao
fb6a652a57 feat: Add checkpoint hash retrieval and enhance metadata formatting in SaveImage class 2025-04-18 23:55:45 +08:00
Will Miao
ea34d753c1 refactor: Remove unnecessary workflow data logging and streamline saveRecipeDirectly function for legacy loras widget 2025-04-18 21:52:26 +08:00
Will Miao
2bc46e708e feat: Update release notes and version to 0.8.7 with enhancements and bug fixes 2025-04-18 19:03:00 +08:00
Will Miao
96e3b5b7b3 feat: Refactor Civitai model API routes and enhance RecipeContextMenu for missing LoRAs handling 2025-04-18 16:44:26 +08:00
Will Miao
fafbafa5e1 feat: Enhance copyTriggerWord function with modern clipboard API and fallback for non-secure contexts. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/110 2025-04-18 14:56:27 +08:00
Will Miao
be8605d8c6 feat: Enhance CivitaiClient and ApiRoutes to handle model version errors and improve metadata fetching. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/112 2025-04-18 14:44:53 +08:00
Will Miao
061660d47a feat: Increase maximum allowed trigger words from 10 to 30. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/109 2025-04-18 11:25:41 +08:00
pixelpaws
2ed6dbb344 Merge pull request #111 from willmiao/dev
Dev
2025-04-18 10:55:07 +08:00
Will Miao
4766b45746 feat: Update SaveImage node to modify default lossless_webp setting and adjust save_kwargs for image formats 2025-04-18 10:52:39 +08:00
Will Miao
0734252e98 feat: Enhance VAEDecodeExtractor to improve image caching and metadata handling 2025-04-18 10:03:26 +08:00
Will Miao
91b4827c1d feat: Enhance image retrieval in MetadataRegistry and update recipe routes to process images from metadata 2025-04-18 09:24:48 +08:00
Will Miao
df6d56ce66 feat: Add IMAGES category to constants and enhance metadata handling in node extractors 2025-04-18 07:12:43 +08:00
Will Miao
f0203c96ab feat: Simplify format_metadata method by removing custom_prompt parameter and update related function calls 2025-04-18 05:34:42 +08:00
Will Miao
bccabe40c0 feat: Enhance KSamplerAdvancedExtractor to include additional sampling parameters and update metadata processing 2025-04-18 05:29:36 +08:00
Will Miao
c2f599b4ff feat: Update node extractors to include UNETLoaderExtractor and enhance metadata handling for guidance parameters 2025-04-17 22:05:40 +08:00
Will Miao
5fd069d70d feat: Enhance checkpoint processing in format_metadata to handle non-string types safely 2025-04-17 09:38:20 +08:00
Will Miao
32d34d1748 feat: Enhance trace_node_input method with depth tracking and target class filtering; add FluxGuidanceExtractor for guidance parameter extraction 2025-04-17 08:06:21 +08:00
Will Miao
18eb605605 feat: Refactor metadata processing to use constants for category keys and improve structure 2025-04-17 06:23:31 +08:00
Will Miao
4fdc88e9e1 feat: Enhance LoraLoaderExtractor to extract base filename from lora_name input 2025-04-16 22:19:38 +08:00
Will Miao
4c69d8d3a8 feat: Integrate metadata collection in RecipeRoutes and simplify saveRecipeDirectly function 2025-04-16 22:15:46 +08:00
Will Miao
d4b2dd0ec1 refactor: Rename to_comfyui_format method to to_dict and update references in save_image.py 2025-04-16 21:42:54 +08:00
Will Miao
181f78421b feat: Standardize LoRA extraction format and enhance input handling in node extractors 2025-04-16 21:20:56 +08:00
Will Miao
8ed38527d0 feat: Implement metadata collection and processing framework with debug node for verification 2025-04-16 20:04:26 +08:00
Will Miao
c4c926070d fix: Update optimize_image method to handle image validation and error logging, and adjust metadata preservation logic. 2025-04-15 12:31:17 +08:00
Will Miao
ed87411e0d refactor: Change logging level from info to debug for service initialization and file monitoring 2025-04-15 11:48:37 +08:00
Will Miao
4ec2a448ab feat: Improve date formatting in filename generation with zero-padding and two-digit year support. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/102 2025-04-15 10:46:57 +08:00
Will Miao
73d01da94e feat: Enhance model preview version management with localStorage support 2025-04-15 10:35:50 +08:00
pixelpaws
df8e02157a Merge pull request #103 from willmiao/dev
feat: Add drag functionality for strength adjustment in LoRA entries.…
2025-04-15 08:57:52 +08:00
Will Miao
6e513ed32a feat: Add drag functionality for strength adjustment in LoRA entries. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/101 2025-04-15 08:56:19 +08:00
pixelpaws
325ef6327d Merge pull request #99 from willmiao/dev
Dev
2025-04-14 20:27:18 +08:00
Will Miao
46700e5ad0 feat: Refactor infinite scroll initialization for improved observer handling and sentinel management 2025-04-14 20:25:44 +08:00
Will Miao
d1e21fa345 feat: Implement context menus for checkpoints and recipes, including metadata refresh and NSFW level management 2025-04-14 15:37:36 +08:00
Will Miao
cede387783 Bump version to 0.8.6 in pyproject.toml 2025-04-14 08:42:00 +08:00
Will Miao
b206427d50 feat: Update README to include enhanced checkpoint management features and improved initial loading details 2025-04-14 08:40:42 +08:00
Will Miao
47d96e2037 feat: Simplify recipe page initialization and enhance error handling for recipe cache loading 2025-04-14 07:03:34 +08:00
Will Miao
e51f7cc1a7 feat: Enhance checkpoint download manager to save active folder preference and update UI accordingly 2025-04-13 22:12:18 +08:00
Will Miao
40381d4b11 feat: Optimize session management and enhance download functionality with resumable support 2025-04-13 21:51:21 +08:00
Will Miao
76fc9e5a3d feat: Add WebSocket support for checkpoint download progress and update related components 2025-04-13 21:31:01 +08:00
Will Miao
9822f2c614 feat: Add Civitai model version retrieval for Checkpoints and update error handling in download managers 2025-04-13 20:36:19 +08:00
Will Miao
8854334ab5 Add tip images 2025-04-13 18:46:44 +08:00
Will Miao
53080844d2 feat: Refactor progress bar classes for initialization component to improve clarity and avoid conflicts 2025-04-13 18:42:36 +08:00
Will Miao
76fd722e33 feat: Improve card layout by adding overflow hidden and fixing flexbox sizing issues 2025-04-13 18:20:15 +08:00
Will Miao
fa27513f76 feat: Enhance infinite scroll functionality with improved observer settings and scroll event handling 2025-04-13 17:58:14 +08:00
Will Miao
72c6f91130 feat: Update initialization component with loading progress and tips carousel 2025-04-13 14:03:02 +08:00
Will Miao
5918f35b8b feat: Add keyboard shortcuts for search input focus and selection 2025-04-13 13:12:32 +08:00
Will Miao
0b11e6e6d0 feat: Enhance initialization component with progress tracking and UI improvements 2025-04-13 12:58:38 +08:00
Will Miao
a043b487bd feat: Add initialization progress WebSocket and UI components
- Implement WebSocket route for initialization progress updates
- Create initialization component with progress bar and stages
- Add styles for initialization UI
- Update base template to include initialization component
- Enhance model scanner to broadcast progress during initialization
2025-04-13 10:41:27 +08:00
pixelpaws
3982489e67 Merge pull request #97 from willmiao/dev
feat: Enhance checkpoint handling by initializing paths and adding st…
2025-04-12 19:10:13 +08:00
Will Miao
5f3c515323 feat: Enhance checkpoint handling by initializing paths and adding static routes 2025-04-12 19:06:17 +08:00
pixelpaws
6e1297d734 Merge pull request #96 from willmiao/dev
Dev
2025-04-12 17:01:07 +08:00
Will Miao
8f3cbdd257 fix: Simplify session item retrieval in loadMoreModels function 2025-04-12 16:54:27 +08:00
Will Miao
2fc06ae64e Refactor file name update in Lora card
- Updated the setupFileNameEditing function to pass the new file name in the updates object when calling updateLoraCard.
- Removed the page reload after file name change to improve user experience.
- Enhanced the updateLoraCard function to handle the 'file_name' update, ensuring the dataset reflects the new file name correctly.
2025-04-12 16:35:35 +08:00
Will Miao
515aa1d2bd fix: Improve error logging and update lora monitor path handling 2025-04-12 16:24:29 +08:00
Will Miao
ff7a36394a refactor: Optimize event handling for folder tags using delegation 2025-04-12 16:15:29 +08:00
Will Miao
5261ab249a Fix checkpoints sort_by 2025-04-12 13:39:32 +08:00
Will Miao
c3192351da feat: Add support for reading SHA256 from .sha256 file in get_file_info function 2025-04-12 11:59:40 +08:00
Will Miao
ce30d067a6 feat: Import and expose loadMoreLoras function in LoraPageManager 2025-04-12 11:46:26 +08:00
Will Miao
e84a8a72c5 feat: Add save metadata route and update checkpoint card functionality 2025-04-12 11:18:21 +08:00
Will Miao
10a4fe04d1 refactor: Update API endpoint for saving model metadata to use consistent route structure 2025-04-12 09:03:34 +08:00
Will Miao
d5ce6441e3 refactor: Simplify service initialization in LoraRoutes and RecipeRoutes, and adjust logging level in ServiceRegistry 2025-04-12 09:01:09 +08:00
Will Miao
a8d21fb1d6 refactor: Remove unused service imports and add new route for scanning LoRA files 2025-04-12 07:49:11 +08:00
Will Miao
9277d8d8f8 refactor: Disable file monitoring functionality with ENABLE_FILE_MONITORING flag 2025-04-12 06:47:47 +08:00
Will Miao
0618541527 checkpoint 2025-04-11 20:22:12 +08:00
Will Miao
1db49a4dd4 refactor: Enhance checkpoint download functionality with new modal and manager integration 2025-04-11 18:25:37 +08:00
Will Miao
3df96034a1 refactor: Consolidate model handling functions into baseModelApi for better code reuse and organization 2025-04-11 14:35:56 +08:00
Will Miao
e991dc061d refactor: Implement common endpoint handlers for model management in ModelRouteUtils and update routes in CheckpointsRoutes 2025-04-11 12:06:05 +08:00
Will Miao
56670066c7 refactor: Optimize preview image handling by converting to webp format and improving error logging 2025-04-11 11:17:49 +08:00
Will Miao
31d27ff3fa refactor: Extract model-related utility functions into ModelRouteUtils for better code organization 2025-04-11 10:54:19 +08:00
Will Miao
297ff0dd25 refactor: Improve download handling for previews and optimize image conversion in DownloadManager 2025-04-11 09:00:58 +08:00
Will Miao
b0a5b48fb2 refactor: Enhance preview file handling and add update_preview_in_cache method for ModelScanner 2025-04-11 08:43:21 +08:00
Will Miao
ac244e6ad9 refactor: Replace hardcoded image width with CARD_PREVIEW_WIDTH constant for consistency 2025-04-11 08:19:19 +08:00
Will Miao
7393e92b21 refactor: Consolidate preview file extensions into constants for improved maintainability 2025-04-11 06:19:15 +08:00
Will Miao
86810d9f03 refactor: Remove move_model method from LoraScanner class to streamline code 2025-04-11 06:05:19 +08:00
Will Miao
18aa8d11ad refactor: Remove showToast call from clearCustomFilter method in LorasControls 2025-04-11 05:59:32 +08:00
Will Miao
fafec56f09 refactor: Rename update_single_lora_cache to update_single_model_cache for consistency 2025-04-11 05:52:56 +08:00
Will Miao
129ca9da81 feat: Implement checkpoint modal functionality with metadata editing, showcase display, and utility functions
- Added ModelMetadata.js for handling model metadata editing, including model name, base model, and file name.
- Introduced ShowcaseView.js to manage the display of images and videos in the checkpoint modal, including NSFW filtering and lazy loading.
- Created index.js as the main entry point for the checkpoint modal, integrating various components and functionalities.
- Developed utils.js for utility functions related to file size formatting and tag rendering.
- Enhanced user experience with editable fields, toast notifications, and improved showcase scrolling.
2025-04-10 22:59:09 +08:00
Will Miao
cbfb9ac87c Enhance CheckpointModal: Implement detailed checkpoint display, editable fields, and showcase functionality 2025-04-10 22:25:40 +08:00
Will Miao
42309edef4 Refactor visibility toggle: Remove toggleApiKeyVisibility function and update related button in modals 2025-04-10 21:43:56 +08:00
Will Miao
559e57ca46 Enhance CheckpointCard: Implement NSFW content handling, toggle blur functionality, and improve video autoplay behavior 2025-04-10 21:28:34 +08:00
Will Miao
311bf1f157 Add support for '.gguf' file extension in CheckpointScanner 2025-04-10 21:15:12 +08:00
Will Miao
131c3cc324 Add Civitai metadata fetching functionality for checkpoints
- Implement fetchCivitai API method to retrieve metadata from Civitai.
- Enhance CheckpointsControls to include fetch from Civitai functionality.
- Update PageControls to register fetch from Civitai event listener for both LoRAs and Checkpoints.
2025-04-10 21:07:17 +08:00
Will Miao
152ec0da0d Refactor Checkpoints functionality: Integrate loadMoreCheckpoints API, remove CheckpointSearchManager, and enhance FilterManager for improved checkpoint loading and filtering. 2025-04-10 19:57:04 +08:00
Will Miao
ee04df40c3 Refactor controls and pagination for Checkpoints and LoRAs: Implement unified PageControls, enhance API integration, and improve event handling for better user experience. 2025-04-10 19:41:02 +08:00
Will Miao
252e90a633 Enhance Checkpoints Manager: Implement API integration for checkpoints, add filtering and sorting options, and improve UI components for better user experience 2025-04-10 16:04:08 +08:00
Will Miao
048d486fa6 Refactor cache initialization in LoraManager and RecipeScanner for improved background processing and error handling 2025-04-10 11:34:19 +08:00
Will Miao
8fdfb68741 checkpoint 2025-04-10 09:08:51 +08:00
Will Miao
64c9e4aeca Update version to 0.8.5 and add release notes for enhanced features and improvements 2025-04-09 11:41:38 +08:00
Will Miao
08b90e8767 Update toast messages to clarify settings update notifications 2025-04-09 11:29:02 +08:00
Will Miao
0206613f9e Update NSFW level filter to include 'R' rating for improved content moderation 2025-04-09 11:25:52 +08:00
Will Miao
ae0629628e Enhance settings modal with video autoplay on hover option and improve layout. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/92 2025-04-09 11:18:30 +08:00
Will Miao
785b2e7287 style: Add padding to recipe list to prevent item cutoff on hover 2025-04-08 13:51:00 +08:00
Will Miao
43e3d0552e style: Update filter indicator and button styles for improved UI consistency
feat: Add pulse animation to filter indicators in Lora and recipe management
refactor: Change filter-active button to a div for better semantic structure
2025-04-08 13:45:15 +08:00
Will Miao
801aa2e876 Enhance Lora and recipe integration with improved filtering and UI updates
- Added support for filtering LoRAs by hash in both API and UI components.
- Implemented session storage management for custom filter states when navigating between recipes and LoRAs.
- Introduced a new button in the recipe modal to view associated LoRAs, enhancing user navigation.
- Updated CSS styles for new UI elements, including a custom filter indicator and LoRA view button.
- Refactored existing JavaScript components to streamline the handling of filter parameters and improve maintainability.
2025-04-08 12:23:51 +08:00
Will Miao
bddc7a438d feat: Add Lora recipes retrieval and filtering functionality
- Implemented a new API endpoint to fetch recipes associated with a specific Lora by its hash.
- Enhanced the recipe scanning logic to support filtering by Lora hash and bypassing other filters.
- Added a new method to retrieve a recipe by its ID with formatted metadata.
- Created a new RecipeTab component to display recipes in the Lora modal.
- Introduced session storage utilities for managing custom filter states.
- Updated the UI to include a custom filter indicator and loading/error states for recipes.
- Refactored existing recipe management logic to accommodate new features and improve maintainability.
2025-04-07 21:53:39 +08:00
Will Miao
b8c78a68e7 refactor: remove unused recipe card CSS styles 2025-04-07 20:36:58 +08:00
Will Miao
49219f4447 feat: Refactor LoraModal into modular components
- Added ShowcaseView.js for rendering LoRA model showcase content with NSFW filtering and lazy loading.
- Introduced TriggerWords.js to manage trigger words, including editing, adding, and saving functionality.
- Created index.js as the main entry point for the LoraModal, integrating all components and functionalities.
- Implemented utils.js for utility functions such as file size formatting and tag rendering.
- Enhanced user experience with editable fields, tooltips, and improved event handling for trigger words and presets.
2025-04-07 15:36:13 +08:00
Will Miao
59b1abb719 Update version to 0.8.4 and add release notes for node layout improvements and bug fixes 2025-04-07 14:49:34 +08:00
Will Miao
3e2cfb552b Refactor image saving logic for batch processing and unique filename generation. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/79 2025-04-07 14:37:39 +08:00
Will Miao
779be1b8d0 Refactor loras_widget styles for improved layout consistency 2025-04-07 13:42:31 +08:00
Will Miao
faf74de238 Enhance model move functionality with detailed error handling and user feedback 2025-04-07 11:14:56 +08:00
Will Miao
50a51c2e79 Refactor Lora widget and dynamic module loading
- Updated lora_loader.js to dynamically import the appropriate loras widget based on ComfyUI version, enhancing compatibility and maintainability.
- Enhanced loras_widget.js with improved height management and styling for better user experience.
- Introduced utility functions in utils.js for version checking and dynamic imports, streamlining widget loading processes.
- Improved overall structure and readability of the code, ensuring better performance and easier future updates.
2025-04-07 09:02:36 +08:00
Will Miao
d31e641496 Add dynamic tags widget selection based on ComfyUI version
- Introduced a mechanism to dynamically import either the legacy or modern tags widget based on the ComfyUI frontend version.
- Updated the `addTagsWidget` function in both `tags_widget.js` and `legacy_tags_widget.js` to enhance tag rendering and widget height management.
- Improved styling and layout for tags, ensuring better alignment and responsiveness.
- Added a new serialization method to handle potential issues with ComfyUI's serialization process.
- Enhanced the overall user experience by providing a more modern and flexible tags widget implementation.
2025-04-07 08:42:20 +08:00
Will Miao
f2d36f5be9 Refactor DownloadManager and LoraFileHandler for improved file monitoring
- Simplified the path handling in DownloadManager by directly adding normalized paths to the ignore list.
- Updated LoraFileHandler to utilize a set for ignore paths, enhancing performance and clarity.
- Implemented debouncing for modified file events to prevent duplicate processing and improve efficiency.
- Enhanced the handling of file creation, modification, and deletion events for .safetensors files, ensuring accurate processing and logging.
- Adjusted cache operations to streamline the addition and removal of files based on real paths.
2025-04-06 22:27:55 +08:00
Will Miao
0b55f61fac Refactor LoraFileHandler to use real file paths for monitoring
- Updated the file monitoring logic to store and verify real file paths instead of mapped paths, ensuring accurate existence checks.
- Enhanced logging for error handling and processing actions, including detailed error messages with exception info.
- Adjusted cache operations to reflect the use of normalized paths for consistency in add/remove actions.
- Improved handling of ignore paths by removing successfully processed files from the ignore list.
2025-04-05 12:10:46 +08:00
pixelpaws
4156dcbafd Merge pull request #83 from willmiao/dev
Dev
2025-04-05 05:28:22 +08:00
Will Miao
36e6ac2362 Add CheckpointMetadata class for enhanced model metadata management
- Introduced a new CheckpointMetadata dataclass to encapsulate metadata for checkpoint models.
- Included fields for file details, model specifications, and additional attributes such as resolution and architecture.
- Implemented a __post_init__ method to initialize tags as an empty list if not provided, ensuring consistent data handling.
2025-04-05 05:16:52 +08:00
Will Miao
9613199152 Enhance SaveImage functionality with custom prompt support
- Added a new optional parameter `custom_prompt` to the SaveImage class methods to allow users to override the default prompt.
- Updated the `format_metadata` method to utilize the custom prompt if provided.
- Modified the `save_images` and `process_image` methods to accept and pass the custom prompt through the workflow processing.
2025-04-04 07:47:46 +08:00
pixelpaws
14328d7496 Merge pull request #77 from willmiao/dev
Add reconnect functionality for deleted LoRAs in recipe modal
2025-04-03 16:56:04 +08:00
Will Miao
6af12d1acc Add reconnect functionality for deleted LoRAs in recipe modal
- Introduced a new API endpoint to reconnect deleted LoRAs to local files.
- Updated RecipeModal to include UI elements for reconnecting LoRAs, including input fields and buttons.
- Enhanced CSS styles for deleted badges and reconnect containers to improve user experience.
- Implemented event handling for reconnect actions, including input validation and API calls.
- Updated recipe data handling to reflect changes after reconnecting LoRAs.
2025-04-03 16:55:19 +08:00
pixelpaws
9b44e49879 Merge pull request #75 from willmiao/dev
Enhance file monitoring for LoRA files
2025-04-03 11:10:29 +08:00
Will Miao
afee18f146 Enhance file monitoring for LoRA files
- Added a method to map symbolic links back to actual paths in the Config class.
- Improved file creation handling in LoraFileHandler to check for file size and existence before processing.
- Introduced handling for file modification events to update the ignore list and schedule updates.
- Increased debounce delay in _process_changes to allow for file downloads to complete.
- Enhanced action processing to prioritize 'add' actions and verify file existence before adding to cache.
2025-04-03 11:09:30 +08:00
Will Miao
f007369a66 Bump version to v0.8.3 2025-04-02 20:18:51 +08:00
pixelpaws
9a9c166dbe Merge pull request #74 from willmiao/dev
Dev
2025-04-02 20:15:11 +08:00
Will Miao
2f90e32dbf Delete unused files 2025-04-02 20:11:41 +08:00
Will Miao
26355ccb79 chore: remove .vscode from git 2025-04-02 20:09:58 +08:00
Will Miao
27ea3c0c8e chore: add .vscode to gitignore 2025-04-02 20:09:08 +08:00
Will Miao
5aa35b211a Update README and update_logs 2025-04-02 20:03:18 +08:00
Will Miao
92450385d2 Update README 2025-04-02 20:00:04 +08:00
Will Miao
8d15e23f3c Add markdown support for changelog in modal
- Introduced a simple markdown parser to convert markdown syntax in changelog items to HTML.
- Updated modal CSS to style markdown elements, enhancing the presentation of changelog items.
- Improved user experience by allowing formatted text in changelog, including bold, italic, code, and links.
2025-04-02 19:36:52 +08:00
Will Miao
73686d4146 Enhance modal and settings functionality with default LoRA root selection
- Updated modal styles for improved layout and added select control for default LoRA root.
- Modified DownloadManager, ImportManager, MoveManager, and SettingsManager to retrieve and set the default LoRA root from storage.
- Introduced asynchronous loading of LoRA roots in SettingsManager to dynamically populate the select options.
- Improved user experience by allowing users to set a default LoRA root for downloads, imports, and moves.
2025-04-02 17:37:16 +08:00
Will Miao
0499ca1300 Update process_node function to ignore type checking
- Added a type: ignore comment to the process_node function to suppress type checking errors.
- Removed the README.md file as it is no longer needed.
2025-04-02 17:02:11 +08:00
Will Miao
234c942f34 Refactor transform functions and update node mappers
- Moved and redefined transform functions for KSampler, EmptyLatentImage, CLIPTextEncode, and FluxGuidance to improve organization and maintainability.
- Updated NODE_MAPPERS to include new input tracking for clip_skip in KSampler and added new transform functions for LatentUpscale and CLIPSetLastLayer.
- Enhanced the transform_sampler_custom_advanced function to handle clip_skip extraction from model inputs.
2025-04-02 17:01:10 +08:00
Will Miao
aec218ba00 Enhance SaveImage class with filename formatting and multiple image support
- Updated the INPUT_TYPES to accept multiple images and modified the corresponding processing methods.
- Introduced a new format_filename method to handle dynamic filename generation using metadata patterns.
- Replaced save_workflow_json with embed_workflow for better clarity in saving workflow metadata.
- Improved directory handling and filename generation logic to ensure proper file saving.
2025-04-02 15:08:36 +08:00
Will Miao
b508f51fcf checkpoint 2025-04-02 14:13:53 +08:00
Will Miao
435628ea59 Refactor WorkflowParser by removing unused methods 2025-04-02 14:13:24 +08:00
Will Miao
4933dbfb87 Refactor ExifUtils by removing unused methods and imports
- Removed the extract_user_comment and update_user_comment methods to streamline the ExifUtils class.
- Cleaned up unnecessary imports and reduced code complexity, focusing on essential functionality for image metadata extraction.
2025-04-02 11:14:05 +08:00
Will Miao
5a93c40b79 Refactor logging levels and improve mapper registration
- Changed warning logs to debug logs in CivitaiClient and RecipeScanner for better log granularity.
- Updated the mapper registration function name for clarity and adjusted related logging messages.
- Enhanced extension loading process to automatically register mappers from NODE_MAPPERS_EXT, improving modularity and maintainability.
2025-04-02 10:29:31 +08:00
Will Miao
a8ec5af037 checkpoint 2025-04-02 06:05:24 +08:00
Will Miao
27db60ce68 checkpoint 2025-04-01 19:17:43 +08:00
Will Miao
195866b00d Implement KJNodes extension with new mappers and transform functions
- Added KJNodes mappers for JoinStrings, StringConstantMultiline, and EmptyLatentImagePresets.
- Introduced transform functions to handle string joining, string constants, and dimension extraction with optional inversion.
- Registered new mappers and logged successful registration for better traceability.
2025-04-01 16:22:57 +08:00
Will Miao
60575b6546 checkpoint 2025-04-01 08:38:49 +08:00
pixelpaws
350b81d678 Merge pull request #64 from richardhristov/main
Remember sort by name/date in LoRAs page
2025-03-31 20:16:29 +08:00
Will Miao
cc95314dae Bump version to v0.8.2 2025-03-30 20:53:22 +08:00
Will Miao
3f97087abb Update unauthorized access error message 2025-03-30 20:15:50 +08:00
Will Miao
f04af2de21 Add Civitai model retrieval and missing LoRAs download functionality
- Introduced new API endpoints for fetching Civitai model details by model version ID or hash.
- Enhanced the download manager to support downloading LoRAs using model version ID or hash, improving flexibility.
- Updated RecipeModal to handle missing LoRAs, allowing users to download them directly from the recipe interface.
- Added tooltip and click functionality for missing LoRAs status, enhancing user experience.
- Improved error handling for missing LoRAs download process, providing clearer feedback to users.
2025-03-30 19:45:03 +08:00
Richard Hristov
e7871bf843 Remember sort by name/date in LoRAs page 2025-03-29 17:11:53 +02:00
Will Miao
8e3308039a Refactor Lora handling in RecipeRoutes and enhance RecipeManager
- Updated Lora filtering logic in RecipeRoutes to skip deleted LoRAs without exclusion checks, improving performance and clarity.
- Enhanced condition for fetching cached LoRAs to ensure valid data is processed.
- Added toggleApiKeyVisibility function to RecipeManager, improving API key management in the UI.
2025-03-29 19:11:13 +08:00
Will Miao
b65350b7cb Add update functionality for recipe metadata in RecipeRoutes and RecipeModal
- Introduced a new API endpoint to update recipe metadata, allowing users to modify recipe titles and tags.
- Enhanced RecipeModal to support inline editing of recipe titles and tags, improving user interaction.
- Updated RecipeCard to reflect changes in recipe metadata, ensuring consistency across the application.
- Improved error handling for metadata updates to provide clearer feedback to users.
2025-03-29 18:46:19 +08:00
Will Miao
069ebce895 Add recipe syntax endpoint and update RecipeCard and RecipeModal for syntax fetching
- Introduced a new API endpoint to retrieve recipe syntax for LoRAs, allowing for better integration with the frontend.
- Updated RecipeCard to fetch recipe syntax from the backend instead of generating it locally.
- Modified RecipeModal to store the recipe ID and fetch syntax when the copy button is clicked, improving user experience.
- Enhanced error handling for fetching recipe syntax to provide clearer feedback to users.
2025-03-29 15:38:49 +08:00
Will Miao
63aa4e188e Add rename functionality for LoRA files and enhance UI for editing file names
- Introduced a new API endpoint to rename LoRA files, including validation and error handling for file paths and names.
- Updated the RecipeScanner to reflect changes in LoRA filenames across recipe files and cache.
- Enhanced the LoraModal UI to allow inline editing of file names with improved user interaction and validation.
- Added CSS styles for the editing interface to improve visual feedback during file name editing.
2025-03-29 09:25:41 +08:00
Will Miao
c31c9c16cf Enhance LoraScanner and file_utils for improved metadata handling
- Updated LoraScanner to first attempt to create metadata from .civitai.info files, improving metadata extraction from existing files.
- Added error handling for reading .civitai.info files and fallback to generating metadata using get_file_info if necessary.
- Refactored file_utils to expose find_preview_file function and added logic to utilize SHA256 from existing .json files to avoid recalculation.
- Improved overall robustness of metadata loading and preview file retrieval processes.
2025-03-28 16:27:59 +08:00
Will Miao
5a8a402fdc Enhance LoraRoutes and templates for improved cache initialization handling
- Updated LoraRoutes to better check cache initialization status and handle loading states.
- Added logging for successful cache loading and error handling for cache retrieval failures.
- Enhanced base.html and loras.html templates to display a loading spinner and initialization notice during cache setup.
- Improved user experience by ensuring the loading notice is displayed appropriately based on initialization state.
2025-03-28 15:04:35 +08:00
Will Miao
85c3e33343 Update version to 0.8.1 and add release notes for new features and improvements
- Bump version from 0.8.0 to 0.8.1 in pyproject.toml.
- Document new features in README.md, including base model correction, LoRA loader flexibility, expanded recipe support, enhanced showcase images, and various UI improvements and bug fixes.
2025-03-28 04:15:54 +08:00
Will Miao
1420ab31a2 Enhance CivitaiClient error handling for unauthorized access
- Updated handling of 401 unauthorized responses to differentiate between API key issues and early access restrictions.
- Improved logging for unauthorized access attempts.
- Refactored condition to check for early access restrictions based on response headers.
- Adjusted logic in DownloadManager to check for early access using a more concise method.
2025-03-28 04:11:08 +08:00
Will Miao
fd1435537f Add ImageSaverMetadataParser for ComfyUI Image Saver plugin metadata handling
- Introduced ImageSaverMetadataParser class to parse metadata from the Image Saver plugin format.
- Implemented methods to extract prompts, negative prompts, and LoRA information, including weights and hashes.
- Enhanced error handling and logging for metadata parsing failures.
- Updated RecipeParserFactory to include ImageSaverMetadataParser for relevant user comments.
2025-03-28 03:27:35 +08:00
Will Miao
4e0473ce11 Fix redownloading loras issue 2025-03-28 02:53:30 +08:00
Will Miao
450592b0d4 Implement Civitai data population methods for LoRA and checkpoint entries
- Added `populate_lora_from_civitai` and `populate_checkpoint_from_civitai` methods to enhance the extraction of model information from Civitai API responses.
- These methods populate LoRA and checkpoint entries with relevant data such as model name, version, thumbnail URL, base model, download URL, and file details.
- Improved error handling and logging for scenarios where models are not found or data retrieval fails.
- Refactored existing code to utilize the new methods, streamlining the process of fetching and updating LoRA and checkpoint metadata.
2025-03-28 02:16:53 +08:00
Will Miao
7cae0ee169 Enhance LoraModal to include image metadata panel
- Added a new image metadata panel to display generation parameters and prompts for images and videos.
- Implemented styles for the metadata panel in lora-modal.css, ensuring it is responsive and visually integrated.
- Introduced functionality to copy prompts to the clipboard and handle metadata interactions within the modal.
- Updated media rendering logic in LoraModal.js to incorporate metadata display and improve user experience.
2025-03-27 20:09:48 +08:00
Will Miao
ecd0e05f79 Add MetaFormatParser for Lora_N Model hash format metadata handling
- Introduced MetaFormatParser class to parse metadata from images with Lora_N Model hash format.
- Implemented methods to validate metadata structure, extract prompts, negative prompts, and LoRA information.
- Enhanced error handling and logging for metadata parsing failures.
- Updated RecipeParserFactory to include MetaFormatParser for relevant user comments.
2025-03-27 17:28:11 +08:00
Will Miao
6e3b4178ac Enhance LoraStacker to return active LoRAs in stack_loras method
- Updated RETURN_TYPES and RETURN_NAMES to include active LoRAs.
- Introduced active_loras list to track active LoRAs and their strengths.
- Formatted active_loras for return as a string in the format <lora:lora_name:strength>.
2025-03-27 16:10:50 +08:00
Will Miao
ba18cbabfd Add ComfyMetadataParser for Civitai ComfyUI metadata handling
- Introduced ComfyMetadataParser class to parse metadata from Civitai ComfyUI JSON format.
- Implemented methods to validate metadata structure, extract LoRA and checkpoint information, and retrieve additional model details from Civitai.
- Enhanced error handling and logging for metadata parsing failures.
- Updated RecipeParserFactory to prioritize ComfyMetadataParser for valid JSON inputs.
2025-03-27 15:43:58 +08:00
Will Miao
dec757c23b Refactor image metadata handling in RecipeRoutes and ExifUtils
- Replaced the download function for images from Twitter to Civitai in recipe_routes.py.
- Updated metadata extraction from user comments to a more comprehensive image metadata extraction method in ExifUtils.
- Enhanced the appending of recipe metadata to utilize the new metadata extraction method.
- Added a new utility function to download images from Civitai.
2025-03-27 14:56:37 +08:00
Will Miao
0459710c9b Made CLIP input optional in LoRA Loader, enabling compatibility with Hunyuan workflows 2025-03-26 21:50:26 +08:00
Will Miao
83582ef8a3 Refactor RecipeScanner to remove custom async timeout and streamline cache initialization
- Removed the custom async_timeout function and replaced it with direct usage of the initialization lock.
- Simplified the cache initialization process by eliminating the dependency on the lora scanner.
- Enhanced error handling during cache initialization to ensure a fallback to an empty cache on failure.
2025-03-26 18:56:18 +08:00
Will Miao
0dc396e148 Enhance RecipeModal to support video previews
- Updated RecipeModal.js to dynamically handle video and image previews based on the file type.
- Modified recipe-modal.css to ensure proper styling for both images and videos.
- Adjusted recipe_modal.html to accommodate the new media handling structure.
2025-03-26 16:39:53 +08:00
pixelpaws
86958e1420 Merge pull request #51 from AlUlkesh/main
Python < 3.11 backward compatibility for timeout.
2025-03-26 10:47:23 +08:00
Will Miao
c5b8e629fb Enhance save functionality in LoraModal for base model editing
- Added a check to prevent saving if the base model value has not changed.
- Stored the original value during editing to compare with the new selection.
- Updated the saveBaseModel function to accept the original value for comparison.
2025-03-26 07:05:32 +08:00
Will Miao
b0a495b4f6 Add base model editing functionality to LoraModal
- Introduced new styles for base model display and editing in lora-modal.css.
- Enhanced LoraModal.js to support editing of the base model with a dropdown selector.
- Implemented save functionality for the updated base model, including UI interactions for editing and saving changes.
2025-03-26 06:49:33 +08:00
Will Miao
7d2809467b Update tutorial video link 2025-03-25 14:10:13 +08:00
AlUlkesh
509e513f3a Python < 3.11 backward compatibility for timeout. 2025-03-24 14:16:46 +01:00
153 changed files with 22505 additions and 6050 deletions

3
.gitignore vendored
View File

@@ -1,4 +1,5 @@
__pycache__/ __pycache__/
settings.json settings.json
output/* output/*
py/run_test.py py/run_test.py
.vscode/

116
README.md
View File

@@ -6,7 +6,7 @@
[![Release](https://img.shields.io/github/v/release/willmiao/ComfyUI-Lora-Manager?include_prereleases&color=blue&logo=github)](https://github.com/willmiao/ComfyUI-Lora-Manager/releases) [![Release](https://img.shields.io/github/v/release/willmiao/ComfyUI-Lora-Manager?include_prereleases&color=blue&logo=github)](https://github.com/willmiao/ComfyUI-Lora-Manager/releases)
[![Release Date](https://img.shields.io/github/release-date/willmiao/ComfyUI-Lora-Manager?color=green&logo=github)](https://github.com/willmiao/ComfyUI-Lora-Manager/releases) [![Release Date](https://img.shields.io/github/release-date/willmiao/ComfyUI-Lora-Manager?color=green&logo=github)](https://github.com/willmiao/ComfyUI-Lora-Manager/releases)
A comprehensive toolset that streamlines organizing, downloading, and applying LoRA models in ComfyUI. With powerful features like recipe management and one-click workflow integration, working with LoRAs becomes faster, smoother, and significantly easier. Access the interface at: `http://localhost:8188/loras` A comprehensive toolset that streamlines organizing, downloading, and applying LoRA models in ComfyUI. With powerful features like recipe management, checkpoint organization, and one-click workflow integration, working with models becomes faster, smoother, and significantly easier. Access the interface at: `http://localhost:8188/loras`
![Interface Preview](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/static/images/screenshot.png) ![Interface Preview](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/static/images/screenshot.png)
@@ -14,11 +14,63 @@ A comprehensive toolset that streamlines organizing, downloading, and applying L
Watch this quick tutorial to learn how to use the new one-click LoRA integration feature: Watch this quick tutorial to learn how to use the new one-click LoRA integration feature:
[![One-Click LoRA Integration Tutorial](https://img.youtube.com/vi/qS95OjX3e70/0.jpg)](https://youtu.be/qS95OjX3e70) [![One-Click LoRA Integration Tutorial](https://img.youtube.com/vi/qS95OjX3e70/0.jpg)](https://youtu.be/qS95OjX3e70)
[![LoRA Manager v0.8.0 - New Recipe Feature & Bulk Operations](https://img.youtube.com/vi/noN7f_ER7yo/0.jpg)](https://youtu.be/noN7f_ER7yo)
--- ---
## Release Notes ## Release Notes
### v0.8.8
* **Real-time TriggerWord Updates** - Enhanced TriggerWord Toggle node to instantly update when connected Lora Loader or Lora Stacker nodes change, without requiring workflow execution
* **Optimized Metadata Recovery** - Improved utilization of existing .civitai.info files for faster initialization and preservation of metadata from models deleted from CivitAI
* **Migration Acceleration** - Further speed improvements for users transitioning from A1111/Forge environments
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.7
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
* **Metadata Collector Overhaul** - Rebuilt metadata collection system with optimized architecture for better performance
* **Improved Save Image Node** - Enhanced metadata capture and image saving performance with the new metadata collector
* **Streamlined Recipe Saving** - Optimized Save Recipe functionality to work independently without requiring Preview Image nodes
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.6 Major Update
* **Checkpoint Management** - Added comprehensive management for model checkpoints including scanning, searching, filtering, and deletion
* **Enhanced Metadata Support** - New capabilities for retrieving and managing checkpoint metadata with improved operations
* **Improved Initial Loading** - Optimized cache initialization with visual progress indicators for better user experience
### v0.8.5
* **Enhanced LoRA & Recipe Connectivity** - Added Recipes tab in LoRA details to see all recipes using a specific LoRA
* **Improved Navigation** - New shortcuts to jump between related LoRAs and Recipes with one-click navigation
* **Video Preview Controls** - Added "Autoplay Videos on Hover" setting to optimize performance and reduce resource usage
* **UI Experience Refinements** - Smoother transitions between related content pages
### v0.8.4
* **Node Layout Improvements** - Fixed layout issues with LoRA Loader and Trigger Words Toggle nodes in newer ComfyUI frontend versions
* **Recipe LoRA Reconnection** - Added ability to reconnect deleted LoRAs in recipes by clicking the "deleted" badge in recipe details
* **Bug Fixes & Stability** - Resolved various issues for improved reliability
### v0.8.3
* **Enhanced Workflow Parser** - Rebuilt workflow analysis engine with improved support for ComfyUI core nodes and easier extensibility
* **Improved Recipe System** - Refined the experimental Save Recipe functionality with better workflow integration
* **New Save Image Node** - Added experimental node with metadata support for perfect CivitAI compatibility
* Supports dynamic filename prefixes with variables [1](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData?tab=readme-ov-file#filename_prefix)
* **Default LoRA Root Setting** - Added configuration option for setting your preferred LoRA directory
### v0.8.2
* **Faster Initialization for Forge Users** - Improved first-run efficiency by utilizing existing `.json` and `.civitai.info` files from Forges CivitAI helper extension, making migration smoother.
* **LoRA Filename Editing** - Added support for renaming LoRA files directly within LoRA Manager.
* **Recipe Editing** - Users can now edit recipe names and tags.
* **Retain Deleted LoRAs in Recipes** - Deleted LoRAs will remain listed in recipes, allowing future functionality to reconnect them once re-obtained.
* **Download Missing LoRAs from Recipes** - Easily fetch missing LoRAs associated with a recipe.
### v0.8.1
* **Base Model Correction** - Added support for modifying base model associations to fix incorrect metadata for non-CivitAI LoRAs
* **LoRA Loader Flexibility** - Made CLIP input optional for model-only workflows like Hunyuan video generation
* **Expanded Recipe Support** - Added compatibility with 3 additional recipe metadata formats
* **Enhanced Showcase Images** - Generation parameters now displayed alongside LoRA preview images
* **UI Improvements & Bug Fixes** - Various interface refinements and stability enhancements
### v0.8.0 ### v0.8.0
* **Introduced LoRA Recipes** - Create, import, save, and share your favorite LoRA combinations * **Introduced LoRA Recipes** - Create, import, save, and share your favorite LoRA combinations
* **Recipe Management System** - Easily browse, search, and organize your LoRA recipes * **Recipe Management System** - Easily browse, search, and organize your LoRA recipes
@@ -27,52 +79,6 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
* **Enhanced UI & UX** - Improved interface design and user experience * **Enhanced UI & UX** - Improved interface design and user experience
* **Bug Fixes & Stability** - Resolved various issues and enhanced overall performance * **Bug Fixes & Stability** - Resolved various issues and enhanced overall performance
### v0.7.37
* Added NSFW content control settings (blur mature content and SFW-only filter)
* Implemented intelligent blur effects for previews and showcase media
* Added manual content rating option through context menu
* Enhanced user experience with configurable content visibility
* Fixed various bugs and improved stability
### v0.7.36
* Enhanced LoRA details view with model descriptions and tags display
* Added tag filtering system for improved model discovery
* Implemented editable trigger words functionality
* Improved TriggerWord Toggle node with new group mode option for granular control
* Added new Lora Stacker node with cross-compatibility support (works with efficiency nodes, ComfyRoll, easy-use, etc.)
* Fixed several bugs
### v0.7.35-beta
* Added base model filtering
* Implemented bulk operations (copy syntax, move multiple LoRAs)
* Added ability to edit LoRA model names in details view
* Added update checker with notification system
* Added support modal for user feedback and community links
### v0.7.33
* Enhanced LoRA Loader node with visual strength adjustment widgets
* Added toggle switches for LoRA enable/disable
* Implemented image tooltips for LoRA preview
* Added TriggerWord Toggle node with visual word selection
* Fixed various bugs and improved stability
### v0.7.3
* Added "Lora Loader (LoraManager)" custom node for workflows
* Implemented one-click LoRA integration
* Added direct copying of LoRA syntax from manager interface
* Added automatic preset strength value application
* Added automatic trigger word loading
### v0.7.0
* Added direct CivitAI integration for downloading LoRAs
* Implemented version selection for model downloads
* Added target folder selection for downloads
* Added context menu with quick actions
* Added force refresh for CivitAI data
* Implemented LoRA movement between folders
* Added personal usage tips and notes for LoRAs
* Improved performance for details window
[View Update History](./update_logs.md) [View Update History](./update_logs.md)
--- ---
@@ -105,6 +111,12 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
- Trigger words at a glance - Trigger words at a glance
- One-click workflow integration with preset values - One-click workflow integration with preset values
- 🔄 **Checkpoint Management**
- Scan and organize checkpoint models
- Filter and search your collection
- View and edit metadata
- Clean up and manage disk space
- 🧩 **LoRA Recipes** - 🧩 **LoRA Recipes**
- Save and share favorite LoRA combinations - Save and share favorite LoRA combinations
- Preserve generation parameters for future reference - Preserve generation parameters for future reference
@@ -116,6 +128,7 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
- Context menu for quick actions - Context menu for quick actions
- Custom notes and usage tips - Custom notes and usage tips
- Multi-folder support - Multi-folder support
- Visual progress indicators during initialization
--- ---
@@ -156,6 +169,15 @@ pip install requirements.txt
--- ---
## Credits
This project has been inspired by and benefited from other excellent ComfyUI extensions:
- [ComfyUI-SaveImageWithMetaData](https://github.com/Comfy-Community/ComfyUI-SaveImageWithMetaData) - For the image metadata functionality
- [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) - For the lora loader functionality
---
## Contributing ## Contributing
If you have suggestions, bug reports, or improvements, feel free to open an issue or contribute directly to the codebase. Pull requests are always welcome! If you have suggestions, bug reports, or improvements, feel free to open an issue or contribute directly to the codebase. Pull requests are always welcome!

View File

@@ -2,17 +2,24 @@ from .py.lora_manager import LoraManager
from .py.nodes.lora_loader import LoraManagerLoader from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker from .py.nodes.lora_stacker import LoraStacker
# from .py.nodes.save_image import SaveImage from .py.nodes.save_image import SaveImage
from .py.nodes.debug_metadata import DebugMetadata
# Import metadata collector to install hooks on startup
from .py.metadata_collector import init as init_metadata_collector
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader, LoraManagerLoader.NAME: LoraManagerLoader,
TriggerWordToggle.NAME: TriggerWordToggle, TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker, LoraStacker.NAME: LoraStacker,
# SaveImage.NAME: SaveImage SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata
} }
WEB_DIRECTORY = "./web/comfyui" WEB_DIRECTORY = "./web/comfyui"
# Initialize metadata collector
init_metadata_collector()
# Register routes on import # Register routes on import
LoraManager.add_routes() LoraManager.add_routes()
__all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY'] __all__ = ['NODE_CLASS_MAPPINGS', 'WEB_DIRECTORY']

View File

@@ -17,6 +17,7 @@ class Config:
# 静态路由映射字典, target to route mapping # 静态路由映射字典, target to route mapping
self._route_mappings = {} self._route_mappings = {}
self.loras_roots = self._init_lora_paths() self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = self._init_checkpoint_paths()
self.temp_directory = folder_paths.get_temp_directory() self.temp_directory = folder_paths.get_temp_directory()
# 在初始化时扫描符号链接 # 在初始化时扫描符号链接
self._scan_symbolic_links() self._scan_symbolic_links()
@@ -39,9 +40,12 @@ class Config:
return False return False
def _scan_symbolic_links(self): def _scan_symbolic_links(self):
"""扫描所有 LoRA 根目录中的符号链接""" """扫描所有 LoRA 和 Checkpoint 根目录中的符号链接"""
for root in self.loras_roots: for root in self.loras_roots:
self._scan_directory_links(root) self._scan_directory_links(root)
for root in self.checkpoints_roots:
self._scan_directory_links(root)
def _scan_directory_links(self, root: str): def _scan_directory_links(self, root: str):
"""递归扫描目录中的符号链接""" """递归扫描目录中的符号链接"""
@@ -73,7 +77,7 @@ class Config:
"""添加静态路由映射""" """添加静态路由映射"""
normalized_path = os.path.normpath(path).replace(os.sep, '/') normalized_path = os.path.normpath(path).replace(os.sep, '/')
self._route_mappings[normalized_path] = route self._route_mappings[normalized_path] = route
logger.info(f"Added route mapping: {normalized_path} -> {route}") # logger.info(f"Added route mapping: {normalized_path} -> {route}")
def map_path_to_link(self, path: str) -> str: def map_path_to_link(self, path: str) -> str:
"""将目标路径映射回符号链接路径""" """将目标路径映射回符号链接路径"""
@@ -85,18 +89,66 @@ class Config:
mapped_path = normalized_path.replace(target_path, link_path, 1) mapped_path = normalized_path.replace(target_path, link_path, 1)
return mapped_path return mapped_path
return path return path
def map_link_to_path(self, link_path: str) -> str:
"""将符号链接路径映射回实际路径"""
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
# 检查路径是否包含在任何映射的目标路径中
for target_path, link_path in self._path_mappings.items():
if normalized_link.startswith(target_path):
# 如果路径以目标路径开头,则替换为实际路径
mapped_path = normalized_link.replace(target_path, link_path, 1)
return mapped_path
return link_path
def _init_lora_paths(self) -> List[str]: def _init_lora_paths(self) -> List[str]:
"""Initialize and validate LoRA paths from ComfyUI settings""" """Initialize and validate LoRA paths from ComfyUI settings"""
paths = sorted(set(path.replace(os.sep, "/") raw_paths = folder_paths.get_folder_paths("loras")
for path in folder_paths.get_folder_paths("loras")
if os.path.exists(path)), key=lambda p: p.lower())
print("Found LoRA roots:", "\n - " + "\n - ".join(paths))
if not paths: # Normalize and resolve symlinks, store mapping from resolved -> original
path_map = {}
for path in raw_paths:
if os.path.exists(path):
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
path_map[real_path] = path_map.get(real_path, path) # preserve first seen
# Now sort and use only the deduplicated real paths
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
print("Found LoRA roots:", "\n - " + "\n - ".join(unique_paths))
if not unique_paths:
raise ValueError("No valid loras folders found in ComfyUI configuration") raise ValueError("No valid loras folders found in ComfyUI configuration")
# 初始化路径映射 for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return unique_paths
def _init_checkpoint_paths(self) -> List[str]:
"""Initialize and validate checkpoint paths from ComfyUI settings"""
# Get checkpoint paths from folder_paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Combine all checkpoint-related paths
all_paths = checkpoint_paths + diffusion_paths + unet_paths
# Filter and normalize paths
paths = sorted(set(path.replace(os.sep, "/")
for path in all_paths
if os.path.exists(path)), key=lambda p: p.lower())
print("Found checkpoint roots:", paths)
if not paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
return []
# 初始化路径映射,与 LoRA 路径处理方式相同
for path in paths: for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path: if real_path != path:

View File

@@ -1,16 +1,13 @@
import asyncio import asyncio
import os
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from .config import config from .config import config
from .routes.lora_routes import LoraRoutes from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes from .routes.api_routes import ApiRoutes
from .routes.recipe_routes import RecipeRoutes from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes from .routes.checkpoints_routes import CheckpointsRoutes
from .services.lora_scanner import LoraScanner from .routes.update_routes import UpdateRoutes
from .services.recipe_scanner import RecipeScanner from .routes.usage_stats_routes import UsageStatsRoutes
from .services.file_monitor import LoraFileMonitor from .services.service_registry import ServiceRegistry
from .services.lora_cache import LoraCache
from .services.recipe_cache import RecipeCache
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -23,7 +20,7 @@ class LoraManager:
"""Initialize and register all routes""" """Initialize and register all routes"""
app = PromptServer.instance.app app = PromptServer.instance.app
added_targets = set() # 用于跟踪已添加的目标路径 added_targets = set() # Track already added target paths
# Add static routes for each lora root # Add static routes for each lora root
for idx, root in enumerate(config.loras_roots, start=1): for idx, root in enumerate(config.loras_roots, start=1):
@@ -35,102 +32,143 @@ class LoraManager:
if link == root: if link == root:
real_root = target real_root = target
break break
# 为原始路径添加静态路由 # Add static route for original path
app.router.add_static(preview_path, real_root) app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}") logger.info(f"Added static route {preview_path} -> {real_root}")
# 记录路由映射 # Record route mapping
config.add_route_mapping(real_root, preview_path) config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root) added_targets.add(real_root)
# 为符号链接的目标路径添加额外的静态路由 # Add static routes for each checkpoint root
link_idx = 1 for idx, root in enumerate(config.checkpoints_roots, start=1):
preview_path = f'/checkpoints_static/root{idx}/preview'
real_root = root
if root in config._path_mappings.values():
for target, link in config._path_mappings.items():
if link == root:
real_root = target
break
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Add static routes for symlink target paths
link_idx = {
'lora': 1,
'checkpoint': 1
}
for target_path, link_path in config._path_mappings.items(): for target_path, link_path in config._path_mappings.items():
if target_path not in added_targets: if target_path not in added_targets:
route_path = f'/loras_static/link_{link_idx}/preview' # Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
link_idx["checkpoint"] += 1
else:
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
link_idx["lora"] += 1
app.router.add_static(route_path, target_path) app.router.add_static(route_path, target_path)
logger.info(f"Added static route for link target {route_path} -> {target_path}") logger.info(f"Added static route for link target {route_path} -> {target_path}")
config.add_route_mapping(target_path, route_path) config.add_route_mapping(target_path, route_path)
added_targets.add(target_path) added_targets.add(target_path)
link_idx += 1
# Add static route for plugin assets # Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path) app.router.add_static('/loras_static', config.static_path)
# Setup feature routes # Setup feature routes
routes = LoraRoutes() lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes() checkpoints_routes = CheckpointsRoutes()
# Setup file monitoring # Initialize routes
monitor = LoraFileMonitor(routes.scanner, config.loras_roots) lora_routes.setup_routes(app)
monitor.start()
routes.setup_routes(app)
checkpoints_routes.setup_routes(app) checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app, monitor) ApiRoutes.setup_routes(app)
RecipeRoutes.setup_routes(app) RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
UsageStatsRoutes.setup_routes(app) # Register usage stats routes
# Store monitor in app for cleanup # Schedule service initialization
app['lora_monitor'] = monitor app.on_startup.append(lambda app: cls._initialize_services())
# Schedule cache initialization using the application's startup handler
app.on_startup.append(lambda app: cls._schedule_cache_init(routes.scanner, routes.recipe_scanner))
# Add cleanup # Add cleanup
app.on_shutdown.append(cls._cleanup) app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup) app.on_shutdown.append(ApiRoutes.cleanup)
@classmethod @classmethod
async def _schedule_cache_init(cls, scanner: LoraScanner, recipe_scanner: RecipeScanner): async def _initialize_services(cls):
"""Schedule cache initialization in the running event loop""" """Initialize all services using the ServiceRegistry"""
try: try:
# 创建低优先级的初始化任务 # Initialize CivitaiClient first to ensure it's ready for other services
lora_task = asyncio.create_task(cls._initialize_lora_cache(scanner), name='lora_cache_init') civitai_client = await ServiceRegistry.get_civitai_client()
# Schedule recipe cache initialization with a delay to let lora scanner initialize first # Get file monitors through ServiceRegistry
recipe_task = asyncio.create_task(cls._initialize_recipe_cache(recipe_scanner, delay=2), name='recipe_cache_init') lora_monitor = await ServiceRegistry.get_lora_monitor()
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
# Start monitors
lora_monitor.start()
logger.debug("Lora monitor started")
# Make sure checkpoint monitor has paths before starting
await checkpoint_monitor.initialize_paths()
checkpoint_monitor.start()
logger.debug("Checkpoint monitor started")
# Register DownloadManager with ServiceRegistry
download_manager = await ServiceRegistry.get_download_manager()
# Initialize WebSocket manager
ws_manager = await ServiceRegistry.get_websocket_manager()
# Initialize scanners in background
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Create low-priority initialization tasks
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error scheduling cache initialization: {e}") logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
@classmethod
async def _initialize_lora_cache(cls, scanner: LoraScanner):
"""Initialize lora cache in background"""
try:
# 设置初始缓存占位
scanner._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# 分阶段加载缓存
await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing lora cache: {e}")
@classmethod
async def _initialize_recipe_cache(cls, scanner: RecipeScanner, delay: float = 2.0):
"""Initialize recipe cache in background with a delay"""
try:
# Wait for the specified delay to let lora scanner initialize first
await asyncio.sleep(delay)
# Set initial empty cache
scanner._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
# Force refresh to load the actual data
await scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"LoRA Manager: Error initializing recipe cache: {e}")
@classmethod @classmethod
async def _cleanup(cls, app): async def _cleanup(cls, app):
"""Cleanup resources""" """Cleanup resources using ServiceRegistry"""
if 'lora_monitor' in app: try:
app['lora_monitor'].stop() logger.info("LoRA Manager: Cleaning up services")
# Get monitors from ServiceRegistry
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
if lora_monitor:
lora_monitor.stop()
logger.info("Stopped LoRA monitor")
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
if checkpoint_monitor:
checkpoint_monitor.stop()
logger.info("Stopped checkpoint monitor")
# Close CivitaiClient gracefully
civitai_client = await ServiceRegistry.get_service("civitai_client")
if civitai_client:
await civitai_client.close()
logger.info("Closed CivitaiClient connection")
except Exception as e:
logger.error(f"Error during cleanup: {e}", exc_info=True)

View File

@@ -0,0 +1,18 @@
import os
import importlib
from .metadata_hook import MetadataHook
from .metadata_registry import MetadataRegistry
def init():
# Install hooks to collect metadata during execution
MetadataHook.install()
# Initialize registry
registry = MetadataRegistry()
print("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None):
"""Helper function to get metadata from the registry"""
registry = MetadataRegistry()
return registry.get_metadata(prompt_id)

View File

@@ -0,0 +1,14 @@
"""Constants used by the metadata collector"""
# Metadata collection constants
# Metadata categories
MODELS = "models"
PROMPTS = "prompts"
SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
IMAGES = "images"
# Complete list of categories to track
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]

View File

@@ -0,0 +1,123 @@
import sys
import inspect
from .metadata_registry import MetadataRegistry
class MetadataHook:
"""Install hooks for metadata collection"""
@staticmethod
def install():
"""Install hooks to collect metadata during execution"""
try:
# Import ComfyUI's execution module
execution = None
try:
# Try direct import first
import execution # type: ignore
except ImportError:
# Try to locate from system modules
for module_name in sys.modules:
if module_name.endswith('.execution'):
execution = sys.modules[module_name]
break
# If we can't find the execution module, we can't install hooks
if execution is None:
print("Could not locate ComfyUI execution module, metadata collection disabled")
return
# Store the original _map_node_over_list function
original_map_node_over_list = execution._map_node_over_list
# Define the wrapped _map_node_over_list function
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
def execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return original_execute(*args, **kwargs)
# Replace the functions
execution._map_node_over_list = map_node_over_list_with_metadata
execution.execute = execute_with_prompt_tracking
# Make map_node_over_list public to avoid it being hidden by hooks
execution.map_node_over_list = original_map_node_over_list
print("Metadata collection hooks installed for runtime values")
except Exception as e:
print(f"Error installing metadata hooks: {str(e)}")

View File

@@ -0,0 +1,245 @@
import json
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
class MetadataProcessor:
"""Process and format collected metadata"""
@staticmethod
def find_primary_sampler(metadata):
"""Find the primary KSampler node (with denoise=1)"""
primary_sampler = None
primary_sampler_id = None
# First, check for KSamplerAdvanced with add_noise="enable"
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
add_noise = parameters.get("add_noise")
# If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced
if add_noise == "enable":
primary_sampler = sampler_info
primary_sampler_id = node_id
break
# If no KSamplerAdvanced found, fall back to traditional KSampler with denoise=1
if primary_sampler is None:
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise")
# If denoise is 1.0, this is likely the primary sampler
if denoise == 1.0 or denoise == 1:
primary_sampler = sampler_info
primary_sampler_id = node_id
break
return primary_sampler_id, primary_sampler
@staticmethod
def trace_node_input(prompt, node_id, input_name, target_class=None, max_depth=10):
"""
Trace an input connection from a node to find the source node
Parameters:
- prompt: The prompt object containing node connections
- node_id: ID of the starting node
- input_name: Name of the input to trace
- target_class: Optional class name to search for (e.g., "CLIPTextEncode")
- max_depth: Maximum depth to follow the node chain to prevent infinite loops
Returns:
- node_id of the found node, or None if not found
"""
if not prompt or not prompt.original_prompt or node_id not in prompt.original_prompt:
return None
# For depth tracking
current_depth = 0
current_node_id = node_id
current_input = input_name
while current_depth < max_depth:
if current_node_id not in prompt.original_prompt:
return None
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if current_input not in node_inputs:
return None
input_value = node_inputs[current_input]
# Input connections are formatted as [node_id, output_index]
if isinstance(input_value, list) and len(input_value) >= 2:
found_node_id = input_value[0] # Connected node_id
# If we're looking for a specific node class
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class or haven't found it yet
if not target_class:
return found_node_id
# Continue tracing through intermediate nodes
current_node_id = found_node_id
# For most conditioning nodes, the input we want to follow is named "conditioning"
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
current_input = "conditioning"
else:
# If there's no "conditioning" input, we can't trace further
return found_node_id if not target_class else None
else:
# We've reached a node with no further connections
return None
current_depth += 1
# If we've reached max depth without finding target_class
return None
@staticmethod
def find_primary_checkpoint(metadata):
"""Find the primary checkpoint model in the workflow"""
if not metadata.get(MODELS):
return None
# In most workflows, there's only one checkpoint, so we can just take the first one
for node_id, model_info in metadata.get(MODELS, {}).items():
if model_info.get("type") == "checkpoint":
return model_info.get("name")
return None
@staticmethod
def extract_generation_params(metadata):
"""Extract generation parameters from metadata using node relationships"""
params = {
"prompt": "",
"negative_prompt": "",
"seed": None,
"steps": None,
"cfg_scale": None,
"guidance": None, # Add guidance parameter
"sampler": None,
"scheduler": None,
"checkpoint": None,
"loras": "",
"size": None,
"clip_skip": None
}
# Get the prompt object for node relationship tracing
prompt = metadata.get("current_prompt")
# Find the primary KSampler node
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
if checkpoint:
params["checkpoint"] = checkpoint
if primary_sampler:
# Extract sampling parameters
sampling_params = primary_sampler.get("parameters", {})
# Handle both seed and noise_seed
params["seed"] = sampling_params.get("seed") if sampling_params.get("seed") is not None else sampling_params.get("noise_seed")
params["steps"] = sampling_params.get("steps")
params["cfg_scale"] = sampling_params.get("cfg")
params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler")
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find any FluxGuidance nodes in the positive conditioning path
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
else:
# Fallback to the previous trace method if needed
latent_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "latent_image")
if latent_node_id:
# Follow chain to find EmptyLatentImage node
size_found = False
current_node_id = latent_node_id
# Limit depth to avoid infinite loops in complex workflows
max_depth = 10
for _ in range(max_depth):
if current_node_id in metadata.get(SIZE, {}):
width = metadata[SIZE][current_node_id].get("width")
height = metadata[SIZE][current_node_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
size_found = True
break
# Try to follow the chain
if prompt and prompt.original_prompt and current_node_id in prompt.original_prompt:
node_info = prompt.original_prompt[current_node_id]
if "inputs" in node_info:
# Look for a connection that might lead to size information
for input_name, input_value in node_info["inputs"].items():
if isinstance(input_value, list) and len(input_value) >= 2:
current_node_id = input_value[0]
break
else:
break # No connections to follow
else:
break # No inputs to follow
else:
break # Can't follow further
# Extract LoRAs using the standardized format
lora_parts = []
for node_id, lora_info in metadata.get(LORAS, {}).items():
# Access the lora_list from the standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
name = lora.get("name", "unknown")
strength = lora.get("strength", 1.0)
lora_parts.append(f"<lora:{name}:{strength}>")
params["loras"] = " ".join(lora_parts)
# Set default clip_skip value
params["clip_skip"] = "1" # Common default
return params
@staticmethod
def to_dict(metadata):
"""Convert extracted metadata to the ComfyUI output.json format"""
params = MetadataProcessor.extract_generation_params(metadata)
# Convert all values to strings to match output.json format
for key in params:
if params[key] is not None:
params[key] = str(params[key])
return params
@staticmethod
def to_json(metadata):
"""Convert metadata to JSON string"""
params = MetadataProcessor.to_dict(metadata)
return json.dumps(params, indent=4)

View File

@@ -0,0 +1,275 @@
import time
from nodes import NODE_CLASS_MAPPINGS
from .node_extractors import NODE_EXTRACTORS, GenericNodeExtractor
from .constants import METADATA_CATEGORIES, IMAGES
class MetadataRegistry:
"""A singleton registry to store and retrieve workflow metadata"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._reset()
return cls._instance
def _reset(self):
self.current_prompt_id = None
self.current_prompt = None
self.metadata = {}
self.prompt_metadata = {}
self.executed_nodes = set()
# Node-level cache for metadata
self.node_cache = {}
# Limit the number of stored prompts
self.max_prompt_history = 3
# Categories we want to track and retrieve from cache
self.metadata_categories = METADATA_CATEGORIES
def _clean_old_prompts(self):
"""Clean up old prompt metadata, keeping only recent ones"""
if len(self.prompt_metadata) <= self.max_prompt_history:
return
# Sort all prompt_ids by timestamp
sorted_prompts = sorted(
self.prompt_metadata.keys(),
key=lambda pid: self.prompt_metadata[pid].get("timestamp", 0)
)
# Remove oldest records
prompts_to_remove = sorted_prompts[:len(sorted_prompts) - self.max_prompt_history]
for pid in prompts_to_remove:
del self.prompt_metadata[pid]
def start_collection(self, prompt_id):
"""Begin metadata collection for a new prompt"""
self.current_prompt_id = prompt_id
self.executed_nodes = set()
self.prompt_metadata[prompt_id] = {
category: {} for category in METADATA_CATEGORIES
}
# Add additional metadata fields
self.prompt_metadata[prompt_id].update({
"execution_order": [],
"current_prompt": None, # Will store the prompt object
"timestamp": time.time()
})
# Clean up old prompt data
self._clean_old_prompts()
def set_current_prompt(self, prompt):
"""Set the current prompt object reference"""
self.current_prompt = prompt
if self.current_prompt_id and self.current_prompt_id in self.prompt_metadata:
# Store the prompt in the metadata for later relationship tracing
self.prompt_metadata[self.current_prompt_id]["current_prompt"] = prompt
def get_metadata(self, prompt_id=None):
"""Get collected metadata for a prompt"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return {}
metadata = self.prompt_metadata[key]
# If we have a current prompt object, check for non-executed nodes
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
# Fill in missing metadata from cache for nodes that weren't executed
self._fill_missing_metadata(key, original_prompt)
return self.prompt_metadata.get(key, {})
def _fill_missing_metadata(self, prompt_id, original_prompt):
"""Fill missing metadata from cache for non-executed nodes"""
if not original_prompt:
return
executed_nodes = self.executed_nodes
metadata = self.prompt_metadata[prompt_id]
# Iterate through nodes in the original prompt
for node_id, node_data in original_prompt.items():
# Skip if already executed in this run
if node_id in executed_nodes:
continue
# Get the node type from the prompt (this is the key in NODE_CLASS_MAPPINGS)
prompt_class_type = node_data.get("class_type")
if not prompt_class_type:
continue
# Convert to actual class name (which is what we use in our cache)
class_type = prompt_class_type
if prompt_class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[prompt_class_type]
class_type = class_obj.__name__
# Create cache key using the actual class name
cache_key = f"{node_id}:{class_type}"
# Check if this node type is relevant for metadata collection
if class_type in NODE_EXTRACTORS:
# Check if we have cached metadata for this node
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
# Apply cached metadata to the current metadata
for category in self.metadata_categories:
if category in cached_data and node_id in cached_data[category]:
if node_id not in metadata[category]:
metadata[category][node_id] = cached_data[category][node_id]
def record_node_execution(self, node_id, class_type, inputs, outputs):
"""Record information about a node's execution"""
if not self.current_prompt_id:
return
# Add to execution order and mark as executed
if node_id not in self.executed_nodes:
self.executed_nodes.add(node_id)
self.prompt_metadata[self.current_prompt_id]["execution_order"].append(node_id)
# Process inputs to simplify working with them
processed_inputs = {}
for input_name, input_values in inputs.items():
if isinstance(input_values, list) and len(input_values) > 0:
# For single values, just use the first one (most common case)
processed_inputs[input_name] = input_values[0]
else:
processed_inputs[input_name] = input_values
# Extract node-specific metadata
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
extractor.extract(
node_id,
processed_inputs,
outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Cache this node's metadata
self._cache_node_metadata(node_id, class_type)
def update_node_execution(self, node_id, class_type, outputs):
"""Update node metadata with output information"""
if not self.current_prompt_id:
return
# Process outputs to make them more usable
processed_outputs = outputs
# Use the same extractor to update with outputs
extractor = NODE_EXTRACTORS.get(class_type, GenericNodeExtractor)
if hasattr(extractor, 'update'):
extractor.update(
node_id,
processed_outputs,
self.prompt_metadata[self.current_prompt_id]
)
# Update the cached metadata for this node
self._cache_node_metadata(node_id, class_type)
def _cache_node_metadata(self, node_id, class_type):
"""Cache the metadata for a specific node"""
if not self.current_prompt_id or not node_id or not class_type:
return
# Create a cache key combining node_id and class_type
cache_key = f"{node_id}:{class_type}"
# Create a shallow copy of the node's metadata
node_metadata = {}
current_metadata = self.prompt_metadata[self.current_prompt_id]
for category in self.metadata_categories:
if category in current_metadata and node_id in current_metadata[category]:
if category not in node_metadata:
node_metadata[category] = {}
node_metadata[category][node_id] = current_metadata[category][node_id]
# Save to cache if we have any metadata for this node
if any(node_metadata.values()):
self.node_cache[cache_key] = node_metadata
def clear_unused_cache(self):
"""Clean up node_cache entries that are no longer in use"""
# Collect all node_ids currently in prompt_metadata
active_node_ids = set()
for prompt_data in self.prompt_metadata.values():
for category in self.metadata_categories:
if category in prompt_data:
active_node_ids.update(prompt_data[category].keys())
# Find cache keys that are no longer needed
keys_to_remove = []
for cache_key in self.node_cache:
node_id = cache_key.split(':')[0]
if node_id not in active_node_ids:
keys_to_remove.append(cache_key)
# Remove cache entries that are no longer needed
for key in keys_to_remove:
del self.node_cache[key]
def clear_metadata(self, prompt_id=None):
"""Clear metadata for a specific prompt or reset all data"""
if prompt_id is not None:
if prompt_id in self.prompt_metadata:
del self.prompt_metadata[prompt_id]
# Clean up cache after removing prompt
self.clear_unused_cache()
else:
# Reset all data
self._reset()
def get_first_decoded_image(self, prompt_id=None):
"""Get the first decoded image result"""
key = prompt_id if prompt_id is not None else self.current_prompt_id
if key not in self.prompt_metadata:
return None
metadata = self.prompt_metadata[key]
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
image_data = metadata[IMAGES]["first_decode"]["image"]
# If it's an image batch or tuple, handle various formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
# Return first element of list/tuple
return image_data[0]
# If it's a tensor, return as is for processing in the route handler
return image_data
# If no image is found in the current metadata, try to find it in the cache
# This handles the case where VAEDecode was cached by ComfyUI and not executed
prompt_obj = metadata.get("current_prompt")
if prompt_obj and hasattr(prompt_obj, "original_prompt"):
original_prompt = prompt_obj.original_prompt
for node_id, node_data in original_prompt.items():
class_type = node_data.get("class_type")
if class_type and class_type in NODE_CLASS_MAPPINGS:
class_obj = NODE_CLASS_MAPPINGS[class_type]
class_name = class_obj.__name__
# Check if this is a VAEDecode node
if class_name == "VAEDecode":
# Try to find this node in the cache
cache_key = f"{node_id}:{class_name}"
if cache_key in self.node_cache:
cached_data = self.node_cache[cache_key]
if IMAGES in cached_data and node_id in cached_data[IMAGES]:
image_data = cached_data[IMAGES][node_id]["image"]
# Handle different image formats
if isinstance(image_data, (list, tuple)) and len(image_data) > 0:
return image_data[0]
return image_data
return None

View File

@@ -0,0 +1,280 @@
import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES
class NodeMetadataExtractor:
"""Base class for node-specific metadata extraction"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
"""Extract metadata from node inputs/outputs"""
pass
@staticmethod
def update(node_id, outputs, metadata):
"""Update metadata with node outputs after execution"""
pass
class GenericNodeExtractor(NodeMetadataExtractor):
"""Default extractor for nodes without specific handling"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
class CheckpointLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "ckpt_name" not in inputs:
return
model_name = inputs.get("ckpt_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "text" not in inputs:
return
text = inputs.get("text", "")
metadata[PROMPTS][node_id] = {
"text": text,
"node_id": node_id
}
class SamplerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "lora_name" not in inputs:
return
lora_name = inputs.get("lora_name")
# Extract base filename without extension from path
lora_name = os.path.splitext(os.path.basename(lora_name))[0]
strength_model = round(float(inputs.get("strength_model", 1.0)), 2)
# Use the standardized format with lora_list
metadata[LORAS][node_id] = {
"lora_list": [
{
"name": lora_name,
"strength": strength_model
}
],
"node_id": node_id
}
class ImageSizeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
width = inputs.get("width", 512)
height = inputs.get("height", 512)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class LoraLoaderManagerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
active_loras = []
# Process lora_stack if available
if "lora_stack" in inputs:
lora_stack = inputs.get("lora_stack", [])
for lora_path, model_strength, clip_strength in lora_stack:
# Extract lora name from path (following the format in lora_loader.py)
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
active_loras.append({
"name": lora_name,
"strength": model_strength
})
# Process loras from inputs
if "loras" in inputs:
loras_data = inputs.get("loras", [])
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
loras_list = loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Filter for active loras
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", True) and not lora.get("_isDummy", False):
active_loras.append({
"name": lora.get("name", ""),
"strength": float(lora.get("strength", 1.0))
})
if active_loras:
metadata[LORAS][node_id] = {
"lora_list": active_loras,
"node_id": node_id
}
class FluxGuidanceExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "guidance" not in inputs:
return
guidance_value = inputs.get("guidance")
# Store the guidance value in SAMPLING category
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
class UNETLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "unet_name" not in inputs:
return
model_name = inputs.get("unet_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
class VAEDecodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
pass
@staticmethod
def update(node_id, outputs, metadata):
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
# Save image data under node ID index to be captured by caching mechanism
metadata[IMAGES][node_id] = {
"node_id": node_id,
"image": outputs
}
# Only set first_decode if it hasn't been recorded yet
if "first_decode" not in metadata[IMAGES]:
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
# Registry of node-specific extractors
NODE_EXTRACTORS = {
# Sampling
"KSampler": SamplerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor, # Add KSamplerAdvanced
"SamplerCustomAdvanced": SamplerExtractor, # Add SamplerCustomAdvanced
# Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"LoraLoader": LoraLoaderExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
# Conditioning
"CLIPTextEncode": CLIPTextEncodeExtractor,
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
# Image
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
# Add other nodes as needed
}

View File

@@ -0,0 +1,35 @@
import logging
from ..metadata_collector.metadata_processor import MetadataProcessor
logger = logging.getLogger(__name__)
class DebugMetadata:
NAME = "Debug Metadata (LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Debug node to verify metadata_processor functionality"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("metadata_json",)
FUNCTION = "process_metadata"
def process_metadata(self, images):
try:
# Get the current execution context's metadata
from ..metadata_collector import get_metadata
metadata = get_metadata()
# Use the MetadataProcessor to convert it to JSON string
metadata_json = MetadataProcessor.to_json(metadata)
return (metadata_json,)
except Exception as e:
logger.error(f"Error processing metadata: {e}")
return ("{}",) # Return empty JSON object in case of error

View File

@@ -5,7 +5,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config from ..config import config
import asyncio import asyncio
import os import os
from .utils import FlexibleOptionalInputType, any_type from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,7 +18,7 @@ class LoraManagerLoader:
return { return {
"required": { "required": {
"model": ("MODEL",), "model": ("MODEL",),
"clip": ("CLIP",), # "clip": ("CLIP",),
"text": (IO.STRING, { "text": (IO.STRING, {
"multiline": True, "multiline": True,
"dynamicPrompts": True, "dynamicPrompts": True,
@@ -32,54 +32,13 @@ class LoraManagerLoader:
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING) RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras") RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras" FUNCTION = "load_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs): def load_loras(self, model, text, **kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def load_loras(self, model, clip, text, **kwargs):
"""Loads multiple LoRAs based on the kwargs input and lora_stack.""" """Loads multiple LoRAs based on the kwargs input and lora_stack."""
loaded_loras = [] loaded_loras = []
all_trigger_words = [] all_trigger_words = []
clip = kwargs.get('clip', None)
lora_stack = kwargs.get('lora_stack', None) lora_stack = kwargs.get('lora_stack', None)
# First process lora_stack if available # First process lora_stack if available
if lora_stack: if lora_stack:
@@ -88,14 +47,14 @@ class LoraManagerLoader:
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength) model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup # Extract lora name for trigger words lookup
lora_name = self.extract_lora_name(lora_path) lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name)) _, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
loaded_loras.append(f"{lora_name}: {model_strength}") loaded_loras.append(f"{lora_name}: {model_strength}")
# Then process loras from kwargs with support for both old and new formats # Then process loras from kwargs with support for both old and new formats
loras_list = self._get_loras_list(kwargs) loras_list = get_loras_list(kwargs)
for lora in loras_list: for lora in loras_list:
if not lora.get('active', False): if not lora.get('active', False):
continue continue
@@ -104,7 +63,7 @@ class LoraManagerLoader:
strength = float(lora['strength']) strength = float(lora['strength'])
# Get lora path and trigger words # Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Apply the LoRA using the resolved path # Apply the LoRA using the resolved path
model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength) model, clip = LoraLoader().load_lora(model, clip, lora_path, strength, strength)

View File

@@ -3,7 +3,7 @@ from ..services.lora_scanner import LoraScanner
from ..config import config from ..config import config
import asyncio import asyncio
import os import os
from .utils import FlexibleOptionalInputType, any_type from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -26,55 +26,14 @@ class LoraStacker:
"optional": FlexibleOptionalInputType(any_type), "optional": FlexibleOptionalInputType(any_type),
} }
RETURN_TYPES = ("LORA_STACK", IO.STRING) RETURN_TYPES = ("LORA_STACK", IO.STRING, IO.STRING)
RETURN_NAMES = ("LORA_STACK", "trigger_words") RETURN_NAMES = ("LORA_STACK", "trigger_words", "active_loras")
FUNCTION = "stack_loras" FUNCTION = "stack_loras"
async def get_lora_info(self, lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(self, lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def _get_loras_list(self, kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
def stack_loras(self, text, **kwargs): def stack_loras(self, text, **kwargs):
"""Stacks multiple LoRAs based on the kwargs input without loading them.""" """Stacks multiple LoRAs based on the kwargs input without loading them."""
stack = [] stack = []
active_loras = []
all_trigger_words = [] all_trigger_words = []
# Process existing lora_stack if available # Process existing lora_stack if available
@@ -83,12 +42,12 @@ class LoraStacker:
stack.extend(lora_stack) stack.extend(lora_stack)
# Get trigger words from existing stack entries # Get trigger words from existing stack entries
for lora_path, _, _ in lora_stack: for lora_path, _, _ in lora_stack:
lora_name = self.extract_lora_name(lora_path) lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(self.get_lora_info(lora_name)) _, trigger_words = asyncio.run(get_lora_info(lora_name))
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
# Process loras from kwargs with support for both old and new formats # Process loras from kwargs with support for both old and new formats
loras_list = self._get_loras_list(kwargs) loras_list = get_loras_list(kwargs)
for lora in loras_list: for lora in loras_list:
if not lora.get('active', False): if not lora.get('active', False):
continue continue
@@ -98,16 +57,20 @@ class LoraStacker:
clip_strength = model_strength # Using same strength for both as in the original loader clip_strength = model_strength # Using same strength for both as in the original loader
# Get lora path and trigger words # Get lora path and trigger words
lora_path, trigger_words = asyncio.run(self.get_lora_info(lora_name)) lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
# Add to stack without loading # Add to stack without loading
# replace '/' with os.sep to avoid different OS path format # replace '/' with os.sep to avoid different OS path format
stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength)) stack.append((lora_path.replace('/', os.sep), model_strength, clip_strength))
active_loras.append((lora_name, model_strength))
# Add trigger words to collection # Add trigger words to collection
all_trigger_words.extend(trigger_words) all_trigger_words.extend(trigger_words)
# use ',, ' to separate trigger words for group mode # use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Format active_loras as <lora:lora_name:strength> separated by spaces
active_loras_text = " ".join([f"<lora:{name}:{str(strength).strip()}>"
for name, strength in active_loras])
return (stack, trigger_words_text) return (stack, trigger_words_text, active_loras_text)

View File

@@ -1,16 +1,44 @@
import json import json
from server import PromptServer # type: ignore import os
import asyncio
import re
import numpy as np
import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner
from ..services.checkpoint_scanner import CheckpointScanner
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
import piexif
class SaveImage: class SaveImage:
NAME = "Save Image (LoraManager)" NAME = "Save Image (LoraManager)"
CATEGORY = "Lora Manager/utils" CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Experimental node to display image preview and print prompt and extra_pnginfo" DESCRIPTION = "Save images with embedded generation metadata in compatible format"
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
self.compress_level = 4
self.counter = 0
# Add pattern format regex for filename substitution
pattern_format = re.compile(r"(%[^%]+%)")
@classmethod @classmethod
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"image": ("IMAGE",), "images": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"file_format": (["png", "jpeg", "webp"],),
},
"optional": {
"lossless_webp": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
}, },
"hidden": { "hidden": {
"prompt": "PROMPT", "prompt": "PROMPT",
@@ -19,23 +47,381 @@ class SaveImage:
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",) RETURN_NAMES = ("images",)
FUNCTION = "process_image" FUNCTION = "process_image"
OUTPUT_NODE = True
def process_image(self, image, prompt=None, extra_pnginfo=None): async def get_lora_hash(self, lora_name):
# Print the prompt information """Get the lora hash from cache"""
print("SaveImage Node - Prompt:") scanner = await LoraScanner.get_instance()
if prompt:
print(json.dumps(prompt, indent=2))
else:
print("No prompt information available")
# Print the extra_pnginfo # Use the new direct filename lookup method
print("\nSaveImage Node - Extra PNG Info:") hash_value = scanner.get_hash_by_filename(lora_name)
if extra_pnginfo: if hash_value:
print(json.dumps(extra_pnginfo, indent=2)) return hash_value
else:
print("No extra PNG info available") # Fallback to old method for compatibility
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
return item.get('sha256')
return None
async def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache"""
scanner = await CheckpointScanner.get_instance()
# Return the image unchanged if not checkpoint_path:
return (image,) return None
# Extract basename without extension
checkpoint_name = os.path.basename(checkpoint_path)
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Try direct filename lookup first
hash_value = scanner.get_hash_by_filename(checkpoint_name)
if hash_value:
return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
normalized_path = checkpoint_path.replace('\\', '/')
for item in cache.raw_data:
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
return item.get('sha256')
return None
async def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not metadata_dict:
return ""
# Helper function to only add parameter if value is not None
def add_param_if_not_none(param_list, label, value):
if value is not None:
param_list.append(f"{label}: {value}")
# Extract the prompt and negative prompt
prompt = metadata_dict.get('prompt', '')
negative_prompt = metadata_dict.get('negative_prompt', '')
# Extract loras from the prompt if present
loras_text = metadata_dict.get('loras', '')
lora_hashes = {}
# If loras are found, add them on a new line after the prompt
if loras_text:
prompt_with_loras = f"{prompt}\n{loras_text}"
# Extract lora names from the format <lora:name:strength>
lora_matches = re.findall(r'<lora:([^:]+):([^>]+)>', loras_text)
# Get hash for each lora
for lora_name, strength in lora_matches:
hash_value = await self.get_lora_hash(lora_name)
if hash_value:
lora_hashes[lora_name] = hash_value
else:
prompt_with_loras = prompt
# Format the first part (prompt and loras)
metadata_parts = [prompt_with_loras]
# Add negative prompt
if negative_prompt:
metadata_parts.append(f"Negative prompt: {negative_prompt}")
# Format the second part (generation parameters)
params = []
# Add standard parameters in the correct order
if 'steps' in metadata_dict:
add_param_if_not_none(params, "Steps", metadata_dict.get('steps'))
# Combine sampler and scheduler information
sampler_name = None
scheduler_name = None
if 'sampler' in metadata_dict:
sampler = metadata_dict.get('sampler')
# Convert ComfyUI sampler names to user-friendly names
sampler_mapping = {
'euler': 'Euler',
'euler_ancestral': 'Euler a',
'dpm_2': 'DPM2',
'dpm_2_ancestral': 'DPM2 a',
'heun': 'Heun',
'dpm_fast': 'DPM fast',
'dpm_adaptive': 'DPM adaptive',
'lms': 'LMS',
'dpmpp_2s_ancestral': 'DPM++ 2S a',
'dpmpp_sde': 'DPM++ SDE',
'dpmpp_sde_gpu': 'DPM++ SDE',
'dpmpp_2m': 'DPM++ 2M',
'dpmpp_2m_sde': 'DPM++ 2M SDE',
'dpmpp_2m_sde_gpu': 'DPM++ 2M SDE',
'ddim': 'DDIM'
}
sampler_name = sampler_mapping.get(sampler, sampler)
if 'scheduler' in metadata_dict:
scheduler = metadata_dict.get('scheduler')
scheduler_mapping = {
'normal': 'Simple',
'karras': 'Karras',
'exponential': 'Exponential',
'sgm_uniform': 'SGM Uniform',
'sgm_quadratic': 'SGM Quadratic'
}
scheduler_name = scheduler_mapping.get(scheduler, scheduler)
# Add combined sampler and scheduler information
if sampler_name:
if scheduler_name:
params.append(f"Sampler: {sampler_name} {scheduler_name}")
else:
params.append(f"Sampler: {sampler_name}")
# CFG scale (Use guidance if available, otherwise fall back to cfg_scale or cfg)
if 'guidance' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('guidance'))
elif 'cfg_scale' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg_scale'))
elif 'cfg' in metadata_dict:
add_param_if_not_none(params, "CFG scale", metadata_dict.get('cfg'))
# Seed
if 'seed' in metadata_dict:
add_param_if_not_none(params, "Seed", metadata_dict.get('seed'))
# Size
if 'size' in metadata_dict:
add_param_if_not_none(params, "Size", metadata_dict.get('size'))
# Model info
if 'checkpoint' in metadata_dict:
# Ensure checkpoint is a string before processing
checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None:
# Get model hash
model_hash = await self.get_checkpoint_hash(checkpoint)
# Extract basename without path
checkpoint_name = os.path.basename(checkpoint)
# Remove extension if present
checkpoint_name = os.path.splitext(checkpoint_name)[0]
# Add model hash if available
if model_hash:
params.append(f"Model hash: {model_hash[:10]}, Model: {checkpoint_name}")
else:
params.append(f"Model: {checkpoint_name}")
# Add LoRA hashes if available
if lora_hashes:
lora_hash_parts = []
for lora_name, hash_value in lora_hashes.items():
lora_hash_parts.append(f"{lora_name}: {hash_value}")
if lora_hash_parts:
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"")
# Combine all parameters with commas
metadata_parts.append(", ".join(params))
# Join all parts with a new line
return "\n".join(metadata_parts)
# credit to nkchocoai
# Add format_filename method to handle pattern substitution
def format_filename(self, filename, metadata_dict):
"""Format filename with metadata values"""
if not metadata_dict:
return filename
result = re.findall(self.pattern_format, filename)
for segment in result:
parts = segment.replace("%", "").split(":")
key = parts[0]
if key == "seed" and 'seed' in metadata_dict:
filename = filename.replace(segment, str(metadata_dict.get('seed', '')))
elif key == "width" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
w = size.split('x')[0] if isinstance(size, str) else size[0]
filename = filename.replace(segment, str(w))
elif key == "height" and 'size' in metadata_dict:
size = metadata_dict.get('size', 'x')
h = size.split('x')[1] if isinstance(size, str) else size[1]
filename = filename.replace(segment, str(h))
elif key == "pprompt" and 'prompt' in metadata_dict:
prompt = metadata_dict.get('prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "nprompt" and 'negative_prompt' in metadata_dict:
prompt = metadata_dict.get('negative_prompt', '').replace("\n", " ")
if len(parts) >= 2:
length = int(parts[1])
prompt = prompt[:length]
filename = filename.replace(segment, prompt.strip())
elif key == "model" and 'checkpoint' in metadata_dict:
model = metadata_dict.get('checkpoint', '')
model = os.path.splitext(os.path.basename(model))[0]
if len(parts) >= 2:
length = int(parts[1])
model = model[:length]
filename = filename.replace(segment, model)
elif key == "date":
from datetime import datetime
now = datetime.now()
date_table = {
"yyyy": f"{now.year:04d}",
"yy": f"{now.year % 100:02d}",
"MM": f"{now.month:02d}",
"dd": f"{now.day:02d}",
"hh": f"{now.hour:02d}",
"mm": f"{now.minute:02d}",
"ss": f"{now.second:02d}",
}
if len(parts) >= 2:
date_format = parts[1]
for k, v in date_table.items():
date_format = date_format.replace(k, v)
filename = filename.replace(segment, date_format)
else:
date_format = "yyyyMMddhhmmss"
for k, v in date_table.items():
date_format = date_format.replace(k, v)
filename = filename.replace(segment, date_format)
return filename
def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Save images with metadata"""
results = []
# Get metadata using the metadata collector
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
# Get or create metadata asynchronously
metadata = asyncio.run(self.format_metadata(metadata_dict))
# Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
# Get initial save path info once for the batch
full_output_folder, filename, counter, subfolder, processed_prefix = folder_paths.get_save_image_path(
filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]
)
# Create directory if it doesn't exist
if not os.path.exists(full_output_folder):
os.makedirs(full_output_folder, exist_ok=True)
# Process each image with incrementing counter
for i, image in enumerate(images):
# Convert the tensor image to numpy array
img = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8))
# Generate filename with counter if needed
base_filename = filename
if add_counter_to_filename:
# Use counter + i to ensure unique filenames for all images in batch
current_counter = counter + i
base_filename += f"_{current_counter:05}_"
# Set file extension and prepare saving parameters
if file_format == "png":
file = base_filename + ".png"
file_extension = ".png"
# Remove "optimize": True to match built-in node behavior
save_kwargs = {"compress_level": self.compress_level}
pnginfo = PngImagePlugin.PngInfo()
elif file_format == "jpeg":
file = base_filename + ".jpg"
file_extension = ".jpg"
save_kwargs = {"quality": quality, "optimize": True}
elif file_format == "webp":
file = base_filename + ".webp"
file_extension = ".webp"
# Add optimization param to control performance
save_kwargs = {"quality": quality, "lossless": lossless_webp, "method": 0}
# Full save path
file_path = os.path.join(full_output_folder, file)
# Save the image with metadata
try:
if file_format == "png":
if metadata:
pnginfo.add_text("parameters", metadata)
if embed_workflow and extra_pnginfo is not None:
workflow_json = json.dumps(extra_pnginfo["workflow"])
pnginfo.add_text("workflow", workflow_json)
save_kwargs["pnginfo"] = pnginfo
img.save(file_path, format="PNG", **save_kwargs)
elif file_format == "jpeg":
# For JPEG, use piexif
if metadata:
try:
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes
except Exception as e:
print(f"Error adding EXIF data: {e}")
img.save(file_path, format="JPEG", **save_kwargs)
elif file_format == "webp":
# For WebP, also use piexif for metadata
if metadata:
try:
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes
except Exception as e:
print(f"Error adding EXIF data: {e}")
img.save(file_path, format="WEBP", **save_kwargs)
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
except Exception as e:
print(f"Error saving image: {e}")
return results
def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Process and save image with metadata"""
# Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True)
# Ensure images is always a list of images
if len(images.shape) == 3: # Single image (height, width, channels)
images = [images]
else: # Multiple images (batch, height, width, channels)
images = [img for img in images]
# Save all images
results = self.save_images(
images,
filename_prefix,
file_format,
prompt,
extra_pnginfo,
lossless_webp,
quality,
embed_workflow,
add_counter_to_filename
)
return (images,)

View File

@@ -47,10 +47,10 @@ class TriggerWordToggle:
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else "" trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
# Send trigger words to frontend # Send trigger words to frontend
PromptServer.instance.send_sync("trigger_word_update", { # PromptServer.instance.send_sync("trigger_word_update", {
"id": id, # "id": id,
"message": trigger_words # "message": trigger_words
}) # })
filtered_triggers = trigger_words filtered_triggers = trigger_words

View File

@@ -30,4 +30,55 @@ class FlexibleOptionalInputType(dict):
return True return True
any_type = AnyType("*") any_type = AnyType("*")
# Common methods extracted from lora_loader.py and lora_stacker.py
import os
import logging
import asyncio
from ..services.lora_scanner import LoraScanner
from ..config import config
logger = logging.getLogger(__name__)
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
basename = os.path.basename(lora_path)
return os.path.splitext(basename)[0]
def get_loras_list(kwargs):
"""Helper to extract loras list from either old or new kwargs format"""
if 'loras' not in kwargs:
return []
loras_data = kwargs['loras']
# Handle new format: {'loras': {'__value__': [...]}}
if isinstance(loras_data, dict) and '__value__' in loras_data:
return loras_data['__value__']
# Handle old format: {'loras': [...]}
elif isinstance(loras_data, list):
return loras_data
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []

File diff suppressed because it is too large Load Diff

View File

@@ -1,37 +1,483 @@
import os import os
from aiohttp import web import json
import jinja2 import jinja2
from aiohttp import web
import logging import logging
import asyncio
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS
from ..services.websocket_manager import ws_manager
from ..services.service_registry import ServiceRegistry
from ..config import config from ..config import config
from ..services.settings_manager import settings from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class CheckpointsRoutes: class CheckpointsRoutes:
"""Route handlers for Checkpoints management endpoints""" """API routes for checkpoint management"""
def __init__(self): def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.template_env = jinja2.Environment( self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path), loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True autoescape=True
) )
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
self.download_manager = await ServiceRegistry.get_download_manager()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
# Add new routes for model management similar to LoRA routes
app.router.add_post('/api/checkpoints/delete', self.delete_model)
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
# Add new WebSocket endpoint for checkpoint progress
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
try:
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
# Process search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Process hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
# Get data from scanner
result = await self.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
folder=folder,
search=search,
fuzzy_search=fuzzy_search,
base_models=base_models,
tags=tags,
search_options=search_options,
hash_filters=hash_filters
)
# Format response items
formatted_result = {
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
# Return as JSON
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_paginated_data(self, page, page_size, sort_by='name',
folder=None, search=None, fuzzy_search=False,
base_models=None, tags=None,
search_options=None, hash_filters=None):
"""Get paginated and filtered checkpoint data"""
cache = await self.scanner.get_cached_data()
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
filtered_data = [
cp for cp in filtered_data
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
filtered_data = [
cp for cp in filtered_data
if cp['folder'].startswith(folder)
]
else:
# Exact folder filtering
filtered_data = [
cp for cp in filtered_data
if cp['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
filtered_data = [
cp for cp in filtered_data
if cp.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
cp for cp in filtered_data
if any(tag in cp.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
search_results = []
for cp in filtered_data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(cp.get('file_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('file_name', '').lower():
search_results.append(cp)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(cp.get('model_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('model_name', '').lower():
search_results.append(cp)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in cp:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
search_results.append(cp)
continue
filtered_data = search_results
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def _format_checkpoint_response(self, checkpoint):
"""Format checkpoint data for API response"""
return {
"model_name": checkpoint["model_name"],
"file_name": checkpoint["file_name"],
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
"base_model": checkpoint.get("base_model", ""),
"folder": checkpoint["folder"],
"sha256": checkpoint.get("sha256", ""),
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
"file_size": checkpoint.get("size", 0),
"modified": checkpoint.get("modified", ""),
"tags": checkpoint.get("tags", []),
"modelDescription": checkpoint.get("modelDescription", ""),
"from_civitai": checkpoint.get("from_civitai", True),
"notes": checkpoint.get("notes", ""),
"model_type": checkpoint.get("model_type", "checkpoint"),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
}
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background"""
try:
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare checkpoints to process
to_process = [
cp for cp in cache.raw_data
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each checkpoint
for cp in to_process:
try:
original_name = cp.get('model_name')
if await ModelRouteUtils.fetch_and_update_model(
sha256=cp['sha256'],
file_path=cp['file_path'],
model_data=cp,
update_cache_func=self.scanner.update_single_model_cache
):
success += 1
if original_name != cp.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': cp.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get top tags
top_tags = await self.scanner.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get base models
base_models = await self.scanner.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def scan_checkpoints(self, request):
"""Force a rescan of checkpoint files"""
try:
await self.scanner.get_cached_data(force_refresh=True)
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
except Exception as e:
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_checkpoint_info(self, request):
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response: async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request""" """Handle GET /checkpoints request"""
try: try:
template = self.template_env.get_template('checkpoints.html') # Check if the CheckpointScanner is initializing
rendered = template.render( # It's initializing if the cache object doesn't exist yet,
is_initializing=False, # OR if the scanner explicitly says it's initializing (background task running).
settings=settings, is_initializing = (
request=request self.scanner._cache is None or
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
) )
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# 正常流程 - 获取已经初始化好的缓存数据
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response( return web.Response(
text=rendered, text=rendered,
content_type='text/html' content_type='text/html'
) )
except Exception as e: except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True) logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response( return web.Response(
@@ -39,6 +485,194 @@ class CheckpointsRoutes:
status=500 status=500
) )
def setup_routes(self, app: web.Application): async def delete_model(self, request: web.Request) -> web.Response:
"""Register routes with the application""" """Handle checkpoint model deletion request"""
app.router.add_get('/checkpoints', self.handle_checkpoints_page) return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request for checkpoints"""
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement for checkpoints"""
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
try:
data = await request.json()
# Create progress callback that uses checkpoint-specific WebSocket
async def progress_callback(progress):
await ws_manager.broadcast_checkpoint_progress({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('checkpoint_root'),
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type="checkpoint"
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401,
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
)
logger.error(f"Error downloading checkpoint: {error_message}")
return web.Response(status=500, text=error_message)
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates for checkpoints"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='File path is required', status=400)
# Remove file path from data to avoid saving it
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
# Get metadata file path
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Update metadata
metadata.update(metadata_updates)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
# Update cache
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
# If model_name was updated, resort the cache
if 'model_name' in metadata_updates:
cache = await self.scanner.get_cached_data()
await cache.resort(name_only=True)
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get the civitai client from service registry
civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
response = await civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if model_type.lower() != 'checkpoint':
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.scanner.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.scanner.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -1,12 +1,11 @@
import os import os
from aiohttp import web from aiohttp import web
import jinja2 import jinja2
from typing import Dict, List from typing import Dict
import logging import logging
from ..services.lora_scanner import LoraScanner
from ..services.recipe_scanner import RecipeScanner
from ..config import config from ..config import config
from ..services.settings_manager import settings # Add this import from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL) logging.getLogger('asyncio').setLevel(logging.CRITICAL)
@@ -15,13 +14,19 @@ class LoraRoutes:
"""Route handlers for LoRA management endpoints""" """Route handlers for LoRA management endpoints"""
def __init__(self): def __init__(self):
self.scanner = LoraScanner() # Initialize service references as None, will be set during async init
self.recipe_scanner = RecipeScanner(self.scanner) self.scanner = None
self.recipe_scanner = None
self.template_env = jinja2.Environment( self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path), loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True autoescape=True
) )
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
def format_lora_data(self, lora: Dict) -> Dict: def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering""" """Format LoRA data for template rendering"""
return { return {
@@ -58,32 +63,49 @@ class LoraRoutes:
async def handle_loras_page(self, request: web.Request) -> web.Response: async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request""" """Handle GET /loras request"""
try: try:
# 不等待缓存数据,直接检查缓存状态 # Ensure services are initialized
await self.init_services()
# Check if the LoraScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = ( is_initializing = (
self.scanner._cache is None and self.scanner._cache is None or
(self.scanner._initialization_task is not None and (hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
not self.scanner._initialization_task.done())
) )
if is_initializing: if is_initializing:
# 如果正在初始化,返回一个只包含加载提示的页面 # If still initializing, return loading page
template = self.template_env.get_template('loras.html') template = self.template_env.get_template('loras.html')
rendered = template.render( rendered = template.render(
folders=[], # 空文件夹列表 folders=[],
is_initializing=True, # 新增标志 is_initializing=True,
settings=settings, # Pass settings to template settings=settings,
request=request # Pass the request object to the template request=request
) )
logger.info("Loras page is initializing, returning loading page")
else: else:
# 正常流程 # Normal flow - get data from initialized cache
cache = await self.scanner.get_cached_data() try:
template = self.template_env.get_template('loras.html') cache = await self.scanner.get_cached_data(force_refresh=False)
rendered = template.render( template = self.template_env.get_template('loras.html')
folders=cache.folders, rendered = template.render(
is_initializing=False, folders=cache.folders,
settings=settings, # Pass settings to template is_initializing=False,
request=request # Pass the request object to the template settings=settings,
) request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Cache error, returning initialization page")
return web.Response( return web.Response(
text=rendered, text=rendered,
@@ -100,32 +122,30 @@ class LoraRoutes:
async def handle_recipes_page(self, request: web.Request) -> web.Response: async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request""" """Handle GET /loras/recipes request"""
try: try:
# Check cache initialization status # Ensure services are initialized
is_initializing = ( await self.init_services()
self.recipe_scanner._cache is None and
(self.recipe_scanner._initialization_task is not None and # Skip initialization check and directly try to get cached data
not self.recipe_scanner._initialization_task.done()) try:
) # Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
if is_initializing: template = self.template_env.get_template('recipes.html')
# If initializing, return a loading page rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html') template = self.template_env.get_template('recipes.html')
rendered = template.render( rendered = template.render(
is_initializing=True, is_initializing=True,
settings=settings, settings=settings,
request=request # Pass the request object to the template request=request
)
else:
# return empty recipes
recipes_data = []
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=recipes_data,
is_initializing=False,
settings=settings,
request=request # Pass the request object to the template
) )
logger.info("Recipe cache error, returning initialization page")
return web.Response( return web.Response(
text=rendered, text=rendered,
@@ -157,5 +177,13 @@ class LoraRoutes:
def setup_routes(self, app: web.Application): def setup_routes(self, app: web.Application):
"""Register routes with the application""" """Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
# Register routes
app.router.add_get('/loras', self.handle_loras_page) app.router.add_get('/loras', self.handle_loras_page)
app.router.add_get('/loras/recipes', self.handle_recipes_page) app.router.add_get('/loras/recipes', self.handle_recipes_page)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()

View File

@@ -1,5 +1,9 @@
import os import os
import time import time
import numpy as np
from PIL import Image
import torch
import io
import logging import logging
from aiohttp import web from aiohttp import web
from typing import Dict from typing import Dict
@@ -8,13 +12,14 @@ import json
import asyncio import asyncio
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.recipe_parsers import RecipeParserFactory from ..utils.recipe_parsers import RecipeParserFactory
from ..services.civitai_client import CivitaiClient from ..utils.constants import CARD_PREVIEW_WIDTH
from ..services.recipe_scanner import RecipeScanner
from ..services.lora_scanner import LoraScanner
from ..config import config from ..config import config
from ..workflow.parser import WorkflowParser from ..metadata_collector import get_metadata # Add MetadataCollector import
from ..utils.utils import download_twitter_image from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from ..metadata_collector.metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -22,13 +27,19 @@ class RecipeRoutes:
"""API route handlers for Recipe management""" """API route handlers for Recipe management"""
def __init__(self): def __init__(self):
self.recipe_scanner = RecipeScanner(LoraScanner()) # Initialize service references as None, will be set during async init
self.civitai_client = CivitaiClient() self.recipe_scanner = None
self.parser = WorkflowParser() self.civitai_client = None
# Remove WorkflowParser instance
# Pre-warm the cache # Pre-warm the cache
self._init_cache_task = None self._init_cache_task = None
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
self.civitai_client = await ServiceRegistry.get_civitai_client()
@classmethod @classmethod
def setup_routes(cls, app: web.Application): def setup_routes(cls, app: web.Application):
"""Register API routes""" """Register API routes"""
@@ -47,15 +58,30 @@ class RecipeRoutes:
app.router.add_get('/api/recipe/{recipe_id}/share', routes.share_recipe) app.router.add_get('/api/recipe/{recipe_id}/share', routes.share_recipe)
app.router.add_get('/api/recipe/{recipe_id}/share/download', routes.download_shared_recipe) app.router.add_get('/api/recipe/{recipe_id}/share/download', routes.download_shared_recipe)
# Add new endpoint for getting recipe syntax
app.router.add_get('/api/recipe/{recipe_id}/syntax', routes.get_recipe_syntax)
# Add new endpoint for updating recipe metadata (name and tags)
app.router.add_put('/api/recipe/{recipe_id}/update', routes.update_recipe)
# Add new endpoint for reconnecting deleted LoRAs
app.router.add_post('/api/recipe/lora/reconnect', routes.reconnect_lora)
# Start cache initialization # Start cache initialization
app.on_startup.append(routes._init_cache) app.on_startup.append(routes._init_cache)
app.router.add_post('/api/recipes/save-from-widget', routes.save_recipe_from_widget) app.router.add_post('/api/recipes/save-from-widget', routes.save_recipe_from_widget)
# Add route to get recipes for a specific Lora
app.router.add_get('/api/recipes/for-lora', routes.get_recipes_for_lora)
async def _init_cache(self, app): async def _init_cache(self, app):
"""Initialize cache on startup""" """Initialize cache on startup"""
try: try:
# First, ensure the lora scanner is fully initialized # Initialize services first
await self.init_services()
# Now that services are initialized, get the lora scanner
lora_scanner = self.recipe_scanner._lora_scanner lora_scanner = self.recipe_scanner._lora_scanner
# Get lora cache to ensure it's initialized # Get lora cache to ensure it's initialized
@@ -73,6 +99,9 @@ class RecipeRoutes:
async def get_recipes(self, request: web.Request) -> web.Response: async def get_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for getting paginated recipes""" """API endpoint for getting paginated recipes"""
try: try:
# Ensure services are initialized
await self.init_services()
# Get query parameters with defaults # Get query parameters with defaults
page = int(request.query.get('page', '1')) page = int(request.query.get('page', '1'))
page_size = int(request.query.get('page_size', '20')) page_size = int(request.query.get('page_size', '20'))
@@ -89,6 +118,9 @@ class RecipeRoutes:
base_models = request.query.get('base_models', None) base_models = request.query.get('base_models', None)
tags = request.query.get('tags', None) tags = request.query.get('tags', None)
# New parameter: get LoRA hash filter
lora_hash = request.query.get('lora_hash', None)
# Parse filter parameters # Parse filter parameters
filters = {} filters = {}
if base_models: if base_models:
@@ -104,14 +136,15 @@ class RecipeRoutes:
'lora_model': search_lora_model 'lora_model': search_lora_model
} }
# Get paginated data # Get paginated data with the new lora_hash parameter
result = await self.recipe_scanner.get_paginated_data( result = await self.recipe_scanner.get_paginated_data(
page=page, page=page,
page_size=page_size, page_size=page_size,
sort_by=sort_by, sort_by=sort_by,
search=search, search=search,
filters=filters, filters=filters,
search_options=search_options search_options=search_options,
lora_hash=lora_hash
) )
# Format the response data with static URLs for file paths # Format the response data with static URLs for file paths
@@ -138,21 +171,18 @@ class RecipeRoutes:
async def get_recipe_detail(self, request: web.Request) -> web.Response: async def get_recipe_detail(self, request: web.Request) -> web.Response:
"""Get detailed information about a specific recipe""" """Get detailed information about a specific recipe"""
try: try:
recipe_id = request.match_info['recipe_id'] # Ensure services are initialized
await self.init_services()
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
# Find the specific recipe recipe_id = request.match_info['recipe_id']
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
# Use the new get_recipe_by_id method from recipe_scanner
recipe = await self.recipe_scanner.get_recipe_by_id(recipe_id)
if not recipe: if not recipe:
return web.json_response({"error": "Recipe not found"}, status=404) return web.json_response({"error": "Recipe not found"}, status=404)
# Format recipe data return web.json_response(recipe)
formatted_recipe = self._format_recipe_data(recipe)
return web.json_response(formatted_recipe)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving recipe details: {e}", exc_info=True) logger.error(f"Error retrieving recipe details: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500) return web.json_response({"error": str(e)}, status=500)
@@ -197,6 +227,9 @@ class RecipeRoutes:
"""Analyze an uploaded image or URL for recipe metadata""" """Analyze an uploaded image or URL for recipe metadata"""
temp_path = None temp_path = None
try: try:
# Ensure services are initialized
await self.init_services()
# Check if request contains multipart data (image) or JSON data (url) # Check if request contains multipart data (image) or JSON data (url)
content_type = request.headers.get('Content-Type', '') content_type = request.headers.get('Content-Type', '')
@@ -235,7 +268,7 @@ class RecipeRoutes:
}, status=400) }, status=400)
# Download image from URL # Download image from URL
temp_path = download_twitter_image(url) temp_path = download_civitai_image(url)
if not temp_path: if not temp_path:
return web.json_response({ return web.json_response({
@@ -244,10 +277,10 @@ class RecipeRoutes:
}, status=400) }, status=400)
# Extract metadata from the image using ExifUtils # Extract metadata from the image using ExifUtils
user_comment = ExifUtils.extract_user_comment(temp_path) metadata = ExifUtils.extract_image_metadata(temp_path)
# If no metadata found, return a more specific error # If no metadata found, return a more specific error
if not user_comment: if not metadata:
result = { result = {
"error": "No metadata found in this image", "error": "No metadata found in this image",
"loras": [] # Return empty loras array to prevent client-side errors "loras": [] # Return empty loras array to prevent client-side errors
@@ -262,7 +295,7 @@ class RecipeRoutes:
return web.json_response(result, status=200) return web.json_response(result, status=200)
# Use the parser factory to get the appropriate parser # Use the parser factory to get the appropriate parser
parser = RecipeParserFactory.create_parser(user_comment) parser = RecipeParserFactory.create_parser(metadata)
if parser is None: if parser is None:
result = { result = {
@@ -280,7 +313,7 @@ class RecipeRoutes:
# Parse the metadata # Parse the metadata
result = await parser.parse_metadata( result = await parser.parse_metadata(
user_comment, metadata,
recipe_scanner=self.recipe_scanner, recipe_scanner=self.recipe_scanner,
civitai_client=self.civitai_client civitai_client=self.civitai_client
) )
@@ -315,6 +348,9 @@ class RecipeRoutes:
async def save_recipe(self, request: web.Request) -> web.Response: async def save_recipe(self, request: web.Request) -> web.Response:
"""Save a recipe to the recipes folder""" """Save a recipe to the recipes folder"""
try: try:
# Ensure services are initialized
await self.init_services()
reader = await request.multipart() reader = await request.multipart()
# Process form data # Process form data
@@ -387,8 +423,7 @@ class RecipeRoutes:
return web.json_response({"error": f"Invalid base64 image data: {str(e)}"}, status=400) return web.json_response({"error": f"Invalid base64 image data: {str(e)}"}, status=400)
elif image_url: elif image_url:
# Download image from URL # Download image from URL
from ..utils.utils import download_twitter_image temp_path = download_civitai_image(image_url)
temp_path = download_twitter_image(image_url)
if not temp_path: if not temp_path:
return web.json_response({"error": "Failed to download image from URL"}, status=400) return web.json_response({"error": "Failed to download image from URL"}, status=400)
@@ -415,7 +450,7 @@ class RecipeRoutes:
# Optimize the image (resize and convert to WebP) # Optimize the image (resize and convert to WebP)
optimized_image, extension = ExifUtils.optimize_image( optimized_image, extension = ExifUtils.optimize_image(
image_data=image, image_data=image,
target_width=480, target_width=CARD_PREVIEW_WIDTH,
format='webp', format='webp',
quality=85, quality=85,
preserve_metadata=True preserve_metadata=True
@@ -433,19 +468,20 @@ class RecipeRoutes:
# Format loras data according to the recipe.json format # Format loras data according to the recipe.json format
loras_data = [] loras_data = []
for lora in metadata.get("loras", []): for lora in metadata.get("loras", []):
# Skip deleted LoRAs if they're marked to be excluded # Modified: Always include deleted LoRAs in the recipe metadata
if lora.get("isDeleted", False) and lora.get("exclude", False): # Even if they're marked to be excluded, we still keep their identifying information
continue # The exclude flag will only be used to determine if they should be included in recipe syntax
# Convert frontend lora format to recipe format # Convert frontend lora format to recipe format
lora_entry = { lora_entry = {
"file_name": lora.get("file_name", "") or os.path.splitext(os.path.basename(lora.get("localPath", "")))[0], "file_name": lora.get("file_name", "") or os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] if lora.get("localPath") else "",
"hash": lora.get("hash", "").lower() if lora.get("hash") else "", "hash": lora.get("hash", "").lower() if lora.get("hash") else "",
"strength": float(lora.get("weight", 1.0)), "strength": float(lora.get("weight", 1.0)),
"modelVersionId": lora.get("id", ""), "modelVersionId": lora.get("id", ""),
"modelName": lora.get("name", ""), "modelName": lora.get("name", ""),
"modelVersionName": lora.get("version", ""), "modelVersionName": lora.get("version", ""),
"isDeleted": lora.get("isDeleted", False) # Preserve deletion status in saved recipe "isDeleted": lora.get("isDeleted", False), # Preserve deletion status in saved recipe
"exclude": lora.get("exclude", False) # Add exclude flag to the recipe
} }
loras_data.append(lora_entry) loras_data.append(lora_entry)
@@ -516,6 +552,9 @@ class RecipeRoutes:
async def delete_recipe(self, request: web.Request) -> web.Response: async def delete_recipe(self, request: web.Request) -> web.Response:
"""Delete a recipe by ID""" """Delete a recipe by ID"""
try: try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id'] recipe_id = request.match_info['recipe_id']
# Get recipes directory # Get recipes directory
@@ -563,6 +602,9 @@ class RecipeRoutes:
async def get_top_tags(self, request: web.Request) -> web.Response: async def get_top_tags(self, request: web.Request) -> web.Response:
"""Get top tags used in recipes""" """Get top tags used in recipes"""
try: try:
# Ensure services are initialized
await self.init_services()
# Get limit parameter with default # Get limit parameter with default
limit = int(request.query.get('limit', '20')) limit = int(request.query.get('limit', '20'))
@@ -595,6 +637,9 @@ class RecipeRoutes:
async def get_base_models(self, request: web.Request) -> web.Response: async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in recipes""" """Get base models used in recipes"""
try: try:
# Ensure services are initialized
await self.init_services()
# Get all recipes from cache # Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data() cache = await self.recipe_scanner.get_cached_data()
@@ -617,12 +662,15 @@ class RecipeRoutes:
logger.error(f"Error retrieving base models: {e}", exc_info=True) logger.error(f"Error retrieving base models: {e}", exc_info=True)
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': str(e) 'error': str(e)}
}, status=500) , status=500)
async def share_recipe(self, request: web.Request) -> web.Response: async def share_recipe(self, request: web.Request) -> web.Response:
"""Process a recipe image for sharing by adding metadata to EXIF""" """Process a recipe image for sharing by adding metadata to EXIF"""
try: try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id'] recipe_id = request.match_info['recipe_id']
# Get all recipes from cache # Get all recipes from cache
@@ -682,6 +730,9 @@ class RecipeRoutes:
async def download_shared_recipe(self, request: web.Request) -> web.Response: async def download_shared_recipe(self, request: web.Request) -> web.Response:
"""Serve a processed recipe image for download""" """Serve a processed recipe image for download"""
try: try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id'] recipe_id = request.match_info['recipe_id']
# Check if we have this shared recipe # Check if we have this shared recipe
@@ -738,50 +789,75 @@ class RecipeRoutes:
async def save_recipe_from_widget(self, request: web.Request) -> web.Response: async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
"""Save a recipe from the LoRAs widget""" """Save a recipe from the LoRAs widget"""
try: try:
reader = await request.multipart() # Ensure services are initialized
await self.init_services()
# Process form data # Get metadata using the metadata collector instead of workflow parsing
workflow_json = None raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
while True: # Check if we have valid metadata
field = await reader.next() if not metadata_dict:
if field is None: return web.json_response({"error": "No generation metadata found"}, status=400)
break
# Get the most recent image from metadata registry instead of temp directory
metadata_registry = MetadataRegistry()
latest_image = metadata_registry.get_first_decoded_image()
if not latest_image:
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
# Convert the image data to bytes - handle tuple and tensor cases
logger.debug(f"Image type: {type(latest_image)}")
try:
# Handle the tuple case first
if isinstance(latest_image, tuple):
# Extract the tensor from the tuple
if len(latest_image) > 0:
tensor_image = latest_image[0]
else:
return web.json_response({"error": "Empty image tuple received"}, status=400)
else:
tensor_image = latest_image
if field.name == 'workflow_json': # Get the shape info for debugging
workflow_text = await field.text() if hasattr(tensor_image, 'shape'):
try: shape_info = tensor_image.shape
workflow_json = json.loads(workflow_text) logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
except:
return web.json_response({"error": "Invalid workflow JSON"}, status=400) # Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
# Handle different tensor shapes
# Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch
if len(image_np.shape) > 3:
# Remove batch dimensions until we get to (H, W, 3)
while len(image_np.shape) > 3:
image_np = image_np[0]
# If values are in [0, 1] range, convert to [0, 255]
if image_np.dtype == np.float32 or image_np.dtype == np.float64:
if image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
# Ensure image is in the right format (HWC with RGB channels)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format='PNG')
image = img_byte_arr.getvalue()
else:
return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400)
except Exception as e:
logger.error(f"Error processing image data: {str(e)}", exc_info=True)
return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400)
if not workflow_json: # Get the lora stack from the metadata
return web.json_response({"error": "Missing required workflow_json field"}, status=400) lora_stack = metadata_dict.get("loras", "")
# Find the latest image in the temp directory
temp_dir = config.temp_directory
image_files = []
for file in os.listdir(temp_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
file_path = os.path.join(temp_dir, file)
image_files.append((file_path, os.path.getmtime(file_path)))
if not image_files:
return web.json_response({"error": "No recent images found to use for recipe"}, status=400)
# Sort by modification time (newest first)
image_files.sort(key=lambda x: x[1], reverse=True)
latest_image_path = image_files[0][0]
# Parse the workflow to extract generation parameters and loras
parsed_workflow = self.parser.parse_workflow(workflow_json)
if not parsed_workflow or not parsed_workflow.get("gen_params"):
return web.json_response({"error": "Could not extract generation parameters from workflow"}, status=400)
# Get the lora stack from the parsed workflow
lora_stack = parsed_workflow.get("loras", "")
# Parse the lora stack format: "<lora:name:strength> <lora:name2:strength2> ..." # Parse the lora stack format: "<lora:name:strength> <lora:name2:strength2> ..."
import re import re
@@ -789,7 +865,7 @@ class RecipeRoutes:
# Check if any loras were found # Check if any loras were found
if not lora_matches: if not lora_matches:
return web.json_response({"error": "No LoRAs found in the workflow"}, status=400) return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400)
# Generate recipe name from the first 3 loras (or less if fewer are available) # Generate recipe name from the first 3 loras (or less if fewer are available)
loras_for_name = lora_matches[:3] # Take at most 3 loras for the name loras_for_name = lora_matches[:3] # Take at most 3 loras for the name
@@ -803,10 +879,6 @@ class RecipeRoutes:
recipe_name = " ".join(recipe_name_parts) recipe_name = " ".join(recipe_name_parts)
# Read the image
with open(latest_image_path, 'rb') as f:
image = f.read()
# Create recipes directory if it doesn't exist # Create recipes directory if it doesn't exist
recipes_dir = self.recipe_scanner.recipes_dir recipes_dir = self.recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True) os.makedirs(recipes_dir, exist_ok=True)
@@ -818,7 +890,7 @@ class RecipeRoutes:
# Optimize the image (resize and convert to WebP) # Optimize the image (resize and convert to WebP)
optimized_image, extension = ExifUtils.optimize_image( optimized_image, extension = ExifUtils.optimize_image(
image_data=image, image_data=image,
target_width=480, target_width=CARD_PREVIEW_WIDTH,
format='webp', format='webp',
quality=85, quality=85,
preserve_metadata=True preserve_metadata=True
@@ -874,7 +946,9 @@ class RecipeRoutes:
"created_date": time.time(), "created_date": time.time(),
"base_model": most_common_base_model, "base_model": most_common_base_model,
"loras": loras_data, "loras": loras_data,
"gen_params": parsed_workflow.get("gen_params", {}), # Use the parsed workflow parameters "checkpoint": metadata_dict.get("checkpoint", ""),
"gen_params": {key: value for key, value in metadata_dict.items()
if key not in ['checkpoint', 'loras']},
"loras_stack": lora_stack # Include the original lora stack string "loras_stack": lora_stack # Include the original lora stack string
} }
@@ -906,3 +980,278 @@ class RecipeRoutes:
except Exception as e: except Exception as e:
logger.error(f"Error saving recipe from widget: {e}", exc_info=True) logger.error(f"Error saving recipe from widget: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500) return web.json_response({"error": str(e)}, status=500)
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
"""Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
# Find the specific recipe
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
if not recipe:
return web.json_response({"error": "Recipe not found"}, status=404)
# Get the loras from the recipe
loras = recipe.get('loras', [])
if not loras:
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
# Generate recipe syntax for all LoRAs that:
# 1. Are in the library (not deleted) OR
# 2. Are deleted but not marked for exclusion
lora_syntax_parts = []
# Access the hash_index from lora_scanner
hash_index = self.recipe_scanner._lora_scanner._hash_index
for lora in loras:
# Skip loras that are deleted AND marked for exclusion
if lora.get("isDeleted", False):
continue
if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")):
continue
# Get the strength
strength = lora.get("strength", 1.0)
# Try to find the actual file name for this lora
file_name = None
hash_value = lora.get("hash", "").lower()
if hash_value and hasattr(hash_index, "_hash_to_path"):
# Look up the file path from the hash
file_path = hash_index._hash_to_path.get(hash_value)
if file_path:
# Extract the file name without extension from the path
file_name = os.path.splitext(os.path.basename(file_path))[0]
# If hash lookup failed, fall back to modelVersionId lookup
if not file_name and lora.get("modelVersionId"):
# Search for files with matching modelVersionId
all_loras = await self.recipe_scanner._lora_scanner.get_cached_data()
for cached_lora in all_loras.raw_data:
if not cached_lora.get("civitai"):
continue
if cached_lora.get("civitai", {}).get("id") == lora.get("modelVersionId"):
file_name = os.path.splitext(os.path.basename(cached_lora["path"]))[0]
break
# If all lookups failed, use the file_name from the recipe
if not file_name:
file_name = lora.get("file_name", "unknown-lora")
# Add to syntax parts
lora_syntax_parts.append(f"<lora:{file_name}:{strength}>")
# Join the LoRA syntax parts
lora_syntax = " ".join(lora_syntax_parts)
return web.json_response({
'success': True,
'syntax': lora_syntax
})
except Exception as e:
logger.error(f"Error generating recipe syntax: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def update_recipe(self, request: web.Request) -> web.Response:
"""Update recipe metadata (name and tags)"""
try:
# Ensure services are initialized
await self.init_services()
recipe_id = request.match_info['recipe_id']
data = await request.json()
# Validate required fields
if 'title' not in data and 'tags' not in data:
return web.json_response({
"error": "At least one field to update must be provided (title or tags)"
}, status=400)
# Use the recipe scanner's update method
success = await self.recipe_scanner.update_recipe_metadata(recipe_id, data)
if not success:
return web.json_response({"error": "Recipe not found or update failed"}, status=404)
return web.json_response({
"success": True,
"recipe_id": recipe_id,
"updates": data
})
except Exception as e:
logger.error(f"Error updating recipe: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def reconnect_lora(self, request: web.Request) -> web.Response:
"""Reconnect a deleted LoRA in a recipe to a local LoRA file"""
try:
# Ensure services are initialized
await self.init_services()
# Parse request data
data = await request.json()
# Validate required fields
required_fields = ['recipe_id', 'lora_data', 'target_name']
for field in required_fields:
if field not in data:
return web.json_response({
"error": f"Missing required field: {field}"
}, status=400)
recipe_id = data['recipe_id']
lora_data = data['lora_data']
target_name = data['target_name']
# Get recipe scanner
scanner = self.recipe_scanner
lora_scanner = scanner._lora_scanner
# Check if recipe exists
recipe_path = os.path.join(scanner.recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_path):
return web.json_response({"error": "Recipe not found"}, status=404)
# Find target LoRA by name
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
if not target_lora:
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
# Load recipe data
with open(recipe_path, 'r', encoding='utf-8') as f:
recipe_data = json.load(f)
# Find the deleted LoRA in the recipe
found = False
updated_lora = None
# Identification can be by hash, modelVersionId, or modelName
for i, lora in enumerate(recipe_data.get('loras', [])):
match_found = False
# Try to match by available identifiers
if 'hash' in lora and 'hash' in lora_data and lora['hash'] == lora_data['hash']:
match_found = True
elif 'modelVersionId' in lora and 'modelVersionId' in lora_data and lora['modelVersionId'] == lora_data['modelVersionId']:
match_found = True
elif 'modelName' in lora and 'modelName' in lora_data and lora['modelName'] == lora_data['modelName']:
match_found = True
if match_found:
# Update LoRA data
lora['isDeleted'] = False
lora['file_name'] = target_name
# Update with information from the target LoRA
if 'sha256' in target_lora:
lora['hash'] = target_lora['sha256'].lower()
if target_lora.get("civitai"):
lora['modelName'] = target_lora['civitai']['model']['name']
lora['modelVersionName'] = target_lora['civitai']['name']
lora['modelVersionId'] = target_lora['civitai']['id']
# Keep original fields for identification
# Mark as found and store updated lora
found = True
updated_lora = dict(lora) # Make a copy for response
break
if not found:
return web.json_response({"error": "Could not find matching deleted LoRA in recipe"}, status=404)
# Save updated recipe
with open(recipe_path, 'w', encoding='utf-8') as f:
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
updated_lora['inLibrary'] = True
updated_lora['preview_url'] = target_lora['preview_url']
updated_lora['localPath'] = target_lora['file_path']
# Update in cache if it exists
if scanner._cache is not None:
for cache_item in scanner._cache.raw_data:
if cache_item.get('id') == recipe_id:
# Replace loras array with updated version
cache_item['loras'] = recipe_data['loras']
# Resort the cache
asyncio.create_task(scanner._cache.resort())
break
# Update EXIF metadata if image exists
image_path = recipe_data.get('file_path')
if image_path and os.path.exists(image_path):
from ..utils.exif_utils import ExifUtils
ExifUtils.append_recipe_metadata(image_path, recipe_data)
return web.json_response({
"success": True,
"recipe_id": recipe_id,
"updated_lora": updated_lora
})
except Exception as e:
logger.error(f"Error reconnecting LoRA: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
"""Get recipes that use a specific Lora"""
try:
# Ensure services are initialized
await self.init_services()
lora_hash = request.query.get('hash')
# Hash is required
if not lora_hash:
return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400)
# Log the search parameters
logger.debug(f"Getting recipes for Lora by hash: {lora_hash}")
# Get all recipes from cache
cache = await self.recipe_scanner.get_cached_data()
# Filter recipes that use this Lora by hash
matching_recipes = []
for recipe in cache.raw_data:
# Check if any of the recipe's loras match this hash
loras = recipe.get('loras', [])
for lora in loras:
if lora.get('hash', '').lower() == lora_hash.lower():
matching_recipes.append(recipe)
break # No need to check other loras in this recipe
# Process the recipes similar to get_paginated_data to ensure all needed data is available
for recipe in matching_recipes:
# Add inLibrary information for each lora
if 'loras' in recipe:
for lora in recipe['loras']:
if 'hash' in lora and lora['hash']:
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
# Ensure file_url is set (needed by frontend)
if 'file_path' in recipe:
recipe['file_url'] = self._format_recipe_file_url(recipe['file_path'])
else:
recipe['file_url'] = '/loras_static/images/no-preview.png'
return web.json_response({'success': True, 'recipes': matching_recipes})
except Exception as e:
logger.error(f"Error getting recipes for Lora: {str(e)}")
return web.json_response({'success': False, 'error': str(e)}, status=500)

View File

@@ -0,0 +1,69 @@
import logging
from aiohttp import web
from ..utils.usage_stats import UsageStats
logger = logging.getLogger(__name__)
class UsageStatsRoutes:
"""Routes for handling usage statistics updates"""
@staticmethod
def setup_routes(app):
"""Register usage stats routes"""
app.router.add_post('/loras/api/update-usage-stats', UsageStatsRoutes.update_usage_stats)
app.router.add_get('/loras/api/get-usage-stats', UsageStatsRoutes.get_usage_stats)
@staticmethod
async def update_usage_stats(request):
"""
Update usage statistics based on a prompt_id
Expects a JSON body with:
{
"prompt_id": "string"
}
"""
try:
# Parse the request body
data = await request.json()
prompt_id = data.get('prompt_id')
if not prompt_id:
return web.json_response({
'success': False,
'error': 'Missing prompt_id'
}, status=400)
# Call the UsageStats to process this prompt_id synchronously
usage_stats = UsageStats()
await usage_stats.process_execution(prompt_id)
return web.json_response({
'success': True
})
except Exception as e:
logger.error(f"Failed to update usage stats: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_usage_stats(request):
"""Get current usage statistics"""
try:
usage_stats = UsageStats()
stats = await usage_stats.get_stats()
return web.json_response({
'success': True,
'data': stats
})
except Exception as e:
logger.error(f"Failed to get usage stats: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

26
py/server_routes.py Normal file
View File

@@ -0,0 +1,26 @@
from aiohttp import web
from server import PromptServer
from .nodes.utils import get_lora_info
@PromptServer.instance.routes.post("/loramanager/get_trigger_words")
async def get_trigger_words(request):
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = await get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})

View File

@@ -0,0 +1,131 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional, Set
import folder_paths # type: ignore
from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
self._checkpoint_roots = self._init_checkpoint_roots()
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _init_checkpoint_roots(self) -> List[str]:
"""Initialize checkpoint roots from ComfyUI settings"""
# Get both checkpoint and diffusion_models paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
# Combine, normalize and deduplicate paths
all_paths = set()
for path in checkpoint_paths + diffusion_paths:
if os.path.exists(path):
norm_path = path.replace(os.sep, "/")
all_paths.add(norm_path)
# Sort for consistent order
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
return sorted_paths
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return self._checkpoint_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
all_checkpoints = []
# Create scan tasks for each directory
scan_tasks = []
for root in self._checkpoint_roots:
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
checkpoints = await task
all_checkpoints.extend(checkpoints)
except Exception as e:
logger.error(f"Error scanning checkpoint directory: {e}")
return all_checkpoints
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a directory for checkpoint files"""
checkpoints = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, checkpoints)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return checkpoints
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
"""Process a single checkpoint file and add to results"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")

View File

@@ -3,6 +3,7 @@ import aiohttp
import os import os
import json import json
import logging import logging
import asyncio
from email.parser import Parser from email.parser import Parser
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List
from urllib.parse import unquote from urllib.parse import unquote
@@ -11,20 +12,51 @@ from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class CivitaiClient: class CivitaiClient:
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of CivitaiClient"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self): def __init__(self):
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self.base_url = "https://civitai.com/api/v1" self.base_url = "https://civitai.com/api/v1"
self.headers = { self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0' 'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
} }
self._session = None self._session = None
# Set default buffer size to 1MB for higher throughput
self.chunk_size = 1024 * 1024
@property @property
async def session(self) -> aiohttp.ClientSession: async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session""" """Lazy initialize the session"""
if self._session is None: if self._session is None:
connector = aiohttp.TCPConnector(ssl=True) # Optimize TCP connection parameters
trust_env = True # 允许使用系统环境变量中的代理设置 connector = aiohttp.TCPConnector(
self._session = aiohttp.ClientSession(connector=connector, trust_env=trust_env) ssl=True,
limit=10, # Increase parallel connections
ttl_dns_cache=300, # DNS cache time
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=trust_env,
timeout=timeout
)
return self._session return self._session
def _parse_content_disposition(self, header: str) -> str: def _parse_content_disposition(self, header: str) -> str:
@@ -74,12 +106,17 @@ class CivitaiClient:
session = await self.session session = await self.session
try: try:
headers = self._get_request_headers() headers = self._get_request_headers()
# Add Range header to allow resumable downloads
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
async with session.get(url, headers=headers, allow_redirects=True) as response: async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200: if response.status != 200:
# Handle early access 401 unauthorized responses # Handle 401 unauthorized responses
if response.status == 401: if response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)") logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
return False, "Early access restriction: You must purchase early access to download this LoRA."
return False, "Invalid or missing CivitAI API key, or early access restriction."
# Handle other client errors that might be permission-related # Handle other client errors that might be permission-related
if response.status == 403: if response.status == 403:
@@ -100,16 +137,23 @@ class CivitaiClient:
# Get total file size for progress calculation # Get total file size for progress calculation
total_size = int(response.headers.get('content-length', 0)) total_size = int(response.headers.get('content-length', 0))
current_size = 0 current_size = 0
last_progress_report_time = datetime.now()
# Stream download to file with progress updates # Stream download to file with progress updates using larger buffer
with open(save_path, 'wb') as f: with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192): async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk: if chunk:
f.write(chunk) f.write(chunk)
current_size += len(chunk) current_size += len(chunk)
if progress_callback and total_size:
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 0.5:
progress = (current_size / total_size) * 100 progress = (current_size / total_size) * 100
await progress_callback(progress) await progress_callback(progress)
last_progress_report_time = now
# Ensure 100% progress is reported # Ensure 100% progress is reported
if progress_callback: if progress_callback:
@@ -117,6 +161,9 @@ class CivitaiClient:
return True, save_path return True, save_path
except aiohttp.ClientError as e:
logger.error(f"Network error during download: {e}")
return False, f"Network error: {str(e)}"
except Exception as e: except Exception as e:
logger.error(f"Download error: {e}") logger.error(f"Download error: {e}")
return False, str(e) return False, str(e)
@@ -154,13 +201,26 @@ class CivitaiClient:
if response.status != 200: if response.status != 200:
return None return None
data = await response.json() data = await response.json()
return data.get('modelVersions', []) # Also return model type along with versions
return {
'modelVersions': data.get('modelVersions', []),
'type': data.get('type', '')
}
except Exception as e: except Exception as e:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
async def get_model_version_info(self, version_id: str) -> Optional[Dict]: async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai""" """Fetch model version metadata from Civitai
Args:
version_id: The Civitai model version ID
Returns:
Tuple[Optional[Dict], Optional[str]]: A tuple containing:
- The model version data or None if not found
- An error message if there was an error, or None on success
"""
try: try:
session = await self.session session = await self.session
url = f"{self.base_url}/model-versions/{version_id}" url = f"{self.base_url}/model-versions/{version_id}"
@@ -168,11 +228,25 @@ class CivitaiClient:
async with session.get(url, headers=headers) as response: async with session.get(url, headers=headers) as response:
if response.status == 200: if response.status == 200:
return await response.json() return await response.json(), None
return None
# Handle specific error cases
if response.status == 404:
# Try to parse the error message
try:
error_data = await response.json()
error_msg = error_data.get('error', f"Model not found (status 404)")
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
except:
return None, "Model not found (status 404)"
# Other error cases
return None, f"Failed to fetch model info (status {response.status})"
except Exception as e: except Exception as e:
logger.error(f"Error fetching model version info: {e}") error_msg = f"Error fetching model version info: {e}"
return None logger.error(error_msg)
return None, error_msg
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]: async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description and tags) from Civitai API """Fetch model metadata (description and tags) from Civitai API
@@ -233,11 +307,9 @@ class CivitaiClient:
if not self._session: if not self._session:
return None return None
logger.info(f"Fetching model version info from Civitai for ID: {model_version_id}")
version_info = await self._session.get(f"{self.base_url}/model-versions/{model_version_id}") version_info = await self._session.get(f"{self.base_url}/model-versions/{model_version_id}")
if not version_info or not version_info.json().get('files'): if not version_info or not version_info.json().get('files'):
logger.warning(f"No files found in version info for ID: {model_version_id}")
return None return None
# Get hash from the first file # Get hash from the first file
@@ -247,8 +319,7 @@ class CivitaiClient:
hash_value = file_info['hashes']['SHA256'].lower() hash_value = file_info['hashes']['SHA256'].lower()
return hash_value return hash_value
logger.warning(f"No SHA256 hash found in version info for ID: {model_version_id}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting hash from Civitai: {e}") logger.error(f"Error getting hash from Civitai: {e}")
return None return None

View File

@@ -1,20 +1,79 @@
import logging import logging
import os import os
import json import json
from typing import Optional, Dict import asyncio
from typing import Optional, Dict, Any
from .civitai_client import CivitaiClient from .civitai_client import CivitaiClient
from .file_monitor import LoraFileMonitor from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.models import LoraMetadata from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from .service_registry import ServiceRegistry
# Download to temporary file first
import tempfile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DownloadManager: class DownloadManager:
def __init__(self, file_monitor: Optional[LoraFileMonitor] = None): _instance = None
self.civitai_client = CivitaiClient() _lock = asyncio.Lock()
self.file_monitor = file_monitor
@classmethod
async def get_instance(cls):
"""Get singleton instance of DownloadManager"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def download_from_civitai(self, download_url: str, save_dir: str, relative_path: str = '', def __init__(self):
progress_callback=None) -> Dict: # Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
self._civitai_client = None # Will be lazily initialized
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_monitor(self):
"""Get the lora file monitor from registry"""
return await ServiceRegistry.get_lora_monitor()
async def _get_checkpoint_monitor(self):
"""Get the checkpoint file monitor from registry"""
return await ServiceRegistry.get_checkpoint_monitor()
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
return await ServiceRegistry.get_lora_scanner()
async def _get_checkpoint_scanner(self):
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None,
model_type: str = "lora") -> Dict:
"""Download model from Civitai
Args:
download_url: Direct download URL for the model
model_hash: SHA256 hash of the model
model_version_id: Civitai model version ID
save_dir: Directory to save the model to
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
model_type: Type of model ('lora' or 'checkpoint')
Returns:
Dict with download result
"""
try: try:
# Update save directory with relative path if provided # Update save directory with relative path if provided
if relative_path: if relative_path:
@@ -22,26 +81,44 @@ class DownloadManager:
# Create directory if it doesn't exist # Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
# Get version info # Get civitai client
version_id = download_url.split('/')[-1] civitai_client = await self._get_civitai_client()
version_info = await self.civitai_client.get_model_version_info(version_id)
if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'}
# Check if this is an early access LoRA # Get version info based on the provided identifier
if 'earlyAccessEndsAt' in version_info: version_info = None
error_msg = None
if download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
elif model_version_id:
# Use model version ID directly
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
elif model_hash:
# Get model by hash
version_info = await civitai_client.get_model_by_hash(model_hash)
if not version_info:
if error_msg and "model not found" in error_msg.lower():
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
early_access_date = version_info.get('earlyAccessEndsAt', '') early_access_date = version_info.get('earlyAccessEndsAt', '')
# Convert to a readable date if possible # Convert to a readable date if possible
try: try:
from datetime import datetime from datetime import datetime
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00')) date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
formatted_date = date_obj.strftime('%Y-%m-%d') formatted_date = date_obj.strftime('%Y-%m-%d')
early_access_msg = f"This LoRA requires early access payment (until {formatted_date}). " early_access_msg = f"This model requires early access payment (until {formatted_date}). "
except: except:
early_access_msg = "This LoRA requires early access payment. " early_access_msg = "This model requires early access payment. "
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai." early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
logger.warning(f"Early access LoRA detected: {version_info.get('name', 'Unknown')}") logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
# We'll still try to download, but log a warning and prepare for potential failure # We'll still try to download, but log a warning and prepare for potential failure
if progress_callback: if progress_callback:
@@ -51,50 +128,51 @@ class DownloadManager:
if progress_callback: if progress_callback:
await progress_callback(0) await progress_callback(0)
# 2. 获取文件信息 # 2. Get file information
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info: if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'} return {'success': False, 'error': 'No primary file found in metadata'}
# 3. 准备下载 # 3. Prepare download
file_name = file_info['name'] file_name = file_info['name']
save_path = os.path.join(save_dir, file_name) save_path = os.path.join(save_dir, file_name)
file_size = file_info.get('sizeKB', 0) * 1024 file_size = file_info.get('sizeKB', 0) * 1024
# 4. 通知文件监控系统 - 使用规范化路径和文件大小 # 4. Notify file monitor - use normalized path and file size
if self.file_monitor and self.file_monitor.handler: file_monitor = await self._get_lora_monitor() if model_type == "lora" else await self._get_checkpoint_monitor()
# Add both the normalized path and potential alternative paths if file_monitor and file_monitor.handler:
normalized_path = save_path.replace(os.sep, '/') file_monitor.handler.add_ignore_path(
self.file_monitor.handler.add_ignore_path(normalized_path, file_size) save_path.replace(os.sep, '/'),
file_size
# Also add the path with file extension variations (.safetensors) )
if not normalized_path.endswith('.safetensors'):
safetensors_path = os.path.splitext(normalized_path)[0] + '.safetensors'
self.file_monitor.handler.add_ignore_path(safetensors_path, file_size)
logger.debug(f"Added download path to ignore list: {normalized_path} (size: {file_size} bytes)")
# 5. 准备元数据 # 5. Prepare metadata based on model type
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) if model_type == "checkpoint":
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating CheckpointMetadata for {file_name}")
else:
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}")
# 5.1 获取并更新模型标签和描述信息 # 5.1 Get and update model tags and description
model_id = version_info.get('modelId') model_id = version_info.get('modelId')
if model_id: if model_id:
model_metadata, _ = await self.civitai_client.get_model_metadata(str(model_id)) model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata: if model_metadata:
if model_metadata.get("tags"): if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", []) metadata.tags = model_metadata.get("tags", [])
if model_metadata.get("description"): if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "") metadata.modelDescription = model_metadata.get("description", "")
# 6. 开始下载流程 # 6. Start download process
result = await self._execute_download( result = await self._execute_download(
download_url=download_url, download_url=file_info.get('downloadUrl', ''),
save_dir=save_dir, save_dir=save_dir,
metadata=metadata, metadata=metadata,
version_info=version_info, version_info=version_info,
relative_path=relative_path, relative_path=relative_path,
progress_callback=progress_callback progress_callback=progress_callback,
model_type=model_type
) )
return result return result
@@ -108,10 +186,12 @@ class DownloadManager:
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
async def _execute_download(self, download_url: str, save_dir: str, async def _execute_download(self, download_url: str, save_dir: str,
metadata: LoraMetadata, version_info: Dict, metadata, version_info: Dict,
relative_path: str, progress_callback=None) -> Dict: relative_path: str, progress_callback=None,
model_type: str = "lora") -> Dict:
"""Execute the actual download process including preview images and model files""" """Execute the actual download process including preview images and model files"""
try: try:
civitai_client = await self._get_civitai_client()
save_path = metadata.file_path save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
@@ -122,20 +202,61 @@ class DownloadManager:
if progress_callback: if progress_callback:
await progress_callback(1) # 1% progress for starting preview download await progress_callback(1) # 1% progress for starting preview download
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png' # Check if it's a video or an image
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext is_video = images[0].get('type') == 'video'
if await self.civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/') if (is_video):
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) # For videos, use .mp4 extension
with open(metadata_path, 'w', encoding='utf-8') as f: preview_ext = '.mp4'
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) preview_path = os.path.splitext(save_path)[0] + preview_ext
# Download video directly
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
else:
# For images, use WebP format for better performance
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name
# Download the original image to temp path
if await civitai_client.download_preview_image(images[0]['url'], temp_path):
# Optimize and convert to WebP
preview_path = os.path.splitext(save_path)[0] + '.webp'
# Use ExifUtils to optimize and convert the image
optimized_data, _ = ExifUtils.optimize_image(
image_data=temp_path,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
# Save the optimized image
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update metadata
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# Remove temporary file
try:
os.unlink(temp_path)
except Exception as e:
logger.warning(f"Failed to delete temp file: {e}")
# Report preview download completion # Report preview download completion
if progress_callback: if progress_callback:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking # Download model file with progress tracking
success, result = await self.civitai_client._download_file( success, result = await civitai_client._download_file(
download_url, download_url,
save_dir, save_dir,
os.path.basename(save_path), os.path.basename(save_path),
@@ -149,15 +270,22 @@ class DownloadManager:
os.remove(path) os.remove(path)
return {'success': False, 'error': result} return {'success': False, 'error': result}
# 4. 更新文件信息(大小和修改时间) # 4. Update file information (size and modified time)
metadata.update_file_info(save_path) metadata.update_file_info(save_path)
# 5. 最终更新元数据 # 5. Final metadata update
with open(metadata_path, 'w', encoding='utf-8') as f: with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False) json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# 6. update lora cache # 6. Update cache based on model type
cache = await self.file_monitor.scanner.get_cached_data() if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}")
else:
scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}")
cache = await scanner.get_cached_data()
metadata_dict = metadata.to_dict() metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict) cache.raw_data.append(metadata_dict)
@@ -166,11 +294,8 @@ class DownloadManager:
all_folders.add(relative_path) all_folders.add(relative_path)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update the hash index with the new LoRA entry # Update the hash index with the new model entry
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path']) scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Update the hash index with the new LoRA entry
self.file_monitor.scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Report 100% completion # Report 100% completion
if progress_callback: if progress_callback:

View File

@@ -1,114 +1,202 @@
from operator import itemgetter
import os import os
import logging import logging
import asyncio import asyncio
import time
from watchdog.observers import Observer from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent from watchdog.events import FileSystemEventHandler
from typing import List from typing import List, Dict, Set, Optional
from threading import Lock from threading import Lock
from .lora_scanner import LoraScanner
from ..config import config from ..config import config
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoraFileHandler(FileSystemEventHandler): # Configuration constant to control file monitoring functionality
"""Handler for LoRA file system events""" ENABLE_FILE_MONITORING = False
class BaseFileHandler(FileSystemEventHandler):
"""Base handler for file system events"""
def __init__(self, scanner: LoraScanner, loop: asyncio.AbstractEventLoop): def __init__(self, loop: asyncio.AbstractEventLoop):
self.scanner = scanner self.loop = loop # Store event loop reference
self.loop = loop # 存储事件循环引用 self.pending_changes = set() # Pending changes
self.pending_changes = set() # 待处理的变更 self.lock = Lock() # Thread-safe lock
self.lock = Lock() # 线程安全锁 self.update_task = None # Async update task
self.update_task = None # 异步更新任务 self._ignore_paths = set() # Paths to ignore
self._ignore_paths = {} # Change to dictionary to store expiration times self._min_ignore_timeout = 5 # Minimum timeout in seconds
self._min_ignore_timeout = 5 # minimum timeout in seconds self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
self._download_speed = 1024 * 1024 # assume 1MB/s as base speed
# Track modified files with timestamps for debouncing
self.modified_files: Dict[str, float] = {}
self.debounce_timer = None
self.debounce_delay = 3.0 # Seconds to wait after last modification
# Track files already scheduled for processing
self.scheduled_files: Set[str] = set()
# File extensions to monitor - should be overridden by subclasses
self.file_extensions = set()
def _should_ignore(self, path: str) -> bool: def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored""" """Check if path should be ignored"""
real_path = os.path.realpath(path) # Resolve any symbolic links real_path = os.path.realpath(path) # Resolve any symbolic links
normalized_path = real_path.replace(os.sep, '/') return real_path.replace(os.sep, '/') in self._ignore_paths
# Also check with backslashes for Windows compatibility
alt_path = real_path.replace('/', '\\')
# 使用传入的事件循环而不是尝试获取当前线程的事件循环
current_time = self.loop.time()
# Check if path is in ignore list and not expired
if normalized_path in self._ignore_paths and self._ignore_paths[normalized_path] > current_time:
return True
# Also check alternative path format
if alt_path in self._ignore_paths and self._ignore_paths[alt_path] > current_time:
return True
return False
def add_ignore_path(self, path: str, file_size: int = 0): def add_ignore_path(self, path: str, file_size: int = 0):
"""Add path to ignore list with dynamic timeout based on file size""" """Add path to ignore list with dynamic timeout based on file size"""
real_path = os.path.realpath(path) # Resolve any symbolic links real_path = os.path.realpath(path) # Resolve any symbolic links
normalized_path = real_path.replace(os.sep, '/') self._ignore_paths.add(real_path.replace(os.sep, '/'))
# Calculate timeout based on file size # Short timeout (e.g. 5 seconds) is sufficient to ignore the CREATE event
# For small files, use minimum timeout timeout = 5
# For larger files, estimate download time + buffer
if file_size > 0:
# Estimate download time in seconds (size / speed) + buffer
estimated_time = (file_size / self._download_speed) + 10
timeout = max(self._min_ignore_timeout, estimated_time)
else:
timeout = self._min_ignore_timeout
current_time = self.loop.time()
expiration_time = current_time + timeout
# Store both normalized and alternative path formats
self._ignore_paths[normalized_path] = expiration_time
# Also store with backslashes for Windows compatibility
alt_path = real_path.replace('/', '\\')
self._ignore_paths[alt_path] = expiration_time
logger.debug(f"Added ignore path: {normalized_path} (expires in {timeout:.1f}s)")
self.loop.call_later( self.loop.call_later(
timeout, timeout,
self._remove_ignore_path, self._ignore_paths.discard,
normalized_path real_path.replace(os.sep, '/')
) )
def _remove_ignore_path(self, path: str):
"""Remove path from ignore list after timeout"""
if path in self._ignore_paths:
del self._ignore_paths[path]
logger.debug(f"Removed ignore path: {path}")
# Also remove alternative path format
alt_path = path.replace('/', '\\')
if alt_path in self._ignore_paths:
del self._ignore_paths[alt_path]
def on_created(self, event): def on_created(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'): if event.is_directory:
return return
if self._should_ignore(event.src_path):
# Handle appropriate files based on extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
# Process this file directly and ignore subsequent modifications
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"File created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
def on_modified(self, event):
if event.is_directory:
return return
logger.info(f"LoRA file created: {event.src_path}")
self._schedule_update('add', event.src_path) # Only process files with supported extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
# Skip if this file is already scheduled for processing
if normalized_path in self.scheduled_files:
return
# Update the timestamp for this file
self.modified_files[normalized_path] = time.time()
# Cancel any existing timer
if self.debounce_timer:
self.debounce_timer.cancel()
# Set a new timer to process modified files after debounce period
self.debounce_timer = self.loop.call_later(
self.debounce_delay,
self.loop.call_soon_threadsafe,
self._process_modified_files
)
def _process_modified_files(self):
"""Process files that have been modified after debounce period"""
current_time = time.time()
files_to_process = []
# Find files that haven't been modified for debounce_delay seconds
for file_path, last_modified in list(self.modified_files.items()):
if current_time - last_modified >= self.debounce_delay:
# Only process if not already scheduled
if file_path not in self.scheduled_files:
files_to_process.append(file_path)
self.scheduled_files.add(file_path)
# Auto-remove from scheduled list after reasonable time
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
file_path
)
del self.modified_files[file_path]
# Process stable files
for file_path in files_to_process:
logger.info(f"Processing modified file: {file_path}")
self._schedule_update('add', file_path)
def on_deleted(self, event): def on_deleted(self, event):
if event.is_directory or not event.src_path.endswith('.safetensors'): if event.is_directory:
return return
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.file_extensions:
return
if self._should_ignore(event.src_path): if self._should_ignore(event.src_path):
return return
logger.info(f"LoRA file deleted: {event.src_path}")
# Remove from scheduled files if present
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"File deleted: {event.src_path}")
self._schedule_update('remove', event.src_path) self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str): #file_path is a real path def on_moved(self, event):
"""Handle file move/rename events"""
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.file_extensions:
if self._should_ignore(event.dest_path):
return
normalized_path = os.path.realpath(event.dest_path).replace(os.sep, '/')
# Only process if not already scheduled
if normalized_path not in self.scheduled_files:
logger.info(f"File renamed/moved to: {event.dest_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.dest_path)
# Auto-remove from scheduled list after reasonable time
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
# If source was a supported file, treat it as deleted
if src_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"File moved/renamed from: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str):
"""Schedule a cache update""" """Schedule a cache update"""
with self.lock: with self.lock:
# 使用 config 中的方法映射路径 # Use config method to map path
mapped_path = config.map_path_to_link(file_path) mapped_path = config.map_path_to_link(file_path)
normalized_path = mapped_path.replace(os.sep, '/') normalized_path = mapped_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path)) self.pending_changes.add((action, normalized_path))
@@ -119,7 +207,20 @@ class LoraFileHandler(FileSystemEventHandler):
"""Create update task in the event loop""" """Create update task in the event loop"""
if self.update_task is None or self.update_task.done(): if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes()) self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing - should be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement _process_changes")
class LoraFileHandler(BaseFileHandler):
"""Handler for LoRA file system events"""
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for LoRAs
self.file_extensions = {'.safetensors'}
async def _process_changes(self, delay: float = 2.0): async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing""" """Process pending changes with debouncing"""
await asyncio.sleep(delay) await asyncio.sleep(delay)
@@ -132,46 +233,54 @@ class LoraFileHandler(FileSystemEventHandler):
if not changes: if not changes:
return return
logger.info(f"Processing {len(changes)} file changes") logger.info(f"Processing {len(changes)} LoRA file changes")
cache = await self.scanner.get_cached_data() # Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
needs_resort = False needs_resort = False
new_folders = set() new_folders = set()
for action, file_path in changes: for action, file_path in changes:
try: try:
if action == 'add': if action == 'add':
# Scan new file # Check if file already exists in cache
lora_data = await self.scanner.scan_single_lora(file_path) existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if lora_data: if existing:
# Update tags count logger.info(f"File {file_path} already in cache, skipping")
for tag in lora_data.get('tags', []): continue
self.scanner._tags_count[tag] = self.scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(lora_data) # Scan new file
new_folders.add(lora_data['folder']) model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index # Update hash index
if 'sha256' in lora_data: if 'sha256' in model_data:
self.scanner._hash_index.add_entry( scanner._hash_index.add_entry(
lora_data['sha256'], model_data['sha256'],
lora_data['file_path'] model_data['file_path']
) )
needs_resort = True needs_resort = True
elif action == 'remove': elif action == 'remove':
# Find the lora to remove so we can update tags count # Find the model to remove so we can update tags count
lora_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if lora_to_remove: if model_to_remove:
# Update tags count by reducing counts # Update tags count by reducing counts
for tag in lora_to_remove.get('tags', []): for tag in model_to_remove.get('tags', []):
if tag in self.scanner._tags_count: if tag in scanner._tags_count:
self.scanner._tags_count[tag] = max(0, self.scanner._tags_count[tag] - 1) scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if self.scanner._tags_count[tag] == 0: if scanner._tags_count[tag] == 0:
del self.scanner._tags_count[tag] del scanner._tags_count[tag]
# Remove from cache and hash index # Remove from cache and hash index
logger.info(f"Removing {file_path} from cache") logger.info(f"Removing {file_path} from cache")
self.scanner._hash_index.remove_by_path(file_path) scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [ cache.raw_data = [
item for item in cache.raw_data item for item in cache.raw_data
if item['file_path'] != file_path if item['file_path'] != file_path
@@ -189,62 +298,245 @@ class LoraFileHandler(FileSystemEventHandler):
cache.folders = sorted(list(all_folders), key=lambda x: x.lower()) cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e: except Exception as e:
logger.error(f"Error in process_changes: {e}") logger.error(f"Error in process_changes for LoRA: {e}")
class LoraFileMonitor: class CheckpointFileHandler(BaseFileHandler):
"""Monitor for LoRA file changes""" """Handler for checkpoint file system events"""
def __init__(self, scanner: LoraScanner, roots: List[str]): def __init__(self, loop: asyncio.AbstractEventLoop):
self.scanner = scanner super().__init__(loop)
scanner.set_file_monitor(self) # Set supported file extensions for checkpoints
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing for checkpoint files"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} checkpoint file changes")
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
for action, file_path in changes:
try:
if action == 'add':
# Check if file already exists in cache
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if existing:
logger.info(f"File {file_path} already in cache, skipping")
continue
# Scan new file
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count if applicable
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from checkpoint cache")
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# Update folder list
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes for checkpoint: {e}")
class BaseFileMonitor:
"""Base class for file monitoring"""
def __init__(self, monitor_paths: List[str]):
self.observer = Observer() self.observer = Observer()
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self.handler = LoraFileHandler(scanner, self.loop)
# 使用已存在的路径映射
self.monitor_paths = set() self.monitor_paths = set()
for root in roots:
self.monitor_paths.add(os.path.realpath(root).replace(os.sep, '/')) # Process monitor paths
for path in monitor_paths:
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
# 添加所有已映射的目标路径 # Add mapped paths from config
for target_path in config._path_mappings.keys(): for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path) self.monitor_paths.add(target_path)
def start(self): def start(self):
"""Start monitoring""" """Start file monitoring"""
for path_info in self.monitor_paths: if not ENABLE_FILE_MONITORING:
logger.debug("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
for path in self.monitor_paths:
try: try:
if isinstance(path_info, tuple): self.observer.schedule(self.handler, path, recursive=True)
# 对于链接,监控目标路径 logger.info(f"Started monitoring: {path}")
_, target_path = path_info
self.observer.schedule(self.handler, target_path, recursive=True)
logger.info(f"Started monitoring target path: {target_path}")
else:
# 对于普通路径,直接监控
self.observer.schedule(self.handler, path_info, recursive=True)
logger.info(f"Started monitoring: {path_info}")
except Exception as e: except Exception as e:
logger.error(f"Error monitoring {path_info}: {e}") logger.error(f"Error monitoring {path}: {e}")
self.observer.start() self.observer.start()
def stop(self): def stop(self):
"""Stop monitoring""" """Stop file monitoring"""
if not ENABLE_FILE_MONITORING:
return
self.observer.stop() self.observer.stop()
self.observer.join() self.observer.join()
def rescan_links(self): def rescan_links(self):
"""重新扫描链接(当添加新的链接时调用)""" """Rescan links when new ones are added"""
if not ENABLE_FILE_MONITORING:
return
# Find new paths not yet being monitored
new_paths = set() new_paths = set()
for path in self.monitor_paths.copy(): for path in config._path_mappings.keys():
self._add_link_targets(path) real_path = os.path.realpath(path).replace(os.sep, '/')
if real_path not in self.monitor_paths:
new_paths.add(real_path)
self.monitor_paths.add(real_path)
# 添加新发现的路径到监控 # Add new paths to monitoring
new_paths = self.monitor_paths - set(self.observer.watches.keys())
for path in new_paths: for path in new_paths:
try: try:
self.observer.schedule(self.handler, path, recursive=True) self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Added new monitoring path: {path}") logger.info(f"Added new monitoring path: {path}")
except Exception as e: except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}") logger.error(f"Error adding new monitor for {path}: {e}")
class LoraFileMonitor(BaseFileMonitor):
"""Monitor for LoRA file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
from ..config import config
monitor_paths = config.loras_roots
super().__init__(monitor_paths)
self.handler = LoraFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
from ..config import config
cls._instance = cls(config.loras_roots)
return cls._instance
class CheckpointFileMonitor(BaseFileMonitor):
"""Monitor for checkpoint file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
# Get checkpoint roots from scanner
monitor_paths = []
# We'll initialize monitor paths later when scanner is available
super().__init__(monitor_paths or [])
self.handler = CheckpointFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls([])
# Now get checkpoint roots from scanner
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
monitor_paths = scanner.get_model_roots()
# Update monitor paths - but don't actually monitor them
for path in monitor_paths:
real_path = os.path.realpath(path).replace(os.sep, '/')
cls._instance.monitor_paths.add(real_path)
return cls._instance
def start(self):
"""Override start to check global enable flag"""
if not ENABLE_FILE_MONITORING:
logger.debug("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
logger.debug("Checkpoint file monitoring is temporarily disabled")
# Skip the actual monitoring setup
pass
async def initialize_paths(self):
"""Initialize monitor paths from scanner - currently disabled"""
if not ENABLE_FILE_MONITORING:
logger.debug("Checkpoint path initialization skipped (monitoring disabled)")
return
logger.debug("Checkpoint file path initialization skipped (monitoring disabled)")
pass

View File

@@ -3,21 +3,22 @@ import os
import logging import logging
import asyncio import asyncio
import shutil import shutil
from typing import List, Dict, Optional import time
from dataclasses import dataclass from typing import List, Dict, Optional, Set
from operator import itemgetter
from ..utils.models import LoraMetadata
from ..config import config from ..config import config
from ..utils.file_utils import load_metadata, get_file_info from .model_scanner import ModelScanner
from .lora_cache import LoraCache from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
from .lora_hash_index import LoraHashIndex
from .settings_manager import settings from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match from ..utils.utils import fuzzy_match
from .service_registry import ServiceRegistry
import sys import sys
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LoraScanner: class LoraScanner(ModelScanner):
"""Service for scanning and managing LoRA files""" """Service for scanning and managing LoRA files"""
_instance = None _instance = None
@@ -29,20 +30,20 @@ class LoraScanner:
return cls._instance return cls._instance
def __init__(self): def __init__(self):
# 确保初始化只执行一次 # Ensure initialization happens only once
if not hasattr(self, '_initialized'): if not hasattr(self, '_initialized'):
self._cache: Optional[LoraCache] = None # Define supported file extensions
self._hash_index = LoraHashIndex() file_extensions = {'.safetensors'}
self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None # Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
self._initialized = True self._initialized = True
self.file_monitor = None # Add this line
self._tags_count = {} # Add a dictionary to store tag counts
def set_file_monitor(self, monitor):
"""Set file monitor instance"""
self.file_monitor = monitor
@classmethod @classmethod
async def get_instance(cls): async def get_instance(cls):
"""Get singleton instance with async support""" """Get singleton instance with async support"""
@@ -50,92 +51,78 @@ class LoraScanner:
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
async def get_cached_data(self, force_refresh: bool = False) -> LoraCache: def get_model_roots(self) -> List[str]:
"""Get cached LoRA data, refresh if needed""" """Get lora root directories"""
async with self._initialization_lock: return config.loras_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# Create scan tasks for each directory
scan_tasks = []
for lora_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(lora_root))
scan_tasks.append(task)
# 如果缓存未初始化但需要响应请求,返回空缓存 # Wait for all tasks to complete
if self._cache is None and not force_refresh: for task in scan_tasks:
return LoraCache( try:
raw_data=[], loras = await task
sorted_by_name=[], all_loras.extend(loras)
sorted_by_date=[], except Exception as e:
folders=[] logger.error(f"Error scanning directory: {e}")
)
# 如果正在初始化,等待完成
if self._initialization_task and not self._initialization_task.done():
try:
await self._initialization_task
except Exception as e:
logger.error(f"Cache initialization failed: {e}")
self._initialization_task = None
if (self._cache is None or force_refresh):
# 创建新的初始化任务 return all_loras
if not self._initialization_task or self._initialization_task.done():
self._initialization_task = asyncio.create_task(self._initialize_cache()) async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
try: with os.scandir(path) as it:
await self._initialization_task entries = list(it)
except Exception as e: for entry in entries:
logger.error(f"Cache initialization failed: {e}") try:
# 如果缓存已存在,继续使用旧缓存 if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
if self._cache is None: # Use original path instead of real path
raise # 如果没有缓存,则抛出异常 file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
return self._cache await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
async def _initialize_cache(self) -> None: await scan_recursive(root_path, set())
"""Initialize or refresh the cache""" return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""Process a single file and add to results list"""
try: try:
# Clear existing hash index result = await self._process_model_file(file_path, root_path)
self._hash_index.clear() if result:
loras.append(result)
# Clear existing tags count
self._tags_count = {}
# Scan for new data
raw_data = await self.scan_all_loras()
# Build hash index and tags count
for lora_data in raw_data:
if 'sha256' in lora_data and 'file_path' in lora_data:
self._hash_index.add_entry(lora_data['sha256'].lower(), lora_data['file_path'])
# Count tags
if 'tags' in lora_data and lora_data['tags']:
for tag in lora_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = LoraCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Call resort_cache to create sorted views
await self._cache.resort()
self._initialization_task = None
logger.info("LoRA Manager: Cache initialization completed")
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error initializing cache: {e}") logger.error(f"Error processing {file_path}: {e}")
self._cache = LoraCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy: bool = False, folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None, base_models: list = None, tags: list = None,
search_options: dict = None) -> Dict: search_options: dict = None, hash_filters: dict = None) -> Dict:
"""Get paginated and filtered lora data """Get paginated and filtered lora data
Args: Args:
@@ -144,10 +131,11 @@ class LoraScanner:
sort_by: Sort method ('name' or 'date') sort_by: Sort method ('name' or 'date')
folder: Filter by folder path folder: Filter by folder path
search: Search term search: Search term
fuzzy: Use fuzzy matching for search fuzzy_search: Use fuzzy matching for search
base_models: List of base models to filter by base_models: List of base models to filter by
tags: List of tags to filter by tags: List of tags to filter by
search_options: Dictionary with search options (filename, modelname, tags, recursive) search_options: Dictionary with search options (filename, modelname, tags, recursive)
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
""" """
cache = await self.get_cached_data() cache = await self.get_cached_data()
@@ -157,90 +145,108 @@ class LoraScanner:
'filename': True, 'filename': True,
'modelname': True, 'modelname': True,
'tags': False, 'tags': False,
'recursive': False 'recursive': False,
} }
# Get the base data set # Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled # Apply SFW filtering if enabled
if settings.get('show_only_sfw', False): if settings.get('show_only_sfw', False):
filtered_data = [ filtered_data = [
item for item in filtered_data lora for lora in filtered_data
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
] ]
# Apply folder filtering # Apply folder filtering
if folder is not None: if folder is not None:
if search_options.get('recursive', False): if search_options.get('recursive', False):
# Recursive mode: match all paths starting with this folder # Recursive folder filtering - include all subfolders
filtered_data = [ filtered_data = [
item for item in filtered_data lora for lora in filtered_data
if item['folder'].startswith(folder + '/') or item['folder'] == folder if lora['folder'].startswith(folder)
] ]
else: else:
# Non-recursive mode: match exact folder # Exact folder filtering
filtered_data = [ filtered_data = [
item for item in filtered_data lora for lora in filtered_data
if item['folder'] == folder if lora['folder'] == folder
] ]
# Apply base model filtering # Apply base model filtering
if base_models and len(base_models) > 0: if base_models and len(base_models) > 0:
filtered_data = [ filtered_data = [
item for item in filtered_data lora for lora in filtered_data
if item.get('base_model') in base_models if lora.get('base_model') in base_models
] ]
# Apply tag filtering # Apply tag filtering
if tags and len(tags) > 0: if tags and len(tags) > 0:
filtered_data = [ filtered_data = [
item for item in filtered_data lora for lora in filtered_data
if any(tag in item.get('tags', []) for tag in tags) if any(tag in lora.get('tags', []) for tag in tags)
] ]
# Apply search filtering # Apply search filtering
if search: if search:
search_results = [] search_results = []
for item in filtered_data: search_opts = search_options or {}
# Check filename if enabled
if search_options.get('filename', True): for lora in filtered_data:
if fuzzy: # Search by file name
if fuzzy_match(item.get('file_name', ''), search): if search_opts.get('filename', True):
search_results.append(item) if fuzzy_match(lora.get('file_name', ''), search):
continue search_results.append(lora)
else:
if search.lower() in item.get('file_name', '').lower():
search_results.append(item)
continue
# Check model name if enabled
if search_options.get('modelname', True):
if fuzzy:
if fuzzy_match(item.get('model_name', ''), search):
search_results.append(item)
continue
else:
if search.lower() in item.get('model_name', '').lower():
search_results.append(item)
continue
# Check tags if enabled
if search_options.get('tags', False) and item.get('tags'):
found_tag = False
for tag in item['tags']:
if fuzzy:
if fuzzy_match(tag, search):
found_tag = True
break
else:
if search.lower() in tag.lower():
found_tag = True
break
if found_tag:
search_results.append(item)
continue continue
# Search by model name
if search_opts.get('modelname', True):
if fuzzy_match(lora.get('model_name', ''), search):
search_results.append(lora)
continue
# Search by tags
if search_opts.get('tags', False) and 'tags' in lora:
if any(fuzzy_match(tag, search) for tag in lora['tags']):
search_results.append(lora)
continue
filtered_data = search_results filtered_data = search_results
# Calculate pagination # Calculate pagination
@@ -258,326 +264,6 @@ class LoraScanner:
return result return result
def invalidate_cache(self):
"""Invalidate the current cache"""
self._cache = None
async def scan_all_loras(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# 分目录异步扫描
scan_tasks = []
for loras_root in config.loras_roots:
task = asyncio.create_task(self._scan_directory(loras_root))
scan_tasks.append(task)
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # 保存原始根路径
async def scan_recursive(path: str, visited_paths: set):
"""递归扫描目录,避免循环链接"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.safetensors'):
# 使用原始路径而不是真实路径
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# 对于目录,使用原始路径继续扫描
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""处理单个文件并添加到结果列表"""
try:
result = await self._process_lora_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def _process_lora_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single LoRA file and return its metadata"""
# Try loading existing metadata
metadata = await load_metadata(file_path)
if metadata is None:
# Create new metadata if none exists
metadata = await get_file_info(file_path)
# Convert to dict and add folder info
lora_data = metadata.to_dict()
# Try to fetch missing metadata from Civitai if needed
await self._fetch_missing_metadata(file_path, lora_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
lora_data['folder'] = folder.replace(os.path.sep, '/')
return lora_data
async def _fetch_missing_metadata(self, file_path: str, lora_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed
Args:
file_path: Path to the lora file
lora_data: Lora metadata dictionary to update
"""
try:
# Skip if already marked as deleted on Civitai
if lora_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
# Check if we need to fetch additional metadata from Civitai
needs_metadata_update = False
model_id = None
# Check if we have Civitai model ID but missing metadata
if lora_data.get('civitai'):
# Try to get model ID directly from the correct location
model_id = lora_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
# Check if tags are missing or empty
tags_missing = not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0
# Check if description is missing or empty
desc_missing = not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
# Fetch missing metadata if needed
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
# Get metadata and status code
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
# Handle 404 status (model deleted from Civitai)
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
# Mark as deleted to avoid future API calls
lora_data['civitai_deleted'] = True
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
# Process valid metadata if available
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
# Update tags if they were missing
if model_metadata.get('tags') and (not lora_data.get('tags') or len(lora_data.get('tags', [])) == 0):
lora_data['tags'] = model_metadata['tags']
# Update description if it was missing
if model_metadata.get('description') and (not lora_data.get('modelDescription') or lora_data.get('modelDescription') in (None, "")):
lora_data['modelDescription'] = model_metadata['description']
# Save the updated metadata back to file
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(lora_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)
async def scan_single_lora(self, file_path: str) -> Optional[Dict]:
"""Scan a single LoRA file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# 获取基本文件信息
metadata = await get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# 确保 folder 字段存在
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a LoRA file"""
# 使用原始路径计算相对路径
for root in config.loras_roots:
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
# 保持原始路径格式
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
# 其余代码保持不变
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True)
target_lora = os.path.join(target_path, f"{base_name}.safetensors").replace(os.sep, '/')
# 使用真实路径进行文件操作
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_lora)
file_size = os.path.getsize(real_source)
if self.file_monitor:
self.file_monitor.handler.add_ignore_path(
real_source,
file_size
)
self.file_monitor.handler.add_ignore_path(
real_target,
file_size
)
# 使用真实路径进行文件操作
shutil.move(real_source, real_target)
# Move associated files
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_lora)
# Move preview file if exists
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
# Update cache
await self.update_single_lora_cache(source_path, target_lora, metadata)
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def update_single_lora_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
cache = await self.get_cached_data()
# Find the existing item to remove its tags from count
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove old path from hash index if exists
self._hash_index.remove_by_path(original_path)
# Remove the old entry from raw_data
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != original_path
]
if metadata:
# If this is an update to an existing path (not a move), ensure folder is preserved
if original_path == new_path:
# Find the folder from existing entries or calculate it
existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None)
if existing_folder:
metadata['folder'] = existing_folder
else:
metadata['folder'] = self._calculate_folder(new_path)
else:
# For moved files, recalculate the folder
metadata['folder'] = self._calculate_folder(new_path)
# Add the updated metadata to raw_data
cache.raw_data.append(metadata)
# Update hash index with new path
if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
# Update folders list
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update tags count with the new/updated tags
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Resort cache
await cache.resort()
return True
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict: async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
"""Update file paths in metadata file""" """Update file paths in metadata file"""
try: try:
@@ -604,49 +290,21 @@ class LoraScanner:
except Exception as e: except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True) logger.error(f"Error updating metadata paths: {e}", exc_info=True)
# Add new methods for hash index functionality # Lora-specific hash index functionality
def has_lora_hash(self, sha256: str) -> bool: def has_lora_hash(self, sha256: str) -> bool:
"""Check if a LoRA with given hash exists""" """Check if a LoRA with given hash exists"""
return self._hash_index.has_hash(sha256.lower()) return self.has_hash(sha256)
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]: def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a LoRA by its hash""" """Get file path for a LoRA by its hash"""
return self._hash_index.get_path(sha256.lower()) return self.get_path_by_hash(sha256)
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]: def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a LoRA by its file path""" """Get hash for a LoRA by its file path"""
return self._hash_index.get_hash(file_path) return self.get_hash_by_path(file_path)
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
"""Get preview static URL for a LoRA by its hash"""
# Get the file path first
file_path = self._hash_index.get_path(sha256.lower())
if not file_path:
return None
# Determine the preview file path (typically same name with different extension)
base_name = os.path.splitext(file_path)[0]
preview_extensions = ['.preview.png', '.preview.jpeg', '.preview.jpg', '.preview.mp4',
'.png', '.jpeg', '.jpg', '.mp4']
for ext in preview_extensions:
preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path):
# Convert to static URL using config
return config.get_preview_static_url(preview_path)
return None
# Add new method to get top tags
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count """Get top tags sorted by count"""
Args:
limit: Maximum number of tags to return
Returns:
List of dictionaries with tag name and count, sorted by count
"""
# Make sure cache is initialized # Make sure cache is initialized
await self.get_cached_data() await self.get_cached_data()
@@ -661,14 +319,7 @@ class LoraScanner:
return sorted_tags[:limit] return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]: async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models used in loras sorted by frequency """Get base models used in loras sorted by frequency"""
Args:
limit: Maximum number of base models to return
Returns:
List of dictionaries with base model name and count, sorted by count
"""
# Make sure cache is initialized # Make sure cache is initialized
cache = await self.get_cached_data() cache = await self.get_cached_data()

View File

@@ -0,0 +1,64 @@
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from operator import itemgetter
@dataclass
class ModelCache:
"""Cache structure for model data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async with self._lock:
self.sorted_by_name = sorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific model in all cached data
Args:
file_path: The file path of the model to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the model wasn't found
"""
async with self._lock:
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Model not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

View File

@@ -0,0 +1,96 @@
from typing import Dict, Optional, Set
import os
class ModelHashIndex:
"""Index for looking up models by hash or path"""
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update hash index entry"""
if not sha256 or not file_path:
return
# Ensure hash is lowercase for consistency
sha256 = sha256.lower()
# Extract filename without extension
filename = self._get_filename_from_path(file_path)
# Remove old path mapping if hash exists
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
old_filename = self._get_filename_from_path(old_path)
if old_filename in self._filename_to_hash:
del self._filename_to_hash[old_filename]
# Remove old hash mapping if filename exists
if filename in self._filename_to_hash:
old_hash = self._filename_to_hash[filename]
if old_hash in self._hash_to_path:
del self._hash_to_path[old_hash]
# Add new mappings
self._hash_to_path[sha256] = file_path
self._filename_to_hash[filename] = sha256
def _get_filename_from_path(self, file_path: str) -> str:
"""Extract filename without extension from path"""
return os.path.splitext(os.path.basename(file_path))[0]
def remove_by_path(self, file_path: str) -> None:
"""Remove entry by file path"""
filename = self._get_filename_from_path(file_path)
if filename in self._filename_to_hash:
hash_val = self._filename_to_hash[filename]
if hash_val in self._hash_to_path:
del self._hash_to_path[hash_val]
del self._filename_to_hash[filename]
def remove_by_hash(self, sha256: str) -> None:
"""Remove entry by hash"""
sha256 = sha256.lower()
if sha256 in self._hash_to_path:
path = self._hash_to_path[sha256]
filename = self._get_filename_from_path(path)
if filename in self._filename_to_hash:
del self._filename_to_hash[filename]
del self._hash_to_path[sha256]
def has_hash(self, sha256: str) -> bool:
"""Check if hash exists in index"""
return sha256.lower() in self._hash_to_path
def get_path(self, sha256: str) -> Optional[str]:
"""Get file path for a hash"""
return self._hash_to_path.get(sha256.lower())
def get_hash(self, file_path: str) -> Optional[str]:
"""Get hash for a file path"""
filename = self._get_filename_from_path(file_path)
return self._filename_to_hash.get(filename)
def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a filename without extension"""
# Strip extension if present to make the function more flexible
filename = os.path.splitext(filename)[0]
return self._filename_to_hash.get(filename)
def clear(self) -> None:
"""Clear all entries"""
self._hash_to_path.clear()
self._filename_to_hash.clear()
def get_all_hashes(self) -> Set[str]:
"""Get all hashes in the index"""
return set(self._hash_to_path.keys())
def get_all_filenames(self) -> Set[str]:
"""Get all filenames in the index"""
return set(self._filename_to_hash.keys())
def __len__(self) -> int:
"""Get number of entries"""
return len(self._hash_to_path)

View File

@@ -0,0 +1,910 @@
import json
import os
import logging
import asyncio
import time
import shutil
from typing import List, Dict, Optional, Type, Set
from ..utils.models import BaseModelMetadata
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager
logger = logging.getLogger(__name__)
class ModelScanner:
"""Base service for scanning and managing model files"""
_lock = asyncio.Lock()
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
"""Initialize the scanner
Args:
model_type: Type of model (lora, checkpoint, etc.)
model_class: Class used to create metadata instances
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
hash_index: Hash index instance (optional)
"""
self.model_type = model_type
self.model_class = model_class
self.file_extensions = file_extensions
self._cache = None
self._hash_index = hash_index or ModelHashIndex()
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
# Register this service
asyncio.create_task(self._register_service())
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
# Set initial empty cache to avoid None reference errors
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Set initializing flag to true
self._is_initializing = True
# Determine the page type based on model type
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
# First, count all model files to track progress
await ws_manager.broadcast_init_progress({
'stage': 'scan_folders',
'progress': 0,
'details': f"Scanning {self.model_type} folders...",
'scanner_type': self.model_type,
'pageType': page_type
})
# Count files in a separate thread to avoid blocking
loop = asyncio.get_event_loop()
total_files = await loop.run_in_executor(
None, # Use default thread pool
self._count_model_files # Run file counting in thread
)
await ws_manager.broadcast_init_progress({
'stage': 'count_models',
'progress': 1, # Changed from 10 to 1
'details': f"Found {total_files} {self.model_type} files",
'scanner_type': self.model_type,
'pageType': page_type
})
start_time = time.time()
# Use thread pool to execute CPU-intensive operations with progress reporting
await loop.run_in_executor(
None, # Use default thread pool
self._initialize_cache_sync, # Run synchronous version in thread
total_files, # Pass the total file count for progress reporting
page_type # Pass the page type for progress reporting
)
# Send final progress update
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
'progress': 99, # Changed from 95 to 99
'details': f"Finalizing {self.model_type} cache...",
'scanner_type': self.model_type,
'pageType': page_type
})
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
# Send completion message
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
await ws_manager.broadcast_init_progress({
'stage': 'finalizing',
'progress': 100,
'status': 'complete',
'details': f"Completed! Found {len(self._cache.raw_data)} {self.model_type} files.",
'scanner_type': self.model_type,
'pageType': page_type
})
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache in background: {e}")
finally:
# Always clear the initializing flag when done
self._is_initializing = False
def _count_model_files(self) -> int:
"""Count all model files with supported extensions in all roots
Returns:
int: Total number of model files found
"""
total_files = 0
visited_real_paths = set()
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
def count_recursive(path):
nonlocal total_files
try:
real_path = os.path.realpath(path)
if real_path in visited_real_paths:
return
visited_real_paths.add(real_path)
with os.scandir(path) as it:
for entry in it:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
total_files += 1
elif entry.is_dir(follow_symlinks=True):
count_recursive(entry.path)
except Exception as e:
logger.error(f"Error counting files in entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error counting files in {path}: {e}")
count_recursive(root_path)
return total_files
def _initialize_cache_sync(self, total_files=0, page_type='loras'):
"""Synchronous version of cache initialization for thread pool execution"""
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a synchronous method to bypass the async lock
def sync_initialize_cache():
# Track progress
processed_files = 0
last_progress_time = time.time()
last_progress_percent = 0
# We need a wrapper around scan_all_models to track progress
# This is a local function that will run in our thread's event loop
async def scan_with_progress():
nonlocal processed_files, last_progress_time, last_progress_percent
# For storing raw model data
all_models = []
# Process each model root
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
# Track visited paths to avoid symlink loops
visited_paths = set()
# Recursively process directory
async def scan_dir_with_progress(path):
nonlocal processed_files, last_progress_time, last_progress_percent
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
result = await self._process_model_file(file_path, root_path)
if result:
all_models.append(result)
# Update progress counter
processed_files += 1
# Update progress periodically (not every file to avoid excessive updates)
current_time = time.time()
if total_files > 0 and (current_time - last_progress_time > 0.5 or processed_files == total_files):
# Adjusted progress calculation
progress_percent = min(99, int(1 + (processed_files / total_files) * 98))
if progress_percent > last_progress_percent:
last_progress_percent = progress_percent
last_progress_time = current_time
# Send progress update through websocket
await ws_manager.broadcast_init_progress({
'stage': 'process_models',
'progress': progress_percent,
'details': f"Processing {self.model_type} files: {processed_files}/{total_files}",
'scanner_type': self.model_type,
'pageType': page_type
})
elif entry.is_dir(follow_symlinks=True):
await scan_dir_with_progress(entry.path)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
# Process the root path
await scan_dir_with_progress(root_path)
return all_models
# Run the progress-tracking scan function
raw_data = loop.run_until_complete(scan_with_progress())
# Update hash index and tags count
for model_data in raw_data:
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Count tags
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache.raw_data = raw_data
loop.run_until_complete(self._cache.resort())
return self._cache
# Run our sync initialization that avoids lock conflicts
return sync_initialize_cache()
except Exception as e:
logger.error(f"Error in thread-based {self.model_type} cache initialization: {e}")
finally:
# Clean up the event loop
loop.close()
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
"""Get cached model data, refresh if needed"""
# If cache is not initialized, return an empty cache
# Actual initialization should be done via initialize_in_background
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# If force refresh is requested, initialize the cache directly
if (force_refresh):
if self._cache is None:
# For initial creation, do a full initialization
await self._initialize_cache()
else:
# For subsequent refreshes, use fast reconciliation
await self._reconcile_cache()
return self._cache
async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache"""
self._is_initializing = True # Set flag
try:
start_time = time.time()
# Clear existing hash index
self._hash_index.clear()
# Clear existing tags count
self._tags_count = {}
# Determine the page type based on model type
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
# Scan for new data
raw_data = await self.scan_all_models()
# Build hash index and tags count
for model_data in raw_data:
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Count tags
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Update cache
self._cache = ModelCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# Resort cache
await self._cache.resort()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, found {len(raw_data)} models")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error initializing cache: {e}")
# Ensure cache is at least an empty structure on error
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
finally:
self._is_initializing = False # Unset flag
async def _reconcile_cache(self) -> None:
"""Fast cache reconciliation - only process differences between cache and filesystem"""
self._is_initializing = True # Set flag for reconciliation duration
try:
start_time = time.time()
logger.info(f"{self.model_type.capitalize()} Scanner: Starting fast cache reconciliation...")
# Get current cached file paths
cached_paths = {item['file_path'] for item in self._cache.raw_data}
path_to_item = {item['file_path']: item for item in self._cache.raw_data}
# Track found files and new files
found_paths = set()
new_files = []
# Scan all model roots
for root_path in self.get_model_roots():
if not os.path.exists(root_path):
continue
# Track visited real paths to avoid symlink loops
visited_real_paths = set()
# Recursively scan directory
for root, _, files in os.walk(root_path, followlinks=True):
real_root = os.path.realpath(root)
if real_root in visited_real_paths:
continue
visited_real_paths.add(real_root)
for file in files:
ext = os.path.splitext(file)[1].lower()
if ext in self.file_extensions:
# Construct paths exactly as they would be in cache
file_path = os.path.join(root, file).replace(os.sep, '/')
# Check if this file is already in cache
if file_path in cached_paths:
found_paths.add(file_path)
continue
# Try case-insensitive match on Windows
if os.name == 'nt':
lower_path = file_path.lower()
matched = False
for cached_path in cached_paths:
if cached_path.lower() == lower_path:
found_paths.add(cached_path)
matched = True
break
if matched:
continue
# This is a new file to process
new_files.append(file_path)
# Yield control periodically
await asyncio.sleep(0)
# Process new files in batches
total_added = 0
if new_files:
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(new_files)} new files to process")
batch_size = 50
for i in range(0, len(new_files), batch_size):
batch = new_files[i:i+batch_size]
for path in batch:
try:
model_data = await self.scan_single_model(path)
if model_data:
# Add to cache
self._cache.raw_data.append(model_data)
# Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Update tags count
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
total_added += 1
except Exception as e:
logger.error(f"Error adding {path} to cache: {e}")
# Yield control after each batch
await asyncio.sleep(0)
# Find missing files (in cache but not in filesystem)
missing_files = cached_paths - found_paths
total_removed = 0
if missing_files:
logger.info(f"{self.model_type.capitalize()} Scanner: Found {len(missing_files)} files to remove from cache")
# Process files to remove
for path in missing_files:
try:
model_to_remove = path_to_item[path]
# Update tags count
for tag in model_to_remove.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Remove from hash index
self._hash_index.remove_by_path(path)
total_removed += 1
except Exception as e:
logger.error(f"Error removing {path} from cache: {e}")
# Update cache data
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in missing_files]
# Resort cache if changes were made
if total_added > 0 or total_removed > 0:
# Update folders list
all_folders = set(item.get('folder', '') for item in self._cache.raw_data)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Resort cache
await self._cache.resort()
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
except Exception as e:
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
finally:
self._is_initializing = False # Unset flag
# These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata"""
raise NotImplementedError("Subclasses must implement scan_all_models")
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
raise NotImplementedError("Subclasses must implement get_model_roots")
async def scan_single_model(self, file_path: str) -> Optional[Dict]:
"""Scan a single model file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# Get basic file info
metadata = await self._get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# Ensure folder field exists
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
"""Get model file info and metadata (extensible for different model types)"""
return await get_file_info(file_path, self.model_class)
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a model file"""
for root in self.get_model_roots():
if file_path.startswith(root):
rel_path = os.path.relpath(file_path, root)
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
# Common methods shared between scanners
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single model file and return its metadata"""
metadata = await load_metadata(file_path, self.model_class)
if metadata is None:
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if file_info:
file_name = os.path.splitext(os.path.basename(file_path))[0]
file_info['name'] = file_name
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await save_metadata(file_path, metadata)
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
else:
# Check if metadata exists but civitai field is empty - try to restore from civitai.info
if metadata.civitai is None or metadata.civitai == {}:
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
if os.path.exists(civitai_info_path):
try:
with open(civitai_info_path, 'r', encoding='utf-8') as f:
version_info = json.load(f)
logger.debug(f"Restoring missing civitai data from .civitai.info for {file_path}")
metadata.civitai = version_info
# Ensure tags are also updated if they're missing
if (not metadata.tags or len(metadata.tags) == 0) and 'model' in version_info:
if 'tags' in version_info['model']:
metadata.tags = version_info['model']['tags']
# Also restore description if missing
if (not metadata.modelDescription or metadata.modelDescription == "") and 'model' in version_info:
if 'description' in version_info['model']:
metadata.modelDescription = version_info['model']['description']
# Save the updated metadata
await save_metadata(file_path, metadata)
logger.debug(f"Updated metadata with civitai info for {file_path}")
except Exception as e:
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
if metadata is None:
metadata = await self._get_file_info(file_path)
model_data = metadata.to_dict()
await self._fetch_missing_metadata(file_path, model_data)
rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path)
model_data['folder'] = folder.replace(os.path.sep, '/')
return model_data
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed"""
try:
if model_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
needs_metadata_update = False
model_id = None
if model_data.get('civitai'):
model_id = model_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
needs_metadata_update = tags_missing or desc_missing
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
model_data['civitai_deleted'] = True
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
model_data['tags'] = model_metadata['tags']
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
model_data['modelDescription'] = model_metadata['description']
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Base implementation for directory scanning"""
models = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models_list.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
try:
source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/')
file_ext = os.path.splitext(source_path)[1]
if not file_ext or file_ext.lower() not in self.file_extensions:
logger.error(f"Invalid file extension for model: {file_ext}")
return False
base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True)
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_file)
file_size = os.path.getsize(real_source)
# Get the appropriate file monitor through ServiceRegistry
if self.model_type == "lora":
monitor = await ServiceRegistry.get_lora_monitor()
elif self.model_type == "checkpoint":
monitor = await ServiceRegistry.get_checkpoint_monitor()
else:
monitor = None
if monitor:
monitor.handler.add_ignore_path(
real_source,
file_size
)
monitor.handler.add_ignore_path(
real_target,
file_size
)
shutil.move(real_source, real_target)
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
metadata = None
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file)
for ext in PREVIEW_EXTENSIONS:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
await self.update_single_model_cache(source_path, target_file, metadata)
return True
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return False
async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
metadata['file_path'] = model_path.replace(os.sep, '/')
if 'preview_url' in metadata:
preview_dir = os.path.dirname(model_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return metadata
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
return None
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict) -> bool:
"""Update cache after a model has been moved or modified"""
cache = await self.get_cached_data()
existing_item = next((item for item in cache.raw_data if item['file_path'] == original_path), None)
if existing_item and 'tags' in existing_item:
for tag in existing_item.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
self._hash_index.remove_by_path(original_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != original_path
]
if metadata:
if original_path == new_path:
existing_folder = next((item['folder'] for item in cache.raw_data
if item['file_path'] == original_path), None)
if existing_folder:
metadata['folder'] = existing_folder
else:
metadata['folder'] = self._calculate_folder(new_path)
else:
metadata['folder'] = self._calculate_folder(new_path)
cache.raw_data.append(metadata)
if 'sha256' in metadata:
self._hash_index.add_entry(metadata['sha256'].lower(), new_path)
all_folders = set(item['folder'] for item in cache.raw_data)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
if 'tags' in metadata:
for tag in metadata.get('tags', []):
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
await cache.resort()
return True
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self._hash_index.has_hash(sha256.lower())
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self._hash_index.get_path(sha256.lower())
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self._hash_index.get_hash(file_path)
def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a model by its filename without path"""
return self._hash_index.get_hash_by_filename(filename)
# TODO: Adjust this method to use metadata instead of finding the file
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
"""Get preview static URL for a model by its hash"""
file_path = self._hash_index.get_path(sha256.lower())
if not file_path:
return None
base_name = os.path.splitext(file_path)[0]
for ext in PREVIEW_EXTENSIONS:
preview_path = f"{base_name}{ext}"
if os.path.exists(preview_path):
return config.get_preview_static_url(preview_path)
return None
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count"""
await self.get_cached_data()
sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'],
reverse=True
)
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models sorted by frequency"""
cache = await self.get_cached_data()
base_model_counts = {}
for model in cache.raw_data:
if 'base_model' in model and model['base_model']:
base_model = model['base_model']
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda x: x['count'], reverse=True)
return sorted_models[:limit]
async def get_model_info_by_name(self, name):
"""Get model information by name"""
try:
cache = await self.get_cached_data()
for model in cache.raw_data:
if model.get("file_name") == name:
return model
return None
except Exception as e:
logger.error(f"Error getting model info by name: {e}", exc_info=True)
return None
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
"""
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)

View File

@@ -37,18 +37,18 @@ class RecipeCache:
Returns: Returns:
bool: True if the update was successful, False if the recipe wasn't found bool: True if the update was successful, False if the recipe wasn't found
""" """
async with self._lock:
# Update in raw_data # Update in raw_data
for item in self.raw_data: for item in self.raw_data:
if item.get('id') == recipe_id: if item.get('id') == recipe_id:
item.update(metadata) item.update(metadata)
break break
else: else:
return False # Recipe not found return False # Recipe not found
# Resort to reflect changes # Resort to reflect changes
await self.resort() await self.resort()
return True return True
async def add_recipe(self, recipe_data: Dict) -> None: async def add_recipe(self, recipe_data: Dict) -> None:
"""Add a new recipe to the cache """Add a new recipe to the cache

View File

@@ -2,13 +2,12 @@ import os
import logging import logging
import asyncio import asyncio
import json import json
import re import time
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any, Tuple
from datetime import datetime
from ..config import config from ..config import config
from .recipe_cache import RecipeCache from .recipe_cache import RecipeCache
from .service_registry import ServiceRegistry
from .lora_scanner import LoraScanner from .lora_scanner import LoraScanner
from .civitai_client import CivitaiClient
from ..utils.utils import fuzzy_match from ..utils.utils import fuzzy_match
import sys import sys
@@ -20,11 +19,22 @@ class RecipeScanner:
_instance = None _instance = None
_lock = asyncio.Lock() _lock = asyncio.Lock()
@classmethod
async def get_instance(cls, lora_scanner: Optional[LoraScanner] = None):
"""Get singleton instance of RecipeScanner"""
async with cls._lock:
if cls._instance is None:
if not lora_scanner:
# Get lora scanner from service registry if not provided
lora_scanner = await ServiceRegistry.get_lora_scanner()
cls._instance = cls(lora_scanner)
return cls._instance
def __new__(cls, lora_scanner: Optional[LoraScanner] = None): def __new__(cls, lora_scanner: Optional[LoraScanner] = None):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._lora_scanner = lora_scanner cls._instance._lora_scanner = lora_scanner
cls._instance._civitai_client = CivitaiClient() cls._instance._civitai_client = None # Will be lazily initialized
return cls._instance return cls._instance
def __init__(self, lora_scanner: Optional[LoraScanner] = None): def __init__(self, lora_scanner: Optional[LoraScanner] = None):
@@ -37,9 +47,148 @@ class RecipeScanner:
if lora_scanner: if lora_scanner:
self._lora_scanner = lora_scanner self._lora_scanner = lora_scanner
self._initialized = True self._initialized = True
# Initialization will be scheduled by LoraManager
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
# Set initial empty cache to avoid None reference errors
if self._cache is None:
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[]
)
# Mark as initializing to prevent concurrent initializations
self._is_initializing = True
try:
# Start timer
start_time = time.time()
# Use thread pool to execute CPU-intensive operations
loop = asyncio.get_event_loop()
cache = await loop.run_in_executor(
None, # Use default thread pool
self._initialize_recipe_cache_sync # Run synchronous version in thread
)
# Calculate elapsed time and log it
elapsed_time = time.time() - start_time
recipe_count = len(cache.raw_data) if cache and hasattr(cache, 'raw_data') else 0
logger.info(f"Recipe cache initialized in {elapsed_time:.2f} seconds. Found {recipe_count} recipes")
finally:
# Mark initialization as complete regardless of outcome
self._is_initializing = False
except Exception as e:
logger.error(f"Recipe Scanner: Error initializing cache in background: {e}")
def _initialize_recipe_cache_sync(self):
"""Synchronous version of recipe cache initialization for thread pool execution"""
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a synchronous method to bypass the async lock
def sync_initialize_cache():
# We need to implement scan_all_recipes logic synchronously here
# instead of calling the async method to avoid event loop issues
recipes = []
recipes_dir = self.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
logger.warning(f"Recipes directory not found: {recipes_dir}")
return recipes
# Get all recipe JSON files in the recipes directory
recipe_files = []
for root, _, files in os.walk(recipes_dir):
recipe_count = sum(1 for f in files if f.lower().endswith('.recipe.json'))
if recipe_count > 0:
for file in files:
if file.lower().endswith('.recipe.json'):
recipe_files.append(os.path.join(root, file))
# Process each recipe file
for recipe_path in recipe_files:
try:
with open(recipe_path, 'r', encoding='utf-8') as f:
recipe_data = json.load(f)
# Validate recipe data
if not recipe_data or not isinstance(recipe_data, dict):
logger.warning(f"Invalid recipe data in {recipe_path}")
continue
# Ensure required fields exist
required_fields = ['id', 'file_path', 'title']
if not all(field in recipe_data for field in required_fields):
logger.warning(f"Missing required fields in {recipe_path}")
continue
# Ensure the image file exists
image_path = recipe_data.get('file_path')
if not os.path.exists(image_path):
recipe_dir = os.path.dirname(recipe_path)
image_filename = os.path.basename(image_path)
alternative_path = os.path.join(recipe_dir, image_filename)
if os.path.exists(alternative_path):
recipe_data['file_path'] = alternative_path
# Ensure loras array exists
if 'loras' not in recipe_data:
recipe_data['loras'] = []
# Ensure gen_params exists
if 'gen_params' not in recipe_data:
recipe_data['gen_params'] = {}
# Add to list without async operations
recipes.append(recipe_data)
except Exception as e:
logger.error(f"Error loading recipe file {recipe_path}: {e}")
import traceback
traceback.print_exc(file=sys.stderr)
# Update cache with the collected data
self._cache.raw_data = recipes
# Create a simplified resort function that doesn't use await
if hasattr(self._cache, "resort"):
try:
# Sort by name
self._cache.sorted_by_name = sorted(
self._cache.raw_data,
key=lambda x: x.get('title', '').lower()
)
# Sort by date (modified or created)
self._cache.sorted_by_date = sorted(
self._cache.raw_data,
key=lambda x: x.get('modified', x.get('created_date', 0)),
reverse=True
)
except Exception as e:
logger.error(f"Error sorting recipe cache: {e}")
return self._cache
# Run our sync initialization that avoids lock conflicts
return sync_initialize_cache()
except Exception as e:
logger.error(f"Error in thread-based recipe cache initialization: {e}")
return self._cache if hasattr(self, '_cache') else None
finally:
# Clean up the event loop
loop.close()
@property @property
def recipes_dir(self) -> str: def recipes_dir(self) -> str:
"""Get path to recipes directory""" """Get path to recipes directory"""
@@ -62,32 +211,16 @@ class RecipeScanner:
if self._is_initializing and not force_refresh: if self._is_initializing and not force_refresh:
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
# Try to acquire the lock with a timeout to prevent deadlocks # If force refresh is requested, initialize the cache directly
try: if force_refresh:
# Use a timeout for acquiring the lock # Try to acquire the lock with a timeout to prevent deadlocks
async with asyncio.timeout(1.0): try:
async with self._initialization_lock: async with self._initialization_lock:
# Check again after acquiring the lock
if self._cache is not None and not force_refresh:
return self._cache
# Mark as initializing to prevent concurrent initializations # Mark as initializing to prevent concurrent initializations
self._is_initializing = True self._is_initializing = True
try: try:
# First ensure the lora scanner is initialized # Scan for recipe data directly
if self._lora_scanner:
try:
lora_cache = await asyncio.wait_for(
self._lora_scanner.get_cached_data(),
timeout=10.0
)
except asyncio.TimeoutError:
logger.error("Timeout waiting for lora scanner initialization")
except Exception as e:
logger.error(f"Error waiting for lora scanner: {e}")
# Scan for recipe data
raw_data = await self.scan_all_recipes() raw_data = await self.scan_all_recipes()
# Update cache # Update cache
@@ -114,14 +247,12 @@ class RecipeScanner:
finally: finally:
# Mark initialization as complete # Mark initialization as complete
self._is_initializing = False self._is_initializing = False
except Exception as e:
logger.error(f"Unexpected error in get_cached_data: {e}")
except asyncio.TimeoutError: # Return the cache (may be empty or partially initialized)
# If we can't acquire the lock in time, return the current cache or an empty one return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
logger.warning("Timeout acquiring initialization lock - returning current cache state")
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
except Exception as e:
logger.error(f"Unexpected error in get_cached_data: {e}")
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
async def scan_all_recipes(self) -> List[Dict]: async def scan_all_recipes(self) -> List[Dict]:
"""Scan all recipe JSON files and return metadata""" """Scan all recipe JSON files and return metadata"""
@@ -210,6 +341,10 @@ class RecipeScanner:
metadata_updated = False metadata_updated = False
for lora in recipe_data['loras']: for lora in recipe_data['loras']:
# Skip deleted loras that were already marked
if lora.get('isDeleted', False):
continue
# Skip if already has complete information # Skip if already has complete information
if 'hash' in lora and 'file_name' in lora and lora['file_name']: if 'hash' in lora and 'file_name' in lora and lora['file_name']:
continue continue
@@ -225,12 +360,19 @@ class RecipeScanner:
metadata_updated = True metadata_updated = True
else: else:
# If not in cache, fetch from Civitai # If not in cache, fetch from Civitai
hash_from_civitai = await self._get_hash_from_civitai(model_version_id) result = await self._get_hash_from_civitai(model_version_id)
if hash_from_civitai: if isinstance(result, tuple):
lora['hash'] = hash_from_civitai hash_from_civitai, is_deleted = result
metadata_updated = True if hash_from_civitai:
lora['hash'] = hash_from_civitai
metadata_updated = True
elif is_deleted:
# Mark the lora as deleted if it was not found on Civitai
lora['isDeleted'] = True
logger.warning(f"Marked lora with modelVersionId {model_version_id} as deleted")
metadata_updated = True
else: else:
logger.warning(f"Could not get hash for modelVersionId {model_version_id}") logger.debug(f"Could not get hash for modelVersionId {model_version_id}")
# If has hash but no file_name, look up in lora library # If has hash but no file_name, look up in lora library
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']): if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
@@ -274,42 +416,32 @@ class RecipeScanner:
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
"""Get hash from Civitai API""" """Get hash from Civitai API"""
try: try:
if not self._civitai_client: # Get CivitaiClient from ServiceRegistry
civitai_client = await self._get_civitai_client()
if not civitai_client:
logger.error("Failed to get CivitaiClient from ServiceRegistry")
return None return None
version_info = await self._civitai_client.get_model_version_info(model_version_id) version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
if not version_info or not version_info.get('files'): if not version_info:
logger.warning(f"No files found in version info for ID: {model_version_id}") if error_msg and "model not found" in error_msg.lower():
return None logger.warning(f"Model with version ID {model_version_id} was not found on Civitai - marking as deleted")
return None, True # Return None hash and True for isDeleted flag
else:
logger.debug(f"Could not get hash for modelVersionId {model_version_id}: {error_msg}")
return None, False # Return None hash but not marked as deleted
# Get hash from the first file # Get hash from the first file
for file_info in version_info.get('files', []): for file_info in version_info.get('files', []):
if file_info.get('hashes', {}).get('SHA256'): if file_info.get('hashes', {}).get('SHA256'):
return file_info['hashes']['SHA256'] return file_info['hashes']['SHA256'], False # Return hash with False for isDeleted flag
logger.warning(f"No SHA256 hash found in version info for ID: {model_version_id}") logger.debug(f"No SHA256 hash found in version info for ID: {model_version_id}")
return None return None, False
except Exception as e: except Exception as e:
logger.error(f"Error getting hash from Civitai: {e}") logger.error(f"Error getting hash from Civitai: {e}")
return None return None, False
async def _get_model_version_name(self, model_version_id: str) -> Optional[str]:
"""Get model version name from Civitai API"""
try:
if not self._civitai_client:
return None
version_info = await self._civitai_client.get_model_version_info(model_version_id)
if version_info and 'name' in version_info:
return version_info['name']
logger.warning(f"No version name found for modelVersionId {model_version_id}")
return None
except Exception as e:
logger.error(f"Error getting model version name from Civitai: {e}")
return None
async def _determine_base_model(self, loras: List[Dict]) -> Optional[str]: async def _determine_base_model(self, loras: List[Dict]) -> Optional[str]:
"""Determine the most common base model among LoRAs""" """Determine the most common base model among LoRAs"""
@@ -349,7 +481,7 @@ class RecipeScanner:
logger.error(f"Error getting base model for lora: {e}") logger.error(f"Error getting base model for lora: {e}")
return None return None
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None): async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
"""Get paginated and filtered recipe data """Get paginated and filtered recipe data
Args: Args:
@@ -359,69 +491,89 @@ class RecipeScanner:
search: Search term search: Search term
filters: Dictionary of filters to apply filters: Dictionary of filters to apply
search_options: Dictionary of search options to apply search_options: Dictionary of search options to apply
lora_hash: Optional SHA256 hash of a LoRA to filter recipes by
bypass_filters: If True, ignore other filters when a lora_hash is provided
""" """
cache = await self.get_cached_data() cache = await self.get_cached_data()
# Get base dataset # Get base dataset
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply search filter # Special case: Filter by LoRA hash (takes precedence if bypass_filters is True)
if search: if lora_hash:
# Default search options if none provided # Filter recipes that contain this LoRA hash
if not search_options: filtered_data = [
search_options = { item for item in filtered_data
'title': True, if 'loras' in item and any(
'tags': True, lora.get('hash', '').lower() == lora_hash.lower()
'lora_name': True, for lora in item['loras']
'lora_model': True )
} ]
# Build the search predicate based on search options if bypass_filters:
def matches_search(item): # Skip other filters if bypass_filters is True
# Search in title if enabled pass
if search_options.get('title', True): # Otherwise continue with normal filtering after applying LoRA hash filter
if fuzzy_match(str(item.get('title', '')), search):
return True
# Search in tags if enabled
if search_options.get('tags', True) and 'tags' in item:
for tag in item['tags']:
if fuzzy_match(tag, search):
return True
# Search in lora file names if enabled
if search_options.get('lora_name', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('file_name', '')), search):
return True
# Search in lora model names if enabled
if search_options.get('lora_model', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('modelName', '')), search):
return True
# No match found
return False
# Filter the data using the search predicate
filtered_data = [item for item in filtered_data if matches_search(item)]
# Apply additional filters # Skip further filtering if we're only filtering by LoRA hash with bypass enabled
if filters: if not (lora_hash and bypass_filters):
# Filter by base model # Apply search filter
if 'base_model' in filters and filters['base_model']: if search:
filtered_data = [ # Default search options if none provided
item for item in filtered_data if not search_options:
if item.get('base_model', '') in filters['base_model'] search_options = {
] 'title': True,
'tags': True,
'lora_name': True,
'lora_model': True
}
# Build the search predicate based on search options
def matches_search(item):
# Search in title if enabled
if search_options.get('title', True):
if fuzzy_match(str(item.get('title', '')), search):
return True
# Search in tags if enabled
if search_options.get('tags', True) and 'tags' in item:
for tag in item['tags']:
if fuzzy_match(tag, search):
return True
# Search in lora file names if enabled
if search_options.get('lora_name', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('file_name', '')), search):
return True
# Search in lora model names if enabled
if search_options.get('lora_model', True) and 'loras' in item:
for lora in item['loras']:
if fuzzy_match(str(lora.get('modelName', '')), search):
return True
# No match found
return False
# Filter the data using the search predicate
filtered_data = [item for item in filtered_data if matches_search(item)]
# Filter by tags # Apply additional filters
if 'tags' in filters and filters['tags']: if filters:
filtered_data = [ # Filter by base model
item for item in filtered_data if 'base_model' in filters and filters['base_model']:
if any(tag in item.get('tags', []) for tag in filters['tags']) filtered_data = [
] item for item in filtered_data
if item.get('base_model', '') in filters['base_model']
]
# Filter by tags
if 'tags' in filters and filters['tags']:
filtered_data = [
item for item in filtered_data
if any(tag in item.get('tags', []) for tag in filters['tags'])
]
# Calculate pagination # Calculate pagination
total_items = len(filtered_data) total_items = len(filtered_data)
@@ -448,4 +600,204 @@ class RecipeScanner:
'total_pages': (total_items + page_size - 1) // page_size 'total_pages': (total_items + page_size - 1) // page_size
} }
return result return result
async def get_recipe_by_id(self, recipe_id: str) -> dict:
"""Get a single recipe by ID with all metadata and formatted URLs
Args:
recipe_id: The ID of the recipe to retrieve
Returns:
Dict containing the recipe data or None if not found
"""
if not recipe_id:
return None
# Get all recipes from cache
cache = await self.get_cached_data()
# Find the recipe with the specified ID
recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None)
if not recipe:
return None
# Format the recipe with all needed information
formatted_recipe = {**recipe} # Copy all fields
# Format file path to URL
if 'file_path' in formatted_recipe:
formatted_recipe['file_url'] = self._format_file_url(formatted_recipe['file_path'])
# Format dates for display
for date_field in ['created_date', 'modified']:
if date_field in formatted_recipe:
formatted_recipe[f"{date_field}_formatted"] = self._format_timestamp(formatted_recipe[date_field])
# Add lora metadata
if 'loras' in formatted_recipe:
for lora in formatted_recipe['loras']:
if 'hash' in lora and lora['hash']:
lora_hash = lora['hash'].lower()
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
return formatted_recipe
def _format_file_url(self, file_path: str) -> str:
"""Format file path as URL for serving in web UI"""
if not file_path:
return '/loras_static/images/no-preview.png'
try:
# Format file path as a URL that will work with static file serving
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
if file_path.replace(os.sep, '/').startswith(recipes_dir):
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
return f"/loras_static/root1/preview/{relative_path}"
# If not in recipes dir, try to create a valid URL from the file name
file_name = os.path.basename(file_path)
return f"/loras_static/root1/preview/recipes/{file_name}"
except Exception as e:
logger.error(f"Error formatting file URL: {e}")
return '/loras_static/images/no-preview.png'
def _format_timestamp(self, timestamp: float) -> str:
"""Format timestamp for display"""
from datetime import datetime
return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S')
async def update_recipe_metadata(self, recipe_id: str, metadata: dict) -> bool:
"""Update recipe metadata (like title and tags) in both file system and cache
Args:
recipe_id: The ID of the recipe to update
metadata: Dictionary containing metadata fields to update (title, tags, etc.)
Returns:
bool: True if successful, False otherwise
"""
import os
import json
# First, find the recipe JSON file path
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
return False
try:
# Load existing recipe data
with open(recipe_json_path, 'r', encoding='utf-8') as f:
recipe_data = json.load(f)
# Update fields
for key, value in metadata.items():
recipe_data[key] = value
# Save updated recipe
with open(recipe_json_path, 'w', encoding='utf-8') as f:
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
# Update the cache if it exists
if self._cache is not None:
await self._cache.update_recipe_metadata(recipe_id, metadata)
# If the recipe has an image, update its EXIF metadata
from ..utils.exif_utils import ExifUtils
image_path = recipe_data.get('file_path')
if image_path and os.path.exists(image_path):
ExifUtils.append_recipe_metadata(image_path, recipe_data)
return True
except Exception as e:
import logging
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
return False
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
"""Update file_name in all recipes that contain a LoRA with the specified hash.
Args:
hash_value: The SHA256 hash value of the LoRA
new_file_name: The new file_name to set
Returns:
Tuple[int, int]: (number of recipes updated in files, number of recipes updated in cache)
"""
if not hash_value or not new_file_name:
return 0, 0
# Always use lowercase hash for consistency
hash_value = hash_value.lower()
# Get recipes directory
recipes_dir = self.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
logger.warning(f"Recipes directory not found: {recipes_dir}")
return 0, 0
# Check if cache is initialized
cache_initialized = self._cache is not None
cache_updated_count = 0
file_updated_count = 0
# Get all recipe JSON files in the recipes directory
recipe_files = []
for root, _, files in os.walk(recipes_dir):
for file in files:
if file.lower().endswith('.recipe.json'):
recipe_files.append(os.path.join(root, file))
# Process each recipe file
for recipe_path in recipe_files:
try:
# Load the recipe data
with open(recipe_path, 'r', encoding='utf-8') as f:
recipe_data = json.load(f)
# Skip if no loras or invalid structure
if not recipe_data or not isinstance(recipe_data, dict) or 'loras' not in recipe_data:
continue
# Check if any lora has matching hash
file_updated = False
for lora in recipe_data.get('loras', []):
if 'hash' in lora and lora['hash'].lower() == hash_value:
# Update file_name
old_file_name = lora.get('file_name', '')
lora['file_name'] = new_file_name
file_updated = True
logger.info(f"Updated file_name in recipe {recipe_path}: {old_file_name} -> {new_file_name}")
# If updated, save the file
if file_updated:
with open(recipe_path, 'w', encoding='utf-8') as f:
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
file_updated_count += 1
# Also update in cache if it exists
if cache_initialized:
recipe_id = recipe_data.get('id')
if recipe_id:
for cache_item in self._cache.raw_data:
if cache_item.get('id') == recipe_id:
# Replace loras array with updated version
cache_item['loras'] = recipe_data['loras']
cache_updated_count += 1
break
except Exception as e:
logger.error(f"Error updating recipe file {recipe_path}: {e}")
import traceback
traceback.print_exc(file=sys.stderr)
# Resort cache if updates were made
if cache_initialized and cache_updated_count > 0:
await self._cache.resort()
logger.info(f"Resorted recipe cache after updating {cache_updated_count} items")
return file_updated_count, cache_updated_count

View File

@@ -0,0 +1,124 @@
import asyncio
import logging
from typing import Optional, Dict, Any, TypeVar, Type
logger = logging.getLogger(__name__)
T = TypeVar('T') # Define a type variable for service types
class ServiceRegistry:
"""Centralized registry for service singletons"""
_instance = None
_services: Dict[str, Any] = {}
_lock = asyncio.Lock()
@classmethod
def get_instance(cls):
"""Get singleton instance of the registry"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
async def register_service(cls, service_name: str, service_instance: Any) -> None:
"""Register a service instance with the registry"""
registry = cls.get_instance()
async with cls._lock:
registry._services[service_name] = service_instance
logger.debug(f"Registered service: {service_name}")
@classmethod
async def get_service(cls, service_name: str) -> Any:
"""Get a service instance by name"""
registry = cls.get_instance()
async with cls._lock:
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
# Convenience methods for common services
@classmethod
async def get_lora_scanner(cls):
"""Get the LoraScanner instance"""
from .lora_scanner import LoraScanner
scanner = await cls.get_service("lora_scanner")
if scanner is None:
scanner = await LoraScanner.get_instance()
await cls.register_service("lora_scanner", scanner)
return scanner
@classmethod
async def get_checkpoint_scanner(cls):
"""Get the CheckpointScanner instance"""
from .checkpoint_scanner import CheckpointScanner
scanner = await cls.get_service("checkpoint_scanner")
if scanner is None:
scanner = await CheckpointScanner.get_instance()
await cls.register_service("checkpoint_scanner", scanner)
return scanner
@classmethod
async def get_lora_monitor(cls):
"""Get the LoraFileMonitor instance"""
from .file_monitor import LoraFileMonitor
monitor = await cls.get_service("lora_monitor")
if monitor is None:
monitor = await LoraFileMonitor.get_instance()
await cls.register_service("lora_monitor", monitor)
return monitor
@classmethod
async def get_checkpoint_monitor(cls):
"""Get the CheckpointFileMonitor instance"""
from .file_monitor import CheckpointFileMonitor
monitor = await cls.get_service("checkpoint_monitor")
if monitor is None:
monitor = await CheckpointFileMonitor.get_instance()
await cls.register_service("checkpoint_monitor", monitor)
return monitor
@classmethod
async def get_civitai_client(cls):
"""Get the CivitaiClient instance"""
from .civitai_client import CivitaiClient
client = await cls.get_service("civitai_client")
if client is None:
client = await CivitaiClient.get_instance()
await cls.register_service("civitai_client", client)
return client
@classmethod
async def get_download_manager(cls):
"""Get the DownloadManager instance"""
from .download_manager import DownloadManager
manager = await cls.get_service("download_manager")
if manager is None:
# We'll let DownloadManager.get_instance handle file_monitor parameter
manager = await DownloadManager.get_instance()
await cls.register_service("download_manager", manager)
return manager
@classmethod
async def get_recipe_scanner(cls):
"""Get the RecipeScanner instance"""
from .recipe_scanner import RecipeScanner
scanner = await cls.get_service("recipe_scanner")
if scanner is None:
lora_scanner = await cls.get_lora_scanner()
scanner = RecipeScanner(lora_scanner)
await cls.register_service("recipe_scanner", scanner)
return scanner
@classmethod
async def get_websocket_manager(cls):
"""Get the WebSocketManager instance"""
from .websocket_manager import ws_manager
manager = await cls.get_service("websocket_manager")
if manager is None:
# ws_manager is already a global instance in websocket_manager.py
from .websocket_manager import ws_manager
await cls.register_service("websocket_manager", ws_manager)
manager = ws_manager
return manager

View File

@@ -9,6 +9,8 @@ class WebSocketManager:
def __init__(self): def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set() self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection""" """Handle new WebSocket connection"""
@@ -23,6 +25,34 @@ class WebSocketManager:
finally: finally:
self._websockets.discard(ws) self._websockets.discard(ws)
return ws return ws
async def handle_init_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for initialization progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._init_websockets.add(ws)
try:
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Init WebSocket error: {ws.exception()}')
finally:
self._init_websockets.discard(ws)
return ws
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for checkpoint download progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._checkpoint_websockets.add(ws)
try:
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
finally:
self._checkpoint_websockets.discard(ws)
return ws
async def broadcast(self, data: Dict): async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients""" """Broadcast message to all connected clients"""
@@ -34,10 +64,48 @@ class WebSocketManager:
await ws.send_json(data) await ws.send_json(data)
except Exception as e: except Exception as e:
logger.error(f"Error sending progress: {e}") logger.error(f"Error sending progress: {e}")
async def broadcast_init_progress(self, data: Dict):
"""Broadcast initialization progress to connected clients"""
if not self._init_websockets:
return
# Ensure data has all required fields
if 'stage' not in data:
data['stage'] = 'processing'
if 'progress' not in data:
data['progress'] = 0
if 'details' not in data:
data['details'] = 'Processing...'
for ws in self._init_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending initialization progress: {e}")
async def broadcast_checkpoint_progress(self, data: Dict):
"""Broadcast checkpoint download progress to connected clients"""
if not self._checkpoint_websockets:
return
for ws in self._checkpoint_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending checkpoint progress: {e}")
def get_connected_clients_count(self) -> int: def get_connected_clients_count(self) -> int:
"""Get number of connected clients""" """Get number of connected clients"""
return len(self._websockets) return len(self._websockets)
def get_init_clients_count(self) -> int:
"""Get number of initialization progress clients"""
return len(self._init_websockets)
def get_checkpoint_clients_count(self) -> int:
"""Get number of checkpoint progress clients"""
return len(self._checkpoint_websockets)
# Global instance # Global instance
ws_manager = WebSocketManager() ws_manager = WebSocketManager()

View File

@@ -5,4 +5,21 @@ NSFW_LEVELS = {
"X": 8, "X": 8,
"XXX": 16, "XXX": 16,
"Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account? "Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account?
} }
# preview extensions
PREVIEW_EXTENSIONS = [
'.webp',
'.preview.webp',
'.preview.png',
'.preview.jpeg',
'.preview.jpg',
'.preview.mp4',
'.png',
'.jpeg',
'.jpg',
'.mp4'
]
# Card preview image width
CARD_PREVIEW_WIDTH = 480

View File

@@ -1,11 +1,10 @@
import piexif import piexif
import json import json
import logging import logging
from typing import Dict, Optional, Any from typing import Optional
from io import BytesIO from io import BytesIO
import os import os
from PIL import Image from PIL import Image
import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -13,11 +12,23 @@ class ExifUtils:
"""Utility functions for working with EXIF data in images""" """Utility functions for working with EXIF data in images"""
@staticmethod @staticmethod
def extract_user_comment(image_path: str) -> Optional[str]: def extract_image_metadata(image_path: str) -> Optional[str]:
"""Extract UserComment field from image EXIF data""" """Extract metadata from image including UserComment or parameters field
Args:
image_path (str): Path to the image file
Returns:
Optional[str]: Extracted metadata or None if not found
"""
try: try:
# First try to open as image to check format # First try to open the image
with Image.open(image_path) as img: with Image.open(image_path) as img:
# Method 1: Check for parameters in image info
if hasattr(img, 'info') and 'parameters' in img.info:
return img.info['parameters']
# Method 2: Check EXIF UserComment field
if img.format not in ['JPEG', 'TIFF', 'WEBP']: if img.format not in ['JPEG', 'TIFF', 'WEBP']:
# For non-JPEG/TIFF/WEBP images, try to get EXIF through PIL # For non-JPEG/TIFF/WEBP images, try to get EXIF through PIL
exif = img._getexif() exif = img._getexif()
@@ -28,82 +39,106 @@ class ExifUtils:
return user_comment[8:].decode('utf-16be') return user_comment[8:].decode('utf-16be')
return user_comment.decode('utf-8', errors='ignore') return user_comment.decode('utf-8', errors='ignore')
return user_comment return user_comment
return None
# For JPEG/TIFF/WEBP, use piexif # For JPEG/TIFF/WEBP, use piexif
exif_dict = piexif.load(image_path) try:
exif_dict = piexif.load(image_path)
if piexif.ExifIFD.UserComment in exif_dict.get('Exif', {}):
user_comment = exif_dict['Exif'][piexif.ExifIFD.UserComment]
if isinstance(user_comment, bytes):
if user_comment.startswith(b'UNICODE\0'):
user_comment = user_comment[8:].decode('utf-16be')
else:
user_comment = user_comment.decode('utf-8', errors='ignore')
return user_comment
except Exception as e:
logger.debug(f"Error loading EXIF data: {e}")
# Method 3: Check PNG metadata for workflow info (for ComfyUI images)
if img.format == 'PNG':
# Look for workflow or prompt metadata in PNG chunks
for key in img.info:
if key in ['workflow', 'prompt', 'parameters']:
return img.info[key]
if piexif.ExifIFD.UserComment in exif_dict.get('Exif', {}):
user_comment = exif_dict['Exif'][piexif.ExifIFD.UserComment]
if isinstance(user_comment, bytes):
if user_comment.startswith(b'UNICODE\0'):
user_comment = user_comment[8:].decode('utf-16be')
else:
user_comment = user_comment.decode('utf-8', errors='ignore')
return user_comment
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error extracting image metadata: {e}", exc_info=True)
return None return None
@staticmethod @staticmethod
def update_user_comment(image_path: str, user_comment: str) -> str: def update_image_metadata(image_path: str, metadata: str) -> str:
"""Update UserComment field in image EXIF data""" """Update metadata in image's EXIF data or parameters fields
Args:
image_path (str): Path to the image file
metadata (str): Metadata string to save
Returns:
str: Path to the updated image
"""
try: try:
# Load the image and its EXIF data # Load the image and check its format
with Image.open(image_path) as img: with Image.open(image_path) as img:
# Get original format
img_format = img.format img_format = img.format
# For WebP format, we need a different approach # For PNG, try to update parameters directly
if img_format == 'WEBP': if img_format == 'PNG':
# WebP doesn't support standard EXIF through piexif # We'll save with parameters in the PNG info
# We'll use PIL's exif parameter directly info_dict = {'parameters': metadata}
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + user_comment.encode('utf-16be')}} img.save(image_path, format='PNG', pnginfo=info_dict)
return image_path
# For WebP format, use PIL's exif parameter directly
elif img_format == 'WEBP':
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict) exif_bytes = piexif.dump(exif_dict)
# Save with the exif data # Save with the exif data
img.save(image_path, format='WEBP', exif=exif_bytes, quality=85) img.save(image_path, format='WEBP', exif=exif_bytes, quality=85)
return image_path return image_path
# For other formats, use the standard approach # For other formats, use standard EXIF approach
try: else:
exif_dict = piexif.load(img.info.get('exif', b'')) try:
except: exif_dict = piexif.load(img.info.get('exif', b''))
exif_dict = {'0th':{}, 'Exif':{}, 'GPS':{}, 'Interop':{}, '1st':{}} except:
exif_dict = {'0th':{}, 'Exif':{}, 'GPS':{}, 'Interop':{}, '1st':{}}
# If no Exif dictionary exists, create one
if 'Exif' not in exif_dict: # If no Exif dictionary exists, create one
exif_dict['Exif'] = {} if 'Exif' not in exif_dict:
exif_dict['Exif'] = {}
# Update the UserComment field - use UNICODE format
unicode_bytes = user_comment.encode('utf-16be') # Update the UserComment field - use UNICODE format
user_comment_bytes = b'UNICODE\0' + unicode_bytes unicode_bytes = metadata.encode('utf-16be')
metadata_bytes = b'UNICODE\0' + unicode_bytes
exif_dict['Exif'][piexif.ExifIFD.UserComment] = user_comment_bytes
exif_dict['Exif'][piexif.ExifIFD.UserComment] = metadata_bytes
# Convert EXIF dict back to bytes
exif_bytes = piexif.dump(exif_dict) # Convert EXIF dict back to bytes
exif_bytes = piexif.dump(exif_dict)
# Save the image with updated EXIF data
img.save(image_path, exif=exif_bytes) # Save the image with updated EXIF data
img.save(image_path, exif=exif_bytes)
return image_path return image_path
except Exception as e: except Exception as e:
logger.error(f"Error updating EXIF data in {image_path}: {e}") logger.error(f"Error updating metadata in {image_path}: {e}")
return image_path return image_path
@staticmethod @staticmethod
def append_recipe_metadata(image_path, recipe_data) -> str: def append_recipe_metadata(image_path, recipe_data) -> str:
"""Append recipe metadata to an image's EXIF data""" """Append recipe metadata to an image's EXIF data"""
try: try:
# First, extract existing user comment # First, extract existing metadata
user_comment = ExifUtils.extract_user_comment(image_path) metadata = ExifUtils.extract_image_metadata(image_path)
# Check if there's already recipe metadata in the user comment # Check if there's already recipe metadata
if user_comment: if metadata:
# Remove any existing recipe metadata # Remove any existing recipe metadata
user_comment = ExifUtils.remove_recipe_metadata(user_comment) metadata = ExifUtils.remove_recipe_metadata(metadata)
# Prepare simplified loras data # Prepare simplified loras data
simplified_loras = [] simplified_loras = []
@@ -133,11 +168,11 @@ class ExifUtils:
# Create the recipe metadata marker # Create the recipe metadata marker
recipe_metadata_marker = f"Recipe metadata: {recipe_metadata_json}" recipe_metadata_marker = f"Recipe metadata: {recipe_metadata_json}"
# Append to existing user comment or create new one # Append to existing metadata or create new one
new_user_comment = f"{user_comment} \n {recipe_metadata_marker}" if user_comment else recipe_metadata_marker new_metadata = f"{metadata} \n {recipe_metadata_marker}" if metadata else recipe_metadata_marker
# Write back to the image # Write back to the image
return ExifUtils.update_user_comment(image_path, new_user_comment) return ExifUtils.update_image_metadata(image_path, new_metadata)
except Exception as e: except Exception as e:
logger.error(f"Error appending recipe metadata: {e}", exc_info=True) logger.error(f"Error appending recipe metadata: {e}", exc_info=True)
return image_path return image_path
@@ -168,7 +203,7 @@ class ExifUtils:
return user_comment[:recipe_marker_index] + user_comment[next_line_index:] return user_comment[:recipe_marker_index] + user_comment[next_line_index:]
@staticmethod @staticmethod
def optimize_image(image_data, target_width=250, format='webp', quality=85, preserve_metadata=True): def optimize_image(image_data, target_width=250, format='webp', quality=85, preserve_metadata=False):
""" """
Optimize an image by resizing and converting to WebP format Optimize an image by resizing and converting to WebP format
@@ -183,304 +218,144 @@ class ExifUtils:
Tuple of (optimized_image_data, extension) Tuple of (optimized_image_data, extension)
""" """
try: try:
# Extract metadata if needed # First validate the image data is usable
user_comment = None img = None
if preserve_metadata: if isinstance(image_data, str) and os.path.exists(image_data):
if isinstance(image_data, str) and os.path.exists(image_data): # It's a file path - validate file
# It's a file path try:
user_comment = ExifUtils.extract_user_comment(image_data) with Image.open(image_data) as test_img:
# Verify the image can be fully loaded by accessing its size
width, height = test_img.size
# If we got here, the image is valid
img = Image.open(image_data) img = Image.open(image_data)
else: except (IOError, OSError) as e:
# It's binary data logger.error(f"Invalid or corrupt image file: {image_data}: {e}")
temp_img = BytesIO(image_data) raise ValueError(f"Cannot process corrupt image: {e}")
img = Image.open(temp_img)
# Save to a temporary file to extract metadata
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(image_data)
user_comment = ExifUtils.extract_user_comment(temp_path)
os.unlink(temp_path)
else: else:
# Just open the image without extracting metadata # It's binary data - validate data
if isinstance(image_data, str) and os.path.exists(image_data): try:
img = Image.open(image_data) with BytesIO(image_data) as temp_buf:
else: test_img = Image.open(temp_buf)
# Verify the image can be fully loaded
width, height = test_img.size
# If successful, reopen for processing
img = Image.open(BytesIO(image_data)) img = Image.open(BytesIO(image_data))
except Exception as e:
logger.error(f"Invalid binary image data: {e}")
raise ValueError(f"Cannot process corrupt image data: {e}")
# Extract metadata if needed and valid
metadata = None
if preserve_metadata:
try:
if isinstance(image_data, str) and os.path.exists(image_data):
# For file path, extract directly
metadata = ExifUtils.extract_image_metadata(image_data)
else:
# For binary data, save to temp file first
import tempfile
with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file:
temp_path = temp_file.name
temp_file.write(image_data)
try:
metadata = ExifUtils.extract_image_metadata(temp_path)
except Exception as e:
logger.warning(f"Failed to extract metadata from temp file: {e}")
finally:
# Clean up temp file
try:
os.unlink(temp_path)
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to extract metadata, continuing without it: {e}")
# Continue without metadata
# Calculate new height to maintain aspect ratio # Calculate new height to maintain aspect ratio
width, height = img.size width, height = img.size
new_height = int(height * (target_width / width)) new_height = int(height * (target_width / width))
# Resize the image # Resize the image with error handling
resized_img = img.resize((target_width, new_height), Image.LANCZOS) try:
resized_img = img.resize((target_width, new_height), Image.LANCZOS)
except Exception as e:
logger.error(f"Failed to resize image: {e}")
# Return original image if resize fails
return image_data, '.jpg' if not isinstance(image_data, str) else os.path.splitext(image_data)[1]
# Save to BytesIO in the specified format # Save to BytesIO in the specified format
output = BytesIO() output = BytesIO()
# WebP format # Set format and extension
if format.lower() == 'webp': if format.lower() == 'webp':
resized_img.save(output, format='WEBP', quality=quality) save_format, extension = 'WEBP', '.webp'
extension = '.webp'
# JPEG format
elif format.lower() in ('jpg', 'jpeg'): elif format.lower() in ('jpg', 'jpeg'):
resized_img.save(output, format='JPEG', quality=quality) save_format, extension = 'JPEG', '.jpg'
extension = '.jpg'
# PNG format
elif format.lower() == 'png': elif format.lower() == 'png':
resized_img.save(output, format='PNG', optimize=True) save_format, extension = 'PNG', '.png'
extension = '.png'
else: else:
# Default to WebP save_format, extension = 'WEBP', '.webp'
resized_img.save(output, format='WEBP', quality=quality)
extension = '.webp' # Save with error handling
try:
if save_format == 'PNG':
resized_img.save(output, format=save_format, optimize=True)
else:
resized_img.save(output, format=save_format, quality=quality)
except Exception as e:
logger.error(f"Failed to save optimized image: {e}")
# Return original image if save fails
return image_data, '.jpg' if not isinstance(image_data, str) else os.path.splitext(image_data)[1]
# Get the optimized image data # Get the optimized image data
optimized_data = output.getvalue() optimized_data = output.getvalue()
# If we need to preserve metadata, write it to a temporary file # Handle metadata preservation if requested and available
if preserve_metadata and user_comment: if preserve_metadata and metadata:
# For WebP format, we'll directly save with metadata try:
if format.lower() == 'webp': if save_format == 'WEBP':
# Create a new BytesIO with metadata # For WebP format, directly save with metadata
output_with_metadata = BytesIO() try:
output_with_metadata = BytesIO()
# Create EXIF data with user comment exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + user_comment.encode('utf-16be')}} exif_bytes = piexif.dump(exif_dict)
exif_bytes = piexif.dump(exif_dict) resized_img.save(output_with_metadata, format='WEBP', exif=exif_bytes, quality=quality)
optimized_data = output_with_metadata.getvalue()
# Save with metadata except Exception as e:
resized_img.save(output_with_metadata, format='WEBP', exif=exif_bytes, quality=quality) logger.warning(f"Failed to add metadata to WebP, continuing without it: {e}")
optimized_data = output_with_metadata.getvalue() else:
else: # For other formats, use temporary file
# For other formats, use the temporary file approach import tempfile
import tempfile with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file:
with tempfile.NamedTemporaryFile(suffix=extension, delete=False) as temp_file: temp_path = temp_file.name
temp_path = temp_file.name temp_file.write(optimized_data)
temp_file.write(optimized_data)
try:
# Add the metadata back # Add metadata
ExifUtils.update_user_comment(temp_path, user_comment) ExifUtils.update_image_metadata(temp_path, metadata)
# Read back the file
# Read the file with metadata with open(temp_path, 'rb') as f:
with open(temp_path, 'rb') as f: optimized_data = f.read()
optimized_data = f.read() except Exception as e:
logger.warning(f"Failed to add metadata to image, continuing without it: {e}")
# Clean up finally:
os.unlink(temp_path) # Clean up temp file
try:
os.unlink(temp_path)
except Exception:
pass
except Exception as e:
logger.warning(f"Failed to preserve metadata: {e}, continuing with unmodified output")
return optimized_data, extension return optimized_data, extension
except Exception as e: except Exception as e:
logger.error(f"Error optimizing image: {e}", exc_info=True) logger.error(f"Error optimizing image: {e}", exc_info=True)
# Return original data if optimization fails # Return original data if optimization completely fails
if isinstance(image_data, str) and os.path.exists(image_data): if isinstance(image_data, str) and os.path.exists(image_data):
with open(image_data, 'rb') as f:
return f.read(), os.path.splitext(image_data)[1]
return image_data, '.jpg'
@staticmethod
def _parse_comfyui_workflow(workflow_data: Any) -> Dict[str, Any]:
"""
Parse ComfyUI workflow data and extract relevant generation parameters
Args:
workflow_data: Raw workflow data (string or dict)
Returns:
Formatted generation parameters dictionary
"""
try:
# If workflow_data is a string, try to parse it as JSON
if isinstance(workflow_data, str):
try: try:
workflow_data = json.loads(workflow_data) with open(image_data, 'rb') as f:
except json.JSONDecodeError: return f.read(), os.path.splitext(image_data)[1]
logger.error("Failed to parse workflow data as JSON") except Exception:
return {} return image_data, '.jpg' # Last resort fallback
return image_data, '.jpg'
# Now workflow_data should be a dictionary
if not isinstance(workflow_data, dict):
logger.error(f"Workflow data is not a dictionary: {type(workflow_data)}")
return {}
# Initialize parameters dictionary with only the required fields
gen_params = {
"prompt": "",
"negative_prompt": "",
"steps": "",
"sampler": "",
"cfg_scale": "",
"seed": "",
"size": "",
"clip_skip": ""
}
# First pass: find the KSampler node to get basic parameters and node references
# Store node references to follow for prompts
positive_ref = None
negative_ref = None
for node_id, node_data in workflow_data.items():
if not isinstance(node_data, dict):
continue
# Extract node inputs if available
inputs = node_data.get("inputs", {})
if not inputs:
continue
# KSampler nodes contain most generation parameters and references to prompt nodes
if "KSampler" in node_data.get("class_type", ""):
# Extract basic sampling parameters
gen_params["steps"] = inputs.get("steps", "")
gen_params["cfg_scale"] = inputs.get("cfg", "")
gen_params["sampler"] = inputs.get("sampler_name", "")
gen_params["seed"] = inputs.get("seed", "")
if isinstance(gen_params["seed"], list) and len(gen_params["seed"]) > 1:
gen_params["seed"] = gen_params["seed"][1] # Use the actual value if it's a list
# Get references to positive and negative prompt nodes
positive_ref = inputs.get("positive", "")
negative_ref = inputs.get("negative", "")
# CLIPSetLastLayer contains clip_skip information
elif "CLIPSetLastLayer" in node_data.get("class_type", ""):
gen_params["clip_skip"] = inputs.get("stop_at_clip_layer", "")
if isinstance(gen_params["clip_skip"], int) and gen_params["clip_skip"] < 0:
# Convert negative layer index to positive clip skip value
gen_params["clip_skip"] = abs(gen_params["clip_skip"])
# Look for resolution information
elif "LatentImage" in node_data.get("class_type", "") or "Empty" in node_data.get("class_type", ""):
width = inputs.get("width", 0)
height = inputs.get("height", 0)
if width and height:
gen_params["size"] = f"{width}x{height}"
# Some nodes have resolution as a string like "832x1216 (0.68)"
resolution = inputs.get("resolution", "")
if isinstance(resolution, str) and "x" in resolution:
gen_params["size"] = resolution.split(" ")[0] # Extract just the dimensions
# Helper function to follow node references and extract text content
def get_text_from_node_ref(node_ref, workflow_data):
if not node_ref or not isinstance(node_ref, list) or len(node_ref) < 2:
return ""
node_id, slot_idx = node_ref
# If we can't find the node, return empty string
if node_id not in workflow_data:
return ""
node = workflow_data[node_id]
inputs = node.get("inputs", {})
# Direct text input in CLIP Text Encode nodes
if "CLIPTextEncode" in node.get("class_type", ""):
text = inputs.get("text", "")
if isinstance(text, str):
return text
elif isinstance(text, list) and len(text) >= 2:
# If text is a reference to another node, follow it
return get_text_from_node_ref(text, workflow_data)
# Other nodes might have text input with different field names
for field_name, field_value in inputs.items():
if field_name == "text" and isinstance(field_value, str):
return field_value
elif isinstance(field_value, list) and len(field_value) >= 2 and field_name in ["text"]:
# If it's a reference to another node, follow it
return get_text_from_node_ref(field_value, workflow_data)
return ""
# Extract prompts by following references from KSampler node
if positive_ref:
gen_params["prompt"] = get_text_from_node_ref(positive_ref, workflow_data)
if negative_ref:
gen_params["negative_prompt"] = get_text_from_node_ref(negative_ref, workflow_data)
# Fallback: if we couldn't extract prompts via references, use the traditional method
if not gen_params["prompt"] or not gen_params["negative_prompt"]:
for node_id, node_data in workflow_data.items():
if not isinstance(node_data, dict):
continue
inputs = node_data.get("inputs", {})
if not inputs:
continue
if "CLIPTextEncode" in node_data.get("class_type", ""):
# Check for negative prompt nodes
title = node_data.get("_meta", {}).get("title", "").lower()
prompt_text = inputs.get("text", "")
if isinstance(prompt_text, str):
if "negative" in title and not gen_params["negative_prompt"]:
gen_params["negative_prompt"] = prompt_text
elif prompt_text and not "negative" in title and not gen_params["prompt"]:
gen_params["prompt"] = prompt_text
return gen_params
except Exception as e:
logger.error(f"Error parsing ComfyUI workflow: {e}", exc_info=True)
return {}
@staticmethod
def extract_comfyui_gen_params(image_path: str) -> Dict[str, Any]:
"""
Extract ComfyUI workflow data from PNG images and format for recipe data
Only extracts the specific generation parameters needed for recipes.
Args:
image_path: Path to the ComfyUI-generated PNG image
Returns:
Dictionary containing formatted generation parameters
"""
try:
# Check if the file exists and is accessible
if not os.path.exists(image_path):
logger.error(f"Image file not found: {image_path}")
return {}
# Open the image to extract embedded workflow data
with Image.open(image_path) as img:
workflow_data = None
# For PNG images, look for the ComfyUI workflow data in PNG chunks
if img.format == 'PNG':
# Check standard metadata fields that might contain workflow
if 'parameters' in img.info:
workflow_data = img.info['parameters']
elif 'prompt' in img.info:
workflow_data = img.info['prompt']
else:
# Look for other potential field names that might contain workflow data
for key in img.info:
if isinstance(key, str) and ('workflow' in key.lower() or 'comfy' in key.lower()):
workflow_data = img.info[key]
break
# If no workflow data found in PNG chunks, try EXIF as fallback
if not workflow_data:
user_comment = ExifUtils.extract_user_comment(image_path)
if user_comment and '{' in user_comment and '}' in user_comment:
# Try to extract JSON part
json_start = user_comment.find('{')
json_end = user_comment.rfind('}') + 1
workflow_data = user_comment[json_start:json_end]
# Parse workflow data if found
if workflow_data:
return ExifUtils._parse_comfyui_workflow(workflow_data)
return {}
except Exception as e:
logger.error(f"Error extracting ComfyUI gen params from {image_path}: {e}", exc_info=True)
return {}

View File

@@ -2,12 +2,14 @@ import logging
import os import os
import hashlib import hashlib
import json import json
from typing import Dict, Optional import time
from typing import Dict, Optional, Type
from .model_utils import determine_base_model from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .lora_metadata import extract_lora_metadata from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
from .models import LoraMetadata from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from .exif_utils import ExifUtils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,35 +17,56 @@ async def calculate_sha256(file_path: str) -> str:
"""Calculate SHA256 hash of a file""" """Calculate SHA256 hash of a file"""
sha256_hash = hashlib.sha256() sha256_hash = hashlib.sha256()
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""): for byte_block in iter(lambda: f.read(128 * 1024), b""):
sha256_hash.update(byte_block) sha256_hash.update(byte_block)
return sha256_hash.hexdigest() return sha256_hash.hexdigest()
def _find_preview_file(base_name: str, dir_path: str) -> str: def find_preview_file(base_name: str, dir_path: str) -> str:
"""Find preview file for given base name in directory""" """Find preview file for given base name in directory"""
preview_patterns = [
f"{base_name}.preview.png",
f"{base_name}.preview.jpg",
f"{base_name}.preview.jpeg",
f"{base_name}.preview.mp4",
f"{base_name}.png",
f"{base_name}.jpg",
f"{base_name}.jpeg",
f"{base_name}.mp4"
]
for pattern in preview_patterns: for ext in PREVIEW_EXTENSIONS:
full_pattern = os.path.join(dir_path, pattern) full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
if os.path.exists(full_pattern): if os.path.exists(full_pattern):
# Check if this is an image and not already webp
if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'):
try:
# Optimize the image to webp format
webp_path = os.path.join(dir_path, f"{base_name}.webp")
# Use ExifUtils to optimize the image
with open(full_pattern, 'rb') as f:
image_data = f.read()
optimized_data, _ = ExifUtils.optimize_image(
image_data=image_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False # Changed from True to False
)
# Save the optimized webp file
with open(webp_path, 'wb') as f:
f.write(optimized_data)
logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}")
return webp_path.replace(os.sep, "/")
except Exception as e:
logger.error(f"Error optimizing preview image {full_pattern}: {e}")
# Fall back to original file if optimization fails
return full_pattern.replace(os.sep, "/")
# Return the original path for webp images or non-image files
return full_pattern.replace(os.sep, "/") return full_pattern.replace(os.sep, "/")
return "" return ""
def normalize_path(path: str) -> str: def normalize_path(path: str) -> str:
"""Normalize file path to use forward slashes""" """Normalize file path to use forward slashes"""
return path.replace(os.sep, "/") if path else path return path.replace(os.sep, "/") if path else path
async def get_file_info(file_path: str) -> Optional[LoraMetadata]: async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Get basic file information as LoraMetadata object""" """Get basic file information as a model metadata object"""
# First check if file actually exists and resolve symlinks # First check if file actually exists and resolve symlinks
try: try:
real_path = os.path.realpath(file_path) real_path = os.path.realpath(file_path)
@@ -56,28 +79,81 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
base_name = os.path.splitext(os.path.basename(file_path))[0] base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path) dir_path = os.path.dirname(file_path)
preview_url = _find_preview_file(base_name, dir_path) preview_url = find_preview_file(base_name, dir_path)
# Check if a .json file exists with SHA256 hash to avoid recalculation
json_path = f"{os.path.splitext(file_path)[0]}.json"
sha256 = None
if os.path.exists(json_path):
try:
with open(json_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
if 'sha256' in json_data:
sha256 = json_data['sha256'].lower()
logger.debug(f"Using SHA256 from .json file for {file_path}")
except Exception as e:
logger.error(f"Error reading .json file for {file_path}: {e}")
# If SHA256 is still not found, check for a .sha256 file
if sha256 is None:
sha256_file = f"{os.path.splitext(file_path)[0]}.sha256"
if os.path.exists(sha256_file):
try:
with open(sha256_file, 'r', encoding='utf-8') as f:
sha256 = f.read().strip().lower()
logger.debug(f"Using SHA256 from .sha256 file for {file_path}")
except Exception as e:
logger.error(f"Error reading .sha256 file for {file_path}: {e}")
try: try:
metadata = LoraMetadata( # If we didn't get SHA256 from the .json file, calculate it
file_name=base_name, if not sha256:
model_name=base_name, start_time = time.time()
file_path=normalize_path(file_path), sha256 = await calculate_sha256(real_path)
size=os.path.getsize(real_path), logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
modified=os.path.getmtime(real_path),
sha256=await calculate_sha256(real_path), # Create default metadata based on model class
base_model="Unknown", # Will be updated later if model_class == CheckpointMetadata:
usage_tips="", metadata = CheckpointMetadata(
notes="", file_name=base_name,
from_civitai=True, model_name=base_name,
preview_url=normalize_path(preview_url), file_path=normalize_path(file_path),
tags=[], size=os.path.getsize(real_path),
modelDescription="" modified=os.path.getmtime(real_path),
) sha256=sha256,
base_model="Unknown", # Will be updated later
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
model_type="checkpoint"
)
# Extract checkpoint-specific metadata
# model_info = await extract_checkpoint_metadata(real_path)
# metadata.base_model = model_info['base_model']
# if 'model_type' in model_info:
# metadata.model_type = model_info['model_type']
else: # Default to LoraMetadata
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="{}",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract lora-specific metadata
model_info = await extract_lora_metadata(real_path)
metadata.base_model = model_info['base_model']
# create metadata file # Save metadata to file
base_model_info = await extract_lora_metadata(real_path)
metadata.base_model = base_model_info['base_model']
await save_metadata(file_path, metadata) await save_metadata(file_path, metadata)
return metadata return metadata
@@ -85,7 +161,7 @@ async def get_file_info(file_path: str) -> Optional[LoraMetadata]:
logger.error(f"Error getting file info for {file_path}: {e}") logger.error(f"Error getting file info for {file_path}: {e}")
return None return None
async def save_metadata(file_path: str, metadata: LoraMetadata) -> None: async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
"""Save metadata to .metadata.json file""" """Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try: try:
@@ -98,7 +174,7 @@ async def save_metadata(file_path: str, metadata: LoraMetadata) -> None:
except Exception as e: except Exception as e:
print(f"Error saving metadata to {metadata_path}: {str(e)}") print(f"Error saving metadata to {metadata_path}: {str(e)}")
async def load_metadata(file_path: str) -> Optional[LoraMetadata]: async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Load metadata from .metadata.json file""" """Load metadata from .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json" metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try: try:
@@ -121,11 +197,12 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
data['file_path'] = normalize_path(file_path) data['file_path'] = normalize_path(file_path)
needs_update = True needs_update = True
# TODO: optimize preview image to webp format if not already done
preview_url = data.get('preview_url', '') preview_url = data.get('preview_url', '')
if not preview_url or not os.path.exists(preview_url): if not preview_url or not os.path.exists(preview_url):
base_name = os.path.splitext(os.path.basename(file_path))[0] base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path) dir_path = os.path.dirname(file_path)
new_preview_url = normalize_path(_find_preview_file(base_name, dir_path)) new_preview_url = normalize_path(find_preview_file(base_name, dir_path))
if new_preview_url != preview_url: if new_preview_url != preview_url:
data['preview_url'] = new_preview_url data['preview_url'] = new_preview_url
needs_update = True needs_update = True
@@ -145,12 +222,22 @@ async def load_metadata(file_path: str) -> Optional[LoraMetadata]:
if 'modelDescription' not in data: if 'modelDescription' not in data:
data['modelDescription'] = "" data['modelDescription'] = ""
needs_update = True needs_update = True
# For checkpoint metadata
if model_class == CheckpointMetadata and 'model_type' not in data:
data['model_type'] = "checkpoint"
needs_update = True
# For lora metadata
if model_class == LoraMetadata and 'usage_tips' not in data:
data['usage_tips'] = "{}"
needs_update = True
if needs_update: if needs_update:
with open(metadata_path, 'w', encoding='utf-8') as f: with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False) json.dump(data, f, indent=2, ensure_ascii=False)
return LoraMetadata.from_dict(data) return model_class.from_dict(data)
except Exception as e: except Exception as e:
print(f"Error loading metadata from {metadata_path}: {str(e)}") print(f"Error loading metadata from {metadata_path}: {str(e)}")

View File

@@ -1,6 +1,7 @@
from safetensors import safe_open from safetensors import safe_open
from typing import Dict from typing import Dict
from .model_utils import determine_base_model from .model_utils import determine_base_model
import os
async def extract_lora_metadata(file_path: str) -> Dict: async def extract_lora_metadata(file_path: str) -> Dict:
"""Extract essential metadata from safetensors file""" """Extract essential metadata from safetensors file"""
@@ -13,4 +14,67 @@ async def extract_lora_metadata(file_path: str) -> Dict:
return {"base_model": base_model} return {"base_model": base_model}
except Exception as e: except Exception as e:
print(f"Error reading metadata from {file_path}: {str(e)}") print(f"Error reading metadata from {file_path}: {str(e)}")
return {"base_model": "Unknown"} return {"base_model": "Unknown"}
async def extract_checkpoint_metadata(file_path: str) -> dict:
"""Extract metadata from a checkpoint file to determine model type and base model"""
try:
# Analyze filename for clues about the model
filename = os.path.basename(file_path).lower()
model_info = {
'base_model': 'Unknown',
'model_type': 'checkpoint'
}
# Detect base model from filename
if 'xl' in filename or 'sdxl' in filename:
model_info['base_model'] = 'SDXL'
elif 'sd3' in filename:
model_info['base_model'] = 'SD3'
elif 'sd2' in filename or 'v2' in filename:
model_info['base_model'] = 'SD2.x'
elif 'sd1' in filename or 'v1' in filename:
model_info['base_model'] = 'SD1.5'
# Detect model type from filename
if 'inpaint' in filename:
model_info['model_type'] = 'inpainting'
elif 'anime' in filename:
model_info['model_type'] = 'anime'
elif 'realistic' in filename:
model_info['model_type'] = 'realistic'
# Try to peek at the safetensors file structure if available
if file_path.endswith('.safetensors'):
import json
import struct
with open(file_path, 'rb') as f:
header_size = struct.unpack('<Q', f.read(8))[0]
header_json = f.read(header_size)
header = json.loads(header_json)
# Look for specific keys to identify model type
metadata = header.get('__metadata__', {})
if metadata:
# Try to determine if it's SDXL
if any(key.startswith('conditioner.embedders.1') for key in header):
model_info['base_model'] = 'SDXL'
# Look for model type info
if metadata.get('modelspec.architecture') == 'SD-XL':
model_info['base_model'] = 'SDXL'
elif metadata.get('modelspec.architecture') == 'SD-3':
model_info['base_model'] = 'SD3'
# Check for specific use case
if metadata.get('modelspec.purpose') == 'inpainting':
model_info['model_type'] = 'inpainting'
return model_info
except Exception as e:
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
# Return default values
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}

View File

@@ -5,23 +5,23 @@ import os
from .model_utils import determine_base_model from .model_utils import determine_base_model
@dataclass @dataclass
class LoraMetadata: class BaseModelMetadata:
"""Represents the metadata structure for a Lora model""" """Base class for all model metadata structures"""
file_name: str # The filename without extension of the lora file_name: str # The filename without extension
model_name: str # The lora's name defined by the creator, initially same as file_name model_name: str # The model's name defined by the creator
file_path: str # Full path to the safetensors file file_path: str # Full path to the model file
size: int # File size in bytes size: int # File size in bytes
modified: float # Last modified timestamp modified: float # Last modified timestamp
sha256: str # SHA256 hash of the file sha256: str # SHA256 hash of the file
base_model: str # Base model (SD1.5/SD2.1/SDXL/etc.) base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL preview_url: str # Preview image URL
preview_nsfw_level: int = 0 # NSFW level of the preview image preview_nsfw_level: int = 0 # NSFW level of the preview image
usage_tips: str = "{}" # Usage tips for the model, json string
notes: str = "" # Additional notes notes: str = "" # Additional notes
from_civitai: bool = True # Whether the lora is from Civitai from_civitai: bool = True # Whether from Civitai
civitai: Optional[Dict] = None # Civitai API data if available civitai: Optional[Dict] = None # Civitai API data if available
tags: List[str] = None # Model tags tags: List[str] = None # Model tags
modelDescription: str = "" # Full model description modelDescription: str = "" # Full model description
civitai_deleted: bool = False # Whether deleted from Civitai
def __post_init__(self): def __post_init__(self):
# Initialize empty lists to avoid mutable default parameter issue # Initialize empty lists to avoid mutable default parameter issue
@@ -29,32 +29,11 @@ class LoraMetadata:
self.tags = [] self.tags = []
@classmethod @classmethod
def from_dict(cls, data: Dict) -> 'LoraMetadata': def from_dict(cls, data: Dict) -> 'BaseModelMetadata':
"""Create LoraMetadata instance from dictionary""" """Create instance from dictionary"""
# Create a copy of the data to avoid modifying the input
data_copy = data.copy() data_copy = data.copy()
return cls(**data_copy) return cls(**data_copy)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download, it is decided by the nsfw level of the preview image
from_civitai=True,
civitai=version_info
)
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization""" """Convert to dictionary for JSON serialization"""
return asdict(self) return asdict(self)
@@ -75,3 +54,77 @@ class LoraMetadata:
self.modified = os.path.getmtime(file_path) self.modified = os.path.getmtime(file_path)
self.file_path = file_path.replace(os.sep, '/') self.file_path = file_path.replace(os.sep, '/')
@dataclass
class LoraMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Lora model"""
usage_tips: str = "{}" # Usage tips for the model, json string
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'LoraMetadata':
"""Create LoraMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0, # Will be updated after preview download
from_civitai=True,
civitai=version_info,
tags=tags,
modelDescription=description
)
@dataclass
class CheckpointMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Checkpoint model"""
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
"""Create CheckpointMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'checkpoint')
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type,
tags=tags,
modelDescription=description
)

View File

@@ -44,6 +44,141 @@ class RecipeMetadataParser(ABC):
Dict containing parsed recipe data with standardized format Dict containing parsed recipe data with standardized format
""" """
pass pass
async def populate_lora_from_civitai(self, lora_entry: Dict[str, Any], civitai_info_tuple: Tuple[Dict[str, Any], Optional[str]],
recipe_scanner=None, base_model_counts=None, hash_value=None) -> Dict[str, Any]:
"""
Populate a lora entry with information from Civitai API response
Args:
lora_entry: The lora entry to populate
civitai_info_tuple: The response tuple from Civitai API (data, error_msg)
recipe_scanner: Optional recipe scanner for local file lookup
base_model_counts: Optional dict to track base model counts
hash_value: Optional hash value to use if not available in civitai_info
Returns:
The populated lora_entry dict
"""
try:
# Unpack the tuple to get the actual data
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
if civitai_info and civitai_info.get("error") != "Model not found":
# Check if this is an early access lora
if civitai_info.get('earlyAccessEndsAt'):
# Convert earlyAccessEndsAt to a human-readable date
early_access_date = civitai_info.get('earlyAccessEndsAt', '')
lora_entry['isEarlyAccess'] = True
lora_entry['earlyAccessEndsAt'] = early_access_date
# Update model name if available
if 'model' in civitai_info and 'name' in civitai_info['model']:
lora_entry['name'] = civitai_info['model']['name']
# Update version if available
if 'name' in civitai_info:
lora_entry['version'] = civitai_info.get('name', '')
# Get thumbnail URL from first image
if 'images' in civitai_info and civitai_info['images']:
lora_entry['thumbnailUrl'] = civitai_info['images'][0].get('url', '')
# Get base model
current_base_model = civitai_info.get('baseModel', '')
lora_entry['baseModel'] = current_base_model
# Update base model counts if tracking them
if base_model_counts is not None and current_base_model:
base_model_counts[current_base_model] = base_model_counts.get(current_base_model, 0) + 1
# Get download URL
lora_entry['downloadUrl'] = civitai_info.get('downloadUrl', '')
# Process file information if available
if 'files' in civitai_info:
model_file = next((file for file in civitai_info.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
# Get size
lora_entry['size'] = model_file.get('sizeKB', 0) * 1024
# Get SHA256 hash
sha256 = model_file.get('hashes', {}).get('SHA256', hash_value)
if sha256:
lora_entry['hash'] = sha256.lower()
# Check if exists locally
if recipe_scanner and lora_entry['hash']:
lora_scanner = recipe_scanner._lora_scanner
exists_locally = lora_scanner.has_lora_hash(lora_entry['hash'])
if exists_locally:
try:
local_path = lora_scanner.get_lora_path_by_hash(lora_entry['hash'])
lora_entry['existsLocally'] = True
lora_entry['localPath'] = local_path
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]
# Get thumbnail from local preview if available
lora_cache = await lora_scanner.get_cached_data()
lora_item = next((item for item in lora_cache.raw_data
if item['sha256'].lower() == lora_entry['hash'].lower()), None)
if lora_item and 'preview_url' in lora_item:
lora_entry['thumbnailUrl'] = config.get_preview_static_url(lora_item['preview_url'])
except Exception as e:
logger.error(f"Error getting local lora path: {e}")
else:
# For missing LoRAs, get file_name from model_file.name
file_name = model_file.get('name', '')
lora_entry['file_name'] = os.path.splitext(file_name)[0] if file_name else ''
else:
# Model not found or deleted
lora_entry['isDeleted'] = True
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
except Exception as e:
logger.error(f"Error populating lora from Civitai info: {e}")
return lora_entry
async def populate_checkpoint_from_civitai(self, checkpoint: Dict[str, Any], civitai_info: Dict[str, Any]) -> Dict[str, Any]:
"""
Populate checkpoint information from Civitai API response
Args:
checkpoint: The checkpoint entry to populate
civitai_info: The response from Civitai API
Returns:
The populated checkpoint dict
"""
try:
if civitai_info and civitai_info.get("error") != "Model not found":
# Update model name if available
if 'model' in civitai_info and 'name' in civitai_info['model']:
checkpoint['name'] = civitai_info['model']['name']
# Update version if available
if 'name' in civitai_info:
checkpoint['version'] = civitai_info.get('name', '')
# Get thumbnail URL from first image
if 'images' in civitai_info and civitai_info['images']:
checkpoint['thumbnailUrl'] = civitai_info['images'][0].get('url', '')
# Get base model
checkpoint['baseModel'] = civitai_info.get('baseModel', '')
# Get download URL
checkpoint['downloadUrl'] = civitai_info.get('downloadUrl', '')
else:
# Model not found or deleted
checkpoint['isDeleted'] = True
except Exception as e:
logger.error(f"Error populating checkpoint from Civitai info: {e}")
return checkpoint
class RecipeFormatParser(RecipeMetadataParser): class RecipeFormatParser(RecipeMetadataParser):
@@ -109,34 +244,15 @@ class RecipeFormatParser(RecipeMetadataParser):
# Try to get additional info from Civitai if we have a model version ID # Try to get additional info from Civitai if we have a model version ID
if lora.get('modelVersionId') and civitai_client: if lora.get('modelVersionId') and civitai_client:
try: try:
civitai_info = await civitai_client.get_model_version_info(lora['modelVersionId']) civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
if civitai_info and civitai_info.get("error") != "Model not found": # Populate lora entry with Civitai info
# Check if this is an early access lora lora_entry = await self.populate_lora_from_civitai(
if civitai_info.get('earlyAccessEndsAt'): lora_entry,
# Convert earlyAccessEndsAt to a human-readable date civitai_info_tuple,
early_access_date = civitai_info.get('earlyAccessEndsAt', '') recipe_scanner,
lora_entry['isEarlyAccess'] = True None, # No need to track base model counts
lora_entry['earlyAccessEndsAt'] = early_access_date lora['hash']
)
# Get thumbnail URL from first image
if 'images' in civitai_info and civitai_info['images']:
lora_entry['thumbnailUrl'] = civitai_info['images'][0].get('url', '')
# Get base model
lora_entry['baseModel'] = civitai_info.get('baseModel', '')
# Get download URL
lora_entry['downloadUrl'] = civitai_info.get('downloadUrl', '')
# Get size from files if available
if 'files' in civitai_info:
model_file = next((file for file in civitai_info.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
lora_entry['size'] = model_file.get('sizeKB', 0) * 1024
else:
lora_entry['isDeleted'] = True
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA: {e}") logger.error(f"Error fetching Civitai info for LoRA: {e}")
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png' lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
@@ -222,57 +338,17 @@ class StandardMetadataParser(RecipeMetadataParser):
# Get additional info from Civitai if client is available # Get additional info from Civitai if client is available
if civitai_client: if civitai_client:
civitai_info = await civitai_client.get_model_version_info(model_version_id) try:
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
# Check if this LoRA exists locally by SHA256 hash # Populate lora entry with Civitai info
if civitai_info and civitai_info.get("error") != "Model not found": lora_entry = await self.populate_lora_from_civitai(
# Check if this is an early access lora lora_entry,
if civitai_info.get('earlyAccessEndsAt'): civitai_info_tuple,
# Convert earlyAccessEndsAt to a human-readable date recipe_scanner,
early_access_date = civitai_info.get('earlyAccessEndsAt', '') base_model_counts
lora_entry['isEarlyAccess'] = True )
lora_entry['earlyAccessEndsAt'] = early_access_date except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA: {e}")
# LoRA exists on Civitai, process its information
if 'files' in civitai_info:
# Find the model file (type="Model") in the files list
model_file = next((file for file in civitai_info.get('files', [])
if file.get('type') == 'Model'), None)
if model_file and recipe_scanner:
sha256 = model_file.get('hashes', {}).get('SHA256', '')
if sha256:
lora_scanner = recipe_scanner._lora_scanner
exists_locally = lora_scanner.has_lora_hash(sha256)
if exists_locally:
local_path = lora_scanner.get_lora_path_by_hash(sha256)
lora_entry['existsLocally'] = True
lora_entry['localPath'] = local_path
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]
else:
# For missing LoRAs, get file_name from model_file.name
file_name = model_file.get('name', '')
lora_entry['file_name'] = os.path.splitext(file_name)[0] if file_name else ''
lora_entry['hash'] = sha256
lora_entry['size'] = model_file.get('sizeKB', 0) * 1024
# Get thumbnail URL from first image
if 'images' in civitai_info and civitai_info['images']:
lora_entry['thumbnailUrl'] = civitai_info['images'][0].get('url', '')
# Get base model and update counts
current_base_model = civitai_info.get('baseModel', '')
lora_entry['baseModel'] = current_base_model
if current_base_model:
base_model_counts[current_base_model] = base_model_counts.get(current_base_model, 0) + 1
# Get download URL
lora_entry['downloadUrl'] = civitai_info.get('downloadUrl', '')
else:
# LoRA is deleted from Civitai or not found
lora_entry['isDeleted'] = True
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
loras.append(lora_entry) loras.append(lora_entry)
@@ -394,7 +470,9 @@ class A1111MetadataParser(RecipeMetadataParser):
# Extract LoRA information from prompt # Extract LoRA information from prompt
lora_weights = {} lora_weights = {}
lora_matches = re.findall(self.LORA_PATTERN, prompt) lora_matches = re.findall(self.LORA_PATTERN, prompt)
for lora_name, weight in lora_matches: for lora_name, weights in lora_matches:
# Take only the first strength value (before the colon)
weight = weights.split(':')[0]
lora_weights[lora_name.strip()] = float(weight.strip()) lora_weights[lora_name.strip()] = float(weight.strip())
# Remove LoRA patterns from prompt # Remove LoRA patterns from prompt
@@ -436,61 +514,18 @@ class A1111MetadataParser(RecipeMetadataParser):
'isDeleted': False 'isDeleted': False
} }
# Get info from Civitai by hash # Get info from Civitai by hash if available
if civitai_client: if civitai_client and hash_value:
try: try:
civitai_info = await civitai_client.get_model_by_hash(hash_value) civitai_info = await civitai_client.get_model_by_hash(hash_value)
if civitai_info and civitai_info.get("error") != "Model not found": # Populate lora entry with Civitai info
# Check if this is an early access lora lora_entry = await self.populate_lora_from_civitai(
if civitai_info.get('earlyAccessEndsAt'): lora_entry,
# Convert earlyAccessEndsAt to a human-readable date civitai_info,
early_access_date = civitai_info.get('earlyAccessEndsAt', '') recipe_scanner,
lora_entry['isEarlyAccess'] = True base_model_counts,
lora_entry['earlyAccessEndsAt'] = early_access_date hash_value
)
# Get model version ID
lora_entry['id'] = civitai_info.get('id', '')
# Get model name and version
lora_entry['name'] = civitai_info.get('model', {}).get('name', lora_name)
lora_entry['version'] = civitai_info.get('name', '')
# Get thumbnail URL
if 'images' in civitai_info and civitai_info['images']:
lora_entry['thumbnailUrl'] = civitai_info['images'][0].get('url', '')
# Get base model and update counts
current_base_model = civitai_info.get('baseModel', '')
lora_entry['baseModel'] = current_base_model
if current_base_model:
base_model_counts[current_base_model] = base_model_counts.get(current_base_model, 0) + 1
# Get download URL
lora_entry['downloadUrl'] = civitai_info.get('downloadUrl', '')
# Get file name and size from Civitai
if 'files' in civitai_info:
model_file = next((file for file in civitai_info.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
file_name = model_file.get('name', '')
lora_entry['file_name'] = os.path.splitext(file_name)[0] if file_name else lora_name
lora_entry['size'] = model_file.get('sizeKB', 0) * 1024
# Update hash to sha256
lora_entry['hash'] = model_file.get('hashes', {}).get('SHA256', hash_value).lower()
# Check if exists locally with sha256 hash
if recipe_scanner and lora_entry['hash']:
lora_scanner = recipe_scanner._lora_scanner
exists_locally = lora_scanner.has_lora_hash(lora_entry['hash'])
if exists_locally:
lora_cache = await lora_scanner.get_cached_data()
lora_item = next((item for item in lora_cache.raw_data if item['sha256'] == lora_entry['hash']), None)
if lora_item:
lora_entry['existsLocally'] = True
lora_entry['localPath'] = lora_item['file_path']
lora_entry['thumbnailUrl'] = config.get_preview_static_url(lora_item['preview_url'])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {hash_value}: {e}") logger.error(f"Error fetching Civitai info for LoRA hash {hash_value}: {e}")
@@ -523,6 +558,500 @@ class A1111MetadataParser(RecipeMetadataParser):
return {"error": str(e), "loras": []} return {"error": str(e), "loras": []}
class ComfyMetadataParser(RecipeMetadataParser):
"""Parser for Civitai ComfyUI metadata JSON format"""
METADATA_MARKER = r"class_type"
def is_metadata_matching(self, user_comment: str) -> bool:
"""Check if the user comment matches the ComfyUI metadata format"""
try:
data = json.loads(user_comment)
# Check if it contains class_type nodes typical of ComfyUI workflow
return isinstance(data, dict) and any(isinstance(v, dict) and 'class_type' in v for v in data.values())
except (json.JSONDecodeError, TypeError):
return False
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from Civitai ComfyUI metadata format"""
try:
data = json.loads(user_comment)
loras = []
# Find all LoraLoader nodes
lora_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'LoraLoader'}
if not lora_nodes:
return {"error": "No LoRA information found in this ComfyUI workflow", "loras": []}
# Process each LoraLoader node
for node_id, node in lora_nodes.items():
if 'inputs' not in node or 'lora_name' not in node['inputs']:
continue
lora_name = node['inputs'].get('lora_name', '')
# Parse the URN to extract model ID and version ID
# Format: "urn:air:sdxl:lora:civitai:1107767@1253442"
lora_id_match = re.search(r'civitai:(\d+)@(\d+)', lora_name)
if not lora_id_match:
continue
model_id = lora_id_match.group(1)
model_version_id = lora_id_match.group(2)
# Get strength from node inputs
weight = node['inputs'].get('strength_model', 1.0)
# Initialize lora entry with default values
lora_entry = {
'id': model_version_id,
'modelId': model_id,
'name': f"Lora {model_id}", # Default name
'version': '',
'type': 'lora',
'weight': weight,
'existsLocally': False,
'localPath': None,
'file_name': '',
'hash': '',
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Get additional info from Civitai if client is available
if civitai_client:
try:
civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info_tuple,
recipe_scanner
)
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA: {e}")
loras.append(lora_entry)
# Find checkpoint info
checkpoint_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'CheckpointLoaderSimple'}
checkpoint = None
checkpoint_id = None
checkpoint_version_id = None
if checkpoint_nodes:
# Get the first checkpoint node
checkpoint_node = next(iter(checkpoint_nodes.values()))
if 'inputs' in checkpoint_node and 'ckpt_name' in checkpoint_node['inputs']:
checkpoint_name = checkpoint_node['inputs']['ckpt_name']
# Parse checkpoint URN
checkpoint_match = re.search(r'civitai:(\d+)@(\d+)', checkpoint_name)
if checkpoint_match:
checkpoint_id = checkpoint_match.group(1)
checkpoint_version_id = checkpoint_match.group(2)
checkpoint = {
'id': checkpoint_version_id,
'modelId': checkpoint_id,
'name': f"Checkpoint {checkpoint_id}",
'version': '',
'type': 'checkpoint'
}
# Get additional checkpoint info from Civitai
if civitai_client:
try:
civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
# Populate checkpoint with Civitai info
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)
except Exception as e:
logger.error(f"Error fetching Civitai info for checkpoint: {e}")
# Extract generation parameters
gen_params = {}
# First try to get from extraMetadata
if 'extraMetadata' in data:
try:
# extraMetadata is a JSON string that needs to be parsed
extra_metadata = json.loads(data['extraMetadata'])
# Map fields from extraMetadata to our standard format
mapping = {
'prompt': 'prompt',
'negativePrompt': 'negative_prompt',
'steps': 'steps',
'sampler': 'sampler',
'cfgScale': 'cfg_scale',
'seed': 'seed'
}
for src_key, dest_key in mapping.items():
if src_key in extra_metadata:
gen_params[dest_key] = extra_metadata[src_key]
# If size info is available, format as "width x height"
if 'width' in extra_metadata and 'height' in extra_metadata:
gen_params['size'] = f"{extra_metadata['width']}x{extra_metadata['height']}"
except Exception as e:
logger.error(f"Error parsing extraMetadata: {e}")
# If extraMetadata doesn't have all the info, try to get from nodes
if not gen_params or len(gen_params) < 3: # At least we want prompt, negative_prompt, and steps
# Find positive prompt node
positive_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and
v.get('class_type', '').endswith('CLIPTextEncode') and
v.get('_meta', {}).get('title') == 'Positive'}
if positive_nodes:
positive_node = next(iter(positive_nodes.values()))
if 'inputs' in positive_node and 'text' in positive_node['inputs']:
gen_params['prompt'] = positive_node['inputs']['text']
# Find negative prompt node
negative_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and
v.get('class_type', '').endswith('CLIPTextEncode') and
v.get('_meta', {}).get('title') == 'Negative'}
if negative_nodes:
negative_node = next(iter(negative_nodes.values()))
if 'inputs' in negative_node and 'text' in negative_node['inputs']:
gen_params['negative_prompt'] = negative_node['inputs']['text']
# Find KSampler node for other parameters
ksampler_nodes = {k: v for k, v in data.items() if isinstance(v, dict) and v.get('class_type') == 'KSampler'}
if ksampler_nodes:
ksampler_node = next(iter(ksampler_nodes.values()))
if 'inputs' in ksampler_node:
inputs = ksampler_node['inputs']
if 'sampler_name' in inputs:
gen_params['sampler'] = inputs['sampler_name']
if 'steps' in inputs:
gen_params['steps'] = inputs['steps']
if 'cfg' in inputs:
gen_params['cfg_scale'] = inputs['cfg']
if 'seed' in inputs:
gen_params['seed'] = inputs['seed']
# Determine base model from loras info
base_model = None
if loras:
# Use the most common base model from loras
base_models = [lora['baseModel'] for lora in loras if lora.get('baseModel')]
if base_models:
from collections import Counter
base_model_counts = Counter(base_models)
base_model = base_model_counts.most_common(1)[0][0]
return {
'base_model': base_model,
'loras': loras,
'checkpoint': checkpoint,
'gen_params': gen_params,
'from_comfy_metadata': True
}
except Exception as e:
logger.error(f"Error parsing ComfyUI metadata: {e}", exc_info=True)
return {"error": str(e), "loras": []}
class MetaFormatParser(RecipeMetadataParser):
"""Parser for images with meta format metadata (Lora_N Model hash format)"""
METADATA_MARKER = r'Lora_\d+ Model hash:'
def is_metadata_matching(self, user_comment: str) -> bool:
"""Check if the user comment matches the metadata format"""
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from images with meta format metadata"""
try:
# Extract prompt and negative prompt
parts = user_comment.split('Negative prompt:', 1)
prompt = parts[0].strip()
# Initialize metadata
metadata = {"prompt": prompt, "loras": []}
# Extract negative prompt and parameters if available
if len(parts) > 1:
negative_and_params = parts[1]
# Extract negative prompt - everything until the first parameter (usually "Steps:")
param_start = re.search(r'([A-Za-z]+): ', negative_and_params)
if param_start:
neg_prompt = negative_and_params[:param_start.start()].strip()
metadata["negative_prompt"] = neg_prompt
params_section = negative_and_params[param_start.start():]
else:
params_section = negative_and_params
# Extract key-value parameters (Steps, Sampler, Seed, etc.)
param_pattern = r'([A-Za-z_0-9 ]+): ([^,]+)'
params = re.findall(param_pattern, params_section)
for key, value in params:
clean_key = key.strip().lower().replace(' ', '_')
metadata[clean_key] = value.strip()
# Extract LoRA information
# Pattern to match lora entries: Lora_0 Model name: ArtVador I.safetensors, Lora_0 Model hash: 08f7133a58, etc.
lora_pattern = r'Lora_(\d+) Model name: ([^,]+), Lora_\1 Model hash: ([^,]+), Lora_\1 Strength model: ([^,]+), Lora_\1 Strength clip: ([^,]+)'
lora_matches = re.findall(lora_pattern, user_comment)
# If the regular pattern doesn't match, try a more flexible approach
if not lora_matches:
# First find all Lora indices
lora_indices = set(re.findall(r'Lora_(\d+)', user_comment))
# For each index, extract the information
for idx in lora_indices:
lora_info = {}
# Extract model name
name_match = re.search(f'Lora_{idx} Model name: ([^,]+)', user_comment)
if name_match:
lora_info['name'] = name_match.group(1).strip()
# Extract model hash
hash_match = re.search(f'Lora_{idx} Model hash: ([^,]+)', user_comment)
if hash_match:
lora_info['hash'] = hash_match.group(1).strip()
# Extract strength model
strength_model_match = re.search(f'Lora_{idx} Strength model: ([^,]+)', user_comment)
if strength_model_match:
lora_info['strength_model'] = float(strength_model_match.group(1).strip())
# Extract strength clip
strength_clip_match = re.search(f'Lora_{idx} Strength clip: ([^,]+)', user_comment)
if strength_clip_match:
lora_info['strength_clip'] = float(strength_clip_match.group(1).strip())
# Only add if we have at least name and hash
if 'name' in lora_info and 'hash' in lora_info:
lora_matches.append((idx, lora_info['name'], lora_info['hash'],
str(lora_info.get('strength_model', 1.0)),
str(lora_info.get('strength_clip', 1.0))))
# Process LoRAs
base_model_counts = {}
loras = []
for match in lora_matches:
if len(match) == 5: # Regular pattern match
idx, name, hash_value, strength_model, strength_clip = match
else: # Flexible approach match
continue # Should not happen now
# Clean up the values
name = name.strip()
if name.endswith('.safetensors'):
name = name[:-12] # Remove .safetensors extension
hash_value = hash_value.strip()
weight = float(strength_model) # Use model strength as weight
# Initialize lora entry with default values
lora_entry = {
'name': name,
'type': 'lora',
'weight': weight,
'existsLocally': False,
'localPath': None,
'file_name': name,
'hash': hash_value,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Get info from Civitai by hash if available
if civitai_client and hash_value:
try:
civitai_info = await civitai_client.get_model_by_hash(hash_value)
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
hash_value
)
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {hash_value}: {e}")
loras.append(lora_entry)
# Extract model information
model = None
if 'model' in metadata:
model = metadata['model']
# Set base_model to the most common one from civitai_info
base_model = None
if base_model_counts:
base_model = max(base_model_counts.items(), key=lambda x: x[1])[0]
# Extract generation parameters for recipe metadata
gen_params = {}
for key in GEN_PARAM_KEYS:
if key in metadata:
gen_params[key] = metadata.get(key, '')
# Try to extract size information if available
if 'width' in metadata and 'height' in metadata:
gen_params['size'] = f"{metadata['width']}x{metadata['height']}"
return {
'base_model': base_model,
'loras': loras,
'gen_params': gen_params,
'raw_metadata': metadata,
'from_meta_format': True
}
except Exception as e:
logger.error(f"Error parsing meta format metadata: {e}", exc_info=True)
return {"error": str(e), "loras": []}
class ImageSaverMetadataParser(RecipeMetadataParser):
"""Parser for ComfyUI Image Saver plugin metadata format"""
METADATA_MARKER = r'Hashes: \{"LORA:'
LORA_PATTERN = r'<lora:([^:]+):([^>]+)>'
HASH_PATTERN = r'Hashes: (\{.*?\})'
def is_metadata_matching(self, user_comment: str) -> bool:
"""Check if the user comment matches the Image Saver metadata format"""
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from Image Saver plugin format"""
try:
# Extract prompt and negative prompt
parts = user_comment.split('Negative prompt:', 1)
prompt = parts[0].strip()
# Initialize metadata
metadata = {"prompt": prompt, "loras": []}
# Extract negative prompt and parameters
if len(parts) > 1:
negative_and_params = parts[1]
# Extract negative prompt
if "Steps:" in negative_and_params:
neg_prompt = negative_and_params.split("Steps:", 1)[0].strip()
metadata["negative_prompt"] = neg_prompt
# Extract key-value parameters (Steps, Sampler, CFG scale, etc.)
param_pattern = r'([A-Za-z ]+): ([^,]+)'
params = re.findall(param_pattern, negative_and_params)
for key, value in params:
clean_key = key.strip().lower().replace(' ', '_')
metadata[clean_key] = value.strip()
# Extract LoRA information from prompt
lora_weights = {}
lora_matches = re.findall(self.LORA_PATTERN, prompt)
for lora_name, weight in lora_matches:
lora_weights[lora_name.strip()] = float(weight.split(':')[0].strip())
# Remove LoRA patterns from prompt
metadata["prompt"] = re.sub(self.LORA_PATTERN, '', prompt).strip()
# Extract LoRA hashes from Hashes section
lora_hashes = {}
hash_match = re.search(self.HASH_PATTERN, user_comment)
if hash_match:
try:
hashes = json.loads(hash_match.group(1))
for key, hash_value in hashes.items():
if key.startswith('LORA:'):
lora_name = key[5:] # Remove 'LORA:' prefix
lora_hashes[lora_name] = hash_value.strip()
except json.JSONDecodeError:
pass
# Process LoRAs and collect base models
base_model_counts = {}
loras = []
# Process each LoRA with hash and weight
for lora_name, hash_value in lora_hashes.items():
weight = lora_weights.get(lora_name, 1.0)
# Initialize lora entry with default values
lora_entry = {
'name': lora_name,
'type': 'lora',
'weight': weight,
'existsLocally': False,
'localPath': None,
'file_name': lora_name,
'hash': hash_value,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Get info from Civitai by hash if available
if civitai_client and hash_value:
try:
civitai_info = await civitai_client.get_model_by_hash(hash_value)
# Populate lora entry with Civitai info
lora_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
hash_value
)
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {hash_value}: {e}")
loras.append(lora_entry)
# Set base_model to the most common one from civitai_info
base_model = None
if base_model_counts:
base_model = max(base_model_counts.items(), key=lambda x: x[1])[0]
# Extract generation parameters for recipe metadata
gen_params = {}
for key in GEN_PARAM_KEYS:
if key in metadata:
gen_params[key] = metadata.get(key, '')
# Add model information if available
if 'model' in metadata:
gen_params['checkpoint'] = metadata['model']
return {
'base_model': base_model,
'loras': loras,
'gen_params': gen_params,
'raw_metadata': metadata
}
except Exception as e:
logger.error(f"Error parsing Image Saver metadata: {e}", exc_info=True)
return {"error": str(e), "loras": []}
class RecipeParserFactory: class RecipeParserFactory:
"""Factory for creating recipe metadata parsers""" """Factory for creating recipe metadata parsers"""
@@ -537,11 +1066,23 @@ class RecipeParserFactory:
Returns: Returns:
Appropriate RecipeMetadataParser implementation Appropriate RecipeMetadataParser implementation
""" """
# Try ComfyMetadataParser first since it requires valid JSON
try:
if ComfyMetadataParser().is_metadata_matching(user_comment):
return ComfyMetadataParser()
except Exception:
# If JSON parsing fails, move on to other parsers
pass
if RecipeFormatParser().is_metadata_matching(user_comment): if RecipeFormatParser().is_metadata_matching(user_comment):
return RecipeFormatParser() return RecipeFormatParser()
elif StandardMetadataParser().is_metadata_matching(user_comment): elif StandardMetadataParser().is_metadata_matching(user_comment):
return StandardMetadataParser() return StandardMetadataParser()
elif A1111MetadataParser().is_metadata_matching(user_comment): elif A1111MetadataParser().is_metadata_matching(user_comment):
return A1111MetadataParser() return A1111MetadataParser()
elif MetaFormatParser().is_metadata_matching(user_comment):
return MetaFormatParser()
elif ImageSaverMetadataParser().is_metadata_matching(user_comment):
return ImageSaverMetadataParser()
else: else:
return None return None

503
py/utils/routes_common.py Normal file
View File

@@ -0,0 +1,503 @@
import os
import json
import logging
from typing import Dict, List, Callable, Awaitable
from aiohttp import web
from .model_utils import determine_base_model
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..config import config
from ..services.civitai_client import CivitaiClient
from ..utils.exif_utils import ExifUtils
from ..services.download_manager import DownloadManager
logger = logging.getLogger(__name__)
class ModelRouteUtils:
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
@staticmethod
async def load_local_metadata(metadata_path: str) -> Dict:
"""Load local metadata file"""
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading metadata from {metadata_path}: {e}")
return {}
@staticmethod
async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def update_model_metadata(metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Update base model
local_metadata['base_model'] = determine_base_model(civitai_metadata.get('baseModel'))
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
# Determine if content is video or image
is_video = first_preview['type'] == 'video'
if is_video:
# For videos use .mp4 extension
preview_ext = '.mp4'
else:
# For images use .webp extension
preview_ext = '.webp'
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_filename = base_name + preview_ext
preview_path = os.path.join(os.path.dirname(metadata_path), preview_filename)
if is_video:
# Download video as is
if await client.download_preview_image(first_preview['url'], preview_path):
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
else:
# For images, download and then optimize to WebP
temp_path = preview_path + ".temp"
if await client.download_preview_image(first_preview['url'], temp_path):
try:
# Read the downloaded image
with open(temp_path, 'rb') as f:
image_data = f.read()
# Optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=image_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
# Save the optimized WebP image
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update metadata
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Remove the temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
except Exception as e:
logger.error(f"Error optimizing preview image: {e}")
# If optimization fails, try to use the downloaded image directly
if os.path.exists(temp_path):
os.rename(temp_path, preview_path)
local_metadata['preview_url'] = preview_path.replace(os.sep, '/')
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
@staticmethod
async def fetch_and_update_model(
sha256: str,
file_path: str,
model_data: dict,
update_cache_func: Callable[[str, str, Dict], Awaitable[bool]]
) -> bool:
"""Fetch and update metadata for a single model
Args:
sha256: SHA256 hash of the model file
file_path: Path to the model file
model_data: The model object in cache to update
update_cache_func: Function to update the cache with new metadata
Returns:
bool: True if successful, False otherwise
"""
client = CivitaiClient()
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Fetch metadata from Civitai
civitai_metadata = await client.get_model_by_hash(sha256)
if not civitai_metadata:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
model_data['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
return False
# Update metadata
await ModelRouteUtils.update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
client
)
# Update cache object directly
model_data.update({
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
# Update cache using the provided function
await update_cache_func(file_path, file_path, local_metadata)
return True
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
return False
finally:
await client.close()
@staticmethod
def filter_civitai_data(data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
@staticmethod
async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]:
"""Delete model and associated files
Args:
target_dir: Directory containing the model files
file_name: Base name of the model file without extension
file_monitor: Optional file monitor to ignore delete events
Returns:
List of deleted file paths
"""
patterns = [
f"{file_name}.safetensors", # Required
f"{file_name}.metadata.json",
]
# Add all preview file extensions
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event if available
if file_monitor:
file_monitor.handler.add_ignore_path(main_path, 0)
# Delete file
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning(f"Model file not found: {main_file}")
# Delete optional files
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as e:
logger.warning(f"Failed to delete {pattern}: {e}")
return deleted
@staticmethod
def get_multipart_ext(filename):
"""Get extension that may have multiple parts like .metadata.json"""
parts = filename.split(".")
if len(parts) > 2: # If contains multi-part extension
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"
# New common endpoint handlers
@staticmethod
async def handle_delete_model(request: web.Request, scanner) -> web.Response:
"""Handle model deletion request
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='Model path is required', status=400)
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
# Get the file monitor from the scanner if available
file_monitor = getattr(scanner, 'file_monitor', None)
deleted_files = await ModelRouteUtils.delete_model_files(
target_dir,
file_name,
file_monitor
)
# Remove from cache
cache = await scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
await cache.resort()
# Update hash index if available
if hasattr(scanner, '_hash_index') and scanner._hash_index:
scanner._hash_index.remove_by_path(file_path)
return web.json_response({
'success': True,
'deleted_files': deleted_files
})
except Exception as e:
logger.error(f"Error deleting model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response:
"""Handle CivitAI metadata fetch request
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
if not local_metadata or not local_metadata.get('sha256'):
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
# Create a client for fetching from Civitai
client = CivitaiClient()
try:
# Fetch and update metadata
civitai_metadata = await client.get_model_by_hash(local_metadata["sha256"])
if not civitai_metadata:
await ModelRouteUtils.handle_not_found_on_civitai(metadata_path, local_metadata)
return web.json_response({"success": False, "error": "Not found on CivitAI"}, status=404)
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client)
# Update the cache
await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata)
return web.json_response({"success": True})
finally:
await client.close()
except Exception as e:
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
return web.json_response({"success": False, "error": str(e)}, status=500)
@staticmethod
async def handle_replace_preview(request: web.Request, scanner) -> web.Response:
"""Handle preview image replacement request
Args:
request: The aiohttp request
scanner: The model scanner instance with methods to update cache
Returns:
web.Response: The HTTP response
"""
try:
reader = await request.multipart()
# Read preview file data
field = await reader.next()
if field.name != 'preview_file':
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get('Content-Type', 'image/png')
preview_data = await field.read()
# Read model path
field = await reader.next()
if field.name != 'model_path':
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
# Save preview file
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
# Determine if content is video or image
if content_type.startswith('video/'):
# For videos, keep original format and use .mp4 extension
extension = '.mp4'
optimized_data = preview_data
else:
# For images, optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=preview_data,
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False
)
extension = '.webp' # Use .webp without .preview part
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update preview path in metadata
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update preview_url directly in the metadata dict
metadata['preview_url'] = preview_path
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
except Exception as e:
logger.error(f"Error updating metadata: {e}")
# Update preview URL in scanner cache
if hasattr(scanner, 'update_preview_in_cache'):
await scanner.update_preview_in_cache(model_path, preview_path)
return web.json_response({
"success": True,
"preview_url": config.get_preview_static_url(preview_path)
})
except Exception as e:
logger.error(f"Error replacing preview: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
"""Handle model download request
Args:
request: The aiohttp request
download_manager: Instance of DownloadManager
model_type: Type of model ('lora' or 'checkpoint')
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
# Create progress callback
async def progress_callback(progress):
from ..services.websocket_manager import ws_manager
await ws_manager.broadcast({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
# Use the correct root directory based on model type
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
save_dir = data.get(root_key)
result = await download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=save_dir,
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type=model_type
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401, # Use 401 status code to match Civitai's response
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
)
logger.error(f"Error downloading {model_type}: {error_message}")
return web.Response(status=500, text=error_message)

267
py/utils/usage_stats.py Normal file
View File

@@ -0,0 +1,267 @@
import os
import json
import time
import asyncio
import logging
from typing import Dict, Set
from ..config import config
from ..services.service_registry import ServiceRegistry
from ..metadata_collector.metadata_registry import MetadataRegistry
from ..metadata_collector.constants import MODELS, LORAS
logger = logging.getLogger(__name__)
class UsageStats:
"""Track usage statistics for models and save to JSON"""
_instance = None
_lock = asyncio.Lock() # For thread safety
# Default stats file name
STATS_FILENAME = "lora_manager_stats.json"
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
# Initialize stats storage
self.stats = {
"checkpoints": {}, # sha256 -> count
"loras": {}, # sha256 -> count
"total_executions": 0,
"last_save_time": 0
}
# Queue for prompt_ids to process
self.pending_prompt_ids = set()
# Load existing stats if available
self._stats_file_path = self._get_stats_file_path()
self._load_stats()
# Save interval in seconds
self.save_interval = 90 # 1.5 minutes
# Start background task to process queued prompt_ids
self._bg_task = asyncio.create_task(self._background_processor())
self._initialized = True
logger.info("Usage statistics tracker initialized")
def _get_stats_file_path(self) -> str:
"""Get the path to the stats JSON file"""
if not config.loras_roots or len(config.loras_roots) == 0:
# Fallback to temporary directory if no lora roots
return os.path.join(config.temp_directory, self.STATS_FILENAME)
# Use the first lora root
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
def _load_stats(self):
"""Load existing statistics from file"""
try:
if os.path.exists(self._stats_file_path):
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
loaded_stats = json.load(f)
# Update our stats with loaded data
if isinstance(loaded_stats, dict):
# Update individual sections to maintain structure
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
self.stats["checkpoints"] = loaded_stats["checkpoints"]
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
self.stats["loras"] = loaded_stats["loras"]
if "total_executions" in loaded_stats:
self.stats["total_executions"] = loaded_stats["total_executions"]
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
except Exception as e:
logger.error(f"Error loading usage statistics: {e}")
async def save_stats(self, force=False):
"""Save statistics to file"""
try:
# Only save if it's been at least save_interval since last save or force is True
current_time = time.time()
if not force and (current_time - self.stats.get("last_save_time", 0)) < self.save_interval:
return False
# Use a lock to prevent concurrent writes
async with self._lock:
# Update last save time
self.stats["last_save_time"] = current_time
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self._stats_file_path), exist_ok=True)
# Write to a temporary file first, then move it to avoid corruption
temp_path = f"{self._stats_file_path}.tmp"
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(self.stats, f, indent=2, ensure_ascii=False)
# Replace the old file with the new one
os.replace(temp_path, self._stats_file_path)
logger.debug(f"Saved usage statistics to {self._stats_file_path}")
return True
except Exception as e:
logger.error(f"Error saving usage statistics: {e}", exc_info=True)
return False
def register_execution(self, prompt_id):
"""Register a completed execution by prompt_id for later processing"""
if prompt_id:
self.pending_prompt_ids.add(prompt_id)
async def _background_processor(self):
"""Background task to process queued prompt_ids"""
try:
while True:
# Wait a short interval before checking for new prompt_ids
await asyncio.sleep(5) # Check every 5 seconds
# Process any pending prompt_ids
if self.pending_prompt_ids:
async with self._lock:
# Get a copy of the set and clear original
prompt_ids = self.pending_prompt_ids.copy()
self.pending_prompt_ids.clear()
# Process each prompt_id
registry = MetadataRegistry()
for prompt_id in prompt_ids:
try:
metadata = registry.get_metadata(prompt_id)
await self._process_metadata(metadata)
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}")
# Periodically save stats
await self.save_stats()
except asyncio.CancelledError:
# Task was cancelled, clean up
await self.save_stats(force=True)
except Exception as e:
logger.error(f"Error in background processing task: {e}", exc_info=True)
# Restart the task after a delay if it fails
asyncio.create_task(self._restart_background_task())
async def _restart_background_task(self):
"""Restart the background task after a delay"""
await asyncio.sleep(30) # Wait 30 seconds before restarting
self._bg_task = asyncio.create_task(self._background_processor())
async def _process_metadata(self, metadata):
"""Process metadata from an execution"""
if not metadata or not isinstance(metadata, dict):
return
# Increment total executions count
self.stats["total_executions"] += 1
# Process checkpoints
if MODELS in metadata and isinstance(metadata[MODELS], dict):
await self._process_checkpoints(metadata[MODELS])
# Process loras
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS])
async def _process_checkpoints(self, models_data):
"""Process checkpoint models from metadata"""
try:
# Get checkpoint scanner service
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
if not checkpoint_scanner:
logger.warning("Checkpoint scanner not available for usage tracking")
return
for node_id, model_info in models_data.items():
if not isinstance(model_info, dict):
continue
# Check if this is a checkpoint model
model_type = model_info.get("type")
if model_type == "checkpoint":
model_name = model_info.get("name")
if not model_name:
continue
# Clean up filename (remove extension if present)
model_filename = os.path.splitext(os.path.basename(model_name))[0]
# Get hash for this checkpoint
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
if model_hash:
# Update stats for this checkpoint
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
except Exception as e:
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
async def _process_loras(self, loras_data):
"""Process LoRA models from metadata"""
try:
# Get LoRA scanner service
lora_scanner = await ServiceRegistry.get_lora_scanner()
if not lora_scanner:
logger.warning("LoRA scanner not available for usage tracking")
return
for node_id, lora_info in loras_data.items():
if not isinstance(lora_info, dict):
continue
# Get the list of LoRAs from standardized format
lora_list = lora_info.get("lora_list", [])
for lora in lora_list:
if not isinstance(lora, dict):
continue
lora_name = lora.get("name")
if not lora_name:
continue
# Get hash for this LoRA
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
if lora_hash:
# Update stats for this LoRA
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
async def get_stats(self):
"""Get current usage statistics"""
return self.stats
async def get_model_usage_count(self, model_type, sha256):
"""Get usage count for a specific model by hash"""
if model_type == "checkpoint":
return self.stats["checkpoints"].get(sha256, 0)
elif model_type == "lora":
return self.stats["loras"].get(sha256, 0)
return 0
async def process_execution(self, prompt_id):
"""Process a prompt execution immediately (synchronous approach)"""
if not prompt_id:
return
try:
# Process metadata for this prompt_id
registry = MetadataRegistry()
metadata = registry.get_metadata(prompt_id)
if metadata:
await self._process_metadata(metadata)
# Save stats if needed
await self.save_stats()
except Exception as e:
logger.error(f"Error processing prompt_id {prompt_id}: {e}", exc_info=True)

View File

@@ -40,7 +40,45 @@ def download_twitter_image(url):
except Exception as e: except Exception as e:
print(f"Error downloading twitter image: {e}") print(f"Error downloading twitter image: {e}")
return None return None
def download_civitai_image(url):
"""Download image from a URL containing avatar image with specific class and style attributes
Args:
url (str): The URL to download image from
Returns:
str: Path to downloaded temporary image file
"""
try:
# Download page content
response = requests.get(url)
response.raise_for_status()
# Parse HTML
soup = BeautifulSoup(response.text, 'html.parser')
# Find image with specific class and style attributes
image = soup.select_one('img.EdgeImage_image__iH4_q.max-h-full.w-auto.max-w-full')
if not image or 'src' not in image.attrs:
return None
image_url = image['src']
# Download image
image_response = requests.get(image_url)
image_response.raise_for_status()
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
temp_file.write(image_response.content)
return temp_file.name
except Exception as e:
print(f"Error downloading civitai avatar: {e}")
return None
def fuzzy_match(text: str, pattern: str, threshold: float = 0.7) -> bool: def fuzzy_match(text: str, pattern: str, threshold: float = 0.7) -> bool:
""" """
Check if text matches pattern using fuzzy matching. Check if text matches pattern using fuzzy matching.

View File

@@ -1,149 +0,0 @@
# ComfyUI Workflow Parser
本模块提供了一个灵活的解析系统可以从ComfyUI工作流中提取生成参数和LoRA信息。
## 设计理念
工作流解析器基于以下设计原则:
1. **模块化**: 每种节点类型由独立的mapper处理
2. **可扩展性**: 通过扩展系统轻松添加新的节点类型支持
3. **回溯**: 通过工作流图的模型输入路径跟踪LoRA节点
4. **灵活性**: 适应不同的ComfyUI工作流结构
## 主要组件
### 1. NodeMapper
`NodeMapper`是所有节点映射器的基类,定义了如何从工作流中提取节点信息:
```python
class NodeMapper:
def __init__(self, node_type: str, inputs_to_track: List[str]):
self.node_type = node_type
self.inputs_to_track = inputs_to_track
def process(self, node_id: str, node_data: Dict, workflow: Dict, parser) -> Any:
# 处理节点的通用逻辑
...
def transform(self, inputs: Dict) -> Any:
# 由子类覆盖以提供特定转换
return inputs
```
### 2. WorkflowParser
主要解析类,通过跟踪工作流图来提取参数:
```python
parser = WorkflowParser()
result = parser.parse_workflow("workflow.json")
```
### 3. 扩展系统
允许通过添加新的自定义mapper来扩展支持的节点类型:
```python
# 在py/workflow/ext/中添加自定义mapper模块
load_extensions() # 自动加载所有扩展
```
## 使用方法
### 基本用法
```python
from workflow.parser import parse_workflow
# 解析工作流并保存结果
result = parse_workflow("workflow.json", "output.json")
```
### 自定义解析
```python
from workflow.parser import WorkflowParser
from workflow.mappers import register_mapper, load_extensions
# 加载扩展
load_extensions()
# 创建解析器
parser = WorkflowParser(load_extensions_on_init=False) # 不自动加载扩展
# 解析工作流
result = parser.parse_workflow(workflow_data)
```
## 扩展系统
### 添加新的节点映射器
`py/workflow/ext/`目录中创建Python文件定义从`NodeMapper`继承的类:
```python
# example_mapper.py
from ..mappers import NodeMapper
class MyCustomNodeMapper(NodeMapper):
def __init__(self):
super().__init__(
node_type="MyCustomNode", # 节点的class_type
inputs_to_track=["param1", "param2"] # 要提取的参数
)
def transform(self, inputs: Dict) -> Any:
# 处理提取的参数
return {
"custom_param": inputs.get("param1", "default")
}
```
扩展系统会自动加载和注册这些映射器。
### LoraManager节点说明
LoraManager相关节点的处理方式:
1. **Lora Loader**: 处理`loras`数组,过滤出`active=true`的条目,和`lora_stack`输入
2. **Lora Stacker**: 处理`loras`数组和已有的`lora_stack`构建叠加的LoRA
3. **TriggerWord Toggle**: 从`toggle_trigger_words`中提取`active=true`的条目
## 输出格式
解析器生成的输出格式如下:
```json
{
"gen_params": {
"prompt": "...",
"negative_prompt": "",
"steps": "25",
"sampler": "dpmpp_2m",
"scheduler": "beta",
"cfg": "1",
"seed": "48",
"guidance": 3.5,
"size": "896x1152",
"clip_skip": "2"
},
"loras": "<lora:name1:0.9> <lora:name2:0.8>"
}
```
## 高级用法
### 直接注册映射器
```python
from workflow.mappers import register_mapper
from workflow.mappers import NodeMapper
# 创建自定义映射器
class CustomMapper(NodeMapper):
# ...实现映射器
# 注册映射器
register_mapper(CustomMapper())

View File

@@ -0,0 +1,285 @@
"""
ComfyUI Core nodes mappers extension for workflow parsing
"""
import logging
from typing import Dict, Any, List
logger = logging.getLogger(__name__)
# =============================================================================
# Transform Functions
# =============================================================================
def transform_random_noise(inputs: Dict) -> Dict:
"""Transform function for RandomNoise node"""
return {"seed": str(inputs.get("noise_seed", ""))}
def transform_ksampler_select(inputs: Dict) -> Dict:
"""Transform function for KSamplerSelect node"""
return {"sampler": inputs.get("sampler_name", "")}
def transform_basic_scheduler(inputs: Dict) -> Dict:
"""Transform function for BasicScheduler node"""
result = {
"scheduler": inputs.get("scheduler", ""),
"denoise": str(inputs.get("denoise", "1.0"))
}
# Get steps from inputs or steps input
if "steps" in inputs:
if isinstance(inputs["steps"], str):
result["steps"] = inputs["steps"]
elif isinstance(inputs["steps"], dict) and "value" in inputs["steps"]:
result["steps"] = str(inputs["steps"]["value"])
else:
result["steps"] = str(inputs["steps"])
return result
def transform_basic_guider(inputs: Dict) -> Dict:
"""Transform function for BasicGuider node"""
result = {}
# Process conditioning
if "conditioning" in inputs:
if isinstance(inputs["conditioning"], str):
result["prompt"] = inputs["conditioning"]
elif isinstance(inputs["conditioning"], dict):
result["conditioning"] = inputs["conditioning"]
# Get model information if needed
if "model" in inputs and isinstance(inputs["model"], dict):
result["model"] = inputs["model"]
return result
def transform_model_sampling_flux(inputs: Dict) -> Dict:
"""Transform function for ModelSamplingFlux - mostly a pass-through node"""
# This node is primarily used for routing, so we mostly pass through values
return inputs["model"]
def transform_sampler_custom_advanced(inputs: Dict) -> Dict:
"""Transform function for SamplerCustomAdvanced node"""
result = {}
# Extract seed from noise
if "noise" in inputs and isinstance(inputs["noise"], dict):
result["seed"] = str(inputs["noise"].get("seed", ""))
# Extract sampler info
if "sampler" in inputs and isinstance(inputs["sampler"], dict):
sampler = inputs["sampler"].get("sampler", "")
if sampler:
result["sampler"] = sampler
# Extract scheduler, steps, denoise from sigmas
if "sigmas" in inputs and isinstance(inputs["sigmas"], dict):
sigmas = inputs["sigmas"]
result["scheduler"] = sigmas.get("scheduler", "")
result["steps"] = str(sigmas.get("steps", ""))
result["denoise"] = str(sigmas.get("denoise", "1.0"))
# Extract prompt and guidance from guider
if "guider" in inputs and isinstance(inputs["guider"], dict):
guider = inputs["guider"]
# Get prompt from conditioning
if "conditioning" in guider and isinstance(guider["conditioning"], str):
result["prompt"] = guider["conditioning"]
elif "conditioning" in guider and isinstance(guider["conditioning"], dict):
result["guidance"] = guider["conditioning"].get("guidance", "")
result["prompt"] = guider["conditioning"].get("prompt", "")
if "model" in guider and isinstance(guider["model"], dict):
result["checkpoint"] = guider["model"].get("checkpoint", "")
result["loras"] = guider["model"].get("loras", "")
result["clip_skip"] = str(int(guider["model"].get("clip_skip", "-1")) * -1)
# Extract dimensions from latent_image
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
latent = inputs["latent_image"]
width = latent.get("width", 0)
height = latent.get("height", 0)
if width and height:
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
return result
def transform_ksampler(inputs: Dict) -> Dict:
"""Transform function for KSampler nodes"""
result = {
"seed": str(inputs.get("seed", "")),
"steps": str(inputs.get("steps", "")),
"cfg": str(inputs.get("cfg", "")),
"sampler": inputs.get("sampler_name", ""),
"scheduler": inputs.get("scheduler", ""),
}
# Process positive prompt
if "positive" in inputs:
result["prompt"] = inputs["positive"]
# Process negative prompt
if "negative" in inputs:
result["negative_prompt"] = inputs["negative"]
# Get dimensions from latent image
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
width = inputs["latent_image"].get("width", 0)
height = inputs["latent_image"].get("height", 0)
if width and height:
result["size"] = f"{width}x{height}"
# Add clip_skip if present
if "clip_skip" in inputs:
result["clip_skip"] = str(inputs.get("clip_skip", ""))
# Add guidance if present
if "guidance" in inputs:
result["guidance"] = str(inputs.get("guidance", ""))
# Add model if present
if "model" in inputs:
result["checkpoint"] = inputs.get("model", {}).get("checkpoint", "")
result["loras"] = inputs.get("model", {}).get("loras", "")
result["clip_skip"] = str(inputs.get("model", {}).get("clip_skip", -1) * -1)
return result
def transform_empty_latent(inputs: Dict) -> Dict:
"""Transform function for EmptyLatentImage nodes"""
width = inputs.get("width", 0)
height = inputs.get("height", 0)
return {"width": width, "height": height, "size": f"{width}x{height}"}
def transform_clip_text(inputs: Dict) -> Any:
"""Transform function for CLIPTextEncode nodes"""
return inputs.get("text", "")
def transform_flux_guidance(inputs: Dict) -> Dict:
"""Transform function for FluxGuidance nodes"""
result = {}
if "guidance" in inputs:
result["guidance"] = inputs["guidance"]
if "conditioning" in inputs:
conditioning = inputs["conditioning"]
if isinstance(conditioning, str):
result["prompt"] = conditioning
else:
result["prompt"] = "Unknown prompt"
return result
def transform_unet_loader(inputs: Dict) -> Dict:
"""Transform function for UNETLoader node"""
unet_name = inputs.get("unet_name", "")
return {"checkpoint": unet_name} if unet_name else {}
def transform_checkpoint_loader(inputs: Dict) -> Dict:
"""Transform function for CheckpointLoaderSimple node"""
ckpt_name = inputs.get("ckpt_name", "")
return {"checkpoint": ckpt_name} if ckpt_name else {}
def transform_latent_upscale_by(inputs: Dict) -> Dict:
"""Transform function for LatentUpscaleBy node"""
result = {}
width = inputs["samples"].get("width", 0) * inputs["scale_by"]
height = inputs["samples"].get("height", 0) * inputs["scale_by"]
result["width"] = width
result["height"] = height
result["size"] = f"{width}x{height}"
return result
def transform_clip_set_last_layer(inputs: Dict) -> Dict:
"""Transform function for CLIPSetLastLayer node"""
result = {}
if "stop_at_clip_layer" in inputs:
result["clip_skip"] = inputs["stop_at_clip_layer"]
return result
# =============================================================================
# Node Mapper Definitions
# =============================================================================
# Define the mappers for ComfyUI core nodes not in main mapper
NODE_MAPPERS_EXT = {
# KSamplers
"SamplerCustomAdvanced": {
"inputs_to_track": ["noise", "guider", "sampler", "sigmas", "latent_image"],
"transform_func": transform_sampler_custom_advanced
},
"KSampler": {
"inputs_to_track": [
"seed", "steps", "cfg", "sampler_name", "scheduler",
"denoise", "positive", "negative", "latent_image",
"model", "clip_skip"
],
"transform_func": transform_ksampler
},
# ComfyUI core nodes
"EmptyLatentImage": {
"inputs_to_track": ["width", "height", "batch_size"],
"transform_func": transform_empty_latent
},
"EmptySD3LatentImage": {
"inputs_to_track": ["width", "height", "batch_size"],
"transform_func": transform_empty_latent
},
"CLIPTextEncode": {
"inputs_to_track": ["text", "clip"],
"transform_func": transform_clip_text
},
"FluxGuidance": {
"inputs_to_track": ["guidance", "conditioning"],
"transform_func": transform_flux_guidance
},
"RandomNoise": {
"inputs_to_track": ["noise_seed"],
"transform_func": transform_random_noise
},
"KSamplerSelect": {
"inputs_to_track": ["sampler_name"],
"transform_func": transform_ksampler_select
},
"BasicScheduler": {
"inputs_to_track": ["scheduler", "steps", "denoise", "model"],
"transform_func": transform_basic_scheduler
},
"BasicGuider": {
"inputs_to_track": ["model", "conditioning"],
"transform_func": transform_basic_guider
},
"ModelSamplingFlux": {
"inputs_to_track": ["max_shift", "base_shift", "width", "height", "model"],
"transform_func": transform_model_sampling_flux
},
"UNETLoader": {
"inputs_to_track": ["unet_name"],
"transform_func": transform_unet_loader
},
"CheckpointLoaderSimple": {
"inputs_to_track": ["ckpt_name"],
"transform_func": transform_checkpoint_loader
},
"LatentUpscale": {
"inputs_to_track": ["width", "height"],
"transform_func": transform_empty_latent
},
"LatentUpscaleBy": {
"inputs_to_track": ["samples", "scale_by"],
"transform_func": transform_latent_upscale_by
},
"CLIPSetLastLayer": {
"inputs_to_track": ["clip", "stop_at_clip_layer"],
"transform_func": transform_clip_set_last_layer
}
}

View File

@@ -1,54 +0,0 @@
"""
Example extension mapper for demonstrating the extension system
"""
from typing import Dict, Any
from ..mappers import NodeMapper
class ExampleNodeMapper(NodeMapper):
"""Example mapper for custom nodes"""
def __init__(self):
super().__init__(
node_type="ExampleCustomNode",
inputs_to_track=["param1", "param2", "image"]
)
def transform(self, inputs: Dict) -> Dict:
"""Transform extracted inputs into the desired output format"""
result = {}
# Extract interesting parameters
if "param1" in inputs:
result["example_param1"] = inputs["param1"]
if "param2" in inputs:
result["example_param2"] = inputs["param2"]
# You can process the data in any way needed
return result
class VAEMapperExtension(NodeMapper):
"""Extension mapper for VAE nodes"""
def __init__(self):
super().__init__(
node_type="VAELoader",
inputs_to_track=["vae_name"]
)
def transform(self, inputs: Dict) -> Dict:
"""Extract VAE information"""
vae_name = inputs.get("vae_name", "")
# Remove path prefix if present
if "/" in vae_name or "\\" in vae_name:
# Get just the filename without path or extension
vae_name = vae_name.replace("\\", "/").split("/")[-1]
vae_name = vae_name.split(".")[0] # Remove extension
return {"vae": vae_name}
# Note: No need to register manually - extensions are automatically registered
# when the extension system loads this file

View File

@@ -0,0 +1,74 @@
"""
KJNodes mappers extension for ComfyUI workflow parsing
"""
import logging
import re
from typing import Dict, Any
logger = logging.getLogger(__name__)
# =============================================================================
# Transform Functions
# =============================================================================
def transform_join_strings(inputs: Dict) -> str:
"""Transform function for JoinStrings nodes"""
string1 = inputs.get("string1", "")
string2 = inputs.get("string2", "")
delimiter = inputs.get("delimiter", "")
return f"{string1}{delimiter}{string2}"
def transform_string_constant(inputs: Dict) -> str:
"""Transform function for StringConstant nodes"""
return inputs.get("string", "")
def transform_empty_latent_presets(inputs: Dict) -> Dict:
"""Transform function for EmptyLatentImagePresets nodes"""
dimensions = inputs.get("dimensions", "")
invert = inputs.get("invert", False)
# Extract width and height from dimensions string
# Expected format: "width x height (ratio)" or similar
width = 0
height = 0
if dimensions:
# Try to extract dimensions using regex
match = re.search(r'(\d+)\s*x\s*(\d+)', dimensions)
if match:
width = int(match.group(1))
height = int(match.group(2))
# If invert is True, swap width and height
if invert and width and height:
width, height = height, width
return {"width": width, "height": height, "size": f"{width}x{height}"}
def transform_int_constant(inputs: Dict) -> int:
"""Transform function for INTConstant nodes"""
return inputs.get("value", 0)
# =============================================================================
# Node Mapper Definitions
# =============================================================================
# Define the mappers for KJNodes
NODE_MAPPERS_EXT = {
"JoinStrings": {
"inputs_to_track": ["string1", "string2", "delimiter"],
"transform_func": transform_join_strings
},
"StringConstantMultiline": {
"inputs_to_track": ["string"],
"transform_func": transform_string_constant
},
"EmptyLatentImagePresets": {
"inputs_to_track": ["dimensions", "invert", "batch_size"],
"transform_func": transform_empty_latent_presets
},
"INTConstant": {
"inputs_to_track": ["value"],
"transform_func": transform_int_constant
}
}

View File

@@ -5,375 +5,229 @@ import logging
import os import os
import importlib.util import importlib.util
import inspect import inspect
from typing import Dict, List, Any, Optional, Union, Type, Callable from typing import Dict, List, Any, Optional, Union, Type, Callable, Tuple
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global mapper registry # Global mapper registry
_MAPPER_REGISTRY: Dict[str, 'NodeMapper'] = {} _MAPPER_REGISTRY: Dict[str, Dict] = {}
class NodeMapper:
"""Base class for node mappers that define how to extract information from a specific node type"""
def __init__(self, node_type: str, inputs_to_track: List[str]):
self.node_type = node_type
self.inputs_to_track = inputs_to_track
def process(self, node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore
"""Process the node and extract relevant information"""
result = {}
for input_name in self.inputs_to_track:
if input_name in node_data.get("inputs", {}):
input_value = node_data["inputs"][input_name]
# Check if input is a reference to another node's output
if isinstance(input_value, list) and len(input_value) == 2:
# Format is [node_id, output_slot]
try:
ref_node_id, output_slot = input_value
# Convert node_id to string if it's an integer
if isinstance(ref_node_id, int):
ref_node_id = str(ref_node_id)
# Recursively process the referenced node
ref_value = parser.process_node(ref_node_id, workflow)
# Store the processed value
if ref_value is not None:
result[input_name] = ref_value
else:
# If we couldn't get a value from the reference, store the raw value
result[input_name] = input_value
except Exception as e:
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
# If we couldn't process the reference, store the raw value
result[input_name] = input_value
else:
# Direct value
result[input_name] = input_value
# Apply any transformations
return self.transform(result)
def transform(self, inputs: Dict) -> Any:
"""Transform the extracted inputs - override in subclasses"""
return inputs
class KSamplerMapper(NodeMapper):
"""Mapper for KSampler nodes"""
def __init__(self):
super().__init__(
node_type="KSampler",
inputs_to_track=["seed", "steps", "cfg", "sampler_name", "scheduler",
"denoise", "positive", "negative", "latent_image",
"model", "clip_skip"]
)
def transform(self, inputs: Dict) -> Dict:
result = {
"seed": str(inputs.get("seed", "")),
"steps": str(inputs.get("steps", "")),
"cfg": str(inputs.get("cfg", "")),
"sampler": inputs.get("sampler_name", ""),
"scheduler": inputs.get("scheduler", ""),
}
# Process positive prompt
if "positive" in inputs:
result["prompt"] = inputs["positive"]
# Process negative prompt
if "negative" in inputs:
result["negative_prompt"] = inputs["negative"]
# Get dimensions from latent image
if "latent_image" in inputs and isinstance(inputs["latent_image"], dict):
width = inputs["latent_image"].get("width", 0)
height = inputs["latent_image"].get("height", 0)
if width and height:
result["size"] = f"{width}x{height}"
# Add clip_skip if present
if "clip_skip" in inputs:
result["clip_skip"] = str(inputs.get("clip_skip", ""))
return result
class EmptyLatentImageMapper(NodeMapper):
"""Mapper for EmptyLatentImage nodes"""
def __init__(self):
super().__init__(
node_type="EmptyLatentImage",
inputs_to_track=["width", "height", "batch_size"]
)
def transform(self, inputs: Dict) -> Dict:
width = inputs.get("width", 0)
height = inputs.get("height", 0)
return {"width": width, "height": height, "size": f"{width}x{height}"}
class EmptySD3LatentImageMapper(NodeMapper):
"""Mapper for EmptySD3LatentImage nodes"""
def __init__(self):
super().__init__(
node_type="EmptySD3LatentImage",
inputs_to_track=["width", "height", "batch_size"]
)
def transform(self, inputs: Dict) -> Dict:
width = inputs.get("width", 0)
height = inputs.get("height", 0)
return {"width": width, "height": height, "size": f"{width}x{height}"}
class CLIPTextEncodeMapper(NodeMapper):
"""Mapper for CLIPTextEncode nodes"""
def __init__(self):
super().__init__(
node_type="CLIPTextEncode",
inputs_to_track=["text", "clip"]
)
def transform(self, inputs: Dict) -> Any:
# Simply return the text
return inputs.get("text", "")
class LoraLoaderMapper(NodeMapper):
"""Mapper for LoraLoader nodes"""
def __init__(self):
super().__init__(
node_type="Lora Loader (LoraManager)",
inputs_to_track=["loras", "lora_stack"]
)
def transform(self, inputs: Dict) -> Dict:
# Fallback to loras array if text field doesn't exist or is invalid
loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
# Process loras array - filter active entries
lora_texts = []
# Check if loras_data is a list or a dict with __value__ key (new format)
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
logger.info(f"Lora: {lora}, active: {lora.get('active')}")
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = lora.get("strength", 1.0)
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Process lora_stack if it exists and is a valid format (list of tuples)
if lora_stack and isinstance(lora_stack, list):
# If lora_stack is a reference to another node ([node_id, output_slot]),
# we don't process it here as it's already been processed recursively
if len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int):
# This is a reference to another node, already processed
pass
else:
# Format each entry from the stack (assuming it's a list of tuples)
for stack_entry in lora_stack:
lora_name = stack_entry[0]
strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Join with spaces
combined_text = " ".join(lora_texts)
return {"loras": combined_text}
class LoraStackerMapper(NodeMapper):
"""Mapper for LoraStacker nodes"""
def __init__(self):
super().__init__(
node_type="Lora Stacker (LoraManager)",
inputs_to_track=["loras", "lora_stack"]
)
def transform(self, inputs: Dict) -> Dict:
loras_data = inputs.get("loras", [])
result_stack = []
# Handle existing stack entries
existing_stack = []
lora_stack_input = inputs.get("lora_stack", [])
# Handle different formats of lora_stack
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
# Format from another LoraStacker node
existing_stack = lora_stack_input["lora_stack"]
elif isinstance(lora_stack_input, list):
# Direct list format or reference format [node_id, output_slot]
if len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and isinstance(lora_stack_input[1], int):
# This is likely a reference that was already processed
pass
else:
# Regular list of tuples/entries
existing_stack = lora_stack_input
# Add existing entries first
if existing_stack:
result_stack.extend(existing_stack)
# Process loras array - filter active entries
# Check if loras_data is a list or a dict with __value__ key (new format)
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = float(lora.get("strength", 1.0))
result_stack.append((lora_name, strength))
return {"lora_stack": result_stack}
class JoinStringsMapper(NodeMapper):
"""Mapper for JoinStrings nodes"""
def __init__(self):
super().__init__(
node_type="JoinStrings",
inputs_to_track=["string1", "string2", "delimiter"]
)
def transform(self, inputs: Dict) -> str:
string1 = inputs.get("string1", "")
string2 = inputs.get("string2", "")
delimiter = inputs.get("delimiter", "")
return f"{string1}{delimiter}{string2}"
class StringConstantMapper(NodeMapper):
"""Mapper for StringConstant and StringConstantMultiline nodes"""
def __init__(self):
super().__init__(
node_type="StringConstantMultiline",
inputs_to_track=["string"]
)
def transform(self, inputs: Dict) -> str:
return inputs.get("string", "")
class TriggerWordToggleMapper(NodeMapper):
"""Mapper for TriggerWordToggle nodes"""
def __init__(self):
super().__init__(
node_type="TriggerWord Toggle (LoraManager)",
inputs_to_track=["toggle_trigger_words"]
)
def transform(self, inputs: Dict) -> str:
toggle_data = inputs.get("toggle_trigger_words", [])
# check if toggle_words is a list or a dict with __value__ key (new format)
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
toggle_words = toggle_data["__value__"]
elif isinstance(toggle_data, list):
toggle_words = toggle_data
else:
toggle_words = []
# Filter active trigger words
active_words = []
for item in toggle_words:
if isinstance(item, dict) and item.get("active", False):
word = item.get("text", "")
if word and not word.startswith("__dummy"):
active_words.append(word)
# Join with commas
result = ", ".join(active_words)
return result
class FluxGuidanceMapper(NodeMapper):
"""Mapper for FluxGuidance nodes"""
def __init__(self):
super().__init__(
node_type="FluxGuidance",
inputs_to_track=["guidance", "conditioning"]
)
def transform(self, inputs: Dict) -> Dict:
result = {}
# Handle guidance parameter
if "guidance" in inputs:
result["guidance"] = inputs["guidance"]
# Handle conditioning (the prompt text)
if "conditioning" in inputs:
conditioning = inputs["conditioning"]
if isinstance(conditioning, str):
result["prompt"] = conditioning
else:
result["prompt"] = "Unknown prompt"
return result
# ============================================================================= # =============================================================================
# Mapper Registry Functions # Mapper Definition Functions
# ============================================================================= # =============================================================================
def register_mapper(mapper: NodeMapper) -> None: def create_mapper(
node_type: str,
inputs_to_track: List[str],
transform_func: Callable[[Dict], Any] = None
) -> Dict:
"""Create a mapper definition for a node type"""
mapper = {
"node_type": node_type,
"inputs_to_track": inputs_to_track,
"transform": transform_func or (lambda inputs: inputs)
}
return mapper
def register_mapper(mapper: Dict) -> None:
"""Register a node mapper in the global registry""" """Register a node mapper in the global registry"""
_MAPPER_REGISTRY[mapper.node_type] = mapper _MAPPER_REGISTRY[mapper["node_type"]] = mapper
logger.debug(f"Registered mapper for node type: {mapper.node_type}") logger.debug(f"Registered mapper for node type: {mapper['node_type']}")
def get_mapper(node_type: str) -> Optional[NodeMapper]: def get_mapper(node_type: str) -> Optional[Dict]:
"""Get a mapper for the specified node type""" """Get a mapper for the specified node type"""
return _MAPPER_REGISTRY.get(node_type) return _MAPPER_REGISTRY.get(node_type)
def get_all_mappers() -> Dict[str, NodeMapper]: def get_all_mappers() -> Dict[str, Dict]:
"""Get all registered mappers""" """Get all registered mappers"""
return _MAPPER_REGISTRY.copy() return _MAPPER_REGISTRY.copy()
def register_default_mappers() -> None: # =============================================================================
"""Register all default mappers""" # Node Processing Function
default_mappers = [ # =============================================================================
KSamplerMapper(),
EmptyLatentImageMapper(), def process_node(node_id: str, node_data: Dict, workflow: Dict, parser: 'WorkflowParser') -> Any: # type: ignore
EmptySD3LatentImageMapper(), """Process a node using its mapper and extract relevant information"""
CLIPTextEncodeMapper(), node_type = node_data.get("class_type")
LoraLoaderMapper(), mapper = get_mapper(node_type)
LoraStackerMapper(),
JoinStringsMapper(),
StringConstantMapper(),
TriggerWordToggleMapper(),
FluxGuidanceMapper()
]
for mapper in default_mappers: if not mapper:
logger.warning(f"No mapper found for node type: {node_type}")
return None
result = {}
# Extract inputs based on the mapper's tracked inputs
for input_name in mapper["inputs_to_track"]:
if input_name in node_data.get("inputs", {}):
input_value = node_data["inputs"][input_name]
# Check if input is a reference to another node's output
if isinstance(input_value, list) and len(input_value) == 2:
try:
# Format is [node_id, output_slot]
ref_node_id, output_slot = input_value
# Convert node_id to string if it's an integer
if isinstance(ref_node_id, int):
ref_node_id = str(ref_node_id)
# Recursively process the referenced node
ref_value = parser.process_node(ref_node_id, workflow)
if ref_value is not None:
result[input_name] = ref_value
else:
# If we couldn't get a value from the reference, store the raw value
result[input_name] = input_value
except Exception as e:
logger.error(f"Error processing reference in node {node_id}, input {input_name}: {e}")
result[input_name] = input_value
else:
# Direct value
result[input_name] = input_value
# Apply the transform function
try:
return mapper["transform"](result)
except Exception as e:
logger.error(f"Error in transform function for node {node_id} of type {node_type}: {e}")
return result
# =============================================================================
# Transform Functions
# =============================================================================
def transform_lora_loader(inputs: Dict) -> Dict:
"""Transform function for LoraLoader nodes"""
loras_data = inputs.get("loras", [])
lora_stack = inputs.get("lora_stack", {}).get("lora_stack", [])
lora_texts = []
# Process loras array
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
# Process each active lora entry
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = lora.get("strength", 1.0)
lora_texts.append(f"<lora:{lora_name}:{strength}>")
# Process lora_stack if valid
if lora_stack and isinstance(lora_stack, list):
if not (len(lora_stack) == 2 and isinstance(lora_stack[0], (str, int)) and isinstance(lora_stack[1], int)):
for stack_entry in lora_stack:
lora_name = stack_entry[0]
strength = stack_entry[1]
lora_texts.append(f"<lora:{lora_name}:{strength}>")
result = {
"checkpoint": inputs.get("model", {}).get("checkpoint", ""),
"loras": " ".join(lora_texts)
}
if "clip" in inputs and isinstance(inputs["clip"], dict):
result["clip_skip"] = inputs["clip"].get("clip_skip", "-1")
return result
def transform_lora_stacker(inputs: Dict) -> Dict:
"""Transform function for LoraStacker nodes"""
loras_data = inputs.get("loras", [])
result_stack = []
# Handle existing stack entries
existing_stack = []
lora_stack_input = inputs.get("lora_stack", [])
if isinstance(lora_stack_input, dict) and "lora_stack" in lora_stack_input:
existing_stack = lora_stack_input["lora_stack"]
elif isinstance(lora_stack_input, list):
if not (len(lora_stack_input) == 2 and isinstance(lora_stack_input[0], (str, int)) and
isinstance(lora_stack_input[1], int)):
existing_stack = lora_stack_input
# Add existing entries
if existing_stack:
result_stack.extend(existing_stack)
# Process new loras
if isinstance(loras_data, dict) and "__value__" in loras_data:
loras_list = loras_data["__value__"]
elif isinstance(loras_data, list):
loras_list = loras_data
else:
loras_list = []
for lora in loras_list:
if isinstance(lora, dict) and lora.get("active", False):
lora_name = lora.get("name", "")
strength = float(lora.get("strength", 1.0))
result_stack.append((lora_name, strength))
return {"lora_stack": result_stack}
def transform_trigger_word_toggle(inputs: Dict) -> str:
"""Transform function for TriggerWordToggle nodes"""
toggle_data = inputs.get("toggle_trigger_words", [])
if isinstance(toggle_data, dict) and "__value__" in toggle_data:
toggle_words = toggle_data["__value__"]
elif isinstance(toggle_data, list):
toggle_words = toggle_data
else:
toggle_words = []
# Filter active trigger words
active_words = []
for item in toggle_words:
if isinstance(item, dict) and item.get("active", False):
word = item.get("text", "")
if word and not word.startswith("__dummy"):
active_words.append(word)
return ", ".join(active_words)
# =============================================================================
# Node Mapper Definitions
# =============================================================================
# Central definition of all supported node types and their configurations
NODE_MAPPERS = {
# LoraManager nodes
"Lora Loader (LoraManager)": {
"inputs_to_track": ["model", "clip", "loras", "lora_stack"],
"transform_func": transform_lora_loader
},
"Lora Stacker (LoraManager)": {
"inputs_to_track": ["loras", "lora_stack"],
"transform_func": transform_lora_stacker
},
"TriggerWord Toggle (LoraManager)": {
"inputs_to_track": ["toggle_trigger_words"],
"transform_func": transform_trigger_word_toggle
}
}
def register_all_mappers() -> None:
"""Register all mappers from the NODE_MAPPERS dictionary"""
for node_type, config in NODE_MAPPERS.items():
mapper = create_mapper(
node_type=node_type,
inputs_to_track=config["inputs_to_track"],
transform_func=config["transform_func"]
)
register_mapper(mapper) register_mapper(mapper)
logger.info(f"Registered {len(NODE_MAPPERS)} node mappers")
# ============================================================================= # =============================================================================
# Extension Loading # Extension Loading
@@ -383,8 +237,8 @@ def load_extensions(ext_dir: str = None) -> None:
""" """
Load mapper extensions from the specified directory Load mapper extensions from the specified directory
Each Python file in the directory will be loaded, and any NodeMapper subclasses Extension files should define a NODE_MAPPERS_EXT dictionary containing mapper configurations.
defined in those files will be automatically registered. These will be added to the global NODE_MAPPERS dictionary and registered automatically.
""" """
# Use default path if none provided # Use default path if none provided
if ext_dir is None: if ext_dir is None:
@@ -411,18 +265,18 @@ def load_extensions(ext_dir: str = None) -> None:
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) spec.loader.exec_module(module)
# Find all NodeMapper subclasses in the module # Check if the module defines NODE_MAPPERS_EXT
for name, obj in inspect.getmembers(module): if hasattr(module, 'NODE_MAPPERS_EXT'):
if (inspect.isclass(obj) and issubclass(obj, NodeMapper) # Add the extension mappers to the global NODE_MAPPERS dictionary
and obj != NodeMapper and hasattr(obj, 'node_type')): NODE_MAPPERS.update(module.NODE_MAPPERS_EXT)
# Instantiate and register the mapper logger.info(f"Added {len(module.NODE_MAPPERS_EXT)} mappers from extension: {filename}")
mapper = obj() else:
register_mapper(mapper) logger.warning(f"Extension {filename} does not define NODE_MAPPERS_EXT dictionary")
logger.info(f"Loaded extension mapper: {mapper.node_type} from {filename}")
except Exception as e: except Exception as e:
logger.warning(f"Error loading extension {filename}: {e}") logger.warning(f"Error loading extension {filename}: {e}")
# Re-register all mappers after loading extensions
register_all_mappers()
# Initialize the registry with default mappers # Initialize the registry with default mappers
register_default_mappers() # register_default_mappers()

View File

@@ -4,7 +4,7 @@ Main workflow parser implementation for ComfyUI
import json import json
import logging import logging
from typing import Dict, List, Any, Optional, Union, Set from typing import Dict, List, Any, Optional, Union, Set
from .mappers import get_mapper, get_all_mappers, load_extensions from .mappers import get_mapper, get_all_mappers, load_extensions, process_node
from .utils import ( from .utils import (
load_workflow, save_output, find_node_by_type, load_workflow, save_output, find_node_by_type,
trace_model_path trace_model_path
@@ -15,14 +15,13 @@ logger = logging.getLogger(__name__)
class WorkflowParser: class WorkflowParser:
"""Parser for ComfyUI workflows""" """Parser for ComfyUI workflows"""
def __init__(self, load_extensions_on_init: bool = True): def __init__(self):
"""Initialize the parser with mappers""" """Initialize the parser with mappers"""
self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles self.processed_nodes: Set[str] = set() # Track processed nodes to avoid cycles
self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results self.node_results_cache: Dict[str, Any] = {} # Cache for processed node results
# Load extensions if requested # Load extensions
if load_extensions_on_init: load_extensions()
load_extensions()
def process_node(self, node_id: str, workflow: Dict) -> Any: def process_node(self, node_id: str, workflow: Dict) -> Any:
"""Process a single node and extract relevant information""" """Process a single node and extract relevant information"""
@@ -45,10 +44,9 @@ class WorkflowParser:
node_type = node_data.get("class_type") node_type = node_data.get("class_type")
result = None result = None
mapper = get_mapper(node_type) if get_mapper(node_type):
if mapper:
try: try:
result = mapper.process(node_id, node_data, workflow, self) result = process_node(node_id, node_data, workflow, self)
# Cache the result # Cache the result
self.node_results_cache[node_id] = result self.node_results_cache[node_id] = result
except Exception as e: except Exception as e:
@@ -60,32 +58,58 @@ class WorkflowParser:
self.processed_nodes.remove(node_id) self.processed_nodes.remove(node_id)
return result return result
def collect_loras_from_model(self, model_input: List, workflow: Dict) -> str: def find_primary_sampler_node(self, workflow: Dict) -> Optional[str]:
"""Collect loras information from the model node chain""" """
if not isinstance(model_input, list) or len(model_input) != 2: Find the primary sampler node in the workflow.
return ""
model_node_id, _ = model_input
# Convert node_id to string if it's an integer
if isinstance(model_node_id, int):
model_node_id = str(model_node_id)
# Process the model node
model_result = self.process_node(model_node_id, workflow)
# If this is a Lora Loader node, return the loras text Priority:
if model_result and isinstance(model_result, dict) and "loras" in model_result: 1. First try to find a SamplerCustomAdvanced node
return model_result["loras"] 2. If not found, look for KSampler nodes with denoise=1.0
3. If still not found, use the first KSampler node
# If not a lora loader, check the node's inputs for a model connection
node_data = workflow.get(model_node_id, {})
inputs = node_data.get("inputs", {})
# If this node has a model input, follow that path Args:
if "model" in inputs and isinstance(inputs["model"], list): workflow: The workflow data as a dictionary
return self.collect_loras_from_model(inputs["model"], workflow)
return "" Returns:
The node ID of the primary sampler node, or None if not found
"""
# First check for SamplerCustomAdvanced nodes
sampler_advanced_nodes = []
ksampler_nodes = []
# Scan workflow for sampler nodes
for node_id, node_data in workflow.items():
node_type = node_data.get("class_type")
if node_type == "SamplerCustomAdvanced":
sampler_advanced_nodes.append(node_id)
elif node_type == "KSampler":
ksampler_nodes.append(node_id)
# If we found SamplerCustomAdvanced nodes, return the first one
if sampler_advanced_nodes:
logger.debug(f"Found SamplerCustomAdvanced node: {sampler_advanced_nodes[0]}")
return sampler_advanced_nodes[0]
# If we have KSampler nodes, look for one with denoise=1.0
if ksampler_nodes:
for node_id in ksampler_nodes:
node_data = workflow[node_id]
inputs = node_data.get("inputs", {})
denoise = inputs.get("denoise", 0)
# Check if denoise is 1.0 (allowing for small floating point differences)
if abs(float(denoise) - 1.0) < 0.001:
logger.debug(f"Found KSampler node with denoise=1.0: {node_id}")
return node_id
# If no KSampler with denoise=1.0 found, use the first one
logger.debug(f"No KSampler with denoise=1.0 found, using first KSampler: {ksampler_nodes[0]}")
return ksampler_nodes[0]
# No sampler nodes found
logger.warning("No sampler nodes found in workflow")
return None
def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict: def parse_workflow(self, workflow_data: Union[str, Dict], output_path: Optional[str] = None) -> Dict:
""" """
@@ -108,77 +132,38 @@ class WorkflowParser:
self.processed_nodes = set() self.processed_nodes = set()
self.node_results_cache = {} self.node_results_cache = {}
# Find the KSampler node # Find the primary sampler node
ksampler_node_id = find_node_by_type(workflow, "KSampler") sampler_node_id = self.find_primary_sampler_node(workflow)
if not ksampler_node_id: if not sampler_node_id:
logger.warning("No KSampler node found in workflow") logger.warning("No suitable sampler node found in workflow")
return {} return {}
# Start parsing from the KSampler node # Process sampler node to extract parameters
result = { sampler_result = self.process_node(sampler_node_id, workflow)
"gen_params": {}, if not sampler_result:
"loras": "" return {}
}
# Process KSampler node to extract parameters # Return the sampler result directly - it's already in the format we need
ksampler_result = self.process_node(ksampler_node_id, workflow) # This simplifies the structure and makes it easier to use in recipe_routes.py
if ksampler_result:
# Process the result
for key, value in ksampler_result.items():
# Special handling for the positive prompt from FluxGuidance
if key == "positive" and isinstance(value, dict):
# Extract guidance value
if "guidance" in value:
result["gen_params"]["guidance"] = value["guidance"]
# Extract prompt
if "prompt" in value:
result["gen_params"]["prompt"] = value["prompt"]
else:
# Normal handling for other values
result["gen_params"][key] = value
# Process the positive prompt node if it exists and we don't have a prompt yet
if "prompt" not in result["gen_params"] and "positive" in ksampler_result:
positive_value = ksampler_result.get("positive")
if isinstance(positive_value, str):
result["gen_params"]["prompt"] = positive_value
# Manually check for FluxGuidance if we don't have guidance value
if "guidance" not in result["gen_params"]:
flux_node_id = find_node_by_type(workflow, "FluxGuidance")
if flux_node_id:
# Get the direct input from the node
node_inputs = workflow[flux_node_id].get("inputs", {})
if "guidance" in node_inputs:
result["gen_params"]["guidance"] = node_inputs["guidance"]
# Extract loras from the model input of KSampler
ksampler_node = workflow.get(ksampler_node_id, {})
ksampler_inputs = ksampler_node.get("inputs", {})
if "model" in ksampler_inputs and isinstance(ksampler_inputs["model"], list):
loras_text = self.collect_loras_from_model(ksampler_inputs["model"], workflow)
if loras_text:
result["loras"] = loras_text
# Handle standard ComfyUI names vs our output format # Handle standard ComfyUI names vs our output format
if "cfg" in result["gen_params"]: if "cfg" in sampler_result:
result["gen_params"]["cfg_scale"] = result["gen_params"].pop("cfg") sampler_result["cfg_scale"] = sampler_result.pop("cfg")
# Add clip_skip = 2 to match reference output if not already present # Add clip_skip = 1 to match reference output if not already present
if "clip_skip" not in result["gen_params"]: if "clip_skip" not in sampler_result:
result["gen_params"]["clip_skip"] = "2" sampler_result["clip_skip"] = "1"
# Ensure the prompt is a string and not a nested dictionary # Ensure the prompt is a string and not a nested dictionary
if "prompt" in result["gen_params"] and isinstance(result["gen_params"]["prompt"], dict): if "prompt" in sampler_result and isinstance(sampler_result["prompt"], dict):
if "prompt" in result["gen_params"]["prompt"]: if "prompt" in sampler_result["prompt"]:
result["gen_params"]["prompt"] = result["gen_params"]["prompt"]["prompt"] sampler_result["prompt"] = sampler_result["prompt"]["prompt"]
# Save the result if requested # Save the result if requested
if output_path: if output_path:
save_output(result, output_path) save_output(sampler_result, output_path)
return result return sampler_result
def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict: def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dict:
@@ -193,4 +178,4 @@ def parse_workflow(workflow_path: str, output_path: Optional[str] = None) -> Dic
Dictionary containing extracted parameters Dictionary containing extracted parameters
""" """
parser = WorkflowParser() parser = WorkflowParser()
return parser.parse_workflow(workflow_path, output_path) return parser.parse_workflow(workflow_path, output_path)

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "comfyui-lora-manager" name = "comfyui-lora-manager"
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration." description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
version = "0.8.0" version = "0.8.8"
license = {file = "LICENSE"} license = {file = "LICENSE"}
dependencies = [ dependencies = [
"aiohttp", "aiohttp",
@@ -11,6 +11,7 @@ dependencies = [
"beautifulsoup4", "beautifulsoup4",
"piexif", "piexif",
"Pillow", "Pillow",
"olefile", # for getting rid of warning message
"requests" "requests"
] ]

View File

@@ -95,7 +95,6 @@
"onSite": false, "onSite": false,
"remixOfId": null "remixOfId": null
} }
// more images here
], ],
"downloadUrl": "https://civitai.com/api/download/models/1387174" "downloadUrl": "https://civitai.com/api/download/models/1387174"
} }

View File

@@ -0,0 +1,153 @@
{
"resource-stack": {
"class_type": "CheckpointLoaderSimple",
"inputs": { "ckpt_name": "urn:air:sdxl:checkpoint:civitai:827184@1410435" }
},
"resource-stack-1": {
"class_type": "LoraLoader",
"inputs": {
"lora_name": "urn:air:sdxl:lora:civitai:1107767@1253442",
"strength_model": 1,
"strength_clip": 1,
"model": ["resource-stack", 0],
"clip": ["resource-stack", 1]
}
},
"resource-stack-2": {
"class_type": "LoraLoader",
"inputs": {
"lora_name": "urn:air:sdxl:lora:civitai:1342708@1516344",
"strength_model": 1,
"strength_clip": 1,
"model": ["resource-stack-1", 0],
"clip": ["resource-stack-1", 1]
}
},
"resource-stack-3": {
"class_type": "LoraLoader",
"inputs": {
"lora_name": "urn:air:sdxl:lora:civitai:122359@135867",
"strength_model": 1.55,
"strength_clip": 1,
"model": ["resource-stack-2", 0],
"clip": ["resource-stack-2", 1]
}
},
"6": {
"class_type": "smZ CLIPTextEncode",
"inputs": {
"text": "masterpiece, best quality, amazing quality, detailed setting, detailed background, 1girl, yunyun (konosuba), nude, red eyes, hair ornament, braid, hair between eyes,low twintails, pink ribbon, bow, hair bow, pussy, frilled skirt, layered skirt, belt, pink thighhighs, (pussy juice), large insertion, vaginal tugging, pussy grip, detailed skin, detailed soles, stretched pussy, feet in stockings, ass, nipples, medium breasts, french kiss, anus, shocked, nervous, penis awe, BREAK Professor\u0027s office, college student, pornographic, 1boy, close eyes, (musscular male, detailed large cock), vaginal sex, college office setting, ass grab, fucking, riding, cowgirl, erotic, side view, deep fucking",
"parser": "comfy",
"text_g": "",
"text_l": "",
"ascore": 2.5,
"width": 0,
"height": 0,
"crop_w": 0,
"crop_h": 0,
"target_width": 0,
"target_height": 0,
"smZ_steps": 1,
"mean_normalization": true,
"multi_conditioning": true,
"use_old_emphasis_implementation": false,
"with_SDXL": false,
"clip": ["resource-stack-3", 1]
},
"_meta": { "title": "Positive" }
},
"7": {
"class_type": "smZ CLIPTextEncode",
"inputs": {
"text": "bad quality,worst quality,worst detail,sketch,censor",
"parser": "comfy",
"text_g": "",
"text_l": "",
"ascore": 2.5,
"width": 0,
"height": 0,
"crop_w": 0,
"crop_h": 0,
"target_width": 0,
"target_height": 0,
"smZ_steps": 1,
"mean_normalization": true,
"multi_conditioning": true,
"use_old_emphasis_implementation": false,
"with_SDXL": false,
"clip": ["resource-stack-3", 1]
},
"_meta": { "title": "Negative" }
},
"20": {
"class_type": "UpscaleModelLoader",
"inputs": { "model_name": "urn:air:other:upscaler:civitai:147759@164821" },
"_meta": { "title": "Load Upscale Model" }
},
"17": {
"class_type": "LoadImage",
"inputs": {
"image": "https://orchestration.civitai.com/v2/consumer/blobs/5KZ6358TW8CNEGPZKD08NVDB30",
"upload": "image"
},
"_meta": { "title": "Image Load" }
},
"19": {
"class_type": "ImageUpscaleWithModel",
"inputs": { "upscale_model": ["20", 0], "image": ["17", 0] },
"_meta": { "title": "Upscale Image (using Model)" }
},
"23": {
"class_type": "ImageScale",
"inputs": {
"upscale_method": "nearest-exact",
"crop": "disabled",
"width": 1280,
"height": 1856,
"image": ["19", 0]
},
"_meta": { "title": "Upscale Image" }
},
"21": {
"class_type": "VAEEncode",
"inputs": { "pixels": ["23", 0], "vae": ["resource-stack", 2] },
"_meta": { "title": "VAE Encode" }
},
"11": {
"class_type": "KSampler",
"inputs": {
"sampler_name": "euler_ancestral",
"scheduler": "normal",
"seed": 2088370631,
"steps": 47,
"cfg": 6.5,
"denoise": 0.3,
"model": ["resource-stack-3", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["21", 0]
},
"_meta": { "title": "KSampler" }
},
"13": {
"class_type": "VAEDecode",
"inputs": { "samples": ["11", 0], "vae": ["resource-stack", 2] },
"_meta": { "title": "VAE Decode" }
},
"12": {
"class_type": "SaveImage",
"inputs": { "filename_prefix": "ComfyUI", "images": ["13", 0] },
"_meta": { "title": "Save Image" }
},
"extra": {
"airs": [
"urn:air:other:upscaler:civitai:147759@164821",
"urn:air:sdxl:checkpoint:civitai:827184@1410435",
"urn:air:sdxl:lora:civitai:1107767@1253442",
"urn:air:sdxl:lora:civitai:1342708@1516344",
"urn:air:sdxl:lora:civitai:122359@135867"
]
},
"extraMetadata": "{\u0022prompt\u0022:\u0022masterpiece, best quality, amazing quality, detailed setting, detailed background, 1girl, yunyun (konosuba), nude, red eyes, hair ornament, braid, hair between eyes,low twintails, pink ribbon, bow, hair bow, pussy, frilled skirt, layered skirt, belt, pink thighhighs, (pussy juice), large insertion, vaginal tugging, pussy grip, detailed skin, detailed soles, stretched pussy, feet in stockings, ass, nipples, medium breasts, french kiss, anus, shocked, nervous, penis awe, BREAK Professor\u0027s office, college student, pornographic, 1boy, close eyes, (musscular male, detailed large cock), vaginal sex, college office setting, ass grab, fucking, riding, cowgirl, erotic, side view, deep fucking\u0022,\u0022negativePrompt\u0022:\u0022bad quality,worst quality,worst detail,sketch,censor\u0022,\u0022steps\u0022:47,\u0022cfgScale\u0022:6.5,\u0022sampler\u0022:\u0022euler_ancestral\u0022,\u0022workflowId\u0022:\u0022img2img-hires\u0022,\u0022resources\u0022:[{\u0022modelVersionId\u0022:1410435,\u0022strength\u0022:1},{\u0022modelVersionId\u0022:1410435,\u0022strength\u0022:1},{\u0022modelVersionId\u0022:1253442,\u0022strength\u0022:1},{\u0022modelVersionId\u0022:1516344,\u0022strength\u0022:1},{\u0022modelVersionId\u0022:135867,\u0022strength\u0022:1.55}],\u0022remixOfId\u0022:32140259}"
}

View File

@@ -2,28 +2,17 @@ a dynamic and dramatic digital artwork featuring a stylized anthropomorphic whit
Negative prompt: Negative prompt:
Steps: 30, Sampler: Undefined, CFG scale: 3.5, Seed: 90300501, Size: 832x1216, Clip skip: 2, Created Date: 2025-03-05T13:51:18.1770234Z, Civitai resources: [{"type":"checkpoint","modelVersionId":691639,"modelName":"FLUX","modelVersionName":"Dev"},{"type":"lora","weight":0.4,"modelVersionId":1202162,"modelName":"Velvet\u0027s Mythic Fantasy Styles | Flux \u002B Pony \u002B illustrious","modelVersionName":"Flux Gothic Lines"},{"type":"lora","weight":0.8,"modelVersionId":1470588,"modelName":"Velvet\u0027s Mythic Fantasy Styles | Flux \u002B Pony \u002B illustrious","modelVersionName":"Flux Retro"},{"type":"lora","weight":0.75,"modelVersionId":746484,"modelName":"Elden Ring - Yoshitaka Amano","modelVersionName":"V1"},{"type":"lora","weight":0.2,"modelVersionId":914935,"modelName":"Ink-style","modelVersionName":"ink-dynamic"},{"type":"lora","weight":0.2,"modelVersionId":1189379,"modelName":"Painterly Fantasy by ChronoKnight - [FLUX \u0026 IL]","modelVersionName":"FLUX"},{"type":"lora","weight":0.2,"modelVersionId":757030,"modelName":"Mezzotint Artstyle for Flux - by Ethanar","modelVersionName":"V1"}], Civitai metadata: {} Steps: 30, Sampler: Undefined, CFG scale: 3.5, Seed: 90300501, Size: 832x1216, Clip skip: 2, Created Date: 2025-03-05T13:51:18.1770234Z, Civitai resources: [{"type":"checkpoint","modelVersionId":691639,"modelName":"FLUX","modelVersionName":"Dev"},{"type":"lora","weight":0.4,"modelVersionId":1202162,"modelName":"Velvet\u0027s Mythic Fantasy Styles | Flux \u002B Pony \u002B illustrious","modelVersionName":"Flux Gothic Lines"},{"type":"lora","weight":0.8,"modelVersionId":1470588,"modelName":"Velvet\u0027s Mythic Fantasy Styles | Flux \u002B Pony \u002B illustrious","modelVersionName":"Flux Retro"},{"type":"lora","weight":0.75,"modelVersionId":746484,"modelName":"Elden Ring - Yoshitaka Amano","modelVersionName":"V1"},{"type":"lora","weight":0.2,"modelVersionId":914935,"modelName":"Ink-style","modelVersionName":"ink-dynamic"},{"type":"lora","weight":0.2,"modelVersionId":1189379,"modelName":"Painterly Fantasy by ChronoKnight - [FLUX \u0026 IL]","modelVersionName":"FLUX"},{"type":"lora","weight":0.2,"modelVersionId":757030,"modelName":"Mezzotint Artstyle for Flux - by Ethanar","modelVersionName":"V1"}], Civitai metadata: {}
<lora:ck-shadow-circuit-IL:0.78>,
masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject,
dynamic angle, dutch angle, from below, epic half body portrait, gritty, wabi sabi, looking at viewer, woman is a geisha, parted lips, dynamic angle, dutch angle, from below, epic half body portrait, gritty, wabi sabi, looking at viewer, woman is a geisha, parted lips,
holographic skin, holofoil glitter, faint, glowing, ethereal, neon hair, glowing hair, otherworldly glow, she is dangerous, holographic skin, holofoil glitter, faint, glowing, ethereal, neon hair, glowing hair, otherworldly glow, she is dangerous
<lora:ck-nc-cyberpunk-IL-000011:0.4> <lora:ck-shadow-circuit-IL:0.78>, <lora:ck-nc-cyberpunk-IL-000011:0.4>, <lora:ck-neon-retrowave-IL:0.2>, <lora:ck-yoneyama-mai-IL-000014:0.4>
<lora:ck-neon-retrowave-IL:0.2>
<lora:ck-yoneyama-mai-IL-000014:0.4>
Negative prompt: score_6, score_5, score_4, bad quality, worst quality, worst detail, sketch, censorship, furry, window, headphones, Negative prompt: score_6, score_5, score_4, bad quality, worst quality, worst detail, sketch, censorship, furry, window, headphones,
Steps: 30, Sampler: Euler a, Schedule type: Simple, CFG scale: 7, Seed: 1405717592, Size: 832x1216, Model hash: 1ad6ca7f70, Model: waiNSFWIllustrious_v100, Denoising strength: 0.35, Hires CFG Scale: 5, Hires upscale: 1.3, Hires steps: 20, Hires upscaler: 4x-AnimeSharp, Lora hashes: "ck-shadow-circuit-IL: 88e247aa8c3d, ck-nc-cyberpunk-IL-000011: 935e6755554c, ck-neon-retrowave-IL: edafb9df7da1, ck-yoneyama-mai-IL-000014: 1b9305692a2e", Version: f2.0.1v1.10.1-1.10.1, Diffusion in Low Bits: Automatic (fp16 LoRA) Steps: 30, Sampler: Euler a, Schedule type: Simple, CFG scale: 7, Seed: 1405717592, Size: 832x1216, Model hash: 1ad6ca7f70, Model: waiNSFWIllustrious_v100, Denoising strength: 0.35, Hires CFG Scale: 5, Hires upscale: 1.3, Hires steps: 20, Hires upscaler: 4x-AnimeSharp, Lora hashes: "ck-shadow-circuit-IL: 88e247aa8c3d, ck-nc-cyberpunk-IL-000011: 935e6755554c, ck-neon-retrowave-IL: edafb9df7da1, ck-yoneyama-mai-IL-000014: 1b9305692a2e", Version: f2.0.1v1.10.1-1.10.1, Diffusion in Low Bits: Automatic (fp16 LoRA)
masterpiece, best quality,high quality, newest, highres,8K,HDR,absurdres, 1girl, solo, steampunk aesthetic, mechanical monocle, long trench coat, leather gloves, brass accessories, intricate clockwork rifle, aiming at viewer, wind-blown scarf, high boots, fingerless gloves, pocket watch, corset, brown and gold color scheme, industrial cityscape, smoke and gears, atmospheric lighting, depth of field, dynamic pose, dramatic composition, detailed background, foreshortening, detailed background, dynamic pose, dynamic composition,dutch angle, detailed backgroud,foreshortening,blurry edges <lora:iLLMythAn1m3Style:1> MythAn1m3
Negative prompt: worst quality, normal quality, anatomical nonsense, bad anatomy,interlocked fingers, extra fingers,watermark,simple background, loli,
Steps: 35, Sampler: DPM++ 2M SDE, Schedule type: Karras, CFG scale: 4, Seed: 3537159932, Size: 1072x1376, Model hash: c364bbdae9, Model: waiNSFWIllustrious_v110, Clip skip: 2, ADetailer model: face_yolov8n.pt, ADetailer confidence: 0.3, ADetailer dilate erode: 4, ADetailer mask blur: 4, ADetailer denoising strength: 0.4, ADetailer inpaint only masked: True, ADetailer inpaint padding: 32, ADetailer version: 24.8.0, Lora hashes: "iLLMythAn1m3Style: d3480076057b", Version: f2.0.1v1.10.1-previous-519-g44eb4ea8, Module 1: sdxl.vae
Masterpiece, best quality, high quality, newest, highres, 8K, HDR, absurdres, 1girl, solo, futuristic warrior, sleek exosuit with glowing energy cores, long braided hair flowing behind, gripping a high-tech bow with an energy arrow drawn, standing on a floating platform overlooking a massive space station, planets and nebulae in the distance, soft glow from distant stars, cinematic depth, foreshortening, dynamic pose, dramatic sci-fi lighting. Masterpiece, best quality, high quality, newest, highres, 8K, HDR, absurdres, 1girl, solo, futuristic warrior, sleek exosuit with glowing energy cores, long braided hair flowing behind, gripping a high-tech bow with an energy arrow drawn, standing on a floating platform overlooking a massive space station, planets and nebulae in the distance, soft glow from distant stars, cinematic depth, foreshortening, dynamic pose, dramatic sci-fi lighting.
Negative prompt: worst quality, normal quality, anatomical nonsense, bad anatomy,interlocked fingers, extra fingers,watermark,simple background, loli, Negative prompt: worst quality, normal quality, anatomical nonsense, bad anatomy,interlocked fingers, extra fingers,watermark,simple background, loli,
Steps: 20, Sampler: euler_ancestral_karras, CFG scale: 8.0, Seed: 691121152183439, Model: il\waiNSFWIllustrious_v110.safetensors, Model hash: c3688ee04c, Lora_0 Model name: iLLMythAn1m3Style.safetensors, Lora_0 Model hash: ba7a040786, Lora_0 Strength model: 1.0, Lora_0 Strength clip: 1.0, Hashes: {"model": "c3688ee04c", "lora:iLLMythAn1m3Style": "ba7a040786"} Steps: 20, Sampler: euler_ancestral_karras, CFG scale: 8.0, Seed: 691121152183439, Model: il\waiNSFWIllustrious_v110.safetensors, Model hash: c3688ee04c, Lora_0 Model name: iLLMythAn1m3Style.safetensors, Lora_0 Model hash: ba7a040786, Lora_0 Strength model: 1.0, Lora_0 Strength clip: 1.0, Hashes: {"model": "c3688ee04c", "lora:iLLMythAn1m3Style": "ba7a040786"}
Masterpiece, best quality, high quality, newest, highres, 8K, HDR, absurdres, 1boy, solo, gothic horror, pale vampire lord in regal, intricately detailed robes, crimson eyes glowing under the dim candlelight of a grand but decayed castle hall, holding a silver goblet filled with an unknown substance, a massive stained-glass window shattered behind him, cold mist rolling in, dramatic lighting, dark yet elegant aesthetic, foreshortening, cinematic perspective. Immerse yourself in the enchanting journey, where harmonious transmutation of Bauhaus art unites photographic precision and contemporary illustration, capturing an enthralling blend between vivid abstract nature and urban landscapes. Let your eyes be captivated by a kaleidoscope of rich, deep reds and yellows, entwined with intriguing shades that beckon a somber atmosphere. As your spirit ventures along this haunting path, witness the mysterious, high-angle perspective dominated by scattered clouds granting you a mesmerizing glimpse into the ever-transforming realm of metamorphosing environments. ,<lora:flux/fav/ck-charcoal-drawing-000014.safetensors:1.0:1.0>
Negative prompt: worst quality, normal quality, anatomical nonsense, bad anatomy,interlocked fingers, extra fingers,watermark,simple background, loli,
Steps: 20, Sampler: euler_ancestral_karras, CFG scale: 8.0, Seed: 290117945770094, Model: il\waiNSFWIllustrious_v110.safetensors, Model hash: c3688ee04c, Lora_0 Model name: iLLMythAn1m3Style.safetensors, Lora_0 Model hash: ba7a040786, Lora_0 Strength model: 0.6, Lora_0 Strength clip: 0.7000000000000001, Hashes: {"model": "c3688ee04c", "lora:iLLMythAn1m3Style": "ba7a040786"}
bo-exposure, An impressionistic oil painting in the style of J.M.W. Turner, depicting a ghostly ship sailing through a sea of swirling golden mist. The waves crash and dissolve into abstract, fiery strokes of orange and deep indigo, blurring the line between ocean and sky. The ship appears almost ethereal, as if drifting between worlds, lost in the ever-changing tides of memory and myth. The dynamic brushstrokes capture the relentless power of nature and the fleeting essence of time.
Negative prompt: Negative prompt:
Steps: 25, Sampler: DPM++ 2M, CFG scale: 3.5, Seed: 1024252061321625, Size: 832x1216, Clip skip: 1, Model hash: , Model: flux_dev, Hashes: {"model": ""}, Version: ComfyUI Steps: 20, Sampler: Euler, CFG scale: 3.5, Seed: 885491426361006, Size: 832x1216, Model hash: 4610115bb0, Model: flux_dev, Hashes: {"LORA:flux/fav/ck-charcoal-drawing-000014.safetensors": "34d36c17c1", "model": "4610115bb0"}, Version: ComfyUI

3
refs/meta_format.txt Normal file
View File

@@ -0,0 +1,3 @@
In this ethereal masterpiece, metallic sculptures juxtapose effortlessly against a subtle backdrop of misty neutral hues. Exquisite curvatures and geometric shapes converge harmoniously, creating an illuminating realm of polished metallic surfaces. Shimmering copper, gleaming silver, and lustrous gold hues dance in perfect balance, highlighting the intricate play of light and shadow cast upon these celestial forms. A halo of diffused radiance envelops each piece, enhancing their textured depths and metallic brilliance while allowing delicate details to emerge from obscurity. The composition conveys a serene yet mesmerizing atmosphere, as if suspended in a dreamlike limbo between reality and fantasy. The tantalizing interplay of colors within this transcendent realm creates a profound sense of depth and grandeur that invites the viewer into an enchanting voyage through abstract metallic beauty. This captivating artwork evokes emotions of boundless curiosity and reverence reminiscent of the timeless works by artists such as Giorgio de Chirico or Paul Klee, while asserting a unique, modern artistic sensibility. With every observation, a new nuance unfolds, as if a never-ending story waiting to be discovered through the lens of metallic artistry.
Negative prompt:
Steps: 25, Sampler: dpmpp_2m_sgm_uniform, Seed: 471889513588087, Model: Fluxmania V5P.safetensors, Model hash: 8ae0583b06, VAE: ae.sft, VAE hash: afc8e28272, Lora_0 Model name: ArtVador I.safetensors, Lora_0 Model hash: 08f7133a58, Lora_0 Strength model: 0.65, Lora_0 Strength clip: 0.65, Lora_1 Model name: Kaoru Yamada.safetensors, Lora_1 Model hash: d4893f7202, Lora_1 Strength model: 0.75, Lora_1 Strength clip: 0.75, Hashes: {"model": "8ae0583b06", "vae": "afc8e28272", "lora:ArtVador I": "08f7133a58", "lora:Kaoru Yamada": "d4893f7202"}

View File

@@ -1,13 +1,11 @@
{ {
"loras": "<lora:ck-neon-retrowave-IL-000012:0.8> <lora:aorunIllstrious:1> <lora:ck-shadow-circuit-IL-000012:0.78> <lora:MoriiMee_Gothic_Niji_Style_Illustrious_r1:0.45> <lora:ck-nc-cyberpunk-IL-000011:0.4>", "loras": "<lora:ck-neon-retrowave-IL-000012:0.8> <lora:aorunIllstrious:1> <lora:ck-shadow-circuit-IL-000012:0.78> <lora:MoriiMee_Gothic_Niji_Style_Illustrious_r1:0.45> <lora:ck-nc-cyberpunk-IL-000011:0.4>",
"gen_params": { "prompt": "in the style of ck-rw, aorun, scales, makeup, bare shoulders, pointy ears, dress, claws, in the style of cksc, artist:moriimee, in the style of cknc, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing",
"prompt": "in the style of ck-rw, aorun, scales, makeup, bare shoulders, pointy ears, dress, claws, in the style of cksc, artist:moriimee, in the style of cknc, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing", "negative_prompt": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw",
"negative_prompt": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw", "steps": "20",
"steps": "20", "sampler": "euler_ancestral",
"sampler": "euler_ancestral", "cfg_scale": "8",
"cfg_scale": "8", "seed": "241",
"seed": "241", "size": "832x1216",
"size": "832x1216", "clip_skip": "2"
"clip_skip": "2"
}
} }

View File

@@ -1,75 +1,12 @@
{ {
"3": {
"inputs": {
"seed": 241,
"steps": 20,
"cfg": 8,
"sampler_name": "euler_ancestral",
"scheduler": "karras",
"denoise": 1,
"model": [
"56",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "il\\waiNSFWIllustrious_v110.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 832,
"height": 1216,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": { "6": {
"inputs": { "inputs": {
"text": [ "text": [
"22", "301",
0 0
], ],
"clip": [ "clip": [
"56", "299",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw",
"clip": [
"56",
1 1
] ]
}, },
@@ -81,12 +18,12 @@
"8": { "8": {
"inputs": { "inputs": {
"samples": [ "samples": [
"3", "13",
0 1
], ],
"vae": [ "vae": [
"4", "10",
2 0
] ]
}, },
"class_type": "VAEDecode", "class_type": "VAEDecode",
@@ -94,7 +31,230 @@
"title": "VAE Decode" "title": "VAE Decode"
} }
}, },
"14": { "10": {
"inputs": {
"vae_name": "flux1\\ae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "Load VAE"
}
},
"11": {
"inputs": {
"clip_name1": "t5xxl_fp8_e4m3fn.safetensors",
"clip_name2": "ViT-L-14-TEXT-detail-improved-hiT-GmP-TE-only-HF.safetensors",
"type": "flux",
"device": "default"
},
"class_type": "DualCLIPLoader",
"_meta": {
"title": "DualCLIPLoader"
}
},
"13": {
"inputs": {
"noise": [
"147",
0
],
"guider": [
"22",
0
],
"sampler": [
"16",
0
],
"sigmas": [
"17",
0
],
"latent_image": [
"48",
0
]
},
"class_type": "SamplerCustomAdvanced",
"_meta": {
"title": "SamplerCustomAdvanced"
}
},
"16": {
"inputs": {
"sampler_name": "dpmpp_2m"
},
"class_type": "KSamplerSelect",
"_meta": {
"title": "KSamplerSelect"
}
},
"17": {
"inputs": {
"scheduler": "beta",
"steps": [
"246",
0
],
"denoise": 1,
"model": [
"28",
0
]
},
"class_type": "BasicScheduler",
"_meta": {
"title": "BasicScheduler"
}
},
"22": {
"inputs": {
"model": [
"28",
0
],
"conditioning": [
"29",
0
]
},
"class_type": "BasicGuider",
"_meta": {
"title": "BasicGuider"
}
},
"28": {
"inputs": {
"max_shift": 1.1500000000000001,
"base_shift": 0.5,
"width": [
"48",
1
],
"height": [
"48",
2
],
"model": [
"299",
0
]
},
"class_type": "ModelSamplingFlux",
"_meta": {
"title": "ModelSamplingFlux"
}
},
"29": {
"inputs": {
"guidance": 3.5,
"conditioning": [
"6",
0
]
},
"class_type": "FluxGuidance",
"_meta": {
"title": "FluxGuidance"
}
},
"48": {
"inputs": {
"resolution": "832x1216 (0.68)",
"batch_size": 1,
"width_override": 0,
"height_override": 0
},
"class_type": "SDXLEmptyLatentSizePicker+",
"_meta": {
"title": "🔧 SDXL Empty Latent Size Picker"
}
},
"65": {
"inputs": {
"unet_name": "flux\\flux1-dev-fp8-e4m3fn.safetensors",
"weight_dtype": "fp8_e4m3fn_fast"
},
"class_type": "UNETLoader",
"_meta": {
"title": "Load Diffusion Model"
}
},
"147": {
"inputs": {
"noise_seed": 651532572596956
},
"class_type": "RandomNoise",
"_meta": {
"title": "RandomNoise"
}
},
"148": {
"inputs": {
"wildcard_text": "__some-prompts__",
"populated_text": "A surreal digital artwork showcases a forward-thinking inventor captivated by his intricate mechanical creation through a large magnifying glass. Viewed from an unconventional perspective, the scene reveals an eccentric assembly of gears, springs, and brass instruments within his workshop. Soft, ethereal light radiates from the invention, casting enigmatic shadows on the walls as time appears to bend around its metallic form, invoking a sense of curiosity, wonder, and exhilaration in discovery.",
"mode": "fixed",
"seed": 553084268162351,
"Select to add Wildcard": "Select the Wildcard to add to the text"
},
"class_type": "ImpactWildcardProcessor",
"_meta": {
"title": "ImpactWildcardProcessor"
}
},
"151": {
"inputs": {
"text": "A hyper-realistic close-up portrait of a young woman with shoulder-length black hair styled in edgy, futuristic layers, adorned with glowing tips. She wears mecha eyewear with a neon green visor that transitions into iridescent shades of teal and gold. The frame is sleek, with angular edges and fine mechanical detailing. Her expression is fierce and confident, with flawless skin highlighted by the neon reflections. She wears a high-tech bodysuit with integrated LED lines and metallic panels. The background depicts a hazy rendition of The Great Wave off Kanagawa by Hokusai, its powerful waves blending seamlessly with the neon tones, amplifying her intense, defiant aura."
},
"class_type": "Text Multiline",
"_meta": {
"title": "Text Multiline"
}
},
"191": {
"inputs": {
"text": "A cinematic, oil painting masterpiece captures the essence of impressionistic surrealism, inspired by Claude Monet. A mysterious woman in a flowing crimson dress stands at the edge of a tranquil lake, where lily pads shimmer under an ethereal, golden twilight. The waters surface reflects a dreamlike sky, its swirling hues of violet and sapphire melting together like liquid light. The thick, expressive brushstrokes lend depth to the scene, evoking a sense of nostalgia and quiet longing, as if the world itself is caught between reality and a fleeting dream. \nA mesmerizing oil painting masterpiece inspired by Salvador Dalí, blending surrealism with post-impressionist texture. A lone violinist plays atop a melting clock tower, his form distorted by the passage of time. The sky is a cascade of swirling, liquid oranges and deep blues, where floating staircases spiral endlessly into the horizon. The impasto technique gives depth and movement to the surreal elements, making time itself feel fluid, as if the world is dissolving into a dream. \nA stunning impressionistic oil painting evokes the spirit of Edvard Munch, capturing a solitary figure standing on a rain-soaked street, illuminated by the glow of flickering gas lamps. The swirling, chaotic strokes of deep blues and fiery reds reflect the turbulence of emotion, while the blurred reflections in the wet cobblestone suggest a merging of past and present. The faceless figure, draped in a dark overcoat, seems lost in thought, embodying the ephemeral nature of memory and time. \nA breathtaking oil painting masterpiece, inspired by Gustav Klimt, presents a celestial ballroom where faceless dancers swirl in an eternal waltz beneath a gilded, star-speckled sky. Their golden garments shimmer with intricate patterns, blending into the opulent mosaic floor that seems to stretch into infinity. The dreamlike composition, rich in warm amber and deep sapphire hues, captures an otherworldly elegance, as if the dancers are suspended in a moment that transcends time. \nA visionary oil painting inspired by Marc Chagall depicts a dreamlike cityscape where gravity ceases to exist. A couple floats above a crimson-tinted town, their forms dissolving into the swirling strokes of a vast, cerulean sky. The buildings below twist and bend in rhythmic motion, their windows glowing like tiny stars. The thick, textured brushwork conveys a sense of weightlessness and wonder, as if love itself has defied the laws of the universe. \nAn impressionistic oil painting in the style of J.M.W. Turner, depicting a ghostly ship sailing through a sea of swirling golden mist. The waves crash and dissolve into abstract, fiery strokes of orange and deep indigo, blurring the line between ocean and sky. The ship appears almost ethereal, as if drifting between worlds, lost in the ever-changing tides of memory and myth. The dynamic brushstrokes capture the relentless power of nature and the fleeting essence of time. \nA captivating oil painting masterpiece, infused with surrealist impressionism, portrays a grand library where books float midair, their pages unraveling into ribbons of light. The towering shelves twist into the heavens, vanishing into an infinite, starry void. A lone scholar, illuminated by the glow of a suspended lantern, reaches for a book that seems to pulse with life. The scene pulses with mystery, where the impasto textures bring depth to the interplay between knowledge and dreams. \nA luminous impressionistic oil painting captures the melancholic beauty of an abandoned carnival, its faded carousel horses frozen mid-gallop beneath a sky of swirling lavender and gold. The wind carries fragments of forgotten laughter through the empty fairground, where scattered ticket stubs and crumbling banners whisper tales of joy long past. The thick, textured brushstrokes blend nostalgia with an eerie dreamlike quality, as if the carnival exists only in the echoes of memory. \nA surreal oil painting in the spirit of René Magritte, featuring a towering lighthouse that emits not light, but cascading waterfalls from its peak. The swirling sky, painted in deep midnight blues, is punctuated by glowing, crescent moons that defy gravity. A lone figure stands at the waters edge, gazing up in quiet contemplation, as if caught between wonder and the unknown. The paintings rich textures and luminous colors create an enigmatic, dreamlike landscape. \nA striking impressionistic oil painting, reminiscent of Van Gogh, portrays a lone traveler on a winding cobblestone path, their silhouette bathed in the golden glow of lantern-lit cherry blossoms. The petals swirl through the night air like glowing embers, blending with the deep, rhythmic strokes of a star-filled indigo sky. The scene captures a feeling of wistful solitude, as if the traveler is walking not only through the city, but through the fleeting nature of time itself."
},
"class_type": "Text Multiline",
"_meta": {
"title": "Text Multiline"
}
},
"203": {
"inputs": {
"string1": [
"289",
0
],
"string2": [
"293",
0
],
"delimiter": ", "
},
"class_type": "JoinStrings",
"_meta": {
"title": "Join Strings"
}
},
"208": {
"inputs": {
"file_path": "",
"dictionary_name": "[filename]",
"label": "TextBatch",
"mode": "automatic",
"index": 0,
"multiline_text": [
"191",
0
]
},
"class_type": "Text Load Line From File",
"_meta": {
"title": "Text Load Line From File"
}
},
"226": {
"inputs": { "inputs": {
"images": [ "images": [
"8", "8",
@@ -106,60 +266,21 @@
"title": "Preview Image" "title": "Preview Image"
} }
}, },
"19": { "246": {
"inputs": { "inputs": {
"stop_at_clip_layer": -2, "value": 25
"clip": [
"4",
1
]
}, },
"class_type": "CLIPSetLastLayer", "class_type": "INTConstant",
"_meta": { "_meta": {
"title": "CLIP Set Last Layer" "title": "Steps"
} }
}, },
"21": { "289": {
"inputs": {
"string": "masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing",
"strip_newlines": false
},
"class_type": "StringConstantMultiline",
"_meta": {
"title": "positive"
}
},
"22": {
"inputs": {
"string1": [
"55",
0
],
"string2": [
"21",
0
],
"delimiter": ", "
},
"class_type": "JoinStrings",
"_meta": {
"title": "Join Strings"
}
},
"55": {
"inputs": { "inputs": {
"group_mode": true, "group_mode": true,
"toggle_trigger_words": [ "toggle_trigger_words": [
{ {
"text": "in the style of ck-rw", "text": "bo-exposure",
"active": true
},
{
"text": "in the style of cksc",
"active": true
},
{
"text": "artist:moriimee",
"active": true "active": true
}, },
{ {
@@ -173,9 +294,9 @@
"_isDummy": true "_isDummy": true
} }
], ],
"orinalMessage": "in the style of ck-rw,, in the style of cksc,, artist:moriimee", "orinalMessage": "bo-exposure",
"trigger_words": [ "trigger_words": [
"56", "299",
2 2
] ]
}, },
@@ -184,25 +305,58 @@
"title": "TriggerWord Toggle (LoraManager)" "title": "TriggerWord Toggle (LoraManager)"
} }
}, },
"56": { "293": {
"inputs": { "inputs": {
"text": "<lora:ck-shadow-circuit-IL-000012:0.78> <lora:MoriiMee_Gothic_Niji_Style_Illustrious_r1:0.45> <lora:ck-nc-cyberpunk-IL-000011:0.4>", "input": 1,
"text1": [
"208",
0
],
"text2": [
"151",
0
]
},
"class_type": "easy textSwitch",
"_meta": {
"title": "Text Switch"
}
},
"297": {
"inputs": {
"text": ""
},
"class_type": "Lora Stacker (LoraManager)",
"_meta": {
"title": "Lora Stacker (LoraManager)"
}
},
"298": {
"inputs": {
"anything": [
"297",
0
]
},
"class_type": "easy showAnything",
"_meta": {
"title": "Show Any"
}
},
"299": {
"inputs": {
"text": "<lora:boFLUX Double Exposure Magic v2:0.8> <lora:FluxDFaeTasticDetails:0.65>",
"loras": [ "loras": [
{ {
"name": "ck-shadow-circuit-IL-000012", "name": "boFLUX Double Exposure Magic v2",
"strength": 0.78, "strength": 0.8,
"active": true "active": true
}, },
{ {
"name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1", "name": "FluxDFaeTasticDetails",
"strength": 0.45, "strength": 0.65,
"active": true "active": true
}, },
{
"name": "ck-nc-cyberpunk-IL-000011",
"strength": 0.4,
"active": false
},
{ {
"name": "__dummy_item1__", "name": "__dummy_item1__",
"strength": 0, "strength": 0,
@@ -217,15 +371,15 @@
} }
], ],
"model": [ "model": [
"4", "65",
0 0
], ],
"clip": [ "clip": [
"4", "11",
1 0
], ],
"lora_stack": [ "lora_stack": [
"57", "297",
0 0
] ]
}, },
@@ -234,64 +388,14 @@
"title": "Lora Loader (LoraManager)" "title": "Lora Loader (LoraManager)"
} }
}, },
"57": { "301": {
"inputs": { "inputs": {
"text": "<lora:aorunIllstrious:1>", "string": "A hyper-realistic close-up portrait of a young woman with shoulder-length black hair styled in edgy, futuristic layers, adorned with glowing tips. She wears mecha eyewear with a neon green visor that transitions into iridescent shades of teal and gold. The frame is sleek, with angular edges and fine mechanical detailing. Her expression is fierce and confident, with flawless skin highlighted by the neon reflections. She wears a high-tech bodysuit with integrated LED lines and metallic panels. The background depicts a hazy rendition of The Great Wave off Kanagawa by Hokusai, its powerful waves blending seamlessly with the neon tones, amplifying her intense, defiant aura.",
"loras": [ "strip_newlines": true
{
"name": "aorunIllstrious",
"strength": "0.90",
"active": false
},
{
"name": "__dummy_item1__",
"strength": 0,
"active": false,
"_isDummy": true
},
{
"name": "__dummy_item2__",
"strength": 0,
"active": false,
"_isDummy": true
}
],
"lora_stack": [
"59",
0
]
}, },
"class_type": "Lora Stacker (LoraManager)", "class_type": "StringConstantMultiline",
"_meta": { "_meta": {
"title": "Lora Stacker (LoraManager)" "title": "String Constant Multiline"
}
},
"59": {
"inputs": {
"text": "<lora:ck-neon-retrowave-IL-000012:0.8>",
"loras": [
{
"name": "ck-neon-retrowave-IL-000012",
"strength": 0.8,
"active": true
},
{
"name": "__dummy_item1__",
"strength": 0,
"active": false,
"_isDummy": true
},
{
"name": "__dummy_item2__",
"strength": 0,
"active": false,
"_isDummy": true
}
]
},
"class_type": "Lora Stacker (LoraManager)",
"_meta": {
"title": "Lora Stacker (LoraManager)"
} }
} }
} }

View File

@@ -5,4 +5,5 @@ watchdog
beautifulsoup4 beautifulsoup4
piexif piexif
Pillow Pillow
olefile
requests requests

View File

@@ -1,25 +0,0 @@
import json
from py.workflow.parser import WorkflowParser
# Load workflow data
with open('refs/prompt.json', 'r') as f:
workflow_data = json.load(f)
# Parse workflow
parser = WorkflowParser()
try:
# Parse the workflow
result = parser.parse_workflow(workflow_data)
print("Parsing successful!")
# Print each component separately
print("\nGeneration Parameters:")
for k, v in result.get("gen_params", {}).items():
print(f" {k}: {v}")
print("\nLoRAs:")
print(result.get("loras", ""))
except Exception as e:
print(f"Error parsing workflow: {e}")
import traceback
traceback.print_exc()

View File

@@ -23,6 +23,7 @@
cursor: pointer; /* Added from recipe-card */ cursor: pointer; /* Added from recipe-card */
display: flex; /* Added from recipe-card */ display: flex; /* Added from recipe-card */
flex-direction: column; /* Added from recipe-card */ flex-direction: column; /* Added from recipe-card */
overflow: hidden; /* Add overflow hidden to contain children */
} }
.lora-card:hover { .lora-card:hover {
@@ -50,9 +51,11 @@
.card-preview { .card-preview {
position: relative; position: relative;
width: 100%; width: 100%;
height: 100%; height: 100%; /* This should work with aspect-ratio on parent */
border-radius: var(--border-radius-base); border-radius: var(--border-radius-base);
overflow: hidden; overflow: hidden;
flex-shrink: 0; /* Prevent shrinking */
min-height: 0; /* Fix for potential flexbox sizing issue in Firefox */
} }
.card-preview img, .card-preview img,

View File

@@ -0,0 +1,84 @@
/* Filter indicator styles */
.control-group .filter-active {
display: flex;
align-items: center;
gap: 6px;
background: var(--lora-accent);
color: white;
border-radius: var(--border-radius-xs);
padding: 4px 10px;
transition: all 0.2s ease;
border: 1px solid var(--lora-accent);
cursor: pointer;
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
font-size: 0.85em;
}
.control-group .filter-active:hover {
opacity: 0.92;
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.15);
}
.control-group .filter-active:active {
transform: translateY(0);
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
}
.control-group .filter-active i.fa-filter {
font-size: 0.9em;
margin-right: 2px;
opacity: 0.9;
}
.control-group .filter-active i.clear-filter {
transition: transform 0.2s ease, background-color 0.2s ease;
cursor: pointer;
margin-left: 4px;
border-radius: 50%;
font-size: 0.85em;
width: 16px;
height: 16px;
display: flex;
align-items: center;
justify-content: center;
}
.control-group .filter-active i.clear-filter:hover {
transform: scale(1.2);
background-color: rgba(255, 255, 255, 0.2);
}
.control-group .filter-active .lora-name {
font-weight: 500;
max-width: 150px;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
/* Animation for filter indicator */
@keyframes filterPulse {
0% { transform: scale(1); box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); }
50% { transform: scale(1.03); box-shadow: 0 3px 8px rgba(0, 0, 0, 0.15); }
100% { transform: scale(1); box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); }
}
.filter-active.animate {
animation: filterPulse 0.6s ease;
}
/* Make responsive */
@media (max-width: 576px) {
.control-group .filter-active {
padding: 6px 10px;
}
.control-group .filter-active .lora-name {
max-width: 100px;
}
.control-group .filter-active:hover {
transform: none; /* Disable hover effects on mobile */
}
}

View File

@@ -0,0 +1,359 @@
/* Initialization Component Styles */
.initialization-container {
width: 100%;
height: 100%;
padding: var(--space-3);
background: var(--lora-surface);
animation: fadeIn 0.3s ease-in-out;
display: flex;
align-items: center;
justify-content: center;
}
.initialization-content {
max-width: 800px;
width: 100%;
}
/* Override loading.css width for initialization component */
.initialization-container .loading-content {
width: 100%;
max-width: 100%;
background: transparent;
backdrop-filter: none;
border: none;
padding: 0;
}
.initialization-header {
text-align: center;
margin-bottom: var(--space-3);
}
.initialization-header h2 {
font-size: 1.8rem;
margin-bottom: var(--space-1);
color: var(--text-color);
}
.init-subtitle {
color: var(--text-color);
opacity: 0.8;
font-size: 1rem;
}
/* Progress Bar Styles specific to initialization */
.initialization-progress {
margin-bottom: var(--space-3);
}
/* Renamed container class */
.init-progress-container {
width: 100%; /* Use full width within its container */
height: 8px; /* Match height from previous .progress-bar-container */
background-color: var(--lora-border); /* Consistent background */
border-radius: 4px;
overflow: hidden;
margin: 0 auto var(--space-1); /* Center horizontally, add bottom margin */
}
/* Renamed progress bar class */
.init-progress-bar {
height: 100%;
/* Use a gradient consistent with the theme accent */
background: linear-gradient(90deg, var(--lora-accent) 0%, color-mix(in oklch, var(--lora-accent) 80%, transparent) 100%);
border-radius: 4px; /* Match container radius */
transition: width 0.3s ease;
width: 0%; /* Start at 0% */
}
/* Remove the old .progress-bar rule specific to initialization to avoid conflicts */
/* .progress-bar { ... } */
/* Progress Details */
.progress-details {
display: flex;
justify-content: space-between;
font-size: 0.9rem;
color: var(--text-color);
margin-top: var(--space-1);
padding: 0 2px;
}
#remainingTime {
font-style: italic;
color: var(--text-color);
opacity: 0.8;
}
/* Stages Styles */
.initialization-stages {
margin-bottom: var(--space-3);
}
.stage-item {
display: flex;
align-items: flex-start;
padding: var(--space-2);
border-radius: var(--border-radius-xs);
margin-bottom: var(--space-1);
transition: background-color 0.2s ease;
border: 1px solid transparent;
}
.stage-item.active {
background-color: rgba(var(--lora-accent), 0.1);
border-color: var(--lora-accent);
}
.stage-item.completed {
background-color: rgba(0, 150, 0, 0.05);
border-color: rgba(0, 150, 0, 0.2);
}
.stage-icon {
display: flex;
align-items: center;
justify-content: center;
width: 40px;
height: 40px;
background: var(--lora-border);
border-radius: 50%;
margin-right: var(--space-2);
}
.stage-item.active .stage-icon {
background: var(--lora-accent);
color: white;
}
.stage-item.completed .stage-icon {
background: rgb(0, 150, 0);
color: white;
}
.stage-content {
flex: 1;
}
.stage-content h4 {
margin: 0 0 5px 0;
font-size: 1rem;
color: var(--text-color);
}
.stage-details {
font-size: 0.85rem;
color: var(--text-color);
opacity: 0.8;
}
.stage-status {
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
}
.stage-status.pending {
color: var(--text-color);
opacity: 0.5;
}
.stage-status.in-progress {
color: var(--lora-accent);
}
.stage-status.completed {
color: rgb(0, 150, 0);
}
/* Tips Container */
.tips-container {
margin-top: var(--space-3);
background: rgba(var(--lora-accent), 0.05);
border-radius: var(--border-radius-base);
padding: var(--space-2);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
}
.tips-header {
display: flex;
align-items: center;
margin-bottom: var(--space-2);
padding-bottom: var(--space-1);
border-bottom: 1px solid var(--lora-border);
}
.tips-header i {
margin-right: 10px;
color: var(--lora-accent);
font-size: 1.2rem;
}
.tips-header h3 {
font-size: 1.2rem;
margin: 0;
color: var(--text-color);
}
/* Tip Carousel with Images */
.tips-content {
position: relative;
}
.tip-carousel {
position: relative;
height: 160px;
overflow: hidden;
}
.tip-item {
position: absolute;
width: 100%;
height: 100%;
display: flex;
opacity: 0;
transition: opacity 0.5s ease;
padding: 0;
border-radius: var(--border-radius-sm);
overflow: hidden;
}
.tip-item.active {
opacity: 1;
}
.tip-image {
width: 40%;
overflow: hidden;
display: flex;
align-items: center;
justify-content: center;
background-color: var(--lora-border);
}
.tip-image img {
width: 100%;
height: 100%;
object-fit: cover;
}
.tip-text {
width: 60%;
padding: var(--space-2);
display: flex;
flex-direction: column;
justify-content: center;
}
.tip-text h4 {
margin: 0 0 var(--space-1) 0;
font-size: 1.1rem;
color: var(--text-color);
}
.tip-text p {
margin: 0;
line-height: 1.5;
font-size: 0.9rem;
color: var(--text-color);
}
.tip-navigation {
display: flex;
justify-content: center;
margin-top: var(--space-2);
}
.tip-dot {
width: 10px;
height: 10px;
border-radius: 50%;
background-color: var(--lora-border);
margin: 0 5px;
cursor: pointer;
transition: background-color 0.2s ease, transform 0.2s ease;
}
.tip-dot:hover {
transform: scale(1.2);
}
.tip-dot.active {
background-color: var(--lora-accent);
}
/* Animation */
@keyframes fadeIn {
from {
opacity: 0;
transform: translateY(10px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
/* Different stage status animations */
@keyframes pulse {
0% {
transform: scale(1);
}
50% {
transform: scale(1.2);
}
100% {
transform: scale(1);
}
}
.stage-item.active .stage-icon i {
animation: pulse 1s infinite;
}
/* Responsive Adjustments */
@media (max-width: 768px) {
.initialization-container {
padding: var(--space-2);
}
.stage-item {
padding: var(--space-1);
}
.stage-icon {
width: 32px;
height: 32px;
min-width: 32px;
}
.tip-item {
flex-direction: column;
height: 220px;
}
.tip-image, .tip-text {
width: 100%;
}
.tip-image {
height: 120px;
}
.tip-carousel {
height: 220px;
}
}
@media (prefers-reduced-motion: reduce) {
.initialization-container,
.tip-item,
.tip-dot {
transition: none;
animation: none;
}
}

View File

@@ -99,6 +99,7 @@
width: 100%; width: 100%;
background: var(--lora-surface); background: var(--lora-surface);
margin-bottom: var(--space-2); margin-bottom: var(--space-2);
overflow: hidden; /* Ensure metadata panel is contained */
} }
.media-wrapper:last-child { .media-wrapper:last-child {
@@ -542,25 +543,53 @@
display: flex; display: flex;
align-items: center; align-items: center;
gap: 8px; gap: 8px;
cursor: pointer;
padding: 4px; padding: 4px;
border-radius: var(--border-radius-xs); border-radius: var(--border-radius-xs);
transition: background-color 0.2s; transition: background-color 0.2s;
position: relative;
} }
.file-name-wrapper:hover { .file-name-wrapper:hover {
background: oklch(var(--lora-accent) / 0.1); background: oklch(var(--lora-accent) / 0.1);
} }
.file-name-wrapper i { .file-name-content {
color: var(--text-color); padding: 2px 4px;
opacity: 0.5; border-radius: var(--border-radius-xs);
transition: opacity 0.2s; border: 1px solid transparent;
flex: 1;
} }
.file-name-wrapper:hover i { .file-name-wrapper.editing .file-name-content {
opacity: 1; border: 1px solid var(--lora-accent);
color: var(--lora-accent); background: var(--bg-color);
outline: none;
}
.edit-file-name-btn {
background: transparent;
border: none;
color: var(--text-color);
opacity: 0;
cursor: pointer;
padding: 2px 5px;
border-radius: var(--border-radius-xs);
transition: all 0.2s ease;
margin-left: var(--space-1);
}
.edit-file-name-btn.visible,
.file-name-wrapper:hover .edit-file-name-btn {
opacity: 0.5;
}
.edit-file-name-btn:hover {
opacity: 0.8 !important;
background: rgba(0, 0, 0, 0.05);
}
[data-theme="dark"] .edit-file-name-btn:hover {
background: rgba(255, 255, 255, 0.05);
} }
/* Base Model and Size combined styles */ /* Base Model and Size combined styles */
@@ -573,6 +602,59 @@
flex: 2; /* 分配更多空间给base model */ flex: 2; /* 分配更多空间给base model */
} }
/* Base model display and editing styles */
.base-model-display {
display: flex;
align-items: center;
position: relative;
}
.base-model-content {
padding: 2px 4px;
border-radius: var(--border-radius-xs);
border: 1px solid transparent;
color: var(--text-color);
flex: 1;
}
.edit-base-model-btn {
background: transparent;
border: none;
color: var(--text-color);
opacity: 0;
cursor: pointer;
padding: 2px 5px;
border-radius: var(--border-radius-xs);
transition: all 0.2s ease;
margin-left: var(--space-1);
}
.edit-base-model-btn.visible,
.base-model-display:hover .edit-base-model-btn {
opacity: 0.5;
}
.edit-base-model-btn:hover {
opacity: 0.8 !important;
background: rgba(0, 0, 0, 0.05);
}
[data-theme="dark"] .edit-base-model-btn:hover {
background: rgba(255, 255, 255, 0.05);
}
.base-model-selector {
width: 100%;
padding: 3px 5px;
background: var(--bg-color);
border: 1px solid var(--lora-accent);
border-radius: var(--border-radius-xs);
color: var(--text-color);
font-size: 0.9em;
outline: none;
margin-right: var(--space-1);
}
.size-wrapper { .size-wrapper {
flex: 1; flex: 1;
border-left: 1px solid var(--lora-border); border-left: 1px solid var(--lora-border);
@@ -781,7 +863,7 @@
} }
.model-description-content blockquote { .model-description-content blockquote {
border-left: 3px solid var(--lora-accent); border-left: 3px solid var (--lora-accent);
padding-left: 1em; padding-left: 1em;
margin-left: 0; margin-left: 0;
margin-right: 0; margin-right: 0;
@@ -1030,4 +1112,215 @@
/* Make sure media wrapper maintains position: relative for absolute positioning of children */ /* Make sure media wrapper maintains position: relative for absolute positioning of children */
.carousel .media-wrapper { .carousel .media-wrapper {
position: relative; position: relative;
}
/* Image Metadata Panel Styles */
.image-metadata-panel {
position: absolute;
bottom: 0;
left: 0;
right: 0;
background: var(--bg-color);
border-top: 1px solid var(--border-color);
padding: var(--space-2);
transform: translateY(100%);
transition: transform 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275), opacity 0.25s ease;
z-index: 5;
max-height: 50%; /* Reduced to take less space */
overflow-y: auto;
box-shadow: 0 -2px 8px rgba(0, 0, 0, 0.1);
opacity: 0;
pointer-events: none;
}
/* Show metadata panel only on hover */
.media-wrapper:hover .image-metadata-panel {
transform: translateY(0);
opacity: 0.98;
pointer-events: auto;
}
/* Adjust to dark theme */
[data-theme="dark"] .image-metadata-panel {
background: var(--card-bg);
box-shadow: 0 -2px 8px rgba(0, 0, 0, 0.3);
}
.metadata-content {
display: flex;
flex-direction: column;
gap: 10px;
}
/* Styling for parameters tags */
.params-tags {
display: flex;
flex-wrap: wrap;
gap: 6px;
margin-bottom: var(--space-1);
padding-bottom: var(--space-1);
border-bottom: 1px solid var(--lora-border);
}
.param-tag {
display: inline-flex;
align-items: center;
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-xs);
padding: 2px 6px;
font-size: 0.8em;
line-height: 1.2;
white-space: nowrap;
}
.param-tag .param-name {
font-weight: 600;
color: var(--text-color);
margin-right: 4px;
opacity: 0.8;
}
.param-tag .param-value {
color: var(--lora-accent);
}
/* Special styling for prompt row */
.metadata-row.prompt-row {
flex-direction: column;
padding-top: 0;
}
.metadata-row.prompt-row + .metadata-row.prompt-row {
margin-top: var(--space-2);
}
.metadata-label {
font-weight: 600;
color: var(--text-color);
opacity: 0.8;
font-size: 0.85em;
display: block;
margin-bottom: 4px;
}
.metadata-prompt-wrapper {
position: relative;
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-xs);
padding: 6px 30px 6px 8px;
margin-top: 2px;
max-height: 80px; /* Reduced from 120px */
overflow-y: auto;
word-break: break-word;
width: 100%;
box-sizing: border-box;
}
.metadata-prompt {
color: var(--text-color);
font-family: monospace;
font-size: 0.85em;
white-space: pre-wrap;
}
.copy-prompt-btn {
position: absolute;
top: 6px;
right: 6px;
background: transparent;
border: none;
color: var(--text-color);
opacity: 0.6;
cursor: pointer;
padding: 3px;
transition: all 0.2s ease;
}
.copy-prompt-btn:hover {
opacity: 1;
color: var(--lora-accent);
}
/* Scrollbar styling for metadata panel */
.image-metadata-panel::-webkit-scrollbar {
width: 6px;
}
.image-metadata-panel::-webkit-scrollbar-track {
background: transparent;
}
.image-metadata-panel::-webkit-scrollbar-thumb {
background-color: var(--border-color);
border-radius: 3px;
}
/* For Firefox */
.image-metadata-panel {
scrollbar-width: thin;
scrollbar-color: var(--border-color) transparent;
}
/* No metadata message styling */
.no-metadata-message {
display: flex;
align-items: center;
justify-content: center;
padding: var(--space-2);
color: var(--text-color);
opacity: 0.7;
text-align: center;
font-style: italic;
gap: 8px;
}
.no-metadata-message i {
font-size: 1.1em;
color: var(--lora-accent);
opacity: 0.8;
}
.view-all-btn {
display: flex;
align-items: center;
gap: 5px;
padding: 6px 12px;
background-color: var(--lora-accent);
color: var(--lora-text);
border: none;
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: background-color 0.2s;
font-size: 13px;
}
.view-all-btn:hover {
opacity: 0.9;
}
/* Loading, error and empty states */
.recipes-loading,
.recipes-error,
.recipes-empty {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 40px;
text-align: center;
min-height: 200px;
}
.recipes-loading i,
.recipes-error i,
.recipes-empty i {
font-size: 32px;
margin-bottom: 15px;
color: var(--lora-accent);
}
.recipes-error i {
color: var(--lora-error);
} }

View File

@@ -196,7 +196,7 @@ body.modal-open {
} }
.settings-modal { .settings-modal {
max-width: 500px; max-width: 650px; /* Further increased from 600px for more space */
} }
/* Settings Links */ /* Settings Links */
@@ -266,14 +266,22 @@ body.modal-open {
} }
} }
/* API key input specific styles */
.api-key-input { .api-key-input {
width: 100%; /* Take full width of parent */
position: relative; position: relative;
display: flex; display: flex;
align-items: center; align-items: center;
} }
.api-key-input input { .api-key-input input {
padding-right: 40px; width: 100%;
padding: 6px 40px 6px 10px; /* Add left padding */
height: 32px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
color: var(--text-color);
} }
.api-key-input .toggle-visibility { .api-key-input .toggle-visibility {
@@ -294,8 +302,10 @@ body.modal-open {
.input-help { .input-help {
font-size: 0.85em; font-size: 0.85em;
color: var(--text-color); color: var(--text-color);
opacity: 0.8; opacity: 0.7;
margin-top: 4px; margin-top: 8px; /* Space between control and help */
line-height: 1.4;
width: 100%; /* Full width */
} }
/* 统一各个 section 的样式 */ /* 统一各个 section 的样式 */
@@ -341,9 +351,8 @@ body.modal-open {
.setting-item { .setting-item {
display: flex; display: flex;
justify-content: space-between; flex-direction: column; /* Changed to column for help text placement */
align-items: flex-start; margin-bottom: var(--space-3); /* Increased to provide more spacing between items */
margin-bottom: var(--space-2);
padding: var(--space-1); padding: var(--space-1);
border-radius: var(--border-radius-xs); border-radius: var(--border-radius-xs);
} }
@@ -356,18 +365,68 @@ body.modal-open {
background: rgba(255, 255, 255, 0.05); background: rgba(255, 255, 255, 0.05);
} }
/* Control row with label and input together */
.setting-row {
display: flex;
flex-direction: row;
justify-content: space-between;
align-items: center;
width: 100%;
}
.setting-info { .setting-info {
flex: 1; margin-bottom: 0;
width: 35%; /* Increased from 30% to prevent wrapping */
flex-shrink: 0; /* Prevent shrinking */
} }
.setting-info label { .setting-info label {
display: block; display: block;
margin-bottom: 4px;
font-weight: 500; font-weight: 500;
margin-bottom: 0;
white-space: nowrap; /* Prevent label wrapping */
} }
.setting-control { .setting-control {
padding-left: var(--space-2); width: 60%; /* Decreased slightly from 65% */
margin-bottom: 0;
display: flex;
justify-content: flex-end; /* Right-align all controls */
}
/* Select Control Styles */
.select-control {
width: 100%;
display: flex;
justify-content: flex-end;
}
.select-control select {
width: 100%;
max-width: 100%; /* Increased from 200px */
padding: 6px 10px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--lora-surface);
color: var(--text-color);
font-size: 0.95em;
height: 32px;
}
/* Fix dark theme select dropdown text color */
[data-theme="dark"] .select-control select {
background-color: rgba(30, 30, 30, 0.9);
color: var(--text-color);
}
[data-theme="dark"] .select-control select option {
background-color: #2d2d2d;
color: var(--text-color);
}
.select-control select:focus {
border-color: var(--lora-accent);
outline: none;
} }
/* Toggle Switch */ /* Toggle Switch */
@@ -377,6 +436,7 @@ body.modal-open {
width: 50px; width: 50px;
height: 24px; height: 24px;
cursor: pointer; cursor: pointer;
margin-left: auto; /* Push to right side */
} }
.toggle-switch input { .toggle-switch input {
@@ -426,15 +486,6 @@ input:checked + .toggle-slider:before {
width: 22px; width: 22px;
} }
/* Update input help styles */
.input-help {
font-size: 0.85em;
color: var(--text-color);
opacity: 0.7;
margin-top: 4px;
line-height: 1.4;
}
/* Blur effect for NSFW content */ /* Blur effect for NSFW content */
.nsfw-blur { .nsfw-blur {
filter: blur(12px); filter: blur(12px);
@@ -482,4 +533,44 @@ input:checked + .toggle-slider:before {
font-style: italic; font-style: italic;
margin-top: var(--space-1); margin-top: var(--space-1);
text-align: center; text-align: center;
}
/* Add styles for markdown elements in changelog */
.changelog-item ul {
padding-left: 20px;
margin-top: 8px;
}
.changelog-item li {
margin-bottom: 6px;
line-height: 1.4;
}
.changelog-item strong {
font-weight: 600;
}
.changelog-item em {
font-style: italic;
}
.changelog-item code {
background: rgba(0, 0, 0, 0.05);
padding: 2px 4px;
border-radius: 3px;
font-family: monospace;
font-size: 0.9em;
}
[data-theme="dark"] .changelog-item code {
background: rgba(255, 255, 255, 0.1);
}
.changelog-item a {
color: var(--lora-accent);
text-decoration: none;
}
.changelog-item a:hover {
text-decoration: underline;
} }

View File

@@ -1,184 +0,0 @@
.recipe-tag-container {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
margin-bottom: 1rem;
}
.recipe-tag {
background: var(--lora-surface-hover);
color: var(--lora-text-secondary);
padding: 0.25rem 0.5rem;
border-radius: var(--border-radius-sm);
font-size: 0.8rem;
cursor: pointer;
transition: all 0.2s ease;
}
.recipe-tag:hover, .recipe-tag.active {
background: var(--lora-primary);
color: var(--lora-text-on-primary);
}
.recipe-card {
position: relative;
background: var(--lora-surface);
border-radius: var(--border-radius-base);
overflow: hidden;
box-shadow: var(--shadow-sm);
transition: all 0.2s ease;
aspect-ratio: 896/1152;
cursor: pointer;
display: flex;
flex-direction: column;
}
.recipe-card:hover {
transform: translateY(-3px);
box-shadow: var(--shadow-md);
}
.recipe-card:focus-visible {
outline: 2px solid var(--lora-accent);
outline-offset: 2px;
}
.recipe-indicator {
position: absolute;
top: 6px;
left: 8px;
width: 24px;
height: 24px;
background: var(--lora-primary);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-weight: bold;
z-index: 2;
}
.recipe-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
gap: 1.5rem;
margin-top: 1.5rem;
}
.placeholder-message {
grid-column: 1 / -1;
text-align: center;
padding: 2rem;
background: var(--lora-surface-alt);
border-radius: var(--border-radius-base);
}
.card-preview {
position: relative;
width: 100%;
height: 100%;
border-radius: var(--border-radius-base);
overflow: hidden;
}
.card-preview img {
width: 100%;
height: 100%;
object-fit: cover;
object-position: center top;
}
.card-header {
position: absolute;
top: 0;
left: 0;
right: 0;
background: linear-gradient(oklch(0% 0 0 / 0.75), transparent 85%);
backdrop-filter: blur(8px);
color: white;
padding: var(--space-1);
display: flex;
justify-content: space-between;
align-items: center;
z-index: 1;
min-height: 20px;
}
.base-model-wrapper {
display: flex;
align-items: center;
gap: 8px;
margin-left: 32px;
}
.card-actions {
display: flex;
gap: 8px;
}
.card-actions i {
cursor: pointer;
opacity: 0.8;
transition: opacity 0.2s ease;
}
.card-actions i:hover {
opacity: 1;
}
.card-footer {
position: absolute;
bottom: 0;
left: 0;
right: 0;
background: linear-gradient(transparent 15%, oklch(0% 0 0 / 0.75));
backdrop-filter: blur(8px);
color: white;
padding: var(--space-1);
display: flex;
justify-content: space-between;
align-items: flex-start;
min-height: 32px;
gap: var(--space-1);
}
.lora-count {
display: flex;
align-items: center;
gap: 4px;
background: rgba(255, 255, 255, 0.2);
padding: 2px 8px;
border-radius: var(--border-radius-xs);
font-size: 0.85em;
position: relative;
}
.lora-count.ready {
background: rgba(46, 204, 113, 0.3);
}
.lora-count.missing {
background: rgba(231, 76, 60, 0.3);
}
/* 响应式设计 */
@media (max-width: 1400px) {
.recipe-grid {
grid-template-columns: repeat(auto-fill, minmax(240px, 1fr));
}
.recipe-card {
max-width: 240px;
}
}
@media (max-width: 768px) {
.recipe-grid {
grid-template-columns: minmax(260px, 1fr);
}
.recipe-card {
max-width: 100%;
}
}

View File

@@ -18,6 +18,110 @@
display: -webkit-box; display: -webkit-box;
-webkit-line-clamp: 2; -webkit-line-clamp: 2;
-webkit-box-orient: vertical; -webkit-box-orient: vertical;
width: calc(100% - 20px);
}
/* Editable content styles */
.editable-content {
position: relative;
width: 100%;
display: flex;
align-items: center;
justify-content: space-between;
}
.editable-content.hide {
display: none;
}
.editable-content .content-text {
flex: 1;
min-width: 0;
overflow: hidden;
text-overflow: ellipsis;
}
.edit-icon {
background: none;
border: none;
color: var(--text-color);
opacity: 0;
cursor: pointer;
padding: 4px 8px;
margin-left: 8px;
border-radius: var(--border-radius-xs);
transition: all 0.2s;
flex-shrink: 0;
display: flex;
align-items: center;
justify-content: center;
}
.editable-content:hover .edit-icon {
opacity: 0.6;
}
.edit-icon:hover {
opacity: 1 !important;
background: var(--lora-surface);
}
/* Content editor styles */
.content-editor {
display: none;
width: 100%;
padding: 4px 0;
}
.content-editor.active {
display: flex;
align-items: center;
gap: 8px;
}
.content-editor input {
flex: 1;
background: var(--bg-color);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-xs);
padding: 6px 8px;
font-size: 1em;
color: var(--text-color);
min-width: 0;
}
.content-editor.tags-editor input {
font-size: 0.9em;
}
/* 删除不再需要的按钮样式 */
.editor-actions {
display: none;
}
/* Special styling for tags content */
.tags-content {
display: flex;
align-items: center;
flex-wrap: nowrap;
gap: 8px;
}
.tags-display {
display: flex;
flex-wrap: nowrap;
gap: 6px;
align-items: center;
flex: 1;
min-width: 0;
overflow: hidden;
}
.no-tags {
font-size: 0.85em;
color: var(--text-color);
opacity: 0.6;
font-style: italic;
} }
/* Recipe Tags styles */ /* Recipe Tags styles */
@@ -129,7 +233,14 @@
justify-content: center; justify-content: center;
} }
.recipe-preview-container img { .recipe-preview-container img,
.recipe-preview-container video {
max-width: 100%;
max-height: 100%;
object-fit: contain;
}
.recipe-preview-media {
max-width: 100%; max-width: 100%;
max-height: 100%; max-height: 100%;
object-fit: contain; object-fit: contain;
@@ -289,6 +400,27 @@
gap: var(--space-1); gap: var(--space-1);
} }
/* View LoRAs button */
.view-loras-btn {
background: none;
border: none;
color: var(--text-color);
opacity: 0.7;
cursor: pointer;
padding: 4px 8px;
border-radius: var(--border-radius-xs);
transition: all 0.2s;
display: flex;
align-items: center;
justify-content: center;
}
.view-loras-btn:hover {
opacity: 1;
background: var(--lora-surface);
color: var(--lora-accent);
}
#recipeLorasCount { #recipeLorasCount {
font-size: 0.9em; font-size: 0.9em;
color: var(--text-color); color: var(--text-color);
@@ -309,6 +441,7 @@
gap: 10px; gap: 10px;
overflow-y: auto; overflow-y: auto;
flex: 1; flex: 1;
padding-top: 4px; /* Add padding to prevent first item from being cut off when hovered */
} }
.recipe-lora-item { .recipe-lora-item {
@@ -322,6 +455,14 @@
will-change: transform; will-change: transform;
/* Create a new containing block for absolutely positioned descendants */ /* Create a new containing block for absolutely positioned descendants */
transform: translateZ(0); transform: translateZ(0);
cursor: pointer; /* Make it clear the item is clickable */
transition: transform 0.2s ease, box-shadow 0.2s ease, border-color 0.2s ease;
}
.recipe-lora-item:hover {
transform: translateY(-1px);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08);
border-color: var(--lora-accent);
} }
.recipe-lora-item.exists-locally { .recipe-lora-item.exists-locally {
@@ -333,6 +474,12 @@
border-left: 4px solid var(--lora-error); border-left: 4px solid var(--lora-error);
} }
.recipe-lora-item.is-deleted {
background: rgba(127, 127, 127, 0.05);
border-left: 4px solid #777;
opacity: 0.8;
}
.recipe-lora-thumbnail { .recipe-lora-thumbnail {
width: 46px; width: 46px;
height: 46px; height: 46px;
@@ -340,9 +487,19 @@
border-radius: var(--border-radius-xs); border-radius: var(--border-radius-xs);
overflow: hidden; overflow: hidden;
background: var(--bg-color); background: var(--bg-color);
display: flex;
align-items: center;
justify-content: center;
} }
.recipe-lora-thumbnail img { .recipe-lora-thumbnail img,
.recipe-lora-thumbnail video {
width: 100%;
height: 100%;
object-fit: cover;
}
.thumbnail-video {
width: 100%; width: 100%;
height: 100%; height: 100%;
object-fit: cover; object-fit: cover;
@@ -457,6 +614,170 @@
font-size: 0.9em; font-size: 0.9em;
} }
/* Deleted badge with reconnect functionality */
.deleted-badge {
display: inline-flex;
align-items: center;
background: #777;
color: white;
padding: 3px 6px;
border-radius: var(--border-radius-xs);
font-size: 0.75em;
font-weight: 500;
white-space: nowrap;
flex-shrink: 0;
}
.deleted-badge i {
margin-right: 4px;
font-size: 0.9em;
}
/* Add reconnect functionality styles */
.deleted-badge.reconnectable {
position: relative;
cursor: pointer;
transition: background-color 0.2s ease;
}
.deleted-badge.reconnectable:hover {
background-color: var(--lora-accent);
}
.deleted-badge .reconnect-tooltip {
position: absolute;
display: none;
background-color: var(--card-bg);
color: var(--text-color);
padding: 8px 12px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
z-index: var(--z-overlay);
width: max-content;
max-width: 200px;
font-size: 0.85rem;
font-weight: normal;
top: calc(100% + 5px);
left: 0;
margin-left: -100px;
}
.deleted-badge.reconnectable:hover .reconnect-tooltip {
display: block;
}
/* LoRA reconnect container */
.lora-reconnect-container {
display: none;
flex-direction: column;
background: var(--lora-surface);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
padding: 12px;
margin-top: 10px;
gap: 10px;
}
.lora-reconnect-container.active {
display: flex;
}
.reconnect-instructions {
display: flex;
flex-direction: column;
gap: 5px;
}
.reconnect-instructions p {
margin: 0;
font-size: 0.95em;
font-weight: 500;
color: var(--text-color);
}
.reconnect-instructions small {
color: var(--text-color);
opacity: 0.7;
font-size: 0.85em;
}
.reconnect-instructions code {
background: rgba(0, 0, 0, 0.1);
padding: 2px 4px;
border-radius: 3px;
font-family: monospace;
font-size: 0.9em;
}
[data-theme="dark"] .reconnect-instructions code {
background: rgba(255, 255, 255, 0.1);
}
.reconnect-form {
display: flex;
flex-direction: column;
gap: 10px;
}
.reconnect-input {
width: calc(100% - 20px);
padding: 8px 10px;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
color: var(--text-color);
font-size: 0.9em;
}
.reconnect-actions {
display: flex;
justify-content: flex-end;
gap: 8px;
}
.reconnect-cancel-btn,
.reconnect-confirm-btn {
padding: 6px 12px;
border-radius: var(--border-radius-xs);
font-size: 0.85em;
cursor: pointer;
border: none;
transition: all 0.2s;
}
.reconnect-cancel-btn {
background: var(--bg-color);
color: var(--text-color);
border: 1px solid var(--border-color);
}
.reconnect-confirm-btn {
background: var(--lora-accent);
color: white;
}
.reconnect-cancel-btn:hover {
background: var(--lora-surface);
}
.reconnect-confirm-btn:hover {
background: color-mix(in oklch, var(--lora-accent), black 10%);
}
/* Recipe status partial state */
.recipe-status.partial {
background: rgba(127, 127, 127, 0.1);
color: #777;
}
/* 标题输入框特定的样式 */
.title-input {
font-size: 1.2em !important; /* 调整为更合适的大小 */
line-height: 1.2;
font-weight: 500;
}
/* Responsive adjustments */ /* Responsive adjustments */
@media (max-width: 768px) { @media (max-width: 768px) {
.recipe-top-section { .recipe-top-section {
@@ -485,7 +806,8 @@
/* Update the local-badge and missing-badge to be positioned within the badge-container */ /* Update the local-badge and missing-badge to be positioned within the badge-container */
.badge-container .local-badge, .badge-container .local-badge,
.badge-container .missing-badge { .badge-container .missing-badge,
.badge-container .deleted-badge {
position: static; /* Override absolute positioning */ position: static; /* Override absolute positioning */
transform: none; /* Remove the transform */ transform: none; /* Remove the transform */
} }
@@ -495,3 +817,46 @@
position: fixed; /* Keep as fixed for Chrome */ position: fixed; /* Keep as fixed for Chrome */
z-index: 100; z-index: 100;
} }
/* Add styles for missing LoRAs download feature */
.recipe-status.missing {
position: relative;
cursor: pointer;
transition: background-color 0.2s ease;
}
.recipe-status.missing:hover {
background-color: rgba(var(--lora-warning-rgb, 255, 165, 0), 0.2);
}
.recipe-status.missing .missing-tooltip {
position: absolute;
display: none;
background-color: var(--card-bg);
color: var(--text-color);
padding: 8px 12px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
z-index: var(--z-overlay);
width: max-content;
max-width: 200px;
font-size: 0.85rem;
font-weight: normal;
margin-left: -100px;
margin-top: -65px;
}
.recipe-status.missing:hover .missing-tooltip {
display: block;
}
.recipe-status.clickable {
cursor: pointer;
padding: 4px 8px;
border-radius: var(--border-radius-xs);
}
.recipe-status.clickable:hover {
background-color: rgba(var(--lora-warning-rgb, 255, 165, 0), 0.2);
}

View File

@@ -38,6 +38,90 @@
flex-wrap: nowrap; flex-wrap: nowrap;
} }
/* Action button styling */
.control-group {
position: relative;
}
.control-group button {
min-width: 100px;
display: flex;
align-items: center;
justify-content: center;
gap: 4px;
border-radius: var(--border-radius-xs);
padding: 4px 10px;
border: 1px solid var(--border-color);
background: var(--card-bg);
color: var(--text-color);
font-size: 0.85em;
transition: all 0.2s ease;
cursor: pointer;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.control-group button:hover {
border-color: var(--lora-accent);
background: var(--bg-color);
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
}
.control-group button:active {
transform: translateY(0);
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.control-group button i {
opacity: 0.8;
transition: opacity 0.2s ease;
}
.control-group button:hover i {
opacity: 1;
}
/* Active state for buttons that can be toggled */
.control-group button.active {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
/* Select dropdown styling */
.control-group select {
min-width: 100px;
padding: 4px 26px 4px 10px;
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
background-color: var(--card-bg);
color: var(--text-color);
font-size: 0.85em;
appearance: none;
-webkit-appearance: none;
-moz-appearance: none;
background-image: url("data:image/svg+xml;charset=UTF-8,%3csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='none' stroke='currentColor' stroke-width='2' stroke-linecap='round' stroke-linejoin='round'%3e%3cpolyline points='6 9 12 15 18 9'%3e%3c/polyline%3e%3c/svg%3e");
background-repeat: no-repeat;
background-position: right 6px center;
background-size: 14px;
cursor: pointer;
transition: all 0.2s ease;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
}
.control-group select:hover {
border-color: var(--lora-accent);
background-color: var(--bg-color);
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
}
.control-group select:focus {
outline: none;
border-color: var(--lora-accent);
box-shadow: 0 0 0 2px oklch(var(--lora-accent) / 0.15);
}
/* Ensure hidden class works properly */ /* Ensure hidden class works properly */
.hidden { .hidden {
display: none !important; display: none !important;
@@ -86,12 +170,14 @@
justify-content: center; justify-content: center;
cursor: pointer; cursor: pointer;
transition: all 0.3s ease; transition: all 0.3s ease;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
} }
.toggle-folders-btn:hover { .toggle-folders-btn:hover {
background: var(--lora-accent); background: var(--lora-accent);
color: white; color: white;
transform: translateY(-2px); transform: translateY(-2px);
box-shadow: 0 3px 6px rgba(0, 0, 0, 0.1);
} }
.toggle-folders-btn i { .toggle-folders-btn i {
@@ -101,8 +187,9 @@
/* Icon-only button style */ /* Icon-only button style */
.icon-only { .icon-only {
min-width: unset !important; min-width: unset !important;
width: 36px !important; width: 32px !important;
padding: 0 !important; padding: 0 !important;
height: 32px !important;
} }
/* Rotate icon when folders are collapsed */ /* Rotate icon when folders are collapsed */
@@ -133,16 +220,25 @@
cursor: pointer; cursor: pointer;
padding: 2px 8px; padding: 2px 8px;
margin: 2px; margin: 2px;
border: 1px solid #ccc; border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs); border-radius: var(--border-radius-xs);
display: inline-block; display: inline-block;
line-height: 1.2; line-height: 1.2;
font-size: 14px; font-size: 14px;
background-color: var(--card-bg);
transition: all 0.2s ease;
}
.tag:hover {
border-color: var(--lora-accent);
background-color: oklch(var(--lora-accent) / 0.1);
transform: translateY(-1px);
} }
.tag.active { .tag.active {
background-color: #007bff; background-color: var(--lora-accent);
color: white; color: white;
border-color: var(--lora-accent);
} }
/* Back to Top Button */ /* Back to Top Button */
@@ -155,7 +251,7 @@
border-radius: 50%; border-radius: 50%;
background: var(--card-bg); background: var(--card-bg);
border: 1px solid var(--border-color); border: 1px solid var(--border-color);
color: var (--text-color); color: var(--text-color);
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; justify-content: center;
@@ -165,6 +261,7 @@
transform: translateY(10px); transform: translateY(10px);
transition: all 0.3s ease; transition: all 0.3s ease;
z-index: var(--z-overlay); z-index: var(--z-overlay);
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
} }
.back-to-top.visible { .back-to-top.visible {
@@ -174,9 +271,10 @@
} }
.back-to-top:hover { .back-to-top:hover {
background: var (--lora-accent); background: var(--lora-accent);
color: white; color: white;
transform: translateY(-2px); transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
} }
@media (max-width: 768px) { @media (max-width: 768px) {
@@ -203,19 +301,22 @@
} }
.toggle-folders-btn:hover { .toggle-folders-btn:hover {
transform: none; /* 移动端下禁用hover效果 */ transform: none; /* Disable hover effects on mobile */
}
.control-group button:hover {
transform: none; /* Disable hover effects on mobile */
}
.control-group select:hover {
transform: none; /* Disable hover effects on mobile */
}
.tag:hover {
transform: none; /* Disable hover effects on mobile */
} }
.back-to-top { .back-to-top {
bottom: 60px; /* Give some extra space from bottom on mobile */ bottom: 60px; /* Give some extra space from bottom on mobile */
} }
} }
/* Standardize button widths in controls */
.control-group button {
min-width: 100px;
display: flex;
align-items: center;
justify-content: center;
gap: 6px;
}

View File

@@ -18,6 +18,8 @@
@import 'components/search-filter.css'; @import 'components/search-filter.css';
@import 'components/bulk.css'; @import 'components/bulk.css';
@import 'components/shared.css'; @import 'components/shared.css';
@import 'components/filter-indicator.css';
@import 'components/initialization.css';
.initialization-notice { .initialization-notice {
display: flex; display: flex;

Binary file not shown.

After

Width:  |  Height:  |  Size: 22 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 91 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 338 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.8 MiB

View File

@@ -0,0 +1,507 @@
// filepath: d:\Workspace\ComfyUI\custom_nodes\ComfyUI-Lora-Manager\static\js\api\baseModelApi.js
import { state, getCurrentPageState } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { showDeleteModal, confirmDelete } from '../utils/modalUtils.js';
import { getSessionItem, saveMapToStorage } from '../utils/storageHelpers.js';
/**
* Shared functionality for handling models (loras and checkpoints)
*/
// Generic function to load more models with pagination
export async function loadMoreModels(options = {}) {
const {
resetPage = false,
updateFolders = false,
modelType = 'lora', // 'lora' or 'checkpoint'
createCardFunction,
endpoint = '/api/loras'
} = options;
const pageState = getCurrentPageState();
if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return;
pageState.isLoading = true;
document.body.classList.add('loading');
try {
// Reset to first page if requested
if (resetPage) {
pageState.currentPage = 1;
// Clear grid if resetting
const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid';
const grid = document.getElementById(gridId);
if (grid) grid.innerHTML = '';
}
const params = new URLSearchParams({
page: pageState.currentPage,
page_size: pageState.pageSize || 20,
sort_by: pageState.sortBy
});
if (pageState.activeFolder !== null) {
params.append('folder', pageState.activeFolder);
}
// Add search parameters if there's a search term
if (pageState.filters?.search) {
params.append('search', pageState.filters.search);
params.append('fuzzy', 'true');
// Add search option parameters if available
if (pageState.searchOptions) {
params.append('search_filename', pageState.searchOptions.filename.toString());
params.append('search_modelname', pageState.searchOptions.modelname.toString());
if (pageState.searchOptions.tags !== undefined) {
params.append('search_tags', pageState.searchOptions.tags.toString());
}
params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString());
}
}
// Add filter parameters if active
if (pageState.filters) {
// Handle tags filters
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
// Checkpoints API expects individual 'tag' parameters, Loras API expects comma-separated 'tags'
if (modelType === 'checkpoint') {
pageState.filters.tags.forEach(tag => {
params.append('tag', tag);
});
} else {
params.append('tags', pageState.filters.tags.join(','));
}
}
// Handle base model filters
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
if (modelType === 'checkpoint') {
pageState.filters.baseModel.forEach(model => {
params.append('base_model', model);
});
} else {
params.append('base_models', pageState.filters.baseModel.join(','));
}
}
}
// Add model-specific parameters
if (modelType === 'lora') {
// Check for recipe-based filtering parameters from session storage
const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash');
const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes');
// Add hash filter parameter if present
if (filterLoraHash) {
params.append('lora_hash', filterLoraHash);
}
// Add multiple hashes filter if present
else if (filterLoraHashes) {
try {
if (Array.isArray(filterLoraHashes) && filterLoraHashes.length > 0) {
params.append('lora_hashes', filterLoraHashes.join(','));
}
} catch (error) {
console.error('Error parsing lora hashes from session storage:', error);
}
}
}
const response = await fetch(`${endpoint}?${params}`);
if (!response.ok) {
throw new Error(`Failed to fetch models: ${response.statusText}`);
}
const data = await response.json();
const gridId = modelType === 'checkpoint' ? 'checkpointGrid' : 'loraGrid';
const grid = document.getElementById(gridId);
if (data.items.length === 0 && pageState.currentPage === 1) {
grid.innerHTML = `<div class="no-results">No ${modelType}s found in this folder</div>`;
pageState.hasMore = false;
} else if (data.items.length > 0) {
pageState.hasMore = pageState.currentPage < data.total_pages;
// Append model cards using the provided card creation function
data.items.forEach(model => {
const card = createCardFunction(model);
grid.appendChild(card);
});
// Increment the page number AFTER successful loading
pageState.currentPage++;
} else {
pageState.hasMore = false;
}
if (updateFolders && data.folders) {
updateFolderTags(data.folders);
}
} catch (error) {
console.error(`Error loading ${modelType}s:`, error);
showToast(`Failed to load ${modelType}s: ${error.message}`, 'error');
} finally {
pageState.isLoading = false;
document.body.classList.remove('loading');
}
}
// Update folder tags in the UI
export function updateFolderTags(folders) {
const folderTagsContainer = document.querySelector('.folder-tags');
if (!folderTagsContainer) return;
// Keep track of currently selected folder
const pageState = getCurrentPageState();
const currentFolder = pageState.activeFolder;
// Create HTML for folder tags
const tagsHTML = folders.map(folder => {
const isActive = folder === currentFolder;
return `<div class="tag ${isActive ? 'active' : ''}" data-folder="${folder}">${folder}</div>`;
}).join('');
// Update the container
folderTagsContainer.innerHTML = tagsHTML;
// Reattach click handlers and ensure the active tag is visible
const tags = folderTagsContainer.querySelectorAll('.tag');
tags.forEach(tag => {
if (typeof toggleFolder === 'function') {
tag.addEventListener('click', toggleFolder);
}
if (tag.dataset.folder === currentFolder) {
tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
});
}
// Generic function to replace a model preview
export function replaceModelPreview(filePath, modelType = 'lora') {
// Open file picker
const input = document.createElement('input');
input.type = 'file';
input.accept ='image/*,video/mp4';
input.onchange = async function() {
if (!input.files || !input.files[0]) return;
const file = input.files[0];
await uploadPreview(filePath, file, modelType);
};
input.click();
}
// Delete a model (generic)
export function deleteModel(filePath, modelType = 'lora') {
if (modelType === 'checkpoint') {
confirmDelete('Are you sure you want to delete this checkpoint?', () => {
performDelete(filePath, modelType);
});
} else {
showDeleteModal(filePath);
}
}
// Reset and reload models
export async function resetAndReload(options = {}) {
const {
updateFolders = false,
modelType = 'lora',
loadMoreFunction
} = options;
const pageState = getCurrentPageState();
// Reset pagination and load more models
if (typeof loadMoreFunction === 'function') {
await loadMoreFunction(true, updateFolders);
}
}
// Generic function to refresh models
export async function refreshModels(options = {}) {
const {
modelType = 'lora',
scanEndpoint = '/api/loras/scan',
resetAndReloadFunction
} = options;
try {
state.loadingManager.showSimpleLoading(`Refreshing ${modelType}s...`);
const response = await fetch(scanEndpoint);
if (!response.ok) {
throw new Error(`Failed to refresh ${modelType}s: ${response.status} ${response.statusText}`);
}
if (typeof resetAndReloadFunction === 'function') {
await resetAndReloadFunction();
}
showToast(`Refresh complete`, 'success');
} catch (error) {
console.error(`Refresh failed:`, error);
showToast(`Failed to refresh ${modelType}s`, 'error');
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
}
}
// Generic fetch from Civitai
export async function fetchCivitaiMetadata(options = {}) {
const {
modelType = 'lora',
fetchEndpoint = '/api/fetch-all-civitai',
resetAndReloadFunction
} = options;
let ws = null;
await state.loadingManager.showWithProgress(async (loading) => {
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const operationComplete = new Promise((resolve, reject) => {
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch(data.status) {
case 'started':
loading.setStatus('Starting metadata fetch...');
break;
case 'processing':
const percent = ((data.processed / data.total) * 100).toFixed(1);
loading.setProgress(percent);
loading.setStatus(
`Processing (${data.processed}/${data.total}) ${data.current_name}`
);
break;
case 'completed':
loading.setProgress(100);
loading.setStatus(
`Completed: Updated ${data.success} of ${data.processed} ${modelType}s`
);
resolve();
break;
case 'error':
reject(new Error(data.error));
break;
}
};
ws.onerror = (error) => {
reject(new Error('WebSocket error: ' + error.message));
};
});
await new Promise((resolve, reject) => {
ws.onopen = resolve;
ws.onerror = reject;
});
const requestBody = modelType === 'checkpoint'
? JSON.stringify({ model_type: 'checkpoint' })
: JSON.stringify({});
const response = await fetch(fetchEndpoint, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: requestBody
});
if (!response.ok) {
throw new Error('Failed to fetch metadata');
}
await operationComplete;
if (typeof resetAndReloadFunction === 'function') {
await resetAndReloadFunction();
}
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
} finally {
if (ws) {
ws.close();
}
}
}, {
initialMessage: 'Connecting...',
completionMessage: 'Metadata update complete'
});
}
// Generic function to refresh single model metadata
export async function refreshSingleModelMetadata(filePath, modelType = 'lora') {
try {
state.loadingManager.showSimpleLoading('Refreshing metadata...');
const endpoint = modelType === 'checkpoint'
? '/api/checkpoints/fetch-civitai'
: '/api/fetch-civitai';
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ file_path: filePath })
});
if (!response.ok) {
throw new Error('Failed to refresh metadata');
}
const data = await response.json();
if (data.success) {
showToast('Metadata refreshed successfully', 'success');
return true;
} else {
throw new Error(data.error || 'Failed to refresh metadata');
}
} catch (error) {
console.error('Error refreshing metadata:', error);
showToast(error.message, 'error');
return false;
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
}
}
// Private methods
// Upload a preview image
async function uploadPreview(filePath, file, modelType = 'lora') {
const loadingOverlay = document.getElementById('loading-overlay');
const loadingStatus = document.querySelector('.loading-status');
try {
if (loadingOverlay) loadingOverlay.style.display = 'flex';
if (loadingStatus) loadingStatus.textContent = 'Uploading preview...';
const formData = new FormData();
// Use appropriate parameter names and endpoint based on model type
// Prepare common form data
formData.append('preview_file', file);
formData.append('model_path', filePath);
// Set endpoint based on model type
const endpoint = modelType === 'checkpoint'
? '/api/checkpoints/replace-preview'
: '/api/replace_preview';
const response = await fetch(endpoint, {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error('Upload failed');
}
const data = await response.json();
// Update the card preview in UI
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
const previewContainer = card.querySelector('.card-preview');
const oldPreview = previewContainer.querySelector('img, video');
// Get the current page's previewVersions Map based on model type
const pageType = modelType === 'checkpoint' ? 'checkpoints' : 'loras';
const previewVersions = state.pages[pageType].previewVersions;
// Update the version timestamp
const timestamp = Date.now();
if (previewVersions) {
previewVersions.set(filePath, timestamp);
// Save the updated Map to localStorage
const storageKey = modelType === 'checkpoint' ? 'checkpoint_preview_versions' : 'lora_preview_versions';
saveMapToStorage(storageKey, previewVersions);
}
const previewUrl = data.preview_url ?
`${data.preview_url}?t=${timestamp}` :
`/api/model/preview_image?path=${encodeURIComponent(filePath)}&t=${timestamp}`;
// Create appropriate element based on file type
if (file.type.startsWith('video/')) {
const video = document.createElement('video');
video.controls = true;
video.autoplay = true;
video.muted = true;
video.loop = true;
video.src = previewUrl;
oldPreview.replaceWith(video);
} else {
const img = document.createElement('img');
img.src = previewUrl;
oldPreview.replaceWith(img);
}
showToast('Preview updated successfully', 'success');
}
} catch (error) {
console.error('Error uploading preview:', error);
showToast('Failed to upload preview image', 'error');
} finally {
if (loadingOverlay) loadingOverlay.style.display = 'none';
}
}
// Private function to perform the delete operation
async function performDelete(filePath, modelType = 'lora') {
try {
showToast(`Deleting ${modelType}...`, 'info');
const response = await fetch('/api/model/delete', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
file_path: filePath,
model_type: modelType
})
});
if (!response.ok) {
throw new Error(`Failed to delete ${modelType}: ${response.status} ${response.statusText}`);
}
const data = await response.json();
if (data.success) {
// Remove the card from UI
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
card.remove();
}
showToast(`${modelType} deleted successfully`, 'success');
} else {
throw new Error(data.error || `Failed to delete ${modelType}`);
}
} catch (error) {
console.error(`Error deleting ${modelType}:`, error);
showToast(`Failed to delete ${modelType}: ${error.message}`, 'error');
}
}

View File

@@ -0,0 +1,83 @@
import { createCheckpointCard } from '../components/CheckpointCard.js';
import {
loadMoreModels,
resetAndReload as baseResetAndReload,
refreshModels as baseRefreshModels,
deleteModel as baseDeleteModel,
replaceModelPreview,
fetchCivitaiMetadata,
refreshSingleModelMetadata
} from './baseModelApi.js';
// Load more checkpoints with pagination
export async function loadMoreCheckpoints(resetPagination = true) {
return loadMoreModels({
resetPage: resetPagination,
updateFolders: true,
modelType: 'checkpoint',
createCardFunction: createCheckpointCard,
endpoint: '/api/checkpoints'
});
}
// Reset and reload checkpoints
export async function resetAndReload() {
return baseResetAndReload({
updateFolders: true,
modelType: 'checkpoint',
loadMoreFunction: loadMoreCheckpoints
});
}
// Refresh checkpoints
export async function refreshCheckpoints() {
return baseRefreshModels({
modelType: 'checkpoint',
scanEndpoint: '/api/checkpoints/scan',
resetAndReloadFunction: resetAndReload
});
}
// Delete a checkpoint
export function deleteCheckpoint(filePath) {
return baseDeleteModel(filePath, 'checkpoint');
}
// Replace checkpoint preview
export function replaceCheckpointPreview(filePath) {
return replaceModelPreview(filePath, 'checkpoint');
}
// Fetch metadata from Civitai for checkpoints
export async function fetchCivitai() {
return fetchCivitaiMetadata({
modelType: 'checkpoint',
fetchEndpoint: '/api/checkpoints/fetch-all-civitai',
resetAndReloadFunction: resetAndReload
});
}
// Refresh single checkpoint metadata
export async function refreshSingleCheckpointMetadata(filePath) {
return refreshSingleModelMetadata(filePath, 'checkpoint');
}
// Save checkpoint metadata (similar to the Lora version)
export async function saveCheckpointMetadata(filePath, data) {
const response = await fetch('/api/checkpoints/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return await response.json();
}

View File

@@ -1,269 +1,38 @@
import { state, getCurrentPageState } from '../state/index.js';
import { showToast } from '../utils/uiHelpers.js';
import { createLoraCard } from '../components/LoraCard.js'; import { createLoraCard } from '../components/LoraCard.js';
import { initializeInfiniteScroll } from '../utils/infiniteScroll.js'; import {
import { showDeleteModal } from '../utils/modalUtils.js'; loadMoreModels,
import { toggleFolder } from '../utils/uiHelpers.js'; resetAndReload as baseResetAndReload,
refreshModels as baseRefreshModels,
deleteModel as baseDeleteModel,
replaceModelPreview,
fetchCivitaiMetadata,
refreshSingleModelMetadata
} from './baseModelApi.js';
export async function loadMoreLoras(resetPage = false, updateFolders = false) { export async function loadMoreLoras(resetPage = false, updateFolders = false) {
const pageState = getCurrentPageState(); return loadMoreModels({
resetPage,
if (pageState.isLoading || (!pageState.hasMore && !resetPage)) return; updateFolders,
modelType: 'lora',
pageState.isLoading = true; createCardFunction: createLoraCard,
try { endpoint: '/api/loras'
// Reset to first page if requested
if (resetPage) {
pageState.currentPage = 1;
// Clear grid if resetting
const grid = document.getElementById('loraGrid');
if (grid) grid.innerHTML = '';
initializeInfiniteScroll();
}
const params = new URLSearchParams({
page: pageState.currentPage,
page_size: 20,
sort_by: pageState.sortBy
});
if (pageState.activeFolder !== null) {
params.append('folder', pageState.activeFolder);
}
// Add search parameters if there's a search term
if (pageState.filters?.search) {
params.append('search', pageState.filters.search);
params.append('fuzzy', 'true');
// Add search option parameters if available
if (pageState.searchOptions) {
params.append('search_filename', pageState.searchOptions.filename.toString());
params.append('search_modelname', pageState.searchOptions.modelname.toString());
params.append('search_tags', (pageState.searchOptions.tags || false).toString());
params.append('recursive', (pageState.searchOptions?.recursive ?? false).toString());
}
}
// Add filter parameters if active
if (pageState.filters) {
if (pageState.filters.tags && pageState.filters.tags.length > 0) {
// Convert the array of tags to a comma-separated string
params.append('tags', pageState.filters.tags.join(','));
}
if (pageState.filters.baseModel && pageState.filters.baseModel.length > 0) {
// Convert the array of base models to a comma-separated string
params.append('base_models', pageState.filters.baseModel.join(','));
}
}
const response = await fetch(`/api/loras?${params}`);
if (!response.ok) {
throw new Error(`Failed to fetch loras: ${response.statusText}`);
}
const data = await response.json();
if (data.items.length === 0 && pageState.currentPage === 1) {
const grid = document.getElementById('loraGrid');
grid.innerHTML = '<div class="no-results">No loras found in this folder</div>';
pageState.hasMore = false;
} else if (data.items.length > 0) {
pageState.hasMore = pageState.currentPage < data.total_pages;
pageState.currentPage++;
appendLoraCards(data.items);
const sentinel = document.getElementById('scroll-sentinel');
if (sentinel && state.observer) {
state.observer.observe(sentinel);
}
} else {
pageState.hasMore = false;
}
if (updateFolders && data.folders) {
updateFolderTags(data.folders);
}
} catch (error) {
console.error('Error loading loras:', error);
showToast('Failed to load loras: ' + error.message, 'error');
} finally {
pageState.isLoading = false;
}
}
function updateFolderTags(folders) {
const folderTagsContainer = document.querySelector('.folder-tags');
if (!folderTagsContainer) return;
// Keep track of currently selected folder
const pageState = getCurrentPageState();
const currentFolder = pageState.activeFolder;
// Create HTML for folder tags
const tagsHTML = folders.map(folder => {
const isActive = folder === currentFolder;
return `<div class="tag ${isActive ? 'active' : ''}" data-folder="${folder}">${folder}</div>`;
}).join('');
// Update the container
folderTagsContainer.innerHTML = tagsHTML;
// Reattach click handlers and ensure the active tag is visible
const tags = folderTagsContainer.querySelectorAll('.tag');
tags.forEach(tag => {
tag.addEventListener('click', toggleFolder);
if (tag.dataset.folder === currentFolder) {
tag.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
}); });
} }
export async function fetchCivitai() { export async function fetchCivitai() {
let ws = null; return fetchCivitaiMetadata({
modelType: 'lora',
await state.loadingManager.showWithProgress(async (loading) => { fetchEndpoint: '/api/fetch-all-civitai',
try { resetAndReloadFunction: resetAndReload
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
const ws = new WebSocket(`${wsProtocol}${window.location.host}/ws/fetch-progress`);
const operationComplete = new Promise((resolve, reject) => {
ws.onmessage = (event) => {
const data = JSON.parse(event.data);
switch(data.status) {
case 'started':
loading.setStatus('Starting metadata fetch...');
break;
case 'processing':
const percent = ((data.processed / data.total) * 100).toFixed(1);
loading.setProgress(percent);
loading.setStatus(
`Processing (${data.processed}/${data.total}) ${data.current_name}`
);
break;
case 'completed':
loading.setProgress(100);
loading.setStatus(
`Completed: Updated ${data.success} of ${data.processed} loras`
);
resolve();
break;
case 'error':
reject(new Error(data.error));
break;
}
};
ws.onerror = (error) => {
reject(new Error('WebSocket error: ' + error.message));
};
});
await new Promise((resolve, reject) => {
ws.onopen = resolve;
ws.onerror = reject;
});
const response = await fetch('/api/fetch-all-civitai', {
method: 'POST',
headers: { 'Content-Type': 'application/json' }
});
if (!response.ok) {
throw new Error('Failed to fetch metadata');
}
await operationComplete;
await resetAndReload();
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
} finally {
if (ws) {
ws.close();
}
}
}, {
initialMessage: 'Connecting...',
completionMessage: 'Metadata update complete'
}); });
} }
export async function deleteModel(filePath) { export async function deleteModel(filePath) {
showDeleteModal(filePath); return baseDeleteModel(filePath, 'lora');
} }
export async function replacePreview(filePath) { export async function replacePreview(filePath) {
const loadingOverlay = document.getElementById('loading-overlay'); return replaceModelPreview(filePath, 'lora');
const loadingStatus = document.querySelector('.loading-status');
const input = document.createElement('input');
input.type = 'file';
input.accept = 'image/*,video/mp4';
input.onchange = async function() {
if (!input.files || !input.files[0]) return;
const file = input.files[0];
const formData = new FormData();
formData.append('preview_file', file);
formData.append('model_path', filePath);
try {
loadingOverlay.style.display = 'flex';
loadingStatus.textContent = 'Uploading preview...';
const response = await fetch('/api/replace_preview', {
method: 'POST',
body: formData
});
if (!response.ok) {
throw new Error('Upload failed');
}
const data = await response.json();
// 更新预览版本
state.previewVersions.set(filePath, Date.now());
// 更新卡片显示
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
const previewContainer = card.querySelector('.card-preview');
const oldPreview = previewContainer.querySelector('img, video');
const previewUrl = `${data.preview_url}?t=${state.previewVersions.get(filePath)}`;
if (file.type.startsWith('video/')) {
const video = document.createElement('video');
video.controls = true;
video.autoplay = true;
video.muted = true;
video.loop = true;
video.src = previewUrl;
oldPreview.replaceWith(video);
} else {
const img = document.createElement('img');
img.src = previewUrl;
oldPreview.replaceWith(img);
}
} catch (error) {
console.error('Error uploading preview:', error);
alert('Failed to upload preview image');
} finally {
loadingOverlay.style.display = 'none';
}
};
input.click();
} }
export function appendLoraCards(loras) { export function appendLoraCards(loras) {
@@ -277,60 +46,26 @@ export function appendLoraCards(loras) {
} }
export async function resetAndReload(updateFolders = false) { export async function resetAndReload(updateFolders = false) {
const pageState = getCurrentPageState(); return baseResetAndReload({
console.log('Resetting with state:', { ...pageState }); updateFolders,
modelType: 'lora',
// Initialize infinite scroll - will reset the observer loadMoreFunction: loadMoreLoras
initializeInfiniteScroll(); });
// Load more loras with reset flag
await loadMoreLoras(true, updateFolders);
} }
export async function refreshLoras() { export async function refreshLoras() {
try { return baseRefreshModels({
state.loadingManager.showSimpleLoading('Refreshing loras...'); modelType: 'lora',
await resetAndReload(); scanEndpoint: '/api/loras/scan',
showToast('Refresh complete', 'success'); resetAndReloadFunction: resetAndReload
} catch (error) { });
console.error('Refresh failed:', error);
showToast('Failed to refresh loras', 'error');
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
}
} }
export async function refreshSingleLoraMetadata(filePath) { export async function refreshSingleLoraMetadata(filePath) {
try { const success = await refreshSingleModelMetadata(filePath, 'lora');
state.loadingManager.showSimpleLoading('Refreshing metadata...'); if (success) {
const response = await fetch('/api/fetch-civitai', { // Reload the current view to show updated data
method: 'POST', await resetAndReload();
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({ file_path: filePath })
});
if (!response.ok) {
throw new Error('Failed to refresh metadata');
}
const data = await response.json();
if (data.success) {
showToast('Metadata refreshed successfully', 'success');
// Reload the current view to show updated data
await resetAndReload();
} else {
throw new Error(data.error || 'Failed to refresh metadata');
}
} catch (error) {
console.error('Error refreshing metadata:', error);
showToast(error.message, 'error');
} finally {
state.loadingManager.hide();
state.loadingManager.restoreProgressBar();
} }
} }

View File

@@ -1,36 +1,59 @@
import { appCore } from './core.js'; import { appCore } from './core.js';
import { state, initPageState } from './state/index.js'; import { initializeInfiniteScroll } from './utils/infiniteScroll.js';
import { confirmDelete, closeDeleteModal } from './utils/modalUtils.js';
import { createPageControls } from './components/controls/index.js';
import { loadMoreCheckpoints } from './api/checkpointApi.js';
import { CheckpointDownloadManager } from './managers/CheckpointDownloadManager.js';
import { CheckpointContextMenu } from './components/ContextMenu/index.js';
// Initialize the Checkpoints page // Initialize the Checkpoints page
class CheckpointsPageManager { class CheckpointsPageManager {
constructor() { constructor() {
// Initialize any necessary state // Initialize page controls
this.initialized = false; this.pageControls = createPageControls('checkpoints');
// Initialize checkpoint download manager
window.checkpointDownloadManager = new CheckpointDownloadManager();
// Expose only necessary functions to global scope
this._exposeRequiredGlobalFunctions();
}
_exposeRequiredGlobalFunctions() {
// Minimal set of functions that need to remain global
window.confirmDelete = confirmDelete;
window.closeDeleteModal = closeDeleteModal;
// Add loadCheckpoints function to window for FilterManager compatibility
window.checkpointManager = {
loadCheckpoints: (reset) => loadMoreCheckpoints(reset)
};
} }
async initialize() { async initialize() {
if (this.initialized) return;
// Initialize page state
initPageState('checkpoints');
// Initialize core application
await appCore.initialize();
// Initialize page-specific components // Initialize page-specific components
this._initializeWorkInProgress(); this.pageControls.restoreFolderFilter();
this.pageControls.initFolderTagsVisibility();
this.initialized = true; // Initialize context menu
} new CheckpointContextMenu();
_initializeWorkInProgress() { // Initialize infinite scroll
// Add any work-in-progress specific initialization here initializeInfiniteScroll('checkpoints');
console.log('Checkpoints Manager is under development');
// Initialize common page features
appCore.initializePageFeatures();
console.log('Checkpoints Manager initialized');
} }
} }
// Initialize everything when DOM is ready // Initialize everything when DOM is ready
document.addEventListener('DOMContentLoaded', async () => { document.addEventListener('DOMContentLoaded', async () => {
// Initialize core application
await appCore.initialize();
// Initialize checkpoints page
const checkpointsPage = new CheckpointsPageManager(); const checkpointsPage = new CheckpointsPageManager();
await checkpointsPage.initialize(); await checkpointsPage.initialize();
}); });

View File

@@ -0,0 +1,302 @@
import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js';
import { showCheckpointModal } from './checkpointModal/index.js';
import { NSFW_LEVELS } from '../utils/constants.js';
import { replaceCheckpointPreview as apiReplaceCheckpointPreview } from '../api/checkpointApi.js';
export function createCheckpointCard(checkpoint) {
const card = document.createElement('div');
card.className = 'lora-card'; // Reuse the same class for styling
card.dataset.sha256 = checkpoint.sha256;
card.dataset.filepath = checkpoint.file_path;
card.dataset.name = checkpoint.model_name;
card.dataset.file_name = checkpoint.file_name;
card.dataset.folder = checkpoint.folder;
card.dataset.modified = checkpoint.modified;
card.dataset.file_size = checkpoint.file_size;
card.dataset.from_civitai = checkpoint.from_civitai;
card.dataset.notes = checkpoint.notes || '';
card.dataset.base_model = checkpoint.base_model || 'Unknown';
// Store metadata if available
if (checkpoint.civitai) {
card.dataset.meta = JSON.stringify(checkpoint.civitai || {});
}
// Store tags if available
if (checkpoint.tags && Array.isArray(checkpoint.tags)) {
card.dataset.tags = JSON.stringify(checkpoint.tags);
}
if (checkpoint.modelDescription) {
card.dataset.modelDescription = checkpoint.modelDescription;
}
// Store NSFW level if available
const nsfwLevel = checkpoint.preview_nsfw_level !== undefined ? checkpoint.preview_nsfw_level : 0;
card.dataset.nsfwLevel = nsfwLevel;
// Determine if the preview should be blurred based on NSFW level and user settings
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
if (shouldBlur) {
card.classList.add('nsfw-content');
}
// Determine preview URL
const previewUrl = checkpoint.preview_url || '/loras_static/images/no-preview.png';
// Get the page-specific previewVersions map
const previewVersions = state.pages.checkpoints.previewVersions || new Map();
const version = previewVersions.get(checkpoint.file_path);
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (nsfwLevel >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Check if autoplayOnHover is enabled for video previews
const autoplayOnHover = state.global?.settings?.autoplayOnHover || false;
const isVideo = previewUrl.endsWith('.mp4');
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
card.innerHTML = `
<div class="card-preview ${shouldBlur ? 'blurred' : ''}">
${isVideo ?
`<video ${videoAttrs}>
<source src="${versionedPreviewUrl}" type="video/mp4">
</video>` :
`<img src="${versionedPreviewUrl}" alt="${checkpoint.model_name}">`
}
<div class="card-header">
${shouldBlur ?
`<button class="toggle-blur-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>` : ''}
<span class="base-model-label ${shouldBlur ? 'with-toggle' : ''}" title="${checkpoint.base_model}">
${checkpoint.base_model}
</span>
<div class="card-actions">
<i class="fas fa-globe"
title="${checkpoint.from_civitai ? 'View on Civitai' : 'Not available from Civitai'}"
${!checkpoint.from_civitai ? 'style="opacity: 0.5; cursor: not-allowed"' : ''}>
</i>
<i class="fas fa-copy"
title="Copy Checkpoint Name">
</i>
<i class="fas fa-trash"
title="Delete Model">
</i>
</div>
</div>
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
<div class="card-footer">
<div class="model-info">
<span class="model-name">${checkpoint.model_name}</span>
</div>
<div class="card-actions">
<i class="fas fa-image"
title="Replace Preview Image">
</i>
</div>
</div>
</div>
`;
// Main card click event
card.addEventListener('click', () => {
// Show checkpoint details modal
const checkpointMeta = {
sha256: card.dataset.sha256,
file_path: card.dataset.filepath,
model_name: card.dataset.name,
file_name: card.dataset.file_name,
folder: card.dataset.folder,
modified: card.dataset.modified,
file_size: parseInt(card.dataset.file_size || '0'),
from_civitai: card.dataset.from_civitai === 'true',
base_model: card.dataset.base_model,
notes: card.dataset.notes || '',
preview_url: versionedPreviewUrl,
// Parse civitai metadata from the card's dataset
civitai: (() => {
try {
return JSON.parse(card.dataset.meta || '{}');
} catch (e) {
console.error('Failed to parse civitai metadata:', e);
return {}; // Return empty object on error
}
})(),
tags: (() => {
try {
return JSON.parse(card.dataset.tags || '[]');
} catch (e) {
console.error('Failed to parse tags:', e);
return []; // Return empty array on error
}
})(),
modelDescription: card.dataset.modelDescription || ''
};
showCheckpointModal(checkpointMeta);
});
// Toggle blur button functionality
const toggleBlurBtn = card.querySelector('.toggle-blur-btn');
if (toggleBlurBtn) {
toggleBlurBtn.addEventListener('click', (e) => {
e.stopPropagation();
const preview = card.querySelector('.card-preview');
const isBlurred = preview.classList.toggle('blurred');
const icon = toggleBlurBtn.querySelector('i');
// Update the icon based on blur state
if (isBlurred) {
icon.className = 'fas fa-eye';
} else {
icon.className = 'fas fa-eye-slash';
}
// Toggle the overlay visibility
const overlay = card.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = isBlurred ? 'flex' : 'none';
}
});
}
// Show content button functionality
const showContentBtn = card.querySelector('.show-content-btn');
if (showContentBtn) {
showContentBtn.addEventListener('click', (e) => {
e.stopPropagation();
const preview = card.querySelector('.card-preview');
preview.classList.remove('blurred');
// Update the toggle button icon
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
// Hide the overlay
const overlay = card.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = 'none';
}
});
}
// Copy button click event
card.querySelector('.fa-copy')?.addEventListener('click', async e => {
e.stopPropagation();
const checkpointName = card.dataset.file_name;
try {
await copyToClipboard(checkpointName, 'Checkpoint name copied');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
});
// Civitai button click event
if (checkpoint.from_civitai) {
card.querySelector('.fa-globe')?.addEventListener('click', e => {
e.stopPropagation();
openCivitai(checkpoint.model_name);
});
}
// Delete button click event
card.querySelector('.fa-trash')?.addEventListener('click', e => {
e.stopPropagation();
deleteCheckpoint(checkpoint.file_path);
});
// Replace preview button click event
card.querySelector('.fa-image')?.addEventListener('click', e => {
e.stopPropagation();
replaceCheckpointPreview(checkpoint.file_path);
});
// Add autoplayOnHover handlers for video elements if needed
const videoElement = card.querySelector('video');
if (videoElement && autoplayOnHover) {
const cardPreview = card.querySelector('.card-preview');
// Remove autoplay attribute and pause initially
videoElement.removeAttribute('autoplay');
videoElement.pause();
// Add mouse events to trigger play/pause
cardPreview.addEventListener('mouseenter', () => {
videoElement.play();
});
cardPreview.addEventListener('mouseleave', () => {
videoElement.pause();
videoElement.currentTime = 0;
});
}
return card;
}
// These functions will be implemented in checkpointApi.js
function openCivitai(modelName) {
// Check if the global function exists (registered by PageControls)
if (window.openCivitai) {
window.openCivitai(modelName);
} else {
// Fallback implementation
const card = document.querySelector(`.lora-card[data-name="${modelName}"]`);
if (!card) return;
const metaData = JSON.parse(card.dataset.meta || '{}');
const civitaiId = metaData.modelId;
const versionId = metaData.id;
// Build URL
if (civitaiId) {
let url = `https://civitai.com/models/${civitaiId}`;
if (versionId) {
url += `?modelVersionId=${versionId}`;
}
window.open(url, '_blank');
} else {
// If no ID, try searching by name
window.open(`https://civitai.com/models?query=${encodeURIComponent(modelName)}`, '_blank');
}
}
}
function deleteCheckpoint(filePath) {
if (window.deleteCheckpoint) {
window.deleteCheckpoint(filePath);
} else {
// Use the modal delete functionality
import('../utils/modalUtils.js').then(({ showDeleteModal }) => {
showDeleteModal(filePath, 'checkpoint');
});
}
}
function replaceCheckpointPreview(filePath) {
if (window.replaceCheckpointPreview) {
window.replaceCheckpointPreview(filePath);
} else {
apiReplaceCheckpointPreview(filePath);
}
}

View File

@@ -130,7 +130,7 @@ export class LoraContextMenu {
} }
async saveModelMetadata(filePath, data) { async saveModelMetadata(filePath, data) {
const response = await fetch('/loras/api/save-metadata', { const response = await fetch('/api/loras/save-metadata', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
@@ -366,4 +366,7 @@ export class LoraContextMenu {
this.menu.style.display = 'none'; this.menu.style.display = 'none';
this.currentCard = null; this.currentCard = null;
} }
} }
// For backward compatibility, re-export the LoraContextMenu class
// export { LoraContextMenu } from './ContextMenu/LoraContextMenu.js';

View File

@@ -0,0 +1,84 @@
export class BaseContextMenu {
constructor(menuId, cardSelector) {
this.menu = document.getElementById(menuId);
this.cardSelector = cardSelector;
this.currentCard = null;
if (!this.menu) {
console.error(`Context menu element with ID ${menuId} not found`);
return;
}
this.init();
}
init() {
// Hide menu on regular clicks
document.addEventListener('click', () => this.hideMenu());
// Show menu on right-click on cards
document.addEventListener('contextmenu', (e) => {
const card = e.target.closest(this.cardSelector);
if (!card) {
this.hideMenu();
return;
}
e.preventDefault();
this.showMenu(e.clientX, e.clientY, card);
});
// Handle menu item clicks
this.menu.addEventListener('click', (e) => {
const menuItem = e.target.closest('.context-menu-item');
if (!menuItem || !this.currentCard) return;
const action = menuItem.dataset.action;
if (!action) return;
this.handleMenuAction(action, menuItem);
this.hideMenu();
});
}
handleMenuAction(action, menuItem) {
// Override in subclass
console.warn('handleMenuAction not implemented');
}
showMenu(x, y, card) {
this.currentCard = card;
this.menu.style.display = 'block';
// Get menu dimensions
const menuRect = this.menu.getBoundingClientRect();
// Get viewport dimensions
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
// Calculate position
let finalX = x;
let finalY = y;
// Ensure menu doesn't go offscreen right
if (x + menuRect.width > viewportWidth) {
finalX = x - menuRect.width;
}
// Ensure menu doesn't go offscreen bottom
if (y + menuRect.height > viewportHeight) {
finalY = y - menuRect.height;
}
// Position menu
this.menu.style.left = `${finalX}px`;
this.menu.style.top = `${finalY}px`;
}
hideMenu() {
if (this.menu) {
this.menu.style.display = 'none';
}
this.currentCard = null;
}
}

View File

@@ -0,0 +1,315 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { refreshSingleCheckpointMetadata, saveCheckpointMetadata } from '../../api/checkpointApi.js';
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
import { getStorageItem } from '../../utils/storageHelpers.js';
export class CheckpointContextMenu extends BaseContextMenu {
constructor() {
super('checkpointContextMenu', '.lora-card');
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
// Initialize NSFW Level Selector events
if (this.nsfwSelector) {
this.initNSFWSelector();
}
}
handleMenuAction(action) {
switch(action) {
case 'details':
// Show checkpoint details
this.currentCard.click();
break;
case 'preview':
// Replace checkpoint preview
if (this.currentCard.querySelector('.fa-image')) {
this.currentCard.querySelector('.fa-image').click();
}
break;
case 'civitai':
// Open civitai page
if (this.currentCard.dataset.from_civitai === 'true') {
if (this.currentCard.querySelector('.fa-globe')) {
this.currentCard.querySelector('.fa-globe').click();
}
} else {
showToast('No CivitAI information available', 'info');
}
break;
case 'delete':
// Delete checkpoint
if (this.currentCard.querySelector('.fa-trash')) {
this.currentCard.querySelector('.fa-trash').click();
}
break;
case 'copyname':
// Copy checkpoint name
if (this.currentCard.querySelector('.fa-copy')) {
this.currentCard.querySelector('.fa-copy').click();
}
break;
case 'refresh-metadata':
// Refresh metadata from CivitAI
refreshSingleCheckpointMetadata(this.currentCard.dataset.filepath);
break;
case 'set-nsfw':
// Set NSFW level
this.showNSFWLevelSelector(null, null, this.currentCard);
break;
case 'move':
// Move to folder (placeholder)
showToast('Move to folder feature coming soon', 'info');
break;
}
}
// NSFW Selector methods
initNSFWSelector() {
// Close button
const closeBtn = this.nsfwSelector.querySelector('.close-nsfw-selector');
closeBtn.addEventListener('click', () => {
this.nsfwSelector.style.display = 'none';
});
// Level buttons
const levelButtons = this.nsfwSelector.querySelectorAll('.nsfw-level-btn');
levelButtons.forEach(btn => {
btn.addEventListener('click', async () => {
const level = parseInt(btn.dataset.level);
const filePath = this.nsfwSelector.dataset.cardPath;
if (!filePath) return;
try {
await saveCheckpointMetadata(filePath, { preview_nsfw_level: level });
// Update card data
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
let metaData = {};
try {
metaData = JSON.parse(card.dataset.meta || '{}');
} catch (err) {
console.error('Error parsing metadata:', err);
}
metaData.preview_nsfw_level = level;
card.dataset.meta = JSON.stringify(metaData);
card.dataset.nsfwLevel = level.toString();
// Apply blur effect immediately
this.updateCardBlurEffect(card, level);
}
showToast(`Content rating set to ${getNSFWLevelName(level)}`, 'success');
this.nsfwSelector.style.display = 'none';
} catch (error) {
showToast(`Failed to set content rating: ${error.message}`, 'error');
}
});
});
// Close when clicking outside
document.addEventListener('click', (e) => {
if (this.nsfwSelector.style.display === 'block' &&
!this.nsfwSelector.contains(e.target) &&
!e.target.closest('.context-menu-item[data-action="set-nsfw"]')) {
this.nsfwSelector.style.display = 'none';
}
});
}
updateCardBlurEffect(card, level) {
// Get user settings for blur threshold
const blurThreshold = parseInt(getStorageItem('nsfwBlurLevel') || '4');
// Get card preview container
const previewContainer = card.querySelector('.card-preview');
if (!previewContainer) return;
// Get preview media element
const previewMedia = previewContainer.querySelector('img') || previewContainer.querySelector('video');
if (!previewMedia) return;
// Check if blur should be applied
if (level >= blurThreshold) {
// Add blur class to the preview container
previewContainer.classList.add('blurred');
// Get or create the NSFW overlay
let nsfwOverlay = previewContainer.querySelector('.nsfw-overlay');
if (!nsfwOverlay) {
// Create new overlay
nsfwOverlay = document.createElement('div');
nsfwOverlay.className = 'nsfw-overlay';
// Create and configure the warning content
const warningContent = document.createElement('div');
warningContent.className = 'nsfw-warning';
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Add warning text and show button
warningContent.innerHTML = `
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
`;
// Add click event to the show button
const showBtn = warningContent.querySelector('.show-content-btn');
showBtn.addEventListener('click', (e) => {
e.stopPropagation();
previewContainer.classList.remove('blurred');
nsfwOverlay.style.display = 'none';
// Update toggle button icon if it exists
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
});
nsfwOverlay.appendChild(warningContent);
previewContainer.appendChild(nsfwOverlay);
} else {
// Update existing overlay
const warningText = nsfwOverlay.querySelector('p');
if (warningText) {
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
warningText.textContent = nsfwText;
}
nsfwOverlay.style.display = 'flex';
}
// Get or create the toggle button in the header
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
let toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (!toggleBtn) {
toggleBtn = document.createElement('button');
toggleBtn.className = 'toggle-blur-btn';
toggleBtn.title = 'Toggle blur';
toggleBtn.innerHTML = '<i class="fas fa-eye"></i>';
// Add click event to toggle button
toggleBtn.addEventListener('click', (e) => {
e.stopPropagation();
const isBlurred = previewContainer.classList.toggle('blurred');
const icon = toggleBtn.querySelector('i');
// Update icon and overlay visibility
if (isBlurred) {
icon.className = 'fas fa-eye';
nsfwOverlay.style.display = 'flex';
} else {
icon.className = 'fas fa-eye-slash';
nsfwOverlay.style.display = 'none';
}
});
// Add to the beginning of header
cardHeader.insertBefore(toggleBtn, cardHeader.firstChild);
// Update base model label class
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && !baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.add('with-toggle');
}
} else {
// Update existing toggle button
toggleBtn.querySelector('i').className = 'fas fa-eye';
}
}
} else {
// Remove blur
previewContainer.classList.remove('blurred');
// Hide overlay if it exists
const overlay = previewContainer.querySelector('.nsfw-overlay');
if (overlay) overlay.style.display = 'none';
// Remove toggle button when content is set to PG or PG13
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
const toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (toggleBtn) {
// Remove the toggle button completely
toggleBtn.remove();
// Update base model label class if it exists
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.remove('with-toggle');
}
}
}
}
}
showNSFWLevelSelector(x, y, card) {
const selector = document.getElementById('nsfwLevelSelector');
const currentLevelEl = document.getElementById('currentNSFWLevel');
// Get current NSFW level
let currentLevel = 0;
try {
const metaData = JSON.parse(card.dataset.meta || '{}');
currentLevel = metaData.preview_nsfw_level || 0;
// Update if we have no recorded level but have a dataset attribute
if (!currentLevel && card.dataset.nsfwLevel) {
currentLevel = parseInt(card.dataset.nsfwLevel) || 0;
}
} catch (err) {
console.error('Error parsing metadata:', err);
}
currentLevelEl.textContent = getNSFWLevelName(currentLevel);
// Position the selector
if (x && y) {
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
const selectorRect = selector.getBoundingClientRect();
// Center the selector if no coordinates provided
let finalX = (viewportWidth - selectorRect.width) / 2;
let finalY = (viewportHeight - selectorRect.height) / 2;
selector.style.left = `${finalX}px`;
selector.style.top = `${finalY}px`;
}
// Highlight current level button
document.querySelectorAll('.nsfw-level-btn').forEach(btn => {
if (parseInt(btn.dataset.level) === currentLevel) {
btn.classList.add('active');
} else {
btn.classList.remove('active');
}
});
// Store reference to current card
selector.dataset.cardPath = card.dataset.filepath;
// Show selector
selector.style.display = 'block';
}
}

View File

@@ -0,0 +1,324 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { refreshSingleLoraMetadata } from '../../api/loraApi.js';
import { showToast, getNSFWLevelName } from '../../utils/uiHelpers.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
import { getStorageItem } from '../../utils/storageHelpers.js';
export class LoraContextMenu extends BaseContextMenu {
constructor() {
super('loraContextMenu', '.lora-card');
this.nsfwSelector = document.getElementById('nsfwLevelSelector');
// Initialize NSFW Level Selector events
if (this.nsfwSelector) {
this.initNSFWSelector();
}
}
handleMenuAction(action, menuItem) {
switch(action) {
case 'detail':
// Trigger the main card click which shows the modal
this.currentCard.click();
break;
case 'civitai':
// Only trigger if the card is from civitai
if (this.currentCard.dataset.from_civitai === 'true') {
if (this.currentCard.dataset.meta === '{}') {
showToast('Please fetch metadata from CivitAI first', 'info');
} else {
this.currentCard.querySelector('.fa-globe')?.click();
}
} else {
showToast('No CivitAI information available', 'info');
}
break;
case 'copyname':
this.currentCard.querySelector('.fa-copy')?.click();
break;
case 'preview':
this.currentCard.querySelector('.fa-image')?.click();
break;
case 'delete':
this.currentCard.querySelector('.fa-trash')?.click();
break;
case 'move':
moveManager.showMoveModal(this.currentCard.dataset.filepath);
break;
case 'refresh-metadata':
refreshSingleLoraMetadata(this.currentCard.dataset.filepath);
break;
case 'set-nsfw':
this.showNSFWLevelSelector(null, null, this.currentCard);
break;
}
}
// NSFW Selector methods from the original context menu
initNSFWSelector() {
// Close button
const closeBtn = this.nsfwSelector.querySelector('.close-nsfw-selector');
closeBtn.addEventListener('click', () => {
this.nsfwSelector.style.display = 'none';
});
// Level buttons
const levelButtons = this.nsfwSelector.querySelectorAll('.nsfw-level-btn');
levelButtons.forEach(btn => {
btn.addEventListener('click', async () => {
const level = parseInt(btn.dataset.level);
const filePath = this.nsfwSelector.dataset.cardPath;
if (!filePath) return;
try {
await this.saveModelMetadata(filePath, { preview_nsfw_level: level });
// Update card data
const card = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (card) {
let metaData = {};
try {
metaData = JSON.parse(card.dataset.meta || '{}');
} catch (err) {
console.error('Error parsing metadata:', err);
}
metaData.preview_nsfw_level = level;
card.dataset.meta = JSON.stringify(metaData);
card.dataset.nsfwLevel = level.toString();
// Apply blur effect immediately
this.updateCardBlurEffect(card, level);
}
showToast(`Content rating set to ${getNSFWLevelName(level)}`, 'success');
this.nsfwSelector.style.display = 'none';
} catch (error) {
showToast(`Failed to set content rating: ${error.message}`, 'error');
}
});
});
// Close when clicking outside
document.addEventListener('click', (e) => {
if (this.nsfwSelector.style.display === 'block' &&
!this.nsfwSelector.contains(e.target) &&
!e.target.closest('.context-menu-item[data-action="set-nsfw"]')) {
this.nsfwSelector.style.display = 'none';
}
});
}
async saveModelMetadata(filePath, data) {
const response = await fetch('/api/loras/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return await response.json();
}
updateCardBlurEffect(card, level) {
// Get user settings for blur threshold
const blurThreshold = parseInt(getStorageItem('nsfwBlurLevel') || '4');
// Get card preview container
const previewContainer = card.querySelector('.card-preview');
if (!previewContainer) return;
// Get preview media element
const previewMedia = previewContainer.querySelector('img') || previewContainer.querySelector('video');
if (!previewMedia) return;
// Check if blur should be applied
if (level >= blurThreshold) {
// Add blur class to the preview container
previewContainer.classList.add('blurred');
// Get or create the NSFW overlay
let nsfwOverlay = previewContainer.querySelector('.nsfw-overlay');
if (!nsfwOverlay) {
// Create new overlay
nsfwOverlay = document.createElement('div');
nsfwOverlay.className = 'nsfw-overlay';
// Create and configure the warning content
const warningContent = document.createElement('div');
warningContent.className = 'nsfw-warning';
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Add warning text and show button
warningContent.innerHTML = `
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
`;
// Add click event to the show button
const showBtn = warningContent.querySelector('.show-content-btn');
showBtn.addEventListener('click', (e) => {
e.stopPropagation();
previewContainer.classList.remove('blurred');
nsfwOverlay.style.display = 'none';
// Update toggle button icon if it exists
const toggleBtn = card.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
});
nsfwOverlay.appendChild(warningContent);
previewContainer.appendChild(nsfwOverlay);
} else {
// Update existing overlay
const warningText = nsfwOverlay.querySelector('p');
if (warningText) {
let nsfwText = "Mature Content";
if (level >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (level >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (level >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
warningText.textContent = nsfwText;
}
nsfwOverlay.style.display = 'flex';
}
// Get or create the toggle button in the header
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
let toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (!toggleBtn) {
toggleBtn = document.createElement('button');
toggleBtn.className = 'toggle-blur-btn';
toggleBtn.title = 'Toggle blur';
toggleBtn.innerHTML = '<i class="fas fa-eye"></i>';
// Add click event to toggle button
toggleBtn.addEventListener('click', (e) => {
e.stopPropagation();
const isBlurred = previewContainer.classList.toggle('blurred');
const icon = toggleBtn.querySelector('i');
// Update icon and overlay visibility
if (isBlurred) {
icon.className = 'fas fa-eye';
nsfwOverlay.style.display = 'flex';
} else {
icon.className = 'fas fa-eye-slash';
nsfwOverlay.style.display = 'none';
}
});
// Add to the beginning of header
cardHeader.insertBefore(toggleBtn, cardHeader.firstChild);
// Update base model label class
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && !baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.add('with-toggle');
}
} else {
// Update existing toggle button
toggleBtn.querySelector('i').className = 'fas fa-eye';
}
}
} else {
// Remove blur
previewContainer.classList.remove('blurred');
// Hide overlay if it exists
const overlay = previewContainer.querySelector('.nsfw-overlay');
if (overlay) overlay.style.display = 'none';
// Remove toggle button when content is set to PG or PG13
const cardHeader = previewContainer.querySelector('.card-header');
if (cardHeader) {
const toggleBtn = cardHeader.querySelector('.toggle-blur-btn');
if (toggleBtn) {
// Remove the toggle button completely
toggleBtn.remove();
// Update base model label class if it exists
const baseModelLabel = cardHeader.querySelector('.base-model-label');
if (baseModelLabel && baseModelLabel.classList.contains('with-toggle')) {
baseModelLabel.classList.remove('with-toggle');
}
}
}
}
}
showNSFWLevelSelector(x, y, card) {
const selector = document.getElementById('nsfwLevelSelector');
const currentLevelEl = document.getElementById('currentNSFWLevel');
// Get current NSFW level
let currentLevel = 0;
try {
const metaData = JSON.parse(card.dataset.meta || '{}');
currentLevel = metaData.preview_nsfw_level || 0;
// Update if we have no recorded level but have a dataset attribute
if (!currentLevel && card.dataset.nsfwLevel) {
currentLevel = parseInt(card.dataset.nsfwLevel) || 0;
}
} catch (err) {
console.error('Error parsing metadata:', err);
}
currentLevelEl.textContent = getNSFWLevelName(currentLevel);
// Position the selector
if (x && y) {
const viewportWidth = document.documentElement.clientWidth;
const viewportHeight = document.documentElement.clientHeight;
const selectorRect = selector.getBoundingClientRect();
// Center the selector if no coordinates provided
let finalX = (viewportWidth - selectorRect.width) / 2;
let finalY = (viewportHeight - selectorRect.height) / 2;
selector.style.left = `${finalX}px`;
selector.style.top = `${finalY}px`;
}
// Highlight current level button
document.querySelectorAll('.nsfw-level-btn').forEach(btn => {
if (parseInt(btn.dataset.level) === currentLevel) {
btn.classList.add('active');
} else {
btn.classList.remove('active');
}
});
// Store reference to current card
selector.dataset.cardPath = card.dataset.filepath;
// Show selector
selector.style.display = 'block';
}
}

View File

@@ -0,0 +1,205 @@
import { BaseContextMenu } from './BaseContextMenu.js';
import { showToast } from '../../utils/uiHelpers.js';
import { setSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
import { state } from '../../state/index.js';
export class RecipeContextMenu extends BaseContextMenu {
constructor() {
super('recipeContextMenu', '.lora-card');
}
showMenu(x, y, card) {
// Call the parent method first to handle basic positioning
super.showMenu(x, y, card);
// Get recipe data to check for missing LoRAs
const recipeId = card.dataset.id;
const missingLorasItem = this.menu.querySelector('.download-missing-item');
if (recipeId && missingLorasItem) {
// Check if this card has missing LoRAs
const loraCountElement = card.querySelector('.lora-count');
const hasMissingLoras = loraCountElement && loraCountElement.classList.contains('missing');
// Show/hide the download missing LoRAs option based on missing status
if (hasMissingLoras) {
missingLorasItem.style.display = 'flex';
} else {
missingLorasItem.style.display = 'none';
}
}
}
handleMenuAction(action) {
const recipeId = this.currentCard.dataset.id;
switch(action) {
case 'details':
// Show recipe details
this.currentCard.click();
break;
case 'copy':
// Copy recipe to clipboard
this.currentCard.querySelector('.fa-copy')?.click();
break;
case 'share':
// Share recipe
this.currentCard.querySelector('.fa-share-alt')?.click();
break;
case 'delete':
// Delete recipe
this.currentCard.querySelector('.fa-trash')?.click();
break;
case 'viewloras':
// View all LoRAs in the recipe
this.viewRecipeLoRAs(recipeId);
break;
case 'download-missing':
// Download missing LoRAs
this.downloadMissingLoRAs(recipeId);
break;
}
}
// View all LoRAs in the recipe
viewRecipeLoRAs(recipeId) {
if (!recipeId) {
showToast('Cannot view LoRAs: Missing recipe ID', 'error');
return;
}
// First get the recipe details to access its LoRAs
fetch(`/api/recipe/${recipeId}`)
.then(response => response.json())
.then(recipe => {
// Clear any previous filters first
removeSessionItem('recipe_to_lora_filterLoraHash');
removeSessionItem('recipe_to_lora_filterLoraHashes');
removeSessionItem('filterRecipeName');
removeSessionItem('viewLoraDetail');
// Collect all hashes from the recipe's LoRAs
const loraHashes = recipe.loras
.filter(lora => lora.hash)
.map(lora => lora.hash.toLowerCase());
if (loraHashes.length > 0) {
// Store the LoRA hashes and recipe name in session storage
setSessionItem('recipe_to_lora_filterLoraHashes', JSON.stringify(loraHashes));
setSessionItem('filterRecipeName', recipe.title);
// Navigate to the LoRAs page
window.location.href = '/loras';
} else {
showToast('No LoRAs found in this recipe', 'info');
}
})
.catch(error => {
console.error('Error loading recipe LoRAs:', error);
showToast('Error loading recipe LoRAs: ' + error.message, 'error');
});
}
// Download missing LoRAs
async downloadMissingLoRAs(recipeId) {
if (!recipeId) {
showToast('Cannot download LoRAs: Missing recipe ID', 'error');
return;
}
try {
// First get the recipe details
const response = await fetch(`/api/recipe/${recipeId}`);
const recipe = await response.json();
// Get missing LoRAs
const missingLoras = recipe.loras.filter(lora => !lora.inLibrary && !lora.isDeleted);
if (missingLoras.length === 0) {
showToast('No missing LoRAs to download', 'info');
return;
}
// Show loading toast
state.loadingManager.showSimpleLoading('Getting version info for missing LoRAs...');
// Get version info for each missing LoRA
const missingLorasWithVersionInfoPromises = missingLoras.map(async lora => {
let endpoint;
// Determine which endpoint to use based on available data
if (lora.modelVersionId) {
endpoint = `/api/civitai/model/version/${lora.modelVersionId}`;
} else if (lora.hash) {
endpoint = `/api/civitai/model/hash/${lora.hash}`;
} else {
console.error("Missing both hash and modelVersionId for lora:", lora);
return null;
}
const versionResponse = await fetch(endpoint);
const versionInfo = await versionResponse.json();
// Return original lora data combined with version info
return {
...lora,
civitaiInfo: versionInfo
};
});
// Wait for all API calls to complete
const lorasWithVersionInfo = await Promise.all(missingLorasWithVersionInfoPromises);
// Filter out null values (failed requests)
const validLoras = lorasWithVersionInfo.filter(lora => lora !== null);
if (validLoras.length === 0) {
showToast('Failed to get information for missing LoRAs', 'error');
return;
}
// Prepare data for import manager using the retrieved information
const recipeData = {
loras: validLoras.map(lora => {
const civitaiInfo = lora.civitaiInfo;
const modelFile = civitaiInfo.files ?
civitaiInfo.files.find(file => file.type === 'Model') : null;
return {
// Basic lora info
name: civitaiInfo.model?.name || lora.name,
version: civitaiInfo.name || '',
strength: lora.strength || 1.0,
// Model identifiers
hash: modelFile?.hashes?.SHA256?.toLowerCase() || lora.hash,
modelVersionId: civitaiInfo.id || lora.modelVersionId,
// Metadata
thumbnailUrl: civitaiInfo.images?.[0]?.url || '',
baseModel: civitaiInfo.baseModel || '',
downloadUrl: civitaiInfo.downloadUrl || '',
size: modelFile ? (modelFile.sizeKB * 1024) : 0,
file_name: modelFile ? modelFile.name.split('.')[0] : '',
// Status flags
existsLocally: false,
isDeleted: civitaiInfo.error === "Model not found",
isEarlyAccess: !!civitaiInfo.earlyAccessEndsAt,
earlyAccessEndsAt: civitaiInfo.earlyAccessEndsAt || ''
};
})
};
// Call ImportManager's download missing LoRAs method
window.importManager.downloadMissingLoras(recipeData, recipeId);
} catch (error) {
console.error('Error downloading missing LoRAs:', error);
showToast('Error preparing LoRAs for download: ' + error.message, 'error');
} finally {
if (state.loadingManager) {
state.loadingManager.hide();
}
}
}
}

View File

@@ -0,0 +1,3 @@
export { LoraContextMenu } from './LoraContextMenu.js';
export { RecipeContextMenu } from './RecipeContextMenu.js';
export { CheckpointContextMenu } from './CheckpointContextMenu.js';

View File

@@ -1,8 +1,9 @@
import { showToast } from '../utils/uiHelpers.js'; import { showToast, openCivitai, copyToClipboard } from '../utils/uiHelpers.js';
import { state } from '../state/index.js'; import { state } from '../state/index.js';
import { showLoraModal } from './LoraModal.js'; import { showLoraModal } from './loraModal/index.js';
import { bulkManager } from '../managers/BulkManager.js'; import { bulkManager } from '../managers/BulkManager.js';
import { NSFW_LEVELS } from '../utils/constants.js'; import { NSFW_LEVELS } from '../utils/constants.js';
import { replacePreview, deleteModel } from '../api/loraApi.js'
export function createLoraCard(lora) { export function createLoraCard(lora) {
const card = document.createElement('div'); const card = document.createElement('div');
@@ -43,7 +44,9 @@ export function createLoraCard(lora) {
card.classList.add('selected'); card.classList.add('selected');
} }
const version = state.previewVersions.get(lora.file_path); // Get the page-specific previewVersions map
const previewVersions = state.pages.loras.previewVersions || new Map();
const version = previewVersions.get(lora.file_path);
const previewUrl = lora.preview_url || '/loras_static/images/no-preview.png'; const previewUrl = lora.preview_url || '/loras_static/images/no-preview.png';
const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl; const versionedPreviewUrl = version ? `${previewUrl}?t=${version}` : previewUrl;
@@ -57,10 +60,15 @@ export function createLoraCard(lora) {
nsfwText = "R-rated Content"; nsfwText = "R-rated Content";
} }
// Check if autoplayOnHover is enabled for video previews
const autoplayOnHover = state.global.settings.autoplayOnHover || false;
const isVideo = previewUrl.endsWith('.mp4');
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
card.innerHTML = ` card.innerHTML = `
<div class="card-preview ${shouldBlur ? 'blurred' : ''}"> <div class="card-preview ${shouldBlur ? 'blurred' : ''}">
${previewUrl.endsWith('.mp4') ? ${isVideo ?
`<video controls autoplay muted loop> `<video ${videoAttrs}>
<source src="${versionedPreviewUrl}" type="video/mp4"> <source src="${versionedPreviewUrl}" type="video/mp4">
</video>` : </video>` :
`<img src="${versionedPreviewUrl}" alt="${lora.model_name}">` `<img src="${versionedPreviewUrl}" alt="${lora.model_name}">`
@@ -197,26 +205,7 @@ export function createLoraCard(lora) {
const strength = usageTips.strength || 1; const strength = usageTips.strength || 1;
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`; const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
try { await copyToClipboard(loraSyntax, 'LoRA syntax copied');
// Modern clipboard API
if (navigator.clipboard && window.isSecureContext) {
await navigator.clipboard.writeText(loraSyntax);
} else {
// Fallback for older browsers
const textarea = document.createElement('textarea');
textarea.value = loraSyntax;
textarea.style.position = 'absolute';
textarea.style.left = '-99999px';
document.body.appendChild(textarea);
textarea.select();
document.execCommand('copy');
document.body.removeChild(textarea);
}
showToast('LoRA syntax copied', 'success');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
}); });
// Civitai button click event // Civitai button click event
@@ -246,6 +235,26 @@ export function createLoraCard(lora) {
actionGroup.style.display = 'none'; actionGroup.style.display = 'none';
}); });
} }
// Add autoplayOnHover handlers for video elements if needed
const videoElement = card.querySelector('video');
if (videoElement && autoplayOnHover) {
const cardPreview = card.querySelector('.card-preview');
// Remove autoplay attribute and pause initially
videoElement.removeAttribute('autoplay');
videoElement.pause();
// Add mouse events to trigger play/pause
cardPreview.addEventListener('mouseenter', () => {
videoElement.play();
});
cardPreview.addEventListener('mouseleave', () => {
videoElement.pause();
videoElement.currentTime = 0;
});
}
return card; return card;
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
// Recipe Card Component // Recipe Card Component
import { showToast } from '../utils/uiHelpers.js'; import { showToast, copyToClipboard } from '../utils/uiHelpers.js';
import { modalManager } from '../managers/ModalManager.js'; import { modalManager } from '../managers/ModalManager.js';
class RecipeCard { class RecipeCard {
@@ -25,7 +25,7 @@ class RecipeCard {
const lorasCount = loras.length; const lorasCount = loras.length;
// Check if all LoRAs are available in the library // Check if all LoRAs are available in the library
const missingLorasCount = loras.filter(lora => !lora.inLibrary).length; const missingLorasCount = loras.filter(lora => !lora.inLibrary && !lora.isDeleted).length;
const allLorasAvailable = missingLorasCount === 0 && lorasCount > 0; const allLorasAvailable = missingLorasCount === 0 && lorasCount > 0;
// Ensure file_url exists, fallback to file_path if needed // Ensure file_url exists, fallback to file_path if needed
@@ -96,24 +96,23 @@ class RecipeCard {
copyRecipeSyntax() { copyRecipeSyntax() {
try { try {
// Generate recipe syntax in the format <lora:file_name:strength> separated by spaces // Get recipe ID
const loras = this.recipe.loras || []; const recipeId = this.recipe.id;
if (loras.length === 0) { if (!recipeId) {
showToast('No LoRAs in this recipe to copy', 'warning'); showToast('Cannot copy recipe syntax: Missing recipe ID', 'error');
return; return;
} }
const syntax = loras.map(lora => { // Fallback if button not found
// Use file_name if available, otherwise use empty placeholder fetch(`/api/recipe/${recipeId}/syntax`)
const fileName = lora.file_name || '[missing-lora]'; .then(response => response.json())
const strength = lora.strength || 1.0; .then(data => {
return `<lora:${fileName}:${strength}>`; if (data.success && data.syntax) {
}).join(' '); return copyToClipboard(data.syntax, 'Recipe syntax copied to clipboard');
} else {
// Copy to clipboard throw new Error(data.error || 'No syntax returned');
navigator.clipboard.writeText(syntax) }
.then(() => {
showToast('Recipe syntax copied to clipboard', 'success');
}) })
.catch(err => { .catch(err => {
console.error('Failed to copy: ', err); console.error('Failed to copy: ', err);
@@ -277,4 +276,4 @@ class RecipeCard {
} }
} }
export { RecipeCard }; export { RecipeCard };

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,102 @@
/**
* ModelDescription.js
* Handles checkpoint model descriptions
*/
import { showToast } from '../../utils/uiHelpers.js';
/**
* Set up tab switching functionality
*/
export function setupTabSwitching() {
const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn');
tabButtons.forEach(button => {
button.addEventListener('click', () => {
// Remove active class from all tabs
document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn =>
btn.classList.remove('active')
);
document.querySelectorAll('.tab-content .tab-pane').forEach(tab =>
tab.classList.remove('active')
);
// Add active class to clicked tab
button.classList.add('active');
const tabId = `${button.dataset.tab}-tab`;
document.getElementById(tabId).classList.add('active');
// If switching to description tab, make sure content is properly loaded and displayed
if (button.dataset.tab === 'description') {
const descriptionContent = document.querySelector('.model-description-content');
if (descriptionContent) {
const hasContent = descriptionContent.innerHTML.trim() !== '';
document.querySelector('.model-description-loading')?.classList.add('hidden');
// If no content, show a message
if (!hasContent) {
descriptionContent.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContent.classList.remove('hidden');
}
}
}
});
});
}
/**
* Load model description from API
* @param {string} modelId - The Civitai model ID
* @param {string} filePath - File path for the model
*/
export async function loadModelDescription(modelId, filePath) {
try {
const descriptionContainer = document.querySelector('.model-description-content');
const loadingElement = document.querySelector('.model-description-loading');
if (!descriptionContainer || !loadingElement) return;
// Show loading indicator
loadingElement.classList.remove('hidden');
descriptionContainer.classList.add('hidden');
// Try to get model description from API
const response = await fetch(`/api/checkpoint-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`);
if (!response.ok) {
throw new Error(`Failed to fetch model description: ${response.statusText}`);
}
const data = await response.json();
if (data.success && data.description) {
// Update the description content
descriptionContainer.innerHTML = data.description;
// Process any links in the description to open in new tab
const links = descriptionContainer.querySelectorAll('a');
links.forEach(link => {
link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer');
});
// Show the description and hide loading indicator
descriptionContainer.classList.remove('hidden');
loadingElement.classList.add('hidden');
} else {
throw new Error(data.error || 'No description available');
}
} catch (error) {
console.error('Error loading model description:', error);
const loadingElement = document.querySelector('.model-description-loading');
if (loadingElement) {
loadingElement.innerHTML = `<div class="error-message">Failed to load model description. ${error.message}</div>`;
}
// Show empty state message in the description container
const descriptionContainer = document.querySelector('.model-description-content');
if (descriptionContainer) {
descriptionContainer.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContainer.classList.remove('hidden');
}
}
}

View File

@@ -0,0 +1,484 @@
/**
* ModelMetadata.js
* Handles checkpoint model metadata editing functionality
*/
import { showToast } from '../../utils/uiHelpers.js';
import { BASE_MODELS } from '../../utils/constants.js';
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
/**
* Save model metadata to the server
* @param {string} filePath - Path to the model file
* @param {Object} data - Metadata to save
* @returns {Promise} - Promise that resolves with the server response
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/checkpoints/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}
/**
* Set up model name editing functionality
* @param {string} filePath - The full file path of the model.
*/
export function setupModelNameEditing(filePath) {
const modelNameContent = document.querySelector('.model-name-content');
const editBtn = document.querySelector('.edit-model-name-btn');
if (!modelNameContent || !editBtn) return;
// Show edit button on hover
const modelNameHeader = document.querySelector('.model-name-header');
modelNameHeader.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
modelNameHeader.addEventListener('mouseleave', () => {
if (!modelNameContent.getAttribute('data-editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
modelNameContent.setAttribute('data-editing', 'true');
modelNameContent.focus();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
if (modelNameContent.childNodes.length > 0) {
range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
editBtn.classList.add('visible');
});
// Handle focus out
modelNameContent.addEventListener('blur', function() {
this.removeAttribute('data-editing');
editBtn.classList.remove('visible');
if (this.textContent.trim() === '') {
// Restore original model name if empty
// Use the passed filePath to find the card
const checkpointCard = document.querySelector(`.checkpoint-card[data-filepath="${filePath}"]`);
if (checkpointCard) {
this.textContent = checkpointCard.dataset.model_name;
}
}
});
// Handle enter key
modelNameContent.addEventListener('keydown', function(e) {
if (e.key === 'Enter') {
e.preventDefault();
// Use the passed filePath
saveModelName(filePath);
this.blur();
}
});
// Limit model name length
modelNameContent.addEventListener('input', function() {
if (this.textContent.length > 100) {
this.textContent = this.textContent.substring(0, 100);
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.setStart(this.childNodes[0], 100);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
showToast('Model name is limited to 100 characters', 'warning');
}
});
}
/**
* Save model name
* @param {string} filePath - File path
*/
async function saveModelName(filePath) {
const modelNameElement = document.querySelector('.model-name-content');
const newModelName = modelNameElement.textContent.trim();
// Validate model name
if (!newModelName) {
showToast('Model name cannot be empty', 'error');
return;
}
// Check if model name is too long
if (newModelName.length > 100) {
showToast('Model name is too long (maximum 100 characters)', 'error');
// Truncate the displayed text
modelNameElement.textContent = newModelName.substring(0, 100);
return;
}
try {
await saveModelMetadata(filePath, { model_name: newModelName });
// Update the card with the new model name
updateCheckpointCard(filePath, { name: newModelName });
showToast('Model name updated successfully', 'success');
// No need to reload the entire page
// setTimeout(() => {
// window.location.reload();
// }, 1500);
} catch (error) {
showToast('Failed to update model name', 'error');
}
}
/**
* Set up base model editing functionality
* @param {string} filePath - The full file path of the model.
*/
export function setupBaseModelEditing(filePath) {
const baseModelContent = document.querySelector('.base-model-content');
const editBtn = document.querySelector('.edit-base-model-btn');
if (!baseModelContent || !editBtn) return;
// Show edit button on hover
const baseModelDisplay = document.querySelector('.base-model-display');
baseModelDisplay.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
baseModelDisplay.addEventListener('mouseleave', () => {
if (!baseModelDisplay.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
baseModelDisplay.classList.add('editing');
// Store the original value to check for changes later
const originalValue = baseModelContent.textContent.trim();
// Create dropdown selector to replace the base model content
const currentValue = originalValue;
const dropdown = document.createElement('select');
dropdown.className = 'base-model-selector';
// Flag to track if a change was made
let valueChanged = false;
// Add options from BASE_MODELS constants
const baseModelCategories = {
'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER],
'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1],
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
'Other Models': [
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN
]
};
// Create option groups for better organization
Object.entries(baseModelCategories).forEach(([category, models]) => {
const group = document.createElement('optgroup');
group.label = category;
models.forEach(model => {
const option = document.createElement('option');
option.value = model;
option.textContent = model;
option.selected = model === currentValue;
group.appendChild(option);
});
dropdown.appendChild(group);
});
// Replace content with dropdown
baseModelContent.style.display = 'none';
baseModelDisplay.insertBefore(dropdown, editBtn);
// Hide edit button during editing
editBtn.style.display = 'none';
// Focus the dropdown
dropdown.focus();
// Handle dropdown change
dropdown.addEventListener('change', function() {
const selectedModel = this.value;
baseModelContent.textContent = selectedModel;
// Mark that a change was made if the value differs from original
if (selectedModel !== originalValue) {
valueChanged = true;
} else {
valueChanged = false;
}
});
// Function to save changes and exit edit mode
const saveAndExit = function() {
// Check if dropdown still exists and remove it
if (dropdown && dropdown.parentNode === baseModelDisplay) {
baseModelDisplay.removeChild(dropdown);
}
// Show the content and edit button
baseModelContent.style.display = '';
editBtn.style.display = '';
// Remove editing class
baseModelDisplay.classList.remove('editing');
// Only save if the value has actually changed
if (valueChanged || baseModelContent.textContent.trim() !== originalValue) {
// Use the passed filePath for saving
saveBaseModel(filePath, originalValue);
}
// Remove this event listener
document.removeEventListener('click', outsideClickHandler);
};
// Handle outside clicks to save and exit
const outsideClickHandler = function(e) {
// If click is outside the dropdown and base model display
if (!baseModelDisplay.contains(e.target)) {
saveAndExit();
}
};
// Add delayed event listener for outside clicks
setTimeout(() => {
document.addEventListener('click', outsideClickHandler);
}, 0);
// Also handle dropdown blur event
dropdown.addEventListener('blur', function(e) {
// Only save if the related target is not the edit button or inside the baseModelDisplay
if (!baseModelDisplay.contains(e.relatedTarget)) {
saveAndExit();
}
});
});
}
/**
* Save base model
* @param {string} filePath - File path
* @param {string} originalValue - Original value (for comparison)
*/
async function saveBaseModel(filePath, originalValue) {
const baseModelElement = document.querySelector('.base-model-content');
const newBaseModel = baseModelElement.textContent.trim();
// Only save if the value has actually changed
if (newBaseModel === originalValue) {
return; // No change, no need to save
}
try {
await saveModelMetadata(filePath, { base_model: newBaseModel });
// Update the card with the new base model
updateCheckpointCard(filePath, { base_model: newBaseModel });
showToast('Base model updated successfully', 'success');
} catch (error) {
showToast('Failed to update base model', 'error');
}
}
/**
* Set up file name editing functionality
* @param {string} filePath - The full file path of the model.
*/
export function setupFileNameEditing(filePath) {
const fileNameContent = document.querySelector('.file-name-content');
const editBtn = document.querySelector('.edit-file-name-btn');
if (!fileNameContent || !editBtn) return;
// Show edit button on hover
const fileNameWrapper = document.querySelector('.file-name-wrapper');
fileNameWrapper.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
fileNameWrapper.addEventListener('mouseleave', () => {
if (!fileNameWrapper.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
fileNameWrapper.classList.add('editing');
fileNameContent.setAttribute('contenteditable', 'true');
fileNameContent.focus();
// Store original value for comparison later
fileNameContent.dataset.originalValue = fileNameContent.textContent.trim();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.selectNodeContents(fileNameContent);
range.collapse(false);
sel.removeAllRanges();
sel.addRange(range);
editBtn.classList.add('visible');
});
// Handle keyboard events in edit mode
fileNameContent.addEventListener('keydown', function(e) {
if (!this.getAttribute('contenteditable')) return;
if (e.key === 'Enter') {
e.preventDefault();
this.blur(); // Trigger save on Enter
} else if (e.key === 'Escape') {
e.preventDefault();
// Restore original value
this.textContent = this.dataset.originalValue;
exitEditMode();
}
});
// Handle input validation
fileNameContent.addEventListener('input', function() {
if (!this.getAttribute('contenteditable')) return;
// Replace invalid characters for filenames
const invalidChars = /[\\/:*?"<>|]/g;
if (invalidChars.test(this.textContent)) {
const cursorPos = window.getSelection().getRangeAt(0).startOffset;
this.textContent = this.textContent.replace(invalidChars, '');
// Restore cursor position
const range = document.createRange();
const sel = window.getSelection();
const newPos = Math.min(cursorPos, this.textContent.length);
if (this.firstChild) {
range.setStart(this.firstChild, newPos);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
showToast('Invalid characters removed from filename', 'warning');
}
});
// Handle focus out - save changes
fileNameContent.addEventListener('blur', async function() {
if (!this.getAttribute('contenteditable')) return;
const newFileName = this.textContent.trim();
const originalValue = this.dataset.originalValue;
// Basic validation
if (!newFileName) {
// Restore original value if empty
this.textContent = originalValue;
showToast('File name cannot be empty', 'error');
exitEditMode();
return;
}
if (newFileName === originalValue) {
// No changes, just exit edit mode
exitEditMode();
return;
}
try {
// Use the passed filePath (which includes the original filename)
// Call API to rename the file
const response = await fetch('/api/rename_checkpoint', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath, // Use the full original path
new_file_name: newFileName
})
});
const result = await response.json();
if (result.success) {
showToast('File name updated successfully', 'success');
// Get the new file path from the result
const pathParts = filePath.split(/[\\/]/);
pathParts.pop(); // Remove old filename
const newFilePath = [...pathParts, newFileName].join('/');
// Update the checkpoint card with new file path
updateCheckpointCard(filePath, {
filepath: newFilePath,
file_name: newFileName
});
// Update the file name display in the modal
document.querySelector('#file-name').textContent = newFileName;
// Update the modal's data-filepath attribute
const modalContent = document.querySelector('#checkpointModal .modal-content');
if (modalContent) {
modalContent.dataset.filepath = newFilePath;
}
// Reload the page after a short delay to reflect changes
setTimeout(() => {
window.location.reload();
}, 1500);
} else {
throw new Error(result.error || 'Unknown error');
}
} catch (error) {
console.error('Error renaming file:', error);
this.textContent = originalValue; // Restore original file name
showToast(`Failed to rename file: ${error.message}`, 'error');
} finally {
exitEditMode();
}
});
function exitEditMode() {
fileNameContent.removeAttribute('contenteditable');
fileNameWrapper.classList.remove('editing');
editBtn.classList.remove('visible');
}
}

View File

@@ -0,0 +1,488 @@
/**
* ShowcaseView.js
* Handles showcase content (images, videos) display for checkpoint modal
*/
import { showToast, copyToClipboard } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { NSFW_LEVELS } from '../../utils/constants.js';
/**
* Render showcase content
* @param {Array} images - Array of images/videos to show
* @returns {string} HTML content
*/
export function renderShowcaseContent(images) {
if (!images?.length) return '<div class="no-examples">No example images available</div>';
// Filter images based on SFW setting
const showOnlySFW = state.settings.show_only_sfw;
let filteredImages = images;
let hiddenCount = 0;
if (showOnlySFW) {
filteredImages = images.filter(img => {
const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0;
const isSfw = nsfwLevel < NSFW_LEVELS.R;
if (!isSfw) hiddenCount++;
return isSfw;
});
}
// Show message if no images are available after filtering
if (filteredImages.length === 0) {
return `
<div class="no-examples">
<p>All example images are filtered due to NSFW content settings</p>
<p class="nsfw-filter-info">Your settings are currently set to show only safe-for-work content</p>
<p>You can change this in Settings <i class="fas fa-cog"></i></p>
</div>
`;
}
// Show hidden content notification if applicable
const hiddenNotification = hiddenCount > 0 ?
`<div class="nsfw-filter-notification">
<i class="fas fa-eye-slash"></i> ${hiddenCount} ${hiddenCount === 1 ? 'image' : 'images'} hidden due to SFW-only setting
</div>` : '';
return `
<div class="scroll-indicator" onclick="toggleShowcase(this)">
<i class="fas fa-chevron-down"></i>
<span>Scroll or click to show ${filteredImages.length} examples</span>
</div>
<div class="carousel collapsed">
${hiddenNotification}
<div class="carousel-container">
${filteredImages.map(img => generateMediaWrapper(img)).join('')}
</div>
</div>
`;
}
/**
* Generate media wrapper HTML for an image or video
* @param {Object} media - Media object with image or video data
* @returns {string} HTML content
*/
function generateMediaWrapper(media) {
// Calculate appropriate aspect ratio:
// 1. Keep original aspect ratio
// 2. Limit maximum height to 60% of viewport height
// 3. Ensure minimum height is 40% of container width
const aspectRatio = (media.height / media.width) * 100;
const containerWidth = 800; // modal content maximum width
const minHeightPercent = 40;
const maxHeightPercent = (window.innerHeight * 0.6 / containerWidth) * 100;
const heightPercent = Math.max(
minHeightPercent,
Math.min(maxHeightPercent, aspectRatio)
);
// Check if media should be blurred
const nsfwLevel = media.nsfwLevel !== undefined ? media.nsfwLevel : 0;
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
// Determine NSFW warning text based on level
let nsfwText = "Mature Content";
if (nsfwLevel >= NSFW_LEVELS.XXX) {
nsfwText = "XXX-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.X) {
nsfwText = "X-rated Content";
} else if (nsfwLevel >= NSFW_LEVELS.R) {
nsfwText = "R-rated Content";
}
// Extract metadata from the media
const meta = media.meta || {};
const prompt = meta.prompt || '';
const negativePrompt = meta.negative_prompt || meta.negativePrompt || '';
const size = meta.Size || `${media.width}x${media.height}`;
const seed = meta.seed || '';
const model = meta.Model || '';
const steps = meta.steps || '';
const sampler = meta.sampler || '';
const cfgScale = meta.cfgScale || '';
const clipSkip = meta.clipSkip || '';
// Check if we have any meaningful generation parameters
const hasParams = seed || model || steps || sampler || cfgScale || clipSkip;
const hasPrompts = prompt || negativePrompt;
// Create metadata panel content
const metadataPanel = generateMetadataPanel(
hasParams, hasPrompts,
prompt, negativePrompt,
size, seed, model, steps, sampler, cfgScale, clipSkip
);
// Check if this is a video or image
if (media.type === 'video') {
return generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel);
}
return generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel);
}
/**
* Generate metadata panel HTML
*/
function generateMetadataPanel(hasParams, hasPrompts, prompt, negativePrompt, size, seed, model, steps, sampler, cfgScale, clipSkip) {
// Create unique IDs for prompt copying
const promptIndex = Math.random().toString(36).substring(2, 15);
const negPromptIndex = Math.random().toString(36).substring(2, 15);
let content = '<div class="image-metadata-panel"><div class="metadata-content">';
if (hasParams) {
content += `
<div class="params-tags">
${size ? `<div class="param-tag"><span class="param-name">Size:</span><span class="param-value">${size}</span></div>` : ''}
${seed ? `<div class="param-tag"><span class="param-name">Seed:</span><span class="param-value">${seed}</span></div>` : ''}
${model ? `<div class="param-tag"><span class="param-name">Model:</span><span class="param-value">${model}</span></div>` : ''}
${steps ? `<div class="param-tag"><span class="param-name">Steps:</span><span class="param-value">${steps}</span></div>` : ''}
${sampler ? `<div class="param-tag"><span class="param-name">Sampler:</span><span class="param-value">${sampler}</span></div>` : ''}
${cfgScale ? `<div class="param-tag"><span class="param-name">CFG:</span><span class="param-value">${cfgScale}</span></div>` : ''}
${clipSkip ? `<div class="param-tag"><span class="param-name">Clip Skip:</span><span class="param-value">${clipSkip}</span></div>` : ''}
</div>
`;
}
if (!hasParams && !hasPrompts) {
content += `
<div class="no-metadata-message">
<i class="fas fa-info-circle"></i>
<span>No generation parameters available</span>
</div>
`;
}
if (prompt) {
content += `
<div class="metadata-row prompt-row">
<span class="metadata-label">Prompt:</span>
<div class="metadata-prompt-wrapper">
<div class="metadata-prompt">${prompt}</div>
<button class="copy-prompt-btn" data-prompt-index="${promptIndex}">
<i class="fas fa-copy"></i>
</button>
</div>
</div>
<div class="hidden-prompt" id="prompt-${promptIndex}" style="display:none;">${prompt}</div>
`;
}
if (negativePrompt) {
content += `
<div class="metadata-row prompt-row">
<span class="metadata-label">Negative Prompt:</span>
<div class="metadata-prompt-wrapper">
<div class="metadata-prompt">${negativePrompt}</div>
<button class="copy-prompt-btn" data-prompt-index="${negPromptIndex}">
<i class="fas fa-copy"></i>
</button>
</div>
</div>
<div class="hidden-prompt" id="prompt-${negPromptIndex}" style="display:none;">${negativePrompt}</div>
`;
}
content += '</div></div>';
return content;
}
/**
* Generate video wrapper HTML
*/
function generateVideoWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) {
return `
<div class="media-wrapper ${shouldBlur ? 'nsfw-media-wrapper' : ''}" style="padding-bottom: ${heightPercent}%">
${shouldBlur ? `
<button class="toggle-blur-btn showcase-toggle-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>
` : ''}
<video controls autoplay muted loop crossorigin="anonymous"
referrerpolicy="no-referrer" data-src="${media.url}"
class="lazy ${shouldBlur ? 'blurred' : ''}">
<source data-src="${media.url}" type="video/mp4">
Your browser does not support video playback
</video>
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
${metadataPanel}
</div>
`;
}
/**
* Generate image wrapper HTML
*/
function generateImageWrapper(media, heightPercent, shouldBlur, nsfwText, metadataPanel) {
return `
<div class="media-wrapper ${shouldBlur ? 'nsfw-media-wrapper' : ''}" style="padding-bottom: ${heightPercent}%">
${shouldBlur ? `
<button class="toggle-blur-btn showcase-toggle-btn" title="Toggle blur">
<i class="fas fa-eye"></i>
</button>
` : ''}
<img data-src="${media.url}"
alt="Preview"
crossorigin="anonymous"
referrerpolicy="no-referrer"
width="${media.width}"
height="${media.height}"
class="lazy ${shouldBlur ? 'blurred' : ''}">
${shouldBlur ? `
<div class="nsfw-overlay">
<div class="nsfw-warning">
<p>${nsfwText}</p>
<button class="show-content-btn">Show</button>
</div>
</div>
` : ''}
${metadataPanel}
</div>
`;
}
/**
* Toggle showcase expansion
*/
export function toggleShowcase(element) {
const carousel = element.nextElementSibling;
const isCollapsed = carousel.classList.contains('collapsed');
const indicator = element.querySelector('span');
const icon = element.querySelector('i');
carousel.classList.toggle('collapsed');
if (isCollapsed) {
const count = carousel.querySelectorAll('.media-wrapper').length;
indicator.textContent = `Scroll or click to hide examples`;
icon.classList.replace('fa-chevron-down', 'fa-chevron-up');
initLazyLoading(carousel);
// Initialize NSFW content blur toggle handlers
initNsfwBlurHandlers(carousel);
// Initialize metadata panel interaction handlers
initMetadataPanelHandlers(carousel);
} else {
const count = carousel.querySelectorAll('.media-wrapper').length;
indicator.textContent = `Scroll or click to show ${count} examples`;
icon.classList.replace('fa-chevron-up', 'fa-chevron-down');
}
}
/**
* Initialize metadata panel interaction handlers
*/
function initMetadataPanelHandlers(container) {
const mediaWrappers = container.querySelectorAll('.media-wrapper');
mediaWrappers.forEach(wrapper => {
const metadataPanel = wrapper.querySelector('.image-metadata-panel');
if (!metadataPanel) return;
// Prevent events from bubbling
metadataPanel.addEventListener('click', (e) => {
e.stopPropagation();
});
// Handle copy prompt buttons
const copyBtns = metadataPanel.querySelectorAll('.copy-prompt-btn');
copyBtns.forEach(copyBtn => {
const promptIndex = copyBtn.dataset.promptIndex;
const promptElement = wrapper.querySelector(`#prompt-${promptIndex}`);
copyBtn.addEventListener('click', async (e) => {
e.stopPropagation();
if (!promptElement) return;
try {
await copyToClipboard(promptElement.textContent, 'Prompt copied to clipboard');
} catch (err) {
console.error('Copy failed:', err);
showToast('Copy failed', 'error');
}
});
});
// Prevent panel scroll from causing modal scroll
metadataPanel.addEventListener('wheel', (e) => {
e.stopPropagation();
});
});
}
/**
* Initialize blur toggle handlers
*/
function initNsfwBlurHandlers(container) {
// Handle toggle blur buttons
const toggleButtons = container.querySelectorAll('.toggle-blur-btn');
toggleButtons.forEach(btn => {
btn.addEventListener('click', (e) => {
e.stopPropagation();
const wrapper = btn.closest('.media-wrapper');
const media = wrapper.querySelector('img, video');
const isBlurred = media.classList.toggle('blurred');
const icon = btn.querySelector('i');
// Update the icon based on blur state
if (isBlurred) {
icon.className = 'fas fa-eye';
} else {
icon.className = 'fas fa-eye-slash';
}
// Toggle the overlay visibility
const overlay = wrapper.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = isBlurred ? 'flex' : 'none';
}
});
});
// Handle "Show" buttons in overlays
const showButtons = container.querySelectorAll('.show-content-btn');
showButtons.forEach(btn => {
btn.addEventListener('click', (e) => {
e.stopPropagation();
const wrapper = btn.closest('.media-wrapper');
const media = wrapper.querySelector('img, video');
media.classList.remove('blurred');
// Update the toggle button icon
const toggleBtn = wrapper.querySelector('.toggle-blur-btn');
if (toggleBtn) {
toggleBtn.querySelector('i').className = 'fas fa-eye-slash';
}
// Hide the overlay
const overlay = wrapper.querySelector('.nsfw-overlay');
if (overlay) {
overlay.style.display = 'none';
}
});
});
}
/**
* Initialize lazy loading for images and videos
*/
function initLazyLoading(container) {
const lazyElements = container.querySelectorAll('.lazy');
const lazyLoad = (element) => {
if (element.tagName.toLowerCase() === 'video') {
element.src = element.dataset.src;
element.querySelector('source').src = element.dataset.src;
element.load();
} else {
element.src = element.dataset.src;
}
element.classList.remove('lazy');
};
const observer = new IntersectionObserver((entries) => {
entries.forEach(entry => {
if (entry.isIntersecting) {
lazyLoad(entry.target);
observer.unobserve(entry.target);
}
});
});
lazyElements.forEach(element => observer.observe(element));
}
/**
* Set up showcase scroll functionality
*/
export function setupShowcaseScroll() {
// Listen for wheel events
document.addEventListener('wheel', (event) => {
const modalContent = document.querySelector('#checkpointModal .modal-content');
if (!modalContent) return;
const showcase = modalContent.querySelector('.showcase-section');
if (!showcase) return;
const carousel = showcase.querySelector('.carousel');
const scrollIndicator = showcase.querySelector('.scroll-indicator');
if (carousel?.classList.contains('collapsed') && event.deltaY > 0) {
const isNearBottom = modalContent.scrollHeight - modalContent.scrollTop - modalContent.clientHeight < 100;
if (isNearBottom) {
toggleShowcase(scrollIndicator);
event.preventDefault();
}
}
}, { passive: false });
// Use MutationObserver to set up back-to-top button when modal content is added
const observer = new MutationObserver((mutations) => {
for (const mutation of mutations) {
if (mutation.type === 'childList' && mutation.addedNodes.length) {
const checkpointModal = document.getElementById('checkpointModal');
if (checkpointModal && checkpointModal.querySelector('.modal-content')) {
setupBackToTopButton(checkpointModal.querySelector('.modal-content'));
}
}
}
});
// Start observing the document body for changes
observer.observe(document.body, { childList: true, subtree: true });
// Also try to set up the button immediately in case the modal is already open
const modalContent = document.querySelector('#checkpointModal .modal-content');
if (modalContent) {
setupBackToTopButton(modalContent);
}
}
/**
* Set up back-to-top button
*/
function setupBackToTopButton(modalContent) {
// Remove any existing scroll listeners to avoid duplicates
modalContent.onscroll = null;
// Add new scroll listener
modalContent.addEventListener('scroll', () => {
const backToTopBtn = modalContent.querySelector('.back-to-top');
if (backToTopBtn) {
if (modalContent.scrollTop > 300) {
backToTopBtn.classList.add('visible');
} else {
backToTopBtn.classList.remove('visible');
}
}
});
// Trigger a scroll event to check initial position
modalContent.dispatchEvent(new Event('scroll'));
}
/**
* Scroll to top of modal content
*/
export function scrollToTop(button) {
const modalContent = button.closest('.modal-content');
if (modalContent) {
modalContent.scrollTo({
top: 0,
behavior: 'smooth'
});
}
}

View File

@@ -0,0 +1,214 @@
/**
* CheckpointModal - Main entry point
*
* Modularized checkpoint modal component that handles checkpoint model details display
*/
import { showToast } from '../../utils/uiHelpers.js';
import { state } from '../../state/index.js';
import { modalManager } from '../../managers/ModalManager.js';
import { renderShowcaseContent, toggleShowcase, setupShowcaseScroll, scrollToTop } from './ShowcaseView.js';
import { setupTabSwitching, loadModelDescription } from './ModelDescription.js';
import {
setupModelNameEditing,
setupBaseModelEditing,
setupFileNameEditing,
saveModelMetadata
} from './ModelMetadata.js';
import { renderCompactTags, setupTagTooltip, formatFileSize } from './utils.js';
import { updateCheckpointCard } from '../../utils/cardUpdater.js';
/**
* Display the checkpoint modal with the given checkpoint data
* @param {Object} checkpoint - Checkpoint data object
*/
export function showCheckpointModal(checkpoint) {
const content = `
<div class="modal-content">
<button class="close" onclick="modalManager.closeModal('checkpointModal')">&times;</button>
<header class="modal-header">
<div class="model-name-header">
<h2 class="model-name-content" contenteditable="true" spellcheck="false">${checkpoint.model_name || 'Checkpoint Details'}</h2>
<button class="edit-model-name-btn" title="Edit model name">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
${renderCompactTags(checkpoint.tags || [])}
</header>
<div class="modal-body">
<div class="info-section">
<div class="info-grid">
<div class="info-item">
<label>Version</label>
<span>${checkpoint.civitai?.name || 'N/A'}</span>
</div>
<div class="info-item">
<label>File Name</label>
<div class="file-name-wrapper">
<span id="file-name" class="file-name-content">${checkpoint.file_name || 'N/A'}</span>
<button class="edit-file-name-btn" title="Edit file name">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
</div>
<div class="info-item location-size">
<div class="location-wrapper">
<label>Location</label>
<span class="file-path">${checkpoint.file_path.replace(/[^/]+$/, '')}</span>
</div>
</div>
<div class="info-item base-size">
<div class="base-wrapper">
<label>Base Model</label>
<div class="base-model-display">
<span class="base-model-content">${checkpoint.base_model || 'Unknown'}</span>
<button class="edit-base-model-btn" title="Edit base model">
<i class="fas fa-pencil-alt"></i>
</button>
</div>
</div>
<div class="size-wrapper">
<label>Size</label>
<span>${formatFileSize(checkpoint.file_size)}</span>
</div>
</div>
<div class="info-item notes">
<label>Additional Notes</label>
<div class="editable-field">
<div class="notes-content" contenteditable="true" spellcheck="false">${checkpoint.notes || 'Add your notes here...'}</div>
<button class="save-btn" onclick="saveCheckpointNotes('${checkpoint.file_path}')">
<i class="fas fa-save"></i>
</button>
</div>
</div>
<div class="info-item full-width">
<label>About this version</label>
<div class="description-text">${checkpoint.description || 'N/A'}</div>
</div>
</div>
</div>
<div class="showcase-section" data-checkpoint-id="${checkpoint.civitai?.modelId || ''}">
<div class="showcase-tabs">
<button class="tab-btn active" data-tab="showcase">Examples</button>
<button class="tab-btn" data-tab="description">Model Description</button>
</div>
<div class="tab-content">
<div id="showcase-tab" class="tab-pane active">
${renderShowcaseContent(checkpoint.civitai?.images || [])}
</div>
<div id="description-tab" class="tab-pane">
<div class="model-description-container">
<div class="model-description-loading">
<i class="fas fa-spinner fa-spin"></i> Loading model description...
</div>
<div class="model-description-content">
${checkpoint.modelDescription || ''}
</div>
</div>
</div>
</div>
<button class="back-to-top" onclick="scrollToTopCheckpoint(this)">
<i class="fas fa-arrow-up"></i>
</button>
</div>
</div>
</div>
`;
modalManager.showModal('checkpointModal', content);
setupEditableFields(checkpoint.file_path);
setupShowcaseScroll();
setupTabSwitching();
setupTagTooltip();
setupModelNameEditing(checkpoint.file_path);
setupBaseModelEditing(checkpoint.file_path);
setupFileNameEditing(checkpoint.file_path);
// If we have a model ID but no description, fetch it
if (checkpoint.civitai?.modelId && !checkpoint.modelDescription) {
loadModelDescription(checkpoint.civitai.modelId, checkpoint.file_path);
}
}
/**
* Set up editable fields in the checkpoint modal
* @param {string} filePath - The full file path of the model.
*/
function setupEditableFields(filePath) {
const editableFields = document.querySelectorAll('.editable-field [contenteditable]');
editableFields.forEach(field => {
field.addEventListener('focus', function() {
if (this.textContent === 'Add your notes here...') {
this.textContent = '';
}
});
field.addEventListener('blur', function() {
if (this.textContent.trim() === '') {
if (this.classList.contains('notes-content')) {
this.textContent = 'Add your notes here...';
}
}
});
});
// Add keydown event listeners for notes
const notesContent = document.querySelector('.notes-content');
if (notesContent) {
notesContent.addEventListener('keydown', async function(e) {
if (e.key === 'Enter') {
if (e.shiftKey) {
// Allow shift+enter for new line
return;
}
e.preventDefault();
await saveNotes(filePath);
}
});
}
}
/**
* Save checkpoint notes
* @param {string} filePath - Path to the checkpoint file
*/
async function saveNotes(filePath) {
const content = document.querySelector('.notes-content').textContent;
try {
await saveModelMetadata(filePath, { notes: content });
// Update the corresponding checkpoint card's dataset
updateCheckpointCard(filePath, { notes: content });
showToast('Notes saved successfully', 'success');
} catch (error) {
showToast('Failed to save notes', 'error');
}
}
// Export the checkpoint modal API
const checkpointModal = {
show: showCheckpointModal,
toggleShowcase,
scrollToTop
};
export { checkpointModal };
// Define global functions for use in HTML
window.toggleShowcase = function(element) {
toggleShowcase(element);
};
window.scrollToTopCheckpoint = function(button) {
scrollToTop(button);
};
window.saveCheckpointNotes = function(filePath) {
saveNotes(filePath);
};

View File

@@ -0,0 +1,74 @@
/**
* utils.js
* CheckpointModal component utility functions
*/
import { showToast } from '../../utils/uiHelpers.js';
/**
* Format file size for display
* @param {number} bytes - File size in bytes
* @returns {string} - Formatted file size
*/
export function formatFileSize(bytes) {
if (!bytes) return 'N/A';
const units = ['B', 'KB', 'MB', 'GB', 'TB'];
let size = bytes;
let unitIndex = 0;
while (size >= 1024 && unitIndex < units.length - 1) {
size /= 1024;
unitIndex++;
}
return `${size.toFixed(1)} ${units[unitIndex]}`;
}
/**
* Render compact tags
* @param {Array} tags - Array of tags
* @returns {string} HTML content
*/
export function renderCompactTags(tags) {
if (!tags || tags.length === 0) return '';
// Display up to 5 tags, with a tooltip indicator if there are more
const visibleTags = tags.slice(0, 5);
const remainingCount = Math.max(0, tags.length - 5);
return `
<div class="model-tags-container">
<div class="model-tags-compact">
${visibleTags.map(tag => `<span class="model-tag-compact">${tag}</span>`).join('')}
${remainingCount > 0 ?
`<span class="model-tag-more" data-count="${remainingCount}">+${remainingCount}</span>` :
''}
</div>
${tags.length > 0 ?
`<div class="model-tags-tooltip">
<div class="tooltip-content">
${tags.map(tag => `<span class="tooltip-tag">${tag}</span>`).join('')}
</div>
</div>` :
''}
</div>
`;
}
/**
* Set up tag tooltip functionality
*/
export function setupTagTooltip() {
const tagsContainer = document.querySelector('.model-tags-container');
const tooltip = document.querySelector('.model-tags-tooltip');
if (tagsContainer && tooltip) {
tagsContainer.addEventListener('mouseenter', () => {
tooltip.classList.add('visible');
});
tagsContainer.addEventListener('mouseleave', () => {
tooltip.classList.remove('visible');
});
}
}

View File

@@ -0,0 +1,60 @@
// CheckpointsControls.js - Specific implementation for the Checkpoints page
import { PageControls } from './PageControls.js';
import { loadMoreCheckpoints, resetAndReload, refreshCheckpoints, fetchCivitai } from '../../api/checkpointApi.js';
import { showToast } from '../../utils/uiHelpers.js';
import { CheckpointDownloadManager } from '../../managers/CheckpointDownloadManager.js';
/**
* CheckpointsControls class - Extends PageControls for Checkpoint-specific functionality
*/
export class CheckpointsControls extends PageControls {
constructor() {
// Initialize with 'checkpoints' page type
super('checkpoints');
// Initialize checkpoint download manager
this.downloadManager = new CheckpointDownloadManager();
// Register API methods specific to the Checkpoints page
this.registerCheckpointsAPI();
}
/**
* Register Checkpoint-specific API methods
*/
registerCheckpointsAPI() {
const checkpointsAPI = {
// Core API functions
loadMoreModels: async (resetPage = false, updateFolders = false) => {
return await loadMoreCheckpoints(resetPage, updateFolders);
},
resetAndReload: async (updateFolders = false) => {
return await resetAndReload(updateFolders);
},
refreshModels: async () => {
return await refreshCheckpoints();
},
// Add fetch from Civitai functionality for checkpoints
fetchFromCivitai: async () => {
return await fetchCivitai();
},
// Add show download modal functionality
showDownloadModal: () => {
this.downloadManager.showDownloadModal();
},
// No clearCustomFilter implementation is needed for checkpoints
// as custom filters are currently only used for LoRAs
clearCustomFilter: async () => {
showToast('No custom filter to clear', 'info');
}
};
// Register the API
this.registerAPI(checkpointsAPI);
}
}

View File

@@ -0,0 +1,146 @@
// LorasControls.js - Specific implementation for the LoRAs page
import { PageControls } from './PageControls.js';
import { loadMoreLoras, fetchCivitai, resetAndReload, refreshLoras } from '../../api/loraApi.js';
import { getSessionItem, removeSessionItem } from '../../utils/storageHelpers.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
* LorasControls class - Extends PageControls for LoRA-specific functionality
*/
export class LorasControls extends PageControls {
constructor() {
// Initialize with 'loras' page type
super('loras');
// Register API methods specific to the LoRAs page
this.registerLorasAPI();
// Check for custom filters (e.g., from recipe navigation)
this.checkCustomFilters();
}
/**
* Register LoRA-specific API methods
*/
registerLorasAPI() {
const lorasAPI = {
// Core API functions
loadMoreModels: async (resetPage = false, updateFolders = false) => {
return await loadMoreLoras(resetPage, updateFolders);
},
resetAndReload: async (updateFolders = false) => {
return await resetAndReload(updateFolders);
},
refreshModels: async () => {
return await refreshLoras();
},
// LoRA-specific API functions
fetchFromCivitai: async () => {
return await fetchCivitai();
},
showDownloadModal: () => {
if (window.downloadManager) {
window.downloadManager.showDownloadModal();
} else {
console.error('Download manager not available');
}
},
toggleBulkMode: () => {
if (window.bulkManager) {
window.bulkManager.toggleBulkMode();
} else {
console.error('Bulk manager not available');
}
},
clearCustomFilter: async () => {
await this.clearCustomFilter();
}
};
// Register the API
this.registerAPI(lorasAPI);
}
/**
* Check for custom filter parameters in session storage (e.g., from recipe page navigation)
*/
checkCustomFilters() {
const filterLoraHash = getSessionItem('recipe_to_lora_filterLoraHash');
const filterLoraHashes = getSessionItem('recipe_to_lora_filterLoraHashes');
const filterRecipeName = getSessionItem('filterRecipeName');
const viewLoraDetail = getSessionItem('viewLoraDetail');
if ((filterLoraHash || filterLoraHashes) && filterRecipeName) {
// Found custom filter parameters, set up the custom filter
// Show the filter indicator
const indicator = document.getElementById('customFilterIndicator');
const filterText = indicator?.querySelector('.customFilterText');
if (indicator && filterText) {
indicator.classList.remove('hidden');
// Set text content with recipe name
const filterType = filterLoraHash && viewLoraDetail ? "Viewing LoRA from" : "Viewing LoRAs from";
const displayText = `${filterType}: ${filterRecipeName}`;
filterText.textContent = this._truncateText(displayText, 30);
filterText.setAttribute('title', displayText);
// Add pulse animation
const filterElement = indicator.querySelector('.filter-active');
if (filterElement) {
filterElement.classList.add('animate');
setTimeout(() => filterElement.classList.remove('animate'), 600);
}
}
// If we're viewing a specific LoRA detail, set up to open the modal
if (filterLoraHash && viewLoraDetail) {
this.pageState.pendingLoraHash = filterLoraHash;
}
}
}
/**
* Clear the custom filter and reload the page
*/
async clearCustomFilter() {
console.log("Clearing custom filter...");
// Remove filter parameters from session storage
removeSessionItem('recipe_to_lora_filterLoraHash');
removeSessionItem('recipe_to_lora_filterLoraHashes');
removeSessionItem('filterRecipeName');
removeSessionItem('viewLoraDetail');
// Hide the filter indicator
const indicator = document.getElementById('customFilterIndicator');
if (indicator) {
indicator.classList.add('hidden');
}
// Reset state
if (this.pageState.pendingLoraHash) {
delete this.pageState.pendingLoraHash;
}
// Reload the loras
await resetAndReload();
}
/**
* Helper to truncate text with ellipsis
* @param {string} text - Text to truncate
* @param {number} maxLength - Maximum length before truncating
* @returns {string} - Truncated text
*/
_truncateText(text, maxLength) {
return text.length > maxLength ? text.substring(0, maxLength - 3) + '...' : text;
}
}

View File

@@ -0,0 +1,388 @@
// PageControls.js - Manages controls for both LoRAs and Checkpoints pages
import { state, getCurrentPageState, setCurrentPageType } from '../../state/index.js';
import { getStorageItem, setStorageItem } from '../../utils/storageHelpers.js';
import { showToast } from '../../utils/uiHelpers.js';
/**
* PageControls class - Unified control management for model pages
*/
export class PageControls {
constructor(pageType) {
// Set the current page type in state
setCurrentPageType(pageType);
// Store the page type
this.pageType = pageType;
// Get the current page state
this.pageState = getCurrentPageState();
// Initialize state based on page type
this.initializeState();
// Store API methods
this.api = null;
// Initialize event listeners
this.initEventListeners();
console.log(`PageControls initialized for ${pageType} page`);
}
/**
* Initialize state based on page type
*/
initializeState() {
// Set default values
this.pageState.pageSize = 20;
this.pageState.isLoading = false;
this.pageState.hasMore = true;
// Load sort preference
this.loadSortPreference();
}
/**
* Register API methods for the page
* @param {Object} api - API methods for the page
*/
registerAPI(api) {
this.api = api;
console.log(`API methods registered for ${this.pageType} page`);
}
/**
* Initialize event listeners for controls
*/
initEventListeners() {
// Sort select handler
const sortSelect = document.getElementById('sortSelect');
if (sortSelect) {
sortSelect.value = this.pageState.sortBy;
sortSelect.addEventListener('change', async (e) => {
this.pageState.sortBy = e.target.value;
this.saveSortPreference(e.target.value);
await this.resetAndReload();
});
}
// Use event delegation for folder tags - this is the key fix
const folderTagsContainer = document.querySelector('.folder-tags-container');
if (folderTagsContainer) {
folderTagsContainer.addEventListener('click', (e) => {
const tag = e.target.closest('.tag');
if (tag) {
this.handleFolderClick(tag);
}
});
}
// Refresh button handler
const refreshBtn = document.querySelector('[data-action="refresh"]');
if (refreshBtn) {
refreshBtn.addEventListener('click', () => this.refreshModels());
}
// Toggle folders button
const toggleFoldersBtn = document.querySelector('.toggle-folders-btn');
if (toggleFoldersBtn) {
toggleFoldersBtn.addEventListener('click', () => this.toggleFolderTags());
}
// Clear custom filter handler
const clearFilterBtn = document.querySelector('.clear-filter');
if (clearFilterBtn) {
clearFilterBtn.addEventListener('click', () => this.clearCustomFilter());
}
// Page-specific event listeners
this.initPageSpecificListeners();
}
/**
* Initialize page-specific event listeners
*/
initPageSpecificListeners() {
// Fetch from Civitai button - available for both loras and checkpoints
const fetchButton = document.querySelector('[data-action="fetch"]');
if (fetchButton) {
fetchButton.addEventListener('click', () => this.fetchFromCivitai());
}
const downloadButton = document.querySelector('[data-action="download"]');
if (downloadButton) {
downloadButton.addEventListener('click', () => this.showDownloadModal());
}
if (this.pageType === 'loras') {
// Bulk operations button - LoRAs only
const bulkButton = document.querySelector('[data-action="bulk"]');
if (bulkButton) {
bulkButton.addEventListener('click', () => this.toggleBulkMode());
}
}
}
/**
* Toggle folder selection
* @param {HTMLElement} tagElement - The folder tag element that was clicked
*/
handleFolderClick(tagElement) {
const folder = tagElement.dataset.folder;
const wasActive = tagElement.classList.contains('active');
document.querySelectorAll('.folder-tags .tag').forEach(t => {
t.classList.remove('active');
});
if (!wasActive) {
tagElement.classList.add('active');
this.pageState.activeFolder = folder;
setStorageItem(`${this.pageType}_activeFolder`, folder);
} else {
this.pageState.activeFolder = null;
setStorageItem(`${this.pageType}_activeFolder`, null);
}
this.resetAndReload();
}
/**
* Restore folder filter from storage
*/
restoreFolderFilter() {
const activeFolder = getStorageItem(`${this.pageType}_activeFolder`);
const folderTag = activeFolder && document.querySelector(`.tag[data-folder="${activeFolder}"]`);
if (folderTag) {
folderTag.classList.add('active');
this.pageState.activeFolder = activeFolder;
this.filterByFolder(activeFolder);
}
}
/**
* Filter displayed cards by folder
* @param {string} folderPath - Folder path to filter by
*/
filterByFolder(folderPath) {
const cardSelector = this.pageType === 'loras' ? '.lora-card' : '.checkpoint-card';
document.querySelectorAll(cardSelector).forEach(card => {
card.style.display = card.dataset.folder === folderPath ? '' : 'none';
});
}
/**
* Update the folder tags display with new folder list
* @param {Array} folders - List of folder names
*/
updateFolderTags(folders) {
const folderTagsContainer = document.querySelector('.folder-tags');
if (!folderTagsContainer) return;
// Keep track of currently selected folder
const currentFolder = this.pageState.activeFolder;
// Create HTML for folder tags
const tagsHTML = folders.map(folder => {
const isActive = folder === currentFolder;
return `<div class="tag ${isActive ? 'active' : ''}" data-folder="${folder}">${folder}</div>`;
}).join('');
// Update the container
folderTagsContainer.innerHTML = tagsHTML;
// Scroll active folder into view (no need to reattach click handlers)
const activeTag = folderTagsContainer.querySelector(`.tag[data-folder="${currentFolder}"]`);
if (activeTag) {
activeTag.scrollIntoView({ behavior: 'smooth', block: 'nearest' });
}
}
/**
* Toggle visibility of folder tags
*/
toggleFolderTags() {
const folderTags = document.querySelector('.folder-tags');
const toggleBtn = document.querySelector('.toggle-folders-btn i');
if (folderTags) {
folderTags.classList.toggle('collapsed');
if (folderTags.classList.contains('collapsed')) {
// Change icon to indicate folders are hidden
toggleBtn.className = 'fas fa-folder-plus';
toggleBtn.parentElement.title = 'Show folder tags';
setStorageItem('folderTagsCollapsed', 'true');
} else {
// Change icon to indicate folders are visible
toggleBtn.className = 'fas fa-folder-minus';
toggleBtn.parentElement.title = 'Hide folder tags';
setStorageItem('folderTagsCollapsed', 'false');
}
}
}
/**
* Initialize folder tags visibility based on stored preference
*/
initFolderTagsVisibility() {
const isCollapsed = getStorageItem('folderTagsCollapsed');
if (isCollapsed) {
const folderTags = document.querySelector('.folder-tags');
const toggleBtn = document.querySelector('.toggle-folders-btn i');
if (folderTags) {
folderTags.classList.add('collapsed');
}
if (toggleBtn) {
toggleBtn.className = 'fas fa-folder-plus';
toggleBtn.parentElement.title = 'Show folder tags';
}
} else {
const toggleBtn = document.querySelector('.toggle-folders-btn i');
if (toggleBtn) {
toggleBtn.className = 'fas fa-folder-minus';
toggleBtn.parentElement.title = 'Hide folder tags';
}
}
}
/**
* Load sort preference from storage
*/
loadSortPreference() {
const savedSort = getStorageItem(`${this.pageType}_sort`);
if (savedSort) {
this.pageState.sortBy = savedSort;
const sortSelect = document.getElementById('sortSelect');
if (sortSelect) {
sortSelect.value = savedSort;
}
}
}
/**
* Save sort preference to storage
* @param {string} sortValue - The sort value to save
*/
saveSortPreference(sortValue) {
setStorageItem(`${this.pageType}_sort`, sortValue);
}
/**
* Open model page on Civitai
* @param {string} modelName - Name of the model
*/
openCivitai(modelName) {
// Get card selector based on page type
const cardSelector = this.pageType === 'loras'
? `.lora-card[data-name="${modelName}"]`
: `.checkpoint-card[data-name="${modelName}"]`;
const card = document.querySelector(cardSelector);
if (!card) return;
const metaData = JSON.parse(card.dataset.meta);
const civitaiId = metaData.modelId;
const versionId = metaData.id;
// Build URL
if (civitaiId) {
let url = `https://civitai.com/models/${civitaiId}`;
if (versionId) {
url += `?modelVersionId=${versionId}`;
}
window.open(url, '_blank');
} else {
// If no ID, try searching by name
window.open(`https://civitai.com/models?query=${encodeURIComponent(modelName)}`, '_blank');
}
}
/**
* Reset and reload the models list
*/
async resetAndReload(updateFolders = false) {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.resetAndReload(updateFolders);
} catch (error) {
console.error(`Error reloading ${this.pageType}:`, error);
showToast(`Failed to reload ${this.pageType}: ${error.message}`, 'error');
}
}
/**
* Refresh models list
*/
async refreshModels() {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.refreshModels();
} catch (error) {
console.error(`Error refreshing ${this.pageType}:`, error);
showToast(`Failed to refresh ${this.pageType}: ${error.message}`, 'error');
}
}
/**
* Fetch metadata from Civitai (available for both LoRAs and Checkpoints)
*/
async fetchFromCivitai() {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.fetchFromCivitai();
} catch (error) {
console.error('Error fetching metadata:', error);
showToast('Failed to fetch metadata: ' + error.message, 'error');
}
}
/**
* Show download modal
*/
showDownloadModal() {
this.api.showDownloadModal();
}
/**
* Toggle bulk mode (LoRAs only)
*/
toggleBulkMode() {
if (this.pageType !== 'loras' || !this.api) {
console.error('Bulk mode is only available for LoRAs');
return;
}
this.api.toggleBulkMode();
}
/**
* Clear custom filter
*/
async clearCustomFilter() {
if (!this.api) {
console.error('API methods not registered');
return;
}
try {
await this.api.clearCustomFilter();
} catch (error) {
console.error('Error clearing custom filter:', error);
showToast('Failed to clear custom filter: ' + error.message, 'error');
}
}
}

View File

@@ -0,0 +1,23 @@
// Controls components index file
import { PageControls } from './PageControls.js';
import { LorasControls } from './LorasControls.js';
import { CheckpointsControls } from './CheckpointsControls.js';
// Export the classes
export { PageControls, LorasControls, CheckpointsControls };
/**
* Factory function to create the appropriate controls based on page type
* @param {string} pageType - The type of page ('loras' or 'checkpoints')
* @returns {PageControls} - The appropriate controls instance
*/
export function createPageControls(pageType) {
if (pageType === 'loras') {
return new LorasControls();
} else if (pageType === 'checkpoints') {
return new CheckpointsControls();
} else {
console.error(`Unknown page type: ${pageType}`);
return null;
}
}

View File

@@ -0,0 +1,495 @@
/**
* Initialization Component
* Manages the display of initialization progress and status
*/
import { appCore } from '../core.js';
import { getSessionItem, setSessionItem } from '../utils/storageHelpers.js';
import { state, getCurrentPageState } from '../state/index.js';
class InitializationManager {
constructor() {
this.currentTipIndex = 0;
this.tipInterval = null;
this.websocket = null;
this.progress = 0;
this.processingStartTime = null;
this.processedFilesCount = 0;
this.totalFilesCount = 0;
this.averageProcessingTime = null;
this.pageType = null; // Added page type property
}
/**
* Initialize the component
*/
initialize() {
// Initialize core application for theme and header functionality
appCore.initialize().then(() => {
console.log('Core application initialized for initialization component');
});
// Detect the current page type
this.detectPageType();
// Check session storage for saved progress
this.restoreProgress();
// Setup the tip carousel
this.setupTipCarousel();
// Connect to WebSocket for progress updates
this.connectWebSocket();
// Add event listeners for tip navigation
this.setupTipNavigation();
// Show first tip as active
document.querySelector('.tip-item').classList.add('active');
}
/**
* Detect the current page type
*/
detectPageType() {
// Get the current page type from URL or data attribute
const path = window.location.pathname;
if (path.includes('/checkpoints')) {
this.pageType = 'checkpoints';
} else if (path.includes('/loras')) {
this.pageType = 'loras';
} else {
// Default to loras if can't determine
this.pageType = 'loras';
}
console.log(`Initialization component detected page type: ${this.pageType}`);
}
/**
* Get the storage key with page type prefix
*/
getStorageKey(key) {
return `${this.pageType}_${key}`;
}
/**
* Restore progress from session storage if available
*/
restoreProgress() {
const savedProgress = getSessionItem(this.getStorageKey('initProgress'));
if (savedProgress) {
console.log(`Restoring ${this.pageType} progress from session storage:`, savedProgress);
// Restore progress percentage
if (savedProgress.progress !== undefined) {
this.updateProgress(savedProgress.progress);
}
// Restore processed files count and total files
if (savedProgress.processedFiles !== undefined) {
this.processedFilesCount = savedProgress.processedFiles;
}
if (savedProgress.totalFiles !== undefined) {
this.totalFilesCount = savedProgress.totalFiles;
}
// Restore processing time metrics if available
if (savedProgress.averageProcessingTime !== undefined) {
this.averageProcessingTime = savedProgress.averageProcessingTime;
this.updateRemainingTime();
}
// Restore progress status message
if (savedProgress.details) {
this.updateStatusMessage(savedProgress.details);
}
}
}
/**
* Connect to WebSocket for initialization progress updates
*/
connectWebSocket() {
try {
const wsProtocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
this.websocket = new WebSocket(`${wsProtocol}${window.location.host}/ws/init-progress`);
this.websocket.onopen = () => {
console.log('Connected to initialization progress WebSocket');
};
this.websocket.onmessage = (event) => {
this.handleProgressUpdate(JSON.parse(event.data));
};
this.websocket.onerror = (error) => {
console.error('WebSocket error:', error);
// Fall back to polling if WebSocket fails
this.fallbackToPolling();
};
this.websocket.onclose = () => {
console.log('WebSocket connection closed');
// Check if we need to fall back to polling
if (!this.pollingActive) {
this.fallbackToPolling();
}
};
} catch (error) {
console.error('Failed to connect to WebSocket:', error);
this.fallbackToPolling();
}
}
/**
* Fall back to polling if WebSocket connection fails
*/
fallbackToPolling() {
this.pollingActive = true;
this.pollProgress();
// Set a simulated progress that moves forward slowly
// This gives users feedback even if the backend isn't providing updates
let simulatedProgress = this.progress || 0;
const simulateInterval = setInterval(() => {
simulatedProgress += 0.5;
if (simulatedProgress > 95) {
clearInterval(simulateInterval);
return;
}
// Only use simulated progress if we haven't received a real update
if (this.progress < simulatedProgress) {
this.updateProgress(simulatedProgress);
}
}, 1000);
}
/**
* Poll for progress updates from the server
*/
pollProgress() {
const checkProgress = () => {
fetch('/api/init-status')
.then(response => response.json())
.then(data => {
this.handleProgressUpdate(data);
// If initialization is complete, stop polling
if (data.status !== 'complete') {
setTimeout(checkProgress, 2000);
} else {
window.location.reload();
}
})
.catch(error => {
console.error('Error polling for progress:', error);
setTimeout(checkProgress, 3000); // Try again after a longer delay
});
};
checkProgress();
}
/**
* Handle progress updates from WebSocket or polling
*/
handleProgressUpdate(data) {
if (!data) return;
// Check if this update is for our page type
if (data.pageType && data.pageType !== this.pageType) {
console.log(`Ignoring update for ${data.pageType}, we're on ${this.pageType}`);
return;
}
// If no pageType is specified in the data but we have scanner_type, map it to pageType
if (!data.pageType && data.scanner_type) {
const scannerTypeToPageType = {
'lora': 'loras',
'checkpoint': 'checkpoints'
};
if (scannerTypeToPageType[data.scanner_type] !== this.pageType) {
console.log(`Ignoring update for ${data.scanner_type}, we're on ${this.pageType}`);
return;
}
}
// Save progress data to session storage
setSessionItem(this.getStorageKey('initProgress'), {
...data,
averageProcessingTime: this.averageProcessingTime,
processedFiles: this.processedFilesCount,
totalFiles: this.totalFilesCount
});
// Update progress percentage
if (data.progress !== undefined) {
this.updateProgress(data.progress);
}
// Update stage-specific details
if (data.details) {
this.updateStatusMessage(data.details);
}
// Track files count for time estimation
if (data.stage === 'count_models' && data.details) {
const match = data.details.match(/Found (\d+)/);
if (match && match[1]) {
this.totalFilesCount = parseInt(match[1]);
}
}
// Track processed files for time estimation
if (data.stage === 'process_models' && data.details) {
const match = data.details.match(/Processing .* files: (\d+)\/(\d+)/);
if (match && match[1] && match[2]) {
const currentCount = parseInt(match[1]);
const totalCount = parseInt(match[2]);
// Make sure we have valid total count
if (totalCount > 0 && this.totalFilesCount === 0) {
this.totalFilesCount = totalCount;
}
// Start tracking processing time once we've processed some files
if (currentCount > 0 && !this.processingStartTime && this.processedFilesCount === 0) {
this.processingStartTime = Date.now();
}
// Calculate average processing time based on elapsed time and files processed
if (this.processingStartTime && currentCount > this.processedFilesCount) {
const newFiles = currentCount - this.processedFilesCount;
const elapsedTime = Date.now() - this.processingStartTime;
const timePerFile = elapsedTime / currentCount; // ms per file
// Update moving average
if (!this.averageProcessingTime) {
this.averageProcessingTime = timePerFile;
} else {
// Simple exponential moving average
this.averageProcessingTime = this.averageProcessingTime * 0.7 + timePerFile * 0.3;
}
// Update remaining time estimate
this.updateRemainingTime();
}
this.processedFilesCount = currentCount;
}
}
// If initialization is complete, reload the page
if (data.status === 'complete') {
this.showCompletionMessage();
// Remove session storage data since we're done
setSessionItem(this.getStorageKey('initProgress'), null);
// Give the user a moment to see the completion message
setTimeout(() => {
window.location.reload();
}, 1500);
}
}
/**
* Update the remaining time display based on current progress
*/
updateRemainingTime() {
if (!this.averageProcessingTime || !this.totalFilesCount || this.totalFilesCount <= 0) {
document.getElementById('remainingTime').textContent = 'Estimating...';
return;
}
const remainingFiles = this.totalFilesCount - this.processedFilesCount;
const remainingTimeMs = remainingFiles * this.averageProcessingTime;
if (remainingTimeMs <= 0) {
document.getElementById('remainingTime').textContent = 'Almost done...';
return;
}
// Format the time for display
let formattedTime;
if (remainingTimeMs < 60000) {
// Less than a minute
formattedTime = 'Less than a minute';
} else if (remainingTimeMs < 3600000) {
// Less than an hour
const minutes = Math.round(remainingTimeMs / 60000);
formattedTime = `~${minutes} minute${minutes !== 1 ? 's' : ''}`;
} else {
// Hours and minutes
const hours = Math.floor(remainingTimeMs / 3600000);
const minutes = Math.round((remainingTimeMs % 3600000) / 60000);
formattedTime = `~${hours} hour${hours !== 1 ? 's' : ''} ${minutes} minute${minutes !== 1 ? 's' : ''}`;
}
document.getElementById('remainingTime').textContent = formattedTime + ' remaining';
}
/**
* Update status message
*/
updateStatusMessage(message) {
const progressStatus = document.getElementById('progressStatus');
if (progressStatus) {
progressStatus.textContent = message;
}
}
/**
* Update the progress bar and percentage
*/
updateProgress(progress) {
this.progress = progress;
const progressBar = document.getElementById('initProgressBar');
const progressPercentage = document.getElementById('progressPercentage');
if (progressBar && progressPercentage) {
progressBar.style.width = `${progress}%`;
progressPercentage.textContent = `${Math.round(progress)}%`;
}
}
/**
* Setup the tip carousel to rotate through tips
*/
setupTipCarousel() {
const tipItems = document.querySelectorAll('.tip-item');
if (tipItems.length === 0) return;
// Show the first tip
tipItems[0].classList.add('active');
document.querySelector('.tip-dot').classList.add('active');
// Set up automatic rotation
this.tipInterval = setInterval(() => {
this.showNextTip();
}, 8000); // Change tip every 8 seconds
}
/**
* Setup tip navigation dots
*/
setupTipNavigation() {
const tipDots = document.querySelectorAll('.tip-dot');
tipDots.forEach((dot, index) => {
dot.addEventListener('click', () => {
this.showTipByIndex(index);
});
});
}
/**
* Show the next tip in the carousel
*/
showNextTip() {
const tipItems = document.querySelectorAll('.tip-item');
const tipDots = document.querySelectorAll('.tip-dot');
if (tipItems.length === 0) return;
// Hide current tip
tipItems[this.currentTipIndex].classList.remove('active');
tipDots[this.currentTipIndex].classList.remove('active');
// Calculate next index
this.currentTipIndex = (this.currentTipIndex + 1) % tipItems.length;
// Show next tip
tipItems[this.currentTipIndex].classList.add('active');
tipDots[this.currentTipIndex].classList.add('active');
}
/**
* Show a specific tip by index
*/
showTipByIndex(index) {
const tipItems = document.querySelectorAll('.tip-item');
const tipDots = document.querySelectorAll('.tip-dot');
if (index >= tipItems.length || index < 0) return;
// Hide current tip
tipItems[this.currentTipIndex].classList.remove('active');
tipDots[this.currentTipIndex].classList.remove('active');
// Update index and show selected tip
this.currentTipIndex = index;
// Show selected tip
tipItems[this.currentTipIndex].classList.add('active');
tipDots[this.currentTipIndex].classList.add('active');
// Reset interval to prevent quick tip change
if (this.tipInterval) {
clearInterval(this.tipInterval);
this.tipInterval = setInterval(() => {
this.showNextTip();
}, 8000);
}
}
/**
* Show completion message
*/
showCompletionMessage() {
// Update progress to 100%
this.updateProgress(100);
// Update status message
this.updateStatusMessage('Initialization complete!');
// Update title and subtitle
const initTitle = document.getElementById('initTitle');
const initSubtitle = document.getElementById('initSubtitle');
const remainingTime = document.getElementById('remainingTime');
if (initTitle) {
initTitle.textContent = 'Initialization Complete';
}
if (initSubtitle) {
initSubtitle.textContent = 'Reloading page...';
}
if (remainingTime) {
remainingTime.textContent = 'Done!';
}
}
/**
* Clean up resources when the component is destroyed
*/
cleanup() {
if (this.tipInterval) {
clearInterval(this.tipInterval);
}
if (this.websocket && this.websocket.readyState === WebSocket.OPEN) {
this.websocket.close();
}
}
}
// Create and export the initialization manager
export const initManager = new InitializationManager();
// Initialize the component when the DOM is loaded
document.addEventListener('DOMContentLoaded', () => {
// Only initialize if we're in initialization mode
const initContainer = document.getElementById('initializationContainer');
if (initContainer) {
initManager.initialize();
}
});
// Clean up when the page is unloaded
window.addEventListener('beforeunload', () => {
initManager.cleanup();
});

View File

@@ -0,0 +1,102 @@
/**
* ModelDescription.js
* 处理LoRA模型描述相关的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
/**
* 设置标签页切换功能
*/
export function setupTabSwitching() {
const tabButtons = document.querySelectorAll('.showcase-tabs .tab-btn');
tabButtons.forEach(button => {
button.addEventListener('click', () => {
// Remove active class from all tabs
document.querySelectorAll('.showcase-tabs .tab-btn').forEach(btn =>
btn.classList.remove('active')
);
document.querySelectorAll('.tab-content .tab-pane').forEach(tab =>
tab.classList.remove('active')
);
// Add active class to clicked tab
button.classList.add('active');
const tabId = `${button.dataset.tab}-tab`;
document.getElementById(tabId).classList.add('active');
// If switching to description tab, make sure content is properly sized
if (button.dataset.tab === 'description') {
const descriptionContent = document.querySelector('.model-description-content');
if (descriptionContent) {
const hasContent = descriptionContent.innerHTML.trim() !== '';
document.querySelector('.model-description-loading')?.classList.add('hidden');
// If no content, show a message
if (!hasContent) {
descriptionContent.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContent.classList.remove('hidden');
}
}
}
});
});
}
/**
* 加载模型描述
* @param {string} modelId - 模型ID
* @param {string} filePath - 文件路径
*/
export async function loadModelDescription(modelId, filePath) {
try {
const descriptionContainer = document.querySelector('.model-description-content');
const loadingElement = document.querySelector('.model-description-loading');
if (!descriptionContainer || !loadingElement) return;
// Show loading indicator
loadingElement.classList.remove('hidden');
descriptionContainer.classList.add('hidden');
// Try to get model description from API
const response = await fetch(`/api/lora-model-description?model_id=${modelId}&file_path=${encodeURIComponent(filePath)}`);
if (!response.ok) {
throw new Error(`Failed to fetch model description: ${response.statusText}`);
}
const data = await response.json();
if (data.success && data.description) {
// Update the description content
descriptionContainer.innerHTML = data.description;
// Process any links in the description to open in new tab
const links = descriptionContainer.querySelectorAll('a');
links.forEach(link => {
link.setAttribute('target', '_blank');
link.setAttribute('rel', 'noopener noreferrer');
});
// Show the description and hide loading indicator
descriptionContainer.classList.remove('hidden');
loadingElement.classList.add('hidden');
} else {
throw new Error(data.error || 'No description available');
}
} catch (error) {
console.error('Error loading model description:', error);
const loadingElement = document.querySelector('.model-description-loading');
if (loadingElement) {
loadingElement.innerHTML = `<div class="error-message">Failed to load model description. ${error.message}</div>`;
}
// Show empty state message in the description container
const descriptionContainer = document.querySelector('.model-description-content');
if (descriptionContainer) {
descriptionContainer.innerHTML = '<div class="no-description">No model description available</div>';
descriptionContainer.classList.remove('hidden');
}
}
}

View File

@@ -0,0 +1,474 @@
/**
* ModelMetadata.js
* 处理LoRA模型元数据编辑相关的功能模块
*/
import { showToast } from '../../utils/uiHelpers.js';
import { BASE_MODELS } from '../../utils/constants.js';
import { updateLoraCard } from '../../utils/cardUpdater.js';
/**
* 保存模型元数据到服务器
* @param {string} filePath - 文件路径
* @param {Object} data - 要保存的数据
* @returns {Promise} 保存操作的Promise
*/
export async function saveModelMetadata(filePath, data) {
const response = await fetch('/api/loras/save-metadata', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
...data
})
});
if (!response.ok) {
throw new Error('Failed to save metadata');
}
return response.json();
}
/**
* 设置模型名称编辑功能
* @param {string} filePath - 文件路径
*/
export function setupModelNameEditing(filePath) {
const modelNameContent = document.querySelector('.model-name-content');
const editBtn = document.querySelector('.edit-model-name-btn');
if (!modelNameContent || !editBtn) return;
// Store the file path in a data attribute for later use
modelNameContent.dataset.filePath = filePath;
// Show edit button on hover
const modelNameHeader = document.querySelector('.model-name-header');
modelNameHeader.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
modelNameHeader.addEventListener('mouseleave', () => {
if (!modelNameContent.getAttribute('data-editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
modelNameContent.setAttribute('data-editing', 'true');
modelNameContent.focus();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
if (modelNameContent.childNodes.length > 0) {
range.setStart(modelNameContent.childNodes[0], modelNameContent.textContent.length);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
editBtn.classList.add('visible');
});
// Handle focus out
modelNameContent.addEventListener('blur', function() {
this.removeAttribute('data-editing');
editBtn.classList.remove('visible');
if (this.textContent.trim() === '') {
// Restore original model name if empty
const filePath = this.dataset.filePath;
const loraCard = document.querySelector(`.lora-card[data-filepath="${filePath}"]`);
if (loraCard) {
this.textContent = loraCard.dataset.model_name;
}
}
});
// Handle enter key
modelNameContent.addEventListener('keydown', function(e) {
if (e.key === 'Enter') {
e.preventDefault();
const filePath = this.dataset.filePath;
saveModelName(filePath);
this.blur();
}
});
// Limit model name length
modelNameContent.addEventListener('input', function() {
// Limit model name length
if (this.textContent.length > 100) {
this.textContent = this.textContent.substring(0, 100);
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.setStart(this.childNodes[0], 100);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
showToast('Model name is limited to 100 characters', 'warning');
}
});
}
/**
* 保存模型名称
* @param {string} filePath - 文件路径
*/
async function saveModelName(filePath) {
const modelNameElement = document.querySelector('.model-name-content');
const newModelName = modelNameElement.textContent.trim();
// Validate model name
if (!newModelName) {
showToast('Model name cannot be empty', 'error');
return;
}
// Check if model name is too long (limit to 100 characters)
if (newModelName.length > 100) {
showToast('Model name is too long (maximum 100 characters)', 'error');
// Truncate the displayed text
modelNameElement.textContent = newModelName.substring(0, 100);
return;
}
try {
await saveModelMetadata(filePath, { model_name: newModelName });
// Update the corresponding lora card's dataset and display
updateLoraCard(filePath, { model_name: newModelName });
showToast('Model name updated successfully', 'success');
} catch (error) {
showToast('Failed to update model name', 'error');
}
}
/**
* 设置基础模型编辑功能
* @param {string} filePath - 文件路径
*/
export function setupBaseModelEditing(filePath) {
const baseModelContent = document.querySelector('.base-model-content');
const editBtn = document.querySelector('.edit-base-model-btn');
if (!baseModelContent || !editBtn) return;
// Store the file path in a data attribute for later use
baseModelContent.dataset.filePath = filePath;
// Show edit button on hover
const baseModelDisplay = document.querySelector('.base-model-display');
baseModelDisplay.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
baseModelDisplay.addEventListener('mouseleave', () => {
if (!baseModelDisplay.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
baseModelDisplay.classList.add('editing');
// Store the original value to check for changes later
const originalValue = baseModelContent.textContent.trim();
// Create dropdown selector to replace the base model content
const currentValue = originalValue;
const dropdown = document.createElement('select');
dropdown.className = 'base-model-selector';
// Flag to track if a change was made
let valueChanged = false;
// Add options from BASE_MODELS constants
const baseModelCategories = {
'Stable Diffusion 1.x': [BASE_MODELS.SD_1_4, BASE_MODELS.SD_1_5, BASE_MODELS.SD_1_5_LCM, BASE_MODELS.SD_1_5_HYPER],
'Stable Diffusion 2.x': [BASE_MODELS.SD_2_0, BASE_MODELS.SD_2_1],
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [BASE_MODELS.SVD, BASE_MODELS.WAN_VIDEO, BASE_MODELS.HUNYUAN_VIDEO],
'Other Models': [
BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.AURAFLOW,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.UNKNOWN
]
};
// Create option groups for better organization
Object.entries(baseModelCategories).forEach(([category, models]) => {
const group = document.createElement('optgroup');
group.label = category;
models.forEach(model => {
const option = document.createElement('option');
option.value = model;
option.textContent = model;
option.selected = model === currentValue;
group.appendChild(option);
});
dropdown.appendChild(group);
});
// Replace content with dropdown
baseModelContent.style.display = 'none';
baseModelDisplay.insertBefore(dropdown, editBtn);
// Hide edit button during editing
editBtn.style.display = 'none';
// Focus the dropdown
dropdown.focus();
// Handle dropdown change
dropdown.addEventListener('change', function() {
const selectedModel = this.value;
baseModelContent.textContent = selectedModel;
// Mark that a change was made if the value differs from original
if (selectedModel !== originalValue) {
valueChanged = true;
} else {
valueChanged = false;
}
});
// Function to save changes and exit edit mode
const saveAndExit = function() {
// Check if dropdown still exists and remove it
if (dropdown && dropdown.parentNode === baseModelDisplay) {
baseModelDisplay.removeChild(dropdown);
}
// Show the content and edit button
baseModelContent.style.display = '';
editBtn.style.display = '';
// Remove editing class
baseModelDisplay.classList.remove('editing');
// Only save if the value has actually changed
if (valueChanged || baseModelContent.textContent.trim() !== originalValue) {
// Get file path from the dataset
const filePath = baseModelContent.dataset.filePath;
// Save the changes, passing the original value for comparison
saveBaseModel(filePath, originalValue);
}
// Remove this event listener
document.removeEventListener('click', outsideClickHandler);
};
// Handle outside clicks to save and exit
const outsideClickHandler = function(e) {
// If click is outside the dropdown and base model display
if (!baseModelDisplay.contains(e.target)) {
saveAndExit();
}
};
// Add delayed event listener for outside clicks
setTimeout(() => {
document.addEventListener('click', outsideClickHandler);
}, 0);
// Also handle dropdown blur event
dropdown.addEventListener('blur', function(e) {
// Only save if the related target is not the edit button or inside the baseModelDisplay
if (!baseModelDisplay.contains(e.relatedTarget)) {
saveAndExit();
}
});
});
}
/**
* 保存基础模型
* @param {string} filePath - 文件路径
* @param {string} originalValue - 原始值(用于比较)
*/
async function saveBaseModel(filePath, originalValue) {
const baseModelElement = document.querySelector('.base-model-content');
const newBaseModel = baseModelElement.textContent.trim();
// Only save if the value has actually changed
if (newBaseModel === originalValue) {
return; // No change, no need to save
}
try {
await saveModelMetadata(filePath, { base_model: newBaseModel });
// Update the corresponding lora card's dataset
updateLoraCard(filePath, { base_model: newBaseModel });
showToast('Base model updated successfully', 'success');
} catch (error) {
showToast('Failed to update base model', 'error');
}
}
/**
* 设置文件名编辑功能
* @param {string} filePath - 文件路径
*/
export function setupFileNameEditing(filePath) {
const fileNameContent = document.querySelector('.file-name-content');
const editBtn = document.querySelector('.edit-file-name-btn');
if (!fileNameContent || !editBtn) return;
// Store the original file path
fileNameContent.dataset.filePath = filePath;
// Show edit button on hover
const fileNameWrapper = document.querySelector('.file-name-wrapper');
fileNameWrapper.addEventListener('mouseenter', () => {
editBtn.classList.add('visible');
});
fileNameWrapper.addEventListener('mouseleave', () => {
if (!fileNameWrapper.classList.contains('editing')) {
editBtn.classList.remove('visible');
}
});
// Handle edit button click
editBtn.addEventListener('click', () => {
fileNameWrapper.classList.add('editing');
fileNameContent.setAttribute('contenteditable', 'true');
fileNameContent.focus();
// Store original value for comparison later
fileNameContent.dataset.originalValue = fileNameContent.textContent.trim();
// Place cursor at the end
const range = document.createRange();
const sel = window.getSelection();
range.selectNodeContents(fileNameContent);
range.collapse(false);
sel.removeAllRanges();
sel.addRange(range);
editBtn.classList.add('visible');
});
// Handle keyboard events in edit mode
fileNameContent.addEventListener('keydown', function(e) {
if (!this.getAttribute('contenteditable')) return;
if (e.key === 'Enter') {
e.preventDefault();
this.blur(); // Trigger save on Enter
} else if (e.key === 'Escape') {
e.preventDefault();
// Restore original value
this.textContent = this.dataset.originalValue;
exitEditMode();
}
});
// Handle input validation
fileNameContent.addEventListener('input', function() {
if (!this.getAttribute('contenteditable')) return;
// Replace invalid characters for filenames
const invalidChars = /[\\/:*?"<>|]/g;
if (invalidChars.test(this.textContent)) {
const cursorPos = window.getSelection().getRangeAt(0).startOffset;
this.textContent = this.textContent.replace(invalidChars, '');
// Restore cursor position
const range = document.createRange();
const sel = window.getSelection();
const newPos = Math.min(cursorPos, this.textContent.length);
if (this.firstChild) {
range.setStart(this.firstChild, newPos);
range.collapse(true);
sel.removeAllRanges();
sel.addRange(range);
}
showToast('Invalid characters removed from filename', 'warning');
}
});
// Handle focus out - save changes
fileNameContent.addEventListener('blur', async function() {
if (!this.getAttribute('contenteditable')) return;
const newFileName = this.textContent.trim();
const originalValue = this.dataset.originalValue;
// Basic validation
if (!newFileName) {
// Restore original value if empty
this.textContent = originalValue;
showToast('File name cannot be empty', 'error');
exitEditMode();
return;
}
if (newFileName === originalValue) {
// No changes, just exit edit mode
exitEditMode();
return;
}
try {
// Get the file path from the dataset
const filePath = this.dataset.filePath;
// Call API to rename the file
const response = await fetch('/api/rename_lora', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
file_path: filePath,
new_file_name: newFileName
})
});
const result = await response.json();
if (result.success) {
showToast('File name updated successfully', 'success');
// Get the new file path and update the card
const newFilePath = filePath.replace(originalValue, newFileName);
// Pass the new file_name in the updates object for proper card update
updateLoraCard(filePath, { file_name: newFileName }, newFilePath);
} else {
throw new Error(result.error || 'Unknown error');
}
} catch (error) {
console.error('Error renaming file:', error);
this.textContent = originalValue; // Restore original file name
showToast(`Failed to rename file: ${error.message}`, 'error');
} finally {
exitEditMode();
}
});
function exitEditMode() {
fileNameContent.removeAttribute('contenteditable');
fileNameWrapper.classList.remove('editing');
editBtn.classList.remove('visible');
}
}

Some files were not shown because too many files have changed in this diff Show More