Compare commits

...

442 Commits

Author SHA1 Message Date
Will Miao
6c5559ae2d chore: Update version to 0.8.29 and add release notes for enhanced recipe imports and bug fixes 2025-08-21 08:44:07 +08:00
Will Miao
9f54622b17 fix: Improve author retrieval logic in calculate_relative_path_for_model function to handle missing creator data 2025-08-21 07:34:54 +08:00
Will Miao
03b6f4b378 refactor: Clean up and optimize import modal and related components, removing unused styles and improving path selection functionality 2025-08-20 23:12:38 +08:00
Will Miao
af4cbe2332 feat: Add LoraManagerTextLoader for loading LoRAs from text syntax with enhanced parsing 2025-08-20 18:16:29 +08:00
Will Miao
141f72963a fix: Enhance download functionality with resumable downloads and improved error handling 2025-08-20 16:40:22 +08:00
Will Miao
3d3c66e12f fix: Improve widget handling in lora_loader, lora_stacker, and wanvideo_lora_select, and ensuring expanded state preservation in loras_widget 2025-08-19 22:31:11 +08:00
Will Miao
ee84571bdb refactor: Simplify handling of base model path mappings and download path templates by removing unnecessary JSON.stringify calls 2025-08-19 20:20:30 +08:00
Will Miao
6500936aad refactor: Remove unused DataWrapper class to clean up utils.js 2025-08-19 20:19:58 +08:00
Will Miao
32d2b6c013 fix: disable pysssss autocomplete in Lora-related nodes
Disable PySSSS autocomplete functionality in:
- Lora Loader
- Lora Stacker
- WanVideo Lora Select node
2025-08-19 08:54:12 +08:00
Will Miao
05df40977d refactor: Update chunk size to 4MB for improved HDD throughput and optimize file writing during downloads 2025-08-18 17:21:24 +08:00
Will Miao
5d7a1dcde5 refactor: Comment out duplicate filename logging in ModelScanner for cleaner cache build process, fixes #365 2025-08-18 16:46:16 +08:00
Will Miao
9c45d9db6c feat: Enhance WanVideoLoraSelect with improved low_mem_load and merge_loras options for better LORA management, see #363 2025-08-18 15:05:57 +08:00
Will Miao
ca692ed0f2 feat: Update release notes and version to v0.8.28 with new features and enhancements 2025-08-18 07:14:08 +08:00
Will Miao
af499565d3 Revert "feat: Add CheckpointLoaderSimpleExtended to NODE_EXTRACTORS for enhanced checkpoint loading"
This reverts commit fe2d7e3a9e.
2025-08-17 22:43:15 +08:00
Will Miao
fe2d7e3a9e feat: Add CheckpointLoaderSimpleExtended to NODE_EXTRACTORS for enhanced checkpoint loading 2025-08-17 21:16:27 +08:00
Will Miao
9f69822221 feat: Refactor SamplerCustom handling and enhance node extractor mappings for improved metadata processing 2025-08-17 20:42:52 +08:00
Will Miao
bb43f047c2 feat: Add auto-organize progress tracking and WebSocket broadcasting in BaseModelRoutes and WebSocketManager 2025-08-16 21:11:33 +08:00
Will Miao
2356662492 fix: Improve author retrieval logic in DownloadManager to handle non-dictionary creator data 2025-08-16 21:10:57 +08:00
Will Miao
1624a45093 fix: Update author retrieval to handle missing username gracefully in DownloadManager and utils 2025-08-16 16:11:56 +08:00
Will Miao
dcb9983786 feat: Refactor duplicates management with user preference for notification visibility and modular banner component, fixes #359 2025-08-16 09:14:35 +08:00
Will Miao
83d1828905 feat: Enhance text cleanup in LoraLoader, LoraStacker, and WanVideoLoraSelect to handle extra commas and trailing commas 2025-08-16 08:31:04 +08:00
Will Miao
6a281cf3ee feat: Implement autocomplete feature with enhanced UI and tooltip support
- Added AutoComplete class to handle input suggestions based on user input.
- Integrated TextAreaCaretHelper for accurate positioning of the dropdown.
- Enhanced dropdown styling with a new color scheme and custom scrollbar.
- Implemented dynamic loading of preview tooltips for selected items.
- Added keyboard navigation support for dropdown items.
- Included functionality to insert selected items into the input field with usage tips.
- Created a separate TextAreaCaretHelper module for managing caret position calculations.
2025-08-16 07:53:55 +08:00
Will Miao
ed1cd39a6c feat: add model notes, preview URL, and Civitai URL endpoints to BaseModelRoutes and BaseModelService 2025-08-15 18:58:49 +08:00
Will Miao
dda19b3920 feat: add download example images functionality to context menus, see #347 2025-08-15 17:15:31 +08:00
Will Miao
25139ca922 feat: enhance bulk operations panel styling and update downloadExampleImages method to accept optional modelTypes parameter 2025-08-15 15:58:33 +08:00
Will Miao
3cd57a582c feat: add force download functionality for example images with progress tracking 2025-08-15 15:16:12 +08:00
Will Miao
d3903ac655 feat: add success toast notification after metadata update completion 2025-08-15 09:43:16 +08:00
Will Miao
199e374318 feat: update release notes for v0.8.27 and bump version to 0.8.27 2025-08-14 07:32:09 +08:00
pixelpaws
8375c1413d Merge pull request #354 from Clusters/main
feat: Add qwen-image as a selectable base model
2025-08-14 07:14:27 +08:00
Andreas
9e268cf016 Merge branch 'willmiao:main' into main 2025-08-13 17:51:10 +02:00
Andreas
112b3abc26 feat: add qwen-image base model to ModelMetadata 2025-08-13 15:49:30 +00:00
Andreas
a8331a2357 feat: add qwen-image to base model constants.js 2025-08-13 15:48:10 +00:00
Will Miao
52e3ad08c1 feat: add placeholder for empty folder tree in download modal 2025-08-13 23:45:37 +08:00
Will Miao
8d01d04ef0 feat: add default path toggle and update download modal for improved path selection 2025-08-13 23:30:48 +08:00
Will Miao
a141384907 feat: update default path customization image for improved clarity 2025-08-13 20:15:11 +08:00
Will Miao
b8aa7184bd feat: update download path template handling for model types and migrate old settings 2025-08-13 19:23:37 +08:00
Will Miao
e4195f874d feat: implement download path templates configuration with support for multiple model types and custom templates 2025-08-13 17:42:40 +08:00
Will Miao
d04deff5ca feat: enhance download and move modals with improved folder path input, autocomplete, and folder tree integration 2025-08-13 14:41:21 +08:00
Will Miao
20ce0778a0 fix: correct default root key generation by using singular model type 2025-08-13 11:06:39 +08:00
Will Miao
5a0b3470f1 feat: enhance auto-organize functionality with empty directory cleanup and progress reporting 2025-08-13 10:36:31 +08:00
Will Miao
a920921570 feat: implement auto-organize models endpoint with batch processing and error handling 2025-08-12 22:39:40 +08:00
Will Miao
286f4ff384 feat: add folder tree and unified folder tree endpoints, enhance download modal with folder path input and tree navigation 2025-08-12 22:34:53 +08:00
Will Miao
71ddfafa98 refactor: move download modal styles to a dedicated file and update import path 2025-08-12 21:40:43 +08:00
Will Miao
b7e3e53697 feat: implement version mismatch handling and banner registration in UpdateService 2025-08-12 15:09:45 +08:00
Will Miao
16df548b77 fix: expand supported file extensions in CheckpointScanner initialization, fixes #353 2025-08-12 09:20:08 +08:00
Will Miao
425c33ae00 fix: update model identifier handling in RecipeModal and DownloadManager for consistency 2025-08-11 17:13:42 +08:00
Will Miao
c9289ed2dc fix: improve duplicate filename handling and logging in ModelScanner and ModelHashIndex 2025-08-11 17:13:21 +08:00
Will Miao
96517cbdef fix: update model_id and model_version_id handling across various services for improved flexibility 2025-08-11 15:31:49 +08:00
Will Miao
b03420faac fix: skip LoRAs without proper identification in Civitai metadata parser 2025-08-11 11:14:45 +08:00
Will Miao
65a1aa7ca2 fix: add missing embeddings folder paths in settings example 2025-08-11 07:05:58 +08:00
pixelpaws
3a92e8eaf9 Update README.md 2025-08-10 16:11:28 +08:00
Will Miao
a8dc50d64a fix: update portable package link to version 0.8.26 in README 2025-08-10 16:05:50 +08:00
Will Miao
3397cc7d8d fix: update screenshot image to reflect latest UI changes 2025-08-10 09:02:46 +08:00
Will Miao
c3e8131b24 feat: enhance download manager to track failed models and update progress reporting 2025-08-10 08:07:52 +08:00
Will Miao
f8ca8584ae feat: enhance URL safety in path mapping by encoding special characters 2025-08-09 16:25:55 +08:00
Will Miao
3050bbe260 fix: improve image handling logic to ensure input is always a list or array, see #346 2025-08-09 07:20:28 +08:00
Will Miao
e1dda2795a update README.md 2025-08-08 20:13:20 +08:00
Will Miao
6d8408e626 feat: update release notes and version to 0.8.26, adding creator search and enhancing node usability 2025-08-08 20:10:06 +08:00
Will Miao
0906271aa9 refactor: simplify auto download check logic by removing unnecessary progress updates 2025-08-08 19:58:20 +08:00
Will Miao
4c33c9d256 feat: enhance folder update logic with error handling in fetchModelsPage 2025-08-08 17:33:11 +08:00
Will Miao
fa9c78209f feat: update API endpoints to include '/list' for model retrieval in routes and templates, fixes #344 2025-08-07 18:06:40 +08:00
Will Miao
6678ec8a60 refactor: remove unused height properties and simplify widget height handling in various components, fixes #284 2025-08-07 16:49:39 +08:00
Will Miao
854e467c12 feat: add debug logging for default root settings in DownloadManager 2025-08-07 14:42:05 +08:00
Will Miao
e6b94c7b21 refactor: remove unused import and simplify filename handling in ModelHashIndex, fixes #342 2025-08-06 19:11:07 +08:00
Will Miao
2c6f9d8602 feat: add creator search option and update related functionality across models and UI 2025-08-06 18:32:57 +08:00
Will Miao
c74033b9c0 refactor: conditionally initialize managers in HeaderManager to avoid unnecessary setup on statistics page 2025-08-06 11:14:02 +08:00
Will Miao
d2b21d27bb refactor: remove unused imports from update_routes.py and requirements.txt 2025-08-06 10:34:40 +08:00
Will Miao
215272469f refactor: replace model API client import and remove performance logging, add reset and reload functionality 2025-08-06 07:56:48 +08:00
Will Miao
f7d05ab0f1 refactor: change logging level from info to debug for download progress messages 2025-08-06 06:44:35 +08:00
Will Miao
6f2ad2be77 fix: update LoRA model type check to use constant for improved readability, fixes #341 2025-08-05 19:11:28 +08:00
Will Miao
66575c719a feat: update version to 0.8.25, add release notes for v0.8.25 including LoRA list reordering, bulk operations, and auto download setting for example images 2025-08-05 18:30:06 +08:00
Will Miao
677a239d53 feat: add setting to include trigger words in LoRA syntax, update UI and functionality, fixes #268 2025-08-05 18:04:10 +08:00
Will Miao
3b96bfe5af feat: add auto download setting for example images with UI toggle and functionality, fixes #288 2025-08-05 16:49:46 +08:00
Will Miao
83be5cfa64 feat: enhance plugin update process by adding .tracking file for extracted files 2025-08-05 15:46:57 +08:00
Will Miao
6b834c2362 Add wiki image 2025-08-05 13:00:10 +08:00
Will Miao
7abfc49e08 feat: implement bulk operations for model management including delete, move, and refresh functionalities 2025-08-05 11:23:20 +08:00
Will Miao
65d5f50088 feat: add LoRA extraction and Civitai info population in CivitaiApiMetadataParser (#307) 2025-08-05 09:29:54 +08:00
Will Miao
4f1f4ffe3d feat: remove unused image download functions and dependencies for cleaner code 2025-08-05 09:09:17 +08:00
Will Miao
b0c2027a1c feat: add path validation for model folder in ExampleImagesFileManager 2025-08-05 07:35:19 +08:00
Will Miao
33c83358b0 feat: streamline Git information retrieval using GitPython for improved accuracy and performance 2025-08-05 07:28:08 +08:00
Will Miao
31223f0526 feat: enhance model root fetching and moving functionality across various components 2025-08-04 23:37:27 +08:00
Will Miao
92daadb92c feat: add endpoints for retrieving checkpoints and unet roots in CheckpointApiClient 2025-08-04 22:23:43 +08:00
Will Miao
fae2e274fd feat: enable move operations for all model types and remove unsupported methods from specific clients 2025-08-04 19:51:02 +08:00
Will Miao
342a722991 feat: refactor model API structure to support specific model types with dedicated API clients for Checkpoints, LoRAs, and Embeddings
refactor: consolidate model API client creation into a factory function for better maintainability
feat: implement move operations for LoRAs and handle unsupported operations for Checkpoints and Embeddings
2025-08-04 19:37:53 +08:00
Will Miao
65ec6aacb7 feat: add model moving endpoints for individual and bulk operations 2025-08-04 18:15:03 +08:00
Will Miao
9387470c69 feat: add endpoints for retrieving checkpoint and unet roots from config 2025-08-04 17:40:19 +08:00
Will Miao
31f6edf8f0 feat: enhance responsiveness of header container for larger screens 2025-08-04 17:19:04 +08:00
Will Miao
487b062175 refactor: simplify API endpoint construction in FilterManager for top tags and base models 2025-08-04 17:06:54 +08:00
Will Miao
d8e13de096 feat: enhance metadata adjustment in CheckpointScanner and ModelScanner for improved model type handling 2025-08-04 17:06:46 +08:00
Will Miao
e8a30088ef refactor: streamline model scanning by removing redundant file processing method and enhancing directory scanning logic 2025-08-04 15:49:50 +08:00
Will Miao
bf7b07ba74 feat: deduplicate and merge checkpoint and unet paths in configuration. See #338 and #312 2025-08-04 10:48:48 +08:00
Will Miao
28fe3e7b7a chore: update version to 0.8.24 in pyproject.toml 2025-08-02 16:23:19 +08:00
Will Miao
c0eff2bb5e feat: enhance async metadata collection by updating function signature and preserving all parameters. Fixes #328 #327 2025-08-01 21:47:52 +08:00
Will Miao
848c1741fe feat: add parsing for 'air' field in Civitai resources to enhance metadata extraction. Fixes #322 2025-07-31 14:15:22 +08:00
Will Miao
1370b8e8c1 feat: implement drag-and-drop reordering for LoRA entries and enhance keyboard navigation. Fixes #302 2025-07-30 15:32:31 +08:00
Will Miao
82a068e610 feat: auto set default root paths for loras, checkpoints, and embeddings in settings 2025-07-30 10:08:21 +08:00
Will Miao
32f42bafaa chore: update version to 0.8.23 in pyproject.toml 2025-07-29 20:30:45 +08:00
Will Miao
4081b7f022 feat: implement settings synchronization with backend and migrate legacy settings 2025-07-29 20:29:19 +08:00
Will Miao
a5808193a6 fix: rename URL error element ID to 'importUrlError' for consistency across components 2025-07-29 16:13:27 +08:00
Will Miao
854ca322c1 fix: update short_hash in git_info to 'stable' in update_routes.py 2025-07-29 08:34:41 +08:00
Will Miao
c1d9b5137a feat: add version name display to model cards in ModelCard.js and style it in card.css. Fixes #287 2025-07-28 16:36:23 +08:00
Will Miao
f33d5745b3 feat: enhance model description editing functionality in ModelDescription.js and integrate with ModelModal.js. Fixes #292 2025-07-28 11:52:04 +08:00
Will Miao
d89c2ca128 chore: Update version to 0.8.22 in pyproject.toml 2025-07-27 21:20:35 +08:00
Will Miao
835584cc85 fix: update restart message for ComfyUI and LoRA Manager after successful update 2025-07-27 21:20:09 +08:00
Will Miao
b2ffbe3a68 feat: implement fallback ZIP download for plugin updates when .git is missing 2025-07-27 20:56:51 +08:00
Will Miao
defcc79e6c feat: add release notes for v0.8.22 2025-07-27 20:34:46 +08:00
Will Miao
c06d9f84f0 fix: disable pointer events on video element in model card preview 2025-07-27 20:02:21 +08:00
Will Miao
fe57a8e156 feat: implement banner service for managing notification banners, including UI integration and storage handling 2025-07-27 18:07:43 +08:00
Will Miao
b77105795a feat: add embedding support in statistics page, including data handling and UI updates 2025-07-27 16:36:14 +08:00
Will Miao
e2df5fcf27 feat: add default embedding root setting and load functionality in settings manager 2025-07-27 15:58:15 +08:00
Will Miao
836a64e728 refactor: enhance bulk metadata refresh functionality and update UI components 2025-07-26 23:45:57 +08:00
Will Miao
08ba0c9f42 refactor: remove one-click integration image from README 2025-07-26 09:55:06 +08:00
Will Miao
6fcc6a5299 Update README.md 2025-07-26 09:53:19 +08:00
Will Miao
6dd58248c6 refactor: add embedding scanner support in download manager and example images processor 2025-07-26 07:35:53 +08:00
pixelpaws
2786801b71 Merge pull request #317 from willmiao/refactor
Refactor
2025-07-26 07:06:37 +08:00
Will Miao
ea29cbeb7a refactor: add synchronous service retrieval method to ServiceRegistry 2025-07-26 07:05:27 +08:00
Will Miao
3cf9121a8c refactor: enhance scanner handling and add embedding support in download manager and misc routes 2025-07-25 23:59:27 +08:00
Will Miao
381bd3938a refactor: rename 'lora-card' to 'model-card' across styles and scripts for consistency 2025-07-25 23:23:57 +08:00
Will Miao
e4ce384023 feat: implement embeddings functionality with context menus, controls, and page management 2025-07-25 23:15:33 +08:00
Will Miao
12d1857b13 refactor: update model type references from 'lora' to 'loras' and streamline event delegation setup 2025-07-25 22:33:46 +08:00
Will Miao
0d9003dea4 refactor: remove legacy card components and update imports to use shared ModelCard component 2025-07-25 22:00:38 +08:00
Will Miao
1a3751acfa refactor: unify API client usage across models and remove deprecated API files 2025-07-25 21:38:54 +08:00
Will Miao
c5a3af2399 feat: add embedding management functionality with routes, services, and UI integration 2025-07-25 21:14:56 +08:00
Will Miao
ea8a64fafc refactor: remove unused get_models method from LoraRoutes 2025-07-25 18:23:52 +08:00
Will Miao
981e367bf1 refactor: remove unused API and page routes from routes.js 2025-07-25 17:51:40 +08:00
Will Miao
a3d6e62035 refactor: centralize resetAndReload functionality in baseModelApi 2025-07-25 17:48:02 +08:00
Will Miao
7f205cdcc8 refactor: unify model download system across all model types
- Add download-related methods to baseModelApi.js for fetching versions, roots, folders, and downloading models
- Replace separate download managers with a unified DownloadManager.js supporting all model types
- Create a single download_modals.html template that adapts to model type (LoRA, checkpoint, etc.)
- Remove old download modals from lora_modals.html and checkpoint_modals.html
- Update apiConfig.js to include civitaiVersions endpoints for each model type
- Centralize event handler binding in DownloadManager.js (no more inline HTML handlers)
- Modal UI and logic now auto-adapt to the current model type, making future extension easier
2025-07-25 17:35:06 +08:00
Will Miao
e587189880 Refactor modal.css into modular components 2025-07-25 16:36:07 +08:00
Will Miao
206c1bd69f Refactor modals.html into modular components 2025-07-25 16:10:16 +08:00
Will Miao
a7d9255c2c refactor: Replace direct model metadata API calls with unified model API client 2025-07-25 15:35:16 +08:00
Will Miao
08265a85ec refactor: Include new file path in response after moving model 2025-07-25 15:10:03 +08:00
Will Miao
1ed5630464 Merge branch 'refactor' of https://github.com/willmiao/ComfyUI-Lora-Manager into refactor 2025-07-25 14:49:30 +08:00
Will Miao
c784615f11 refactor: Simplify API calls and enhance model moving functionality 2025-07-25 14:48:28 +08:00
Will Miao
26d51b1190 refactor: Simplify API calls and enhance model moving functionality 2025-07-25 14:44:05 +08:00
Will Miao
d83fad6abc Refactor API structure to unify model operations
- Introduced MODEL_TYPES and MODEL_CONFIG for centralized model type management.
- Created a unified API client for checkpoints and loras to streamline operations.
- Updated all API calls in checkpointApi.js and loraApi.js to use the new client.
- Simplified context menus and model card operations to leverage the unified API client.
- Enhanced state management to accommodate new model types and their configurations.
- Added virtual scrolling functions for recipes and improved loading states.
- Refactored modal utilities to handle model exclusion and deletion generically.
- Improved error handling and user feedback across various operations.
2025-07-25 10:04:18 +08:00
Will Miao
692796db46 refactor: Update API endpoints to include 'loras' prefix for consistency 2025-07-24 19:56:55 +08:00
pixelpaws
f15c6f33f9 Merge pull request #313 from willmiao/refactor
Refactor
2025-07-24 19:37:17 +08:00
Will Miao
dda9eb4d7c refactor: Remove MessagePack dependency and related cache management code 2025-07-24 19:30:47 +08:00
Will Miao
6f3aeb61e7 feat: Implement Git-based update functionality with nightly mode support and UI enhancements 2025-07-24 19:03:52 +08:00
Will Miao
d6145e633f refactor: Simplify cache resort calls in model metadata updates and API routes 2025-07-24 10:47:19 +08:00
Will Miao
07014d98ce refactor: Enhance logging configuration by adding a filter for non-critical connection reset errors 2025-07-24 09:47:51 +08:00
Will Miao
e8ccdabe6c refactor: Enhance sorting functionality and UI for model selection, including legacy format conversion 2025-07-24 09:26:15 +08:00
Will Miao
cf9fd2d5c2 refactor: Rename LoraScanner methods for consistency and remove deprecated checkpoint methods 2025-07-24 06:25:33 +08:00
Will Miao
bf9aa9356b refactor: Update model retrieval methods in RecipeRoutes and streamline CheckpointScanner and LoraScanner initialization 2025-07-23 23:27:18 +08:00
Will Miao
68d00ce289 refactor: Adjust logging configuration to reduce verbosity for asyncio logger 2025-07-23 22:58:40 +08:00
Will Miao
5288021e4f refactor: Simplify filtering methods and enhance CJK character handling in LoraService 2025-07-23 22:55:42 +08:00
Will Miao
4d38add291 Revert "refactor: Update logging configuration to use asyncio logger and remove aiohttp access logger references"
This reverts commit 804808da4a.
2025-07-23 22:23:48 +08:00
Will Miao
804808da4a refactor: Update logging configuration to use asyncio logger and remove aiohttp access logger references 2025-07-23 22:09:42 +08:00
Will Miao
298a95432d feat: Integrate WebSocket routes for download progress tracking in standalone manager 2025-07-23 18:02:38 +08:00
Will Miao
a834fc4b30 feat: Update API routes for LoRA management and enhance folder handling 2025-07-23 17:26:06 +08:00
Will Miao
2c6c9542dd refactor: Change logging level from info to debug for service registration 2025-07-23 16:59:16 +08:00
Will Miao
a9a7f4c8ec refactor: Remove legacy API route handlers from standalone manager 2025-07-23 16:30:00 +08:00
Will Miao
ea9370443d refactor: Implement download management routes and update API endpoints for LoRA 2025-07-23 16:11:02 +08:00
Will Miao
c2e00b240e feat: Enhance model routes with generic page handling and template integration 2025-07-23 15:30:39 +08:00
Will Miao
a2b81ea099 refactor: Implement base model routes and services for LoRA and Checkpoint
- Added BaseModelRoutes class to handle common routes and logic for model types.
- Created CheckpointRoutes class inheriting from BaseModelRoutes for checkpoint-specific routes.
- Implemented CheckpointService class for handling checkpoint-related data and operations.
- Developed LoraService class for managing LoRA-specific functionalities.
- Introduced ModelServiceFactory to manage service and route registrations for different model types.
- Established methods for fetching, filtering, and formatting model data across services.
- Integrated CivitAI metadata handling within model routes and services.
- Added pagination and filtering capabilities for model data retrieval.
2025-07-23 14:39:02 +08:00
Will Miao
ee609e8eac Revert "feat: Implement check for missing creator in model metadata"
This reverts commit 0184dfd7eb.
2025-07-23 06:33:00 +08:00
Will Miao
e04ef671e9 feat: Update metadata handling to use current timestamp for model modifications 2025-07-22 22:56:45 +08:00
Will Miao
0184dfd7eb feat: Implement check for missing creator in model metadata 2025-07-22 20:14:39 +08:00
Will Miao
eccfa0ca54 feat: Add keyboard shortcuts for bulk operations and enhance shortcut key styling 2025-07-22 19:14:36 +08:00
Will Miao
6d3feb4bef feat: Update styles for creator info and Civitai view in Lora modal; refactor button to div 2025-07-22 18:07:19 +08:00
Will Miao
29d2b5ee4b feat: Enhance creator info display and add Civitai view functionality in ModelModal 2025-07-22 17:43:33 +08:00
Will Miao
c82fabb67f feat: Refactor model type determination to use state for saving metadata and handling events 2025-07-22 16:44:21 +08:00
Will Miao
fcfc868e57 feat: Move LoRA related components to shared directory for consistency
- Added PresetTags.js to handle LoRA model preset parameter tags.
- Introduced RecipeTab.js for managing recipes associated with LoRA models.
- Created TriggerWords.js to manage trigger word functionality for LoRA models.
- Implemented utility functions in utils.js for general model modal operations.
2025-07-22 16:00:04 +08:00
Will Miao
67b403f8ca Update wiki images 2025-07-21 16:39:00 +08:00
Will Miao
de06c6b2f6 feat: Add download cancellation and tracking features in DownloadManager and API routes 2025-07-21 15:38:20 +08:00
Will Miao
fa444dfb8a Fix typo 2025-07-21 08:30:36 +08:00
Will Miao
124002a472 feat: Add JSON parsing for base_model_path_mappings and refactor path handling in DownloadManager 2025-07-21 07:37:34 +08:00
Will Miao
0c883433c1 feat: Implement download path template settings and base model path mappings in UI 2025-07-21 07:37:03 +08:00
Will Miao
bcf3b2cf55 feat: Add default root paths for LoRA and checkpoint if only one exists 2025-07-20 09:45:09 +08:00
Will Miao
357c4e9c08 refactor: Normalize and deduplicate checkpoint and unet paths in configuration 2025-07-19 23:06:43 +08:00
Will Miao
9edfc68e91 fix: Remove path_mappings.yaml from repository and update .gitignore 2025-07-19 10:09:02 +08:00
Will Miao
8c06cb3e80 chore: Bump version to 0.8.21 in pyproject.toml 2025-07-19 08:28:02 +08:00
Will Miao
144fa0a6d4 refactor: Remove redundant metadata collector initialization 2025-07-18 09:39:54 +08:00
Will Miao
25d5a1541e feat: Add pyyaml to requirements for YAML support 2025-07-17 15:09:29 +08:00
Will Miao
a579d36389 fix: Improve error message for example image import failure 2025-07-17 14:58:02 +08:00
Will Miao
d766dac341 feat: Enhance metadata collection by adding support for async execution hooks and improving error handling. See #291 #298 2025-07-17 14:45:56 +08:00
Will Miao
b15ef1bbc6 feat: Update metadata file name in MetadataManager to match actual file name. See #294 2025-07-17 06:30:41 +08:00
Will Miao
3e52e00597 feat: Add path mappings configuration file for customizable model download directories 2025-07-16 17:41:23 +08:00
Will Miao
f749dd0d52 feat: Add YAML configuration for path mappings to customize model download directories 2025-07-16 17:07:13 +08:00
Will Miao
48a8a42108 Update README.md 2025-07-16 10:33:18 +08:00
Will Miao
db7f57a5a4 feat: Refactor sampler extractors to reduce redundancy and improve maintainability. Add support for KSampler [pipe] from comfyui-impact-pack and comfyui-inspire-pack 2025-07-16 08:08:11 +08:00
Will Miao
556381b983 feat: Simplify error responses in handle_download_model with consistent JSON format 2025-07-14 17:07:52 +08:00
Will Miao
158d7d5898 Update wiki images 2025-07-12 20:24:34 +08:00
Will Miao
18844da95d chore: Update version to 0.8.20 2025-07-12 10:33:15 +08:00
Will Miao
7e0df4d718 feat: Add Civitai model tags for prioritized subfolder organization in download manager 2025-07-12 10:32:15 +08:00
Will Miao
0dbb76e8c8 feat: Add download progress endpoint and implement progress tracking in WebSocketManager 2025-07-12 10:11:16 +08:00
Will Miao
f73b3422a6 feat: Add GET endpoint for model download and handle parameters conversion 2025-07-12 09:17:36 +08:00
Will Miao
bd95e802ec refactor: Replace asynchronous service calls with synchronous counterparts in SaveImage and ServiceRegistry. Fixes #282 2025-07-11 22:48:39 +08:00
Will Miao
5de16a78c5 refactor: Replace asyncio.run with synchronous get_lora_info calls in LoraManagerLoader, LoraStacker, WanVideoLoraSelect, and ApiRoutes. See #282 2025-07-11 07:24:33 +08:00
Will Miao
6f8e09fcde chore: Update version to 0.8.20-beta in pyproject.toml 2025-07-10 18:48:56 +08:00
Will Miao
f54d480f03 refactor: Update section title and improve alignment in README for Browser Extension 2025-07-10 18:43:12 +08:00
Will Miao
e68b213fb3 feat: Add LM Civitai Extension details to README and update release notes for v0.8.20 2025-07-10 18:37:22 +08:00
Will Miao
132334d500 feat: Add new content indicators for Documentation tab and update links in modals 2025-07-10 17:39:59 +08:00
Will Miao
a6f04c6d7e refactor: Remove unused imports and dependencies from utils, recipe_routes, requirements, and pyproject files. See #278 2025-07-10 16:36:28 +08:00
Will Miao
854e8bf356 feat: Adjust CivitaiClient.get_model_version logic to handle API changes — querying by model ID no longer includes image generation metadata. Fixes #279 2025-07-10 15:29:34 +08:00
Will Miao
6ff883d2d3 fix: Update diffusers version requirement to >=0.33.1 in requirements.txt. See #278 2025-07-10 10:55:13 +08:00
Will Miao
849b97afba feat: Add CR_ApplyControlNetStack extractor and enhance prompt conditioning handling in metadata processing. Fixes #277 2025-07-10 09:26:53 +08:00
Will Miao
1bd2635864 feat: Add smZ_CLIPTextEncode extractor to NODE_EXTRACTORS. See #277 2025-07-09 22:56:56 +08:00
Will Miao
79ab0f7b6c refactor: Update folder loading to fetch dynamically from API in DownloadManager and MoveManager. Fixes #274 2025-07-09 20:29:49 +08:00
Will Miao
79011bd257 refactor: Update model_id and model_version_id types to integers and add validation in routes 2025-07-09 14:21:49 +08:00
Will Miao
c692713ffb refactor: Simplify model version existence checks and enhance version retrieval methods in scanners 2025-07-09 10:26:03 +08:00
pixelpaws
df9b554ce1 Merge pull request #267 from younyokel/patch-2
Update requirements.txt
2025-07-08 21:24:49 +08:00
Will Miao
277a8e4682 Add wiki images 2025-07-08 10:05:43 +08:00
Will Miao
acb52dba09 refactor: Remove redundant local file fallback and debug logs in showcase file handling 2025-07-07 16:34:19 +08:00
Will Miao
8f10765254 feat: Add health check route to MiscRoutes for server status monitoring 2025-07-06 21:40:47 +08:00
Will Miao
0653f59473 feat: Enhance relative path handling in download manager to include base model 2025-07-03 10:28:52 +08:00
Will Miao
7a4b5a4667 feat: Implement download progress WebSocket and enhance download manager with unique IDs 2025-07-02 23:48:35 +08:00
Will Miao
49c4a4068b feat: Add default checkpoint root setting with dynamic options in settings modal 2025-07-02 21:46:21 +08:00
Will Miao
40ad590046 refactor: Update checkpoint handling to use base_models_roots and streamline path management 2025-07-02 21:29:41 +08:00
Will Miao
30374ae3e6 feat: Add ServiceRegistry import to routes_common.py for improved service management 2025-07-02 19:24:04 +08:00
Will Miao
ab22d16bad feat: Rename download endpoint from /api/download-lora to /api/download-model and update related logic 2025-07-02 19:21:25 +08:00
Will Miao
971cd56a4a feat: Update WebSocket endpoint for checkpoint progress and adjust related routes 2025-07-02 18:38:02 +08:00
Will Miao
d7cb546c5f refactor: Simplify model download handling by consolidating download logic and updating parameter usage 2025-07-02 18:25:42 +08:00
Will Miao
9d8b7344cd feat: Enhance Civitai image metadata parser to prevent duplicate LoRAs 2025-07-02 16:50:19 +08:00
Will Miao
2d4f6ae7ce feat: Add route to check if a model exists in the library 2025-07-02 14:45:19 +08:00
Edward Johan
d9126807b0 Update requirements.txt 2025-07-01 00:13:29 +05:00
Will Miao
cad5fb3fba feat: Add mock module creation for py/nodes directory to prevent loading modules from the nodes directory 2025-06-30 20:19:37 +08:00
Will Miao
afe23ad6b7 fix: Update project description for clarity and engagement 2025-06-30 15:21:50 +08:00
Will Miao
fc4327087b Add WanVideo Lora Select node and related functionality. Fixes #266
- Implemented the WanVideo Lora Select node in Python with input handling for low memory loading and LORA syntax processing.
- Updated the JavaScript side to register the new node and manage its widget interactions.
- Enhanced constants files to include the new node type and its corresponding ID.
- Modified existing Lora Loader and Stacker references to accommodate the new node in various workflows and UI components.
- Added example workflow JSON for the new node to demonstrate its usage.
2025-06-30 15:10:34 +08:00
Will Miao
71762d788f Add Lora Loader node support for Nunchaku SVDQuant FLUX model architecture with template workflow. Fixes #255 2025-06-29 23:57:50 +08:00
Will Miao
6472e00fb0 fix: Update EXTRANETS_REGEX to allow for hyphens in hypernet identifiers. Fixes #264 2025-06-29 16:48:02 +08:00
pixelpaws
4043846767 Merge pull request #261 from Rauks/add-flux-kontext
feat: Add "Flux.1 Kontext" base model
2025-06-28 21:10:51 +08:00
Karl Woditsch
d3b2bc962c feat: Add "Flux.1 Kontext" base model 2025-06-28 15:01:26 +02:00
Will Miao
54f7b64821 Replace Chart.js CDN link with local path for statistics page. Fixes #260 2025-06-28 20:53:00 +08:00
Will Miao
82a2a6e669 chore: update version to 0.8.19 and add release notes for new features and enhancements 2025-06-28 08:04:16 +08:00
Will Miao
6376d60af5 Add temp debug console logging 2025-06-27 17:47:19 +08:00
Will Miao
b1e2e3831f fix: enhance model processing logic to skip already processed models only if their directories contain files. See #259 2025-06-27 13:09:19 +08:00
Will Miao
5de1c8aa82 feat: add node selector header with action mode indicator and instructions for improved user guidance 2025-06-27 12:39:20 +08:00
Will Miao
63dc5c2bdb fix: change overflow-y property to scroll for consistent vertical scrolling behavior 2025-06-27 11:44:43 +08:00
Will Miao
7f2d1670a0 feat: add startExpanded option to renderShowcaseContent for improved showcase interaction 2025-06-27 10:12:17 +08:00
Will Miao
53c8c337fc fix: remove unnecessary variable assignment for trigger words section in edit mode 2025-06-27 09:58:24 +08:00
Will Miao
5b4ec1b2a2 feat: implement disabled state for header search on statistics page with appropriate styling and functionality adjustments 2025-06-27 09:45:48 +08:00
Will Miao
64dd2ed141 feat: enhance node registration and management with support for multiple nodes and improved UI elements. Fixes #220 2025-06-26 23:00:55 +08:00
Will Miao
eb57e04e95 feat: implement thread-safe node registry and registration endpoints for Lora nodes 2025-06-26 18:31:14 +08:00
Will Miao
ae905c8630 fix: correct extension name format and update initialization method in usage stats 2025-06-26 16:57:26 +08:00
Will Miao
c157e794f0 feat: implement event delegation for checkpoint cards and enhance Civitai link handling 2025-06-26 11:42:43 +08:00
Will Miao
ed9bae6f6a feat: enhance recipe metadata handling with NSFW level updates and context menu actions. FIxes #247 2025-06-26 11:04:51 +08:00
Will Miao
9fe1ce19ad feat: add Patreon support section to the support modal with styling 2025-06-26 09:54:07 +08:00
Will Miao
6148236cbd fix: add missing patreon entry in FUNDING.yml 2025-06-26 08:23:12 +08:00
Will Miao
2471eb518a fix: correct key reference in process_trigger_words and update comment for widget values. Fixes #254 2025-06-25 20:57:12 +08:00
Will Miao
8931b41c76 feat: refactor API routes for renaming models and update related functions 2025-06-25 19:38:38 +08:00
Will Miao
7f523f167d fix: correct indentation for appending lora_entry in CivitaiApiMetadataParser. Fixes #253 2025-06-25 15:57:14 +08:00
Will Miao
446b6d6158 feat: sync saved example images path with backend on path update. Fixes #250 2025-06-25 15:34:25 +08:00
Will Miao
2ee057e19b feat: update metadata saving to ensure backup creation and support nested civitai structure 2025-06-25 11:50:10 +08:00
Will Miao
afc810f21f feat: prevent Ctrl+A behavior when search input is focused. See #251 2025-06-24 22:12:53 +08:00
pixelpaws
357052a903 Merge pull request #252 from willmiao/stats-page
Add statistics page with metrics, charts, and insights functionality
2025-06-24 21:37:06 +08:00
Will Miao
39d6d8d04a Add statistics page with metrics, charts, and insights functionality
- Implemented CSS styles for the statistics page layout and components.
- Developed JavaScript functionality for managing statistics, including data fetching, chart rendering, and tab navigation.
- Created HTML template for the statistics page, integrating dynamic content for metrics, charts, and insights.
- Added responsive design adjustments and loading states for better user experience.
2025-06-24 21:36:20 +08:00
Will Miao
888896c0c0 feat: add card info display setting with options for always visible or reveal on hover 2025-06-24 17:41:52 +08:00
Will Miao
ceee482ecc feat: refactor Lora handling by introducing chainCallback for improved node initialization and widget management. Fixes #176 2025-06-24 16:36:15 +08:00
Will Miao
d0ed1213d8 feat: enhance LoRA metadata handling by adding model IDs and updating recipe data structure. Fixes #246 2025-06-24 11:12:21 +08:00
Will Miao
f6ef428008 feat: update preview URL handling in RecipeRoutes and optimize recipe refresh logic in RecipeModal. Fixes #244 2025-06-23 15:29:22 +08:00
Will Miao
e726c4f442 feat: enhance metadata extraction for TSC samplers with vae_decode handling 2025-06-23 10:55:27 +08:00
Will Miao
402318e586 feat: enhance metadata processing and extraction for Efficient nodes with improved prompt handling and conditioning outputs. 2025-06-22 13:21:31 +08:00
Will Miao
b198cc2a6e feat: enhance metadata enrichment process to update file paths and preview URLs dynamically. See #113 2025-06-21 21:24:22 +08:00
Will Miao
c3dd4da11b feat: enhance theme toggle functionality with auto theme support and icon updates. Fix #243 2025-06-21 20:43:44 +08:00
Will Miao
ba2e42b06e feat: enhance LoraModal with notes hint and cleanup functionality on close 2025-06-21 20:04:57 +08:00
Will Miao
fa0902dc74 feat: add AdvancedCLIPTextEncode to NODE_EXTRACTORS for enhanced metadata extraction. See #234 2025-06-21 06:22:33 +08:00
Will Miao
8fcb6083dc feat: update release notes and version to 0.8.18 with new features and improvements 2025-06-20 18:25:15 +08:00
Will Miao
1ef88140e3 fix: adjust widget heights and padding for improved layout and text alignment 2025-06-20 17:21:31 +08:00
Will Miao
aa34c4c84c refactor: streamline prompt matching logic in MetadataProcessor 2025-06-20 17:00:23 +08:00
Will Miao
32d12bb334 feat: update API routes for version info and enhance version fetching functionality 2025-06-20 16:38:11 +08:00
Will Miao
1b2a02cb1a feat: add git information display in update modals and enhance version check functionality 2025-06-20 15:22:07 +08:00
Will Miao
2ff11a16c4 feat: implement DebugMetadata node with metadata display and update functionality 2025-06-20 14:17:39 +08:00
Will Miao
441af82dbd fix: update EXIF metadata extraction method for better compatibility with non-JPEG formats 2025-06-20 11:15:05 +08:00
Will Miao
e09c09af6f feat: support GIF format for preview images. Fixes #236 2025-06-20 10:51:52 +08:00
Will Miao
3721fe226f Remove unused code 2025-06-20 10:43:02 +08:00
Will Miao
8ace0e11cf Update find_preview_file to include example extension from Civitai Helper for A1111. Fixes #225 2025-06-20 10:41:42 +08:00
Will Miao
5e249b0b59 fix: Update from_civitai flag to True in metadata creation for checkpoints and LoraMetadata. Fixes #238 2025-06-20 05:48:28 +08:00
Will Miao
4889955ecf feat: Add conditioning matching to prompts and update metadata handling in node extractors. See #235 2025-06-20 00:04:02 +08:00
pixelpaws
d840fd53da Merge pull request #231 from PredatorIWD/fix-crash-on-symlinks
Don't crash completely if a symlink resolve fails
2025-06-19 18:34:03 +08:00
pixelpaws
a61819cdb3 Merge branch 'main' into fix-crash-on-symlinks 2025-06-19 18:33:40 +08:00
Will Miao
e986fbb5fb refactor: Streamline progress file handling and enhance metadata extraction for images 2025-06-19 18:12:16 +08:00
Will Miao
8f4d575ec8 refactor: Improve metadata handling and streamline example image loading in modals 2025-06-19 17:07:28 +08:00
Will Miao
605a06317b feat: Enhance media handling by adding NSFW level support and improving preview image management 2025-06-19 15:19:24 +08:00
Will Miao
a7304ccf47 feat: Add deepMerge method for improved object merging in VirtualScroller 2025-06-19 12:46:50 +08:00
Will Miao
374e2bd4b9 refactor: Add MediaRenderers, MediaUtils, MetadataPanel, and ShowcaseView components for enhanced media handling in showcase
- Implemented MediaRenderers.js to generate HTML for video and image wrappers, including NSFW handling and media controls.
- Created MediaUtils.js for utility functions to manage media loading, lazy loading, and metadata panel interactions.
- Developed MetadataPanel.js to generate metadata panels for media items, including prompts and generation parameters.
- Introduced ShowcaseView.js to render showcase content, manage media items, and handle file imports with drag-and-drop support.
2025-06-19 11:21:32 +08:00
Will Miao
09a3246ddb Add delete functionality for custom example images with API endpoint 2025-06-19 11:21:00 +08:00
Will Miao
a615603866 Prevent Ctrl+A behavior in modals by checking for open modals before handling the key event 2025-06-18 18:43:11 +08:00
Will Miao
1ca05808e1 Enhance preview image upload by deleting existing previews and updating UI state management 2025-06-18 18:37:13 +08:00
Will Miao
5febc2a805 Add update indicator and animation for updated cards in VirtualScroller 2025-06-18 17:30:49 +08:00
Will Miao
3c047bee58 Refactor example images handling by introducing migration logic, updating metadata structure, and enhancing image loading in the UI 2025-06-18 17:14:49 +08:00
Will Miao
022c6c157a Refactor example images code 2025-06-18 09:28:00 +08:00
Will Miao
fa587d5678 Refactor modal components by removing unused imports and commenting out cache management section in modals.html 2025-06-17 21:06:01 +08:00
Will Miao
afa5a42f5a Refactor metadata handling by introducing MetadataManager for centralized operations and improving error handling 2025-06-17 21:01:48 +08:00
Will Miao
71df8ba3e2 Refactor metadata handling by removing direct UI updates from saveModelMetadata and related functions 2025-06-17 20:25:39 +08:00
Will Miao
8764998e8c Update example images optimization message to clarify metadata preservation 2025-06-16 23:26:55 +08:00
Will Miao
2cb4f3aac8 Add example images access modal and API integration for checking image availability. Fixes #183 and #209 2025-06-16 21:33:49 +08:00
Will Miao
1ccaf33aac Refactor example images management by removing centralized examples settings and migration functionality 2025-06-16 18:29:37 +08:00
Will Miao
cb0a8e0413 Implement example image import functionality with UI and backend integration 2025-06-16 18:14:53 +08:00
Luka Celebic
8674168df4 Don't crash completely if a symlink resolve fails 2025-06-15 20:00:21 +02:00
Will Miao
2221653801 Add bulk selection functionality and limit thumbnail display in BulkManager. See #229 2025-06-15 22:21:21 +08:00
Will Miao
78bcdcef5d Enhance CivitAI metadata fetch handling and update virtual scroller item management. See #227 2025-06-15 08:34:22 +08:00
Will Miao
672fbe2ac0 Remove unused and outdated code to improve clarity 2025-06-15 06:18:47 +08:00
Will Miao
56a5970b44 Adjust NSFW warning styles for medium and compact density modes 2025-06-14 19:49:54 +08:00
Will Miao
a66cef7cfe Increase max-height for model names in medium and compact density modes to prevent text cutoff 2025-06-14 19:30:46 +08:00
Will Miao
c0b1c2e099 Remove commented-out Civitai context menu item from checkpoints and context menu templates 2025-06-14 18:13:37 +08:00
Will Miao
9e553bb87b Refactor card update functions to unify model and Lora card handling; remove unused metadata path update logic. See #228 2025-06-14 09:39:59 +08:00
Will Miao
f966514bc7 Add tag editing functionality and update compact tags rendering 2025-06-13 20:42:44 +08:00
Will Miao
dc0a49f96d Refactor trigger words and metadata editing styles
- Removed outdated styles from trigger words CSS and consolidated into a new shared edit-metadata CSS file.
- Updated JavaScript components for trigger words and model tags to utilize the new metadata styles.
- Adjusted class names and structure in the HTML to align with the new styling conventions.
- Enhanced the UI for editing tags and trigger words, ensuring consistency across components.
2025-06-13 20:19:10 +08:00
Will Miao
65c783c024 Refactor lora-modal.css into modular components 2025-06-13 15:10:26 +08:00
Will Miao
6395836fbb Add styles for empty tags and update tag rendering logic to always display container 2025-06-13 07:11:07 +08:00
Will Miao
a7207084ef Remove unused monitor cleanup logic from LoraManager and DownloadManager 2025-06-13 05:52:52 +08:00
Will Miao
27ef1f1e71 Refactor tag editing setup: improve event handler management for edit and save buttons 2025-06-13 05:46:53 +08:00
Will Miao
68fdb14cd6 Remove unused lora monitor retrieval and ignore path logic from ApiRoutes, DownloadManager, and ModelScanner. Fixes #226 2025-06-13 05:46:22 +08:00
Will Miao
c2af282a85 Add tag editing functionality: implement UI for editing model tags, including save and delete options, and integrate with existing modal structure. 2025-06-12 21:00:17 +08:00
Will Miao
92d48335cb Add endpoints and functionality for verifying duplicates in Lora and Checkpoints
- Implemented `/api/loras/verify-duplicates` and `/api/checkpoints/verify-duplicates` endpoints.
- Added `handle_verify_duplicates` method in `ModelRouteUtils` to process duplicate verification requests.
- Enhanced `ModelDuplicatesManager` to manage verification state and display results.
- Updated CSS for verification badges and hash mismatch indicators. Fixes #221
2025-06-12 12:06:01 +08:00
Will Miao
78cac2edc2 Add DoRA type support. move VALID_LORA_TYPES to utils.constants and update imports in recipe parsers and API routes. 2025-06-12 09:25:00 +08:00
Will Miao
26d105c439 Enhance Civitai model handling: add get_model_version method for detailed metadata retrieval, update routes to utilize new method, and improve URL handling in context menu for model re-linking. 2025-06-11 22:06:16 +08:00
Will Miao
7fec107b98 Refactor context menus to use ModelContextMenuMixin for shared functionality
- Introduced ModelContextMenuMixin to encapsulate shared methods for Lora and Checkpoint context menus.
- Updated CheckpointContextMenu to utilize the mixin for common actions and NSFW level handling.
- Simplified LoraContextMenu by integrating the mixin, removing redundant methods.
- Removed duplicated NSFW handling logic and centralized it in the mixin.
- Adjusted import/export statements to reflect the new structure and ensure proper functionality.
2025-06-11 20:52:45 +08:00
Will Miao
eb01ad3af9 Refactor model response inclusion to only include groups with multiple models; update model removal logic to accept hash value. See #221 2025-06-11 19:52:44 +08:00
Will Miao
e0d9880b32 Remove duplicate hash entries with a single path in get_duplicate_hashes method 2025-06-11 17:33:13 +08:00
Will Miao
e81e96f0ab Refactor file monitoring and model scanning; remove unused monitors and streamline model file deletion process. 2025-06-11 17:02:10 +08:00
Will Miao
06d5bd259c Refactor model file processing in ModelScanner to determine root paths and enhance error logging for missing roots. 2025-06-11 15:53:35 +08:00
Will Miao
14238b8d62 Update preview URL handling in load_metadata function to reflect model location changes. See #113 2025-06-11 15:43:12 +08:00
Will Miao
3b51886927 Add cache file control to ModelScanner; implement flags to enable/disable cache usage and clear cache files accordingly. See #222 2025-06-11 09:17:10 +08:00
Will Miao
a295ff2e06 Refactor video embed implementation to enhance privacy and user experience; replace iframe with a privacy-friendly video container and add external link buttons for YouTube access. 2025-06-10 06:44:08 +08:00
Will Miao
18cdaabf5e Update release notes and version to v0.8.17, adding new features including duplicate model detection, enhanced URL recipe imports, and improved trigger word control. 2025-06-09 19:07:53 +08:00
Will Miao
787e37b7c6 Add CivitAI re-linking functionality and related UI components. Fixes #216
- Implemented new API endpoints for re-linking models to CivitAI.
- Added context menu options for re-linking in both Lora and Checkpoint context menus.
- Created a modal for user confirmation and input for CivitAI model URL.
- Updated styles for the new modal and context menu items.
- Enhanced error handling and user feedback during the re-linking process.
2025-06-09 17:23:03 +08:00
Will Miao
4e5c8b2dd0 Add help modal functionality and update related UI components 2025-06-09 14:55:18 +08:00
Will Miao
d8ddacde38 Remove 'folder' field from model metadata before saving to file. See #211 2025-06-09 11:26:24 +08:00
Will Miao
bb1e42f0d3 Add restart required icon to example images download location label. See #212 2025-06-08 20:43:10 +08:00
pixelpaws
923669c495 Merge pull request #213 from willmiao/migrate-images
Migrate images
2025-06-08 20:11:37 +08:00
Will Miao
7a4139544c Add method to update model metadata from local example images. Fixes #211 2025-06-08 20:10:36 +08:00
Will Miao
4d6ea0236b Add centralized example images setting and update related UI components 2025-06-08 17:38:46 +08:00
Will Miao
e872a06f22 Refactor MiscRoutes and move example images related api to ExampleImagesRoutes 2025-06-08 14:40:30 +08:00
Will Miao
647bda2160 Add API endpoint and frontend integration for fetching example image files 2025-06-07 22:31:57 +08:00
Will Miao
c1e93d23f3 Merge branch 'migrate-images' of https://github.com/willmiao/ComfyUI-Lora-Manager into migrate-images 2025-06-07 11:32:55 +08:00
Will Miao
c96550cc68 Enhance migration and download processes: add backend path update and prevent duplicate completion toasts 2025-06-07 11:29:53 +08:00
Will Miao
b1015ecdc5 Add migration functionality for example images: implement API endpoint and UI controls 2025-06-07 11:27:25 +08:00
Will Miao
f1b928a037 Add migration functionality for example images: implement API endpoint and UI controls 2025-06-07 09:34:07 +08:00
Will Miao
16c312c90b Fix version description not showing. Fixes #210 2025-06-07 01:29:38 +08:00
Will Miao
110ffd0118 Refactor modal close behavior: ensure consistent handling of closeOnOutsideClick option across multiple modals. 2025-06-06 10:32:18 +08:00
Will Miao
35ad872419 Enhance duplicates management: add help tooltip for duplicate groups and improve responsive styling for banners and groups. 2025-06-05 15:06:53 +08:00
Will Miao
9b943cf2b8 Update custom node icon 2025-06-05 06:48:48 +08:00
Will Miao
9d1b357e64 Enhance cache validation logic: add logging for version and model type mismatches, and relax directory structure checks to improve cache validity. 2025-06-04 20:47:14 +08:00
Will Miao
9fc2fb4d17 Enhance model caching and exclusion functionality: update cache version, add excluded models to cache data, and ensure cache is saved to disk after model exclusion and deletion. 2025-06-04 18:38:45 +08:00
Will Miao
641fa8a3d9 Enhance duplicates mode functionality: add toggle for entering/exiting mode, improve exit button styling, and manage control button states during duplicates mode. 2025-06-04 16:46:57 +08:00
Will Miao
add9269706 Enhance duplicate mode exit logic: hide duplicates banner, clear model grid, and re-enable virtual scrolling. Improve spacer element handling in VirtualScroller by recreating it if not found in the DOM. 2025-06-04 16:05:57 +08:00
Will Miao
1a01c4a344 Refactor trigger words UI handling: improve event listener management, restore original words on cancel, and enhance dropdown update logic. See #147 2025-06-04 15:02:13 +08:00
Will Miao
b4e7feed06 Enhance trained words extraction and display: include class tokens in response and update UI accordingly. See #147 2025-06-04 12:04:38 +08:00
Will Miao
4b96c650eb Enhance example image handling: improve filename extraction and fallback for local images 2025-06-04 11:30:56 +08:00
Will Miao
107aef3785 Enhance SaveImage and TriggerWordToggle: add tooltips for parameters to improve user guidance 2025-06-03 19:40:01 +08:00
Will Miao
b49807824f Fix optimizeExampleImages setting in SettingsManager 2025-06-03 18:10:43 +08:00
Will Miao
e5ef2ef8b5 Add default_active parameter to TriggerWordToggle for controlling default state 2025-06-03 17:45:52 +08:00
Will Miao
88779ed56c Enhance Lora Manager widget: add configurable window size for Shift+Click behavior 2025-06-03 16:25:31 +08:00
Will Miao
8b59fb6adc Refactor ShowcaseView and uiHelpers for improved image/video handling
- Moved getLocalExampleImageUrl function to uiHelpers.js for better modularity.
- Updated ShowcaseView.js to utilize the new structure for local and fallback URLs.
- Enhanced lazy loading functions to support both primary and fallback URLs for images and videos.
- Simplified metadata panel generation in ShowcaseView.js.
- Improved showcase toggle functionality and added initialization for lazy loading and metadata handlers.
2025-06-03 16:06:54 +08:00
Will Miao
7945647b0b Refactor core application and recipe manager: remove lazy loading functionality and clean up imports in uiHelpers. 2025-06-03 15:40:51 +08:00
Will Miao
2d39b84806 Add CivitaiApiMetadataParser and improve recipe parsing logic for Civitai images. Also fixes #197
Additional info: Now prioritizes using the Civitai Images API to fetch image and generation metadata. Even NSFW images can now be imported via URL.
2025-06-03 14:58:43 +08:00
Will Miao
e151a19fcf Implement bulk operations for LoRAs: add send to workflow and bulk delete functionality with modal confirmation. 2025-06-03 07:44:52 +08:00
Will Miao
99d2ba26b9 Add API endpoint for fetching trained words and implement dropdown suggestions in the trigger words editor. See #147 2025-06-02 17:04:33 +08:00
Will Miao
396924f4cc Add badge for duplicate count and update logic in ModelDuplicatesManager and PageControls 2025-06-02 09:42:28 +08:00
Will Miao
7545312229 Add bulk delete endpoint for checkpoints and enhance ModelDuplicatesManager for better handling of model types 2025-06-02 08:54:31 +08:00
Will Miao
26f9779fbf Add bulk delete functionality for loras and implement model duplicates management. See #198
- Introduced a new API endpoint for bulk deleting loras.
- Added ModelDuplicatesManager to handle duplicate models for loras and checkpoints.
- Implemented UI components for displaying duplicates and managing selections.
- Enhanced controls with a button for finding duplicates.
- Updated templates to include a duplicates banner and associated actions.
2025-06-02 08:08:45 +08:00
Will Miao
0bd62eef3a Add endpoints for finding duplicate loras and filename conflicts; implement tracking for duplicates in ModelHashIndex and update ModelScanner to handle new data structures. 2025-05-31 20:50:51 +08:00
Will Miao
e06d15f508 Remove LoraHashIndex class and related functionality to streamline codebase. 2025-05-31 20:25:12 +08:00
Will Miao
aa1ee96bc9 Add versioning and history tracking to usage statistics. Implement backup and conversion for old stats format, enhancing data structure for checkpoints and loras. 2025-05-31 16:38:18 +08:00
Will Miao
355c73512d Enhance modal close behavior by tracking mouse events on the background. Implement logic to close modals only if mouseup occurs on the background after mousedown, improving user experience. 2025-05-31 08:53:20 +08:00
Will Miao
0daf9d92ff Update version to 0.8.16 and enhance release notes with new features, improvements, and bug fixes. 2025-05-30 21:04:24 +08:00
Will Miao
37de26ce25 Enhance Lora code update handling for browser and desktop modes. Implement broadcast support for Lora Loader nodes and improve node ID management in the workflow. 2025-05-30 20:12:38 +08:00
Will Miao
0eaef7e7a0 Refactor extension name for consistency in usage statistics tracking 2025-05-30 17:30:29 +08:00
Will Miao
8063cee3cd Add rename functionality for checkpoint and LoRA files with loading indicators 2025-05-30 16:38:18 +08:00
Will Miao
cbb25b4ac0 Enhance model metadata saving functionality with loading indicators and improved validation. Refactor editing logic for better user experience in both checkpoint and LoRA modals. Fixes #200 2025-05-30 16:30:01 +08:00
Will Miao
c62206a157 Add preprocessing for MessagePack serialization to handle large integers. See #201 2025-05-30 10:55:48 +08:00
Will Miao
09832141d0 Add functionality to open example images folder for models 2025-05-30 09:42:36 +08:00
Will Miao
bf8e121a10 Add functionality to copy LoRA syntax and update event handling for copy action 2025-05-30 09:02:17 +08:00
Will Miao
68568073ec Refactor model caching logic to streamline adding models and ensure disk persistence 2025-05-30 07:34:39 +08:00
Will Miao
ec36524c35 Add Civitai image URL optimization and simplify image processing logic 2025-05-29 22:20:16 +08:00
Will Miao
67acd9fd2c Relax cache validation by removing strict modification time checks, allowing users to refresh the cache as needed. 2025-05-29 20:58:06 +08:00
Will Miao
f7be5c8d25 Change log level to info for cache save operation and ensure cache is saved to disk after updating preview URL 2025-05-29 20:09:58 +08:00
Will Miao
ceacac75e0 Increase minimum width of dropdown menu for improved usability 2025-05-29 15:55:14 +08:00
Will Miao
bae66f94e8 Add full rebuild option to model refresh functionality and enhance dropdown controls 2025-05-29 15:51:45 +08:00
Will Miao
ddf132bd78 Add cache management feature: implement clear cache API and modal confirmation 2025-05-29 14:36:13 +08:00
Will Miao
afb012029f Enhance get_cached_data method: improve cache rebuilding logic and ensure cache is saved after initialization 2025-05-29 08:50:17 +08:00
Will Miao
651e14c8c3 Enhance get_cached_data method: add rebuild_cache option for improved cache management 2025-05-29 08:36:18 +08:00
Will Miao
e7c626eb5f Add MessagePack support for efficient cache serialization and update dependencies 2025-05-28 22:30:06 +08:00
pixelpaws
a0b0d40a19 Update README.md 2025-05-27 22:28:26 +08:00
Will Miao
42e3ab9e27 Update tutorial links in README: replace outdated video links with the latest tutorial 2025-05-27 19:24:22 +08:00
Will Miao
6e5f333364 Enhance model file moving logic: support moving associated files and handle metadata paths 2025-05-27 05:41:39 +08:00
Will Miao
f33a9abe60 Limit Lora hash display to first 10 characters and improve WebP metadata handling 2025-05-22 16:29:12 +08:00
Will Miao
7f1bbdd615 Remove debug print statement for primary sampler ID in MetadataProcessor 2025-05-22 16:01:55 +08:00
Will Miao
d3bf8eaceb Add container padding properties to VirtualScroller and adjust card padding 2025-05-22 15:23:32 +08:00
Will Miao
b9c9d602de Enhance download modals: auto-focus on URL input and auto-select version if only one available 2025-05-22 11:07:52 +08:00
Will Miao
b25fbd6e24 Refactor modal styles: remove model name field and adjust margin for modal content header 2025-05-22 10:02:13 +08:00
Will Miao
6052608a4e Update version to 0.8.15-bugfix in pyproject.toml 2025-05-22 04:42:12 +08:00
Will Miao
a073b82751 Enhance WebP image saving: add EXIF data and workflow metadata support. Fixes #193 2025-05-21 19:17:12 +08:00
Will Miao
8250acdfb5 Add creator information display to Lora and Checkpoint modals. #186 2025-05-21 15:31:23 +08:00
Will Miao
8e1f73a34e Refactor display density settings: replace compact mode with display density option and update related UI components 2025-05-20 19:35:41 +08:00
Will Miao
50704bc882 Enhance error handling and input validation in fetch_and_update_model method 2025-05-20 13:57:22 +08:00
Will Miao
35d34e3513 Revert db0b49c427 Refactor load_metadata to use save_metadata for updating metadata files 2025-05-19 21:46:01 +08:00
Will Miao
ea834f3de6 Revert "Enhance metadata processing in ModelScanner: prevent intermediate writes, restore missing civitai data, and ensure base_model consistency. #185"
This reverts commit 99b36442bb.
2025-05-19 21:39:31 +08:00
Will Miao
11aedde72f Fix save_metadata call to await asynchronous execution in load_metadata function. Fixes #192 2025-05-19 15:01:56 +08:00
Will Miao
488654abc8 Improve card layout responsiveness and scrolling behavior 2025-05-18 07:49:39 +08:00
Will Miao
da1be0dc65 Merge branch 'main' of https://github.com/willmiao/ComfyUI-Lora-Manager 2025-05-17 15:40:23 +08:00
Will Miao
d0c728a339 Enhance node tracing logic and improve prompt handling in metadata processing. See #189 2025-05-17 15:40:05 +08:00
pixelpaws
66c66c4d9b Update README.md 2025-05-16 17:08:23 +08:00
Will Miao
4882721387 Update version to 0.8.15 and add release notes for enhanced features and improvements 2025-05-16 16:13:37 +08:00
Will Miao
06a8850c0c Add more wiki images 2025-05-16 15:54:52 +08:00
Will Miao
370aa06c67 Refactor duplicates banner styles for improved layout and responsiveness 2025-05-16 15:47:08 +08:00
Will Miao
c9fa0564e7 Update images 2025-05-16 11:36:37 +08:00
Will Miao
2ba7a0ceba Add keyboard navigation support and related styles for enhanced user experience 2025-05-15 20:17:57 +08:00
Will Miao
276aedfbb9 Set 'from_civitai' flag to True when updating local metadata with CivitAI data 2025-05-15 16:50:32 +08:00
Will Miao
c193c75674 Fix misleading error message for invalid civitai api key or early access deny 2025-05-15 13:46:46 +08:00
Will Miao
a562ba3746 Fix TriggerWord Toggle not updating when all LoRAs are disabled 2025-05-15 10:30:46 +08:00
Will Miao
2fedd572ff Add header drag functionality for proportional strength adjustment of LoRAs 2025-05-15 10:12:46 +08:00
Will Miao
db0b49c427 Refactor load_metadata to use save_metadata for updating metadata files 2025-05-15 09:49:30 +08:00
Will Miao
03a6f8111c Add functionality to copy and send LoRA/Recipe syntax to workflow
- Implemented copy functionality for LoRA and Recipe syntax in context menus.
- Added options to send LoRA and Recipe to workflow in both append and replace modes.
- Updated HTML templates to include new context menu items for sending actions.
2025-05-15 07:01:50 +08:00
Will Miao
925ad7b3e0 Add user-select: none to prevent text selection on cards and control elements 2025-05-15 05:36:56 +08:00
Will Miao
bf793d5b8b Refactor Lora and Recipe card event handling: replace copy functionality with direct send to ComfyUI workflow, update UI elements, and enhance sendLoraToWorkflow to support recipe syntax. 2025-05-14 23:51:00 +08:00
Will Miao
64a906ca5e Add Lora syntax send to comfyui functionality: implement API endpoint and frontend integration for sending and updating LoRA codes in ComfyUI nodes. 2025-05-14 21:09:36 +08:00
Will Miao
99b36442bb Enhance metadata processing in ModelScanner: prevent intermediate writes, restore missing civitai data, and ensure base_model consistency. #185 2025-05-14 19:16:58 +08:00
Will Miao
3c5164d510 Update screenshot 2025-05-13 22:56:51 +08:00
Will Miao
ec4b5a4d45 Update release notes and version to v0.8.14: add virtualized scrolling, compact display mode, and enhanced LoRA node functionality. 2025-05-13 22:50:32 +08:00
Will Miao
78e1901779 Add compact mode settings and styles for improved layout control. Fixes #33 2025-05-13 21:40:37 +08:00
Will Miao
cb539314de Ensure full LoRA node chain is considered when updating TriggerWord Toggle nodes 2025-05-13 20:33:52 +08:00
Will Miao
c7627fe0de Remove no longer needed ref files. 2025-05-13 17:57:59 +08:00
Will Miao
84bfad7ce5 Enhance model deletion handling in UI: integrate virtual scroller updates and remove legacy UI card removal logic. 2025-05-13 17:50:28 +08:00
Will Miao
3e06938b05 Add enableDataWindowing option to VirtualScroller for improved control over data fetching. (Disable data windowing for now) 2025-05-13 17:13:17 +08:00
Will Miao
4f712fec14 Reduce default delay in model processing from 0.2 to 0.1 seconds for improved responsiveness. 2025-05-13 15:30:09 +08:00
Will Miao
c5c9659c76 Update refreshModels to pass folder update flag to resetAndReloadFunction 2025-05-13 15:25:40 +08:00
Will Miao
d6e175c1f1 Add API endpoints for retrieving LoRA notes and trigger words; enhance context menu with copy options. Supports #177 2025-05-13 15:14:25 +08:00
Will Miao
88088e1071 Restructure the code of loras_widget into smaller, more manageable modules. 2025-05-13 14:42:28 +08:00
Will Miao
958ddbca86 Fix workaround for saved value retrieval in Loras widget to address custom nodes issue. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/176 2025-05-13 12:27:18 +08:00
Will Miao
6670fd28f4 Add sync functionality for clipStrength when collapsed in Loras widget. https://github.com/willmiao/ComfyUI-Lora-Manager/issues/176 2025-05-13 11:45:13 +08:00
pixelpaws
1e59c31de3 Merge pull request #184 from willmiao/vscroll
Add virtual scroll
2025-05-12 22:27:40 +08:00
Will Miao
c966dbbbbc Enhance DuplicatesManager and VirtualScroller to manage virtual scrolling state and improve rendering logic 2025-05-12 21:31:03 +08:00
Will Miao
af8f5ba04e Implement client-side placeholder handling for empty recipe grid and remove server-side conditional rendering 2025-05-12 21:20:28 +08:00
Will Miao
b741ed0b3b Refactor recipe and checkpoint management to implement virtual scrolling and improve state handling 2025-05-12 20:07:47 +08:00
Will Miao
01ba3c14f8 Implement virtual scrolling for model loading and checkpoint management 2025-05-12 17:47:57 +08:00
Will Miao
d13b1a83ad checkpoint 2025-05-12 16:44:45 +08:00
Will Miao
303477db70 update 2025-05-12 14:50:10 +08:00
Will Miao
311e89e9e7 checkpoint 2025-05-12 13:59:11 +08:00
Will Miao
8546cfe714 checkpoint 2025-05-12 10:25:58 +08:00
Will Miao
e6f4d84b9a Merge branch 'main' of https://github.com/willmiao/ComfyUI-Lora-Manager 2025-05-11 18:50:53 +08:00
Will Miao
ce7e422169 Revert "refactor: streamline LoraCard event handling and implement virtual scrolling for improved performance"
This reverts commit 5dd8d905fa.
2025-05-11 18:50:19 +08:00
pixelpaws
e5aec80984 Merge pull request #179 from jakerdy/patch-1
[Fix] `/api/chekcpoints/info/{name}` change misspelled method call
2025-05-11 17:10:40 +08:00
Jak Erdy
6d97817390 [Fix] /api/chekcpoints/info/{name} change misspelled method call
If you call:
`http://127.0.0.1:8188/api/checkpoints/info/some_name`
You will get error, that there is no method `get_checkpoint_info_by_name` in `scanner`.
Lookslike it wasn't fixed after refactoring or something. Now it works as expected.
2025-05-10 17:38:10 +07:00
Will Miao
d516f22159 Merge branch 'main' of https://github.com/willmiao/ComfyUI-Lora-Manager 2025-05-10 07:34:06 +08:00
pixelpaws
e918c18ca2 Create FUNDING.yml 2025-05-09 20:17:35 +08:00
Will Miao
5dd8d905fa refactor: streamline LoraCard event handling and implement virtual scrolling for improved performance 2025-05-09 16:33:34 +08:00
Will Miao
1121d1ee6c Revert "update"
This reverts commit 4793f096af.
2025-05-09 16:14:10 +08:00
Will Miao
4793f096af update 2025-05-09 15:42:56 +08:00
Will Miao
7b5b4ce082 refactor: enhance CFGGuider handling and add CFGGuiderExtractor for improved metadata extraction. Fixes https://github.com/willmiao/ComfyUI-Lora-Manager/issues/172 2025-05-09 13:50:22 +08:00
258 changed files with 35913 additions and 17694 deletions

5
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1,5 @@
# These are supported funding model platforms
patreon: PixelPawsAI
ko_fi: pixelpawsai
custom: ['paypal.me/pixelpawsai']

2
.gitignore vendored
View File

@@ -1,5 +1,7 @@
__pycache__/
settings.json
path_mappings.yaml
output/*
py/run_test.py
.vscode/
cache/

180
README.md
View File

@@ -13,101 +13,104 @@ A comprehensive toolset that streamlines organizing, downloading, and applying L
## 📺 Tutorial: One-Click LoRA Integration
Watch this quick tutorial to learn how to use the new one-click LoRA integration feature:
[![One-Click LoRA Integration Tutorial](https://img.youtube.com/vi/qS95OjX3e70/0.jpg)](https://youtu.be/qS95OjX3e70)
[![LoRA Manager v0.8.10 - Checkpoint Management, Standalone Mode, and New Features!](https://img.youtube.com/vi/VKvTlCB78h4/0.jpg)](https://youtu.be/VKvTlCB78h4)
[![One-Click LoRA Integration Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/static/images/video-thumbnails/getting-started.jpg)](https://youtu.be/hvKw31YpE-U)
## 🌐 Browser Extension
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
![LM Civitai Extension Preview](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/wiki-images/civitai-models-page.png)
<div>
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
</a>
</div>
<div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div>
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
---
## Release Notes
### v0.8.13
* **Enhanced Recipe Management** - Added "Find duplicates" feature to identify and batch delete duplicate recipes with duplicate detection notifications during imports
* **Improved Source Tracking** - Source URLs are now saved with recipes imported via URL, allowing users to view original content with one click or manually edit links
* **Advanced LoRA Control** - Double-click LoRAs in Loader/Stacker nodes to access expanded CLIP strength controls for more precise adjustments of model and CLIP strength separately
* **Lycoris Model Support** - Added compatibility with Lycoris models for expanded creative options
* **Bug Fixes & UX Improvements** - Resolved various issues and enhanced overall user experience with numerous optimizations
### v0.8.29
* **Enhanced Recipe Imports** - Improved recipe importing with new target folder selection, featuring path input autocomplete and interactive folder tree navigation. Added a "Use Default Path" option when downloading missing LoRAs.
* **WanVideo Lora Select Node Update** - Updated the WanVideo Lora Select node with a 'merge_loras' option to match the counterpart node in the WanVideoWrapper node package.
* **Autocomplete Conflict Resolution** - Resolved an autocomplete feature conflict in LoRA nodes with pysssss autocomplete.
* **Improved Download Functionality** - Enhanced download functionality with resumable downloads and improved error handling.
* **Bug Fixes** - Addressed several bugs for improved stability and performance.
### v0.8.12
* **Enhanced Model Discovery** - Added alphabetical navigation bar to LoRAs page for faster browsing through large collections
* **Optimized Example Images** - Improved download logic to automatically refresh stale metadata before fetching example images
* **Model Exclusion System** - New right-click option to exclude specific LoRAs or checkpoints from management
* **Improved Showcase Experience** - Enhanced interaction in LoRA and checkpoint showcase areas for better usability
### v0.8.28
* **Autocomplete for Node Inputs** - Instantly find and add LoRAs by filename directly in Lora Loader, Lora Stacker, and WanVideo Lora Select nodes. Autocomplete suggestions include preview tooltips and preset weights, allowing you to quickly select LoRAs without opening the LoRA Manager UI.
* **Duplicate Notification Control** - Added a switch to duplicates mode, enabling users to turn off duplicate model notifications for a more streamlined experience.
* **Download Example Images from Context Menu** - Introduced a new context menu option to download example images for individual models.
### v0.8.11
* **Offline Image Support** - Added functionality to download and save all model example images locally, ensuring access even when offline or if images are removed from CivitAI or the site is down
* **Resilient Download System** - Implemented pause/resume capability with checkpoint recovery that persists through restarts or unexpected exits
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.27
* **User Experience Enhancements** - Improved the model download target folder selection with path input autocomplete and interactive folder tree navigation, making it easier and faster to choose where models are saved.
* **Default Path Option for Downloads** - Added a "Use Default Path" option when downloading models. When enabled, models are automatically organized and stored according to your configured path template settings.
* **Advanced Download Path Templates** - Expanded path template settings, allowing users to set individual templates for LoRA, checkpoint, and embedding models for greater flexibility. Introduced the `{author}` placeholder, enabling automatic organization of model files by creator name.
* **Bug Fixes & Stability Improvements** - Addressed various bugs and improved overall stability for a smoother experience.
### v0.8.10
* **Standalone Mode** - Run LoRA Manager independently from ComfyUI for a lightweight experience that works even with other stable diffusion interfaces
* **Portable Edition** - New one-click portable version for easy startup and updates in standalone mode
* **Enhanced Metadata Collection** - Added support for SamplerCustomAdvanced node in the metadata collector module
* **Improved UI Organization** - Optimized Lora Loader node height to display up to 5 LoRAs at once with scrolling capability for larger collections
### v0.8.26
* **Creator Search Option** - Added ability to search models by creator name, making it easier to find models from specific authors.
* **Enhanced Node Usability** - Improved user experience for Lora Loader, Lora Stacker, and WanVideo Lora Select nodes by fixing the maximum height of the text input area. Users can now freely and conveniently adjust the LoRA region within these nodes.
* **Compatibility Fixes** - Resolved compatibility issues with ComfyUI and certain custom nodes, including ComfyUI-Custom-Scripts, ensuring smoother integration and operation.
### v0.8.9
* **Favorites System** - New functionality to bookmark your favorite LoRAs and checkpoints for quick access and better organization
* **Enhanced UI Controls** - Increased model card button sizes for improved usability and easier interaction
* **Smoother Page Transitions** - Optimized interface switching between pages, eliminating flash issues particularly noticeable in dark theme
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.25
* **LoRA List Reordering**
- Drag & Drop: Easily rearrange LoRA entries using the drag handle.
- Keyboard Shortcuts:
- Arrow keys: Navigate between LoRAs
- Ctrl/Cmd + Arrow: Move selected LoRA up/down
- Ctrl/Cmd + Home/End: Move selected LoRA to top/bottom
- Delete/Backspace: Remove selected LoRA
- Context Menu: Right-click for quick actions like Move Up, Move Down, Move to Top, Move to Bottom.
* **Bulk Operations for Checkpoints & Embeddings**
- Bulk Mode: Select multiple checkpoints or embeddings for batch actions.
- Bulk Refresh: Update Civitai metadata for selected models.
- Bulk Delete: Remove multiple models at once.
- Bulk Move (Embeddings): Move selected embeddings to a different folder.
* **New Setting: Auto Download Example Images**
- Automatically fetch example images for models missing previews (requires download location to be set). Enabled by default.
* **General Improvements**
- Various user experience enhancements and stability fixes.
### v0.8.8
* **Real-time TriggerWord Updates** - Enhanced TriggerWord Toggle node to instantly update when connected Lora Loader or Lora Stacker nodes change, without requiring workflow execution
* **Optimized Metadata Recovery** - Improved utilization of existing .civitai.info files for faster initialization and preservation of metadata from models deleted from CivitAI
* **Migration Acceleration** - Further speed improvements for users transitioning from A1111/Forge environments
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.22
* **Embeddings Management** - Added Embeddings page for comprehensive embedding model management.
* **Advanced Sorting Options** - Introduced flexible sorting controls, allowing sorting by name, added date, or file size in both ascending and descending order.
* **Custom Download Path Templates & Base Model Mapping** - Implemented UI settings for configuring download path templates and base model path mappings, allowing customized model organization and storage location when downloading models via LM Civitai Extension.
* **LM Civitai Extension Enhancements** - Improved concurrent download performance and stability, with new support for canceling active downloads directly from the extension interface.
* **Update Feature** - Added update functionality, allowing users to update LoRA Manager to the latest release version directly from the LoRA Manager UI.
* **Bulk Operations: Refresh All** - Added bulk refresh functionality, allowing users to update Civitai metadata across multiple LoRAs.
### v0.8.7
* **Enhanced Context Menu** - Added comprehensive context menu functionality to Recipes and Checkpoints pages for improved workflow
* **Interactive LoRA Strength Control** - Implemented drag functionality in LoRA Loader for intuitive strength adjustment
* **Metadata Collector Overhaul** - Rebuilt metadata collection system with optimized architecture for better performance
* **Improved Save Image Node** - Enhanced metadata capture and image saving performance with the new metadata collector
* **Streamlined Recipe Saving** - Optimized Save Recipe functionality to work independently without requiring Preview Image nodes
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.20
* **LM Civitai Extension** - Released [browser extension through Chrome Web Store](https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) that works seamlessly with LoRA Manager to enhance Civitai browsing experience, showing which models are already in your local library, enabling one-click downloads, and providing queue and parallel download support
* **Enhanced Lora Loader** - Added support for nunchaku, improving convenience when working with ComfyUI-nunchaku workflows, plus new template workflows for quick onboarding
* **WanVideo Integration** - Introduced WanVideo Lora Select (LoraManager) node compatible with ComfyUI-WanVideoWrapper for streamlined lora usage in video workflows, including a template workflow to help you get started quickly
### v0.8.6 Major Update
* **Checkpoint Management** - Added comprehensive management for model checkpoints including scanning, searching, filtering, and deletion
* **Enhanced Metadata Support** - New capabilities for retrieving and managing checkpoint metadata with improved operations
* **Improved Initial Loading** - Optimized cache initialization with visual progress indicators for better user experience
### v0.8.19
* **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights
* **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting
* **Enhanced NSFW Controls** - Added support for setting NSFW levels on recipes with automatic content blurring based on user preferences
* **Customizable Card Display** - New display settings allowing users to choose whether card information and action buttons are always visible or only revealed on hover
* **Expanded Compatibility** - Added support for efficiency-nodes-comfyui in Save Recipe and Save Image nodes, plus fixed compatibility with ComfyUI_Custom_Nodes_AlekPet
### v0.8.5
* **Enhanced LoRA & Recipe Connectivity** - Added Recipes tab in LoRA details to see all recipes using a specific LoRA
* **Improved Navigation** - New shortcuts to jump between related LoRAs and Recipes with one-click navigation
* **Video Preview Controls** - Added "Autoplay Videos on Hover" setting to optimize performance and reduce resource usage
* **UI Experience Refinements** - Smoother transitions between related content pages
### v0.8.18
* **Custom Example Images** - Added ability to import your own example images for LoRAs and checkpoints with automatic metadata extraction from embedded information
* **Enhanced Example Management** - New action buttons to set specific examples as previews or delete custom examples
* **Improved Duplicate Detection** - Enhanced "Find Duplicates" with hash verification feature to eliminate false positives when identifying duplicate models
* **Tag Management** - Added tag editing functionality allowing users to customize and manage model tags
* **Advanced Selection Controls** - Implemented Ctrl+A shortcut for quickly selecting all filtered LoRAs, automatically entering bulk mode when needed
* **Note**: Cache file functionality temporarily disabled pending rework
### v0.8.4
* **Node Layout Improvements** - Fixed layout issues with LoRA Loader and Trigger Words Toggle nodes in newer ComfyUI frontend versions
* **Recipe LoRA Reconnection** - Added ability to reconnect deleted LoRAs in recipes by clicking the "deleted" badge in recipe details
* **Bug Fixes & Stability** - Resolved various issues for improved reliability
### v0.8.3
* **Enhanced Workflow Parser** - Rebuilt workflow analysis engine with improved support for ComfyUI core nodes and easier extensibility
* **Improved Recipe System** - Refined the experimental Save Recipe functionality with better workflow integration
* **New Save Image Node** - Added experimental node with metadata support for perfect CivitAI compatibility
* Supports dynamic filename prefixes with variables [1](https://github.com/nkchocoai/ComfyUI-SaveImageWithMetaData?tab=readme-ov-file#filename_prefix)
* **Default LoRA Root Setting** - Added configuration option for setting your preferred LoRA directory
### v0.8.2
* **Faster Initialization for Forge Users** - Improved first-run efficiency by utilizing existing `.json` and `.civitai.info` files from Forges CivitAI helper extension, making migration smoother.
* **LoRA Filename Editing** - Added support for renaming LoRA files directly within LoRA Manager.
* **Recipe Editing** - Users can now edit recipe names and tags.
* **Retain Deleted LoRAs in Recipes** - Deleted LoRAs will remain listed in recipes, allowing future functionality to reconnect them once re-obtained.
* **Download Missing LoRAs from Recipes** - Easily fetch missing LoRAs associated with a recipe.
### v0.8.1
* **Base Model Correction** - Added support for modifying base model associations to fix incorrect metadata for non-CivitAI LoRAs
* **LoRA Loader Flexibility** - Made CLIP input optional for model-only workflows like Hunyuan video generation
* **Expanded Recipe Support** - Added compatibility with 3 additional recipe metadata formats
* **Enhanced Showcase Images** - Generation parameters now displayed alongside LoRA preview images
* **UI Improvements & Bug Fixes** - Various interface refinements and stability enhancements
### v0.8.0
* **Introduced LoRA Recipes** - Create, import, save, and share your favorite LoRA combinations
* **Recipe Management System** - Easily browse, search, and organize your LoRA recipes
* **Workflow Integration** - Save recipes directly from your workflow with generation parameters preserved
* **Simplified Workflow Application** - Quickly apply saved recipes to new projects
* **Enhanced UI & UX** - Improved interface design and user experience
* **Bug Fixes & Stability** - Resolved various issues and enhanced overall performance
### v0.8.17
* **Duplicate Model Detection** - Added "Find Duplicates" functionality for LoRAs and checkpoints using model file hash detection, enabling convenient viewing and batch deletion of duplicate models
* **Enhanced URL Recipe Imports** - Optimized import recipe via URL functionality using CivitAI API calls instead of web scraping, now supporting all rated images (including NSFW) for recipe imports
* **Improved TriggerWord Control** - Enhanced TriggerWord Toggle node with new default_active switch to set the initial state (active/inactive) when trigger words are added
* **Centralized Example Management** - Added "Migrate Existing Example Images" feature to consolidate downloaded example images from model folders into central storage with customizable naming patterns
* **Intelligent Word Suggestions** - Implemented smart trigger word suggestions by reading class tokens and tag frequency from safetensors files, displaying recommendations when editing trigger words
* **Model Version Management** - Added "Re-link to CivitAI" context menu option for connecting models to different CivitAI versions when needed
[View Update History](./update_logs.md)
@@ -126,13 +129,6 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
- 🚀 **High Performance**
- Fast model loading and browsing
- Smooth scrolling through large collections
- Real-time updates when files change
- 📂 **Advanced Organization**
- Quick search with fuzzy matching
- Folder-based categorization
- Move LoRAs between folders
- Sort by name or date
- 🌐 **Rich Model Integration**
- Direct download from CivitAI
@@ -173,10 +169,11 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
### Option 2: **Portable Standalone Edition** (No ComfyUI required)
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.8.10/lora_manager_portable.7z)
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.8.26/lora_manager_portable.7z)
2. Copy the provided `settings.json.example` file to create a new file named `settings.json` in `comfyui-lora-manager` folder
3. Edit `settings.json` to include your correct model folder paths and CivitAI API key
4. Run run.bat
- To change the startup port, edit `run.bat` and modify the parameter (e.g. `--port 9001`)
### Option 3: **Manual Installation**
@@ -296,6 +293,8 @@ If you find this project helpful, consider supporting its development:
[![ko-fi](https://ko-fi.com/img/githubbutton_sm.svg)](https://ko-fi.com/pixelpawsai)
[![Patreon](https://img.shields.io/badge/Become%20a%20Patron-F96854.svg?style=for-the-badge&logo=patreon&logoColor=white)](https://patreon.com/PixelPawsAI)
WeChat: [Click to view QR code](https://raw.githubusercontent.com/willmiao/ComfyUI-Lora-Manager/main/static/images/wechat-qr.webp)
## 💬 Community
@@ -304,3 +303,6 @@ Join our Discord community for support, discussions, and updates:
[Discord Server](https://discord.gg/vcqNrWVFvM)
---
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=willmiao/ComfyUI-Lora-Manager&type=Date)](https://star-history.com/#willmiao/ComfyUI-Lora-Manager&Date)

View File

@@ -1,18 +1,21 @@
from .py.lora_manager import LoraManager
from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader
from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.lora_stacker import LoraStacker
from .py.nodes.save_image import SaveImage
from .py.nodes.debug_metadata import DebugMetadata
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
# Import metadata collector to install hooks on startup
from .py.metadata_collector import init as init_metadata_collector
NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader,
LoraManagerTextLoader.NAME: LoraManagerTextLoader,
TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker,
SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata
DebugMetadata.NAME: DebugMetadata,
WanVideoLoraSelect.NAME: WanVideoLoraSelect
}
WEB_DIRECTORY = "./web/comfyui"

Binary file not shown.

After

Width:  |  Height:  |  Size: 68 KiB

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -5,6 +5,7 @@ from typing import List
import logging
import sys
import json
import urllib.parse
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
@@ -17,13 +18,17 @@ class Config:
def __init__(self):
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
# 路径映射字典, target to link mapping
# Path mapping dictionary, target to link mapping
self._path_mappings = {}
# 静态路由映射字典, target to route mapping
# Static route mapping dictionary, target to route mapping
self._route_mappings = {}
self.loras_roots = self._init_lora_paths()
self.checkpoints_roots = self._init_checkpoint_paths()
# 在初始化时扫描符号链接
self.checkpoints_roots = None
self.unet_roots = None
self.embeddings_roots = None
self.base_models_roots = self._init_checkpoint_paths()
self.embeddings_roots = self._init_embedding_paths()
# Scan symbolic links during initialization
self._scan_symbolic_links()
if not standalone_mode:
@@ -33,34 +38,37 @@ class Config:
def save_folder_paths_to_settings(self):
"""Save folder paths to settings.json for standalone mode to use later"""
try:
# Check if we're running in ComfyUI mode (not standalone)
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type):
# Get all relevant paths
lora_paths = folder_paths.get_folder_paths("loras")
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffuser_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
# Load existing settings
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
settings = {}
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings = json.load(f)
# Update settings with paths
settings['folder_paths'] = {
'loras': lora_paths,
'checkpoints': checkpoint_paths,
'diffusers': diffuser_paths,
'unet': unet_paths
}
# Save settings
with open(settings_path, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=2)
logger.info("Saved folder paths to settings.json")
# Check if we're running in ComfyUI mode (not standalone)
# Load existing settings
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
settings = {}
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings = json.load(f)
# Update settings with paths
settings['folder_paths'] = {
'loras': self.loras_roots,
'checkpoints': self.checkpoints_roots,
'unet': self.unet_roots,
'embeddings': self.embeddings_roots,
}
# Add default roots if there's only one item and key doesn't exist
if len(self.loras_roots) == 1 and "default_lora_root" not in settings:
settings["default_lora_root"] = self.loras_roots[0]
if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings:
settings["default_checkpoint_root"] = self.checkpoints_roots[0]
if self.embeddings_roots and len(self.embeddings_roots) == 1 and "default_embedding_root" not in settings:
settings["default_embedding_root"] = self.embeddings_roots[0]
# Save settings
with open(settings_path, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=2)
logger.info("Saved folder paths to settings.json")
except Exception as e:
logger.warning(f"Failed to save folder paths: {e}")
@@ -82,15 +90,18 @@ class Config:
return False
def _scan_symbolic_links(self):
"""扫描所有 LoRA Checkpoint 根目录中的符号链接"""
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
for root in self.loras_roots:
self._scan_directory_links(root)
for root in self.checkpoints_roots:
for root in self.base_models_roots:
self._scan_directory_links(root)
for root in self.embeddings_roots:
self._scan_directory_links(root)
def _scan_directory_links(self, root: str):
"""递归扫描目录中的符号链接"""
"""Recursively scan symbolic links in a directory"""
try:
with os.scandir(root) as it:
for entry in it:
@@ -105,40 +116,40 @@ class Config:
logger.error(f"Error scanning links in {root}: {e}")
def add_path_mapping(self, link_path: str, target_path: str):
"""添加符号链接路径映射
target_path: 实际目标路径
link_path: 符号链接路径
"""Add a symbolic link path mapping
target_path: actual target path
link_path: symbolic link path
"""
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
# 保持原有的映射关系:目标路径 -> 链接路径
# Keep the original mapping: target path -> link path
self._path_mappings[normalized_target] = normalized_link
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
def add_route_mapping(self, path: str, route: str):
"""添加静态路由映射"""
"""Add a static route mapping"""
normalized_path = os.path.normpath(path).replace(os.sep, '/')
self._route_mappings[normalized_path] = route
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
def map_path_to_link(self, path: str) -> str:
"""将目标路径映射回符号链接路径"""
"""Map a target path back to its symbolic link path"""
normalized_path = os.path.normpath(path).replace(os.sep, '/')
# 检查路径是否包含在任何映射的目标路径中
# Check if the path is contained in any mapped target path
for target_path, link_path in self._path_mappings.items():
if normalized_path.startswith(target_path):
# 如果路径以目标路径开头,则替换为链接路径
# If the path starts with the target path, replace with link path
mapped_path = normalized_path.replace(target_path, link_path, 1)
return mapped_path
return path
def map_link_to_path(self, link_path: str) -> str:
"""将符号链接路径映射回实际路径"""
"""Map a symbolic link path back to the actual path"""
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
# 检查路径是否包含在任何映射的目标路径中
# Check if the path is contained in any mapped target path
for target_path, link_path in self._path_mappings.items():
if normalized_link.startswith(target_path):
# 如果路径以目标路径开头,则替换为实际路径
# If the path starts with the target path, replace with actual path
mapped_path = normalized_link.replace(target_path, link_path, 1)
return mapped_path
return link_path
@@ -177,35 +188,85 @@ class Config:
"""Initialize and validate checkpoint paths from ComfyUI settings"""
try:
# Get checkpoint paths from folder_paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusers")
unet_paths = folder_paths.get_folder_paths("unet")
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
raw_unet_paths = folder_paths.get_folder_paths("unet")
# Combine all checkpoint-related paths
all_paths = checkpoint_paths + diffusion_paths + unet_paths
# Normalize and resolve symlinks for checkpoints, store mapping from resolved -> original
checkpoint_map = {}
for path in raw_checkpoint_paths:
if os.path.exists(path):
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
checkpoint_map[real_path] = checkpoint_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
# Filter and normalize paths
paths = sorted(set(path.replace(os.sep, "/")
for path in all_paths
if os.path.exists(path)), key=lambda p: p.lower())
# Normalize and resolve symlinks for unet, store mapping from resolved -> original
unet_map = {}
for path in raw_unet_paths:
if os.path.exists(path):
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]"))
# Merge both maps and deduplicate by real path
merged_map = {}
for real_path, orig_path in {**checkpoint_map, **unet_map}.items():
if real_path not in merged_map:
merged_map[real_path] = orig_path
# Now sort and use only the deduplicated real paths
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
if not paths:
# Split back into checkpoints and unet roots for class properties
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()]
self.unet_roots = [p for p in unique_paths if p in unet_map.values()]
all_paths = unique_paths
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
if not all_paths:
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
return []
# 初始化路径映射,与 LoRA 路径处理方式相同
for path in paths:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
if real_path != path:
self.add_path_mapping(path, real_path)
# Initialize path mappings
for original_path in all_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return paths
return all_paths
except Exception as e:
logger.warning(f"Error initializing checkpoint paths: {e}")
return []
def _init_embedding_paths(self) -> List[str]:
"""Initialize and validate embedding paths from ComfyUI settings"""
try:
raw_paths = folder_paths.get_folder_paths("embeddings")
# Normalize and resolve symlinks, store mapping from resolved -> original
path_map = {}
for path in raw_paths:
if os.path.exists(path):
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
# Now sort and use only the deduplicated real paths
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
if not unique_paths:
logger.warning("No valid embeddings folders found in ComfyUI configuration")
return []
for original_path in unique_paths:
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
if real_path != original_path:
self.add_path_mapping(original_path, real_path)
return unique_paths
except Exception as e:
logger.warning(f"Error initializing embedding paths: {e}")
return []
def get_preview_static_url(self, preview_path: str) -> str:
"""Convert local preview path to static URL"""
if not preview_path:
@@ -215,8 +276,10 @@ class Config:
for path, route in self._route_mappings.items():
if real_path.startswith(path):
relative_path = os.path.relpath(real_path, path)
return f'{route}/{relative_path.replace(os.sep, "/")}'
relative_path = os.path.relpath(real_path, path).replace(os.sep, '/')
safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')]
safe_path = '/'.join(safe_parts)
return f'{route}/{safe_path}'
return ""

View File

@@ -1,17 +1,21 @@
import asyncio
from server import PromptServer # type: ignore
from .config import config
from .routes.lora_routes import LoraRoutes
from .routes.api_routes import ApiRoutes
from .routes.recipe_routes import RecipeRoutes
from .routes.checkpoints_routes import CheckpointsRoutes
from .routes.update_routes import UpdateRoutes
from .routes.misc_routes import MiscRoutes
from .services.service_registry import ServiceRegistry
from .services.settings_manager import settings
import logging
import sys
import os
import logging
from pathlib import Path
from server import PromptServer # type: ignore
from .config import config
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
from .routes.recipe_routes import RecipeRoutes
from .routes.stats_routes import StatsRoutes
from .routes.update_routes import UpdateRoutes
from .routes.misc_routes import MiscRoutes
from .routes.example_images_routes import ExampleImagesRoutes
from .services.service_registry import ServiceRegistry
from .services.settings_manager import settings
from .utils.example_images_migration import ExampleImagesMigration
from .services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
@@ -23,12 +27,28 @@ class LoraManager:
@classmethod
def add_routes(cls):
"""Initialize and register all routes"""
"""Initialize and register all routes using the new refactored architecture"""
app = PromptServer.instance.app
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Add specific suppression for connection reset errors
class ConnectionResetFilter(logging.Filter):
def filter(self, record):
# Filter out connection reset errors that are not critical
if "ConnectionResetError" in str(record.getMessage()):
return False
if "_call_connection_lost" in str(record.getMessage()):
return False
if "WinError 10054" in str(record.getMessage()):
return False
return True
# Apply the filter to asyncio logger
asyncio_logger = logging.getLogger("asyncio")
asyncio_logger.addFilter(ConnectionResetFilter())
added_targets = set() # Track already added target paths
# Add static route for example images if the path exists in settings
@@ -57,7 +77,7 @@ class LoraManager:
added_targets.add(real_root)
# Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1):
for idx, root in enumerate(config.base_models_roots, start=1):
preview_path = f'/checkpoints_static/root{idx}/preview'
real_root = root
@@ -74,98 +94,116 @@ class LoraManager:
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Add static routes for each embedding root
for idx, root in enumerate(config.embeddings_roots, start=1):
preview_path = f'/embeddings_static/root{idx}/preview'
real_root = root
if root in config._path_mappings.values():
for target, link in config._path_mappings.items():
if link == root:
real_root = target
break
# Add static route for original path
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {real_root}")
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(real_root)
# Add static routes for symlink target paths
link_idx = {
'lora': 1,
'checkpoint': 1
'checkpoint': 1,
'embedding': 1
}
for target_path, link_path in config._path_mappings.items():
if target_path not in added_targets:
# Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
# Determine if this is a checkpoint, lora, or embedding link based on path
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots)
is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots)
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
link_idx["checkpoint"] += 1
elif is_embedding:
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
link_idx["embedding"] += 1
else:
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
link_idx["lora"] += 1
app.router.add_static(route_path, target_path)
logger.info(f"Added static route for link target {route_path} -> {target_path}")
config.add_route_mapping(target_path, route_path)
added_targets.add(target_path)
try:
app.router.add_static(route_path, Path(target_path).resolve(strict=False))
logger.info(f"Added static route for link target {route_path} -> {target_path}")
config.add_route_mapping(target_path, route_path)
added_targets.add(target_path)
except Exception as e:
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
continue
# Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
# Register default model types with the factory
register_default_model_types()
# Initialize routes
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app)
# Setup all model routes using the factory
ModelServiceFactory.setup_all_routes(app)
# Setup non-model-specific routes
stats_routes = StatsRoutes()
stats_routes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app) # Register miscellaneous routes
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
@classmethod
async def _initialize_services(cls):
"""Initialize all services using the ServiceRegistry"""
try:
# Ensure aiohttp access logger is configured with reduced verbosity
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Initialize CivitaiClient first to ensure it's ready for other services
civitai_client = await ServiceRegistry.get_civitai_client()
# Get file monitors through ServiceRegistry
lora_monitor = await ServiceRegistry.get_lora_monitor()
checkpoint_monitor = await ServiceRegistry.get_checkpoint_monitor()
# Start monitors
lora_monitor.start()
logger.debug("Lora monitor started")
# Make sure checkpoint monitor has paths before starting
await checkpoint_monitor.initialize_paths()
checkpoint_monitor.start()
logger.debug("Checkpoint monitor started")
await ServiceRegistry.get_civitai_client()
# Register DownloadManager with ServiceRegistry
download_manager = await ServiceRegistry.get_download_manager()
await ServiceRegistry.get_download_manager()
# Initialize WebSocket manager
ws_manager = await ServiceRegistry.get_websocket_manager()
await ServiceRegistry.get_websocket_manager()
# Initialize scanners in background
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Initialize recipe scanner if needed
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Initialize metadata collector if not in standalone mode
if not STANDALONE_MODE:
from .metadata_collector import init as init_metadata
init_metadata()
logger.debug("Metadata collector initialized")
# Create low-priority initialization tasks
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init')
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
await ExampleImagesMigration.check_and_run_migrations()
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
@@ -177,17 +215,6 @@ class LoraManager:
"""Cleanup resources using ServiceRegistry"""
try:
logger.info("LoRA Manager: Cleaning up services")
# Get monitors from ServiceRegistry
lora_monitor = await ServiceRegistry.get_service("lora_monitor")
if lora_monitor:
lora_monitor.stop()
logger.info("Stopped LoRA monitor")
checkpoint_monitor = await ServiceRegistry.get_service("checkpoint_monitor")
if checkpoint_monitor:
checkpoint_monitor.stop()
logger.info("Stopped checkpoint monitor")
# Close CivitaiClient gracefully
civitai_client = await ServiceRegistry.get_service("civitai_client")

View File

@@ -1,7 +1,5 @@
"""Constants used by the metadata collector"""
# Metadata collection constants
# Metadata categories
MODELS = "models"
PROMPTS = "prompts"
@@ -9,6 +7,7 @@ SAMPLING = "sampling"
LORAS = "loras"
SIZE = "size"
IMAGES = "images"
IS_SAMPLER = "is_sampler" # New constant to mark sampler nodes
# Complete list of categories to track
METADATA_CATEGORIES = [MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES]

View File

@@ -26,98 +26,179 @@ class MetadataHook:
print("Could not locate ComfyUI execution module, metadata collection disabled")
return
# Store the original _map_node_over_list function
original_map_node_over_list = execution._map_node_over_list
# Detect whether we're using the new async version of ComfyUI
is_async = False
map_node_func_name = '_map_node_over_list'
# Define the wrapped _map_node_over_list function
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
if hasattr(execution, '_async_map_node_over_list'):
is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list)
map_node_func_name = '_async_map_node_over_list'
elif hasattr(execution, '_map_node_over_list'):
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
def execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return original_execute(*args, **kwargs)
# Replace the functions
execution._map_node_over_list = map_node_over_list_with_metadata
execution.execute = execute_with_prompt_tracking
# Make map_node_over_list public to avoid it being hidden by hooks
execution.map_node_over_list = original_map_node_over_list
if is_async:
print("Detected async ComfyUI execution, installing async metadata hooks")
MetadataHook._install_async_hooks(execution, map_node_func_name)
else:
print("Detected sync ComfyUI execution, installing sync metadata hooks")
MetadataHook._install_sync_hooks(execution)
print("Metadata collection hooks installed for runtime values")
except Exception as e:
print(f"Error installing metadata hooks: {str(e)}")
@staticmethod
def _install_sync_hooks(execution):
"""Install hooks for synchronous execution model"""
# Store the original _map_node_over_list function
original_map_node_over_list = execution._map_node_over_list
# Define the wrapped _map_node_over_list function
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record inputs before execution
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
# Get the current prompt_id from the registry
registry = MetadataRegistry()
prompt_id = registry.current_prompt_id
if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__
# Unique ID might be available through the obj if it has a unique_id field
node_id = getattr(obj, 'unique_id', None)
if node_id is None and pre_execute_cb:
# Try to extract node_id through reflection
frame = inspect.currentframe()
while frame:
if 'unique_id' in frame.f_locals:
node_id = frame.f_locals['unique_id']
break
frame = frame.f_back
# Record outputs after execution
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
def execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return original_execute(*args, **kwargs)
# Replace the functions
execution._map_node_over_list = map_node_over_list_with_metadata
execution.execute = execute_with_prompt_tracking
@staticmethod
def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'):
"""Install hooks for asynchronous execution model"""
# Store the original _async_map_node_over_list function
original_map_node_over_list = getattr(execution, map_node_func_name)
# Wrapped async function, compatible with both stable and nightly
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs):
hidden_inputs = kwargs.get('hidden_inputs', None)
# Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
registry = MetadataRegistry()
if prompt_id is not None:
class_type = obj.__class__.__name__
node_id = unique_id
if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}")
# Call original function with all args/kwargs
results = await original_map_node_over_list(
prompt_id, unique_id, obj, input_data_all, func,
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
)
if func == obj.FUNCTION and hasattr(obj, '__class__'):
try:
registry = MetadataRegistry()
if prompt_id is not None:
class_type = obj.__class__.__name__
node_id = unique_id
if node_id is not None:
registry.update_node_execution(node_id, class_type, results)
except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}")
return results
# Also hook the execute function to track the current prompt_id
original_execute = execution.execute
async def async_execute_with_prompt_tracking(*args, **kwargs):
if len(args) >= 7: # Check if we have enough arguments
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
registry = MetadataRegistry()
# Start collection if this is a new prompt
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
registry.start_collection(prompt_id)
# Store the dynprompt reference for node lookups
if hasattr(prompt, 'original_prompt'):
registry.set_current_prompt(prompt)
# Execute the original function
return await original_execute(*args, **kwargs)
# Replace the functions with async versions
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
execution.execute = async_execute_with_prompt_tracking

View File

@@ -1,36 +1,117 @@
import json
import sys
from .constants import IMAGES
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER
class MetadataProcessor:
"""Process and format collected metadata"""
@staticmethod
def find_primary_sampler(metadata):
"""Find the primary KSampler node (with highest denoise value)"""
def find_primary_sampler(metadata, downstream_id=None):
"""
Find the primary KSampler node that executed before the given downstream node
Parameters:
- metadata: The workflow metadata
- downstream_id: Optional ID of a downstream node to help identify the specific primary sampler
"""
if downstream_id is None:
if IMAGES in metadata and "first_decode" in metadata[IMAGES]:
downstream_id = metadata[IMAGES]["first_decode"]["node_id"]
# If we have a downstream_id and execution_order, use it to narrow down potential samplers
if downstream_id and "execution_order" in metadata:
execution_order = metadata["execution_order"]
# Find the index of the downstream node in the execution order
if downstream_id in execution_order:
downstream_index = execution_order.index(downstream_id)
# Extract all sampler nodes that executed before the downstream node
candidate_samplers = {}
for i in range(downstream_index):
node_id = execution_order[i]
# Use IS_SAMPLER flag to identify true sampler nodes
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
candidate_samplers[node_id] = metadata[SAMPLING][node_id]
# If we found candidate samplers, apply primary sampler logic to these candidates only
if candidate_samplers:
# Collect potential primary samplers based on different criteria
custom_advanced_samplers = []
advanced_add_noise_samplers = []
high_denoise_samplers = []
max_denoise = -1
high_denoise_id = None
# First, check for SamplerCustomAdvanced among candidates
prompt = metadata.get("current_prompt")
if prompt and prompt.original_prompt:
for node_id in candidate_samplers:
node_info = prompt.original_prompt.get(node_id, {})
if node_info.get("class_type") == "SamplerCustomAdvanced":
custom_advanced_samplers.append(node_id)
# Next, check for KSamplerAdvanced with add_noise="enable" among candidates
for node_id, sampler_info in candidate_samplers.items():
parameters = sampler_info.get("parameters", {})
add_noise = parameters.get("add_noise")
if add_noise == "enable":
advanced_add_noise_samplers.append(node_id)
# Find the sampler with highest denoise value among candidates
for node_id, sampler_info in candidate_samplers.items():
parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise")
if denoise is not None and denoise > max_denoise:
max_denoise = denoise
high_denoise_id = node_id
if high_denoise_id:
high_denoise_samplers.append(high_denoise_id)
# Combine all potential primary samplers
potential_samplers = custom_advanced_samplers + advanced_add_noise_samplers + high_denoise_samplers
# Find the most recent potential primary sampler (closest to downstream node)
for i in range(downstream_index - 1, -1, -1):
node_id = execution_order[i]
if node_id in potential_samplers:
return node_id, candidate_samplers[node_id]
# If no potential sampler found from our criteria, return the most recent sampler
if candidate_samplers:
for i in range(downstream_index - 1, -1, -1):
node_id = execution_order[i]
if node_id in candidate_samplers:
return node_id, candidate_samplers[node_id]
# If no downstream_id provided or no suitable sampler found, fall back to original logic
primary_sampler = None
primary_sampler_id = None
max_denoise = -1 # Track the highest denoise value
max_denoise = -1
# First, check for SamplerCustomAdvanced
prompt = metadata.get("current_prompt")
if prompt and prompt.original_prompt:
for node_id, node_info in prompt.original_prompt.items():
if node_info.get("class_type") == "SamplerCustomAdvanced":
# Found a SamplerCustomAdvanced node
if node_id in metadata.get(SAMPLING, {}):
# Check if the node is in SAMPLING and has IS_SAMPLER flag
if node_id in metadata.get(SAMPLING, {}) and metadata[SAMPLING][node_id].get(IS_SAMPLER, False):
return node_id, metadata[SAMPLING][node_id]
# Next, check for KSamplerAdvanced with add_noise="enable"
# Next, check for KSamplerAdvanced with add_noise="enable" using IS_SAMPLER flag
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
# Skip if not marked as a sampler
if not sampler_info.get(IS_SAMPLER, False):
continue
parameters = sampler_info.get("parameters", {})
add_noise = parameters.get("add_noise")
# If add_noise is "enable", this is likely the primary sampler for KSamplerAdvanced
if add_noise == "enable":
primary_sampler = sampler_info
primary_sampler_id = node_id
@@ -39,10 +120,12 @@ class MetadataProcessor:
# If no specialized sampler found, find the sampler with highest denoise value
if primary_sampler is None:
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
# Skip if not marked as a sampler
if not sampler_info.get(IS_SAMPLER, False):
continue
parameters = sampler_info.get("parameters", {})
denoise = parameters.get("denoise")
# If denoise exists and is higher than current max, use this sampler
if denoise is not None and denoise > max_denoise:
max_denoise = denoise
primary_sampler = sampler_info
@@ -74,13 +157,18 @@ class MetadataProcessor:
current_node_id = node_id
current_input = input_name
# If we're just tracing to origin (no target_class), keep track of the last valid node
last_valid_node = None
while current_depth < max_depth:
if current_node_id not in prompt.original_prompt:
return None
return last_valid_node if not target_class else None
node_inputs = prompt.original_prompt[current_node_id].get("inputs", {})
if current_input not in node_inputs:
return None
# We've reached a node without the specified input - this is our origin node
# if we're not looking for a specific target_class
return current_node_id if not target_class else None
input_value = node_inputs[current_input]
# Input connections are formatted as [node_id, output_index]
@@ -91,9 +179,9 @@ class MetadataProcessor:
if target_class and prompt.original_prompt[found_node_id].get("class_type") == target_class:
return found_node_id
# If we're not looking for a specific class or haven't found it yet
# If we're not looking for a specific class, update the last valid node
if not target_class:
return found_node_id
last_valid_node = found_node_id
# Continue tracing through intermediate nodes
current_node_id = found_node_id
@@ -101,16 +189,17 @@ class MetadataProcessor:
if "conditioning" in prompt.original_prompt[current_node_id].get("inputs", {}):
current_input = "conditioning"
else:
# If there's no "conditioning" input, we can't trace further
# If there's no "conditioning" input, return the current node
# if we're not looking for a specific target_class
return found_node_id if not target_class else None
else:
# We've reached a node with no further connections
return None
return last_valid_node if not target_class else None
current_depth += 1
# If we've reached max depth without finding target_class
return None
return last_valid_node if not target_class else None
@staticmethod
def find_primary_checkpoint(metadata):
@@ -126,8 +215,80 @@ class MetadataProcessor:
return None
@staticmethod
def extract_generation_params(metadata):
"""Extract generation parameters from metadata using node relationships"""
def match_conditioning_to_prompts(metadata, sampler_id):
"""
Match conditioning objects from a sampler to prompts in metadata
Parameters:
- metadata: The workflow metadata
- sampler_id: ID of the sampler node to match
Returns:
- Dictionary with 'prompt' and 'negative_prompt' if found
"""
result = {
"prompt": "",
"negative_prompt": ""
}
# Check if we have stored conditioning objects for this sampler
if sampler_id in metadata.get(PROMPTS, {}) and (
"pos_conditioning" in metadata[PROMPTS][sampler_id] or
"neg_conditioning" in metadata[PROMPTS][sampler_id]):
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
# Helper function to recursively find prompt text for a conditioning object
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
if conditioning_obj is None:
return ""
# Try to match conditioning objects with those stored by extractors
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
# For nodes with single conditioning output
if "conditioning" in prompt_data:
if id(prompt_data["conditioning"]) == id(conditioning_obj):
return prompt_data.get("text", "")
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
if is_positive and "positive_encoded" in prompt_data:
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
if "positive_text" in prompt_data:
return prompt_data["positive_text"]
else:
orig_conditioning = prompt_data.get("orig_pos_cond", None)
if orig_conditioning is not None:
# Recursively find the prompt text for the original conditioning
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
if not is_positive and "negative_encoded" in prompt_data:
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
if "negative_text" in prompt_data:
return prompt_data["negative_text"]
else:
orig_conditioning = prompt_data.get("orig_neg_cond", None)
if orig_conditioning is not None:
# Recursively find the prompt text for the original conditioning
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
return ""
# Find prompt texts using the helper function
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
return result
@staticmethod
def extract_generation_params(metadata, id=None):
"""
Extract generation parameters from metadata using node relationships
Parameters:
- metadata: The workflow metadata
- id: Optional ID of a downstream node to help identify the specific primary sampler
"""
params = {
"prompt": "",
"negative_prompt": "",
@@ -147,13 +308,20 @@ class MetadataProcessor:
prompt = metadata.get("current_prompt")
# Find the primary KSampler node
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata)
primary_sampler_id, primary_sampler = MetadataProcessor.find_primary_sampler(metadata, id)
# Directly get checkpoint from metadata instead of tracing
checkpoint = MetadataProcessor.find_primary_checkpoint(metadata)
if checkpoint:
params["checkpoint"] = checkpoint
# Check if guidance parameter exists in any sampling node
for node_id, sampler_info in metadata.get(SAMPLING, {}).items():
parameters = sampler_info.get("parameters", {})
if "guidance" in parameters and parameters["guidance"] is not None:
params["guidance"] = parameters["guidance"]
break
if primary_sampler:
# Extract sampling parameters
sampling_params = primary_sampler.get("parameters", {})
@@ -164,7 +332,6 @@ class MetadataProcessor:
params["sampler"] = sampling_params.get("sampler_name")
params["scheduler"] = sampling_params.get("scheduler")
# Trace connections from the primary sampler
if prompt and primary_sampler_id:
# Check if this is a SamplerCustomAdvanced node
is_custom_advanced = False
@@ -172,71 +339,43 @@ class MetadataProcessor:
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
if is_custom_advanced:
# For SamplerCustomAdvanced, trace specific inputs
# 1. Trace sigmas input to find BasicScheduler
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
params["steps"] = scheduler_params.get("steps")
params["scheduler"] = scheduler_params.get("scheduler")
# 2. Trace sampler input to find KSamplerSelect
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
params["sampler"] = sampler_params.get("sampler_name")
# 3. Trace guider input for FluxGuidance and CLIPTextEncode
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
if guider_node_id:
# Look for FluxGuidance along the guider path
flux_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Find CLIPTextEncode for positive prompt (through conditioning)
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# For SamplerCustomAdvanced, use the new handler method
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
else:
# Original tracing for standard samplers
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
else:
# If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux
positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10)
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "")
# Also extract guidance value if present in the sampling data
if positive_flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][positive_flux_node_id].get("parameters", {})
if "guidance" in flux_params:
params["guidance"] = flux_params.get("guidance")
# For standard samplers, match conditioning objects to prompts
prompt_results = MetadataProcessor.match_conditioning_to_prompts(metadata, primary_sampler_id)
params["prompt"] = prompt_results["prompt"]
params["negative_prompt"] = prompt_results["negative_prompt"]
# If prompts were still not found, fall back to tracing connections
if not params["prompt"]:
# Original tracing for standard samplers
# Trace positive prompt - look specifically for CLIPTextEncode
positive_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
else:
# If CLIPTextEncode is not found, try to find CLIPTextEncodeFlux
positive_flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "CLIPTextEncodeFlux", max_depth=10)
if positive_flux_node_id and positive_flux_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_flux_node_id].get("text", "")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# Find any FluxGuidance nodes in the positive conditioning path
flux_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "positive", "FluxGuidance", max_depth=5)
if flux_node_id and flux_node_id in metadata.get(SAMPLING, {}):
flux_params = metadata[SAMPLING][flux_node_id].get("parameters", {})
params["guidance"] = flux_params.get("guidance")
# Trace negative prompt - look specifically for CLIPTextEncode
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# For SamplerCustom, handle any additional parameters
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
# Size extraction is same for all sampler types
# Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
# Size extraction is same for all sampler types
# Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}):
width = metadata[SIZE][primary_sampler_id].get("width")
height = metadata[SIZE][primary_sampler_id].get("height")
if width and height:
params["size"] = f"{width}x{height}"
# Extract LoRAs using the standardized format
lora_parts = []
@@ -256,13 +395,19 @@ class MetadataProcessor:
return params
@staticmethod
def to_dict(metadata):
"""Convert extracted metadata to the ComfyUI output.json format"""
def to_dict(metadata, id=None):
"""
Convert extracted metadata to the ComfyUI output.json format
Parameters:
- metadata: The workflow metadata
- id: Optional ID of a downstream node to help identify the specific primary sampler
"""
if standalone_mode:
# Return empty dictionary in standalone mode
return {}
params = MetadataProcessor.extract_generation_params(metadata)
params = MetadataProcessor.extract_generation_params(metadata, id)
# Convert all values to strings to match output.json format
for key in params:
@@ -272,7 +417,63 @@ class MetadataProcessor:
return params
@staticmethod
def to_json(metadata):
def to_json(metadata, id=None):
"""Convert metadata to JSON string"""
params = MetadataProcessor.to_dict(metadata)
params = MetadataProcessor.to_dict(metadata, id)
return json.dumps(params, indent=4)
@staticmethod
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
"""
Handle parameter extraction for SamplerCustomAdvanced nodes
Parameters:
- metadata: The workflow metadata
- prompt: The prompt object containing node connections
- primary_sampler_id: ID of the SamplerCustomAdvanced node
- params: Parameters dictionary to update
"""
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
return
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
if "sigmas" in sampler_inputs:
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
params["steps"] = scheduler_params.get("steps")
params["scheduler"] = scheduler_params.get("scheduler")
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
if "sampler" in sampler_inputs:
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
params["sampler"] = sampler_params.get("sampler_name")
# 3. Trace guider input for CFGGuider and CLIPTextEncode
if "guider" in sampler_inputs:
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
if guider_node_id and guider_node_id in prompt.original_prompt:
# Check if the guider node is a CFGGuider
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
# Extract cfg value from the CFGGuider
if guider_node_id in metadata.get(SAMPLING, {}):
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
params["cfg_scale"] = cfg_params.get("cfg")
# Find CLIPTextEncode for positive prompt
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find CLIPTextEncode for negative prompt
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
else:
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")

View File

@@ -1,6 +1,6 @@
import os
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IMAGES, IS_SAMPLER
class NodeMetadataExtractor:
@@ -35,7 +35,70 @@ class CheckpointLoaderExtractor(NodeMetadataExtractor):
"type": "checkpoint",
"node_id": node_id
}
class TSCCheckpointLoaderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "ckpt_name" not in inputs:
return
model_name = inputs.get("ckpt_name")
if model_name:
metadata[MODELS][node_id] = {
"name": model_name,
"type": "checkpoint",
"node_id": node_id
}
# For loader node has lora_stack input, like Efficient Loader from Efficient Nodes
active_loras = []
# Process lora_stack if available
if "lora_stack" in inputs:
lora_stack = inputs.get("lora_stack", [])
for lora_path, model_strength, clip_strength in lora_stack:
# Extract lora name from path (following the format in lora_loader.py)
lora_name = os.path.splitext(os.path.basename(lora_path))[0]
active_loras.append({
"name": lora_name,
"strength": model_strength
})
if active_loras:
metadata[LORAS][node_id] = {
"lora_list": active_loras,
"node_id": node_id
}
# Extract positive and negative prompt text if available
positive_text = inputs.get("positive", "")
negative_text = inputs.get("negative", "")
if positive_text or negative_text:
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
# Store both positive and negative text
metadata[PROMPTS][node_id]["positive_text"] = positive_text
metadata[PROMPTS][node_id]["negative_text"] = negative_text
@staticmethod
def update(node_id, outputs, metadata):
# Handle conditioning outputs from TSC_EfficientLoader
# outputs is a list with [(model, positive_encoded, negative_encoded, {"samples":latent}, vae, clip, dependencies,)]
if outputs and isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
if isinstance(first_output, tuple) and len(first_output) >= 3:
positive_conditioning = first_output[1]
negative_conditioning = first_output[2]
# Save both conditioning objects in metadata
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["positive_encoded"] = positive_conditioning
metadata[PROMPTS][node_id]["negative_encoded"] = negative_conditioning
class CLIPTextEncodeExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
@@ -47,60 +110,49 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
"text": text,
"node_id": node_id
}
class SamplerExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
sampling_params = {}
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
def update(node_id, outputs, metadata):
if outputs and isinstance(outputs, list) and len(outputs) > 0:
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
conditioning = outputs[0][0]
metadata[PROMPTS][node_id]["conditioning"] = conditioning
# Base Sampler Extractor to reduce code redundancy
class BaseSamplerExtractor(NodeMetadataExtractor):
"""Base extractor for sampler nodes with common functionality"""
@staticmethod
def extract_sampling_params(node_id, inputs, metadata, param_keys):
"""Extract sampling parameters from inputs"""
sampling_params = {}
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
for key in param_keys:
if key in inputs:
sampling_params[key] = inputs[key]
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
"node_id": node_id,
IS_SAMPLER: True # Add sampler flag
}
@staticmethod
def extract_conditioning(node_id, inputs, metadata):
"""Extract conditioning objects from inputs"""
# Store the conditioning objects directly in metadata for later matching
pos_conditioning = inputs.get("positive", None)
neg_conditioning = inputs.get("negative", None)
# Save conditioning objects in metadata for later matching
if pos_conditioning is not None or neg_conditioning is not None:
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
@staticmethod
def extract_latent_dimensions(node_id, inputs, metadata):
"""Extract dimensions from latent image"""
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
@@ -121,6 +173,176 @@ class KSamplerAdvancedExtractor(NodeMetadataExtractor):
"height": height,
"node_id": node_id
}
class SamplerExtractor(BaseSamplerExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
# Extract common sampling parameters
BaseSamplerExtractor.extract_sampling_params(
node_id, inputs, metadata,
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
)
# Extract conditioning objects
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
# Extract latent dimensions
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
class KSamplerAdvancedExtractor(BaseSamplerExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
# Extract common sampling parameters
BaseSamplerExtractor.extract_sampling_params(
node_id, inputs, metadata,
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
)
# Extract conditioning objects
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
# Extract latent dimensions
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
class KSamplerBasicPipeExtractor(BaseSamplerExtractor):
"""Extractor for KSamplerBasicPipe and KSampler_inspire_pipe nodes"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
# Extract common sampling parameters
BaseSamplerExtractor.extract_sampling_params(
node_id, inputs, metadata,
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
)
# Extract conditioning objects from basic_pipe
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
basic_pipe = inputs["basic_pipe"]
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
pos_conditioning = basic_pipe[3] # positive is at index 3
neg_conditioning = basic_pipe[4] # negative is at index 4
# Save conditioning objects in metadata
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
# Extract latent dimensions
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
class KSamplerAdvancedBasicPipeExtractor(BaseSamplerExtractor):
"""Extractor for KSamplerAdvancedBasicPipe nodes"""
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
# Extract common sampling parameters
BaseSamplerExtractor.extract_sampling_params(
node_id, inputs, metadata,
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
)
# Extract conditioning objects from basic_pipe
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
basic_pipe = inputs["basic_pipe"]
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
pos_conditioning = basic_pipe[3] # positive is at index 3
neg_conditioning = basic_pipe[4] # negative is at index 4
# Save conditioning objects in metadata
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
# Extract latent dimensions
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
class TSCSamplerBaseExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
# Store vae_decode setting for later use in update
if inputs and "vae_decode" in inputs:
if SAMPLING not in metadata:
metadata[SAMPLING] = {}
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
# Store the vae_decode setting
metadata[SAMPLING][node_id]["vae_decode"] = inputs["vae_decode"]
@staticmethod
def update(node_id, outputs, metadata):
# Check if vae_decode was set to "true"
should_save_image = True
if SAMPLING in metadata and node_id in metadata[SAMPLING]:
vae_decode = metadata[SAMPLING][node_id].get("vae_decode")
if vae_decode is not None:
should_save_image = (vae_decode == "true")
# Skip image saving if vae_decode isn't "true"
if not should_save_image:
return
# Ensure IMAGES category exists
if IMAGES not in metadata:
metadata[IMAGES] = {}
# Extract output_images from the TSC sampler format
# outputs = [{"ui": {"images": preview_images}, "result": result}]
# where result = (original_model, original_positive, original_negative, latent_list, optional_vae, output_images,)
if outputs and isinstance(outputs, list) and len(outputs) > 0:
# Get the first item in the list
output_item = outputs[0]
if isinstance(output_item, dict) and "result" in output_item:
result = output_item["result"]
if isinstance(result, tuple) and len(result) >= 6:
# The output_images is the last element in the result tuple
output_images = (result[5],)
# Save image data under node ID index to be captured by caching mechanism
metadata[IMAGES][node_id] = {
"node_id": node_id,
"image": output_images
}
# Only set first_decode if it hasn't been recorded yet
if "first_decode" not in metadata[IMAGES]:
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
# Call parent extract methods
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
# Update method is inherited from TSCSamplerBaseExtractor
class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
# Call parent extract methods
KSamplerAdvancedExtractor.extract(node_id, inputs, outputs, metadata)
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
# Update method is inherited from TSCSamplerBaseExtractor
class LoraLoaderExtractor(NodeMetadataExtractor):
@staticmethod
@@ -269,7 +491,8 @@ class KSamplerSelectExtractor(NodeMetadataExtractor):
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
"node_id": node_id,
IS_SAMPLER: False # Mark as non-primary sampler
}
class BasicSchedulerExtractor(NodeMetadataExtractor):
@@ -285,10 +508,11 @@ class BasicSchedulerExtractor(NodeMetadataExtractor):
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
"node_id": node_id,
IS_SAMPLER: False # Mark as non-primary sampler
}
class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
@@ -303,29 +527,12 @@ class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
metadata[SAMPLING][node_id] = {
"parameters": sampling_params,
"node_id": node_id
"node_id": node_id,
IS_SAMPLER: True # Add sampler flag
}
# Extract latent image dimensions if available
if "latent_image" in inputs and inputs["latent_image"] is not None:
latent = inputs["latent_image"]
if isinstance(latent, dict) and "samples" in latent:
# Extract dimensions from latent tensor
samples = latent["samples"]
if hasattr(samples, "shape") and len(samples.shape) >= 3:
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
# Multiply by 8 to get actual pixel dimensions
height = int(samples.shape[2] * 8)
width = int(samples.shape[3] * 8)
if SIZE not in metadata:
metadata[SIZE] = {}
metadata[SIZE][node_id] = {
"width": width,
"height": height,
"node_id": node_id
}
# Extract latent dimensions
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
import json
@@ -338,11 +545,20 @@ class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
clip_l_text = inputs.get("clip_l", "")
t5xxl_text = inputs.get("t5xxl", "")
# Create JSON string with T5 content first, then CLIP-L
combined_text = json.dumps({
"T5": t5xxl_text,
"CLIP-L": clip_l_text
})
# If both are empty, use empty string
if not clip_l_text and not t5xxl_text:
combined_text = ""
# If one is empty, use the non-empty one
elif not clip_l_text:
combined_text = t5xxl_text
elif not t5xxl_text:
combined_text = clip_l_text
# If both have content, use JSON format
else:
combined_text = json.dumps({
"T5": t5xxl_text,
"CLIP-L": clip_l_text
})
metadata[PROMPTS][node_id] = {
"text": combined_text,
@@ -362,27 +578,103 @@ class CLIPTextEncodeFluxExtractor(NodeMetadataExtractor):
metadata[SAMPLING][node_id]["parameters"]["guidance"] = guidance_value
@staticmethod
def update(node_id, outputs, metadata):
if outputs and isinstance(outputs, list) and len(outputs) > 0:
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
conditioning = outputs[0][0]
metadata[PROMPTS][node_id]["conditioning"] = conditioning
class CFGGuiderExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs or "cfg" not in inputs:
return
cfg_value = inputs.get("cfg")
# Store the cfg value in SAMPLING category
if SAMPLING not in metadata:
metadata[SAMPLING] = {}
if node_id not in metadata[SAMPLING]:
metadata[SAMPLING][node_id] = {"parameters": {}, "node_id": node_id}
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
@staticmethod
def extract(node_id, inputs, outputs, metadata):
if not inputs:
return
# Save the original conditioning inputs
base_positive = inputs.get("base_positive")
base_negative = inputs.get("base_negative")
if base_positive is not None or base_negative is not None:
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
@staticmethod
def update(node_id, outputs, metadata):
# Extract transformed conditionings from outputs
# outputs structure: [(base_positive, base_negative, show_help, )]
if outputs and isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
if isinstance(first_output, tuple) and len(first_output) >= 2:
transformed_positive = first_output[0]
transformed_negative = first_output[1]
# Save transformed conditioning objects in metadata
if node_id not in metadata[PROMPTS]:
metadata[PROMPTS][node_id] = {"node_id": node_id}
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
# Registry of node-specific extractors
# Keys are node class names
NODE_EXTRACTORS = {
# Sampling
"KSampler": SamplerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor,
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, # Updated to use dedicated extractor
"SamplerCustom": KSamplerAdvancedExtractor,
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
"KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack
"KSamplerAdvancedBasicPipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-impact-pack
"KSampler_inspire_pipe": KSamplerBasicPipeExtractor, # comfyui-inspire-pack
"KSamplerAdvanced_inspire_pipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-inspire-pack
# Sampling Selectors
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
# Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
"LoraLoader": LoraLoaderExtractor,
"LoraManagerLoader": LoraLoaderManagerExtractor,
# Conditioning
"CLIPTextEncode": CLIPTextEncodeExtractor,
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
# Latent
"EmptyLatentImage": ImageSizeExtractor,
# Flux
"FluxGuidance": FluxGuidanceExtractor, # Add FluxGuidance
"CFGGuider": CFGGuiderExtractor, # Add CFGGuider
# Image
"VAEDecode": VAEDecodeExtractor, # Added VAEDecode extractor
# Add other nodes as needed

View File

@@ -1,4 +1,5 @@
import logging
from server import PromptServer # type: ignore
from ..metadata_collector.metadata_processor import MetadataProcessor
logger = logging.getLogger(__name__)
@@ -7,6 +8,7 @@ class DebugMetadata:
NAME = "Debug Metadata (LoraManager)"
CATEGORY = "Lora Manager/utils"
DESCRIPTION = "Debug node to verify metadata_processor functionality"
OUTPUT_NODE = True
@classmethod
def INPUT_TYPES(cls):
@@ -14,22 +16,30 @@ class DebugMetadata:
"required": {
"images": ("IMAGE",),
},
"hidden": {
"id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("metadata_json",)
RETURN_TYPES = ()
FUNCTION = "process_metadata"
def process_metadata(self, images):
def process_metadata(self, images, id):
try:
# Get the current execution context's metadata
from ..metadata_collector import get_metadata
metadata = get_metadata()
# Use the MetadataProcessor to convert it to JSON string
metadata_json = MetadataProcessor.to_json(metadata)
metadata_json = MetadataProcessor.to_json(metadata, id)
# Send metadata to frontend for display
PromptServer.instance.send_sync("metadata_update", {
"id": id,
"metadata": metadata_json
})
return (metadata_json,)
except Exception as e:
logger.error(f"Error processing metadata: {e}")
return ("{}",) # Return empty JSON object in case of error
return ()

View File

@@ -1,15 +1,16 @@
import logging
import re
from nodes import LoraLoader
from comfy.comfy_types import IO # type: ignore
import asyncio
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
logger = logging.getLogger(__name__)
class LoraManagerLoader:
NAME = "Lora Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
@@ -17,7 +18,8 @@ class LoraManagerLoader:
"model": ("MODEL",),
# "clip": ("CLIP",),
"text": (IO.STRING, {
"multiline": True,
"multiline": True,
"pysssss.autocomplete": False,
"dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>"
@@ -37,19 +39,39 @@ class LoraManagerLoader:
clip = kwargs.get('clip', None)
lora_stack = kwargs.get('lora_stack', None)
# Check if model is a Nunchaku Flux model - simplified approach
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
# Check if model is a Nunchaku Flux model using only class name
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
# Not a model with the expected structure
pass
# First process lora_stack if available
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
# Apply the LoRA using the provided path and strengths
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# Use our custom function for Flux models
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged for Nunchaku models
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name))
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Add clip strength to output if different from model strength
if abs(model_strength - clip_strength) > 0.001:
# Add clip strength to output if different from model strength (except for Nunchaku models)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
@@ -66,13 +88,19 @@ class LoraManagerLoader:
clip_strength = float(lora.get('clipStrength', model_strength))
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
lora_path, trigger_words = get_lora_info(lora_name)
# Apply the LoRA using the resolved path with separate strengths
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# For Nunchaku models, use our custom function
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Include clip strength in output if different from model strength
if abs(model_strength - clip_strength) > 0.001:
# Include clip strength in output if different from model strength and not a Nunchaku model
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
@@ -102,4 +130,142 @@ class LoraManagerLoader:
formatted_loras_text = " ".join(formatted_loras)
return (model, clip, trigger_words_text, formatted_loras_text)
class LoraManagerTextLoader:
NAME = "LoRA Text Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"lora_syntax": (IO.STRING, {
"defaultInput": True,
"forceInput": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
}),
},
"optional": {
"clip": ("CLIP",),
"lora_stack": ("LORA_STACK",),
}
}
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras_from_text"
def parse_lora_syntax(self, text):
"""Parse LoRA syntax from text input."""
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
matches = re.findall(pattern, text, re.IGNORECASE)
loras = []
for match in matches:
lora_name = match[0].strip()
model_strength = float(match[1])
clip_strength = float(match[2]) if match[2] else model_strength
loras.append({
'name': lora_name,
'model_strength': model_strength,
'clip_strength': clip_strength
})
return loras
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
"""Load LoRAs based on text syntax input."""
loaded_loras = []
all_trigger_words = []
# Check if model is a Nunchaku Flux model - simplified approach
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
# Check if model is a Nunchaku Flux model using only class name
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
# Not a model with the expected structure
pass
# First process lora_stack if available
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# Use our custom function for Flux models
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged for Nunchaku models
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Add clip strength to output if different from model strength (except for Nunchaku models)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
# Parse and process LoRAs from text syntax
parsed_loras = self.parse_lora_syntax(lora_syntax)
for lora in parsed_loras:
lora_name = lora['name']
model_strength = lora['model_strength']
clip_strength = lora['clip_strength']
# Get lora path and trigger words
lora_path, trigger_words = get_lora_info(lora_name)
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# For Nunchaku models, use our custom function
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Include clip strength in output if different from model strength and not a Nunchaku model
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
# Add trigger words to collection
all_trigger_words.extend(trigger_words)
# use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Format loaded_loras with support for both formats
formatted_loras = []
for item in loaded_loras:
parts = item.split(":")
lora_name = parts[0].strip()
strength_parts = parts[1].strip().split(",")
if len(strength_parts) > 1:
# Different model and clip strengths
model_str = strength_parts[0].strip()
clip_str = strength_parts[1].strip()
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
else:
# Same strength for both
model_str = strength_parts[0].strip()
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
formatted_loras_text = " ".join(formatted_loras)
return (model, clip, trigger_words_text, formatted_loras_text)

View File

@@ -1,9 +1,8 @@
from comfy.comfy_types import IO # type: ignore
from ..services.lora_scanner import LoraScanner
from ..config import config
import asyncio
import os
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
import logging
logger = logging.getLogger(__name__)
@@ -18,6 +17,7 @@ class LoraStacker:
"required": {
"text": (IO.STRING, {
"multiline": True,
"pysssss.autocomplete": False,
"dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>"
@@ -43,7 +43,7 @@ class LoraStacker:
# Get trigger words from existing stack entries
for lora_path, _, _ in lora_stack:
lora_name = extract_lora_name(lora_path)
_, trigger_words = asyncio.run(get_lora_info(lora_name))
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Process loras from kwargs with support for both old and new formats
@@ -58,7 +58,7 @@ class LoraStacker:
clip_strength = float(lora.get('clipStrength', model_strength))
# Get lora path and trigger words
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
lora_path, trigger_words = get_lora_info(lora_name)
# Add to stack without loading
# replace '/' with os.sep to avoid different OS path format

View File

@@ -1,11 +1,9 @@
import json
import os
import asyncio
import re
import numpy as np
import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner
from ..services.checkpoint_scanner import CheckpointScanner
from ..services.service_registry import ServiceRegistry
from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin
@@ -31,16 +29,36 @@ class SaveImage:
return {
"required": {
"images": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"file_format": (["png", "jpeg", "webp"],),
"filename_prefix": ("STRING", {
"default": "ComfyUI",
"tooltip": "Base filename for saved images. Supports format patterns like %seed%, %width%, %height%, %model%, etc."
}),
"file_format": (["png", "jpeg", "webp"], {
"tooltip": "Image format to save as. PNG preserves quality, JPEG is smaller, WebP balances size and quality."
}),
},
"optional": {
"lossless_webp": ("BOOLEAN", {"default": False}),
"quality": ("INT", {"default": 100, "min": 1, "max": 100}),
"embed_workflow": ("BOOLEAN", {"default": False}),
"add_counter_to_filename": ("BOOLEAN", {"default": True}),
"lossless_webp": ("BOOLEAN", {
"default": False,
"tooltip": "When enabled, saves WebP images with lossless compression. Results in larger files but no quality loss."
}),
"quality": ("INT", {
"default": 100,
"min": 1,
"max": 100,
"tooltip": "Compression quality for JPEG and lossy WebP formats (1-100). Higher values mean better quality but larger files."
}),
"embed_workflow": ("BOOLEAN", {
"default": False,
"tooltip": "Embeds the complete workflow data into the image metadata. Only works with PNG and WebP formats."
}),
"add_counter_to_filename": ("BOOLEAN", {
"default": True,
"tooltip": "Adds an incremental counter to filenames to prevent overwriting previous images."
}),
},
"hidden": {
"id": "UNIQUE_ID",
"prompt": "PROMPT",
"extra_pnginfo": "EXTRA_PNGINFO",
},
@@ -51,25 +69,20 @@ class SaveImage:
FUNCTION = "process_image"
OUTPUT_NODE = True
async def get_lora_hash(self, lora_name):
def get_lora_hash(self, lora_name):
"""Get the lora hash from cache"""
scanner = await LoraScanner.get_instance()
scanner = ServiceRegistry.get_service_sync("lora_scanner")
# Use the new direct filename lookup method
hash_value = scanner.get_hash_by_filename(lora_name)
if hash_value:
return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
return item.get('sha256')
return None
async def get_checkpoint_hash(self, checkpoint_path):
def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache"""
scanner = await CheckpointScanner.get_instance()
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
if not checkpoint_path:
return None
@@ -82,18 +95,10 @@ class SaveImage:
hash_value = scanner.get_hash_by_filename(checkpoint_name)
if hash_value:
return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
normalized_path = checkpoint_path.replace('\\', '/')
for item in cache.raw_data:
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
return item.get('sha256')
return None
async def format_metadata(self, metadata_dict):
def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example"""
if not metadata_dict:
return ""
@@ -120,7 +125,7 @@ class SaveImage:
# Get hash for each lora
for lora_name, strength in lora_matches:
hash_value = await self.get_lora_hash(lora_name)
hash_value = self.get_lora_hash(lora_name)
if hash_value:
lora_hashes[lora_name] = hash_value
else:
@@ -206,7 +211,7 @@ class SaveImage:
checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None:
# Get model hash
model_hash = await self.get_checkpoint_hash(checkpoint)
model_hash = self.get_checkpoint_hash(checkpoint)
# Extract basename without path
checkpoint_name = os.path.basename(checkpoint)
@@ -223,7 +228,7 @@ class SaveImage:
if lora_hashes:
lora_hash_parts = []
for lora_name, hash_value in lora_hashes.items():
lora_hash_parts.append(f"{lora_name}: {hash_value}")
lora_hash_parts.append(f"{lora_name}: {hash_value[:10]}")
if lora_hash_parts:
params.append(f"Lora hashes: \"{', '.join(lora_hash_parts)}\"")
@@ -300,17 +305,16 @@ class SaveImage:
return filename
def save_images(self, images, filename_prefix, file_format, prompt=None, extra_pnginfo=None,
def save_images(self, images, filename_prefix, file_format, id, prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Save images with metadata"""
results = []
# Get metadata using the metadata collector
raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata)
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
# Get or create metadata asynchronously
metadata = asyncio.run(self.format_metadata(metadata_dict))
metadata = self.format_metadata(metadata_dict)
# Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
@@ -378,14 +382,23 @@ class SaveImage:
print(f"Error adding EXIF data: {e}")
img.save(file_path, format="JPEG", **save_kwargs)
elif file_format == "webp":
# For WebP, also use piexif for metadata
if metadata:
try:
exif_dict = {'Exif': {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}}
exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes
except Exception as e:
print(f"Error adding EXIF data: {e}")
try:
# For WebP, use piexif for metadata
exif_dict = {}
if metadata:
exif_dict['Exif'] = {piexif.ExifIFD.UserComment: b'UNICODE\0' + metadata.encode('utf-16be')}
# Add workflow if needed
if embed_workflow and extra_pnginfo is not None:
workflow_json = json.dumps(extra_pnginfo["workflow"])
exif_dict['0th'] = {piexif.ImageIFD.ImageDescription: "Workflow:" + workflow_json}
exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes
except Exception as e:
print(f"Error adding EXIF data: {e}")
img.save(file_path, format="WEBP", **save_kwargs)
results.append({
@@ -399,23 +412,28 @@ class SaveImage:
return results
def process_image(self, images, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
def process_image(self, images, id, filename_prefix="ComfyUI", file_format="png", prompt=None, extra_pnginfo=None,
lossless_webp=True, quality=100, embed_workflow=False, add_counter_to_filename=True):
"""Process and save image with metadata"""
# Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True)
# Ensure images is always a list of images
if len(images.shape) == 3: # Single image (height, width, channels)
images = [images]
else: # Multiple images (batch, height, width, channels)
images = [img for img in images]
# If images is already a list or array of images, do nothing; otherwise, convert to list
if isinstance(images, (list, np.ndarray)):
pass
else:
# Ensure images is always a list of images
if len(images.shape) == 3: # Single image (height, width, channels)
images = [images]
else: # Multiple images (batch, height, width, channels)
images = [img for img in images]
# Save all images
results = self.save_images(
images,
filename_prefix,
file_format,
id,
prompt,
extra_pnginfo,
lossless_webp,

View File

@@ -16,11 +16,18 @@ class TriggerWordToggle:
def INPUT_TYPES(cls):
return {
"required": {
"group_mode": ("BOOLEAN", {"default": True}),
"group_mode": ("BOOLEAN", {
"default": True,
"tooltip": "When enabled, treats each group of trigger words as a single toggleable unit."
}),
"default_active": ("BOOLEAN", {
"default": True,
"tooltip": "Sets the default initial state (active or inactive) when trigger words are added."
}),
},
"optional": FlexibleOptionalInputType(any_type),
"hidden": {
"id": "UNIQUE_ID", # 会被 ComfyUI 自动替换为唯一ID
"id": "UNIQUE_ID",
},
}
@@ -41,17 +48,11 @@ class TriggerWordToggle:
else:
return data
def process_trigger_words(self, id, group_mode, **kwargs):
def process_trigger_words(self, id, group_mode, default_active, **kwargs):
# Handle both old and new formats for trigger_words
trigger_words_data = self._get_toggle_data(kwargs, 'trigger_words')
trigger_words_data = self._get_toggle_data(kwargs, 'orinalMessage')
trigger_words = trigger_words_data if isinstance(trigger_words_data, str) else ""
# Send trigger words to frontend
# PromptServer.instance.send_sync("trigger_word_update", {
# "id": id,
# "message": trigger_words
# })
filtered_triggers = trigger_words
# Get toggle data with support for both formats

View File

@@ -35,31 +35,11 @@ any_type = AnyType("*")
# Common methods extracted from lora_loader.py and lora_stacker.py
import os
import logging
import asyncio
from ..services.lora_scanner import LoraScanner
from ..config import config
import copy
import folder_paths
logger = logging.getLogger(__name__)
async def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
scanner = await LoraScanner.get_instance()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, [] # Fallback if not found
def extract_lora_name(lora_path):
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
# Get the basename without extension
@@ -81,4 +61,70 @@ def get_loras_list(kwargs):
# Unexpected format
else:
logger.warning(f"Unexpected loras format: {type(loras_data)}")
return []
return []
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
import safetensors.torch
state_dict = {}
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
for k in f.keys():
if filter_prefix and not k.startswith(filter_prefix):
continue
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
return state_dict
def to_diffusers(input_lora):
"""Simplified version of to_diffusers for Flux LoRA conversion"""
import torch
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from diffusers.loaders import FluxLoraLoaderMixin
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
tensors = {k: v for k, v in input_lora.items()}
# Convert FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
return new_tensors
def nunchaku_load_lora(model, lora_name, lora_strength):
"""Load a Flux LoRA for Nunchaku model"""
model_wrapper = model.model.diffusion_model
transformer = model_wrapper.model
# Save the transformer temporarily
model_wrapper.model = None
ret_model = copy.deepcopy(model) # copy everything except the model
ret_model_wrapper = ret_model.model.diffusion_model
# Restore the model and set it for the copy
model_wrapper.model = transformer
ret_model_wrapper.model = transformer
# Get full path to the LoRA file
lora_path = folder_paths.get_full_path("loras", lora_name)
ret_model_wrapper.loras.append((lora_path, lora_strength))
# Convert the LoRA to diffusers format
sd = to_diffusers(lora_path)
# Handle embedding adjustment if needed
if "transformer.x_embedder.lora_A.weight" in sd:
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
assert new_in_channels % 4 == 0
new_in_channels = new_in_channels // 4
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
if old_in_channels < new_in_channels:
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
return ret_model

View File

@@ -0,0 +1,98 @@
from comfy.comfy_types import IO # type: ignore
import folder_paths # type: ignore
from ..utils.utils import get_lora_info
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
import logging
logger = logging.getLogger(__name__)
class WanVideoLoraSelect:
NAME = "WanVideo Lora Select (LoraManager)"
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
"text": (IO.STRING, {
"multiline": True,
"pysssss.autocomplete": False,
"dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>"
}),
},
"optional": FlexibleOptionalInputType(any_type),
}
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras"
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
loras_list = []
all_trigger_words = []
active_loras = []
# Process existing prev_lora if available
prev_lora = kwargs.get('prev_lora', None)
if prev_lora is not None:
loras_list.extend(prev_lora)
if not merge_loras:
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
# Get blocks if available
blocks = kwargs.get('blocks', {})
selected_blocks = blocks.get("selected_blocks", {})
layer_filter = blocks.get("layer_filter", "")
# Process loras from kwargs with support for both old and new formats
loras_from_widget = get_loras_list(kwargs)
for lora in loras_from_widget:
if not lora.get('active', False):
continue
lora_name = lora['name']
model_strength = float(lora['strength'])
clip_strength = float(lora.get('clipStrength', model_strength))
# Get lora path and trigger words
lora_path, trigger_words = get_lora_info(lora_name)
# Create lora item for WanVideo format
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"strength": model_strength,
"name": lora_path.split(".")[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,
"merge_loras": merge_loras,
}
# Add to list and collect active loras
loras_list.append(lora_item)
active_loras.append((lora_name, model_strength, clip_strength))
# Add trigger words to collection
all_trigger_words.extend(trigger_words)
# Format trigger_words for output
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Format active_loras for output
formatted_loras = []
for name, model_strength, clip_strength in active_loras:
if abs(model_strength - clip_strength) > 0.001:
# Different model and clip strengths
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
else:
# Same strength for both
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (loras_list, trigger_words_text, active_loras_text)

View File

@@ -7,7 +7,8 @@ from .parsers import (
RecipeFormatParser,
ComfyMetadataParser,
MetaFormatParser,
AutomaticMetadataParser
AutomaticMetadataParser,
CivitaiApiMetadataParser
)
__all__ = [
@@ -18,5 +19,6 @@ __all__ = [
'RecipeFormatParser',
'ComfyMetadataParser',
'MetaFormatParser',
'AutomaticMetadataParser'
'AutomaticMetadataParser',
'CivitaiApiMetadataParser'
]

View File

@@ -7,7 +7,7 @@ import re
from typing import Dict, List, Any, Optional, Tuple
from abc import ABC, abstractmethod
from ..config import config
from .constants import VALID_LORA_TYPES
from ..utils.constants import VALID_LORA_TYPES
logger = logging.getLogger(__name__)
@@ -79,6 +79,9 @@ class RecipeMetadataParser(ABC):
if 'model' in civitai_info and 'name' in civitai_info['model']:
lora_entry['name'] = civitai_info['model']['name']
lora_entry['id'] = civitai_info.get('id')
lora_entry['modelId'] = civitai_info.get('modelId')
# Update version if available
if 'name' in civitai_info:
lora_entry['version'] = civitai_info.get('name', '')
@@ -116,10 +119,10 @@ class RecipeMetadataParser(ABC):
# Check if exists locally
if recipe_scanner and lora_entry['hash']:
lora_scanner = recipe_scanner._lora_scanner
exists_locally = lora_scanner.has_lora_hash(lora_entry['hash'])
exists_locally = lora_scanner.has_hash(lora_entry['hash'])
if exists_locally:
try:
local_path = lora_scanner.get_lora_path_by_hash(lora_entry['hash'])
local_path = lora_scanner.get_path_by_hash(lora_entry['hash'])
lora_entry['existsLocally'] = True
lora_entry['localPath'] = local_path
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]

View File

@@ -1,5 +1,8 @@
"""Constants used across recipe parsers."""
# Import VALID_LORA_TYPES from utils.constants
from ..utils.constants import VALID_LORA_TYPES
# Constants for generation parameters
GEN_PARAM_KEYS = [
'prompt',
@@ -11,6 +14,3 @@ GEN_PARAM_KEYS = [
'size',
'clip_skip',
]
# Valid Lora types
VALID_LORA_TYPES = ['lora', 'locon']

View File

@@ -5,7 +5,8 @@ from .parsers import (
RecipeFormatParser,
ComfyMetadataParser,
MetaFormatParser,
AutomaticMetadataParser
AutomaticMetadataParser,
CivitaiApiMetadataParser
)
from .base import RecipeMetadataParser
@@ -15,29 +16,49 @@ class RecipeParserFactory:
"""Factory for creating recipe metadata parsers"""
@staticmethod
def create_parser(user_comment: str) -> RecipeMetadataParser:
def create_parser(metadata) -> RecipeMetadataParser:
"""
Create appropriate parser based on the user comment content
Create appropriate parser based on the metadata content
Args:
user_comment: The EXIF UserComment string from the image
metadata: The metadata from the image (dict or str)
Returns:
Appropriate RecipeMetadataParser implementation
"""
# Try ComfyMetadataParser first since it requires valid JSON
# First, try CivitaiApiMetadataParser for dict input
if isinstance(metadata, dict):
try:
if CivitaiApiMetadataParser().is_metadata_matching(metadata):
return CivitaiApiMetadataParser()
except Exception as e:
logger.debug(f"CivitaiApiMetadataParser check failed: {e}")
pass
# Convert dict to string for other parsers that expect string input
try:
import json
metadata_str = json.dumps(metadata)
except Exception as e:
logger.debug(f"Failed to convert dict to JSON string: {e}")
return None
else:
metadata_str = metadata
# Try ComfyMetadataParser which requires valid JSON
try:
if ComfyMetadataParser().is_metadata_matching(user_comment):
if ComfyMetadataParser().is_metadata_matching(metadata_str):
return ComfyMetadataParser()
except Exception:
# If JSON parsing fails, move on to other parsers
pass
if RecipeFormatParser().is_metadata_matching(user_comment):
# Check other parsers that expect string input
if RecipeFormatParser().is_metadata_matching(metadata_str):
return RecipeFormatParser()
elif AutomaticMetadataParser().is_metadata_matching(user_comment):
elif AutomaticMetadataParser().is_metadata_matching(metadata_str):
return AutomaticMetadataParser()
elif MetaFormatParser().is_metadata_matching(user_comment):
elif MetaFormatParser().is_metadata_matching(metadata_str):
return MetaFormatParser()
else:
return None

View File

@@ -4,10 +4,12 @@ from .recipe_format import RecipeFormatParser
from .comfy import ComfyMetadataParser
from .meta_format import MetaFormatParser
from .automatic import AutomaticMetadataParser
from .civitai_image import CivitaiApiMetadataParser
__all__ = [
'RecipeFormatParser',
'ComfyMetadataParser',
'MetaFormatParser',
'AutomaticMetadataParser',
'CivitaiApiMetadataParser',
]

View File

@@ -19,7 +19,7 @@ class AutomaticMetadataParser(RecipeMetadataParser):
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
EXTRANETS_REGEX = r'<(lora|hypernet):([a-zA-Z0-9_\.\-]+):([0-9.]+)>'
EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
@@ -181,13 +181,30 @@ class AutomaticMetadataParser(RecipeMetadataParser):
# First use Civitai resources if available (more reliable source)
if metadata.get("civitai_resources"):
for resource in metadata.get("civitai_resources", []):
# --- Added: Parse 'air' field if present ---
air = resource.get("air")
if air:
# Format: urn:air:sdxl:lora:civitai:1221007@1375651
# Or: urn:air:sdxl:checkpoint:civitai:623891@2019115
air_pattern = r"urn:air:[^:]+:(?P<type>[^:]+):civitai:(?P<modelId>\d+)@(?P<modelVersionId>\d+)"
air_match = re.match(air_pattern, air)
if air_match:
air_type = air_match.group("type")
air_modelId = int(air_match.group("modelId"))
air_modelVersionId = int(air_match.group("modelVersionId"))
# checkpoint/lycoris/lora/hypernet
resource["type"] = air_type
resource["modelId"] = air_modelId
resource["modelVersionId"] = air_modelVersionId
# --- End added ---
if resource.get("type") in ["lora", "lycoris", "hypernet"] and resource.get("modelVersionId"):
# Initialize lora entry
lora_entry = {
'id': str(resource.get("modelVersionId")),
'modelId': str(resource.get("modelId")) if resource.get("modelId") else None,
'id': resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", ""),
'version': resource.get("modelVersionName", resource.get("versionName", "")),
'type': resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False,

View File

@@ -0,0 +1,347 @@
"""Parser for Civitai image metadata format."""
import json
import logging
from typing import Dict, Any, Union
from ..base import RecipeMetadataParser
from ..constants import GEN_PARAM_KEYS
logger = logging.getLogger(__name__)
class CivitaiApiMetadataParser(RecipeMetadataParser):
"""Parser for Civitai image metadata format"""
def is_metadata_matching(self, metadata) -> bool:
"""Check if the metadata matches the Civitai image metadata format
Args:
metadata: The metadata from the image (dict)
Returns:
bool: True if this parser can handle the metadata
"""
if not metadata or not isinstance(metadata, dict):
return False
# Check for key markers specific to Civitai image metadata
return any([
"resources" in metadata,
"civitaiResources" in metadata,
"additionalResources" in metadata
])
async def parse_metadata(self, metadata, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from Civitai image format
Args:
metadata: The metadata from the image (dict)
recipe_scanner: Optional recipe scanner service
civitai_client: Optional Civitai API client
Returns:
Dict containing parsed recipe data
"""
try:
# Initialize result structure
result = {
'base_model': None,
'loras': [],
'gen_params': {},
'from_civitai_image': True
}
# Track already added LoRAs to prevent duplicates
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
# Extract prompt and negative prompt
if "prompt" in metadata:
result["gen_params"]["prompt"] = metadata["prompt"]
if "negativePrompt" in metadata:
result["gen_params"]["negative_prompt"] = metadata["negativePrompt"]
# Extract other generation parameters
param_mapping = {
"steps": "steps",
"sampler": "sampler",
"cfgScale": "cfg_scale",
"seed": "seed",
"Size": "size",
"clipSkip": "clip_skip",
}
for civitai_key, our_key in param_mapping.items():
if civitai_key in metadata and our_key in GEN_PARAM_KEYS:
result["gen_params"][our_key] = metadata[civitai_key]
# Extract base model information - directly if available
if "baseModel" in metadata:
result["base_model"] = metadata["baseModel"]
elif "Model hash" in metadata and civitai_client:
model_hash = metadata["Model hash"]
model_info = await civitai_client.get_model_by_hash(model_hash)
if model_info:
result["base_model"] = model_info.get("baseModel", "")
elif "Model" in metadata and isinstance(metadata.get("resources"), list):
# Try to find base model in resources
for resource in metadata.get("resources", []):
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
# This is likely the checkpoint model
if civitai_client and resource.get("hash"):
model_info = await civitai_client.get_model_by_hash(resource.get("hash"))
if model_info:
result["base_model"] = model_info.get("baseModel", "")
base_model_counts = {}
# Process standard resources array
if "resources" in metadata and isinstance(metadata["resources"], list):
for resource in metadata["resources"]:
# Modified to process resources without a type field as potential LoRAs
if resource.get("type", "lora") == "lora":
lora_hash = resource.get("hash", "")
# Skip LoRAs without proper identification (hash or modelVersionId)
if not lora_hash and not resource.get("modelVersionId"):
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
continue
# Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras:
continue
lora_entry = {
'name': resource.get("name", "Unknown LoRA"),
'type': "lora",
'weight': float(resource.get("weight", 1.0)),
'hash': lora_hash,
'existsLocally': False,
'localPath': None,
'file_name': resource.get("name", "Unknown"),
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get info from Civitai if hash is available
if lora_entry['hash'] and civitai_client:
try:
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
lora_hash
)
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']:
added_loras[str(lora_entry['id'])] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
# Track by hash if we have it
if lora_hash:
added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry)
# Process civitaiResources array
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
for resource in metadata["civitaiResources"]:
# Get unique identifier for deduplication
version_id = str(resource.get("modelVersionId", ""))
# Skip if we've already added this LoRA
if version_id and version_id in added_loras:
continue
# Initialize lora entry
lora_entry = {
'id': resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", ""),
'type': resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get info from Civitai if modelVersionId is available
if version_id and civitai_client:
try:
# Use get_model_version_info instead of get_model_version
civitai_info, error = await civitai_client.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
continue
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts
)
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
except Exception as e:
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
# Track this LoRA in our deduplication dict
if version_id:
added_loras[version_id] = len(result["loras"])
result["loras"].append(lora_entry)
# Process additionalResources array
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
for resource in metadata["additionalResources"]:
# Skip resources that aren't LoRAs or LyCORIS
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
continue
lora_type = resource.get("type", "lora")
name = resource.get("name", "")
# Extract ID from URN format if available
version_id = None
if name and "civitai:" in name:
parts = name.split("@")
if len(parts) > 1:
version_id = parts[1]
# Skip if we've already added this LoRA
if version_id in added_loras:
continue
lora_entry = {
'name': name,
'type': lora_type,
'weight': float(resource.get("strength", 1.0)),
'hash': "",
'existsLocally': False,
'localPath': None,
'file_name': name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# If we have a version ID and civitai client, try to get more info
if version_id and civitai_client:
try:
# Use get_model_version_info with the version ID
civitai_info, error = await civitai_client.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
else:
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts
)
if populated_entry is None:
continue # Skip invalid LoRA types
lora_entry = populated_entry
# Track this LoRA for deduplication
if version_id:
added_loras[version_id] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
result["loras"].append(lora_entry)
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
lora_index = 0
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
lora_name = metadata[f"Lora_{lora_index} Model name"]
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
# Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras:
lora_index += 1
continue
lora_entry = {
'name': lora_name,
'type': "lora",
'weight': lora_strength_model,
'hash': lora_hash,
'existsLocally': False,
'localPath': None,
'file_name': lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get info from Civitai if hash is available
if lora_entry['hash'] and civitai_client:
try:
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
lora_hash
)
if populated_entry is None:
lora_index += 1
continue # Skip invalid LoRA types
lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']:
added_loras[str(lora_entry['id'])] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
# Track by hash if we have it
if lora_hash:
added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry)
lora_index += 1
# If base model wasn't found earlier, use the most common one from LoRAs
if not result["base_model"] and base_model_counts:
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0]
return result
except Exception as e:
logger.error(f"Error parsing Civitai image metadata: {e}", exc_info=True)
return {"error": str(e), "loras": []}

View File

@@ -43,7 +43,7 @@ class RecipeFormatParser(RecipeMetadataParser):
for lora in recipe_metadata.get('loras', []):
# Convert recipe lora format to frontend format
lora_entry = {
'id': lora.get('modelVersionId', ''),
'id': int(lora.get('modelVersionId', 0)),
'name': lora.get('modelName', ''),
'version': lora.get('modelVersionName', ''),
'type': 'lora',
@@ -55,7 +55,7 @@ class RecipeFormatParser(RecipeMetadataParser):
# Check if this LoRA exists locally by SHA256 hash
if lora.get('hash') and recipe_scanner:
lora_scanner = recipe_scanner._lora_scanner
exists_locally = lora_scanner.has_lora_hash(lora['hash'])
exists_locally = lora_scanner.has_hash(lora['hash'])
if exists_locally:
lora_cache = await lora_scanner.get_cached_data()
lora_item = next((item for item in lora_cache.raw_data if item['sha256'].lower() == lora['hash'].lower()), None)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,140 @@
import logging
from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry
from ..config import config
logger = logging.getLogger(__name__)
class CheckpointRoutes(BaseModelRoutes):
"""Checkpoint-specific route controller"""
def __init__(self):
"""Initialize Checkpoint routes with Checkpoint service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "checkpoints.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def setup_routes(self, app: web.Application):
"""Setup Checkpoint routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'checkpoints' prefix (includes page route)
super().setup_routes(app, 'checkpoints')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Checkpoint-specific routes"""
# Checkpoint-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
# Checkpoint info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
# Checkpoint roots and Unet roots
app.router.add_get(f'/api/{prefix}/checkpoints_roots', self.get_checkpoints_roots)
app.router.add_get(f'/api/{prefix}/unet_roots', self.get_unet_roots)
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.service.get_model_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if model_type.lower() != 'checkpoint':
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
"""Return the list of checkpoint roots from config"""
try:
roots = config.checkpoints_roots
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_unet_roots(self, request: web.Request) -> web.Response:
"""Return the list of unet roots from config"""
try:
roots = config.unet_roots
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting unet roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -1,694 +0,0 @@
import os
import json
import jinja2
from aiohttp import web
import logging
import asyncio
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import NSFW_LEVELS
from ..services.websocket_manager import ws_manager
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__)
class CheckpointsRoutes:
"""API routes for checkpoint management"""
def __init__(self):
self.scanner = None # Will be initialized in setup_routes
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
self.download_manager = None # Will be initialized in setup_routes
self._download_lock = asyncio.Lock()
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
self.download_manager = await ServiceRegistry.get_download_manager()
def setup_routes(self, app):
"""Register routes with the aiohttp app"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
app.router.add_get('/api/checkpoints', self.get_checkpoints)
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
# Add new routes for model management similar to LoRA routes
app.router.add_post('/api/checkpoints/delete', self.delete_model)
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
# Add new WebSocket endpoint for checkpoint progress
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
try:
# Parse query parameters
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
# Process search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Process hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
# Get data from scanner
result = await self.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
folder=folder,
search=search,
fuzzy_search=fuzzy_search,
base_models=base_models,
tags=tags,
search_options=search_options,
hash_filters=hash_filters,
favorites_only=favorites_only # Pass favorites_only parameter
)
# Format response items
formatted_result = {
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
# Return as JSON
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_paginated_data(self, page, page_size, sort_by='name',
folder=None, search=None, fuzzy_search=False,
base_models=None, tags=None,
search_options=None, hash_filters=None,
favorites_only=False): # Add favorites_only parameter with default False
"""Get paginated and filtered checkpoint data"""
cache = await self.scanner.get_cached_data()
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
cp for cp in filtered_data
if cp.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
filtered_data = [
cp for cp in filtered_data
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
filtered_data = [
cp for cp in filtered_data
if cp.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
filtered_data = [
cp for cp in filtered_data
if cp['folder'].startswith(folder)
]
else:
# Exact folder filtering
filtered_data = [
cp for cp in filtered_data
if cp['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
filtered_data = [
cp for cp in filtered_data
if cp.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
cp for cp in filtered_data
if any(tag in cp.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
search_results = []
for cp in filtered_data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(cp.get('file_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('file_name', '').lower():
search_results.append(cp)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(cp.get('model_name', ''), search):
search_results.append(cp)
continue
elif search.lower() in cp.get('model_name', '').lower():
search_results.append(cp)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in cp:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
search_results.append(cp)
continue
filtered_data = search_results
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def _format_checkpoint_response(self, checkpoint):
"""Format checkpoint data for API response"""
return {
"model_name": checkpoint["model_name"],
"file_name": checkpoint["file_name"],
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
"base_model": checkpoint.get("base_model", ""),
"folder": checkpoint["folder"],
"sha256": checkpoint.get("sha256", ""),
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
"file_size": checkpoint.get("size", 0),
"modified": checkpoint.get("modified", ""),
"tags": checkpoint.get("tags", []),
"modelDescription": checkpoint.get("modelDescription", ""),
"from_civitai": checkpoint.get("from_civitai", True),
"notes": checkpoint.get("notes", ""),
"model_type": checkpoint.get("model_type", "checkpoint"),
"favorite": checkpoint.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
}
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all checkpoints in the background"""
try:
cache = await self.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare checkpoints to process
to_process = [
cp for cp in cache.raw_data
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each checkpoint
for cp in to_process:
try:
original_name = cp.get('model_name')
if await ModelRouteUtils.fetch_and_update_model(
sha256=cp['sha256'],
file_path=cp['file_path'],
model_data=cp,
update_cache_func=self.scanner.update_single_model_cache
):
success += 1
if original_name != cp.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': cp.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
if needs_resort:
await cache.resort(name_only=True)
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
return web.Response(text=str(e), status=500)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get top tags
top_tags = await self.scanner.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in loras"""
try:
# Parse query parameters
limit = int(request.query.get('limit', '20'))
# Validate limit
if limit < 1 or limit > 100:
limit = 20 # Default to a reasonable limit
# Get base models
base_models = await self.scanner.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def scan_checkpoints(self, request):
"""Force a rescan of checkpoint files"""
try:
await self.scanner.get_cached_data(force_refresh=True)
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
except Exception as e:
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_checkpoint_info(self, request):
"""Get detailed information for a specific checkpoint by name"""
try:
name = request.match_info.get('name', '')
checkpoint_info = await self.scanner.get_checkpoint_info_by_name(name)
if checkpoint_info:
return web.json_response(checkpoint_info)
else:
return web.json_response({"error": "Checkpoint not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
"""Handle GET /checkpoints request"""
try:
# Check if the CheckpointScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.scanner._cache is None or
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[], # 空文件夹列表
is_initializing=True, # 新增标志
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
logger.info("Checkpoints page is initializing, returning loading page")
else:
# 正常流程 - 获取已经初始化好的缓存数据
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings, # Pass settings to template
request=request # Pass the request object to the template
)
except Exception as cache_error:
logger.error(f"Error loading checkpoints cache data: {cache_error}")
# 如果获取缓存失败,也显示初始化页面
template = self.template_env.get_template('checkpoints.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Checkpoints cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
return web.Response(
text="Error loading checkpoints page",
status=500
)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle checkpoint model deletion request"""
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
async def exclude_model(self, request: web.Request) -> web.Response:
"""Handle checkpoint model exclusion request"""
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request for checkpoints"""
return await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement for checkpoints"""
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
async def download_checkpoint(self, request: web.Request) -> web.Response:
"""Handle checkpoint download request"""
async with self._download_lock:
# Get the download manager from service registry if not already initialized
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
try:
data = await request.json()
# Create progress callback that uses checkpoint-specific WebSocket
async def progress_callback(progress):
await ws_manager.broadcast_checkpoint_progress({
'status': 'progress',
'progress': progress
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
result = await self.download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_version_id=model_version_id,
save_dir=data.get('checkpoint_root'),
relative_path=data.get('relative_path', ''),
progress_callback=progress_callback,
model_type="checkpoint"
)
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401,
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response(result)
except Exception as e:
error_message = str(e)
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
)
logger.error(f"Error downloading checkpoint: {error_message}")
return web.Response(status=500, text=error_message)
async def get_checkpoint_roots(self, request):
"""Return the checkpoint root directories"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
roots = self.scanner.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def save_metadata(self, request: web.Request) -> web.Response:
"""Handle saving metadata updates for checkpoints"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='File path is required', status=400)
# Remove file path from data to avoid saving it
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
# Get metadata file path
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Update metadata
metadata.update(metadata_updates)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
# Update cache
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
# If model_name was updated, resort the cache
if 'model_name' in metadata_updates:
cache = await self.scanner.get_cached_data()
await cache.resort(name_only=True)
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai checkpoint model with local availability info"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get the civitai client from service registry
civitai_client = await ServiceRegistry.get_civitai_client()
model_id = request.match_info['model_id']
response = await civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be Checkpoint
if (model_type.lower() != 'checkpoint'):
return web.json_response({
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.scanner.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.scanner.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -0,0 +1,105 @@
import logging
from aiohttp import web
from .base_model_routes import BaseModelRoutes
from ..services.embedding_service import EmbeddingService
from ..services.service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class EmbeddingRoutes(BaseModelRoutes):
"""Embedding-specific route controller"""
def __init__(self):
"""Initialize Embedding routes with Embedding service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "embeddings.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.service = EmbeddingService(embedding_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def setup_routes(self, app: web.Application):
"""Setup Embedding routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Setup common routes with 'embeddings' prefix (includes page route)
super().setup_routes(app, 'embeddings')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Embedding-specific routes"""
# Embedding-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding)
# Embedding info by name
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info)
async def get_embedding_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific embedding by name"""
try:
name = request.match_info.get('name', '')
embedding_info = await self.service.get_model_info_by_name(name)
if embedding_info:
return web.json_response(embedding_info)
else:
return web.json_response({"error": "Embedding not found"}, status=404)
except Exception as e:
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai embedding model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be TextualInversion (Embedding)
if model_type.lower() not in ['textualinversion', 'embedding']:
return web.json_response({
'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching embedding model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -0,0 +1,74 @@
import logging
from ..utils.example_images_download_manager import DownloadManager
from ..utils.example_images_processor import ExampleImagesProcessor
from ..utils.example_images_file_manager import ExampleImagesFileManager
from ..services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
class ExampleImagesRoutes:
"""Routes for example images related functionality"""
@staticmethod
def setup_routes(app):
"""Register example images routes"""
app.router.add_post('/api/download-example-images', ExampleImagesRoutes.download_example_images)
app.router.add_post('/api/import-example-images', ExampleImagesRoutes.import_example_images)
app.router.add_get('/api/example-images-status', ExampleImagesRoutes.get_example_images_status)
app.router.add_post('/api/pause-example-images', ExampleImagesRoutes.pause_example_images)
app.router.add_post('/api/resume-example-images', ExampleImagesRoutes.resume_example_images)
app.router.add_post('/api/open-example-images-folder', ExampleImagesRoutes.open_example_images_folder)
app.router.add_get('/api/example-image-files', ExampleImagesRoutes.get_example_image_files)
app.router.add_get('/api/has-example-images', ExampleImagesRoutes.has_example_images)
app.router.add_post('/api/delete-example-image', ExampleImagesRoutes.delete_example_image)
app.router.add_post('/api/force-download-example-images', ExampleImagesRoutes.force_download_example_images)
@staticmethod
async def download_example_images(request):
"""Download example images for models from Civitai"""
return await DownloadManager.start_download(request)
@staticmethod
async def get_example_images_status(request):
"""Get the current status of example images download"""
return await DownloadManager.get_status(request)
@staticmethod
async def pause_example_images(request):
"""Pause the example images download"""
return await DownloadManager.pause_download(request)
@staticmethod
async def resume_example_images(request):
"""Resume the example images download"""
return await DownloadManager.resume_download(request)
@staticmethod
async def open_example_images_folder(request):
"""Open the example images folder for a specific model"""
return await ExampleImagesFileManager.open_folder(request)
@staticmethod
async def get_example_image_files(request):
"""Get list of example image files for a specific model"""
return await ExampleImagesFileManager.get_files(request)
@staticmethod
async def import_example_images(request):
"""Import local example images for a model"""
return await ExampleImagesProcessor.import_images(request)
@staticmethod
async def has_example_images(request):
"""Check if example images folder exists and is not empty for a model"""
return await ExampleImagesFileManager.has_images(request)
@staticmethod
async def delete_example_image(request):
"""Delete a custom example image for a model"""
return await ExampleImagesProcessor.delete_custom_image(request)
@staticmethod
async def force_download_example_images(request):
"""Force download example images for specific models"""
return await DownloadManager.start_force_download(request)

View File

@@ -1,189 +1,398 @@
import os
from aiohttp import web
import jinja2
from typing import Dict
import asyncio
import logging
from ..config import config
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
from aiohttp import web
from typing import Dict
from server import PromptServer # type: ignore
from .base_model_routes import BaseModelRoutes
from ..services.lora_service import LoraService
from ..services.service_registry import ServiceRegistry
from ..utils.routes_common import ModelRouteUtils
from ..utils.utils import get_lora_info
logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
class LoraRoutes:
"""Route handlers for LoRA management endpoints"""
class LoraRoutes(BaseModelRoutes):
"""LoRA-specific route controller"""
def __init__(self):
# Initialize service references as None, will be set during async init
self.scanner = None
self.recipe_scanner = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.scanner = await ServiceRegistry.get_lora_scanner()
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
"""Initialize LoRA routes with LoRA service"""
# Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "loras.html"
async def initialize_services(self):
"""Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Initialize parent with the service
super().__init__(self.service)
def format_lora_data(self, lora: Dict) -> Dict:
"""Format LoRA data for template rendering"""
return {
"model_name": lora["model_name"],
"file_name": lora["file_name"],
"preview_url": config.get_preview_static_url(lora["preview_url"]),
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
"base_model": lora["base_model"],
"folder": lora["folder"],
"sha256": lora["sha256"],
"file_path": lora["file_path"].replace(os.sep, "/"),
"size": lora["size"],
"tags": lora["tags"],
"modelDescription": lora["modelDescription"],
"usage_tips": lora["usage_tips"],
"notes": lora["notes"],
"modified": lora["modified"],
"from_civitai": lora.get("from_civitai", True),
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
}
def _filter_civitai_data(self, data: Dict) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
]
return {k: data[k] for k in fields if k in data}
async def handle_loras_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras request"""
try:
# Ensure services are initialized
await self.init_services()
# Check if the LoraScanner is initializing
# It's initializing if the cache object doesn't exist yet,
# OR if the scanner explicitly says it's initializing (background task running).
is_initializing = (
self.scanner._cache is None or
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
)
if is_initializing:
# If still initializing, return loading page
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Loras page is initializing, returning loading page")
else:
# Normal flow - get data from initialized cache
try:
cache = await self.scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=cache.folders,
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
template = self.template_env.get_template('loras.html')
rendered = template.render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
logger.info("Cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling loras request: {e}", exc_info=True)
return web.Response(
text="Error loading loras page",
status=500
)
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# Ensure services are initialized
await self.init_services()
# Skip initialization check and directly try to get cached data
try:
# Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request
)
logger.info("Recipe cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling recipes request: {e}", exc_info=True)
return web.Response(
text="Error loading recipes page",
status=500
)
def _format_recipe_file_url(self, file_path: str) -> str:
"""Format file path for recipe image as a URL - same as in recipe_routes"""
try:
# Return the file URL directly for the first lora root's preview
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
if file_path.replace(os.sep, '/').startswith(recipes_dir):
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
return f"/loras_static/root1/preview/{relative_path}"
# If not in recipes dir, try to create a valid URL from the file path
file_name = os.path.basename(file_path)
return f"/loras_static/root1/preview/recipes/{file_name}"
except Exception as e:
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
return '/loras_static/images/no-preview.png' # Return default image on error
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
"""Setup LoRA routes"""
# Schedule service initialization on app startup
app.on_startup.append(lambda _: self.initialize_services())
# Register routes
app.router.add_get('/loras', self.handle_loras_page)
app.router.add_get('/loras/recipes', self.handle_recipes_page)
# Setup common routes with 'loras' prefix (includes page route)
super().setup_routes(app, 'loras')
def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup LoRA-specific routes"""
# LoRA-specific query routes
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
app.router.add_get(f'/api/{prefix}/model-description', self.get_lora_model_description)
app.router.add_get(f'/api/{prefix}/usage-tips-by-path', self.get_lora_usage_tips_by_path)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()
# CivitAI integration with LoRA-specific validation
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
# ComfyUI integration
app.router.add_post(f'/api/{prefix}/get_trigger_words', self.get_trigger_words)
def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse LoRA-specific parameters"""
params = {}
# LoRA-specific parameters
if 'first_letter' in request.query:
params['first_letter'] = request.query.get('first_letter')
# Handle fuzzy search parameter name variation
if request.query.get('fuzzy') == 'true':
params['fuzzy_search'] = True
# Handle additional filter parameters for LoRAs
if 'lora_hash' in request.query:
if not params.get('hash_filters'):
params['hash_filters'] = {}
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
elif 'lora_hashes' in request.query:
if not params.get('hash_filters'):
params['hash_filters'] = {}
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
return params
# LoRA-specific route handlers
async def get_letter_counts(self, request: web.Request) -> web.Response:
"""Get count of LoRAs for each letter of the alphabet"""
try:
letter_counts = await self.service.get_letter_counts()
return web.json_response({
'success': True,
'letter_counts': letter_counts
})
except Exception as e:
logger.error(f"Error getting letter counts: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_notes(self, request: web.Request) -> web.Response:
"""Get notes for a specific LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
notes = await self.service.get_lora_notes(lora_name)
if notes is not None:
return web.json_response({
'success': True,
'notes': notes
})
else:
return web.json_response({
'success': False,
'error': 'LoRA not found in cache'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora notes: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for a specific LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
trigger_words = await self.service.get_lora_trigger_words(lora_name)
return web.json_response({
'success': True,
'trigger_words': trigger_words
})
except Exception as e:
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
"""Get usage tips for a LoRA by its relative path"""
try:
relative_path = request.query.get('relative_path')
if not relative_path:
return web.Response(text='Relative path is required', status=400)
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
return web.json_response({
'success': True,
'usage_tips': usage_tips or ''
})
except Exception as e:
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
preview_url = await self.service.get_lora_preview_url(lora_name)
if preview_url:
return web.json_response({
'success': True,
'preview_url': preview_url
})
else:
return web.json_response({
'success': False,
'error': 'No preview URL found for the specified lora'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
"""Get the Civitai URL for a LoRA file"""
try:
lora_name = request.query.get('name')
if not lora_name:
return web.Response(text='Lora file name is required', status=400)
result = await self.service.get_lora_civitai_url(lora_name)
if result['civitai_url']:
return web.json_response({
'success': True,
**result
})
else:
return web.json_response({
'success': False,
'error': 'No Civitai data found for the specified lora'
}, status=404)
except Exception as e:
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
# CivitAI integration methods
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai LoRA model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be LORA, LoCon, or DORA
from ..utils.constants import VALID_LORA_TYPES
if model_type.lower() not in VALID_LORA_TYPES:
return web.json_response({
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model") in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching LoRA model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from Civitai API
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
if not model_id:
return web.json_response({
'success': False,
'error': 'Model ID is required'
}, status=400)
# Check if we already have the description stored in metadata
description = None
tags = []
creator = {}
if file_path:
import os
from ..utils.metadata_manager import MetadataManager
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
creator = metadata.get('creator', {})
# If description is not in metadata, fetch from CivitAI
if not description:
logger.info(f"Fetching model metadata for model ID: {model_id}")
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
if model_metadata:
description = model_metadata.get('description')
tags = model_metadata.get('tags', [])
creator = model_metadata.get('creator', {})
# Save the metadata to file if we have a file path and got metadata
if file_path:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
metadata['modelDescription'] = description
metadata['tags'] = tags
# Ensure the civitai dict exists
if 'civitai' not in metadata:
metadata['civitai'] = {}
# Store creator in the civitai nested structure
metadata['civitai']['creator'] = creator
await MetadataManager.save_metadata(file_path, metadata, True)
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
return web.json_response({
'success': True,
'description': description or "<p>No model description available.</p>",
'tags': tags,
'creator': creator
})
except Exception as e:
logger.error(f"Error getting model metadata: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models"""
try:
json_data = await request.json()
lora_names = json_data.get("lora_names", [])
node_ids = json_data.get("node_ids", [])
all_trigger_words = []
for lora_name in lora_names:
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Format the trigger words
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Send update to all connected trigger word toggle nodes
for node_id in node_ids:
PromptServer.instance.send_sync("trigger_word_update", {
"id": node_id,
"message": trigger_words_text
})
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error getting trigger words: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
import os
import time
import base64
import jinja2
import numpy as np
from PIL import Image
import torch
import io
import logging
from aiohttp import web
@@ -16,12 +16,12 @@ from ..utils.exif_utils import ExifUtils
from ..recipes import RecipeParserFactory
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..services.settings_manager import settings
from ..config import config
# Check if running in standalone mode
standalone_mode = 'nodes' not in sys.modules
from ..utils.utils import download_civitai_image
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
# Only import MetadataRegistry in non-standalone mode
@@ -40,7 +40,10 @@ class RecipeRoutes:
# Initialize service references as None, will be set during async init
self.recipe_scanner = None
self.civitai_client = None
# Remove WorkflowParser instance
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
# Pre-warm the cache
self._init_cache_task = None
@@ -54,6 +57,8 @@ class RecipeRoutes:
def setup_routes(cls, app: web.Application):
"""Register API routes"""
routes = cls()
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
app.router.add_get('/api/recipes', routes.get_recipes)
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
@@ -115,6 +120,46 @@ class RecipeRoutes:
await self.recipe_scanner.get_cached_data(force_refresh=True)
except Exception as e:
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
async def handle_recipes_page(self, request: web.Request) -> web.Response:
"""Handle GET /loras/recipes request"""
try:
# Ensure services are initialized
await self.init_services()
# Skip initialization check and directly try to get cached data
try:
# Recipe scanner will initialize cache if needed
await self.recipe_scanner.get_cached_data(force_refresh=False)
template = self.template_env.get_template('recipes.html')
rendered = template.render(
recipes=[], # Frontend will load recipes via API
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading recipe cache data: {cache_error}")
# Still keep error handling - show initializing page on error
template = self.template_env.get_template('recipes.html')
rendered = template.render(
is_initializing=True,
settings=settings,
request=request
)
logger.info("Recipe cache error, returning initialization page")
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling recipes request: {e}", exc_info=True)
return web.Response(
text="Error loading recipes page",
status=500
)
async def get_recipes(self, request: web.Request) -> web.Response:
"""API endpoint for getting paginated recipes"""
@@ -254,6 +299,7 @@ class RecipeRoutes:
content_type = request.headers.get('Content-Type', '')
is_url_mode = False
metadata = None # Initialize metadata variable
if 'multipart/form-data' in content_type:
# Handle image upload
@@ -287,17 +333,53 @@ class RecipeRoutes:
"loras": []
}, status=400)
# Download image from URL
temp_path = download_civitai_image(url)
# Check if this is a Civitai image URL
import re
civitai_image_match = re.match(r'https://civitai\.com/images/(\d+)', url)
if not temp_path:
return web.json_response({
"error": "Failed to download image from URL",
"loras": []
}, status=400)
if civitai_image_match:
# Extract image ID and fetch image info using get_image_info
image_id = civitai_image_match.group(1)
image_info = await self.civitai_client.get_image_info(image_id)
if not image_info:
return web.json_response({
"error": "Failed to fetch image information from Civitai",
"loras": []
}, status=400)
# Get image URL from response
image_url = image_info.get('url')
if not image_url:
return web.json_response({
"error": "No image URL found in Civitai response",
"loras": []
}, status=400)
# Download image directly from URL
session = await self.civitai_client.session
# Create a temporary file to save the downloaded image
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
temp_path = temp_file.name
async with session.get(image_url) as response:
if response.status != 200:
return web.json_response({
"error": f"Failed to download image from URL: HTTP {response.status}",
"loras": []
}, status=400)
with open(temp_path, 'wb') as f:
f.write(await response.read())
# Use meta field from image_info as metadata
if 'meta' in image_info:
metadata = image_info['meta']
# Extract metadata from the image using ExifUtils
metadata = ExifUtils.extract_image_metadata(temp_path)
# If metadata wasn't obtained from Civitai API, extract it from the image
if metadata is None:
# Extract metadata from the image using ExifUtils
metadata = ExifUtils.extract_image_metadata(temp_path)
# If no metadata found, return a more specific error
if not metadata:
@@ -545,21 +627,6 @@ class RecipeRoutes:
image = base64.b64decode(image_base64)
except Exception as e:
return web.json_response({"error": f"Invalid base64 image data: {str(e)}"}, status=400)
elif image_url:
# Download image from URL
temp_path = download_civitai_image(image_url)
if not temp_path:
return web.json_response({"error": "Failed to download image from URL"}, status=400)
# Read the downloaded image
with open(temp_path, 'rb') as f:
image = f.read()
# Clean up temp file
try:
os.unlink(temp_path)
except:
pass
else:
return web.json_response({"error": "No image data provided"}, status=400)
@@ -601,7 +668,7 @@ class RecipeRoutes:
"file_name": lora.get("file_name", "") or os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] if lora.get("localPath") else "",
"hash": lora.get("hash", "").lower() if lora.get("hash") else "",
"strength": float(lora.get("weight", 1.0)),
"modelVersionId": lora.get("id", ""),
"modelVersionId": lora.get("id", 0),
"modelName": lora.get("name", ""),
"modelVersionName": lora.get("version", ""),
"isDeleted": lora.get("isDeleted", False), # Preserve deletion status in saved recipe
@@ -949,7 +1016,7 @@ class RecipeRoutes:
else:
latest_image = None
if not latest_image:
if latest_image is None:
return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400)
# Convert the image data to bytes - handle tuple and tensor cases
@@ -971,6 +1038,8 @@ class RecipeRoutes:
shape_info = tensor_image.shape
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
import torch
# Convert tensor to numpy array
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
@@ -1053,14 +1122,14 @@ class RecipeRoutes:
for lora_name, lora_strength in lora_matches:
try:
# Get lora info from scanner
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
# Create lora entry
lora_entry = {
"file_name": lora_name,
"hash": lora_info.get("sha256", "").lower() if lora_info else "",
"strength": float(lora_strength),
"modelVersionId": lora_info.get("civitai", {}).get("id", "") if lora_info else "",
"modelVersionId": lora_info.get("civitai", {}).get("id", 0) if lora_info else 0,
"modelName": lora_info.get("civitai", {}).get("model", {}).get("name", "") if lora_info else lora_name,
"modelVersionName": lora_info.get("civitai", {}).get("name", "") if lora_info else "",
"isDeleted": False
@@ -1072,7 +1141,7 @@ class RecipeRoutes:
# Get base model from lora scanner for the available loras
base_model_counts = {}
for lora in loras_data:
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
if lora_info and "base_model" in lora_info:
base_model = lora_info["base_model"]
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
@@ -1162,7 +1231,7 @@ class RecipeRoutes:
if lora.get("isDeleted", False):
continue
if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")):
if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")):
continue
# Get the strength
@@ -1219,9 +1288,9 @@ class RecipeRoutes:
data = await request.json()
# Validate required fields
if 'title' not in data and 'tags' not in data and 'source_path' not in data:
if 'title' not in data and 'tags' not in data and 'source_path' not in data and 'preview_nsfw_level' not in data:
return web.json_response({
"error": "At least one field to update must be provided (title or tags or source_path)"
"error": "At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
}, status=400)
# Use the recipe scanner's update method
@@ -1249,7 +1318,7 @@ class RecipeRoutes:
data = await request.json()
# Validate required fields
required_fields = ['recipe_id', 'lora_data', 'target_name']
required_fields = ['recipe_id', 'lora_index', 'target_name']
for field in required_fields:
if field not in data:
return web.json_response({
@@ -1257,7 +1326,7 @@ class RecipeRoutes:
}, status=400)
recipe_id = data['recipe_id']
lora_data = data['lora_data']
lora_index = int(data['lora_index'])
target_name = data['target_name']
# Get recipe scanner
@@ -1270,53 +1339,34 @@ class RecipeRoutes:
return web.json_response({"error": "Recipe not found"}, status=404)
# Find target LoRA by name
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
target_lora = await lora_scanner.get_model_info_by_name(target_name)
if not target_lora:
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
# Load recipe data
with open(recipe_path, 'r', encoding='utf-8') as f:
recipe_data = json.load(f)
# Find the deleted LoRA in the recipe
found = False
updated_lora = None
lora = recipe_data.get("loras", [])[lora_index] if lora_index < len(recipe_data.get('loras', [])) else None
if lora is None:
return web.json_response({"error": "LoRA index out of range in recipe"}, status=404)
# Update LoRA data
lora['isDeleted'] = False
lora['exclude'] = False
lora['file_name'] = target_name
# Identification can be by hash, modelVersionId, or modelName
for i, lora in enumerate(recipe_data.get('loras', [])):
match_found = False
# Try to match by available identifiers
if 'hash' in lora and 'hash' in lora_data and lora['hash'] == lora_data['hash']:
match_found = True
elif 'modelVersionId' in lora and 'modelVersionId' in lora_data and lora['modelVersionId'] == lora_data['modelVersionId']:
match_found = True
elif 'modelName' in lora and 'modelName' in lora_data and lora['modelName'] == lora_data['modelName']:
match_found = True
if match_found:
# Update LoRA data
lora['isDeleted'] = False
lora['file_name'] = target_name
# Update with information from the target LoRA
if 'sha256' in target_lora:
lora['hash'] = target_lora['sha256'].lower()
if target_lora.get("civitai"):
lora['modelName'] = target_lora['civitai']['model']['name']
lora['modelVersionName'] = target_lora['civitai']['name']
lora['modelVersionId'] = target_lora['civitai']['id']
# Keep original fields for identification
# Mark as found and store updated lora
found = True
updated_lora = dict(lora) # Make a copy for response
break
if not found:
return web.json_response({"error": "Could not find matching deleted LoRA in recipe"}, status=404)
# Update with information from the target LoRA
if 'sha256' in target_lora:
lora['hash'] = target_lora['sha256'].lower()
if target_lora.get("civitai"):
lora['modelName'] = target_lora['civitai']['model']['name']
lora['modelVersionName'] = target_lora['civitai']['name']
lora['modelVersionId'] = target_lora['civitai']['id']
updated_lora = dict(lora) # Make a copy for response
# Recalculate recipe fingerprint after updating LoRA
from ..utils.utils import calculate_recipe_fingerprint
recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', []))
@@ -1326,7 +1376,7 @@ class RecipeRoutes:
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
updated_lora['inLibrary'] = True
updated_lora['preview_url'] = target_lora['preview_url']
updated_lora['preview_url'] = config.get_preview_static_url(target_lora['preview_url'])
updated_lora['localPath'] = target_lora['file_path']
# Update in cache if it exists
@@ -1401,9 +1451,9 @@ class RecipeRoutes:
if 'loras' in recipe:
for lora in recipe['loras']:
if 'hash' in lora and lora['hash']:
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower())
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower())
# Ensure file_url is set (needed by frontend)
if 'file_path' in recipe:

500
py/routes/stats_routes.py Normal file
View File

@@ -0,0 +1,500 @@
import os
import json
import jinja2
from aiohttp import web
import logging
from datetime import datetime, timedelta
from collections import defaultdict, Counter
from typing import Dict, List, Any
from ..config import config
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry
from ..utils.usage_stats import UsageStats
logger = logging.getLogger(__name__)
class StatsRoutes:
"""Route handlers for Statistics page and API endpoints"""
def __init__(self):
self.lora_scanner = None
self.checkpoint_scanner = None
self.embedding_scanner = None
self.usage_stats = None
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True
)
async def init_services(self):
"""Initialize services from ServiceRegistry"""
self.lora_scanner = await ServiceRegistry.get_lora_scanner()
self.checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.usage_stats = UsageStats()
async def handle_stats_page(self, request: web.Request) -> web.Response:
"""Handle GET /statistics request"""
try:
# Ensure services are initialized
await self.init_services()
# Check if scanners are initializing
lora_initializing = (
self.lora_scanner._cache is None or
(hasattr(self.lora_scanner, 'is_initializing') and self.lora_scanner.is_initializing())
)
checkpoint_initializing = (
self.checkpoint_scanner._cache is None or
(hasattr(self.checkpoint_scanner, '_is_initializing') and self.checkpoint_scanner._is_initializing)
)
embedding_initializing = (
self.embedding_scanner._cache is None or
(hasattr(self.embedding_scanner, 'is_initializing') and self.embedding_scanner.is_initializing())
)
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
template = self.template_env.get_template('statistics.html')
rendered = template.render(
is_initializing=is_initializing,
settings=settings,
request=request
)
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling statistics request: {e}", exc_info=True)
return web.Response(
text="Error loading statistics page",
status=500
)
async def get_collection_overview(self, request: web.Request) -> web.Response:
"""Get collection overview statistics"""
try:
await self.init_services()
# Get LoRA statistics
lora_cache = await self.lora_scanner.get_cached_data()
lora_count = len(lora_cache.raw_data)
lora_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data)
# Get Checkpoint statistics
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
checkpoint_count = len(checkpoint_cache.raw_data)
checkpoint_size = sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
# Get Embedding statistics
embedding_cache = await self.embedding_scanner.get_cached_data()
embedding_count = len(embedding_cache.raw_data)
embedding_size = sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
# Get usage statistics
usage_data = await self.usage_stats.get_stats()
return web.json_response({
'success': True,
'data': {
'total_models': lora_count + checkpoint_count + embedding_count,
'lora_count': lora_count,
'checkpoint_count': checkpoint_count,
'embedding_count': embedding_count,
'total_size': lora_size + checkpoint_size + embedding_size,
'lora_size': lora_size,
'checkpoint_size': checkpoint_size,
'embedding_size': embedding_size,
'total_generations': usage_data.get('total_executions', 0),
'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})),
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})),
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
}
})
except Exception as e:
logger.error(f"Error getting collection overview: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_usage_analytics(self, request: web.Request) -> web.Response:
"""Get usage analytics data"""
try:
await self.init_services()
# Get usage statistics
usage_data = await self.usage_stats.get_stats()
# Get model data for enrichment
lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Create hash to model mapping
lora_map = {lora['sha256']: lora for lora in lora_cache.raw_data}
checkpoint_map = {cp['sha256']: cp for cp in checkpoint_cache.raw_data}
embedding_map = {emb['sha256']: emb for emb in embedding_cache.raw_data}
# Prepare top used models
top_loras = self._get_top_used_models(usage_data.get('loras', {}), lora_map, 10)
top_checkpoints = self._get_top_used_models(usage_data.get('checkpoints', {}), checkpoint_map, 10)
top_embeddings = self._get_top_used_models(usage_data.get('embeddings', {}), embedding_map, 10)
# Prepare usage timeline (last 30 days)
timeline = self._get_usage_timeline(usage_data, 30)
return web.json_response({
'success': True,
'data': {
'top_loras': top_loras,
'top_checkpoints': top_checkpoints,
'top_embeddings': top_embeddings,
'usage_timeline': timeline,
'total_executions': usage_data.get('total_executions', 0)
}
})
except Exception as e:
logger.error(f"Error getting usage analytics: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_base_model_distribution(self, request: web.Request) -> web.Response:
"""Get base model distribution statistics"""
try:
await self.init_services()
# Get model data
lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Count by base model
lora_base_models = Counter(lora.get('base_model', 'Unknown') for lora in lora_cache.raw_data)
checkpoint_base_models = Counter(cp.get('base_model', 'Unknown') for cp in checkpoint_cache.raw_data)
embedding_base_models = Counter(emb.get('base_model', 'Unknown') for emb in embedding_cache.raw_data)
return web.json_response({
'success': True,
'data': {
'loras': dict(lora_base_models),
'checkpoints': dict(checkpoint_base_models),
'embeddings': dict(embedding_base_models)
}
})
except Exception as e:
logger.error(f"Error getting base model distribution: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_tag_analytics(self, request: web.Request) -> web.Response:
"""Get tag usage analytics"""
try:
await self.init_services()
# Get model data
lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Count tag frequencies
all_tags = []
for lora in lora_cache.raw_data:
all_tags.extend(lora.get('tags', []))
for cp in checkpoint_cache.raw_data:
all_tags.extend(cp.get('tags', []))
for emb in embedding_cache.raw_data:
all_tags.extend(emb.get('tags', []))
tag_counts = Counter(all_tags)
# Get top 50 tags
top_tags = [{'tag': tag, 'count': count} for tag, count in tag_counts.most_common(50)]
return web.json_response({
'success': True,
'data': {
'top_tags': top_tags,
'total_unique_tags': len(tag_counts)
}
})
except Exception as e:
logger.error(f"Error getting tag analytics: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_storage_analytics(self, request: web.Request) -> web.Response:
"""Get storage usage analytics"""
try:
await self.init_services()
# Get usage statistics
usage_data = await self.usage_stats.get_stats()
# Get model data
lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Create models with usage data
lora_storage = []
for lora in lora_cache.raw_data:
usage_count = 0
if lora['sha256'] in usage_data.get('loras', {}):
usage_count = usage_data['loras'][lora['sha256']].get('total', 0)
lora_storage.append({
'name': lora['model_name'],
'size': lora.get('size', 0),
'usage_count': usage_count,
'folder': lora.get('folder', ''),
'base_model': lora.get('base_model', 'Unknown')
})
checkpoint_storage = []
for cp in checkpoint_cache.raw_data:
usage_count = 0
if cp['sha256'] in usage_data.get('checkpoints', {}):
usage_count = usage_data['checkpoints'][cp['sha256']].get('total', 0)
checkpoint_storage.append({
'name': cp['model_name'],
'size': cp.get('size', 0),
'usage_count': usage_count,
'folder': cp.get('folder', ''),
'base_model': cp.get('base_model', 'Unknown')
})
embedding_storage = []
for emb in embedding_cache.raw_data:
usage_count = 0
if emb['sha256'] in usage_data.get('embeddings', {}):
usage_count = usage_data['embeddings'][emb['sha256']].get('total', 0)
embedding_storage.append({
'name': emb['model_name'],
'size': emb.get('size', 0),
'usage_count': usage_count,
'folder': emb.get('folder', ''),
'base_model': emb.get('base_model', 'Unknown')
})
# Sort by size
lora_storage.sort(key=lambda x: x['size'], reverse=True)
checkpoint_storage.sort(key=lambda x: x['size'], reverse=True)
embedding_storage.sort(key=lambda x: x['size'], reverse=True)
return web.json_response({
'success': True,
'data': {
'loras': lora_storage[:20], # Top 20 by size
'checkpoints': checkpoint_storage[:20],
'embeddings': embedding_storage[:20]
}
})
except Exception as e:
logger.error(f"Error getting storage analytics: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_insights(self, request: web.Request) -> web.Response:
"""Get smart insights about the collection"""
try:
await self.init_services()
# Get usage statistics
usage_data = await self.usage_stats.get_stats()
# Get model data
lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
insights = []
# Calculate unused models
unused_loras = self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {}))
unused_checkpoints = self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
unused_embeddings = self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
total_loras = len(lora_cache.raw_data)
total_checkpoints = len(checkpoint_cache.raw_data)
total_embeddings = len(embedding_cache.raw_data)
if total_loras > 0:
unused_lora_percent = (unused_loras / total_loras) * 100
if unused_lora_percent > 50:
insights.append({
'type': 'warning',
'title': 'High Number of Unused LoRAs',
'description': f'{unused_lora_percent:.1f}% of your LoRAs ({unused_loras}/{total_loras}) have never been used.',
'suggestion': 'Consider organizing or archiving unused models to free up storage space.'
})
if total_checkpoints > 0:
unused_checkpoint_percent = (unused_checkpoints / total_checkpoints) * 100
if unused_checkpoint_percent > 30:
insights.append({
'type': 'warning',
'title': 'Unused Checkpoints Detected',
'description': f'{unused_checkpoint_percent:.1f}% of your checkpoints ({unused_checkpoints}/{total_checkpoints}) have never been used.',
'suggestion': 'Review and consider removing checkpoints you no longer need.'
})
if total_embeddings > 0:
unused_embedding_percent = (unused_embeddings / total_embeddings) * 100
if unused_embedding_percent > 50:
insights.append({
'type': 'warning',
'title': 'High Number of Unused Embeddings',
'description': f'{unused_embedding_percent:.1f}% of your embeddings ({unused_embeddings}/{total_embeddings}) have never been used.',
'suggestion': 'Consider organizing or archiving unused embeddings to optimize your collection.'
})
# Storage insights
total_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data) + \
sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data) + \
sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
if total_size > 100 * 1024 * 1024 * 1024: # 100GB
insights.append({
'type': 'info',
'title': 'Large Collection Detected',
'description': f'Your model collection is using {self._format_size(total_size)} of storage.',
'suggestion': 'Consider using external storage or cloud solutions for better organization.'
})
# Recent activity insight
if usage_data.get('total_executions', 0) > 100:
insights.append({
'type': 'success',
'title': 'Active User',
'description': f'You\'ve completed {usage_data["total_executions"]} generations so far!',
'suggestion': 'Keep exploring and creating amazing content with your models.'
})
return web.json_response({
'success': True,
'data': {
'insights': insights
}
})
except Exception as e:
logger.error(f"Error getting insights: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
def _count_unused_models(self, models: List[Dict], usage_data: Dict) -> int:
"""Count models that have never been used"""
used_hashes = set(usage_data.keys())
unused_count = 0
for model in models:
if model.get('sha256') not in used_hashes:
unused_count += 1
return unused_count
def _get_top_used_models(self, usage_data: Dict, model_map: Dict, limit: int) -> List[Dict]:
"""Get top used models with their metadata"""
sorted_usage = sorted(usage_data.items(), key=lambda x: x[1].get('total', 0), reverse=True)
top_models = []
for sha256, usage_info in sorted_usage[:limit]:
if sha256 in model_map:
model = model_map[sha256]
top_models.append({
'name': model['model_name'],
'usage_count': usage_info.get('total', 0),
'base_model': model.get('base_model', 'Unknown'),
'preview_url': config.get_preview_static_url(model.get('preview_url', '')),
'folder': model.get('folder', '')
})
return top_models
def _get_usage_timeline(self, usage_data: Dict, days: int) -> List[Dict]:
"""Get usage timeline for the past N days"""
timeline = []
today = datetime.now()
for i in range(days):
date = today - timedelta(days=i)
date_str = date.strftime('%Y-%m-%d')
lora_usage = 0
checkpoint_usage = 0
embedding_usage = 0
# Count usage for this date
for model_usage in usage_data.get('loras', {}).values():
if isinstance(model_usage, dict) and 'history' in model_usage:
lora_usage += model_usage['history'].get(date_str, 0)
for model_usage in usage_data.get('checkpoints', {}).values():
if isinstance(model_usage, dict) and 'history' in model_usage:
checkpoint_usage += model_usage['history'].get(date_str, 0)
for model_usage in usage_data.get('embeddings', {}).values():
if isinstance(model_usage, dict) and 'history' in model_usage:
embedding_usage += model_usage['history'].get(date_str, 0)
timeline.append({
'date': date_str,
'lora_usage': lora_usage,
'checkpoint_usage': checkpoint_usage,
'embedding_usage': embedding_usage,
'total_usage': lora_usage + checkpoint_usage + embedding_usage
})
return list(reversed(timeline)) # Oldest to newest
def _format_size(self, size_bytes: int) -> str:
"""Format file size in human readable format"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if size_bytes < 1024.0:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024.0
return f"{size_bytes:.1f} PB"
def setup_routes(self, app: web.Application):
"""Register routes with the application"""
# Add an app startup handler to initialize services
app.on_startup.append(self._on_startup)
# Register page route
app.router.add_get('/statistics', self.handle_stats_page)
# Register API routes
app.router.add_get('/api/stats/collection-overview', self.get_collection_overview)
app.router.add_get('/api/stats/usage-analytics', self.get_usage_analytics)
app.router.add_get('/api/stats/base-model-distribution', self.get_base_model_distribution)
app.router.add_get('/api/stats/tag-analytics', self.get_tag_analytics)
app.router.add_get('/api/stats/storage-analytics', self.get_storage_analytics)
app.router.add_get('/api/stats/insights', self.get_insights)
async def _on_startup(self, app):
"""Initialize services when the app starts"""
await self.init_services()

View File

@@ -2,8 +2,13 @@ import os
import aiohttp
import logging
import toml
import git
import zipfile
import shutil
import tempfile
from aiohttp import web
from typing import Dict, Any, List
from typing import Dict, List
logger = logging.getLogger(__name__)
@@ -13,7 +18,9 @@ class UpdateRoutes:
@staticmethod
def setup_routes(app):
"""Register update check routes"""
app.router.add_get('/loras/api/check-updates', UpdateRoutes.check_updates)
app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
app.router.add_post('/api/perform-update', UpdateRoutes.perform_update)
@staticmethod
async def check_updates(request):
@@ -22,24 +29,39 @@ class UpdateRoutes:
Returns update status and version information
"""
try:
nightly = request.query.get('nightly', 'false').lower() == 'true'
# Read local version from pyproject.toml
local_version = UpdateRoutes._get_local_version()
# Get git info (commit hash, branch)
git_info = UpdateRoutes._get_git_info()
# Fetch remote version from GitHub
remote_version, changelog = await UpdateRoutes._get_remote_version()
if nightly:
remote_version, changelog = await UpdateRoutes._get_nightly_version()
else:
remote_version, changelog = await UpdateRoutes._get_remote_version()
# Compare versions
update_available = UpdateRoutes._compare_versions(
local_version.replace('v', ''),
remote_version.replace('v', '')
)
if nightly:
# For nightly, compare commit hashes
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
else:
# For stable, compare semantic versions
update_available = UpdateRoutes._compare_versions(
local_version.replace('v', ''),
remote_version.replace('v', '')
)
return web.json_response({
'success': True,
'current_version': local_version,
'latest_version': remote_version,
'update_available': update_available,
'changelog': changelog
'changelog': changelog,
'git_info': git_info,
'nightly': nightly
})
except Exception as e:
@@ -48,6 +70,279 @@ class UpdateRoutes:
'success': False,
'error': str(e)
})
@staticmethod
async def get_version_info(request):
"""
Returns the current version in the format 'version-short_hash'
"""
try:
# Read local version from pyproject.toml
local_version = UpdateRoutes._get_local_version().replace('v', '')
# Get git info (commit hash, branch)
git_info = UpdateRoutes._get_git_info()
short_hash = git_info['short_hash']
# Format: version-short_hash
version_string = f"{local_version}-{short_hash}"
return web.json_response({
'success': True,
'version': version_string
})
except Exception as e:
logger.error(f"Failed to get version info: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
})
@staticmethod
async def perform_update(request):
"""
Perform Git-based update to latest release tag or main branch.
If .git is missing, fallback to ZIP download.
"""
try:
body = await request.json() if request.has_body else {}
nightly = body.get('nightly', False)
current_dir = os.path.dirname(os.path.abspath(__file__))
plugin_root = os.path.dirname(os.path.dirname(current_dir))
settings_path = os.path.join(plugin_root, 'settings.json')
settings_backup = None
if os.path.exists(settings_path):
with open(settings_path, 'r', encoding='utf-8') as f:
settings_backup = f.read()
logger.info("Backed up settings.json")
git_folder = os.path.join(plugin_root, '.git')
if os.path.exists(git_folder):
# Git update
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
else:
# Fallback: Download ZIP and replace files
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
if settings_backup and success:
with open(settings_path, 'w', encoding='utf-8') as f:
f.write(settings_backup)
logger.info("Restored settings.json")
if success:
return web.json_response({
'success': True,
'message': f'Successfully updated to {new_version}',
'new_version': new_version
})
else:
return web.json_response({
'success': False,
'error': 'Failed to complete update'
})
except Exception as e:
logger.error(f"Failed to perform update: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
})
@staticmethod
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
"""
Download latest release ZIP from GitHub and replace plugin files.
Skips settings.json. Writes extracted file list to .tracking.
"""
repo_owner = "willmiao"
repo_name = "ComfyUI-Lora-Manager"
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
try:
async with aiohttp.ClientSession() as session:
async with session.get(github_api) as resp:
if resp.status != 200:
logger.error(f"Failed to fetch release info: {resp.status}")
return False, ""
data = await resp.json()
zip_url = data.get("zipball_url")
version = data.get("tag_name", "unknown")
# Download ZIP
async with session.get(zip_url) as zip_resp:
if zip_resp.status != 200:
logger.error(f"Failed to download ZIP: {zip_resp.status}")
return False, ""
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
tmp_zip.write(await zip_resp.read())
zip_path = tmp_zip.name
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json'])
# Extract ZIP to temp dir
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
# Find extracted folder (GitHub ZIP contains a root folder)
extracted_root = next(os.scandir(tmp_dir)).path
# Copy files, skipping settings.json
for item in os.listdir(extracted_root):
src = os.path.join(extracted_root, item)
dst = os.path.join(plugin_root, item)
if os.path.isdir(src):
if os.path.exists(dst):
shutil.rmtree(dst)
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json'))
else:
if item == 'settings.json':
continue
shutil.copy2(src, dst)
# Write .tracking file: list all files under extracted_root, relative to extracted_root
# for ComfyUI Manager to work properly
tracking_info_file = os.path.join(plugin_root, '.tracking')
tracking_files = []
for root, dirs, files in os.walk(extracted_root):
for file in files:
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
tracking_files.append(rel_path.replace("\\", "/"))
with open(tracking_info_file, "w", encoding='utf-8') as file:
file.write('\n'.join(tracking_files))
os.remove(zip_path)
logger.info(f"Updated plugin via ZIP to {version}")
return True, version
except Exception as e:
logger.error(f"ZIP update failed: {e}", exc_info=True)
return False, ""
def _clean_plugin_folder(plugin_root, skip_files=None):
skip_files = skip_files or []
for item in os.listdir(plugin_root):
if item in skip_files:
continue
path = os.path.join(plugin_root, item)
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
@staticmethod
async def _get_nightly_version() -> tuple[str, List[str]]:
"""
Fetch latest commit from main branch
"""
repo_owner = "willmiao"
repo_name = "ComfyUI-Lora-Manager"
# Use GitHub API to fetch the latest commit from main branch
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
try:
async with aiohttp.ClientSession() as session:
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
if response.status != 200:
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
return "main", []
data = await response.json()
commit_sha = data.get('sha', '')[:7] # Short hash
commit_message = data.get('commit', {}).get('message', '')
# Format as "main-{short_hash}"
version = f"main-{commit_sha}"
# Use commit message as changelog
changelog = [commit_message] if commit_message else []
return version, changelog
except Exception as e:
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
return "main", []
@staticmethod
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
"""
Compare local commit hash with remote main branch
"""
try:
local_hash = local_git_info.get('short_hash', 'unknown')
if local_hash == 'unknown':
return True # Assume update available if we can't get local hash
# Extract remote hash from version string (format: "main-{hash}")
if '-' in remote_version:
remote_hash = remote_version.split('-')[-1]
return local_hash != remote_hash
return True # Default to update available
except Exception as e:
logger.error(f"Error comparing nightly versions: {e}")
return False
@staticmethod
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
"""
Perform Git-based update using GitPython
Args:
plugin_root: Path to the plugin root directory
nightly: Whether to update to main branch or latest release
Returns:
tuple: (success, new_version)
"""
try:
# Open the Git repository
repo = git.Repo(plugin_root)
# Fetch latest changes
origin = repo.remotes.origin
origin.fetch()
if nightly:
# Switch to main branch and pull latest
main_branch = 'main'
if main_branch not in [branch.name for branch in repo.branches]:
# Create local main branch if it doesn't exist
repo.create_head(main_branch, origin.refs.main)
repo.heads[main_branch].checkout()
origin.pull(main_branch)
# Get new commit hash
new_version = f"main-{repo.head.commit.hexsha[:7]}"
else:
# Get latest release tag
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
if not tags:
logger.error("No tags found in repository")
return False, ""
latest_tag = tags[0]
# Checkout to latest tag
repo.git.checkout(latest_tag.name)
new_version = latest_tag.name
logger.info(f"Successfully updated to {new_version}")
return True, new_version
except git.exc.GitError as e:
logger.error(f"Git error during update: {e}")
return False, ""
except Exception as e:
logger.error(f"Error during Git update: {e}")
return False, ""
@staticmethod
def _get_local_version() -> str:
@@ -72,6 +367,35 @@ class UpdateRoutes:
logger.error(f"Failed to get local version: {e}", exc_info=True)
return "v0.0.0"
@staticmethod
def _get_git_info() -> Dict[str, str]:
"""Get Git repository information"""
current_dir = os.path.dirname(os.path.abspath(__file__))
plugin_root = os.path.dirname(os.path.dirname(current_dir))
git_info = {
'commit_hash': 'unknown',
'short_hash': 'stable',
'branch': 'unknown',
'commit_date': 'unknown'
}
try:
# Check if we're in a git repository
if not os.path.exists(os.path.join(plugin_root, '.git')):
return git_info
repo = git.Repo(plugin_root)
commit = repo.head.commit
git_info['commit_hash'] = commit.hexsha
git_info['short_hash'] = commit.hexsha[:7]
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached'
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d')
except Exception as e:
logger.warning(f"Error getting git info: {e}")
return git_info
@staticmethod
async def _get_remote_version() -> tuple[str, List[str]]:
"""

View File

@@ -0,0 +1,422 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Type
import logging
import os
from ..utils.models import BaseModelMetadata
from ..utils.constants import NSFW_LEVELS
from .settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__)
class BaseModelService(ABC):
"""Base service class for all model types"""
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
"""Initialize the service
Args:
model_type: Type of model (lora, checkpoint, etc.)
scanner: Model scanner instance
metadata_class: Metadata class for this model type
"""
self.model_type = model_type
self.scanner = scanner
self.metadata_class = metadata_class
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None, hash_filters: dict = None,
favorites_only: bool = False, **kwargs) -> Dict:
"""Get paginated and filtered model data
Args:
page: Page number (1-based)
page_size: Number of items per page
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
folder: Folder filter
search: Search term
fuzzy_search: Whether to use fuzzy search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Search options dict
hash_filters: Hash filtering options
favorites_only: Filter for favorites only
**kwargs: Additional model-specific filters
Returns:
Dict containing paginated results
"""
cache = await self.scanner.get_cached_data()
# Parse sort_by into sort_key and order
if ':' in sort_by:
sort_key, order = sort_by.split(':', 1)
sort_key = sort_key.strip()
order = order.strip().lower()
if order not in ('asc', 'desc'):
order = 'asc'
else:
sort_key = sort_by.strip()
order = 'asc'
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set using new sort logic
filtered_data = await cache.get_sorted_data(sort_key, order)
# Apply hash filtering if provided (highest priority)
if hash_filters:
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
# Jump to pagination for hash filters
return self._paginate(filtered_data, page, page_size)
# Apply common filters
filtered_data = await self._apply_common_filters(
filtered_data, folder, base_models, tags, favorites_only, search_options
)
# Apply search filtering
if search:
filtered_data = await self._apply_search_filters(
filtered_data, search, fuzzy_search, search_options
)
# Apply model-specific filters
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
return self._paginate(filtered_data, page, page_size)
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
"""Apply hash-based filtering"""
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower()
return [
item for item in data
if item.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes)
return [
item for item in data
if item.get('sha256', '').lower() in hash_set
]
return data
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
base_models: list = None, tags: list = None,
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
"""Apply common filters that work across all model types"""
# Apply SFW filtering if enabled in settings
if settings.get('show_only_sfw', False):
data = [
item for item in data
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
data = [
item for item in data
if item.get('favorite', False) is True
]
# Apply folder filtering
if folder is not None:
if search_options and search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
data = [
item for item in data
if item['folder'].startswith(folder)
]
else:
# Exact folder filtering
data = [
item for item in data
if item['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
data = [
item for item in data
if item.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
data = [
item for item in data
if any(tag in item.get('tags', []) for tag in tags)
]
return data
async def _apply_search_filters(self, data: List[Dict], search: str,
fuzzy_search: bool, search_options: dict) -> List[Dict]:
"""Apply search filtering"""
search_results = []
for item in data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(item.get('file_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('file_name', '').lower():
search_results.append(item)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(item.get('model_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('model_name', '').lower():
search_results.append(item)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in item:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
for tag in item['tags']):
search_results.append(item)
continue
# Search by creator
civitai = item.get('civitai')
creator_username = ''
if civitai and isinstance(civitai, dict):
creator = civitai.get('creator')
if creator and isinstance(creator, dict):
creator_username = creator.get('username', '')
if search_options.get('creator', False) and creator_username:
if fuzzy_search:
if fuzzy_match(creator_username, search):
search_results.append(item)
continue
elif search.lower() in creator_username.lower():
search_results.append(item)
continue
return search_results
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed"""
return data
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
"""Apply pagination to filtered data"""
total_items = len(data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
return {
'items': data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
@abstractmethod
async def format_response(self, model_data: Dict) -> Dict:
"""Format model data for API response - must be implemented by subclasses"""
pass
# Common service methods that delegate to scanner
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
"""Get top tags sorted by frequency"""
return await self.scanner.get_top_tags(limit)
async def get_base_models(self, limit: int = 20) -> List[Dict]:
"""Get base models sorted by frequency"""
return await self.scanner.get_base_models(limit)
def has_hash(self, sha256: str) -> bool:
"""Check if a model with given hash exists"""
return self.scanner.has_hash(sha256)
def get_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a model by its hash"""
return self.scanner.get_path_by_hash(sha256)
def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path"""
return self.scanner.get_hash_by_path(file_path)
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
"""Trigger model scanning"""
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
async def get_model_info_by_name(self, name: str):
"""Get model information by name"""
return await self.scanner.get_model_info_by_name(name)
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
return self.scanner.get_model_roots()
async def get_folder_tree(self, model_root: str) -> Dict:
"""Get hierarchical folder tree for a specific model root"""
cache = await self.scanner.get_cached_data()
# Build tree structure from folders
tree = {}
for folder in cache.folders:
# Check if this folder belongs to the specified model root
folder_belongs_to_root = False
for root in self.scanner.get_model_roots():
if root == model_root:
folder_belongs_to_root = True
break
if not folder_belongs_to_root:
continue
# Split folder path into components
parts = folder.split('/') if folder else []
current_level = tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return tree
async def get_unified_folder_tree(self) -> Dict:
"""Get unified folder tree across all model roots"""
cache = await self.scanner.get_cached_data()
# Build unified tree structure by analyzing all relative paths
unified_tree = {}
# Get all model roots for path normalization
model_roots = self.scanner.get_model_roots()
for folder in cache.folders:
if not folder: # Skip empty folders
continue
# Find which root this folder belongs to by checking the actual file paths
# This is a simplified approach - we'll use the folder as-is since it should already be relative
relative_path = folder
# Split folder path into components
parts = relative_path.split('/')
current_level = unified_tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return unified_tree
async def get_model_notes(self, model_name: str) -> Optional[str]:
"""Get notes for a specific model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
return model.get('notes', '')
return None
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
"""Get the static preview URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
preview_url = model.get('preview_url')
if preview_url:
from ..config import config
return config.get_preview_static_url(preview_url)
return None
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
civitai_data = model.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
"""Search model relative file paths for autocomplete functionality"""
cache = await self.scanner.get_cached_data()
matching_paths = []
search_lower = search_term.lower()
# Get model roots for path calculation
model_roots = self.scanner.get_model_roots()
for model in cache.raw_data:
file_path = model.get('file_path', '')
if not file_path:
continue
# Calculate relative path from model root
relative_path = None
for root in model_roots:
# Normalize paths for comparison
normalized_root = os.path.normpath(root).replace(os.sep, '/')
normalized_file = os.path.normpath(file_path).replace(os.sep, '/')
if normalized_file.startswith(normalized_root):
# Remove root and leading slash to get relative path
relative_path = normalized_file[len(normalized_root):].lstrip('/')
break
if relative_path and search_lower in relative_path.lower():
matching_paths.append(relative_path)
if len(matching_paths) >= limit * 2: # Get more for better sorting
break
# Sort by relevance (exact matches first, then by length)
matching_paths.sort(key=lambda x: (
not x.lower().startswith(search_lower), # Exact prefix matches first
len(x), # Then by length (shorter first)
x.lower() # Then alphabetically
))
return matching_paths[:limit]

View File

@@ -1,131 +1,34 @@
import os
import logging
import asyncio
from typing import List, Dict, Optional, Set
import folder_paths # type: ignore
from typing import List
from ..utils.models import CheckpointMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
class CheckpointScanner(ModelScanner):
"""Service for scanning and managing checkpoint files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
self._checkpoint_roots = self._init_checkpoint_roots()
self._initialized = True
# Define supported file extensions
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
super().__init__(
model_type="checkpoint",
model_class=CheckpointMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
def adjust_metadata(self, metadata, file_path, root_path):
if hasattr(metadata, "model_type"):
if root_path in config.checkpoints_roots:
metadata.model_type = "checkpoint"
elif root_path in config.unet_roots:
metadata.model_type = "diffusion_model"
return metadata
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _init_checkpoint_roots(self) -> List[str]:
"""Initialize checkpoint roots from ComfyUI settings"""
# Get both checkpoint and diffusion_models paths
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
# Combine, normalize and deduplicate paths
all_paths = set()
for path in checkpoint_paths + diffusion_paths:
if os.path.exists(path):
norm_path = path.replace(os.sep, "/")
all_paths.add(norm_path)
# Sort for consistent order
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
return sorted_paths
def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories"""
return self._checkpoint_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all checkpoint directories and return metadata"""
all_checkpoints = []
# Create scan tasks for each directory
scan_tasks = []
for root in self._checkpoint_roots:
task = asyncio.create_task(self._scan_directory(root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
checkpoints = await task
all_checkpoints.extend(checkpoints)
except Exception as e:
logger.error(f"Error scanning checkpoint directory: {e}")
return all_checkpoints
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a directory for checkpoint files"""
checkpoints = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
# Check if file has supported extension
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, checkpoints)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return checkpoints
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
"""Process a single checkpoint file and add to results"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
checkpoints.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
return config.base_models_roots

View File

@@ -0,0 +1,51 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import CheckpointMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class CheckpointService(BaseModelService):
"""Checkpoint-specific service implementation"""
def __init__(self, scanner):
"""Initialize Checkpoint service
Args:
scanner: Checkpoint scanner instance
"""
super().__init__("checkpoint", scanner, CheckpointMetadata)
async def format_response(self, checkpoint_data: Dict) -> Dict:
"""Format Checkpoint data for API response"""
return {
"model_name": checkpoint_data["model_name"],
"file_name": checkpoint_data["file_name"],
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
"base_model": checkpoint_data.get("base_model", ""),
"folder": checkpoint_data["folder"],
"sha256": checkpoint_data.get("sha256", ""),
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
"file_size": checkpoint_data.get("size", 0),
"modified": checkpoint_data.get("modified", ""),
"tags": checkpoint_data.get("tags", []),
"modelDescription": checkpoint_data.get("modelDescription", ""),
"from_civitai": checkpoint_data.get("from_civitai", True),
"notes": checkpoint_data.get("notes", ""),
"model_type": checkpoint_data.get("model_type", "checkpoint"),
"favorite": checkpoint_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
}
def find_duplicate_hashes(self) -> Dict:
"""Find Checkpoints with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find Checkpoints with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -1,13 +1,11 @@
from datetime import datetime
import aiohttp
import os
import json
import logging
import asyncio
from email.parser import Parser
from typing import Optional, Dict, Tuple, List
from urllib.parse import unquote
from ..utils.models import LoraMetadata
logger = logging.getLogger(__name__)
@@ -35,8 +33,8 @@ class CivitaiClient:
}
self._session = None
self._session_created_at = None
# Set default buffer size to 1MB for higher throughput
self.chunk_size = 1024 * 1024
# Adjust chunk size based on storage type - consider making this configurable
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput
@property
async def session(self) -> aiohttp.ClientSession:
@@ -45,14 +43,14 @@ class CivitaiClient:
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=3, # Further reduced from 5 to 3
ttl_dns_cache=0, # Disabled DNS caching completely
limit=8, # Increase from 3 to 8 for better parallelism
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
# Configure timeout parameters - increase read timeout for large files and remove sock_read timeout
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=trust_env,
@@ -104,7 +102,7 @@ class CivitaiClient:
return headers
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
"""Download file with content-disposition support and progress tracking
"""Download file with resumable downloads and retry mechanism
Args:
url: Download URL
@@ -115,73 +113,190 @@ class CivitaiClient:
Returns:
Tuple[bool, str]: (success, save_path or error message)
"""
logger.debug(f"Resolving DNS for: {url}")
max_retries = 5
retry_count = 0
base_delay = 2.0 # Base delay for exponential backoff
# Initial setup
session = await self._ensure_fresh_session()
try:
headers = self._get_request_headers()
# Add Range header to allow resumable downloads
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
logger.debug(f"Starting download from: {url}")
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
# Handle 401 unauthorized responses
if response.status == 401:
save_path = os.path.join(save_dir, default_filename)
part_path = save_path + '.part'
# Get existing file size for resume
resume_offset = 0
if os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Resuming download from offset {resume_offset} bytes")
total_size = 0
filename = default_filename
while retry_count <= max_retries:
try:
headers = self._get_request_headers()
# Add Range header for resume if we have partial data
if resume_offset > 0:
headers['Range'] = f'bytes={resume_offset}-'
# Add Range header to allow resumable downloads
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}")
if resume_offset > 0:
logger.debug(f"Requesting range from byte {resume_offset}")
async with session.get(url, headers=headers, allow_redirects=True) as response:
# Handle different response codes
if response.status == 200:
# Full content response
if resume_offset > 0:
# Server doesn't support ranges, restart from beginning
logger.warning("Server doesn't support range requests, restarting download")
resume_offset = 0
if os.path.exists(part_path):
os.remove(part_path)
elif response.status == 206:
# Partial content response (resume successful)
content_range = response.headers.get('Content-Range')
if content_range:
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
range_parts = content_range.split('/')
if len(range_parts) == 2:
total_size = int(range_parts[1])
logger.info(f"Successfully resumed download from byte {resume_offset}")
elif response.status == 416:
# Range not satisfiable - file might be complete or corrupted
if os.path.exists(part_path):
part_size = os.path.getsize(part_path)
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
# Try to get actual file size
head_response = await session.head(url, headers=self._get_request_headers())
if head_response.status == 200:
actual_size = int(head_response.headers.get('content-length', 0))
if part_size == actual_size:
# File is complete, just rename it
os.rename(part_path, save_path)
if progress_callback:
await progress_callback(100)
return True, save_path
# Remove corrupted part file and restart
os.remove(part_path)
resume_offset = 0
continue
elif response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
return False, "Invalid or missing CivitAI API key, or early access restriction."
# Handle other client errors that might be permission-related
if response.status == 403:
elif response.status == 403:
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
return False, "Access forbidden: You don't have permission to download this file."
else:
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get filename from content-disposition header (only on first attempt)
if retry_count == 0:
content_disposition = response.headers.get('Content-Disposition')
parsed_filename = self._parse_content_disposition(content_disposition)
if parsed_filename:
filename = parsed_filename
# Update paths with correct filename
save_path = os.path.join(save_dir, filename)
new_part_path = save_path + '.part'
# Rename existing part file if filename changed
if part_path != new_part_path and os.path.exists(part_path):
os.rename(part_path, new_part_path)
part_path = new_part_path
# Generic error response for other status codes
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get total file size for progress calculation (if not set from Content-Range)
if total_size == 0:
total_size = int(response.headers.get('content-length', 0))
if response.status == 206:
# For partial content, add the offset to get total file size
total_size += resume_offset
# Get filename from content-disposition header
content_disposition = response.headers.get('Content-Disposition')
filename = self._parse_content_disposition(content_disposition)
if not filename:
filename = default_filename
save_path = os.path.join(save_dir, filename)
# Get total file size for progress calculation
total_size = int(response.headers.get('content-length', 0))
current_size = 0
last_progress_report_time = datetime.now()
current_size = resume_offset
last_progress_report_time = datetime.now()
# Stream download to file with progress updates using larger buffer
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
f.write(chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 0.5:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
# Stream download to file with progress updates using larger buffer
loop = asyncio.get_running_loop()
mode = 'ab' if resume_offset > 0 else 'wb'
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Download completed successfully
# Verify file size if total_size was provided
final_size = os.path.getsize(part_path)
if total_size > 0 and final_size != total_size:
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
# Don't treat this as fatal error, rename anyway
# Atomically rename .part to final file with retries
max_rename_attempts = 5
rename_attempt = 0
rename_success = False
while rename_attempt < max_rename_attempts and not rename_success:
try:
os.rename(part_path, save_path)
rename_success = True
except PermissionError as e:
rename_attempt += 1
if rename_attempt < max_rename_attempts:
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
await asyncio.sleep(2) # Wait before retrying
else:
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
return False, f"Failed to finalize download: {str(e)}"
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}")
except aiohttp.ClientError as e:
logger.error(f"Network error during download: {e}")
return False, f"Network error: {str(e)}"
except Exception as e:
logger.error(f"Download error: {e}")
return False, str(e)
if retry_count <= max_retries:
# Calculate delay with exponential backoff
delay = base_delay * (2 ** (retry_count - 1))
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
# Update resume offset for next attempt
if os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Will resume from byte {resume_offset}")
# Refresh session to get new connection
await self.close()
session = await self._ensure_fresh_session()
continue
else:
logger.error(f"Max retries exceeded for download: {e}")
return False, f"Network error after {max_retries + 1} attempts: {str(e)}"
except Exception as e:
logger.error(f"Unexpected download error: {e}")
return False, str(e)
return False, f"Download failed after {max_retries + 1} attempts"
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try:
@@ -224,6 +339,89 @@ class CivitaiClient:
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata
Args:
model_id: The Civitai model ID (optional if version_id is provided)
version_id: Optional specific version ID to retrieve
Returns:
Optional[Dict]: The model version data with additional fields or None if not found
"""
try:
session = await self._ensure_fresh_session()
headers = self._get_request_headers()
# Case 1: Only version_id is provided
if model_id is None and version_id is not None:
# First get the version info to extract model_id
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response:
if response.status != 200:
return None
version = await response.json()
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
# Now get the model data for additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return version # Return version without additional metadata
model_data = await response.json()
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
return version
# Case 2: model_id is provided (with or without version_id)
elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
data = await response.json()
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 3: Get detailed version info using the version_id
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
if response.status != 200:
return None
version = await response.json()
# Step 4: Enrich version_info with model data
# Add description and tags from model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e:
logger.error(f"Error fetching model version: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from Civitai
@@ -346,3 +544,34 @@ class CivitaiClient:
except Exception as e:
logger.error(f"Error getting hash from Civitai: {e}")
return None
async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API
Args:
image_id: The Civitai image ID
Returns:
Optional[Dict]: The image data or None if not found
"""
try:
session = await self._ensure_fresh_session()
headers = self._get_request_headers()
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}")
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
if data and "items" in data and len(data["items"]) > 0:
logger.debug(f"Successfully fetched image info for ID: {image_id}")
return data["items"][0]
logger.warning(f"No image found with ID: {image_id}")
return None
logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})")
return None
except Exception as e:
error_msg = f"Error fetching image info: {e}"
logger.error(error_msg)
return None

View File

@@ -1,13 +1,15 @@
import logging
import os
import json
import asyncio
from typing import Optional, Dict, Any
from .civitai_client import CivitaiClient
from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH
from collections import OrderedDict
import uuid
from typing import Dict
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry
from .settings_manager import settings
# Download to temporary file first
import tempfile
@@ -33,20 +35,16 @@ class DownloadManager:
self._initialized = True
self._civitai_client = None # Will be lazily initialized
# Add download management
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_monitor(self):
"""Get the lora file monitor from registry"""
return await ServiceRegistry.get_lora_monitor()
async def _get_checkpoint_monitor(self):
"""Get the checkpoint file monitor from registry"""
return await ServiceRegistry.get_checkpoint_monitor()
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
@@ -55,55 +53,218 @@ class DownloadManager:
async def _get_checkpoint_scanner(self):
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
model_version_id: str = None, save_dir: str = None,
relative_path: str = '', progress_callback=None,
model_type: str = "lora") -> Dict:
"""Download model from Civitai
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False,
download_id: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control
Args:
download_url: Direct download URL for the model
model_hash: SHA256 hash of the model
model_version_id: Civitai model version ID
save_dir: Directory to save the model to
model_id: Civitai model ID (optional if model_version_id is provided)
model_version_id: Civitai model version ID (optional if model_id is provided)
save_dir: Directory to save the model
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
model_type: Type of model ('lora' or 'checkpoint')
use_default_paths: Flag to use default paths
download_id: Unique identifier for this download task
Returns:
Dict with download result
"""
# Validate that at least one identifier is provided
if not model_id and not model_version_id:
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
# Use provided download_id or generate new one
task_id = download_id or str(uuid.uuid4())
# Register download task in tracking dict
self._active_downloads[task_id] = {
'model_id': model_id,
'model_version_id': model_version_id,
'progress': 0,
'status': 'queued'
}
# Create tracking task
download_task = asyncio.create_task(
self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths
)
)
# Store task for tracking and cancellation
self._download_tasks[task_id] = download_task
try:
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Wait for download to complete
result = await download_task
result['download_id'] = task_id # Include download_id in result
return result
except asyncio.CancelledError:
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
finally:
# Clean up task reference
if task_id in self._download_tasks:
del self._download_tasks[task_id]
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False):
"""Execute download with semaphore to limit concurrency"""
# Update status to waiting
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'waiting'
# Wrap progress callback to track progress in active_downloads
original_callback = progress_callback
async def tracking_callback(progress):
if task_id in self._active_downloads:
self._active_downloads[task_id]['progress'] = progress
if original_callback:
await original_callback(progress)
# Acquire semaphore to limit concurrent downloads
try:
async with self._download_semaphore:
# Update status to downloading
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'downloading'
# Use original download implementation
try:
# Check for cancellation before starting
if asyncio.current_task().cancelled():
raise asyncio.CancelledError()
result = await self._execute_original_download(
model_id, model_version_id, save_dir,
relative_path, tracking_callback, use_default_paths,
task_id
)
# Update status based on result
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
if not result['success']:
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
return result
except asyncio.CancelledError:
# Handle cancellation
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'cancelled'
logger.info(f"Download cancelled for task {task_id}")
raise
except Exception as e:
# Handle other errors
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'failed'
self._active_downloads[task_id]['error'] = str(e)
return {'success': False, 'error': str(e)}
finally:
# Schedule cleanup of download record after delay
asyncio.create_task(self._cleanup_download_record(task_id))
async def _cleanup_download_record(self, task_id: str):
"""Keep completed downloads in history for a short time"""
await asyncio.sleep(600) # Keep for 10 minutes
if task_id in self._active_downloads:
del self._active_downloads[task_id]
async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths,
download_id=None):
"""Wrapper for original download_from_civitai implementation"""
try:
# Check if model version already exists in library
if model_version_id is not None:
# Check both scanners
lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
# Check checkpoint scanner
if await checkpoint_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Check embedding scanner
if await embedding_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Get civitai client
civitai_client = await self._get_civitai_client()
# Get version info based on the provided identifier
version_info = None
error_msg = None
if model_hash:
# Get model by hash
version_info = await civitai_client.get_model_by_hash(model_hash)
elif model_version_id:
# Use model version ID directly
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
elif download_url:
# Extract version ID from download URL
version_id = download_url.split('/')[-1]
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
version_info = await civitai_client.get_model_version(model_id, model_version_id)
if not version_info:
if error_msg and "model not found" in error_msg.lower():
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
return {'success': False, 'error': 'Failed to fetch model metadata'}
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
if model_type_from_info == 'checkpoint':
model_type = 'checkpoint'
elif model_type_from_info in VALID_LORA_TYPES:
model_type = 'lora'
elif model_type_from_info == 'textualinversion':
model_type = 'embedding'
else:
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
# Case 2: model_version_id was None, check after getting version_info
if model_version_id is None:
version_id = version_info.get('id')
if model_type == 'lora':
# Check lora scanner
lora_scanner = await self._get_lora_scanner()
if await lora_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
elif model_type == 'checkpoint':
# Check checkpoint scanner
checkpoint_scanner = await self._get_checkpoint_scanner()
if await checkpoint_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
elif model_type == 'embedding':
# Embeddings are not checked in scanners, but we can still check if it exists
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
if await embedding_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Handle use_default_paths
if use_default_paths:
# Set save_dir based on model type
if model_type == 'checkpoint':
default_path = settings.get('default_checkpoint_root')
if not default_path:
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
save_dir = default_path
elif model_type == 'lora':
default_path = settings.get('default_lora_root')
if not default_path:
return {'success': False, 'error': 'Default lora root path not set in settings'}
save_dir = default_path
elif model_type == 'embedding':
default_path = settings.get('default_embedding_root')
if not default_path:
return {'success': False, 'error': 'Default embedding root path not set in settings'}
save_dir = default_path
# Calculate relative path using template
relative_path = self._calculate_relative_path(version_info, model_type)
# Update save directory with relative path if provided
if relative_path:
save_dir = os.path.join(save_dir, relative_path)
# Create directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Check if this is an early access model
if version_info.get('earlyAccessEndsAt'):
@@ -113,9 +274,9 @@ class DownloadManager:
from datetime import datetime
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
formatted_date = date_obj.strftime('%Y-%m-%d')
early_access_msg = f"This model requires early access payment (until {formatted_date}). "
early_access_msg = f"This model requires payment (until {formatted_date}). "
except:
early_access_msg = "This model requires early access payment. "
early_access_msg = "This model requires payment. "
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
@@ -137,28 +298,16 @@ class DownloadManager:
file_name = file_info['name']
save_path = os.path.join(save_dir, file_name)
# 4. Notify file monitor - use normalized path and file size
# file monitor is despreted, so we don't need to use it
# 5. Prepare metadata based on model type
if model_type == "checkpoint":
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating CheckpointMetadata for {file_name}")
else:
elif model_type == "lora":
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating LoraMetadata for {file_name}")
# 5.1 Get and update model tags, description and creator info
model_id = version_info.get('modelId')
if model_id:
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
if model_metadata:
if model_metadata.get("tags"):
metadata.tags = model_metadata.get("tags", [])
if model_metadata.get("description"):
metadata.modelDescription = model_metadata.get("description", "")
if model_metadata.get("creator"):
metadata.civitai["creator"] = model_metadata.get("creator")
elif model_type == "embedding":
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
logger.info(f"Creating EmbeddingMetadata for {file_name}")
# 6. Start download process
result = await self._execute_download(
@@ -168,9 +317,14 @@ class DownloadManager:
version_info=version_info,
relative_path=relative_path,
progress_callback=progress_callback,
model_type=model_type
model_type=model_type,
download_id=download_id
)
# If early_access_msg exists and download failed, replace error message
if 'early_access_msg' in locals() and not result.get('success', False):
result['error'] = early_access_msg
return result
except Exception as e:
@@ -181,15 +335,74 @@ class DownloadManager:
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
return {'success': False, 'error': str(e)}
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str:
"""Calculate relative path using template from settings
Args:
version_info: Version info from Civitai API
model_type: Type of model ('lora', 'checkpoint', 'embedding')
Returns:
Relative path string
"""
# Get path template from settings for specific model type
path_template = settings.get_download_path_template(model_type)
# If template is empty, return empty path (flat structure)
if not path_template:
return ''
# Get base model name
base_model = version_info.get('baseModel', '')
# Get author from creator data
creator_info = version_info.get('creator')
if creator_info and isinstance(creator_info, dict):
author = creator_info.get('username') or 'Anonymous'
else:
author = 'Anonymous'
# Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model)
# Get model tags
model_tags = version_info.get('model', {}).get('tags', [])
# Find the first Civitai model tag that exists in model_tags
first_tag = ''
for civitai_tag in CIVITAI_MODEL_TAGS:
if civitai_tag in model_tags:
first_tag = civitai_tag
break
# If no Civitai model tag found, fallback to first tag
if not first_tag and model_tags:
first_tag = model_tags[0]
# Format the template with available data
formatted_path = path_template
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
formatted_path = formatted_path.replace('{first_tag}', first_tag)
formatted_path = formatted_path.replace('{author}', author)
return formatted_path
async def _execute_download(self, download_url: str, save_dir: str,
metadata, version_info: Dict,
relative_path: str, progress_callback=None,
model_type: str = "lora") -> Dict:
model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
civitai_client = await self._get_civitai_client()
save_path = metadata.file_path
part_path = save_path + '.part'
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# Store file paths in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = save_path
self._active_downloads[download_id]['part_path'] = part_path
# Download preview image if available
images = version_info.get('images', [])
@@ -210,8 +423,6 @@ class DownloadManager:
if await civitai_client.download_preview_image(images[0]['url'], preview_path):
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
else:
# For images, use WebP format for better performance
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
@@ -238,8 +449,6 @@ class DownloadManager:
# Update metadata
metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
# Remove temporary file
try:
@@ -260,38 +469,46 @@ class DownloadManager:
)
if not success:
# Clean up files on failure
for path in [save_path, metadata_path, metadata.preview_url]:
# Clean up files on failure, but preserve .part file for resume
cleanup_files = [metadata_path]
if metadata.preview_url and os.path.exists(metadata.preview_url):
cleanup_files.append(metadata.preview_url)
for path in cleanup_files:
if path and os.path.exists(path):
os.remove(path)
try:
os.remove(path)
except Exception as e:
logger.warning(f"Failed to cleanup file {path}: {e}")
# Log but don't remove .part file to allow resume
if os.path.exists(part_path):
logger.info(f"Preserving partial download for resume: {part_path}")
return {'success': False, 'error': result}
# 4. Update file information (size and modified time)
metadata.update_file_info(save_path)
# 5. Final metadata update
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata.to_dict(), f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(save_path, metadata, True)
# 6. Update cache based on model type
if model_type == "checkpoint":
scanner = await self._get_checkpoint_scanner()
logger.info(f"Updating checkpoint cache for {save_path}")
else:
elif model_type == "lora":
scanner = await self._get_lora_scanner()
logger.info(f"Updating lora cache for {save_path}")
elif model_type == "embedding":
scanner = await ServiceRegistry.get_embedding_scanner()
logger.info(f"Updating embedding cache for {save_path}")
cache = await scanner.get_cached_data()
# Convert metadata to dictionary
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = relative_path
cache.raw_data.append(metadata_dict)
await cache.resort()
all_folders = set(cache.folders)
all_folders.add(relative_path)
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update the hash index with the new model entry
scanner._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
# Add model to cache and save to disk in a single operation
await scanner.add_model_to_cache(metadata_dict, relative_path)
# Report 100% completion
if progress_callback:
@@ -303,10 +520,18 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True)
# Clean up partial downloads
for path in [save_path, metadata_path]:
# Clean up partial downloads except .part file
cleanup_files = [metadata_path]
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
cleanup_files.append(metadata.preview_url)
for path in cleanup_files:
if path and os.path.exists(path):
os.remove(path)
try:
os.remove(path)
except Exception as e:
logger.warning(f"Failed to cleanup file {path}: {e}")
return {'success': False, 'error': str(e)}
async def _handle_download_progress(self, file_progress: float, progress_callback):
@@ -319,4 +544,99 @@ class DownloadManager:
if progress_callback:
# Scale file progress to 3-100 range (after preview download)
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
await progress_callback(round(overall_progress))
await progress_callback(round(overall_progress))
async def cancel_download(self, download_id: str) -> Dict:
"""Cancel an active download by download_id
Args:
download_id: The unique identifier of the download task
Returns:
Dict: Status of the cancellation operation
"""
if download_id not in self._download_tasks:
return {'success': False, 'error': 'Download task not found'}
try:
# Get the task and cancel it
task = self._download_tasks[download_id]
task.cancel()
# Update status in active downloads
if download_id in self._active_downloads:
self._active_downloads[download_id]['status'] = 'cancelling'
# Wait briefly for the task to acknowledge cancellation
try:
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
# Clean up ALL files including .part when user cancels
download_info = self._active_downloads.get(download_id)
if download_info:
# Delete the main file
if 'file_path' in download_info:
file_path = download_info['file_path']
if os.path.exists(file_path):
try:
os.unlink(file_path)
logger.debug(f"Deleted cancelled download: {file_path}")
except Exception as e:
logger.error(f"Error deleting file: {e}")
# Delete the .part file (only on user cancellation)
if 'part_path' in download_info:
part_path = download_info['part_path']
if os.path.exists(part_path):
try:
os.unlink(part_path)
logger.debug(f"Deleted partial download: {part_path}")
except Exception as e:
logger.error(f"Error deleting part file: {e}")
# Delete metadata file if exists
if 'file_path' in download_info:
file_path = download_info['file_path']
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
os.unlink(metadata_path)
except Exception as e:
logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4)
for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path):
try:
os.unlink(preview_path)
logger.debug(f"Deleted preview file: {preview_path}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
return {'success': True, 'message': 'Download cancelled successfully'}
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
async def get_active_downloads(self) -> Dict:
"""Get information about all active downloads
Returns:
Dict: List of active downloads and their status
"""
return {
'downloads': [
{
'download_id': task_id,
'model_id': info.get('model_id'),
'model_version_id': info.get('model_version_id'),
'progress': info.get('progress', 0),
'status': info.get('status', 'unknown'),
'error': info.get('error', None)
}
for task_id, info in self._active_downloads.items()
]
}

View File

@@ -0,0 +1,26 @@
import logging
from typing import List
from ..utils.models import EmbeddingMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex
logger = logging.getLogger(__name__)
class EmbeddingScanner(ModelScanner):
"""Service for scanning and managing embedding files"""
def __init__(self):
# Define supported file extensions
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
super().__init__(
model_type="embedding",
model_class=EmbeddingMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex()
)
def get_model_roots(self) -> List[str]:
"""Get embedding root directories"""
return config.embeddings_roots

View File

@@ -0,0 +1,51 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import EmbeddingMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class EmbeddingService(BaseModelService):
"""Embedding-specific service implementation"""
def __init__(self, scanner):
"""Initialize Embedding service
Args:
scanner: Embedding scanner instance
"""
super().__init__("embedding", scanner, EmbeddingMetadata)
async def format_response(self, embedding_data: Dict) -> Dict:
"""Format Embedding data for API response"""
return {
"model_name": embedding_data["model_name"],
"file_name": embedding_data["file_name"],
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
"base_model": embedding_data.get("base_model", ""),
"folder": embedding_data["folder"],
"sha256": embedding_data.get("sha256", ""),
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
"file_size": embedding_data.get("size", 0),
"modified": embedding_data.get("modified", ""),
"tags": embedding_data.get("tags", []),
"modelDescription": embedding_data.get("modelDescription", ""),
"from_civitai": embedding_data.get("from_civitai", True),
"notes": embedding_data.get("notes", ""),
"model_type": embedding_data.get("model_type", "embedding"),
"favorite": embedding_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}))
}
def find_duplicate_hashes(self) -> Dict:
"""Find Embeddings with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find Embeddings with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -1,542 +0,0 @@
import os
import logging
import asyncio
import time
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
from typing import List, Dict, Set, Optional
from threading import Lock
from ..config import config
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
# Configuration constant to control file monitoring functionality
ENABLE_FILE_MONITORING = False
class BaseFileHandler(FileSystemEventHandler):
"""Base handler for file system events"""
def __init__(self, loop: asyncio.AbstractEventLoop):
self.loop = loop # Store event loop reference
self.pending_changes = set() # Pending changes
self.lock = Lock() # Thread-safe lock
self.update_task = None # Async update task
self._ignore_paths = set() # Paths to ignore
self._min_ignore_timeout = 5 # Minimum timeout in seconds
self._download_speed = 1024 * 1024 # Assume 1MB/s as base speed
# Track modified files with timestamps for debouncing
self.modified_files: Dict[str, float] = {}
self.debounce_timer = None
self.debounce_delay = 3.0 # Seconds to wait after last modification
# Track files already scheduled for processing
self.scheduled_files: Set[str] = set()
# File extensions to monitor - should be overridden by subclasses
self.file_extensions = set()
def _should_ignore(self, path: str) -> bool:
"""Check if path should be ignored"""
real_path = os.path.realpath(path) # Resolve any symbolic links
return real_path.replace(os.sep, '/') in self._ignore_paths
def add_ignore_path(self, path: str, file_size: int = 0):
"""Add path to ignore list with dynamic timeout based on file size"""
real_path = os.path.realpath(path) # Resolve any symbolic links
self._ignore_paths.add(real_path.replace(os.sep, '/'))
# Short timeout (e.g. 5 seconds) is sufficient to ignore the CREATE event
timeout = 5
self.loop.call_later(
timeout,
self._ignore_paths.discard,
real_path.replace(os.sep, '/')
)
def on_created(self, event):
if event.is_directory:
return
# Handle appropriate files based on extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
# Process this file directly and ignore subsequent modifications
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
if normalized_path not in self.scheduled_files:
logger.info(f"File created: {event.src_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.src_path)
# Ignore modifications for a short period after creation
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
def on_modified(self, event):
if event.is_directory:
return
# Only process files with supported extensions
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
# Skip if this file is already scheduled for processing
if normalized_path in self.scheduled_files:
return
# Update the timestamp for this file
self.modified_files[normalized_path] = time.time()
# Cancel any existing timer
if self.debounce_timer:
self.debounce_timer.cancel()
# Set a new timer to process modified files after debounce period
self.debounce_timer = self.loop.call_later(
self.debounce_delay,
self.loop.call_soon_threadsafe,
self._process_modified_files
)
def _process_modified_files(self):
"""Process files that have been modified after debounce period"""
current_time = time.time()
files_to_process = []
# Find files that haven't been modified for debounce_delay seconds
for file_path, last_modified in list(self.modified_files.items()):
if current_time - last_modified >= self.debounce_delay:
# Only process if not already scheduled
if file_path not in self.scheduled_files:
files_to_process.append(file_path)
self.scheduled_files.add(file_path)
# Auto-remove from scheduled list after reasonable time
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
file_path
)
del self.modified_files[file_path]
# Process stable files
for file_path in files_to_process:
logger.info(f"Processing modified file: {file_path}")
self._schedule_update('add', file_path)
def on_deleted(self, event):
if event.is_directory:
return
file_ext = os.path.splitext(event.src_path)[1].lower()
if file_ext not in self.file_extensions:
return
if self._should_ignore(event.src_path):
return
# Remove from scheduled files if present
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"File deleted: {event.src_path}")
self._schedule_update('remove', event.src_path)
def on_moved(self, event):
"""Handle file move/rename events"""
src_ext = os.path.splitext(event.src_path)[1].lower()
dest_ext = os.path.splitext(event.dest_path)[1].lower()
# If destination has supported extension, treat as new file
if dest_ext in self.file_extensions:
if self._should_ignore(event.dest_path):
return
normalized_path = os.path.realpath(event.dest_path).replace(os.sep, '/')
# Only process if not already scheduled
if normalized_path not in self.scheduled_files:
logger.info(f"File renamed/moved to: {event.dest_path}")
self.scheduled_files.add(normalized_path)
self._schedule_update('add', event.dest_path)
# Auto-remove from scheduled list after reasonable time
self.loop.call_later(
self.debounce_delay * 2,
self.scheduled_files.discard,
normalized_path
)
# If source was a supported file, treat it as deleted
if src_ext in self.file_extensions:
if self._should_ignore(event.src_path):
return
normalized_path = os.path.realpath(event.src_path).replace(os.sep, '/')
self.scheduled_files.discard(normalized_path)
logger.info(f"File moved/renamed from: {event.src_path}")
self._schedule_update('remove', event.src_path)
def _schedule_update(self, action: str, file_path: str):
"""Schedule a cache update"""
with self.lock:
# Use config method to map path
mapped_path = config.map_path_to_link(file_path)
normalized_path = mapped_path.replace(os.sep, '/')
self.pending_changes.add((action, normalized_path))
self.loop.call_soon_threadsafe(self._create_update_task)
def _create_update_task(self):
"""Create update task in the event loop"""
if self.update_task is None or self.update_task.done():
self.update_task = asyncio.create_task(self._process_changes())
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing - should be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement _process_changes")
class LoraFileHandler(BaseFileHandler):
"""Handler for LoRA file system events"""
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for LoRAs
self.file_extensions = {'.safetensors'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} LoRA file changes")
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
for action, file_path in changes:
try:
if action == 'add':
# Check if file already exists in cache
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if existing:
logger.info(f"File {file_path} already in cache, skipping")
continue
# Scan new file
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from cache")
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# Update folder list
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes for LoRA: {e}")
class CheckpointFileHandler(BaseFileHandler):
"""Handler for checkpoint file system events"""
def __init__(self, loop: asyncio.AbstractEventLoop):
super().__init__(loop)
# Set supported file extensions for checkpoints
self.file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
async def _process_changes(self, delay: float = 2.0):
"""Process pending changes with debouncing for checkpoint files"""
await asyncio.sleep(delay)
try:
with self.lock:
changes = self.pending_changes.copy()
self.pending_changes.clear()
if not changes:
return
logger.info(f"Processing {len(changes)} checkpoint file changes")
# Get scanner through ServiceRegistry
scanner = await ServiceRegistry.get_checkpoint_scanner()
cache = await scanner.get_cached_data()
needs_resort = False
new_folders = set()
for action, file_path in changes:
try:
if action == 'add':
# Check if file already exists in cache
existing = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if existing:
logger.info(f"File {file_path} already in cache, skipping")
continue
# Scan new file
model_data = await scanner.scan_single_model(file_path)
if model_data:
# Update tags count if applicable
for tag in model_data.get('tags', []):
scanner._tags_count[tag] = scanner._tags_count.get(tag, 0) + 1
cache.raw_data.append(model_data)
new_folders.add(model_data['folder'])
# Update hash index
if 'sha256' in model_data:
scanner._hash_index.add_entry(
model_data['sha256'],
model_data['file_path']
)
needs_resort = True
elif action == 'remove':
# Find the model to remove so we can update tags count
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
if model_to_remove:
# Update tags count by reducing counts
for tag in model_to_remove.get('tags', []):
if tag in scanner._tags_count:
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
if scanner._tags_count[tag] == 0:
del scanner._tags_count[tag]
# Remove from cache and hash index
logger.info(f"Removing {file_path} from checkpoint cache")
scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data
if item['file_path'] != file_path
]
needs_resort = True
except Exception as e:
logger.error(f"Error processing checkpoint {action} for {file_path}: {e}")
if needs_resort:
await cache.resort()
# Update folder list
all_folders = set(cache.folders) | new_folders
cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
except Exception as e:
logger.error(f"Error in process_changes for checkpoint: {e}")
class BaseFileMonitor:
"""Base class for file monitoring"""
def __init__(self, monitor_paths: List[str]):
self.observer = Observer()
self.loop = asyncio.get_event_loop()
self.monitor_paths = set()
# Process monitor paths
for path in monitor_paths:
self.monitor_paths.add(os.path.realpath(path).replace(os.sep, '/'))
# Add mapped paths from config
for target_path in config._path_mappings.keys():
self.monitor_paths.add(target_path)
def start(self):
"""Start file monitoring"""
if not ENABLE_FILE_MONITORING:
logger.debug("File monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
for path in self.monitor_paths:
try:
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Started monitoring: {path}")
except Exception as e:
logger.error(f"Error monitoring {path}: {e}")
self.observer.start()
def stop(self):
"""Stop file monitoring"""
if not ENABLE_FILE_MONITORING:
return
self.observer.stop()
self.observer.join()
def rescan_links(self):
"""Rescan links when new ones are added"""
if not ENABLE_FILE_MONITORING:
return
# Find new paths not yet being monitored
new_paths = set()
for path in config._path_mappings.keys():
real_path = os.path.realpath(path).replace(os.sep, '/')
if real_path not in self.monitor_paths:
new_paths.add(real_path)
self.monitor_paths.add(real_path)
# Add new paths to monitoring
for path in new_paths:
try:
self.observer.schedule(self.handler, path, recursive=True)
logger.info(f"Added new monitoring path: {path}")
except Exception as e:
logger.error(f"Error adding new monitor for {path}: {e}")
class LoraFileMonitor(BaseFileMonitor):
"""Monitor for LoRA file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
from ..config import config
monitor_paths = config.loras_roots
super().__init__(monitor_paths)
self.handler = LoraFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
from ..config import config
cls._instance = cls(config.loras_roots)
return cls._instance
class CheckpointFileMonitor(BaseFileMonitor):
"""Monitor for checkpoint file changes"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls, monitor_paths=None):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self, monitor_paths=None):
if not hasattr(self, '_initialized'):
if monitor_paths is None:
# Get checkpoint roots from scanner
monitor_paths = []
# We'll initialize monitor paths later when scanner is available
super().__init__(monitor_paths or [])
self.handler = CheckpointFileHandler(self.loop)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls([])
# Now get checkpoint roots from scanner
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
monitor_paths = scanner.get_model_roots()
# Update monitor paths - but don't actually monitor them
for path in monitor_paths:
real_path = os.path.realpath(path).replace(os.sep, '/')
cls._instance.monitor_paths.add(real_path)
return cls._instance
def start(self):
"""Override start to check global enable flag"""
if not ENABLE_FILE_MONITORING:
logger.debug("Checkpoint file monitoring is disabled via ENABLE_FILE_MONITORING setting")
return
logger.debug("Checkpoint file monitoring is temporarily disabled")
# Skip the actual monitoring setup
pass
async def initialize_paths(self):
"""Initialize monitor paths from scanner - currently disabled"""
if not ENABLE_FILE_MONITORING:
logger.debug("Checkpoint path initialization skipped (monitoring disabled)")
return
logger.debug("Checkpoint file path initialization skipped (monitoring disabled)")
pass

View File

@@ -1,65 +0,0 @@
import asyncio
from typing import List, Dict
from dataclasses import dataclass
from operator import itemgetter
from natsort import natsorted
@dataclass
class LoraCache:
"""Cache structure for LoRA data"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async with self._lock:
self.sorted_by_name = natsorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
"""Update preview_url for a specific lora in all cached data
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
Returns:
bool: True if the update was successful, False if the lora wasn't found
"""
async with self._lock:
# Update in raw_data
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
else:
return False # Lora not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

View File

@@ -1,54 +0,0 @@
from typing import Dict, Optional
import logging
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class LoraHashIndex:
"""Index for mapping LoRA file hashes to their file paths"""
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update a hash -> path mapping"""
if not sha256 or not file_path:
return
# Always store lowercase hashes for consistency
self._hash_to_path[sha256.lower()] = file_path
def remove_entry(self, sha256: str) -> None:
"""Remove a hash entry"""
if sha256:
self._hash_to_path.pop(sha256.lower(), None)
def remove_by_path(self, file_path: str) -> None:
"""Remove entry by file path"""
for sha256, path in list(self._hash_to_path.items()):
if path == file_path:
del self._hash_to_path[sha256]
break
def get_path(self, sha256: str) -> Optional[str]:
"""Get file path for a given hash"""
if not sha256:
return None
return self._hash_to_path.get(sha256.lower())
def get_hash(self, file_path: str) -> Optional[str]:
"""Get hash for a given file path"""
for sha256, path in self._hash_to_path.items():
if path == file_path:
return sha256
return None
def has_hash(self, sha256: str) -> bool:
"""Check if hash exists in index"""
if not sha256:
return False
return sha256.lower() in self._hash_to_path
def clear(self) -> None:
"""Clear all entries"""
self._hash_to_path.clear()

View File

@@ -1,20 +1,10 @@
import json
import os
import logging
import asyncio
import shutil
import time
import re
from typing import List, Dict, Optional, Set
from typing import List
from ..utils.models import LoraMetadata
from ..config import config
from .model_scanner import ModelScanner
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
from .settings_manager import settings
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match
from .service_registry import ServiceRegistry
import sys
logger = logging.getLogger(__name__)
@@ -22,430 +12,21 @@ logger = logging.getLogger(__name__)
class LoraScanner(ModelScanner):
"""Service for scanning and managing LoRA files"""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
# Ensure initialization happens only once
if not hasattr(self, '_initialized'):
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
self._initialized = True
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
# Define supported file extensions
file_extensions = {'.safetensors'}
# Initialize parent class with ModelHashIndex
super().__init__(
model_type="lora",
model_class=LoraMetadata,
file_extensions=file_extensions,
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
)
def get_model_roots(self) -> List[str]:
"""Get lora root directories"""
return config.loras_roots
async def scan_all_models(self) -> List[Dict]:
"""Scan all LoRA directories and return metadata"""
all_loras = []
# Create scan tasks for each directory
scan_tasks = []
for lora_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(lora_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
loras = await task
all_loras.extend(loras)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_loras
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for LoRA files"""
loras = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, loras)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return loras
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
loras.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
folder: str = None, search: str = None, fuzzy_search: bool = False,
base_models: list = None, tags: list = None,
search_options: dict = None, hash_filters: dict = None,
favorites_only: bool = False, first_letter: str = None) -> Dict:
"""Get paginated and filtered lora data
Args:
page: Current page number (1-based)
page_size: Number of items per page
sort_by: Sort method ('name' or 'date')
folder: Filter by folder path
search: Search term
fuzzy_search: Use fuzzy matching for search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Dictionary with search options (filename, modelname, tags, recursive)
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
favorites_only: Filter for favorite models only
first_letter: Filter by first letter of model name
"""
cache = await self.get_cached_data()
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
# Apply hash filtering if provided (highest priority)
if hash_filters:
single_hash = hash_filters.get('single_hash')
multiple_hashes = hash_filters.get('multiple_hashes')
if single_hash:
# Filter by single hash
single_hash = single_hash.lower() # Ensure lowercase for matching
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() == single_hash
]
elif multiple_hashes:
# Filter by multiple hashes
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
filtered_data = [
lora for lora in filtered_data
if lora.get('sha256', '').lower() in hash_set
]
# Jump to pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
# Apply SFW filtering if enabled
if settings.get('show_only_sfw', False):
filtered_data = [
lora for lora in filtered_data
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
]
# Apply favorites filtering if enabled
if favorites_only:
filtered_data = [
lora for lora in filtered_data
if lora.get('favorite', False) is True
]
# Apply first letter filtering
if first_letter:
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
# Apply folder filtering
if folder is not None:
if search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
filtered_data = [
lora for lora in filtered_data
if lora['folder'].startswith(folder)
]
else:
# Exact folder filtering
filtered_data = [
lora for lora in filtered_data
if lora['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
filtered_data = [
lora for lora in filtered_data
if lora.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
filtered_data = [
lora for lora in filtered_data
if any(tag in lora.get('tags', []) for tag in tags)
]
# Apply search filtering
if search:
search_results = []
search_opts = search_options or {}
for lora in filtered_data:
# Search by file name
if search_opts.get('filename', True):
if fuzzy_match(lora.get('file_name', ''), search):
search_results.append(lora)
continue
# Search by model name
if search_opts.get('modelname', True):
if fuzzy_match(lora.get('model_name', ''), search):
search_results.append(lora)
continue
# Search by tags
if search_opts.get('tags', False) and 'tags' in lora:
if any(fuzzy_match(tag, search) for tag in lora['tags']):
search_results.append(lora)
continue
filtered_data = search_results
# Calculate pagination
total_items = len(filtered_data)
start_idx = (page - 1) * page_size
end_idx = min(start_idx + page_size, total_items)
result = {
'items': filtered_data[start_idx:end_idx],
'total': total_items,
'page': page,
'page_size': page_size,
'total_pages': (total_items + page_size - 1) // page_size
}
return result
def _filter_by_first_letter(self, data, letter):
"""Filter data by first letter of model name
Special handling:
- '#': Numbers (0-9)
- '@': Special characters (not alphanumeric)
- '': CJK characters
"""
filtered_data = []
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if letter == '#' and first_char.isdigit():
filtered_data.append(lora)
elif letter == '@' and not first_char.isalnum():
# Special characters (not alphanumeric)
filtered_data.append(lora)
elif letter == '' and self._is_cjk_character(first_char):
# CJK characters
filtered_data.append(lora)
elif letter.upper() == first_char:
# Regular alphabet matching
filtered_data.append(lora)
return filtered_data
def _is_cjk_character(self, char):
"""Check if character is a CJK character"""
# Define Unicode ranges for CJK characters
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
(0x3300, 0x33FF), # CJK Compatibility
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
(0x3100, 0x312F), # Bopomofo
(0x31A0, 0x31BF), # Bopomofo Extended
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
(0xAC00, 0xD7AF), # Hangul Syllables
(0x1100, 0x11FF), # Hangul Jamo
(0xA960, 0xA97F), # Hangul Jamo Extended-A
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
]
code_point = ord(char)
return any(start <= code_point <= end for start, end in cjk_ranges)
async def get_letter_counts(self):
"""Get count of models for each letter of the alphabet"""
cache = await self.get_cached_data()
data = cache.sorted_by_name
# Define letter categories
letters = {
'#': 0, # Numbers
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
'Y': 0, 'Z': 0,
'@': 0, # Special characters
'': 0 # CJK characters
}
# Count models for each letter
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if first_char.isdigit():
letters['#'] += 1
elif first_char in letters:
letters[first_char] += 1
elif self._is_cjk_character(first_char):
letters[''] += 1
elif not first_char.isalnum():
letters['@'] += 1
return letters
async def _update_metadata_paths(self, metadata_path: str, lora_path: str) -> Dict:
"""Update file paths in metadata file"""
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update file_path
metadata['file_path'] = lora_path.replace(os.sep, '/')
# Update preview_url if exists
if 'preview_url' in metadata:
preview_dir = os.path.dirname(lora_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return metadata
except Exception as e:
logger.error(f"Error updating metadata paths: {e}", exc_info=True)
# Lora-specific hash index functionality
def has_lora_hash(self, sha256: str) -> bool:
"""Check if a LoRA with given hash exists"""
return self.has_hash(sha256)
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
"""Get file path for a LoRA by its hash"""
return self.get_path_by_hash(sha256)
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a LoRA by its file path"""
return self.get_hash_by_path(file_path)
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get top tags sorted by count"""
# Make sure cache is initialized
await self.get_cached_data()
# Sort tags by count in descending order
sorted_tags = sorted(
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
key=lambda x: x['count'],
reverse=True
)
# Return limited number
return sorted_tags[:limit]
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
"""Get base models used in loras sorted by frequency"""
# Make sure cache is initialized
cache = await self.get_cached_data()
# Count base model occurrences
base_model_counts = {}
for lora in cache.raw_data:
if 'base_model' in lora and lora['base_model']:
base_model = lora['base_model']
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
# Sort base models by count
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda x: x['count'], reverse=True)
# Return limited number
return sorted_models[:limit]
async def diagnose_hash_index(self):
"""Diagnostic method to verify hash index functionality"""
@@ -482,19 +63,3 @@ class LoraScanner(ModelScanner):
test_hash_result = self._hash_index.get_hash(test_path)
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
async def get_lora_info_by_name(self, name):
"""Get LoRA information by name"""
try:
# Get cached data
cache = await self.get_cached_data()
# Find the LoRA by name
for lora in cache.raw_data:
if lora.get("file_name") == name:
return lora
return None
except Exception as e:
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
return None

182
py/services/lora_service.py Normal file
View File

@@ -0,0 +1,182 @@
import os
import logging
from typing import Dict, List, Optional
from .base_model_service import BaseModelService
from ..utils.models import LoraMetadata
from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__)
class LoraService(BaseModelService):
"""LoRA-specific service implementation"""
def __init__(self, scanner):
"""Initialize LoRA service
Args:
scanner: LoRA scanner instance
"""
super().__init__("lora", scanner, LoraMetadata)
async def format_response(self, lora_data: Dict) -> Dict:
"""Format LoRA data for API response"""
return {
"model_name": lora_data["model_name"],
"file_name": lora_data["file_name"],
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
"base_model": lora_data.get("base_model", ""),
"folder": lora_data["folder"],
"sha256": lora_data.get("sha256", ""),
"file_path": lora_data["file_path"].replace(os.sep, "/"),
"file_size": lora_data.get("size", 0),
"modified": lora_data.get("modified", ""),
"tags": lora_data.get("tags", []),
"modelDescription": lora_data.get("modelDescription", ""),
"from_civitai": lora_data.get("from_civitai", True),
"usage_tips": lora_data.get("usage_tips", ""),
"notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False),
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
}
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply LoRA-specific filters"""
# Handle first_letter filter for LoRAs
first_letter = kwargs.get('first_letter')
if first_letter:
data = self._filter_by_first_letter(data, first_letter)
return data
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
"""Filter data by first letter of model name
Special handling:
- '#': Numbers (0-9)
- '@': Special characters (not alphanumeric)
- '': CJK characters
"""
filtered_data = []
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if letter == '#' and first_char.isdigit():
filtered_data.append(lora)
elif letter == '@' and not first_char.isalnum():
# Special characters (not alphanumeric)
filtered_data.append(lora)
elif letter == '' and self._is_cjk_character(first_char):
# CJK characters
filtered_data.append(lora)
elif letter.upper() == first_char:
# Regular alphabet matching
filtered_data.append(lora)
return filtered_data
def _is_cjk_character(self, char: str) -> bool:
"""Check if character is a CJK character"""
# Define Unicode ranges for CJK characters
cjk_ranges = [
(0x4E00, 0x9FFF), # CJK Unified Ideographs
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
(0x3300, 0x33FF), # CJK Compatibility
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
(0x3100, 0x312F), # Bopomofo
(0x31A0, 0x31BF), # Bopomofo Extended
(0x3040, 0x309F), # Hiragana
(0x30A0, 0x30FF), # Katakana
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
(0xAC00, 0xD7AF), # Hangul Syllables
(0x1100, 0x11FF), # Hangul Jamo
(0xA960, 0xA97F), # Hangul Jamo Extended-A
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
]
code_point = ord(char)
return any(start <= code_point <= end for start, end in cjk_ranges)
# LoRA-specific methods
async def get_letter_counts(self) -> Dict[str, int]:
"""Get count of LoRAs for each letter of the alphabet"""
cache = await self.scanner.get_cached_data()
data = cache.raw_data
# Define letter categories
letters = {
'#': 0, # Numbers
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
'Y': 0, 'Z': 0,
'@': 0, # Special characters
'': 0 # CJK characters
}
# Count models for each letter
for lora in data:
model_name = lora.get('model_name', '')
if not model_name:
continue
first_char = model_name[0].upper()
if first_char.isdigit():
letters['#'] += 1
elif first_char in letters:
letters[first_char] += 1
elif self._is_cjk_character(first_char):
letters[''] += 1
elif not first_char.isalnum():
letters['@'] += 1
return letters
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
"""Get trigger words for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
return civitai_data.get('trainedWords', [])
return []
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
"""Get usage tips for a LoRA by its relative path"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
file_path = lora.get('file_path', '')
if file_path:
# Convert to forward slashes and extract relative path
file_path_normalized = file_path.replace('\\', '/')
# Find the relative path part by looking for the relative_path in the full path
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
return lora.get('usage_tips', '')
return None
def find_duplicate_hashes(self) -> Dict:
"""Find LoRAs with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes()
def find_duplicate_filenames(self) -> Dict:
"""Find LoRAs with conflicting filenames"""
return self.scanner._hash_index.get_duplicate_filenames()

View File

@@ -1,43 +1,92 @@
import asyncio
from typing import List, Dict
from typing import List, Dict, Tuple
from dataclasses import dataclass
from operator import itemgetter
from natsort import natsorted
# Supported sort modes: (sort_key, order)
# order: 'asc' for ascending, 'desc' for descending
SUPPORTED_SORT_MODES = [
('name', 'asc'),
('name', 'desc'),
('date', 'asc'),
('date', 'desc'),
('size', 'asc'),
('size', 'desc'),
]
@dataclass
class ModelCache:
"""Cache structure for model data"""
"""Cache structure for model data with extensible sorting"""
raw_data: List[Dict]
sorted_by_name: List[Dict]
sorted_by_date: List[Dict]
folders: List[str]
def __post_init__(self):
self._lock = asyncio.Lock()
# Cache for last sort: (sort_key, order) -> sorted list
self._last_sort: Tuple[str, str] = (None, None)
self._last_sorted_data: List[Dict] = []
# Default sort on init
asyncio.create_task(self.resort())
async def resort(self, name_only: bool = False):
"""Resort all cached data views"""
async def resort(self):
"""Resort cached data according to last sort mode if set"""
async with self._lock:
self.sorted_by_name = natsorted(
self.raw_data,
key=lambda x: x['model_name'].lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('modified'),
reverse=True
)
# Update folder list
if self._last_sort != (None, None):
sort_key, order = self._last_sort
sorted_data = self._sort_data(self.raw_data, sort_key, order)
self._last_sorted_data = sorted_data
# Update folder list
# else: do nothing
all_folders = set(l['folder'] for l in self.raw_data)
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
"""Sort data by sort_key and order"""
reverse = (order == 'desc')
if sort_key == 'name':
# Natural sort by model_name, case-insensitive
return natsorted(
data,
key=lambda x: x['model_name'].lower(),
reverse=reverse
)
elif sort_key == 'date':
# Sort by modified timestamp
return sorted(
data,
key=itemgetter('modified'),
reverse=reverse
)
elif sort_key == 'size':
# Sort by file size
return sorted(
data,
key=itemgetter('size'),
reverse=reverse
)
else:
# Fallback: no sort
return list(data)
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
"""Get sorted data by sort_key and order, using cache if possible"""
async with self._lock:
if (sort_key, order) == self._last_sort:
return self._last_sorted_data
sorted_data = self._sort_data(self.raw_data, sort_key, order)
self._last_sort = (sort_key, order)
self._last_sorted_data = sorted_data
return sorted_data
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
"""Update preview_url for a specific model in all cached data
Args:
file_path: The file path of the model to update
preview_url: The new preview URL
preview_nsfw_level: The NSFW level of the preview
Returns:
bool: True if the update was successful, False if the model wasn't found
@@ -47,19 +96,9 @@ class ModelCache:
for item in self.raw_data:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
item['preview_nsfw_level'] = preview_nsfw_level
break
else:
return False # Model not found
# Update in sorted lists (references to the same dict objects)
for item in self.sorted_by_name:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
for item in self.sorted_by_date:
if item['file_path'] == file_path:
item['preview_url'] = preview_url
break
return True

View File

@@ -1,12 +1,15 @@
from typing import Dict, Optional, Set
from typing import Dict, Optional, Set, List
import os
class ModelHashIndex:
"""Index for looking up models by hash or path"""
"""Index for looking up models by hash or filename"""
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
self._filename_to_hash: Dict[str, str] = {}
# New data structures for tracking duplicates
self._duplicate_hashes: Dict[str, List[str]] = {} # sha256 -> list of paths
self._duplicate_filenames: Dict[str, List[str]] = {} # filename -> list of paths
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update hash index entry"""
@@ -19,18 +22,43 @@ class ModelHashIndex:
# Extract filename without extension
filename = self._get_filename_from_path(file_path)
# Track duplicates by hash
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
if old_path != file_path: # Only record if it's actually a different path
if sha256 not in self._duplicate_hashes:
self._duplicate_hashes[sha256] = [old_path]
if file_path not in self._duplicate_hashes.get(sha256, []):
self._duplicate_hashes.setdefault(sha256, []).append(file_path)
# Track duplicates by filename - FIXED LOGIC
if filename in self._filename_to_hash:
existing_hash = self._filename_to_hash[filename]
existing_path = self._hash_to_path.get(existing_hash)
# If this is a different file with the same filename
if existing_path and existing_path != file_path:
# Initialize duplicates tracking if needed
if filename not in self._duplicate_filenames:
self._duplicate_filenames[filename] = [existing_path]
# Add current file to duplicates if not already present
if file_path not in self._duplicate_filenames[filename]:
self._duplicate_filenames[filename].append(file_path)
# Remove old path mapping if hash exists
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
old_filename = self._get_filename_from_path(old_path)
if old_filename in self._filename_to_hash:
if old_filename in self._filename_to_hash and self._filename_to_hash[old_filename] == sha256:
del self._filename_to_hash[old_filename]
# Remove old hash mapping if filename exists
# Remove old hash mapping if filename exists and points to different hash
if filename in self._filename_to_hash:
old_hash = self._filename_to_hash[filename]
if old_hash in self._hash_to_path:
del self._hash_to_path[old_hash]
if old_hash != sha256 and old_hash in self._hash_to_path:
# Don't delete the old hash mapping, just update filename mapping
pass
# Add new mappings
self._hash_to_path[sha256] = file_path
@@ -40,24 +68,126 @@ class ModelHashIndex:
"""Extract filename without extension from path"""
return os.path.splitext(os.path.basename(file_path))[0]
def remove_by_path(self, file_path: str) -> None:
def remove_by_path(self, file_path: str, hash_val: str = None) -> None:
"""Remove entry by file path"""
filename = self._get_filename_from_path(file_path)
if filename in self._filename_to_hash:
hash_val = self._filename_to_hash[filename]
if hash_val in self._hash_to_path:
# Find the hash for this file path
if hash_val is None:
for h, p in self._hash_to_path.items():
if p == file_path:
hash_val = h
break
# If we didn't find a hash, nothing to do
if not hash_val:
return
# Update duplicates tracking for hash
if hash_val in self._duplicate_hashes:
# Remove the current path from duplicates
self._duplicate_hashes[hash_val] = [p for p in self._duplicate_hashes[hash_val] if p != file_path]
# Update or remove hash mapping based on remaining duplicates
if len(self._duplicate_hashes[hash_val]) > 0:
# Replace with one of the remaining paths
new_path = self._duplicate_hashes[hash_val][0]
new_filename = self._get_filename_from_path(new_path)
# Update hash-to-path mapping
self._hash_to_path[hash_val] = new_path
# IMPORTANT: Update filename-to-hash mapping for consistency
# Remove old filename mapping if it points to this hash
if filename in self._filename_to_hash and self._filename_to_hash[filename] == hash_val:
del self._filename_to_hash[filename]
# Add new filename mapping
self._filename_to_hash[new_filename] = hash_val
# If only one duplicate left, remove from duplicates tracking
if len(self._duplicate_hashes[hash_val]) == 1:
del self._duplicate_hashes[hash_val]
else:
# No duplicates left, remove hash entry completely
del self._duplicate_hashes[hash_val]
del self._hash_to_path[hash_val]
del self._filename_to_hash[filename]
# Remove corresponding filename entry if it points to this hash
if filename in self._filename_to_hash and self._filename_to_hash[filename] == hash_val:
del self._filename_to_hash[filename]
else:
# No duplicates, simply remove the hash entry
del self._hash_to_path[hash_val]
# Remove corresponding filename entry if it points to this hash
if filename in self._filename_to_hash and self._filename_to_hash[filename] == hash_val:
del self._filename_to_hash[filename]
# Update duplicates tracking for filename
if filename in self._duplicate_filenames:
# Remove the current path from duplicates
self._duplicate_filenames[filename] = [p for p in self._duplicate_filenames[filename] if p != file_path]
# Update or remove filename mapping based on remaining duplicates
if len(self._duplicate_filenames[filename]) > 0:
# Get the hash for the first remaining duplicate path
first_dup_path = self._duplicate_filenames[filename][0]
first_dup_hash = None
for h, p in self._hash_to_path.items():
if p == first_dup_path:
first_dup_hash = h
break
# Update the filename to hash mapping if we found a hash
if first_dup_hash:
self._filename_to_hash[filename] = first_dup_hash
# If only one duplicate left, remove from duplicates tracking
if len(self._duplicate_filenames[filename]) == 1:
del self._duplicate_filenames[filename]
else:
# No duplicates left, remove filename entry completely
del self._duplicate_filenames[filename]
if filename in self._filename_to_hash:
del self._filename_to_hash[filename]
def remove_by_hash(self, sha256: str) -> None:
"""Remove entry by hash"""
sha256 = sha256.lower()
if sha256 in self._hash_to_path:
path = self._hash_to_path[sha256]
filename = self._get_filename_from_path(path)
if filename in self._filename_to_hash:
del self._filename_to_hash[filename]
del self._hash_to_path[sha256]
if sha256 not in self._hash_to_path:
return
# Get the path and filename
path = self._hash_to_path[sha256]
filename = self._get_filename_from_path(path)
# Get all paths for this hash (including duplicates)
paths_to_remove = [path]
if sha256 in self._duplicate_hashes:
paths_to_remove.extend(self._duplicate_hashes[sha256])
del self._duplicate_hashes[sha256]
# Remove hash-to-path mapping
del self._hash_to_path[sha256]
# Update filename-to-hash and duplicate filenames for all paths
for path_to_remove in paths_to_remove:
fname = self._get_filename_from_path(path_to_remove)
# If this filename maps to the hash we're removing, remove it
if fname in self._filename_to_hash and self._filename_to_hash[fname] == sha256:
del self._filename_to_hash[fname]
# Update duplicate filenames tracking
if fname in self._duplicate_filenames:
self._duplicate_filenames[fname] = [p for p in self._duplicate_filenames[fname] if p != path_to_remove]
if not self._duplicate_filenames[fname]:
del self._duplicate_filenames[fname]
elif len(self._duplicate_filenames[fname]) == 1:
# If only one entry remains, it's no longer a duplicate
del self._duplicate_filenames[fname]
def has_hash(self, sha256: str) -> bool:
"""Check if hash exists in index"""
@@ -74,14 +204,14 @@ class ModelHashIndex:
def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a filename without extension"""
# Strip extension if present to make the function more flexible
filename = os.path.splitext(filename)[0]
return self._filename_to_hash.get(filename)
def clear(self) -> None:
"""Clear all entries"""
self._hash_to_path.clear()
self._filename_to_hash.clear()
self._duplicate_hashes.clear()
self._duplicate_filenames.clear()
def get_all_hashes(self) -> Set[str]:
"""Get all hashes in the index"""
@@ -91,6 +221,14 @@ class ModelHashIndex:
"""Get all filenames in the index"""
return set(self._filename_to_hash.keys())
def get_duplicate_hashes(self) -> Dict[str, List[str]]:
"""Get dictionary of duplicate hashes and their paths"""
return self._duplicate_hashes
def get_duplicate_filenames(self) -> Dict[str, List[str]]:
"""Get dictionary of duplicate filenames and their paths"""
return self._duplicate_filenames
def __len__(self) -> int:
"""Get number of entries"""
return len(self._hash_to_path)

View File

@@ -8,7 +8,8 @@ from typing import List, Dict, Optional, Type, Set
from ..utils.models import BaseModelMetadata
from ..config import config
from ..utils.file_utils import load_metadata, get_file_info, find_preview_file, save_metadata
from ..utils.file_utils import find_preview_file
from ..utils.metadata_manager import MetadataManager
from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS
@@ -20,7 +21,30 @@ logger = logging.getLogger(__name__)
class ModelScanner:
"""Base service for scanning and managing model files"""
_lock = asyncio.Lock()
_instances = {} # Dictionary to store instances by class
_locks = {} # Dictionary to store locks by class
def __new__(cls, *args, **kwargs):
"""Implement singleton pattern for each subclass"""
if cls not in cls._instances:
cls._instances[cls] = super().__new__(cls)
return cls._instances[cls]
@classmethod
def _get_lock(cls):
"""Get or create a lock for this class"""
if cls not in cls._locks:
cls._locks[cls] = asyncio.Lock()
return cls._locks[cls]
@classmethod
async def get_instance(cls):
"""Get singleton instance with async support"""
lock = cls._get_lock()
async with lock:
if cls not in cls._instances:
cls._instances[cls] = cls()
return cls._instances[cls]
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
"""Initialize the scanner
@@ -31,6 +55,10 @@ class ModelScanner:
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
hash_index: Hash index instance (optional)
"""
# Ensure initialization happens only once per instance
if hasattr(self, '_initialized'):
return
self.model_type = model_type
self.model_class = model_class
self.file_extensions = file_extensions
@@ -39,15 +67,16 @@ class ModelScanner:
self._tags_count = {} # Dictionary to store tag counts
self._is_initializing = False # Flag to track initialization state
self._excluded_models = [] # List to track excluded models
self._initialized = True
# Register this service
asyncio.create_task(self._register_service())
async def _register_service(self):
"""Register this instance with the ServiceRegistry"""
service_name = f"{self.model_type}_scanner"
await ServiceRegistry.register_service(service_name, self)
async def initialize_in_background(self) -> None:
"""Initialize cache in background using thread pool"""
try:
@@ -55,8 +84,6 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
@@ -66,7 +93,16 @@ class ModelScanner:
# Determine the page type based on model type
page_type = 'loras' if self.model_type == 'lora' else 'checkpoints'
# First, count all model files to track progress
# First, try to load from cache
await ws_manager.broadcast_init_progress({
'stage': 'loading_cache',
'progress': 0,
'details': f"Loading {self.model_type} cache...",
'scanner_type': self.model_type,
'pageType': page_type
})
# If cache loading failed, proceed with full scan
await ws_manager.broadcast_init_progress({
'stage': 'scan_folders',
'progress': 0,
@@ -266,6 +302,13 @@ class ModelScanner:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Log duplicate filename warnings after building the index
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
# if duplicate_filenames:
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
# for filename, paths in duplicate_filenames.items():
# logger.warning(f" Duplicate filename '{filename}': {paths}")
# Update cache
self._cache.raw_data = raw_data
loop.run_until_complete(self._cache.resort())
@@ -280,25 +323,26 @@ class ModelScanner:
# Clean up the event loop
loop.close()
async def get_cached_data(self, force_refresh: bool = False) -> ModelCache:
"""Get cached model data, refresh if needed"""
async def get_cached_data(self, force_refresh: bool = False, rebuild_cache: bool = False) -> ModelCache:
"""Get cached model data, refresh if needed
Args:
force_refresh: Whether to refresh the cache
rebuild_cache: Whether to completely rebuild the cache
"""
# If cache is not initialized, return an empty cache
# Actual initialization should be done via initialize_in_background
if self._cache is None and not force_refresh:
return ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
# If force refresh is requested, initialize the cache directly
if (force_refresh):
if self._cache is None:
# For initial creation, do a full initialization
if force_refresh:
if rebuild_cache:
await self._initialize_cache()
else:
# For subsequent refreshes, use fast reconciliation
await self._reconcile_cache()
return self._cache
@@ -330,11 +374,16 @@ class ModelScanner:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Log duplicate filename warnings after building the index
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
# if duplicate_filenames:
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
# for filename, paths in duplicate_filenames.items():
# logger.warning(f" Duplicate filename '{filename}': {paths}")
# Update cache
self._cache = ModelCache(
raw_data=raw_data,
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
@@ -348,8 +397,6 @@ class ModelScanner:
if self._cache is None:
self._cache = ModelCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[]
)
finally:
@@ -426,26 +473,33 @@ class ModelScanner:
batch = new_files[i:i+batch_size]
for path in batch:
try:
model_data = await self.scan_single_model(path)
if model_data:
# Add to cache
self._cache.raw_data.append(model_data)
# Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Update tags count
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
total_added += 1
# Find the appropriate root path for this file
root_path = None
for potential_root in self.get_model_roots():
if path.startswith(potential_root):
root_path = potential_root
break
if root_path:
model_data = await self._process_model_file(path, root_path)
if model_data:
# Add to cache
self._cache.raw_data.append(model_data)
# Update hash index if available
if 'sha256' in model_data and 'file_path' in model_data:
self._hash_index.add_entry(model_data['sha256'].lower(), model_data['file_path'])
# Update tags count
if 'tags' in model_data and model_data['tags']:
for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
total_added += 1
else:
logger.error(f"Could not determine root path for {path}")
except Exception as e:
logger.error(f"Error adding {path} to cache: {e}")
# Yield control after each batch
await asyncio.sleep(0)
# Find missing files (in cache but not in filesystem)
missing_files = cached_paths - found_paths
@@ -490,41 +544,71 @@ class ModelScanner:
finally:
self._is_initializing = False # Unset flag
# These methods should be implemented in child classes
async def scan_all_models(self) -> List[Dict]:
"""Scan all model directories and return metadata"""
raise NotImplementedError("Subclasses must implement scan_all_models")
all_models = []
# Create scan tasks for each directory
scan_tasks = []
for model_root in self.get_model_roots():
task = asyncio.create_task(self._scan_directory(model_root))
scan_tasks.append(task)
# Wait for all tasks to complete
for task in scan_tasks:
try:
models = await task
all_models.extend(models)
except Exception as e:
logger.error(f"Error scanning directory: {e}")
return all_models
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Scan a single directory for model files"""
models = []
original_root = root_path # Save original root path
async def scan_recursive(path: str, visited_paths: set):
"""Recursively scan directory, avoiding circular symlinks"""
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
file_path = entry.path.replace(os.sep, "/")
result = await self._process_model_file(file_path, original_root)
if result:
models.append(result)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
def is_initializing(self) -> bool:
"""Check if the scanner is currently initializing"""
return self._is_initializing
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
raise NotImplementedError("Subclasses must implement get_model_roots")
async def scan_single_model(self, file_path: str) -> Optional[Dict]:
"""Scan a single model file and return its metadata"""
try:
if not os.path.exists(os.path.realpath(file_path)):
return None
# Get basic file info
metadata = await self._get_file_info(file_path)
if not metadata:
return None
folder = self._calculate_folder(file_path)
# Ensure folder field exists
metadata_dict = metadata.to_dict()
metadata_dict['folder'] = folder or ''
return metadata_dict
except Exception as e:
logger.error(f"Error scanning {file_path}: {e}")
return None
async def _get_file_info(self, file_path: str) -> Optional[BaseModelMetadata]:
async def _create_default_metadata(self, file_path: str) -> Optional[BaseModelMetadata]:
"""Get model file info and metadata (extensible for different model types)"""
return await get_file_info(file_path, self.model_class)
return await MetadataManager.create_default_metadata(file_path, self.model_class)
def _calculate_folder(self, file_path: str) -> str:
"""Calculate the folder path for a model file"""
@@ -534,10 +618,13 @@ class ModelScanner:
return os.path.dirname(rel_path).replace(os.path.sep, '/')
return ''
# Common methods shared between scanners
def adjust_metadata(self, metadata, file_path, root_path):
"""Hook for subclasses: adjust metadata during scanning"""
return metadata
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single model file and return its metadata"""
metadata = await load_metadata(file_path, self.model_class)
metadata = await MetadataManager.load_metadata(file_path, self.model_class)
if metadata is None:
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
@@ -553,7 +640,7 @@ class ModelScanner:
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await save_metadata(file_path, metadata)
await MetadataManager.save_metadata(file_path, metadata, True)
logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
@@ -580,13 +667,16 @@ class ModelScanner:
metadata.modelDescription = version_info['model']['description']
# Save the updated metadata
await save_metadata(file_path, metadata)
await MetadataManager.save_metadata(file_path, metadata, True)
logger.debug(f"Updated metadata with civitai info for {file_path}")
except Exception as e:
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
if metadata is None:
metadata = await self._get_file_info(file_path)
metadata = await self._create_default_metadata(file_path)
# Hook: allow subclasses to adjust metadata
metadata = self.adjust_metadata(metadata, file_path, root_path)
model_data = metadata.to_dict()
@@ -594,6 +684,14 @@ class ModelScanner:
if model_data.get('exclude', False):
self._excluded_models.append(model_data['file_path'])
return None
# Check for duplicate filename before adding to hash index
filename = os.path.splitext(os.path.basename(file_path))[0]
existing_hash = self._hash_index.get_hash_by_filename(filename)
if existing_hash and existing_hash != model_data.get('sha256', '').lower():
existing_path = self._hash_index.get_path(existing_hash)
if existing_path and existing_path != file_path:
logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
await self._fetch_missing_metadata(file_path, model_data)
rel_path = os.path.relpath(file_path, root_path)
@@ -636,9 +734,7 @@ class ModelScanner:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
model_data['civitai_deleted'] = True
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(file_path, model_data)
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
@@ -651,53 +747,44 @@ class ModelScanner:
model_data['civitai']['creator'] = model_metadata['creator']
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(model_data, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(file_path, model_data, True)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Base implementation for directory scanning"""
models = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
"""Add a model to the cache
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
"""Process a single file and add to results list"""
Args:
metadata_dict: The model metadata dictionary
folder: The relative folder path for the model
Returns:
bool: True if successful, False otherwise
"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models_list.append(result)
if self._cache is None:
await self.get_cached_data()
# Update folder in metadata
metadata_dict['folder'] = folder
# Add to cache
self._cache.raw_data.append(metadata_dict)
# Resort cache data
await self._cache.resort()
# Update folders list
all_folders = set(self._cache.folders)
all_folders.add(folder)
self._cache.folders = sorted(list(all_folders), key=lambda x: x.lower())
# Update the hash index
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
return True
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
logger.error(f"Error adding model to cache: {e}")
return False
async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location"""
@@ -721,47 +808,42 @@ class ModelScanner:
real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_file)
file_size = os.path.getsize(real_source)
# Get the appropriate file monitor through ServiceRegistry
if self.model_type == "lora":
monitor = await ServiceRegistry.get_lora_monitor()
elif self.model_type == "checkpoint":
monitor = await ServiceRegistry.get_checkpoint_monitor()
else:
monitor = None
if monitor:
monitor.handler.add_ignore_path(
real_source,
file_size
)
monitor.handler.add_ignore_path(
real_target,
file_size
)
shutil.move(real_source, real_target)
source_metadata = os.path.join(source_dir, f"{base_name}.metadata.json")
# Move all associated files with the same base name
source_metadata = None
moved_metadata_path = None
# Find all files with the same base name in the source directory
files_to_move = []
try:
for file in os.listdir(source_dir):
if file.startswith(base_name + ".") and file != os.path.basename(source_path):
source_file_path = os.path.join(source_dir, file)
# Store metadata file path for special handling
if file == f"{base_name}.metadata.json":
source_metadata = source_file_path
moved_metadata_path = os.path.join(target_path, file)
else:
files_to_move.append((source_file_path, os.path.join(target_path, file)))
except Exception as e:
logger.error(f"Error listing files in {source_dir}: {e}")
# Move all associated files
metadata = None
if os.path.exists(source_metadata):
target_metadata = os.path.join(target_path, f"{base_name}.metadata.json")
shutil.move(source_metadata, target_metadata)
metadata = await self._update_metadata_paths(target_metadata, target_file)
for source_file, target_file_path in files_to_move:
try:
shutil.move(source_file, target_file_path)
except Exception as e:
logger.error(f"Error moving associated file {source_file}: {e}")
# Move civitai.info file if exists
source_civitai = os.path.join(source_dir, f"{base_name}.civitai.info")
if os.path.exists(source_civitai):
target_civitai = os.path.join(target_path, f"{base_name}.civitai.info")
shutil.move(source_civitai, target_civitai)
for ext in PREVIEW_EXTENSIONS:
source_preview = os.path.join(source_dir, f"{base_name}{ext}")
if os.path.exists(source_preview):
target_preview = os.path.join(target_path, f"{base_name}{ext}")
shutil.move(source_preview, target_preview)
break
# Handle metadata file specially to update paths
if source_metadata and os.path.exists(source_metadata):
try:
shutil.move(source_metadata, moved_metadata_path)
metadata = await self._update_metadata_paths(moved_metadata_path, target_file)
except Exception as e:
logger.error(f"Error moving metadata file: {e}")
await self.update_single_model_cache(source_path, target_file, metadata)
@@ -779,15 +861,14 @@ class ModelScanner:
metadata['file_path'] = model_path.replace(os.sep, '/')
if 'preview_url' in metadata:
if 'preview_url' in metadata and metadata['preview_url']:
preview_dir = os.path.dirname(model_path)
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
preview_ext = os.path.splitext(metadata['preview_url'])[1]
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(metadata_path, metadata)
return metadata
@@ -918,12 +999,13 @@ class ModelScanner:
"""Get list of excluded model file paths"""
return self._excluded_models.copy()
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
async def update_preview_in_cache(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
"""Update preview URL in cache for a specific lora
Args:
file_path: The file path of the lora to update
preview_url: The new preview URL
preview_nsfw_level: The NSFW level of the preview
Returns:
bool: True if the update was successful, False if cache doesn't exist or lora wasn't found
@@ -931,4 +1013,213 @@ class ModelScanner:
if self._cache is None:
return False
return await self._cache.update_preview_url(file_path, preview_url)
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
"""Delete multiple models and update cache in a batch operation
Args:
file_paths: List of file paths to delete
Returns:
Dict containing results of the operation
"""
try:
if not file_paths:
return {
'success': False,
'error': 'No file paths provided for deletion',
'results': []
}
# Keep track of success and failures
results = []
total_deleted = 0
cache_updated = False
# Get cache data
cache = await self.get_cached_data()
# Track deleted models to update cache once
deleted_models = []
for file_path in file_paths:
try:
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
# Delete all associated files for the model
from ..utils.routes_common import ModelRouteUtils
deleted_files = await ModelRouteUtils.delete_model_files(
target_dir,
file_name
)
if deleted_files:
deleted_models.append(file_path)
results.append({
'file_path': file_path,
'success': True,
'deleted_files': deleted_files
})
total_deleted += 1
else:
results.append({
'file_path': file_path,
'success': False,
'error': 'No files deleted'
})
except Exception as e:
logger.error(f"Error deleting file {file_path}: {e}")
results.append({
'file_path': file_path,
'success': False,
'error': str(e)
})
# Batch update cache if any models were deleted
if deleted_models:
# Update the cache in a batch operation
cache_updated = await self._batch_update_cache_for_deleted_models(deleted_models)
return {
'success': True,
'total_deleted': total_deleted,
'total_attempted': len(file_paths),
'cache_updated': cache_updated,
'results': results
}
except Exception as e:
logger.error(f"Error in bulk delete: {e}", exc_info=True)
return {
'success': False,
'error': str(e),
'results': []
}
async def _batch_update_cache_for_deleted_models(self, file_paths: List[str]) -> bool:
"""Update cache after multiple models have been deleted
Args:
file_paths: List of file paths that were deleted
Returns:
bool: True if cache was updated and saved successfully
"""
if not file_paths or self._cache is None:
return False
try:
# Get all models that need to be removed from cache
models_to_remove = [item for item in self._cache.raw_data if item['file_path'] in file_paths]
if not models_to_remove:
return False
# Update tag counts
for model in models_to_remove:
for tag in model.get('tags', []):
if tag in self._tags_count:
self._tags_count[tag] = max(0, self._tags_count[tag] - 1)
if self._tags_count[tag] == 0:
del self._tags_count[tag]
# Update hash index
for model in models_to_remove:
file_path = model['file_path']
if hasattr(self, '_hash_index') and self._hash_index:
# Get the hash and filename before removal for duplicate checking
file_name = os.path.splitext(os.path.basename(file_path))[0]
hash_val = model.get('sha256', '').lower()
# Remove from hash index
self._hash_index.remove_by_path(file_path, hash_val)
# Check and clean up duplicates
self._cleanup_duplicates_after_removal(hash_val, file_name)
# Update cache data
self._cache.raw_data = [item for item in self._cache.raw_data if item['file_path'] not in file_paths]
# Resort cache
await self._cache.resort()
return True
except Exception as e:
logger.error(f"Error updating cache after bulk delete: {e}", exc_info=True)
return False
def _cleanup_duplicates_after_removal(self, hash_val: str, file_name: str) -> None:
"""Clean up duplicate entries in hash index after removing a model
Args:
hash_val: SHA256 hash of the removed model
file_name: File name of the removed model without extension
"""
if not hash_val or not file_name or not hasattr(self, '_hash_index'):
return
# Clean up hash duplicates if only 0 or 1 entries remain
if hash_val in self._hash_index._duplicate_hashes:
if len(self._hash_index._duplicate_hashes[hash_val]) <= 1:
del self._hash_index._duplicate_hashes[hash_val]
# Clean up filename duplicates if only 0 or 1 entries remain
if file_name in self._hash_index._duplicate_filenames:
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
del self._hash_index._duplicate_filenames[file_name]
async def check_model_version_exists(self, model_version_id: int) -> bool:
"""Check if a specific model version exists in the cache
Args:
model_version_id: Civitai model version ID
Returns:
bool: True if the model version exists, False otherwise
"""
try:
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
return False
for item in cache.raw_data:
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
return True
return False
except Exception as e:
logger.error(f"Error checking model version existence: {e}")
return False
async def get_model_versions_by_id(self, model_id: int) -> List[Dict]:
"""Get all versions of a model by its ID
Args:
model_id: Civitai model ID
Returns:
List[Dict]: List of version information dictionaries
"""
try:
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
return []
versions = []
for item in cache.raw_data:
if (item.get('civitai') and
item['civitai'].get('modelId') == model_id and
item['civitai'].get('id')):
versions.append({
'versionId': item['civitai'].get('id'),
'name': item['civitai'].get('name'),
'fileName': item.get('file_name', '')
})
return versions
except Exception as e:
logger.error(f"Error getting model versions: {e}")
return []

View File

@@ -0,0 +1,142 @@
from typing import Dict, Type, Any
import logging
logger = logging.getLogger(__name__)
class ModelServiceFactory:
"""Factory for managing model services and routes"""
_services: Dict[str, Type] = {}
_routes: Dict[str, Type] = {}
_initialized_services: Dict[str, Any] = {}
_initialized_routes: Dict[str, Any] = {}
@classmethod
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
"""Register a new model type with its service and route classes
Args:
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
service_class: The service class for this model type
route_class: The route class for this model type
"""
cls._services[model_type] = service_class
cls._routes[model_type] = route_class
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
@classmethod
def get_service_class(cls, model_type: str) -> Type:
"""Get service class for a model type
Args:
model_type: The model type identifier
Returns:
The service class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._services:
raise ValueError(f"Unknown model type: {model_type}")
return cls._services[model_type]
@classmethod
def get_route_class(cls, model_type: str) -> Type:
"""Get route class for a model type
Args:
model_type: The model type identifier
Returns:
The route class for the model type
Raises:
ValueError: If model type is not registered
"""
if model_type not in cls._routes:
raise ValueError(f"Unknown model type: {model_type}")
return cls._routes[model_type]
@classmethod
def get_route_instance(cls, model_type: str):
"""Get or create route instance for a model type
Args:
model_type: The model type identifier
Returns:
The route instance for the model type
"""
if model_type not in cls._initialized_routes:
route_class = cls.get_route_class(model_type)
cls._initialized_routes[model_type] = route_class()
return cls._initialized_routes[model_type]
@classmethod
def setup_all_routes(cls, app):
"""Setup routes for all registered model types
Args:
app: The aiohttp application instance
"""
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
for model_type in cls._services.keys():
try:
routes_instance = cls.get_route_instance(model_type)
routes_instance.setup_routes(app)
logger.info(f"Successfully set up routes for {model_type}")
except Exception as e:
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
@classmethod
def get_registered_types(cls) -> list:
"""Get list of all registered model types
Returns:
List of registered model type identifiers
"""
return list(cls._services.keys())
@classmethod
def is_registered(cls, model_type: str) -> bool:
"""Check if a model type is registered
Args:
model_type: The model type identifier
Returns:
True if the model type is registered, False otherwise
"""
return model_type in cls._services
@classmethod
def clear_registrations(cls):
"""Clear all registrations - mainly for testing purposes"""
cls._services.clear()
cls._routes.clear()
cls._initialized_services.clear()
cls._initialized_routes.clear()
logger.info("Cleared all model type registrations")
def register_default_model_types():
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
from ..services.lora_service import LoraService
from ..services.checkpoint_service import CheckpointService
from ..services.embedding_service import EmbeddingService
from ..routes.lora_routes import LoraRoutes
from ..routes.checkpoint_routes import CheckpointRoutes
from ..routes.embedding_routes import EmbeddingRoutes
# Register LoRA model type
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
# Register Checkpoint model type
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
# Register Embedding model type
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
logger.info("Registered default model types: lora, checkpoint, embedding")

View File

@@ -393,8 +393,8 @@ class RecipeScanner:
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
hash_value = lora['hash']
if self._lora_scanner.has_lora_hash(hash_value):
lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value)
if self._lora_scanner.has_hash(hash_value):
lora_path = self._lora_scanner.get_path_by_hash(hash_value)
if lora_path:
file_name = os.path.splitext(os.path.basename(lora_path))[0]
lora['file_name'] = file_name
@@ -465,7 +465,7 @@ class RecipeScanner:
# Count occurrences of each base model
for lora in loras:
if 'hash' in lora:
lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash'])
lora_path = self._lora_scanner.get_path_by_hash(lora['hash'])
if lora_path:
base_model = await self._get_base_model_for_lora(lora_path)
if base_model:
@@ -603,9 +603,9 @@ class RecipeScanner:
if 'loras' in item:
for lora in item['loras']:
if 'hash' in lora and lora['hash']:
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower())
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
result = {
'items': paginated_items,
@@ -655,9 +655,9 @@ class RecipeScanner:
for lora in formatted_recipe['loras']:
if 'hash' in lora and lora['hash']:
lora_hash = lora['hash'].lower()
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
return formatted_recipe

View File

@@ -7,118 +7,209 @@ logger = logging.getLogger(__name__)
T = TypeVar('T') # Define a type variable for service types
class ServiceRegistry:
"""Centralized registry for service singletons"""
"""Central registry for managing singleton services"""
_instance = None
_services: Dict[str, Any] = {}
_lock = asyncio.Lock()
_locks: Dict[str, asyncio.Lock] = {}
@classmethod
def get_instance(cls):
"""Get singleton instance of the registry"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
async def register_service(cls, name: str, service: Any) -> None:
"""Register a service instance with the registry
Args:
name: Service name identifier
service: Service instance to register
"""
cls._services[name] = service
logger.debug(f"Registered service: {name}")
@classmethod
async def register_service(cls, service_name: str, service_instance: Any) -> None:
"""Register a service instance with the registry"""
registry = cls.get_instance()
async with cls._lock:
registry._services[service_name] = service_instance
logger.debug(f"Registered service: {service_name}")
async def get_service(cls, name: str) -> Optional[Any]:
"""Get a service instance by name
Args:
name: Service name identifier
Returns:
Service instance or None if not found
"""
return cls._services.get(name)
@classmethod
async def get_service(cls, service_name: str) -> Any:
"""Get a service instance by name"""
registry = cls.get_instance()
async with cls._lock:
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
def get_service_sync(cls, name: str) -> Optional[Any]:
"""Synchronously get a service instance by name
Args:
name: Service name identifier
Returns:
Service instance or None if not found
"""
return cls._services.get(name)
@classmethod
def _get_lock(cls, name: str) -> asyncio.Lock:
"""Get or create a lock for a service
Args:
name: Service name identifier
Returns:
AsyncIO lock for the service
"""
if name not in cls._locks:
cls._locks[name] = asyncio.Lock()
return cls._locks[name]
# Convenience methods for common services
@classmethod
async def get_lora_scanner(cls):
"""Get the LoraScanner instance"""
from .lora_scanner import LoraScanner
scanner = await cls.get_service("lora_scanner")
if scanner is None:
scanner = await LoraScanner.get_instance()
await cls.register_service("lora_scanner", scanner)
return scanner
"""Get or create LoRA scanner instance"""
service_name = "lora_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .lora_scanner import LoraScanner
scanner = await LoraScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_checkpoint_scanner(cls):
"""Get the CheckpointScanner instance"""
from .checkpoint_scanner import CheckpointScanner
scanner = await cls.get_service("checkpoint_scanner")
if scanner is None:
"""Get or create Checkpoint scanner instance"""
service_name = "checkpoint_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .checkpoint_scanner import CheckpointScanner
scanner = await CheckpointScanner.get_instance()
await cls.register_service("checkpoint_scanner", scanner)
return scanner
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_lora_monitor(cls):
"""Get the LoraFileMonitor instance"""
from .file_monitor import LoraFileMonitor
monitor = await cls.get_service("lora_monitor")
if monitor is None:
monitor = await LoraFileMonitor.get_instance()
await cls.register_service("lora_monitor", monitor)
return monitor
@classmethod
async def get_checkpoint_monitor(cls):
"""Get the CheckpointFileMonitor instance"""
from .file_monitor import CheckpointFileMonitor
monitor = await cls.get_service("checkpoint_monitor")
if monitor is None:
monitor = await CheckpointFileMonitor.get_instance()
await cls.register_service("checkpoint_monitor", monitor)
return monitor
@classmethod
async def get_civitai_client(cls):
"""Get the CivitaiClient instance"""
from .civitai_client import CivitaiClient
client = await cls.get_service("civitai_client")
if client is None:
client = await CivitaiClient.get_instance()
await cls.register_service("civitai_client", client)
return client
@classmethod
async def get_download_manager(cls):
"""Get the DownloadManager instance"""
from .download_manager import DownloadManager
manager = await cls.get_service("download_manager")
if manager is None:
# We'll let DownloadManager.get_instance handle file_monitor parameter
manager = await DownloadManager.get_instance()
await cls.register_service("download_manager", manager)
return manager
@classmethod
async def get_recipe_scanner(cls):
"""Get the RecipeScanner instance"""
from .recipe_scanner import RecipeScanner
scanner = await cls.get_service("recipe_scanner")
if scanner is None:
lora_scanner = await cls.get_lora_scanner()
scanner = RecipeScanner(lora_scanner)
await cls.register_service("recipe_scanner", scanner)
return scanner
"""Get or create Recipe scanner instance"""
service_name = "recipe_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .recipe_scanner import RecipeScanner
scanner = await RecipeScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
async def get_civitai_client(cls):
"""Get or create CivitAI client instance"""
service_name = "civitai_client"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .civitai_client import CivitaiClient
client = await CivitaiClient.get_instance()
cls._services[service_name] = client
logger.debug(f"Created and registered {service_name}")
return client
@classmethod
async def get_download_manager(cls):
"""Get or create Download manager instance"""
service_name = "download_manager"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .download_manager import DownloadManager
manager = DownloadManager()
cls._services[service_name] = manager
logger.debug(f"Created and registered {service_name}")
return manager
@classmethod
async def get_websocket_manager(cls):
"""Get the WebSocketManager instance"""
from .websocket_manager import ws_manager
manager = await cls.get_service("websocket_manager")
if manager is None:
# ws_manager is already a global instance in websocket_manager.py
"""Get or create WebSocket manager instance"""
service_name = "websocket_manager"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .websocket_manager import ws_manager
await cls.register_service("websocket_manager", ws_manager)
manager = ws_manager
return manager
cls._services[service_name] = ws_manager
logger.debug(f"Registered {service_name}")
return ws_manager
@classmethod
async def get_embedding_scanner(cls):
"""Get or create Embedding scanner instance"""
service_name = "embedding_scanner"
if service_name in cls._services:
return cls._services[service_name]
async with cls._get_lock(service_name):
# Double-check after acquiring lock
if service_name in cls._services:
return cls._services[service_name]
# Import here to avoid circular imports
from .embedding_scanner import EmbeddingScanner
scanner = await EmbeddingScanner.get_instance()
cls._services[service_name] = scanner
logger.debug(f"Created and registered {service_name}")
return scanner
@classmethod
def clear_services(cls):
"""Clear all registered services - mainly for testing"""
cls._services.clear()
cls._locks.clear()
logger.info("Cleared all registered services")

View File

@@ -9,6 +9,8 @@ class SettingsManager:
def __init__(self):
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'settings.json')
self.settings = self._load_settings()
self._migrate_download_path_template()
self._auto_set_default_roots()
self._check_environment_variables()
def _load_settings(self) -> Dict[str, Any]:
@@ -21,6 +23,46 @@ class SettingsManager:
logger.error(f"Error loading settings: {e}")
return self._get_default_settings()
def _migrate_download_path_template(self):
"""Migrate old download_path_template to new download_path_templates"""
old_template = self.settings.get('download_path_template')
templates = self.settings.get('download_path_templates')
# If old template exists and new templates don't exist, migrate
if old_template is not None and not templates:
logger.info("Migrating download_path_template to download_path_templates")
self.settings['download_path_templates'] = {
'lora': old_template,
'checkpoint': old_template,
'embedding': old_template
}
# Remove old setting
del self.settings['download_path_template']
self._save_settings()
logger.info("Migration completed")
def _auto_set_default_roots(self):
"""Auto set default root paths if only one folder is present and default is empty."""
folder_paths = self.settings.get('folder_paths', {})
updated = False
# loras
loras = folder_paths.get('loras', [])
if isinstance(loras, list) and len(loras) == 1 and not self.settings.get('default_lora_root'):
self.settings['default_lora_root'] = loras[0]
updated = True
# checkpoints
checkpoints = folder_paths.get('checkpoints', [])
if isinstance(checkpoints, list) and len(checkpoints) == 1 and not self.settings.get('default_checkpoint_root'):
self.settings['default_checkpoint_root'] = checkpoints[0]
updated = True
# embeddings
embeddings = folder_paths.get('embeddings', [])
if isinstance(embeddings, list) and len(embeddings) == 1 and not self.settings.get('default_embedding_root'):
self.settings['default_embedding_root'] = embeddings[0]
updated = True
if updated:
self._save_settings()
def _check_environment_variables(self) -> None:
"""Check for environment variables and update settings if needed"""
env_api_key = os.environ.get('CIVITAI_API_KEY')
@@ -58,4 +100,16 @@ class SettingsManager:
except Exception as e:
logger.error(f"Error saving settings: {e}")
def get_download_path_template(self, model_type: str) -> str:
"""Get download path template for specific model type
Args:
model_type: The type of model ('lora', 'checkpoint', 'embedding')
Returns:
Template string for the model type, defaults to '{base_model}/{first_tag}'
"""
templates = self.settings.get('download_path_templates', {})
return templates.get(model_type, '{base_model}/{first_tag}')
settings = SettingsManager()

View File

@@ -1,6 +1,9 @@
import logging
from aiohttp import web
from typing import Set, Dict, Optional
from uuid import uuid4
import asyncio
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
@@ -10,7 +13,12 @@ class WebSocketManager:
def __init__(self):
self._websockets: Set[web.WebSocketResponse] = set()
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
# Add progress tracking dictionary
self._download_progress: Dict[str, Dict] = {}
# Add auto-organize progress tracking
self._auto_organize_progress: Optional[Dict] = None
self._auto_organize_lock = asyncio.Lock()
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection"""
@@ -39,21 +47,48 @@ class WebSocketManager:
finally:
self._init_websockets.discard(ws)
return ws
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for checkpoint download progress"""
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection for download progress"""
ws = web.WebSocketResponse()
await ws.prepare(request)
self._checkpoint_websockets.add(ws)
# Get download_id from query parameters
download_id = request.query.get('id')
if not download_id:
# Generate a new download ID if not provided
download_id = str(uuid4())
# Store the websocket with its download ID
self._download_websockets[download_id] = ws
try:
# Send the download ID back to the client
await ws.send_json({
'type': 'download_id',
'download_id': download_id
})
async for msg in ws:
if msg.type == web.WSMsgType.ERROR:
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
logger.error(f'Download WebSocket error: {ws.exception()}')
finally:
self._checkpoint_websockets.discard(ws)
if download_id in self._download_websockets:
del self._download_websockets[download_id]
# Schedule cleanup of completed downloads after WebSocket disconnection
asyncio.create_task(self._delayed_cleanup(download_id))
return ws
async def _delayed_cleanup(self, download_id: str, delay_seconds: int = 300):
"""Clean up download progress after a delay (5 minutes by default)"""
await asyncio.sleep(delay_seconds)
progress_data = self._download_progress.get(download_id)
if progress_data and progress_data.get('progress', 0) >= 100:
self.cleanup_download_progress(download_id)
logger.debug(f"Delayed cleanup completed for download {download_id}")
async def broadcast(self, data: Dict):
"""Broadcast message to all connected clients"""
if not self._websockets:
@@ -84,17 +119,72 @@ class WebSocketManager:
except Exception as e:
logger.error(f"Error sending initialization progress: {e}")
async def broadcast_checkpoint_progress(self, data: Dict):
"""Broadcast checkpoint download progress to connected clients"""
if not self._checkpoint_websockets:
async def broadcast_download_progress(self, download_id: str, data: Dict):
"""Send progress update to specific download client"""
# Store simplified progress data in memory (only progress percentage)
self._download_progress[download_id] = {
'progress': data.get('progress', 0),
'timestamp': datetime.now()
}
if download_id not in self._download_websockets:
logger.debug(f"No WebSocket found for download ID: {download_id}")
return
for ws in self._checkpoint_websockets:
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending checkpoint progress: {e}")
ws = self._download_websockets[download_id]
try:
await ws.send_json(data)
except Exception as e:
logger.error(f"Error sending download progress: {e}")
async def broadcast_auto_organize_progress(self, data: Dict):
"""Broadcast auto-organize progress to connected clients"""
# Store progress data in memory
self._auto_organize_progress = data
# Broadcast via WebSocket
await self.broadcast(data)
def get_auto_organize_progress(self) -> Optional[Dict]:
"""Get current auto-organize progress"""
return self._auto_organize_progress
def cleanup_auto_organize_progress(self):
"""Clear auto-organize progress data"""
self._auto_organize_progress = None
def is_auto_organize_running(self) -> bool:
"""Check if auto-organize is currently running"""
if not self._auto_organize_progress:
return False
status = self._auto_organize_progress.get('status')
return status in ['started', 'processing', 'cleaning']
async def get_auto_organize_lock(self):
"""Get the auto-organize lock"""
return self._auto_organize_lock
def get_download_progress(self, download_id: str) -> Optional[Dict]:
"""Get progress information for a specific download"""
return self._download_progress.get(download_id)
def cleanup_download_progress(self, download_id: str):
"""Remove progress info for a specific download"""
self._download_progress.pop(download_id, None)
def cleanup_old_downloads(self, max_age_hours: int = 24):
"""Clean up old download progress entries"""
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
to_remove = []
for download_id, progress_data in self._download_progress.items():
if progress_data.get('timestamp', datetime.now()) < cutoff_time:
to_remove.append(download_id)
for download_id in to_remove:
self._download_progress.pop(download_id, None)
logger.debug(f"Cleaned up old download progress for {download_id}")
def get_connected_clients_count(self) -> int:
"""Get number of connected clients"""
return len(self._websockets)
@@ -102,10 +192,14 @@ class WebSocketManager:
def get_init_clients_count(self) -> int:
"""Get number of initialization progress clients"""
return len(self._init_websockets)
def get_checkpoint_clients_count(self) -> int:
"""Get number of checkpoint progress clients"""
return len(self._checkpoint_websockets)
def get_download_clients_count(self) -> int:
"""Get number of download progress clients"""
return len(self._download_websockets)
def generate_download_id(self) -> str:
"""Generate a unique download ID"""
return str(uuid4())
# Global instance
ws_manager = WebSocketManager()

View File

@@ -7,6 +7,16 @@ NSFW_LEVELS = {
"Blocked": 32, # Probably not actually visible through the API without being logged in on model owner account?
}
# Node type constants
NODE_TYPES = {
"Lora Loader (LoraManager)": 1,
"Lora Stacker (LoraManager)": 2,
"WanVideo Lora Select (LoraManager)": 3
}
# Default ComfyUI node color when bgcolor is null
DEFAULT_NODE_COLOR = "#353535"
# preview extensions
PREVIEW_EXTENSIONS = [
'.webp',
@@ -18,7 +28,9 @@ PREVIEW_EXTENSIONS = [
'.png',
'.jpeg',
'.jpg',
'.mp4'
'.mp4',
'.gif',
'.webm'
]
# Card preview image width
@@ -31,4 +43,18 @@ EXAMPLE_IMAGE_WIDTH = 832
SUPPORTED_MEDIA_EXTENSIONS = {
'images': ['.jpg', '.jpeg', '.png', '.webp', '.gif'],
'videos': ['.mp4', '.webm']
}
}
# Valid Lora types
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
# Auto-organize settings
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
# Civitai model tags in priority order for subfolder organization
CIVITAI_MODEL_TAGS = [
'character', 'style', 'concept', 'clothing',
# 'base model', # exclude 'base model'
'poses', 'background', 'tool', 'vehicle', 'buildings',
'objects', 'assets', 'animal', 'action'
]

View File

@@ -0,0 +1,796 @@
import logging
import os
import asyncio
import json
import time
import aiohttp
from aiohttp import web
from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater
from ..services.websocket_manager import ws_manager # Add this import at the top
logger = logging.getLogger(__name__)
# Download status tracking
download_task = None
is_downloading = False
download_progress = {
'total': 0,
'completed': 0,
'current_model': '',
'status': 'idle', # idle, running, paused, completed, error
'errors': [],
'last_error': None,
'start_time': None,
'end_time': None,
'processed_models': set(), # Track models that have been processed
'refreshed_models': set(), # Track models that had metadata refreshed
'failed_models': set() # Track models that failed to download after metadata refresh
}
class DownloadManager:
"""Manages downloading example images for models"""
@staticmethod
async def start_download(request):
"""
Start downloading example images for models
Expects a JSON body with:
{
"output_dir": "path/to/output", # Base directory to save example images
"optimize": true, # Whether to optimize images (default: true)
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
"delay": 1.0 # Delay between downloads to avoid rate limiting (default: 1.0)
}
"""
global download_task, is_downloading, download_progress
if is_downloading:
# Create a copy for JSON serialization
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
'success': False,
'error': 'Download already in progress',
'status': response_progress
}, status=400)
try:
# Parse the request body
data = await request.json()
output_dir = data.get('output_dir')
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
if not output_dir:
return web.json_response({
'success': False,
'error': 'Missing output_dir parameter'
}, status=400)
# Create the output directory
os.makedirs(output_dir, exist_ok=True)
# Initialize progress tracking
download_progress['total'] = 0
download_progress['completed'] = 0
download_progress['current_model'] = ''
download_progress['status'] = 'running'
download_progress['errors'] = []
download_progress['last_error'] = None
download_progress['start_time'] = time.time()
download_progress['end_time'] = None
# Get the processed models list from a file if it exists
progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
saved_progress = json.load(f)
download_progress['processed_models'] = set(saved_progress.get('processed_models', []))
download_progress['failed_models'] = set(saved_progress.get('failed_models', []))
logger.debug(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed, {len(download_progress['failed_models'])} models marked as failed")
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
download_progress['processed_models'] = set()
download_progress['failed_models'] = set()
else:
download_progress['processed_models'] = set()
download_progress['failed_models'] = set()
# Start the download task
is_downloading = True
download_task = asyncio.create_task(
DownloadManager._download_all_example_images(
output_dir,
optimize,
model_types,
delay
)
)
# Create a copy for JSON serialization
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
'success': True,
'message': 'Download started',
'status': response_progress
})
except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_status(request):
"""Get the current status of example images download"""
global download_progress
# Create a copy of the progress dict with the set converted to a list for JSON serialization
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
response_progress['failed_models'] = list(download_progress['failed_models'])
return web.json_response({
'success': True,
'is_downloading': is_downloading,
'status': response_progress
})
@staticmethod
async def pause_download(request):
"""Pause the example images download"""
global download_progress
if not is_downloading:
return web.json_response({
'success': False,
'error': 'No download in progress'
}, status=400)
download_progress['status'] = 'paused'
return web.json_response({
'success': True,
'message': 'Download paused'
})
@staticmethod
async def resume_download(request):
"""Resume the example images download"""
global download_progress
if not is_downloading:
return web.json_response({
'success': False,
'error': 'No download in progress'
}, status=400)
if download_progress['status'] == 'paused':
download_progress['status'] = 'running'
return web.json_response({
'success': True,
'message': 'Download resumed'
})
else:
return web.json_response({
'success': False,
'error': f"Download is in '{download_progress['status']}' state, cannot resume"
}, status=400)
@staticmethod
async def _download_all_example_images(output_dir, optimize, model_types, delay):
"""Download example images for all models"""
global is_downloading, download_progress
# Create independent download session
connector = aiohttp.TCPConnector(
ssl=True,
limit=3,
force_close=False,
enable_cleanup_closed=True
)
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
independent_session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=timeout
)
try:
# Get scanners
scanners = []
if 'lora' in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(('lora', lora_scanner))
if 'checkpoint' in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(('checkpoint', checkpoint_scanner))
if 'embedding' in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(('embedding', embedding_scanner))
# Get all models
all_models = []
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
if model.get('sha256'):
all_models.append((scanner_type, model, scanner))
# Update total count
download_progress['total'] = len(all_models)
logger.debug(f"Found {download_progress['total']} models to process")
# Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models):
# Main logic for processing model is here, but actual operations are delegated to other classes
was_remote_download = await DownloadManager._process_model(
scanner_type, model, scanner,
output_dir, optimize, independent_session
)
# Update progress
download_progress['completed'] += 1
# Only add delay after remote download of models, and not after processing the last model
if was_remote_download and i < len(all_models) - 1 and download_progress['status'] == 'running':
await asyncio.sleep(delay)
# Mark as completed
download_progress['status'] = 'completed'
download_progress['end_time'] = time.time()
logger.debug(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
except Exception as e:
error_msg = f"Error during example images download: {str(e)}"
logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
download_progress['status'] = 'error'
download_progress['end_time'] = time.time()
finally:
# Close the independent session
try:
await independent_session.close()
except Exception as e:
logger.error(f"Error closing download session: {e}")
# Save final progress to file
try:
DownloadManager._save_progress(output_dir)
except Exception as e:
logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading
is_downloading = False
@staticmethod
async def _process_model(scanner_type, model, scanner, output_dir, optimize, independent_session):
"""Process a single model download"""
global download_progress
# Check if download is paused
while download_progress['status'] == 'paused':
await asyncio.sleep(1)
# Check if download should continue
if download_progress['status'] != 'running':
logger.info(f"Download stopped: {download_progress['status']}")
return False # Return False to indicate no remote download happened
model_hash = model.get('sha256', '').lower()
model_name = model.get('model_name', 'Unknown')
model_file_path = model.get('file_path', '')
model_file_name = model.get('file_name', '')
try:
# Update current model info
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
# Skip if already in failed models
if model_hash in download_progress['failed_models']:
logger.debug(f"Skipping known failed model: {model_name}")
return False
# Skip if already processed AND directory exists with files
if model_hash in download_progress['processed_models']:
model_dir = os.path.join(output_dir, model_hash)
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
if has_files:
logger.debug(f"Skipping already processed model: {model_name}")
return False
else:
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
# Remove from processed models since we need to reprocess
download_progress['processed_models'].discard(model_hash)
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay
local_images_processed = await ExampleImagesProcessor.process_local_examples(
model_file_path, model_file_name, model_name, model_dir, optimize
)
# If we processed local images, update metadata
if local_images_processed:
await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir
)
download_progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened
# If no local images, try to download from remote
elif model.get('civitai') and model.get('civitai', {}).get('images'):
images = model.get('civitai', {}).get('images', [])
success, is_stale = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, images, model_dir, optimize, independent_session
)
# If metadata is stale, try to refresh it
if is_stale and model_hash not in download_progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner
)
# Get the updated model data
updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
if updated_model and updated_model.get('civitai', {}).get('images'):
# Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', [])
success, _ = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, updated_images, model_dir, optimize, independent_session
)
download_progress['refreshed_models'].add(model_hash)
# Mark as processed if successful, or as failed if unsuccessful after refresh
if success:
download_progress['processed_models'].add(model_hash)
else:
# If we refreshed metadata and still failed, mark as permanently failed
if model_hash in download_progress['refreshed_models']:
download_progress['failed_models'].add(model_hash)
logger.info(f"Marking model {model_name} as failed after metadata refresh")
return True # Return True to indicate a remote download happened
else:
# No civitai data or images available, mark as failed to avoid future attempts
download_progress['failed_models'].add(model_hash)
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
# Save progress periodically
if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1:
DownloadManager._save_progress(output_dir)
return False # Default return if no conditions met
except Exception as e:
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
return False # Return False on exception
@staticmethod
def _save_progress(output_dir):
"""Save download progress to file"""
global download_progress
try:
progress_file = os.path.join(output_dir, '.download_progress.json')
# Read existing progress file if it exists
existing_data = {}
if os.path.exists(progress_file):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
except Exception as e:
logger.warning(f"Failed to read existing progress file: {e}")
# Create new progress data
progress_data = {
'processed_models': list(download_progress['processed_models']),
'refreshed_models': list(download_progress['refreshed_models']),
'failed_models': list(download_progress['failed_models']),
'completed': download_progress['completed'],
'total': download_progress['total'],
'last_update': time.time()
}
# Preserve existing fields (especially naming_version)
for key, value in existing_data.items():
if key not in progress_data:
progress_data[key] = value
# Write updated progress data
with open(progress_file, 'w', encoding='utf-8') as f:
json.dump(progress_data, f, indent=2)
except Exception as e:
logger.error(f"Failed to save progress file: {e}")
@staticmethod
async def start_force_download(request):
"""
Force download example images for specific models
Expects a JSON body with:
{
"model_hashes": ["hash1", "hash2", ...], # List of model hashes to download
"output_dir": "path/to/output", # Base directory to save example images
"optimize": true, # Whether to optimize images (default: true)
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
"delay": 1.0 # Delay between downloads (default: 1.0)
}
"""
global download_task, is_downloading, download_progress
if is_downloading:
return web.json_response({
'success': False,
'error': 'Download already in progress'
}, status=400)
try:
# Parse the request body
data = await request.json()
model_hashes = data.get('model_hashes', [])
output_dir = data.get('output_dir')
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
if not model_hashes:
return web.json_response({
'success': False,
'error': 'Missing model_hashes parameter'
}, status=400)
if not output_dir:
return web.json_response({
'success': False,
'error': 'Missing output_dir parameter'
}, status=400)
# Create the output directory
os.makedirs(output_dir, exist_ok=True)
# Initialize progress tracking
download_progress['total'] = len(model_hashes)
download_progress['completed'] = 0
download_progress['current_model'] = ''
download_progress['status'] = 'running'
download_progress['errors'] = []
download_progress['last_error'] = None
download_progress['start_time'] = time.time()
download_progress['end_time'] = None
download_progress['processed_models'] = set()
download_progress['refreshed_models'] = set()
download_progress['failed_models'] = set()
# Set download status to downloading
is_downloading = True
# Execute the download function directly instead of creating a background task
result = await DownloadManager._download_specific_models_example_images_sync(
model_hashes,
output_dir,
optimize,
model_types,
delay
)
# Set download status to not downloading
is_downloading = False
return web.json_response({
'success': True,
'message': 'Force download completed',
'result': result
})
except Exception as e:
# Set download status to not downloading
is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):
"""Download example images for specific models only - synchronous version"""
global download_progress
# Create independent download session
connector = aiohttp.TCPConnector(
ssl=True,
limit=3,
force_close=False,
enable_cleanup_closed=True
)
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
independent_session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=timeout
)
try:
# Get scanners
scanners = []
if 'lora' in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(('lora', lora_scanner))
if 'checkpoint' in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(('checkpoint', checkpoint_scanner))
if 'embedding' in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(('embedding', embedding_scanner))
# Find the specified models
models_to_process = []
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
if model.get('sha256') in model_hashes:
models_to_process.append((scanner_type, model, scanner))
# Update total count based on found models
download_progress['total'] = len(models_to_process)
logger.debug(f"Found {download_progress['total']} models to process")
# Send initial progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': 0,
'total': download_progress['total'],
'status': 'running',
'current_model': ''
})
# Process each model
success_count = 0
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
# Force process this model regardless of previous status
was_successful = await DownloadManager._process_specific_model(
scanner_type, model, scanner,
output_dir, optimize, independent_session
)
if was_successful:
success_count += 1
# Update progress
download_progress['completed'] += 1
# Send progress update via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': download_progress['completed'],
'total': download_progress['total'],
'status': 'running',
'current_model': download_progress['current_model']
})
# Only add delay after remote download, and not after processing the last model
if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running':
await asyncio.sleep(delay)
# Mark as completed
download_progress['status'] = 'completed'
download_progress['end_time'] = time.time()
logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
# Send final progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': download_progress['completed'],
'total': download_progress['total'],
'status': 'completed',
'current_model': ''
})
return {
'total': download_progress['total'],
'processed': download_progress['completed'],
'successful': success_count,
'errors': download_progress['errors']
}
except Exception as e:
error_msg = f"Error during forced example images download: {str(e)}"
logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
download_progress['status'] = 'error'
download_progress['end_time'] = time.time()
# Send error status via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': download_progress['completed'],
'total': download_progress['total'],
'status': 'error',
'error': error_msg,
'current_model': ''
})
raise
finally:
# Close the independent session
try:
await independent_session.close()
except Exception as e:
logger.error(f"Error closing download session: {e}")
@staticmethod
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session):
"""Process a specific model for forced download, ignoring previous download status"""
global download_progress
# Check if download is paused
while download_progress['status'] == 'paused':
await asyncio.sleep(1)
# Check if download should continue
if download_progress['status'] != 'running':
logger.info(f"Download stopped: {download_progress['status']}")
return False
model_hash = model.get('sha256', '').lower()
model_name = model.get('model_name', 'Unknown')
model_file_path = model.get('file_path', '')
model_file_name = model.get('file_name', '')
try:
# Update current model info
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay
local_images_processed = await ExampleImagesProcessor.process_local_examples(
model_file_path, model_file_name, model_name, model_dir, optimize
)
# If we processed local images, update metadata
if local_images_processed:
await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir
)
download_progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened
# If no local images, try to download from remote
elif model.get('civitai') and model.get('civitai', {}).get('images'):
images = model.get('civitai', {}).get('images', [])
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, images, model_dir, optimize, independent_session
)
# If metadata is stale, try to refresh it
if is_stale and model_hash not in download_progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner
)
# Get the updated model data
updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
if updated_model and updated_model.get('civitai', {}).get('images'):
# Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', [])
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, updated_images, model_dir, optimize, independent_session
)
# Combine failed images from both attempts
failed_images.extend(additional_failed_images)
download_progress['refreshed_models'].add(model_hash)
# For forced downloads, remove failed images from metadata
if failed_images:
# Create a copy of images excluding failed ones
await DownloadManager._remove_failed_images_from_metadata(
model_hash, model_name, failed_images, scanner
)
# Mark as processed
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
download_progress['processed_models'].add(model_hash)
return True # Return True to indicate a remote download happened
else:
logger.debug(f"No civitai images available for model {model_name}")
return False
except Exception as e:
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
return False # Return False on exception
@staticmethod
async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner):
"""Remove failed images from model metadata"""
try:
# Get current model data
model_data = await MetadataUpdater.get_updated_model(model_hash, scanner)
if not model_data:
logger.warning(f"Could not find model data for {model_name} to remove failed images")
return
if not model_data.get('civitai', {}).get('images'):
logger.warning(f"No images in metadata for {model_name}")
return
# Get current images
current_images = model_data['civitai']['images']
# Filter out failed images
updated_images = [img for img in current_images if img.get('url') not in failed_images]
# If images were removed, update metadata
if len(updated_images) < len(current_images):
removed_count = len(current_images) - len(updated_images)
logger.info(f"Removing {removed_count} failed images from metadata for {model_name}")
# Update the images list
model_data['civitai']['images'] = updated_images
# Save metadata to file
file_path = model_data.get('file_path')
if file_path:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved updated metadata for {model_name} after removing failed images")
# Update the scanner cache
await scanner.update_single_model_cache(file_path, file_path, model_data)
except Exception as e:
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)

View File

@@ -0,0 +1,209 @@
import logging
import os
import re
import sys
import subprocess
from aiohttp import web
from ..services.settings_manager import settings
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
logger = logging.getLogger(__name__)
class ExampleImagesFileManager:
"""Manages access and operations for example image files"""
@staticmethod
async def open_folder(request):
"""
Open the example images folder for a specific model
Expects a JSON request body with:
{
"model_hash": "sha256_hash" # SHA256 hash of the model
}
"""
try:
# Parse request body
data = await request.json()
model_hash = data.get('model_hash')
if not model_hash:
return web.json_response({
'success': False,
'error': 'Missing model_hash parameter'
}, status=400)
# Get example images path from settings
example_images_path = settings.get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
'error': 'No example images path configured. Please set it in the settings panel first.'
}, status=400)
# Construct folder path for this model
model_folder = os.path.join(example_images_path, model_hash)
model_folder = os.path.abspath(model_folder) # Get absolute path
# Path validation: ensure model_folder is under example_images_path
if not model_folder.startswith(os.path.abspath(example_images_path)):
return web.json_response({
'success': False,
'error': 'Invalid model folder path'
}, status=400)
# Check if folder exists
if not os.path.exists(model_folder):
return web.json_response({
'success': False,
'error': 'No example images found for this model. Download example images first.'
}, status=404)
# Open folder in file explorer
if os.name == 'nt': # Windows
os.startfile(model_folder)
elif os.name == 'posix': # macOS and Linux
if sys.platform == 'darwin': # macOS
subprocess.Popen(['open', model_folder])
else: # Linux
subprocess.Popen(['xdg-open', model_folder])
return web.json_response({
'success': True,
'message': f'Opened example images folder for model {model_hash}'
})
except Exception as e:
logger.error(f"Failed to open example images folder: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_files(request):
"""
Get the list of example image files for a specific model
Expects:
- model_hash in query parameters
Returns:
- List of image files and their paths
"""
try:
# Get model_hash from query parameters
model_hash = request.query.get('model_hash')
if not model_hash:
return web.json_response({
'success': False,
'error': 'Missing model_hash parameter'
}, status=400)
# Get example images path from settings
example_images_path = settings.get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
'error': 'No example images path configured'
}, status=400)
# Construct folder path for this model
model_folder = os.path.join(example_images_path, model_hash)
# Check if folder exists
if not os.path.exists(model_folder):
return web.json_response({
'success': False,
'error': 'No example images found for this model',
'files': []
}, status=404)
# Get list of files in the folder
files = []
for file in os.listdir(model_folder):
file_path = os.path.join(model_folder, file)
if os.path.isfile(file_path):
# Check if file is a supported media file
file_ext = os.path.splitext(file)[1].lower()
if (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
files.append({
'name': file,
'path': f'/example_images_static/{model_hash}/{file}',
'extension': file_ext,
'is_video': file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
})
return web.json_response({
'success': True,
'files': files
})
except Exception as e:
logger.error(f"Failed to get example image files: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def has_images(request):
"""
Check if the example images folder for a model exists and is not empty
Expects:
- model_hash in query parameters
Returns:
- Boolean indicating whether the folder exists and contains images/videos
"""
try:
# Get model_hash from query parameters
model_hash = request.query.get('model_hash')
if not model_hash:
return web.json_response({
'success': False,
'error': 'Missing model_hash parameter'
}, status=400)
# Get example images path from settings
example_images_path = settings.get('example_images_path')
if not example_images_path:
return web.json_response({
'has_images': False
})
# Construct folder path for this model
model_folder = os.path.join(example_images_path, model_hash)
# Check if folder exists
if not os.path.exists(model_folder) or not os.path.isdir(model_folder):
return web.json_response({
'has_images': False
})
# Check if folder contains any supported media files
for file in os.listdir(model_folder):
file_path = os.path.join(model_folder, file)
if os.path.isfile(file_path):
file_ext = os.path.splitext(file)[1].lower()
if (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
return web.json_response({
'has_images': True
})
# If reached here, folder exists but has no supported media files
return web.json_response({
'has_images': False
})
except Exception as e:
logger.error(f"Failed to check example images folder: {e}", exc_info=True)
return web.json_response({
'has_images': False,
'error': str(e)
})

View File

@@ -0,0 +1,390 @@
import logging
import os
import re
from ..utils.metadata_manager import MetadataManager
from ..utils.routes_common import ModelRouteUtils
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..utils.exif_utils import ExifUtils
from ..recipes.constants import GEN_PARAM_KEYS
logger = logging.getLogger(__name__)
class MetadataUpdater:
"""Handles updating model metadata related to example images"""
@staticmethod
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner):
"""Refresh model metadata from CivitAI
Args:
model_hash: SHA256 hash of the model
model_name: Model name (for logging)
scanner_type: Scanner type ('lora' or 'checkpoint')
scanner: Scanner instance for this model type
Returns:
bool: True if metadata was successfully refreshed, False otherwise
"""
from ..utils.example_images_download_manager import download_progress
try:
# Find the model in the scanner cache
cache = await scanner.get_cached_data()
model_data = None
for item in cache.raw_data:
if item.get('sha256') == model_hash:
model_data = item
break
if not model_data:
logger.warning(f"Model {model_name} with hash {model_hash} not found in cache")
return False
file_path = model_data.get('file_path')
if not file_path:
logger.warning(f"Model {model_name} has no file path")
return False
# Track that we're refreshing this model
download_progress['refreshed_models'].add(model_hash)
# Use ModelRouteUtils to refresh metadata
async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata)
success = await ModelRouteUtils.fetch_and_update_model(
model_hash,
file_path,
model_data,
update_cache_func
)
if success:
logger.info(f"Successfully refreshed metadata for {model_name}")
return True
else:
logger.warning(f"Failed to refresh metadata for {model_name}")
return False
except Exception as e:
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
return False
@staticmethod
async def get_updated_model(model_hash, scanner):
"""Get updated model data
Args:
model_hash: SHA256 hash of the model
scanner: Scanner instance
Returns:
dict: Updated model data or None if not found
"""
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('sha256') == model_hash:
return item
return None
@staticmethod
async def update_metadata_from_local_examples(model_hash, model, scanner_type, scanner, model_dir):
"""Update model metadata with local example image information
Args:
model_hash: SHA256 hash of the model
model: Model data dictionary
scanner_type: Scanner type ('lora' or 'checkpoint')
scanner: Scanner instance for this model type
model_dir: Model images directory
Returns:
bool: True if metadata was successfully updated, False otherwise
"""
try:
# Collect local image paths
local_images_paths = []
if os.path.exists(model_dir):
for file in os.listdir(model_dir):
file_path = os.path.join(model_dir, file)
if os.path.isfile(file_path):
file_ext = os.path.splitext(file)[1].lower()
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
if is_supported:
local_images_paths.append(file_path)
# Check if metadata update is needed (no civitai field or empty images)
needs_update = not model.get('civitai') or not model.get('civitai', {}).get('images')
if needs_update and local_images_paths:
logger.debug(f"Found {len(local_images_paths)} local example images for {model.get('model_name')}, updating metadata")
# Create or get civitai field
if not model.get('civitai'):
model['civitai'] = {}
# Create images array
images = []
# Generate metadata for each local image/video
for path in local_images_paths:
# Determine if video or image
file_ext = os.path.splitext(path)[1].lower()
is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
# Create image metadata entry
image_entry = {
"url": "", # Empty URL as required
"nsfwLevel": 0,
"width": 720, # Default dimensions
"height": 1280,
"type": "video" if is_video else "image",
"meta": None,
"hasMeta": False,
"hasPositivePrompt": False
}
# If it's an image, try to get actual dimensions (optional enhancement)
try:
from PIL import Image
if not is_video and os.path.exists(path):
with Image.open(path) as img:
image_entry["width"], image_entry["height"] = img.size
except:
# If PIL fails or is unavailable, use default dimensions
pass
images.append(image_entry)
# Update the model's civitai.images field
model['civitai']['images'] = images
# Save metadata to .metadata.json file
file_path = model.get('file_path')
try:
# Create a copy of model data without 'folder' field
model_copy = model.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model.get('model_name')}")
except Exception as e:
logger.error(f"Failed to save metadata for {model.get('model_name')}: {str(e)}")
# Save updated metadata to scanner cache
success = await scanner.update_single_model_cache(file_path, file_path, model)
if success:
logger.info(f"Successfully updated metadata for {model.get('model_name')} with {len(images)} local examples")
return True
else:
logger.warning(f"Failed to update metadata for {model.get('model_name')}")
return False
except Exception as e:
logger.error(f"Error updating metadata from local examples: {str(e)}", exc_info=True)
return False
@staticmethod
async def update_metadata_after_import(model_hash, model_data, scanner, newly_imported_paths):
"""Update model metadata after importing example images
Args:
model_hash: SHA256 hash of the model
model_data: Model data dictionary
scanner: Scanner instance (lora or checkpoint)
newly_imported_paths: List of paths to newly imported files
Returns:
tuple: (regular_images, custom_images) - Both image arrays
"""
try:
# Ensure civitai field exists in model_data
if not model_data.get('civitai'):
model_data['civitai'] = {}
# Ensure customImages array exists
if not model_data['civitai'].get('customImages'):
model_data['civitai']['customImages'] = []
# Get current customImages array
custom_images = model_data['civitai']['customImages']
# Add new image entry for each imported file
for path_tuple in newly_imported_paths:
path, short_id = path_tuple
# Determine if video or image
file_ext = os.path.splitext(path)[1].lower()
is_video = file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
# Create image metadata entry
image_entry = {
"url": "", # Empty URL as requested
"id": short_id,
"nsfwLevel": 0,
"width": 720, # Default dimensions
"height": 1280,
"type": "video" if is_video else "image",
"meta": None,
"hasMeta": False,
"hasPositivePrompt": False
}
# Extract and parse metadata if this is an image
if not is_video:
try:
# Extract metadata from image
extracted_metadata = ExifUtils.extract_image_metadata(path)
if extracted_metadata:
# Parse the extracted metadata to get generation parameters
parsed_meta = MetadataUpdater._parse_image_metadata(extracted_metadata)
if parsed_meta:
image_entry["meta"] = parsed_meta
image_entry["hasMeta"] = True
image_entry["hasPositivePrompt"] = bool(parsed_meta.get("prompt", ""))
logger.debug(f"Extracted metadata from {os.path.basename(path)}")
except Exception as e:
logger.warning(f"Failed to extract metadata from {os.path.basename(path)}: {e}")
# If it's an image, try to get actual dimensions
try:
from PIL import Image
if not is_video and os.path.exists(path):
with Image.open(path) as img:
image_entry["width"], image_entry["height"] = img.size
except:
# If PIL fails or is unavailable, use default dimensions
pass
# Append to existing customImages array
custom_images.append(image_entry)
# Save metadata to .metadata.json file
file_path = model_data.get('file_path')
if file_path:
try:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved metadata for {model_data.get('model_name')}")
except Exception as e:
logger.error(f"Failed to save metadata: {str(e)}")
# Save updated metadata to scanner cache
if file_path:
await scanner.update_single_model_cache(file_path, file_path, model_data)
# Get regular images array (might be None)
regular_images = model_data['civitai'].get('images', [])
# Return both image arrays
return regular_images, custom_images
except Exception as e:
logger.error(f"Failed to update metadata after import: {e}", exc_info=True)
return [], []
@staticmethod
def _parse_image_metadata(user_comment):
"""Parse metadata from image to extract generation parameters
Args:
user_comment: Metadata string extracted from image
Returns:
dict: Parsed metadata with generation parameters
"""
if not user_comment:
return None
try:
# Initialize metadata dictionary
metadata = {}
# Split on Negative prompt if it exists
if "Negative prompt:" in user_comment:
parts = user_comment.split('Negative prompt:', 1)
prompt = parts[0].strip()
negative_and_params = parts[1] if len(parts) > 1 else ""
else:
# No negative prompt section
param_start = re.search(r'Steps: \d+', user_comment)
if param_start:
prompt = user_comment[:param_start.start()].strip()
negative_and_params = user_comment[param_start.start():]
else:
prompt = user_comment.strip()
negative_and_params = ""
# Add prompt if it's in GEN_PARAM_KEYS
if 'prompt' in GEN_PARAM_KEYS:
metadata['prompt'] = prompt
# Extract negative prompt and parameters
if negative_and_params:
# If we split on "Negative prompt:", check for params section
if "Negative prompt:" in user_comment:
param_start = re.search(r'Steps: ', negative_and_params)
if param_start:
neg_prompt = negative_and_params[:param_start.start()].strip()
if 'negative_prompt' in GEN_PARAM_KEYS:
metadata['negative_prompt'] = neg_prompt
params_section = negative_and_params[param_start.start():]
else:
if 'negative_prompt' in GEN_PARAM_KEYS:
metadata['negative_prompt'] = negative_and_params.strip()
params_section = ""
else:
# No negative prompt, entire section is params
params_section = negative_and_params
# Extract generation parameters
if params_section:
# Extract basic parameters
param_pattern = r'([A-Za-z\s]+): ([^,]+)'
params = re.findall(param_pattern, params_section)
for key, value in params:
clean_key = key.strip().lower().replace(' ', '_')
# Skip if not in recognized gen param keys
if clean_key not in GEN_PARAM_KEYS:
continue
# Convert numeric values
if clean_key in ['steps', 'seed']:
try:
metadata[clean_key] = int(value.strip())
except ValueError:
metadata[clean_key] = value.strip()
elif clean_key in ['cfg_scale']:
try:
metadata[clean_key] = float(value.strip())
except ValueError:
metadata[clean_key] = value.strip()
else:
metadata[clean_key] = value.strip()
# Extract size if available and add if a recognized key
size_match = re.search(r'Size: (\d+)x(\d+)', params_section)
if size_match and 'size' in GEN_PARAM_KEYS:
width, height = size_match.groups()
metadata['size'] = f"{width}x{height}"
# Return metadata if we have any entries
return metadata if metadata else None
except Exception as e:
logger.error(f"Error parsing image metadata: {e}", exc_info=True)
return None

View File

@@ -0,0 +1,318 @@
import asyncio
import logging
import os
import re
import json
from ..services.settings_manager import settings
from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager
from ..utils.example_images_processor import ExampleImagesProcessor
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
logger = logging.getLogger(__name__)
CURRENT_NAMING_VERSION = 2 # Increment this when naming conventions change
class ExampleImagesMigration:
"""Handles migrations for example images naming conventions"""
@staticmethod
async def check_and_run_migrations():
"""Check if migrations are needed and run them in background"""
example_images_path = settings.get('example_images_path')
if not example_images_path or not os.path.exists(example_images_path):
logger.debug("No example images path configured or path doesn't exist, skipping migrations")
return
# Check current version from progress file
current_version = 0
progress_file = os.path.join(example_images_path, '.download_progress.json')
if os.path.exists(progress_file):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
progress_data = json.load(f)
current_version = progress_data.get('naming_version', 0)
except Exception as e:
logger.error(f"Failed to load progress file for migration check: {e}")
# If current version is less than target version, start migration
if current_version < CURRENT_NAMING_VERSION:
logger.info(f"Starting example images naming migration from v{current_version} to v{CURRENT_NAMING_VERSION}")
# Start migration in background task
asyncio.create_task(
ExampleImagesMigration.run_migrations(example_images_path, current_version, CURRENT_NAMING_VERSION)
)
@staticmethod
async def run_migrations(example_images_path, from_version, to_version):
"""Run necessary migrations based on version difference"""
try:
# Get all model folders
model_folders = []
for item in os.listdir(example_images_path):
item_path = os.path.join(example_images_path, item)
if os.path.isdir(item_path) and len(item) == 64: # SHA256 hash is 64 chars
model_folders.append(item_path)
logger.info(f"Found {len(model_folders)} model folders to check for migration")
# Apply migrations sequentially
if from_version < 1 and to_version >= 1:
await ExampleImagesMigration._migrate_to_v1(model_folders)
if from_version < 2 and to_version >= 2:
await ExampleImagesMigration._migrate_to_v2(model_folders)
# Update version in progress file
progress_file = os.path.join(example_images_path, '.download_progress.json')
try:
progress_data = {}
if os.path.exists(progress_file):
with open(progress_file, 'r', encoding='utf-8') as f:
progress_data = json.load(f)
progress_data['naming_version'] = to_version
with open(progress_file, 'w', encoding='utf-8') as f:
json.dump(progress_data, f, indent=2)
logger.info(f"Example images naming migration to v{to_version} completed")
except Exception as e:
logger.error(f"Failed to update version in progress file: {e}")
except Exception as e:
logger.error(f"Error during migration: {e}", exc_info=True)
@staticmethod
async def _migrate_to_v1(model_folders):
"""Migrate from 1-based to 0-based indexing"""
count = 0
for folder in model_folders:
has_one_based = False
has_zero_based = False
files_to_rename = []
# Check naming pattern in this folder
for file in os.listdir(folder):
if re.match(r'image_1\.\w+$', file):
has_one_based = True
if re.match(r'image_0\.\w+$', file):
has_zero_based = True
# Only migrate folders with 1-based indexing and no 0-based
if has_one_based and not has_zero_based:
# Create rename mapping
for file in os.listdir(folder):
match = re.match(r'image_(\d+)\.(\w+)$', file)
if match:
index = int(match.group(1))
ext = match.group(2)
if index > 0: # Only rename if index is positive
files_to_rename.append((
file,
f"image_{index-1}.{ext}"
))
# Use temporary names to avoid conflicts
for old_name, new_name in files_to_rename:
old_path = os.path.join(folder, old_name)
temp_path = os.path.join(folder, f"temp_{old_name}")
try:
os.rename(old_path, temp_path)
except Exception as e:
logger.error(f"Failed to rename {old_path} to {temp_path}: {e}")
# Rename from temporary names to final names
for old_name, new_name in files_to_rename:
temp_path = os.path.join(folder, f"temp_{old_name}")
new_path = os.path.join(folder, new_name)
try:
os.rename(temp_path, new_path)
logger.debug(f"Renamed {old_name} to {new_name} in {folder}")
except Exception as e:
logger.error(f"Failed to rename {temp_path} to {new_path}: {e}")
count += 1
# Give other tasks a chance to run
if count % 10 == 0:
await asyncio.sleep(0)
logger.info(f"Migrated {count} folders from 1-based to 0-based indexing")
@staticmethod
async def _migrate_to_v2(model_folders):
"""
Migrate to v2 naming scheme:
- Move custom examples from images array to customImages array
- Rename files from image_<index>.<ext> to custom_<short_id>.<ext>
- Add id field to each custom image entry
"""
count = 0
updated_models = 0
migration_errors = 0
# Get scanner instances
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
# Wait until scanners are initialized
scanners = [lora_scanner, checkpoint_scanner]
for scanner in scanners:
if scanner.is_initializing():
logger.info("Waiting for scanners to complete initialization before starting migration...")
initialized = False
retry_count = 0
while not initialized and retry_count < 120: # Wait up to 120 seconds
await asyncio.sleep(1)
initialized = not scanner.is_initializing()
retry_count += 1
if not initialized:
logger.warning("Scanner initialization timeout - proceeding with migration anyway")
logger.info(f"Starting migration to v2 naming scheme for {len(model_folders)} model folders")
for folder in model_folders:
try:
# Extract model hash from folder name
model_hash = os.path.basename(folder)
if not model_hash or len(model_hash) != 64:
continue
# Find the model in scanner cache
model_data = None
scanner = None
for scan_obj in scanners:
if scan_obj.has_hash(model_hash):
cache = await scan_obj.get_cached_data()
for item in cache.raw_data:
if item.get('sha256') == model_hash:
model_data = item
scanner = scan_obj
break
if model_data:
break
if not model_data or not scanner:
logger.debug(f"Model with hash {model_hash} not found in cache, skipping migration")
continue
# Clone model data to avoid modifying the cache directly
model_metadata = model_data.copy()
# Check if model has civitai metadata
if not model_metadata.get('civitai'):
continue
# Get images array
images = model_metadata.get('civitai', {}).get('images', [])
if not images:
continue
# Initialize customImages array if it doesn't exist
if not model_metadata['civitai'].get('customImages'):
model_metadata['civitai']['customImages'] = []
# Find custom examples (entries with empty url)
custom_indices = []
for i, image in enumerate(images):
if image.get('url') == "":
custom_indices.append(i)
if not custom_indices:
continue
logger.debug(f"Found {len(custom_indices)} custom examples in {model_hash}")
# Process each custom example
for index in custom_indices:
try:
image_entry = images[index]
# Determine media type based on the entry type
media_type = 'videos' if image_entry.get('type') == 'video' else 'images'
extensions_to_try = SUPPORTED_MEDIA_EXTENSIONS[media_type]
# Find the image file by trying possible extensions
old_path = None
old_filename = None
found = False
for ext in extensions_to_try:
test_path = os.path.join(folder, f"image_{index}{ext}")
if os.path.exists(test_path):
old_path = test_path
old_filename = f"image_{index}{ext}"
found = True
break
if not found:
logger.warning(f"Could not find file for index {index} in {model_hash}, skipping")
continue
# Generate short ID for the custom example
short_id = ExampleImagesProcessor.generate_short_id()
# Get file extension
file_ext = os.path.splitext(old_path)[1]
# Create new filename
new_filename = f"custom_{short_id}{file_ext}"
new_path = os.path.join(folder, new_filename)
# Rename the file
try:
os.rename(old_path, new_path)
logger.debug(f"Renamed {old_filename} to {new_filename} in {folder}")
except Exception as e:
logger.error(f"Failed to rename {old_path} to {new_path}: {e}")
continue
# Create a copy of the image entry with the id field
custom_entry = image_entry.copy()
custom_entry['id'] = short_id
# Add to customImages array
model_metadata['civitai']['customImages'].append(custom_entry)
count += 1
except Exception as e:
logger.error(f"Error migrating custom example at index {index} for {model_hash}: {e}")
# Remove custom examples from the original images array
model_metadata['civitai']['images'] = [
img for i, img in enumerate(images) if i not in custom_indices
]
# Save the updated metadata
file_path = model_data.get('file_path')
if file_path:
try:
# Create a copy of model data without 'folder' field
model_copy = model_metadata.copy()
model_copy.pop('folder', None)
# Save metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
# Update scanner cache
await scanner.update_single_model_cache(file_path, file_path, model_metadata)
updated_models += 1
except Exception as e:
logger.error(f"Failed to save metadata for {model_hash}: {e}")
migration_errors += 1
# Give other tasks a chance to run
if count % 10 == 0:
await asyncio.sleep(0)
except Exception as e:
logger.error(f"Error migrating folder {folder}: {e}")
migration_errors += 1
logger.info(f"Migration to v2 complete: migrated {count} custom examples across {updated_models} models with {migration_errors} errors")

View File

@@ -0,0 +1,568 @@
import logging
import os
import re
import tempfile
import random
import string
from aiohttp import web
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings
from .example_images_metadata import MetadataUpdater
from ..utils.metadata_manager import MetadataManager
logger = logging.getLogger(__name__)
class ExampleImagesProcessor:
"""Processes and manipulates example images"""
@staticmethod
def generate_short_id(length=8):
"""Generate a short random alphanumeric identifier"""
chars = string.ascii_lowercase + string.digits
return ''.join(random.choice(chars) for _ in range(length))
@staticmethod
def get_civitai_optimized_url(image_url):
"""Convert Civitai image URL to its optimized WebP version"""
base_pattern = r'(https://image\.civitai\.com/[^/]+/[^/]+)'
match = re.match(base_pattern, image_url)
if match:
base_url = match.group(1)
return f"{base_url}/optimized=true/image.webp"
return image_url
@staticmethod
async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, independent_session):
"""Download images for a single model
Returns:
tuple: (success, is_stale_metadata) - whether download was successful, whether metadata is stale
"""
model_success = True
for i, image in enumerate(model_images):
image_url = image.get('url')
if not image_url:
continue
# Get image filename from URL
image_filename = os.path.basename(image_url.split('?')[0])
image_ext = os.path.splitext(image_filename)[1].lower()
# Handle images and videos
is_image = image_ext in SUPPORTED_MEDIA_EXTENSIONS['images']
is_video = image_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
if not (is_image or is_video):
logger.debug(f"Skipping unsupported file type: {image_filename}")
continue
# Use 0-based indexing instead of 1-based indexing
save_filename = f"image_{i}{image_ext}"
# If optimizing images and this is a Civitai image, use their pre-optimized WebP version
if is_image and optimize and 'civitai.com' in image_url:
image_url = ExampleImagesProcessor.get_civitai_optimized_url(image_url)
save_filename = f"image_{i}.webp"
# Check if already downloaded
save_path = os.path.join(model_dir, save_filename)
if os.path.exists(save_path):
logger.debug(f"File already exists: {save_path}")
continue
# Download the file
try:
logger.debug(f"Downloading {save_filename} for {model_name}")
# Download directly using the independent session
async with independent_session.get(image_url, timeout=60) as response:
if response.status == 200:
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
if chunk:
f.write(chunk)
elif response.status == 404:
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
logger.warning(error_msg)
model_success = False # Mark the model as failed due to 404 error
# Return early to trigger metadata refresh attempt
return False, True # (success, is_metadata_stale)
else:
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
logger.warning(error_msg)
model_success = False # Mark the model as failed
except Exception as e:
error_msg = f"Error downloading file {image_url}: {str(e)}"
logger.error(error_msg)
model_success = False # Mark the model as failed
return model_success, False # (success, is_metadata_stale)
@staticmethod
async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session):
"""Download images for a single model with tracking of failed image URLs
Returns:
tuple: (success, is_stale_metadata, failed_images) - whether download was successful, whether metadata is stale, list of failed image URLs
"""
model_success = True
failed_images = []
for i, image in enumerate(model_images):
image_url = image.get('url')
if not image_url:
continue
# Get image filename from URL
image_filename = os.path.basename(image_url.split('?')[0])
image_ext = os.path.splitext(image_filename)[1].lower()
# Handle images and videos
is_image = image_ext in SUPPORTED_MEDIA_EXTENSIONS['images']
is_video = image_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
if not (is_image or is_video):
logger.debug(f"Skipping unsupported file type: {image_filename}")
continue
# Use 0-based indexing instead of 1-based indexing
save_filename = f"image_{i}{image_ext}"
# If optimizing images and this is a Civitai image, use their pre-optimized WebP version
if is_image and optimize and 'civitai.com' in image_url:
image_url = ExampleImagesProcessor.get_civitai_optimized_url(image_url)
save_filename = f"image_{i}.webp"
# Check if already downloaded
save_path = os.path.join(model_dir, save_filename)
if os.path.exists(save_path):
logger.debug(f"File already exists: {save_path}")
continue
# Download the file
try:
logger.debug(f"Downloading {save_filename} for {model_name}")
# Download directly using the independent session
async with independent_session.get(image_url, timeout=60) as response:
if response.status == 200:
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(8192):
if chunk:
f.write(chunk)
elif response.status == 404:
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
logger.warning(error_msg)
model_success = False # Mark the model as failed due to 404 error
failed_images.append(image_url) # Track failed URL
# Return early to trigger metadata refresh attempt
return False, True, failed_images # (success, is_metadata_stale, failed_images)
else:
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
logger.warning(error_msg)
model_success = False # Mark the model as failed
failed_images.append(image_url) # Track failed URL
except Exception as e:
error_msg = f"Error downloading file {image_url}: {str(e)}"
logger.error(error_msg)
model_success = False # Mark the model as failed
failed_images.append(image_url) # Track failed URL
return model_success, False, failed_images # (success, is_metadata_stale, failed_images)
@staticmethod
async def process_local_examples(model_file_path, model_file_name, model_name, model_dir, optimize):
"""Process local example images
Returns:
bool: True if local images were processed successfully, False otherwise
"""
try:
if not model_file_path or not os.path.exists(os.path.dirname(model_file_path)):
return False
model_dir_path = os.path.dirname(model_file_path)
local_images = []
# Look for files with pattern: filename.example.*.ext
if model_file_name:
example_prefix = f"{model_file_name}.example."
if os.path.exists(model_dir_path):
for file in os.listdir(model_dir_path):
file_lower = file.lower()
if file_lower.startswith(example_prefix.lower()):
file_ext = os.path.splitext(file_lower)[1]
is_supported = (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos'])
if is_supported:
local_images.append(os.path.join(model_dir_path, file))
# Process local images if found
if local_images:
logger.info(f"Found {len(local_images)} local example images for {model_name}")
for local_image_path in local_images:
# Extract index from filename
file_name = os.path.basename(local_image_path)
example_prefix = f"{model_file_name}.example."
try:
# Extract the part between '.example.' and the file extension
index_part = file_name[len(example_prefix):].split('.')[0]
# Try to parse it as an integer
index = int(index_part)
local_ext = os.path.splitext(local_image_path)[1].lower()
save_filename = f"image_{index}{local_ext}"
except (ValueError, IndexError):
# If we can't parse the index, fall back to sequential numbering
logger.warning(f"Could not extract index from {file_name}, using sequential numbering")
local_ext = os.path.splitext(local_image_path)[1].lower()
save_filename = f"image_{len(local_images)}{local_ext}"
save_path = os.path.join(model_dir, save_filename)
# Skip if already exists in output directory
if os.path.exists(save_path):
logger.debug(f"File already exists in output: {save_path}")
continue
# Copy the file
with open(local_image_path, 'rb') as src_file:
with open(save_path, 'wb') as dst_file:
dst_file.write(src_file.read())
return True
return False
except Exception as e:
logger.error(f"Error processing local examples for {model_name}: {str(e)}")
return False
@staticmethod
async def import_images(request):
"""
Import local example images
Accepts:
- multipart/form-data form with model_hash and files fields
or
- JSON request with model_hash and file_paths
Returns:
- Success status and list of imported files
"""
try:
model_hash = None
files_to_import = []
temp_files_to_cleanup = []
# Check if it's a multipart form-data request (direct file upload)
if request.content_type and 'multipart/form-data' in request.content_type:
reader = await request.multipart()
# First get model_hash
field = await reader.next()
if field.name == 'model_hash':
model_hash = await field.text()
# Then process all files
while True:
field = await reader.next()
if field is None:
break
if field.name == 'files':
# Create a temporary file with appropriate suffix for type detection
file_name = field.filename
file_ext = os.path.splitext(file_name)[1].lower()
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
temp_path = tmp_file.name
temp_files_to_cleanup.append(temp_path) # Track for cleanup
# Write chunks to the temporary file
while True:
chunk = await field.read_chunk()
if not chunk:
break
tmp_file.write(chunk)
# Add to the list of files to process
files_to_import.append(temp_path)
else:
# Parse JSON request (legacy method using file paths)
data = await request.json()
model_hash = data.get('model_hash')
files_to_import = data.get('file_paths', [])
if not model_hash:
return web.json_response({
'success': False,
'error': 'Missing model_hash parameter'
}, status=400)
if not files_to_import:
return web.json_response({
'success': False,
'error': 'No files provided to import'
}, status=400)
# Get example images path
example_images_path = settings.get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
'error': 'No example images path configured'
}, status=400)
# Find the model and get current metadata
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
model_data = None
scanner = None
# Check both scanners to find the model
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
cache = await scan_obj.get_cached_data()
for item in cache.raw_data:
if item.get('sha256') == model_hash:
model_data = item
scanner = scan_obj
break
if model_data:
break
if not model_data:
return web.json_response({
'success': False,
'error': f"Model with hash {model_hash} not found in cache"
}, status=404)
# Create model folder
model_folder = os.path.join(example_images_path, model_hash)
os.makedirs(model_folder, exist_ok=True)
imported_files = []
errors = []
newly_imported_paths = []
# Process each file path
for file_path in files_to_import:
try:
# Ensure the file exists
if not os.path.isfile(file_path):
errors.append(f"File not found: {file_path}")
continue
# Check if file type is supported
file_ext = os.path.splitext(file_path)[1].lower()
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
errors.append(f"Unsupported file type: {file_path}")
continue
# Generate new filename using short ID instead of UUID
short_id = ExampleImagesProcessor.generate_short_id()
new_filename = f"custom_{short_id}{file_ext}"
dest_path = os.path.join(model_folder, new_filename)
# Copy the file
import shutil
shutil.copy2(file_path, dest_path)
# Store both the dest_path and the short_id
newly_imported_paths.append((dest_path, short_id))
# Add to imported files list
imported_files.append({
'name': new_filename,
'path': f'/example_images_static/{model_hash}/{new_filename}',
'extension': file_ext,
'is_video': file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
})
except Exception as e:
errors.append(f"Error importing {file_path}: {str(e)}")
# Update metadata with new example images
regular_images, custom_images = await MetadataUpdater.update_metadata_after_import(
model_hash,
model_data,
scanner,
newly_imported_paths
)
return web.json_response({
'success': len(imported_files) > 0,
'message': f'Successfully imported {len(imported_files)} files' +
(f' with {len(errors)} errors' if errors else ''),
'files': imported_files,
'errors': errors,
'regular_images': regular_images,
'custom_images': custom_images,
"model_file_path": model_data.get('file_path', ''),
})
except Exception as e:
logger.error(f"Failed to import example images: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
finally:
# Clean up temporary files
for temp_file in temp_files_to_cleanup:
try:
os.remove(temp_file)
except Exception as e:
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
@staticmethod
async def delete_custom_image(request):
"""
Delete a custom example image for a model
Accepts:
- JSON request with model_hash and short_id
Returns:
- Success status and updated image lists
"""
try:
# Parse request data
data = await request.json()
model_hash = data.get('model_hash')
short_id = data.get('short_id')
if not model_hash or not short_id:
return web.json_response({
'success': False,
'error': 'Missing required parameters: model_hash and short_id'
}, status=400)
# Get example images path
example_images_path = settings.get('example_images_path')
if not example_images_path:
return web.json_response({
'success': False,
'error': 'No example images path configured'
}, status=400)
# Find the model and get current metadata
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
model_data = None
scanner = None
# Check both scanners to find the model
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
if scan_obj.has_hash(model_hash):
cache = await scan_obj.get_cached_data()
for item in cache.raw_data:
if item.get('sha256') == model_hash:
model_data = item
scanner = scan_obj
break
if model_data:
break
if not model_data:
return web.json_response({
'success': False,
'error': f"Model with hash {model_hash} not found in cache"
}, status=404)
# Check if model has custom images
if not model_data.get('civitai', {}).get('customImages'):
return web.json_response({
'success': False,
'error': f"Model has no custom images"
}, status=404)
# Find the custom image with matching short_id
custom_images = model_data['civitai']['customImages']
matching_image = None
new_custom_images = []
for image in custom_images:
if image.get('id') == short_id:
matching_image = image
else:
new_custom_images.append(image)
if not matching_image:
return web.json_response({
'success': False,
'error': f"Custom image with id {short_id} not found"
}, status=404)
# Find and delete the actual file
model_folder = os.path.join(example_images_path, model_hash)
file_deleted = False
if os.path.exists(model_folder):
for filename in os.listdir(model_folder):
if f"custom_{short_id}" in filename:
file_path = os.path.join(model_folder, filename)
try:
os.remove(file_path)
file_deleted = True
logger.info(f"Deleted custom example file: {file_path}")
break
except Exception as e:
return web.json_response({
'success': False,
'error': f"Failed to delete file: {str(e)}"
}, status=500)
if not file_deleted:
logger.warning(f"File for custom example with id {short_id} not found, but metadata will still be updated")
# Update metadata
model_data['civitai']['customImages'] = new_custom_images
# Save updated metadata to file
file_path = model_data.get('file_path')
if file_path:
try:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.debug(f"Saved updated metadata for {model_data.get('model_name')}")
except Exception as e:
logger.error(f"Failed to save metadata: {str(e)}")
return web.json_response({
'success': False,
'error': f"Failed to save metadata: {str(e)}"
}, status=500)
# Update cache
await scanner.update_single_model_cache(file_path, file_path, model_data)
# Get regular images array (might be None)
regular_images = model_data['civitai'].get('images', [])
return web.json_response({
'success': True,
'regular_images': regular_images,
'custom_images': new_custom_images,
'model_file_path': model_data.get('file_path', '')
})
except Exception as e:
logger.error(f"Failed to delete custom example image: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

View File

@@ -31,7 +31,7 @@ class ExifUtils:
# Method 2: Check EXIF UserComment field
if img.format not in ['JPEG', 'TIFF', 'WEBP']:
# For non-JPEG/TIFF/WEBP images, try to get EXIF through PIL
exif = img._getexif()
exif = img.getexif()
if exif and piexif.ExifIFD.UserComment in exif:
user_comment = exif[piexif.ExifIFD.UserComment]
if isinstance(user_comment, bytes):
@@ -147,7 +147,7 @@ class ExifUtils:
"file_name": lora.get("file_name", ""),
"hash": lora.get("hash", "").lower() if lora.get("hash") else "",
"strength": float(lora.get("strength", 1.0)),
"modelVersionId": lora.get("modelVersionId", ""),
"modelVersionId": lora.get("modelVersionId", 0),
"modelName": lora.get("modelName", ""),
"modelVersionName": lora.get("modelVersionName", ""),
}

View File

@@ -1,13 +1,7 @@
import logging
import os
import hashlib
import json
import time
from typing import Dict, Optional, Type
from .model_utils import determine_base_model
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
from .models import BaseModelMetadata, LoraMetadata, CheckpointMetadata
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from .exif_utils import ExifUtils
@@ -24,7 +18,12 @@ async def calculate_sha256(file_path: str) -> str:
def find_preview_file(base_name: str, dir_path: str) -> str:
"""Find preview file for given base name in directory"""
for ext in PREVIEW_EXTENSIONS:
temp_extensions = PREVIEW_EXTENSIONS.copy()
# Add example extension for compatibility
# https://github.com/willmiao/ComfyUI-Lora-Manager/issues/225
# The preview image will be optimized to lora-name.webp, so it won't affect other logic
temp_extensions.append(".example.0.jpeg")
for ext in temp_extensions:
full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
if os.path.exists(full_pattern):
# Check if this is an image and not already webp
@@ -42,7 +41,7 @@ def find_preview_file(base_name: str, dir_path: str) -> str:
target_width=CARD_PREVIEW_WIDTH,
format='webp',
quality=85,
preserve_metadata=False # Changed from True to False
preserve_metadata=False
)
# Save the optimized webp file
@@ -63,199 +62,4 @@ def find_preview_file(base_name: str, dir_path: str) -> str:
def normalize_path(path: str) -> str:
"""Normalize file path to use forward slashes"""
return path.replace(os.sep, "/") if path else path
async def get_file_info(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Get basic file information as a model metadata object"""
# First check if file actually exists and resolve symlinks
try:
real_path = os.path.realpath(file_path)
if not os.path.exists(real_path):
return None
except Exception as e:
logger.error(f"Error checking file existence for {file_path}: {e}")
return None
base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path)
preview_url = find_preview_file(base_name, dir_path)
# Check if a .json file exists with SHA256 hash to avoid recalculation
json_path = f"{os.path.splitext(file_path)[0]}.json"
sha256 = None
if os.path.exists(json_path):
try:
with open(json_path, 'r', encoding='utf-8') as f:
json_data = json.load(f)
if 'sha256' in json_data:
sha256 = json_data['sha256'].lower()
logger.debug(f"Using SHA256 from .json file for {file_path}")
except Exception as e:
logger.error(f"Error reading .json file for {file_path}: {e}")
# If SHA256 is still not found, check for a .sha256 file
if sha256 is None:
sha256_file = f"{os.path.splitext(file_path)[0]}.sha256"
if os.path.exists(sha256_file):
try:
with open(sha256_file, 'r', encoding='utf-8') as f:
sha256 = f.read().strip().lower()
logger.debug(f"Using SHA256 from .sha256 file for {file_path}")
except Exception as e:
logger.error(f"Error reading .sha256 file for {file_path}: {e}")
try:
# If we didn't get SHA256 from the .json file, calculate it
if not sha256:
start_time = time.time()
sha256 = await calculate_sha256(real_path)
logger.debug(f"Calculated SHA256 for {file_path} in {time.time() - start_time:.2f} seconds")
# Create default metadata based on model class
if model_class == CheckpointMetadata:
metadata = CheckpointMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
model_type="checkpoint"
)
# Extract checkpoint-specific metadata
# model_info = await extract_checkpoint_metadata(real_path)
# metadata.base_model = model_info['base_model']
# if 'model_type' in model_info:
# metadata.model_type = model_info['model_type']
else: # Default to LoraMetadata
metadata = LoraMetadata(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=os.path.getmtime(real_path),
sha256=sha256,
base_model="Unknown", # Will be updated later
usage_tips="{}",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription=""
)
# Extract lora-specific metadata
model_info = await extract_lora_metadata(real_path)
metadata.base_model = model_info['base_model']
# Save metadata to file
await save_metadata(file_path, metadata)
return metadata
except Exception as e:
logger.error(f"Error getting file info for {file_path}: {e}")
return None
async def save_metadata(file_path: str, metadata: BaseModelMetadata) -> None:
"""Save metadata to .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
metadata_dict = metadata.to_dict()
metadata_dict['file_path'] = normalize_path(metadata_dict['file_path'])
metadata_dict['preview_url'] = normalize_path(metadata_dict['preview_url'])
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata_dict, f, indent=2, ensure_ascii=False)
except Exception as e:
print(f"Error saving metadata to {metadata_path}: {str(e)}")
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""Load metadata from .metadata.json file"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
try:
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
data = json.load(f)
needs_update = False
# Check and normalize base model name
normalized_base_model = determine_base_model(data['base_model'])
if data['base_model'] != normalized_base_model:
data['base_model'] = normalized_base_model
needs_update = True
# Compare paths without extensions
stored_path_base = os.path.splitext(data['file_path'])[0]
current_path_base = os.path.splitext(normalize_path(file_path))[0]
if stored_path_base != current_path_base:
data['file_path'] = normalize_path(file_path)
needs_update = True
# TODO: optimize preview image to webp format if not already done
preview_url = data.get('preview_url', '')
if not preview_url or not os.path.exists(preview_url):
base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path)
new_preview_url = normalize_path(find_preview_file(base_name, dir_path))
if new_preview_url != preview_url:
data['preview_url'] = new_preview_url
needs_update = True
else:
# Compare preview paths without extensions
stored_preview_base = os.path.splitext(preview_url)[0]
current_preview_base = os.path.splitext(normalize_path(preview_url))[0]
if stored_preview_base != current_preview_base:
data['preview_url'] = normalize_path(preview_url)
needs_update = True
# Ensure all fields are present
if 'tags' not in data:
data['tags'] = []
needs_update = True
if 'modelDescription' not in data:
data['modelDescription'] = ""
needs_update = True
# For checkpoint metadata
if model_class == CheckpointMetadata and 'model_type' not in data:
data['model_type'] = "checkpoint"
needs_update = True
# For lora metadata
if model_class == LoraMetadata and 'usage_tips' not in data:
data['usage_tips'] = "{}"
needs_update = True
# Update preview_nsfw_level if needed
civitai_data = data.get('civitai', {})
civitai_images = civitai_data.get('images', []) if civitai_data else []
if (data.get('preview_url') and
data.get('preview_nsfw_level', 0) == 0 and
civitai_images and
civitai_images[0].get('nsfwLevel', 0) != 0):
data['preview_nsfw_level'] = civitai_images[0]['nsfwLevel']
# TODO: write to metadata file
# needs_update = True
if needs_update:
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return model_class.from_dict(data)
except Exception as e:
print(f"Error loading metadata from {metadata_path}: {str(e)}")
return None
async def update_civitai_metadata(file_path: str, civitai_data: Dict) -> None:
"""Update metadata file with Civitai data"""
metadata = await load_metadata(file_path)
metadata['civitai'] = civitai_data
await save_metadata(file_path, metadata)
return path.replace(os.sep, "/") if path else path

View File

@@ -1,8 +1,9 @@
from safetensors import safe_open
from typing import Dict
from typing import Dict, List, Tuple
from .model_utils import determine_base_model
import os
import logging
import json
logger = logging.getLogger(__name__)
@@ -80,4 +81,53 @@ async def extract_checkpoint_metadata(file_path: str) -> dict:
except Exception as e:
logger.error(f"Error extracting checkpoint metadata for {file_path}: {e}")
# Return default values
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
return {'base_model': 'Unknown', 'model_type': 'checkpoint'}
async def extract_trained_words(file_path: str) -> Tuple[List[Tuple[str, int]], str]:
"""Extract trained words from a safetensors file and sort by frequency
Args:
file_path: Path to the safetensors file
Returns:
Tuple of:
- List of (word, frequency) tuples sorted by frequency (highest first)
- class_tokens value (or None if not found)
"""
class_tokens = None
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
# Extract class_tokens from ss_datasets if present
if metadata and "ss_datasets" in metadata:
try:
datasets_data = json.loads(metadata["ss_datasets"])
# Look for class_tokens in the first subset
if datasets_data and isinstance(datasets_data, list) and datasets_data[0].get("subsets"):
subsets = datasets_data[0].get("subsets", [])
if subsets and isinstance(subsets, list) and len(subsets) > 0:
class_tokens = subsets[0].get("class_tokens")
except Exception as e:
logger.error(f"Error parsing ss_datasets for class_tokens: {str(e)}")
# Extract tag frequency as before
if metadata and "ss_tag_frequency" in metadata:
# Parse the JSON string into a dictionary
tag_data = json.loads(metadata["ss_tag_frequency"])
# The structure may have an outer key (like "image_dir" or "img")
# We need to get the inner dictionary with the actual word frequencies
if tag_data:
# Get the first key (usually "image_dir" or "img")
first_key = list(tag_data.keys())[0]
words_dict = tag_data[first_key]
# Sort words by frequency (highest first)
sorted_words = sorted(words_dict.items(), key=lambda x: x[1], reverse=True)
return sorted_words, class_tokens
except Exception as e:
logger.error(f"Error extracting trained words from {file_path}: {str(e)}")
return [], class_tokens

View File

@@ -0,0 +1,313 @@
from datetime import datetime
import os
import json
import shutil
import logging
from typing import Dict, Optional, Type, Union
from .models import BaseModelMetadata, LoraMetadata
from .file_utils import normalize_path, find_preview_file, calculate_sha256
from .lora_metadata import extract_lora_metadata, extract_checkpoint_metadata
logger = logging.getLogger(__name__)
class MetadataManager:
"""
Centralized manager for all metadata operations.
This class is responsible for:
1. Loading metadata safely with fallback mechanisms
2. Saving metadata with atomic operations and backups
3. Creating default metadata for models
4. Handling unknown fields gracefully
"""
@staticmethod
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""
Load metadata with robust error handling and data preservation.
Args:
file_path: Path to the model file
model_class: Class to instantiate (LoraMetadata, CheckpointMetadata, etc.)
Returns:
BaseModelMetadata instance or None if file doesn't exist
"""
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
backup_path = f"{metadata_path}.bak"
# Try loading the main metadata file
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Create model instance
metadata = model_class.from_dict(data)
# Normalize paths
await MetadataManager._normalize_metadata_paths(metadata, file_path)
return metadata
except json.JSONDecodeError:
# JSON parsing error - try to restore from backup
logger.warning(f"Invalid JSON in metadata file: {metadata_path}")
return await MetadataManager._restore_from_backup(backup_path, file_path, model_class)
except Exception as e:
# Other errors might be due to unknown fields or schema changes
logger.error(f"Error loading metadata from {metadata_path}: {str(e)}")
return await MetadataManager._restore_from_backup(backup_path, file_path, model_class)
return None
@staticmethod
async def _restore_from_backup(backup_path: str, file_path: str, model_class: Type[BaseModelMetadata]) -> Optional[BaseModelMetadata]:
"""
Try to restore metadata from backup file
Args:
backup_path: Path to backup file
file_path: Path to the original model file
model_class: Class to instantiate
Returns:
BaseModelMetadata instance or None if restoration fails
"""
if os.path.exists(backup_path):
try:
logger.info(f"Attempting to restore metadata from backup: {backup_path}")
with open(backup_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Process data similarly to normal loading
metadata = model_class.from_dict(data)
await MetadataManager._normalize_metadata_paths(metadata, file_path)
return metadata
except Exception as e:
logger.error(f"Failed to restore from backup: {str(e)}")
return None
@staticmethod
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict], create_backup: bool = False) -> bool:
"""
Save metadata with atomic write operations and backup creation.
Args:
path: Path to the model file or directly to the metadata file
metadata: Metadata to save (either BaseModelMetadata object or dict)
create_backup: Whether to create a new backup of existing file if a backup doesn't already exist
Returns:
bool: Success or failure
"""
# Determine if the input is a metadata path or a model file path
if path.endswith('.metadata.json'):
metadata_path = path
else:
# Use existing logic for model file paths
file_path = path
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
temp_path = f"{metadata_path}.tmp"
backup_path = f"{metadata_path}.bak"
try:
# Create backup if file exists and either:
# 1. create_backup is True, OR
# 2. backup file doesn't already exist
if os.path.exists(metadata_path) and (create_backup or not os.path.exists(backup_path)):
try:
shutil.copy2(metadata_path, backup_path)
logger.debug(f"Created metadata backup at: {backup_path}")
except Exception as e:
logger.warning(f"Failed to create metadata backup: {str(e)}")
# Convert to dict if needed
if isinstance(metadata, BaseModelMetadata):
metadata_dict = metadata.to_dict()
# Preserve unknown fields if present
if hasattr(metadata, '_unknown_fields'):
metadata_dict.update(metadata._unknown_fields)
else:
metadata_dict = metadata.copy()
# Normalize paths
if 'file_path' in metadata_dict:
metadata_dict['file_path'] = normalize_path(metadata_dict['file_path'])
if 'preview_url' in metadata_dict:
metadata_dict['preview_url'] = normalize_path(metadata_dict['preview_url'])
# Write to temporary file first
with open(temp_path, 'w', encoding='utf-8') as f:
json.dump(metadata_dict, f, indent=2, ensure_ascii=False)
# Atomic rename operation
os.replace(temp_path, metadata_path)
return True
except Exception as e:
logger.error(f"Error saving metadata to {metadata_path}: {str(e)}")
# Clean up temporary file if it exists
if os.path.exists(temp_path):
try:
os.remove(temp_path)
except:
pass
return False
@staticmethod
async def create_default_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
"""
Create basic metadata structure for a model file.
This replaces the old get_file_info function with a more appropriately named method.
Args:
file_path: Path to the model file
model_class: Class to instantiate
Returns:
BaseModelMetadata instance or None if file doesn't exist
"""
# First check if file actually exists and resolve symlinks
try:
real_path = os.path.realpath(file_path)
if not os.path.exists(real_path):
return None
except Exception as e:
logger.error(f"Error checking file existence for {file_path}: {e}")
return None
try:
base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path)
# Find preview image
preview_url = find_preview_file(base_name, dir_path)
# Calculate file hash
sha256 = await calculate_sha256(real_path)
# Create instance based on model type
if model_class.__name__ == "CheckpointMetadata":
metadata = model_class(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=datetime.now().timestamp(),
sha256=sha256,
base_model="Unknown",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
model_type="checkpoint",
from_civitai=True
)
elif model_class.__name__ == "EmbeddingMetadata":
metadata = model_class(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=datetime.now().timestamp(),
sha256=sha256,
base_model="Unknown",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
from_civitai=True
)
else: # Default to LoraMetadata
metadata = model_class(
file_name=base_name,
model_name=base_name,
file_path=normalize_path(file_path),
size=os.path.getsize(real_path),
modified=datetime.now().timestamp(),
sha256=sha256,
base_model="Unknown",
preview_url=normalize_path(preview_url),
tags=[],
modelDescription="",
from_civitai=True,
usage_tips="{}"
)
# Try to extract model-specific metadata
# await MetadataManager._enrich_metadata(metadata, real_path)
# Save the created metadata
await MetadataManager.save_metadata(file_path, metadata, create_backup=False)
return metadata
except Exception as e:
logger.error(f"Error creating default metadata for {file_path}: {e}")
return None
@staticmethod
async def _enrich_metadata(metadata: BaseModelMetadata, file_path: str) -> None:
"""
Enrich metadata with model-specific information
Args:
metadata: Metadata to enrich
file_path: Path to the model file
"""
try:
if metadata.__class__.__name__ == "LoraMetadata":
model_info = await extract_lora_metadata(file_path)
metadata.base_model = model_info['base_model']
# elif metadata.__class__.__name__ == "CheckpointMetadata":
# model_info = await extract_checkpoint_metadata(file_path)
# metadata.base_model = model_info['base_model']
# if 'model_type' in model_info:
# metadata.model_type = model_info['model_type']
except Exception as e:
logger.error(f"Error enriching metadata: {str(e)}")
@staticmethod
async def _normalize_metadata_paths(metadata: BaseModelMetadata, file_path: str) -> None:
"""
Normalize paths in metadata object
Args:
metadata: Metadata object to update
file_path: Current file path for the model
"""
need_update = False
# Check if file_name matches the actual file name
base_name = os.path.splitext(os.path.basename(file_path))[0]
if metadata.file_name != base_name:
metadata.file_name = base_name
need_update = True
# Check if file path is different from what's in metadata
if normalize_path(file_path) != metadata.file_path:
metadata.file_path = normalize_path(file_path)
need_update = True
# Check if preview exists at the current location
preview_url = metadata.preview_url
if preview_url:
# Get directory parts of both paths
file_dir = os.path.dirname(file_path)
preview_dir = os.path.dirname(preview_url)
# Update preview if it doesn't exist OR if model and preview are in different directories
if not os.path.exists(preview_url) or file_dir != preview_dir:
base_name = os.path.splitext(os.path.basename(file_path))[0]
dir_path = os.path.dirname(file_path)
new_preview_url = find_preview_file(base_name, dir_path)
if new_preview_url:
metadata.preview_url = normalize_path(new_preview_url)
need_update = True
# If path attributes were changed, save the metadata back to disk
if need_update:
await MetadataManager.save_metadata(file_path, metadata, create_backup=False)

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass, asdict
from typing import Dict, Optional, List
from dataclasses import dataclass, asdict, field
from typing import Dict, Optional, List, Any
from datetime import datetime
import os
from .model_utils import determine_base_model
@@ -11,7 +11,7 @@ class BaseModelMetadata:
model_name: str # The model's name defined by the creator
file_path: str # Full path to the model file
size: int # File size in bytes
modified: float # Last modified timestamp
modified: float # Timestamp when the model was added to the management system
sha256: str # SHA256 hash of the file
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
preview_url: str # Preview image URL
@@ -24,6 +24,7 @@ class BaseModelMetadata:
civitai_deleted: bool = False # Whether deleted from Civitai
favorite: bool = False # Whether the model is a favorite
exclude: bool = False # Whether to exclude this model from the cache
_unknown_fields: Dict[str, Any] = field(default_factory=dict, repr=False, compare=False) # Store unknown fields
def __post_init__(self):
# Initialize empty lists to avoid mutable default parameter issue
@@ -34,16 +35,43 @@ class BaseModelMetadata:
def from_dict(cls, data: Dict) -> 'BaseModelMetadata':
"""Create instance from dictionary"""
data_copy = data.copy()
return cls(**data_copy)
# Use cached fields if available, otherwise compute them
if not hasattr(cls, '_known_fields_cache'):
known_fields = set()
for c in cls.mro():
if hasattr(c, '__annotations__'):
known_fields.update(c.__annotations__.keys())
cls._known_fields_cache = known_fields
known_fields = cls._known_fields_cache
# Extract fields that match our class attributes
fields_to_use = {k: v for k, v in data_copy.items() if k in known_fields}
# Store unknown fields separately
unknown_fields = {k: v for k, v in data_copy.items() if k not in known_fields and not k.startswith('_')}
# Create instance with known fields
instance = cls(**fields_to_use)
# Add unknown fields as a separate attribute
instance._unknown_fields = unknown_fields
return instance
def to_dict(self) -> Dict:
"""Convert to dictionary for JSON serialization"""
return asdict(self)
@property
def modified_datetime(self) -> datetime:
"""Convert modified timestamp to datetime object"""
return datetime.fromtimestamp(self.modified)
result = asdict(self)
# Remove private fields
result = {k: v for k, v in result.items() if not k.startswith('_')}
# Add back unknown fields if they exist
if hasattr(self, '_unknown_fields'):
result.update(self._unknown_fields)
return result
def update_civitai_info(self, civitai_data: Dict) -> None:
"""Update Civitai information"""
@@ -95,7 +123,7 @@ class LoraMetadata(BaseModelMetadata):
@dataclass
class CheckpointMetadata(BaseModelMetadata):
"""Represents the metadata structure for a Checkpoint model"""
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
@@ -130,3 +158,41 @@ class CheckpointMetadata(BaseModelMetadata):
modelDescription=description
)
@dataclass
class EmbeddingMetadata(BaseModelMetadata):
"""Represents the metadata structure for an Embedding model"""
model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.)
@classmethod
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata':
"""Create EmbeddingMetadata instance from Civitai version info"""
file_name = file_info['name']
base_model = determine_base_model(version_info.get('baseModel', ''))
model_type = version_info.get('type', 'embedding')
# Extract tags and description if available
tags = []
description = ""
if 'model' in version_info:
if 'tags' in version_info['model']:
tags = version_info['model']['tags']
if 'description' in version_info['model']:
description = version_info['model']['description']
return cls(
file_name=os.path.splitext(file_name)[0],
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
file_path=save_path.replace(os.sep, '/'),
size=file_info.get('sizeKB', 0) * 1024,
modified=datetime.now().timestamp(),
sha256=file_info['hashes'].get('SHA256', '').lower(),
base_model=base_model,
preview_url=None, # Will be updated after preview download
preview_nsfw_level=0,
from_civitai=True,
civitai=version_info,
model_type=model_type,
tags=tags,
modelDescription=description
)

View File

@@ -8,8 +8,11 @@ from .model_utils import determine_base_model
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
from ..config import config
from ..services.civitai_client import CivitaiClient
from ..services.service_registry import ServiceRegistry
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from ..services.download_manager import DownloadManager
from ..services.websocket_manager import ws_manager
logger = logging.getLogger(__name__)
@@ -32,27 +35,60 @@ class ModelRouteUtils:
async def handle_not_found_on_civitai(metadata_path: str, local_metadata: Dict) -> None:
"""Handle case when model is not found on CivitAI"""
local_metadata['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(metadata_path, local_metadata)
@staticmethod
async def update_model_metadata(metadata_path: str, local_metadata: Dict,
civitai_metadata: Dict, client: CivitaiClient) -> None:
"""Update local metadata with CivitAI data"""
local_metadata['civitai'] = civitai_metadata
# Save existing trainedWords and customImages if they exist
existing_civitai = local_metadata.get('civitai') or {} # Use empty dict if None
# Create a new civitai metadata by updating existing with new
merged_civitai = existing_civitai.copy()
merged_civitai.update(civitai_metadata)
# Special handling for trainedWords - ensure we don't lose any existing trained words
if 'trainedWords' in existing_civitai:
existing_trained_words = existing_civitai.get('trainedWords', [])
new_trained_words = civitai_metadata.get('trainedWords', [])
# Use a set to combine words without duplicates, then convert back to list
merged_trained_words = list(set(existing_trained_words + new_trained_words))
merged_civitai['trainedWords'] = merged_trained_words
# Update local metadata with merged civitai data
local_metadata['civitai'] = merged_civitai
local_metadata['from_civitai'] = True
# Update model name if available
if 'model' in civitai_metadata:
if civitai_metadata.get('model', {}).get('name'):
local_metadata['model_name'] = civitai_metadata['model']['name']
# Fetch additional model metadata (description and tags) if we have model ID
model_id = civitai_metadata['modelId']
if model_id:
model_metadata, _ = await client.get_model_metadata(str(model_id))
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
# Extract model metadata directly from civitai_metadata if available
model_metadata = None
if 'model' in civitai_metadata and civitai_metadata.get('model'):
# Data is already available in the response from get_model_version
model_metadata = {
'description': civitai_metadata.get('model', {}).get('description', ''),
'tags': civitai_metadata.get('model', {}).get('tags', []),
'creator': civitai_metadata.get('creator', {})
}
# If we have modelId and don't have enough metadata, fetch additional data
if not model_metadata or not model_metadata.get('description'):
model_id = civitai_metadata.get('modelId')
if model_id:
fetched_metadata, _ = await client.get_model_metadata(str(model_id))
if fetched_metadata:
model_metadata = fetched_metadata
# Update local metadata with the model information
if model_metadata:
local_metadata['modelDescription'] = model_metadata.get('description', '')
local_metadata['tags'] = model_metadata.get('tags', [])
if 'creator' in model_metadata and model_metadata['creator']:
local_metadata['civitai']['creator'] = model_metadata['creator']
# Update base model
@@ -61,7 +97,7 @@ class ModelRouteUtils:
# Update preview if needed
if not local_metadata.get('preview_url') or not os.path.exists(local_metadata['preview_url']):
first_preview = next((img for img in civitai_metadata.get('images', [])), None)
if first_preview:
if (first_preview):
# Determine if content is video or image
is_video = first_preview['type'] == 'video'
@@ -120,8 +156,7 @@ class ModelRouteUtils:
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(metadata_path, local_metadata, True)
@staticmethod
async def fetch_and_update_model(
@@ -143,6 +178,11 @@ class ModelRouteUtils:
"""
client = CivitaiClient()
try:
# Validate input parameters
if not isinstance(model_data, dict):
logger.error(f"Invalid model_data type: {type(model_data)}")
return False
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model metadata exists
@@ -154,8 +194,7 @@ class ModelRouteUtils:
# Mark as not from CivitAI if not found
local_metadata['from_civitai'] = False
model_data['from_civitai'] = False
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(local_metadata, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(file_path, local_metadata)
return False
# Update metadata
@@ -166,21 +205,25 @@ class ModelRouteUtils:
client
)
# Update cache object directly
model_data.update({
# Update cache object directly using safe .get() method
update_dict = {
'model_name': local_metadata.get('model_name'),
'preview_url': local_metadata.get('preview_url'),
'from_civitai': True,
'civitai': civitai_metadata
})
}
model_data.update(update_dict)
# Update cache using the provided function
await update_cache_func(file_path, file_path, local_metadata)
return True
except KeyError as e:
logger.error(f"Error fetching CivitAI data - Missing key: {e} in model_data={model_data}")
return False
except Exception as e:
logger.error(f"Error fetching CivitAI data: {e}")
logger.error(f"Error fetching CivitAI data: {str(e)}", exc_info=True) # Include stack trace
return False
finally:
await client.close()
@@ -194,18 +237,17 @@ class ModelRouteUtils:
fields = [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images"
"model", "images", "customImages", "creator"
]
return {k: data[k] for k in fields if k in data}
@staticmethod
async def delete_model_files(target_dir: str, file_name: str, file_monitor=None) -> List[str]:
async def delete_model_files(target_dir: str, file_name: str) -> List[str]:
"""Delete model and associated files
Args:
target_dir: Directory containing the model files
file_name: Base name of the model file without extension
file_monitor: Optional file monitor to ignore delete events
Returns:
List of deleted file paths
@@ -223,11 +265,7 @@ class ModelRouteUtils:
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
if os.path.exists(main_path):
# Notify file monitor to ignore delete event if available
if file_monitor:
file_monitor.handler.add_ignore_path(main_path, 0)
if os.path.exists(main_path):
# Delete file
os.remove(main_path)
deleted.append(main_path)
@@ -248,10 +286,12 @@ class ModelRouteUtils:
@staticmethod
def get_multipart_ext(filename):
"""Get extension that may have multiple parts like .metadata.json"""
"""Get extension that may have multiple parts like .metadata.json or .metadata.json.bak"""
parts = filename.split(".")
if len(parts) > 2: # If contains multi-part extension
if len(parts) == 3: # If contains 2-part extension
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
elif len(parts) >= 4: # If contains 3-part or more extensions
return "." + ".".join(parts[-3:]) # Take the last three parts, like ".metadata.json.bak"
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"
# New common endpoint handlers
@@ -276,13 +316,9 @@ class ModelRouteUtils:
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
# Get the file monitor from the scanner if available
file_monitor = getattr(scanner, 'file_monitor', None)
deleted_files = await ModelRouteUtils.delete_model_files(
target_dir,
file_name,
file_monitor
file_name
)
# Remove from cache
@@ -312,7 +348,7 @@ class ModelRouteUtils:
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
web.Response: The HTTP response with metadata on success
"""
try:
data = await request.json()
@@ -337,7 +373,8 @@ class ModelRouteUtils:
# Update the cache
await scanner.update_single_model_cache(data['file_path'], data['file_path'], local_metadata)
return web.json_response({"success": True})
# Return the updated metadata along with success status
return web.json_response({"success": True, "metadata": local_metadata})
finally:
await client.close()
@@ -347,15 +384,7 @@ class ModelRouteUtils:
@staticmethod
async def handle_replace_preview(request: web.Request, scanner) -> web.Response:
"""Handle preview image replacement request
Args:
request: The aiohttp request
scanner: The model scanner instance with methods to update cache
Returns:
web.Response: The HTTP response
"""
"""Handle preview image replacement request"""
try:
reader = await request.multipart()
@@ -364,6 +393,15 @@ class ModelRouteUtils:
if field.name != 'preview_file':
raise ValueError("Expected 'preview_file' field")
content_type = field.headers.get('Content-Type', 'image/png')
# Try to get original filename if available
content_disposition = field.headers.get('Content-Disposition', '')
original_filename = None
import re
filename_match = re.search(r'filename="(.*?)"', content_disposition)
if filename_match:
original_filename = filename_match.group(1)
preview_data = await field.read()
# Read model path
@@ -372,17 +410,47 @@ class ModelRouteUtils:
raise ValueError("Expected 'model_path' field")
model_path = (await field.read()).decode()
# Read NSFW level
nsfw_level = 0 # Default to 0 (unknown)
field = await reader.next()
if field and field.name == 'nsfw_level':
try:
nsfw_level = int((await field.read()).decode())
except (ValueError, TypeError):
logger.warning("Invalid NSFW level format, using default 0")
# Save preview file
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
# Determine if content is video or image
# Determine format based on content type and original filename
is_gif = False
if original_filename and original_filename.lower().endswith('.gif'):
is_gif = True
elif content_type.lower() == 'image/gif':
is_gif = True
# Determine if content is video or image and handle specific formats
if content_type.startswith('video/'):
# For videos, keep original format and use .mp4 extension
extension = '.mp4'
# For videos, preserve original format if possible
if original_filename:
extension = os.path.splitext(original_filename)[1].lower()
# Default to .mp4 if no extension or unrecognized
if not extension or extension not in ['.mp4', '.webm', '.mov', '.avi']:
extension = '.mp4'
else:
# Try to determine extension from content type
if 'webm' in content_type:
extension = '.webm'
else:
extension = '.mp4' # Default
optimized_data = preview_data # No optimization for videos
elif is_gif:
# Preserve GIF format without optimization
extension = '.gif'
optimized_data = preview_data
else:
# For images, optimize and convert to WebP
# For other images, optimize and convert to WebP
optimized_data, _ = ExifUtils.optimize_image(
image_data=preview_data,
target_width=CARD_PREVIEW_WIDTH,
@@ -390,35 +458,45 @@ class ModelRouteUtils:
quality=85,
preserve_metadata=False
)
extension = '.webp' # Use .webp without .preview part
extension = '.webp'
# Delete any existing preview files for this model
for ext in PREVIEW_EXTENSIONS:
existing_preview = os.path.join(folder, base_name + ext)
if os.path.exists(existing_preview):
try:
os.remove(existing_preview)
logger.debug(f"Deleted existing preview: {existing_preview}")
except Exception as e:
logger.warning(f"Failed to delete existing preview {existing_preview}: {e}")
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, '/')
with open(preview_path, 'wb') as f:
f.write(optimized_data)
# Update preview path in metadata
# Update preview path and NSFW level in metadata
metadata_path = os.path.splitext(model_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
# Update preview_url directly in the metadata dict
# Update preview_url and preview_nsfw_level in the metadata dict
metadata['preview_url'] = preview_path
metadata['preview_nsfw_level'] = nsfw_level
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(model_path, metadata)
except Exception as e:
logger.error(f"Error updating metadata: {e}")
# Update preview URL in scanner cache
if hasattr(scanner, 'update_preview_in_cache'):
await scanner.update_preview_in_cache(model_path, preview_path)
await scanner.update_preview_in_cache(model_path, preview_path, nsfw_level)
return web.json_response({
"success": True,
"preview_url": config.get_preview_static_url(preview_path)
"preview_url": config.get_preview_static_url(preview_path),
"preview_nsfw_level": nsfw_level
})
except Exception as e:
@@ -448,8 +526,7 @@ class ModelRouteUtils:
metadata['exclude'] = True
# Save updated metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
await MetadataManager.save_metadata(file_path, metadata)
# Update cache
cache = await scanner.get_cached_data()
@@ -485,66 +562,77 @@ class ModelRouteUtils:
return web.Response(text=str(e), status=500)
@staticmethod
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
"""Handle model download request
Args:
request: The aiohttp request
download_manager: Instance of DownloadManager
model_type: Type of model ('lora' or 'checkpoint')
Returns:
web.Response: The HTTP response
"""
async def handle_download_model(request: web.Request) -> web.Response:
"""Handle model download request"""
try:
download_manager = await ServiceRegistry.get_download_manager()
data = await request.json()
# Create progress callback
# Get or generate a download ID
download_id = data.get('download_id', ws_manager.generate_download_id())
# Create progress callback with download ID
async def progress_callback(progress):
from ..services.websocket_manager import ws_manager
await ws_manager.broadcast({
await ws_manager.broadcast_download_progress(download_id, {
'status': 'progress',
'progress': progress
'progress': progress,
'download_id': download_id
})
# Check which identifier is provided
download_url = data.get('download_url')
model_hash = data.get('model_hash')
model_version_id = data.get('model_version_id')
# Check which identifier is provided and convert to int
model_id = None
model_version_id = None
# Validate that at least one identifier is provided
if not any([download_url, model_hash, model_version_id]):
return web.Response(
status=400,
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
)
if data.get('model_id'):
try:
model_id = int(data.get('model_id'))
except (TypeError, ValueError):
return web.json_response({
'success': False,
'error': "Invalid model_id: Must be an integer"
}, status=400)
# Convert model_version_id to int if provided
if data.get('model_version_id'):
try:
model_version_id = int(data.get('model_version_id'))
except (TypeError, ValueError):
return web.json_response({
'success': False,
'error': "Invalid model_version_id: Must be an integer"
}, status=400)
# Use the correct root directory based on model type
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
save_dir = data.get(root_key)
# At least one identifier is required
if not model_id and not model_version_id:
return web.json_response({
'success': False,
'error': "Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
}, status=400)
use_default_paths = data.get('use_default_paths', False)
# Pass the download_id to download_from_civitai
result = await download_manager.download_from_civitai(
download_url=download_url,
model_hash=model_hash,
model_id=model_id,
model_version_id=model_version_id,
save_dir=save_dir,
save_dir=data.get('model_root'),
relative_path=data.get('relative_path', ''),
use_default_paths=use_default_paths,
progress_callback=progress_callback,
model_type=model_type
download_id=download_id # Pass download_id explicitly
)
# Include download_id in the response
result['download_id'] = download_id
if not result.get('success', False):
error_message = result.get('error', 'Unknown error')
# Return 401 for early access errors
if 'early access' in error_message.lower():
logger.warning(f"Early access download failed: {error_message}")
return web.Response(
status=401, # Use 401 status code to match Civitai's response
text=f"Early Access Restriction: {error_message}"
)
return web.Response(status=500, text=error_message)
return web.json_response({
'success': False,
'error': error_message,
'download_id': download_id
}, status=500)
return web.json_response(result)
@@ -554,10 +642,453 @@ class ModelRouteUtils:
# Check if this might be an early access error
if '401' in error_message:
logger.warning(f"Early access error (401): {error_message}")
return web.Response(
status=401,
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
)
return web.json_response({
'success': False,
'error': "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
}, status=401)
logger.error(f"Error downloading {model_type}: {error_message}")
return web.Response(status=500, text=error_message)
logger.error(f"Error downloading model: {error_message}")
return web.json_response({
'success': False,
'error': error_message
}, status=500)
@staticmethod
async def handle_cancel_download(request: web.Request) -> web.Response:
"""Handle cancellation of a download task
Args:
request: The aiohttp request
Returns:
web.Response: The HTTP response
"""
try:
download_manager = await ServiceRegistry.get_download_manager()
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
result = await download_manager.cancel_download(download_id)
# Notify clients about cancellation via WebSocket
await ws_manager.broadcast_download_progress(download_id, {
'status': 'cancelled',
'progress': 0,
'download_id': download_id,
'message': 'Download cancelled by user'
})
return web.json_response(result)
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_list_downloads(request: web.Request) -> web.Response:
"""Get list of active downloads
Args:
request: The aiohttp request
Returns:
web.Response: The HTTP response with list of downloads
"""
try:
download_manager = await ServiceRegistry.get_download_manager()
result = await download_manager.get_active_downloads()
return web.json_response(result)
except Exception as e:
logger.error(f"Error listing downloads: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response:
"""Handle bulk deletion of models
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_paths = data.get('file_paths', [])
if not file_paths:
return web.json_response({
'success': False,
'error': 'No file paths provided for deletion'
}, status=400)
# Use the scanner's bulk delete method to handle all cache and file operations
result = await scanner.bulk_delete_models(file_paths)
return web.json_response({
'success': result.get('success', False),
'total_deleted': result.get('total_deleted', 0),
'total_attempted': result.get('total_attempted', len(file_paths)),
'results': result.get('results', [])
})
except Exception as e:
logger.error(f"Error in bulk delete: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_relink_civitai(request: web.Request, scanner) -> web.Response:
"""Handle CivitAI metadata re-linking request by model ID and/or version ID
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_path = data.get('file_path')
model_id = int(data.get('model_id'))
model_version_id = None
if data.get('model_version_id'):
model_version_id = int(data.get('model_version_id'))
if not file_path or not model_id:
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Check if model metadata exists
local_metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Create a client for fetching from Civitai
client = await CivitaiClient.get_instance()
try:
# Fetch metadata using get_model_version which includes more comprehensive data
civitai_metadata = await client.get_model_version(model_id, model_version_id)
if not civitai_metadata:
error_msg = f"Model version not found on CivitAI for ID: {model_id}"
if model_version_id:
error_msg += f" with version: {model_version_id}"
return web.json_response({"success": False, "error": error_msg}, status=404)
# Try to find the primary model file to get the SHA256 hash
primary_model_file = None
for file in civitai_metadata.get('files', []):
if file.get('primary', False) and file.get('type') == 'Model':
primary_model_file = file
break
# Update the SHA256 hash in local metadata if available
if primary_model_file and primary_model_file.get('hashes', {}).get('SHA256'):
local_metadata['sha256'] = primary_model_file['hashes']['SHA256'].lower()
# Update metadata with CivitAI information
await ModelRouteUtils.update_model_metadata(metadata_path, local_metadata, civitai_metadata, client)
# Update the cache
await scanner.update_single_model_cache(file_path, file_path, local_metadata)
return web.json_response({
"success": True,
"message": f"Model successfully re-linked to Civitai model {model_id}" +
(f" version {model_version_id}" if model_version_id else ""),
"hash": local_metadata.get('sha256', '')
})
finally:
await client.close()
except Exception as e:
logger.error(f"Error re-linking to CivitAI: {e}", exc_info=True)
return web.json_response({"success": False, "error": str(e)}, status=500)
@staticmethod
async def handle_verify_duplicates(request: web.Request, scanner) -> web.Response:
"""Handle verification of duplicate model hashes
Args:
request: The aiohttp request
scanner: The model scanner instance with cache management methods
Returns:
web.Response: The HTTP response with verification results
"""
try:
data = await request.json()
file_paths = data.get('file_paths', [])
if not file_paths:
return web.json_response({
'success': False,
'error': 'No file paths provided for verification'
}, status=400)
# Results tracking
results = {
'verified_as_duplicates': True, # Start true, set to false if any mismatch
'mismatched_files': [],
'new_hash_map': {}
}
# Get expected hash from the first file's metadata
expected_hash = None
first_metadata_path = os.path.splitext(file_paths[0])[0] + '.metadata.json'
first_metadata = await ModelRouteUtils.load_local_metadata(first_metadata_path)
if first_metadata and 'sha256' in first_metadata:
expected_hash = first_metadata['sha256'].lower()
# Process each file
for file_path in file_paths:
# Skip files that don't exist
if not os.path.exists(file_path):
continue
# Calculate actual hash
try:
from .file_utils import calculate_sha256
actual_hash = await calculate_sha256(file_path)
# Get metadata
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Compare hashes
stored_hash = metadata.get('sha256', '').lower()
# Set expected hash from first file if not yet set
if not expected_hash:
expected_hash = stored_hash
# Check if hash matches expected hash
if actual_hash != expected_hash:
results['verified_as_duplicates'] = False
results['mismatched_files'].append(file_path)
results['new_hash_map'][file_path] = actual_hash
# Check if stored hash needs updating
if actual_hash != stored_hash:
# Update metadata with actual hash
metadata['sha256'] = actual_hash
# Save updated metadata
await MetadataManager.save_metadata(file_path, metadata)
# Update cache
await scanner.update_single_model_cache(file_path, file_path, metadata)
except Exception as e:
logger.error(f"Error verifying hash for {file_path}: {e}")
results['mismatched_files'].append(file_path)
results['new_hash_map'][file_path] = "error_calculating_hash"
results['verified_as_duplicates'] = False
return web.json_response({
'success': True,
**results
})
except Exception as e:
logger.error(f"Error verifying duplicate models: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_rename_model(request: web.Request, scanner) -> web.Response:
"""Handle renaming a model file and its associated files
Args:
request: The aiohttp request
scanner: The model scanner instance
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_path = data.get('file_path')
new_file_name = data.get('new_file_name')
if not file_path or not new_file_name:
return web.json_response({
'success': False,
'error': 'File path and new file name are required'
}, status=400)
# Validate the new file name (no path separators or invalid characters)
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
if any(char in new_file_name for char in invalid_chars):
return web.json_response({
'success': False,
'error': 'Invalid characters in file name'
}, status=400)
# Get the directory and current file name
target_dir = os.path.dirname(file_path)
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
# Check if the target file already exists
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(os.sep, '/')
if os.path.exists(new_file_path):
return web.json_response({
'success': False,
'error': 'A file with this name already exists'
}, status=400)
# Define the patterns for associated files
patterns = [
f"{old_file_name}.safetensors", # Required
f"{old_file_name}.metadata.json",
f"{old_file_name}.metadata.json.bak",
]
# Add all preview file extensions
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{old_file_name}{ext}")
# Find all matching files
existing_files = []
for pattern in patterns:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
existing_files.append((path, pattern))
# Get the hash from the main file to update hash index
hash_value = None
metadata = None
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
if os.path.exists(metadata_path):
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
hash_value = metadata.get('sha256')
# Rename all files
renamed_files = []
new_metadata_path = None
for old_path, pattern in existing_files:
# Get the file extension like .safetensors or .metadata.json
ext = ModelRouteUtils.get_multipart_ext(pattern)
# Create the new path
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
# Rename the file
os.rename(old_path, new_path)
renamed_files.append(new_path)
# Keep track of metadata path for later update
if ext == '.metadata.json':
new_metadata_path = new_path
# Update the metadata file with new file name and paths
if new_metadata_path and metadata:
# Update file_name, file_path and preview_url in metadata
metadata['file_name'] = new_file_name
metadata['file_path'] = new_file_path
# Update preview_url if it exists
if 'preview_url' in metadata and metadata['preview_url']:
old_preview = metadata['preview_url']
ext = ModelRouteUtils.get_multipart_ext(old_preview)
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
metadata['preview_url'] = new_preview
# Save updated metadata
await MetadataManager.save_metadata(new_file_path, metadata)
# Update the scanner cache
if metadata:
await scanner.update_single_model_cache(file_path, new_file_path, metadata)
# Update recipe files and cache if hash is available and recipe_scanner exists
if hash_value and hasattr(scanner, 'update_lora_filename_by_hash'):
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
if recipe_scanner:
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed model")
return web.json_response({
'success': True,
'new_file_path': new_file_path,
'new_preview_path': config.get_preview_static_url(new_preview),
'renamed_files': renamed_files,
'reload_required': False
})
except Exception as e:
logger.error(f"Error renaming model: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
"""Handle saving metadata updates
Args:
request: The aiohttp request
scanner: The model scanner instance
Returns:
web.Response: The HTTP response
"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.Response(text='File path is required', status=400)
# Remove file path from data to avoid saving it
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
# Get metadata file path
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
# Load existing metadata
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
# Handle nested updates (for civitai.trainedWords)
for key, value in metadata_updates.items():
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
# Deep update for nested dictionaries
for nested_key, nested_value in value.items():
metadata[key][nested_key] = nested_value
else:
# Regular update for top-level keys
metadata[key] = value
# Save updated metadata
await MetadataManager.save_metadata(file_path, metadata)
# Update cache
await scanner.update_single_model_cache(file_path, file_path, metadata)
# If model_name was updated, resort the cache
if 'model_name' in metadata_updates:
cache = await scanner.get_cached_data()
await cache.resort()
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error saving metadata: {e}", exc_info=True)
return web.Response(text=str(e), status=500)

View File

@@ -4,6 +4,8 @@ import sys
import time
import asyncio
import logging
import datetime
import shutil
from typing import Dict, Set
from ..config import config
@@ -26,6 +28,7 @@ class UsageStats:
# Default stats file name
STATS_FILENAME = "lora_manager_stats.json"
BACKUP_SUFFIX = ".backup"
def __new__(cls):
if cls._instance is None:
@@ -39,8 +42,8 @@ class UsageStats:
# Initialize stats storage
self.stats = {
"checkpoints": {}, # sha256 -> count
"loras": {}, # sha256 -> count
"checkpoints": {}, # sha256 -> { total: count, history: { date: count } }
"loras": {}, # sha256 -> { total: count, history: { date: count } }
"total_executions": 0,
"last_save_time": 0
}
@@ -70,6 +73,68 @@ class UsageStats:
# Use the first lora root
return os.path.join(config.loras_roots[0], self.STATS_FILENAME)
def _backup_old_stats(self):
"""Backup the old stats file before conversion"""
if os.path.exists(self._stats_file_path):
backup_path = f"{self._stats_file_path}{self.BACKUP_SUFFIX}"
try:
shutil.copy2(self._stats_file_path, backup_path)
logger.info(f"Backed up old stats file to {backup_path}")
return True
except Exception as e:
logger.error(f"Failed to backup stats file: {e}")
return False
def _convert_old_format(self, old_stats):
"""Convert old stats format to new format with history"""
new_stats = {
"checkpoints": {},
"loras": {},
"total_executions": old_stats.get("total_executions", 0),
"last_save_time": old_stats.get("last_save_time", time.time())
}
# Get today's date in YYYY-MM-DD format
today = datetime.datetime.now().strftime("%Y-%m-%d")
# Convert checkpoint stats
if "checkpoints" in old_stats and isinstance(old_stats["checkpoints"], dict):
for hash_id, count in old_stats["checkpoints"].items():
new_stats["checkpoints"][hash_id] = {
"total": count,
"history": {
today: count
}
}
# Convert lora stats
if "loras" in old_stats and isinstance(old_stats["loras"], dict):
for hash_id, count in old_stats["loras"].items():
new_stats["loras"][hash_id] = {
"total": count,
"history": {
today: count
}
}
logger.info("Successfully converted stats from old format to new format with history")
return new_stats
def _is_old_format(self, stats):
"""Check if the stats are in the old format (direct count values)"""
# Check if any lora or checkpoint entry is a direct number instead of an object
if "loras" in stats and isinstance(stats["loras"], dict):
for hash_id, data in stats["loras"].items():
if isinstance(data, (int, float)):
return True
if "checkpoints" in stats and isinstance(stats["checkpoints"], dict):
for hash_id, data in stats["checkpoints"].items():
if isinstance(data, (int, float)):
return True
return False
def _load_stats(self):
"""Load existing statistics from file"""
try:
@@ -77,18 +142,27 @@ class UsageStats:
with open(self._stats_file_path, 'r', encoding='utf-8') as f:
loaded_stats = json.load(f)
# Update our stats with loaded data
if isinstance(loaded_stats, dict):
# Update individual sections to maintain structure
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
self.stats["checkpoints"] = loaded_stats["checkpoints"]
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
self.stats["loras"] = loaded_stats["loras"]
if "total_executions" in loaded_stats:
self.stats["total_executions"] = loaded_stats["total_executions"]
# Check if old format and needs conversion
if self._is_old_format(loaded_stats):
logger.info("Detected old stats format, performing conversion")
self._backup_old_stats()
self.stats = self._convert_old_format(loaded_stats)
else:
# Update our stats with loaded data (already in new format)
if isinstance(loaded_stats, dict):
# Update individual sections to maintain structure
if "checkpoints" in loaded_stats and isinstance(loaded_stats["checkpoints"], dict):
self.stats["checkpoints"] = loaded_stats["checkpoints"]
if "loras" in loaded_stats and isinstance(loaded_stats["loras"], dict):
self.stats["loras"] = loaded_stats["loras"]
if "total_executions" in loaded_stats:
self.stats["total_executions"] = loaded_stats["total_executions"]
if "last_save_time" in loaded_stats:
self.stats["last_save_time"] = loaded_stats["last_save_time"]
logger.info(f"Loaded usage statistics from {self._stats_file_path}")
except Exception as e:
logger.error(f"Error loading usage statistics: {e}")
@@ -174,15 +248,18 @@ class UsageStats:
# Increment total executions count
self.stats["total_executions"] += 1
# Get today's date in YYYY-MM-DD format
today = datetime.datetime.now().strftime("%Y-%m-%d")
# Process checkpoints
if MODELS in metadata and isinstance(metadata[MODELS], dict):
await self._process_checkpoints(metadata[MODELS])
await self._process_checkpoints(metadata[MODELS], today)
# Process loras
if LORAS in metadata and isinstance(metadata[LORAS], dict):
await self._process_loras(metadata[LORAS])
await self._process_loras(metadata[LORAS], today)
async def _process_checkpoints(self, models_data):
async def _process_checkpoints(self, models_data, today_date):
"""Process checkpoint models from metadata"""
try:
# Get checkpoint scanner service
@@ -208,12 +285,24 @@ class UsageStats:
# Get hash for this checkpoint
model_hash = checkpoint_scanner.get_hash_by_filename(model_filename)
if model_hash:
# Update stats for this checkpoint
self.stats["checkpoints"][model_hash] = self.stats["checkpoints"].get(model_hash, 0) + 1
# Update stats for this checkpoint with date tracking
if model_hash not in self.stats["checkpoints"]:
self.stats["checkpoints"][model_hash] = {
"total": 0,
"history": {}
}
# Increment total count
self.stats["checkpoints"][model_hash]["total"] += 1
# Increment today's count
if today_date not in self.stats["checkpoints"][model_hash]["history"]:
self.stats["checkpoints"][model_hash]["history"][today_date] = 0
self.stats["checkpoints"][model_hash]["history"][today_date] += 1
except Exception as e:
logger.error(f"Error processing checkpoint usage: {e}", exc_info=True)
async def _process_loras(self, loras_data):
async def _process_loras(self, loras_data, today_date):
"""Process LoRA models from metadata"""
try:
# Get LoRA scanner service
@@ -239,8 +328,20 @@ class UsageStats:
# Get hash for this LoRA
lora_hash = lora_scanner.get_hash_by_filename(lora_name)
if lora_hash:
# Update stats for this LoRA
self.stats["loras"][lora_hash] = self.stats["loras"].get(lora_hash, 0) + 1
# Update stats for this LoRA with date tracking
if lora_hash not in self.stats["loras"]:
self.stats["loras"][lora_hash] = {
"total": 0,
"history": {}
}
# Increment total count
self.stats["loras"][lora_hash]["total"] += 1
# Increment today's count
if today_date not in self.stats["loras"][lora_hash]["history"]:
self.stats["loras"][lora_hash]["history"][today_date] = 0
self.stats["loras"][lora_hash]["history"][today_date] += 1
except Exception as e:
logger.error(f"Error processing LoRA usage: {e}", exc_info=True)
@@ -251,9 +352,11 @@ class UsageStats:
async def get_model_usage_count(self, model_type, sha256):
"""Get usage count for a specific model by hash"""
if model_type == "checkpoint":
return self.stats["checkpoints"].get(sha256, 0)
if sha256 in self.stats["checkpoints"]:
return self.stats["checkpoints"][sha256]["total"]
elif model_type == "lora":
return self.stats["loras"].get(sha256, 0)
if sha256 in self.stats["loras"]:
return self.stats["loras"][sha256]["total"]
return 0
async def process_execution(self, prompt_id):

View File

@@ -1,85 +1,56 @@
from difflib import SequenceMatcher
import requests
import tempfile
import re
from bs4 import BeautifulSoup
import os
from typing import Dict
from ..services.service_registry import ServiceRegistry
from ..config import config
from ..services.settings_manager import settings
from .constants import CIVITAI_MODEL_TAGS
import asyncio
def download_twitter_image(url):
"""Download image from a URL containing twitter:image meta tag
def get_lora_info(lora_name):
"""Get the lora path and trigger words from cache"""
async def _get_lora_info_async():
scanner = await ServiceRegistry.get_lora_scanner()
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
file_path = item.get('file_path')
if file_path:
for root in config.loras_roots:
root = root.replace(os.sep, '/')
if file_path.startswith(root):
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
# Get trigger words from civitai metadata
civitai = item.get('civitai', {})
trigger_words = civitai.get('trainedWords', []) if civitai else []
return relative_path, trigger_words
return lora_name, []
Args:
url (str): The URL to download image from
Returns:
str: Path to downloaded temporary image file
"""
try:
# Download page content
response = requests.get(url)
response.raise_for_status()
# Check if we're already in an event loop
loop = asyncio.get_running_loop()
# If we're in a running loop, we need to use a different approach
# Create a new thread to run the async code
import concurrent.futures
# Parse HTML
soup = BeautifulSoup(response.text, 'html.parser')
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(_get_lora_info_async())
finally:
new_loop.close()
# Find twitter:image meta tag
meta_tag = soup.find('meta', attrs={'property': 'twitter:image'})
if not meta_tag:
return None
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
image_url = meta_tag['content']
# Download image
image_response = requests.get(image_url)
image_response.raise_for_status()
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
temp_file.write(image_response.content)
return temp_file.name
except Exception as e:
print(f"Error downloading twitter image: {e}")
return None
except RuntimeError:
# No event loop is running, we can use asyncio.run()
return asyncio.run(_get_lora_info_async())
def download_civitai_image(url):
"""Download image from a URL containing avatar image with specific class and style attributes
Args:
url (str): The URL to download image from
Returns:
str: Path to downloaded temporary image file
"""
try:
# Download page content
response = requests.get(url)
response.raise_for_status()
# Parse HTML
soup = BeautifulSoup(response.text, 'html.parser')
# Find image with specific class and style attributes
image = soup.select_one('img.EdgeImage_image__iH4_q.max-h-full.w-auto.max-w-full')
if not image or 'src' not in image.attrs:
return None
image_url = image['src']
# Download image
image_response = requests.get(image_url)
image_response.raise_for_status()
# Save to temp file
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
temp_file.write(image_response.content)
return temp_file.name
except Exception as e:
print(f"Error downloading civitai avatar: {e}")
return None
def fuzzy_match(text: str, pattern: str, threshold: float = 0.7) -> bool:
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
"""
Check if text matches pattern using fuzzy matching.
Returns True if similarity ratio is above threshold.
@@ -142,7 +113,7 @@ def calculate_recipe_fingerprint(loras):
# Get the hash - use modelVersionId as fallback if hash is empty
hash_value = lora.get("hash", "").lower()
if not hash_value and lora.get("isDeleted", False) and lora.get("modelVersionId"):
hash_value = lora.get("modelVersionId")
hash_value = str(lora.get("modelVersionId"))
# Skip entries without a valid hash
if not hash_value:
@@ -160,3 +131,95 @@ def calculate_recipe_fingerprint(loras):
fingerprint = "|".join([f"{hash_value}:{strength}" for hash_value, strength in valid_loras])
return fingerprint
def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora') -> str:
"""Calculate relative path for existing model using template from settings
Args:
model_data: Model data from scanner cache
model_type: Type of model ('lora', 'checkpoint', 'embedding')
Returns:
Relative path string (empty string for flat structure)
"""
# Get path template from settings for specific model type
path_template = settings.get_download_path_template(model_type)
# If template is empty, return empty path (flat structure)
if not path_template:
return ''
# Get base model name from model metadata
civitai_data = model_data.get('civitai', {})
# For CivitAI models, prefer civitai data only if 'id' exists; for non-CivitAI models, use model_data directly
if civitai_data and civitai_data.get('id') is not None:
base_model = civitai_data.get('baseModel', '')
# Get author from civitai creator data
creator_info = civitai_data.get('creator') or {}
author = creator_info.get('username') or 'Anonymous'
else:
# Fallback to model_data fields for non-CivitAI models
base_model = model_data.get('base_model', '')
author = 'Anonymous' # Default for non-CivitAI models
model_tags = model_data.get('tags', [])
# Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model)
# Find the first Civitai model tag that exists in model_tags
first_tag = ''
for civitai_tag in CIVITAI_MODEL_TAGS:
if civitai_tag in model_tags:
first_tag = civitai_tag
break
# If no Civitai model tag found, fallback to first tag
if not first_tag and model_tags:
first_tag = model_tags[0]
if not first_tag:
first_tag = 'no tags' # Default if no tags available
# Format the template with available data
formatted_path = path_template
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
formatted_path = formatted_path.replace('{first_tag}', first_tag)
formatted_path = formatted_path.replace('{author}', author)
return formatted_path
def remove_empty_dirs(path):
"""Recursively remove empty directories starting from the given path.
Args:
path (str): Root directory to start cleaning from
Returns:
int: Number of empty directories removed
"""
removed_count = 0
if not os.path.isdir(path):
return removed_count
# List all files in directory
files = os.listdir(path)
# Process all subdirectories first
for file in files:
full_path = os.path.join(path, file)
if os.path.isdir(full_path):
removed_count += remove_empty_dirs(full_path)
# Check if directory is now empty (after processing subdirectories)
if not os.listdir(path):
try:
os.rmdir(path)
removed_count += 1
except OSError:
pass
return removed_count

View File

@@ -1,20 +1,18 @@
[project]
name = "comfyui-lora-manager"
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
version = "0.8.13"
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
version = "0.8.29"
license = {file = "LICENSE"}
dependencies = [
"aiohttp",
"jinja2",
"safetensors",
"watchdog",
"beautifulsoup4",
"piexif",
"Pillow",
"olefile", # for getting rid of warning message
"requests",
"toml",
"natsort"
"natsort",
"GitPython"
]
[project.urls]
@@ -24,4 +22,4 @@ Repository = "https://github.com/willmiao/ComfyUI-Lora-Manager"
[tool.comfy]
PublisherId = "willmiao"
DisplayName = "ComfyUI-Lora-Manager"
Icon = ""
Icon = "https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/static/images/android-chrome-512x512.png?raw=true"

View File

@@ -1,11 +1,258 @@
{
"loras": "<lora:ck-neon-retrowave-IL-000012:0.8> <lora:aorunIllstrious:1> <lora:ck-shadow-circuit-IL-000012:0.78> <lora:MoriiMee_Gothic_Niji_Style_Illustrious_r1:0.45> <lora:ck-nc-cyberpunk-IL-000011:0.4>",
"prompt": "in the style of ck-rw, aorun, scales, makeup, bare shoulders, pointy ears, dress, claws, in the style of cksc, artist:moriimee, in the style of cknc, masterpiece, best quality, good quality, very aesthetic, absurdres, newest, 8K, depth of field, focused subject, close up, stylized, in gold and neon shades, wabi sabi, 1girl, rainbow angel wings, looking at viewer, dynamic angle, from below, from side, relaxing",
"negative_prompt": "bad quality, worst quality, worst detail, sketch ,signature, watermark, patreon logo, nsfw",
"steps": "20",
"sampler": "euler_ancestral",
"cfg_scale": "8",
"seed": "241",
"size": "832x1216",
"clip_skip": "2"
"id": 649516,
"name": "Cynthia -シロナ - Pokemon Diamond and Pearl - PDXL LORA",
"description": "<p><strong>Warning: Without Adetailer eyes are fucked (rainbow color and artefact)</strong></p><p><span style=\"color:rgb(193, 194, 197)\">Trained on </span><a target=\"_blank\" rel=\"ugc\" href=\"https://civitai.com/models/257749/horsefucker-diffusion-v6-xl\"><strong>Pony Diffusion V6 XL</strong></a> with 63 pictures.<br />Best result with weight between : 0.8-1.</p><p><span style=\"color:rgb(193, 194, 197)\">Basic prompts : </span><code>1girl, cynthia \\(pokemon\\), blonde hair, hair over one eye, very long hair, grey eyes, eyelashes, hair ornament</code> <br /><span style=\"color:rgb(193, 194, 197)\">Outfit prompts : </span><code>fur collar, black coat, fur-trimmed coat, long sleeves, black pants, black shirt, high heels</code></p><p>Reviews are really appreciated, i love to see the community use my work, that's why I share it.<br />If you like my work, you can tip me <a target=\"_blank\" rel=\"ugc\" href=\"https://ko-fi.com/konan49773\"><strong>here.</strong></a></p><p>Got a specific request ? I'm open for commission on my <a target=\"_blank\" rel=\"ugc\" href=\"https://ko-fi.com/konan49773/commissions\"><strong>kofi</strong></a> or<strong> </strong><a target=\"_blank\" rel=\"ugc\" href=\"https://www.fiverr.com/konanai/create-lora-model-for-you\"><strong>fiverr gig</strong></a> *! If you provide enough data, OCs are accepted</p>",
"allowNoCredit": true,
"allowCommercialUse": [
"Image",
"RentCivit"
],
"allowDerivatives": true,
"allowDifferentLicense": true,
"type": "LORA",
"minor": false,
"sfwOnly": false,
"poi": false,
"nsfw": false,
"nsfwLevel": 29,
"availability": "Public",
"cosmetic": null,
"supportsGeneration": true,
"stats": {
"downloadCount": 811,
"favoriteCount": 0,
"thumbsUpCount": 175,
"thumbsDownCount": 0,
"commentCount": 4,
"ratingCount": 0,
"rating": 0,
"tippedAmountCount": 10
},
"creator": {
"username": "Konan",
"image": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/7cd552a1-60fe-4baf-a0e4-f7d5d5381711/width=96/Konan.jpeg"
},
"tags": [
"anime",
"character",
"cynthia",
"woman",
"pokemon",
"pokegirl"
],
"modelVersions": [
{
"id": 726676,
"index": 0,
"name": "v1.0",
"baseModel": "Pony",
"createdAt": "2024-08-16T01:13:16.099Z",
"publishedAt": "2024-08-16T01:14:44.984Z",
"status": "Published",
"availability": "Public",
"nsfwLevel": 29,
"trainedWords": [
"1girl, cynthia \\(pokemon\\), blonde hair, hair over one eye, very long hair, grey eyes, eyelashes, hair ornament",
"fur collar, black coat, fur-trimmed coat, long sleeves, black pants, black shirt, high heels"
],
"covered": true,
"stats": {
"downloadCount": 811,
"ratingCount": 0,
"rating": 0,
"thumbsUpCount": 175,
"thumbsDownCount": 0
},
"files": [
{
"id": 641092,
"sizeKB": 56079.65234375,
"name": "CynthiaXL.safetensors",
"type": "Model",
"pickleScanResult": "Success",
"pickleScanMessage": "No Pickle imports",
"virusScanResult": "Success",
"virusScanMessage": null,
"scannedAt": "2024-08-16T01:17:19.087Z",
"metadata": {
"format": "SafeTensor"
},
"hashes": {},
"downloadUrl": "https://civitai.com/api/download/models/726676",
"primary": true
}
],
"images": [
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/b346d757-2b59-4aeb-9f09-3bee2724519d/width=1248/24511993.jpeg",
"nsfwLevel": 1,
"width": 1248,
"height": 1824,
"hash": "UqNc==RP.9s+~pxvIst7kWWBWBjY%MWBt7WB",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/fc132ac0-cc1c-4b68-a1d7-5b97b0996ac2/width=1248/24511997.jpeg",
"nsfwLevel": 1,
"width": 1248,
"height": 1824,
"hash": "UMGSS+?tTw.60MIX9cbb~WxHRRR-NEtLRiR%",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/7b3237d1-e672-466a-85d0-cc5dd42ab130/width=1160/24512001.jpeg",
"nsfwLevel": 4,
"width": 1160,
"height": 1696,
"hash": "U9NA6f~o00%h00wvIYt74:ER-=D%5600DiE1",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/ccd7d11d-4fa9-4434-85a1-fb999312e60d/width=1248/24511991.jpeg",
"nsfwLevel": 1,
"width": 1248,
"height": 1824,
"hash": "UyNTg.j?~qxu?aoLRkj]%MfkM{jZaya}a#ax",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/1743be6d-7fe5-4b55-9f19-c931618fa259/width=1248/24511996.jpeg",
"nsfwLevel": 4,
"width": 1248,
"height": 1824,
"hash": "UGOC~n^+?w~6Tx_4oM^$yYEkMds74:9F#*xY",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/91693c98-d037-4489-882c-100eb26019a0/width=1160/24512010.jpeg",
"nsfwLevel": 4,
"width": 1160,
"height": 1696,
"hash": "UJI}kp^-Kl%hXAIX4;Nf^+M|9GRP0Mt8%L%2",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/49c7a294-ac5b-4832-98e5-2acd0f1a8782/width=1248/24512017.jpeg",
"nsfwLevel": 4,
"width": 1248,
"height": 1824,
"hash": "UML;8Qn|9G%3mnWA4nWFMf%N?Hae~qog-oNF",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/d7b442f2-6ead-4a7a-9578-54d9ec2ff148/width=1248/24512015.jpeg",
"nsfwLevel": 1,
"width": 1248,
"height": 1824,
"hash": "UPGR#kt8xw%M0LWC9bWC?wxtR*NLM^jrxWM|",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/d840f1e9-3dd3-4531-b83a-1ba2c6b7feaa/width=1160/24512004.jpeg",
"nsfwLevel": 8,
"width": 1160,
"height": 1696,
"hash": "ULNm1i_39wi^*I%hDiM_tlo#xuV?^kNIxCs,",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/520387ae-c176-43e3-92bd-5cd2a672475e/width=1248/24512012.jpeg",
"nsfwLevel": 4,
"width": 1248,
"height": 1824,
"hash": "URM%l.%M.9Ip~poIkExu_3V@M|xuD%oJM{D*",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/9ea28b94-f326-4776-83ff-851cc203c627/width=1248/24511988.jpeg",
"nsfwLevel": 1,
"width": 1248,
"height": 1824,
"hash": "U-PZloog_Nxut6j]WXWB-;j?IVa#ofaxj]j]",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
},
{
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/2e749dbb-7d5a-48f1-8e29-fea5022a5fe9/width=1248/24522268.jpeg",
"nsfwLevel": 16,
"width": 1248,
"height": 1824,
"hash": "UPLgtm9Z0z=|0yRRE2-A9rWAoNE1~DwOr=t7",
"type": "image",
"minor": false,
"poi": false,
"hasMeta": true,
"hasPositivePrompt": true,
"onSite": false,
"remixOfId": null
}
],
"downloadUrl": "https://civitai.com/api/download/models/726676"
}
]
}

View File

@@ -1,13 +1,10 @@
aiohttp
jinja2
safetensors
watchdog
beautifulsoup4
piexif
Pillow
olefile
requests
toml
numpy
torch
natsort
natsort
GitPython

View File

@@ -9,6 +9,10 @@
"checkpoints": [
"C:/path/to/your/checkpoints_folder",
"C:/path/to/another/checkpoints_folder"
],
"embeddings": [
"C:/path/to/your/embeddings_folder",
"C:/path/to/another/embeddings_folder"
]
}
}

View File

@@ -1,7 +1,28 @@
from pathlib import Path
import os
import sys
import json
# Create mock modules for py/nodes directory - add this before any other imports
def mock_nodes_directory():
"""Create mock modules for all Python files in the py/nodes directory"""
nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes')
if os.path.exists(nodes_dir):
# Create a mock module for the nodes package itself
sys.modules['py.nodes'] = type('MockNodesModule', (), {})
# Create mock modules for all Python files in the nodes directory
for file in os.listdir(nodes_dir):
if file.endswith('.py') and file != '__init__.py':
module_name = file[:-3] # Remove .py extension
full_module_name = f'py.nodes.{module_name}'
# Create empty module object
sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {})
print(f"Created mock module for: {full_module_name}")
# Run the mocking function before any other imports
mock_nodes_directory()
# Create mock folder_paths module BEFORE any other imports
class MockFolderPaths:
@staticmethod
@@ -85,6 +106,22 @@ logger = logging.getLogger("lora-manager-standalone")
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Add specific suppression for connection reset errors
class ConnectionResetFilter(logging.Filter):
def filter(self, record):
# Filter out connection reset errors that are not critical
if "ConnectionResetError" in str(record.getMessage()):
return False
if "_call_connection_lost" in str(record.getMessage()):
return False
if "WinError 10054" in str(record.getMessage()):
return False
return True
# Apply the filter to asyncio logger
asyncio_logger = logging.getLogger("asyncio")
asyncio_logger.addFilter(ConnectionResetFilter())
# Now we can import the global config from our local modules
from py.config import config
@@ -97,17 +134,6 @@ class StandaloneServer:
# Ensure the app's access logger is configured to reduce verbosity
self.app._subapps = [] # Ensure this exists to avoid AttributeError
# Configure access logging for the app
self.app.on_startup.append(self._configure_access_logger)
async def _configure_access_logger(self, app):
"""Configure access logger to reduce verbosity"""
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# If using aiohttp>=3.8.0, configure access logger through app directly
if hasattr(app, 'access_logger'):
app.access_logger.setLevel(logging.WARNING)
async def setup(self):
"""Set up the standalone server"""
@@ -197,9 +223,6 @@ class StandaloneLoraManager(LoraManager):
# Store app in a global-like location for compatibility
sys.modules['server'].PromptServer.instance = server_instance
# Configure aiohttp access logger to be less verbose
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
added_targets = set() # Track already added target paths
@@ -231,7 +254,7 @@ class StandaloneLoraManager(LoraManager):
added_targets.add(os.path.normpath(real_root))
# Add static routes for each checkpoint root
for idx, root in enumerate(config.checkpoints_roots, start=1):
for idx, root in enumerate(config.base_models_roots, start=1):
if not os.path.exists(root):
logger.warning(f"Checkpoint root path does not exist: {root}")
continue
@@ -256,23 +279,50 @@ class StandaloneLoraManager(LoraManager):
# Record route mapping
config.add_route_mapping(real_root, preview_path)
added_targets.add(os.path.normpath(real_root))
# Add static routes for each embedding root
for idx, root in enumerate(getattr(config, "embeddings_roots", []), start=1):
if not os.path.exists(root):
logger.warning(f"Embedding root path does not exist: {root}")
continue
preview_path = f'/embeddings_static/root{idx}/preview'
real_root = root
for target, link in config._path_mappings.items():
if os.path.normpath(link) == os.path.normpath(root):
real_root = target
break
display_root = real_root.replace('\\', '/')
app.router.add_static(preview_path, real_root)
logger.info(f"Added static route {preview_path} -> {display_root}")
config.add_route_mapping(real_root, preview_path)
added_targets.add(os.path.normpath(real_root))
# Add static routes for symlink target paths that aren't already covered
link_idx = {
'lora': 1,
'checkpoint': 1
'checkpoint': 1,
'embedding': 1
}
for target_path, link_path in config._path_mappings.items():
norm_target = os.path.normpath(target_path)
if norm_target not in added_targets:
# Determine if this is a checkpoint or lora link based on path
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots)
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots)
# Determine if this is a checkpoint, lora, or embedding link based on path
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots)
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots)
is_embedding = any(os.path.normpath(emb_root) in os.path.normpath(link_path) for emb_root in getattr(config, "embeddings_roots", []))
is_embedding = is_embedding or any(os.path.normpath(emb_root) in norm_target for emb_root in getattr(config, "embeddings_roots", []))
if is_checkpoint:
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
link_idx["checkpoint"] += 1
elif is_embedding:
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
link_idx["embedding"] += 1
else:
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
link_idx["lora"] += 1
@@ -280,39 +330,52 @@ class StandaloneLoraManager(LoraManager):
# Display path with forward slashes for consistency
display_target = target_path.replace('\\', '/')
app.router.add_static(route_path, target_path)
logger.info(f"Added static route for link target {route_path} -> {display_target}")
config.add_route_mapping(target_path, route_path)
added_targets.add(norm_target)
try:
app.router.add_static(route_path, Path(target_path).resolve(strict=False))
logger.info(f"Added static route for link target {route_path} -> {display_target}")
config.add_route_mapping(target_path, route_path)
added_targets.add(norm_target)
except Exception as e:
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
continue
# Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path)
# Setup feature routes
from py.routes.lora_routes import LoraRoutes
from py.routes.api_routes import ApiRoutes
from py.services.model_service_factory import ModelServiceFactory, register_default_model_types
from py.routes.recipe_routes import RecipeRoutes
from py.routes.checkpoints_routes import CheckpointsRoutes
from py.routes.update_routes import UpdateRoutes
from py.routes.misc_routes import MiscRoutes
from py.routes.example_images_routes import ExampleImagesRoutes
from py.routes.stats_routes import StatsRoutes
from py.services.websocket_manager import ws_manager
lora_routes = LoraRoutes()
checkpoints_routes = CheckpointsRoutes()
register_default_model_types()
# Setup all model routes using the factory
ModelServiceFactory.setup_all_routes(app)
stats_routes = StatsRoutes()
# Initialize routes
lora_routes.setup_routes(app)
checkpoints_routes.setup_routes(app)
ApiRoutes.setup_routes(app)
stats_routes.setup_routes(app)
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
# Schedule service initialization
app.on_startup.append(lambda app: cls._initialize_services())
# Add cleanup
app.on_shutdown.append(cls._cleanup)
app.on_shutdown.append(ApiRoutes.cleanup)
def parse_args():
"""Parse command line arguments"""
@@ -337,9 +400,6 @@ async def main():
# Set log level
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Explicitly configure aiohttp access logger regardless of selected log level
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
# Create the server instance
server = StandaloneServer()

View File

@@ -29,16 +29,29 @@ html, body {
:root {
--bg-color: #ffffff;
--text-color: #333333;
--text-muted: #6c757d;
--card-bg: #ffffff;
--border-color: #e0e0e0;
/* Color System */
--lora-accent: oklch(68% 0.28 256);
/* Color Components */
--lora-accent-l: 68%;
--lora-accent-c: 0.28;
--lora-accent-h: 256;
--lora-warning-l: 75%;
--lora-warning-c: 0.25;
--lora-warning-h: 80;
--lora-success-l: 70%;
--lora-success-c: 0.2;
--lora-success-h: 140;
/* Composed Colors */
--lora-accent: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
--lora-surface: oklch(100% 0 0 / 0.98);
--lora-border: oklch(90% 0.02 256 / 0.15);
--lora-text: oklch(95% 0.02 256);
--lora-error: oklch(75% 0.32 29);
--lora-warning: oklch(75% 0.25 80); /* Modified to be used with oklch() */
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h));
/* Spacing Scale */
--space-1: calc(8px * 1);
@@ -57,6 +70,11 @@ html, body {
--border-radius-xs: 4px;
--scrollbar-width: 8px; /* 添加滚动条宽度变量 */
/* Shortcut styles */
--shortcut-bg: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.12);
--shortcut-border: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.25);
--shortcut-text: var(--text-color);
}
html[data-theme="dark"] {
@@ -72,6 +90,7 @@ html[data-theme="light"] {
[data-theme="dark"] {
--bg-color: #1a1a1a;
--text-color: #e0e0e0;
--text-muted: #a0a0a0;
--card-bg: #2d2d2d;
--border-color: #404040;

View File

@@ -0,0 +1,245 @@
/* Banner Container */
.banner-container {
position: relative;
width: 100%;
z-index: calc(var(--z-header) - 1);
border-bottom: 1px solid var(--border-color);
background: var(--card-bg);
margin-bottom: var(--space-2);
}
/* Individual Banner */
.banner-item {
position: relative;
padding: var(--space-2) var(--space-3);
background: linear-gradient(135deg,
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.05),
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.02)
);
border-left: 4px solid var(--lora-accent);
animation: banner-slide-down 0.3s ease-in-out;
}
/* Banner Content Layout */
.banner-content {
display: flex;
align-items: center;
justify-content: space-between;
gap: var(--space-3);
max-width: 1400px;
margin: 0 auto;
}
/* Banner Text Section */
.banner-text {
flex: 1;
min-width: 0;
}
.banner-title {
margin: 0 0 4px 0;
font-size: 1.1em;
font-weight: 600;
color: var(--text-color);
line-height: 1.3;
}
.banner-description {
margin: 0;
font-size: 0.9em;
color: var(--text-muted);
line-height: 1.4;
}
/* Banner Actions */
.banner-actions {
display: flex;
align-items: center;
gap: var(--space-1);
flex-shrink: 0;
}
.banner-action {
display: inline-flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
border-radius: var(--border-radius-xs);
text-decoration: none;
font-size: 0.85em;
font-weight: 500;
transition: all 0.2s ease;
white-space: nowrap;
border: 1px solid transparent;
}
.banner-action i {
font-size: 0.9em;
}
/* Primary Action Button */
.banner-action-primary {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
.banner-action-primary:hover {
background: oklch(calc(var(--lora-accent-l) - 5%) var(--lora-accent-c) var(--lora-accent-h));
transform: translateY(-1px);
box-shadow: 0 3px 6px oklch(var(--lora-accent) / 0.3);
}
/* Secondary Action Button */
.banner-action-secondary {
background: var(--card-bg);
color: var(--text-color);
border-color: var(--border-color);
}
.banner-action-secondary:hover {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
transform: translateY(-1px);
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
/* Tertiary Action Button */
.banner-action-tertiary {
background: transparent;
color: var(--lora-accent);
border-color: var(--lora-accent);
}
.banner-action-tertiary:hover {
background: var(--lora-accent);
color: white;
transform: translateY(-1px);
}
/* Dismiss Button */
.banner-dismiss {
position: absolute;
top: 8px;
right: 8px;
width: 24px;
height: 24px;
border: none;
background: transparent;
color: var(--text-muted);
cursor: pointer;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
transition: all 0.2s ease;
font-size: 0.8em;
}
.banner-dismiss:hover {
background: oklch(var(--lora-accent) / 0.1);
color: var(--lora-accent);
transform: scale(1.1);
}
/* Animations */
@keyframes banner-slide-down {
from {
opacity: 0;
transform: translateY(-100%);
}
to {
opacity: 1;
transform: translateY(0);
}
}
@keyframes banner-slide-up {
from {
opacity: 1;
transform: translateY(0);
max-height: 200px;
}
to {
opacity: 0;
transform: translateY(-20px);
max-height: 0;
padding-top: 0;
padding-bottom: 0;
}
}
/* Responsive Design */
@media (max-width: 768px) {
.banner-content {
flex-direction: column;
align-items: flex-start;
gap: var(--space-2);
}
.banner-actions {
width: 100%;
flex-wrap: wrap;
justify-content: flex-start;
}
.banner-action {
flex: 1;
min-width: 0;
justify-content: center;
}
.banner-dismiss {
top: 6px;
right: 6px;
}
.banner-item {
padding: var(--space-2);
}
.banner-title {
font-size: 1em;
}
.banner-description {
font-size: 0.85em;
}
}
@media (max-width: 480px) {
.banner-actions {
flex-direction: column;
width: 100%;
}
.banner-action {
width: 100%;
justify-content: center;
}
.banner-content {
gap: var(--space-1);
}
}
/* Dark theme adjustments */
[data-theme="dark"] .banner-item {
background: linear-gradient(135deg,
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.08),
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.03)
);
}
/* Prevent text selection */
.banner-item,
.banner-title,
.banner-description,
.banner-action,
.banner-dismiss {
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
}

View File

@@ -12,7 +12,9 @@
z-index: var(--z-overlay);
display: flex;
flex-direction: column;
min-width: 300px;
min-width: 420px;
max-width: 900px;
width: auto;
transition: all 0.4s cubic-bezier(0.175, 0.885, 0.32, 1.275);
opacity: 0;
}
@@ -48,6 +50,8 @@
color: var(--text-color);
cursor: pointer;
font-size: 14px;
white-space: nowrap;
min-height: 36px;
display: flex;
align-items: center;
gap: 6px;
@@ -60,13 +64,25 @@
border-color: var(--lora-accent);
}
/* Danger button style - updated to use proper theme variables */
.bulk-operations-actions button.danger-btn {
background: oklch(70% 0.2 29); /* Light red background that works in both themes */
color: oklch(98% 0.01 0); /* Almost white text for good contrast */
border-color: var(--lora-error);
}
.bulk-operations-actions button.danger-btn:hover {
background: var(--lora-error);
color: oklch(100% 0 0); /* Pure white text on hover for maximum contrast */
}
/* Style for selected cards */
.lora-card.selected {
.model-card.selected {
box-shadow: 0 0 0 2px var(--lora-accent);
position: relative;
}
.lora-card.selected::after {
.model-card.selected::after {
content: "✓";
position: absolute;
top: 10px;
@@ -93,6 +109,8 @@
@media (max-width: 768px) {
.bulk-operations-panel {
width: calc(100% - 40px);
min-width: unset;
max-width: unset;
left: 20px;
transform: none;
border-radius: var(--border-radius-sm);
@@ -262,83 +280,6 @@
background: var(--lora-accent);
}
/* NSFW Level Selector */
.nsfw-level-selector {
position: fixed;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
background: var(--card-bg);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-base);
padding: 16px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
z-index: var(--z-modal);
width: 300px;
display: none;
}
.nsfw-level-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 16px;
}
.nsfw-level-header h3 {
margin: 0;
font-size: 16px;
font-weight: 500;
}
.close-nsfw-selector {
background: transparent;
border: none;
color: var(--text-color);
cursor: pointer;
padding: 4px;
border-radius: var(--border-radius-xs);
}
.close-nsfw-selector:hover {
background: var(--border-color);
}
.current-level {
margin-bottom: 12px;
padding: 8px;
background: var(--bg-color);
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
}
.nsfw-level-options {
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.nsfw-level-btn {
flex: 1 0 calc(33% - 8px);
padding: 8px;
border-radius: var(--border-radius-xs);
background: var(--bg-color);
border: 1px solid var(--border-color);
color: var(--text-color);
cursor: pointer;
transition: all 0.2s ease;
}
.nsfw-level-btn:hover {
background: var(--lora-border);
}
.nsfw-level-btn.active {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
/* Mobile optimizations */
@media (max-width: 768px) {
.selected-thumbnails-strip {

View File

@@ -1,48 +1,76 @@
/* 卡片网格布局 */
.card-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr)); /* Adjusted from 320px */
gap: 12px; /* Reduced from var(--space-2) for tighter horizontal spacing */
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr)); /* Base size */
gap: 12px; /* Consistent gap for both row and column spacing */
row-gap: 20px; /* Increase vertical spacing between rows */
margin-top: var(--space-2);
padding-top: 4px; /* 添加顶部内边距,为悬停动画提供空间 */
padding-bottom: 4px; /* 添加底部内边距,为悬停动画提供空间 */
max-width: 1400px; /* Container width control */
width: 100%; /* Ensure it takes full width of container */
max-width: 1400px; /* Base container width */
margin-left: auto;
margin-right: auto;
box-sizing: border-box; /* Include padding in width calculation */
}
.lora-card {
.model-card {
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-base);
backdrop-filter: blur(16px);
transition: transform 160ms ease-out;
aspect-ratio: 896/1152;
max-width: 260px; /* Adjusted from 320px to fit 5 cards */
aspect-ratio: 896/1152; /* Preserve aspect ratio */
max-width: 260px; /* Base size */
width: 100%;
margin: 0 auto;
cursor: pointer; /* Added from recipe-card */
display: flex; /* Added from recipe-card */
flex-direction: column; /* Added from recipe-card */
overflow: hidden; /* Add overflow hidden to contain children */
cursor: pointer;
display: flex;
flex-direction: column;
overflow: hidden;
}
.lora-card:hover {
.model-card:hover {
transform: translateY(-2px);
background: oklch(100% 0 0 / 0.6);
}
.lora-card:focus-visible {
.model-card:focus-visible {
outline: 2px solid var(--lora-accent);
outline-offset: 2px;
}
/* Responsive adjustments for 1440p screens (2K) */
@media (min-width: 2000px) {
.card-grid {
max-width: 1800px; /* Increased for 2K screens */
grid-template-columns: repeat(auto-fill, minmax(270px, 1fr));
}
.model-card {
max-width: 270px;
}
}
/* Responsive adjustments for 4K screens */
@media (min-width: 3000px) {
.card-grid {
max-width: 2400px; /* Increased for 4K screens */
grid-template-columns: repeat(auto-fill, minmax(280px, 1fr));
}
.model-card {
max-width: 280px;
}
}
/* Responsive adjustments */
@media (max-width: 1400px) {
.card-grid {
grid-template-columns: repeat(auto-fill, minmax(240px, 1fr));
}
.lora-card {
.model-card {
max-width: 240px;
}
}
@@ -58,6 +86,42 @@
min-height: 0; /* Fix for potential flexbox sizing issue in Firefox */
}
/* Smaller text for medium density */
.medium-density .model-name {
font-size: 0.95em;
max-height: 3em; /* Increased from 2.6em */
}
.medium-density .base-model-label {
font-size: 0.85em;
max-width: 120px;
}
.medium-density .card-actions i {
font-size: 0.98em;
padding: 4px;
}
/* Smaller text for compact mode */
.compact-density .model-name {
font-size: 0.9em;
max-height: 2.8em; /* Increased from 2.4em */
}
.compact-density .base-model-label {
font-size: 0.8em;
max-width: 110px;
}
.compact-density .card-actions i {
font-size: 0.95em;
padding: 3px;
}
.compact-density .model-info {
padding-bottom: 2px;
}
.card-preview img,
.card-preview video {
width: 100%;
@@ -103,6 +167,38 @@
text-shadow: 1px 1px 1px rgba(0, 0, 0, 0.5);
}
/* NSFW warning adjustments for medium density */
.medium-density .nsfw-warning {
padding: calc(var(--space-2) * 0.85);
max-width: 70%;
}
.medium-density .nsfw-warning p {
font-size: 0.95em;
margin-bottom: calc(var(--space-1) * 0.85);
}
.medium-density .show-content-btn {
font-size: 0.85em;
padding: 3px calc(var(--space-1) * 0.85);
}
/* NSFW warning adjustments for compact density */
.compact-density .nsfw-warning {
padding: calc(var(--space-2) * 0.7);
max-width: 60%;
}
.compact-density .nsfw-warning p {
font-size: 0.85em;
margin-bottom: calc(var(--space-1) * 0.7);
}
.compact-density .show-content-btn {
font-size: 0.8em;
padding: 2px var(--space-1);
}
.toggle-blur-btn {
position: absolute;
left: var(--space-1);
@@ -156,6 +252,18 @@
z-index: 3;
}
/* New styles for hover reveal mode */
.hover-reveal .card-header,
.hover-reveal .card-footer {
opacity: 0;
transition: opacity 0.2s ease;
}
.hover-reveal .model-card:hover .card-header,
.hover-reveal .model-card:hover .card-footer {
opacity: 1;
}
.card-footer {
position: absolute;
bottom: 0;
@@ -237,7 +345,7 @@
grid-template-columns: minmax(260px, 1fr); /* Adjusted minimum size for mobile */
}
.lora-card {
.model-card {
max-width: 100%; /* Allow cards to fill available space on mobile */
}
}
@@ -267,21 +375,24 @@
text-decoration: none;
}
/* Updated model name to fix text cutoff issues */
.model-name {
font-weight: bold;
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.5);
font-size: 0.95em;
word-break: break-word;
display: block;
max-height: 2.8em;
max-height: 3em; /* Increased to ensure two full lines */
overflow: hidden;
/* Add line height for consistency */
line-height: 1.4;
}
.model-info {
flex: 1;
min-width: 0;
overflow: hidden;
padding-bottom: 4px;
padding-bottom: 6px; /* Increased from 4px to give more room for text */
}
.base-model {
@@ -313,28 +424,50 @@
font-size: 0.85em;
}
/* Recipe specific elements - migrated from recipe-card.css */
.recipe-indicator {
position: absolute;
top: 6px;
left: 8px;
width: 24px;
height: 24px;
background: var(--lora-primary);
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-weight: bold;
z-index: 2;
/* Style for version name */
.version-name {
display: inline-block;
color: rgba(255,255,255,0.8); /* Muted white */
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.5);
font-size: 0.85em;
word-break: break-word;
overflow: hidden;
line-height: 1.4;
margin-top: 2px;
opacity: 0.8; /* Slightly transparent for better readability */
border: 1px solid rgba(255,255,255,0.25); /* Subtle border */
border-radius: var(--border-radius-xs);
padding: 1px 6px;
background: rgba(0,0,0,0.18); /* Optional: subtle background for contrast */
}
.base-model-wrapper {
display: flex;
align-items: center;
gap: 8px;
margin-left: 32px; /* For accommodating the recipe indicator */
/* Medium density adjustments for version name */
.medium-density .version-name {
font-size: 0.8em;
}
/* Compact density adjustments for version name */
.compact-density .version-name {
font-size: 0.75em;
}
/* Prevent text selection on cards and interactive elements */
.model-card,
.model-card *,
.card-actions,
.card-actions i,
.toggle-blur-btn,
.show-content-btn,
.card-preview img,
.card-preview video,
.card-footer,
.card-header,
.model-name,
.base-model-label {
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
}
.lora-count {
@@ -362,4 +495,84 @@
padding: 2rem;
background: var(--lora-surface-alt);
border-radius: var(--border-radius-base);
}
}
/* Virtual scrolling specific styles - updated */
.virtual-scroll-item {
position: absolute;
box-sizing: border-box;
transition: transform 160ms ease-out;
margin: 0; /* Remove margins, positioning is handled by VirtualScroller */
width: 100%; /* Allow width to be set by the VirtualScroller */
}
.virtual-scroll-item:hover {
transform: translateY(-2px); /* Keep hover effect */
z-index: 1; /* Ensure hovered items appear above others */
}
/* When using virtual scroll, adjust container */
.card-grid.virtual-scroll {
display: block;
position: relative;
margin: 0 auto;
padding: 4px 0; /* Add top/bottom padding equivalent to card padding */
height: auto;
width: 100%;
max-width: 1400px; /* Keep the max-width from original grid */
box-sizing: border-box; /* Include padding in width calculation */
overflow-x: hidden; /* Prevent horizontal overflow */
}
/* For larger screens, allow more space for the cards */
@media (min-width: 2000px) {
.card-grid.virtual-scroll {
max-width: 1800px;
}
}
@media (min-width: 3000px) {
.card-grid.virtual-scroll {
max-width: 2400px;
}
}
/* Add after the existing .model-card:hover styles */
@keyframes update-pulse {
0% { box-shadow: 0 0 0 0 var(--lora-accent-transparent); }
50% { box-shadow: 0 0 0 4px var(--lora-accent-transparent); }
100% { box-shadow: 0 0 0 0 var(--lora-accent-transparent); }
}
/* Add semi-transparent version of accent color for animation */
:root {
--lora-accent-transparent: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.6);
}
.model-card.updated {
animation: update-pulse 1.2s ease-out;
}
/* Add a subtle updated tag that fades in and out */
.update-indicator {
position: absolute;
top: 8px;
right: 8px;
background: var(--lora-accent);
color: white;
border-radius: var(--border-radius-xs);
padding: 3px 6px;
font-size: 0.75em;
opacity: 0;
transform: translateY(-5px);
z-index: 4;
animation: update-tag 1.8s ease-out forwards;
}
@keyframes update-tag {
0% { opacity: 0; transform: translateY(-5px); }
15% { opacity: 1; transform: translateY(0); }
85% { opacity: 1; transform: translateY(0); }
100% { opacity: 0; transform: translateY(0); }
}

View File

@@ -1,197 +0,0 @@
/* Download Modal Styles */
.download-step {
margin: var(--space-2) 0;
}
.input-group {
margin-bottom: var(--space-2);
}
.input-group label {
display: block;
margin-bottom: 8px;
color: var(--text-color);
}
.input-group input,
.input-group select {
width: 100%;
padding: 8px;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
color: var(--text-color);
}
/* Version List Styles */
.version-list {
max-height: 400px;
overflow-y: auto;
margin: var(--space-2) 0;
display: flex;
flex-direction: column;
gap: 12px;
padding: 1px;
}
.version-item {
display: flex;
gap: var(--space-2);
padding: var(--space-2);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: all 0.2s ease;
background: var(--bg-color);
margin: 1px;
position: relative;
}
.version-item:hover {
border-color: var(--lora-accent);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
z-index: 1;
}
.version-item.selected {
border: 2px solid var(--lora-accent);
background: oklch(var(--lora-accent) / 0.05);
}
.version-thumbnail {
width: 80px;
height: 80px;
flex-shrink: 0;
border-radius: var(--border-radius-xs);
overflow: hidden;
background: var(--bg-color);
}
.version-thumbnail img {
width: 100%;
height: 100%;
object-fit: cover;
}
.version-content {
display: flex;
flex-direction: column;
gap: 8px;
flex: 1;
min-width: 0;
}
.version-header {
display: flex;
align-items: flex-start;
justify-content: space-between;
gap: var(--space-2);
}
.version-content h3 {
margin: 0;
font-size: 1.1em;
color: var(--text-color);
flex: 1;
}
.version-info {
display: flex;
flex-wrap: wrap;
flex-direction: row !important;
gap: 8px;
align-items: center;
font-size: 0.9em;
}
.version-info .base-model {
background: oklch(var(--lora-accent) / 0.1);
color: var(--lora-accent);
padding: 2px 8px;
border-radius: var(--border-radius-xs);
}
.version-meta {
display: flex;
gap: 12px;
font-size: 0.85em;
color: var(--text-color);
opacity: 0.7;
}
.version-meta span {
display: flex;
align-items: center;
gap: 4px;
}
/* Folder Browser Styles */
.folder-browser {
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
padding: var(--space-1);
max-height: 200px;
overflow-y: auto;
}
.folder-item {
padding: 8px;
cursor: pointer;
border-radius: var(--border-radius-xs);
transition: background-color 0.2s;
}
.folder-item:hover {
background: var(--lora-surface);
}
.folder-item.selected {
background: oklch(var(--lora-accent) / 0.1);
border: 1px solid var(--lora-accent);
}
/* Path Preview Styles */
.path-preview {
margin-bottom: var(--space-3);
padding: var(--space-2);
background: var(--bg-color);
border-radius: var(--border-radius-sm);
border: 1px dashed var(--border-color);
}
.path-preview label {
display: block;
margin-bottom: 8px;
color: var(--text-color);
font-size: 0.9em;
opacity: 0.8;
}
.path-display {
padding: var(--space-1);
color: var(--text-color);
font-family: monospace;
font-size: 0.9em;
line-height: 1.4;
white-space: pre-wrap;
word-break: break-all;
opacity: 0.85;
background: var(--lora-surface);
border-radius: var(--border-radius-xs);
}
/* Dark theme adjustments */
[data-theme="dark"] .version-item {
background: var(--lora-surface);
}
[data-theme="dark"] .local-path {
background: var(--lora-surface);
border-color: var(--lora-border);
}
/* Enhance the local badge to make it more noticeable */
.version-item.exists-locally {
background: oklch(var(--lora-accent) / 0.05);
border-left: 4px solid var(--lora-accent);
}

View File

@@ -2,30 +2,46 @@
/* Duplicates banner */
.duplicates-banner {
position: sticky;
top: 48px; /* Match header height */
left: 0;
position: sticky; /* Keep the sticky position */
top: var(--space-1);
width: 100%;
background-color: var(--card-bg);
background-color: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1); /* Use accent color with low opacity */
color: var(--text-color);
border-bottom: 1px solid var(--border-color);
border-top: 1px solid oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.3); /* Add top border with accent color */
border-bottom: 1px solid oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.4); /* Make bottom border stronger */
z-index: var(--z-overlay);
padding: 12px 16px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15);
padding: 12px 0;
box-shadow: 0 3px 10px rgba(0, 0, 0, 0.2); /* Stronger shadow */
transition: all 0.3s ease;
margin-bottom: 20px;
}
.duplicates-banner .banner-content {
position: relative;
max-width: 1400px;
margin: 0 auto;
display: flex;
align-items: center;
gap: 12px;
padding: 0 16px;
}
/* Responsive container for larger screens - match container in layout.css */
@media (min-width: 2000px) {
.duplicates-banner .banner-content {
max-width: 1800px;
}
}
@media (min-width: 3000px) {
.duplicates-banner .banner-content {
max-width: 2400px;
}
}
.duplicates-banner i.fa-exclamation-triangle {
font-size: 18px;
color: oklch(var(--lora-warning));
color: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
}
.duplicates-banner .banner-actions {
@@ -35,6 +51,29 @@
align-items: center;
}
/* Improved exit button in banner */
.duplicates-banner button.btn-exit-mode {
min-width: 120px;
background-color: var(--card-bg);
color: var(--text-color);
border: 1px solid var(--border-color);
padding: 6px 12px;
border-radius: var(--border-radius-xs);
font-size: 0.85em;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
gap: 6px;
transition: all 0.2s ease;
}
.duplicates-banner button.btn-exit-mode:hover {
background-color: var(--bg-color);
border-color: var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h);
transform: translateY(-1px);
}
.duplicates-banner button {
min-width: 100px;
display: flex;
@@ -53,7 +92,7 @@
}
.duplicates-banner button:hover {
border-color: var(--lora-accent);
border-color: var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h);
background: var(--bg-color);
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
@@ -78,23 +117,42 @@
/* Duplicate groups */
.duplicate-group {
position: relative;
border: 2px solid oklch(var(--lora-warning));
border: 2px solid oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
border-radius: var(--border-radius-base);
padding: 16px;
margin-bottom: 24px;
background: var(--card-bg);
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.12); /* Add subtle shadow to groups */
/* Add responsive width settings to match banner */
max-width: 1400px;
margin-left: auto;
margin-right: auto;
}
/* Add responsive container adjustments for duplicate groups - match container in banner */
@media (min-width: 2000px) {
.duplicate-group {
max-width: 1800px;
}
}
@media (min-width: 3000px) {
.duplicate-group {
max-width: 2400px;
}
}
.duplicate-group-header {
background-color: var(--bg-color);
color: var(--text-color);
border: 1px solid var(--border-color);
padding: 8px 16px;
padding: 10px 16px; /* Slightly increased padding */
border-radius: var(--border-radius-xs);
margin-bottom: 16px;
display: flex;
justify-content: space-between;
align-items: center;
border-left: 4px solid oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); /* Add accent border on the left */
}
.duplicate-group-header span:last-child {
@@ -122,7 +180,7 @@
}
.duplicate-group-header button:hover {
border-color: var(--lora-accent);
border-color: var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h);
background: var(--bg-color);
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
@@ -137,7 +195,7 @@
}
/* Make cards in duplicate groups have consistent width */
.card-group-container .lora-card {
.card-group-container .model-card {
flex: 0 0 auto;
width: 240px;
margin: 0;
@@ -177,32 +235,32 @@
}
.group-toggle-btn:hover {
border-color: var(--lora-accent);
border-color: var(--lora-accent-l) var(--lora-accent-c) var (--lora-accent-h);
transform: translateY(-1px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
}
/* Duplicate card styling */
.lora-card.duplicate {
.model-card.duplicate {
position: relative;
transition: all 0.2s ease;
}
.lora-card.duplicate:hover {
border-color: var(--lora-accent);
.model-card.duplicate:hover {
border-color: var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h);
}
.lora-card.duplicate.latest {
.model-card.duplicate.latest {
border-style: solid;
border-color: oklch(var(--lora-warning));
border-color: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
}
.lora-card.duplicate-selected {
border: 2px solid oklch(var(--lora-accent));
.model-card.duplicate-selected {
border: 2px solid oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
box-shadow: 0 0 8px rgba(0, 0, 0, 0.2);
}
.lora-card .selector-checkbox {
.model-card .selector-checkbox {
position: absolute;
top: 10px;
right: 10px;
@@ -213,12 +271,12 @@
}
/* Latest indicator */
.lora-card.duplicate.latest::after {
.model-card.duplicate.latest::after {
content: "Latest";
position: absolute;
top: 10px;
left: 10px;
background: oklch(var(--lora-accent));
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
color: white;
font-size: 12px;
padding: 2px 6px;
@@ -226,6 +284,251 @@
z-index: 5;
}
/* Model tooltip for duplicates mode */
.model-tooltip {
position: absolute;
background-color: var(--card-bg);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-sm);
box-shadow: 0 2px 10px rgba(0,0,0,0.2);
padding: 10px;
z-index: 1000;
max-width: 350px;
min-width: 250px;
color: var(--text-color);
font-size: 0.9em;
pointer-events: none; /* Don't block mouse events */
}
.model-tooltip .tooltip-header {
font-weight: bold;
font-size: 1.1em;
margin-bottom: 8px;
padding-bottom: 5px;
border-bottom: 1px solid var(--border-color);
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.model-tooltip .tooltip-info div {
margin-bottom: 4px;
display: flex;
flex-wrap: wrap;
word-break: break-all; /* Ensure long hashes wrap properly */
}
.model-tooltip .tooltip-info div strong {
margin-right: 5px;
min-width: 70px;
}
/* Latest indicator */
.hash-mismatch-info {
margin-top: 8px;
padding-top: 8px;
border-top: 1px dashed var(--border-color);
color: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
font-weight: bold;
word-break: break-all; /* Ensure long hashes wrap properly */
}
/* Verification Badge Styles */
.verification-badge {
display: inline-flex;
align-items: center;
margin-left: 8px;
padding: 2px 6px;
font-size: 0.8em;
border-radius: var(--border-radius-xs);
font-weight: normal;
}
.verification-badge.metadata {
background-color: var(--bg-color);
border: 1px solid var(--border-color);
color: var(--text-color);
}
.verification-badge.verified {
background-color: oklch(70% 0.2 140); /* Green for verified */
color: white;
}
.verification-badge.mismatch {
background-color: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
color: white;
}
.verification-badge i {
margin-right: 4px;
}
/* Hash Mismatch Styling */
.model-card.duplicate.hash-mismatch {
border: 2px dashed oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
opacity: 0.85;
position: relative;
}
.model-card.duplicate.hash-mismatch::before {
content: "";
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: repeating-linear-gradient(
45deg,
oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h) / 0.05),
oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h) / 0.05) 10px,
transparent 10px,
transparent 20px
);
z-index: 1;
pointer-events: none;
}
.model-card.duplicate.hash-mismatch .card-preview {
filter: grayscale(20%);
}
/* Mismatch Badge */
.mismatch-badge {
position: absolute;
top: 10px;
left: 10px; /* Changed from right:10px to left:10px */
background: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
color: white;
font-size: 12px;
padding: 3px 8px;
border-radius: var(--border-radius-xs);
z-index: 5;
}
/* Disabled checkbox style */
.model-card.duplicate.hash-mismatch .selector-checkbox {
opacity: 0.5;
cursor: not-allowed;
}
/* Hash mismatch info in tooltip */
.hash-mismatch-info {
margin-top: 8px;
padding-top: 8px;
border-top: 1px dashed var(--border-color);
color: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
font-weight: bold;
}
/* Verify hash button styling */
.btn-verify-hashes {
display: flex;
align-items: center;
gap: 6px;
padding: 4px 10px;
background: var(--card-bg);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
font-size: 0.85em;
cursor: pointer;
transition: all 0.2s ease;
}
.btn-verify-hashes:hover {
background: var(--bg-color);
border-color: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
transform: translateY(-1px);
}
.btn-verify-hashes i {
font-size: 0.9em;
}
/* Badge Styles */
.badge {
display: inline-flex;
align-items: center;
justify-content: center;
min-width: 16px; /* Reduced from 20px */
height: 16px; /* Reduced from 20px */
border-radius: 8px; /* Adjusted for smaller size */
background-color: var(--lora-error);
color: white;
font-size: 10px; /* Smaller font size */
font-weight: bold;
padding: 0 4px; /* Reduced padding */
position: absolute;
top: -8px; /* Moved closer to button */
right: -8px; /* Moved closer to button */
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.15); /* Softer shadow */
transition: transform 0.2s ease, opacity 0.2s ease;
}
.badge:empty {
display: none;
}
/* Make the pulse animation more subtle */
.badge.pulse {
animation: badge-pulse 2s infinite; /* Slower animation */
}
@keyframes badge-pulse {
0% {
transform: scale(1);
}
50% {
transform: scale(1.1); /* Less expansion */
}
100% {
transform: scale(1);
}
}
/* Help icon styling */
.help-icon {
color: var(--text-color);
opacity: 0.7;
cursor: help;
font-size: 16px;
margin-left: 8px;
transition: all 0.2s ease;
}
.help-icon:hover {
opacity: 1;
color: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
}
/* Help tooltip */
.help-tooltip {
display: none;
position: absolute;
max-width: 400px;
background: var(--card-bg);
color: var(--text-color);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-sm);
padding: 12px 16px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
z-index: var(--z-overlay);
font-size: 0.9em;
margin-top: 10px;
text-align: left;
pointer-events: none;
}
.help-tooltip:after {
content: "";
position: absolute;
top: -8px;
left: 10px; /* Position the arrow near the left instead of center */
border-width: 0 8px 8px 8px;
border-style: solid;
border-color: transparent transparent var(--card-bg) transparent;
}
/* Responsive adjustments */
@media (max-width: 768px) {
.duplicates-banner .banner-content {
@@ -256,4 +559,50 @@
margin-left: 0;
flex: 1;
}
.help-tooltip {
max-width: calc(100% - 40px);
}
/* Remove the fixed positioning adjustments for mobile since we're now using dynamic positioning */
.help-tooltip:after {
left: 10px;
}
}
/* In dark mode, add additional distinction */
html[data-theme="dark"] .duplicates-banner {
box-shadow: 0 3px 12px rgba(0, 0, 0, 0.4); /* Stronger shadow in dark mode */
background-color: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.15); /* Slightly stronger background in dark mode */
}
html[data-theme="dark"] .duplicate-group {
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.25); /* Stronger shadow in dark mode */
}
html[data-theme="dark"] .help-tooltip {
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}
/* Styles for disabled controls during duplicates mode */
.disabled-during-duplicates {
opacity: 0.5 !important;
pointer-events: none !important;
cursor: not-allowed !important;
user-select: none !important;
filter: grayscale(50%) !important;
}
/* Make the active duplicates button more prominent */
#findDuplicatesBtn.active {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
box-shadow: 0 0 0 2px oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.25);
position: relative;
z-index: 5;
}
#findDuplicatesBtn.active:hover {
background: oklch(calc(var(--lora-accent-l) - 5%) var(--lora-accent-c) var(--lora-accent-h));
}

View File

@@ -19,6 +19,18 @@
height: 100%;
}
/* Responsive header container for larger screens */
@media (min-width: 2000px) {
.header-container {
max-width: 1800px;
}
}
@media (min-width: 3000px) {
.header-container {
max-width: 2400px;
}
}
/* Logo and title styling */
.header-branding {
display: flex;
@@ -31,7 +43,7 @@
align-items: center;
text-decoration: none;
color: var(--text-color);
gap: 8px;
gap: 2px;
}
.app-logo {
@@ -79,6 +91,50 @@
flex: 1;
max-width: 400px;
margin: 0 1rem;
transition: opacity 0.2s ease;
}
/* Disabled state for header search */
.header-search.disabled {
opacity: 0.5;
pointer-events: none;
}
.header-search.disabled input {
background-color: var(--input-disabled-bg, #f5f5f5);
color: var(--text-muted);
cursor: not-allowed;
}
.header-search.disabled button {
background-color: var(--button-disabled-bg, #e0e0e0);
color: var(--text-muted);
cursor: not-allowed;
}
.header-search.disabled .search-icon {
color: var(--text-muted);
}
/* Dark theme specific styles for disabled header search */
[data-theme="dark"] .header-search.disabled input {
background-color: #3a3a3a;
color: #888888;
border-color: #555555;
}
[data-theme="dark"] .header-search.disabled button {
background-color: #3a3a3a;
color: #888888;
border-color: #555555;
}
[data-theme="dark"] .header-search.disabled .search-icon {
color: #888888;
}
[data-theme="dark"] .header-search.disabled .fas {
color: #888888;
}
/* Header controls (formerly corner controls) */
@@ -115,7 +171,8 @@
}
.theme-toggle .light-icon,
.theme-toggle .dark-icon {
.theme-toggle .dark-icon,
.theme-toggle .auto-icon {
position: absolute;
top: 50%;
left: 50%;
@@ -124,18 +181,60 @@
transition: opacity 0.3s ease;
}
/* Default state shows dark icon */
.theme-toggle .dark-icon {
opacity: 1;
}
[data-theme="light"] .theme-toggle .light-icon {
/* Light theme shows light icon */
.theme-toggle.theme-light .light-icon {
opacity: 1;
}
[data-theme="light"] .theme-toggle .dark-icon {
.theme-toggle.theme-light .dark-icon,
.theme-toggle.theme-light .auto-icon {
opacity: 0;
}
/* Dark theme shows dark icon */
.theme-toggle.theme-dark .dark-icon {
opacity: 1;
}
.theme-toggle.theme-dark .light-icon,
.theme-toggle.theme-dark .auto-icon {
opacity: 0;
}
/* Auto theme shows auto icon */
.theme-toggle.theme-auto .auto-icon {
opacity: 1;
}
.theme-toggle.theme-auto .light-icon,
.theme-toggle.theme-auto .dark-icon {
opacity: 0;
}
/* Badge styling */
.update-badge {
position: absolute;
top: -3px;
right: -3px;
width: 8px;
height: 8px;
background-color: var(--lora-error);
border-radius: 50%;
border: 2px solid var(--card-bg);
transition: all 0.2s ease;
pointer-events: none;
opacity: 0;
}
.update-badge.visible {
opacity: 1;
}
/* Mobile adjustments */
@media (max-width: 768px) {
.app-title {

View File

@@ -337,72 +337,7 @@
margin-left: 8px;
}
/* Location Selection Styles */
.location-selection {
margin: var(--space-2) 0;
padding: var(--space-2);
background: var(--lora-surface);
border-radius: var(--border-radius-sm);
}
/* Reuse folder browser and path preview styles from download-modal.css */
.folder-browser {
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
padding: var(--space-1);
max-height: 200px;
overflow-y: auto;
}
.folder-item {
padding: 8px;
cursor: pointer;
border-radius: var(--border-radius-xs);
transition: background-color 0.2s;
}
.folder-item:hover {
background: var(--lora-surface);
}
.folder-item.selected {
background: oklch(var(--lora-accent) / 0.1);
border: 1px solid var(--lora-accent);
}
.path-preview {
margin-bottom: var(--space-3);
padding: var(--space-2);
background: var(--bg-color);
border-radius: var(--border-radius-sm);
border: 1px dashed var(--border-color);
}
.path-preview label {
display: block;
margin-bottom: 8px;
color: var(--text-color);
font-size: 0.9em;
opacity: 0.8;
}
.path-display {
padding: var(--space-1);
color: var(--text-color);
font-family: monospace;
font-size: 0.9em;
line-height: 1.4;
white-space: pre-wrap;
word-break: break-all;
opacity: 0.85;
background: var(--lora-surface);
border-radius: var(--border-radius-xs);
}
/* Input Group Styles */
.input-group {
margin-bottom: var(--space-2);
}
.input-with-button {
display: flex;
@@ -430,22 +365,6 @@
background: oklch(from var(--lora-accent) l c h / 0.9);
}
.input-group label {
display: block;
margin-bottom: 8px;
color: var(--text-color);
}
.input-group input,
.input-group select {
width: 100%;
padding: 8px;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
color: var(--text-color);
}
/* Dark theme adjustments */
[data-theme="dark"] .lora-item {
background: var(--lora-surface);

View File

@@ -0,0 +1,96 @@
/* Keyboard navigation indicator and help */
.keyboard-nav-hint {
display: inline-flex;
align-items: center;
justify-content: center;
position: relative;
width: 32px;
height: 32px;
border-radius: 50%;
background: var(--card-bg);
border: 1px solid var(--border-color);
color: var(--text-color);
cursor: help;
transition: all 0.2s ease;
margin-left: 8px;
}
.keyboard-nav-hint:hover {
background: var(--lora-accent);
color: white;
transform: translateY(-2px);
box-shadow: 0 3px 5px rgba(0, 0, 0, 0.08);
}
.keyboard-nav-hint i {
font-size: 14px;
}
/* Tooltip styling */
.tooltip {
position: relative;
}
.tooltip .tooltiptext {
visibility: hidden;
width: 240px;
background-color: var(--lora-surface);
color: var(--text-color);
text-align: center;
border-radius: var(--border-radius-xs);
padding: 8px;
position: absolute;
z-index: 9999; /* 确保在卡片上方显示 */
left: 120%; /* 将tooltip显示在图标右侧 */
top: 50%; /* 垂直居中 */
transform: translateY(-50%); /* 垂直居中 */
opacity: 0;
transition: opacity 0.3s;
box-shadow: 0 3px 8px rgba(0, 0, 0, 0.15);
border: 1px solid var(--lora-border);
font-size: 0.85em;
line-height: 1.4;
}
.tooltip .tooltiptext::after {
content: "";
position: absolute;
top: 50%; /* 箭头垂直居中 */
right: 100%; /* 箭头在左侧 */
margin-top: -5px;
border-width: 5px;
border-style: solid;
border-color: transparent var(--lora-border) transparent transparent; /* 箭头指向左侧 */
}
.tooltip:hover .tooltiptext {
visibility: visible;
opacity: 1;
}
/* Keyboard shortcuts table */
.keyboard-shortcuts {
width: 100%;
border-collapse: collapse;
margin-top: 5px;
}
.keyboard-shortcuts td {
padding: 4px;
text-align: left;
}
.keyboard-shortcuts td:first-child {
font-weight: bold;
width: 40%;
}
.key {
display: inline-block;
background: var(--bg-color);
border: 1px solid var(--border-color);
border-radius: 3px;
padding: 1px 5px;
font-size: 0.8em;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.08);
}

View File

@@ -109,7 +109,7 @@
}
@media (prefers-reduced-motion: reduce) {
.lora-card,
.model-card,
.progress-bar,
.current-item-bar {
transition: none;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,100 @@
/* Model Description Styling */
.model-description-container {
background: var(--lora-surface);
border-radius: var(--border-radius-sm);
overflow: hidden;
min-height: 200px;
position: relative;
/* Remove the max-height and overflow-y to allow content to expand naturally */
}
.model-description-loading {
display: flex;
align-items: center;
justify-content: center;
padding: var(--space-3);
color: var(--text-color);
opacity: 0.7;
font-size: 0.9em;
}
.model-description-loading .fa-spinner {
margin-right: var(--space-1);
}
.model-description-content {
padding: var(--space-2);
line-height: 1.5;
overflow-wrap: break-word;
font-size: 0.95em;
}
.model-description-content h1,
.model-description-content h2,
.model-description-content h3,
.model-description-content h4,
.model-description-content h5,
.model-description-content h6 {
margin-top: 1em;
margin-bottom: 0.5em;
font-weight: 600;
}
.model-description-content p {
margin-bottom: 1em;
}
.model-description-content img {
max-width: 100%;
height: auto;
border-radius: var(--border-radius-xs);
display: block;
margin: 1em 0;
}
.model-description-content pre {
background: rgba(0, 0, 0, 0.05);
border-radius: var(--border-radius-xs);
padding: var(--space-1);
white-space: pre-wrap;
margin: 1em 0;
overflow-x: auto;
}
.model-description-content code {
font-family: monospace;
font-size: 0.9em;
background: rgba(0, 0, 0, 0.05);
padding: 0.1em 0.3em;
border-radius: 3px;
}
.model-description-content pre code {
background: transparent;
padding: 0;
}
.model-description-content ul,
.model-description-content ol {
margin-left: 1.5em;
margin-bottom: 1em;
}
.model-description-content li {
margin-bottom: 0.5em;
}
.model-description-content blockquote {
border-left: 3px solid var (--lora-accent);
padding-left: 1em;
margin-left: 0;
margin-right: 0;
font-style: italic;
opacity: 0.8;
}
/* Adjust dark mode for model description */
[data-theme="dark"] .model-description-content pre,
[data-theme="dark"] .model-description-content code {
background: rgba(255, 255, 255, 0.05);
}

View File

@@ -0,0 +1,486 @@
/* Lora Modal Header */
.modal-header {
display: flex;
flex-direction: column;
justify-content: flex-start;
align-items: flex-start;
margin-bottom: var(--space-3);
padding-bottom: var(--space-2);
border-bottom: 1px solid var(--lora-border);
}
/* Info Grid */
.info-grid {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: var(--space-2);
margin-bottom: var(--space-3);
}
.info-item {
padding: var(--space-2);
background: rgba(0, 0, 0, 0.03);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-sm);
}
/* 调整深色主题下的样式 */
[data-theme="dark"] .info-item {
background: rgba(255, 255, 255, 0.03);
border: 1px solid var(--lora-border);
}
.info-item.full-width {
grid-column: 1 / -1;
}
.info-item label {
display: block;
font-size: 0.85em;
color: var(--text-color);
opacity: 0.8;
margin-bottom: 4px;
}
.info-item span {
color: var(--text-color);
word-break: break-word;
}
.info-item.usage-tips,
.info-item.notes {
grid-column: 1 / -1 !important; /* Make notes section full width */
}
/* Add specific styles for notes content */
.info-item.notes .editable-field [contenteditable] {
min-height: 60px; /* Increase height for multiple lines */
max-height: 150px; /* Limit maximum height */
overflow-y: auto; /* Add scrolling for long content */
white-space: pre-wrap; /* Preserve line breaks */
line-height: 1.5; /* Improve readability */
padding: 8px 12px; /* Slightly increase padding */
}
.file-path {
font-family: monospace;
font-size: 0.9em;
}
.description-text {
line-height: 1.5;
max-height: 100px;
overflow-y: auto;
}
/* Editable Fields */
.editable-field {
position: relative;
display: flex;
gap: 8px;
align-items: flex-start;
}
.editable-field [contenteditable] {
flex: 1;
min-height: 24px;
padding: 4px 8px;
background: var(--bg-color);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
font-size: 0.9em;
line-height: 1.4;
color: var(--text-color);
transition: border-color 0.2s;
word-break: break-word;
}
.editable-field [contenteditable]:focus {
outline: none;
border-color: var(--lora-accent);
background: var(--bg-color);
}
.editable-field [contenteditable]:empty::before {
content: attr(data-placeholder);
color: var(--text-color);
opacity: 0.5;
}
.notes-hint {
font-size: 0.8em;
color: var(--text-color);
opacity: 0.7;
margin-left: 5px;
cursor: help;
position: relative; /* Add positioning context */
}
@media (max-width: 640px) {
.info-item.usage-tips,
.info-item.notes {
grid-column: 1 / -1;
}
}
/* 修改 back-to-top 按钮样式,使其固定在 modal 内部 */
.modal-content .back-to-top {
position: sticky; /* 改用 sticky 定位 */
float: right; /* 使用 float 确保按钮在右侧 */
bottom: 20px; /* 距离底部的距离 */
margin-right: 20px; /* 右侧间距 */
margin-top: -56px; /* 负边距确保不占用额外空间 */
width: 36px;
height: 36px;
border-radius: 50%;
background: var(--card-bg);
border: 1px solid var(--border-color);
color: var(--text-color);
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
opacity: 0;
visibility: hidden;
transform: translateY(10px);
transition: all 0.3s ease;
z-index: 10;
}
.modal-content .back-to-top.visible {
opacity: 1;
visibility: visible;
transform: translateY(0);
}
.modal-content .back-to-top:hover {
background: var(--lora-accent);
color: white;
transform: translateY(-2px);
}
/* File name copy styles */
.file-name-wrapper {
display: flex;
align-items: center;
gap: 8px;
padding: 4px;
border-radius: var(--border-radius-xs);
transition: background-color 0.2s;
position: relative;
}
.file-name-content {
padding: 2px 4px;
border-radius: var(--border-radius-xs);
border: 1px solid transparent;
flex: 1;
}
.file-name-wrapper.editing .file-name-content {
border: 1px solid var(--lora-accent);
background: var(--bg-color);
outline: none;
}
/* 合并编辑按钮样式 */
.edit-model-name-btn,
.edit-file-name-btn,
.edit-base-model-btn,
.edit-model-description-btn {
background: transparent;
border: none;
color: var(--text-color);
opacity: 0;
cursor: pointer;
padding: 2px 5px;
border-radius: var(--border-radius-xs);
transition: all 0.2s ease;
margin-left: var(--space-1);
}
.edit-model-name-btn.visible,
.edit-file-name-btn.visible,
.edit-base-model-btn.visible,
.edit-model-description-btn.visible,
.model-name-header:hover .edit-model-name-btn,
.file-name-wrapper:hover .edit-file-name-btn,
.base-model-display:hover .edit-base-model-btn,
.model-name-header:hover .edit-model-description-btn {
opacity: 0.5;
}
.edit-model-name-btn:hover,
.edit-file-name-btn:hover,
.edit-base-model-btn:hover,
.edit-model-description-btn:hover {
opacity: 0.8 !important;
background: rgba(0, 0, 0, 0.05);
}
[data-theme="dark"] .edit-model-name-btn:hover,
[data-theme="dark"] .edit-file-name-btn:hover,
[data-theme="dark"] .edit-base-model-btn:hover {
background: rgba(255, 255, 255, 0.05);
}
/* Base Model and Size combined styles */
.info-item.base-size {
display: flex;
gap: var(--space-3);
}
.base-wrapper {
flex: 2; /* 分配更多空间给base model */
}
/* Base model display and editing styles */
.base-model-display {
display: flex;
align-items: center;
position: relative;
}
.base-model-content {
padding: 2px 4px;
border-radius: var(--border-radius-xs);
border: 1px solid transparent;
color: var(--text-color);
flex: 1;
}
.base-model-selector {
width: 100%;
padding: 3px 5px;
background: var(--bg-color);
border: 1px solid var(--lora-accent);
border-radius: var(--border-radius-xs);
color: var(--text-color);
font-size: 0.9em;
outline: none;
margin-right: var(--space-1);
}
.size-wrapper {
flex: 1;
border-left: 1px solid var(--lora-border);
padding-left: var(--space-3);
}
.base-wrapper label,
.size-wrapper label {
display: block;
margin-bottom: 4px;
}
.size-wrapper span {
font-family: monospace;
font-size: 0.9em;
opacity: 0.9;
}
/* New Model Name Header Styles */
.model-name-header {
display: flex;
align-items: center;
width: calc(100% - 40px); /* Avoid overlap with close button */
position: relative;
}
.model-name-content {
margin: 0;
padding: var(--space-1);
border-radius: var(--border-radius-xs);
font-size: 1.5em !important;
font-weight: 600;
line-height: 1.2;
color: var(--text-color);
border: 1px solid transparent;
outline: none;
flex: 1;
}
.model-name-content:focus {
border: 1px solid var(--lora-accent);
background: var(--bg-color);
}
/* Tab System Styling */
.showcase-tabs {
display: flex;
border-bottom: 1px solid var(--lora-border);
margin-bottom: var(--space-2);
position: relative;
z-index: 2;
}
.tab-btn {
padding: var(--space-1) var(--space-2);
background: transparent;
border: none;
border-bottom: 2px solid transparent;
color: var(--text-color);
cursor: pointer;
font-size: 0.95em;
transition: all 0.2s;
opacity: 0.7;
position: relative;
}
.tab-btn:hover {
opacity: 1;
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.05);
}
.tab-btn.active {
border-bottom: 2px solid var(--lora-accent);
opacity: 1;
font-weight: 600;
}
.tab-content {
position: relative;
min-height: 100px;
}
.tab-pane {
display: none;
}
.tab-pane.active {
display: block;
}
.view-all-btn {
display: flex;
align-items: center;
gap: 5px;
padding: 6px 12px;
background-color: var(--lora-accent);
color: var(--lora-text);
border: none;
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: background-color 0.2s;
font-size: 13px;
}
.view-all-btn:hover {
opacity: 0.9;
}
/* Loading, error and empty states */
.recipes-loading,
.recipes-error,
.recipes-empty {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
padding: 40px;
text-align: center;
min-height: 200px;
}
.recipes-loading i,
.recipes-error i,
.recipes-empty i {
font-size: 32px;
margin-bottom: 15px;
color: var(--lora-accent);
}
.recipes-error i {
color: var(--lora-error);
}
/* Creator Information Styles */
.creator-info {
display: flex;
align-items: center;
gap: 10px;
padding: 2px 10px;
background: rgba(0, 0, 0, 0.03);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-sm);
max-width: fit-content;
cursor: pointer;
transition: all 0.2s;
}
[data-theme="dark"] .creator-info,
[data-theme="dark"] .civitai-view {
background: rgba(255, 255, 255, 0.03);
border: 1px solid var(--lora-border);
}
.creator-avatar {
width: 26px;
height: 26px;
border-radius: 50%;
overflow: hidden;
flex-shrink: 0;
display: flex;
align-items: center;
justify-content: center;
background: var(--lora-surface);
border: 1px solid var(--lora-border);
}
.creator-avatar img {
width: 100%;
height: 100%;
object-fit: cover;
}
.creator-placeholder {
background: var(--lora-accent);
color: white;
display: flex;
align-items: center;
justify-content: center;
}
.creator-username {
font-size: 0.9em;
font-weight: 500;
color: var(--text-color);
}
/* Add hover effect for creator info */
.creator-info:hover,
.civitai-view:hover {
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
border-color: var(--lora-accent);
transform: translateY(-1px);
}
.creator-actions {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: var(--space-1);
flex-wrap: wrap;
}
.civitai-view {
display: flex;
align-items: center;
gap: 6px;
padding: 6px 12px;
background: rgba(0, 0, 0, 0.03);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-sm);
color: var(--text-color);
cursor: pointer;
font-weight: 500;
font-size: 0.9em;
transition: all 0.2s;
}
.civitai-view i {
font-size: 20px;
display: flex;
align-items: center;
justify-content: center;
}

View File

@@ -0,0 +1,68 @@
/* Update Preset Controls styles */
.preset-controls {
display: flex;
gap: var(--space-2);
margin-bottom: var(--space-2);
}
.preset-controls select,
.preset-controls input {
padding: var(--space-1);
background: var(--bg-color);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-xs);
color: var(--text-color);
}
.preset-tags {
display: flex;
flex-wrap: wrap;
gap: var(--space-1);
}
.preset-tag {
display: flex;
align-items: center;
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-xs);
padding: calc(var(--space-1) * 0.5) var(--space-1);
gap: var(--space-1);
transition: all 0.2s ease;
}
.preset-tag span {
color: var(--lora-accent);
font-size: 0.9em;
}
.preset-tag i {
color: var(--text-color);
opacity: 0.5;
cursor: pointer;
transition: all 0.2s ease;
}
.preset-tag:hover {
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
border-color: var(--lora-accent);
}
.preset-tag i:hover {
color: var(--lora-error);
opacity: 1;
}
.add-preset-btn {
padding: calc(var(--space-1) * 0.5) var(--space-2);
background: var(--lora-accent);
color: var(--lora-text);
border: none;
border-radius: var(--border-radius-xs);
cursor: pointer;
transition: opacity 0.2s;
}
.add-preset-btn:hover {
opacity: 0.9;
}

View File

@@ -0,0 +1,478 @@
/* Showcase Section */
.showcase-section {
position: relative;
margin-top: var(--space-4);
}
.carousel {
transition: max-height 0.3s ease-in-out;
overflow: hidden;
}
.carousel.collapsed {
max-height: 0;
}
.carousel-container {
display: flex;
flex-direction: column;
gap: var(--space-2);
}
.media-wrapper {
position: relative;
width: 100%;
background: var(--lora-surface);
margin-bottom: var(--space-2);
overflow: hidden; /* Ensure metadata panel is contained */
}
.media-wrapper:last-child {
margin-bottom: 0;
}
.media-wrapper img,
.media-wrapper video {
position: absolute;
top: 0;
left: 0;
width: 100%;
height: 100%;
object-fit: contain;
}
.no-examples {
text-align: center;
padding: var(--space-3);
color: var(--text-color);
opacity: 0.7;
}
/* Adjust the media wrapper for tab system */
#showcase-tab .carousel-container {
margin-top: var(--space-2);
}
/* Add styles for blurred showcase content */
.nsfw-media-wrapper {
position: relative;
}
.media-wrapper img.blurred,
.media-wrapper video.blurred {
filter: blur(25px);
}
.media-wrapper .nsfw-overlay {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
display: flex;
align-items: center;
justify-content: center;
z-index: 2;
pointer-events: none;
}
/* Position the toggle button at the top left of showcase media */
.showcase-toggle-btn {
position: absolute;
z-index: 3;
}
/* Add styles for showcase media controls */
.media-controls {
position: absolute;
display: flex;
gap: 6px;
z-index: 4;
opacity: 0;
transform: translateY(-5px);
transition: opacity 0.2s ease, transform 0.2s ease;
pointer-events: none;
}
.media-controls.visible {
opacity: 1;
transform: translateY(0);
pointer-events: auto;
}
.media-control-btn {
width: 28px;
height: 28px;
border-radius: 50%;
background: var(--bg-color);
border: 1px solid var(--border-color);
color: var(--text-color);
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: all 0.2s ease;
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.15);
padding: 0;
position: relative;
overflow: hidden;
}
.media-control-btn:hover {
transform: translateY(-2px);
box-shadow: 0 3px 7px rgba(0, 0, 0, 0.2);
}
.media-control-btn.set-preview-btn:hover {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
.media-control-btn.example-delete-btn:hover:not(.disabled) {
background: var(--lora-error);
color: white;
border-color: var(--lora-error);
}
/* Disabled state for delete button */
.media-control-btn.example-delete-btn.disabled {
opacity: 0.5;
cursor: not-allowed;
}
/* Two-step confirmation for delete button */
.media-control-btn.example-delete-btn .confirm-icon {
position: absolute;
top: 0;
left: 0;
right: 0;
bottom: 0;
display: flex;
align-items: center;
justify-content: center;
background: var(--lora-error);
color: white;
font-size: 1em;
opacity: 0;
transition: opacity 0.2s ease;
}
.media-control-btn.example-delete-btn.confirm .fa-trash-alt {
opacity: 0;
}
.media-control-btn.example-delete-btn.confirm .confirm-icon {
opacity: 1;
}
.media-control-btn.example-delete-btn.confirm {
background: var(--lora-error);
color: white;
border-color: var(--lora-error);
}
@keyframes pulse {
0% {
box-shadow: 0 0 0 0 rgba(220, 53, 69, 0.7);
}
70% {
box-shadow: 0 0 0 5px rgba(220, 53, 69, 0);
}
100% {
box-shadow: 0 0 0 0 rgba(220, 53, 69, 0);
}
}
/* Image Metadata Panel Styles */
.image-metadata-panel {
position: absolute;
bottom: 0;
left: 0;
right: 0;
background: var(--bg-color);
border-top: 1px solid var(--border-color);
padding: var(--space-2);
transform: translateY(100%);
transition: transform 0.3s cubic-bezier(0.175, 0.885, 0.32, 1.275), opacity 0.25s ease;
z-index: 5;
max-height: 50%; /* Reduced to take less space */
overflow-y: auto;
box-shadow: 0 -2px 8px rgba(0, 0, 0, 0.1);
opacity: 0;
pointer-events: none;
}
/* Show metadata panel only when the 'visible' class is added */
.media-wrapper .image-metadata-panel.visible {
transform: translateY(0);
opacity: 0.98;
pointer-events: auto;
}
/* Adjust to dark theme */
[data-theme="dark"] .image-metadata-panel {
background: var(--card-bg);
box-shadow: 0 -2px 8px rgba(0, 0, 0, 0.3);
}
.metadata-content {
display: flex;
flex-direction: column;
gap: 10px;
}
/* Styling for parameters tags */
.params-tags {
display: flex;
flex-wrap: wrap;
gap: 6px;
margin-bottom: var(--space-1);
padding-bottom: var(--space-1);
border-bottom: 1px solid var(--lora-border);
}
.param-tag {
display: inline-flex;
align-items: center;
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-xs);
padding: 2px 6px;
font-size: 0.8em;
line-height: 1.2;
white-space: nowrap;
}
.param-tag .param-name {
font-weight: 600;
color: var(--text-color);
margin-right: 4px;
opacity: 0.8;
}
.param-tag .param-value {
color: var(--lora-accent);
}
/* Special styling for prompt row */
.metadata-row.prompt-row {
flex-direction: column;
padding-top: 0;
}
.metadata-row.prompt-row + .metadata-row.prompt-row {
margin-top: var(--space-2);
}
.metadata-label {
font-weight: 600;
color: var(--text-color);
opacity: 0.8;
font-size: 0.85em;
display: block;
margin-bottom: 4px;
}
.metadata-prompt-wrapper {
position: relative;
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-xs);
padding: 6px 30px 6px 8px;
margin-top: 2px;
max-height: 80px; /* Reduced from 120px */
overflow-y: auto;
word-break: break-word;
width: 100%;
box-sizing: border-box;
}
.metadata-prompt {
color: var(--text-color);
font-family: monospace;
font-size: 0.85em;
white-space: pre-wrap;
}
.copy-prompt-btn {
position: absolute;
top: 6px;
right: 6px;
background: transparent;
border: none;
color: var(--text-color);
opacity: 0.6;
cursor: pointer;
padding: 3px;
transition: all 0.2s ease;
}
.copy-prompt-btn:hover {
opacity: 1;
color: var(--lora-accent);
}
/* Scrollbar styling for metadata panel */
.image-metadata-panel::-webkit-scrollbar {
width: 6px;
}
.image-metadata-panel::-webkit-scrollbar-track {
background: transparent;
}
.image-metadata-panel::-webkit-scrollbar-thumb {
background-color: var(--border-color);
border-radius: 3px;
}
/* For Firefox */
.image-metadata-panel {
scrollbar-width: thin;
scrollbar-color: var(--border-color) transparent;
}
/* No metadata message styling */
.no-metadata-message {
display: flex;
align-items: center;
justify-content: center;
padding: var(--space-2);
color: var(--text-color);
opacity: 0.7;
text-align: center;
font-style: italic;
gap: 8px;
}
.no-metadata-message i {
font-size: 1.1em;
color: var(--lora-accent);
opacity: 0.8;
}
/* Scroll Indicator */
.scroll-indicator {
cursor: pointer;
padding: var(--space-2);
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-sm);
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
margin-bottom: var(--space-2);
transition: background-color 0.2s, transform 0.2s;
}
.scroll-indicator:hover {
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
transform: translateY(-1px);
}
.scroll-indicator span {
font-size: 0.9em;
color: var(--text-color);
}
.lazy {
opacity: 0;
transition: opacity 0.3s;
}
.lazy[src] {
opacity: 1;
}
/* Example Import Area */
.example-import-area {
margin-top: var(--space-4);
padding: var(--space-2);
}
.example-import-area.empty {
margin-top: var(--space-2);
padding: var(--space-4) var(--space-2);
}
.import-container {
border: 2px dashed var(--border-color);
border-radius: var(--border-radius-sm);
padding: var(--space-4);
text-align: center;
transition: all 0.3s ease;
background: var(--lora-surface);
cursor: pointer;
}
.import-container.highlight {
border-color: var(--lora-accent);
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
transform: scale(1.01);
}
.import-placeholder {
display: flex;
flex-direction: column;
align-items: center;
gap: var(--space-1);
padding-top: var(--space-1);
}
.import-placeholder i {
font-size: 2.5rem;
/* color: var(--lora-accent); */
opacity: 0.8;
margin-bottom: var(--space-1);
}
.import-placeholder h3 {
margin: 0 0 var(--space-1);
font-size: 1.2rem;
font-weight: 500;
color: var(--text-color);
}
.import-placeholder p {
margin: var(--space-1) 0;
color: var(--text-color);
opacity: 0.8;
}
.import-placeholder .sub-text {
font-size: 0.9em;
opacity: 0.6;
margin: var(--space-1) 0;
}
.import-formats {
font-size: 0.8em !important;
opacity: 0.6 !important;
margin-top: var(--space-2) !important;
}
.select-files-btn {
background: var(--lora-accent);
color: var(--lora-text);
border: none;
border-radius: var(--border-radius-xs);
padding: var(--space-2) var(--space-3);
cursor: pointer;
font-size: 0.9em;
display: flex;
align-items: center;
gap: 8px;
transition: all 0.2s;
}
.select-files-btn:hover {
opacity: 0.9;
transform: translateY(-1px);
}
/* For dark theme */
[data-theme="dark"] .import-container {
background: rgba(255, 255, 255, 0.03);
}

View File

@@ -0,0 +1,148 @@
/* Model Tags styles */
.model-tags {
display: none;
}
.model-tag {
display: none;
}
/* Updated Model Tags styles - improved visibility in light theme */
.model-tags-container {
position: relative;
}
.model-tags-compact {
display: flex;
flex-wrap: nowrap;
gap: 6px;
align-items: center;
}
.model-tag-compact {
/* Updated styles to match info-item appearance */
background: rgba(0, 0, 0, 0.03);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-xs);
padding: 2px 8px;
font-size: 0.75em;
color: var(--text-color);
white-space: nowrap;
}
/* Style for empty tags placeholder */
.model-tag-empty {
background: rgba(0, 0, 0, 0.02);
border: 1px dashed rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-xs);
padding: 2px 8px;
font-size: 0.75em;
color: var(--text-color);
white-space: nowrap;
opacity: 0.7;
font-style: italic;
}
/* Adjust dark theme tag styles */
[data-theme="dark"] .model-tag-compact {
background: rgba(255, 255, 255, 0.03);
border: 1px solid var(--lora-border);
}
/* Dark theme for empty tags */
[data-theme="dark"] .model-tag-empty {
background: rgba(255, 255, 255, 0.02);
border: 1px dashed var(--lora-border);
}
.model-tag-more {
background: var(--lora-accent);
color: var(--lora-text);
border-radius: var(--border-radius-xs);
padding: 2px 8px;
font-size: 0.75em;
cursor: pointer;
white-space: nowrap;
font-weight: 500;
}
.model-tags-tooltip {
position: absolute;
top: calc(100% + 8px);
left: 0;
background: var(--card-bg);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-sm);
box-shadow: 0 3px 8px rgba(0, 0, 0, 0.15);
padding: 10px 14px;
max-width: 400px;
z-index: 10;
opacity: 0;
visibility: hidden;
transform: translateY(-4px);
transition: all 0.2s ease;
pointer-events: none;
}
.model-tags-tooltip.visible {
opacity: 1;
visibility: visible;
transform: translateY(0);
pointer-events: auto;
}
.tooltip-content {
display: flex;
flex-wrap: wrap;
gap: 6px;
max-height: 200px;
overflow-y: auto;
}
.tooltip-tag {
/* Updated styles to match info-item appearance */
background: rgba(0, 0, 0, 0.03);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-xs);
padding: 3px 8px;
font-size: 0.75em;
color: var(--text-color);
}
/* Adjust dark theme tooltip tag styles */
[data-theme="dark"] .tooltip-tag {
background: rgba(255, 255, 255, 0.03);
border: 1px solid var(--lora-border);
}
/* Model Tags Edit Mode */
.model-tags-header {
display: flex;
justify-content: space-between;
align-items: center;
}
.edit-tags-btn {
background: transparent;
border: none;
color: var(--text-color);
opacity: 0;
cursor: pointer;
padding: 2px 5px;
border-radius: var(--border-radius-xs);
transition: all 0.2s ease;
margin-left: var(--space-1);
}
.edit-tags-btn.visible,
.model-tags-container:hover .edit-tags-btn {
opacity: 0.5;
}
/* Edit mode active state */
.model-tags-container.edit-mode {
width: 100%;
display: block;
flex-basis: 100%;
grid-column: 1 / -1;
}

View File

@@ -0,0 +1,112 @@
/* Update Trigger Words styles */
.info-item.trigger-words {
padding: var(--space-2);
background: rgba(0, 0, 0, 0.03);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-sm);
}
/* 调整 trigger words 样式 */
[data-theme="dark"] .info-item.trigger-words {
background: rgba(255, 255, 255, 0.03);
border: 1px solid var(--lora-border);
}
/* New header style for trigger words */
.trigger-words-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 6px;
}
.trigger-words-content {
margin-bottom: var(--space-1);
}
.trigger-words-tags {
display: flex;
flex-wrap: wrap;
gap: 8px;
align-items: flex-start;
}
/* No trigger words message */
.no-trigger-words {
color: var(--text-color);
opacity: 0.7;
font-style: italic;
font-size: 0.9em;
}
/* Trigger word tags in display mode */
.trigger-word-tag {
display: inline-flex;
align-items: center;
background: var(--bg-color);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
padding: 4px 8px;
cursor: pointer;
transition: all 0.2s ease;
gap: 6px;
position: relative;
}
.trigger-word-content {
color: var(--lora-accent) !important;
font-size: 0.85em;
line-height: 1.4;
word-break: break-word;
}
.trigger-word-tag:hover {
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
border-color: var(--lora-accent);
}
.trigger-word-copy {
display: flex;
align-items: center;
color: var(--text-color);
opacity: 0.5;
flex-shrink: 0;
transition: opacity 0.2s;
}
.trained-word-freq {
color: var(--text-color);
font-size: 0.75em;
background: rgba(0, 0, 0, 0.05);
border-radius: 10px;
min-width: 20px;
padding: 1px 5px;
text-align: center;
line-height: 1.2;
}
[data-theme="dark"] .trained-word-freq {
background: rgba(255, 255, 255, 0.05);
}
/* Class tokens styling */
.class-tokens-container {
padding: 10px;
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.class-token-item {
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1) !important;
border: 1px solid var(--lora-accent) !important;
}
.token-badge {
background: var(--lora-accent);
color: white;
font-size: 0.7em;
padding: 2px 5px;
border-radius: 8px;
white-space: nowrap;
}

View File

@@ -39,4 +39,182 @@
.context-menu-item i {
width: 16px;
text-align: center;
}
/* NSFW Level Selector */
.nsfw-level-selector {
position: fixed;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
background: var(--card-bg);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-base);
padding: 16px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.2);
z-index: var(--z-modal);
width: 300px;
display: none;
}
.nsfw-level-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 16px;
}
.nsfw-level-header h3 {
margin: 0;
font-size: 16px;
font-weight: 500;
}
.close-nsfw-selector {
background: transparent;
border: none;
color: var(--text-color);
cursor: pointer;
padding: 4px;
border-radius: var(--border-radius-xs);
}
.close-nsfw-selector:hover {
background: var(--border-color);
}
.current-level {
margin-bottom: 12px;
padding: 8px;
background: var(--bg-color);
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
}
.nsfw-level-options {
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.nsfw-level-btn {
flex: 1 0 calc(33% - 8px);
padding: 8px;
border-radius: var(--border-radius-xs);
background: var(--bg-color);
border: 1px solid var(--border-color);
color: var(--text-color);
cursor: pointer;
transition: all 0.2s ease;
}
.nsfw-level-btn:hover {
background: var(--lora-border);
}
.nsfw-level-btn.active {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
/* Node Selector */
.node-selector {
position: fixed;
background: var(--lora-surface);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
padding: 4px 0;
min-width: 200px;
max-width: 350px;
max-height: 400px;
overflow-y: auto;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.2);
z-index: 1000;
display: none;
backdrop-filter: blur(10px);
}
.node-item {
padding: 10px 15px;
cursor: pointer;
display: flex;
align-items: center;
gap: 10px;
color: var(--text-color);
background: var(--lora-surface);
transition: background-color 0.2s;
border-bottom: 1px solid var(--border-color);
}
.node-item:last-child {
border-bottom: none;
}
.node-item:hover {
background-color: var(--lora-accent);
color: var(--lora-text);
}
.node-icon-indicator {
width: 24px;
height: 24px;
border-radius: 4px;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
}
.node-icon-indicator i {
color: white;
font-size: 12px;
text-shadow: 0 1px 2px rgba(0, 0, 0, 0.3);
}
.node-icon-indicator.all-nodes {
background: linear-gradient(45deg, #4a90e2, #357abd);
}
/* Remove old node-color-indicator styles */
.node-color-indicator {
display: none;
}
.send-all-item {
border-top: 1px solid var(--border-color);
font-weight: 500;
background: var(--card-bg);
}
.send-all-item:hover {
background-color: var(--lora-accent);
color: var(--lora-text);
}
.send-all-item i {
width: 16px;
text-align: center;
}
/* Node Selector Header */
.node-selector-header {
padding: 10px 15px;
border-bottom: 1px solid var(--border-color);
background: var(--card-bg);
display: flex;
flex-direction: column;
gap: 4px;
}
.selector-action-type {
font-weight: 600;
font-size: 14px;
color: var(--lora-accent);
}
.selector-instruction {
font-size: 12px;
color: var(--text-muted);
font-style: italic;
}

View File

@@ -0,0 +1,274 @@
/* modal 基础样式 */
.modal {
display: none;
position: fixed;
top: 48px; /* Start below the header */
left: 0;
width: 100%;
height: calc(100% - 48px); /* Adjust height to exclude header */
background: rgba(0, 0, 0, 0.2); /* 调整为更淡的半透明黑色 */
z-index: var(--z-modal);
overflow: auto; /* Change from hidden to auto to allow scrolling */
}
/* 当模态窗口打开时禁止body滚动 */
body.modal-open {
position: fixed;
width: 100%;
padding-right: var(--scrollbar-width, 0px); /* 补偿滚动条消失导致的页面偏移 */
}
/* modal-content 样式 */
.modal-content {
position: relative;
max-width: 800px;
height: auto;
/* max-height: calc(90vh - 48px); */
margin: 1rem auto; /* Keep reduced top margin */
background: var(--lora-surface);
border-radius: var(--border-radius-base);
padding: var(--space-3);
border: 1px solid var(--lora-border);
box-shadow:
0 4px 6px -1px rgba(0, 0, 0, 0.1),
0 2px 4px -1px rgba(0, 0, 0, 0.06),
0 10px 15px -3px rgba(0, 0, 0, 0.05);
overflow-y: auto;
overflow-x: hidden; /* 防止水平滚动条 */
}
/* 当 modal 打开时锁定 body */
body.modal-open {
overflow: hidden !important; /* 覆盖 base.css 中的 scroll */
padding-right: var(--scrollbar-width, 8px); /* 使用滚动条宽度作为补偿 */
}
@keyframes modalFadeIn {
from {
opacity: 0;
transform: translateY(-20px);
}
to {
opacity: 1;
transform: translateY(0);
}
}
.modal-actions {
display: flex;
gap: var(--space-2);
justify-content: center;
margin-top: var(--space-3);
}
.cancel-btn, .delete-btn, .exclude-btn, .confirm-btn {
padding: 8px var(--space-2);
border-radius: 6px;
border: none;
cursor: pointer;
font-weight: 500;
min-width: 100px;
}
.cancel-btn {
background: var(--lora-surface);
border: 1px solid var(--lora-border);
color: var(--text-color);
}
.delete-btn {
background: var(--lora-error);
color: white;
}
/* Style for exclude button - different from delete button */
.exclude-btn, .confirm-btn {
background: var(--lora-accent, #4f46e5);
color: white;
}
.cancel-btn:hover {
background: var(--lora-border);
}
.delete-btn:hover {
opacity: 0.9;
}
.exclude-btn:hover, .confirm-btn:hover {
opacity: 0.9;
background: oklch(from var(--lora-accent, #4f46e5) l c h / 85%);
}
.modal-content h2 {
color: var(--text-color);
margin-bottom: var(--space-1);
font-size: 1.5em;
}
.close {
position: absolute;
top: var(--space-2);
right: var(--space-2);
background: transparent;
border: none;
color: var(--text-color);
font-size: 1.5em;
cursor: pointer;
opacity: 0.7;
transition: opacity 0.2s;
}
.close:hover {
opacity: 1;
}
/* 统一各个 section 的样式 */
.support-section,
.changelog-section,
.update-info,
.info-item,
.path-preview {
background: rgba(0, 0, 0, 0.03);
border: 1px solid rgba(0, 0, 0, 0.1);
border-radius: var(--border-radius-sm);
padding: var(--space-2);
}
/* 深色主题统一样式 */
[data-theme="dark"] .modal-content {
background: var(--lora-surface);
border: 1px solid var(--lora-border);
}
[data-theme="dark"] .support-section,
[data-theme="dark"] .changelog-section,
[data-theme="dark"] .update-info,
[data-theme="dark"] .info-item,
[data-theme="dark"] .path-preview {
background: rgba(255, 255, 255, 0.03);
border: 1px solid var(--lora-border);
}
.primary-btn {
display: flex;
align-items: center;
gap: 8px;
padding: 8px 16px;
background-color: var(--lora-accent);
color: var(--lora-text);
border: none;
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: background-color 0.2s;
font-size: 0.95em;
}
.primary-btn:hover {
background-color: oklch(from var(--lora-accent) l c h / 85%);
color: var(--lora-text);
}
/* Secondary button styles */
.secondary-btn {
display: flex;
align-items: center;
gap: 8px;
padding: 8px 16px;
background-color: var(--card-bg);
color: var (--text-color);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: all 0.2s;
font-size: 0.95em;
}
.secondary-btn:hover {
background-color: var(--border-color);
color: var(--text-color);
}
/* Disabled button styles */
.primary-btn.disabled {
opacity: 0.5;
cursor: not-allowed;
background-color: var(--lora-accent);
color: var(--lora-text);
pointer-events: none;
}
.secondary-btn.disabled {
opacity: 0.5;
cursor: not-allowed;
pointer-events: none;
}
.restart-required-icon {
color: var(--lora-warning);
margin-left: 5px;
font-size: 0.85em;
vertical-align: text-bottom;
}
/* Dark theme specific button adjustments */
[data-theme="dark"] .primary-btn:hover {
background-color: oklch(from var(--lora-accent) l c h / 75%);
}
[data-theme="dark"] .secondary-btn {
background-color: var(--lora-surface);
}
[data-theme="dark"] .secondary-btn:hover {
background-color: oklch(35% 0.02 256 / 0.98);
}
.primary-btn.disabled {
opacity: 0.5;
cursor: not-allowed;
}
.primary-btn.disabled {
opacity: 0.5;
cursor: not-allowed;
}
/* Add styles for delete preview image */
.delete-preview {
max-width: 150px;
margin: 0 auto var(--space-2);
overflow: hidden;
}
.delete-preview img {
width: 100%;
height: auto;
max-height: 150px;
object-fit: contain;
border-radius: var(--border-radius-sm);
}
.delete-info {
text-align: center;
}
.delete-info h3 {
margin-bottom: var(--space-1);
word-break: break-word;
}
.delete-info p {
margin: var(--space-1) 0;
font-size: 0.9em;
opacity: 0.8;
}
.delete-note {
font-size: 0.85em;
color: var(--text-color);
opacity: 0.7;
font-style: italic;
margin-top: var(--space-1);
text-align: center;
}

View File

@@ -0,0 +1,48 @@
/* Delete Modal specific styles */
.delete-message {
color: var(--text-color);
margin: var(--space-2) 0;
}
/* Update delete modal styles */
.delete-modal {
display: none; /* Set initial display to none */
position: fixed;
top: 0;
left: 0;
width: 100%;
height: 100%;
background: rgba(0, 0, 0, 0.8);
z-index: var(--z-overlay);
}
/* Add new style for when modal is shown */
.delete-modal.show {
display: flex;
align-items: center;
justify-content: center;
}
.delete-modal-content {
max-width: 500px;
width: 90%;
text-align: center;
margin: 0 auto;
position: relative;
animation: modalFadeIn 0.2s ease-out;
}
.delete-model-info,
.exclude-model-info {
/* Update info display styling */
background: var(--lora-surface);
border: 1px solid var(--lora-border);
border-radius: var(--border-radius-sm);
padding: var(--space-2);
margin: var(--space-2) 0;
color: var(--text-color);
word-break: break-all;
text-align: left;
line-height: 1.5;
}

View File

@@ -0,0 +1,505 @@
/* Download Modal Styles */
.input-group {
margin-bottom: var(--space-2);
}
.input-group label {
display: block;
margin-bottom: 8px;
color: var(--text-color);
}
.input-group input,
.input-group select {
width: 100%;
padding: 8px;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
color: var(--text-color);
}
/* Version List Styles */
.version-list {
max-height: 400px;
overflow-y: auto;
margin: var(--space-2) 0;
display: flex;
flex-direction: column;
gap: 12px;
padding: 1px;
}
.version-item {
display: flex;
gap: var(--space-2);
padding: var(--space-2);
border: 1px solid var(--border-color);
border-radius: var(--border-radius-sm);
cursor: pointer;
transition: all 0.2s ease;
background: var(--bg-color);
margin: 1px;
position: relative;
}
.version-item:hover {
border-color: var(--lora-accent);
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
z-index: 1;
}
.version-item.selected {
border: 2px solid var(--lora-accent);
background: oklch(var(--lora-accent) / 0.05);
}
.version-thumbnail {
width: 80px;
height: 80px;
flex-shrink: 0;
border-radius: var(--border-radius-xs);
overflow: hidden;
background: var(--bg-color);
}
.version-thumbnail img {
width: 100%;
height: 100%;
object-fit: cover;
}
.version-content {
display: flex;
flex-direction: column;
gap: 8px;
flex: 1;
min-width: 0;
}
.version-header {
display: flex;
align-items: flex-start;
justify-content: space-between;
gap: var(--space-2);
}
.version-content h3 {
margin: 0;
font-size: 1.1em;
color: var(--text-color);
flex: 1;
}
.version-content .version-info {
display: flex;
flex-wrap: wrap;
flex-direction: row !important;
gap: 8px;
align-items: center;
font-size: 0.9em;
}
.version-content .version-info .base-model {
background: oklch(var(--lora-accent) / 0.1);
color: var(--lora-accent);
padding: 2px 8px;
border-radius: var(--border-radius-xs);
}
.version-meta {
display: flex;
gap: 12px;
font-size: 0.85em;
color: var(--text-color);
opacity: 0.7;
}
.version-meta span {
display: flex;
align-items: center;
gap: 4px;
}
.folder-item {
padding: 8px;
cursor: pointer;
border-radius: var(--border-radius-xs);
transition: background-color 0.2s;
}
.folder-item:hover {
background: var(--lora-surface);
}
.folder-item.selected {
background: oklch(var(--lora-accent) / 0.1);
border: 1px solid var(--lora-accent);
}
/* Path Input Styles */
.path-input-container {
position: relative;
display: flex;
gap: 8px;
align-items: center;
}
.path-input-container input {
flex: 1;
}
.create-folder-btn {
padding: 8px;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
color: var(--text-color);
cursor: pointer;
transition: all 0.2s ease;
display: flex;
align-items: center;
justify-content: center;
width: 36px;
height: 36px;
}
.create-folder-btn:hover {
border-color: var(--lora-accent);
background: oklch(var(--lora-accent) / 0.05);
}
.path-suggestions {
position: absolute;
top: 46%;
left: 0;
right: 0;
z-index: 1000;
margin: 0 24px;
background: var(--bg-color);
border: 1px solid var(--border-color);
border-top: none;
border-radius: 0 0 var(--border-radius-xs) var(--border-radius-xs);
max-height: 200px;
overflow-y: auto;
}
.path-suggestion {
padding: 8px 12px;
cursor: pointer;
transition: background-color 0.2s;
border-bottom: 1px solid var(--border-color);
}
.path-suggestion:last-child {
border-bottom: none;
}
.path-suggestion:hover {
background: var(--lora-surface);
}
.path-suggestion.active {
background: oklch(var(--lora-accent) / 0.1);
color: var(--lora-accent);
}
/* Breadcrumb Navigation Styles */
.breadcrumb-nav {
display: flex;
align-items: center;
gap: 4px;
margin-bottom: var(--space-2);
padding: var(--space-1);
background: var(--lora-surface);
border-radius: var(--border-radius-xs);
border: 1px solid var(--border-color);
overflow-x: auto;
white-space: nowrap;
}
.breadcrumb-item {
display: flex;
align-items: center;
gap: 4px;
padding: 4px 8px;
border-radius: var(--border-radius-xs);
cursor: pointer;
transition: all 0.2s ease;
color: var(--text-color);
opacity: 0.7;
text-decoration: none;
}
.breadcrumb-item:hover {
background: var(--bg-color);
opacity: 1;
}
.breadcrumb-item.active {
background: oklch(var(--lora-accent) / 0.1);
color: var(--lora-accent);
opacity: 1;
}
.breadcrumb-separator {
color: var(--text-color);
opacity: 0.5;
margin: 0 4px;
}
/* Folder Tree Styles */
.folder-tree-container {
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
max-height: 300px;
overflow-y: auto;
}
.folder-tree {
padding: var(--space-1);
}
.tree-node {
user-select: none;
}
.tree-node-content {
display: flex;
align-items: center;
gap: 4px;
padding: 4px 8px;
cursor: pointer;
border-radius: var(--border-radius-xs);
transition: all 0.2s ease;
position: relative;
}
.tree-node-content:hover {
background: var(--lora-surface);
}
.tree-node-content.selected {
background: oklch(var(--lora-accent) / 0.1);
border: 1px solid var(--lora-accent);
}
.tree-expand-icon {
width: 16px;
height: 16px;
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
border-radius: 2px;
transition: all 0.2s ease;
}
.tree-expand-icon:hover {
background: var(--lora-surface);
}
.tree-expand-icon.expanded {
transform: rotate(90deg);
}
.tree-folder-icon {
width: 16px;
height: 16px;
display: flex;
align-items: center;
justify-content: center;
color: var(--lora-accent);
}
.tree-folder-name {
flex: 1;
font-size: 0.9em;
color: var(--text-color);
}
.tree-children {
margin-left: 20px;
display: none;
}
.tree-children.expanded {
display: block;
}
.tree-node.has-children > .tree-node-content .tree-expand-icon {
opacity: 1;
}
.tree-node:not(.has-children) > .tree-node-content .tree-expand-icon {
opacity: 0;
pointer-events: none;
}
/* Create folder inline form */
.create-folder-form {
display: flex;
gap: 8px;
margin-left: 20px;
align-items: center;
height: 21px;
}
.create-folder-form input {
flex: 1;
padding: 4px 8px;
border: 1px solid var(--lora-accent);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
color: var(--text-color);
font-size: 0.9em;
}
.create-folder-form button {
padding: 4px 8px;
border: 1px solid var(--border-color);
border-radius: var(--border-radius-xs);
background: var(--bg-color);
color: var(--text-color);
cursor: pointer;
font-size: 0.8em;
transition: all 0.2s ease;
}
.create-folder-form button.confirm {
background: var(--lora-accent);
color: white;
border-color: var(--lora-accent);
}
.create-folder-form button:hover {
background: var(--lora-surface);
}
/* Path Preview Styles */
.path-preview {
margin-bottom: var(--space-3);
padding: var(--space-2);
background: var(--bg-color);
border-radius: var(--border-radius-sm);
border: 1px dashed var(--border-color);
}
.path-preview-header {
display: flex;
align-items: center;
justify-content: space-between;
margin-bottom: 12px;
gap: var(--space-2);
}
.path-preview-header label {
margin: 0;
color: var(--text-color);
font-size: 0.9em;
opacity: 0.8;
}
.path-display {
padding: var(--space-1);
color: var(--text-color);
font-family: monospace;
font-size: 0.9em;
line-height: 1.4;
white-space: pre-wrap;
word-break: break-all;
opacity: 0.85;
background: var(--lora-surface);
border-radius: var(--border-radius-xs);
}
/* Inline Toggle Styles */
.inline-toggle-container {
display: flex;
align-items: center;
gap: 8px;
cursor: pointer;
user-select: none;
position: relative;
}
.inline-toggle-label {
font-size: 0.85em;
color: var(--text-color);
opacity: 0.9;
white-space: nowrap;
}
.inline-toggle-container .toggle-switch {
position: relative;
width: 36px;
height: 18px;
flex-shrink: 0;
}
.inline-toggle-container .toggle-switch input {
opacity: 0;
width: 0;
height: 0;
position: absolute;
}
.inline-toggle-container .toggle-slider {
position: absolute;
cursor: pointer;
top: 0;
left: 0;
right: 0;
bottom: 0;
background-color: var(--border-color);
transition: all 0.3s ease;
border-radius: 18px;
}
.inline-toggle-container .toggle-slider:before {
position: absolute;
content: "";
height: 12px;
width: 12px;
left: 3px;
bottom: 3px;
background-color: white;
transition: all 0.3s ease;
border-radius: 50%;
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
}
.inline-toggle-container .toggle-switch input:checked + .toggle-slider {
background-color: var(--lora-accent);
}
.inline-toggle-container .toggle-switch input:checked + .toggle-slider:before {
transform: translateX(18px);
}
/* Dark theme adjustments */
[data-theme="dark"] .version-item {
background: var(--lora-surface);
}
[data-theme="dark"] .local-path {
background: var(--lora-surface);
border-color: var(--lora-border);
}
[data-theme="dark"] .toggle-slider:before {
background-color: #f0f0f0;
}
/* Enhance the local badge to make it more noticeable */
.version-item.exists-locally {
background: oklch(var(--lora-accent) / 0.05);
border-left: 4px solid var(--lora-accent);
}
.manual-path-selection.disabled {
opacity: 0.5;
pointer-events: none;
user-select: none;
}

View File

@@ -0,0 +1,72 @@
/* Example Access Modal */
.example-access-modal {
max-width: 550px;
text-align: center;
}
.example-access-options {
display: flex;
flex-direction: column;
gap: var(--space-2);
margin: var(--space-3) 0;
}
.example-option-btn {
display: flex;
flex-direction: column;
align-items: center;
padding: var(--space-2);
border-radius: var(--border-radius-sm);
border: 1px solid var(--lora-border);
background-color: var(--lora-surface);
cursor: pointer;
transition: all 0.2s;
}
.example-option-btn:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
border-color: var(--lora-accent);
}
.example-option-btn i {
font-size: 2em;
margin-bottom: var(--space-1);
color: var(--lora-accent);
}
.option-title {
font-weight: 500;
margin-bottom: 4px;
font-size: 1.1em;
}
.option-desc {
font-size: 0.9em;
opacity: 0.8;
}
.example-option-btn.disabled {
opacity: 0.5;
cursor: not-allowed;
}
.example-option-btn.disabled i {
color: var(--text-color);
opacity: 0.5;
}
.modal-footer-note {
font-size: 0.9em;
opacity: 0.7;
margin-top: var(--space-2);
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
}
/* Dark theme adjustments */
[data-theme="dark"] .example-option-btn:hover {
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.25);
}

Some files were not shown because too many files have changed in this diff Show More