mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Compare commits
358 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c23ab04d90 | ||
|
|
d50dde6cf6 | ||
|
|
fcb1fb39be | ||
|
|
b0ef74f802 | ||
|
|
f332aef41d | ||
|
|
1f91a3da8e | ||
|
|
16840c321d | ||
|
|
c109e392ad | ||
|
|
5e69671366 | ||
|
|
52d23d9b75 | ||
|
|
4c4e6d7a7b | ||
|
|
03b6e78705 | ||
|
|
24c01141d7 | ||
|
|
6dc2811af4 | ||
|
|
e6425dce32 | ||
|
|
95e2ff5f1e | ||
|
|
92ac487128 | ||
|
|
3250fa89cb | ||
|
|
7475de366b | ||
|
|
affb507b37 | ||
|
|
3320b80150 | ||
|
|
fb2b69b787 | ||
|
|
29a05f6533 | ||
|
|
9fa3fac973 | ||
|
|
904b0d104a | ||
|
|
1d31dae110 | ||
|
|
476ecb7423 | ||
|
|
4eb67cf6da | ||
|
|
a5a9f7ed83 | ||
|
|
c0b029e228 | ||
|
|
9bebcc9a4b | ||
|
|
ac7d23011c | ||
|
|
491e09b7b5 | ||
|
|
192bc237bf | ||
|
|
f041f4a114 | ||
|
|
2546580377 | ||
|
|
8fbf2ab56d | ||
|
|
ea727aad2e | ||
|
|
5520aecbba | ||
|
|
6b738a4769 | ||
|
|
903a8050b3 | ||
|
|
31b032429d | ||
|
|
2bcf341f04 | ||
|
|
ca6f45b359 | ||
|
|
2a67cec16b | ||
|
|
1800afe31b | ||
|
|
8c6311355d | ||
|
|
91801dff85 | ||
|
|
be594133f0 | ||
|
|
8a538d117e | ||
|
|
8d9118cbee | ||
|
|
b67464ea13 | ||
|
|
33334da0bb | ||
|
|
40ce2baa7b | ||
|
|
1134466cc0 | ||
|
|
92341111ad | ||
|
|
4956d6781f | ||
|
|
63562240c4 | ||
|
|
84d801cf14 | ||
|
|
b56fe4ca68 | ||
|
|
6c83c65e02 | ||
|
|
a83f020fcc | ||
|
|
7f9a3bf272 | ||
|
|
f80e266d02 | ||
|
|
7bef562541 | ||
|
|
b2428f607c | ||
|
|
8303196b57 | ||
|
|
987b8c8742 | ||
|
|
e60a579b85 | ||
|
|
be8edafed0 | ||
|
|
a258a18fa4 | ||
|
|
59010ca431 | ||
|
|
75f3764e6c | ||
|
|
867ffd1163 | ||
|
|
6acccbbb94 | ||
|
|
b2c4efab45 | ||
|
|
408a435b71 | ||
|
|
36d3cd93d5 | ||
|
|
b36fea002e | ||
|
|
52acbd954a | ||
|
|
f6709a55c3 | ||
|
|
7b374d747b | ||
|
|
fd480a9360 | ||
|
|
ec8b228867 | ||
|
|
401200050b | ||
|
|
29160bd6e5 | ||
|
|
3c9e402bc0 | ||
|
|
ff4d0f0208 | ||
|
|
f82908221c | ||
|
|
4246908f2e | ||
|
|
f64597afd2 | ||
|
|
975ff2672d | ||
|
|
e90ba31784 | ||
|
|
a4074c93bc | ||
|
|
7a8b7598c7 | ||
|
|
cd0d832f14 | ||
|
|
5b0becaaf2 | ||
|
|
9817bac2fe | ||
|
|
f6bd48cfcd | ||
|
|
01843b8f2b | ||
|
|
94ed81de5e | ||
|
|
0700b8f399 | ||
|
|
d62cff9841 | ||
|
|
083f4805b2 | ||
|
|
8e5bfd379e | ||
|
|
2366f143d8 | ||
|
|
e997f5bc1b | ||
|
|
842beec7cc | ||
|
|
d2268fc9e0 | ||
|
|
a98e26139f | ||
|
|
522a3ea88b | ||
|
|
d7949fbc30 | ||
|
|
6df083a1d5 | ||
|
|
4dc80e7f6e | ||
|
|
c2a8508513 | ||
|
|
159193ef43 | ||
|
|
1f37ffb105 | ||
|
|
919fed05c5 | ||
|
|
1814f83bee | ||
|
|
1823840456 | ||
|
|
623c28bfc3 | ||
|
|
3079131337 | ||
|
|
a34ade0120 | ||
|
|
e9ada70088 | ||
|
|
597cc48248 | ||
|
|
ec3f857ef1 | ||
|
|
383b4de539 | ||
|
|
1bf9326604 | ||
|
|
d9f5459d46 | ||
|
|
e45a1b1e19 | ||
|
|
331ad8f644 | ||
|
|
52fa88b04c | ||
|
|
8895a64d24 | ||
|
|
fdec535559 | ||
|
|
6c5559ae2d | ||
|
|
9f54622b17 | ||
|
|
03b6f4b378 | ||
|
|
af4cbe2332 | ||
|
|
141f72963a | ||
|
|
3d3c66e12f | ||
|
|
ee84571bdb | ||
|
|
6500936aad | ||
|
|
32d2b6c013 | ||
|
|
05df40977d | ||
|
|
5d7a1dcde5 | ||
|
|
9c45d9db6c | ||
|
|
ca692ed0f2 | ||
|
|
af499565d3 | ||
|
|
fe2d7e3a9e | ||
|
|
9f69822221 | ||
|
|
bb43f047c2 | ||
|
|
2356662492 | ||
|
|
1624a45093 | ||
|
|
dcb9983786 | ||
|
|
83d1828905 | ||
|
|
6a281cf3ee | ||
|
|
ed1cd39a6c | ||
|
|
dda19b3920 | ||
|
|
25139ca922 | ||
|
|
3cd57a582c | ||
|
|
d3903ac655 | ||
|
|
199e374318 | ||
|
|
8375c1413d | ||
|
|
9e268cf016 | ||
|
|
112b3abc26 | ||
|
|
a8331a2357 | ||
|
|
52e3ad08c1 | ||
|
|
8d01d04ef0 | ||
|
|
a141384907 | ||
|
|
b8aa7184bd | ||
|
|
e4195f874d | ||
|
|
d04deff5ca | ||
|
|
20ce0778a0 | ||
|
|
5a0b3470f1 | ||
|
|
a920921570 | ||
|
|
286f4ff384 | ||
|
|
71ddfafa98 | ||
|
|
b7e3e53697 | ||
|
|
16df548b77 | ||
|
|
425c33ae00 | ||
|
|
c9289ed2dc | ||
|
|
96517cbdef | ||
|
|
b03420faac | ||
|
|
65a1aa7ca2 | ||
|
|
3a92e8eaf9 | ||
|
|
a8dc50d64a | ||
|
|
3397cc7d8d | ||
|
|
c3e8131b24 | ||
|
|
f8ca8584ae | ||
|
|
3050bbe260 | ||
|
|
e1dda2795a | ||
|
|
6d8408e626 | ||
|
|
0906271aa9 | ||
|
|
4c33c9d256 | ||
|
|
fa9c78209f | ||
|
|
6678ec8a60 | ||
|
|
854e467c12 | ||
|
|
e6b94c7b21 | ||
|
|
2c6f9d8602 | ||
|
|
c74033b9c0 | ||
|
|
d2b21d27bb | ||
|
|
215272469f | ||
|
|
f7d05ab0f1 | ||
|
|
6f2ad2be77 | ||
|
|
66575c719a | ||
|
|
677a239d53 | ||
|
|
3b96bfe5af | ||
|
|
83be5cfa64 | ||
|
|
6b834c2362 | ||
|
|
7abfc49e08 | ||
|
|
65d5f50088 | ||
|
|
4f1f4ffe3d | ||
|
|
b0c2027a1c | ||
|
|
33c83358b0 | ||
|
|
31223f0526 | ||
|
|
92daadb92c | ||
|
|
fae2e274fd | ||
|
|
342a722991 | ||
|
|
65ec6aacb7 | ||
|
|
9387470c69 | ||
|
|
31f6edf8f0 | ||
|
|
487b062175 | ||
|
|
d8e13de096 | ||
|
|
e8a30088ef | ||
|
|
bf7b07ba74 | ||
|
|
28fe3e7b7a | ||
|
|
c0eff2bb5e | ||
|
|
848c1741fe | ||
|
|
1370b8e8c1 | ||
|
|
82a068e610 | ||
|
|
32f42bafaa | ||
|
|
4081b7f022 | ||
|
|
a5808193a6 | ||
|
|
854ca322c1 | ||
|
|
c1d9b5137a | ||
|
|
f33d5745b3 | ||
|
|
d89c2ca128 | ||
|
|
835584cc85 | ||
|
|
b2ffbe3a68 | ||
|
|
defcc79e6c | ||
|
|
c06d9f84f0 | ||
|
|
fe57a8e156 | ||
|
|
b77105795a | ||
|
|
e2df5fcf27 | ||
|
|
836a64e728 | ||
|
|
08ba0c9f42 | ||
|
|
6fcc6a5299 | ||
|
|
6dd58248c6 | ||
|
|
2786801b71 | ||
|
|
ea29cbeb7a | ||
|
|
3cf9121a8c | ||
|
|
381bd3938a | ||
|
|
e4ce384023 | ||
|
|
12d1857b13 | ||
|
|
0d9003dea4 | ||
|
|
1a3751acfa | ||
|
|
c5a3af2399 | ||
|
|
ea8a64fafc | ||
|
|
981e367bf1 | ||
|
|
a3d6e62035 | ||
|
|
7f205cdcc8 | ||
|
|
e587189880 | ||
|
|
206c1bd69f | ||
|
|
a7d9255c2c | ||
|
|
08265a85ec | ||
|
|
1ed5630464 | ||
|
|
c784615f11 | ||
|
|
26d51b1190 | ||
|
|
d83fad6abc | ||
|
|
692796db46 | ||
|
|
f15c6f33f9 | ||
|
|
dda9eb4d7c | ||
|
|
6f3aeb61e7 | ||
|
|
d6145e633f | ||
|
|
07014d98ce | ||
|
|
e8ccdabe6c | ||
|
|
cf9fd2d5c2 | ||
|
|
bf9aa9356b | ||
|
|
68d00ce289 | ||
|
|
5288021e4f | ||
|
|
4d38add291 | ||
|
|
804808da4a | ||
|
|
298a95432d | ||
|
|
a834fc4b30 | ||
|
|
2c6c9542dd | ||
|
|
a9a7f4c8ec | ||
|
|
ea9370443d | ||
|
|
c2e00b240e | ||
|
|
a2b81ea099 | ||
|
|
ee609e8eac | ||
|
|
e04ef671e9 | ||
|
|
0184dfd7eb | ||
|
|
eccfa0ca54 | ||
|
|
6d3feb4bef | ||
|
|
29d2b5ee4b | ||
|
|
c82fabb67f | ||
|
|
fcfc868e57 | ||
|
|
67b403f8ca | ||
|
|
de06c6b2f6 | ||
|
|
fa444dfb8a | ||
|
|
124002a472 | ||
|
|
0c883433c1 | ||
|
|
bcf3b2cf55 | ||
|
|
357c4e9c08 | ||
|
|
9edfc68e91 | ||
|
|
8c06cb3e80 | ||
|
|
144fa0a6d4 | ||
|
|
25d5a1541e | ||
|
|
a579d36389 | ||
|
|
d766dac341 | ||
|
|
b15ef1bbc6 | ||
|
|
3e52e00597 | ||
|
|
f749dd0d52 | ||
|
|
48a8a42108 | ||
|
|
db7f57a5a4 | ||
|
|
556381b983 | ||
|
|
158d7d5898 | ||
|
|
18844da95d | ||
|
|
7e0df4d718 | ||
|
|
0dbb76e8c8 | ||
|
|
f73b3422a6 | ||
|
|
bd95e802ec | ||
|
|
5de16a78c5 | ||
|
|
6f8e09fcde | ||
|
|
f54d480f03 | ||
|
|
e68b213fb3 | ||
|
|
132334d500 | ||
|
|
a6f04c6d7e | ||
|
|
854e8bf356 | ||
|
|
6ff883d2d3 | ||
|
|
849b97afba | ||
|
|
1bd2635864 | ||
|
|
79ab0f7b6c | ||
|
|
79011bd257 | ||
|
|
c692713ffb | ||
|
|
df9b554ce1 | ||
|
|
277a8e4682 | ||
|
|
acb52dba09 | ||
|
|
8f10765254 | ||
|
|
0653f59473 | ||
|
|
7a4b5a4667 | ||
|
|
49c4a4068b | ||
|
|
40ad590046 | ||
|
|
30374ae3e6 | ||
|
|
ab22d16bad | ||
|
|
971cd56a4a | ||
|
|
d7cb546c5f | ||
|
|
9d8b7344cd | ||
|
|
2d4f6ae7ce | ||
|
|
d9126807b0 | ||
|
|
cad5fb3fba | ||
|
|
afe23ad6b7 | ||
|
|
fc4327087b | ||
|
|
71762d788f | ||
|
|
6472e00fb0 | ||
|
|
4043846767 | ||
|
|
d3b2bc962c | ||
|
|
54f7b64821 |
1
.github/copilot-instructions.md
vendored
Normal file
1
.github/copilot-instructions.md
vendored
Normal file
@@ -0,0 +1 @@
|
||||
Always use English for comments.
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,5 +1,6 @@
|
||||
__pycache__/
|
||||
settings.json
|
||||
path_mappings.yaml
|
||||
output/*
|
||||
py/run_test.py
|
||||
.vscode/
|
||||
|
||||
122
README.md
122
README.md
@@ -10,86 +10,72 @@ A comprehensive toolset that streamlines organizing, downloading, and applying L
|
||||
|
||||

|
||||
|
||||
One-click Integration:
|
||||

|
||||
|
||||
## 📺 Tutorial: One-Click LoRA Integration
|
||||
Watch this quick tutorial to learn how to use the new one-click LoRA integration feature:
|
||||
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
[](https://youtu.be/hvKw31YpE-U)
|
||||
|
||||
## 🌐 Browser Extension
|
||||
Enhance your Civitai browsing experience with our companion browser extension! See which models you already have, download new ones with a single click, and manage your downloads efficiently.
|
||||
|
||||

|
||||
|
||||
<div>
|
||||
<a href="https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb" style="display: inline-block; background-color: #4285F4; color: white; padding: 8px 16px; text-decoration: none; border-radius: 4px; font-weight: bold; margin: 10px 0;">
|
||||
<img src="https://www.google.com/chrome/static/images/chrome-logo.svg" width="20" style="vertical-align: middle; margin-right: 8px;"> Get Extension from Chrome Web Store
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div>
|
||||
|
||||
📚 [Learn More: Complete Tutorial](https://github.com/willmiao/ComfyUI-Lora-Manager/wiki/LoRA-Manager-Civitai-Extension-(Chrome-Extension))
|
||||
|
||||
---
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.8.19
|
||||
* **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights
|
||||
* **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting
|
||||
* **Enhanced NSFW Controls** - Added support for setting NSFW levels on recipes with automatic content blurring based on user preferences
|
||||
* **Customizable Card Display** - New display settings allowing users to choose whether card information and action buttons are always visible or only revealed on hover
|
||||
* **Expanded Compatibility** - Added support for efficiency-nodes-comfyui in Save Recipe and Save Image nodes, plus fixed compatibility with ComfyUI_Custom_Nodes_AlekPet
|
||||
### v0.9.2
|
||||
* **Bulk Auto-Organization Action** - Added a new bulk auto-organization feature. You can now select multiple models and automatically organize them according to your current path template settings for streamlined management.
|
||||
* **Bug Fixes** - Addressed several bugs to improve stability and reliability.
|
||||
|
||||
### v0.8.18
|
||||
* **Custom Example Images** - Added ability to import your own example images for LoRAs and checkpoints with automatic metadata extraction from embedded information
|
||||
* **Enhanced Example Management** - New action buttons to set specific examples as previews or delete custom examples
|
||||
* **Improved Duplicate Detection** - Enhanced "Find Duplicates" with hash verification feature to eliminate false positives when identifying duplicate models
|
||||
* **Tag Management** - Added tag editing functionality allowing users to customize and manage model tags
|
||||
* **Advanced Selection Controls** - Implemented Ctrl+A shortcut for quickly selecting all filtered LoRAs, automatically entering bulk mode when needed
|
||||
* **Note**: Cache file functionality temporarily disabled pending rework
|
||||
### v0.9.1
|
||||
* **Enhanced Bulk Operations** - Improved bulk operations with Marquee Selection and a bulk operation context menu, providing a more intuitive, desktop-application-like user experience.
|
||||
* **New Bulk Actions** - Added bulk operations for adding tags and setting base models to multiple models simultaneously.
|
||||
|
||||
### v0.8.17
|
||||
* **Duplicate Model Detection** - Added "Find Duplicates" functionality for LoRAs and checkpoints using model file hash detection, enabling convenient viewing and batch deletion of duplicate models
|
||||
* **Enhanced URL Recipe Imports** - Optimized import recipe via URL functionality using CivitAI API calls instead of web scraping, now supporting all rated images (including NSFW) for recipe imports
|
||||
* **Improved TriggerWord Control** - Enhanced TriggerWord Toggle node with new default_active switch to set the initial state (active/inactive) when trigger words are added
|
||||
* **Centralized Example Management** - Added "Migrate Existing Example Images" feature to consolidate downloaded example images from model folders into central storage with customizable naming patterns
|
||||
* **Intelligent Word Suggestions** - Implemented smart trigger word suggestions by reading class tokens and tag frequency from safetensors files, displaying recommendations when editing trigger words
|
||||
* **Model Version Management** - Added "Re-link to CivitAI" context menu option for connecting models to different CivitAI versions when needed
|
||||
### v0.9.0
|
||||
* **UI Overhaul for Enhanced Navigation** - Replaced the top flat folder tags with a new folder sidebar and breadcrumb navigation system for a more intuitive folder browsing and selection experience.
|
||||
* **Dual-Mode Folder Sidebar** - The new folder sidebar offers two display modes: 'List Mode,' which mirrors the classic folder view, and 'Tree Mode,' which presents a hierarchical folder structure for effortless navigation through nested directories.
|
||||
* **Internationalization Support** - Introduced multi-language support, now available in English, Simplified Chinese, Traditional Chinese, Spanish, Japanese, Korean, French, Russian, and German. Feedback from native speakers is welcome to improve the translations.
|
||||
* **Automatic Filename Conflict Resolution** - Implemented automatic file renaming (`original name + short hash`) to prevent conflicts when downloading or moving models.
|
||||
* **Performance Optimizations & Bug Fixes** - Various performance improvements and bug fixes for a more stable and responsive experience.
|
||||
|
||||
### v0.8.16
|
||||
* **Dramatic Startup Speed Improvement** - Added cache serialization mechanism for significantly faster loading times, especially beneficial for large model collections
|
||||
* **Enhanced Refresh Options** - Extended functionality with "Full Rebuild (complete)" option alongside "Quick Refresh (incremental)" to fix potential memory cache issues without requiring application restart
|
||||
* **Customizable Display Density** - Replaced compact mode with adjustable display density settings for personalized layout customization
|
||||
* **Model Creator Information** - Added creator details to model information panels for better attribution
|
||||
* **Improved WebP Support** - Enhanced Save Image node with workflow embedding capability for WebP format images
|
||||
* **Direct Example Access** - Added "Open Example Images Folder" button to card interfaces for convenient browsing of downloaded model examples
|
||||
* **Enhanced Compatibility** - Full ComfyUI Desktop support for "Send lora or recipe to workflow" functionality
|
||||
* **Cache Management** - Added settings to clear existing cache files when needed
|
||||
* **Bug Fixes & Stability** - Various improvements for overall reliability and performance
|
||||
### v0.8.30
|
||||
* **Automatic Model Path Correction** - Added auto-correction for model paths in built-in nodes such as Load Checkpoint, Load Diffusion Model, Load LoRA, and other custom nodes with similar functionality. Workflows containing outdated or incorrect model paths will now be automatically updated to reflect the current location of your models.
|
||||
* **Node UI Enhancements** - Improved node interface for a smoother and more intuitive user experience.
|
||||
* **Bug Fixes** - Addressed various bugs to enhance stability and reliability.
|
||||
|
||||
### v0.8.15
|
||||
* **Enhanced One-Click Integration** - Replaced copy button with direct send button allowing LoRAs/recipes to be sent directly to your current ComfyUI workflow without needing to paste
|
||||
* **Flexible Workflow Integration** - Click to append LoRAs/recipes to existing loader nodes or Shift+click to replace content, with additional right-click menu options for "Send to Workflow (Append)" or "Send to Workflow (Replace)"
|
||||
* **Improved LoRA Loader Controls** - Added header drag functionality for proportional strength adjustment of all LoRAs simultaneously (including CLIP strengths when expanded)
|
||||
* **Keyboard Navigation Support** - Implemented Page Up/Down for page scrolling, Home key to jump to top, and End key to jump to bottom for faster browsing through large collections
|
||||
### v0.8.29
|
||||
* **Enhanced Recipe Imports** - Improved recipe importing with new target folder selection, featuring path input autocomplete and interactive folder tree navigation. Added a "Use Default Path" option when downloading missing LoRAs.
|
||||
* **WanVideo Lora Select Node Update** - Updated the WanVideo Lora Select node with a 'merge_loras' option to match the counterpart node in the WanVideoWrapper node package.
|
||||
* **Autocomplete Conflict Resolution** - Resolved an autocomplete feature conflict in LoRA nodes with pysssss autocomplete.
|
||||
* **Improved Download Functionality** - Enhanced download functionality with resumable downloads and improved error handling.
|
||||
* **Bug Fixes** - Addressed several bugs for improved stability and performance.
|
||||
|
||||
### v0.8.14
|
||||
* **Virtualized Scrolling** - Completely rebuilt rendering mechanism for smooth browsing with no lag or freezing, now supporting virtually unlimited model collections with optimized layouts for large displays, improving space utilization and user experience
|
||||
* **Compact Display Mode** - Added space-efficient view option that displays more cards per row (7 on 1080p, 8 on 2K, 10 on 4K)
|
||||
* **Enhanced LoRA Node Functionality** - Comprehensive improvements to LoRA loader/stacker nodes including real-time trigger word updates (reflecting any change anywhere in the LoRA chain for precise updates) and expanded context menu with "Copy Notes" and "Copy Trigger Words" options for faster workflow
|
||||
### v0.8.28
|
||||
* **Autocomplete for Node Inputs** - Instantly find and add LoRAs by filename directly in Lora Loader, Lora Stacker, and WanVideo Lora Select nodes. Autocomplete suggestions include preview tooltips and preset weights, allowing you to quickly select LoRAs without opening the LoRA Manager UI.
|
||||
* **Duplicate Notification Control** - Added a switch to duplicates mode, enabling users to turn off duplicate model notifications for a more streamlined experience.
|
||||
* **Download Example Images from Context Menu** - Introduced a new context menu option to download example images for individual models.
|
||||
|
||||
### v0.8.13
|
||||
* **Enhanced Recipe Management** - Added "Find duplicates" feature to identify and batch delete duplicate recipes with duplicate detection notifications during imports
|
||||
* **Improved Source Tracking** - Source URLs are now saved with recipes imported via URL, allowing users to view original content with one click or manually edit links
|
||||
* **Advanced LoRA Control** - Double-click LoRAs in Loader/Stacker nodes to access expanded CLIP strength controls for more precise adjustments of model and CLIP strength separately
|
||||
* **Lycoris Model Support** - Added compatibility with Lycoris models for expanded creative options
|
||||
* **Bug Fixes & UX Improvements** - Resolved various issues and enhanced overall user experience with numerous optimizations
|
||||
### v0.8.27
|
||||
* **User Experience Enhancements** - Improved the model download target folder selection with path input autocomplete and interactive folder tree navigation, making it easier and faster to choose where models are saved.
|
||||
* **Default Path Option for Downloads** - Added a "Use Default Path" option when downloading models. When enabled, models are automatically organized and stored according to your configured path template settings.
|
||||
* **Advanced Download Path Templates** - Expanded path template settings, allowing users to set individual templates for LoRA, checkpoint, and embedding models for greater flexibility. Introduced the `{author}` placeholder, enabling automatic organization of model files by creator name.
|
||||
* **Bug Fixes & Stability Improvements** - Addressed various bugs and improved overall stability for a smoother experience.
|
||||
|
||||
### v0.8.12
|
||||
* **Enhanced Model Discovery** - Added alphabetical navigation bar to LoRAs page for faster browsing through large collections
|
||||
* **Optimized Example Images** - Improved download logic to automatically refresh stale metadata before fetching example images
|
||||
* **Model Exclusion System** - New right-click option to exclude specific LoRAs or checkpoints from management
|
||||
* **Improved Showcase Experience** - Enhanced interaction in LoRA and checkpoint showcase areas for better usability
|
||||
|
||||
### v0.8.11
|
||||
* **Offline Image Support** - Added functionality to download and save all model example images locally, ensuring access even when offline or if images are removed from CivitAI or the site is down
|
||||
* **Resilient Download System** - Implemented pause/resume capability with checkpoint recovery that persists through restarts or unexpected exits
|
||||
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
|
||||
|
||||
### v0.8.10
|
||||
* **Standalone Mode** - Run LoRA Manager independently from ComfyUI for a lightweight experience that works even with other stable diffusion interfaces
|
||||
* **Portable Edition** - New one-click portable version for easy startup and updates in standalone mode
|
||||
* **Enhanced Metadata Collection** - Added support for SamplerCustomAdvanced node in the metadata collector module
|
||||
* **Improved UI Organization** - Optimized Lora Loader node height to display up to 5 LoRAs at once with scrolling capability for larger collections
|
||||
### v0.8.26
|
||||
* **Creator Search Option** - Added ability to search models by creator name, making it easier to find models from specific authors.
|
||||
* **Enhanced Node Usability** - Improved user experience for Lora Loader, Lora Stacker, and WanVideo Lora Select nodes by fixing the maximum height of the text input area. Users can now freely and conveniently adjust the LoRA region within these nodes.
|
||||
* **Compatibility Fixes** - Resolved compatibility issues with ComfyUI and certain custom nodes, including ComfyUI-Custom-Scripts, ensuring smoother integration and operation.
|
||||
|
||||
[View Update History](./update_logs.md)
|
||||
|
||||
@@ -148,10 +134,11 @@ Watch this quick tutorial to learn how to use the new one-click LoRA integration
|
||||
|
||||
### Option 2: **Portable Standalone Edition** (No ComfyUI required)
|
||||
|
||||
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.8.15/lora_manager_portable.7z)
|
||||
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.8.26/lora_manager_portable.7z)
|
||||
2. Copy the provided `settings.json.example` file to create a new file named `settings.json` in `comfyui-lora-manager` folder
|
||||
3. Edit `settings.json` to include your correct model folder paths and CivitAI API key
|
||||
4. Run run.bat
|
||||
- To change the startup port, edit `run.bat` and modify the parameter (e.g. `--port 9001`)
|
||||
|
||||
### Option 3: **Manual Installation**
|
||||
|
||||
@@ -281,3 +268,6 @@ Join our Discord community for support, discussions, and updates:
|
||||
[Discord Server](https://discord.gg/vcqNrWVFvM)
|
||||
|
||||
---
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#willmiao/ComfyUI-Lora-Manager&Date)
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
from .py.lora_manager import LoraManager
|
||||
from .py.nodes.lora_loader import LoraManagerLoader
|
||||
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader
|
||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||
from .py.nodes.lora_stacker import LoraStacker
|
||||
from .py.nodes.save_image import SaveImage
|
||||
from .py.nodes.debug_metadata import DebugMetadata
|
||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
||||
# Import metadata collector to install hooks on startup
|
||||
from .py.metadata_collector import init as init_metadata_collector
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
LoraManagerLoader.NAME: LoraManagerLoader,
|
||||
LoraManagerTextLoader.NAME: LoraManagerTextLoader,
|
||||
TriggerWordToggle.NAME: TriggerWordToggle,
|
||||
LoraStacker.NAME: LoraStacker,
|
||||
SaveImage.NAME: SaveImage,
|
||||
DebugMetadata.NAME: DebugMetadata
|
||||
DebugMetadata.NAME: DebugMetadata,
|
||||
WanVideoLoraSelect.NAME: WanVideoLoraSelect
|
||||
}
|
||||
|
||||
WEB_DIRECTORY = "./web/comfyui"
|
||||
|
||||
182
docs/EventManagementImplementation.md
Normal file
182
docs/EventManagementImplementation.md
Normal file
@@ -0,0 +1,182 @@
|
||||
# Event Management Implementation Summary
|
||||
|
||||
## What Has Been Implemented
|
||||
|
||||
### 1. Enhanced EventManager Class
|
||||
- **Location**: `static/js/utils/EventManager.js`
|
||||
- **Features**:
|
||||
- Priority-based event handling
|
||||
- Conditional execution based on application state
|
||||
- Element filtering (target/exclude selectors)
|
||||
- Mouse button filtering
|
||||
- Automatic cleanup with cleanup functions
|
||||
- State tracking for app modes
|
||||
- Error handling for event handlers
|
||||
|
||||
### 2. BulkManager Integration
|
||||
- **Location**: `static/js/managers/BulkManager.js`
|
||||
- **Migrated Events**:
|
||||
- Global keyboard shortcuts (Ctrl+A, Escape, B key)
|
||||
- Marquee selection events (mousedown, mousemove, mouseup, contextmenu)
|
||||
- State synchronization with EventManager
|
||||
- **Benefits**:
|
||||
- Centralized priority handling
|
||||
- Conditional execution based on modal state
|
||||
- Better coordination with other components
|
||||
|
||||
### 3. UIHelpers Integration
|
||||
- **Location**: `static/js/utils/uiHelpers.js`
|
||||
- **Migrated Events**:
|
||||
- Mouse position tracking for node selector positioning
|
||||
- Node selector click events (outside clicks and selection)
|
||||
- State management for node selector
|
||||
- **Benefits**:
|
||||
- Reduced direct DOM listeners
|
||||
- Coordinated state tracking
|
||||
- Better cleanup
|
||||
|
||||
### 4. ModelCard Integration
|
||||
- **Location**: `static/js/components/shared/ModelCard.js`
|
||||
- **Migrated Events**:
|
||||
- Model card click delegation
|
||||
- Action button handling (star, globe, copy, etc.)
|
||||
- Better return value handling for event propagation
|
||||
- **Benefits**:
|
||||
- Single event listener for all model cards
|
||||
- Priority-based execution
|
||||
- Better event flow control
|
||||
|
||||
### 5. Documentation and Initialization
|
||||
- **EventManagerDocs.md**: Comprehensive documentation
|
||||
- **eventManagementInit.js**: Initialization and global handlers
|
||||
- **Features**:
|
||||
- Global escape key handling
|
||||
- Modal state synchronization
|
||||
- Error handling
|
||||
- Analytics integration points
|
||||
- Cleanup on page unload
|
||||
|
||||
## Application States Tracked
|
||||
|
||||
1. **bulkMode**: When bulk selection mode is active
|
||||
2. **marqueeActive**: When marquee selection is in progress
|
||||
3. **modalOpen**: When any modal dialog is open
|
||||
4. **nodeSelectorActive**: When node selector popup is visible
|
||||
|
||||
## Priority Levels Used
|
||||
|
||||
- **250+**: Critical system events (escape keys)
|
||||
- **200+**: High priority system events (modal close)
|
||||
- **100-199**: Application-level shortcuts (bulk operations)
|
||||
- **80-99**: UI interactions (marquee selection)
|
||||
- **60-79**: Component interactions (model cards)
|
||||
- **10-49**: Tracking and monitoring
|
||||
- **1-9**: Analytics and low-priority tasks
|
||||
|
||||
## Event Flow Examples
|
||||
|
||||
### Bulk Mode Toggle (B key)
|
||||
1. **Priority 100**: BulkManager keyboard handler catches 'b' key
|
||||
2. Toggles bulk mode state
|
||||
3. Updates EventManager state
|
||||
4. Updates UI accordingly
|
||||
5. Stops propagation (returns true)
|
||||
|
||||
### Marquee Selection
|
||||
1. **Priority 80**: BulkManager mousedown handler (only in .models-container, excluding cards/buttons)
|
||||
2. Starts marquee selection
|
||||
3. **Priority 90**: BulkManager mousemove handler (only when marquee active)
|
||||
4. Updates selection rectangle
|
||||
5. **Priority 90**: BulkManager mouseup handler ends selection
|
||||
|
||||
### Model Card Click
|
||||
1. **Priority 60**: ModelCard delegation handler checks for specific elements
|
||||
2. If action button: handles action and stops propagation
|
||||
3. If general card click: continues to other handlers
|
||||
4. Bulk selection may also handle the event if in bulk mode
|
||||
|
||||
## Remaining Event Listeners (Not Yet Migrated)
|
||||
|
||||
### High Priority for Migration
|
||||
1. **SearchManager keyboard events** - Global search shortcuts
|
||||
2. **ModalManager escape handling** - Already integrated with initialization
|
||||
3. **Scroll-based events** - Back to top, virtual scrolling
|
||||
4. **Resize events** - Panel positioning, responsive layouts
|
||||
|
||||
### Medium Priority
|
||||
1. **Form input events** - Tag inputs, settings forms
|
||||
2. **Component-specific events** - Recipe modal, showcase view
|
||||
3. **Sidebar events** - Resize handling, toggle events
|
||||
|
||||
### Low Priority (Can Remain As-Is)
|
||||
1. **VirtualScroller events** - Performance-critical, specialized
|
||||
2. **Component lifecycle events** - Modal open/close callbacks
|
||||
3. **One-time setup events** - Theme initialization, etc.
|
||||
|
||||
## Benefits Achieved
|
||||
|
||||
### Performance Improvements
|
||||
- **Reduced DOM listeners**: From ~15+ individual listeners to ~5 coordinated handlers
|
||||
- **Conditional execution**: Handlers only run when conditions are met
|
||||
- **Priority ordering**: Important events handled first
|
||||
- **Better memory management**: Automatic cleanup prevents leaks
|
||||
|
||||
### Coordination Improvements
|
||||
- **State synchronization**: All components aware of app state
|
||||
- **Event flow control**: Proper propagation stopping
|
||||
- **Conflict resolution**: Priority system prevents conflicts
|
||||
- **Debugging**: Centralized event handling for easier debugging
|
||||
|
||||
### Code Quality Improvements
|
||||
- **Consistent patterns**: All event handling follows same patterns
|
||||
- **Better separation of concerns**: Event logic separated from business logic
|
||||
- **Error handling**: Centralized error catching and reporting
|
||||
- **Documentation**: Clear patterns for future development
|
||||
|
||||
## Next Steps (Recommendations)
|
||||
|
||||
### 1. Migrate Search Events
|
||||
```javascript
|
||||
// In SearchManager.js
|
||||
eventManager.addHandler('keydown', 'search-shortcuts', (e) => {
|
||||
if ((e.ctrlKey || e.metaKey) && e.key === 'f') {
|
||||
this.focusSearchInput();
|
||||
return true;
|
||||
}
|
||||
}, { priority: 120 });
|
||||
```
|
||||
|
||||
### 2. Integrate Resize Events
|
||||
```javascript
|
||||
// Create ResizeManager
|
||||
eventManager.addHandler('resize', 'layout-resize', debounce((e) => {
|
||||
this.updateLayoutDimensions();
|
||||
}, 250), { priority: 50 });
|
||||
```
|
||||
|
||||
### 3. Add Debug Mode
|
||||
```javascript
|
||||
// In EventManager.js
|
||||
if (window.DEBUG_EVENTS) {
|
||||
console.log(`Event ${eventType} handled by ${source} (priority: ${priority})`);
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Create Event Analytics
|
||||
```javascript
|
||||
// Track event patterns for optimization
|
||||
eventManager.addHandler('*', 'analytics', (e) => {
|
||||
this.trackEventUsage(e.type, performance.now());
|
||||
}, { priority: 1 });
|
||||
```
|
||||
|
||||
## Testing Recommendations
|
||||
|
||||
1. **Verify bulk mode interactions** work correctly
|
||||
2. **Test marquee selection** in various scenarios
|
||||
3. **Check modal state synchronization**
|
||||
4. **Verify node selector** positioning and cleanup
|
||||
5. **Test keyboard shortcuts** don't conflict
|
||||
6. **Verify proper cleanup** when components are destroyed
|
||||
|
||||
The centralized event management system provides a solid foundation for coordinated, efficient event handling across the application while maintaining good performance and code organization.
|
||||
301
docs/EventManagerDocs.md
Normal file
301
docs/EventManagerDocs.md
Normal file
@@ -0,0 +1,301 @@
|
||||
# Centralized Event Management System
|
||||
|
||||
This document describes the centralized event management system that coordinates event handling across the ComfyUI LoRA Manager application.
|
||||
|
||||
## Overview
|
||||
|
||||
The `EventManager` class provides a centralized way to handle DOM events with priority-based execution, conditional execution based on application state, and proper cleanup mechanisms.
|
||||
|
||||
## Features
|
||||
|
||||
- **Priority-based execution**: Handlers with higher priority run first
|
||||
- **Conditional execution**: Handlers can be executed based on application state
|
||||
- **Element filtering**: Handlers can target specific elements or exclude others
|
||||
- **Automatic cleanup**: Cleanup functions are called when handlers are removed
|
||||
- **State tracking**: Tracks application states like bulk mode, modal open, etc.
|
||||
|
||||
## Basic Usage
|
||||
|
||||
### Importing
|
||||
|
||||
```javascript
|
||||
import { eventManager } from './EventManager.js';
|
||||
```
|
||||
|
||||
### Adding Event Handlers
|
||||
|
||||
```javascript
|
||||
eventManager.addHandler('click', 'myComponent', (event) => {
|
||||
console.log('Button clicked!');
|
||||
return true; // Stop propagation to other handlers
|
||||
}, {
|
||||
priority: 100,
|
||||
targetSelector: '.my-button',
|
||||
skipWhenModalOpen: true
|
||||
});
|
||||
```
|
||||
|
||||
### Removing Event Handlers
|
||||
|
||||
```javascript
|
||||
// Remove specific handler
|
||||
eventManager.removeHandler('click', 'myComponent');
|
||||
|
||||
// Remove all handlers for a component
|
||||
eventManager.removeAllHandlersForSource('myComponent');
|
||||
```
|
||||
|
||||
### Updating Application State
|
||||
|
||||
```javascript
|
||||
// Set state
|
||||
eventManager.setState('bulkMode', true);
|
||||
eventManager.setState('modalOpen', true);
|
||||
|
||||
// Get state
|
||||
const isBulkMode = eventManager.getState('bulkMode');
|
||||
```
|
||||
|
||||
## Available States
|
||||
|
||||
- `bulkMode`: Whether bulk selection mode is active
|
||||
- `marqueeActive`: Whether marquee selection is in progress
|
||||
- `modalOpen`: Whether any modal is currently open
|
||||
- `nodeSelectorActive`: Whether the node selector popup is active
|
||||
|
||||
## Handler Options
|
||||
|
||||
### Priority
|
||||
Higher numbers = higher priority. Handlers run in descending priority order.
|
||||
|
||||
```javascript
|
||||
{
|
||||
priority: 100 // High priority
|
||||
}
|
||||
```
|
||||
|
||||
### Conditional Execution
|
||||
|
||||
```javascript
|
||||
{
|
||||
onlyInBulkMode: true, // Only run when bulk mode is active
|
||||
onlyWhenMarqueeActive: true, // Only run when marquee selection is active
|
||||
skipWhenModalOpen: true, // Skip when any modal is open
|
||||
skipWhenNodeSelectorActive: true, // Skip when node selector is active
|
||||
onlyWhenNodeSelectorActive: true // Only run when node selector is active
|
||||
}
|
||||
```
|
||||
|
||||
### Element Filtering
|
||||
|
||||
```javascript
|
||||
{
|
||||
targetSelector: '.model-card', // Only handle events on matching elements
|
||||
excludeSelector: 'button, input', // Exclude events from these elements
|
||||
button: 0 // Only handle specific mouse button (0=left, 1=middle, 2=right)
|
||||
}
|
||||
```
|
||||
|
||||
### Cleanup Functions
|
||||
|
||||
```javascript
|
||||
{
|
||||
cleanup: () => {
|
||||
// Custom cleanup logic
|
||||
console.log('Handler cleaned up');
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### BulkManager Integration
|
||||
|
||||
```javascript
|
||||
class BulkManager {
|
||||
registerEventHandlers() {
|
||||
// High priority keyboard shortcuts
|
||||
eventManager.addHandler('keydown', 'bulkManager-keyboard', (e) => {
|
||||
return this.handleGlobalKeyboard(e);
|
||||
}, {
|
||||
priority: 100,
|
||||
skipWhenModalOpen: true
|
||||
});
|
||||
|
||||
// Marquee selection
|
||||
eventManager.addHandler('mousedown', 'bulkManager-marquee-start', (e) => {
|
||||
return this.handleMarqueeStart(e);
|
||||
}, {
|
||||
priority: 80,
|
||||
skipWhenModalOpen: true,
|
||||
targetSelector: '.models-container',
|
||||
excludeSelector: '.model-card, button, input',
|
||||
button: 0
|
||||
});
|
||||
}
|
||||
|
||||
cleanup() {
|
||||
eventManager.removeAllHandlersForSource('bulkManager-keyboard');
|
||||
eventManager.removeAllHandlersForSource('bulkManager-marquee-start');
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Modal Integration
|
||||
|
||||
```javascript
|
||||
class ModalManager {
|
||||
showModal(modalId) {
|
||||
// Update state when modal opens
|
||||
eventManager.setState('modalOpen', true);
|
||||
this.displayModal(modalId);
|
||||
}
|
||||
|
||||
closeModal(modalId) {
|
||||
// Update state when modal closes
|
||||
eventManager.setState('modalOpen', false);
|
||||
this.hideModal(modalId);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Component Event Delegation
|
||||
|
||||
```javascript
|
||||
export function setupComponentEvents() {
|
||||
eventManager.addHandler('click', 'myComponent-actions', (event) => {
|
||||
const button = event.target.closest('.action-button');
|
||||
if (!button) return false;
|
||||
|
||||
this.handleAction(button.dataset.action);
|
||||
return true; // Stop propagation
|
||||
}, {
|
||||
priority: 60,
|
||||
targetSelector: '.component-container'
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Use Descriptive Source Names
|
||||
Use the format `componentName-purposeDescription`:
|
||||
```javascript
|
||||
// Good
|
||||
'bulkManager-marqueeSelection'
|
||||
'nodeSelector-clickOutside'
|
||||
'modelCard-delegation'
|
||||
|
||||
// Avoid
|
||||
'bulk'
|
||||
'click'
|
||||
'handler1'
|
||||
```
|
||||
|
||||
### 2. Set Appropriate Priorities
|
||||
- 200+: Critical system events (escape keys, critical modals)
|
||||
- 100-199: High priority application events (keyboard shortcuts)
|
||||
- 50-99: Normal UI interactions (buttons, cards)
|
||||
- 1-49: Low priority events (tracking, analytics)
|
||||
|
||||
### 3. Use Conditional Execution
|
||||
Instead of checking state inside handlers, use options:
|
||||
```javascript
|
||||
// Good
|
||||
eventManager.addHandler('click', 'bulk-action', handler, {
|
||||
onlyInBulkMode: true
|
||||
});
|
||||
|
||||
// Avoid
|
||||
eventManager.addHandler('click', 'bulk-action', (e) => {
|
||||
if (!state.bulkMode) return;
|
||||
// handler logic
|
||||
});
|
||||
```
|
||||
|
||||
### 4. Clean Up Properly
|
||||
Always clean up handlers when components are destroyed:
|
||||
```javascript
|
||||
class MyComponent {
|
||||
constructor() {
|
||||
this.registerEvents();
|
||||
}
|
||||
|
||||
destroy() {
|
||||
eventManager.removeAllHandlersForSource('myComponent');
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 5. Return Values Matter
|
||||
- Return `true` to stop event propagation to other handlers
|
||||
- Return `false` or `undefined` to continue with other handlers
|
||||
|
||||
## Migration Guide
|
||||
|
||||
### From Direct Event Listeners
|
||||
|
||||
**Before:**
|
||||
```javascript
|
||||
document.addEventListener('click', (e) => {
|
||||
if (e.target.closest('.my-button')) {
|
||||
this.handleClick(e);
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**After:**
|
||||
```javascript
|
||||
eventManager.addHandler('click', 'myComponent-button', (e) => {
|
||||
this.handleClick(e);
|
||||
}, {
|
||||
targetSelector: '.my-button'
|
||||
});
|
||||
```
|
||||
|
||||
### From Event Delegation
|
||||
|
||||
**Before:**
|
||||
```javascript
|
||||
container.addEventListener('click', (e) => {
|
||||
const card = e.target.closest('.model-card');
|
||||
if (!card) return;
|
||||
|
||||
if (e.target.closest('.action-btn')) {
|
||||
this.handleAction(e);
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
**After:**
|
||||
```javascript
|
||||
eventManager.addHandler('click', 'container-actions', (e) => {
|
||||
const card = e.target.closest('.model-card');
|
||||
if (!card) return false;
|
||||
|
||||
if (e.target.closest('.action-btn')) {
|
||||
this.handleAction(e);
|
||||
return true;
|
||||
}
|
||||
}, {
|
||||
targetSelector: '.container'
|
||||
});
|
||||
```
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
1. **Reduced DOM listeners**: Single listener per event type instead of multiple
|
||||
2. **Conditional execution**: Handlers only run when conditions are met
|
||||
3. **Priority ordering**: Important handlers run first, avoiding unnecessary work
|
||||
4. **Automatic cleanup**: Prevents memory leaks from orphaned listeners
|
||||
5. **Centralized debugging**: All event handling flows through one system
|
||||
|
||||
## Debugging
|
||||
|
||||
Enable debug logging to trace event handling:
|
||||
```javascript
|
||||
// Add to EventManager.js for debugging
|
||||
console.log(`Handling ${eventType} event with ${handlers.length} handlers`);
|
||||
```
|
||||
|
||||
The event manager provides a foundation for coordinated, efficient event handling across the entire application.
|
||||
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
BIN
example_workflows/nunchaku-flux.1-dev.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 68 KiB |
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
1
example_workflows/nunchaku-flux.1-dev.json
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
1174
locales/de.json
Normal file
1174
locales/de.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/en.json
Normal file
1174
locales/en.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/es.json
Normal file
1174
locales/es.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/fr.json
Normal file
1174
locales/fr.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/ja.json
Normal file
1174
locales/ja.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/ko.json
Normal file
1174
locales/ko.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/ru.json
Normal file
1174
locales/ru.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/zh-CN.json
Normal file
1174
locales/zh-CN.json
Normal file
File diff suppressed because it is too large
Load Diff
1174
locales/zh-TW.json
Normal file
1174
locales/zh-TW.json
Normal file
File diff suppressed because it is too large
Load Diff
209
py/config.py
209
py/config.py
@@ -5,6 +5,7 @@ from typing import List
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import urllib.parse
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = 'nodes' not in sys.modules
|
||||
@@ -17,13 +18,18 @@ class Config:
|
||||
def __init__(self):
|
||||
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
|
||||
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
|
||||
# 路径映射字典, target to link mapping
|
||||
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
|
||||
# Path mapping dictionary, target to link mapping
|
||||
self._path_mappings = {}
|
||||
# 静态路由映射字典, target to route mapping
|
||||
# Static route mapping dictionary, target to route mapping
|
||||
self._route_mappings = {}
|
||||
self.loras_roots = self._init_lora_paths()
|
||||
self.checkpoints_roots = self._init_checkpoint_paths()
|
||||
# 在初始化时扫描符号链接
|
||||
self.checkpoints_roots = None
|
||||
self.unet_roots = None
|
||||
self.embeddings_roots = None
|
||||
self.base_models_roots = self._init_checkpoint_paths()
|
||||
self.embeddings_roots = self._init_embedding_paths()
|
||||
# Scan symbolic links during initialization
|
||||
self._scan_symbolic_links()
|
||||
|
||||
if not standalone_mode:
|
||||
@@ -33,34 +39,37 @@ class Config:
|
||||
def save_folder_paths_to_settings(self):
|
||||
"""Save folder paths to settings.json for standalone mode to use later"""
|
||||
try:
|
||||
# Check if we're running in ComfyUI mode (not standalone)
|
||||
if hasattr(folder_paths, "get_folder_paths") and not isinstance(folder_paths, type):
|
||||
# Get all relevant paths
|
||||
lora_paths = folder_paths.get_folder_paths("loras")
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffuser_paths = folder_paths.get_folder_paths("diffusers")
|
||||
unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Load existing settings
|
||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||
settings = {}
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Update settings with paths
|
||||
settings['folder_paths'] = {
|
||||
'loras': lora_paths,
|
||||
'checkpoints': checkpoint_paths,
|
||||
'diffusers': diffuser_paths,
|
||||
'unet': unet_paths
|
||||
}
|
||||
|
||||
# Save settings
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
logger.info("Saved folder paths to settings.json")
|
||||
# Check if we're running in ComfyUI mode (not standalone)
|
||||
# Load existing settings
|
||||
settings_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
|
||||
settings = {}
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings = json.load(f)
|
||||
|
||||
# Update settings with paths
|
||||
settings['folder_paths'] = {
|
||||
'loras': self.loras_roots,
|
||||
'checkpoints': self.checkpoints_roots,
|
||||
'unet': self.unet_roots,
|
||||
'embeddings': self.embeddings_roots,
|
||||
}
|
||||
|
||||
# Add default roots if there's only one item and key doesn't exist
|
||||
if len(self.loras_roots) == 1 and "default_lora_root" not in settings:
|
||||
settings["default_lora_root"] = self.loras_roots[0]
|
||||
|
||||
if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings:
|
||||
settings["default_checkpoint_root"] = self.checkpoints_roots[0]
|
||||
|
||||
if self.embeddings_roots and len(self.embeddings_roots) == 1 and "default_embedding_root" not in settings:
|
||||
settings["default_embedding_root"] = self.embeddings_roots[0]
|
||||
|
||||
# Save settings
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(settings, f, indent=2)
|
||||
|
||||
logger.info("Saved folder paths to settings.json")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to save folder paths: {e}")
|
||||
|
||||
@@ -82,15 +91,18 @@ class Config:
|
||||
return False
|
||||
|
||||
def _scan_symbolic_links(self):
|
||||
"""扫描所有 LoRA 和 Checkpoint 根目录中的符号链接"""
|
||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
||||
for root in self.loras_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.checkpoints_roots:
|
||||
for root in self.base_models_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
for root in self.embeddings_roots:
|
||||
self._scan_directory_links(root)
|
||||
|
||||
def _scan_directory_links(self, root: str):
|
||||
"""递归扫描目录中的符号链接"""
|
||||
"""Recursively scan symbolic links in a directory"""
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
@@ -105,40 +117,40 @@ class Config:
|
||||
logger.error(f"Error scanning links in {root}: {e}")
|
||||
|
||||
def add_path_mapping(self, link_path: str, target_path: str):
|
||||
"""添加符号链接路径映射
|
||||
target_path: 实际目标路径
|
||||
link_path: 符号链接路径
|
||||
"""Add a symbolic link path mapping
|
||||
target_path: actual target path
|
||||
link_path: symbolic link path
|
||||
"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
normalized_target = os.path.normpath(target_path).replace(os.sep, '/')
|
||||
# 保持原有的映射关系:目标路径 -> 链接路径
|
||||
# Keep the original mapping: target path -> link path
|
||||
self._path_mappings[normalized_target] = normalized_link
|
||||
logger.info(f"Added path mapping: {normalized_target} -> {normalized_link}")
|
||||
|
||||
def add_route_mapping(self, path: str, route: str):
|
||||
"""添加静态路由映射"""
|
||||
"""Add a static route mapping"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
self._route_mappings[normalized_path] = route
|
||||
# logger.info(f"Added route mapping: {normalized_path} -> {route}")
|
||||
|
||||
def map_path_to_link(self, path: str) -> str:
|
||||
"""将目标路径映射回符号链接路径"""
|
||||
"""Map a target path back to its symbolic link path"""
|
||||
normalized_path = os.path.normpath(path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_path.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为链接路径
|
||||
# If the path starts with the target path, replace with link path
|
||||
mapped_path = normalized_path.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return path
|
||||
|
||||
def map_link_to_path(self, link_path: str) -> str:
|
||||
"""将符号链接路径映射回实际路径"""
|
||||
"""Map a symbolic link path back to the actual path"""
|
||||
normalized_link = os.path.normpath(link_path).replace(os.sep, '/')
|
||||
# 检查路径是否包含在任何映射的目标路径中
|
||||
# Check if the path is contained in any mapped target path
|
||||
for target_path, link_path in self._path_mappings.items():
|
||||
if normalized_link.startswith(target_path):
|
||||
# 如果路径以目标路径开头,则替换为实际路径
|
||||
# If the path starts with the target path, replace with actual path
|
||||
mapped_path = normalized_link.replace(target_path, link_path, 1)
|
||||
return mapped_path
|
||||
return link_path
|
||||
@@ -177,47 +189,106 @@ class Config:
|
||||
"""Initialize and validate checkpoint paths from ComfyUI settings"""
|
||||
try:
|
||||
# Get checkpoint paths from folder_paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusers")
|
||||
unet_paths = folder_paths.get_folder_paths("unet")
|
||||
raw_checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
raw_unet_paths = folder_paths.get_folder_paths("unet")
|
||||
|
||||
# Combine all checkpoint-related paths
|
||||
all_paths = checkpoint_paths + diffusion_paths + unet_paths
|
||||
# Normalize and resolve symlinks for checkpoints, store mapping from resolved -> original
|
||||
checkpoint_map = {}
|
||||
for path in raw_checkpoint_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
checkpoint_map[real_path] = checkpoint_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Filter and normalize paths
|
||||
paths = sorted(set(path.replace(os.sep, "/")
|
||||
for path in all_paths
|
||||
if os.path.exists(path)), key=lambda p: p.lower())
|
||||
# Normalize and resolve symlinks for unet, store mapping from resolved -> original
|
||||
unet_map = {}
|
||||
for path in raw_unet_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(paths) if paths else "[]"))
|
||||
# Merge both maps and deduplicate by real path
|
||||
merged_map = {}
|
||||
for real_path, orig_path in {**checkpoint_map, **unet_map}.items():
|
||||
if real_path not in merged_map:
|
||||
merged_map[real_path] = orig_path
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower())
|
||||
|
||||
if not paths:
|
||||
# Split back into checkpoints and unet roots for class properties
|
||||
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()]
|
||||
self.unet_roots = [p for p in unique_paths if p in unet_map.values()]
|
||||
|
||||
all_paths = unique_paths
|
||||
|
||||
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
|
||||
|
||||
if not all_paths:
|
||||
logger.warning("No valid checkpoint folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
# 初始化路径映射,与 LoRA 路径处理方式相同
|
||||
for path in paths:
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
if real_path != path:
|
||||
self.add_path_mapping(path, real_path)
|
||||
# Initialize path mappings
|
||||
for original_path in all_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return paths
|
||||
return all_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing checkpoint paths: {e}")
|
||||
return []
|
||||
|
||||
def _init_embedding_paths(self) -> List[str]:
|
||||
"""Initialize and validate embedding paths from ComfyUI settings"""
|
||||
try:
|
||||
raw_paths = folder_paths.get_folder_paths("embeddings")
|
||||
|
||||
# Normalize and resolve symlinks, store mapping from resolved -> original
|
||||
path_map = {}
|
||||
for path in raw_paths:
|
||||
if os.path.exists(path):
|
||||
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
|
||||
path_map[real_path] = path_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
|
||||
|
||||
# Now sort and use only the deduplicated real paths
|
||||
unique_paths = sorted(path_map.values(), key=lambda p: p.lower())
|
||||
logger.info("Found embedding roots:" + ("\n - " + "\n - ".join(unique_paths) if unique_paths else "[]"))
|
||||
|
||||
if not unique_paths:
|
||||
logger.warning("No valid embeddings folders found in ComfyUI configuration")
|
||||
return []
|
||||
|
||||
for original_path in unique_paths:
|
||||
real_path = os.path.normpath(os.path.realpath(original_path)).replace(os.sep, '/')
|
||||
if real_path != original_path:
|
||||
self.add_path_mapping(original_path, real_path)
|
||||
|
||||
return unique_paths
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing embedding paths: {e}")
|
||||
return []
|
||||
|
||||
def get_preview_static_url(self, preview_path: str) -> str:
|
||||
"""Convert local preview path to static URL"""
|
||||
if not preview_path:
|
||||
return ""
|
||||
|
||||
real_path = os.path.realpath(preview_path).replace(os.sep, '/')
|
||||
|
||||
|
||||
# Find longest matching path (most specific match)
|
||||
best_match = ""
|
||||
best_route = ""
|
||||
|
||||
for path, route in self._route_mappings.items():
|
||||
if real_path.startswith(path):
|
||||
relative_path = os.path.relpath(real_path, path)
|
||||
return f'{route}/{relative_path.replace(os.sep, "/")}'
|
||||
|
||||
if real_path.startswith(path) and len(path) > len(best_match):
|
||||
best_match = path
|
||||
best_route = route
|
||||
|
||||
if best_match:
|
||||
relative_path = os.path.relpath(real_path, best_match).replace(os.sep, '/')
|
||||
safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')]
|
||||
safe_path = '/'.join(safe_parts)
|
||||
return f'{best_route}/{safe_path}'
|
||||
|
||||
return ""
|
||||
|
||||
# Global config instance
|
||||
|
||||
@@ -6,10 +6,8 @@ from pathlib import Path
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .config import config
|
||||
from .routes.lora_routes import LoraRoutes
|
||||
from .routes.api_routes import ApiRoutes
|
||||
from .services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from .routes.recipe_routes import RecipeRoutes
|
||||
from .routes.checkpoints_routes import CheckpointsRoutes
|
||||
from .routes.stats_routes import StatsRoutes
|
||||
from .routes.update_routes import UpdateRoutes
|
||||
from .routes.misc_routes import MiscRoutes
|
||||
@@ -17,6 +15,7 @@ from .routes.example_images_routes import ExampleImagesRoutes
|
||||
from .services.service_registry import ServiceRegistry
|
||||
from .services.settings_manager import settings
|
||||
from .utils.example_images_migration import ExampleImagesMigration
|
||||
from .services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -28,12 +27,28 @@ class LoraManager:
|
||||
|
||||
@classmethod
|
||||
def add_routes(cls):
|
||||
"""Initialize and register all routes"""
|
||||
"""Initialize and register all routes using the new refactored architecture"""
|
||||
app = PromptServer.instance.app
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
# Add static route for example images if the path exists in settings
|
||||
@@ -62,7 +77,7 @@ class LoraManager:
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
preview_path = f'/checkpoints_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
@@ -79,21 +94,45 @@ class LoraManager:
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(config.embeddings_roots, start=1):
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
if root in config._path_mappings.values():
|
||||
for target, link in config._path_mappings.items():
|
||||
if link == root:
|
||||
real_root = target
|
||||
break
|
||||
# Add static route for original path
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {real_root}")
|
||||
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(real_root)
|
||||
|
||||
# Add static routes for symlink target paths
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
if target_path not in added_targets:
|
||||
# Determine if this is a checkpoint or lora link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.checkpoints_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.checkpoints_roots)
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(cp_root in link_path for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(cp_root in target_path for cp_root in config.base_models_roots)
|
||||
is_embedding = any(emb_root in link_path for emb_root in config.embeddings_roots)
|
||||
is_embedding = is_embedding or any(emb_root in target_path for emb_root in config.embeddings_roots)
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
@@ -106,39 +145,46 @@ class LoraManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# Add static route for locales JSON files
|
||||
if os.path.exists(config.i18n_path):
|
||||
app.router.add_static('/locales', config.i18n_path)
|
||||
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
|
||||
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
stats_routes = StatsRoutes()
|
||||
# Register default model types with the factory
|
||||
register_default_model_types()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
stats_routes.setup_routes(app) # Add statistics routes
|
||||
ApiRoutes.setup_routes(app)
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
# Setup non-model-specific routes
|
||||
stats_routes = StatsRoutes()
|
||||
stats_routes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app) # Register miscellaneous routes
|
||||
ExampleImagesRoutes.setup_routes(app) # Register example images routes
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
logger.info(f"LoRA Manager: Set up routes for {len(ModelServiceFactory.get_registered_types())} model types: {', '.join(ModelServiceFactory.get_registered_types())}")
|
||||
|
||||
@classmethod
|
||||
async def _initialize_services(cls):
|
||||
"""Initialize all services using the ServiceRegistry"""
|
||||
try:
|
||||
# Ensure aiohttp access logger is configured with reduced verbosity
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Initialize CivitaiClient first to ensure it's ready for other services
|
||||
await ServiceRegistry.get_civitai_client()
|
||||
|
||||
@@ -151,28 +197,273 @@ class LoraManager:
|
||||
# Initialize scanners in background
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Initialize recipe scanner if needed
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
|
||||
# Initialize metadata collector if not in standalone mode
|
||||
if not STANDALONE_MODE:
|
||||
from .metadata_collector import init as init_metadata
|
||||
init_metadata()
|
||||
logger.debug("Metadata collector initialized")
|
||||
|
||||
# Create low-priority initialization tasks
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
init_tasks = [
|
||||
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'),
|
||||
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'),
|
||||
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'),
|
||||
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
|
||||
]
|
||||
|
||||
await ExampleImagesMigration.check_and_run_migrations()
|
||||
|
||||
# Schedule post-initialization tasks to run after scanners complete
|
||||
asyncio.create_task(
|
||||
cls._run_post_initialization_tasks(init_tasks),
|
||||
name='post_init_tasks'
|
||||
)
|
||||
|
||||
logger.info("LoRA Manager: All services initialized and background tasks scheduled")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _run_post_initialization_tasks(cls, init_tasks):
|
||||
"""Run post-initialization tasks after all scanners complete"""
|
||||
try:
|
||||
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...")
|
||||
|
||||
# Wait for all scanner initialization tasks to complete
|
||||
await asyncio.gather(*init_tasks, return_exceptions=True)
|
||||
|
||||
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...")
|
||||
|
||||
# Run post-initialization tasks
|
||||
post_tasks = [
|
||||
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
|
||||
asyncio.create_task(cls._cleanup_example_images_folders(), name='cleanup_example_images'),
|
||||
# Add more post-initialization tasks here as needed
|
||||
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
||||
]
|
||||
|
||||
# Run all post-initialization tasks
|
||||
results = await asyncio.gather(*post_tasks, return_exceptions=True)
|
||||
|
||||
# Log results
|
||||
for i, result in enumerate(results):
|
||||
task_name = post_tasks[i].get_name()
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Post-initialization task '{task_name}' failed: {result}")
|
||||
else:
|
||||
logger.debug(f"Post-initialization task '{task_name}' completed successfully")
|
||||
|
||||
logger.debug("LoRA Manager: All post-initialization tasks completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files(cls):
|
||||
"""Clean up .bak files in all model roots"""
|
||||
try:
|
||||
logger.debug("Starting cleanup of .bak files in model directories...")
|
||||
|
||||
# Collect all model roots
|
||||
all_roots = set()
|
||||
all_roots.update(config.loras_roots)
|
||||
all_roots.update(config.base_models_roots)
|
||||
all_roots.update(config.embeddings_roots)
|
||||
|
||||
total_deleted = 0
|
||||
total_size_freed = 0
|
||||
|
||||
for root_path in all_roots:
|
||||
if not os.path.exists(root_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path)
|
||||
total_deleted += deleted_count
|
||||
total_size_freed += size_freed
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
|
||||
|
||||
# Yield control periodically
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
if total_deleted > 0:
|
||||
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total")
|
||||
else:
|
||||
logger.debug("Backup cleanup completed: no .bak files found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during backup file cleanup: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_backup_files_in_directory(cls, directory_path: str):
|
||||
"""Clean up .bak files in a specific directory recursively
|
||||
|
||||
Args:
|
||||
directory_path: Path to the directory to clean
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (number of files deleted, total size freed in bytes)
|
||||
"""
|
||||
deleted_count = 0
|
||||
size_freed = 0
|
||||
visited_paths = set()
|
||||
|
||||
def cleanup_recursive(path):
|
||||
nonlocal deleted_count, size_freed
|
||||
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'):
|
||||
file_size = entry.stat().st_size
|
||||
os.remove(entry.path)
|
||||
deleted_count += 1
|
||||
size_freed += file_size
|
||||
logger.debug(f"Deleted .bak file: {entry.path}")
|
||||
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
cleanup_recursive(entry.path)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not delete .bak file {entry.path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory {path} for .bak files: {e}")
|
||||
|
||||
# Run the recursive cleanup in a thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, cleanup_recursive, directory_path)
|
||||
|
||||
return deleted_count, size_freed
|
||||
|
||||
@classmethod
|
||||
async def _cleanup_example_images_folders(cls):
|
||||
"""Clean up invalid or empty folders in example images directory"""
|
||||
try:
|
||||
example_images_path = settings.get('example_images_path')
|
||||
if not example_images_path or not os.path.exists(example_images_path):
|
||||
logger.debug("Example images path not configured or doesn't exist, skipping cleanup")
|
||||
return
|
||||
|
||||
logger.debug(f"Starting cleanup of example images folders in: {example_images_path}")
|
||||
|
||||
# Get all scanner instances to check hash validity
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
total_folders_checked = 0
|
||||
empty_folders_removed = 0
|
||||
invalid_hash_folders_removed = 0
|
||||
|
||||
# Scan the example images directory
|
||||
try:
|
||||
with os.scandir(example_images_path) as it:
|
||||
for entry in it:
|
||||
if not entry.is_dir(follow_symlinks=False):
|
||||
continue
|
||||
|
||||
folder_name = entry.name
|
||||
folder_path = entry.path
|
||||
total_folders_checked += 1
|
||||
|
||||
try:
|
||||
# Check if folder is empty
|
||||
is_empty = cls._is_folder_empty(folder_path)
|
||||
if is_empty:
|
||||
logger.debug(f"Removing empty example images folder: {folder_name}")
|
||||
await cls._remove_folder_safely(folder_path)
|
||||
empty_folders_removed += 1
|
||||
continue
|
||||
|
||||
# Check if folder name is a valid SHA256 hash (64 hex characters)
|
||||
if len(folder_name) != 64 or not all(c in '0123456789abcdefABCDEF' for c in folder_name):
|
||||
logger.debug(f"Removing invalid hash folder: {folder_name}")
|
||||
await cls._remove_folder_safely(folder_path)
|
||||
invalid_hash_folders_removed += 1
|
||||
continue
|
||||
|
||||
# Check if hash exists in any of the scanners
|
||||
hash_exists = (
|
||||
lora_scanner.has_hash(folder_name) or
|
||||
checkpoint_scanner.has_hash(folder_name) or
|
||||
embedding_scanner.has_hash(folder_name)
|
||||
)
|
||||
|
||||
if not hash_exists:
|
||||
logger.debug(f"Removing example images folder for deleted model: {folder_name}")
|
||||
await cls._remove_folder_safely(folder_path)
|
||||
invalid_hash_folders_removed += 1
|
||||
continue
|
||||
|
||||
logger.debug(f"Keeping valid example images folder: {folder_name}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing example images folder {folder_name}: {e}")
|
||||
|
||||
# Yield control periodically
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning example images directory: {e}")
|
||||
return
|
||||
|
||||
# Log final cleanup report
|
||||
total_removed = empty_folders_removed + invalid_hash_folders_removed
|
||||
if total_removed > 0:
|
||||
logger.info(f"Example images cleanup completed: checked {total_folders_checked} folders, "
|
||||
f"removed {empty_folders_removed} empty folders and {invalid_hash_folders_removed} "
|
||||
f"folders for deleted/invalid models (total: {total_removed} removed)")
|
||||
else:
|
||||
logger.info(f"Example images cleanup completed: checked {total_folders_checked} folders, "
|
||||
f"no cleanup needed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def _is_folder_empty(cls, folder_path: str) -> bool:
|
||||
"""Check if a folder is empty
|
||||
|
||||
Args:
|
||||
folder_path: Path to the folder to check
|
||||
|
||||
Returns:
|
||||
bool: True if folder is empty, False otherwise
|
||||
"""
|
||||
try:
|
||||
with os.scandir(folder_path) as it:
|
||||
return not any(it)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error checking if folder is empty {folder_path}: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def _remove_folder_safely(cls, folder_path: str):
|
||||
"""Safely remove a folder and all its contents
|
||||
|
||||
Args:
|
||||
folder_path: Path to the folder to remove
|
||||
"""
|
||||
try:
|
||||
import shutil
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(None, shutil.rmtree, folder_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to remove folder {folder_path}: {e}")
|
||||
|
||||
@classmethod
|
||||
async def _cleanup(cls, app):
|
||||
"""Cleanup resources using ServiceRegistry"""
|
||||
|
||||
@@ -26,98 +26,179 @@ class MetadataHook:
|
||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||
return
|
||||
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
# Detect whether we're using the new async version of ComfyUI
|
||||
is_async = False
|
||||
map_node_func_name = '_map_node_over_list'
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
if hasattr(execution, '_async_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._async_map_node_over_list)
|
||||
map_node_func_name = '_async_map_node_over_list'
|
||||
elif hasattr(execution, '_map_node_over_list'):
|
||||
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
# Make map_node_over_list public to avoid it being hidden by hooks
|
||||
execution.map_node_over_list = original_map_node_over_list
|
||||
if is_async:
|
||||
print("Detected async ComfyUI execution, installing async metadata hooks")
|
||||
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||
else:
|
||||
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||
MetadataHook._install_sync_hooks(execution)
|
||||
|
||||
print("Metadata collection hooks installed for runtime values")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error installing metadata hooks: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _install_sync_hooks(execution):
|
||||
"""Install hooks for synchronous execution model"""
|
||||
# Store the original _map_node_over_list function
|
||||
original_map_node_over_list = execution._map_node_over_list
|
||||
|
||||
# Define the wrapped _map_node_over_list function
|
||||
def map_node_over_list_with_metadata(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection on GraphBuilder.set_default_prefix
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record inputs before execution
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
|
||||
# After execution, collect outputs for relevant nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
# Get the current prompt_id from the registry
|
||||
registry = MetadataRegistry()
|
||||
prompt_id = registry.current_prompt_id
|
||||
|
||||
if prompt_id is not None:
|
||||
# Get node class type
|
||||
class_type = obj.__class__.__name__
|
||||
|
||||
# Unique ID might be available through the obj if it has a unique_id field
|
||||
node_id = getattr(obj, 'unique_id', None)
|
||||
if node_id is None and pre_execute_cb:
|
||||
# Try to extract node_id through reflection
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if 'unique_id' in frame.f_locals:
|
||||
node_id = frame.f_locals['unique_id']
|
||||
break
|
||||
frame = frame.f_back
|
||||
|
||||
# Record outputs after execution
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
def execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions
|
||||
execution._map_node_over_list = map_node_over_list_with_metadata
|
||||
execution.execute = execute_with_prompt_tracking
|
||||
|
||||
@staticmethod
|
||||
def _install_async_hooks(execution, map_node_func_name='_async_map_node_over_list'):
|
||||
"""Install hooks for asynchronous execution model"""
|
||||
# Store the original _async_map_node_over_list function
|
||||
original_map_node_over_list = getattr(execution, map_node_func_name)
|
||||
|
||||
# Wrapped async function, compatible with both stable and nightly
|
||||
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs):
|
||||
hidden_inputs = kwargs.get('hidden_inputs', None)
|
||||
# Only collect metadata when calling the main function of nodes
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Call original function with all args/kwargs
|
||||
results = await original_map_node_over_list(
|
||||
prompt_id, unique_id, obj, input_data_all, func,
|
||||
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
|
||||
)
|
||||
|
||||
if func == obj.FUNCTION and hasattr(obj, '__class__'):
|
||||
try:
|
||||
registry = MetadataRegistry()
|
||||
if prompt_id is not None:
|
||||
class_type = obj.__class__.__name__
|
||||
node_id = unique_id
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
# Also hook the execute function to track the current prompt_id
|
||||
original_execute = execution.execute
|
||||
|
||||
async def async_execute_with_prompt_tracking(*args, **kwargs):
|
||||
if len(args) >= 7: # Check if we have enough arguments
|
||||
server, prompt, caches, node_id, extra_data, executed, prompt_id = args[:7]
|
||||
registry = MetadataRegistry()
|
||||
|
||||
# Start collection if this is a new prompt
|
||||
if not registry.current_prompt_id or registry.current_prompt_id != prompt_id:
|
||||
registry.start_collection(prompt_id)
|
||||
|
||||
# Store the dynprompt reference for node lookups
|
||||
if hasattr(prompt, 'original_prompt'):
|
||||
registry.set_current_prompt(prompt)
|
||||
|
||||
# Execute the original function
|
||||
return await original_execute(*args, **kwargs)
|
||||
|
||||
# Replace the functions with async versions
|
||||
setattr(execution, map_node_func_name, async_map_node_over_list_with_metadata)
|
||||
execution.execute = async_execute_with_prompt_tracking
|
||||
|
||||
@@ -238,25 +238,45 @@ class MetadataProcessor:
|
||||
|
||||
pos_conditioning = metadata[PROMPTS][sampler_id].get("pos_conditioning")
|
||||
neg_conditioning = metadata[PROMPTS][sampler_id].get("neg_conditioning")
|
||||
|
||||
# Try to match conditioning objects with those stored by CLIPTextEncodeExtractor
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if pos_conditioning is not None and id(prompt_data["conditioning"]) == id(pos_conditioning):
|
||||
result["prompt"] = prompt_data.get("text", "")
|
||||
|
||||
# Helper function to recursively find prompt text for a conditioning object
|
||||
def find_prompt_text_for_conditioning(conditioning_obj, is_positive=True):
|
||||
if conditioning_obj is None:
|
||||
return ""
|
||||
|
||||
if neg_conditioning is not None and id(prompt_data["conditioning"]) == id(neg_conditioning):
|
||||
result["negative_prompt"] = prompt_data.get("text", "")
|
||||
# Try to match conditioning objects with those stored by extractors
|
||||
for prompt_node_id, prompt_data in metadata[PROMPTS].items():
|
||||
# For nodes with single conditioning output
|
||||
if "conditioning" in prompt_data:
|
||||
if id(prompt_data["conditioning"]) == id(conditioning_obj):
|
||||
return prompt_data.get("text", "")
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if is_positive and "positive_encoded" in prompt_data:
|
||||
if id(prompt_data["positive_encoded"]) == id(conditioning_obj):
|
||||
if "positive_text" in prompt_data:
|
||||
return prompt_data["positive_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_pos_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=True)
|
||||
|
||||
if not is_positive and "negative_encoded" in prompt_data:
|
||||
if id(prompt_data["negative_encoded"]) == id(conditioning_obj):
|
||||
if "negative_text" in prompt_data:
|
||||
return prompt_data["negative_text"]
|
||||
else:
|
||||
orig_conditioning = prompt_data.get("orig_neg_cond", None)
|
||||
if orig_conditioning is not None:
|
||||
# Recursively find the prompt text for the original conditioning
|
||||
return find_prompt_text_for_conditioning(orig_conditioning, is_positive=False)
|
||||
|
||||
# For nodes with separate pos_conditioning and neg_conditioning outputs (like TSC_EfficientLoader)
|
||||
if "positive_encoded" in prompt_data:
|
||||
if pos_conditioning is not None and id(prompt_data["positive_encoded"]) == id(pos_conditioning):
|
||||
result["prompt"] = prompt_data.get("positive_text", "")
|
||||
|
||||
if "negative_encoded" in prompt_data:
|
||||
if neg_conditioning is not None and id(prompt_data["negative_encoded"]) == id(neg_conditioning):
|
||||
result["negative_prompt"] = prompt_data.get("negative_text", "")
|
||||
return ""
|
||||
|
||||
# Find prompt texts using the helper function
|
||||
result["prompt"] = find_prompt_text_for_conditioning(pos_conditioning, is_positive=True)
|
||||
result["negative_prompt"] = find_prompt_text_for_conditioning(neg_conditioning, is_positive=False)
|
||||
|
||||
return result
|
||||
|
||||
@@ -319,44 +339,8 @@ class MetadataProcessor:
|
||||
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
|
||||
|
||||
if is_custom_advanced:
|
||||
# For SamplerCustomAdvanced, trace specific inputs
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
# For SamplerCustomAdvanced, use the new handler method
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
else:
|
||||
# For standard samplers, match conditioning objects to prompts
|
||||
@@ -381,6 +365,9 @@ class MetadataProcessor:
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "negative", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
|
||||
# For SamplerCustom, handle any additional parameters
|
||||
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
|
||||
|
||||
# Size extraction is same for all sampler types
|
||||
# Check if the sampler itself has size information (from latent_image)
|
||||
@@ -434,3 +421,59 @@ class MetadataProcessor:
|
||||
"""Convert metadata to JSON string"""
|
||||
params = MetadataProcessor.to_dict(metadata, id)
|
||||
return json.dumps(params, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
|
||||
"""
|
||||
Handle parameter extraction for SamplerCustomAdvanced nodes
|
||||
|
||||
Parameters:
|
||||
- metadata: The workflow metadata
|
||||
- prompt: The prompt object containing node connections
|
||||
- primary_sampler_id: ID of the SamplerCustomAdvanced node
|
||||
- params: Parameters dictionary to update
|
||||
"""
|
||||
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
|
||||
return
|
||||
|
||||
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
|
||||
|
||||
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
|
||||
if "sigmas" in sampler_inputs:
|
||||
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
|
||||
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
|
||||
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
|
||||
params["steps"] = scheduler_params.get("steps")
|
||||
params["scheduler"] = scheduler_params.get("scheduler")
|
||||
|
||||
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
|
||||
if "sampler" in sampler_inputs:
|
||||
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
|
||||
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
|
||||
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
|
||||
params["sampler"] = sampler_params.get("sampler_name")
|
||||
|
||||
# 3. Trace guider input for CFGGuider and CLIPTextEncode
|
||||
if "guider" in sampler_inputs:
|
||||
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
|
||||
if guider_node_id and guider_node_id in prompt.original_prompt:
|
||||
# Check if the guider node is a CFGGuider
|
||||
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
|
||||
# Extract cfg value from the CFGGuider
|
||||
if guider_node_id in metadata.get(SAMPLING, {}):
|
||||
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
|
||||
params["cfg_scale"] = cfg_params.get("cfg")
|
||||
|
||||
# Find CLIPTextEncode for positive prompt
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
# Find CLIPTextEncode for negative prompt
|
||||
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
|
||||
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
|
||||
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
|
||||
else:
|
||||
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
|
||||
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
|
||||
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
|
||||
|
||||
@@ -117,15 +117,15 @@ class CLIPTextEncodeExtractor(NodeMetadataExtractor):
|
||||
if isinstance(outputs[0], tuple) and len(outputs[0]) > 0:
|
||||
conditioning = outputs[0][0]
|
||||
metadata[PROMPTS][node_id]["conditioning"] = conditioning
|
||||
|
||||
class SamplerExtractor(NodeMetadataExtractor):
|
||||
|
||||
# Base Sampler Extractor to reduce code redundancy
|
||||
class BaseSamplerExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for sampler nodes with common functionality"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
def extract_sampling_params(node_id, inputs, metadata, param_keys):
|
||||
"""Extract sampling parameters from inputs"""
|
||||
sampling_params = {}
|
||||
for key in ["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]:
|
||||
for key in param_keys:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
@@ -134,7 +134,10 @@ class SamplerExtractor(NodeMetadataExtractor):
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_conditioning(node_id, inputs, metadata):
|
||||
"""Extract conditioning objects from inputs"""
|
||||
# Store the conditioning objects directly in metadata for later matching
|
||||
pos_conditioning = inputs.get("positive", None)
|
||||
neg_conditioning = inputs.get("negative", None)
|
||||
@@ -146,7 +149,10 @@ class SamplerExtractor(NodeMetadataExtractor):
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
|
||||
@staticmethod
|
||||
def extract_latent_dimensions(node_id, inputs, metadata):
|
||||
"""Extract dimensions from latent image"""
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
@@ -167,59 +173,106 @@ class SamplerExtractor(NodeMetadataExtractor):
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
|
||||
class KSamplerAdvancedExtractor(NodeMetadataExtractor):
|
||||
|
||||
class SamplerExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
sampling_params = {}
|
||||
for key in ["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]:
|
||||
if key in inputs:
|
||||
sampling_params[key] = inputs[key]
|
||||
|
||||
metadata[SAMPLING][node_id] = {
|
||||
"parameters": sampling_params,
|
||||
"node_id": node_id,
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Store the conditioning objects directly in metadata for later matching
|
||||
pos_conditioning = inputs.get("positive", None)
|
||||
neg_conditioning = inputs.get("negative", None)
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
# Save conditioning objects in metadata for later matching
|
||||
if pos_conditioning is not None or neg_conditioning is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
class KSamplerAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
# Extract conditioning objects
|
||||
BaseSamplerExtractor.extract_conditioning(node_id, inputs, metadata)
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerBasicPipe and KSampler_inspire_pipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["seed", "steps", "cfg", "sampler_name", "scheduler", "denoise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class KSamplerAdvancedBasicPipeExtractor(BaseSamplerExtractor):
|
||||
"""Extractor for KSamplerAdvancedBasicPipe nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Extract common sampling parameters
|
||||
BaseSamplerExtractor.extract_sampling_params(
|
||||
node_id, inputs, metadata,
|
||||
["noise_seed", "steps", "cfg", "sampler_name", "scheduler", "add_noise"]
|
||||
)
|
||||
|
||||
# Extract conditioning objects from basic_pipe
|
||||
if "basic_pipe" in inputs and inputs["basic_pipe"] is not None:
|
||||
basic_pipe = inputs["basic_pipe"]
|
||||
# Typically, basic_pipe structure is (model, clip, vae, positive, negative)
|
||||
if isinstance(basic_pipe, tuple) and len(basic_pipe) >= 5:
|
||||
pos_conditioning = basic_pipe[3] # positive is at index 3
|
||||
neg_conditioning = basic_pipe[4] # negative is at index 4
|
||||
|
||||
# Save conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["pos_conditioning"] = pos_conditioning
|
||||
metadata[PROMPTS][node_id]["neg_conditioning"] = neg_conditioning
|
||||
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
class TSCSamplerBaseExtractor(NodeMetadataExtractor):
|
||||
"""Base extractor for handling TSC sampler node outputs"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Store vae_decode setting for later use in update
|
||||
@@ -273,7 +326,6 @@ class TSCSamplerBaseExtractor(NodeMetadataExtractor):
|
||||
metadata[IMAGES]["first_decode"] = metadata[IMAGES][node_id]
|
||||
|
||||
class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
|
||||
"""Extractor for TSC_KSampler nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
@@ -284,11 +336,10 @@ class TSCKSamplerExtractor(SamplerExtractor, TSCSamplerBaseExtractor):
|
||||
|
||||
|
||||
class TSCKSamplerAdvancedExtractor(KSamplerAdvancedExtractor, TSCSamplerBaseExtractor):
|
||||
"""Extractor for TSC_KSamplerAdvanced nodes"""
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
# Call parent extract methods
|
||||
SamplerExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
KSamplerAdvancedExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
TSCSamplerBaseExtractor.extract(node_id, inputs, outputs, metadata)
|
||||
|
||||
# Update method is inherited from TSCSamplerBaseExtractor
|
||||
@@ -461,7 +512,7 @@ class BasicSchedulerExtractor(NodeMetadataExtractor):
|
||||
IS_SAMPLER: False # Mark as non-primary sampler
|
||||
}
|
||||
|
||||
class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
|
||||
class SamplerCustomAdvancedExtractor(BaseSamplerExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
@@ -480,26 +531,8 @@ class SamplerCustomAdvancedExtractor(NodeMetadataExtractor):
|
||||
IS_SAMPLER: True # Add sampler flag
|
||||
}
|
||||
|
||||
# Extract latent image dimensions if available
|
||||
if "latent_image" in inputs and inputs["latent_image"] is not None:
|
||||
latent = inputs["latent_image"]
|
||||
if isinstance(latent, dict) and "samples" in latent:
|
||||
# Extract dimensions from latent tensor
|
||||
samples = latent["samples"]
|
||||
if hasattr(samples, "shape") and len(samples.shape) >= 3:
|
||||
# Correct shape interpretation: [batch_size, channels, height/8, width/8]
|
||||
# Multiply by 8 to get actual pixel dimensions
|
||||
height = int(samples.shape[2] * 8)
|
||||
width = int(samples.shape[3] * 8)
|
||||
|
||||
if SIZE not in metadata:
|
||||
metadata[SIZE] = {}
|
||||
|
||||
metadata[SIZE][node_id] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"node_id": node_id
|
||||
}
|
||||
# Extract latent dimensions
|
||||
BaseSamplerExtractor.extract_latent_dimensions(node_id, inputs, metadata)
|
||||
|
||||
import json
|
||||
|
||||
@@ -569,21 +602,63 @@ class CFGGuiderExtractor(NodeMetadataExtractor):
|
||||
|
||||
metadata[SAMPLING][node_id]["parameters"]["cfg"] = cfg_value
|
||||
|
||||
class CR_ApplyControlNetStackExtractor(NodeMetadataExtractor):
|
||||
@staticmethod
|
||||
def extract(node_id, inputs, outputs, metadata):
|
||||
if not inputs:
|
||||
return
|
||||
|
||||
# Save the original conditioning inputs
|
||||
base_positive = inputs.get("base_positive")
|
||||
base_negative = inputs.get("base_negative")
|
||||
|
||||
if base_positive is not None or base_negative is not None:
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["orig_pos_cond"] = base_positive
|
||||
metadata[PROMPTS][node_id]["orig_neg_cond"] = base_negative
|
||||
|
||||
@staticmethod
|
||||
def update(node_id, outputs, metadata):
|
||||
# Extract transformed conditionings from outputs
|
||||
# outputs structure: [(base_positive, base_negative, show_help, )]
|
||||
if outputs and isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, tuple) and len(first_output) >= 2:
|
||||
transformed_positive = first_output[0]
|
||||
transformed_negative = first_output[1]
|
||||
|
||||
# Save transformed conditioning objects in metadata
|
||||
if node_id not in metadata[PROMPTS]:
|
||||
metadata[PROMPTS][node_id] = {"node_id": node_id}
|
||||
|
||||
metadata[PROMPTS][node_id]["positive_encoded"] = transformed_positive
|
||||
metadata[PROMPTS][node_id]["negative_encoded"] = transformed_negative
|
||||
|
||||
# Registry of node-specific extractors
|
||||
# Keys are node class names
|
||||
NODE_EXTRACTORS = {
|
||||
# Sampling
|
||||
"KSampler": SamplerExtractor,
|
||||
"KSamplerAdvanced": KSamplerAdvancedExtractor,
|
||||
"SamplerCustom": KSamplerAdvancedExtractor,
|
||||
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
|
||||
"ClownsharKSampler_Beta": SamplerExtractor,
|
||||
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
|
||||
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
|
||||
"KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSamplerAdvancedBasicPipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-impact-pack
|
||||
"KSampler_inspire_pipe": KSamplerBasicPipeExtractor, # comfyui-inspire-pack
|
||||
"KSamplerAdvanced_inspire_pipe": KSamplerAdvancedBasicPipeExtractor, # comfyui-inspire-pack
|
||||
# Sampling Selectors
|
||||
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
|
||||
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
|
||||
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
|
||||
# Loaders
|
||||
"CheckpointLoaderSimple": CheckpointLoaderExtractor,
|
||||
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
|
||||
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
|
||||
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
|
||||
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
|
||||
@@ -594,6 +669,8 @@ NODE_EXTRACTORS = {
|
||||
"CLIPTextEncodeFlux": CLIPTextEncodeFluxExtractor, # Add CLIPTextEncodeFlux
|
||||
"WAS_Text_to_Conditioning": CLIPTextEncodeExtractor,
|
||||
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
|
||||
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
|
||||
# Latent
|
||||
"EmptyLatentImage": ImageSizeExtractor,
|
||||
# Flux
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import logging
|
||||
import re
|
||||
from nodes import LoraLoader
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
import asyncio
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list, nunchaku_load_lora
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraManagerLoader:
|
||||
NAME = "Lora Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
@@ -17,7 +18,8 @@ class LoraManagerLoader:
|
||||
"model": ("MODEL",),
|
||||
# "clip": ("CLIP",),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -37,19 +39,39 @@ class LoraManagerLoader:
|
||||
|
||||
clip = kwargs.get('clip', None)
|
||||
lora_stack = kwargs.get('lora_stack', None)
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the provided path and strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
@@ -66,13 +88,19 @@ class LoraManagerLoader:
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the resolved path with separate strengths
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
@@ -102,4 +130,142 @@ class LoraManagerLoader:
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
|
||||
class LoraManagerTextLoader:
|
||||
NAME = "LoRA Text Loader (LoraManager)"
|
||||
CATEGORY = "Lora Manager/loaders"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"lora_syntax": (IO.STRING, {
|
||||
"defaultInput": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
|
||||
}),
|
||||
},
|
||||
"optional": {
|
||||
"clip": ("CLIP",),
|
||||
"lora_stack": ("LORA_STACK",),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
|
||||
FUNCTION = "load_loras_from_text"
|
||||
|
||||
def parse_lora_syntax(self, text):
|
||||
"""Parse LoRA syntax from text input."""
|
||||
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
|
||||
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
|
||||
loras = []
|
||||
for match in matches:
|
||||
lora_name = match[0].strip()
|
||||
model_strength = float(match[1])
|
||||
clip_strength = float(match[2]) if match[2] else model_strength
|
||||
|
||||
loras.append({
|
||||
'name': lora_name,
|
||||
'model_strength': model_strength,
|
||||
'clip_strength': clip_strength
|
||||
})
|
||||
|
||||
return loras
|
||||
|
||||
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
|
||||
"""Load LoRAs based on text syntax input."""
|
||||
loaded_loras = []
|
||||
all_trigger_words = []
|
||||
|
||||
# Check if model is a Nunchaku Flux model - simplified approach
|
||||
is_nunchaku_model = False
|
||||
|
||||
try:
|
||||
model_wrapper = model.model.diffusion_model
|
||||
# Check if model is a Nunchaku Flux model using only class name
|
||||
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
|
||||
is_nunchaku_model = True
|
||||
logger.info("Detected Nunchaku Flux model")
|
||||
except (AttributeError, TypeError):
|
||||
# Not a model with the expected structure
|
||||
pass
|
||||
|
||||
# First process lora_stack if available
|
||||
if lora_stack:
|
||||
for lora_path, model_strength, clip_strength in lora_stack:
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# Use our custom function for Flux models
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged for Nunchaku models
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Extract lora name for trigger words lookup
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
all_trigger_words.extend(trigger_words)
|
||||
# Add clip strength to output if different from model strength (except for Nunchaku models)
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Parse and process LoRAs from text syntax
|
||||
parsed_loras = self.parse_lora_syntax(lora_syntax)
|
||||
for lora in parsed_loras:
|
||||
lora_name = lora['name']
|
||||
model_strength = lora['model_strength']
|
||||
clip_strength = lora['clip_strength']
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Apply the LoRA using the appropriate loader
|
||||
if is_nunchaku_model:
|
||||
# For Nunchaku models, use our custom function
|
||||
model = nunchaku_load_lora(model, lora_path, model_strength)
|
||||
# clip remains unchanged
|
||||
else:
|
||||
# Use default loader for standard models
|
||||
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
|
||||
|
||||
# Include clip strength in output if different from model strength and not a Nunchaku model
|
||||
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
|
||||
else:
|
||||
loaded_loras.append(f"{lora_name}: {model_strength}")
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# use ',, ' to separate trigger words for group mode
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format loaded_loras with support for both formats
|
||||
formatted_loras = []
|
||||
for item in loaded_loras:
|
||||
parts = item.split(":")
|
||||
lora_name = parts[0].strip()
|
||||
strength_parts = parts[1].strip().split(",")
|
||||
|
||||
if len(strength_parts) > 1:
|
||||
# Different model and clip strengths
|
||||
model_str = strength_parts[0].strip()
|
||||
clip_str = strength_parts[1].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
model_str = strength_parts[0].strip()
|
||||
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
|
||||
|
||||
formatted_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (model, clip, trigger_words_text, formatted_loras_text)
|
||||
@@ -1,9 +1,8 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import asyncio
|
||||
import os
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_lora_info, extract_lora_name, get_loras_list
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, extract_lora_name, get_loras_list
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -18,6 +17,7 @@ class LoraStacker:
|
||||
"required": {
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
@@ -43,7 +43,7 @@ class LoraStacker:
|
||||
# Get trigger words from existing stack entries
|
||||
for lora_path, _, _ in lora_stack:
|
||||
lora_name = extract_lora_name(lora_path)
|
||||
_, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
@@ -58,7 +58,7 @@ class LoraStacker:
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = asyncio.run(get_lora_info(lora_name))
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Add to stack without loading
|
||||
# replace '/' with os.sep to avoid different OS path format
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
import asyncio
|
||||
import re
|
||||
import numpy as np
|
||||
import folder_paths # type: ignore
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..services.checkpoint_scanner import CheckpointScanner
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
@@ -71,25 +69,20 @@ class SaveImage:
|
||||
FUNCTION = "process_image"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
async def get_lora_hash(self, lora_name):
|
||||
def get_lora_hash(self, lora_name):
|
||||
"""Get the lora hash from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
||||
|
||||
# Use the new direct filename lookup method
|
||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
return item.get('sha256')
|
||||
return None
|
||||
|
||||
async def get_checkpoint_hash(self, checkpoint_path):
|
||||
def get_checkpoint_hash(self, checkpoint_path):
|
||||
"""Get the checkpoint hash from cache"""
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
@@ -102,18 +95,10 @@ class SaveImage:
|
||||
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||
if hash_value:
|
||||
return hash_value
|
||||
|
||||
# Fallback to old method for compatibility
|
||||
cache = await scanner.get_cached_data()
|
||||
normalized_path = checkpoint_path.replace('\\', '/')
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
|
||||
return item.get('sha256')
|
||||
|
||||
return None
|
||||
|
||||
async def format_metadata(self, metadata_dict):
|
||||
def format_metadata(self, metadata_dict):
|
||||
"""Format metadata in the requested format similar to userComment example"""
|
||||
if not metadata_dict:
|
||||
return ""
|
||||
@@ -140,7 +125,7 @@ class SaveImage:
|
||||
|
||||
# Get hash for each lora
|
||||
for lora_name, strength in lora_matches:
|
||||
hash_value = await self.get_lora_hash(lora_name)
|
||||
hash_value = self.get_lora_hash(lora_name)
|
||||
if hash_value:
|
||||
lora_hashes[lora_name] = hash_value
|
||||
else:
|
||||
@@ -226,7 +211,7 @@ class SaveImage:
|
||||
checkpoint = metadata_dict.get('checkpoint')
|
||||
if checkpoint is not None:
|
||||
# Get model hash
|
||||
model_hash = await self.get_checkpoint_hash(checkpoint)
|
||||
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||
|
||||
# Extract basename without path
|
||||
checkpoint_name = os.path.basename(checkpoint)
|
||||
@@ -329,8 +314,7 @@ class SaveImage:
|
||||
raw_metadata = get_metadata()
|
||||
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
||||
|
||||
# Get or create metadata asynchronously
|
||||
metadata = asyncio.run(self.format_metadata(metadata_dict))
|
||||
metadata = self.format_metadata(metadata_dict)
|
||||
|
||||
# Process filename_prefix with pattern substitution
|
||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||
@@ -434,11 +418,15 @@ class SaveImage:
|
||||
# Make sure the output directory exists
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
|
||||
# Ensure images is always a list of images
|
||||
if len(images.shape) == 3: # Single image (height, width, channels)
|
||||
images = [images]
|
||||
else: # Multiple images (batch, height, width, channels)
|
||||
images = [img for img in images]
|
||||
# If images is already a list or array of images, do nothing; otherwise, convert to list
|
||||
if isinstance(images, (list, np.ndarray)):
|
||||
pass
|
||||
else:
|
||||
# Ensure images is always a list of images
|
||||
if len(images.shape) == 3: # Single image (height, width, channels)
|
||||
images = [images]
|
||||
else: # Multiple images (batch, height, width, channels)
|
||||
images = [img for img in images]
|
||||
|
||||
# Save all images
|
||||
results = self.save_images(
|
||||
|
||||
@@ -35,31 +35,11 @@ any_type = AnyType("*")
|
||||
# Common methods extracted from lora_loader.py and lora_stacker.py
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from ..services.lora_scanner import LoraScanner
|
||||
from ..config import config
|
||||
import copy
|
||||
import folder_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_lora_info(lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, [] # Fallback if not found
|
||||
|
||||
def extract_lora_name(lora_path):
|
||||
"""Extract the lora name from a lora path (e.g., 'IL\\aorunIllstrious.safetensors' -> 'aorunIllstrious')"""
|
||||
# Get the basename without extension
|
||||
@@ -81,4 +61,70 @@ def get_loras_list(kwargs):
|
||||
# Unexpected format
|
||||
else:
|
||||
logger.warning(f"Unexpected loras format: {type(loras_data)}")
|
||||
return []
|
||||
return []
|
||||
|
||||
def load_state_dict_in_safetensors(path, device="cpu", filter_prefix=""):
|
||||
"""Simplified version of load_state_dict_in_safetensors that just loads from a local path"""
|
||||
import safetensors.torch
|
||||
|
||||
state_dict = {}
|
||||
with safetensors.torch.safe_open(path, framework="pt", device=device) as f:
|
||||
for k in f.keys():
|
||||
if filter_prefix and not k.startswith(filter_prefix):
|
||||
continue
|
||||
state_dict[k.removeprefix(filter_prefix)] = f.get_tensor(k)
|
||||
return state_dict
|
||||
|
||||
def to_diffusers(input_lora):
|
||||
"""Simplified version of to_diffusers for Flux LoRA conversion"""
|
||||
import torch
|
||||
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
|
||||
from diffusers.loaders import FluxLoraLoaderMixin
|
||||
|
||||
if isinstance(input_lora, str):
|
||||
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
|
||||
else:
|
||||
tensors = {k: v for k, v in input_lora.items()}
|
||||
|
||||
# Convert FP8 tensors to BF16
|
||||
for k, v in tensors.items():
|
||||
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
|
||||
tensors[k] = v.to(torch.bfloat16)
|
||||
|
||||
new_tensors = FluxLoraLoaderMixin.lora_state_dict(tensors)
|
||||
new_tensors = convert_unet_state_dict_to_peft(new_tensors)
|
||||
|
||||
return new_tensors
|
||||
|
||||
def nunchaku_load_lora(model, lora_name, lora_strength):
|
||||
"""Load a Flux LoRA for Nunchaku model"""
|
||||
model_wrapper = model.model.diffusion_model
|
||||
transformer = model_wrapper.model
|
||||
|
||||
# Save the transformer temporarily
|
||||
model_wrapper.model = None
|
||||
ret_model = copy.deepcopy(model) # copy everything except the model
|
||||
ret_model_wrapper = ret_model.model.diffusion_model
|
||||
|
||||
# Restore the model and set it for the copy
|
||||
model_wrapper.model = transformer
|
||||
ret_model_wrapper.model = transformer
|
||||
|
||||
# Get full path to the LoRA file
|
||||
lora_path = folder_paths.get_full_path("loras", lora_name)
|
||||
ret_model_wrapper.loras.append((lora_path, lora_strength))
|
||||
|
||||
# Convert the LoRA to diffusers format
|
||||
sd = to_diffusers(lora_path)
|
||||
|
||||
# Handle embedding adjustment if needed
|
||||
if "transformer.x_embedder.lora_A.weight" in sd:
|
||||
new_in_channels = sd["transformer.x_embedder.lora_A.weight"].shape[1]
|
||||
assert new_in_channels % 4 == 0
|
||||
new_in_channels = new_in_channels // 4
|
||||
|
||||
old_in_channels = ret_model.model.model_config.unet_config["in_channels"]
|
||||
if old_in_channels < new_in_channels:
|
||||
ret_model.model.model_config.unet_config["in_channels"] = new_in_channels
|
||||
|
||||
return ret_model
|
||||
98
py/nodes/wanvideo_lora_select.py
Normal file
98
py/nodes/wanvideo_lora_select.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from comfy.comfy_types import IO # type: ignore
|
||||
import folder_paths # type: ignore
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import FlexibleOptionalInputType, any_type, get_loras_list
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class WanVideoLoraSelect:
|
||||
NAME = "WanVideo Lora Select (LoraManager)"
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"text": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"pysssss.autocomplete": False,
|
||||
"dynamicPrompts": True,
|
||||
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
|
||||
"placeholder": "LoRA syntax input: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
"optional": FlexibleOptionalInputType(any_type),
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
FUNCTION = "process_loras"
|
||||
|
||||
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs):
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
# Process existing prev_lora if available
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_loras:
|
||||
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
|
||||
|
||||
# Get blocks if available
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
# Process loras from kwargs with support for both old and new formats
|
||||
loras_from_widget = get_loras_list(kwargs)
|
||||
for lora in loras_from_widget:
|
||||
if not lora.get('active', False):
|
||||
continue
|
||||
|
||||
lora_name = lora['name']
|
||||
model_strength = float(lora['strength'])
|
||||
clip_strength = float(lora.get('clipStrength', model_strength))
|
||||
|
||||
# Get lora path and trigger words
|
||||
lora_path, trigger_words = get_lora_info(lora_name)
|
||||
|
||||
# Create lora item for WanVideo format
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_loras,
|
||||
}
|
||||
|
||||
# Add to list and collect active loras
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name, model_strength, clip_strength))
|
||||
|
||||
# Add trigger words to collection
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format trigger_words for output
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Format active_loras for output
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
# Different model and clip strengths
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
# Same strength for both
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
128
py/nodes/wanvideo_lora_select_from_text.py
Normal file
128
py/nodes/wanvideo_lora_select_from_text.py
Normal file
@@ -0,0 +1,128 @@
|
||||
from comfy.comfy_types import IO
|
||||
import folder_paths
|
||||
from ..utils.utils import get_lora_info
|
||||
from .utils import any_type
|
||||
import logging
|
||||
|
||||
# 初始化日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 定义新节点的类
|
||||
class WanVideoLoraSelectFromText:
|
||||
# 节点在UI中显示的名称
|
||||
NAME = "WanVideo Lora Select From Text (LoraManager)"
|
||||
# 节点所属的分类
|
||||
CATEGORY = "Lora Manager/stackers"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
|
||||
"merge_lora": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
|
||||
"lora_syntax": (IO.STRING, {
|
||||
"multiline": True,
|
||||
"defaultInput": True,
|
||||
"forceInput": True,
|
||||
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>"
|
||||
}),
|
||||
},
|
||||
|
||||
"optional": {
|
||||
"prev_lora": ("WANVIDLORA",),
|
||||
"blocks": ("BLOCKS",)
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
|
||||
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
|
||||
|
||||
FUNCTION = "process_loras_from_syntax"
|
||||
|
||||
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
|
||||
text_to_process = lora_syntax
|
||||
|
||||
blocks = kwargs.get('blocks', {})
|
||||
selected_blocks = blocks.get("selected_blocks", {})
|
||||
layer_filter = blocks.get("layer_filter", "")
|
||||
|
||||
loras_list = []
|
||||
all_trigger_words = []
|
||||
active_loras = []
|
||||
|
||||
prev_lora = kwargs.get('prev_lora', None)
|
||||
if prev_lora is not None:
|
||||
loras_list.extend(prev_lora)
|
||||
|
||||
if not merge_lora:
|
||||
low_mem_load = False
|
||||
|
||||
parts = text_to_process.split('<lora:')
|
||||
for part in parts[1:]:
|
||||
end_index = part.find('>')
|
||||
if end_index == -1:
|
||||
continue
|
||||
|
||||
content = part[:end_index]
|
||||
lora_parts = content.split(':')
|
||||
|
||||
lora_name_raw = ""
|
||||
model_strength = 1.0
|
||||
clip_strength = 1.0
|
||||
|
||||
if len(lora_parts) == 2:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = model_strength
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strength for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
elif len(lora_parts) >= 3:
|
||||
lora_name_raw = lora_parts[0].strip()
|
||||
try:
|
||||
model_strength = float(lora_parts[1])
|
||||
clip_strength = float(lora_parts[2])
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"Invalid strengths for LoRA '{lora_name_raw}'. Skipping.")
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
lora_path, trigger_words = get_lora_info(lora_name_raw)
|
||||
|
||||
lora_item = {
|
||||
"path": folder_paths.get_full_path("loras", lora_path),
|
||||
"strength": model_strength,
|
||||
"name": lora_path.split(".")[0],
|
||||
"blocks": selected_blocks,
|
||||
"layer_filter": layer_filter,
|
||||
"low_mem_load": low_mem_load,
|
||||
"merge_loras": merge_lora,
|
||||
}
|
||||
|
||||
loras_list.append(lora_item)
|
||||
active_loras.append((lora_name_raw, model_strength, clip_strength))
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
formatted_loras = []
|
||||
for name, model_strength, clip_strength in active_loras:
|
||||
if abs(model_strength - clip_strength) > 0.001:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
|
||||
else:
|
||||
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
|
||||
|
||||
active_loras_text = " ".join(formatted_loras)
|
||||
|
||||
return (loras_list, trigger_words_text, active_loras_text)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": WanVideoLoraSelectFromText
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"WanVideoLoraSelectFromText": "WanVideo Lora Select From Text (LoraManager)"
|
||||
}
|
||||
,
|
||||
@@ -119,10 +119,10 @@ class RecipeMetadataParser(ABC):
|
||||
# Check if exists locally
|
||||
if recipe_scanner and lora_entry['hash']:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
exists_locally = lora_scanner.has_lora_hash(lora_entry['hash'])
|
||||
exists_locally = lora_scanner.has_hash(lora_entry['hash'])
|
||||
if exists_locally:
|
||||
try:
|
||||
local_path = lora_scanner.get_lora_path_by_hash(lora_entry['hash'])
|
||||
local_path = lora_scanner.get_path_by_hash(lora_entry['hash'])
|
||||
lora_entry['existsLocally'] = True
|
||||
lora_entry['localPath'] = local_path
|
||||
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]
|
||||
|
||||
@@ -19,7 +19,7 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
LORA_HASHES_REGEX = r', Lora hashes:\s*"([^"]+)"'
|
||||
CIVITAI_RESOURCES_REGEX = r', Civitai resources:\s*(\[\{.*?\}\])'
|
||||
CIVITAI_METADATA_REGEX = r', Civitai metadata:\s*(\{.*?\})'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([a-zA-Z0-9_\.\-]+):([0-9.]+)>'
|
||||
EXTRANETS_REGEX = r'<(lora|hypernet):([^:]+):(-?[0-9.]+)>'
|
||||
MODEL_HASH_PATTERN = r'Model hash: ([a-zA-Z0-9]+)'
|
||||
VAE_HASH_PATTERN = r'VAE hash: ([a-zA-Z0-9]+)'
|
||||
|
||||
@@ -181,13 +181,30 @@ class AutomaticMetadataParser(RecipeMetadataParser):
|
||||
# First use Civitai resources if available (more reliable source)
|
||||
if metadata.get("civitai_resources"):
|
||||
for resource in metadata.get("civitai_resources", []):
|
||||
# --- Added: Parse 'air' field if present ---
|
||||
air = resource.get("air")
|
||||
if air:
|
||||
# Format: urn:air:sdxl:lora:civitai:1221007@1375651
|
||||
# Or: urn:air:sdxl:checkpoint:civitai:623891@2019115
|
||||
air_pattern = r"urn:air:[^:]+:(?P<type>[^:]+):civitai:(?P<modelId>\d+)@(?P<modelVersionId>\d+)"
|
||||
air_match = re.match(air_pattern, air)
|
||||
if air_match:
|
||||
air_type = air_match.group("type")
|
||||
air_modelId = int(air_match.group("modelId"))
|
||||
air_modelVersionId = int(air_match.group("modelVersionId"))
|
||||
# checkpoint/lycoris/lora/hypernet
|
||||
resource["type"] = air_type
|
||||
resource["modelId"] = air_modelId
|
||||
resource["modelVersionId"] = air_modelVersionId
|
||||
# --- End added ---
|
||||
|
||||
if resource.get("type") in ["lora", "lycoris", "hypernet"] and resource.get("modelVersionId"):
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'version': resource.get("modelVersionName", resource.get("versionName", "")),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
|
||||
@@ -50,6 +50,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
'from_civitai_image': True
|
||||
}
|
||||
|
||||
# Track already added LoRAs to prevent duplicates
|
||||
added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
|
||||
|
||||
# Extract prompt and negative prompt
|
||||
if "prompt" in metadata:
|
||||
result["gen_params"]["prompt"] = metadata["prompt"]
|
||||
@@ -96,11 +99,22 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
for resource in metadata["resources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type", "lora") == "lora":
|
||||
lora_hash = resource.get("hash", "")
|
||||
|
||||
# Skip LoRAs without proper identification (hash or modelVersionId)
|
||||
if not lora_hash and not resource.get("modelVersionId"):
|
||||
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
|
||||
continue
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': resource.get("name", "Unknown LoRA"),
|
||||
'type': "lora",
|
||||
'weight': float(resource.get("weight", 1.0)),
|
||||
'hash': resource.get("hash", ""),
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': resource.get("name", "Unknown"),
|
||||
@@ -114,7 +128,6 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and civitai_client:
|
||||
try:
|
||||
lora_hash = lora_entry['hash']
|
||||
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
@@ -129,43 +142,120 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process civitaiResources array
|
||||
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
|
||||
for resource in metadata["civitaiResources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
||||
# Initialize lora entry with the same structure as in automatic.py
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
# Get unique identifier for deduplication
|
||||
version_id = str(resource.get("modelVersionId", ""))
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id and version_id in added_loras:
|
||||
continue
|
||||
|
||||
# Initialize lora entry
|
||||
lora_entry = {
|
||||
'id': resource.get("modelVersionId", 0),
|
||||
'modelId': resource.get("modelId", 0),
|
||||
'name': resource.get("modelName", "Unknown LoRA"),
|
||||
'version': resource.get("modelVersionName", ""),
|
||||
'type': resource.get("type", "lora"),
|
||||
'weight': round(float(resource.get("weight", 1.0)), 2),
|
||||
'existsLocally': False,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if version_id and civitai_client:
|
||||
try:
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
continue
|
||||
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {version_id}: {e}")
|
||||
|
||||
# Track this LoRA in our deduplication dict
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
|
||||
# Try to get info from Civitai if modelVersionId is available
|
||||
if resource.get('modelVersionId') and civitai_client:
|
||||
try:
|
||||
version_id = str(resource.get('modelVersionId'))
|
||||
# Use get_model_version_info instead of get_model_version
|
||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
continue
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process additionalResources array
|
||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||
for resource in metadata["additionalResources"]:
|
||||
# Skip resources that aren't LoRAs or LyCORIS
|
||||
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
|
||||
continue
|
||||
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
|
||||
# Extract ID from URN format if available
|
||||
version_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
version_id = parts[1]
|
||||
|
||||
# Skip if we've already added this LoRA
|
||||
if version_id in added_loras:
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# If we have a version ID and civitai client, try to get more info
|
||||
if version_id and civitai_client:
|
||||
try:
|
||||
# Use get_model_version_info with the version ID
|
||||
civitai_info, error = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
else:
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
@@ -177,65 +267,74 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model version {resource.get('modelVersionId')}: {e}")
|
||||
|
||||
|
||||
# Track this LoRA for deduplication
|
||||
if version_id:
|
||||
added_loras[version_id] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
# Process additionalResources array
|
||||
if "additionalResources" in metadata and isinstance(metadata["additionalResources"], list):
|
||||
for resource in metadata["additionalResources"]:
|
||||
# Modified to process resources without a type field as potential LoRAs
|
||||
if resource.get("type") in ["lora", "lycoris"] or "type" not in resource:
|
||||
lora_type = resource.get("type", "lora")
|
||||
name = resource.get("name", "")
|
||||
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
|
||||
lora_index = 0
|
||||
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
|
||||
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
|
||||
lora_name = metadata[f"Lora_{lora_index} Model name"]
|
||||
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
|
||||
|
||||
# Skip if we've already added this LoRA by hash
|
||||
if lora_hash and lora_hash in added_loras:
|
||||
lora_index += 1
|
||||
continue
|
||||
|
||||
lora_entry = {
|
||||
'name': lora_name,
|
||||
'type': "lora",
|
||||
'weight': lora_strength_model,
|
||||
'hash': lora_hash,
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': lora_name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
|
||||
# Try to get info from Civitai if hash is available
|
||||
if lora_entry['hash'] and civitai_client:
|
||||
try:
|
||||
civitai_info = await civitai_client.get_model_by_hash(lora_hash)
|
||||
|
||||
# Extract ID from URN format if available
|
||||
model_id = None
|
||||
if name and "civitai:" in name:
|
||||
parts = name.split("@")
|
||||
if len(parts) > 1:
|
||||
model_id = parts[1]
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts,
|
||||
lora_hash
|
||||
)
|
||||
|
||||
lora_entry = {
|
||||
'name': name,
|
||||
'type': lora_type,
|
||||
'weight': float(resource.get("strength", 1.0)),
|
||||
'hash': "",
|
||||
'existsLocally': False,
|
||||
'localPath': None,
|
||||
'file_name': name,
|
||||
'thumbnailUrl': '/loras_static/images/no-preview.png',
|
||||
'baseModel': '',
|
||||
'size': 0,
|
||||
'downloadUrl': '',
|
||||
'isDeleted': False
|
||||
}
|
||||
if populated_entry is None:
|
||||
lora_index += 1
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
|
||||
# If we have a model ID and civitai client, try to get more info
|
||||
if model_id and civitai_client:
|
||||
try:
|
||||
# Use get_model_version_info with the model ID
|
||||
civitai_info, error = await civitai_client.get_model_version_info(model_id)
|
||||
|
||||
if error:
|
||||
logger.warning(f"Error getting model version info: {error}")
|
||||
else:
|
||||
populated_entry = await self.populate_lora_from_civitai(
|
||||
lora_entry,
|
||||
civitai_info,
|
||||
recipe_scanner,
|
||||
base_model_counts
|
||||
)
|
||||
|
||||
if populated_entry is None:
|
||||
continue # Skip invalid LoRA types
|
||||
|
||||
lora_entry = populated_entry
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for model ID {model_id}: {e}")
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
# If we have a version ID from Civitai, track it for deduplication
|
||||
if 'id' in lora_entry and lora_entry['id']:
|
||||
added_loras[str(lora_entry['id'])] = len(result["loras"])
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
|
||||
|
||||
# Track by hash if we have it
|
||||
if lora_hash:
|
||||
added_loras[lora_hash] = len(result["loras"])
|
||||
|
||||
result["loras"].append(lora_entry)
|
||||
|
||||
lora_index += 1
|
||||
|
||||
# If base model wasn't found earlier, use the most common one from LoRAs
|
||||
if not result["base_model"] and base_model_counts:
|
||||
|
||||
@@ -55,7 +55,7 @@ class RecipeFormatParser(RecipeMetadataParser):
|
||||
# Check if this LoRA exists locally by SHA256 hash
|
||||
if lora.get('hash') and recipe_scanner:
|
||||
lora_scanner = recipe_scanner._lora_scanner
|
||||
exists_locally = lora_scanner.has_lora_hash(lora['hash'])
|
||||
exists_locally = lora_scanner.has_hash(lora['hash'])
|
||||
if exists_locally:
|
||||
lora_cache = await lora_scanner.get_cached_data()
|
||||
lora_item = next((item for item in lora_cache.raw_data if item['sha256'].lower() == lora['hash'].lower()), None)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1254
py/routes/base_model_routes.py
Normal file
1254
py/routes/base_model_routes.py
Normal file
File diff suppressed because it is too large
Load Diff
140
py/routes/checkpoint_routes.py
Normal file
140
py/routes/checkpoint_routes.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointRoutes(BaseModelRoutes):
|
||||
"""Checkpoint-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "checkpoints.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.service = CheckpointService(checkpoint_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Checkpoint routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||
super().setup_routes(app, 'checkpoints')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup Checkpoint-specific routes"""
|
||||
# Checkpoint-specific CivitAI integration
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
|
||||
|
||||
# Checkpoint info by name
|
||||
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
||||
|
||||
# Checkpoint roots and Unet roots
|
||||
app.router.add_get(f'/api/{prefix}/checkpoints_roots', self.get_checkpoints_roots)
|
||||
app.router.add_get(f'/api/{prefix}/unet_roots', self.get_unet_roots)
|
||||
|
||||
async def get_checkpoint_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if model_type.lower() != 'checkpoint':
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def get_checkpoints_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of checkpoint roots from config"""
|
||||
try:
|
||||
roots = config.checkpoints_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_unet_roots(self, request: web.Request) -> web.Response:
|
||||
"""Return the list of unet roots from config"""
|
||||
try:
|
||||
roots = config.unet_roots
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting unet roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
@@ -1,843 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointsRoutes:
|
||||
"""API routes for checkpoint management"""
|
||||
|
||||
def __init__(self):
|
||||
self.scanner = None # Will be initialized in setup_routes
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
self.download_manager = None # Will be initialized in setup_routes
|
||||
self._download_lock = asyncio.Lock()
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
def setup_routes(self, app):
|
||||
"""Register routes with the aiohttp app"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
app.router.add_get('/checkpoints', self.handle_checkpoints_page)
|
||||
app.router.add_get('/api/checkpoints', self.get_checkpoints)
|
||||
app.router.add_post('/api/checkpoints/fetch-all-civitai', self.fetch_all_civitai)
|
||||
app.router.add_get('/api/checkpoints/base-models', self.get_base_models)
|
||||
app.router.add_get('/api/checkpoints/top-tags', self.get_top_tags)
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
||||
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
|
||||
|
||||
# Add new routes for model management similar to LoRA routes
|
||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
|
||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||
app.router.add_post('/api/checkpoints/relink-civitai', self.relink_civitai) # Add new relink endpoint
|
||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
||||
app.router.add_post('/api/checkpoints/save-metadata', self.save_metadata) # Add new route
|
||||
app.router.add_post('/api/checkpoints/rename', self.rename_checkpoint) # Add new rename endpoint
|
||||
|
||||
# Add new WebSocket endpoint for checkpoint progress
|
||||
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
|
||||
|
||||
# Add new routes for finding duplicates and filename conflicts
|
||||
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts)
|
||||
|
||||
# Add new endpoint for bulk deleting checkpoints
|
||||
app.router.add_post('/api/checkpoints/bulk-delete', self.bulk_delete_checkpoints)
|
||||
|
||||
# Add new endpoint for verifying duplicates
|
||||
app.router.add_post('/api/checkpoints/verify-duplicates', self.verify_duplicates)
|
||||
|
||||
async def get_checkpoints(self, request):
|
||||
"""Get paginated checkpoint data"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
page = int(request.query.get('page', '1'))
|
||||
page_size = min(int(request.query.get('page_size', '20')), 100)
|
||||
sort_by = request.query.get('sort_by', 'name')
|
||||
folder = request.query.get('folder', None)
|
||||
search = request.query.get('search', None)
|
||||
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
|
||||
base_models = request.query.getall('base_model', [])
|
||||
tags = request.query.getall('tag', [])
|
||||
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true' # Add favorites_only parameter
|
||||
|
||||
# Process search options
|
||||
search_options = {
|
||||
'filename': request.query.get('search_filename', 'true').lower() == 'true',
|
||||
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
|
||||
'tags': request.query.get('search_tags', 'false').lower() == 'true',
|
||||
'recursive': request.query.get('recursive', 'false').lower() == 'true',
|
||||
}
|
||||
|
||||
# Process hash filters if provided
|
||||
hash_filters = {}
|
||||
if 'hash' in request.query:
|
||||
hash_filters['single_hash'] = request.query['hash']
|
||||
elif 'hashes' in request.query:
|
||||
try:
|
||||
hash_list = json.loads(request.query['hashes'])
|
||||
if isinstance(hash_list, list):
|
||||
hash_filters['multiple_hashes'] = hash_list
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Get data from scanner
|
||||
result = await self.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
folder=folder,
|
||||
search=search,
|
||||
fuzzy_search=fuzzy_search,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
search_options=search_options,
|
||||
hash_filters=hash_filters,
|
||||
favorites_only=favorites_only # Pass favorites_only parameter
|
||||
)
|
||||
|
||||
# Format response items
|
||||
formatted_result = {
|
||||
'items': [self._format_checkpoint_response(cp) for cp in result['items']],
|
||||
'total': result['total'],
|
||||
'page': result['page'],
|
||||
'page_size': result['page_size'],
|
||||
'total_pages': result['total_pages']
|
||||
}
|
||||
|
||||
# Return as JSON
|
||||
return web.json_response(formatted_result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_paginated_data(self, page, page_size, sort_by='name',
|
||||
folder=None, search=None, fuzzy_search=False,
|
||||
base_models=None, tags=None,
|
||||
search_options=None, hash_filters=None,
|
||||
favorites_only=False): # Add favorites_only parameter with default False
|
||||
"""Get paginated and filtered checkpoint data"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if not cp.get('preview_nsfw_level') or cp.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if cp.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
cp for cp in filtered_data
|
||||
if any(tag in cp.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
|
||||
for cp in filtered_data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('file_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('file_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(cp.get('model_name', ''), search):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
elif search.lower() in cp.get('model_name', '').lower():
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in cp:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) for tag in cp['tags']):
|
||||
search_results.append(cp)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _format_checkpoint_response(self, checkpoint):
|
||||
"""Format checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint["model_name"],
|
||||
"file_name": checkpoint["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint.get("base_model", ""),
|
||||
"folder": checkpoint["folder"],
|
||||
"sha256": checkpoint.get("sha256", ""),
|
||||
"file_path": checkpoint["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint.get("size", 0),
|
||||
"modified": checkpoint.get("modified", ""),
|
||||
"tags": checkpoint.get("tags", []),
|
||||
"modelDescription": checkpoint.get("modelDescription", ""),
|
||||
"from_civitai": checkpoint.get("from_civitai", True),
|
||||
"notes": checkpoint.get("notes", ""),
|
||||
"model_type": checkpoint.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint.get("civitai", {}))
|
||||
}
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Fetch CivitAI metadata for all checkpoints in the background"""
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
# Prepare checkpoints to process
|
||||
to_process = [
|
||||
cp for cp in cache.raw_data
|
||||
if cp.get('sha256') and (not cp.get('civitai') or 'id' not in cp.get('civitai')) and cp.get('from_civitai', True)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
# Send initial progress
|
||||
await ws_manager.broadcast({
|
||||
'status': 'started',
|
||||
'total': total_to_process,
|
||||
'processed': 0,
|
||||
'success': 0
|
||||
})
|
||||
|
||||
# Process each checkpoint
|
||||
for cp in to_process:
|
||||
try:
|
||||
original_name = cp.get('model_name')
|
||||
if await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=cp['sha256'],
|
||||
file_path=cp['file_path'],
|
||||
model_data=cp,
|
||||
update_cache_func=self.scanner.update_single_model_cache
|
||||
):
|
||||
success += 1
|
||||
if original_name != cp.get('model_name'):
|
||||
needs_resort = True
|
||||
|
||||
processed += 1
|
||||
|
||||
# Send progress update
|
||||
await ws_manager.broadcast({
|
||||
'status': 'processing',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success,
|
||||
'current_name': cp.get('model_name', 'Unknown')
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching CivitAI data for {cp['file_path']}: {e}")
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
# Send completion message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'completed',
|
||||
'total': total_to_process,
|
||||
'processed': processed,
|
||||
'success': success
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed checkpoints (total: {total})"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Send error message
|
||||
await ws_manager.broadcast({
|
||||
'status': 'error',
|
||||
'error': str(e)
|
||||
})
|
||||
logger.error(f"Error in fetch_all_civitai for checkpoints: {e}")
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
"""Handle request for top tags sorted by frequency"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get top tags
|
||||
top_tags = await self.scanner.get_top_tags(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': top_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Internal server error'
|
||||
}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
"""Get base models used in loras"""
|
||||
try:
|
||||
# Parse query parameters
|
||||
limit = int(request.query.get('limit', '20'))
|
||||
|
||||
# Validate limit
|
||||
if limit < 1 or limit > 100:
|
||||
limit = 20 # Default to a reasonable limit
|
||||
|
||||
# Get base models
|
||||
base_models = await self.scanner.get_base_models(limit)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'base_models': base_models
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving base models: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def scan_checkpoints(self, request):
|
||||
"""Force a rescan of checkpoint files"""
|
||||
try:
|
||||
# Get the full_rebuild parameter and convert to bool, default to False
|
||||
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
|
||||
|
||||
await self.scanner.get_cached_data(force_refresh=True, rebuild_cache=full_rebuild)
|
||||
return web.json_response({"status": "success", "message": "Checkpoint scan completed"})
|
||||
except Exception as e:
|
||||
logger.error(f"Error in scan_checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_checkpoint_info(self, request):
|
||||
"""Get detailed information for a specific checkpoint by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
checkpoint_info = await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
if checkpoint_info:
|
||||
return web.json_response(checkpoint_info)
|
||||
else:
|
||||
return web.json_response({"error": "Checkpoint not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def handle_checkpoints_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /checkpoints request"""
|
||||
try:
|
||||
# Check if the CheckpointScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or
|
||||
(hasattr(self.scanner, '_is_initializing') and self.scanner._is_initializing)
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[], # 空文件夹列表
|
||||
is_initializing=True, # 新增标志
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
|
||||
logger.info("Checkpoints page is initializing, returning loading page")
|
||||
else:
|
||||
# 正常流程 - 获取已经初始化好的缓存数据
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings, # Pass settings to template
|
||||
request=request # Pass the request object to the template
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading checkpoints cache data: {cache_error}")
|
||||
# 如果获取缓存失败,也显示初始化页面
|
||||
template = self.template_env.get_template('checkpoints.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Checkpoints cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling checkpoints request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading checkpoints page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model deletion request"""
|
||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||
|
||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint model exclusion request"""
|
||||
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request for checkpoints"""
|
||||
response = await ModelRouteUtils.handle_fetch_civitai(request, self.scanner)
|
||||
|
||||
# If successful, format the metadata before returning
|
||||
if response.status == 200:
|
||||
data = json.loads(response.body.decode('utf-8'))
|
||||
if data.get("success") and data.get("metadata"):
|
||||
formatted_metadata = self._format_checkpoint_response(data["metadata"])
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"metadata": formatted_metadata
|
||||
})
|
||||
|
||||
# Otherwise, return the original response
|
||||
return response
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
"""Handle preview image replacement for checkpoints"""
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self.scanner)
|
||||
|
||||
async def download_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle checkpoint download request"""
|
||||
async with self._download_lock:
|
||||
# Get the download manager from service registry if not already initialized
|
||||
if self.download_manager is None:
|
||||
self.download_manager = await ServiceRegistry.get_download_manager()
|
||||
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback that uses checkpoint-specific WebSocket
|
||||
async def progress_callback(progress):
|
||||
await ws_manager.broadcast_checkpoint_progress({
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
|
||||
result = await self.download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=data.get('checkpoint_root'),
|
||||
relative_path=data.get('relative_path', ''),
|
||||
progress_callback=progress_callback,
|
||||
model_type="checkpoint"
|
||||
)
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This model requires purchase. Please ensure you have purchased early access and are logged in to Civitai."
|
||||
)
|
||||
|
||||
logger.error(f"Error downloading checkpoint: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
|
||||
async def get_checkpoint_roots(self, request):
|
||||
"""Return the checkpoint root directories"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
roots = self.scanner.get_model_roots()
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"roots": roots
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
"""Handle saving metadata updates for checkpoints"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Update metadata
|
||||
metadata.update(metadata_updates)
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
await self.scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await self.scanner.get_cached_data()
|
||||
await cache.resort(name_only=True)
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get the civitai client from service registry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
response = await civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if (model_type.lower() != 'checkpoint'):
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with duplicate SHA256 hashes"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate hashes from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for sha256, paths in duplicates.items():
|
||||
group = {
|
||||
"hash": sha256,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Add the primary model too
|
||||
primary_path = self.scanner._hash_index.get_path(sha256)
|
||||
if primary_path and primary_path not in paths:
|
||||
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
|
||||
if primary_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(primary_model))
|
||||
|
||||
if len(group["models"]) > 1: # Only include if we found multiple models
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"duplicates": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
|
||||
"""Find checkpoints with conflicting filenames"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get duplicate filenames from hash index
|
||||
duplicates = self.scanner._hash_index.get_duplicate_filenames()
|
||||
|
||||
# Format the response
|
||||
result = []
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for filename, paths in duplicates.items():
|
||||
group = {
|
||||
"filename": filename,
|
||||
"models": []
|
||||
}
|
||||
# Find matching models for each path
|
||||
for path in paths:
|
||||
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
|
||||
if model:
|
||||
group["models"].append(self._format_checkpoint_response(model))
|
||||
|
||||
# Find the model from the main index too
|
||||
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
|
||||
if hash_val:
|
||||
main_path = self.scanner._hash_index.get_path(hash_val)
|
||||
if main_path and main_path not in paths:
|
||||
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
|
||||
if main_model:
|
||||
group["models"].insert(0, self._format_checkpoint_response(main_model))
|
||||
|
||||
if group["models"]:
|
||||
result.append(group)
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"conflicts": result,
|
||||
"count": len(result)
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def bulk_delete_checkpoints(self, request: web.Request) -> web.Response:
|
||||
"""Handle bulk deletion of checkpoint models"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self.scanner)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bulk delete checkpoints: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
"""Handle CivitAI metadata re-linking request by model version ID for checkpoints"""
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self.scanner)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
"""Handle verification of duplicate checkpoint hashes"""
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self.scanner)
|
||||
|
||||
async def rename_checkpoint(self, request: web.Request) -> web.Response:
|
||||
"""Handle renaming a checkpoint file and its associated files"""
|
||||
return await ModelRouteUtils.handle_rename_model(request, self.scanner)
|
||||
105
py/routes/embedding_routes.py
Normal file
105
py/routes/embedding_routes.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingRoutes(BaseModelRoutes):
|
||||
"""Embedding-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Embedding routes with Embedding service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "embeddings.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
self.service = EmbeddingService(embedding_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Setup Embedding routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Setup common routes with 'embeddings' prefix (includes page route)
|
||||
super().setup_routes(app, 'embeddings')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup Embedding-specific routes"""
|
||||
# Embedding-specific CivitAI integration
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding)
|
||||
|
||||
# Embedding info by name
|
||||
app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info)
|
||||
|
||||
async def get_embedding_info(self, request: web.Request) -> web.Response:
|
||||
"""Get detailed information for a specific embedding by name"""
|
||||
try:
|
||||
name = request.match_info.get('name', '')
|
||||
embedding_info = await self.service.get_model_info_by_name(name)
|
||||
|
||||
if embedding_info:
|
||||
return web.json_response(embedding_info)
|
||||
else:
|
||||
return web.json_response({"error": "Embedding not found"}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai embedding model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be TextualInversion (Embedding)
|
||||
if model_type.lower() not in ['textualinversion', 'embedding']:
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching embedding model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from ..utils.example_images_download_manager import DownloadManager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
from ..utils.example_images_metadata import MetadataUpdater
|
||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,6 +21,7 @@ class ExampleImagesRoutes:
|
||||
app.router.add_get('/api/example-image-files', ExampleImagesRoutes.get_example_image_files)
|
||||
app.router.add_get('/api/has-example-images', ExampleImagesRoutes.has_example_images)
|
||||
app.router.add_post('/api/delete-example-image', ExampleImagesRoutes.delete_example_image)
|
||||
app.router.add_post('/api/force-download-example-images', ExampleImagesRoutes.force_download_example_images)
|
||||
|
||||
@staticmethod
|
||||
async def download_example_images(request):
|
||||
@@ -65,4 +66,9 @@ class ExampleImagesRoutes:
|
||||
@staticmethod
|
||||
async def delete_example_image(request):
|
||||
"""Delete a custom example image for a model"""
|
||||
return await ExampleImagesProcessor.delete_custom_image(request)
|
||||
return await ExampleImagesProcessor.delete_custom_image(request)
|
||||
|
||||
@staticmethod
|
||||
async def force_download_example_images(request):
|
||||
"""Force download example images for specific models"""
|
||||
return await DownloadManager.start_force_download(request)
|
||||
@@ -1,188 +1,329 @@
|
||||
import os
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
from typing import Dict
|
||||
import asyncio
|
||||
import logging
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
from aiohttp import web
|
||||
from typing import Dict
|
||||
from server import PromptServer # type: ignore
|
||||
|
||||
from .base_model_routes import BaseModelRoutes
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.utils import get_lora_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger('asyncio').setLevel(logging.CRITICAL)
|
||||
|
||||
class LoraRoutes:
|
||||
"""Route handlers for LoRA management endpoints"""
|
||||
class LoraRoutes(BaseModelRoutes):
|
||||
"""LoRA-specific route controller"""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.scanner = None
|
||||
self.recipe_scanner = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
async def init_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
"""Initialize LoRA routes with LoRA service"""
|
||||
# Service will be initialized later via setup_routes
|
||||
self.service = None
|
||||
self.civitai_client = None
|
||||
self.template_name = "loras.html"
|
||||
|
||||
async def initialize_services(self):
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.service = LoraService(lora_scanner)
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
# Initialize parent with the service
|
||||
super().__init__(self.service)
|
||||
|
||||
def format_lora_data(self, lora: Dict) -> Dict:
|
||||
"""Format LoRA data for template rendering"""
|
||||
return {
|
||||
"model_name": lora["model_name"],
|
||||
"file_name": lora["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora["preview_url"]),
|
||||
"preview_nsfw_level": lora.get("preview_nsfw_level", 0),
|
||||
"base_model": lora["base_model"],
|
||||
"folder": lora["folder"],
|
||||
"sha256": lora["sha256"],
|
||||
"file_path": lora["file_path"].replace(os.sep, "/"),
|
||||
"size": lora["size"],
|
||||
"tags": lora["tags"],
|
||||
"modelDescription": lora["modelDescription"],
|
||||
"usage_tips": lora["usage_tips"],
|
||||
"notes": lora["notes"],
|
||||
"modified": lora["modified"],
|
||||
"from_civitai": lora.get("from_civitai", True),
|
||||
"civitai": self._filter_civitai_data(lora.get("civitai", {}))
|
||||
}
|
||||
|
||||
def _filter_civitai_data(self, data: Dict) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def handle_loras_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Check if the LoraScanner is initializing
|
||||
# It's initializing if the cache object doesn't exist yet,
|
||||
# OR if the scanner explicitly says it's initializing (background task running).
|
||||
is_initializing = (
|
||||
self.scanner._cache is None or self.scanner.is_initializing()
|
||||
)
|
||||
|
||||
if is_initializing:
|
||||
# If still initializing, return loading page
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
|
||||
logger.info("Loras page is initializing, returning loading page")
|
||||
else:
|
||||
# Normal flow - get data from initialized cache
|
||||
try:
|
||||
cache = await self.scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=cache.folders,
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading cache data: {cache_error}")
|
||||
template = self.template_env.get_template('loras.html')
|
||||
rendered = template.render(
|
||||
folders=[],
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling loras request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading loras page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
def _format_recipe_file_url(self, file_path: str) -> str:
|
||||
"""Format file path for recipe image as a URL - same as in recipe_routes"""
|
||||
try:
|
||||
# Return the file URL directly for the first lora root's preview
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/')
|
||||
if file_path.replace(os.sep, '/').startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/')
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
# If not in recipes dir, try to create a valid URL from the file path
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting recipe file URL: {e}", exc_info=True)
|
||||
return '/loras_static/images/no-preview.png' # Return default image on error
|
||||
|
||||
def setup_routes(self, app: web.Application):
|
||||
"""Register routes with the application"""
|
||||
# Add an app startup handler to initialize services
|
||||
app.on_startup.append(self._on_startup)
|
||||
"""Setup LoRA routes"""
|
||||
# Schedule service initialization on app startup
|
||||
app.on_startup.append(lambda _: self.initialize_services())
|
||||
|
||||
# Register routes
|
||||
app.router.add_get('/loras', self.handle_loras_page)
|
||||
app.router.add_get('/loras/recipes', self.handle_recipes_page)
|
||||
# Setup common routes with 'loras' prefix (includes page route)
|
||||
super().setup_routes(app, 'loras')
|
||||
|
||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
||||
"""Setup LoRA-specific routes"""
|
||||
# LoRA-specific query routes
|
||||
app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
|
||||
app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
|
||||
app.router.add_get(f'/api/{prefix}/usage-tips-by-path', self.get_lora_usage_tips_by_path)
|
||||
|
||||
async def _on_startup(self, app):
|
||||
"""Initialize services when the app starts"""
|
||||
await self.init_services()
|
||||
# CivitAI integration with LoRA-specific validation
|
||||
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
|
||||
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
|
||||
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
|
||||
|
||||
# ComfyUI integration
|
||||
app.router.add_post(f'/api/{prefix}/get_trigger_words', self.get_trigger_words)
|
||||
|
||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||
"""Parse LoRA-specific parameters"""
|
||||
params = {}
|
||||
|
||||
# LoRA-specific parameters
|
||||
if 'first_letter' in request.query:
|
||||
params['first_letter'] = request.query.get('first_letter')
|
||||
|
||||
# Handle fuzzy search parameter name variation
|
||||
if request.query.get('fuzzy') == 'true':
|
||||
params['fuzzy_search'] = True
|
||||
|
||||
# Handle additional filter parameters for LoRAs
|
||||
if 'lora_hash' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['single_hash'] = request.query['lora_hash'].lower()
|
||||
elif 'lora_hashes' in request.query:
|
||||
if not params.get('hash_filters'):
|
||||
params['hash_filters'] = {}
|
||||
params['hash_filters']['multiple_hashes'] = [h.lower() for h in request.query['lora_hashes'].split(',')]
|
||||
|
||||
return params
|
||||
|
||||
# LoRA-specific route handlers
|
||||
async def get_letter_counts(self, request: web.Request) -> web.Response:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
try:
|
||||
letter_counts = await self.service.get_letter_counts()
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'letter_counts': letter_counts
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting letter counts: {e}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_notes(self, request: web.Request) -> web.Response:
|
||||
"""Get notes for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
notes = await self.service.get_lora_notes(lora_name)
|
||||
if notes is not None:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'notes': notes
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'LoRA not found in cache'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora notes: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
trigger_words = await self.service.get_lora_trigger_words(lora_name)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'trigger_words': trigger_words
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora trigger words: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
try:
|
||||
relative_path = request.query.get('relative_path')
|
||||
if not relative_path:
|
||||
return web.Response(text='Relative path is required', status=400)
|
||||
|
||||
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'usage_tips': usage_tips or ''
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_preview_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the static preview URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
preview_url = await self.service.get_lora_preview_url(lora_name)
|
||||
if preview_url:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'preview_url': preview_url
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No preview URL found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora preview URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_lora_civitai_url(self, request: web.Request) -> web.Response:
|
||||
"""Get the Civitai URL for a LoRA file"""
|
||||
try:
|
||||
lora_name = request.query.get('name')
|
||||
if not lora_name:
|
||||
return web.Response(text='Lora file name is required', status=400)
|
||||
|
||||
result = await self.service.get_lora_civitai_url(lora_name)
|
||||
if result['civitai_url']:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
**result
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No Civitai data found for the specified lora'
|
||||
}, status=404)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lora Civitai URL: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
# CivitAI integration methods
|
||||
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai LoRA model with local availability info"""
|
||||
try:
|
||||
model_id = request.match_info['model_id']
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be LORA, LoCon, or DORA
|
||||
from ..utils.constants import VALID_LORA_TYPES
|
||||
if model_type.lower() not in VALID_LORA_TYPES:
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the model file (type="Model") in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.service.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.service.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching LoRA model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by model version ID"""
|
||||
try:
|
||||
model_version_id = request.match_info.get('modelVersionId')
|
||||
|
||||
# Get model details from Civitai API
|
||||
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
|
||||
|
||||
if not model:
|
||||
# Log warning for failed model retrieval
|
||||
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
|
||||
|
||||
# Determine status code based on error message
|
||||
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
|
||||
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": error_msg or "Failed to fetch model information"
|
||||
}, status=status_code)
|
||||
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
|
||||
"""Get CivitAI model details by hash"""
|
||||
try:
|
||||
hash = request.match_info.get('hash')
|
||||
model = await self.civitai_client.get_model_by_hash(hash)
|
||||
return web.json_response(model)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model details by hash: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
async def get_trigger_words(self, request: web.Request) -> web.Response:
|
||||
"""Get trigger words for specified LoRA models"""
|
||||
try:
|
||||
json_data = await request.json()
|
||||
lora_names = json_data.get("lora_names", [])
|
||||
node_ids = json_data.get("node_ids", [])
|
||||
|
||||
all_trigger_words = []
|
||||
for lora_name in lora_names:
|
||||
_, trigger_words = get_lora_info(lora_name)
|
||||
all_trigger_words.extend(trigger_words)
|
||||
|
||||
# Format the trigger words
|
||||
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
|
||||
|
||||
# Send update to all connected trigger word toggle nodes
|
||||
for node_id in node_ids:
|
||||
PromptServer.instance.send_sync("trigger_word_update", {
|
||||
"id": node_id,
|
||||
"message": trigger_words_text
|
||||
})
|
||||
|
||||
return web.json_response({"success": True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting trigger words: {e}")
|
||||
return web.json_response({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -10,6 +10,7 @@ from ..utils.usage_stats import UsageStats
|
||||
from ..utils.lora_metadata import extract_trained_words
|
||||
from ..config import config
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -90,6 +91,8 @@ class MiscRoutes:
|
||||
# Add new route for clearing cache
|
||||
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
|
||||
|
||||
app.router.add_get('/api/health-check', lambda request: web.json_response({'status': 'ok'}))
|
||||
|
||||
# Usage stats routes
|
||||
app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
|
||||
app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
|
||||
@@ -106,6 +109,9 @@ class MiscRoutes:
|
||||
# Node registry endpoints
|
||||
app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes)
|
||||
app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
|
||||
|
||||
# Add new route for checking if a model exists in the library
|
||||
app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists)
|
||||
|
||||
@staticmethod
|
||||
async def clear_cache(request):
|
||||
@@ -160,6 +166,9 @@ class MiscRoutes:
|
||||
|
||||
# Validate and update settings
|
||||
for key, value in data.items():
|
||||
if value == settings.get(key):
|
||||
# No change, skip
|
||||
continue
|
||||
# Special handling for example_images_path - verify path exists
|
||||
if key == 'example_images_path' and value:
|
||||
if not os.path.exists(value):
|
||||
@@ -580,3 +589,111 @@ class MiscRoutes:
|
||||
'error': 'Internal Error',
|
||||
'message': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def check_model_exists(request):
|
||||
"""
|
||||
Check if a model with specified modelId and optionally modelVersionId exists in the library
|
||||
|
||||
Expects query parameters:
|
||||
- modelId: int - Civitai model ID (required)
|
||||
- modelVersionId: int - Civitai model version ID (optional)
|
||||
|
||||
Returns:
|
||||
- If modelVersionId is provided: JSON with a boolean 'exists' field
|
||||
- If modelVersionId is not provided: JSON with a list of modelVersionIds that exist in the library
|
||||
"""
|
||||
try:
|
||||
# Get the modelId and modelVersionId from query parameters
|
||||
model_id_str = request.query.get('modelId')
|
||||
model_version_id_str = request.query.get('modelVersionId')
|
||||
|
||||
# Validate modelId parameter (required)
|
||||
if not model_id_str:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing required parameter: modelId'
|
||||
}, status=400)
|
||||
|
||||
try:
|
||||
# Convert modelId to integer
|
||||
model_id = int(model_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Parameter modelId must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Get all scanners
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# If modelVersionId is provided, check for specific version
|
||||
if model_version_id_str:
|
||||
try:
|
||||
model_version_id = int(model_version_id_str)
|
||||
except ValueError:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Parameter modelVersionId must be an integer'
|
||||
}, status=400)
|
||||
|
||||
# Check lora scanner first
|
||||
exists = False
|
||||
model_type = None
|
||||
|
||||
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = 'lora'
|
||||
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = 'checkpoint'
|
||||
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id):
|
||||
exists = True
|
||||
model_type = 'embedding'
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'exists': exists,
|
||||
'modelType': model_type if exists else None
|
||||
})
|
||||
|
||||
# If modelVersionId is not provided, return all version IDs for the model
|
||||
else:
|
||||
lora_versions = await lora_scanner.get_model_versions_by_id(model_id)
|
||||
checkpoint_versions = []
|
||||
embedding_versions = []
|
||||
|
||||
# 优先lora,其次checkpoint,最后embedding
|
||||
if not lora_versions:
|
||||
checkpoint_versions = await checkpoint_scanner.get_model_versions_by_id(model_id)
|
||||
if not lora_versions and not checkpoint_versions:
|
||||
embedding_versions = await embedding_scanner.get_model_versions_by_id(model_id)
|
||||
|
||||
model_type = None
|
||||
versions = []
|
||||
|
||||
if lora_versions:
|
||||
model_type = 'lora'
|
||||
versions = lora_versions
|
||||
elif checkpoint_versions:
|
||||
model_type = 'checkpoint'
|
||||
versions = checkpoint_versions
|
||||
elif embedding_versions:
|
||||
model_type = 'embedding'
|
||||
versions = embedding_versions
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'modelId': model_id,
|
||||
'modelType': model_type,
|
||||
'versions': versions
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model existence: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import jinja2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import io
|
||||
import logging
|
||||
from aiohttp import web
|
||||
@@ -16,12 +16,13 @@ from ..utils.exif_utils import ExifUtils
|
||||
from ..recipes import RecipeParserFactory
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..config import config
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = 'nodes' not in sys.modules
|
||||
|
||||
from ..utils.utils import download_civitai_image
|
||||
from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import
|
||||
|
||||
# Only import MetadataRegistry in non-standalone mode
|
||||
@@ -40,7 +41,10 @@ class RecipeRoutes:
|
||||
# Initialize service references as None, will be set during async init
|
||||
self.recipe_scanner = None
|
||||
self.civitai_client = None
|
||||
# Remove WorkflowParser instance
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True
|
||||
)
|
||||
|
||||
# Pre-warm the cache
|
||||
self._init_cache_task = None
|
||||
@@ -54,6 +58,8 @@ class RecipeRoutes:
|
||||
def setup_routes(cls, app: web.Application):
|
||||
"""Register API routes"""
|
||||
routes = cls()
|
||||
app.router.add_get('/loras/recipes', routes.handle_recipes_page)
|
||||
|
||||
app.router.add_get('/api/recipes', routes.get_recipes)
|
||||
app.router.add_get('/api/recipe/{recipe_id}', routes.get_recipe_detail)
|
||||
app.router.add_post('/api/recipes/analyze-image', routes.analyze_recipe_image)
|
||||
@@ -115,6 +121,61 @@ class RecipeRoutes:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True)
|
||||
|
||||
async def handle_recipes_page(self, request: web.Request) -> web.Response:
|
||||
"""Handle GET /loras/recipes request"""
|
||||
try:
|
||||
# Ensure services are initialized
|
||||
await self.init_services()
|
||||
|
||||
# 获取用户语言设置
|
||||
user_language = settings.get('language', 'en')
|
||||
|
||||
# 设置服务端i18n语言
|
||||
server_i18n.set_locale(user_language)
|
||||
|
||||
# 为模板环境添加i18n过滤器
|
||||
if not hasattr(self.template_env, '_i18n_filter_added'):
|
||||
self.template_env.filters['t'] = server_i18n.create_template_filter()
|
||||
self.template_env._i18n_filter_added = True
|
||||
|
||||
# Skip initialization check and directly try to get cached data
|
||||
try:
|
||||
# Recipe scanner will initialize cache if needed
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=False)
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
recipes=[], # Frontend will load recipes via API
|
||||
is_initializing=False,
|
||||
settings=settings,
|
||||
request=request,
|
||||
# 添加服务端翻译函数
|
||||
t=server_i18n.get_translation,
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(f"Error loading recipe cache data: {cache_error}")
|
||||
# Still keep error handling - show initializing page on error
|
||||
template = self.template_env.get_template('recipes.html')
|
||||
rendered = template.render(
|
||||
is_initializing=True,
|
||||
settings=settings,
|
||||
request=request,
|
||||
# 添加服务端翻译函数
|
||||
t=server_i18n.get_translation,
|
||||
)
|
||||
logger.info("Recipe cache error, returning initialization page")
|
||||
|
||||
return web.Response(
|
||||
text=rendered,
|
||||
content_type='text/html'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling recipes request: {e}", exc_info=True)
|
||||
return web.Response(
|
||||
text="Error loading recipes page",
|
||||
status=500
|
||||
)
|
||||
|
||||
async def get_recipes(self, request: web.Request) -> web.Response:
|
||||
"""API endpoint for getting paginated recipes"""
|
||||
@@ -330,16 +391,6 @@ class RecipeRoutes:
|
||||
# Use meta field from image_info as metadata
|
||||
if 'meta' in image_info:
|
||||
metadata = image_info['meta']
|
||||
|
||||
else:
|
||||
# Not a Civitai image URL, use the original download method
|
||||
temp_path = download_civitai_image(url)
|
||||
|
||||
if not temp_path:
|
||||
return web.json_response({
|
||||
"error": "Failed to download image from URL",
|
||||
"loras": []
|
||||
}, status=400)
|
||||
|
||||
# If metadata wasn't obtained from Civitai API, extract it from the image
|
||||
if metadata is None:
|
||||
@@ -592,21 +643,6 @@ class RecipeRoutes:
|
||||
image = base64.b64decode(image_base64)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": f"Invalid base64 image data: {str(e)}"}, status=400)
|
||||
elif image_url:
|
||||
# Download image from URL
|
||||
temp_path = download_civitai_image(image_url)
|
||||
if not temp_path:
|
||||
return web.json_response({"error": "Failed to download image from URL"}, status=400)
|
||||
|
||||
# Read the downloaded image
|
||||
with open(temp_path, 'rb') as f:
|
||||
image = f.read()
|
||||
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
return web.json_response({"error": "No image data provided"}, status=400)
|
||||
|
||||
@@ -1018,6 +1054,8 @@ class RecipeRoutes:
|
||||
shape_info = tensor_image.shape
|
||||
logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}")
|
||||
|
||||
import torch
|
||||
|
||||
# Convert tensor to numpy array
|
||||
if isinstance(tensor_image, torch.Tensor):
|
||||
image_np = tensor_image.cpu().numpy()
|
||||
@@ -1100,7 +1138,7 @@ class RecipeRoutes:
|
||||
for lora_name, lora_strength in lora_matches:
|
||||
try:
|
||||
# Get lora info from scanner
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora_name)
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name)
|
||||
|
||||
# Create lora entry
|
||||
lora_entry = {
|
||||
@@ -1119,7 +1157,7 @@ class RecipeRoutes:
|
||||
# Get base model from lora scanner for the available loras
|
||||
base_model_counts = {}
|
||||
for lora in loras_data:
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_lora_info_by_name(lora.get("file_name", ""))
|
||||
lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", ""))
|
||||
if lora_info and "base_model" in lora_info:
|
||||
base_model = lora_info["base_model"]
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
@@ -1209,7 +1247,7 @@ class RecipeRoutes:
|
||||
if lora.get("isDeleted", False):
|
||||
continue
|
||||
|
||||
if not self.recipe_scanner._lora_scanner.has_lora_hash(lora.get("hash", "")):
|
||||
if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")):
|
||||
continue
|
||||
|
||||
# Get the strength
|
||||
@@ -1317,7 +1355,7 @@ class RecipeRoutes:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
# Find target LoRA by name
|
||||
target_lora = await lora_scanner.get_lora_info_by_name(target_name)
|
||||
target_lora = await lora_scanner.get_model_info_by_name(target_name)
|
||||
if not target_lora:
|
||||
return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404)
|
||||
|
||||
@@ -1429,9 +1467,9 @@ class RecipeRoutes:
|
||||
if 'loras' in recipe:
|
||||
for lora in recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_lora_hash(lora['hash'].lower())
|
||||
lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self.recipe_scanner._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||
|
||||
# Ensure file_url is set (needed by frontend)
|
||||
if 'file_path' in recipe:
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Dict, List, Any
|
||||
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.usage_stats import UsageStats
|
||||
|
||||
@@ -20,6 +21,7 @@ class StatsRoutes:
|
||||
def __init__(self):
|
||||
self.lora_scanner = None
|
||||
self.checkpoint_scanner = None
|
||||
self.embedding_scanner = None
|
||||
self.usage_stats = None
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
@@ -30,6 +32,7 @@ class StatsRoutes:
|
||||
"""Initialize services from ServiceRegistry"""
|
||||
self.lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
self.checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
self.embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
self.usage_stats = UsageStats()
|
||||
|
||||
async def handle_stats_page(self, request: web.Request) -> web.Response:
|
||||
@@ -49,13 +52,30 @@ class StatsRoutes:
|
||||
(hasattr(self.checkpoint_scanner, '_is_initializing') and self.checkpoint_scanner._is_initializing)
|
||||
)
|
||||
|
||||
is_initializing = lora_initializing or checkpoint_initializing
|
||||
embedding_initializing = (
|
||||
self.embedding_scanner._cache is None or
|
||||
(hasattr(self.embedding_scanner, 'is_initializing') and self.embedding_scanner.is_initializing())
|
||||
)
|
||||
|
||||
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
|
||||
|
||||
# 获取用户语言设置
|
||||
user_language = settings.get('language', 'en')
|
||||
|
||||
# 设置服务端i18n语言
|
||||
server_i18n.set_locale(user_language)
|
||||
|
||||
# 为模板环境添加i18n过滤器
|
||||
if not hasattr(self.template_env, '_i18n_filter_added'):
|
||||
self.template_env.filters['t'] = server_i18n.create_template_filter()
|
||||
self.template_env._i18n_filter_added = True
|
||||
|
||||
template = self.template_env.get_template('statistics.html')
|
||||
rendered = template.render(
|
||||
is_initializing=is_initializing,
|
||||
settings=settings,
|
||||
request=request
|
||||
request=request,
|
||||
t=server_i18n.get_translation,
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
@@ -85,21 +105,29 @@ class StatsRoutes:
|
||||
checkpoint_count = len(checkpoint_cache.raw_data)
|
||||
checkpoint_size = sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
|
||||
|
||||
# Get Embedding statistics
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
embedding_count = len(embedding_cache.raw_data)
|
||||
embedding_size = sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
# Get usage statistics
|
||||
usage_data = await self.usage_stats.get_stats()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'total_models': lora_count + checkpoint_count,
|
||||
'total_models': lora_count + checkpoint_count + embedding_count,
|
||||
'lora_count': lora_count,
|
||||
'checkpoint_count': checkpoint_count,
|
||||
'total_size': lora_size + checkpoint_size,
|
||||
'embedding_count': embedding_count,
|
||||
'total_size': lora_size + checkpoint_size + embedding_size,
|
||||
'lora_size': lora_size,
|
||||
'checkpoint_size': checkpoint_size,
|
||||
'embedding_size': embedding_size,
|
||||
'total_generations': usage_data.get('total_executions', 0),
|
||||
'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})),
|
||||
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
|
||||
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})),
|
||||
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -121,14 +149,17 @@ class StatsRoutes:
|
||||
# Get model data for enrichment
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create hash to model mapping
|
||||
lora_map = {lora['sha256']: lora for lora in lora_cache.raw_data}
|
||||
checkpoint_map = {cp['sha256']: cp for cp in checkpoint_cache.raw_data}
|
||||
embedding_map = {emb['sha256']: emb for emb in embedding_cache.raw_data}
|
||||
|
||||
# Prepare top used models
|
||||
top_loras = self._get_top_used_models(usage_data.get('loras', {}), lora_map, 10)
|
||||
top_checkpoints = self._get_top_used_models(usage_data.get('checkpoints', {}), checkpoint_map, 10)
|
||||
top_embeddings = self._get_top_used_models(usage_data.get('embeddings', {}), embedding_map, 10)
|
||||
|
||||
# Prepare usage timeline (last 30 days)
|
||||
timeline = self._get_usage_timeline(usage_data, 30)
|
||||
@@ -138,6 +169,7 @@ class StatsRoutes:
|
||||
'data': {
|
||||
'top_loras': top_loras,
|
||||
'top_checkpoints': top_checkpoints,
|
||||
'top_embeddings': top_embeddings,
|
||||
'usage_timeline': timeline,
|
||||
'total_executions': usage_data.get('total_executions', 0)
|
||||
}
|
||||
@@ -158,16 +190,19 @@ class StatsRoutes:
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count by base model
|
||||
lora_base_models = Counter(lora.get('base_model', 'Unknown') for lora in lora_cache.raw_data)
|
||||
checkpoint_base_models = Counter(cp.get('base_model', 'Unknown') for cp in checkpoint_cache.raw_data)
|
||||
embedding_base_models = Counter(emb.get('base_model', 'Unknown') for emb in embedding_cache.raw_data)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': dict(lora_base_models),
|
||||
'checkpoints': dict(checkpoint_base_models)
|
||||
'checkpoints': dict(checkpoint_base_models),
|
||||
'embeddings': dict(embedding_base_models)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -186,6 +221,7 @@ class StatsRoutes:
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Count tag frequencies
|
||||
all_tags = []
|
||||
@@ -193,6 +229,8 @@ class StatsRoutes:
|
||||
all_tags.extend(lora.get('tags', []))
|
||||
for cp in checkpoint_cache.raw_data:
|
||||
all_tags.extend(cp.get('tags', []))
|
||||
for emb in embedding_cache.raw_data:
|
||||
all_tags.extend(emb.get('tags', []))
|
||||
|
||||
tag_counts = Counter(all_tags)
|
||||
|
||||
@@ -225,6 +263,7 @@ class StatsRoutes:
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
# Create models with usage data
|
||||
lora_storage = []
|
||||
@@ -255,15 +294,31 @@ class StatsRoutes:
|
||||
'base_model': cp.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
embedding_storage = []
|
||||
for emb in embedding_cache.raw_data:
|
||||
usage_count = 0
|
||||
if emb['sha256'] in usage_data.get('embeddings', {}):
|
||||
usage_count = usage_data['embeddings'][emb['sha256']].get('total', 0)
|
||||
|
||||
embedding_storage.append({
|
||||
'name': emb['model_name'],
|
||||
'size': emb.get('size', 0),
|
||||
'usage_count': usage_count,
|
||||
'folder': emb.get('folder', ''),
|
||||
'base_model': emb.get('base_model', 'Unknown')
|
||||
})
|
||||
|
||||
# Sort by size
|
||||
lora_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
checkpoint_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
embedding_storage.sort(key=lambda x: x['size'], reverse=True)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'data': {
|
||||
'loras': lora_storage[:20], # Top 20 by size
|
||||
'checkpoints': checkpoint_storage[:20]
|
||||
'checkpoints': checkpoint_storage[:20],
|
||||
'embeddings': embedding_storage[:20]
|
||||
}
|
||||
})
|
||||
|
||||
@@ -285,15 +340,18 @@ class StatsRoutes:
|
||||
# Get model data
|
||||
lora_cache = await self.lora_scanner.get_cached_data()
|
||||
checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
|
||||
embedding_cache = await self.embedding_scanner.get_cached_data()
|
||||
|
||||
insights = []
|
||||
|
||||
# Calculate unused models
|
||||
unused_loras = self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {}))
|
||||
unused_checkpoints = self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
|
||||
unused_embeddings = self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
|
||||
|
||||
total_loras = len(lora_cache.raw_data)
|
||||
total_checkpoints = len(checkpoint_cache.raw_data)
|
||||
total_embeddings = len(embedding_cache.raw_data)
|
||||
|
||||
if total_loras > 0:
|
||||
unused_lora_percent = (unused_loras / total_loras) * 100
|
||||
@@ -315,9 +373,20 @@ class StatsRoutes:
|
||||
'suggestion': 'Review and consider removing checkpoints you no longer need.'
|
||||
})
|
||||
|
||||
if total_embeddings > 0:
|
||||
unused_embedding_percent = (unused_embeddings / total_embeddings) * 100
|
||||
if unused_embedding_percent > 50:
|
||||
insights.append({
|
||||
'type': 'warning',
|
||||
'title': 'High Number of Unused Embeddings',
|
||||
'description': f'{unused_embedding_percent:.1f}% of your embeddings ({unused_embeddings}/{total_embeddings}) have never been used.',
|
||||
'suggestion': 'Consider organizing or archiving unused embeddings to optimize your collection.'
|
||||
})
|
||||
|
||||
# Storage insights
|
||||
total_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data) + \
|
||||
sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
|
||||
sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data) + \
|
||||
sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
|
||||
|
||||
if total_size > 100 * 1024 * 1024 * 1024: # 100GB
|
||||
insights.append({
|
||||
@@ -390,6 +459,7 @@ class StatsRoutes:
|
||||
|
||||
lora_usage = 0
|
||||
checkpoint_usage = 0
|
||||
embedding_usage = 0
|
||||
|
||||
# Count usage for this date
|
||||
for model_usage in usage_data.get('loras', {}).values():
|
||||
@@ -400,11 +470,16 @@ class StatsRoutes:
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
checkpoint_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
for model_usage in usage_data.get('embeddings', {}).values():
|
||||
if isinstance(model_usage, dict) and 'history' in model_usage:
|
||||
embedding_usage += model_usage['history'].get(date_str, 0)
|
||||
|
||||
timeline.append({
|
||||
'date': date_str,
|
||||
'lora_usage': lora_usage,
|
||||
'checkpoint_usage': checkpoint_usage,
|
||||
'total_usage': lora_usage + checkpoint_usage
|
||||
'embedding_usage': embedding_usage,
|
||||
'total_usage': lora_usage + checkpoint_usage + embedding_usage
|
||||
})
|
||||
|
||||
return list(reversed(timeline)) # Oldest to newest
|
||||
|
||||
@@ -2,10 +2,13 @@ import os
|
||||
import aiohttp
|
||||
import logging
|
||||
import toml
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
import git
|
||||
import zipfile
|
||||
import shutil
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from typing import Dict, Any, List
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,6 +20,7 @@ class UpdateRoutes:
|
||||
"""Register update check routes"""
|
||||
app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
|
||||
app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
|
||||
app.router.add_post('/api/perform-update', UpdateRoutes.perform_update)
|
||||
|
||||
@staticmethod
|
||||
async def check_updates(request):
|
||||
@@ -25,6 +29,8 @@ class UpdateRoutes:
|
||||
Returns update status and version information
|
||||
"""
|
||||
try:
|
||||
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
||||
|
||||
# Read local version from pyproject.toml
|
||||
local_version = UpdateRoutes._get_local_version()
|
||||
|
||||
@@ -32,13 +38,21 @@ class UpdateRoutes:
|
||||
git_info = UpdateRoutes._get_git_info()
|
||||
|
||||
# Fetch remote version from GitHub
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
if nightly:
|
||||
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
||||
else:
|
||||
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
||||
|
||||
# Compare versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
if nightly:
|
||||
# For nightly, compare commit hashes
|
||||
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
||||
else:
|
||||
# For stable, compare semantic versions
|
||||
update_available = UpdateRoutes._compare_versions(
|
||||
local_version.replace('v', ''),
|
||||
remote_version.replace('v', '')
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -46,7 +60,8 @@ class UpdateRoutes:
|
||||
'latest_version': remote_version,
|
||||
'update_available': update_available,
|
||||
'changelog': changelog,
|
||||
'git_info': git_info
|
||||
'git_info': git_info,
|
||||
'nightly': nightly
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
@@ -55,7 +70,7 @@ class UpdateRoutes:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_version_info(request):
|
||||
"""
|
||||
@@ -84,6 +99,251 @@ class UpdateRoutes:
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def perform_update(request):
|
||||
"""
|
||||
Perform Git-based update to latest release tag or main branch.
|
||||
If .git is missing, fallback to ZIP download.
|
||||
"""
|
||||
try:
|
||||
body = await request.json() if request.has_body else {}
|
||||
nightly = body.get('nightly', False)
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
settings_path = os.path.join(plugin_root, 'settings.json')
|
||||
settings_backup = None
|
||||
if os.path.exists(settings_path):
|
||||
with open(settings_path, 'r', encoding='utf-8') as f:
|
||||
settings_backup = f.read()
|
||||
logger.info("Backed up settings.json")
|
||||
|
||||
git_folder = os.path.join(plugin_root, '.git')
|
||||
if os.path.exists(git_folder):
|
||||
# Git update
|
||||
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
||||
else:
|
||||
# Fallback: Download ZIP and replace files
|
||||
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
|
||||
|
||||
if settings_backup and success:
|
||||
with open(settings_path, 'w', encoding='utf-8') as f:
|
||||
f.write(settings_backup)
|
||||
logger.info("Restored settings.json")
|
||||
|
||||
if success:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f'Successfully updated to {new_version}',
|
||||
'new_version': new_version
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Failed to complete update'
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
|
||||
"""
|
||||
Download latest release ZIP from GitHub and replace plugin files.
|
||||
Skips settings.json. Writes extracted file list to .tracking.
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(github_api) as resp:
|
||||
if resp.status != 200:
|
||||
logger.error(f"Failed to fetch release info: {resp.status}")
|
||||
return False, ""
|
||||
data = await resp.json()
|
||||
zip_url = data.get("zipball_url")
|
||||
version = data.get("tag_name", "unknown")
|
||||
|
||||
# Download ZIP
|
||||
async with session.get(zip_url) as zip_resp:
|
||||
if zip_resp.status != 200:
|
||||
logger.error(f"Failed to download ZIP: {zip_resp.status}")
|
||||
return False, ""
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||
tmp_zip.write(await zip_resp.read())
|
||||
zip_path = tmp_zip.name
|
||||
|
||||
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json'])
|
||||
|
||||
# Extract ZIP to temp dir
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(tmp_dir)
|
||||
# Find extracted folder (GitHub ZIP contains a root folder)
|
||||
extracted_root = next(os.scandir(tmp_dir)).path
|
||||
|
||||
# Copy files, skipping settings.json
|
||||
for item in os.listdir(extracted_root):
|
||||
src = os.path.join(extracted_root, item)
|
||||
dst = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(src):
|
||||
if os.path.exists(dst):
|
||||
shutil.rmtree(dst)
|
||||
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json'))
|
||||
else:
|
||||
if item == 'settings.json':
|
||||
continue
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
# Write .tracking file: list all files under extracted_root, relative to extracted_root
|
||||
# for ComfyUI Manager to work properly
|
||||
tracking_info_file = os.path.join(plugin_root, '.tracking')
|
||||
tracking_files = []
|
||||
for root, dirs, files in os.walk(extracted_root):
|
||||
for file in files:
|
||||
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
|
||||
tracking_files.append(rel_path.replace("\\", "/"))
|
||||
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
||||
file.write('\n'.join(tracking_files))
|
||||
|
||||
os.remove(zip_path)
|
||||
logger.info(f"Updated plugin via ZIP to {version}")
|
||||
return True, version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
||||
return False, ""
|
||||
|
||||
def _clean_plugin_folder(plugin_root, skip_files=None):
|
||||
skip_files = skip_files or []
|
||||
for item in os.listdir(plugin_root):
|
||||
if item in skip_files:
|
||||
continue
|
||||
path = os.path.join(plugin_root, item)
|
||||
if os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
else:
|
||||
os.remove(path)
|
||||
|
||||
@staticmethod
|
||||
async def _get_nightly_version() -> tuple[str, List[str]]:
|
||||
"""
|
||||
Fetch latest commit from main branch
|
||||
"""
|
||||
repo_owner = "willmiao"
|
||||
repo_name = "ComfyUI-Lora-Manager"
|
||||
|
||||
# Use GitHub API to fetch the latest commit from main branch
|
||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
|
||||
return "main", []
|
||||
|
||||
data = await response.json()
|
||||
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||
commit_message = data.get('commit', {}).get('message', '')
|
||||
|
||||
# Format as "main-{short_hash}"
|
||||
version = f"main-{commit_sha}"
|
||||
|
||||
# Use commit message as changelog
|
||||
changelog = [commit_message] if commit_message else []
|
||||
|
||||
return version, changelog
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||
return "main", []
|
||||
|
||||
@staticmethod
|
||||
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
||||
"""
|
||||
Compare local commit hash with remote main branch
|
||||
"""
|
||||
try:
|
||||
local_hash = local_git_info.get('short_hash', 'unknown')
|
||||
if local_hash == 'unknown':
|
||||
return True # Assume update available if we can't get local hash
|
||||
|
||||
# Extract remote hash from version string (format: "main-{hash}")
|
||||
if '-' in remote_version:
|
||||
remote_hash = remote_version.split('-')[-1]
|
||||
return local_hash != remote_hash
|
||||
|
||||
return True # Default to update available
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error comparing nightly versions: {e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
||||
"""
|
||||
Perform Git-based update using GitPython
|
||||
|
||||
Args:
|
||||
plugin_root: Path to the plugin root directory
|
||||
nightly: Whether to update to main branch or latest release
|
||||
|
||||
Returns:
|
||||
tuple: (success, new_version)
|
||||
"""
|
||||
try:
|
||||
# Open the Git repository
|
||||
repo = git.Repo(plugin_root)
|
||||
|
||||
# Fetch latest changes
|
||||
origin = repo.remotes.origin
|
||||
origin.fetch()
|
||||
|
||||
if nightly:
|
||||
# Switch to main branch and pull latest
|
||||
main_branch = 'main'
|
||||
if main_branch not in [branch.name for branch in repo.branches]:
|
||||
# Create local main branch if it doesn't exist
|
||||
repo.create_head(main_branch, origin.refs.main)
|
||||
|
||||
repo.heads[main_branch].checkout()
|
||||
origin.pull(main_branch)
|
||||
|
||||
# Get new commit hash
|
||||
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
||||
|
||||
else:
|
||||
# Get latest release tag
|
||||
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
||||
if not tags:
|
||||
logger.error("No tags found in repository")
|
||||
return False, ""
|
||||
|
||||
latest_tag = tags[0]
|
||||
|
||||
# Checkout to latest tag
|
||||
repo.git.checkout(latest_tag.name)
|
||||
|
||||
new_version = latest_tag.name
|
||||
|
||||
logger.info(f"Successfully updated to {new_version}")
|
||||
return True, new_version
|
||||
|
||||
except git.exc.GitError as e:
|
||||
logger.error(f"Git error during update: {e}")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"Error during Git update: {e}")
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def _get_local_version() -> str:
|
||||
"""Get local plugin version from pyproject.toml"""
|
||||
@@ -112,65 +372,28 @@ class UpdateRoutes:
|
||||
"""Get Git repository information"""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
||||
|
||||
|
||||
git_info = {
|
||||
'commit_hash': 'unknown',
|
||||
'short_hash': 'unknown',
|
||||
'short_hash': 'stable',
|
||||
'branch': 'unknown',
|
||||
'commit_date': 'unknown'
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
# Check if we're in a git repository
|
||||
if not os.path.exists(os.path.join(plugin_root, '.git')):
|
||||
return git_info
|
||||
|
||||
# Get current commit hash
|
||||
result = subprocess.run(
|
||||
['git', 'rev-parse', 'HEAD'],
|
||||
cwd=plugin_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False
|
||||
)
|
||||
if result.returncode == 0:
|
||||
git_info['commit_hash'] = result.stdout.strip()
|
||||
git_info['short_hash'] = git_info['commit_hash'][:7]
|
||||
|
||||
# Get current branch name
|
||||
result = subprocess.run(
|
||||
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
|
||||
cwd=plugin_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False
|
||||
)
|
||||
if result.returncode == 0:
|
||||
git_info['branch'] = result.stdout.strip()
|
||||
|
||||
# Get commit date
|
||||
result = subprocess.run(
|
||||
['git', 'show', '-s', '--format=%ci', 'HEAD'],
|
||||
cwd=plugin_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False
|
||||
)
|
||||
if result.returncode == 0:
|
||||
commit_date = result.stdout.strip()
|
||||
# Format the date nicely if possible
|
||||
try:
|
||||
date_obj = datetime.strptime(commit_date, '%Y-%m-%d %H:%M:%S %z')
|
||||
git_info['commit_date'] = date_obj.strftime('%Y-%m-%d')
|
||||
except:
|
||||
git_info['commit_date'] = commit_date
|
||||
|
||||
|
||||
repo = git.Repo(plugin_root)
|
||||
commit = repo.head.commit
|
||||
git_info['commit_hash'] = commit.hexsha
|
||||
git_info['short_hash'] = commit.hexsha[:7]
|
||||
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached'
|
||||
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d')
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting git info: {e}")
|
||||
|
||||
|
||||
return git_info
|
||||
|
||||
@staticmethod
|
||||
|
||||
451
py/services/base_model_service.py
Normal file
451
py/services/base_model_service.py
Normal file
@@ -0,0 +1,451 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Type
|
||||
import logging
|
||||
import os
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from .settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
||||
"""Initialize the service
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.)
|
||||
scanner: Model scanner instance
|
||||
metadata_class: Metadata class for this model type
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, **kwargs) -> Dict:
|
||||
"""Get paginated and filtered model data
|
||||
|
||||
Args:
|
||||
page: Page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
|
||||
folder: Folder filter
|
||||
search: Search term
|
||||
fuzzy_search: Whether to use fuzzy search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Search options dict
|
||||
hash_filters: Hash filtering options
|
||||
favorites_only: Filter for favorites only
|
||||
**kwargs: Additional model-specific filters
|
||||
|
||||
Returns:
|
||||
Dict containing paginated results
|
||||
"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Parse sort_by into sort_key and order
|
||||
if ':' in sort_by:
|
||||
sort_key, order = sort_by.split(':', 1)
|
||||
sort_key = sort_key.strip()
|
||||
order = order.strip().lower()
|
||||
if order not in ('asc', 'desc'):
|
||||
order = 'asc'
|
||||
else:
|
||||
sort_key = sort_by.strip()
|
||||
order = 'asc'
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': True,
|
||||
}
|
||||
|
||||
# Get the base data set using new sort logic
|
||||
filtered_data = await cache.get_sorted_data(sort_key, order)
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
||||
|
||||
# Jump to pagination for hash filters
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
# Apply common filters
|
||||
filtered_data = await self._apply_common_filters(
|
||||
filtered_data, folder, base_models, tags, favorites_only, search_options
|
||||
)
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data, search, fuzzy_search, search_options
|
||||
)
|
||||
|
||||
# Apply model-specific filters
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower()
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes)
|
||||
return [
|
||||
item for item in data
|
||||
if item.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
||||
base_models: list = None, tags: list = None,
|
||||
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
data = [
|
||||
item for item in data
|
||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options and search_options.get('recursive', True):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
# Ensure we match exact folder or its subfolders by checking path boundaries
|
||||
if folder == "":
|
||||
# Empty folder means root - include all items
|
||||
pass # Don't filter anything
|
||||
else:
|
||||
# Add trailing slash to ensure we match folder boundaries correctly
|
||||
folder_with_separator = folder + "/"
|
||||
data = [
|
||||
item for item in data
|
||||
if (item['folder'] == folder or
|
||||
item['folder'].startswith(folder_with_separator))
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if any(tag in item.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_search_filters(self, data: List[Dict], search: str,
|
||||
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
search_results = []
|
||||
|
||||
for item in data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('file_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('file_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('model_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('model_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in item:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
||||
for tag in item['tags']):
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by creator
|
||||
civitai = item.get('civitai')
|
||||
creator_username = ''
|
||||
if civitai and isinstance(civitai, dict):
|
||||
creator = civitai.get('creator')
|
||||
if creator and isinstance(creator, dict):
|
||||
creator_username = creator.get('username', '')
|
||||
if search_options.get('creator', False) and creator_username:
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(creator_username, search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in creator_username.lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
return search_results
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
return data
|
||||
|
||||
def _paginate(self, data: List[Dict], page: int, page_size: int) -> Dict:
|
||||
"""Apply pagination to filtered data"""
|
||||
total_items = len(data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
return {
|
||||
'items': data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def format_response(self, model_data: Dict) -> Dict:
|
||||
"""Format model data for API response - must be implemented by subclasses"""
|
||||
pass
|
||||
|
||||
# Common service methods that delegate to scanner
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get top tags sorted by frequency"""
|
||||
return await self.scanner.get_top_tags(limit)
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict]:
|
||||
"""Get base models sorted by frequency"""
|
||||
return await self.scanner.get_base_models(limit)
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if a model with given hash exists"""
|
||||
return self.scanner.has_hash(sha256)
|
||||
|
||||
def get_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a model by its hash"""
|
||||
return self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self.scanner.get_hash_by_path(file_path)
|
||||
|
||||
async def scan_models(self, force_refresh: bool = False, rebuild_cache: bool = False):
|
||||
"""Trigger model scanning"""
|
||||
return await self.scanner.get_cached_data(force_refresh=force_refresh, rebuild_cache=rebuild_cache)
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
"""Get model information by name"""
|
||||
return await self.scanner.get_model_info_by_name(name)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||
"""Get hierarchical folder tree for a specific model root"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build tree structure from folders
|
||||
tree = {}
|
||||
|
||||
for folder in cache.folders:
|
||||
# Check if this folder belongs to the specified model root
|
||||
folder_belongs_to_root = False
|
||||
for root in self.scanner.get_model_roots():
|
||||
if root == model_root:
|
||||
folder_belongs_to_root = True
|
||||
break
|
||||
|
||||
if not folder_belongs_to_root:
|
||||
continue
|
||||
|
||||
# Split folder path into components
|
||||
parts = folder.split('/') if folder else []
|
||||
current_level = tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return tree
|
||||
|
||||
async def get_unified_folder_tree(self) -> Dict:
|
||||
"""Get unified folder tree across all model roots"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
# Build unified tree structure by analyzing all relative paths
|
||||
unified_tree = {}
|
||||
|
||||
# Get all model roots for path normalization
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for folder in cache.folders:
|
||||
if not folder: # Skip empty folders
|
||||
continue
|
||||
|
||||
# Find which root this folder belongs to by checking the actual file paths
|
||||
# This is a simplified approach - we'll use the folder as-is since it should already be relative
|
||||
relative_path = folder
|
||||
|
||||
# Split folder path into components
|
||||
parts = relative_path.split('/')
|
||||
current_level = unified_tree
|
||||
|
||||
for part in parts:
|
||||
if part not in current_level:
|
||||
current_level[part] = {}
|
||||
current_level = current_level[part]
|
||||
|
||||
return unified_tree
|
||||
|
||||
async def get_model_notes(self, model_name: str) -> Optional[str]:
|
||||
"""Get notes for a specific model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
return model.get('notes', '')
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
|
||||
"""Get the static preview URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
preview_url = model.get('preview_url')
|
||||
if preview_url:
|
||||
from ..config import config
|
||||
return config.get_preview_static_url(preview_url)
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
|
||||
"""Get the Civitai URL for a model file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model['file_name'] == model_name:
|
||||
civitai_data = model.get('civitai', {})
|
||||
model_id = civitai_data.get('modelId')
|
||||
version_id = civitai_data.get('id')
|
||||
|
||||
if model_id:
|
||||
civitai_url = f"https://civitai.com/models/{model_id}"
|
||||
if version_id:
|
||||
civitai_url += f"?modelVersionId={version_id}"
|
||||
|
||||
return {
|
||||
'civitai_url': civitai_url,
|
||||
'model_id': str(model_id),
|
||||
'version_id': str(version_id) if version_id else None
|
||||
}
|
||||
|
||||
return {'civitai_url': None, 'model_id': None, 'version_id': None}
|
||||
|
||||
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
|
||||
"""Get filtered CivitAI metadata for a model by file path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model.get('file_path') == file_path:
|
||||
return ModelRouteUtils.filter_civitai_data(model.get("civitai", {}))
|
||||
|
||||
return None
|
||||
|
||||
async def get_model_description(self, file_path: str) -> Optional[str]:
|
||||
"""Get model description by file path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model.get('file_path') == file_path:
|
||||
return model.get('modelDescription', '')
|
||||
|
||||
return None
|
||||
|
||||
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
|
||||
"""Search model relative file paths for autocomplete functionality"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
matching_paths = []
|
||||
search_lower = search_term.lower()
|
||||
|
||||
# Get model roots for path calculation
|
||||
model_roots = self.scanner.get_model_roots()
|
||||
|
||||
for model in cache.raw_data:
|
||||
file_path = model.get('file_path', '')
|
||||
if not file_path:
|
||||
continue
|
||||
|
||||
# Calculate relative path from model root
|
||||
relative_path = None
|
||||
for root in model_roots:
|
||||
# Normalize paths for comparison
|
||||
normalized_root = os.path.normpath(root)
|
||||
normalized_file = os.path.normpath(file_path)
|
||||
|
||||
if normalized_file.startswith(normalized_root):
|
||||
# Remove root and leading separator to get relative path
|
||||
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
|
||||
break
|
||||
|
||||
if relative_path and search_lower in relative_path.lower():
|
||||
matching_paths.append(relative_path)
|
||||
|
||||
if len(matching_paths) >= limit * 2: # Get more for better sorting
|
||||
break
|
||||
|
||||
# Sort by relevance (exact matches first, then by length)
|
||||
matching_paths.sort(key=lambda x: (
|
||||
not x.lower().startswith(search_lower), # Exact prefix matches first
|
||||
len(x), # Then by length (shorter first)
|
||||
x.lower() # Then alphabetically
|
||||
))
|
||||
|
||||
return matching_paths[:limit]
|
||||
@@ -1,131 +1,34 @@
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Optional, Set
|
||||
import folder_paths # type: ignore
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from .service_registry import ServiceRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointScanner(ModelScanner):
|
||||
"""Service for scanning and managing checkpoint files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
self._checkpoint_roots = self._init_checkpoint_roots()
|
||||
self._initialized = True
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'}
|
||||
super().__init__(
|
||||
model_type="checkpoint",
|
||||
model_class=CheckpointMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
if hasattr(metadata, "model_type"):
|
||||
if root_path in config.checkpoints_roots:
|
||||
metadata.model_type = "checkpoint"
|
||||
elif root_path in config.unet_roots:
|
||||
metadata.model_type = "diffusion_model"
|
||||
return metadata
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
def _init_checkpoint_roots(self) -> List[str]:
|
||||
"""Initialize checkpoint roots from ComfyUI settings"""
|
||||
# Get both checkpoint and diffusion_models paths
|
||||
checkpoint_paths = folder_paths.get_folder_paths("checkpoints")
|
||||
diffusion_paths = folder_paths.get_folder_paths("diffusion_models")
|
||||
|
||||
# Combine, normalize and deduplicate paths
|
||||
all_paths = set()
|
||||
for path in checkpoint_paths + diffusion_paths:
|
||||
if os.path.exists(path):
|
||||
norm_path = path.replace(os.sep, "/")
|
||||
all_paths.add(norm_path)
|
||||
|
||||
# Sort for consistent order
|
||||
sorted_paths = sorted(all_paths, key=lambda p: p.lower())
|
||||
|
||||
return sorted_paths
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get checkpoint root directories"""
|
||||
return self._checkpoint_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all checkpoint directories and return metadata"""
|
||||
all_checkpoints = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for root in self._checkpoint_roots:
|
||||
task = asyncio.create_task(self._scan_directory(root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
checkpoints = await task
|
||||
all_checkpoints.extend(checkpoints)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning checkpoint directory: {e}")
|
||||
|
||||
return all_checkpoints
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a directory for checkpoint files"""
|
||||
checkpoints = []
|
||||
original_root = root_path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
# Check if file has supported extension
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, checkpoints)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return checkpoints
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, checkpoints: list):
|
||||
"""Process a single checkpoint file and add to results"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
checkpoints.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
return config.base_models_roots
|
||||
50
py/services/checkpoint_service.py
Normal file
50
py/services/checkpoint_service.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CheckpointService(BaseModelService):
|
||||
"""Checkpoint-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Checkpoint service
|
||||
|
||||
Args:
|
||||
scanner: Checkpoint scanner instance
|
||||
"""
|
||||
super().__init__("checkpoint", scanner, CheckpointMetadata)
|
||||
|
||||
async def format_response(self, checkpoint_data: Dict) -> Dict:
|
||||
"""Format Checkpoint data for API response"""
|
||||
return {
|
||||
"model_name": checkpoint_data["model_name"],
|
||||
"file_name": checkpoint_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(checkpoint_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": checkpoint_data.get("preview_nsfw_level", 0),
|
||||
"base_model": checkpoint_data.get("base_model", ""),
|
||||
"folder": checkpoint_data["folder"],
|
||||
"sha256": checkpoint_data.get("sha256", ""),
|
||||
"file_path": checkpoint_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": checkpoint_data.get("size", 0),
|
||||
"modified": checkpoint_data.get("modified", ""),
|
||||
"tags": checkpoint_data.get("tags", []),
|
||||
"from_civitai": checkpoint_data.get("from_civitai", True),
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Checkpoints with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Checkpoints with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,13 +1,11 @@
|
||||
from datetime import datetime
|
||||
import aiohttp
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from email.parser import Parser
|
||||
from typing import Optional, Dict, Tuple, List
|
||||
from urllib.parse import unquote
|
||||
from ..utils.models import LoraMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,8 +33,8 @@ class CivitaiClient:
|
||||
}
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
# Set default buffer size to 1MB for higher throughput
|
||||
self.chunk_size = 1024 * 1024
|
||||
# Adjust chunk size based on storage type - consider making this configurable
|
||||
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
@@ -45,14 +43,14 @@ class CivitaiClient:
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=3, # Further reduced from 5 to 3
|
||||
ttl_dns_cache=0, # Disabled DNS caching completely
|
||||
limit=8, # Increase from 3 to 8 for better parallelism
|
||||
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
trust_env = True # Allow using system environment proxy settings
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
||||
# Configure timeout parameters - increase read timeout for large files and remove sock_read timeout
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None)
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=trust_env,
|
||||
@@ -104,7 +102,7 @@ class CivitaiClient:
|
||||
return headers
|
||||
|
||||
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||
"""Download file with content-disposition support and progress tracking
|
||||
"""Download file with resumable downloads and retry mechanism
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
@@ -115,73 +113,176 @@ class CivitaiClient:
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
logger.debug(f"Resolving DNS for: {url}")
|
||||
max_retries = 5
|
||||
retry_count = 0
|
||||
base_delay = 2.0 # Base delay for exponential backoff
|
||||
|
||||
# Initial setup
|
||||
session = await self._ensure_fresh_session()
|
||||
try:
|
||||
headers = self._get_request_headers()
|
||||
|
||||
# Add Range header to allow resumable downloads
|
||||
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
||||
|
||||
logger.debug(f"Starting download from: {url}")
|
||||
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
||||
if response.status != 200:
|
||||
# Handle 401 unauthorized responses
|
||||
if response.status == 401:
|
||||
save_path = os.path.join(save_dir, default_filename)
|
||||
part_path = save_path + '.part'
|
||||
|
||||
# Get existing file size for resume
|
||||
resume_offset = 0
|
||||
if os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||
|
||||
total_size = 0
|
||||
filename = default_filename
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
headers = self._get_request_headers()
|
||||
|
||||
# Add Range header for resume if we have partial data
|
||||
if resume_offset > 0:
|
||||
headers['Range'] = f'bytes={resume_offset}-'
|
||||
|
||||
# Add Range header to allow resumable downloads
|
||||
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
||||
|
||||
logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}")
|
||||
if resume_offset > 0:
|
||||
logger.debug(f"Requesting range from byte {resume_offset}")
|
||||
|
||||
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
||||
# Handle different response codes
|
||||
if response.status == 200:
|
||||
# Full content response
|
||||
if resume_offset > 0:
|
||||
# Server doesn't support ranges, restart from beginning
|
||||
logger.warning("Server doesn't support range requests, restarting download")
|
||||
resume_offset = 0
|
||||
if os.path.exists(part_path):
|
||||
os.remove(part_path)
|
||||
elif response.status == 206:
|
||||
# Partial content response (resume successful)
|
||||
content_range = response.headers.get('Content-Range')
|
||||
if content_range:
|
||||
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
|
||||
range_parts = content_range.split('/')
|
||||
if len(range_parts) == 2:
|
||||
total_size = int(range_parts[1])
|
||||
logger.info(f"Successfully resumed download from byte {resume_offset}")
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable - file might be complete or corrupted
|
||||
if os.path.exists(part_path):
|
||||
part_size = os.path.getsize(part_path)
|
||||
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
|
||||
# Try to get actual file size
|
||||
head_response = await session.head(url, headers=self._get_request_headers())
|
||||
if head_response.status == 200:
|
||||
actual_size = int(head_response.headers.get('content-length', 0))
|
||||
if part_size == actual_size:
|
||||
# File is complete, just rename it
|
||||
os.rename(part_path, save_path)
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
return True, save_path
|
||||
# Remove corrupted part file and restart
|
||||
os.remove(part_path)
|
||||
resume_offset = 0
|
||||
continue
|
||||
elif response.status == 401:
|
||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||
|
||||
return False, "Invalid or missing CivitAI API key, or early access restriction."
|
||||
|
||||
# Handle other client errors that might be permission-related
|
||||
if response.status == 403:
|
||||
elif response.status == 403:
|
||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||
return False, "Access forbidden: You don't have permission to download this file."
|
||||
else:
|
||||
logger.error(f"Download failed for {url} with status {response.status}")
|
||||
return False, f"Download failed with status {response.status}"
|
||||
|
||||
# Generic error response for other status codes
|
||||
logger.error(f"Download failed for {url} with status {response.status}")
|
||||
return False, f"Download failed with status {response.status}"
|
||||
# Get total file size for progress calculation (if not set from Content-Range)
|
||||
if total_size == 0:
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
if response.status == 206:
|
||||
# For partial content, add the offset to get total file size
|
||||
total_size += resume_offset
|
||||
|
||||
# Get filename from content-disposition header
|
||||
content_disposition = response.headers.get('Content-Disposition')
|
||||
filename = self._parse_content_disposition(content_disposition)
|
||||
if not filename:
|
||||
filename = default_filename
|
||||
|
||||
save_path = os.path.join(save_dir, filename)
|
||||
|
||||
# Get total file size for progress calculation
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
current_size = 0
|
||||
last_progress_report_time = datetime.now()
|
||||
current_size = resume_offset
|
||||
last_progress_report_time = datetime.now()
|
||||
|
||||
# Stream download to file with progress updates using larger buffer
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and total_size and time_diff >= 0.5:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
# Stream download to file with progress updates using larger buffer
|
||||
loop = asyncio.get_running_loop()
|
||||
mode = 'ab' if resume_offset > 0 else 'wb'
|
||||
with open(part_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||
if chunk:
|
||||
# Run blocking file write in executor
|
||||
await loop.run_in_executor(None, f.write, chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and total_size and time_diff >= 1.0:
|
||||
progress = (current_size / total_size) * 100
|
||||
await progress_callback(progress)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Download completed successfully
|
||||
# Verify file size if total_size was provided
|
||||
final_size = os.path.getsize(part_path)
|
||||
if total_size > 0 and final_size != total_size:
|
||||
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
|
||||
# Don't treat this as fatal error, rename anyway
|
||||
|
||||
# Atomically rename .part to final file with retries
|
||||
max_rename_attempts = 5
|
||||
rename_attempt = 0
|
||||
rename_success = False
|
||||
|
||||
while rename_attempt < max_rename_attempts and not rename_success:
|
||||
try:
|
||||
os.rename(part_path, save_path)
|
||||
rename_success = True
|
||||
except PermissionError as e:
|
||||
rename_attempt += 1
|
||||
if rename_attempt < max_rename_attempts:
|
||||
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
|
||||
await asyncio.sleep(2) # Wait before retrying
|
||||
else:
|
||||
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
||||
return False, f"Failed to finalize download: {str(e)}"
|
||||
|
||||
# Ensure 100% progress is reported
|
||||
if progress_callback:
|
||||
await progress_callback(100)
|
||||
|
||||
return True, save_path
|
||||
return True, save_path
|
||||
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
|
||||
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
|
||||
retry_count += 1
|
||||
logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Network error during download: {e}")
|
||||
return False, f"Network error: {str(e)}"
|
||||
except Exception as e:
|
||||
logger.error(f"Download error: {e}")
|
||||
return False, str(e)
|
||||
if retry_count <= max_retries:
|
||||
# Calculate delay with exponential backoff
|
||||
delay = base_delay * (2 ** (retry_count - 1))
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Update resume offset for next attempt
|
||||
if os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Will resume from byte {resume_offset}")
|
||||
|
||||
# Refresh session to get new connection
|
||||
await self.close()
|
||||
session = await self._ensure_fresh_session()
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for download: {e}")
|
||||
return False, f"Network error after {max_retries + 1} attempts: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected download error: {e}")
|
||||
return False, str(e)
|
||||
|
||||
return False, f"Download failed after {max_retries + 1} attempts"
|
||||
|
||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||
try:
|
||||
@@ -225,11 +326,11 @@ class CivitaiClient:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return None
|
||||
|
||||
async def get_model_version(self, model_id: str, version_id: str = "") -> Optional[Dict]:
|
||||
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||
"""Get specific model version with additional metadata
|
||||
|
||||
Args:
|
||||
model_id: The Civitai model ID
|
||||
model_id: The Civitai model ID (optional if version_id is provided)
|
||||
version_id: Optional specific version ID to retrieve
|
||||
|
||||
Returns:
|
||||
@@ -237,52 +338,72 @@ class CivitaiClient:
|
||||
"""
|
||||
try:
|
||||
session = await self._ensure_fresh_session()
|
||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
headers = self._get_request_headers()
|
||||
|
||||
# Case 1: Only version_id is provided
|
||||
if model_id is None and version_id is not None:
|
||||
# First get the version info to extract model_id
|
||||
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
model_versions = data.get('modelVersions', [])
|
||||
|
||||
# Find matching version
|
||||
matched_version = None
|
||||
|
||||
if version_id:
|
||||
# If version_id provided, find exact match
|
||||
for version in model_versions:
|
||||
if str(version.get('id')) == str(version_id):
|
||||
matched_version = version
|
||||
break
|
||||
else:
|
||||
# If no version_id then use the first version
|
||||
matched_version = model_versions[0] if model_versions else None
|
||||
|
||||
# If no match found, return None
|
||||
if not matched_version:
|
||||
return None
|
||||
version = await response.json()
|
||||
model_id = version.get('modelId')
|
||||
|
||||
# Build result with modified fields
|
||||
result = matched_version.copy() # Copy to avoid modifying original
|
||||
if not model_id:
|
||||
logger.error(f"No modelId found in version {version_id}")
|
||||
return None
|
||||
|
||||
# Replace index with modelId
|
||||
if 'index' in result:
|
||||
del result['index']
|
||||
result['modelId'] = model_id
|
||||
# Now get the model data for additional metadata
|
||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||
if response.status != 200:
|
||||
return version # Return version without additional metadata
|
||||
|
||||
model_data = await response.json()
|
||||
|
||||
# Enrich version with model data
|
||||
version['model']['description'] = model_data.get("description")
|
||||
version['model']['tags'] = model_data.get("tags", [])
|
||||
version['creator'] = model_data.get("creator")
|
||||
|
||||
return version
|
||||
|
||||
# Case 2: model_id is provided (with or without version_id)
|
||||
elif model_id is not None:
|
||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
model_versions = data.get('modelVersions', [])
|
||||
|
||||
# Step 2: Determine the version_id to use
|
||||
target_version_id = version_id
|
||||
if target_version_id is None:
|
||||
target_version_id = model_versions[0].get('id')
|
||||
|
||||
# Add model field with metadata from top level
|
||||
result['model'] = {
|
||||
"name": data.get("name"),
|
||||
"type": data.get("type"),
|
||||
"nsfw": data.get("nsfw", False),
|
||||
"poi": data.get("poi", False),
|
||||
"description": data.get("description"),
|
||||
"tags": data.get("tags", [])
|
||||
}
|
||||
|
||||
# Add creator field from top level
|
||||
result['creator'] = data.get("creator")
|
||||
|
||||
return result
|
||||
# Step 3: Get detailed version info using the version_id
|
||||
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
version = await response.json()
|
||||
|
||||
# Step 4: Enrich version_info with model data
|
||||
# Add description and tags from model data
|
||||
version['model']['description'] = data.get("description")
|
||||
version['model']['tags'] = data.get("tags", [])
|
||||
|
||||
# Add creator from model data
|
||||
version['creator'] = data.get("creator")
|
||||
|
||||
return version
|
||||
|
||||
# Case 3: Neither model_id nor version_id provided
|
||||
else:
|
||||
logger.error("Either model_id or version_id must be provided")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model version: {e}")
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
import uuid
|
||||
from typing import Dict
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import settings
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -33,6 +35,10 @@ class DownloadManager:
|
||||
self._initialized = True
|
||||
|
||||
self._civitai_client = None # Will be lazily initialized
|
||||
# Add download management
|
||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
|
||||
async def _get_civitai_client(self):
|
||||
"""Lazily initialize CivitaiClient from registry"""
|
||||
@@ -47,55 +53,218 @@ class DownloadManager:
|
||||
async def _get_checkpoint_scanner(self):
|
||||
"""Get the checkpoint scanner from registry"""
|
||||
return await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
async def download_from_civitai(self, download_url: str = None, model_hash: str = None,
|
||||
model_version_id: str = None, save_dir: str = None,
|
||||
relative_path: str = '', progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
"""Download model from Civitai
|
||||
|
||||
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
||||
save_dir: str = None, relative_path: str = '',
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
download_id: str = None) -> Dict:
|
||||
"""Download model from Civitai with task tracking and concurrency control
|
||||
|
||||
Args:
|
||||
download_url: Direct download URL for the model
|
||||
model_hash: SHA256 hash of the model
|
||||
model_version_id: Civitai model version ID
|
||||
save_dir: Directory to save the model to
|
||||
model_id: Civitai model ID (optional if model_version_id is provided)
|
||||
model_version_id: Civitai model version ID (optional if model_id is provided)
|
||||
save_dir: Directory to save the model
|
||||
relative_path: Relative path within save_dir
|
||||
progress_callback: Callback function for progress updates
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
use_default_paths: Flag to use default paths
|
||||
download_id: Unique identifier for this download task
|
||||
|
||||
Returns:
|
||||
Dict with download result
|
||||
"""
|
||||
# Validate that at least one identifier is provided
|
||||
if not model_id and not model_version_id:
|
||||
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
|
||||
|
||||
# Use provided download_id or generate new one
|
||||
task_id = download_id or str(uuid.uuid4())
|
||||
|
||||
# Register download task in tracking dict
|
||||
self._active_downloads[task_id] = {
|
||||
'model_id': model_id,
|
||||
'model_version_id': model_version_id,
|
||||
'progress': 0,
|
||||
'status': 'queued'
|
||||
}
|
||||
|
||||
# Create tracking task
|
||||
download_task = asyncio.create_task(
|
||||
self._download_with_semaphore(
|
||||
task_id, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths
|
||||
)
|
||||
)
|
||||
|
||||
# Store task for tracking and cancellation
|
||||
self._download_tasks[task_id] = download_task
|
||||
|
||||
try:
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
# Wait for download to complete
|
||||
result = await download_task
|
||||
result['download_id'] = task_id # Include download_id in result
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
|
||||
finally:
|
||||
# Clean up task reference
|
||||
if task_id in self._download_tasks:
|
||||
del self._download_tasks[task_id]
|
||||
|
||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||
save_dir: str, relative_path: str,
|
||||
progress_callback=None, use_default_paths: bool = False):
|
||||
"""Execute download with semaphore to limit concurrency"""
|
||||
# Update status to waiting
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'waiting'
|
||||
|
||||
# Wrap progress callback to track progress in active_downloads
|
||||
original_callback = progress_callback
|
||||
async def tracking_callback(progress):
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['progress'] = progress
|
||||
if original_callback:
|
||||
await original_callback(progress)
|
||||
|
||||
# Acquire semaphore to limit concurrent downloads
|
||||
try:
|
||||
async with self._download_semaphore:
|
||||
# Update status to downloading
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'downloading'
|
||||
|
||||
# Use original download implementation
|
||||
try:
|
||||
# Check for cancellation before starting
|
||||
if asyncio.current_task().cancelled():
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
result = await self._execute_original_download(
|
||||
model_id, model_version_id, save_dir,
|
||||
relative_path, tracking_callback, use_default_paths,
|
||||
task_id
|
||||
)
|
||||
|
||||
# Update status based on result
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
||||
if not result['success']:
|
||||
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
||||
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'cancelled'
|
||||
logger.info(f"Download cancelled for task {task_id}")
|
||||
raise
|
||||
except Exception as e:
|
||||
# Handle other errors
|
||||
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'failed'
|
||||
self._active_downloads[task_id]['error'] = str(e)
|
||||
return {'success': False, 'error': str(e)}
|
||||
finally:
|
||||
# Schedule cleanup of download record after delay
|
||||
asyncio.create_task(self._cleanup_download_record(task_id))
|
||||
|
||||
async def _cleanup_download_record(self, task_id: str):
|
||||
"""Keep completed downloads in history for a short time"""
|
||||
await asyncio.sleep(600) # Keep for 10 minutes
|
||||
if task_id in self._active_downloads:
|
||||
del self._active_downloads[task_id]
|
||||
|
||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths,
|
||||
download_id=None):
|
||||
"""Wrapper for original download_from_civitai implementation"""
|
||||
try:
|
||||
# Check if model version already exists in library
|
||||
if model_version_id is not None:
|
||||
# Check both scanners
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
# Check lora scanner first
|
||||
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
|
||||
# Check checkpoint scanner
|
||||
if await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
|
||||
# Check embedding scanner
|
||||
if await embedding_scanner.check_model_version_exists(model_version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Get civitai client
|
||||
civitai_client = await self._get_civitai_client()
|
||||
|
||||
# Get version info based on the provided identifier
|
||||
version_info = None
|
||||
error_msg = None
|
||||
|
||||
if model_hash:
|
||||
# Get model by hash
|
||||
version_info = await civitai_client.get_model_by_hash(model_hash)
|
||||
elif model_version_id:
|
||||
# Use model version ID directly
|
||||
version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
|
||||
elif download_url:
|
||||
# Extract version ID from download URL
|
||||
version_id = download_url.split('/')[-1]
|
||||
version_info, error_msg = await civitai_client.get_model_version_info(version_id)
|
||||
|
||||
version_info = await civitai_client.get_model_version(model_id, model_version_id)
|
||||
|
||||
if not version_info:
|
||||
if error_msg and "model not found" in error_msg.lower():
|
||||
return {'success': False, 'error': f'Model not found on Civitai: {error_msg}'}
|
||||
return {'success': False, 'error': error_msg or 'Failed to fetch model metadata'}
|
||||
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
||||
|
||||
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
||||
if model_type_from_info == 'checkpoint':
|
||||
model_type = 'checkpoint'
|
||||
elif model_type_from_info in VALID_LORA_TYPES:
|
||||
model_type = 'lora'
|
||||
elif model_type_from_info == 'textualinversion':
|
||||
model_type = 'embedding'
|
||||
else:
|
||||
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
||||
|
||||
# Case 2: model_version_id was None, check after getting version_info
|
||||
if model_version_id is None:
|
||||
version_id = version_info.get('id')
|
||||
|
||||
if model_type == 'lora':
|
||||
# Check lora scanner
|
||||
lora_scanner = await self._get_lora_scanner()
|
||||
if await lora_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||
elif model_type == 'checkpoint':
|
||||
# Check checkpoint scanner
|
||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||
if await checkpoint_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||
elif model_type == 'embedding':
|
||||
# Embeddings are not checked in scanners, but we can still check if it exists
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
if await embedding_scanner.check_model_version_exists(version_id):
|
||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||
|
||||
# Handle use_default_paths
|
||||
if use_default_paths:
|
||||
# Set save_dir based on model type
|
||||
if model_type == 'checkpoint':
|
||||
default_path = settings.get('default_checkpoint_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'lora':
|
||||
default_path = settings.get('default_lora_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
||||
save_dir = default_path
|
||||
elif model_type == 'embedding':
|
||||
default_path = settings.get('default_embedding_root')
|
||||
if not default_path:
|
||||
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
||||
save_dir = default_path
|
||||
|
||||
# Calculate relative path using template
|
||||
relative_path = self._calculate_relative_path(version_info, model_type)
|
||||
|
||||
# Update save directory with relative path if provided
|
||||
if relative_path:
|
||||
save_dir = os.path.join(save_dir, relative_path)
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# Check if this is an early access model
|
||||
if version_info.get('earlyAccessEndsAt'):
|
||||
@@ -105,9 +274,9 @@ class DownloadManager:
|
||||
from datetime import datetime
|
||||
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
|
||||
formatted_date = date_obj.strftime('%Y-%m-%d')
|
||||
early_access_msg = f"This model requires early access payment (until {formatted_date}). "
|
||||
early_access_msg = f"This model requires payment (until {formatted_date}). "
|
||||
except:
|
||||
early_access_msg = "This model requires early access payment. "
|
||||
early_access_msg = "This model requires payment. "
|
||||
|
||||
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
|
||||
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
|
||||
@@ -133,21 +302,12 @@ class DownloadManager:
|
||||
if model_type == "checkpoint":
|
||||
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating CheckpointMetadata for {file_name}")
|
||||
else:
|
||||
elif model_type == "lora":
|
||||
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating LoraMetadata for {file_name}")
|
||||
|
||||
# 5.1 Get and update model tags, description and creator info
|
||||
model_id = version_info.get('modelId')
|
||||
if model_id:
|
||||
model_metadata, _ = await civitai_client.get_model_metadata(str(model_id))
|
||||
if model_metadata:
|
||||
if model_metadata.get("tags"):
|
||||
metadata.tags = model_metadata.get("tags", [])
|
||||
if model_metadata.get("description"):
|
||||
metadata.modelDescription = model_metadata.get("description", "")
|
||||
if model_metadata.get("creator"):
|
||||
metadata.civitai["creator"] = model_metadata.get("creator")
|
||||
elif model_type == "embedding":
|
||||
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
||||
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
||||
|
||||
# 6. Start download process
|
||||
result = await self._execute_download(
|
||||
@@ -157,9 +317,14 @@ class DownloadManager:
|
||||
version_info=version_info,
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
model_type=model_type,
|
||||
download_id=download_id
|
||||
)
|
||||
|
||||
# If early_access_msg exists and download failed, replace error message
|
||||
if 'early_access_msg' in locals() and not result.get('success', False):
|
||||
result['error'] = early_access_msg
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
@@ -170,15 +335,100 @@ class DownloadManager:
|
||||
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str:
|
||||
"""Calculate relative path using template from settings
|
||||
|
||||
Args:
|
||||
version_info: Version info from Civitai API
|
||||
model_type: Type of model ('lora', 'checkpoint', 'embedding')
|
||||
|
||||
Returns:
|
||||
Relative path string
|
||||
"""
|
||||
# Get path template from settings for specific model type
|
||||
path_template = settings.get_download_path_template(model_type)
|
||||
|
||||
# If template is empty, return empty path (flat structure)
|
||||
if not path_template:
|
||||
return ''
|
||||
|
||||
# Get base model name
|
||||
base_model = version_info.get('baseModel', '')
|
||||
|
||||
# Get author from creator data
|
||||
creator_info = version_info.get('creator')
|
||||
if creator_info and isinstance(creator_info, dict):
|
||||
author = creator_info.get('username') or 'Anonymous'
|
||||
else:
|
||||
author = 'Anonymous'
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||
|
||||
# Get model tags
|
||||
model_tags = version_info.get('model', {}).get('tags', [])
|
||||
|
||||
# Find the first Civitai model tag that exists in model_tags
|
||||
first_tag = ''
|
||||
for civitai_tag in CIVITAI_MODEL_TAGS:
|
||||
if civitai_tag in model_tags:
|
||||
first_tag = civitai_tag
|
||||
break
|
||||
|
||||
# If no Civitai model tag found, fallback to first tag
|
||||
if not first_tag and model_tags:
|
||||
first_tag = model_tags[0]
|
||||
|
||||
# Format the template with available data
|
||||
formatted_path = path_template
|
||||
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
|
||||
formatted_path = formatted_path.replace('{first_tag}', first_tag)
|
||||
formatted_path = formatted_path.replace('{author}', author)
|
||||
|
||||
return formatted_path
|
||||
|
||||
async def _execute_download(self, download_url: str, save_dir: str,
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora") -> Dict:
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
civitai_client = await self._get_civitai_client()
|
||||
save_path = metadata.file_path
|
||||
|
||||
# Extract original filename details
|
||||
original_filename = os.path.basename(metadata.file_path)
|
||||
base_name, extension = os.path.splitext(original_filename)
|
||||
|
||||
# Check for filename conflicts and generate unique filename if needed
|
||||
# Use the hash from metadata for conflict resolution
|
||||
def hash_provider():
|
||||
return metadata.sha256
|
||||
|
||||
unique_filename = metadata.generate_unique_filename(
|
||||
save_dir,
|
||||
base_name,
|
||||
extension,
|
||||
hash_provider=hash_provider
|
||||
)
|
||||
|
||||
# Update paths if filename changed
|
||||
if unique_filename != original_filename:
|
||||
logger.info(f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'")
|
||||
save_path = os.path.join(save_dir, unique_filename)
|
||||
# Update metadata with new file path and name
|
||||
metadata.file_path = save_path.replace(os.sep, '/')
|
||||
metadata.file_name = os.path.splitext(unique_filename)[0]
|
||||
else:
|
||||
save_path = metadata.file_path
|
||||
|
||||
part_path = save_path + '.part'
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
# Store file paths in active_downloads for potential cleanup
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['file_path'] = save_path
|
||||
self._active_downloads[download_id]['part_path'] = part_path
|
||||
|
||||
# Download preview image if available
|
||||
images = version_info.get('images', [])
|
||||
@@ -245,25 +495,40 @@ class DownloadManager:
|
||||
)
|
||||
|
||||
if not success:
|
||||
# Clean up files on failure
|
||||
for path in [save_path, metadata_path, metadata.preview_url]:
|
||||
# Clean up files on failure, but preserve .part file for resume
|
||||
cleanup_files = [metadata_path]
|
||||
if metadata.preview_url and os.path.exists(metadata.preview_url):
|
||||
cleanup_files.append(metadata.preview_url)
|
||||
|
||||
for path in cleanup_files:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
||||
|
||||
# Log but don't remove .part file to allow resume
|
||||
if os.path.exists(part_path):
|
||||
logger.info(f"Preserving partial download for resume: {part_path}")
|
||||
|
||||
return {'success': False, 'error': result}
|
||||
|
||||
# 4. Update file information (size and modified time)
|
||||
metadata.update_file_info(save_path)
|
||||
|
||||
# 5. Final metadata update
|
||||
await MetadataManager.save_metadata(save_path, metadata, True)
|
||||
await MetadataManager.save_metadata(save_path, metadata)
|
||||
|
||||
# 6. Update cache based on model type
|
||||
if model_type == "checkpoint":
|
||||
scanner = await self._get_checkpoint_scanner()
|
||||
logger.info(f"Updating checkpoint cache for {save_path}")
|
||||
else:
|
||||
elif model_type == "lora":
|
||||
scanner = await self._get_lora_scanner()
|
||||
logger.info(f"Updating lora cache for {save_path}")
|
||||
elif model_type == "embedding":
|
||||
scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
logger.info(f"Updating embedding cache for {save_path}")
|
||||
|
||||
# Convert metadata to dictionary
|
||||
metadata_dict = metadata.to_dict()
|
||||
@@ -281,10 +546,18 @@ class DownloadManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
||||
# Clean up partial downloads
|
||||
for path in [save_path, metadata_path]:
|
||||
# Clean up partial downloads except .part file
|
||||
cleanup_files = [metadata_path]
|
||||
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
|
||||
cleanup_files.append(metadata.preview_url)
|
||||
|
||||
for path in cleanup_files:
|
||||
if path and os.path.exists(path):
|
||||
os.remove(path)
|
||||
try:
|
||||
os.remove(path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to cleanup file {path}: {e}")
|
||||
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
||||
@@ -297,4 +570,99 @@ class DownloadManager:
|
||||
if progress_callback:
|
||||
# Scale file progress to 3-100 range (after preview download)
|
||||
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
|
||||
await progress_callback(round(overall_progress))
|
||||
await progress_callback(round(overall_progress))
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict:
|
||||
"""Cancel an active download by download_id
|
||||
|
||||
Args:
|
||||
download_id: The unique identifier of the download task
|
||||
|
||||
Returns:
|
||||
Dict: Status of the cancellation operation
|
||||
"""
|
||||
if download_id not in self._download_tasks:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
try:
|
||||
# Get the task and cancel it
|
||||
task = self._download_tasks[download_id]
|
||||
task.cancel()
|
||||
|
||||
# Update status in active downloads
|
||||
if download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||
|
||||
# Wait briefly for the task to acknowledge cancellation
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
|
||||
# Clean up ALL files including .part when user cancels
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info:
|
||||
# Delete the main file
|
||||
if 'file_path' in download_info:
|
||||
file_path = download_info['file_path']
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
logger.debug(f"Deleted cancelled download: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting file: {e}")
|
||||
|
||||
# Delete the .part file (only on user cancellation)
|
||||
if 'part_path' in download_info:
|
||||
part_path = download_info['part_path']
|
||||
if os.path.exists(part_path):
|
||||
try:
|
||||
os.unlink(part_path)
|
||||
logger.debug(f"Deleted partial download: {part_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting part file: {e}")
|
||||
|
||||
# Delete metadata file if exists
|
||||
if 'file_path' in download_info:
|
||||
file_path = download_info['file_path']
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
os.unlink(metadata_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting metadata file: {e}")
|
||||
|
||||
# Delete preview file if exists (.webp or .mp4)
|
||||
for preview_ext in ['.webp', '.mp4']:
|
||||
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
||||
if os.path.exists(preview_path):
|
||||
try:
|
||||
os.unlink(preview_path)
|
||||
logger.debug(f"Deleted preview file: {preview_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting preview file: {e}")
|
||||
|
||||
return {'success': True, 'message': 'Download cancelled successfully'}
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
|
||||
async def get_active_downloads(self) -> Dict:
|
||||
"""Get information about all active downloads
|
||||
|
||||
Returns:
|
||||
Dict: List of active downloads and their status
|
||||
"""
|
||||
return {
|
||||
'downloads': [
|
||||
{
|
||||
'download_id': task_id,
|
||||
'model_id': info.get('model_id'),
|
||||
'model_version_id': info.get('model_version_id'),
|
||||
'progress': info.get('progress', 0),
|
||||
'status': info.get('status', 'unknown'),
|
||||
'error': info.get('error', None)
|
||||
}
|
||||
for task_id, info in self._active_downloads.items()
|
||||
]
|
||||
}
|
||||
26
py/services/embedding_scanner.py
Normal file
26
py/services/embedding_scanner.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingScanner(ModelScanner):
|
||||
"""Service for scanning and managing embedding files"""
|
||||
|
||||
def __init__(self):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||
super().__init__(
|
||||
model_type="embedding",
|
||||
model_class=EmbeddingMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex()
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get embedding root directories"""
|
||||
return config.embeddings_roots
|
||||
50
py/services/embedding_service.py
Normal file
50
py/services/embedding_service.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingService(BaseModelService):
|
||||
"""Embedding-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize Embedding service
|
||||
|
||||
Args:
|
||||
scanner: Embedding scanner instance
|
||||
"""
|
||||
super().__init__("embedding", scanner, EmbeddingMetadata)
|
||||
|
||||
async def format_response(self, embedding_data: Dict) -> Dict:
|
||||
"""Format Embedding data for API response"""
|
||||
return {
|
||||
"model_name": embedding_data["model_name"],
|
||||
"file_name": embedding_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(embedding_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": embedding_data.get("preview_nsfw_level", 0),
|
||||
"base_model": embedding_data.get("base_model", ""),
|
||||
"folder": embedding_data["folder"],
|
||||
"sha256": embedding_data.get("sha256", ""),
|
||||
"file_path": embedding_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": embedding_data.get("size", 0),
|
||||
"modified": embedding_data.get("modified", ""),
|
||||
"tags": embedding_data.get("tags", []),
|
||||
"from_civitai": embedding_data.get("from_civitai", True),
|
||||
"notes": embedding_data.get("notes", ""),
|
||||
"model_type": embedding_data.get("model_type", "embedding"),
|
||||
"favorite": embedding_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find Embeddings with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find Embeddings with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,65 +0,0 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
@dataclass
|
||||
class LoraCache:
|
||||
"""Cache structure for LoRA data"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str) -> bool:
|
||||
"""Update preview_url for a specific lora in all cached data
|
||||
|
||||
Args:
|
||||
file_path: The file path of the lora to update
|
||||
preview_url: The new preview URL
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the lora wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
else:
|
||||
return False # Lora not found
|
||||
|
||||
# Update in sorted lists (references to the same dict objects)
|
||||
for item in self.sorted_by_name:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
for item in self.sorted_by_date:
|
||||
if item['file_path'] == file_path:
|
||||
item['preview_url'] = preview_url
|
||||
break
|
||||
|
||||
return True
|
||||
@@ -1,20 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import shutil
|
||||
import time
|
||||
import re
|
||||
from typing import List, Dict, Optional, Set
|
||||
from typing import List
|
||||
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from .model_scanner import ModelScanner
|
||||
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||
from .settings_manager import settings
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match
|
||||
from .service_registry import ServiceRegistry
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,404 +12,21 @@ logger = logging.getLogger(__name__)
|
||||
class LoraScanner(ModelScanner):
|
||||
"""Service for scanning and managing LoRA files"""
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
# Ensure initialization happens only once
|
||||
if not hasattr(self, '_initialized'):
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
async with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
# Define supported file extensions
|
||||
file_extensions = {'.safetensors'}
|
||||
|
||||
# Initialize parent class with ModelHashIndex
|
||||
super().__init__(
|
||||
model_type="lora",
|
||||
model_class=LoraMetadata,
|
||||
file_extensions=file_extensions,
|
||||
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||
)
|
||||
|
||||
def get_model_roots(self) -> List[str]:
|
||||
"""Get lora root directories"""
|
||||
return config.loras_roots
|
||||
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all LoRA directories and return metadata"""
|
||||
all_loras = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for lora_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(lora_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
loras = await task
|
||||
all_loras.extend(loras)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_loras
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for LoRA files"""
|
||||
loras = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
# Use original path instead of real path
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, loras)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
# For directories, continue scanning with original path
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return loras
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, loras: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
loras.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, first_letter: str = None) -> Dict:
|
||||
"""Get paginated and filtered lora data
|
||||
|
||||
Args:
|
||||
page: Current page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort method ('name' or 'date')
|
||||
folder: Filter by folder path
|
||||
search: Search term
|
||||
fuzzy_search: Use fuzzy matching for search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Dictionary with search options (filename, modelname, tags, recursive)
|
||||
hash_filters: Dictionary with hash filtering options (single_hash or multiple_hashes)
|
||||
favorites_only: Filter for favorite models only
|
||||
first_letter: Filter by first letter of model name
|
||||
"""
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': False,
|
||||
}
|
||||
|
||||
# Get the base data set
|
||||
filtered_data = cache.sorted_by_date if sort_by == 'date' else cache.sorted_by_name
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
single_hash = hash_filters.get('single_hash')
|
||||
multiple_hashes = hash_filters.get('multiple_hashes')
|
||||
|
||||
if single_hash:
|
||||
# Filter by single hash
|
||||
single_hash = single_hash.lower() # Ensure lowercase for matching
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() == single_hash
|
||||
]
|
||||
elif multiple_hashes:
|
||||
# Filter by multiple hashes
|
||||
hash_set = set(hash.lower() for hash in multiple_hashes) # Convert to set for faster lookup
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('sha256', '').lower() in hash_set
|
||||
]
|
||||
|
||||
|
||||
# Jump to pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
# Apply SFW filtering if enabled
|
||||
if settings.get('show_only_sfw', False):
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if not lora.get('preview_nsfw_level') or lora.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply first letter filtering
|
||||
if first_letter:
|
||||
filtered_data = self._filter_by_first_letter(filtered_data, first_letter)
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options.get('recursive', False):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'].startswith(folder)
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if lora.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
filtered_data = [
|
||||
lora for lora in filtered_data
|
||||
if any(tag in lora.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
# Apply search filtering
|
||||
if search:
|
||||
search_results = []
|
||||
search_opts = search_options or {}
|
||||
|
||||
for lora in filtered_data:
|
||||
# Search by file name
|
||||
if search_opts.get('filename', True):
|
||||
if fuzzy_match(lora.get('file_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_opts.get('modelname', True):
|
||||
if fuzzy_match(lora.get('model_name', ''), search):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_opts.get('tags', False) and 'tags' in lora:
|
||||
if any(fuzzy_match(tag, search) for tag in lora['tags']):
|
||||
search_results.append(lora)
|
||||
continue
|
||||
|
||||
filtered_data = search_results
|
||||
|
||||
# Calculate pagination
|
||||
total_items = len(filtered_data)
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = min(start_idx + page_size, total_items)
|
||||
|
||||
result = {
|
||||
'items': filtered_data[start_idx:end_idx],
|
||||
'total': total_items,
|
||||
'page': page,
|
||||
'page_size': page_size,
|
||||
'total_pages': (total_items + page_size - 1) // page_size
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _filter_by_first_letter(self, data, letter):
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char):
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
async def get_letter_counts(self):
|
||||
"""Get count of models for each letter of the alphabet"""
|
||||
cache = await self.get_cached_data()
|
||||
data = cache.sorted_by_name
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
# Lora-specific hash index functionality
|
||||
def has_lora_hash(self, sha256: str) -> bool:
|
||||
"""Check if a LoRA with given hash exists"""
|
||||
return self.has_hash(sha256)
|
||||
|
||||
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a LoRA by its hash"""
|
||||
return self.get_path_by_hash(sha256)
|
||||
|
||||
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a LoRA by its file path"""
|
||||
return self.get_hash_by_path(file_path)
|
||||
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get top tags sorted by count"""
|
||||
# Make sure cache is initialized
|
||||
await self.get_cached_data()
|
||||
|
||||
# Sort tags by count in descending order
|
||||
sorted_tags = sorted(
|
||||
[{"tag": tag, "count": count} for tag, count in self._tags_count.items()],
|
||||
key=lambda x: x['count'],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Return limited number
|
||||
return sorted_tags[:limit]
|
||||
|
||||
async def get_base_models(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get base models used in loras sorted by frequency"""
|
||||
# Make sure cache is initialized
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Count base model occurrences
|
||||
base_model_counts = {}
|
||||
for lora in cache.raw_data:
|
||||
if 'base_model' in lora and lora['base_model']:
|
||||
base_model = lora['base_model']
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
# Sort base models by count
|
||||
sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda x: x['count'], reverse=True)
|
||||
|
||||
# Return limited number
|
||||
return sorted_models[:limit]
|
||||
|
||||
async def diagnose_hash_index(self):
|
||||
"""Diagnostic method to verify hash index functionality"""
|
||||
@@ -456,19 +63,3 @@ class LoraScanner(ModelScanner):
|
||||
test_hash_result = self._hash_index.get_hash(test_path)
|
||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||
|
||||
async def get_lora_info_by_name(self, name):
|
||||
"""Get LoRA information by name"""
|
||||
try:
|
||||
# Get cached data
|
||||
cache = await self.get_cached_data()
|
||||
|
||||
# Find the LoRA by name
|
||||
for lora in cache.raw_data:
|
||||
if lora.get("file_name") == name:
|
||||
return lora
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LoRA info by name: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
182
py/services/lora_service.py
Normal file
182
py/services/lora_service.py
Normal file
@@ -0,0 +1,182 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LoraService(BaseModelService):
|
||||
"""LoRA-specific service implementation"""
|
||||
|
||||
def __init__(self, scanner):
|
||||
"""Initialize LoRA service
|
||||
|
||||
Args:
|
||||
scanner: LoRA scanner instance
|
||||
"""
|
||||
super().__init__("lora", scanner, LoraMetadata)
|
||||
|
||||
async def format_response(self, lora_data: Dict) -> Dict:
|
||||
"""Format LoRA data for API response"""
|
||||
return {
|
||||
"model_name": lora_data["model_name"],
|
||||
"file_name": lora_data["file_name"],
|
||||
"preview_url": config.get_preview_static_url(lora_data.get("preview_url", "")),
|
||||
"preview_nsfw_level": lora_data.get("preview_nsfw_level", 0),
|
||||
"base_model": lora_data.get("base_model", ""),
|
||||
"folder": lora_data["folder"],
|
||||
"sha256": lora_data.get("sha256", ""),
|
||||
"file_path": lora_data["file_path"].replace(os.sep, "/"),
|
||||
"file_size": lora_data.get("size", 0),
|
||||
"modified": lora_data.get("modified", ""),
|
||||
"tags": lora_data.get("tags", []),
|
||||
"from_civitai": lora_data.get("from_civitai", True),
|
||||
"usage_tips": lora_data.get("usage_tips", ""),
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply LoRA-specific filters"""
|
||||
# Handle first_letter filter for LoRAs
|
||||
first_letter = kwargs.get('first_letter')
|
||||
if first_letter:
|
||||
data = self._filter_by_first_letter(data, first_letter)
|
||||
|
||||
return data
|
||||
|
||||
def _filter_by_first_letter(self, data: List[Dict], letter: str) -> List[Dict]:
|
||||
"""Filter data by first letter of model name
|
||||
|
||||
Special handling:
|
||||
- '#': Numbers (0-9)
|
||||
- '@': Special characters (not alphanumeric)
|
||||
- '漢': CJK characters
|
||||
"""
|
||||
filtered_data = []
|
||||
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if letter == '#' and first_char.isdigit():
|
||||
filtered_data.append(lora)
|
||||
elif letter == '@' and not first_char.isalnum():
|
||||
# Special characters (not alphanumeric)
|
||||
filtered_data.append(lora)
|
||||
elif letter == '漢' and self._is_cjk_character(first_char):
|
||||
# CJK characters
|
||||
filtered_data.append(lora)
|
||||
elif letter.upper() == first_char:
|
||||
# Regular alphabet matching
|
||||
filtered_data.append(lora)
|
||||
|
||||
return filtered_data
|
||||
|
||||
def _is_cjk_character(self, char: str) -> bool:
|
||||
"""Check if character is a CJK character"""
|
||||
# Define Unicode ranges for CJK characters
|
||||
cjk_ranges = [
|
||||
(0x4E00, 0x9FFF), # CJK Unified Ideographs
|
||||
(0x3400, 0x4DBF), # CJK Unified Ideographs Extension A
|
||||
(0x20000, 0x2A6DF), # CJK Unified Ideographs Extension B
|
||||
(0x2A700, 0x2B73F), # CJK Unified Ideographs Extension C
|
||||
(0x2B740, 0x2B81F), # CJK Unified Ideographs Extension D
|
||||
(0x2B820, 0x2CEAF), # CJK Unified Ideographs Extension E
|
||||
(0x2CEB0, 0x2EBEF), # CJK Unified Ideographs Extension F
|
||||
(0x30000, 0x3134F), # CJK Unified Ideographs Extension G
|
||||
(0xF900, 0xFAFF), # CJK Compatibility Ideographs
|
||||
(0x3300, 0x33FF), # CJK Compatibility
|
||||
(0x3200, 0x32FF), # Enclosed CJK Letters and Months
|
||||
(0x3100, 0x312F), # Bopomofo
|
||||
(0x31A0, 0x31BF), # Bopomofo Extended
|
||||
(0x3040, 0x309F), # Hiragana
|
||||
(0x30A0, 0x30FF), # Katakana
|
||||
(0x31F0, 0x31FF), # Katakana Phonetic Extensions
|
||||
(0xAC00, 0xD7AF), # Hangul Syllables
|
||||
(0x1100, 0x11FF), # Hangul Jamo
|
||||
(0xA960, 0xA97F), # Hangul Jamo Extended-A
|
||||
(0xD7B0, 0xD7FF), # Hangul Jamo Extended-B
|
||||
]
|
||||
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in cjk_ranges)
|
||||
|
||||
# LoRA-specific methods
|
||||
async def get_letter_counts(self) -> Dict[str, int]:
|
||||
"""Get count of LoRAs for each letter of the alphabet"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
data = cache.raw_data
|
||||
|
||||
# Define letter categories
|
||||
letters = {
|
||||
'#': 0, # Numbers
|
||||
'A': 0, 'B': 0, 'C': 0, 'D': 0, 'E': 0, 'F': 0, 'G': 0, 'H': 0,
|
||||
'I': 0, 'J': 0, 'K': 0, 'L': 0, 'M': 0, 'N': 0, 'O': 0, 'P': 0,
|
||||
'Q': 0, 'R': 0, 'S': 0, 'T': 0, 'U': 0, 'V': 0, 'W': 0, 'X': 0,
|
||||
'Y': 0, 'Z': 0,
|
||||
'@': 0, # Special characters
|
||||
'漢': 0 # CJK characters
|
||||
}
|
||||
|
||||
# Count models for each letter
|
||||
for lora in data:
|
||||
model_name = lora.get('model_name', '')
|
||||
if not model_name:
|
||||
continue
|
||||
|
||||
first_char = model_name[0].upper()
|
||||
|
||||
if first_char.isdigit():
|
||||
letters['#'] += 1
|
||||
elif first_char in letters:
|
||||
letters[first_char] += 1
|
||||
elif self._is_cjk_character(first_char):
|
||||
letters['漢'] += 1
|
||||
elif not first_char.isalnum():
|
||||
letters['@'] += 1
|
||||
|
||||
return letters
|
||||
|
||||
async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
|
||||
"""Get trigger words for a specific LoRA file"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
if lora['file_name'] == lora_name:
|
||||
civitai_data = lora.get('civitai', {})
|
||||
return civitai_data.get('trainedWords', [])
|
||||
|
||||
return []
|
||||
|
||||
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]:
|
||||
"""Get usage tips for a LoRA by its relative path"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
|
||||
for lora in cache.raw_data:
|
||||
file_path = lora.get('file_path', '')
|
||||
if file_path:
|
||||
# Convert to forward slashes and extract relative path
|
||||
file_path_normalized = file_path.replace('\\', '/')
|
||||
relative_path = relative_path.replace('\\', '/')
|
||||
# Find the relative path part by looking for the relative_path in the full path
|
||||
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
|
||||
return lora.get('usage_tips', '')
|
||||
|
||||
return None
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
"""Find LoRAs with duplicate SHA256 hashes"""
|
||||
return self.scanner._hash_index.get_duplicate_hashes()
|
||||
|
||||
def find_duplicate_filenames(self) -> Dict:
|
||||
"""Find LoRAs with conflicting filenames"""
|
||||
return self.scanner._hash_index.get_duplicate_filenames()
|
||||
@@ -1,37 +1,85 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Tuple
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
# Supported sort modes: (sort_key, order)
|
||||
# order: 'asc' for ascending, 'desc' for descending
|
||||
SUPPORTED_SORT_MODES = [
|
||||
('name', 'asc'),
|
||||
('name', 'desc'),
|
||||
('date', 'asc'),
|
||||
('date', 'desc'),
|
||||
('size', 'asc'),
|
||||
('size', 'desc'),
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ModelCache:
|
||||
"""Cache structure for model data"""
|
||||
"""Cache structure for model data with extensible sorting"""
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
folders: List[str]
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
# Cache for last sort: (sort_key, order) -> sorted list
|
||||
self._last_sort: Tuple[str, str] = (None, None)
|
||||
self._last_sorted_data: List[Dict] = []
|
||||
# Default sort on init
|
||||
asyncio.create_task(self.resort())
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async def resort(self):
|
||||
"""Resort cached data according to last sort mode if set"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x['model_name'].lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=True
|
||||
)
|
||||
# Update folder list
|
||||
if self._last_sort != (None, None):
|
||||
sort_key, order = self._last_sort
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
# Update folder list
|
||||
# else: do nothing
|
||||
|
||||
all_folders = set(l['folder'] for l in self.raw_data)
|
||||
self.folders = sorted(list(all_folders), key=lambda x: x.lower())
|
||||
|
||||
def _sort_data(self, data: List[Dict], sort_key: str, order: str) -> List[Dict]:
|
||||
"""Sort data by sort_key and order"""
|
||||
reverse = (order == 'desc')
|
||||
if sort_key == 'name':
|
||||
# Natural sort by model_name, case-insensitive
|
||||
return natsorted(
|
||||
data,
|
||||
key=lambda x: x['model_name'].lower(),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('modified'),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size
|
||||
return sorted(
|
||||
data,
|
||||
key=itemgetter('size'),
|
||||
reverse=reverse
|
||||
)
|
||||
else:
|
||||
# Fallback: no sort
|
||||
return list(data)
|
||||
|
||||
async def get_sorted_data(self, sort_key: str = 'name', order: str = 'asc') -> List[Dict]:
|
||||
"""Get sorted data by sort_key and order, using cache if possible"""
|
||||
async with self._lock:
|
||||
if (sort_key, order) == self._last_sort:
|
||||
return self._last_sorted_data
|
||||
sorted_data = self._sort_data(self.raw_data, sort_key, order)
|
||||
self._last_sort = (sort_key, order)
|
||||
self._last_sorted_data = sorted_data
|
||||
return sorted_data
|
||||
|
||||
async def update_preview_url(self, file_path: str, preview_url: str, preview_nsfw_level: int) -> bool:
|
||||
"""Update preview_url for a specific model in all cached data
|
||||
|
||||
|
||||
@@ -31,29 +31,34 @@ class ModelHashIndex:
|
||||
if file_path not in self._duplicate_hashes.get(sha256, []):
|
||||
self._duplicate_hashes.setdefault(sha256, []).append(file_path)
|
||||
|
||||
# Track duplicates by filename
|
||||
# Track duplicates by filename - FIXED LOGIC
|
||||
if filename in self._filename_to_hash:
|
||||
old_hash = self._filename_to_hash[filename]
|
||||
if old_hash != sha256: # Different models with the same name
|
||||
old_path = self._hash_to_path.get(old_hash)
|
||||
if old_path:
|
||||
if filename not in self._duplicate_filenames:
|
||||
self._duplicate_filenames[filename] = [old_path]
|
||||
if file_path not in self._duplicate_filenames.get(filename, []):
|
||||
self._duplicate_filenames.setdefault(filename, []).append(file_path)
|
||||
existing_hash = self._filename_to_hash[filename]
|
||||
existing_path = self._hash_to_path.get(existing_hash)
|
||||
|
||||
# If this is a different file with the same filename
|
||||
if existing_path and existing_path != file_path:
|
||||
# Initialize duplicates tracking if needed
|
||||
if filename not in self._duplicate_filenames:
|
||||
self._duplicate_filenames[filename] = [existing_path]
|
||||
|
||||
# Add current file to duplicates if not already present
|
||||
if file_path not in self._duplicate_filenames[filename]:
|
||||
self._duplicate_filenames[filename].append(file_path)
|
||||
|
||||
# Remove old path mapping if hash exists
|
||||
if sha256 in self._hash_to_path:
|
||||
old_path = self._hash_to_path[sha256]
|
||||
old_filename = self._get_filename_from_path(old_path)
|
||||
if old_filename in self._filename_to_hash:
|
||||
if old_filename in self._filename_to_hash and self._filename_to_hash[old_filename] == sha256:
|
||||
del self._filename_to_hash[old_filename]
|
||||
|
||||
# Remove old hash mapping if filename exists
|
||||
# Remove old hash mapping if filename exists and points to different hash
|
||||
if filename in self._filename_to_hash:
|
||||
old_hash = self._filename_to_hash[filename]
|
||||
if old_hash in self._hash_to_path:
|
||||
del self._hash_to_path[old_hash]
|
||||
if old_hash != sha256 and old_hash in self._hash_to_path:
|
||||
# Don't delete the old hash mapping, just update filename mapping
|
||||
pass
|
||||
|
||||
# Add new mappings
|
||||
self._hash_to_path[sha256] = file_path
|
||||
@@ -199,8 +204,6 @@ class ModelHashIndex:
|
||||
|
||||
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||
"""Get hash for a filename without extension"""
|
||||
# Strip extension if present to make the function more flexible
|
||||
filename = os.path.splitext(filename)[0]
|
||||
return self._filename_to_hash.get(filename)
|
||||
|
||||
def clear(self) -> None:
|
||||
|
||||
@@ -5,11 +5,10 @@ import asyncio
|
||||
import time
|
||||
import shutil
|
||||
from typing import List, Dict, Optional, Type, Set
|
||||
import msgpack # Add MessagePack import for efficient serialization
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..config import config
|
||||
from ..utils.file_utils import find_preview_file
|
||||
from ..utils.file_utils import find_preview_file, get_preview_extension
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_cache import ModelCache
|
||||
from .model_hash_index import ModelHashIndex
|
||||
@@ -19,17 +18,33 @@ from .websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define cache version to handle future format changes
|
||||
# Version history:
|
||||
# 1 - Initial version
|
||||
# 2 - Added duplicate_filenames and duplicate_hashes tracking
|
||||
# 3 - Added _excluded_models list to cache
|
||||
CACHE_VERSION = 3
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
|
||||
_lock = asyncio.Lock()
|
||||
_instances = {} # Dictionary to store instances by class
|
||||
_locks = {} # Dictionary to store locks by class
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""Implement singleton pattern for each subclass"""
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = super().__new__(cls)
|
||||
return cls._instances[cls]
|
||||
|
||||
@classmethod
|
||||
def _get_lock(cls):
|
||||
"""Get or create a lock for this class"""
|
||||
if cls not in cls._locks:
|
||||
cls._locks[cls] = asyncio.Lock()
|
||||
return cls._locks[cls]
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance with async support"""
|
||||
lock = cls._get_lock()
|
||||
async with lock:
|
||||
if cls not in cls._instances:
|
||||
cls._instances[cls] = cls()
|
||||
return cls._instances[cls]
|
||||
|
||||
def __init__(self, model_type: str, model_class: Type[BaseModelMetadata], file_extensions: Set[str], hash_index: Optional[ModelHashIndex] = None):
|
||||
"""Initialize the scanner
|
||||
@@ -40,6 +55,10 @@ class ModelScanner:
|
||||
file_extensions: Set of supported file extensions including the dot (e.g. {'.safetensors'})
|
||||
hash_index: Hash index instance (optional)
|
||||
"""
|
||||
# Ensure initialization happens only once per instance
|
||||
if hasattr(self, '_initialized'):
|
||||
return
|
||||
|
||||
self.model_type = model_type
|
||||
self.model_class = model_class
|
||||
self.file_extensions = file_extensions
|
||||
@@ -48,202 +67,15 @@ class ModelScanner:
|
||||
self._tags_count = {} # Dictionary to store tag counts
|
||||
self._is_initializing = False # Flag to track initialization state
|
||||
self._excluded_models = [] # List to track excluded models
|
||||
self._dirs_last_modified = {} # Track directory modification times
|
||||
self._use_cache_files = False # Flag to control cache file usage, default to disabled
|
||||
|
||||
# Clear cache files if disabled
|
||||
if not self._use_cache_files:
|
||||
self._clear_cache_files()
|
||||
self._initialized = True
|
||||
|
||||
# Register this service
|
||||
asyncio.create_task(self._register_service())
|
||||
|
||||
def _clear_cache_files(self):
|
||||
"""Clear existing cache files if they exist"""
|
||||
try:
|
||||
cache_path = self._get_cache_file_path()
|
||||
if cache_path and os.path.exists(cache_path):
|
||||
os.remove(cache_path)
|
||||
logger.info(f"Cleared {self.model_type} cache file: {cache_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error clearing {self.model_type} cache file: {e}")
|
||||
|
||||
async def _register_service(self):
|
||||
"""Register this instance with the ServiceRegistry"""
|
||||
service_name = f"{self.model_type}_scanner"
|
||||
await ServiceRegistry.register_service(service_name, self)
|
||||
|
||||
def _get_cache_file_path(self) -> Optional[str]:
|
||||
"""Get the path to the cache file"""
|
||||
# Get the directory where this module is located
|
||||
current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
|
||||
|
||||
# Create a cache directory within the project if it doesn't exist
|
||||
cache_dir = os.path.join(current_dir, "cache")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
# Create filename based on model type
|
||||
cache_filename = f"lm_{self.model_type}_cache.msgpack"
|
||||
return os.path.join(cache_dir, cache_filename)
|
||||
|
||||
def _prepare_for_msgpack(self, data):
|
||||
"""Preprocess data to accommodate MessagePack serialization limitations
|
||||
|
||||
Converts integers exceeding safe range to strings
|
||||
|
||||
Args:
|
||||
data: Any type of data structure
|
||||
|
||||
Returns:
|
||||
Preprocessed data structure with large integers converted to strings
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {k: self._prepare_for_msgpack(v) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [self._prepare_for_msgpack(item) for item in data]
|
||||
elif isinstance(data, int) and (data > 9007199254740991 or data < -9007199254740991):
|
||||
# Convert integers exceeding JavaScript's safe integer range (2^53-1) to strings
|
||||
return str(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
async def _save_cache_to_disk(self) -> bool:
|
||||
"""Save cache data to disk using MessagePack"""
|
||||
if not self._use_cache_files:
|
||||
logger.debug(f"Cache files disabled for {self.model_type}, skipping save")
|
||||
return False
|
||||
|
||||
if self._cache is None or not self._cache.raw_data:
|
||||
logger.debug(f"No {self.model_type} cache data to save")
|
||||
return False
|
||||
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path:
|
||||
logger.warning(f"Cannot determine {self.model_type} cache file location")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create cache data structure
|
||||
cache_data = {
|
||||
"version": CACHE_VERSION,
|
||||
"timestamp": time.time(),
|
||||
"model_type": self.model_type,
|
||||
"raw_data": self._cache.raw_data,
|
||||
"hash_index": {
|
||||
"hash_to_path": self._hash_index._hash_to_path,
|
||||
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
|
||||
"duplicate_hashes": self._hash_index._duplicate_hashes,
|
||||
"duplicate_filenames": self._hash_index._duplicate_filenames
|
||||
},
|
||||
"tags_count": self._tags_count,
|
||||
"dirs_last_modified": self._get_dirs_last_modified(),
|
||||
"excluded_models": self._excluded_models # Add excluded_models to cache data
|
||||
}
|
||||
|
||||
# Preprocess data to handle large integers
|
||||
processed_cache_data = self._prepare_for_msgpack(cache_data)
|
||||
|
||||
# Write to temporary file first (atomic operation)
|
||||
temp_path = f"{cache_path}.tmp"
|
||||
with open(temp_path, 'wb') as f:
|
||||
msgpack.pack(processed_cache_data, f)
|
||||
|
||||
# Replace the old file with the new one
|
||||
if os.path.exists(cache_path):
|
||||
os.replace(temp_path, cache_path)
|
||||
else:
|
||||
os.rename(temp_path, cache_path)
|
||||
|
||||
logger.info(f"Saved {self.model_type} cache with {len(self._cache.raw_data)} models to {cache_path}")
|
||||
logger.debug(f"Hash index stats - hash_to_path: {len(self._hash_index._hash_to_path)}, filename_to_hash: {len(self._hash_index._filename_to_hash)}, duplicate_hashes: {len(self._hash_index._duplicate_hashes)}, duplicate_filenames: {len(self._hash_index._duplicate_filenames)}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {self.model_type} cache to disk: {e}")
|
||||
# Try to clean up temp file if it exists
|
||||
if 'temp_path' in locals() and os.path.exists(temp_path):
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _get_dirs_last_modified(self) -> Dict[str, float]:
|
||||
"""Get last modified time for all model directories"""
|
||||
dirs_info = {}
|
||||
for root in self.get_model_roots():
|
||||
if os.path.exists(root):
|
||||
dirs_info[root] = os.path.getmtime(root)
|
||||
# Also check immediate subdirectories for changes
|
||||
try:
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
if entry.is_dir(follow_symlinks=True):
|
||||
dirs_info[entry.path] = entry.stat().st_mtime
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting directory info for {root}: {e}")
|
||||
return dirs_info
|
||||
|
||||
def _is_cache_valid(self, cache_data: Dict) -> bool:
|
||||
"""Validate if the loaded cache is still valid"""
|
||||
if not cache_data or cache_data.get("version") != CACHE_VERSION:
|
||||
logger.info(f"Cache invalid - version mismatch. Got: {cache_data.get('version')}, Expected: {CACHE_VERSION}")
|
||||
return False
|
||||
|
||||
if cache_data.get("model_type") != self.model_type:
|
||||
logger.info(f"Cache invalid - model type mismatch. Got: {cache_data.get('model_type')}, Expected: {self.model_type}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _load_cache_from_disk(self) -> bool:
|
||||
"""Load cache data from disk using MessagePack"""
|
||||
if not self._use_cache_files:
|
||||
logger.info(f"Cache files disabled for {self.model_type}, skipping load")
|
||||
return False
|
||||
|
||||
start_time = time.time()
|
||||
cache_path = self._get_cache_file_path()
|
||||
if not cache_path or not os.path.exists(cache_path):
|
||||
return False
|
||||
|
||||
try:
|
||||
with open(cache_path, 'rb') as f:
|
||||
cache_data = msgpack.unpack(f)
|
||||
|
||||
# Validate cache data
|
||||
if not self._is_cache_valid(cache_data):
|
||||
logger.info(f"{self.model_type.capitalize()} cache file found but invalid or outdated")
|
||||
return False
|
||||
|
||||
# Load data into memory
|
||||
self._cache = ModelCache(
|
||||
raw_data=cache_data["raw_data"],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# Load hash index
|
||||
hash_index_data = cache_data.get("hash_index", {})
|
||||
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
|
||||
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
|
||||
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
|
||||
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
|
||||
|
||||
# Load tags count
|
||||
self._tags_count = cache_data.get("tags_count", {})
|
||||
|
||||
# Load excluded models
|
||||
self._excluded_models = cache_data.get("excluded_models", [])
|
||||
|
||||
# Resort the cache
|
||||
await self._cache.resort()
|
||||
|
||||
logger.info(f"Loaded {self.model_type} cache from disk with {len(self._cache.raw_data)} models in {time.time() - start_time:.2f} seconds")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {self.model_type} cache from disk: {e}")
|
||||
return False
|
||||
|
||||
async def initialize_in_background(self) -> None:
|
||||
"""Initialize cache in background using thread pool"""
|
||||
@@ -252,8 +84,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -271,21 +101,6 @@ class ModelScanner:
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
|
||||
if cache_loaded:
|
||||
# Cache loaded successfully, broadcast complete message
|
||||
await ws_manager.broadcast_init_progress({
|
||||
'stage': 'finalizing',
|
||||
'progress': 100,
|
||||
'status': 'complete',
|
||||
'details': f"Loaded {len(self._cache.raw_data)} {self.model_type} files from cache.",
|
||||
'scanner_type': self.model_type,
|
||||
'pageType': page_type
|
||||
})
|
||||
self._is_initializing = False
|
||||
return
|
||||
|
||||
# If cache loading failed, proceed with full scan
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -332,9 +147,6 @@ class ModelScanner:
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} cache initialized in {time.time() - start_time:.2f} seconds. Found {len(self._cache.raw_data)} models")
|
||||
|
||||
# Save the cache to disk after initialization
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
# Send completion message
|
||||
await asyncio.sleep(0.5) # Small delay to ensure final progress message is sent
|
||||
await ws_manager.broadcast_init_progress({
|
||||
@@ -490,6 +302,13 @@ class ModelScanner:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache.raw_data = raw_data
|
||||
loop.run_until_complete(self._cache.resort())
|
||||
@@ -509,40 +328,21 @@ class ModelScanner:
|
||||
|
||||
Args:
|
||||
force_refresh: Whether to refresh the cache
|
||||
rebuild_cache: Whether to completely rebuild the cache by reloading from disk first
|
||||
rebuild_cache: Whether to completely rebuild the cache
|
||||
"""
|
||||
# If cache is not initialized, return an empty cache
|
||||
# Actual initialization should be done via initialize_in_background
|
||||
if self._cache is None and not force_refresh:
|
||||
return ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
# If force refresh is requested, initialize the cache directly
|
||||
if force_refresh:
|
||||
# If rebuild_cache is True, try to reload from disk before reconciliation
|
||||
if rebuild_cache:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Attempting to rebuild cache from disk...")
|
||||
cache_loaded = await self._load_cache_from_disk()
|
||||
if cache_loaded:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Successfully reloaded cache from disk")
|
||||
else:
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Could not reload cache from disk, proceeding with complete rebuild")
|
||||
# If loading from disk failed, do a complete rebuild and save to disk
|
||||
await self._initialize_cache()
|
||||
await self._save_cache_to_disk()
|
||||
return self._cache
|
||||
|
||||
if self._cache is None:
|
||||
# For initial creation, do a full initialization
|
||||
await self._initialize_cache()
|
||||
# Save the newly built cache
|
||||
await self._save_cache_to_disk()
|
||||
else:
|
||||
# For subsequent refreshes, use fast reconciliation
|
||||
await self._reconcile_cache()
|
||||
|
||||
return self._cache
|
||||
@@ -574,11 +374,16 @@ class ModelScanner:
|
||||
for tag in model_data['tags']:
|
||||
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
|
||||
|
||||
# Log duplicate filename warnings after building the index
|
||||
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
|
||||
# if duplicate_filenames:
|
||||
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
|
||||
# for filename, paths in duplicate_filenames.items():
|
||||
# logger.warning(f" Duplicate filename '{filename}': {paths}")
|
||||
|
||||
# Update cache
|
||||
self._cache = ModelCache(
|
||||
raw_data=raw_data,
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
|
||||
@@ -592,8 +397,6 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
self._cache = ModelCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[]
|
||||
)
|
||||
finally:
|
||||
@@ -735,19 +538,66 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
logger.info(f"{self.model_type.capitalize()} Scanner: Cache reconciliation completed in {time.time() - start_time:.2f} seconds. Added {total_added}, removed {total_removed} models.")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.model_type.capitalize()} Scanner: Error reconciling cache: {e}", exc_info=True)
|
||||
finally:
|
||||
self._is_initializing = False # Unset flag
|
||||
|
||||
# These methods should be implemented in child classes
|
||||
async def scan_all_models(self) -> List[Dict]:
|
||||
"""Scan all model directories and return metadata"""
|
||||
raise NotImplementedError("Subclasses must implement scan_all_models")
|
||||
all_models = []
|
||||
|
||||
# Create scan tasks for each directory
|
||||
scan_tasks = []
|
||||
for model_root in self.get_model_roots():
|
||||
task = asyncio.create_task(self._scan_directory(model_root))
|
||||
scan_tasks.append(task)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for task in scan_tasks:
|
||||
try:
|
||||
models = await task
|
||||
all_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning directory: {e}")
|
||||
|
||||
return all_models
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Scan a single directory for model files"""
|
||||
models = []
|
||||
original_root = root_path # Save original root path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
"""Recursively scan directory, avoiding circular symlinks"""
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
result = await self._process_model_file(file_path, original_root)
|
||||
# Only add to models if result is not None (skip corrupted metadata)
|
||||
if result:
|
||||
models.append(result)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return models
|
||||
|
||||
def is_initializing(self) -> bool:
|
||||
"""Check if the scanner is currently initializing"""
|
||||
@@ -769,10 +619,18 @@ class ModelScanner:
|
||||
return os.path.dirname(rel_path).replace(os.path.sep, '/')
|
||||
return ''
|
||||
|
||||
# Common methods shared between scanners
|
||||
def adjust_metadata(self, metadata, file_path, root_path):
|
||||
"""Hook for subclasses: adjust metadata during scanning"""
|
||||
return metadata
|
||||
|
||||
async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
|
||||
"""Process a single model file and return its metadata"""
|
||||
metadata = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.model_class)
|
||||
|
||||
if should_skip:
|
||||
# Metadata file exists but cannot be parsed - skip this model
|
||||
logger.warning(f"Skipping model {file_path} due to corrupted metadata file")
|
||||
return None
|
||||
|
||||
if metadata is None:
|
||||
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
|
||||
@@ -788,7 +646,7 @@ class ModelScanner:
|
||||
|
||||
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
|
||||
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
|
||||
await MetadataManager.save_metadata(file_path, metadata, True)
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
logger.debug(f"Created metadata from .civitai.info for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
|
||||
@@ -815,7 +673,7 @@ class ModelScanner:
|
||||
metadata.modelDescription = version_info['model']['description']
|
||||
|
||||
# Save the updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata, True)
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
logger.debug(f"Updated metadata with civitai info for {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
|
||||
@@ -823,12 +681,23 @@ class ModelScanner:
|
||||
if metadata is None:
|
||||
metadata = await self._create_default_metadata(file_path)
|
||||
|
||||
# Hook: allow subclasses to adjust metadata
|
||||
metadata = self.adjust_metadata(metadata, file_path, root_path)
|
||||
|
||||
model_data = metadata.to_dict()
|
||||
|
||||
# Skip excluded models
|
||||
if model_data.get('exclude', False):
|
||||
self._excluded_models.append(model_data['file_path'])
|
||||
return None
|
||||
|
||||
# Check for duplicate filename before adding to hash index
|
||||
filename = os.path.splitext(os.path.basename(file_path))[0]
|
||||
existing_hash = self._hash_index.get_hash_by_filename(filename)
|
||||
if existing_hash and existing_hash != model_data.get('sha256', '').lower():
|
||||
existing_path = self._hash_index.get_path(existing_hash)
|
||||
if existing_path and existing_path != file_path:
|
||||
logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
|
||||
|
||||
await self._fetch_missing_metadata(file_path, model_data)
|
||||
rel_path = os.path.relpath(file_path, root_path)
|
||||
@@ -884,54 +753,12 @@ class ModelScanner:
|
||||
|
||||
model_data['civitai']['creator'] = model_metadata['creator']
|
||||
|
||||
await MetadataManager.save_metadata(file_path, model_data, True)
|
||||
await MetadataManager.save_metadata(file_path, model_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
|
||||
|
||||
async def _scan_directory(self, root_path: str) -> List[Dict]:
|
||||
"""Base implementation for directory scanning"""
|
||||
models = []
|
||||
original_root = root_path
|
||||
|
||||
async def scan_recursive(path: str, visited_paths: set):
|
||||
try:
|
||||
real_path = os.path.realpath(path)
|
||||
if real_path in visited_paths:
|
||||
logger.debug(f"Skipping already visited path: {path}")
|
||||
return
|
||||
visited_paths.add(real_path)
|
||||
|
||||
with os.scandir(path) as it:
|
||||
entries = list(it)
|
||||
for entry in entries:
|
||||
try:
|
||||
if entry.is_file(follow_symlinks=True):
|
||||
ext = os.path.splitext(entry.name)[1].lower()
|
||||
if ext in self.file_extensions:
|
||||
file_path = entry.path.replace(os.sep, "/")
|
||||
await self._process_single_file(file_path, original_root, models)
|
||||
await asyncio.sleep(0)
|
||||
elif entry.is_dir(follow_symlinks=True):
|
||||
await scan_recursive(entry.path, visited_paths)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing entry {entry.path}: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning {path}: {e}")
|
||||
|
||||
await scan_recursive(root_path, set())
|
||||
return models
|
||||
|
||||
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
|
||||
"""Process a single file and add to results list"""
|
||||
try:
|
||||
result = await self._process_model_file(file_path, root_path)
|
||||
if result:
|
||||
models_list.append(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {file_path}: {e}")
|
||||
|
||||
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
|
||||
"""Add a model to the cache and save to disk
|
||||
"""Add a model to the cache
|
||||
|
||||
Args:
|
||||
metadata_dict: The model metadata dictionary
|
||||
@@ -960,16 +787,21 @@ class ModelScanner:
|
||||
|
||||
# Update the hash index
|
||||
self._hash_index.add_entry(metadata_dict['sha256'], metadata_dict['file_path'])
|
||||
|
||||
# Save to disk
|
||||
await self._save_cache_to_disk()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding model to cache: {e}")
|
||||
return False
|
||||
|
||||
async def move_model(self, source_path: str, target_path: str) -> bool:
|
||||
"""Move a model and its associated files to a new location"""
|
||||
async def move_model(self, source_path: str, target_path: str) -> Optional[str]:
|
||||
"""Move a model and its associated files to a new location
|
||||
|
||||
Args:
|
||||
source_path: Original file path
|
||||
target_path: Target directory path
|
||||
|
||||
Returns:
|
||||
Optional[str]: New file path if successful, None if failed
|
||||
"""
|
||||
try:
|
||||
source_path = source_path.replace(os.sep, '/')
|
||||
target_path = target_path.replace(os.sep, '/')
|
||||
@@ -978,14 +810,28 @@ class ModelScanner:
|
||||
|
||||
if not file_ext or file_ext.lower() not in self.file_extensions:
|
||||
logger.error(f"Invalid file extension for model: {file_ext}")
|
||||
return False
|
||||
return None
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||
source_dir = os.path.dirname(source_path)
|
||||
|
||||
os.makedirs(target_path, exist_ok=True)
|
||||
|
||||
target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
|
||||
def get_source_hash():
|
||||
return self.get_hash_by_path(source_path)
|
||||
|
||||
# Check for filename conflicts and auto-rename if necessary
|
||||
from ..utils.models import BaseModelMetadata
|
||||
final_filename = BaseModelMetadata.generate_unique_filename(
|
||||
target_path, base_name, file_ext, get_source_hash
|
||||
)
|
||||
|
||||
target_file = os.path.join(target_path, final_filename).replace(os.sep, '/')
|
||||
final_base_name = os.path.splitext(final_filename)[0]
|
||||
|
||||
# Log if filename was changed due to conflict
|
||||
if final_filename != f"{base_name}{file_ext}":
|
||||
logger.info(f"Renamed {base_name}{file_ext} to {final_filename} to avoid filename conflict")
|
||||
|
||||
real_source = os.path.realpath(source_path)
|
||||
real_target = os.path.realpath(target_file)
|
||||
@@ -1002,12 +848,17 @@ class ModelScanner:
|
||||
for file in os.listdir(source_dir):
|
||||
if file.startswith(base_name + ".") and file != os.path.basename(source_path):
|
||||
source_file_path = os.path.join(source_dir, file)
|
||||
# Generate new filename with the same base name as the model file
|
||||
file_suffix = file[len(base_name):] # Get the part after base_name (e.g., ".metadata.json", ".preview.png")
|
||||
new_associated_filename = f"{final_base_name}{file_suffix}"
|
||||
target_associated_path = os.path.join(target_path, new_associated_filename)
|
||||
|
||||
# Store metadata file path for special handling
|
||||
if file == f"{base_name}.metadata.json":
|
||||
source_metadata = source_file_path
|
||||
moved_metadata_path = os.path.join(target_path, file)
|
||||
moved_metadata_path = target_associated_path
|
||||
else:
|
||||
files_to_move.append((source_file_path, os.path.join(target_path, file)))
|
||||
files_to_move.append((source_file_path, target_associated_path))
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing files in {source_dir}: {e}")
|
||||
|
||||
@@ -1029,11 +880,11 @@ class ModelScanner:
|
||||
|
||||
await self.update_single_model_cache(source_path, target_file, metadata)
|
||||
|
||||
return True
|
||||
return target_file
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error moving model: {e}", exc_info=True)
|
||||
return False
|
||||
return None
|
||||
|
||||
async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict:
|
||||
"""Update file paths in metadata file"""
|
||||
@@ -1042,12 +893,15 @@ class ModelScanner:
|
||||
metadata = json.load(f)
|
||||
|
||||
metadata['file_path'] = model_path.replace(os.sep, '/')
|
||||
# Update file_name to match the new filename
|
||||
metadata['file_name'] = os.path.splitext(os.path.basename(model_path))[0]
|
||||
|
||||
if 'preview_url' in metadata and metadata['preview_url']:
|
||||
preview_dir = os.path.dirname(model_path)
|
||||
preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
|
||||
preview_ext = os.path.splitext(metadata['preview_url'])[1]
|
||||
new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
|
||||
# Update preview filename to match the new base name
|
||||
new_base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
preview_ext = get_preview_extension(metadata['preview_url'])
|
||||
new_preview_path = os.path.join(preview_dir, f"{new_base_name}{preview_ext}")
|
||||
metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
|
||||
|
||||
await MetadataManager.save_metadata(metadata_path, metadata)
|
||||
@@ -1102,9 +956,6 @@ class ModelScanner:
|
||||
|
||||
await cache.resort()
|
||||
|
||||
# Save the updated cache
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
@@ -1117,8 +968,16 @@ class ModelScanner:
|
||||
|
||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a model by its file path"""
|
||||
return self._hash_index.get_hash(file_path)
|
||||
if self._cache is None or not self._cache.raw_data:
|
||||
return None
|
||||
|
||||
# Iterate through cache data to find matching file path
|
||||
for model_data in self._cache.raw_data:
|
||||
if model_data.get('file_path') == file_path:
|
||||
return model_data.get('sha256')
|
||||
|
||||
return None
|
||||
|
||||
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||
"""Get hash for a model by its filename without path"""
|
||||
return self._hash_index.get_hash_by_filename(filename)
|
||||
@@ -1198,11 +1057,7 @@ class ModelScanner:
|
||||
if self._cache is None:
|
||||
return False
|
||||
|
||||
updated = await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
if updated:
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
return updated
|
||||
return await self._cache.update_preview_url(file_path, preview_url, preview_nsfw_level)
|
||||
|
||||
async def bulk_delete_models(self, file_paths: List[str]) -> Dict:
|
||||
"""Delete multiple models and update cache in a batch operation
|
||||
@@ -1334,9 +1189,6 @@ class ModelScanner:
|
||||
# Resort cache
|
||||
await self._cache.resort()
|
||||
|
||||
# Save updated cache to disk
|
||||
await self._save_cache_to_disk()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -1362,3 +1214,56 @@ class ModelScanner:
|
||||
if file_name in self._hash_index._duplicate_filenames:
|
||||
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
|
||||
del self._hash_index._duplicate_filenames[file_name]
|
||||
|
||||
async def check_model_version_exists(self, model_version_id: int) -> bool:
|
||||
"""Check if a specific model version exists in the cache
|
||||
|
||||
Args:
|
||||
model_version_id: Civitai model version ID
|
||||
|
||||
Returns:
|
||||
bool: True if the model version exists, False otherwise
|
||||
"""
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
return False
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking model version existence: {e}")
|
||||
return False
|
||||
|
||||
async def get_model_versions_by_id(self, model_id: int) -> List[Dict]:
|
||||
"""Get all versions of a model by its ID
|
||||
|
||||
Args:
|
||||
model_id: Civitai model ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of version information dictionaries
|
||||
"""
|
||||
try:
|
||||
cache = await self.get_cached_data()
|
||||
if not cache or not cache.raw_data:
|
||||
return []
|
||||
|
||||
versions = []
|
||||
for item in cache.raw_data:
|
||||
if (item.get('civitai') and
|
||||
item['civitai'].get('modelId') == model_id and
|
||||
item['civitai'].get('id')):
|
||||
versions.append({
|
||||
'versionId': item['civitai'].get('id'),
|
||||
'name': item['civitai'].get('name'),
|
||||
'fileName': item.get('file_name', '')
|
||||
})
|
||||
|
||||
return versions
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model versions: {e}")
|
||||
return []
|
||||
|
||||
142
py/services/model_service_factory.py
Normal file
142
py/services/model_service_factory.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from typing import Dict, Type, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ModelServiceFactory:
|
||||
"""Factory for managing model services and routes"""
|
||||
|
||||
_services: Dict[str, Type] = {}
|
||||
_routes: Dict[str, Type] = {}
|
||||
_initialized_services: Dict[str, Any] = {}
|
||||
_initialized_routes: Dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def register_model_type(cls, model_type: str, service_class: Type, route_class: Type):
|
||||
"""Register a new model type with its service and route classes
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier (e.g., 'lora', 'checkpoint')
|
||||
service_class: The service class for this model type
|
||||
route_class: The route class for this model type
|
||||
"""
|
||||
cls._services[model_type] = service_class
|
||||
cls._routes[model_type] = route_class
|
||||
logger.info(f"Registered model type '{model_type}' with service {service_class.__name__} and routes {route_class.__name__}")
|
||||
|
||||
@classmethod
|
||||
def get_service_class(cls, model_type: str) -> Type:
|
||||
"""Get service class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The service class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._services:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._services[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_class(cls, model_type: str) -> Type:
|
||||
"""Get route class for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route class for the model type
|
||||
|
||||
Raises:
|
||||
ValueError: If model type is not registered
|
||||
"""
|
||||
if model_type not in cls._routes:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
return cls._routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def get_route_instance(cls, model_type: str):
|
||||
"""Get or create route instance for a model type
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
The route instance for the model type
|
||||
"""
|
||||
if model_type not in cls._initialized_routes:
|
||||
route_class = cls.get_route_class(model_type)
|
||||
cls._initialized_routes[model_type] = route_class()
|
||||
return cls._initialized_routes[model_type]
|
||||
|
||||
@classmethod
|
||||
def setup_all_routes(cls, app):
|
||||
"""Setup routes for all registered model types
|
||||
|
||||
Args:
|
||||
app: The aiohttp application instance
|
||||
"""
|
||||
logger.info(f"Setting up routes for {len(cls._services)} registered model types")
|
||||
|
||||
for model_type in cls._services.keys():
|
||||
try:
|
||||
routes_instance = cls.get_route_instance(model_type)
|
||||
routes_instance.setup_routes(app)
|
||||
logger.info(f"Successfully set up routes for {model_type}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup routes for {model_type}: {e}", exc_info=True)
|
||||
|
||||
@classmethod
|
||||
def get_registered_types(cls) -> list:
|
||||
"""Get list of all registered model types
|
||||
|
||||
Returns:
|
||||
List of registered model type identifiers
|
||||
"""
|
||||
return list(cls._services.keys())
|
||||
|
||||
@classmethod
|
||||
def is_registered(cls, model_type: str) -> bool:
|
||||
"""Check if a model type is registered
|
||||
|
||||
Args:
|
||||
model_type: The model type identifier
|
||||
|
||||
Returns:
|
||||
True if the model type is registered, False otherwise
|
||||
"""
|
||||
return model_type in cls._services
|
||||
|
||||
@classmethod
|
||||
def clear_registrations(cls):
|
||||
"""Clear all registrations - mainly for testing purposes"""
|
||||
cls._services.clear()
|
||||
cls._routes.clear()
|
||||
cls._initialized_services.clear()
|
||||
cls._initialized_routes.clear()
|
||||
logger.info("Cleared all model type registrations")
|
||||
|
||||
|
||||
def register_default_model_types():
|
||||
"""Register the default model types (LoRA, Checkpoint, and Embedding)"""
|
||||
from ..services.lora_service import LoraService
|
||||
from ..services.checkpoint_service import CheckpointService
|
||||
from ..services.embedding_service import EmbeddingService
|
||||
from ..routes.lora_routes import LoraRoutes
|
||||
from ..routes.checkpoint_routes import CheckpointRoutes
|
||||
from ..routes.embedding_routes import EmbeddingRoutes
|
||||
|
||||
# Register LoRA model type
|
||||
ModelServiceFactory.register_model_type('lora', LoraService, LoraRoutes)
|
||||
|
||||
# Register Checkpoint model type
|
||||
ModelServiceFactory.register_model_type('checkpoint', CheckpointService, CheckpointRoutes)
|
||||
|
||||
# Register Embedding model type
|
||||
ModelServiceFactory.register_model_type('embedding', EmbeddingService, EmbeddingRoutes)
|
||||
|
||||
logger.info("Registered default model types: lora, checkpoint, embedding")
|
||||
@@ -393,8 +393,8 @@ class RecipeScanner:
|
||||
if 'hash' in lora and (not lora.get('file_name') or not lora['file_name']):
|
||||
hash_value = lora['hash']
|
||||
|
||||
if self._lora_scanner.has_lora_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(hash_value)
|
||||
if self._lora_scanner.has_hash(hash_value):
|
||||
lora_path = self._lora_scanner.get_path_by_hash(hash_value)
|
||||
if lora_path:
|
||||
file_name = os.path.splitext(os.path.basename(lora_path))[0]
|
||||
lora['file_name'] = file_name
|
||||
@@ -465,7 +465,7 @@ class RecipeScanner:
|
||||
# Count occurrences of each base model
|
||||
for lora in loras:
|
||||
if 'hash' in lora:
|
||||
lora_path = self._lora_scanner.get_lora_path_by_hash(lora['hash'])
|
||||
lora_path = self._lora_scanner.get_path_by_hash(lora['hash'])
|
||||
if lora_path:
|
||||
base_model = await self._get_base_model_for_lora(lora_path)
|
||||
if base_model:
|
||||
@@ -603,9 +603,9 @@ class RecipeScanner:
|
||||
if 'loras' in item:
|
||||
for lora in item['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora['hash'].lower())
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||
|
||||
result = {
|
||||
'items': paginated_items,
|
||||
@@ -655,9 +655,9 @@ class RecipeScanner:
|
||||
for lora in formatted_recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora_hash = lora['hash'].lower()
|
||||
lora['inLibrary'] = self._lora_scanner.has_lora_hash(lora_hash)
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_lora_path_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
|
||||
|
||||
return formatted_recipe
|
||||
|
||||
|
||||
114
py/services/server_i18n.py
Normal file
114
py/services/server_i18n.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ServerI18nManager:
|
||||
"""Server-side internationalization manager for template rendering"""
|
||||
|
||||
def __init__(self):
|
||||
self.translations = {}
|
||||
self.current_locale = 'en'
|
||||
self._load_translations()
|
||||
|
||||
def _load_translations(self):
|
||||
"""Load all translation files from the locales directory"""
|
||||
i18n_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
'locales'
|
||||
)
|
||||
|
||||
if not os.path.exists(i18n_path):
|
||||
logger.warning(f"I18n directory not found: {i18n_path}")
|
||||
return
|
||||
|
||||
# Load all available locale files
|
||||
for filename in os.listdir(i18n_path):
|
||||
if filename.endswith('.json'):
|
||||
locale_code = filename[:-5] # Remove .json extension
|
||||
try:
|
||||
self._load_locale_file(i18n_path, filename, locale_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading locale file {filename}: {e}")
|
||||
|
||||
def _load_locale_file(self, path: str, filename: str, locale_code: str):
|
||||
"""Load a single locale JSON file"""
|
||||
file_path = os.path.join(path, filename)
|
||||
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
translations = json.load(f)
|
||||
|
||||
self.translations[locale_code] = translations
|
||||
logger.debug(f"Loaded translations for {locale_code} from {filename}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing locale file {filename}: {e}")
|
||||
|
||||
def set_locale(self, locale: str):
|
||||
"""Set the current locale"""
|
||||
if locale in self.translations:
|
||||
self.current_locale = locale
|
||||
else:
|
||||
logger.warning(f"Locale {locale} not found, using 'en'")
|
||||
self.current_locale = 'en'
|
||||
|
||||
def get_translation(self, key: str, params: Dict[str, Any] = None, **kwargs) -> str:
|
||||
"""Get translation for a key with optional parameters (supports both dict and keyword args)"""
|
||||
# Merge kwargs into params for convenience
|
||||
if params is None:
|
||||
params = {}
|
||||
if kwargs:
|
||||
params = {**params, **kwargs}
|
||||
|
||||
if self.current_locale not in self.translations:
|
||||
return key
|
||||
|
||||
# Navigate through nested object using dot notation
|
||||
keys = key.split('.')
|
||||
value = self.translations[self.current_locale]
|
||||
|
||||
for k in keys:
|
||||
if isinstance(value, dict) and k in value:
|
||||
value = value[k]
|
||||
else:
|
||||
# Fallback to English if current locale doesn't have the key
|
||||
if self.current_locale != 'en' and 'en' in self.translations:
|
||||
en_value = self.translations['en']
|
||||
for k in keys:
|
||||
if isinstance(en_value, dict) and k in en_value:
|
||||
en_value = en_value[k]
|
||||
else:
|
||||
return key
|
||||
value = en_value
|
||||
else:
|
||||
return key
|
||||
break
|
||||
|
||||
if not isinstance(value, str):
|
||||
return key
|
||||
|
||||
# Replace parameters if provided
|
||||
if params:
|
||||
for param_key, param_value in params.items():
|
||||
placeholder = f"{{{param_key}}}"
|
||||
double_placeholder = f"{{{{{param_key}}}}}"
|
||||
value = value.replace(placeholder, str(param_value))
|
||||
value = value.replace(double_placeholder, str(param_value))
|
||||
|
||||
return value
|
||||
|
||||
def get_available_locales(self) -> list:
|
||||
"""Get list of available locales"""
|
||||
return list(self.translations.keys())
|
||||
|
||||
def create_template_filter(self):
|
||||
"""Create a Jinja2 filter function for templates"""
|
||||
def t_filter(key: str, **params) -> str:
|
||||
return self.get_translation(key, params)
|
||||
return t_filter
|
||||
|
||||
# Create global instance
|
||||
server_i18n = ServerI18nManager()
|
||||
@@ -7,97 +7,209 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar('T') # Define a type variable for service types
|
||||
|
||||
class ServiceRegistry:
|
||||
"""Centralized registry for service singletons"""
|
||||
"""Central registry for managing singleton services"""
|
||||
|
||||
_instance = None
|
||||
_services: Dict[str, Any] = {}
|
||||
_lock = asyncio.Lock()
|
||||
_locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get singleton instance of the registry"""
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
async def register_service(cls, name: str, service: Any) -> None:
|
||||
"""Register a service instance with the registry
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
service: Service instance to register
|
||||
"""
|
||||
cls._services[name] = service
|
||||
logger.debug(f"Registered service: {name}")
|
||||
|
||||
@classmethod
|
||||
async def register_service(cls, service_name: str, service_instance: Any) -> None:
|
||||
"""Register a service instance with the registry"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
registry._services[service_name] = service_instance
|
||||
logger.debug(f"Registered service: {service_name}")
|
||||
async def get_service(cls, name: str) -> Optional[Any]:
|
||||
"""Get a service instance by name
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
Service instance or None if not found
|
||||
"""
|
||||
return cls._services.get(name)
|
||||
|
||||
@classmethod
|
||||
async def get_service(cls, service_name: str) -> Any:
|
||||
"""Get a service instance by name"""
|
||||
registry = cls.get_instance()
|
||||
async with cls._lock:
|
||||
if service_name not in registry._services:
|
||||
logger.debug(f"Service {service_name} not found in registry")
|
||||
return None
|
||||
return registry._services[service_name]
|
||||
def get_service_sync(cls, name: str) -> Optional[Any]:
|
||||
"""Synchronously get a service instance by name
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
Service instance or None if not found
|
||||
"""
|
||||
return cls._services.get(name)
|
||||
|
||||
@classmethod
|
||||
def _get_lock(cls, name: str) -> asyncio.Lock:
|
||||
"""Get or create a lock for a service
|
||||
|
||||
Args:
|
||||
name: Service name identifier
|
||||
|
||||
Returns:
|
||||
AsyncIO lock for the service
|
||||
"""
|
||||
if name not in cls._locks:
|
||||
cls._locks[name] = asyncio.Lock()
|
||||
return cls._locks[name]
|
||||
|
||||
# Convenience methods for common services
|
||||
@classmethod
|
||||
async def get_lora_scanner(cls):
|
||||
"""Get the LoraScanner instance"""
|
||||
from .lora_scanner import LoraScanner
|
||||
scanner = await cls.get_service("lora_scanner")
|
||||
if scanner is None:
|
||||
scanner = await LoraScanner.get_instance()
|
||||
await cls.register_service("lora_scanner", scanner)
|
||||
return scanner
|
||||
"""Get or create LoRA scanner instance"""
|
||||
service_name = "lora_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .lora_scanner import LoraScanner
|
||||
|
||||
scanner = await LoraScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_checkpoint_scanner(cls):
|
||||
"""Get the CheckpointScanner instance"""
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
scanner = await cls.get_service("checkpoint_scanner")
|
||||
if scanner is None:
|
||||
"""Get or create Checkpoint scanner instance"""
|
||||
service_name = "checkpoint_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .checkpoint_scanner import CheckpointScanner
|
||||
|
||||
scanner = await CheckpointScanner.get_instance()
|
||||
await cls.register_service("checkpoint_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get the CivitaiClient instance"""
|
||||
from .civitai_client import CivitaiClient
|
||||
client = await cls.get_service("civitai_client")
|
||||
if client is None:
|
||||
client = await CivitaiClient.get_instance()
|
||||
await cls.register_service("civitai_client", client)
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get the DownloadManager instance"""
|
||||
from .download_manager import DownloadManager
|
||||
manager = await cls.get_service("download_manager")
|
||||
if manager is None:
|
||||
manager = await DownloadManager.get_instance()
|
||||
await cls.register_service("download_manager", manager)
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_recipe_scanner(cls):
|
||||
"""Get the RecipeScanner instance"""
|
||||
from .recipe_scanner import RecipeScanner
|
||||
scanner = await cls.get_service("recipe_scanner")
|
||||
if scanner is None:
|
||||
lora_scanner = await cls.get_lora_scanner()
|
||||
scanner = RecipeScanner(lora_scanner)
|
||||
await cls.register_service("recipe_scanner", scanner)
|
||||
return scanner
|
||||
|
||||
"""Get or create Recipe scanner instance"""
|
||||
service_name = "recipe_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .recipe_scanner import RecipeScanner
|
||||
|
||||
scanner = await RecipeScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
async def get_civitai_client(cls):
|
||||
"""Get or create CivitAI client instance"""
|
||||
service_name = "civitai_client"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .civitai_client import CivitaiClient
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
cls._services[service_name] = client
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return client
|
||||
|
||||
@classmethod
|
||||
async def get_download_manager(cls):
|
||||
"""Get or create Download manager instance"""
|
||||
service_name = "download_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .download_manager import DownloadManager
|
||||
|
||||
manager = DownloadManager()
|
||||
cls._services[service_name] = manager
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return manager
|
||||
|
||||
@classmethod
|
||||
async def get_websocket_manager(cls):
|
||||
"""Get the WebSocketManager instance"""
|
||||
from .websocket_manager import ws_manager
|
||||
manager = await cls.get_service("websocket_manager")
|
||||
if manager is None:
|
||||
# ws_manager is already a global instance in websocket_manager.py
|
||||
"""Get or create WebSocket manager instance"""
|
||||
service_name = "websocket_manager"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .websocket_manager import ws_manager
|
||||
await cls.register_service("websocket_manager", ws_manager)
|
||||
manager = ws_manager
|
||||
return manager
|
||||
|
||||
cls._services[service_name] = ws_manager
|
||||
logger.debug(f"Registered {service_name}")
|
||||
return ws_manager
|
||||
|
||||
@classmethod
|
||||
async def get_embedding_scanner(cls):
|
||||
"""Get or create Embedding scanner instance"""
|
||||
service_name = "embedding_scanner"
|
||||
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
async with cls._get_lock(service_name):
|
||||
# Double-check after acquiring lock
|
||||
if service_name in cls._services:
|
||||
return cls._services[service_name]
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from .embedding_scanner import EmbeddingScanner
|
||||
|
||||
scanner = await EmbeddingScanner.get_instance()
|
||||
cls._services[service_name] = scanner
|
||||
logger.debug(f"Created and registered {service_name}")
|
||||
return scanner
|
||||
|
||||
@classmethod
|
||||
def clear_services(cls):
|
||||
"""Clear all registered services - mainly for testing"""
|
||||
cls._services.clear()
|
||||
cls._locks.clear()
|
||||
logger.info("Cleared all registered services")
|
||||
@@ -9,6 +9,8 @@ class SettingsManager:
|
||||
def __init__(self):
|
||||
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'settings.json')
|
||||
self.settings = self._load_settings()
|
||||
self._migrate_download_path_template()
|
||||
self._auto_set_default_roots()
|
||||
self._check_environment_variables()
|
||||
|
||||
def _load_settings(self) -> Dict[str, Any]:
|
||||
@@ -21,6 +23,46 @@ class SettingsManager:
|
||||
logger.error(f"Error loading settings: {e}")
|
||||
return self._get_default_settings()
|
||||
|
||||
def _migrate_download_path_template(self):
|
||||
"""Migrate old download_path_template to new download_path_templates"""
|
||||
old_template = self.settings.get('download_path_template')
|
||||
templates = self.settings.get('download_path_templates')
|
||||
|
||||
# If old template exists and new templates don't exist, migrate
|
||||
if old_template is not None and not templates:
|
||||
logger.info("Migrating download_path_template to download_path_templates")
|
||||
self.settings['download_path_templates'] = {
|
||||
'lora': old_template,
|
||||
'checkpoint': old_template,
|
||||
'embedding': old_template
|
||||
}
|
||||
# Remove old setting
|
||||
del self.settings['download_path_template']
|
||||
self._save_settings()
|
||||
logger.info("Migration completed")
|
||||
|
||||
def _auto_set_default_roots(self):
|
||||
"""Auto set default root paths if only one folder is present and default is empty."""
|
||||
folder_paths = self.settings.get('folder_paths', {})
|
||||
updated = False
|
||||
# loras
|
||||
loras = folder_paths.get('loras', [])
|
||||
if isinstance(loras, list) and len(loras) == 1 and not self.settings.get('default_lora_root'):
|
||||
self.settings['default_lora_root'] = loras[0]
|
||||
updated = True
|
||||
# checkpoints
|
||||
checkpoints = folder_paths.get('checkpoints', [])
|
||||
if isinstance(checkpoints, list) and len(checkpoints) == 1 and not self.settings.get('default_checkpoint_root'):
|
||||
self.settings['default_checkpoint_root'] = checkpoints[0]
|
||||
updated = True
|
||||
# embeddings
|
||||
embeddings = folder_paths.get('embeddings', [])
|
||||
if isinstance(embeddings, list) and len(embeddings) == 1 and not self.settings.get('default_embedding_root'):
|
||||
self.settings['default_embedding_root'] = embeddings[0]
|
||||
updated = True
|
||||
if updated:
|
||||
self._save_settings()
|
||||
|
||||
def _check_environment_variables(self) -> None:
|
||||
"""Check for environment variables and update settings if needed"""
|
||||
env_api_key = os.environ.get('CIVITAI_API_KEY')
|
||||
@@ -38,7 +80,8 @@ class SettingsManager:
|
||||
"""Return default settings"""
|
||||
return {
|
||||
"civitai_api_key": "",
|
||||
"show_only_sfw": False
|
||||
"show_only_sfw": False,
|
||||
"language": "en" # 添加默认语言设置
|
||||
}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
@@ -58,4 +101,53 @@ class SettingsManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving settings: {e}")
|
||||
|
||||
def get_download_path_template(self, model_type: str) -> str:
|
||||
"""Get download path template for specific model type
|
||||
|
||||
Args:
|
||||
model_type: The type of model ('lora', 'checkpoint', 'embedding')
|
||||
|
||||
Returns:
|
||||
Template string for the model type, defaults to '{base_model}/{first_tag}'
|
||||
"""
|
||||
templates = self.settings.get('download_path_templates', {})
|
||||
|
||||
# Handle edge case where templates might be stored as JSON string
|
||||
if isinstance(templates, str):
|
||||
try:
|
||||
# Try to parse JSON string
|
||||
parsed_templates = json.loads(templates)
|
||||
if isinstance(parsed_templates, dict):
|
||||
# Update settings with parsed dictionary
|
||||
self.settings['download_path_templates'] = parsed_templates
|
||||
self._save_settings()
|
||||
templates = parsed_templates
|
||||
logger.info("Successfully parsed download_path_templates from JSON string")
|
||||
else:
|
||||
raise ValueError("Parsed JSON is not a dictionary")
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
# If parsing fails, set default values
|
||||
logger.warning(f"Failed to parse download_path_templates JSON string: {e}. Setting default values.")
|
||||
default_template = '{base_model}/{first_tag}'
|
||||
templates = {
|
||||
'lora': default_template,
|
||||
'checkpoint': default_template,
|
||||
'embedding': default_template
|
||||
}
|
||||
self.settings['download_path_templates'] = templates
|
||||
self._save_settings()
|
||||
|
||||
# Ensure templates is a dictionary
|
||||
if not isinstance(templates, dict):
|
||||
default_template = '{base_model}/{first_tag}'
|
||||
templates = {
|
||||
'lora': default_template,
|
||||
'checkpoint': default_template,
|
||||
'embedding': default_template
|
||||
}
|
||||
self.settings['download_path_templates'] = templates
|
||||
self._save_settings()
|
||||
|
||||
return templates.get(model_type, '{base_model}/{first_tag}')
|
||||
|
||||
settings = SettingsManager()
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from typing import Set, Dict, Optional
|
||||
from uuid import uuid4
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,7 +13,12 @@ class WebSocketManager:
|
||||
def __init__(self):
|
||||
self._websockets: Set[web.WebSocketResponse] = set()
|
||||
self._init_websockets: Set[web.WebSocketResponse] = set() # New set for initialization progress clients
|
||||
self._checkpoint_websockets: Set[web.WebSocketResponse] = set() # New set for checkpoint download progress
|
||||
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
|
||||
# Add progress tracking dictionary
|
||||
self._download_progress: Dict[str, Dict] = {}
|
||||
# Add auto-organize progress tracking
|
||||
self._auto_organize_progress: Optional[Dict] = None
|
||||
self._auto_organize_lock = asyncio.Lock()
|
||||
|
||||
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection"""
|
||||
@@ -39,21 +47,48 @@ class WebSocketManager:
|
||||
finally:
|
||||
self._init_websockets.discard(ws)
|
||||
return ws
|
||||
|
||||
async def handle_checkpoint_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection for checkpoint download progress"""
|
||||
|
||||
async def handle_download_connection(self, request: web.Request) -> web.WebSocketResponse:
|
||||
"""Handle new WebSocket connection for download progress"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
self._checkpoint_websockets.add(ws)
|
||||
|
||||
# Get download_id from query parameters
|
||||
download_id = request.query.get('id')
|
||||
|
||||
if not download_id:
|
||||
# Generate a new download ID if not provided
|
||||
download_id = str(uuid4())
|
||||
|
||||
# Store the websocket with its download ID
|
||||
self._download_websockets[download_id] = ws
|
||||
|
||||
try:
|
||||
# Send the download ID back to the client
|
||||
await ws.send_json({
|
||||
'type': 'download_id',
|
||||
'download_id': download_id
|
||||
})
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == web.WSMsgType.ERROR:
|
||||
logger.error(f'Checkpoint WebSocket error: {ws.exception()}')
|
||||
logger.error(f'Download WebSocket error: {ws.exception()}')
|
||||
finally:
|
||||
self._checkpoint_websockets.discard(ws)
|
||||
if download_id in self._download_websockets:
|
||||
del self._download_websockets[download_id]
|
||||
|
||||
# Schedule cleanup of completed downloads after WebSocket disconnection
|
||||
asyncio.create_task(self._delayed_cleanup(download_id))
|
||||
return ws
|
||||
|
||||
|
||||
async def _delayed_cleanup(self, download_id: str, delay_seconds: int = 300):
|
||||
"""Clean up download progress after a delay (5 minutes by default)"""
|
||||
await asyncio.sleep(delay_seconds)
|
||||
progress_data = self._download_progress.get(download_id)
|
||||
if progress_data and progress_data.get('progress', 0) >= 100:
|
||||
self.cleanup_download_progress(download_id)
|
||||
logger.debug(f"Delayed cleanup completed for download {download_id}")
|
||||
|
||||
async def broadcast(self, data: Dict):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self._websockets:
|
||||
@@ -84,17 +119,72 @@ class WebSocketManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending initialization progress: {e}")
|
||||
|
||||
async def broadcast_checkpoint_progress(self, data: Dict):
|
||||
"""Broadcast checkpoint download progress to connected clients"""
|
||||
if not self._checkpoint_websockets:
|
||||
async def broadcast_download_progress(self, download_id: str, data: Dict):
|
||||
"""Send progress update to specific download client"""
|
||||
# Store simplified progress data in memory (only progress percentage)
|
||||
self._download_progress[download_id] = {
|
||||
'progress': data.get('progress', 0),
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
if download_id not in self._download_websockets:
|
||||
logger.debug(f"No WebSocket found for download ID: {download_id}")
|
||||
return
|
||||
|
||||
for ws in self._checkpoint_websockets:
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending checkpoint progress: {e}")
|
||||
|
||||
ws = self._download_websockets[download_id]
|
||||
try:
|
||||
await ws.send_json(data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending download progress: {e}")
|
||||
|
||||
async def broadcast_auto_organize_progress(self, data: Dict):
|
||||
"""Broadcast auto-organize progress to connected clients"""
|
||||
# Store progress data in memory
|
||||
self._auto_organize_progress = data
|
||||
|
||||
# Broadcast via WebSocket
|
||||
await self.broadcast(data)
|
||||
|
||||
def get_auto_organize_progress(self) -> Optional[Dict]:
|
||||
"""Get current auto-organize progress"""
|
||||
return self._auto_organize_progress
|
||||
|
||||
def cleanup_auto_organize_progress(self):
|
||||
"""Clear auto-organize progress data"""
|
||||
self._auto_organize_progress = None
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
"""Check if auto-organize is currently running"""
|
||||
if not self._auto_organize_progress:
|
||||
return False
|
||||
status = self._auto_organize_progress.get('status')
|
||||
return status in ['started', 'processing', 'cleaning']
|
||||
|
||||
async def get_auto_organize_lock(self):
|
||||
"""Get the auto-organize lock"""
|
||||
return self._auto_organize_lock
|
||||
|
||||
def get_download_progress(self, download_id: str) -> Optional[Dict]:
|
||||
"""Get progress information for a specific download"""
|
||||
return self._download_progress.get(download_id)
|
||||
|
||||
def cleanup_download_progress(self, download_id: str):
|
||||
"""Remove progress info for a specific download"""
|
||||
self._download_progress.pop(download_id, None)
|
||||
|
||||
def cleanup_old_downloads(self, max_age_hours: int = 24):
|
||||
"""Clean up old download progress entries"""
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
to_remove = []
|
||||
|
||||
for download_id, progress_data in self._download_progress.items():
|
||||
if progress_data.get('timestamp', datetime.now()) < cutoff_time:
|
||||
to_remove.append(download_id)
|
||||
|
||||
for download_id in to_remove:
|
||||
self._download_progress.pop(download_id, None)
|
||||
logger.debug(f"Cleaned up old download progress for {download_id}")
|
||||
|
||||
def get_connected_clients_count(self) -> int:
|
||||
"""Get number of connected clients"""
|
||||
return len(self._websockets)
|
||||
@@ -102,10 +192,14 @@ class WebSocketManager:
|
||||
def get_init_clients_count(self) -> int:
|
||||
"""Get number of initialization progress clients"""
|
||||
return len(self._init_websockets)
|
||||
|
||||
def get_checkpoint_clients_count(self) -> int:
|
||||
"""Get number of checkpoint progress clients"""
|
||||
return len(self._checkpoint_websockets)
|
||||
|
||||
def get_download_clients_count(self) -> int:
|
||||
"""Get number of download progress clients"""
|
||||
return len(self._download_websockets)
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
"""Generate a unique download ID"""
|
||||
return str(uuid4())
|
||||
|
||||
# Global instance
|
||||
ws_manager = WebSocketManager()
|
||||
@@ -10,7 +10,8 @@ NSFW_LEVELS = {
|
||||
# Node type constants
|
||||
NODE_TYPES = {
|
||||
"Lora Loader (LoraManager)": 1,
|
||||
"Lora Stacker (LoraManager)": 2
|
||||
"Lora Stacker (LoraManager)": 2,
|
||||
"WanVideo Lora Select (LoraManager)": 3
|
||||
}
|
||||
|
||||
# Default ComfyUI node color when bgcolor is null
|
||||
@@ -45,4 +46,15 @@ SUPPORTED_MEDIA_EXTENSIONS = {
|
||||
}
|
||||
|
||||
# Valid Lora types
|
||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||
VALID_LORA_TYPES = ['lora', 'locon', 'dora']
|
||||
|
||||
# Auto-organize settings
|
||||
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
|
||||
|
||||
# Civitai model tags in priority order for subfolder organization
|
||||
CIVITAI_MODEL_TAGS = [
|
||||
'character', 'style', 'concept', 'clothing',
|
||||
'realistic', 'anime', 'toon', 'furry',
|
||||
'poses', 'background', 'tool', 'vehicle', 'buildings',
|
||||
'objects', 'assets', 'animal', 'action'
|
||||
]
|
||||
@@ -6,8 +6,10 @@ import time
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .example_images_processor import ExampleImagesProcessor
|
||||
from .example_images_metadata import MetadataUpdater
|
||||
from ..services.websocket_manager import ws_manager # Add this import at the top
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +26,8 @@ download_progress = {
|
||||
'start_time': None,
|
||||
'end_time': None,
|
||||
'processed_models': set(), # Track models that have been processed
|
||||
'refreshed_models': set() # Track models that had metadata refreshed
|
||||
'refreshed_models': set(), # Track models that had metadata refreshed
|
||||
'failed_models': set() # Track models that failed to download after metadata refresh
|
||||
}
|
||||
|
||||
class DownloadManager:
|
||||
@@ -50,6 +53,7 @@ class DownloadManager:
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
@@ -64,6 +68,7 @@ class DownloadManager:
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
delay = 0 # Temporary: Disable delay to speed up downloads
|
||||
|
||||
if not output_dir:
|
||||
return web.json_response({
|
||||
@@ -91,12 +96,15 @@ class DownloadManager:
|
||||
with open(progress_file, 'r', encoding='utf-8') as f:
|
||||
saved_progress = json.load(f)
|
||||
download_progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
||||
logger.info(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed")
|
||||
download_progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
||||
logger.debug(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed, {len(download_progress['failed_models'])} models marked as failed")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load progress file: {e}")
|
||||
download_progress['processed_models'] = set()
|
||||
download_progress['failed_models'] = set()
|
||||
else:
|
||||
download_progress['processed_models'] = set()
|
||||
download_progress['failed_models'] = set()
|
||||
|
||||
# Start the download task
|
||||
is_downloading = True
|
||||
@@ -113,6 +121,7 @@ class DownloadManager:
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -136,6 +145,7 @@ class DownloadManager:
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -214,6 +224,10 @@ class DownloadManager:
|
||||
if 'checkpoint' in model_types:
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
scanners.append(('checkpoint', checkpoint_scanner))
|
||||
|
||||
if 'embedding' in model_types:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
scanners.append(('embedding', embedding_scanner))
|
||||
|
||||
# Get all models
|
||||
all_models = []
|
||||
@@ -226,7 +240,7 @@ class DownloadManager:
|
||||
|
||||
# Update total count
|
||||
download_progress['total'] = len(all_models)
|
||||
logger.info(f"Found {download_progress['total']} models to process")
|
||||
logger.debug(f"Found {download_progress['total']} models to process")
|
||||
|
||||
# Process each model
|
||||
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
||||
@@ -246,7 +260,7 @@ class DownloadManager:
|
||||
# Mark as completed
|
||||
download_progress['status'] = 'completed'
|
||||
download_progress['end_time'] = time.time()
|
||||
logger.info(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
|
||||
logger.debug(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during example images download: {str(e)}"
|
||||
@@ -295,6 +309,11 @@ class DownloadManager:
|
||||
# Update current model info
|
||||
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||
|
||||
# Skip if already in failed models
|
||||
if model_hash in download_progress['failed_models']:
|
||||
logger.debug(f"Skipping known failed model: {model_name}")
|
||||
return False
|
||||
|
||||
# Skip if already processed AND directory exists with files
|
||||
if model_hash in download_progress['processed_models']:
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
@@ -304,6 +323,8 @@ class DownloadManager:
|
||||
return False
|
||||
else:
|
||||
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
|
||||
# Remove from processed models since we need to reprocess
|
||||
download_progress['processed_models'].discard(model_hash)
|
||||
|
||||
# Create model directory
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
@@ -347,12 +368,23 @@ class DownloadManager:
|
||||
success, _ = await ExampleImagesProcessor.download_model_images(
|
||||
model_hash, model_name, updated_images, model_dir, optimize, independent_session
|
||||
)
|
||||
|
||||
download_progress['refreshed_models'].add(model_hash)
|
||||
|
||||
# Only mark as processed if all images were downloaded successfully
|
||||
# Mark as processed if successful, or as failed if unsuccessful after refresh
|
||||
if success:
|
||||
download_progress['processed_models'].add(model_hash)
|
||||
else:
|
||||
# If we refreshed metadata and still failed, mark as permanently failed
|
||||
if model_hash in download_progress['refreshed_models']:
|
||||
download_progress['failed_models'].add(model_hash)
|
||||
logger.info(f"Marking model {model_name} as failed after metadata refresh")
|
||||
|
||||
return True # Return True to indicate a remote download happened
|
||||
else:
|
||||
# No civitai data or images available, mark as failed to avoid future attempts
|
||||
download_progress['failed_models'].add(model_hash)
|
||||
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
|
||||
|
||||
# Save progress periodically
|
||||
if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1:
|
||||
@@ -387,6 +419,7 @@ class DownloadManager:
|
||||
progress_data = {
|
||||
'processed_models': list(download_progress['processed_models']),
|
||||
'refreshed_models': list(download_progress['refreshed_models']),
|
||||
'failed_models': list(download_progress['failed_models']),
|
||||
'completed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'last_update': time.time()
|
||||
@@ -401,4 +434,364 @@ class DownloadManager:
|
||||
with open(progress_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(progress_data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save progress file: {e}")
|
||||
logger.error(f"Failed to save progress file: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def start_force_download(request):
|
||||
"""
|
||||
Force download example images for specific models
|
||||
|
||||
Expects a JSON body with:
|
||||
{
|
||||
"model_hashes": ["hash1", "hash2", ...], # List of model hashes to download
|
||||
"output_dir": "path/to/output", # Base directory to save example images
|
||||
"optimize": true, # Whether to optimize images (default: true)
|
||||
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
|
||||
"delay": 1.0 # Delay between downloads (default: 1.0)
|
||||
}
|
||||
"""
|
||||
global download_task, is_downloading, download_progress
|
||||
|
||||
if is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download already in progress'
|
||||
}, status=400)
|
||||
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
model_hashes = data.get('model_hashes', [])
|
||||
output_dir = data.get('output_dir')
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
if not model_hashes:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hashes parameter'
|
||||
}, status=400)
|
||||
|
||||
if not output_dir:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing output_dir parameter'
|
||||
}, status=400)
|
||||
|
||||
# Create the output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Initialize progress tracking
|
||||
download_progress['total'] = len(model_hashes)
|
||||
download_progress['completed'] = 0
|
||||
download_progress['current_model'] = ''
|
||||
download_progress['status'] = 'running'
|
||||
download_progress['errors'] = []
|
||||
download_progress['last_error'] = None
|
||||
download_progress['start_time'] = time.time()
|
||||
download_progress['end_time'] = None
|
||||
download_progress['processed_models'] = set()
|
||||
download_progress['refreshed_models'] = set()
|
||||
download_progress['failed_models'] = set()
|
||||
|
||||
# Set download status to downloading
|
||||
is_downloading = True
|
||||
|
||||
# Execute the download function directly instead of creating a background task
|
||||
result = await DownloadManager._download_specific_models_example_images_sync(
|
||||
model_hashes,
|
||||
output_dir,
|
||||
optimize,
|
||||
model_types,
|
||||
delay
|
||||
)
|
||||
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': 'Force download completed',
|
||||
'result': result
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):
|
||||
"""Download example images for specific models only - synchronous version"""
|
||||
global download_progress
|
||||
|
||||
# Create independent download session
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=3,
|
||||
force_close=False,
|
||||
enable_cleanup_closed=True
|
||||
)
|
||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
||||
independent_session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=True,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
try:
|
||||
# Get scanners
|
||||
scanners = []
|
||||
if 'lora' in model_types:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
scanners.append(('lora', lora_scanner))
|
||||
|
||||
if 'checkpoint' in model_types:
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
scanners.append(('checkpoint', checkpoint_scanner))
|
||||
|
||||
if 'embedding' in model_types:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
scanners.append(('embedding', embedding_scanner))
|
||||
|
||||
# Find the specified models
|
||||
models_to_process = []
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
if model.get('sha256') in model_hashes:
|
||||
models_to_process.append((scanner_type, model, scanner))
|
||||
|
||||
# Update total count based on found models
|
||||
download_progress['total'] = len(models_to_process)
|
||||
logger.debug(f"Found {download_progress['total']} models to process")
|
||||
|
||||
# Send initial progress via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': 0,
|
||||
'total': download_progress['total'],
|
||||
'status': 'running',
|
||||
'current_model': ''
|
||||
})
|
||||
|
||||
# Process each model
|
||||
success_count = 0
|
||||
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
|
||||
# Force process this model regardless of previous status
|
||||
was_successful = await DownloadManager._process_specific_model(
|
||||
scanner_type, model, scanner,
|
||||
output_dir, optimize, independent_session
|
||||
)
|
||||
|
||||
if was_successful:
|
||||
success_count += 1
|
||||
|
||||
# Update progress
|
||||
download_progress['completed'] += 1
|
||||
|
||||
# Send progress update via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'status': 'running',
|
||||
'current_model': download_progress['current_model']
|
||||
})
|
||||
|
||||
# Only add delay after remote download, and not after processing the last model
|
||||
if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running':
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Mark as completed
|
||||
download_progress['status'] = 'completed'
|
||||
download_progress['end_time'] = time.time()
|
||||
logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
|
||||
|
||||
# Send final progress via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'status': 'completed',
|
||||
'current_model': ''
|
||||
})
|
||||
|
||||
return {
|
||||
'total': download_progress['total'],
|
||||
'processed': download_progress['completed'],
|
||||
'successful': success_count,
|
||||
'errors': download_progress['errors']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during forced example images download: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
download_progress['status'] = 'error'
|
||||
download_progress['end_time'] = time.time()
|
||||
|
||||
# Send error status via WebSocket
|
||||
await ws_manager.broadcast({
|
||||
'type': 'example_images_progress',
|
||||
'processed': download_progress['completed'],
|
||||
'total': download_progress['total'],
|
||||
'status': 'error',
|
||||
'error': error_msg,
|
||||
'current_model': ''
|
||||
})
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Close the independent session
|
||||
try:
|
||||
await independent_session.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing download session: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session):
|
||||
"""Process a specific model for forced download, ignoring previous download status"""
|
||||
global download_progress
|
||||
|
||||
# Check if download is paused
|
||||
while download_progress['status'] == 'paused':
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check if download should continue
|
||||
if download_progress['status'] != 'running':
|
||||
logger.info(f"Download stopped: {download_progress['status']}")
|
||||
return False
|
||||
|
||||
model_hash = model.get('sha256', '').lower()
|
||||
model_name = model.get('model_name', 'Unknown')
|
||||
model_file_path = model.get('file_path', '')
|
||||
model_file_name = model.get('file_name', '')
|
||||
|
||||
try:
|
||||
# Update current model info
|
||||
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||
|
||||
# Create model directory
|
||||
model_dir = os.path.join(output_dir, model_hash)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
# First check for local example images - local processing doesn't need delay
|
||||
local_images_processed = await ExampleImagesProcessor.process_local_examples(
|
||||
model_file_path, model_file_name, model_name, model_dir, optimize
|
||||
)
|
||||
|
||||
# If we processed local images, update metadata
|
||||
if local_images_processed:
|
||||
await MetadataUpdater.update_metadata_from_local_examples(
|
||||
model_hash, model, scanner_type, scanner, model_dir
|
||||
)
|
||||
download_progress['processed_models'].add(model_hash)
|
||||
return False # Return False to indicate no remote download happened
|
||||
|
||||
# If no local images, try to download from remote
|
||||
elif model.get('civitai') and model.get('civitai', {}).get('images'):
|
||||
images = model.get('civitai', {}).get('images', [])
|
||||
|
||||
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||
model_hash, model_name, images, model_dir, optimize, independent_session
|
||||
)
|
||||
|
||||
# If metadata is stale, try to refresh it
|
||||
if is_stale and model_hash not in download_progress['refreshed_models']:
|
||||
await MetadataUpdater.refresh_model_metadata(
|
||||
model_hash, model_name, scanner_type, scanner
|
||||
)
|
||||
|
||||
# Get the updated model data
|
||||
updated_model = await MetadataUpdater.get_updated_model(
|
||||
model_hash, scanner
|
||||
)
|
||||
|
||||
if updated_model and updated_model.get('civitai', {}).get('images'):
|
||||
# Retry download with updated metadata
|
||||
updated_images = updated_model.get('civitai', {}).get('images', [])
|
||||
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||
model_hash, model_name, updated_images, model_dir, optimize, independent_session
|
||||
)
|
||||
|
||||
# Combine failed images from both attempts
|
||||
failed_images.extend(additional_failed_images)
|
||||
|
||||
download_progress['refreshed_models'].add(model_hash)
|
||||
|
||||
# For forced downloads, remove failed images from metadata
|
||||
if failed_images:
|
||||
# Create a copy of images excluding failed ones
|
||||
await DownloadManager._remove_failed_images_from_metadata(
|
||||
model_hash, model_name, failed_images, scanner
|
||||
)
|
||||
|
||||
# Mark as processed
|
||||
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
|
||||
download_progress['processed_models'].add(model_hash)
|
||||
|
||||
return True # Return True to indicate a remote download happened
|
||||
else:
|
||||
logger.debug(f"No civitai images available for model {model_name}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
return False # Return False on exception
|
||||
|
||||
@staticmethod
|
||||
async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner):
|
||||
"""Remove failed images from model metadata"""
|
||||
try:
|
||||
# Get current model data
|
||||
model_data = await MetadataUpdater.get_updated_model(model_hash, scanner)
|
||||
if not model_data:
|
||||
logger.warning(f"Could not find model data for {model_name} to remove failed images")
|
||||
return
|
||||
|
||||
if not model_data.get('civitai', {}).get('images'):
|
||||
logger.warning(f"No images in metadata for {model_name}")
|
||||
return
|
||||
|
||||
# Get current images
|
||||
current_images = model_data['civitai']['images']
|
||||
|
||||
# Filter out failed images
|
||||
updated_images = [img for img in current_images if img.get('url') not in failed_images]
|
||||
|
||||
# If images were removed, update metadata
|
||||
if len(updated_images) < len(current_images):
|
||||
removed_count = len(current_images) - len(updated_images)
|
||||
logger.info(f"Removing {removed_count} failed images from metadata for {model_name}")
|
||||
|
||||
# Update the images list
|
||||
model_data['civitai']['images'] = updated_images
|
||||
|
||||
# Save metadata to file
|
||||
file_path = model_data.get('file_path')
|
||||
if file_path:
|
||||
# Create a copy of model data without 'folder' field
|
||||
model_copy = model_data.copy()
|
||||
model_copy.pop('folder', None)
|
||||
|
||||
# Write metadata to file
|
||||
await MetadataManager.save_metadata(file_path, model_copy)
|
||||
logger.info(f"Saved updated metadata for {model_name} after removing failed images")
|
||||
|
||||
# Update the scanner cache
|
||||
await scanner.update_single_model_cache(file_path, file_path, model_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
|
||||
@@ -43,7 +43,15 @@ class ExampleImagesFileManager:
|
||||
|
||||
# Construct folder path for this model
|
||||
model_folder = os.path.join(example_images_path, model_hash)
|
||||
|
||||
model_folder = os.path.abspath(model_folder) # Get absolute path
|
||||
|
||||
# Path validation: ensure model_folder is under example_images_path
|
||||
if not model_folder.startswith(os.path.abspath(example_images_path)):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Invalid model folder path'
|
||||
}, status=400)
|
||||
|
||||
# Check if folder exists
|
||||
if not os.path.exists(model_folder):
|
||||
return web.json_response({
|
||||
|
||||
@@ -102,6 +102,78 @@ class ExampleImagesProcessor:
|
||||
|
||||
return model_success, False # (success, is_metadata_stale)
|
||||
|
||||
@staticmethod
|
||||
async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session):
|
||||
"""Download images for a single model with tracking of failed image URLs
|
||||
|
||||
Returns:
|
||||
tuple: (success, is_stale_metadata, failed_images) - whether download was successful, whether metadata is stale, list of failed image URLs
|
||||
"""
|
||||
model_success = True
|
||||
failed_images = []
|
||||
|
||||
for i, image in enumerate(model_images):
|
||||
image_url = image.get('url')
|
||||
if not image_url:
|
||||
continue
|
||||
|
||||
# Get image filename from URL
|
||||
image_filename = os.path.basename(image_url.split('?')[0])
|
||||
image_ext = os.path.splitext(image_filename)[1].lower()
|
||||
|
||||
# Handle images and videos
|
||||
is_image = image_ext in SUPPORTED_MEDIA_EXTENSIONS['images']
|
||||
is_video = image_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']
|
||||
|
||||
if not (is_image or is_video):
|
||||
logger.debug(f"Skipping unsupported file type: {image_filename}")
|
||||
continue
|
||||
|
||||
# Use 0-based indexing instead of 1-based indexing
|
||||
save_filename = f"image_{i}{image_ext}"
|
||||
|
||||
# If optimizing images and this is a Civitai image, use their pre-optimized WebP version
|
||||
if is_image and optimize and 'civitai.com' in image_url:
|
||||
image_url = ExampleImagesProcessor.get_civitai_optimized_url(image_url)
|
||||
save_filename = f"image_{i}.webp"
|
||||
|
||||
# Check if already downloaded
|
||||
save_path = os.path.join(model_dir, save_filename)
|
||||
if os.path.exists(save_path):
|
||||
logger.debug(f"File already exists: {save_path}")
|
||||
continue
|
||||
|
||||
# Download the file
|
||||
try:
|
||||
logger.debug(f"Downloading {save_filename} for {model_name}")
|
||||
|
||||
# Download directly using the independent session
|
||||
async with independent_session.get(image_url, timeout=60) as response:
|
||||
if response.status == 200:
|
||||
with open(save_path, 'wb') as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
elif response.status == 404:
|
||||
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
|
||||
logger.warning(error_msg)
|
||||
model_success = False # Mark the model as failed due to 404 error
|
||||
failed_images.append(image_url) # Track failed URL
|
||||
# Return early to trigger metadata refresh attempt
|
||||
return False, True, failed_images # (success, is_metadata_stale, failed_images)
|
||||
else:
|
||||
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
|
||||
logger.warning(error_msg)
|
||||
model_success = False # Mark the model as failed
|
||||
failed_images.append(image_url) # Track failed URL
|
||||
except Exception as e:
|
||||
error_msg = f"Error downloading file {image_url}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
model_success = False # Mark the model as failed
|
||||
failed_images.append(image_url) # Track failed URL
|
||||
|
||||
return model_success, False, failed_images # (success, is_metadata_stale, failed_images)
|
||||
|
||||
@staticmethod
|
||||
async def process_local_examples(model_file_path, model_file_name, model_name, model_dir, optimize):
|
||||
"""Process local example images
|
||||
@@ -251,12 +323,13 @@ class ExampleImagesProcessor:
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner]:
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||
cache = await scan_obj.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
if item.get('sha256') == model_hash:
|
||||
@@ -384,12 +457,13 @@ class ExampleImagesProcessor:
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner]:
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||
if scan_obj.has_hash(model_hash):
|
||||
cache = await scan_obj.get_cached_data()
|
||||
for item in cache.raw_data:
|
||||
|
||||
@@ -27,39 +27,58 @@ def find_preview_file(base_name: str, dir_path: str) -> str:
|
||||
full_pattern = os.path.join(dir_path, f"{base_name}{ext}")
|
||||
if os.path.exists(full_pattern):
|
||||
# Check if this is an image and not already webp
|
||||
if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'):
|
||||
try:
|
||||
# Optimize the image to webp format
|
||||
webp_path = os.path.join(dir_path, f"{base_name}.webp")
|
||||
# TODO: disable the optimization for now, maybe add a config option later
|
||||
# if ext.lower().endswith(('.jpg', '.jpeg', '.png')) and not ext.lower().endswith('.webp'):
|
||||
# try:
|
||||
# # Optimize the image to webp format
|
||||
# webp_path = os.path.join(dir_path, f"{base_name}.webp")
|
||||
|
||||
# Use ExifUtils to optimize the image
|
||||
with open(full_pattern, 'rb') as f:
|
||||
image_data = f.read()
|
||||
# # Use ExifUtils to optimize the image
|
||||
# with open(full_pattern, 'rb') as f:
|
||||
# image_data = f.read()
|
||||
|
||||
optimized_data, _ = ExifUtils.optimize_image(
|
||||
image_data=image_data,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format='webp',
|
||||
quality=85,
|
||||
preserve_metadata=False
|
||||
)
|
||||
# optimized_data, _ = ExifUtils.optimize_image(
|
||||
# image_data=image_data,
|
||||
# target_width=CARD_PREVIEW_WIDTH,
|
||||
# format='webp',
|
||||
# quality=85,
|
||||
# preserve_metadata=False
|
||||
# )
|
||||
|
||||
# Save the optimized webp file
|
||||
with open(webp_path, 'wb') as f:
|
||||
f.write(optimized_data)
|
||||
# # Save the optimized webp file
|
||||
# with open(webp_path, 'wb') as f:
|
||||
# f.write(optimized_data)
|
||||
|
||||
logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}")
|
||||
return webp_path.replace(os.sep, "/")
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing preview image {full_pattern}: {e}")
|
||||
# Fall back to original file if optimization fails
|
||||
return full_pattern.replace(os.sep, "/")
|
||||
# logger.debug(f"Optimized preview image from {full_pattern} to {webp_path}")
|
||||
# return webp_path.replace(os.sep, "/")
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error optimizing preview image {full_pattern}: {e}")
|
||||
# # Fall back to original file if optimization fails
|
||||
# return full_pattern.replace(os.sep, "/")
|
||||
|
||||
# Return the original path for webp images or non-image files
|
||||
return full_pattern.replace(os.sep, "/")
|
||||
|
||||
return ""
|
||||
|
||||
def get_preview_extension(preview_path: str) -> str:
|
||||
"""Get the complete preview extension from a preview file path
|
||||
|
||||
Args:
|
||||
preview_path: Path to the preview file
|
||||
|
||||
Returns:
|
||||
str: The complete extension (e.g., '.preview.png', '.png', '.webp')
|
||||
"""
|
||||
preview_path_lower = preview_path.lower()
|
||||
|
||||
# Check for compound extensions first (longer matches first)
|
||||
for ext in sorted(PREVIEW_EXTENSIONS, key=len, reverse=True):
|
||||
if preview_path_lower.endswith(ext.lower()):
|
||||
return ext
|
||||
|
||||
return os.path.splitext(preview_path)[1]
|
||||
|
||||
def normalize_path(path: str) -> str:
|
||||
"""Normalize file path to use forward slashes"""
|
||||
return path.replace(os.sep, "/") if path else path
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import logging
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
@@ -16,7 +16,7 @@ class MetadataManager:
|
||||
|
||||
This class is responsible for:
|
||||
1. Loading metadata safely with fallback mechanisms
|
||||
2. Saving metadata with atomic operations and backups
|
||||
2. Saving metadata with atomic operations
|
||||
3. Creating default metadata for models
|
||||
4. Handling unknown fields gracefully
|
||||
"""
|
||||
@@ -24,81 +24,44 @@ class MetadataManager:
|
||||
@staticmethod
|
||||
async def load_metadata(file_path: str, model_class: Type[BaseModelMetadata] = LoraMetadata) -> Optional[BaseModelMetadata]:
|
||||
"""
|
||||
Load metadata with robust error handling and data preservation.
|
||||
Load metadata safely.
|
||||
|
||||
Args:
|
||||
file_path: Path to the model file
|
||||
model_class: Class to instantiate (LoraMetadata, CheckpointMetadata, etc.)
|
||||
|
||||
Returns:
|
||||
BaseModelMetadata instance or None if file doesn't exist
|
||||
tuple: (metadata, should_skip)
|
||||
- metadata: BaseModelMetadata instance or None
|
||||
- should_skip: True if corrupted metadata file exists and model should be skipped
|
||||
"""
|
||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||
backup_path = f"{metadata_path}.bak"
|
||||
|
||||
# Try loading the main metadata file
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Create model instance
|
||||
metadata = model_class.from_dict(data)
|
||||
|
||||
# Normalize paths
|
||||
await MetadataManager._normalize_metadata_paths(metadata, file_path)
|
||||
|
||||
return metadata
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# JSON parsing error - try to restore from backup
|
||||
logger.warning(f"Invalid JSON in metadata file: {metadata_path}")
|
||||
return await MetadataManager._restore_from_backup(backup_path, file_path, model_class)
|
||||
|
||||
except Exception as e:
|
||||
# Other errors might be due to unknown fields or schema changes
|
||||
logger.error(f"Error loading metadata from {metadata_path}: {str(e)}")
|
||||
return await MetadataManager._restore_from_backup(backup_path, file_path, model_class)
|
||||
# Check if metadata file exists
|
||||
if not os.path.exists(metadata_path):
|
||||
return None, False
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _restore_from_backup(backup_path: str, file_path: str, model_class: Type[BaseModelMetadata]) -> Optional[BaseModelMetadata]:
|
||||
"""
|
||||
Try to restore metadata from backup file
|
||||
|
||||
Args:
|
||||
backup_path: Path to backup file
|
||||
file_path: Path to the original model file
|
||||
model_class: Class to instantiate
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
Returns:
|
||||
BaseModelMetadata instance or None if restoration fails
|
||||
"""
|
||||
if os.path.exists(backup_path):
|
||||
try:
|
||||
logger.info(f"Attempting to restore metadata from backup: {backup_path}")
|
||||
with open(backup_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
# Create model instance
|
||||
metadata = model_class.from_dict(data)
|
||||
|
||||
# Process data similarly to normal loading
|
||||
metadata = model_class.from_dict(data)
|
||||
await MetadataManager._normalize_metadata_paths(metadata, file_path)
|
||||
return metadata
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restore from backup: {str(e)}")
|
||||
|
||||
return None
|
||||
# Normalize paths
|
||||
await MetadataManager._normalize_metadata_paths(metadata, file_path)
|
||||
|
||||
return metadata, False
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
error_type = "Invalid JSON" if isinstance(e, json.JSONDecodeError) else "Parse error"
|
||||
logger.error(f"{error_type} in metadata file: {metadata_path}. Error: {str(e)}. Skipping model to preserve existing data.")
|
||||
return None, True # should_skip = True
|
||||
|
||||
@staticmethod
|
||||
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict], create_backup: bool = False) -> bool:
|
||||
async def save_metadata(path: str, metadata: Union[BaseModelMetadata, Dict]) -> bool:
|
||||
"""
|
||||
Save metadata with atomic write operations and backup creation.
|
||||
Save metadata with atomic write operations.
|
||||
|
||||
Args:
|
||||
path: Path to the model file or directly to the metadata file
|
||||
metadata: Metadata to save (either BaseModelMetadata object or dict)
|
||||
create_backup: Whether to create a new backup of existing file if a backup doesn't already exist
|
||||
|
||||
Returns:
|
||||
bool: Success or failure
|
||||
@@ -111,19 +74,8 @@ class MetadataManager:
|
||||
file_path = path
|
||||
metadata_path = f"{os.path.splitext(file_path)[0]}.metadata.json"
|
||||
temp_path = f"{metadata_path}.tmp"
|
||||
backup_path = f"{metadata_path}.bak"
|
||||
|
||||
try:
|
||||
# Create backup if file exists and either:
|
||||
# 1. create_backup is True, OR
|
||||
# 2. backup file doesn't already exist
|
||||
if os.path.exists(metadata_path) and (create_backup or not os.path.exists(backup_path)):
|
||||
try:
|
||||
shutil.copy2(metadata_path, backup_path)
|
||||
logger.debug(f"Created metadata backup at: {backup_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create metadata backup: {str(e)}")
|
||||
|
||||
# Convert to dict if needed
|
||||
if isinstance(metadata, BaseModelMetadata):
|
||||
metadata_dict = metadata.to_dict()
|
||||
@@ -196,7 +148,7 @@ class MetadataManager:
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=sha256,
|
||||
base_model="Unknown",
|
||||
preview_url=normalize_path(preview_url),
|
||||
@@ -205,13 +157,27 @@ class MetadataManager:
|
||||
model_type="checkpoint",
|
||||
from_civitai=True
|
||||
)
|
||||
elif model_class.__name__ == "EmbeddingMetadata":
|
||||
metadata = model_class(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=sha256,
|
||||
base_model="Unknown",
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription="",
|
||||
from_civitai=True
|
||||
)
|
||||
else: # Default to LoraMetadata
|
||||
metadata = model_class(
|
||||
file_name=base_name,
|
||||
model_name=base_name,
|
||||
file_path=normalize_path(file_path),
|
||||
size=os.path.getsize(real_path),
|
||||
modified=os.path.getmtime(real_path),
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=sha256,
|
||||
base_model="Unknown",
|
||||
preview_url=normalize_path(preview_url),
|
||||
@@ -222,10 +188,10 @@ class MetadataManager:
|
||||
)
|
||||
|
||||
# Try to extract model-specific metadata
|
||||
await MetadataManager._enrich_metadata(metadata, real_path)
|
||||
# await MetadataManager._enrich_metadata(metadata, real_path)
|
||||
|
||||
# Save the created metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata, create_backup=False)
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -266,6 +232,12 @@ class MetadataManager:
|
||||
"""
|
||||
need_update = False
|
||||
|
||||
# Check if file_name matches the actual file name
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
if metadata.file_name != base_name:
|
||||
metadata.file_name = base_name
|
||||
need_update = True
|
||||
|
||||
# Check if file path is different from what's in metadata
|
||||
if normalize_path(file_path) != metadata.file_path:
|
||||
metadata.file_path = normalize_path(file_path)
|
||||
@@ -289,4 +261,4 @@ class MetadataManager:
|
||||
|
||||
# If path attributes were changed, save the metadata back to disk
|
||||
if need_update:
|
||||
await MetadataManager.save_metadata(file_path, metadata, create_backup=False)
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
@@ -11,7 +11,7 @@ class BaseModelMetadata:
|
||||
model_name: str # The model's name defined by the creator
|
||||
file_path: str # Full path to the model file
|
||||
size: int # File size in bytes
|
||||
modified: float # Last modified timestamp
|
||||
modified: float # Timestamp when the model was added to the management system
|
||||
sha256: str # SHA256 hash of the file
|
||||
base_model: str # Base model type (SD1.5/SD2.1/SDXL/etc.)
|
||||
preview_url: str # Preview image URL
|
||||
@@ -73,11 +73,6 @@ class BaseModelMetadata:
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def modified_datetime(self) -> datetime:
|
||||
"""Convert modified timestamp to datetime object"""
|
||||
return datetime.fromtimestamp(self.modified)
|
||||
|
||||
def update_civitai_info(self, civitai_data: Dict) -> None:
|
||||
"""Update Civitai information"""
|
||||
self.civitai = civitai_data
|
||||
@@ -88,6 +83,50 @@ class BaseModelMetadata:
|
||||
self.size = os.path.getsize(file_path)
|
||||
self.modified = os.path.getmtime(file_path)
|
||||
self.file_path = file_path.replace(os.sep, '/')
|
||||
# Update file_name when file_path changes
|
||||
self.file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
@staticmethod
|
||||
def generate_unique_filename(target_dir: str, base_name: str, extension: str, hash_provider: callable = None) -> str:
|
||||
"""Generate a unique filename to avoid conflicts
|
||||
|
||||
Args:
|
||||
target_dir: Target directory path
|
||||
base_name: Base filename without extension
|
||||
extension: File extension including the dot
|
||||
hash_provider: A callable that returns the SHA256 hash when needed
|
||||
|
||||
Returns:
|
||||
str: Unique filename that doesn't conflict with existing files
|
||||
"""
|
||||
original_filename = f"{base_name}{extension}"
|
||||
target_path = os.path.join(target_dir, original_filename)
|
||||
|
||||
# If no conflict, return original filename
|
||||
if not os.path.exists(target_path):
|
||||
return original_filename
|
||||
|
||||
# Only compute hash when needed
|
||||
if hash_provider:
|
||||
sha256_hash = hash_provider()
|
||||
else:
|
||||
sha256_hash = "0000"
|
||||
|
||||
# Generate short hash (first 4 characters of SHA256)
|
||||
short_hash = sha256_hash[:4] if sha256_hash else "0000"
|
||||
|
||||
# Try with short hash suffix
|
||||
unique_filename = f"{base_name}-{short_hash}{extension}"
|
||||
unique_path = os.path.join(target_dir, unique_filename)
|
||||
|
||||
# If still conflicts, add incremental number
|
||||
counter = 1
|
||||
while os.path.exists(unique_path):
|
||||
unique_filename = f"{base_name}-{short_hash}-{counter}{extension}"
|
||||
unique_path = os.path.join(target_dir, unique_filename)
|
||||
counter += 1
|
||||
|
||||
return unique_filename
|
||||
|
||||
@dataclass
|
||||
class LoraMetadata(BaseModelMetadata):
|
||||
@@ -128,7 +167,7 @@ class LoraMetadata(BaseModelMetadata):
|
||||
@dataclass
|
||||
class CheckpointMetadata(BaseModelMetadata):
|
||||
"""Represents the metadata structure for a Checkpoint model"""
|
||||
model_type: str = "checkpoint" # Model type (checkpoint, inpainting, etc.)
|
||||
model_type: str = "checkpoint" # Model type (checkpoint, diffusion_model, etc.)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'CheckpointMetadata':
|
||||
@@ -163,3 +202,41 @@ class CheckpointMetadata(BaseModelMetadata):
|
||||
modelDescription=description
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class EmbeddingMetadata(BaseModelMetadata):
|
||||
"""Represents the metadata structure for an Embedding model"""
|
||||
model_type: str = "embedding" # Model type (embedding, textual_inversion, etc.)
|
||||
|
||||
@classmethod
|
||||
def from_civitai_info(cls, version_info: Dict, file_info: Dict, save_path: str) -> 'EmbeddingMetadata':
|
||||
"""Create EmbeddingMetadata instance from Civitai version info"""
|
||||
file_name = file_info['name']
|
||||
base_model = determine_base_model(version_info.get('baseModel', ''))
|
||||
model_type = version_info.get('type', 'embedding')
|
||||
|
||||
# Extract tags and description if available
|
||||
tags = []
|
||||
description = ""
|
||||
if 'model' in version_info:
|
||||
if 'tags' in version_info['model']:
|
||||
tags = version_info['model']['tags']
|
||||
if 'description' in version_info['model']:
|
||||
description = version_info['model']['description']
|
||||
|
||||
return cls(
|
||||
file_name=os.path.splitext(file_name)[0],
|
||||
model_name=version_info.get('model').get('name', os.path.splitext(file_name)[0]),
|
||||
file_path=save_path.replace(os.sep, '/'),
|
||||
size=file_info.get('sizeKB', 0) * 1024,
|
||||
modified=datetime.now().timestamp(),
|
||||
sha256=file_info['hashes'].get('SHA256', '').lower(),
|
||||
base_model=base_model,
|
||||
preview_url=None, # Will be updated after preview download
|
||||
preview_nsfw_level=0,
|
||||
from_civitai=True,
|
||||
civitai=version_info,
|
||||
model_type=model_type,
|
||||
tags=tags,
|
||||
modelDescription=description
|
||||
)
|
||||
|
||||
|
||||
@@ -8,9 +8,11 @@ from .model_utils import determine_base_model
|
||||
from .constants import PREVIEW_EXTENSIONS, CARD_PREVIEW_WIDTH
|
||||
from ..config import config
|
||||
from ..services.civitai_client import CivitaiClient
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..services.download_manager import DownloadManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -154,7 +156,7 @@ class ModelRouteUtils:
|
||||
local_metadata['preview_nsfw_level'] = first_preview.get('nsfwLevel', 0)
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(metadata_path, local_metadata, True)
|
||||
await MetadataManager.save_metadata(metadata_path, local_metadata)
|
||||
|
||||
@staticmethod
|
||||
async def fetch_and_update_model(
|
||||
@@ -227,13 +229,13 @@ class ModelRouteUtils:
|
||||
await client.close()
|
||||
|
||||
@staticmethod
|
||||
def filter_civitai_data(data: Dict) -> Dict:
|
||||
def filter_civitai_data(data: Dict, minimal: bool = False) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
|
||||
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images", "customImages", "creator"
|
||||
]
|
||||
@@ -327,8 +329,6 @@ class ModelRouteUtils:
|
||||
# Update hash index if available
|
||||
if hasattr(scanner, '_hash_index') and scanner._hash_index:
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
await scanner._save_cache_to_disk()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -551,8 +551,6 @@ class ModelRouteUtils:
|
||||
|
||||
# Add to excluded models list
|
||||
scanner._excluded_models.append(file_path)
|
||||
|
||||
await scanner._save_cache_to_disk()
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
@@ -564,66 +562,77 @@ class ModelRouteUtils:
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
|
||||
"""Handle model download request
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
download_manager: Instance of DownloadManager
|
||||
model_type: Type of model ('lora' or 'checkpoint')
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
async def handle_download_model(request: web.Request) -> web.Response:
|
||||
"""Handle model download request"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
data = await request.json()
|
||||
|
||||
# Create progress callback
|
||||
# Get or generate a download ID
|
||||
download_id = data.get('download_id', ws_manager.generate_download_id())
|
||||
|
||||
# Create progress callback with download ID
|
||||
async def progress_callback(progress):
|
||||
from ..services.websocket_manager import ws_manager
|
||||
await ws_manager.broadcast({
|
||||
await ws_manager.broadcast_download_progress(download_id, {
|
||||
'status': 'progress',
|
||||
'progress': progress
|
||||
'progress': progress,
|
||||
'download_id': download_id
|
||||
})
|
||||
|
||||
# Check which identifier is provided
|
||||
download_url = data.get('download_url')
|
||||
model_hash = data.get('model_hash')
|
||||
model_version_id = data.get('model_version_id')
|
||||
# Check which identifier is provided and convert to int
|
||||
model_id = None
|
||||
model_version_id = None
|
||||
|
||||
# Validate that at least one identifier is provided
|
||||
if not any([download_url, model_hash, model_version_id]):
|
||||
return web.Response(
|
||||
status=400,
|
||||
text="Missing required parameter: Please provide either 'download_url', 'hash', or 'modelVersionId'"
|
||||
)
|
||||
if data.get('model_id'):
|
||||
try:
|
||||
model_id = int(data.get('model_id'))
|
||||
except (TypeError, ValueError):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Invalid model_id: Must be an integer"
|
||||
}, status=400)
|
||||
|
||||
# Convert model_version_id to int if provided
|
||||
if data.get('model_version_id'):
|
||||
try:
|
||||
model_version_id = int(data.get('model_version_id'))
|
||||
except (TypeError, ValueError):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Invalid model_version_id: Must be an integer"
|
||||
}, status=400)
|
||||
|
||||
# Use the correct root directory based on model type
|
||||
root_key = 'checkpoint_root' if model_type == 'checkpoint' else 'lora_root'
|
||||
save_dir = data.get(root_key)
|
||||
# At least one identifier is required
|
||||
if not model_id and not model_version_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||
}, status=400)
|
||||
|
||||
use_default_paths = data.get('use_default_paths', False)
|
||||
|
||||
# Pass the download_id to download_from_civitai
|
||||
result = await download_manager.download_from_civitai(
|
||||
download_url=download_url,
|
||||
model_hash=model_hash,
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=save_dir,
|
||||
save_dir=data.get('model_root'),
|
||||
relative_path=data.get('relative_path', ''),
|
||||
use_default_paths=use_default_paths,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type
|
||||
download_id=download_id # Pass download_id explicitly
|
||||
)
|
||||
|
||||
# Include download_id in the response
|
||||
result['download_id'] = download_id
|
||||
|
||||
if not result.get('success', False):
|
||||
error_message = result.get('error', 'Unknown error')
|
||||
|
||||
# Return 401 for early access errors
|
||||
if 'early access' in error_message.lower():
|
||||
logger.warning(f"Early access download failed: {error_message}")
|
||||
return web.Response(
|
||||
status=401, # Use 401 status code to match Civitai's response
|
||||
text=f"Early Access Restriction: {error_message}"
|
||||
)
|
||||
|
||||
return web.Response(status=500, text=error_message)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': error_message,
|
||||
'download_id': download_id
|
||||
}, status=500)
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
@@ -633,13 +642,75 @@ class ModelRouteUtils:
|
||||
# Check if this might be an early access error
|
||||
if '401' in error_message:
|
||||
logger.warning(f"Early access error (401): {error_message}")
|
||||
return web.Response(
|
||||
status=401,
|
||||
text="Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
}, status=401)
|
||||
|
||||
logger.error(f"Error downloading {model_type}: {error_message}")
|
||||
return web.Response(status=500, text=error_message)
|
||||
logger.error(f"Error downloading model: {error_message}")
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': error_message
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_cancel_download(request: web.Request) -> web.Response:
|
||||
"""Handle cancellation of a download task
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
download_id = request.match_info.get('download_id')
|
||||
if not download_id:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download ID is required'
|
||||
}, status=400)
|
||||
|
||||
result = await download_manager.cancel_download(download_id)
|
||||
|
||||
# Notify clients about cancellation via WebSocket
|
||||
await ws_manager.broadcast_download_progress(download_id, {
|
||||
'status': 'cancelled',
|
||||
'progress': 0,
|
||||
'download_id': download_id,
|
||||
'message': 'Download cancelled by user'
|
||||
})
|
||||
|
||||
return web.json_response(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_list_downloads(request: web.Request) -> web.Response:
|
||||
"""Get list of active downloads
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response with list of downloads
|
||||
"""
|
||||
try:
|
||||
download_manager = await ServiceRegistry.get_download_manager()
|
||||
result = await download_manager.get_active_downloads()
|
||||
return web.json_response(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing downloads: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response:
|
||||
@@ -693,8 +764,10 @@ class ModelRouteUtils:
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
model_id = data.get('model_id')
|
||||
model_version_id = data.get('model_version_id')
|
||||
model_id = int(data.get('model_id'))
|
||||
model_version_id = None
|
||||
if data.get('model_version_id'):
|
||||
model_version_id = int(data.get('model_version_id'))
|
||||
|
||||
if not file_path or not model_id:
|
||||
return web.json_response({"success": False, "error": "Both file_path and model_id are required"}, status=400)
|
||||
@@ -797,11 +870,11 @@ class ModelRouteUtils:
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Compare hashes
|
||||
stored_hash = metadata.get('sha256', '').lower()
|
||||
stored_hash = metadata.get('sha256', '').lower();
|
||||
|
||||
# Set expected hash from first file if not yet set
|
||||
if not expected_hash:
|
||||
expected_hash = stored_hash
|
||||
expected_hash = stored_hash;
|
||||
|
||||
# Check if hash matches expected hash
|
||||
if actual_hash != expected_hash:
|
||||
@@ -905,10 +978,11 @@ class ModelRouteUtils:
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
hash_value = metadata.get('sha256')
|
||||
|
||||
logger.info(f"hash_value: {hash_value}, metadata_path: {metadata_path}, metadata: {metadata}")
|
||||
# Rename all files
|
||||
renamed_files = []
|
||||
new_metadata_path = None
|
||||
new_preview = None
|
||||
|
||||
for old_path, pattern in existing_files:
|
||||
# Get the file extension like .safetensors or .metadata.json
|
||||
@@ -955,6 +1029,7 @@ class ModelRouteUtils:
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'new_file_path': new_file_path,
|
||||
'new_preview_path': config.get_preview_static_url(new_preview),
|
||||
'renamed_files': renamed_files,
|
||||
'reload_required': False
|
||||
})
|
||||
@@ -965,3 +1040,116 @@ class ModelRouteUtils:
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle saving metadata updates
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
# Remove file path from data to avoid saving it
|
||||
metadata_updates = {k: v for k, v in data.items() if k != 'file_path'}
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Handle nested updates (for civitai.trainedWords)
|
||||
for key, value in metadata_updates.items():
|
||||
if isinstance(value, dict) and key in metadata and isinstance(metadata[key], dict):
|
||||
# Deep update for nested dictionaries
|
||||
for nested_key, nested_value in value.items():
|
||||
metadata[key][nested_key] = nested_value
|
||||
else:
|
||||
# Regular update for top-level keys
|
||||
metadata[key] = value
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
await scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
# If model_name was updated, resort the cache
|
||||
if 'model_name' in metadata_updates:
|
||||
cache = await scanner.get_cached_data()
|
||||
await cache.resort()
|
||||
|
||||
return web.json_response({'success': True})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_add_tags(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle adding tags to model metadata
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
new_tags = data.get('tags', [])
|
||||
|
||||
if not file_path:
|
||||
return web.Response(text='File path is required', status=400)
|
||||
|
||||
if not isinstance(new_tags, list):
|
||||
return web.Response(text='Tags must be a list', status=400)
|
||||
|
||||
# Get metadata file path
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
|
||||
# Load existing metadata
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
|
||||
# Get existing tags (case insensitive)
|
||||
existing_tags = metadata.get('tags', [])
|
||||
existing_tags_lower = [tag.lower() for tag in existing_tags]
|
||||
|
||||
# Add new tags that don't already exist (case insensitive check)
|
||||
tags_added = []
|
||||
for tag in new_tags:
|
||||
if isinstance(tag, str) and tag.strip():
|
||||
tag_stripped = tag.strip()
|
||||
if tag_stripped.lower() not in existing_tags_lower:
|
||||
existing_tags.append(tag_stripped)
|
||||
existing_tags_lower.append(tag_stripped.lower())
|
||||
tags_added.append(tag_stripped)
|
||||
|
||||
# Update metadata with combined tags
|
||||
metadata['tags'] = existing_tags
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
await scanner.update_single_model_cache(file_path, file_path, metadata)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'tags': existing_tags
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding tags: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@@ -1,85 +1,56 @@
|
||||
from difflib import SequenceMatcher
|
||||
import requests
|
||||
import tempfile
|
||||
import re
|
||||
from bs4 import BeautifulSoup
|
||||
import os
|
||||
from typing import Dict
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..config import config
|
||||
from ..services.settings_manager import settings
|
||||
from .constants import CIVITAI_MODEL_TAGS
|
||||
import asyncio
|
||||
|
||||
def download_twitter_image(url):
|
||||
"""Download image from a URL containing twitter:image meta tag
|
||||
def get_lora_info(lora_name):
|
||||
"""Get the lora path and trigger words from cache"""
|
||||
async def _get_lora_info_async():
|
||||
scanner = await ServiceRegistry.get_lora_scanner()
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
for item in cache.raw_data:
|
||||
if item.get('file_name') == lora_name:
|
||||
file_path = item.get('file_path')
|
||||
if file_path:
|
||||
for root in config.loras_roots:
|
||||
root = root.replace(os.sep, '/')
|
||||
if file_path.startswith(root):
|
||||
relative_path = os.path.relpath(file_path, root).replace(os.sep, '/')
|
||||
# Get trigger words from civitai metadata
|
||||
civitai = item.get('civitai', {})
|
||||
trigger_words = civitai.get('trainedWords', []) if civitai else []
|
||||
return relative_path, trigger_words
|
||||
return lora_name, []
|
||||
|
||||
Args:
|
||||
url (str): The URL to download image from
|
||||
|
||||
Returns:
|
||||
str: Path to downloaded temporary image file
|
||||
"""
|
||||
try:
|
||||
# Download page content
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
# Check if we're already in an event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in a running loop, we need to use a different approach
|
||||
# Create a new thread to run the async code
|
||||
import concurrent.futures
|
||||
|
||||
# Parse HTML
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_get_lora_info_async())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
# Find twitter:image meta tag
|
||||
meta_tag = soup.find('meta', attrs={'property': 'twitter:image'})
|
||||
if not meta_tag:
|
||||
return None
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
|
||||
image_url = meta_tag['content']
|
||||
|
||||
# Download image
|
||||
image_response = requests.get(image_url)
|
||||
image_response.raise_for_status()
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
|
||||
temp_file.write(image_response.content)
|
||||
return temp_file.name
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error downloading twitter image: {e}")
|
||||
return None
|
||||
except RuntimeError:
|
||||
# No event loop is running, we can use asyncio.run()
|
||||
return asyncio.run(_get_lora_info_async())
|
||||
|
||||
def download_civitai_image(url):
|
||||
"""Download image from a URL containing avatar image with specific class and style attributes
|
||||
|
||||
Args:
|
||||
url (str): The URL to download image from
|
||||
|
||||
Returns:
|
||||
str: Path to downloaded temporary image file
|
||||
"""
|
||||
try:
|
||||
# Download page content
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse HTML
|
||||
soup = BeautifulSoup(response.text, 'html.parser')
|
||||
|
||||
# Find image with specific class and style attributes
|
||||
image = soup.select_one('img.EdgeImage_image__iH4_q.max-h-full.w-auto.max-w-full')
|
||||
|
||||
if not image or 'src' not in image.attrs:
|
||||
return None
|
||||
|
||||
image_url = image['src']
|
||||
|
||||
# Download image
|
||||
image_response = requests.get(image_url)
|
||||
image_response.raise_for_status()
|
||||
|
||||
# Save to temp file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
|
||||
temp_file.write(image_response.content)
|
||||
return temp_file.name
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error downloading civitai avatar: {e}")
|
||||
return None
|
||||
|
||||
def fuzzy_match(text: str, pattern: str, threshold: float = 0.7) -> bool:
|
||||
def fuzzy_match(text: str, pattern: str, threshold: float = 0.85) -> bool:
|
||||
"""
|
||||
Check if text matches pattern using fuzzy matching.
|
||||
Returns True if similarity ratio is above threshold.
|
||||
@@ -160,3 +131,95 @@ def calculate_recipe_fingerprint(loras):
|
||||
fingerprint = "|".join([f"{hash_value}:{strength}" for hash_value, strength in valid_loras])
|
||||
|
||||
return fingerprint
|
||||
|
||||
def calculate_relative_path_for_model(model_data: Dict, model_type: str = 'lora') -> str:
|
||||
"""Calculate relative path for existing model using template from settings
|
||||
|
||||
Args:
|
||||
model_data: Model data from scanner cache
|
||||
model_type: Type of model ('lora', 'checkpoint', 'embedding')
|
||||
|
||||
Returns:
|
||||
Relative path string (empty string for flat structure)
|
||||
"""
|
||||
# Get path template from settings for specific model type
|
||||
path_template = settings.get_download_path_template(model_type)
|
||||
|
||||
# If template is empty, return empty path (flat structure)
|
||||
if not path_template:
|
||||
return ''
|
||||
|
||||
# Get base model name from model metadata
|
||||
civitai_data = model_data.get('civitai', {})
|
||||
|
||||
# For CivitAI models, prefer civitai data only if 'id' exists; for non-CivitAI models, use model_data directly
|
||||
if civitai_data and civitai_data.get('id') is not None:
|
||||
base_model = civitai_data.get('baseModel', '')
|
||||
# Get author from civitai creator data
|
||||
creator_info = civitai_data.get('creator') or {}
|
||||
author = creator_info.get('username') or 'Anonymous'
|
||||
else:
|
||||
# Fallback to model_data fields for non-CivitAI models
|
||||
base_model = model_data.get('base_model', '')
|
||||
author = 'Anonymous' # Default for non-CivitAI models
|
||||
|
||||
model_tags = model_data.get('tags', [])
|
||||
|
||||
# Apply mapping if available
|
||||
base_model_mappings = settings.get('base_model_path_mappings', {})
|
||||
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
||||
|
||||
# Find the first Civitai model tag that exists in model_tags
|
||||
first_tag = ''
|
||||
for civitai_tag in CIVITAI_MODEL_TAGS:
|
||||
if civitai_tag in model_tags:
|
||||
first_tag = civitai_tag
|
||||
break
|
||||
|
||||
# If no Civitai model tag found, fallback to first tag
|
||||
if not first_tag and model_tags:
|
||||
first_tag = model_tags[0]
|
||||
|
||||
if not first_tag:
|
||||
first_tag = 'no tags' # Default if no tags available
|
||||
|
||||
# Format the template with available data
|
||||
formatted_path = path_template
|
||||
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
|
||||
formatted_path = formatted_path.replace('{first_tag}', first_tag)
|
||||
formatted_path = formatted_path.replace('{author}', author)
|
||||
|
||||
return formatted_path
|
||||
|
||||
def remove_empty_dirs(path):
|
||||
"""Recursively remove empty directories starting from the given path.
|
||||
|
||||
Args:
|
||||
path (str): Root directory to start cleaning from
|
||||
|
||||
Returns:
|
||||
int: Number of empty directories removed
|
||||
"""
|
||||
removed_count = 0
|
||||
|
||||
if not os.path.isdir(path):
|
||||
return removed_count
|
||||
|
||||
# List all files in directory
|
||||
files = os.listdir(path)
|
||||
|
||||
# Process all subdirectories first
|
||||
for file in files:
|
||||
full_path = os.path.join(path, file)
|
||||
if os.path.isdir(full_path):
|
||||
removed_count += remove_empty_dirs(full_path)
|
||||
|
||||
# Check if directory is now empty (after processing subdirectories)
|
||||
if not os.listdir(path):
|
||||
try:
|
||||
os.rmdir(path)
|
||||
removed_count += 1
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return removed_count
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
[project]
|
||||
name = "comfyui-lora-manager"
|
||||
description = "LoRA Manager for ComfyUI - Access it at http://localhost:8188/loras for managing LoRA models with previews and metadata integration."
|
||||
version = "0.8.19"
|
||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||
version = "0.9.2"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
"jinja2",
|
||||
"safetensors",
|
||||
"watchdog",
|
||||
"beautifulsoup4",
|
||||
"piexif",
|
||||
"Pillow",
|
||||
"olefile", # for getting rid of warning message
|
||||
"requests",
|
||||
"toml",
|
||||
"natsort",
|
||||
"msgpack"
|
||||
"GitPython"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
aiohttp
|
||||
jinja2
|
||||
safetensors
|
||||
watchdog
|
||||
beautifulsoup4
|
||||
piexif
|
||||
Pillow
|
||||
olefile
|
||||
requests
|
||||
toml
|
||||
numpy
|
||||
torch
|
||||
natsort
|
||||
msgpack
|
||||
GitPython
|
||||
|
||||
305
scripts/sync_translation_keys.py
Normal file
305
scripts/sync_translation_keys.py
Normal file
@@ -0,0 +1,305 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Translation Key Synchronization Script
|
||||
|
||||
This script synchronizes new translation keys from en.json to all other locale files
|
||||
while maintaining exact formatting consistency to pass test_i18n.py validation.
|
||||
|
||||
Features:
|
||||
- Preserves exact line-by-line formatting
|
||||
- Maintains proper indentation and structure
|
||||
- Adds missing keys with placeholder translations
|
||||
- Handles nested objects correctly
|
||||
- Ensures all locale files have identical structure
|
||||
|
||||
Usage:
|
||||
python scripts/sync_translation_keys.py [--dry-run] [--verbose]
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
import argparse
|
||||
from typing import Dict, List, Set, Tuple, Any, Optional
|
||||
from collections import OrderedDict
|
||||
|
||||
# Add the parent directory to the path so we can import modules if needed
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
class TranslationKeySynchronizer:
|
||||
"""Synchronizes translation keys across locale files while maintaining formatting."""
|
||||
|
||||
def __init__(self, locales_dir: str, verbose: bool = False):
|
||||
self.locales_dir = locales_dir
|
||||
self.verbose = verbose
|
||||
self.reference_locale = 'en'
|
||||
self.target_locales = ['zh-CN', 'zh-TW', 'ja', 'ru', 'de', 'fr', 'es', 'ko']
|
||||
|
||||
def log(self, message: str, level: str = 'INFO'):
|
||||
"""Log a message if verbose mode is enabled."""
|
||||
if self.verbose or level == 'ERROR':
|
||||
print(f"[{level}] {message}")
|
||||
|
||||
def load_json_preserve_order(self, file_path: str) -> Tuple[Dict[str, Any], List[str]]:
|
||||
"""
|
||||
Load a JSON file preserving the exact order and formatting.
|
||||
Returns both the parsed data and the original lines.
|
||||
"""
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
content = ''.join(lines)
|
||||
|
||||
# Parse JSON while preserving order
|
||||
data = json.loads(content, object_pairs_hook=OrderedDict)
|
||||
return data, lines
|
||||
|
||||
def get_all_leaf_keys(self, data: Any, prefix: str = '') -> Dict[str, Any]:
|
||||
"""
|
||||
Extract all leaf keys (non-object values) with their full paths.
|
||||
Returns a dictionary mapping full key paths to their values.
|
||||
"""
|
||||
keys = {}
|
||||
|
||||
if isinstance(data, (dict, OrderedDict)):
|
||||
for key, value in data.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
|
||||
if isinstance(value, (dict, OrderedDict)):
|
||||
# Recursively get nested keys
|
||||
keys.update(self.get_all_leaf_keys(value, full_key))
|
||||
else:
|
||||
# Leaf node - actual translatable value
|
||||
keys[full_key] = value
|
||||
|
||||
return keys
|
||||
|
||||
def merge_json_structures(self, reference_data: Dict[str, Any], target_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge the reference JSON structure with existing target translations.
|
||||
This creates a new structure that matches the reference exactly but preserves
|
||||
existing translations where available. Keys not in reference are removed.
|
||||
"""
|
||||
def merge_recursive(ref_obj, target_obj):
|
||||
if isinstance(ref_obj, (dict, OrderedDict)):
|
||||
result = OrderedDict()
|
||||
# Only include keys that exist in the reference
|
||||
for key, ref_value in ref_obj.items():
|
||||
if key in target_obj and isinstance(target_obj[key], type(ref_value)):
|
||||
# Key exists in target with same type
|
||||
if isinstance(ref_value, (dict, OrderedDict)):
|
||||
# Recursively merge nested objects
|
||||
result[key] = merge_recursive(ref_value, target_obj[key])
|
||||
else:
|
||||
# Use existing translation
|
||||
result[key] = target_obj[key]
|
||||
else:
|
||||
# Key missing in target or type mismatch
|
||||
if isinstance(ref_value, (dict, OrderedDict)):
|
||||
# Recursively handle nested objects
|
||||
result[key] = merge_recursive(ref_value, {})
|
||||
else:
|
||||
# Create placeholder translation
|
||||
result[key] = f"[TODO: Translate] {ref_value}"
|
||||
return result
|
||||
else:
|
||||
# For non-dict values, use reference (this shouldn't happen at root level)
|
||||
return ref_obj
|
||||
|
||||
return merge_recursive(reference_data, target_data)
|
||||
|
||||
def format_json_like_reference(self, data: Dict[str, Any], reference_lines: List[str]) -> List[str]:
|
||||
"""
|
||||
Format the merged JSON data to match the reference file's formatting exactly.
|
||||
"""
|
||||
# Use json.dumps with proper formatting to match the reference style
|
||||
formatted_json = json.dumps(data, indent=4, ensure_ascii=False, separators=(',', ': '))
|
||||
|
||||
# Split into lines and ensure consistent line endings
|
||||
formatted_lines = [line + '\n' for line in formatted_json.split('\n')]
|
||||
|
||||
# Make sure the last line doesn't have extra newlines
|
||||
if formatted_lines and formatted_lines[-1].strip() == '':
|
||||
formatted_lines = formatted_lines[:-1]
|
||||
|
||||
# Ensure the last line ends with just a newline
|
||||
if formatted_lines and not formatted_lines[-1].endswith('\n'):
|
||||
formatted_lines[-1] += '\n'
|
||||
|
||||
return formatted_lines
|
||||
|
||||
def synchronize_locale_simple(self, locale: str, reference_data: Dict[str, Any],
|
||||
reference_lines: List[str], dry_run: bool = False) -> bool:
|
||||
"""
|
||||
Synchronize a locale file using JSON structure merging.
|
||||
Handles both addition of missing keys and removal of obsolete keys.
|
||||
"""
|
||||
locale_file = os.path.join(self.locales_dir, f'{locale}.json')
|
||||
|
||||
if not os.path.exists(locale_file):
|
||||
self.log(f"Locale file {locale_file} does not exist!", 'ERROR')
|
||||
return False
|
||||
|
||||
try:
|
||||
target_data, _ = self.load_json_preserve_order(locale_file)
|
||||
except Exception as e:
|
||||
self.log(f"Error loading {locale_file}: {e}", 'ERROR')
|
||||
return False
|
||||
|
||||
# Get keys to check for differences
|
||||
ref_keys = self.get_all_leaf_keys(reference_data)
|
||||
target_keys = self.get_all_leaf_keys(target_data)
|
||||
missing_keys = set(ref_keys.keys()) - set(target_keys.keys())
|
||||
obsolete_keys = set(target_keys.keys()) - set(ref_keys.keys())
|
||||
|
||||
if not missing_keys and not obsolete_keys:
|
||||
self.log(f"Locale {locale} is already up to date")
|
||||
return False
|
||||
|
||||
# Report changes
|
||||
if missing_keys:
|
||||
self.log(f"Found {len(missing_keys)} missing keys in {locale}:")
|
||||
for key in sorted(missing_keys):
|
||||
self.log(f" + {key}")
|
||||
|
||||
if obsolete_keys:
|
||||
self.log(f"Found {len(obsolete_keys)} obsolete keys in {locale}:")
|
||||
for key in sorted(obsolete_keys):
|
||||
self.log(f" - {key}")
|
||||
|
||||
if dry_run:
|
||||
total_changes = len(missing_keys) + len(obsolete_keys)
|
||||
self.log(f"DRY RUN: Would update {locale} with {len(missing_keys)} additions and {len(obsolete_keys)} deletions ({total_changes} total changes)")
|
||||
return True
|
||||
|
||||
# Merge the structures (this will both add missing keys and remove obsolete ones)
|
||||
try:
|
||||
merged_data = self.merge_json_structures(reference_data, target_data)
|
||||
|
||||
# Format to match reference style
|
||||
new_lines = self.format_json_like_reference(merged_data, reference_lines)
|
||||
|
||||
# Validate that the result is valid JSON
|
||||
reconstructed_content = ''.join(new_lines)
|
||||
json.loads(reconstructed_content) # This will raise an exception if invalid
|
||||
|
||||
# Write the updated file
|
||||
with open(locale_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(new_lines)
|
||||
|
||||
total_changes = len(missing_keys) + len(obsolete_keys)
|
||||
self.log(f"Successfully updated {locale} with {len(missing_keys)} additions and {len(obsolete_keys)} deletions ({total_changes} total changes)")
|
||||
return True
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
self.log(f"Generated invalid JSON for {locale}: {e}", 'ERROR')
|
||||
return False
|
||||
except Exception as e:
|
||||
self.log(f"Error updating {locale_file}: {e}", 'ERROR')
|
||||
return False
|
||||
|
||||
def synchronize_all(self, dry_run: bool = False) -> bool:
|
||||
"""
|
||||
Synchronize all locale files with the reference.
|
||||
Returns True if all operations were successful.
|
||||
"""
|
||||
# Load reference file
|
||||
reference_file = os.path.join(self.locales_dir, f'{self.reference_locale}.json')
|
||||
|
||||
if not os.path.exists(reference_file):
|
||||
self.log(f"Reference file {reference_file} does not exist!", 'ERROR')
|
||||
return False
|
||||
|
||||
try:
|
||||
reference_data, reference_lines = self.load_json_preserve_order(reference_file)
|
||||
reference_keys = self.get_all_leaf_keys(reference_data)
|
||||
except Exception as e:
|
||||
self.log(f"Error loading reference file: {e}", 'ERROR')
|
||||
return False
|
||||
|
||||
self.log(f"Loaded reference file with {len(reference_keys)} keys")
|
||||
|
||||
success = True
|
||||
changes_made = False
|
||||
|
||||
# Synchronize each target locale
|
||||
for locale in self.target_locales:
|
||||
try:
|
||||
if self.synchronize_locale_simple(locale, reference_data, reference_lines, dry_run):
|
||||
changes_made = True
|
||||
except Exception as e:
|
||||
self.log(f"Error synchronizing {locale}: {e}", 'ERROR')
|
||||
success = False
|
||||
|
||||
if changes_made:
|
||||
self.log("Synchronization completed with changes")
|
||||
else:
|
||||
self.log("All locale files are already up to date")
|
||||
|
||||
return success
|
||||
|
||||
def main():
|
||||
"""Main entry point for the script."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Synchronize translation keys from en.json to all other locale files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Show what would be changed without making actual changes'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Enable verbose output'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--locales-dir',
|
||||
default=None,
|
||||
help='Path to locales directory (default: auto-detect from script location)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine locales directory
|
||||
if args.locales_dir:
|
||||
locales_dir = args.locales_dir
|
||||
else:
|
||||
# Auto-detect based on script location
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
locales_dir = os.path.join(os.path.dirname(script_dir), 'locales')
|
||||
|
||||
if not os.path.exists(locales_dir):
|
||||
print(f"ERROR: Locales directory not found: {locales_dir}")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Translation Key Synchronization")
|
||||
print(f"Locales directory: {locales_dir}")
|
||||
print(f"Mode: {'DRY RUN' if args.dry_run else 'LIVE UPDATE'}")
|
||||
print("-" * 50)
|
||||
|
||||
# Create synchronizer and run
|
||||
synchronizer = TranslationKeySynchronizer(locales_dir, args.verbose)
|
||||
|
||||
try:
|
||||
success = synchronizer.synchronize_all(args.dry_run)
|
||||
|
||||
if success:
|
||||
print("\n✅ Synchronization completed successfully!")
|
||||
if not args.dry_run:
|
||||
print("💡 Run 'python test_i18n.py' to verify formatting consistency")
|
||||
else:
|
||||
print("\n❌ Synchronization completed with errors!")
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n⚠️ Operation cancelled by user")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -9,6 +9,10 @@
|
||||
"checkpoints": [
|
||||
"C:/path/to/your/checkpoints_folder",
|
||||
"C:/path/to/another/checkpoints_folder"
|
||||
],
|
||||
"embeddings": [
|
||||
"C:/path/to/your/embeddings_folder",
|
||||
"C:/path/to/another/embeddings_folder"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
119
standalone.py
119
standalone.py
@@ -3,6 +3,26 @@ import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
# Create mock modules for py/nodes directory - add this before any other imports
|
||||
def mock_nodes_directory():
|
||||
"""Create mock modules for all Python files in the py/nodes directory"""
|
||||
nodes_dir = os.path.join(os.path.dirname(__file__), 'py', 'nodes')
|
||||
if os.path.exists(nodes_dir):
|
||||
# Create a mock module for the nodes package itself
|
||||
sys.modules['py.nodes'] = type('MockNodesModule', (), {})
|
||||
|
||||
# Create mock modules for all Python files in the nodes directory
|
||||
for file in os.listdir(nodes_dir):
|
||||
if file.endswith('.py') and file != '__init__.py':
|
||||
module_name = file[:-3] # Remove .py extension
|
||||
full_module_name = f'py.nodes.{module_name}'
|
||||
# Create empty module object
|
||||
sys.modules[full_module_name] = type(f'Mock{module_name.capitalize()}Module', (), {})
|
||||
print(f"Created mock module for: {full_module_name}")
|
||||
|
||||
# Run the mocking function before any other imports
|
||||
mock_nodes_directory()
|
||||
|
||||
# Create mock folder_paths module BEFORE any other imports
|
||||
class MockFolderPaths:
|
||||
@staticmethod
|
||||
@@ -86,6 +106,22 @@ logger = logging.getLogger("lora-manager-standalone")
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Add specific suppression for connection reset errors
|
||||
class ConnectionResetFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Filter out connection reset errors that are not critical
|
||||
if "ConnectionResetError" in str(record.getMessage()):
|
||||
return False
|
||||
if "_call_connection_lost" in str(record.getMessage()):
|
||||
return False
|
||||
if "WinError 10054" in str(record.getMessage()):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Apply the filter to asyncio logger
|
||||
asyncio_logger = logging.getLogger("asyncio")
|
||||
asyncio_logger.addFilter(ConnectionResetFilter())
|
||||
|
||||
# Now we can import the global config from our local modules
|
||||
from py.config import config
|
||||
|
||||
@@ -98,17 +134,6 @@ class StandaloneServer:
|
||||
|
||||
# Ensure the app's access logger is configured to reduce verbosity
|
||||
self.app._subapps = [] # Ensure this exists to avoid AttributeError
|
||||
|
||||
# Configure access logging for the app
|
||||
self.app.on_startup.append(self._configure_access_logger)
|
||||
|
||||
async def _configure_access_logger(self, app):
|
||||
"""Configure access logger to reduce verbosity"""
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# If using aiohttp>=3.8.0, configure access logger through app directly
|
||||
if hasattr(app, 'access_logger'):
|
||||
app.access_logger.setLevel(logging.WARNING)
|
||||
|
||||
async def setup(self):
|
||||
"""Set up the standalone server"""
|
||||
@@ -198,9 +223,6 @@ class StandaloneLoraManager(LoraManager):
|
||||
|
||||
# Store app in a global-like location for compatibility
|
||||
sys.modules['server'].PromptServer.instance = server_instance
|
||||
|
||||
# Configure aiohttp access logger to be less verbose
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
added_targets = set() # Track already added target paths
|
||||
|
||||
@@ -232,7 +254,7 @@ class StandaloneLoraManager(LoraManager):
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for each checkpoint root
|
||||
for idx, root in enumerate(config.checkpoints_roots, start=1):
|
||||
for idx, root in enumerate(config.base_models_roots, start=1):
|
||||
if not os.path.exists(root):
|
||||
logger.warning(f"Checkpoint root path does not exist: {root}")
|
||||
continue
|
||||
@@ -257,23 +279,50 @@ class StandaloneLoraManager(LoraManager):
|
||||
# Record route mapping
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for each embedding root
|
||||
for idx, root in enumerate(getattr(config, "embeddings_roots", []), start=1):
|
||||
if not os.path.exists(root):
|
||||
logger.warning(f"Embedding root path does not exist: {root}")
|
||||
continue
|
||||
|
||||
preview_path = f'/embeddings_static/root{idx}/preview'
|
||||
|
||||
real_root = root
|
||||
for target, link in config._path_mappings.items():
|
||||
if os.path.normpath(link) == os.path.normpath(root):
|
||||
real_root = target
|
||||
break
|
||||
|
||||
display_root = real_root.replace('\\', '/')
|
||||
app.router.add_static(preview_path, real_root)
|
||||
logger.info(f"Added static route {preview_path} -> {display_root}")
|
||||
|
||||
config.add_route_mapping(real_root, preview_path)
|
||||
added_targets.add(os.path.normpath(real_root))
|
||||
|
||||
# Add static routes for symlink target paths that aren't already covered
|
||||
link_idx = {
|
||||
'lora': 1,
|
||||
'checkpoint': 1
|
||||
'checkpoint': 1,
|
||||
'embedding': 1
|
||||
}
|
||||
|
||||
for target_path, link_path in config._path_mappings.items():
|
||||
norm_target = os.path.normpath(target_path)
|
||||
if norm_target not in added_targets:
|
||||
# Determine if this is a checkpoint or lora link based on path
|
||||
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.checkpoints_roots)
|
||||
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.checkpoints_roots)
|
||||
|
||||
# Determine if this is a checkpoint, lora, or embedding link based on path
|
||||
is_checkpoint = any(os.path.normpath(cp_root) in os.path.normpath(link_path) for cp_root in config.base_models_roots)
|
||||
is_checkpoint = is_checkpoint or any(os.path.normpath(cp_root) in norm_target for cp_root in config.base_models_roots)
|
||||
is_embedding = any(os.path.normpath(emb_root) in os.path.normpath(link_path) for emb_root in getattr(config, "embeddings_roots", []))
|
||||
is_embedding = is_embedding or any(os.path.normpath(emb_root) in norm_target for emb_root in getattr(config, "embeddings_roots", []))
|
||||
|
||||
if is_checkpoint:
|
||||
route_path = f'/checkpoints_static/link_{link_idx["checkpoint"]}/preview'
|
||||
link_idx["checkpoint"] += 1
|
||||
elif is_embedding:
|
||||
route_path = f'/embeddings_static/link_{link_idx["embedding"]}/preview'
|
||||
link_idx["embedding"] += 1
|
||||
else:
|
||||
route_path = f'/loras_static/link_{link_idx["lora"]}/preview'
|
||||
link_idx["lora"] += 1
|
||||
@@ -290,39 +339,48 @@ class StandaloneLoraManager(LoraManager):
|
||||
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
|
||||
continue
|
||||
|
||||
# Add static route for locales JSON files
|
||||
if os.path.exists(config.i18n_path):
|
||||
app.router.add_static('/locales', config.i18n_path)
|
||||
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
|
||||
|
||||
# Add static route for plugin assets
|
||||
app.router.add_static('/loras_static', config.static_path)
|
||||
|
||||
# Setup feature routes
|
||||
from py.routes.lora_routes import LoraRoutes
|
||||
from py.routes.api_routes import ApiRoutes
|
||||
from py.services.model_service_factory import ModelServiceFactory, register_default_model_types
|
||||
from py.routes.recipe_routes import RecipeRoutes
|
||||
from py.routes.checkpoints_routes import CheckpointsRoutes
|
||||
from py.routes.update_routes import UpdateRoutes
|
||||
from py.routes.misc_routes import MiscRoutes
|
||||
from py.routes.example_images_routes import ExampleImagesRoutes
|
||||
from py.routes.stats_routes import StatsRoutes
|
||||
from py.services.websocket_manager import ws_manager
|
||||
|
||||
lora_routes = LoraRoutes()
|
||||
checkpoints_routes = CheckpointsRoutes()
|
||||
|
||||
register_default_model_types()
|
||||
|
||||
# Setup all model routes using the factory
|
||||
ModelServiceFactory.setup_all_routes(app)
|
||||
|
||||
stats_routes = StatsRoutes()
|
||||
|
||||
# Initialize routes
|
||||
lora_routes.setup_routes(app)
|
||||
checkpoints_routes.setup_routes(app)
|
||||
stats_routes.setup_routes(app)
|
||||
ApiRoutes.setup_routes(app)
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
app.router.add_get('/ws/download-progress', ws_manager.handle_download_connection)
|
||||
app.router.add_get('/ws/init-progress', ws_manager.handle_init_connection)
|
||||
|
||||
# Schedule service initialization
|
||||
app.on_startup.append(lambda app: cls._initialize_services())
|
||||
|
||||
# Add cleanup
|
||||
app.on_shutdown.append(cls._cleanup)
|
||||
app.on_shutdown.append(ApiRoutes.cleanup)
|
||||
|
||||
def parse_args():
|
||||
"""Parse command line arguments"""
|
||||
@@ -347,9 +405,6 @@ async def main():
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# Explicitly configure aiohttp access logger regardless of selected log level
|
||||
logging.getLogger('aiohttp.access').setLevel(logging.WARNING)
|
||||
|
||||
# Create the server instance
|
||||
server = StandaloneServer()
|
||||
|
||||
|
||||
@@ -46,12 +46,12 @@ html, body {
|
||||
|
||||
/* Composed Colors */
|
||||
--lora-accent: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
|
||||
--lora-surface: oklch(100% 0 0 / 0.98);
|
||||
--lora-surface: oklch(97% 0 0 / 0.95);
|
||||
--lora-border: oklch(90% 0.02 256 / 0.15);
|
||||
--lora-text: oklch(95% 0.02 256);
|
||||
--lora-error: oklch(75% 0.32 29);
|
||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h)); /* Modified to be used with oklch() */
|
||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h)); /* New green success color */
|
||||
--lora-warning: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
--lora-success: oklch(var(--lora-success-l) var(--lora-success-c) var(--lora-success-h));
|
||||
|
||||
/* Spacing Scale */
|
||||
--space-1: calc(8px * 1);
|
||||
@@ -70,6 +70,11 @@ html, body {
|
||||
--border-radius-xs: 4px;
|
||||
|
||||
--scrollbar-width: 8px; /* 添加滚动条宽度变量 */
|
||||
|
||||
/* Shortcut styles */
|
||||
--shortcut-bg: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.12);
|
||||
--shortcut-border: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.25);
|
||||
--shortcut-text: var(--text-color);
|
||||
}
|
||||
|
||||
html[data-theme="dark"] {
|
||||
|
||||
245
static/css/components/banner.css
Normal file
245
static/css/components/banner.css
Normal file
@@ -0,0 +1,245 @@
|
||||
/* Banner Container */
|
||||
.banner-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
z-index: calc(var(--z-header) - 1);
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
background: var(--card-bg);
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
/* Individual Banner */
|
||||
.banner-item {
|
||||
position: relative;
|
||||
padding: var(--space-2) var(--space-3);
|
||||
background: linear-gradient(135deg,
|
||||
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.05),
|
||||
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.02)
|
||||
);
|
||||
border-left: 4px solid var(--lora-accent);
|
||||
animation: banner-slide-down 0.3s ease-in-out;
|
||||
}
|
||||
|
||||
/* Banner Content Layout */
|
||||
.banner-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
gap: var(--space-3);
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
/* Banner Text Section */
|
||||
.banner-text {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.banner-title {
|
||||
margin: 0 0 4px 0;
|
||||
font-size: 1.1em;
|
||||
font-weight: 600;
|
||||
color: var(--text-color);
|
||||
line-height: 1.3;
|
||||
}
|
||||
|
||||
.banner-description {
|
||||
margin: 0;
|
||||
font-size: 0.9em;
|
||||
color: var(--text-muted);
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
/* Banner Actions */
|
||||
.banner-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: var(--space-1);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.banner-action {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
text-decoration: none;
|
||||
font-size: 0.85em;
|
||||
font-weight: 500;
|
||||
transition: all 0.2s ease;
|
||||
white-space: nowrap;
|
||||
border: 1px solid transparent;
|
||||
}
|
||||
|
||||
.banner-action i {
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
/* Primary Action Button */
|
||||
.banner-action-primary {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.banner-action-primary:hover {
|
||||
background: oklch(calc(var(--lora-accent-l) - 5%) var(--lora-accent-c) var(--lora-accent-h));
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 3px 6px oklch(var(--lora-accent) / 0.3);
|
||||
}
|
||||
|
||||
/* Secondary Action Button */
|
||||
.banner-action-secondary {
|
||||
background: var(--card-bg);
|
||||
color: var(--text-color);
|
||||
border-color: var(--border-color);
|
||||
}
|
||||
|
||||
.banner-action-secondary:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
border-color: var(--lora-accent);
|
||||
transform: translateY(-1px);
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
/* Tertiary Action Button */
|
||||
.banner-action-tertiary {
|
||||
background: transparent;
|
||||
color: var(--lora-accent);
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.banner-action-tertiary:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
/* Dismiss Button */
|
||||
.banner-dismiss {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
color: var(--text-muted);
|
||||
cursor: pointer;
|
||||
border-radius: 50%;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s ease;
|
||||
font-size: 0.8em;
|
||||
}
|
||||
|
||||
.banner-dismiss:hover {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
color: var(--lora-accent);
|
||||
transform: scale(1.1);
|
||||
}
|
||||
|
||||
/* Animations */
|
||||
@keyframes banner-slide-down {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-100%);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes banner-slide-up {
|
||||
from {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
max-height: 200px;
|
||||
}
|
||||
to {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
max-height: 0;
|
||||
padding-top: 0;
|
||||
padding-bottom: 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* Responsive Design */
|
||||
@media (max-width: 768px) {
|
||||
.banner-content {
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.banner-actions {
|
||||
width: 100%;
|
||||
flex-wrap: wrap;
|
||||
justify-content: flex-start;
|
||||
}
|
||||
|
||||
.banner-action {
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.banner-dismiss {
|
||||
top: 6px;
|
||||
right: 6px;
|
||||
}
|
||||
|
||||
.banner-item {
|
||||
padding: var(--space-2);
|
||||
}
|
||||
|
||||
.banner-title {
|
||||
font-size: 1em;
|
||||
}
|
||||
|
||||
.banner-description {
|
||||
font-size: 0.85em;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.banner-actions {
|
||||
flex-direction: column;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.banner-action {
|
||||
width: 100%;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.banner-content {
|
||||
gap: var(--space-1);
|
||||
}
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .banner-item {
|
||||
background: linear-gradient(135deg,
|
||||
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.08),
|
||||
oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.03)
|
||||
);
|
||||
}
|
||||
|
||||
/* Prevent text selection */
|
||||
.banner-item,
|
||||
.banner-title,
|
||||
.banner-description,
|
||||
.banner-action,
|
||||
.banner-dismiss {
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
-ms-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
@@ -1,84 +1,10 @@
|
||||
/* Bulk Operations Styles */
|
||||
.bulk-operations-panel {
|
||||
position: fixed;
|
||||
bottom: 20px;
|
||||
left: 50%;
|
||||
transform: translateY(100px) translateX(-50%);
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-base);
|
||||
padding: 12px 16px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
|
||||
z-index: var(--z-overlay);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-width: 300px;
|
||||
transition: all 0.4s cubic-bezier(0.175, 0.885, 0.32, 1.275);
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
.bulk-operations-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 12px;
|
||||
gap: 20px; /* Increase space between count and buttons */
|
||||
}
|
||||
|
||||
#selectedCount {
|
||||
font-weight: 500;
|
||||
background: var(--bg-color);
|
||||
padding: 6px 12px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
min-width: 80px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.bulk-operations-actions {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.bulk-operations-actions button {
|
||||
padding: 6px 12px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
font-size: 14px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.bulk-operations-actions button:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Danger button style - updated to use proper theme variables */
|
||||
.bulk-operations-actions button.danger-btn {
|
||||
background: oklch(70% 0.2 29); /* Light red background that works in both themes */
|
||||
color: oklch(98% 0.01 0); /* Almost white text for good contrast */
|
||||
border-color: var(--lora-error);
|
||||
}
|
||||
|
||||
.bulk-operations-actions button.danger-btn:hover {
|
||||
background: var(--lora-error);
|
||||
color: oklch(100% 0 0); /* Pure white text on hover for maximum contrast */
|
||||
}
|
||||
|
||||
/* Style for selected cards */
|
||||
.lora-card.selected {
|
||||
.model-card.selected {
|
||||
box-shadow: 0 0 0 2px var(--lora-accent);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.lora-card.selected::after {
|
||||
.model-card.selected::after {
|
||||
content: "✓";
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
@@ -95,201 +21,61 @@
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
/* Update bulk operations button to match others when active */
|
||||
#bulkOperationsBtn.active {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.bulk-operations-panel {
|
||||
width: calc(100% - 40px);
|
||||
left: 20px;
|
||||
transform: none;
|
||||
border-radius: var(--border-radius-sm);
|
||||
}
|
||||
|
||||
.bulk-operations-actions {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
}
|
||||
|
||||
.bulk-operations-panel.visible {
|
||||
transform: translateY(0) translateX(-50%);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Thumbnail Strip Styles */
|
||||
.selected-thumbnails-strip {
|
||||
/* Marquee selection styles */
|
||||
.marquee-selection {
|
||||
position: fixed;
|
||||
bottom: 80px; /* Position above the bulk operations panel */
|
||||
left: 50%;
|
||||
transform: translateX(-50%) translateY(20px);
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-base);
|
||||
box-shadow: 0 4px 16px rgba(0, 0, 0, 0.15);
|
||||
z-index: calc(var(--z-overlay) - 1); /* Just below the bulk panel z-index */
|
||||
padding: 16px;
|
||||
max-width: 80%;
|
||||
width: auto;
|
||||
transition: all 0.3s ease;
|
||||
opacity: 0;
|
||||
overflow: hidden;
|
||||
border: 2px dashed var(--lora-accent, #007bff);
|
||||
background: rgba(0, 123, 255, 0.1);
|
||||
pointer-events: none;
|
||||
z-index: 9999;
|
||||
border-radius: 2px;
|
||||
}
|
||||
|
||||
.selected-thumbnails-strip.visible {
|
||||
opacity: 1;
|
||||
transform: translateX(-50%) translateY(0);
|
||||
/* Visual feedback when marquee selecting */
|
||||
.marquee-selecting {
|
||||
cursor: crosshair;
|
||||
user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
-ms-user-select: none;
|
||||
}
|
||||
|
||||
.thumbnails-container {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
overflow-x: auto;
|
||||
padding-bottom: 8px; /* Space for scrollbar */
|
||||
/* Prevent text selection during marquee */
|
||||
.marquee-selecting * {
|
||||
user-select: none;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
-ms-user-select: none;
|
||||
}
|
||||
|
||||
/* Remove bulk base model modal specific styles - now using shared components */
|
||||
/* Use shared metadata editing styles instead */
|
||||
|
||||
/* Override for bulk base model select to ensure proper width */
|
||||
.bulk-base-model-select {
|
||||
width: 100%;
|
||||
max-width: 100%;
|
||||
align-items: flex-start;
|
||||
}
|
||||
|
||||
.selected-thumbnail {
|
||||
position: relative;
|
||||
width: 80px;
|
||||
min-width: 80px; /* Prevent shrinking */
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
overflow: hidden;
|
||||
cursor: pointer;
|
||||
background: var(--bg-color);
|
||||
transition: transform 0.2s ease, box-shadow 0.2s ease;
|
||||
}
|
||||
|
||||
.selected-thumbnail:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.selected-thumbnail img,
|
||||
.selected-thumbnail video {
|
||||
width: 100%;
|
||||
aspect-ratio: 1 / 1;
|
||||
object-fit: cover;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.thumbnail-name {
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
background: rgba(0, 0, 0, 0.6);
|
||||
color: white;
|
||||
font-size: 10px;
|
||||
padding: 3px 5px;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.thumbnail-remove {
|
||||
position: absolute;
|
||||
top: 3px;
|
||||
right: 3px;
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
border-radius: 50%;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
color: white;
|
||||
border: none;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
font-size: 10px;
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s ease, background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.thumbnail-remove:hover {
|
||||
opacity: 1;
|
||||
background: var(--lora-error);
|
||||
}
|
||||
|
||||
.strip-close-btn {
|
||||
position: absolute;
|
||||
top: 5px;
|
||||
right: 5px;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: none;
|
||||
border: none;
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s ease;
|
||||
font-size: 0.95em;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
.strip-close-btn:hover {
|
||||
opacity: 1;
|
||||
.bulk-base-model-select:focus {
|
||||
border-color: var(--lora-accent);
|
||||
outline: none;
|
||||
}
|
||||
|
||||
/* Style the selectedCount to indicate it's clickable */
|
||||
.selectable-count {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 5px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s ease;
|
||||
/* Dark theme support for bulk base model select */
|
||||
[data-theme="dark"] .bulk-base-model-select {
|
||||
background-color: rgba(30, 30, 30, 0.9);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.selectable-count:hover {
|
||||
background: var(--lora-border);
|
||||
}
|
||||
|
||||
.dropdown-caret {
|
||||
font-size: 12px;
|
||||
visibility: hidden; /* Will be shown via JS when items are selected */
|
||||
}
|
||||
|
||||
/* Scrollbar styling for the thumbnails container */
|
||||
.thumbnails-container::-webkit-scrollbar {
|
||||
height: 6px;
|
||||
}
|
||||
|
||||
.thumbnails-container::-webkit-scrollbar-track {
|
||||
background: var(--bg-color);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.thumbnails-container::-webkit-scrollbar-thumb {
|
||||
background: var(--border-color);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.thumbnails-container::-webkit-scrollbar-thumb:hover {
|
||||
background: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Mobile optimizations */
|
||||
@media (max-width: 768px) {
|
||||
.selected-thumbnails-strip {
|
||||
width: calc(100% - 40px);
|
||||
max-width: none;
|
||||
left: 20px;
|
||||
transform: translateY(20px);
|
||||
border-radius: var(--border-radius-sm);
|
||||
}
|
||||
|
||||
.selected-thumbnails-strip.visible {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
.selected-thumbnail {
|
||||
width: 70px;
|
||||
min-width: 70px;
|
||||
}
|
||||
[data-theme="dark"] .bulk-base-model-select option {
|
||||
background-color: #2d2d2d;
|
||||
color: var(--text-color);
|
||||
}
|
||||
@@ -14,7 +14,7 @@
|
||||
box-sizing: border-box; /* Include padding in width calculation */
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-base);
|
||||
@@ -30,24 +30,24 @@
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.lora-card:hover {
|
||||
.model-card:hover {
|
||||
transform: translateY(-2px);
|
||||
background: oklch(100% 0 0 / 0.6);
|
||||
}
|
||||
|
||||
.lora-card:focus-visible {
|
||||
.model-card:focus-visible {
|
||||
outline: 2px solid var(--lora-accent);
|
||||
outline-offset: 2px;
|
||||
}
|
||||
|
||||
/* Responsive adjustments for 1440p screens (2K) */
|
||||
@media (min-width: 2000px) {
|
||||
@media (min-width: 2150px) {
|
||||
.card-grid {
|
||||
max-width: 1800px; /* Increased for 2K screens */
|
||||
grid-template-columns: repeat(auto-fill, minmax(270px, 1fr));
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 270px;
|
||||
}
|
||||
}
|
||||
@@ -59,7 +59,7 @@
|
||||
grid-template-columns: repeat(auto-fill, minmax(280px, 1fr));
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 280px;
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@
|
||||
grid-template-columns: repeat(auto-fill, minmax(240px, 1fr));
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 240px;
|
||||
}
|
||||
}
|
||||
@@ -259,8 +259,8 @@
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.hover-reveal .lora-card:hover .card-header,
|
||||
.hover-reveal .lora-card:hover .card-footer {
|
||||
.hover-reveal .model-card:hover .card-header,
|
||||
.hover-reveal .model-card:hover .card-footer {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
@@ -345,7 +345,7 @@
|
||||
grid-template-columns: minmax(260px, 1fr); /* Adjusted minimum size for mobile */
|
||||
}
|
||||
|
||||
.lora-card {
|
||||
.model-card {
|
||||
max-width: 100%; /* Allow cards to fill available space on mobile */
|
||||
}
|
||||
}
|
||||
@@ -424,9 +424,36 @@
|
||||
font-size: 0.85em;
|
||||
}
|
||||
|
||||
/* Style for version name */
|
||||
.version-name {
|
||||
display: inline-block;
|
||||
color: rgba(255,255,255,0.8); /* Muted white */
|
||||
text-shadow: 1px 1px 2px rgba(0, 0, 0, 0.5);
|
||||
font-size: 0.85em;
|
||||
word-break: break-word;
|
||||
overflow: hidden;
|
||||
line-height: 1.4;
|
||||
margin-top: 2px;
|
||||
opacity: 0.8; /* Slightly transparent for better readability */
|
||||
border: 1px solid rgba(255,255,255,0.25); /* Subtle border */
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: 1px 6px;
|
||||
background: rgba(0,0,0,0.18); /* Optional: subtle background for contrast */
|
||||
}
|
||||
|
||||
/* Medium density adjustments for version name */
|
||||
.medium-density .version-name {
|
||||
font-size: 0.8em;
|
||||
}
|
||||
|
||||
/* Compact density adjustments for version name */
|
||||
.compact-density .version-name {
|
||||
font-size: 0.75em;
|
||||
}
|
||||
|
||||
/* Prevent text selection on cards and interactive elements */
|
||||
.lora-card,
|
||||
.lora-card *,
|
||||
.model-card,
|
||||
.model-card *,
|
||||
.card-actions,
|
||||
.card-actions i,
|
||||
.toggle-blur-btn,
|
||||
@@ -498,7 +525,7 @@
|
||||
}
|
||||
|
||||
/* For larger screens, allow more space for the cards */
|
||||
@media (min-width: 2000px) {
|
||||
@media (min-width: 2150px) {
|
||||
.card-grid.virtual-scroll {
|
||||
max-width: 1800px;
|
||||
}
|
||||
@@ -510,7 +537,7 @@
|
||||
}
|
||||
}
|
||||
|
||||
/* Add after the existing .lora-card:hover styles */
|
||||
/* Add after the existing .model-card:hover styles */
|
||||
|
||||
@keyframes update-pulse {
|
||||
0% { box-shadow: 0 0 0 0 var(--lora-accent-transparent); }
|
||||
@@ -523,7 +550,7 @@
|
||||
--lora-accent-transparent: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.6);
|
||||
}
|
||||
|
||||
.lora-card.updated {
|
||||
.model-card.updated {
|
||||
animation: update-pulse 1.2s ease-out;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,197 +0,0 @@
|
||||
/* Download Modal Styles */
|
||||
.download-step {
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.input-group input,
|
||||
.input-group select {
|
||||
width: 100%;
|
||||
padding: 8px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Version List Styles */
|
||||
.version-list {
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
margin: var(--space-2) 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
padding: 1px;
|
||||
}
|
||||
|
||||
.version-item {
|
||||
display: flex;
|
||||
gap: var(--space-2);
|
||||
padding: var(--space-2);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
background: var(--bg-color);
|
||||
margin: 1px;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.version-item:hover {
|
||||
border-color: var(--lora-accent);
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.version-item.selected {
|
||||
border: 2px solid var(--lora-accent);
|
||||
background: oklch(var(--lora-accent) / 0.05);
|
||||
}
|
||||
|
||||
.version-thumbnail {
|
||||
width: 80px;
|
||||
height: 80px;
|
||||
flex-shrink: 0;
|
||||
border-radius: var(--border-radius-xs);
|
||||
overflow: hidden;
|
||||
background: var(--bg-color);
|
||||
}
|
||||
|
||||
.version-thumbnail img {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.version-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.version-header {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
justify-content: space-between;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.version-content h3 {
|
||||
margin: 0;
|
||||
font-size: 1.1em;
|
||||
color: var(--text-color);
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.version-content .version-info {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
flex-direction: row !important;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.version-content .version-info .base-model {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
color: var(--lora-accent);
|
||||
padding: 2px 8px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
}
|
||||
|
||||
.version-meta {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
font-size: 0.85em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.version-meta span {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
/* Folder Browser Styles */
|
||||
.folder-browser {
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-1);
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.folder-item {
|
||||
padding: 8px;
|
||||
cursor: pointer;
|
||||
border-radius: var(--border-radius-xs);
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.folder-item:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.folder-item.selected {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
border: 1px solid var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Path Preview Styles */
|
||||
.path-preview {
|
||||
margin-bottom: var(--space-3);
|
||||
padding: var(--space-2);
|
||||
background: var(--bg-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
border: 1px dashed var(--border-color);
|
||||
}
|
||||
|
||||
.path-preview label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.path-display {
|
||||
padding: var(--space-1);
|
||||
color: var(--text-color);
|
||||
font-family: monospace;
|
||||
font-size: 0.9em;
|
||||
line-height: 1.4;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
opacity: 0.85;
|
||||
background: var(--lora-surface);
|
||||
border-radius: var(--border-radius-xs);
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .version-item {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .local-path {
|
||||
background: var(--lora-surface);
|
||||
border-color: var(--lora-border);
|
||||
}
|
||||
|
||||
/* Enhance the local badge to make it more noticeable */
|
||||
.version-item.exists-locally {
|
||||
background: oklch(var(--lora-accent) / 0.05);
|
||||
border-left: 4px solid var(--lora-accent);
|
||||
}
|
||||
@@ -27,7 +27,7 @@
|
||||
}
|
||||
|
||||
/* Responsive container for larger screens - match container in layout.css */
|
||||
@media (min-width: 2000px) {
|
||||
@media (min-width: 2150px) {
|
||||
.duplicates-banner .banner-content {
|
||||
max-width: 1800px;
|
||||
}
|
||||
@@ -130,7 +130,7 @@
|
||||
}
|
||||
|
||||
/* Add responsive container adjustments for duplicate groups - match container in banner */
|
||||
@media (min-width: 2000px) {
|
||||
@media (min-width: 2150px) {
|
||||
.duplicate-group {
|
||||
max-width: 1800px;
|
||||
}
|
||||
@@ -195,7 +195,7 @@
|
||||
}
|
||||
|
||||
/* Make cards in duplicate groups have consistent width */
|
||||
.card-group-container .lora-card {
|
||||
.card-group-container .model-card {
|
||||
flex: 0 0 auto;
|
||||
width: 240px;
|
||||
margin: 0;
|
||||
@@ -241,26 +241,26 @@
|
||||
}
|
||||
|
||||
/* Duplicate card styling */
|
||||
.lora-card.duplicate {
|
||||
.model-card.duplicate {
|
||||
position: relative;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.lora-card.duplicate:hover {
|
||||
.model-card.duplicate:hover {
|
||||
border-color: var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h);
|
||||
}
|
||||
|
||||
.lora-card.duplicate.latest {
|
||||
.model-card.duplicate.latest {
|
||||
border-style: solid;
|
||||
border-color: oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
}
|
||||
|
||||
.lora-card.duplicate-selected {
|
||||
.model-card.duplicate-selected {
|
||||
border: 2px solid oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h));
|
||||
box-shadow: 0 0 8px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.lora-card .selector-checkbox {
|
||||
.model-card .selector-checkbox {
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
right: 10px;
|
||||
@@ -271,7 +271,7 @@
|
||||
}
|
||||
|
||||
/* Latest indicator */
|
||||
.lora-card.duplicate.latest::after {
|
||||
.model-card.duplicate.latest::after {
|
||||
content: "Latest";
|
||||
position: absolute;
|
||||
top: 10px;
|
||||
@@ -365,13 +365,13 @@
|
||||
}
|
||||
|
||||
/* Hash Mismatch Styling */
|
||||
.lora-card.duplicate.hash-mismatch {
|
||||
.model-card.duplicate.hash-mismatch {
|
||||
border: 2px dashed oklch(var(--lora-warning-l) var(--lora-warning-c) var(--lora-warning-h));
|
||||
opacity: 0.85;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.lora-card.duplicate.hash-mismatch::before {
|
||||
.model-card.duplicate.hash-mismatch::before {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 0;
|
||||
@@ -389,7 +389,7 @@
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.lora-card.duplicate.hash-mismatch .card-preview {
|
||||
.model-card.duplicate.hash-mismatch .card-preview {
|
||||
filter: grayscale(20%);
|
||||
}
|
||||
|
||||
@@ -407,7 +407,7 @@
|
||||
}
|
||||
|
||||
/* Disabled checkbox style */
|
||||
.lora-card.duplicate.hash-mismatch .selector-checkbox {
|
||||
.model-card.duplicate.hash-mismatch .selector-checkbox {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
z-index: var(--z-header);
|
||||
height: 48px; /* Reduced height */
|
||||
width: 100%;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.05);
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1); /* Slightly stronger shadow */
|
||||
}
|
||||
|
||||
.header-container {
|
||||
@@ -19,6 +19,18 @@
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
/* Responsive header container for larger screens */
|
||||
@media (min-width: 2150px) {
|
||||
.header-container {
|
||||
max-width: 1800px;
|
||||
}
|
||||
}
|
||||
@media (min-width: 3000px) {
|
||||
.header-container {
|
||||
max-width: 2400px;
|
||||
}
|
||||
}
|
||||
|
||||
/* Logo and title styling */
|
||||
.header-branding {
|
||||
display: flex;
|
||||
@@ -31,7 +43,7 @@
|
||||
align-items: center;
|
||||
text-decoration: none;
|
||||
color: var(--text-color);
|
||||
gap: 8px;
|
||||
gap: 2px;
|
||||
}
|
||||
|
||||
.app-logo {
|
||||
@@ -223,11 +235,6 @@
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.update-badge.hidden,
|
||||
.update-badge:not(.visible) {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
/* Mobile adjustments */
|
||||
@media (max-width: 768px) {
|
||||
.app-title {
|
||||
|
||||
@@ -337,72 +337,7 @@
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
/* Location Selection Styles */
|
||||
.location-selection {
|
||||
margin: var(--space-2) 0;
|
||||
padding: var(--space-2);
|
||||
background: var(--lora-surface);
|
||||
border-radius: var(--border-radius-sm);
|
||||
}
|
||||
|
||||
/* Reuse folder browser and path preview styles from download-modal.css */
|
||||
.folder-browser {
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-1);
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.folder-item {
|
||||
padding: 8px;
|
||||
cursor: pointer;
|
||||
border-radius: var(--border-radius-xs);
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.folder-item:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.folder-item.selected {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
border: 1px solid var(--lora-accent);
|
||||
}
|
||||
|
||||
.path-preview {
|
||||
margin-bottom: var(--space-3);
|
||||
padding: var(--space-2);
|
||||
background: var(--bg-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
border: 1px dashed var(--border-color);
|
||||
}
|
||||
|
||||
.path-preview label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.path-display {
|
||||
padding: var(--space-1);
|
||||
color: var(--text-color);
|
||||
font-family: monospace;
|
||||
font-size: 0.9em;
|
||||
line-height: 1.4;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
opacity: 0.85;
|
||||
background: var(--lora-surface);
|
||||
border-radius: var(--border-radius-xs);
|
||||
}
|
||||
|
||||
/* Input Group Styles */
|
||||
.input-group {
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.input-with-button {
|
||||
display: flex;
|
||||
@@ -430,22 +365,6 @@
|
||||
background: oklch(from var(--lora-accent) l c h / 0.9);
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.input-group input,
|
||||
.input-group select {
|
||||
width: 100%;
|
||||
padding: 8px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .lora-item {
|
||||
background: var(--lora-surface);
|
||||
|
||||
@@ -40,10 +40,10 @@
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: 8px;
|
||||
position: absolute;
|
||||
z-index: 9999; /* 确保在卡片上方显示 */
|
||||
left: 120%; /* 将tooltip显示在图标右侧 */
|
||||
top: 50%; /* 垂直居中 */
|
||||
transform: translateY(-50%); /* 垂直居中 */
|
||||
z-index: 9999; /* Ensure tooltip appears above cards */
|
||||
left: 120%; /* Position tooltip to the right of the icon */
|
||||
top: 50%; /* Vertically center */
|
||||
transform: translateY(-15%); /* Vertically center */
|
||||
opacity: 0;
|
||||
transition: opacity 0.3s;
|
||||
box-shadow: 0 3px 8px rgba(0, 0, 0, 0.15);
|
||||
@@ -55,12 +55,12 @@
|
||||
.tooltip .tooltiptext::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 50%; /* 箭头垂直居中 */
|
||||
right: 100%; /* 箭头在左侧 */
|
||||
top: 50%; /* Vertically center arrow */
|
||||
right: 100%; /* Arrow on the left side */
|
||||
margin-top: -5px;
|
||||
border-width: 5px;
|
||||
border-style: solid;
|
||||
border-color: transparent var(--lora-border) transparent transparent; /* 箭头指向左侧 */
|
||||
border-color: transparent var(--lora-border) transparent transparent; /* Arrow points left */
|
||||
}
|
||||
|
||||
.tooltip:hover .tooltiptext {
|
||||
|
||||
@@ -109,7 +109,7 @@
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.lora-card,
|
||||
.model-card,
|
||||
.progress-bar,
|
||||
.current-item-bar {
|
||||
transition: none;
|
||||
|
||||
@@ -183,7 +183,11 @@
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.edit-file-name-btn {
|
||||
/* 合并编辑按钮样式 */
|
||||
.edit-model-name-btn,
|
||||
.edit-file-name-btn,
|
||||
.edit-base-model-btn,
|
||||
.edit-model-description-btn {
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--text-color);
|
||||
@@ -195,17 +199,28 @@
|
||||
margin-left: var(--space-1);
|
||||
}
|
||||
|
||||
.edit-model-name-btn.visible,
|
||||
.edit-file-name-btn.visible,
|
||||
.file-name-wrapper:hover .edit-file-name-btn {
|
||||
.edit-base-model-btn.visible,
|
||||
.edit-model-description-btn.visible,
|
||||
.model-name-header:hover .edit-model-name-btn,
|
||||
.file-name-wrapper:hover .edit-file-name-btn,
|
||||
.base-model-display:hover .edit-base-model-btn,
|
||||
.model-name-header:hover .edit-model-description-btn {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.edit-file-name-btn:hover {
|
||||
.edit-model-name-btn:hover,
|
||||
.edit-file-name-btn:hover,
|
||||
.edit-base-model-btn:hover,
|
||||
.edit-model-description-btn:hover {
|
||||
opacity: 0.8 !important;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .edit-file-name-btn:hover {
|
||||
[data-theme="dark"] .edit-model-name-btn:hover,
|
||||
[data-theme="dark"] .edit-file-name-btn:hover,
|
||||
[data-theme="dark"] .edit-base-model-btn:hover {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
@@ -234,32 +249,6 @@
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.edit-base-model-btn {
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--text-color);
|
||||
opacity: 0;
|
||||
cursor: pointer;
|
||||
padding: 2px 5px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
transition: all 0.2s ease;
|
||||
margin-left: var(--space-1);
|
||||
}
|
||||
|
||||
.edit-base-model-btn.visible,
|
||||
.base-model-display:hover .edit-base-model-btn {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.edit-base-model-btn:hover {
|
||||
opacity: 0.8 !important;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .edit-base-model-btn:hover {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
.base-model-selector {
|
||||
width: 100%;
|
||||
padding: 3px 5px;
|
||||
@@ -316,32 +305,6 @@
|
||||
background: var(--bg-color);
|
||||
}
|
||||
|
||||
.edit-model-name-btn {
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--text-color);
|
||||
opacity: 0;
|
||||
cursor: pointer;
|
||||
padding: 2px 5px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
transition: all 0.2s ease;
|
||||
margin-left: var(--space-1);
|
||||
}
|
||||
|
||||
.edit-model-name-btn.visible,
|
||||
.model-name-header:hover .edit-model-name-btn {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.edit-model-name-btn:hover {
|
||||
opacity: 0.8 !important;
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .edit-model-name-btn:hover {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
/* Tab System Styling */
|
||||
.showcase-tabs {
|
||||
display: flex;
|
||||
@@ -436,22 +399,24 @@
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
margin-bottom: var(--space-1);
|
||||
padding: 6px 10px;
|
||||
padding: 2px 10px;
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-sm);
|
||||
max-width: fit-content;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .creator-info {
|
||||
[data-theme="dark"] .creator-info,
|
||||
[data-theme="dark"] .civitai-view {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
.creator-avatar {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
width: 26px;
|
||||
height: 26px;
|
||||
border-radius: 50%;
|
||||
overflow: hidden;
|
||||
flex-shrink: 0;
|
||||
@@ -482,8 +447,40 @@
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Optional: add hover effect for creator info */
|
||||
.creator-info:hover {
|
||||
/* Add hover effect for creator info */
|
||||
.creator-info:hover,
|
||||
.civitai-view:hover {
|
||||
background: oklch(var(--lora-accent-l) var(--lora-accent-c) var(--lora-accent-h) / 0.1);
|
||||
border-color: var(--lora-accent);
|
||||
transform: translateY(-1px);
|
||||
}
|
||||
|
||||
.creator-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
margin-bottom: var(--space-1);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.civitai-view {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-sm);
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
font-size: 0.9em;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.civitai-view i {
|
||||
font-size: 20px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
@@ -176,11 +176,6 @@
|
||||
background: linear-gradient(45deg, #4a90e2, #357abd);
|
||||
}
|
||||
|
||||
/* Remove old node-color-indicator styles */
|
||||
.node-color-indicator {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.send-all-item {
|
||||
border-top: 1px solid var(--border-color);
|
||||
font-weight: 500;
|
||||
@@ -217,4 +212,24 @@
|
||||
font-size: 12px;
|
||||
color: var(--text-muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Bulk Context Menu Header */
|
||||
.bulk-context-header {
|
||||
padding: 10px 12px;
|
||||
background: var(--card-bg); /* Use card background for subtlety */
|
||||
color: var(--text-color); /* Use standard text color */
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-weight: 500;
|
||||
font-size: 14px;
|
||||
border-radius: var(--border-radius-xs) var(--border-radius-xs) 0 0;
|
||||
border-bottom: 1px solid var(--border-color); /* Add subtle separator */
|
||||
}
|
||||
|
||||
.bulk-context-header i {
|
||||
width: 16px;
|
||||
text-align: center;
|
||||
color: var(--lora-accent); /* Accent only the icon for a hint of color */
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
278
static/css/components/modal/_base.css
Normal file
278
static/css/components/modal/_base.css
Normal file
@@ -0,0 +1,278 @@
|
||||
/* modal 基础样式 */
|
||||
.modal {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 48px; /* Start below the header */
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: calc(100% - 48px); /* Adjust height to exclude header */
|
||||
background: rgba(0, 0, 0, 0.2); /* 调整为更淡的半透明黑色 */
|
||||
z-index: var(--z-modal);
|
||||
overflow: auto; /* Change from hidden to auto to allow scrolling */
|
||||
}
|
||||
|
||||
/* 当模态窗口打开时,禁止body滚动 */
|
||||
body.modal-open {
|
||||
position: fixed;
|
||||
width: 100%;
|
||||
padding-right: var(--scrollbar-width, 0px); /* 补偿滚动条消失导致的页面偏移 */
|
||||
}
|
||||
|
||||
/* modal-content 样式 */
|
||||
.modal-content {
|
||||
position: relative;
|
||||
max-width: 800px;
|
||||
height: auto;
|
||||
max-height: calc(90vh);
|
||||
margin: 1rem auto; /* Keep reduced top margin */
|
||||
background: var(--lora-surface);
|
||||
border-radius: var(--border-radius-base);
|
||||
padding: var(--space-3);
|
||||
border: 1px solid var(--lora-border);
|
||||
box-shadow:
|
||||
0 4px 6px -1px rgba(0, 0, 0, 0.1),
|
||||
0 2px 4px -1px rgba(0, 0, 0, 0.06),
|
||||
0 10px 15px -3px rgba(0, 0, 0, 0.05);
|
||||
overflow-y: auto;
|
||||
overflow-x: hidden; /* 防止水平滚动条 */
|
||||
}
|
||||
|
||||
.modal-content-large {
|
||||
min-height: 480px;
|
||||
}
|
||||
|
||||
/* 当 modal 打开时锁定 body */
|
||||
body.modal-open {
|
||||
overflow: hidden !important; /* 覆盖 base.css 中的 scroll */
|
||||
padding-right: var(--scrollbar-width, 8px); /* 使用滚动条宽度作为补偿 */
|
||||
}
|
||||
|
||||
@keyframes modalFadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.modal-actions {
|
||||
display: flex;
|
||||
gap: var(--space-2);
|
||||
justify-content: center;
|
||||
margin-top: var(--space-3);
|
||||
}
|
||||
|
||||
.cancel-btn, .delete-btn, .exclude-btn, .confirm-btn {
|
||||
padding: 8px var(--space-2);
|
||||
border-radius: 6px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
min-width: 100px;
|
||||
}
|
||||
|
||||
.cancel-btn {
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.delete-btn {
|
||||
background: var(--lora-error);
|
||||
color: white;
|
||||
}
|
||||
|
||||
/* Style for exclude button - different from delete button */
|
||||
.exclude-btn, .confirm-btn {
|
||||
background: var(--lora-accent, #4f46e5);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.cancel-btn:hover {
|
||||
background: var(--lora-border);
|
||||
}
|
||||
|
||||
.delete-btn:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.exclude-btn:hover, .confirm-btn:hover {
|
||||
opacity: 0.9;
|
||||
background: oklch(from var(--lora-accent, #4f46e5) l c h / 85%);
|
||||
}
|
||||
|
||||
.modal-content h2 {
|
||||
color: var(--text-color);
|
||||
margin-bottom: var(--space-1);
|
||||
font-size: 1.5em;
|
||||
}
|
||||
|
||||
.close {
|
||||
position: absolute;
|
||||
top: var(--space-2);
|
||||
right: var(--space-2);
|
||||
background: transparent;
|
||||
border: none;
|
||||
color: var(--text-color);
|
||||
font-size: 1.5em;
|
||||
cursor: pointer;
|
||||
opacity: 0.7;
|
||||
transition: opacity 0.2s;
|
||||
}
|
||||
|
||||
.close:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* 统一各个 section 的样式 */
|
||||
.support-section,
|
||||
.changelog-section,
|
||||
.update-info,
|
||||
.info-item,
|
||||
.path-preview {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
}
|
||||
|
||||
/* 深色主题统一样式 */
|
||||
[data-theme="dark"] .modal-content {
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .support-section,
|
||||
[data-theme="dark"] .changelog-section,
|
||||
[data-theme="dark"] .update-info,
|
||||
[data-theme="dark"] .info-item,
|
||||
[data-theme="dark"] .path-preview {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
.primary-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 8px 16px;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
border: none;
|
||||
border-radius: var(--border-radius-sm);
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.primary-btn:hover {
|
||||
background-color: oklch(from var(--lora-accent) l c h / 85%);
|
||||
color: var(--lora-text);
|
||||
}
|
||||
|
||||
/* Secondary button styles */
|
||||
.secondary-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 8px 16px;
|
||||
background-color: var(--card-bg);
|
||||
color: var (--text-color);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.secondary-btn:hover {
|
||||
background-color: var(--border-color);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Disabled button styles */
|
||||
.primary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.secondary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.restart-required-icon {
|
||||
color: var(--lora-warning);
|
||||
margin-left: 5px;
|
||||
font-size: 0.85em;
|
||||
vertical-align: text-bottom;
|
||||
}
|
||||
|
||||
/* Dark theme specific button adjustments */
|
||||
[data-theme="dark"] .primary-btn:hover {
|
||||
background-color: oklch(from var(--lora-accent) l c h / 75%);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .secondary-btn {
|
||||
background-color: var(--lora-surface);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .secondary-btn:hover {
|
||||
background-color: oklch(35% 0.02 256 / 0.98);
|
||||
}
|
||||
|
||||
.primary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.primary-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
/* Add styles for delete preview image */
|
||||
.delete-preview {
|
||||
max-width: 150px;
|
||||
margin: 0 auto var(--space-2);
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.delete-preview img {
|
||||
width: 100%;
|
||||
height: auto;
|
||||
max-height: 150px;
|
||||
object-fit: contain;
|
||||
border-radius: var(--border-radius-sm);
|
||||
}
|
||||
|
||||
.delete-info {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.delete-info h3 {
|
||||
margin-bottom: var(--space-1);
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.delete-info p {
|
||||
margin: var(--space-1) 0;
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.delete-note {
|
||||
font-size: 0.85em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
font-style: italic;
|
||||
margin-top: var(--space-1);
|
||||
text-align: center;
|
||||
}
|
||||
48
static/css/components/modal/delete-modal.css
Normal file
48
static/css/components/modal/delete-modal.css
Normal file
@@ -0,0 +1,48 @@
|
||||
/* Delete Modal specific styles */
|
||||
|
||||
.delete-message {
|
||||
color: var(--text-color);
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
/* Update delete modal styles */
|
||||
.delete-modal {
|
||||
display: none; /* Set initial display to none */
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
z-index: var(--z-overlay);
|
||||
}
|
||||
|
||||
/* Add new style for when modal is shown */
|
||||
.delete-modal.show {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.delete-modal-content {
|
||||
max-width: 500px;
|
||||
width: 90%;
|
||||
text-align: center;
|
||||
margin: 0 auto;
|
||||
position: relative;
|
||||
animation: modalFadeIn 0.2s ease-out;
|
||||
}
|
||||
|
||||
.delete-model-info,
|
||||
.exclude-model-info {
|
||||
/* Update info display styling */
|
||||
background: var(--lora-surface);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
margin: var(--space-2) 0;
|
||||
color: var(--text-color);
|
||||
word-break: break-all;
|
||||
text-align: left;
|
||||
line-height: 1.5;
|
||||
}
|
||||
505
static/css/components/modal/download-modal.css
Normal file
505
static/css/components/modal/download-modal.css
Normal file
@@ -0,0 +1,505 @@
|
||||
/* Download Modal Styles */
|
||||
.input-group {
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.input-group input,
|
||||
.input-group select {
|
||||
width: 100%;
|
||||
padding: 8px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Version List Styles */
|
||||
.version-list {
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
margin: var(--space-2) 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 12px;
|
||||
padding: 1px;
|
||||
}
|
||||
|
||||
.version-item {
|
||||
display: flex;
|
||||
gap: var(--space-2);
|
||||
padding: var(--space-2);
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
background: var(--bg-color);
|
||||
margin: 1px;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.version-item:hover {
|
||||
border-color: var(--lora-accent);
|
||||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.version-item.selected {
|
||||
border: 2px solid var(--lora-accent);
|
||||
background: oklch(var(--lora-accent) / 0.05);
|
||||
}
|
||||
|
||||
.version-thumbnail {
|
||||
width: 80px;
|
||||
height: 80px;
|
||||
flex-shrink: 0;
|
||||
border-radius: var(--border-radius-xs);
|
||||
overflow: hidden;
|
||||
background: var(--bg-color);
|
||||
}
|
||||
|
||||
.version-thumbnail img {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.version-content {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
flex: 1;
|
||||
min-width: 0;
|
||||
}
|
||||
|
||||
.version-header {
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
justify-content: space-between;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.version-content h3 {
|
||||
margin: 0;
|
||||
font-size: 1.1em;
|
||||
color: var(--text-color);
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.version-content .version-info {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
flex-direction: row !important;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.version-content .version-info .base-model {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
color: var(--lora-accent);
|
||||
padding: 2px 8px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
}
|
||||
|
||||
.version-meta {
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
font-size: 0.85em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.version-meta span {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.folder-item {
|
||||
padding: 8px;
|
||||
cursor: pointer;
|
||||
border-radius: var(--border-radius-xs);
|
||||
transition: background-color 0.2s;
|
||||
}
|
||||
|
||||
.folder-item:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.folder-item.selected {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
border: 1px solid var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Path Input Styles */
|
||||
.path-input-container {
|
||||
position: relative;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.path-input-container input {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.create-folder-btn {
|
||||
padding: 8px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
}
|
||||
|
||||
.create-folder-btn:hover {
|
||||
border-color: var(--lora-accent);
|
||||
background: oklch(var(--lora-accent) / 0.05);
|
||||
}
|
||||
|
||||
.path-suggestions {
|
||||
position: absolute;
|
||||
top: 46%;
|
||||
left: 0;
|
||||
right: 0;
|
||||
z-index: 1000;
|
||||
margin: 0 24px;
|
||||
background: var(--bg-color);
|
||||
border: 1px solid var(--border-color);
|
||||
border-top: none;
|
||||
border-radius: 0 0 var(--border-radius-xs) var(--border-radius-xs);
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.path-suggestion {
|
||||
padding: 8px 12px;
|
||||
cursor: pointer;
|
||||
transition: background-color 0.2s;
|
||||
border-bottom: 1px solid var(--border-color);
|
||||
}
|
||||
|
||||
.path-suggestion:last-child {
|
||||
border-bottom: none;
|
||||
}
|
||||
|
||||
.path-suggestion:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.path-suggestion.active {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Breadcrumb Navigation Styles */
|
||||
.breadcrumb-nav {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
margin-bottom: var(--space-2);
|
||||
padding: var(--space-1);
|
||||
background: var(--lora-surface);
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
overflow-x: auto;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.breadcrumb-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
padding: 4px 8px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.breadcrumb-item:hover {
|
||||
background: var(--bg-color);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.breadcrumb-item.active {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
color: var(--lora-accent);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.breadcrumb-separator {
|
||||
color: var(--text-color);
|
||||
opacity: 0.5;
|
||||
margin: 0 4px;
|
||||
}
|
||||
|
||||
/* Folder Tree Styles */
|
||||
.folder-tree-container {
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.folder-tree {
|
||||
padding: var(--space-1);
|
||||
}
|
||||
|
||||
.tree-node {
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.tree-node-content {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
padding: 4px 8px;
|
||||
cursor: pointer;
|
||||
border-radius: var(--border-radius-xs);
|
||||
transition: all 0.2s ease;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.tree-node-content:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.tree-node-content.selected {
|
||||
background: oklch(var(--lora-accent) / 0.1);
|
||||
border: 1px solid var(--lora-accent);
|
||||
}
|
||||
|
||||
.tree-expand-icon {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
border-radius: 2px;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.tree-expand-icon:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
.tree-expand-icon.expanded {
|
||||
transform: rotate(90deg);
|
||||
}
|
||||
|
||||
.tree-folder-icon {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.tree-folder-name {
|
||||
flex: 1;
|
||||
font-size: 0.9em;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.tree-children {
|
||||
margin-left: 20px;
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tree-children.expanded {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.tree-node.has-children > .tree-node-content .tree-expand-icon {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.tree-node:not(.has-children) > .tree-node-content .tree-expand-icon {
|
||||
opacity: 0;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
/* Create folder inline form */
|
||||
.create-folder-form {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
margin-left: 20px;
|
||||
align-items: center;
|
||||
height: 21px;
|
||||
}
|
||||
|
||||
.create-folder-form input {
|
||||
flex: 1;
|
||||
padding: 4px 8px;
|
||||
border: 1px solid var(--lora-accent);
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.create-folder-form button {
|
||||
padding: 4px 8px;
|
||||
border: 1px solid var(--border-color);
|
||||
border-radius: var(--border-radius-xs);
|
||||
background: var(--bg-color);
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
font-size: 0.8em;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.create-folder-form button.confirm {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.create-folder-form button:hover {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
/* Path Preview Styles */
|
||||
.path-preview {
|
||||
margin-bottom: var(--space-3);
|
||||
padding: var(--space-2);
|
||||
background: var(--bg-color);
|
||||
border-radius: var(--border-radius-sm);
|
||||
border: 1px dashed var(--border-color);
|
||||
}
|
||||
|
||||
.path-preview-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 12px;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.path-preview-header label {
|
||||
margin: 0;
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.path-display {
|
||||
padding: var(--space-1);
|
||||
color: var(--text-color);
|
||||
font-family: monospace;
|
||||
font-size: 0.9em;
|
||||
line-height: 1.4;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-all;
|
||||
opacity: 0.85;
|
||||
background: var(--lora-surface);
|
||||
border-radius: var(--border-radius-xs);
|
||||
}
|
||||
|
||||
/* Inline Toggle Styles */
|
||||
.inline-toggle-container {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.inline-toggle-label {
|
||||
font-size: 0.85em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.9;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.inline-toggle-container .toggle-switch {
|
||||
position: relative;
|
||||
width: 36px;
|
||||
height: 18px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.inline-toggle-container .toggle-switch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
position: absolute;
|
||||
}
|
||||
|
||||
.inline-toggle-container .toggle-slider {
|
||||
position: absolute;
|
||||
cursor: pointer;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: var(--border-color);
|
||||
transition: all 0.3s ease;
|
||||
border-radius: 18px;
|
||||
}
|
||||
|
||||
.inline-toggle-container .toggle-slider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 12px;
|
||||
width: 12px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
transition: all 0.3s ease;
|
||||
border-radius: 50%;
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.2);
|
||||
}
|
||||
|
||||
.inline-toggle-container .toggle-switch input:checked + .toggle-slider {
|
||||
background-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.inline-toggle-container .toggle-switch input:checked + .toggle-slider:before {
|
||||
transform: translateX(18px);
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .version-item {
|
||||
background: var(--lora-surface);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .local-path {
|
||||
background: var(--lora-surface);
|
||||
border-color: var(--lora-border);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .toggle-slider:before {
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
|
||||
/* Enhance the local badge to make it more noticeable */
|
||||
.version-item.exists-locally {
|
||||
background: oklch(var(--lora-accent) / 0.05);
|
||||
border-left: 4px solid var(--lora-accent);
|
||||
}
|
||||
|
||||
.manual-path-selection.disabled {
|
||||
opacity: 0.5;
|
||||
pointer-events: none;
|
||||
user-select: none;
|
||||
}
|
||||
72
static/css/components/modal/example-access-modal.css
Normal file
72
static/css/components/modal/example-access-modal.css
Normal file
@@ -0,0 +1,72 @@
|
||||
/* Example Access Modal */
|
||||
.example-access-modal {
|
||||
max-width: 550px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.example-access-options {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
margin: var(--space-3) 0;
|
||||
}
|
||||
|
||||
.example-option-btn {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
padding: var(--space-2);
|
||||
border-radius: var(--border-radius-sm);
|
||||
border: 1px solid var(--lora-border);
|
||||
background-color: var(--lora-surface);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.example-option-btn:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.example-option-btn i {
|
||||
font-size: 2em;
|
||||
margin-bottom: var(--space-1);
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.option-title {
|
||||
font-weight: 500;
|
||||
margin-bottom: 4px;
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
.option-desc {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.example-option-btn.disabled {
|
||||
opacity: 0.5;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.example-option-btn.disabled i {
|
||||
color: var(--text-color);
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.modal-footer-note {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.7;
|
||||
margin-top: var(--space-2);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .example-option-btn:hover {
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.25);
|
||||
}
|
||||
307
static/css/components/modal/help-modal.css
Normal file
307
static/css/components/modal/help-modal.css
Normal file
@@ -0,0 +1,307 @@
|
||||
/* Help Modal styles */
|
||||
.help-modal {
|
||||
max-width: 850px;
|
||||
}
|
||||
|
||||
.help-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.modal-help-icon {
|
||||
font-size: 24px;
|
||||
color: var(--lora-accent);
|
||||
margin-right: var(--space-2);
|
||||
vertical-align: text-bottom;
|
||||
}
|
||||
|
||||
/* Tab navigation styles */
|
||||
.help-tabs {
|
||||
display: flex;
|
||||
border-bottom: 1px solid var(--lora-border);
|
||||
margin-bottom: var(--space-2);
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.tab-btn {
|
||||
padding: 8px 16px;
|
||||
background: transparent;
|
||||
border: none;
|
||||
border-bottom: 2px solid transparent;
|
||||
color: var(--text-color);
|
||||
cursor: pointer;
|
||||
font-weight: 500;
|
||||
transition: all 0.2s;
|
||||
opacity: 0.7;
|
||||
}
|
||||
|
||||
.tab-btn:hover {
|
||||
background-color: rgba(0, 0, 0, 0.05);
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.tab-btn.active {
|
||||
color: var(--lora-accent);
|
||||
border-bottom: 2px solid var(--lora-accent);
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* Add styles for tab with new content indicator */
|
||||
.tab-btn.has-new-content {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.tab-btn.has-new-content::after {
|
||||
content: "";
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
background-color: var(--lora-accent);
|
||||
border-radius: 50%;
|
||||
animation: pulse 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0% { opacity: 1; transform: scale(1); }
|
||||
50% { opacity: 0.7; transform: scale(1.1); }
|
||||
100% { opacity: 1; transform: scale(1); }
|
||||
}
|
||||
|
||||
/* Tab content styles */
|
||||
.help-content {
|
||||
padding: var(--space-1) 0;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.tab-pane {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tab-pane.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.help-text {
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
.help-text ul {
|
||||
padding-left: 20px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.help-text li {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
/* Documentation link styles */
|
||||
.docs-section {
|
||||
margin-bottom: var(--space-3);
|
||||
}
|
||||
|
||||
.docs-section h4 {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
margin-bottom: var(--space-1);
|
||||
}
|
||||
|
||||
.docs-links {
|
||||
list-style-type: none;
|
||||
padding-left: var(--space-3);
|
||||
}
|
||||
|
||||
.docs-links li {
|
||||
margin-bottom: var(--space-1);
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.docs-links li:before {
|
||||
content: "•";
|
||||
position: absolute;
|
||||
left: -15px;
|
||||
color: var(--lora-accent);
|
||||
}
|
||||
|
||||
.docs-links a {
|
||||
color: var(--lora-accent);
|
||||
text-decoration: none;
|
||||
transition: color 0.2s;
|
||||
}
|
||||
|
||||
.docs-links a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
/* New content badge styles */
|
||||
.new-content-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
font-size: 0.7em;
|
||||
font-weight: 600;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
padding: 2px 6px;
|
||||
border-radius: 10px;
|
||||
margin-left: 8px;
|
||||
vertical-align: middle;
|
||||
animation: fadeIn 0.5s ease-in-out;
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.2);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.new-content-badge.inline {
|
||||
font-size: 0.65em;
|
||||
padding: 1px 4px;
|
||||
margin-left: 6px;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
/* Dark theme adjustments for new content badge */
|
||||
[data-theme="dark"] .new-content-badge {
|
||||
box-shadow: 0 1px 3px rgba(0, 0, 0, 0.4);
|
||||
}
|
||||
|
||||
/* Update video list styles */
|
||||
.video-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-3);
|
||||
}
|
||||
|
||||
.video-item {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.video-info {
|
||||
padding: var(--space-1);
|
||||
}
|
||||
|
||||
.video-info h4 {
|
||||
margin-bottom: var(--space-1);
|
||||
}
|
||||
|
||||
.video-info p {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .tab-btn:hover {
|
||||
background-color: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
/* Update date badge styles */
|
||||
.update-date-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
font-size: 0.75em;
|
||||
font-weight: 500;
|
||||
background-color: var(--lora-accent);
|
||||
color: var(--lora-text);
|
||||
padding: 4px 8px;
|
||||
border-radius: 12px;
|
||||
margin-left: 10px;
|
||||
vertical-align: middle;
|
||||
animation: fadeIn 0.5s ease-in-out;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.update-date-badge i {
|
||||
margin-right: 5px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; transform: translateY(-5px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .update-date-badge {
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.3);
|
||||
}
|
||||
|
||||
/* Privacy-friendly video embed styles */
|
||||
.video-container {
|
||||
position: relative;
|
||||
width: 100%;
|
||||
padding-bottom: 56.25%; /* 16:9 aspect ratio */
|
||||
height: 0;
|
||||
margin-bottom: var(--space-2);
|
||||
border-radius: var(--border-radius-sm);
|
||||
overflow: hidden;
|
||||
background-color: rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.video-thumbnail {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.video-thumbnail img {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
transition: filter 0.2s ease;
|
||||
}
|
||||
|
||||
.video-play-overlay {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
background-color: rgba(0, 0, 0, 0.5);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
/* External link button styles */
|
||||
.external-link-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
padding: 10px 20px;
|
||||
border-radius: var(--border-radius-sm);
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
background-color: var(--lora-accent);
|
||||
color: white;
|
||||
text-decoration: none;
|
||||
border: none;
|
||||
}
|
||||
|
||||
.external-link-btn:hover {
|
||||
background-color: oklch(from var(--lora-accent) l c h / 85%);
|
||||
}
|
||||
|
||||
.video-thumbnail i {
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
/* Smaller video container for the updates tab */
|
||||
.video-item .video-container {
|
||||
padding-bottom: 40%; /* Shorter height for the playlist */
|
||||
}
|
||||
|
||||
/* Dark theme adjustments */
|
||||
[data-theme="dark"] .video-container {
|
||||
background-color: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
53
static/css/components/modal/relink-civitai-modal.css
Normal file
53
static/css/components/modal/relink-civitai-modal.css
Normal file
@@ -0,0 +1,53 @@
|
||||
/* Re-link to Civitai Modal styles */
|
||||
.warning-box {
|
||||
background-color: rgba(255, 193, 7, 0.1);
|
||||
border: 1px solid rgba(255, 193, 7, 0.5);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
margin-bottom: var(--space-3);
|
||||
}
|
||||
|
||||
.warning-box i {
|
||||
color: var(--lora-warning);
|
||||
margin-right: var(--space-1);
|
||||
}
|
||||
|
||||
.warning-box ul {
|
||||
padding-left: 20px;
|
||||
margin: var(--space-1) 0;
|
||||
}
|
||||
|
||||
.warning-box li {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.input-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.input-group label {
|
||||
margin-bottom: var(--space-1);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.input-group input {
|
||||
padding: 8px 12px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.input-error {
|
||||
color: var(--lora-error);
|
||||
font-size: 0.9em;
|
||||
min-height: 20px;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .warning-box {
|
||||
background-color: rgba(255, 193, 7, 0.05);
|
||||
border-color: rgba(255, 193, 7, 0.3);
|
||||
}
|
||||
580
static/css/components/modal/settings-modal.css
Normal file
580
static/css/components/modal/settings-modal.css
Normal file
@@ -0,0 +1,580 @@
|
||||
/* Settings styles */
|
||||
.settings-toggle {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 50%;
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.settings-toggle:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.settings-modal {
|
||||
max-width: 650px; /* Further increased from 600px for more space */
|
||||
}
|
||||
|
||||
/* Settings Links */
|
||||
.settings-links {
|
||||
margin-top: var(--space-3);
|
||||
padding-top: var(--space-2);
|
||||
border-top: 1px solid var(--lora-border);
|
||||
display: flex;
|
||||
gap: var(--space-2);
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.settings-link {
|
||||
width: 36px;
|
||||
height: 36px;
|
||||
border-radius: 50%;
|
||||
background: var(--card-bg);
|
||||
border: 1px solid var(--border-color);
|
||||
color: var(--text-color);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease;
|
||||
text-decoration: none;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.settings-link:hover {
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.settings-link i {
|
||||
font-size: 1.1em;
|
||||
}
|
||||
|
||||
/* Tooltip styles */
|
||||
.settings-link::after {
|
||||
content: attr(title);
|
||||
position: absolute;
|
||||
bottom: calc(100% + 8px);
|
||||
left: 50%;
|
||||
transform: translateX(-50%);
|
||||
background: rgba(0, 0, 0, 0.8);
|
||||
color: white;
|
||||
padding: 4px 8px;
|
||||
border-radius: 4px;
|
||||
font-size: 0.8em;
|
||||
white-space: nowrap;
|
||||
opacity: 0;
|
||||
visibility: hidden;
|
||||
transition: opacity 0.2s, visibility 0.2s;
|
||||
pointer-events: none;
|
||||
}
|
||||
|
||||
.settings-link:hover::after {
|
||||
opacity: 1;
|
||||
visibility: visible;
|
||||
}
|
||||
|
||||
/* Responsive adjustment */
|
||||
@media (max-width: 480px) {
|
||||
.settings-links {
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
}
|
||||
|
||||
/* API key input specific styles */
|
||||
.api-key-input {
|
||||
width: 100%; /* Take full width of parent */
|
||||
position: relative;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.api-key-input input {
|
||||
width: 100%;
|
||||
padding: 6px 40px 6px 10px; /* Add left padding */
|
||||
height: 32px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.api-key-input .toggle-visibility {
|
||||
position: absolute;
|
||||
right: 8px;
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--text-color);
|
||||
opacity: 0.6;
|
||||
cursor: pointer;
|
||||
padding: 4px 8px;
|
||||
}
|
||||
|
||||
.api-key-input .toggle-visibility:hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.input-help {
|
||||
font-size: 0.85em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.7;
|
||||
margin-top: 8px; /* Space between control and help */
|
||||
line-height: 1.4;
|
||||
width: 100%; /* Full width */
|
||||
}
|
||||
|
||||
/* Settings Styles */
|
||||
.settings-section {
|
||||
margin-top: var(--space-3);
|
||||
border-top: 1px solid var(--lora-border);
|
||||
padding-top: var(--space-2);
|
||||
}
|
||||
|
||||
.settings-section h3 {
|
||||
font-size: 1.1em;
|
||||
margin-bottom: var(--space-2);
|
||||
color: var(--text-color);
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.setting-item {
|
||||
display: flex;
|
||||
flex-direction: column; /* Changed to column for help text placement */
|
||||
margin-bottom: var(--space-3); /* Increased to provide more spacing between items */
|
||||
padding: var(--space-1);
|
||||
border-radius: var(--border-radius-xs);
|
||||
}
|
||||
|
||||
.setting-item:hover {
|
||||
background: rgba(0, 0, 0, 0.02);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .setting-item:hover {
|
||||
background: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
/* Control row with label and input together */
|
||||
.setting-row {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.setting-info {
|
||||
margin-bottom: 0;
|
||||
width: 35%; /* Increased from 30% to prevent wrapping */
|
||||
flex-shrink: 0; /* Prevent shrinking */
|
||||
}
|
||||
|
||||
.setting-info label {
|
||||
display: block;
|
||||
font-weight: 500;
|
||||
margin-bottom: 0;
|
||||
white-space: nowrap; /* Prevent label wrapping */
|
||||
}
|
||||
|
||||
.setting-control {
|
||||
width: 60%; /* Decreased slightly from 65% */
|
||||
margin-bottom: 0;
|
||||
display: flex;
|
||||
justify-content: flex-end; /* Right-align all controls */
|
||||
}
|
||||
|
||||
/* Select Control Styles */
|
||||
.select-control {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
}
|
||||
|
||||
.select-control select {
|
||||
width: 100%;
|
||||
max-width: 100%; /* Increased from 200px */
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.95em;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
/* Fix dark theme select dropdown text color */
|
||||
[data-theme="dark"] .select-control select {
|
||||
background-color: rgba(30, 30, 30, 0.9);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .select-control select option {
|
||||
background-color: #2d2d2d;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
.select-control select:focus {
|
||||
border-color: var(--lora-accent);
|
||||
outline: none;
|
||||
}
|
||||
|
||||
/* Toggle Switch */
|
||||
.toggle-switch {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 50px;
|
||||
height: 24px;
|
||||
cursor: pointer;
|
||||
margin-left: auto; /* Push to right side */
|
||||
}
|
||||
|
||||
.toggle-switch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
}
|
||||
|
||||
.toggle-slider {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: var(--border-color);
|
||||
transition: .3s;
|
||||
border-radius: 24px;
|
||||
}
|
||||
|
||||
.toggle-slider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
transition: .3s;
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider {
|
||||
background-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
input:checked + .toggle-slider:before {
|
||||
transform: translateX(26px);
|
||||
}
|
||||
|
||||
.toggle-label {
|
||||
margin-left: 60px;
|
||||
line-height: 24px;
|
||||
}
|
||||
|
||||
/* Add small animation for the toggle */
|
||||
.toggle-slider:active:before {
|
||||
width: 22px;
|
||||
}
|
||||
|
||||
/* Blur effect for NSFW content */
|
||||
.nsfw-blur {
|
||||
filter: blur(12px);
|
||||
transition: filter 0.3s ease;
|
||||
}
|
||||
|
||||
.nsfw-blur:hover {
|
||||
filter: blur(8px);
|
||||
}
|
||||
|
||||
/* Example Images Settings Styles */
|
||||
.download-buttons {
|
||||
justify-content: flex-start;
|
||||
gap: var(--space-2);
|
||||
}
|
||||
|
||||
.path-control {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
align-items: center;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.path-control input[type="text"] {
|
||||
flex: 1;
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var (--text-color);
|
||||
font-size: 0.95em;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
/* Add warning text style for settings */
|
||||
.warning-text {
|
||||
color: var(--lora-warning, #e67e22);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .warning-text {
|
||||
color: var(--lora-warning, #f39c12);
|
||||
}
|
||||
|
||||
/* Add styles for list description */
|
||||
.list-description {
|
||||
margin: 8px 0;
|
||||
padding-left: 20px;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
.list-description li {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
/* Path Template Settings Styles */
|
||||
.template-preview {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: var(--border-radius-xs);
|
||||
padding: var(--space-1);
|
||||
margin-top: 8px;
|
||||
font-family: monospace;
|
||||
font-size: 1.1em;
|
||||
color: var(--lora-accent);
|
||||
display: none;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .template-preview {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
}
|
||||
|
||||
.template-preview:before {
|
||||
content: "Preview: ";
|
||||
opacity: 0.7;
|
||||
color: var(--text-color);
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
/* Base Model Mappings Styles - Updated to match other settings */
|
||||
.mappings-container {
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
background: rgba(0, 0, 0, 0.02);
|
||||
margin-top: 8px; /* Add consistent spacing */
|
||||
}
|
||||
|
||||
[data-theme="dark"] .mappings-container {
|
||||
background: rgba(255, 255, 255, 0.02);
|
||||
}
|
||||
|
||||
.add-mapping-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 6px 12px;
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: var(--border-radius-xs);
|
||||
cursor: pointer;
|
||||
font-size: 0.9em;
|
||||
transition: all 0.2s;
|
||||
height: 32px; /* Match other control heights */
|
||||
}
|
||||
|
||||
.add-mapping-btn:hover {
|
||||
background: oklch(from var(--lora-accent) l c h / 85%);
|
||||
}
|
||||
|
||||
.mapping-row {
|
||||
margin-bottom: var(--space-2);
|
||||
}
|
||||
|
||||
.mapping-row:last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.mapping-controls {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr auto;
|
||||
gap: var(--space-1);
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.base-model-select,
|
||||
.path-value-input {
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.9em;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
.path-value-input {
|
||||
height: 18px;
|
||||
}
|
||||
|
||||
.base-model-select:focus,
|
||||
.path-value-input:focus {
|
||||
border-color: var(--lora-accent);
|
||||
outline: none;
|
||||
box-shadow: 0 0 0 2px rgba(var(--lora-accent-rgb, 79, 70, 229), 0.1);
|
||||
}
|
||||
|
||||
.remove-mapping-btn {
|
||||
width: 32px;
|
||||
height: 32px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--lora-error);
|
||||
background: transparent;
|
||||
color: var(--lora-error);
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.remove-mapping-btn:hover {
|
||||
background: var(--lora-error);
|
||||
color: white;
|
||||
}
|
||||
|
||||
.mapping-empty-state {
|
||||
text-align: center;
|
||||
padding: var(--space-3);
|
||||
color: var(--text-color);
|
||||
opacity: 0.6;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Responsive adjustments for mapping controls */
|
||||
@media (max-width: 768px) {
|
||||
.mapping-controls {
|
||||
grid-template-columns: 1fr;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.remove-mapping-btn {
|
||||
width: 100%;
|
||||
height: 36px;
|
||||
justify-self: stretch;
|
||||
}
|
||||
}
|
||||
|
||||
/* Dark theme specific adjustments */
|
||||
[data-theme="dark"] .base-model-select,
|
||||
[data-theme="dark"] .path-value-input {
|
||||
background-color: rgba(30, 30, 30, 0.9);
|
||||
}
|
||||
|
||||
[data-theme="dark"] .base-model-select option {
|
||||
background-color: #2d2d2d;
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
/* Template Configuration Styles */
|
||||
.placeholder-info {
|
||||
margin-top: var(--space-1);
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
gap: var(--space-1);
|
||||
}
|
||||
|
||||
.placeholder-tag {
|
||||
display: inline-block;
|
||||
background: var(--lora-accent);
|
||||
color: white;
|
||||
padding: 2px 6px;
|
||||
border-radius: 3px;
|
||||
font-family: monospace;
|
||||
font-size: 1em;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.template-custom-row {
|
||||
margin-top: 8px;
|
||||
animation: slideDown 0.2s ease-out;
|
||||
}
|
||||
|
||||
@keyframes slideDown {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-10px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
.template-custom-input {
|
||||
width: 96%;
|
||||
padding: 6px 10px;
|
||||
border-radius: var(--border-radius-xs);
|
||||
border: 1px solid var(--border-color);
|
||||
background-color: var(--lora-surface);
|
||||
color: var(--text-color);
|
||||
font-size: 0.95em;
|
||||
font-family: monospace;
|
||||
height: 24px;
|
||||
transition: border-color 0.2s;
|
||||
}
|
||||
|
||||
.template-custom-input:focus {
|
||||
border-color: var(--lora-accent);
|
||||
outline: none;
|
||||
box-shadow: 0 0 0 2px rgba(var(--lora-accent-rgb, 79, 70, 229), 0.1);
|
||||
}
|
||||
|
||||
.template-custom-input::placeholder {
|
||||
color: var(--text-color);
|
||||
opacity: 0.5;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.template-validation {
|
||||
margin-top: 6px;
|
||||
font-size: 0.85em;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
min-height: 20px;
|
||||
}
|
||||
|
||||
.template-validation.valid {
|
||||
color: var(--lora-success, #22c55e);
|
||||
}
|
||||
|
||||
.template-validation.invalid {
|
||||
color: var(--lora-error, #ef4444);
|
||||
}
|
||||
|
||||
.template-validation i {
|
||||
width: 12px;
|
||||
}
|
||||
|
||||
/* Dark theme specific adjustments */
|
||||
[data-theme="dark"] .template-custom-input {
|
||||
background-color: rgba(30, 30, 30, 0.9);
|
||||
}
|
||||
|
||||
/* Responsive adjustments */
|
||||
@media (max-width: 768px) {
|
||||
.placeholder-info {
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
}
|
||||
}
|
||||
124
static/css/components/modal/update-modal.css
Normal file
124
static/css/components/modal/update-modal.css
Normal file
@@ -0,0 +1,124 @@
|
||||
/* Update Modal specific styles */
|
||||
.update-actions {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-2);
|
||||
align-items: stretch;
|
||||
flex-wrap: nowrap;
|
||||
}
|
||||
|
||||
.update-link {
|
||||
color: var(--lora-accent);
|
||||
text-decoration: none;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.update-link:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
/* Update progress styles */
|
||||
.update-progress {
|
||||
background: rgba(0, 0, 0, 0.03);
|
||||
border: 1px solid var(--lora-border);
|
||||
border-radius: var(--border-radius-sm);
|
||||
padding: var(--space-2);
|
||||
margin: var(--space-2) 0;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .update-progress {
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
}
|
||||
|
||||
.progress-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--space-1);
|
||||
}
|
||||
|
||||
.progress-text {
|
||||
font-size: 0.9em;
|
||||
color: var(--text-color);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.update-progress-bar {
|
||||
width: 100%;
|
||||
height: 8px;
|
||||
background-color: rgba(0, 0, 0, 0.1);
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .update-progress-bar {
|
||||
background-color: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
.progress-fill {
|
||||
height: 100%;
|
||||
background-color: var(--lora-accent);
|
||||
width: 0%;
|
||||
transition: width 0.3s ease;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Update button states */
|
||||
#updateBtn {
|
||||
min-width: 120px;
|
||||
}
|
||||
|
||||
#updateBtn.updating {
|
||||
background-color: var(--lora-warning);
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
#updateBtn.success {
|
||||
background-color: var(--lora-success);
|
||||
}
|
||||
|
||||
#updateBtn.error {
|
||||
background-color: var(--lora-error);
|
||||
}
|
||||
|
||||
/* Add styles for markdown elements in changelog */
|
||||
.changelog-item ul {
|
||||
padding-left: 20px;
|
||||
margin-top: 8px;
|
||||
}
|
||||
|
||||
.changelog-item li {
|
||||
margin-bottom: 6px;
|
||||
line-height: 1.4;
|
||||
}
|
||||
|
||||
.changelog-item strong {
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.changelog-item em {
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.changelog-item code {
|
||||
background: rgba(0, 0, 0, 0.05);
|
||||
padding: 2px 4px;
|
||||
border-radius: 3px;
|
||||
font-family: monospace;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
[data-theme="dark"] .changelog-item code {
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
}
|
||||
|
||||
.changelog-item a {
|
||||
color: var(--lora-accent);
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.changelog-item a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
@@ -445,69 +445,6 @@
|
||||
border-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
/* Switch styles */
|
||||
.search-option-switch {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.switch {
|
||||
position: relative;
|
||||
display: inline-block;
|
||||
width: 46px;
|
||||
height: 24px;
|
||||
}
|
||||
|
||||
.switch input {
|
||||
opacity: 0;
|
||||
width: 0;
|
||||
height: 0;
|
||||
}
|
||||
|
||||
.slider {
|
||||
position: absolute;
|
||||
cursor: pointer;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background-color: #ccc;
|
||||
transition: .4s;
|
||||
}
|
||||
|
||||
.slider:before {
|
||||
position: absolute;
|
||||
content: "";
|
||||
height: 18px;
|
||||
width: 18px;
|
||||
left: 3px;
|
||||
bottom: 3px;
|
||||
background-color: white;
|
||||
transition: .4s;
|
||||
}
|
||||
|
||||
input:checked + .slider {
|
||||
background-color: var(--lora-accent);
|
||||
}
|
||||
|
||||
input:focus + .slider {
|
||||
box-shadow: 0 0 1px var(--lora-accent);
|
||||
}
|
||||
|
||||
input:checked + .slider:before {
|
||||
transform: translateX(22px);
|
||||
}
|
||||
|
||||
.slider.round {
|
||||
border-radius: 34px;
|
||||
}
|
||||
|
||||
.slider.round:before {
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
/* Mobile adjustments */
|
||||
@media (max-width: 768px) {
|
||||
.search-options-panel,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user