Compare commits

..

1 Commits

Author SHA1 Message Date
Will Miao
68c5f79a67 Refactor showcase and modal components for improved functionality and performance
- Removed unused showcase toggle functionality from ModelCard and ModelModal.
- Simplified metadata panel handling in MediaUtils and MetadataPanel, transitioning to button-based visibility instead of hover.
- Enhanced showcase rendering logic in ShowcaseView to support new layout and navigation features.
- Updated event handling for media controls and thumbnail navigation to streamline user interactions.
- Improved example image import functionality and error handling.
- Cleaned up redundant code and comments across various components for better readability and maintainability.
2025-07-27 15:52:09 +08:00
285 changed files with 12356 additions and 45812 deletions

View File

@@ -1 +0,0 @@
Always use English for comments.

3
.gitignore vendored
View File

@@ -5,6 +5,3 @@ output/*
py/run_test.py py/run_test.py
.vscode/ .vscode/
cache/ cache/
civitai/
node_modules/
coverage/

View File

@@ -1,19 +0,0 @@
# Repository Guidelines
## Project Structure & Module Organization
ComfyUI LoRA Manager pairs a Python backend with lightweight browser scripts. Backend modules live in `py/`, organized by responsibility: HTTP entry points under `routes/`, feature logic in `services/`, reusable helpers within `utils/`, and custom nodes in `nodes/`. Front-end widgets that extend the ComfyUI interface sit in `web/comfyui/`, while static images and templates are in `static/` and `templates/`. Shared localization files are stored in `locales/`, with workflow examples under `example_workflows/`. Tests currently reside alongside the source (`test_i18n.py`) until a dedicated `tests/` folder is introduced.
## Build, Test, and Development Commands
Install dependencies with `pip install -r requirements.txt` from the repo root. Launch the standalone server for iterative work via `python standalone.py --port 8188`; ComfyUI users can also load the extension directly through ComfyUI's custom node manager. Run backend checks with `python -m pytest test_i18n.py`, and target new test files explicitly (e.g. `python -m pytest tests/test_recipes.py` once added). Use `python scripts/sync_translation_keys.py` to reconcile locale keys after updating UI strings.
## Coding Style & Naming Conventions
Follow PEP 8 with four-space indentation and descriptive snake_case module/function names, mirroring files such as `py/services/settings_manager.py`. Classes remain PascalCase, constants UPPER_SNAKE_CASE, and loggers retrieved via `logging.getLogger(__name__)`. Prefer explicit type hints for new public APIs and docstrings that clarify side effects. JavaScript in `web/comfyui/` is modern ES modules; keep imports relative, favor camelCase functions, and mirror existing file suffixes like `_widget.js` for UI components.
## Testing Guidelines
Extend pytest coverage by co-locating tests near the code under test or in `tests/` with names like `test_<feature>.py`. When introducing new routes or services, add regression cases that mock ComfyUI dependencies (see the standalone mocking helpers in `standalone.py`). Prioritize deterministic fixtures for filesystem interactions and ensure translations include coverage when adding new locale keys. Always run `python -m pytest` before submitting work.
## Commit & Pull Request Guidelines
Commits follow the conventional pattern seen in `git log` (`feat(scope):`, `fix(scope):`, `chore(scope):`). Keep messages imperative and scoped to a single change. Pull requests should summarize the problem, detail the solution, list manual test evidence, and link any GitHub issues. Include UI screenshots or GIFs when front-end behavior changes, and call out migration steps (e.g., settings updates) in the PR description.
## Configuration & Localization Tips
Sample configuration defaults live in `settings.json.example`; copy it to `settings.json` and adjust model directories before running the standalone server. Whenever you add UI text, update `locales/<lang>.json` and run the translation sync script. Store reference assets in `civitai/` or `docs/` rather than mixing them with production templates, keeping the runtime folders (`static/`, `templates/`) deploy-ready.

110
README.md
View File

@@ -34,53 +34,79 @@ Enhance your Civitai browsing experience with our companion browser extension! S
## Release Notes ## Release Notes
### v0.9.3 ### v0.8.20
* **Metadata Archive Database Support** - Added the ability to download and utilize a metadata archive database, enabling access to metadata for models that have been deleted from CivitAI. * **LM Civitai Extension** - Released [browser extension through Chrome Web Store](https://chromewebstore.google.com/detail/lm-civitai-extension/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) that works seamlessly with LoRA Manager to enhance Civitai browsing experience, showing which models are already in your local library, enabling one-click downloads, and providing queue and parallel download support
* **App-Level Proxy Settings** - Introduced support for configuring a global proxy within the application, making it easier to use the manager behind network restrictions. * **Enhanced Lora Loader** - Added support for nunchaku, improving convenience when working with ComfyUI-nunchaku workflows, plus new template workflows for quick onboarding
* **Bug Fixes** - Various bug fixes for improved stability and reliability. * **WanVideo Integration** - Introduced WanVideo Lora Select (LoraManager) node compatible with ComfyUI-WanVideoWrapper for streamlined lora usage in video workflows, including a template workflow to help you get started quickly
### v0.9.2 ### v0.8.19
* **Bulk Auto-Organization Action** - Added a new bulk auto-organization feature. You can now select multiple models and automatically organize them according to your current path template settings for streamlined management. * **Analytics Dashboard** - Added new Statistics page providing comprehensive visual analysis of model collection and usage patterns for better library insights
* **Bug Fixes** - Addressed several bugs to improve stability and reliability. * **Target Node Selection** - Enhanced workflow integration with intelligent target choosing when sending LoRAs/recipes to workflows with multiple loader/stacker nodes; a visual selector now appears showing node color, type, ID, and title for precise targeting
* **Enhanced NSFW Controls** - Added support for setting NSFW levels on recipes with automatic content blurring based on user preferences
* **Customizable Card Display** - New display settings allowing users to choose whether card information and action buttons are always visible or only revealed on hover
* **Expanded Compatibility** - Added support for efficiency-nodes-comfyui in Save Recipe and Save Image nodes, plus fixed compatibility with ComfyUI_Custom_Nodes_AlekPet
### v0.9.1 ### v0.8.18
* **Enhanced Bulk Operations** - Improved bulk operations with Marquee Selection and a bulk operation context menu, providing a more intuitive, desktop-application-like user experience. * **Custom Example Images** - Added ability to import your own example images for LoRAs and checkpoints with automatic metadata extraction from embedded information
* **New Bulk Actions** - Added bulk operations for adding tags and setting base models to multiple models simultaneously. * **Enhanced Example Management** - New action buttons to set specific examples as previews or delete custom examples
* **Improved Duplicate Detection** - Enhanced "Find Duplicates" with hash verification feature to eliminate false positives when identifying duplicate models
* **Tag Management** - Added tag editing functionality allowing users to customize and manage model tags
* **Advanced Selection Controls** - Implemented Ctrl+A shortcut for quickly selecting all filtered LoRAs, automatically entering bulk mode when needed
* **Note**: Cache file functionality temporarily disabled pending rework
### v0.9.0 ### v0.8.17
* **UI Overhaul for Enhanced Navigation** - Replaced the top flat folder tags with a new folder sidebar and breadcrumb navigation system for a more intuitive folder browsing and selection experience. * **Duplicate Model Detection** - Added "Find Duplicates" functionality for LoRAs and checkpoints using model file hash detection, enabling convenient viewing and batch deletion of duplicate models
* **Dual-Mode Folder Sidebar** - The new folder sidebar offers two display modes: 'List Mode,' which mirrors the classic folder view, and 'Tree Mode,' which presents a hierarchical folder structure for effortless navigation through nested directories. * **Enhanced URL Recipe Imports** - Optimized import recipe via URL functionality using CivitAI API calls instead of web scraping, now supporting all rated images (including NSFW) for recipe imports
* **Internationalization Support** - Introduced multi-language support, now available in English, Simplified Chinese, Traditional Chinese, Spanish, Japanese, Korean, French, Russian, and German. Feedback from native speakers is welcome to improve the translations. * **Improved TriggerWord Control** - Enhanced TriggerWord Toggle node with new default_active switch to set the initial state (active/inactive) when trigger words are added
* **Automatic Filename Conflict Resolution** - Implemented automatic file renaming (`original name + short hash`) to prevent conflicts when downloading or moving models. * **Centralized Example Management** - Added "Migrate Existing Example Images" feature to consolidate downloaded example images from model folders into central storage with customizable naming patterns
* **Performance Optimizations & Bug Fixes** - Various performance improvements and bug fixes for a more stable and responsive experience. * **Intelligent Word Suggestions** - Implemented smart trigger word suggestions by reading class tokens and tag frequency from safetensors files, displaying recommendations when editing trigger words
* **Model Version Management** - Added "Re-link to CivitAI" context menu option for connecting models to different CivitAI versions when needed
### v0.8.30 ### v0.8.16
* **Automatic Model Path Correction** - Added auto-correction for model paths in built-in nodes such as Load Checkpoint, Load Diffusion Model, Load LoRA, and other custom nodes with similar functionality. Workflows containing outdated or incorrect model paths will now be automatically updated to reflect the current location of your models. * **Dramatic Startup Speed Improvement** - Added cache serialization mechanism for significantly faster loading times, especially beneficial for large model collections
* **Node UI Enhancements** - Improved node interface for a smoother and more intuitive user experience. * **Enhanced Refresh Options** - Extended functionality with "Full Rebuild (complete)" option alongside "Quick Refresh (incremental)" to fix potential memory cache issues without requiring application restart
* **Bug Fixes** - Addressed various bugs to enhance stability and reliability. * **Customizable Display Density** - Replaced compact mode with adjustable display density settings for personalized layout customization
* **Model Creator Information** - Added creator details to model information panels for better attribution
* **Improved WebP Support** - Enhanced Save Image node with workflow embedding capability for WebP format images
* **Direct Example Access** - Added "Open Example Images Folder" button to card interfaces for convenient browsing of downloaded model examples
* **Enhanced Compatibility** - Full ComfyUI Desktop support for "Send lora or recipe to workflow" functionality
* **Cache Management** - Added settings to clear existing cache files when needed
* **Bug Fixes & Stability** - Various improvements for overall reliability and performance
### v0.8.29 ### v0.8.15
* **Enhanced Recipe Imports** - Improved recipe importing with new target folder selection, featuring path input autocomplete and interactive folder tree navigation. Added a "Use Default Path" option when downloading missing LoRAs. * **Enhanced One-Click Integration** - Replaced copy button with direct send button allowing LoRAs/recipes to be sent directly to your current ComfyUI workflow without needing to paste
* **WanVideo Lora Select Node Update** - Updated the WanVideo Lora Select node with a 'merge_loras' option to match the counterpart node in the WanVideoWrapper node package. * **Flexible Workflow Integration** - Click to append LoRAs/recipes to existing loader nodes or Shift+click to replace content, with additional right-click menu options for "Send to Workflow (Append)" or "Send to Workflow (Replace)"
* **Autocomplete Conflict Resolution** - Resolved an autocomplete feature conflict in LoRA nodes with pysssss autocomplete. * **Improved LoRA Loader Controls** - Added header drag functionality for proportional strength adjustment of all LoRAs simultaneously (including CLIP strengths when expanded)
* **Improved Download Functionality** - Enhanced download functionality with resumable downloads and improved error handling. * **Keyboard Navigation Support** - Implemented Page Up/Down for page scrolling, Home key to jump to top, and End key to jump to bottom for faster browsing through large collections
* **Bug Fixes** - Addressed several bugs for improved stability and performance.
### v0.8.28 ### v0.8.14
* **Autocomplete for Node Inputs** - Instantly find and add LoRAs by filename directly in Lora Loader, Lora Stacker, and WanVideo Lora Select nodes. Autocomplete suggestions include preview tooltips and preset weights, allowing you to quickly select LoRAs without opening the LoRA Manager UI. * **Virtualized Scrolling** - Completely rebuilt rendering mechanism for smooth browsing with no lag or freezing, now supporting virtually unlimited model collections with optimized layouts for large displays, improving space utilization and user experience
* **Duplicate Notification Control** - Added a switch to duplicates mode, enabling users to turn off duplicate model notifications for a more streamlined experience. * **Compact Display Mode** - Added space-efficient view option that displays more cards per row (7 on 1080p, 8 on 2K, 10 on 4K)
* **Download Example Images from Context Menu** - Introduced a new context menu option to download example images for individual models. * **Enhanced LoRA Node Functionality** - Comprehensive improvements to LoRA loader/stacker nodes including real-time trigger word updates (reflecting any change anywhere in the LoRA chain for precise updates) and expanded context menu with "Copy Notes" and "Copy Trigger Words" options for faster workflow
### v0.8.27 ### v0.8.13
* **User Experience Enhancements** - Improved the model download target folder selection with path input autocomplete and interactive folder tree navigation, making it easier and faster to choose where models are saved. * **Enhanced Recipe Management** - Added "Find duplicates" feature to identify and batch delete duplicate recipes with duplicate detection notifications during imports
* **Default Path Option for Downloads** - Added a "Use Default Path" option when downloading models. When enabled, models are automatically organized and stored according to your configured path template settings. * **Improved Source Tracking** - Source URLs are now saved with recipes imported via URL, allowing users to view original content with one click or manually edit links
* **Advanced Download Path Templates** - Expanded path template settings, allowing users to set individual templates for LoRA, checkpoint, and embedding models for greater flexibility. Introduced the `{author}` placeholder, enabling automatic organization of model files by creator name. * **Advanced LoRA Control** - Double-click LoRAs in Loader/Stacker nodes to access expanded CLIP strength controls for more precise adjustments of model and CLIP strength separately
* **Bug Fixes & Stability Improvements** - Addressed various bugs and improved overall stability for a smoother experience. * **Lycoris Model Support** - Added compatibility with Lycoris models for expanded creative options
* **Bug Fixes & UX Improvements** - Resolved various issues and enhanced overall user experience with numerous optimizations
### v0.8.26 ### v0.8.12
* **Creator Search Option** - Added ability to search models by creator name, making it easier to find models from specific authors. * **Enhanced Model Discovery** - Added alphabetical navigation bar to LoRAs page for faster browsing through large collections
* **Enhanced Node Usability** - Improved user experience for Lora Loader, Lora Stacker, and WanVideo Lora Select nodes by fixing the maximum height of the text input area. Users can now freely and conveniently adjust the LoRA region within these nodes. * **Optimized Example Images** - Improved download logic to automatically refresh stale metadata before fetching example images
* **Compatibility Fixes** - Resolved compatibility issues with ComfyUI and certain custom nodes, including ComfyUI-Custom-Scripts, ensuring smoother integration and operation. * **Model Exclusion System** - New right-click option to exclude specific LoRAs or checkpoints from management
* **Improved Showcase Experience** - Enhanced interaction in LoRA and checkpoint showcase areas for better usability
### v0.8.11
* **Offline Image Support** - Added functionality to download and save all model example images locally, ensuring access even when offline or if images are removed from CivitAI or the site is down
* **Resilient Download System** - Implemented pause/resume capability with checkpoint recovery that persists through restarts or unexpected exits
* **Bug Fixes & Stability** - Resolved various issues to enhance overall reliability and performance
### v0.8.10
* **Standalone Mode** - Run LoRA Manager independently from ComfyUI for a lightweight experience that works even with other stable diffusion interfaces
* **Portable Edition** - New one-click portable version for easy startup and updates in standalone mode
* **Enhanced Metadata Collection** - Added support for SamplerCustomAdvanced node in the metadata collector module
* **Improved UI Organization** - Optimized Lora Loader node height to display up to 5 LoRAs at once with scrolling capability for larger collections
[View Update History](./update_logs.md) [View Update History](./update_logs.md)
@@ -139,11 +165,10 @@ Enhance your Civitai browsing experience with our companion browser extension! S
### Option 2: **Portable Standalone Edition** (No ComfyUI required) ### Option 2: **Portable Standalone Edition** (No ComfyUI required)
1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.9.2/lora_manager_portable.7z) 1. Download the [Portable Package](https://github.com/willmiao/ComfyUI-Lora-Manager/releases/download/v0.8.15/lora_manager_portable.7z)
2. Copy the provided `settings.json.example` file to create a new file named `settings.json` in `comfyui-lora-manager` folder 2. Copy the provided `settings.json.example` file to create a new file named `settings.json` in `comfyui-lora-manager` folder
3. Edit `settings.json` to include your correct model folder paths and CivitAI API key 3. Edit `settings.json` to include your correct model folder paths and CivitAI API key
4. Run run.bat 4. Run run.bat
- To change the startup port, edit `run.bat` and modify the parameter (e.g. `--port 9001`)
### Option 3: **Manual Installation** ### Option 3: **Manual Installation**
@@ -273,6 +298,3 @@ Join our Discord community for support, discussions, and updates:
[Discord Server](https://discord.gg/vcqNrWVFvM) [Discord Server](https://discord.gg/vcqNrWVFvM)
--- ---
## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=willmiao/ComfyUI-Lora-Manager&type=Date)](https://star-history.com/#willmiao/ComfyUI-Lora-Manager&Date)

View File

@@ -1,42 +1,20 @@
try: # pragma: no cover - import fallback for pytest collection from .py.lora_manager import LoraManager
from .py.lora_manager import LoraManager from .py.nodes.lora_loader import LoraManagerLoader
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader from .py.nodes.trigger_word_toggle import TriggerWordToggle
from .py.nodes.trigger_word_toggle import TriggerWordToggle from .py.nodes.lora_stacker import LoraStacker
from .py.nodes.lora_stacker import LoraStacker from .py.nodes.save_image import SaveImage
from .py.nodes.save_image import SaveImage from .py.nodes.debug_metadata import DebugMetadata
from .py.nodes.debug_metadata import DebugMetadata from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect # Import metadata collector to install hooks on startup
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText from .py.metadata_collector import init as init_metadata_collector
from .py.metadata_collector import init as init_metadata_collector
except ImportError: # pragma: no cover - allows running under pytest without package install
import importlib
import pathlib
import sys
package_root = pathlib.Path(__file__).resolve().parent
if str(package_root) not in sys.path:
sys.path.append(str(package_root))
LoraManager = importlib.import_module("py.lora_manager").LoraManager
LoraManagerLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerLoader
LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader
TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
SaveImage = importlib.import_module("py.nodes.save_image").SaveImage
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
WanVideoLoraSelect = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelect
WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText
init_metadata_collector = importlib.import_module("py.metadata_collector").init
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
LoraManagerLoader.NAME: LoraManagerLoader, LoraManagerLoader.NAME: LoraManagerLoader,
LoraManagerTextLoader.NAME: LoraManagerTextLoader,
TriggerWordToggle.NAME: TriggerWordToggle, TriggerWordToggle.NAME: TriggerWordToggle,
LoraStacker.NAME: LoraStacker, LoraStacker.NAME: LoraStacker,
SaveImage.NAME: SaveImage, SaveImage.NAME: SaveImage,
DebugMetadata.NAME: DebugMetadata, DebugMetadata.NAME: DebugMetadata,
WanVideoLoraSelect.NAME: WanVideoLoraSelect, WanVideoLoraSelect.NAME: WanVideoLoraSelect
WanVideoLoraSelectFromText.NAME: WanVideoLoraSelectFromText
} }
WEB_DIRECTORY = "./web/comfyui" WEB_DIRECTORY = "./web/comfyui"

View File

@@ -1,180 +0,0 @@
## Overview
The **LoRA Manager Civitai Extension** is a Browser extension designed to work seamlessly with [LoRA Manager](https://github.com/willmiao/ComfyUI-Lora-Manager) to significantly enhance your browsing experience on [Civitai](https://civitai.com).
It also supports browsing on [CivArchive](https://civarchive.com/) (formerly CivitaiArchive).
With this extension, you can:
✅ Instantly see which models are already present in your local library
✅ Download new models with a single click
✅ Manage downloads efficiently with queue and parallel download support
✅ Keep your downloaded models automatically organized according to your custom settings
![Civitai Models page](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/wiki-images/civitai-models-page.png)
![CivArchive Models page](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/wiki-images/civarchive-models-page.png)
---
## Why Are All Features for Supporters Only?
I love building tools for the Stable Diffusion and ComfyUI communities, and LoRA Manager is a passion project that I've poured countless hours into. When I created this companion extension, my hope was to offer its core features for free, as a thank-you to all of you.
Unfortunately, I've reached a point where I need to be realistic. The level of support from the free model has been far lower than what's needed to justify the continuous development and maintenance for both projects. It was a difficult decision, but I've chosen to make the extension's features exclusive to supporters.
This change is crucial for me to be able to continue dedicating my time to improving the free and open-source LoRA Manager, which I'm committed to keeping available for everyone.
Your support does more than just unlock a few features—it allows me to keep innovating and ensures the core LoRA Manager project thrives. I'm incredibly grateful for your understanding and any support you can offer. ❤️
(_For those who previously supported me on Ko-fi with a one-time donation, I'll be sending out license keys individually as a thank-you._)
---
## Installation
### Supported Browsers & Installation Methods
| Browser | Installation Method |
|--------------------|-------------------------------------------------------------------------------------|
| **Google Chrome** | [Chrome Web Store link](https://chromewebstore.google.com/detail/capigligggeijgmocnaflanlbghnamgm?utm_source=item-share-cb) |
| **Microsoft Edge** | Install via Chrome Web Store (compatible) |
| **Brave Browser** | Install via Chrome Web Store (compatible) |
| **Opera** | Install via Chrome Web Store (compatible) |
| **Firefox** | <div id="firefox-install" class="install-ok"><a href="https://github.com/willmiao/lm-civitai-extension-firefox/releases/latest/download/extension.xpi">📦 Install Firefox Extension (reviewed and verified by Mozilla)</a></div> |
For non-Chrome browsers (e.g., Microsoft Edge), you can typically install extensions from the Chrome Web Store by following these steps: open the extensions Chrome Web Store page, click 'Get extension', then click 'Allow' when prompted to enable installations from other stores, and finally click 'Add extension' to complete the installation.
---
## Privacy & Security
I understand concerns around browser extensions and privacy, and I want to be fully transparent about how the **LM Civitai Extension** works:
- **Reviewed and Verified**
This extension has been **manually reviewed and approved by the Chrome Web Store**. The Firefox version uses the **exact same code** (only the packaging format differs) and has passed **Mozillas Add-on review**.
- **Minimal Network Access**
The only external server this extension connects to is:
**`https://willmiao.shop`** — used solely for **license validation**.
It does **not collect, transmit, or store any personal or usage data**.
No browsing history, no user IDs, no analytics, no hidden trackers.
- **Local-Only Model Detection**
Model detection and LoRA Manager communication all happen **locally** within your browser, directly interacting with your local LoRA Manager backend.
I value your trust and are committed to keeping your local setup private and secure. If you have any questions, feel free to reach out!
---
## How to Use
After installing the extension, you'll automatically receive a **7-day trial** to explore all features.
When the extension is correctly installed and your license is valid:
- Open **Civitai**, and you'll see visual indicators added by the extension on model cards, showing:
- ✅ Models already present in your local library
- ⬇️ A download button for models not in your library
Clicking the download button adds the corresponding model version to the download queue, waiting to be downloaded. You can set up to **5 models to download simultaneously**.
### Visual Indicators Appear On:
- **Home Page** — Featured models
- **Models Page**
- **Creator Profiles** — If the creator has set their models to be visible
- **Recommended Resources** — On individual model pages
### Version Buttons on Model Pages
On a specific model page, visual indicators also appear on version buttons, showing which versions are already in your local library.
When switching to a specific version by clicking a version button:
- Clicking the download button will open a dropdown:
- Download via **LoRA Manager**
- Download via **Original Download** (browser download)
You can check **Remember my choice** to set your preferred default. You can change this setting anytime in the extension's settings.
![Civitai Model Page](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/wiki-images/civitai-model-page.png)
### Resources on Image Pages (2025-08-05) — now shows in-library indicators for image resources. Import image as recipe coming soon!
![Civitai Image Page](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/wiki-images/civitai-image-page.jpg)
---
## Model Download Location & LoRA Manager Settings
To use the **one-click download function**, you must first set:
- Your **Default LoRAs Root**
- Your **Default Checkpoints Root**
These are set within LoRA Manager's settings.
When everything is configured, downloaded model files will be placed in:
`<Default_Models_Root>/<Base_Model_of_the_Model>/<First_Tag_of_the_Model>`
### Update: Default Path Customization (2025-07-21)
A new setting to customize the default download path has been added in the nightly version. You can now personalize where models are saved when downloading via the LM Civitai Extension.
![Default Path Customization](https://github.com/willmiao/ComfyUI-Lora-Manager/blob/main/wiki-images/default-path-customization.png)
The previous YAML path mapping file will be deprecated—settings will now be unified in settings.json to simplify configuration.
---
## Backend Port Configuration
If your **ComfyUI** or **LoRA Manager** backend is running on a port **other than the default 8188**, you must configure the backend port in the extension's settings.
After correctly setting and saving the port, you'll see in the extension's header area:
- A **Healthy** status with the tooltip: `Connected to LoRA Manager on port xxxx`
---
## Advanced Usage
### Connecting to a Remote LoRA Manager
If your LoRA Manager is running on another computer, you can still connect from your browser using port forwarding.
> **Why can't you set a remote IP directly?**
>
> For privacy and security, the extension only requests access to `http://127.0.0.1/*`. Supporting remote IPs would require much broader permissions, which may be rejected by browser stores and could raise user concerns.
**Solution: Port Forwarding with `socat`**
On your browser computer, run:
`socat TCP-LISTEN:8188,bind=127.0.0.1,fork TCP:REMOTE.IP.ADDRESS.HERE:8188`
- Replace `REMOTE.IP.ADDRESS.HERE` with the IP of the machine running LoRA Manager.
- Adjust the port if needed.
This lets the extension connect to `127.0.0.1:8188` as usual, with traffic forwarded to your remote server.
_Thanks to user **Temikus** for sharing this solution!_
---
## Roadmap
The extension will evolve alongside **LoRA Manager** improvements. Planned features include:
- [x] Support for **additional model types** (e.g., embeddings)
- [ ] One-click **Recipe Import**
- [x] Display of in-library status for all resources in the **Resources Used** section of the image page
- [x] One-click **Auto-organize Models**
**Stay tuned — and thank you for your support!**
---

View File

@@ -1,93 +0,0 @@
# Example image route architecture
The example image routing stack mirrors the layered model route stack described in
[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup,
handler orchestration, and long-running workflows now live in clearly separated modules so
we can extend download/import behaviour without touching the entire feature surface.
```mermaid
graph TD
subgraph HTTP
A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller]
end
subgraph Application
B --> C[ExampleImagesHandlerSet]
C --> D1[Handlers]
D1 --> E1[Use cases]
E1 --> F1[Download manager / processor / file manager]
end
subgraph Side Effects
F1 --> G1[Filesystem]
F1 --> G2[Model metadata]
F1 --> G3[WebSocket progress]
end
```
## Layer responsibilities
| Layer | Module(s) | Responsibility |
| --- | --- | --- |
| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. |
| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. |
| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. |
| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. |
| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. |
## Handler responsibilities & invariants
`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}`
mapping consumed by the registrar. The table below outlines how each handler collaborates
with the use cases and utilities.
| Handler | Key endpoints | Collaborators | Contracts |
| --- | --- | --- | --- |
| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. |
| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. |
| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. |
## Use case boundaries
| Use case | Entry point | Dependencies | Guarantees |
| --- | --- | --- | --- |
| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. |
| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. |
## Maintaining critical invariants
* **Shared progress snapshots** - The download handler returns the same snapshot built by
`DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket
progress events.
* **Safe filesystem access** - All folder/file actions flow through
`ExampleImagesFileManager`, which validates the configured example image root and ensures
responses never leak absolute paths outside the allowed directory.
* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`,
which updates model metadata via `MetadataManager` and notifies the relevant scanners so
cache state stays in sync.
## Migration notes
The refactor brings the example image stack in line with the model/recipe stacks:
1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects
should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring
handler callables.
2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like
`BaseModelRoutes`. If you previously instantiated handlers directly, inject custom
collaborators via the controller constructor (`download_manager`, `processor`,
`file_manager`) to keep test seams predictable.
3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching
`DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers
expect those abstractions to surface validation/concurrency errors, and bypassing them
will skip the HTTP-friendly error mapping.
## Extending the stack
1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`.
2. Expose the coroutine on an existing handler class (or create a new handler and extend
`ExampleImagesHandlerSet`).
3. Wire additional services or factories inside `_build_handler_set` on
`ExampleImagesRoutes`, mirroring how the model stack introduces new use cases.
`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause
flows, and import validations. Use it as a template when introducing new handler
collaborators or error mappings.

View File

@@ -1,100 +0,0 @@
# Base model route architecture
The model routing stack now splits HTTP wiring, orchestration logic, and
business rules into discrete layers. The goal is to make it obvious where a
new collaborator should live and which contract it must honour. The diagram
below captures the end-to-end flow for a typical request:
```mermaid
graph TD
subgraph HTTP
A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy]
end
subgraph Application
B --> C[ModelHandlerSet]
C --> D1[Handlers]
D1 --> E1[Use cases]
E1 --> F1[Services / scanners]
end
subgraph Side Effects
F1 --> G1[Cache & metadata]
F1 --> G2[Filesystem]
F1 --> G3[WebSocket state]
end
```
Every box maps to a concrete module:
| Layer | Module(s) | Responsibility |
| --- | --- | --- |
| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. |
| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. |
| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). |
| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. |
| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. |
## Handler responsibilities & contracts
`ModelHandlerSet` flattens the handler objects into the exact callables used by
the registrar. The table below highlights the separation of concerns within
the set and the invariants that must hold after each handler returns.
| Handler | Key endpoints | Collaborators | Contracts |
| --- | --- | --- | --- |
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. |
| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. |
## Use case boundaries
Each use case exposes a narrow asynchronous API that hides the underlying
services. Their error mapping is essential for predictable HTTP responses.
| Use case | Entry point | Dependencies | Guarantees |
| --- | --- | --- | --- |
| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. |
| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. |
| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. |
## Maintaining legacy contracts
The refactor preserves the invariants called out in the previous architecture
notes. The most critical ones are reiterated here to emphasise the
collaboration points:
1. **Cache mutations** Delete, exclude, rename, and bulk delete operations are
channelled through `ModelManagementHandler`. The handler delegates to
`ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is
mutated in-place before the handler returns. The accompanying tests assert
that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after
each mutation.
2. **Preview updates** `PreviewAssetService.replace_preview` writes the new
asset, `MetadataSyncService` persists the JSON metadata, and
`scanner.update_preview_in_cache` mirrors the change. The handler returns
the static URL produced by `config.get_preview_static_url`, keeping browser
clients in lockstep with disk state.
3. **Download progress** `DownloadCoordinator.schedule_download` generates the
download identifier, registers a WebSocket progress callback, and caches the
latest numeric progress via `WebSocketManager`. Both `download_model`
responses and `/download-progress/{id}` polling read from the same cache to
guarantee consistent progress reporting across transports.
## Extending the stack
To add a new shared route:
1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name.
2. Implement the corresponding coroutine on one of the handlers inside
`ModelHandlerSet` (or introduce a new handler class when the concern does not
fit existing ones).
3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by
wiring services or use cases through the constructor parameters.
Model-specific routes should continue to be registered inside the subclass
implementation of `setup_specific_routes`, reusing the shared registrar where
possible.

View File

@@ -1,89 +0,0 @@
# Recipe route architecture
The recipe routing stack now mirrors the modular model route design. HTTP
bindings, controller wiring, handler orchestration, and business rules live in
separate layers so new behaviours can be added without re-threading the entire
feature. The diagram below outlines the flow for a typical request:
```mermaid
graph TD
subgraph HTTP
A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller]
end
subgraph Application
B --> C[RecipeHandlerSet]
C --> D1[Handlers]
D1 --> E1[Use cases]
E1 --> F1[Services / scanners]
end
subgraph Side Effects
F1 --> G1[Cache & fingerprint index]
F1 --> G2[Metadata files]
F1 --> G3[Temporary shares]
end
```
## Layer responsibilities
| Layer | Module(s) | Responsibility |
| --- | --- | --- |
| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. |
| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. |
| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. |
| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. |
## Handler responsibilities & invariants
`RecipeHandlerSet` flattens purpose-built handler objects into the callables the
registrar binds. Each handler is responsible for a narrow concern and enforces a
set of invariants before returning:
| Handler | Key endpoints | Collaborators | Contracts |
| --- | --- | --- | --- |
| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. |
| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. |
| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. |
| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. |
| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. |
| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. |
## Use case boundaries
The dedicated services encapsulate long-running work so handlers stay thin.
| Use case | Entry point | Dependencies | Guarantees |
| --- | --- | --- | --- |
| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. |
| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. |
| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. |
## Maintaining critical invariants
* **Cache updates** Mutations (`save`, `delete`, `bulk_delete`, `update`) call
back into the recipe scanner to mutate the in-memory cache and fingerprint
index before returning a response. Tests assert that these methods are invoked
even when stubbing persistence.
* **Fingerprint management** `RecipePersistenceService` recomputes
fingerprints whenever LoRA metadata changes and duplicate lookups use those
fingerprints to group recipes. Handlers bubble the resulting IDs so clients
can merge duplicates without an extra fetch.
* **Metadata synchronisation** Saving or reconnecting a recipe updates the
JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the
scanner to resort its cache. Sharing relies on this metadata to generate
filenames and ensure downloads stay in sync with on-disk state.
## Extending the stack
1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name.
2. Implement the coroutine on an existing handler or introduce a new handler
class inside `py/routes/handlers/recipe_handlers.py` when the concern does
not fit existing ones.
3. Wire additional collaborators inside
`BaseRecipeRoutes._create_handler_set` (inject new services or factories) and
expose helper getters on the handler owner if the handler needs to share
utilities.
Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing,
mutation, analysis-error, and sharing paths end-to-end, ensuring the controller
and handler wiring remains valid as new capabilities are added.

View File

@@ -1,23 +0,0 @@
# Frontend Automation Testing Roadmap
This roadmap tracks the planned rollout of automated testing for the ComfyUI LoRA Manager frontend. Each phase builds on the infrastructure introduced in this change set and records progress so future contributors can quickly identify the next tasks.
## Phase Overview
| Phase | Goal | Primary Focus | Status | Notes |
| --- | --- | --- | --- | --- |
| Phase 0 | Establish baseline tooling | Add Node test runner, jsdom environment, and seed smoke tests | ✅ Complete | Vitest + jsdom configured, example state tests committed |
| Phase 1 | Cover state management logic | Unit test selectors, derived data helpers, and storage utilities under `static/js/state` and `static/js/utils` | ✅ Complete | Storage helpers and state selectors now exercised via deterministic suites |
| Phase 2 | Test AppCore orchestration | Simulate page bootstrapping, infinite scroll hooks, and manager registration using JSDOM DOM fixtures | 🟡 In Progress | AppCore initialization specs landed; expand to additional page wiring and scroll hooks |
| Phase 3 | Validate page-specific managers | Add focused suites for `loras`, `checkpoints`, `embeddings`, and `recipes` managers covering filtering, sorting, and bulk actions | ⚪ Not Started | Consider shared helpers for mocking API modules and storage |
| Phase 4 | Interaction-level regression tests | Exercise template fragments, modals, and menus to ensure UI wiring remains intact | ⚪ Not Started | Evaluate Playwright component testing or happy-path DOM snapshots |
| Phase 5 | Continuous integration & coverage | Integrate frontend tests into CI workflow and track coverage metrics | ⚪ Not Started | Align reporting directories with backend coverage for unified reporting |
## Next Steps Checklist
- [x] Expand unit tests for `storageHelpers` covering migrations and namespace behavior.
- [ ] Document DOM fixture strategy for reproducing template structures in tests.
- [x] Prototype AppCore initialization test that verifies manager bootstrapping with stubbed dependencies.
- [ ] Evaluate integrating coverage reporting once test surface grows (> 20 specs).
Maintaining this roadmap alongside code changes will make it easier to append new automated test tasks and update their progress.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

2572
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +0,0 @@
{
"name": "comfyui-lora-manager-frontend",
"version": "0.1.0",
"private": true,
"type": "module",
"scripts": {
"test": "vitest run",
"test:watch": "vitest"
},
"devDependencies": {
"jsdom": "^24.0.0",
"vitest": "^1.6.0"
}
}

View File

@@ -1,12 +0,0 @@
"""Project namespace package."""
# pytest's internal compatibility layer still imports ``py.path.local`` from the
# historical ``py`` dependency. Because this project reuses the ``py`` package
# name, we expose a minimal shim so ``py.path.local`` resolves to ``pathlib.Path``
# during test runs without pulling in the external dependency.
from pathlib import Path
from types import SimpleNamespace
path = SimpleNamespace(local=Path)
__all__ = ["path"]

View File

@@ -3,11 +3,11 @@ import platform
import folder_paths # type: ignore import folder_paths # type: ignore
from typing import List from typing import List
import logging import logging
import sys
import json import json
import urllib.parse
# Use an environment variable to control standalone mode # Check if running in standalone mode
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = 'nodes' not in sys.modules
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,7 +17,6 @@ class Config:
def __init__(self): def __init__(self):
self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates') self.templates_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'templates')
self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static') self.static_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'static')
self.i18n_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'locales')
# Path mapping dictionary, target to link mapping # Path mapping dictionary, target to link mapping
self._path_mappings = {} self._path_mappings = {}
# Static route mapping dictionary, target to route mapping # Static route mapping dictionary, target to route mapping
@@ -61,9 +60,6 @@ class Config:
if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings: if self.checkpoints_roots and len(self.checkpoints_roots) == 1 and "default_checkpoint_root" not in settings:
settings["default_checkpoint_root"] = self.checkpoints_roots[0] settings["default_checkpoint_root"] = self.checkpoints_roots[0]
if self.embeddings_roots and len(self.embeddings_roots) == 1 and "default_embedding_root" not in settings:
settings["default_embedding_root"] = self.embeddings_roots[0]
# Save settings # Save settings
with open(settings_path, 'w', encoding='utf-8') as f: with open(settings_path, 'w', encoding='utf-8') as f:
json.dump(settings, f, indent=2) json.dump(settings, f, indent=2)
@@ -205,20 +201,16 @@ class Config:
real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/') real_path = os.path.normpath(os.path.realpath(path)).replace(os.sep, '/')
unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen unet_map[real_path] = unet_map.get(real_path, path.replace(os.sep, "/")) # preserve first seen
# Merge both maps and deduplicate by real path
merged_map = {}
for real_path, orig_path in {**checkpoint_map, **unet_map}.items():
if real_path not in merged_map:
merged_map[real_path] = orig_path
# Now sort and use only the deduplicated real paths # Now sort and use only the deduplicated real paths
unique_paths = sorted(merged_map.values(), key=lambda p: p.lower()) unique_checkpoint_paths = sorted(checkpoint_map.values(), key=lambda p: p.lower())
unique_unet_paths = sorted(unet_map.values(), key=lambda p: p.lower())
# Split back into checkpoints and unet roots for class properties # Store individual paths in class properties
self.checkpoints_roots = [p for p in unique_paths if p in checkpoint_map.values()] self.checkpoints_roots = unique_checkpoint_paths
self.unet_roots = [p for p in unique_paths if p in unet_map.values()] self.unet_roots = unique_unet_paths
all_paths = unique_paths # Combine all checkpoint-related paths for return value
all_paths = unique_checkpoint_paths + unique_unet_paths
logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]")) logger.info("Found checkpoint roots:" + ("\n - " + "\n - ".join(all_paths) if all_paths else "[]"))
@@ -268,25 +260,16 @@ class Config:
return [] return []
def get_preview_static_url(self, preview_path: str) -> str: def get_preview_static_url(self, preview_path: str) -> str:
"""Convert local preview path to static URL"""
if not preview_path: if not preview_path:
return "" return ""
real_path = os.path.realpath(preview_path).replace(os.sep, '/') real_path = os.path.realpath(preview_path).replace(os.sep, '/')
# Find longest matching path (most specific match)
best_match = ""
best_route = ""
for path, route in self._route_mappings.items(): for path, route in self._route_mappings.items():
if real_path.startswith(path) and len(path) > len(best_match): if real_path.startswith(path):
best_match = path relative_path = os.path.relpath(real_path, path)
best_route = route return f'{route}/{relative_path.replace(os.sep, "/")}'
if best_match:
relative_path = os.path.relpath(real_path, best_match).replace(os.sep, '/')
safe_parts = [urllib.parse.quote(part) for part in relative_path.split('/')]
safe_path = '/'.join(safe_parts)
return f'{best_route}/{safe_path}'
return "" return ""

View File

@@ -16,7 +16,6 @@ from .services.service_registry import ServiceRegistry
from .services.settings_manager import settings from .services.settings_manager import settings
from .utils.example_images_migration import ExampleImagesMigration from .utils.example_images_migration import ExampleImagesMigration
from .services.websocket_manager import ws_manager from .services.websocket_manager import ws_manager
from .services.example_images_cleanup_service import ExampleImagesCleanupService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -147,11 +146,6 @@ class LoraManager:
logger.warning(f"Failed to add static route on initialization for {target_path}: {e}") logger.warning(f"Failed to add static route on initialization for {target_path}: {e}")
continue continue
# Add static route for locales JSON files
if os.path.exists(config.i18n_path):
app.router.add_static('/locales', config.i18n_path)
logger.info(f"Added static route for locales: /locales -> {config.i18n_path}")
# Add static route for plugin assets # Add static route for plugin assets
app.router.add_static('/loras_static', config.static_path) app.router.add_static('/loras_static', config.static_path)
@@ -167,7 +161,7 @@ class LoraManager:
RecipeRoutes.setup_routes(app) RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app) UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app) MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) ExampleImagesRoutes.setup_routes(app)
# Setup WebSocket routes that are shared across all model types # Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
@@ -192,9 +186,6 @@ class LoraManager:
# Register DownloadManager with ServiceRegistry # Register DownloadManager with ServiceRegistry
await ServiceRegistry.get_download_manager() await ServiceRegistry.get_download_manager()
from .services.metadata_service import initialize_metadata_providers
await initialize_metadata_providers()
# Initialize WebSocket manager # Initialize WebSocket manager
await ServiceRegistry.get_websocket_manager() await ServiceRegistry.get_websocket_manager()
@@ -207,188 +198,29 @@ class LoraManager:
recipe_scanner = await ServiceRegistry.get_recipe_scanner() recipe_scanner = await ServiceRegistry.get_recipe_scanner()
# Create low-priority initialization tasks # Create low-priority initialization tasks
init_tasks = [ asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init')
asyncio.create_task(lora_scanner.initialize_in_background(), name='lora_cache_init'), asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init')
asyncio.create_task(checkpoint_scanner.initialize_in_background(), name='checkpoint_cache_init'), asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init')
asyncio.create_task(embedding_scanner.initialize_in_background(), name='embedding_cache_init'), asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
asyncio.create_task(recipe_scanner.initialize_in_background(), name='recipe_cache_init')
]
await ExampleImagesMigration.check_and_run_migrations() await ExampleImagesMigration.check_and_run_migrations()
# Schedule post-initialization tasks to run after scanners complete logger.info("LoRA Manager: All services initialized and background tasks scheduled")
asyncio.create_task(
cls._run_post_initialization_tasks(init_tasks),
name='post_init_tasks'
)
logger.debug("LoRA Manager: All services initialized and background tasks scheduled")
except Exception as e: except Exception as e:
logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True) logger.error(f"LoRA Manager: Error initializing services: {e}", exc_info=True)
@classmethod
async def _run_post_initialization_tasks(cls, init_tasks):
"""Run post-initialization tasks after all scanners complete"""
try:
logger.debug("LoRA Manager: Waiting for scanner initialization to complete...")
# Wait for all scanner initialization tasks to complete
await asyncio.gather(*init_tasks, return_exceptions=True)
logger.debug("LoRA Manager: Scanner initialization completed, starting post-initialization tasks...")
# Run post-initialization tasks
post_tasks = [
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
# Add more post-initialization tasks here as needed
# asyncio.create_task(cls._another_post_task(), name='another_task'),
]
# Run all post-initialization tasks
results = await asyncio.gather(*post_tasks, return_exceptions=True)
# Log results
for i, result in enumerate(results):
task_name = post_tasks[i].get_name()
if isinstance(result, Exception):
logger.error(f"Post-initialization task '{task_name}' failed: {result}")
else:
logger.debug(f"Post-initialization task '{task_name}' completed successfully")
logger.debug("LoRA Manager: All post-initialization tasks completed")
except Exception as e:
logger.error(f"LoRA Manager: Error in post-initialization tasks: {e}", exc_info=True)
@classmethod
async def _cleanup_backup_files(cls):
"""Clean up .bak files in all model roots"""
try:
logger.debug("Starting cleanup of .bak files in model directories...")
# Collect all model roots
all_roots = set()
all_roots.update(config.loras_roots)
all_roots.update(config.base_models_roots)
all_roots.update(config.embeddings_roots)
total_deleted = 0
total_size_freed = 0
for root_path in all_roots:
if not os.path.exists(root_path):
continue
try:
deleted_count, size_freed = await cls._cleanup_backup_files_in_directory(root_path)
total_deleted += deleted_count
total_size_freed += size_freed
if deleted_count > 0:
logger.debug(f"Cleaned up {deleted_count} .bak files in {root_path} (freed {size_freed / (1024*1024):.2f} MB)")
except Exception as e:
logger.error(f"Error cleaning up .bak files in {root_path}: {e}")
# Yield control periodically
await asyncio.sleep(0.01)
if total_deleted > 0:
logger.debug(f"Backup cleanup completed: removed {total_deleted} .bak files, freed {total_size_freed / (1024*1024):.2f} MB total")
else:
logger.debug("Backup cleanup completed: no .bak files found")
except Exception as e:
logger.error(f"Error during backup file cleanup: {e}", exc_info=True)
@classmethod
async def _cleanup_backup_files_in_directory(cls, directory_path: str):
"""Clean up .bak files in a specific directory recursively
Args:
directory_path: Path to the directory to clean
Returns:
Tuple[int, int]: (number of files deleted, total size freed in bytes)
"""
deleted_count = 0
size_freed = 0
visited_paths = set()
def cleanup_recursive(path):
nonlocal deleted_count, size_freed
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
return
visited_paths.add(real_path)
with os.scandir(path) as it:
for entry in it:
try:
if entry.is_file(follow_symlinks=True) and entry.name.endswith('.bak'):
file_size = entry.stat().st_size
os.remove(entry.path)
deleted_count += 1
size_freed += file_size
logger.debug(f"Deleted .bak file: {entry.path}")
elif entry.is_dir(follow_symlinks=True):
cleanup_recursive(entry.path)
except Exception as e:
logger.warning(f"Could not delete .bak file {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning directory {path} for .bak files: {e}")
# Run the recursive cleanup in a thread pool to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, cleanup_recursive, directory_path)
return deleted_count, size_freed
@classmethod
async def _cleanup_example_images_folders(cls):
"""Invoke the example images cleanup service for manual execution."""
try:
service = ExampleImagesCleanupService()
result = await service.cleanup_example_image_folders()
if result.get('success'):
logger.debug(
"Manual example images cleanup completed: moved=%s",
result.get('moved_total'),
)
elif result.get('partial_success'):
logger.warning(
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
result.get('moved_total'),
result.get('move_failures'),
)
else:
logger.debug(
"Manual example images cleanup skipped or failed: %s",
result.get('error', 'no changes'),
)
return result
except Exception as e: # pragma: no cover - defensive guard
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
return {
'success': False,
'error': str(e),
'error_code': 'unexpected_error',
}
@classmethod @classmethod
async def _cleanup(cls, app): async def _cleanup(cls, app):
"""Cleanup resources using ServiceRegistry""" """Cleanup resources using ServiceRegistry"""
try: try:
logger.info("LoRA Manager: Cleaning up services") logger.info("LoRA Manager: Cleaning up services")
# Close CivitaiClient gracefully
civitai_client = await ServiceRegistry.get_service("civitai_client")
if civitai_client:
await civitai_client.close()
logger.info("Closed CivitaiClient connection")
except Exception as e: except Exception as e:
logger.error(f"Error during cleanup: {e}", exc_info=True) logger.error(f"Error during cleanup: {e}", exc_info=True)

View File

@@ -1,7 +1,9 @@
import os import os
import importlib
import sys
# Check if running in standalone mode # Check if running in standalone mode
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = 'nodes' not in sys.modules
if not standalone_mode: if not standalone_mode:
from .metadata_hook import MetadataHook from .metadata_hook import MetadataHook

View File

@@ -146,33 +146,45 @@ class MetadataHook:
# Store the original _async_map_node_over_list function # Store the original _async_map_node_over_list function
original_map_node_over_list = getattr(execution, map_node_func_name) original_map_node_over_list = getattr(execution, map_node_func_name)
# Wrapped async function, compatible with both stable and nightly # Define the wrapped async function - NOTE: Updated signature with prompt_id and unique_id!
async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, *args, **kwargs): async def async_map_node_over_list_with_metadata(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
hidden_inputs = kwargs.get('hidden_inputs', None)
# Only collect metadata when calling the main function of nodes # Only collect metadata when calling the main function of nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'): if func == obj.FUNCTION and hasattr(obj, '__class__'):
try: try:
# Get the current prompt_id from the registry
registry = MetadataRegistry() registry = MetadataRegistry()
# We now have prompt_id directly from the function parameters
if prompt_id is not None: if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__ class_type = obj.__class__.__name__
# Use the passed unique_id parameter instead of trying to extract it
node_id = unique_id node_id = unique_id
# Record inputs before execution
if node_id is not None: if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None) registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e: except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}") print(f"Error collecting metadata (pre-execution): {str(e)}")
# Call original function with all args/kwargs # Execute the original async function with ALL parameters in the correct order
results = await original_map_node_over_list( results = await original_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
prompt_id, unique_id, obj, input_data_all, func,
allow_interrupt, execution_block_cb, pre_execute_cb, *args, **kwargs
)
# After execution, collect outputs for relevant nodes
if func == obj.FUNCTION and hasattr(obj, '__class__'): if func == obj.FUNCTION and hasattr(obj, '__class__'):
try: try:
# Get the current prompt_id from the registry
registry = MetadataRegistry() registry = MetadataRegistry()
if prompt_id is not None: if prompt_id is not None:
# Get node class type
class_type = obj.__class__.__name__ class_type = obj.__class__.__name__
# Use the passed unique_id parameter
node_id = unique_id node_id = unique_id
# Record outputs after execution
if node_id is not None: if node_id is not None:
registry.update_node_execution(node_id, class_type, results) registry.update_node_execution(node_id, class_type, results)
except Exception as e: except Exception as e:

View File

@@ -1,9 +1,9 @@
import json import json
import os import sys
from .constants import IMAGES from .constants import IMAGES
# Check if running in standalone mode # Check if running in standalone mode
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = 'nodes' not in sys.modules
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER
@@ -295,7 +295,7 @@ class MetadataProcessor:
"seed": None, "seed": None,
"steps": None, "steps": None,
"cfg_scale": None, "cfg_scale": None,
# "guidance": None, # Add guidance parameter "guidance": None, # Add guidance parameter
"sampler": None, "sampler": None,
"scheduler": None, "scheduler": None,
"checkpoint": None, "checkpoint": None,
@@ -339,8 +339,44 @@ class MetadataProcessor:
is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced" is_custom_advanced = prompt.original_prompt[primary_sampler_id].get("class_type") == "SamplerCustomAdvanced"
if is_custom_advanced: if is_custom_advanced:
# For SamplerCustomAdvanced, use the new handler method # For SamplerCustomAdvanced, trace specific inputs
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
# 1. Trace sigmas input to find BasicScheduler
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", "BasicScheduler", max_depth=5)
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
params["steps"] = scheduler_params.get("steps")
params["scheduler"] = scheduler_params.get("scheduler")
# 2. Trace sampler input to find KSamplerSelect
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
params["sampler"] = sampler_params.get("sampler_name")
# 3. Trace guider input for CFGGuider and CLIPTextEncode
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
if guider_node_id and guider_node_id in prompt.original_prompt:
# Check if the guider node is a CFGGuider
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
# Extract cfg value from the CFGGuider
if guider_node_id in metadata.get(SAMPLING, {}):
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
params["cfg_scale"] = cfg_params.get("cfg")
# Find CLIPTextEncode for positive prompt
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find CLIPTextEncode for negative prompt
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
else:
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
else: else:
# For standard samplers, match conditioning objects to prompts # For standard samplers, match conditioning objects to prompts
@@ -366,9 +402,6 @@ class MetadataProcessor:
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}): if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "") params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
# For SamplerCustom, handle any additional parameters
MetadataProcessor.handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params)
# Size extraction is same for all sampler types # Size extraction is same for all sampler types
# Check if the sampler itself has size information (from latent_image) # Check if the sampler itself has size information (from latent_image)
if primary_sampler_id in metadata.get(SIZE, {}): if primary_sampler_id in metadata.get(SIZE, {}):
@@ -421,59 +454,3 @@ class MetadataProcessor:
"""Convert metadata to JSON string""" """Convert metadata to JSON string"""
params = MetadataProcessor.to_dict(metadata, id) params = MetadataProcessor.to_dict(metadata, id)
return json.dumps(params, indent=4) return json.dumps(params, indent=4)
@staticmethod
def handle_custom_advanced_sampler(metadata, prompt, primary_sampler_id, params):
"""
Handle parameter extraction for SamplerCustomAdvanced nodes
Parameters:
- metadata: The workflow metadata
- prompt: The prompt object containing node connections
- primary_sampler_id: ID of the SamplerCustomAdvanced node
- params: Parameters dictionary to update
"""
if not prompt.original_prompt or primary_sampler_id not in prompt.original_prompt:
return
sampler_inputs = prompt.original_prompt[primary_sampler_id].get("inputs", {})
# 1. Trace sigmas input to find BasicScheduler (only if sigmas input exists)
if "sigmas" in sampler_inputs:
scheduler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sigmas", None, max_depth=5)
if scheduler_node_id and scheduler_node_id in metadata.get(SAMPLING, {}):
scheduler_params = metadata[SAMPLING][scheduler_node_id].get("parameters", {})
params["steps"] = scheduler_params.get("steps")
params["scheduler"] = scheduler_params.get("scheduler")
# 2. Trace sampler input to find KSamplerSelect (only if sampler input exists)
if "sampler" in sampler_inputs:
sampler_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "sampler", "KSamplerSelect", max_depth=5)
if sampler_node_id and sampler_node_id in metadata.get(SAMPLING, {}):
sampler_params = metadata[SAMPLING][sampler_node_id].get("parameters", {})
params["sampler"] = sampler_params.get("sampler_name")
# 3. Trace guider input for CFGGuider and CLIPTextEncode
if "guider" in sampler_inputs:
guider_node_id = MetadataProcessor.trace_node_input(prompt, primary_sampler_id, "guider", max_depth=5)
if guider_node_id and guider_node_id in prompt.original_prompt:
# Check if the guider node is a CFGGuider
if prompt.original_prompt[guider_node_id].get("class_type") == "CFGGuider":
# Extract cfg value from the CFGGuider
if guider_node_id in metadata.get(SAMPLING, {}):
cfg_params = metadata[SAMPLING][guider_node_id].get("parameters", {})
params["cfg_scale"] = cfg_params.get("cfg")
# Find CLIPTextEncode for positive prompt
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "positive", "CLIPTextEncode", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")
# Find CLIPTextEncode for negative prompt
negative_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "negative", "CLIPTextEncode", max_depth=10)
if negative_node_id and negative_node_id in metadata.get(PROMPTS, {}):
params["negative_prompt"] = metadata[PROMPTS][negative_node_id].get("text", "")
else:
positive_node_id = MetadataProcessor.trace_node_input(prompt, guider_node_id, "conditioning", max_depth=10)
if positive_node_id and positive_node_id in metadata.get(PROMPTS, {}):
params["prompt"] = metadata[PROMPTS][positive_node_id].get("text", "")

View File

@@ -642,9 +642,7 @@ NODE_EXTRACTORS = {
# Sampling # Sampling
"KSampler": SamplerExtractor, "KSampler": SamplerExtractor,
"KSamplerAdvanced": KSamplerAdvancedExtractor, "KSamplerAdvanced": KSamplerAdvancedExtractor,
"SamplerCustom": KSamplerAdvancedExtractor,
"SamplerCustomAdvanced": SamplerCustomAdvancedExtractor, "SamplerCustomAdvanced": SamplerCustomAdvancedExtractor,
"ClownsharKSampler_Beta": SamplerExtractor,
"TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes "TSC_KSampler": TSCKSamplerExtractor, # Efficient Nodes
"TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes "TSC_KSamplerAdvanced": TSCKSamplerAdvancedExtractor, # Efficient Nodes
"KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack "KSamplerBasicPipe": KSamplerBasicPipeExtractor, # comfyui-impact-pack
@@ -654,11 +652,9 @@ NODE_EXTRACTORS = {
# Sampling Selectors # Sampling Selectors
"KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect "KSamplerSelect": KSamplerSelectExtractor, # Add KSamplerSelect
"BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler "BasicScheduler": BasicSchedulerExtractor, # Add BasicScheduler
"AlignYourStepsScheduler": BasicSchedulerExtractor, # Add AlignYourStepsScheduler
# Loaders # Loaders
"CheckpointLoaderSimple": CheckpointLoaderExtractor, "CheckpointLoaderSimple": CheckpointLoaderExtractor,
"comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader "comfyLoader": CheckpointLoaderExtractor, # easy comfyLoader
"CheckpointLoaderSimpleWithImages": CheckpointLoaderExtractor, # CheckpointLoader|pysssss
"TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes "TSC_EfficientLoader": TSCCheckpointLoaderExtractor, # Efficient Nodes
"UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor "UNETLoader": UNETLoaderExtractor, # Updated to use dedicated extractor
"UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor "UnetLoaderGGUF": UNETLoaderExtractor, # Updated to use dedicated extractor
@@ -671,7 +667,6 @@ NODE_EXTRACTORS = {
"AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb "AdvancedCLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/BlenderNeko/ComfyUI_ADV_CLIP_emb
"smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes "smZ_CLIPTextEncode": CLIPTextEncodeExtractor, # From https://github.com/shiimizu/ComfyUI_smZNodes
"CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack "CR_ApplyControlNetStack": CR_ApplyControlNetStackExtractor, # Add CR_ApplyControlNetStack
"PCTextEncode": CLIPTextEncodeExtractor, # From https://github.com/asagi4/comfyui-prompt-control
# Latent # Latent
"EmptyLatentImage": ImageSizeExtractor, "EmptyLatentImage": ImageSizeExtractor,
# Flux # Flux

View File

@@ -1 +0,0 @@
"""Server middleware modules"""

View File

@@ -1,53 +0,0 @@
"""Cache control middleware for ComfyUI server"""
from aiohttp import web
from typing import Callable, Awaitable
# Time in seconds
ONE_HOUR: int = 3600
ONE_DAY: int = 86400
IMG_EXTENSIONS = (
".jpg",
".jpeg",
".png",
".ppm",
".bmp",
".pgm",
".tif",
".tiff",
".webp",
".mp4"
)
@web.middleware
async def cache_control(
request: web.Request, handler: Callable[[web.Request], Awaitable[web.Response]]
) -> web.Response:
"""Cache control middleware that sets appropriate cache headers based on file type and response status"""
response: web.Response = await handler(request)
if (
request.path.endswith(".js")
or request.path.endswith(".css")
or request.path.endswith("index.json")
):
response.headers.setdefault("Cache-Control", "no-cache")
return response
# Early return for non-image files - no cache headers needed
if not request.path.lower().endswith(IMG_EXTENSIONS):
return response
# Handle image files
if response.status == 404:
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_HOUR}")
elif response.status in (200, 201, 202, 203, 204, 205, 206, 301, 308):
# Success responses and permanent redirects - cache for 1 day
response.headers.setdefault("Cache-Control", f"public, max-age={ONE_DAY}")
elif response.status in (302, 303, 307):
# Temporary redirects - no cache
response.headers.setdefault("Cache-Control", "no-cache")
# Note: 304 Not Modified falls through - no cache headers set
return response

View File

@@ -1,5 +1,4 @@
import logging import logging
import re
from nodes import LoraLoader from nodes import LoraLoader
from comfy.comfy_types import IO # type: ignore from comfy.comfy_types import IO # type: ignore
from ..utils.utils import get_lora_info from ..utils.utils import get_lora_info
@@ -19,7 +18,6 @@ class LoraManagerLoader:
# "clip": ("CLIP",), # "clip": ("CLIP",),
"text": (IO.STRING, { "text": (IO.STRING, {
"multiline": True, "multiline": True,
"pysssss.autocomplete": False,
"dynamicPrompts": True, "dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation", "tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>" "placeholder": "LoRA syntax input: <lora:name:strength>"
@@ -111,144 +109,6 @@ class LoraManagerLoader:
# use ',, ' to separate trigger words for group mode # use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else "" trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Format loaded_loras with support for both formats
formatted_loras = []
for item in loaded_loras:
parts = item.split(":")
lora_name = parts[0]
strength_parts = parts[1].strip().split(",")
if len(strength_parts) > 1:
# Different model and clip strengths
model_str = strength_parts[0].strip()
clip_str = strength_parts[1].strip()
formatted_loras.append(f"<lora:{lora_name}:{model_str}:{clip_str}>")
else:
# Same strength for both
model_str = strength_parts[0].strip()
formatted_loras.append(f"<lora:{lora_name}:{model_str}>")
formatted_loras_text = " ".join(formatted_loras)
return (model, clip, trigger_words_text, formatted_loras_text)
class LoraManagerTextLoader:
NAME = "LoRA Text Loader (LoraManager)"
CATEGORY = "Lora Manager/loaders"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"lora_syntax": (IO.STRING, {
"defaultInput": True,
"forceInput": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation"
}),
},
"optional": {
"clip": ("CLIP",),
"lora_stack": ("LORA_STACK",),
}
}
RETURN_TYPES = ("MODEL", "CLIP", IO.STRING, IO.STRING)
RETURN_NAMES = ("MODEL", "CLIP", "trigger_words", "loaded_loras")
FUNCTION = "load_loras_from_text"
def parse_lora_syntax(self, text):
"""Parse LoRA syntax from text input."""
# Pattern to match <lora:name:strength> or <lora:name:model_strength:clip_strength>
pattern = r'<lora:([^:>]+):([^:>]+)(?::([^:>]+))?>'
matches = re.findall(pattern, text, re.IGNORECASE)
loras = []
for match in matches:
lora_name = match[0]
model_strength = float(match[1])
clip_strength = float(match[2]) if match[2] else model_strength
loras.append({
'name': lora_name,
'model_strength': model_strength,
'clip_strength': clip_strength
})
return loras
def load_loras_from_text(self, model, lora_syntax, clip=None, lora_stack=None):
"""Load LoRAs based on text syntax input."""
loaded_loras = []
all_trigger_words = []
# Check if model is a Nunchaku Flux model - simplified approach
is_nunchaku_model = False
try:
model_wrapper = model.model.diffusion_model
# Check if model is a Nunchaku Flux model using only class name
if model_wrapper.__class__.__name__ == "ComfyFluxWrapper":
is_nunchaku_model = True
logger.info("Detected Nunchaku Flux model")
except (AttributeError, TypeError):
# Not a model with the expected structure
pass
# First process lora_stack if available
if lora_stack:
for lora_path, model_strength, clip_strength in lora_stack:
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# Use our custom function for Flux models
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged for Nunchaku models
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Extract lora name for trigger words lookup
lora_name = extract_lora_name(lora_path)
_, trigger_words = get_lora_info(lora_name)
all_trigger_words.extend(trigger_words)
# Add clip strength to output if different from model strength (except for Nunchaku models)
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
# Parse and process LoRAs from text syntax
parsed_loras = self.parse_lora_syntax(lora_syntax)
for lora in parsed_loras:
lora_name = lora['name']
model_strength = lora['model_strength']
clip_strength = lora['clip_strength']
# Get lora path and trigger words
lora_path, trigger_words = get_lora_info(lora_name)
# Apply the LoRA using the appropriate loader
if is_nunchaku_model:
# For Nunchaku models, use our custom function
model = nunchaku_load_lora(model, lora_path, model_strength)
# clip remains unchanged
else:
# Use default loader for standard models
model, clip = LoraLoader().load_lora(model, clip, lora_path, model_strength, clip_strength)
# Include clip strength in output if different from model strength and not a Nunchaku model
if not is_nunchaku_model and abs(model_strength - clip_strength) > 0.001:
loaded_loras.append(f"{lora_name}: {model_strength},{clip_strength}")
else:
loaded_loras.append(f"{lora_name}: {model_strength}")
# Add trigger words to collection
all_trigger_words.extend(trigger_words)
# use ',, ' to separate trigger words for group mode
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
# Format loaded_loras with support for both formats # Format loaded_loras with support for both formats
formatted_loras = [] formatted_loras = []
for item in loaded_loras: for item in loaded_loras:

View File

@@ -17,7 +17,6 @@ class LoraStacker:
"required": { "required": {
"text": (IO.STRING, { "text": (IO.STRING, {
"multiline": True, "multiline": True,
"pysssss.autocomplete": False,
"dynamicPrompts": True, "dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation", "tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>" "placeholder": "LoRA syntax input: <lora:name:strength>"

View File

@@ -1,5 +1,6 @@
import json import json
import os import os
import asyncio
import re import re
import numpy as np import numpy as np
import folder_paths # type: ignore import folder_paths # type: ignore
@@ -418,15 +419,11 @@ class SaveImage:
# Make sure the output directory exists # Make sure the output directory exists
os.makedirs(self.output_dir, exist_ok=True) os.makedirs(self.output_dir, exist_ok=True)
# If images is already a list or array of images, do nothing; otherwise, convert to list # Ensure images is always a list of images
if isinstance(images, (list, np.ndarray)): if len(images.shape) == 3: # Single image (height, width, channels)
pass images = [images]
else: else: # Multiple images (batch, height, width, channels)
# Ensure images is always a list of images images = [img for img in images]
if len(images.shape) == 3: # Single image (height, width, channels)
images = [images]
else: # Multiple images (batch, height, width, channels)
images = [img for img in images]
# Save all images # Save all images
results = self.save_images( results = self.save_images(

View File

@@ -14,11 +14,9 @@ class WanVideoLoraSelect:
def INPUT_TYPES(cls): def INPUT_TYPES(cls):
return { return {
"required": { "required": {
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}), "low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load the LORA model with less VRAM usage, slower loading"}),
"merge_loras": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
"text": (IO.STRING, { "text": (IO.STRING, {
"multiline": True, "multiline": True,
"pysssss.autocomplete": False,
"dynamicPrompts": True, "dynamicPrompts": True,
"tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation", "tooltip": "Format: <lora:lora_name:strength> separated by spaces or punctuation",
"placeholder": "LoRA syntax input: <lora:name:strength>" "placeholder": "LoRA syntax input: <lora:name:strength>"
@@ -31,7 +29,7 @@ class WanVideoLoraSelect:
RETURN_NAMES = ("lora", "trigger_words", "active_loras") RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras" FUNCTION = "process_loras"
def process_loras(self, text, low_mem_load=False, merge_loras=True, **kwargs): def process_loras(self, text, low_mem_load=False, **kwargs):
loras_list = [] loras_list = []
all_trigger_words = [] all_trigger_words = []
active_loras = [] active_loras = []
@@ -41,9 +39,6 @@ class WanVideoLoraSelect:
if prev_lora is not None: if prev_lora is not None:
loras_list.extend(prev_lora) loras_list.extend(prev_lora)
if not merge_loras:
low_mem_load = False # Unmerged LoRAs don't need low_mem_load
# Get blocks if available # Get blocks if available
blocks = kwargs.get('blocks', {}) blocks = kwargs.get('blocks', {})
selected_blocks = blocks.get("selected_blocks", {}) selected_blocks = blocks.get("selected_blocks", {})
@@ -70,7 +65,6 @@ class WanVideoLoraSelect:
"blocks": selected_blocks, "blocks": selected_blocks,
"layer_filter": layer_filter, "layer_filter": layer_filter,
"low_mem_load": low_mem_load, "low_mem_load": low_mem_load,
"merge_loras": merge_loras,
} }
# Add to list and collect active loras # Add to list and collect active loras

View File

@@ -1,127 +0,0 @@
from comfy.comfy_types import IO
import folder_paths
from ..utils.utils import get_lora_info
from .utils import any_type
import logging
# 初始化日志记录器
logger = logging.getLogger(__name__)
# 定义新节点的类
class WanVideoLoraSelectFromText:
# 节点在UI中显示的名称
NAME = "WanVideo Lora Select From Text (LoraManager)"
# 节点所属的分类
CATEGORY = "Lora Manager/stackers"
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"low_mem_load": ("BOOLEAN", {"default": False, "tooltip": "Load LORA models with less VRAM usage, slower loading. This affects ALL LoRAs, not just the current ones. No effect if merge_loras is False"}),
"merge_lora": ("BOOLEAN", {"default": True, "tooltip": "Merge LoRAs into the model, otherwise they are loaded on the fly. Always disabled for GGUF and scaled fp8 models. This affects ALL LoRAs, not just the current one"}),
"lora_syntax": (IO.STRING, {
"multiline": True,
"defaultInput": True,
"forceInput": True,
"tooltip": "Connect a TEXT output for LoRA syntax: <lora:name:strength>"
}),
},
"optional": {
"prev_lora": ("WANVIDLORA",),
"blocks": ("BLOCKS",)
}
}
RETURN_TYPES = ("WANVIDLORA", IO.STRING, IO.STRING)
RETURN_NAMES = ("lora", "trigger_words", "active_loras")
FUNCTION = "process_loras_from_syntax"
def process_loras_from_syntax(self, lora_syntax, low_mem_load=False, merge_lora=True, **kwargs):
text_to_process = lora_syntax
blocks = kwargs.get('blocks', {})
selected_blocks = blocks.get("selected_blocks", {})
layer_filter = blocks.get("layer_filter", "")
loras_list = []
all_trigger_words = []
active_loras = []
prev_lora = kwargs.get('prev_lora', None)
if prev_lora is not None:
loras_list.extend(prev_lora)
if not merge_lora:
low_mem_load = False
parts = text_to_process.split('<lora:')
for part in parts[1:]:
end_index = part.find('>')
if end_index == -1:
continue
content = part[:end_index]
lora_parts = content.split(':')
lora_name_raw = ""
model_strength = 1.0
clip_strength = 1.0
if len(lora_parts) == 2:
lora_name_raw = lora_parts[0].strip()
try:
model_strength = float(lora_parts[1])
clip_strength = model_strength
except (ValueError, IndexError):
logger.warning(f"Invalid strength for LoRA '{lora_name_raw}'. Skipping.")
continue
elif len(lora_parts) >= 3:
lora_name_raw = lora_parts[0].strip()
try:
model_strength = float(lora_parts[1])
clip_strength = float(lora_parts[2])
except (ValueError, IndexError):
logger.warning(f"Invalid strengths for LoRA '{lora_name_raw}'. Skipping.")
continue
else:
continue
lora_path, trigger_words = get_lora_info(lora_name_raw)
lora_item = {
"path": folder_paths.get_full_path("loras", lora_path),
"strength": model_strength,
"name": lora_path.split(".")[0],
"blocks": selected_blocks,
"layer_filter": layer_filter,
"low_mem_load": low_mem_load,
"merge_loras": merge_lora,
}
loras_list.append(lora_item)
active_loras.append((lora_name_raw, model_strength, clip_strength))
all_trigger_words.extend(trigger_words)
trigger_words_text = ",, ".join(all_trigger_words) if all_trigger_words else ""
formatted_loras = []
for name, model_strength, clip_strength in active_loras:
if abs(model_strength - clip_strength) > 0.001:
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}:{str(clip_strength).strip()}>")
else:
formatted_loras.append(f"<lora:{name}:{str(model_strength).strip()}>")
active_loras_text = " ".join(formatted_loras)
return (loras_list, trigger_words_text, active_loras_text)
NODE_CLASS_MAPPINGS = {
"WanVideoLoraSelectFromText": WanVideoLoraSelectFromText
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoLoraSelectFromText": "WanVideo Lora Select From Text (LoraManager)"
}

View File

@@ -55,7 +55,7 @@ class RecipeMetadataParser(ABC):
# Unpack the tuple to get the actual data # Unpack the tuple to get the actual data
civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None) civitai_info, error_msg = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
if not civitai_info or error_msg == "Model not found": if not civitai_info or civitai_info.get("error") == "Model not found":
# Model not found or deleted # Model not found or deleted
lora_entry['isDeleted'] = True lora_entry['isDeleted'] = True
lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png' lora_entry['thumbnailUrl'] = '/loras_static/images/no-preview.png'
@@ -119,10 +119,10 @@ class RecipeMetadataParser(ABC):
# Check if exists locally # Check if exists locally
if recipe_scanner and lora_entry['hash']: if recipe_scanner and lora_entry['hash']:
lora_scanner = recipe_scanner._lora_scanner lora_scanner = recipe_scanner._lora_scanner
exists_locally = lora_scanner.has_hash(lora_entry['hash']) exists_locally = lora_scanner.has_lora_hash(lora_entry['hash'])
if exists_locally: if exists_locally:
try: try:
local_path = lora_scanner.get_path_by_hash(lora_entry['hash']) local_path = lora_scanner.get_lora_path_by_hash(lora_entry['hash'])
lora_entry['existsLocally'] = True lora_entry['existsLocally'] = True
lora_entry['localPath'] = local_path lora_entry['localPath'] = local_path
lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0] lora_entry['file_name'] = os.path.splitext(os.path.basename(local_path))[0]

View File

@@ -6,7 +6,6 @@ import logging
from typing import Dict, Any from typing import Dict, Any
from ..base import RecipeMetadataParser from ..base import RecipeMetadataParser
from ..constants import GEN_PARAM_KEYS from ..constants import GEN_PARAM_KEYS
from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,9 +30,6 @@ class AutomaticMetadataParser(RecipeMetadataParser):
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from Automatic1111 format""" """Parse metadata from Automatic1111 format"""
try: try:
# Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider()
# Split on Negative prompt if it exists # Split on Negative prompt if it exists
if "Negative prompt:" in user_comment: if "Negative prompt:" in user_comment:
parts = user_comment.split('Negative prompt:', 1) parts = user_comment.split('Negative prompt:', 1)
@@ -185,30 +181,13 @@ class AutomaticMetadataParser(RecipeMetadataParser):
# First use Civitai resources if available (more reliable source) # First use Civitai resources if available (more reliable source)
if metadata.get("civitai_resources"): if metadata.get("civitai_resources"):
for resource in metadata.get("civitai_resources", []): for resource in metadata.get("civitai_resources", []):
# --- Added: Parse 'air' field if present ---
air = resource.get("air")
if air:
# Format: urn:air:sdxl:lora:civitai:1221007@1375651
# Or: urn:air:sdxl:checkpoint:civitai:623891@2019115
air_pattern = r"urn:air:[^:]+:(?P<type>[^:]+):civitai:(?P<modelId>\d+)@(?P<modelVersionId>\d+)"
air_match = re.match(air_pattern, air)
if air_match:
air_type = air_match.group("type")
air_modelId = int(air_match.group("modelId"))
air_modelVersionId = int(air_match.group("modelVersionId"))
# checkpoint/lycoris/lora/hypernet
resource["type"] = air_type
resource["modelId"] = air_modelId
resource["modelVersionId"] = air_modelVersionId
# --- End added ---
if resource.get("type") in ["lora", "lycoris", "hypernet"] and resource.get("modelVersionId"): if resource.get("type") in ["lora", "lycoris", "hypernet"] and resource.get("modelVersionId"):
# Initialize lora entry # Initialize lora entry
lora_entry = { lora_entry = {
'id': resource.get("modelVersionId", 0), 'id': resource.get("modelVersionId", 0),
'modelId': resource.get("modelId", 0), 'modelId': resource.get("modelId", 0),
'name': resource.get("modelName", "Unknown LoRA"), 'name': resource.get("modelName", "Unknown LoRA"),
'version': resource.get("modelVersionName", resource.get("versionName", "")), 'version': resource.get("modelVersionName", ""),
'type': resource.get("type", "lora"), 'type': resource.get("type", "lora"),
'weight': round(float(resource.get("weight", 1.0)), 2), 'weight': round(float(resource.get("weight", 1.0)), 2),
'existsLocally': False, 'existsLocally': False,
@@ -220,9 +199,9 @@ class AutomaticMetadataParser(RecipeMetadataParser):
} }
# Get additional info from Civitai # Get additional info from Civitai
if metadata_provider: if civitai_client:
try: try:
civitai_info = await metadata_provider.get_model_version_info(resource.get("modelVersionId")) civitai_info = await civitai_client.get_model_version_info(resource.get("modelVersionId"))
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
civitai_info, civitai_info,
@@ -275,11 +254,11 @@ class AutomaticMetadataParser(RecipeMetadataParser):
} }
# Try to get info from Civitai # Try to get info from Civitai
if metadata_provider: if civitai_client:
try: try:
if lora_hash: if lora_hash:
# If we have hash, use it for lookup # If we have hash, use it for lookup
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = await civitai_client.get_model_by_hash(lora_hash)
else: else:
civitai_info = None civitai_info = None

View File

@@ -5,7 +5,6 @@ import logging
from typing import Dict, Any, Union from typing import Dict, Any, Union
from ..base import RecipeMetadataParser from ..base import RecipeMetadataParser
from ..constants import GEN_PARAM_KEYS from ..constants import GEN_PARAM_KEYS
from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -37,15 +36,12 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
Args: Args:
metadata: The metadata from the image (dict) metadata: The metadata from the image (dict)
recipe_scanner: Optional recipe scanner service recipe_scanner: Optional recipe scanner service
civitai_client: Optional Civitai API client (deprecated, use metadata_provider instead) civitai_client: Optional Civitai API client
Returns: Returns:
Dict containing parsed recipe data Dict containing parsed recipe data
""" """
try: try:
# Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider()
# Initialize result structure # Initialize result structure
result = { result = {
'base_model': None, 'base_model': None,
@@ -57,14 +53,6 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Track already added LoRAs to prevent duplicates # Track already added LoRAs to prevent duplicates
added_loras = {} # key: model_version_id or hash, value: index in result["loras"] added_loras = {} # key: model_version_id or hash, value: index in result["loras"]
# Extract hash information from hashes field for LoRA matching
lora_hashes = {}
if "hashes" in metadata and isinstance(metadata["hashes"], dict):
for key, hash_value in metadata["hashes"].items():
if key.startswith("LORA:"):
lora_name = key.replace("LORA:", "")
lora_hashes[lora_name] = hash_value
# Extract prompt and negative prompt # Extract prompt and negative prompt
if "prompt" in metadata: if "prompt" in metadata:
result["gen_params"]["prompt"] = metadata["prompt"] result["gen_params"]["prompt"] = metadata["prompt"]
@@ -89,9 +77,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Extract base model information - directly if available # Extract base model information - directly if available
if "baseModel" in metadata: if "baseModel" in metadata:
result["base_model"] = metadata["baseModel"] result["base_model"] = metadata["baseModel"]
elif "Model hash" in metadata and metadata_provider: elif "Model hash" in metadata and civitai_client:
model_hash = metadata["Model hash"] model_hash = metadata["Model hash"]
model_info, error = await metadata_provider.get_model_by_hash(model_hash) model_info = await civitai_client.get_model_by_hash(model_hash)
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
elif "Model" in metadata and isinstance(metadata.get("resources"), list): elif "Model" in metadata and isinstance(metadata.get("resources"), list):
@@ -99,8 +87,8 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
for resource in metadata.get("resources", []): for resource in metadata.get("resources", []):
if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"): if resource.get("type") == "model" and resource.get("name") == metadata.get("Model"):
# This is likely the checkpoint model # This is likely the checkpoint model
if metadata_provider and resource.get("hash"): if civitai_client and resource.get("hash"):
model_info, error = await metadata_provider.get_model_by_hash(resource.get("hash")) model_info = await civitai_client.get_model_by_hash(resource.get("hash"))
if model_info: if model_info:
result["base_model"] = model_info.get("baseModel", "") result["base_model"] = model_info.get("baseModel", "")
@@ -113,15 +101,6 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
if resource.get("type", "lora") == "lora": if resource.get("type", "lora") == "lora":
lora_hash = resource.get("hash", "") lora_hash = resource.get("hash", "")
# Try to get hash from the hashes field if not present in resource
if not lora_hash and resource.get("name"):
lora_hash = lora_hashes.get(resource["name"], "")
# Skip LoRAs without proper identification (hash or modelVersionId)
if not lora_hash and not resource.get("modelVersionId"):
logger.debug(f"Skipping LoRA resource '{resource.get('name', 'Unknown')}' - no hash or modelVersionId")
continue
# Skip if we've already added this LoRA by hash # Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras: if lora_hash and lora_hash in added_loras:
continue continue
@@ -142,9 +121,9 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
} }
# Try to get info from Civitai if hash is available # Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider: if lora_entry['hash'] and civitai_client:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash) civitai_info = await civitai_client.get_model_by_hash(lora_hash)
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
@@ -174,6 +153,10 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
# Process civitaiResources array # Process civitaiResources array
if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list): if "civitaiResources" in metadata and isinstance(metadata["civitaiResources"], list):
for resource in metadata["civitaiResources"]: for resource in metadata["civitaiResources"]:
# Skip resources that aren't LoRAs or LyCORIS
if resource.get("type") not in ["lora", "lycoris"] and "type" not in resource:
continue
# Get unique identifier for deduplication # Get unique identifier for deduplication
version_id = str(resource.get("modelVersionId", "")) version_id = str(resource.get("modelVersionId", ""))
@@ -198,10 +181,14 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
} }
# Try to get info from Civitai if modelVersionId is available # Try to get info from Civitai if modelVersionId is available
if version_id and metadata_provider: if version_id and civitai_client:
try: try:
# Use get_model_version_info instead of get_model_version # Use get_model_version_info instead of get_model_version
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info, error = await civitai_client.get_model_version_info(version_id)
if error:
logger.warning(f"Error getting model version info: {error}")
continue
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
@@ -259,92 +246,35 @@ class CivitaiApiMetadataParser(RecipeMetadataParser):
'isDeleted': False 'isDeleted': False
} }
# If we have a version ID and metadata provider, try to get more info # If we have a version ID and civitai client, try to get more info
if version_id and metadata_provider: if version_id and civitai_client:
try: try:
# Use get_model_version_info with the version ID # Use get_model_version_info with the version ID
civitai_info = await metadata_provider.get_model_version_info(version_id) civitai_info, error = await civitai_client.get_model_version_info(version_id)
populated_entry = await self.populate_lora_from_civitai( if error:
lora_entry, logger.warning(f"Error getting model version info: {error}")
civitai_info, else:
recipe_scanner, populated_entry = await self.populate_lora_from_civitai(
base_model_counts lora_entry,
) civitai_info,
recipe_scanner,
base_model_counts
)
if populated_entry is None: if populated_entry is None:
continue # Skip invalid LoRA types continue # Skip invalid LoRA types
lora_entry = populated_entry lora_entry = populated_entry
# Track this LoRA for deduplication # Track this LoRA for deduplication
if version_id: if version_id:
added_loras[version_id] = len(result["loras"]) added_loras[version_id] = len(result["loras"])
except Exception as e: except Exception as e:
logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}") logger.error(f"Error fetching Civitai info for model ID {version_id}: {e}")
result["loras"].append(lora_entry) result["loras"].append(lora_entry)
# Check for LoRA info in the format "Lora_0 Model hash", "Lora_0 Model name", etc.
lora_index = 0
while f"Lora_{lora_index} Model hash" in metadata and f"Lora_{lora_index} Model name" in metadata:
lora_hash = metadata[f"Lora_{lora_index} Model hash"]
lora_name = metadata[f"Lora_{lora_index} Model name"]
lora_strength_model = float(metadata.get(f"Lora_{lora_index} Strength model", 1.0))
# Skip if we've already added this LoRA by hash
if lora_hash and lora_hash in added_loras:
lora_index += 1
continue
lora_entry = {
'name': lora_name,
'type': "lora",
'weight': lora_strength_model,
'hash': lora_hash,
'existsLocally': False,
'localPath': None,
'file_name': lora_name,
'thumbnailUrl': '/loras_static/images/no-preview.png',
'baseModel': '',
'size': 0,
'downloadUrl': '',
'isDeleted': False
}
# Try to get info from Civitai if hash is available
if lora_entry['hash'] and metadata_provider:
try:
civitai_info = await metadata_provider.get_model_by_hash(lora_hash)
populated_entry = await self.populate_lora_from_civitai(
lora_entry,
civitai_info,
recipe_scanner,
base_model_counts,
lora_hash
)
if populated_entry is None:
lora_index += 1
continue # Skip invalid LoRA types
lora_entry = populated_entry
# If we have a version ID from Civitai, track it for deduplication
if 'id' in lora_entry and lora_entry['id']:
added_loras[str(lora_entry['id'])] = len(result["loras"])
except Exception as e:
logger.error(f"Error fetching Civitai info for LoRA hash {lora_entry['hash']}: {e}")
# Track by hash if we have it
if lora_hash:
added_loras[lora_hash] = len(result["loras"])
result["loras"].append(lora_entry)
lora_index += 1
# If base model wasn't found earlier, use the most common one from LoRAs # If base model wasn't found earlier, use the most common one from LoRAs
if not result["base_model"] and base_model_counts: if not result["base_model"] and base_model_counts:
result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0] result["base_model"] = max(base_model_counts.items(), key=lambda x: x[1])[0]

View File

@@ -6,7 +6,6 @@ import logging
from typing import Dict, Any from typing import Dict, Any
from ..base import RecipeMetadataParser from ..base import RecipeMetadataParser
from ..constants import GEN_PARAM_KEYS from ..constants import GEN_PARAM_KEYS
from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,9 +26,6 @@ class ComfyMetadataParser(RecipeMetadataParser):
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from Civitai ComfyUI metadata format""" """Parse metadata from Civitai ComfyUI metadata format"""
try: try:
# Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider()
data = json.loads(user_comment) data = json.loads(user_comment)
loras = [] loras = []
@@ -77,10 +73,10 @@ class ComfyMetadataParser(RecipeMetadataParser):
'isDeleted': False 'isDeleted': False
} }
# Get additional info from Civitai if metadata provider is available # Get additional info from Civitai if client is available
if metadata_provider: if civitai_client:
try: try:
civitai_info_tuple = await metadata_provider.get_model_version_info(model_version_id) civitai_info_tuple = await civitai_client.get_model_version_info(model_version_id)
# Populate lora entry with Civitai info # Populate lora entry with Civitai info
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,
@@ -120,9 +116,9 @@ class ComfyMetadataParser(RecipeMetadataParser):
} }
# Get additional checkpoint info from Civitai # Get additional checkpoint info from Civitai
if metadata_provider: if civitai_client:
try: try:
civitai_info_tuple = await metadata_provider.get_model_version_info(checkpoint_version_id) civitai_info_tuple = await civitai_client.get_model_version_info(checkpoint_version_id)
civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None) civitai_info, _ = civitai_info_tuple if isinstance(civitai_info_tuple, tuple) else (civitai_info_tuple, None)
# Populate checkpoint with Civitai info # Populate checkpoint with Civitai info
checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info) checkpoint = await self.populate_checkpoint_from_civitai(checkpoint, civitai_info)

View File

@@ -5,7 +5,6 @@ import logging
from typing import Dict, Any from typing import Dict, Any
from ..base import RecipeMetadataParser from ..base import RecipeMetadataParser
from ..constants import GEN_PARAM_KEYS from ..constants import GEN_PARAM_KEYS
from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -19,11 +18,8 @@ class MetaFormatParser(RecipeMetadataParser):
return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None return re.search(self.METADATA_MARKER, user_comment, re.IGNORECASE | re.DOTALL) is not None
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from images with meta format metadata (Lora_N Model hash format)""" """Parse metadata from images with meta format metadata"""
try: try:
# Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider()
# Extract prompt and negative prompt # Extract prompt and negative prompt
parts = user_comment.split('Negative prompt:', 1) parts = user_comment.split('Negative prompt:', 1)
prompt = parts[0].strip() prompt = parts[0].strip()
@@ -126,9 +122,9 @@ class MetaFormatParser(RecipeMetadataParser):
} }
# Get info from Civitai by hash if available # Get info from Civitai by hash if available
if metadata_provider and hash_value: if civitai_client and hash_value:
try: try:
civitai_info = await metadata_provider.get_model_by_hash(hash_value) civitai_info = await civitai_client.get_model_by_hash(hash_value)
# Populate lora entry with Civitai info # Populate lora entry with Civitai info
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,

View File

@@ -7,7 +7,6 @@ from typing import Dict, Any
from ...config import config from ...config import config
from ..base import RecipeMetadataParser from ..base import RecipeMetadataParser
from ..constants import GEN_PARAM_KEYS from ..constants import GEN_PARAM_KEYS
from ...services.metadata_service import get_default_metadata_provider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -24,9 +23,6 @@ class RecipeFormatParser(RecipeMetadataParser):
async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]: async def parse_metadata(self, user_comment: str, recipe_scanner=None, civitai_client=None) -> Dict[str, Any]:
"""Parse metadata from images with dedicated recipe metadata format""" """Parse metadata from images with dedicated recipe metadata format"""
try: try:
# Get metadata provider instead of using civitai_client directly
metadata_provider = await get_default_metadata_provider()
# Extract recipe metadata from user comment # Extract recipe metadata from user comment
try: try:
# Look for recipe metadata section # Look for recipe metadata section
@@ -59,7 +55,7 @@ class RecipeFormatParser(RecipeMetadataParser):
# Check if this LoRA exists locally by SHA256 hash # Check if this LoRA exists locally by SHA256 hash
if lora.get('hash') and recipe_scanner: if lora.get('hash') and recipe_scanner:
lora_scanner = recipe_scanner._lora_scanner lora_scanner = recipe_scanner._lora_scanner
exists_locally = lora_scanner.has_hash(lora['hash']) exists_locally = lora_scanner.has_lora_hash(lora['hash'])
if exists_locally: if exists_locally:
lora_cache = await lora_scanner.get_cached_data() lora_cache = await lora_scanner.get_cached_data()
lora_item = next((item for item in lora_cache.raw_data if item['sha256'].lower() == lora['hash'].lower()), None) lora_item = next((item for item in lora_cache.raw_data if item['sha256'].lower() == lora['hash'].lower()), None)
@@ -75,9 +71,9 @@ class RecipeFormatParser(RecipeMetadataParser):
lora_entry['localPath'] = None lora_entry['localPath'] = None
# Try to get additional info from Civitai if we have a model version ID # Try to get additional info from Civitai if we have a model version ID
if lora.get('modelVersionId') and metadata_provider: if lora.get('modelVersionId') and civitai_client:
try: try:
civitai_info_tuple = await metadata_provider.get_model_version_info(lora['modelVersionId']) civitai_info_tuple = await civitai_client.get_model_version_info(lora['modelVersionId'])
# Populate lora entry with Civitai info # Populate lora entry with Civitai info
populated_entry = await self.populate_lora_from_civitai( populated_entry = await self.populate_lora_from_civitai(
lora_entry, lora_entry,

View File

@@ -1,275 +1,619 @@
from __future__ import annotations
import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Dict, Mapping import asyncio
import json
import logging
from aiohttp import web
from typing import Dict
import jinja2 import jinja2
from aiohttp import web
from ..utils.routes_common import ModelRouteUtils
from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings
from ..config import config from ..config import config
from ..services.download_coordinator import DownloadCoordinator
from ..services.downloader import get_downloader
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService
from ..services.model_file_service import ModelFileService, ModelMoveService
from ..services.model_lifecycle_service import ModelLifecycleService
from ..services.preview_asset_service import PreviewAssetService
from ..services.server_i18n import server_i18n as default_server_i18n
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings as default_settings
from ..services.tag_update_service import TagUpdateService
from ..services.websocket_manager import ws_manager as default_ws_manager
from ..services.use_cases import (
AutoOrganizeUseCase,
BulkMetadataRefreshUseCase,
DownloadModelUseCase,
)
from ..services.websocket_progress_callback import (
WebSocketBroadcastCallback,
WebSocketProgressCallback,
)
from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
from .handlers.model_handlers import (
ModelAutoOrganizeHandler,
ModelCivitaiHandler,
ModelDownloadHandler,
ModelHandlerSet,
ModelListingHandler,
ModelManagementHandler,
ModelMoveHandler,
ModelPageView,
ModelQueryHandler,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseModelRoutes(ABC): class BaseModelRoutes(ABC):
"""Base route controller for all model types.""" """Base route controller for all model types"""
template_name: str | None = None def __init__(self, service):
"""Initialize the route controller
def __init__( Args:
self, service: Model service instance (LoraService, CheckpointService, etc.)
service=None, """
*,
settings_service=default_settings,
ws_manager=default_ws_manager,
server_i18n=default_server_i18n,
metadata_provider_factory=get_default_metadata_provider,
) -> None:
self.service = None
self.model_type = ""
self._settings = settings_service
self._ws_manager = ws_manager
self._server_i18n = server_i18n
self._metadata_provider_factory = metadata_provider_factory
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True,
)
self.model_file_service: ModelFileService | None = None
self.model_move_service: ModelMoveService | None = None
self.model_lifecycle_service: ModelLifecycleService | None = None
self.websocket_progress_callback = WebSocketProgressCallback()
self.metadata_progress_callback = WebSocketBroadcastCallback()
self._handler_set: ModelHandlerSet | None = None
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
self._preview_service = PreviewAssetService(
metadata_manager=MetadataManager,
downloader_factory=get_downloader,
exif_utils=ExifUtils,
)
self._metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=self._preview_service,
settings=settings_service,
default_metadata_provider_factory=metadata_provider_factory,
metadata_provider_selector=get_metadata_provider,
)
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
self._download_coordinator = DownloadCoordinator(
ws_manager=self._ws_manager,
download_manager_factory=ServiceRegistry.get_download_manager,
)
if service is not None:
self.attach_service(service)
def attach_service(self, service) -> None:
"""Attach a model service and rebuild handler dependencies."""
self.service = service self.service = service
self.model_type = service.model_type self.model_type = service.model_type
self.model_file_service = ModelFileService(service.scanner, service.model_type) self.template_env = jinja2.Environment(
self.model_move_service = ModelMoveService(service.scanner) loader=jinja2.FileSystemLoader(config.templates_path),
self.model_lifecycle_service = ModelLifecycleService( autoescape=True
scanner=service.scanner,
metadata_manager=MetadataManager,
metadata_loader=self._metadata_sync_service.load_local_metadata,
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
)
self._handler_set = None
self._handler_mapping = None
def _ensure_handler_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
if self._handler_mapping is None:
handler_set = self._create_handler_set()
self._handler_set = handler_set
self._handler_mapping = handler_set.to_route_mapping()
return self._handler_mapping
def _create_handler_set(self) -> ModelHandlerSet:
service = self._ensure_service()
page_view = ModelPageView(
template_env=self.template_env,
template_name=self.template_name or "",
service=service,
settings_service=self._settings,
server_i18n=self._server_i18n,
logger=logger,
)
listing = ModelListingHandler(
service=service,
parse_specific_params=self._parse_specific_params,
logger=logger,
)
management = ModelManagementHandler(
service=service,
logger=logger,
metadata_sync=self._metadata_sync_service,
preview_service=self._preview_service,
tag_update_service=self._tag_update_service,
lifecycle_service=self._ensure_lifecycle_service(),
)
query = ModelQueryHandler(service=service, logger=logger)
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
download = ModelDownloadHandler(
ws_manager=self._ws_manager,
logger=logger,
download_use_case=download_use_case,
download_coordinator=self._download_coordinator,
)
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=self._metadata_sync_service,
settings_service=self._settings,
logger=logger,
)
civitai = ModelCivitaiHandler(
service=service,
settings_service=self._settings,
ws_manager=self._ws_manager,
logger=logger,
metadata_provider_factory=self._metadata_provider_factory,
validate_model_type=self._validate_civitai_model_type,
expected_model_types=self._get_expected_model_types,
find_model_file=self._find_model_file,
metadata_sync=self._metadata_sync_service,
metadata_refresh_use_case=metadata_refresh_use_case,
metadata_progress_callback=self.metadata_progress_callback,
)
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
auto_organize_use_case = AutoOrganizeUseCase(
file_service=self._ensure_file_service(),
lock_provider=self._ws_manager,
)
auto_organize = ModelAutoOrganizeHandler(
use_case=auto_organize_use_case,
progress_callback=self.websocket_progress_callback,
ws_manager=self._ws_manager,
logger=logger,
)
return ModelHandlerSet(
page_view=page_view,
listing=listing,
management=management,
query=query,
download=download,
civitai=civitai,
move=move,
auto_organize=auto_organize,
) )
@property def setup_routes(self, app: web.Application, prefix: str):
def route_handlers(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: """Setup common routes for the model type
return self._ensure_handler_mapping()
def setup_routes(self, app: web.Application, prefix: str) -> None: Args:
registrar = ModelRouteRegistrar(app) app: aiohttp application
handler_lookup = { prefix: URL prefix (e.g., 'loras', 'checkpoints')
definition.handler_name: self._make_handler_proxy(definition.handler_name) """
for definition in COMMON_ROUTE_DEFINITIONS # Common model management routes
} app.router.add_get(f'/api/{prefix}', self.get_models)
registrar.register_common_routes(prefix, handler_lookup) app.router.add_post(f'/api/{prefix}/delete', self.delete_model)
self.setup_specific_routes(registrar, prefix) app.router.add_post(f'/api/{prefix}/exclude', self.exclude_model)
app.router.add_post(f'/api/{prefix}/fetch-civitai', self.fetch_civitai)
app.router.add_post(f'/api/{prefix}/relink-civitai', self.relink_civitai)
app.router.add_post(f'/api/{prefix}/replace-preview', self.replace_preview)
app.router.add_post(f'/api/{prefix}/save-metadata', self.save_metadata)
app.router.add_post(f'/api/{prefix}/rename', self.rename_model)
app.router.add_post(f'/api/{prefix}/bulk-delete', self.bulk_delete_models)
app.router.add_post(f'/api/{prefix}/verify-duplicates', self.verify_duplicates)
# Common query routes
app.router.add_get(f'/api/{prefix}/top-tags', self.get_top_tags)
app.router.add_get(f'/api/{prefix}/base-models', self.get_base_models)
app.router.add_get(f'/api/{prefix}/scan', self.scan_models)
app.router.add_get(f'/api/{prefix}/roots', self.get_model_roots)
app.router.add_get(f'/api/{prefix}/folders', self.get_folders)
app.router.add_get(f'/api/{prefix}/find-duplicates', self.find_duplicate_models)
app.router.add_get(f'/api/{prefix}/find-filename-conflicts', self.find_filename_conflicts)
# Common Download management
app.router.add_post(f'/api/download-model', self.download_model)
app.router.add_get(f'/api/download-model-get', self.download_model_get)
app.router.add_get(f'/api/cancel-download-get', self.cancel_download_get)
app.router.add_get(f'/api/download-progress/{{download_id}}', self.get_download_progress)
# CivitAI integration routes
app.router.add_post(f'/api/{prefix}/fetch-all-civitai', self.fetch_all_civitai)
# app.router.add_get(f'/api/civitai/versions/{{model_id}}', self.get_civitai_versions)
# Add generic page route
app.router.add_get(f'/{prefix}', self.handle_models_page)
# Setup model-specific routes
self.setup_specific_routes(app, prefix)
@abstractmethod @abstractmethod
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str) -> None: def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup model-specific routes.""" """Setup model-specific routes - to be implemented by subclasses"""
raise NotImplementedError pass
async def handle_models_page(self, request: web.Request) -> web.Response:
"""
Generic handler for model pages (e.g., /loras, /checkpoints).
Subclasses should set self.template_env and template_name.
"""
try:
# Check if the scanner is initializing
is_initializing = (
self.service.scanner._cache is None or
(hasattr(self.service.scanner, 'is_initializing') and callable(self.service.scanner.is_initializing) and self.service.scanner.is_initializing()) or
(hasattr(self.service.scanner, '_is_initializing') and self.service.scanner._is_initializing)
)
template_name = getattr(self, "template_name", None)
if not self.template_env or not template_name:
return web.Response(text="Template environment or template name not set", status=500)
if is_initializing:
rendered = self.template_env.get_template(template_name).render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
else:
try:
cache = await self.service.scanner.get_cached_data(force_refresh=False)
rendered = self.template_env.get_template(template_name).render(
folders=getattr(cache, "folders", []),
is_initializing=False,
settings=settings,
request=request
)
except Exception as cache_error:
logger.error(f"Error loading cache data: {cache_error}")
rendered = self.template_env.get_template(template_name).render(
folders=[],
is_initializing=True,
settings=settings,
request=request
)
return web.Response(
text=rendered,
content_type='text/html'
)
except Exception as e:
logger.error(f"Error handling models page: {e}", exc_info=True)
return web.Response(
text="Error loading models page",
status=500
)
async def get_models(self, request: web.Request) -> web.Response:
"""Get paginated model data"""
try:
# Parse common query parameters
params = self._parse_common_params(request)
# Get data from service
result = await self.service.get_paginated_data(**params)
# Format response items
formatted_result = {
'items': [await self.service.format_response(item) for item in result['items']],
'total': result['total'],
'page': result['page'],
'page_size': result['page_size'],
'total_pages': result['total_pages']
}
return web.json_response(formatted_result)
except Exception as e:
logger.error(f"Error in get_{self.model_type}s: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
def _parse_common_params(self, request: web.Request) -> Dict:
"""Parse common query parameters"""
# Parse basic pagination and sorting
page = int(request.query.get('page', '1'))
page_size = min(int(request.query.get('page_size', '20')), 100)
sort_by = request.query.get('sort_by', 'name')
folder = request.query.get('folder', None)
search = request.query.get('search', None)
fuzzy_search = request.query.get('fuzzy_search', 'false').lower() == 'true'
# Parse filter arrays
base_models = request.query.getall('base_model', [])
tags = request.query.getall('tag', [])
favorites_only = request.query.get('favorites_only', 'false').lower() == 'true'
# Parse search options
search_options = {
'filename': request.query.get('search_filename', 'true').lower() == 'true',
'modelname': request.query.get('search_modelname', 'true').lower() == 'true',
'tags': request.query.get('search_tags', 'false').lower() == 'true',
'recursive': request.query.get('recursive', 'false').lower() == 'true',
}
# Parse hash filters if provided
hash_filters = {}
if 'hash' in request.query:
hash_filters['single_hash'] = request.query['hash']
elif 'hashes' in request.query:
try:
hash_list = json.loads(request.query['hashes'])
if isinstance(hash_list, list):
hash_filters['multiple_hashes'] = hash_list
except (json.JSONDecodeError, TypeError):
pass
return {
'page': page,
'page_size': page_size,
'sort_by': sort_by,
'folder': folder,
'search': search,
'fuzzy_search': fuzzy_search,
'base_models': base_models,
'tags': tags,
'search_options': search_options,
'hash_filters': hash_filters,
'favorites_only': favorites_only,
# Add model-specific parameters
**self._parse_specific_params(request)
}
def _parse_specific_params(self, request: web.Request) -> Dict: def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse model-specific parameters - to be overridden by subclasses.""" """Parse model-specific parameters - to be overridden by subclasses"""
return {} return {}
def _validate_civitai_model_type(self, model_type: str) -> bool: # Common route handlers
"""Validate CivitAI model type - to be overridden by subclasses.""" async def delete_model(self, request: web.Request) -> web.Response:
return True """Handle model deletion request"""
return await ModelRouteUtils.handle_delete_model(request, self.service.scanner)
def _get_expected_model_types(self) -> str: async def exclude_model(self, request: web.Request) -> web.Response:
"""Get expected model types string for error messages - to be overridden by subclasses.""" """Handle model exclusion request"""
return "any model type" return await ModelRouteUtils.handle_exclude_model(request, self.service.scanner)
def _find_model_file(self, files): async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Find the appropriate model file from the files list - can be overridden by subclasses.""" """Handle CivitAI metadata fetch request"""
return next((file for file in files if file.get("type") == "Model" and file.get("primary") is True), None) response = await ModelRouteUtils.handle_fetch_civitai(request, self.service.scanner)
def get_handler(self, name: str) -> Callable[[web.Request], web.StreamResponse]: # If successful, format the metadata before returning
"""Expose handlers for subclasses or tests.""" if response.status == 200:
return self._ensure_handler_mapping()[name] data = json.loads(response.body.decode('utf-8'))
if data.get("success") and data.get("metadata"):
formatted_metadata = await self.service.format_response(data["metadata"])
return web.json_response({
"success": True,
"metadata": formatted_metadata
})
def _ensure_service(self): return response
if self.service is None:
raise RuntimeError("Model service has not been attached")
return self.service
def _ensure_file_service(self) -> ModelFileService: async def relink_civitai(self, request: web.Request) -> web.Response:
if self.model_file_service is None: """Handle CivitAI metadata re-linking request"""
service = self._ensure_service() return await ModelRouteUtils.handle_relink_civitai(request, self.service.scanner)
self.model_file_service = ModelFileService(service.scanner, service.model_type)
return self.model_file_service
def _ensure_move_service(self) -> ModelMoveService: async def replace_preview(self, request: web.Request) -> web.Response:
if self.model_move_service is None: """Handle preview image replacement"""
service = self._ensure_service() return await ModelRouteUtils.handle_replace_preview(request, self.service.scanner)
self.model_move_service = ModelMoveService(service.scanner)
return self.model_move_service
def _ensure_lifecycle_service(self) -> ModelLifecycleService: async def save_metadata(self, request: web.Request) -> web.Response:
if self.model_lifecycle_service is None: """Handle saving metadata updates"""
service = self._ensure_service() return await ModelRouteUtils.handle_save_metadata(request, self.service.scanner)
self.model_lifecycle_service = ModelLifecycleService(
scanner=service.scanner,
metadata_manager=MetadataManager,
metadata_loader=self._metadata_sync_service.load_local_metadata,
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
)
return self.model_lifecycle_service
def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]: async def rename_model(self, request: web.Request) -> web.Response:
async def proxy(request: web.Request) -> web.StreamResponse: """Handle renaming a model file and its associated files"""
try: return await ModelRouteUtils.handle_rename_model(request, self.service.scanner)
handler = self.get_handler(name)
except RuntimeError:
return web.json_response({"success": False, "error": "Service not ready"}, status=503)
return await handler(request)
return proxy async def bulk_delete_models(self, request: web.Request) -> web.Response:
"""Handle bulk deletion of models"""
return await ModelRouteUtils.handle_bulk_delete_models(request, self.service.scanner)
async def verify_duplicates(self, request: web.Request) -> web.Response:
"""Handle verification of duplicate model hashes"""
return await ModelRouteUtils.handle_verify_duplicates(request, self.service.scanner)
async def get_top_tags(self, request: web.Request) -> web.Response:
"""Handle request for top tags sorted by frequency"""
try:
limit = int(request.query.get('limit', '20'))
if limit < 1 or limit > 100:
limit = 20
top_tags = await self.service.get_top_tags(limit)
return web.json_response({
'success': True,
'tags': top_tags
})
except Exception as e:
logger.error(f"Error getting top tags: {str(e)}", exc_info=True)
return web.json_response({
'success': False,
'error': 'Internal server error'
}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
"""Get base models used in models"""
try:
limit = int(request.query.get('limit', '20'))
if limit < 1 or limit > 100:
limit = 20
base_models = await self.service.get_base_models(limit)
return web.json_response({
'success': True,
'base_models': base_models
})
except Exception as e:
logger.error(f"Error retrieving base models: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def scan_models(self, request: web.Request) -> web.Response:
"""Force a rescan of model files"""
try:
full_rebuild = request.query.get('full_rebuild', 'false').lower() == 'true'
await self.service.scan_models(force_refresh=True, rebuild_cache=full_rebuild)
return web.json_response({
"status": "success",
"message": f"{self.model_type.capitalize()} scan completed"
})
except Exception as e:
logger.error(f"Error in scan_{self.model_type}s: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500)
async def get_model_roots(self, request: web.Request) -> web.Response:
"""Return the model root directories"""
try:
roots = self.service.get_model_roots()
return web.json_response({
"success": True,
"roots": roots
})
except Exception as e:
logger.error(f"Error getting {self.model_type} roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_folders(self, request: web.Request) -> web.Response:
"""Get all folders in the cache"""
try:
cache = await self.service.scanner.get_cached_data()
return web.json_response({
'folders': cache.folders
})
except Exception as e:
logger.error(f"Error getting folders: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def find_duplicate_models(self, request: web.Request) -> web.Response:
"""Find models with duplicate SHA256 hashes"""
try:
# Get duplicate hashes from service
duplicates = self.service.find_duplicate_hashes()
# Format the response
result = []
cache = await self.service.scanner.get_cached_data()
for sha256, paths in duplicates.items():
group = {
"hash": sha256,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(await self.service.format_response(model))
# Add the primary model too
primary_path = self.service.get_path_by_hash(sha256)
if primary_path and primary_path not in paths:
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
if primary_model:
group["models"].insert(0, await self.service.format_response(primary_model))
if len(group["models"]) > 1: # Only include if we found multiple models
result.append(group)
return web.json_response({
"success": True,
"duplicates": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding duplicate {self.model_type}s: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
"""Find models with conflicting filenames"""
try:
# Get duplicate filenames from service
duplicates = self.service.find_duplicate_filenames()
# Format the response
result = []
cache = await self.service.scanner.get_cached_data()
for filename, paths in duplicates.items():
group = {
"filename": filename,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(await self.service.format_response(model))
# Find the model from the main index too
hash_val = self.service.scanner._hash_index.get_hash_by_filename(filename)
if hash_val:
main_path = self.service.get_path_by_hash(hash_val)
if main_path and main_path not in paths:
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
if main_model:
group["models"].insert(0, await self.service.format_response(main_model))
if group["models"]:
result.append(group)
return web.json_response({
"success": True,
"conflicts": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding filename conflicts for {self.model_type}s: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
# Download management methods
async def download_model(self, request: web.Request) -> web.Response:
"""Handle model download request"""
return await ModelRouteUtils.handle_download_model(request)
async def download_model_get(self, request: web.Request) -> web.Response:
"""Handle model download request via GET method"""
try:
# Extract query parameters
model_id = request.query.get('model_id')
if not model_id:
return web.Response(
status=400,
text="Missing required parameter: Please provide 'model_id'"
)
# Get optional parameters
model_version_id = request.query.get('model_version_id')
download_id = request.query.get('download_id')
use_default_paths = request.query.get('use_default_paths', 'false').lower() == 'true'
# Create a data dictionary that mimics what would be received from a POST request
data = {
'model_id': model_id
}
# Add optional parameters only if they are provided
if model_version_id:
data['model_version_id'] = model_version_id
if download_id:
data['download_id'] = download_id
data['use_default_paths'] = use_default_paths
# Create a mock request object with the data
future = asyncio.get_event_loop().create_future()
future.set_result(data)
mock_request = type('MockRequest', (), {
'json': lambda self=None: future
})()
# Call the existing download handler
return await ModelRouteUtils.handle_download_model(mock_request)
except Exception as e:
error_message = str(e)
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
return web.Response(status=500, text=error_message)
async def cancel_download_get(self, request: web.Request) -> web.Response:
"""Handle GET request for cancelling a download by download_id"""
try:
download_id = request.query.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
# Create a mock request with match_info for compatibility
mock_request = type('MockRequest', (), {
'match_info': {'download_id': download_id}
})()
return await ModelRouteUtils.handle_cancel_download(mock_request)
except Exception as e:
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_download_progress(self, request: web.Request) -> web.Response:
"""Handle request for download progress by download_id"""
try:
# Get download_id from URL path
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
progress_data = ws_manager.get_download_progress(download_id)
if progress_data is None:
return web.json_response({
'success': False,
'error': 'Download ID not found'
}, status=404)
return web.json_response({
'success': True,
'progress': progress_data.get('progress', 0)
})
except Exception as e:
logger.error(f"Error getting download progress: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
"""Fetch CivitAI metadata for all models in the background"""
try:
cache = await self.service.scanner.get_cached_data()
total = len(cache.raw_data)
processed = 0
success = 0
needs_resort = False
# Prepare models to process
to_process = [
model for model in cache.raw_data
if model.get('sha256') and (not model.get('civitai') or 'id' not in model.get('civitai')) and model.get('from_civitai', True)
]
total_to_process = len(to_process)
# Send initial progress
await ws_manager.broadcast({
'status': 'started',
'total': total_to_process,
'processed': 0,
'success': 0
})
# Process each model
for model in to_process:
try:
original_name = model.get('model_name')
if await ModelRouteUtils.fetch_and_update_model(
sha256=model['sha256'],
file_path=model['file_path'],
model_data=model,
update_cache_func=self.service.scanner.update_single_model_cache
):
success += 1
if original_name != model.get('model_name'):
needs_resort = True
processed += 1
# Send progress update
await ws_manager.broadcast({
'status': 'processing',
'total': total_to_process,
'processed': processed,
'success': success,
'current_name': model.get('model_name', 'Unknown')
})
except Exception as e:
logger.error(f"Error fetching CivitAI data for {model['file_path']}: {e}")
if needs_resort:
await cache.resort()
# Send completion message
await ws_manager.broadcast({
'status': 'completed',
'total': total_to_process,
'processed': processed,
'success': success
})
return web.json_response({
"success": True,
"message": f"Successfully updated {success} of {processed} processed {self.model_type}s (total: {total})"
})
except Exception as e:
# Send error message
await ws_manager.broadcast({
'status': 'error',
'error': str(e)
})
logger.error(f"Error in fetch_all_civitai for {self.model_type}s: {e}")
return web.Response(text=str(e), status=500)
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model with local availability info"""
# This will be implemented by subclasses as they need CivitAI client access
return web.json_response({
"error": "Not implemented in base class"
}, status=501)

View File

@@ -1,217 +0,0 @@
"""Base infrastructure shared across recipe routes."""
from __future__ import annotations
import logging
import os
from typing import Callable, Mapping
import jinja2
from aiohttp import web
from ..config import config
from ..recipes import RecipeParserFactory
from ..services.downloader import get_downloader
from ..services.recipes import (
RecipeAnalysisService,
RecipePersistenceService,
RecipeSharingService,
)
from ..services.server_i18n import server_i18n
from ..services.service_registry import ServiceRegistry
from ..services.settings_manager import settings
from ..utils.constants import CARD_PREVIEW_WIDTH
from ..utils.exif_utils import ExifUtils
from .handlers.recipe_handlers import (
RecipeAnalysisHandler,
RecipeHandlerSet,
RecipeListingHandler,
RecipeManagementHandler,
RecipePageView,
RecipeQueryHandler,
RecipeSharingHandler,
)
from .recipe_route_registrar import ROUTE_DEFINITIONS
logger = logging.getLogger(__name__)
class BaseRecipeRoutes:
"""Common dependency and startup wiring for recipe routes."""
_HANDLER_NAMES: tuple[str, ...] = tuple(
definition.handler_name for definition in ROUTE_DEFINITIONS
)
template_name: str = "recipes.html"
def __init__(self) -> None:
self.recipe_scanner = None
self.lora_scanner = None
self.civitai_client = None
self.settings = settings
self.server_i18n = server_i18n
self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path),
autoescape=True,
)
self._i18n_registered = False
self._startup_hooks_registered = False
self._handler_set: RecipeHandlerSet | None = None
self._handler_mapping: dict[str, Callable] | None = None
async def attach_dependencies(self, app: web.Application | None = None) -> None:
"""Resolve shared services from the registry."""
await self._ensure_services()
self._ensure_i18n_filter()
async def ensure_dependencies_ready(self) -> None:
"""Ensure dependencies are available for request handlers."""
if self.recipe_scanner is None or self.civitai_client is None:
await self.attach_dependencies()
def register_startup_hooks(self, app: web.Application) -> None:
"""Register startup hooks once for dependency wiring."""
if self._startup_hooks_registered:
return
app.on_startup.append(self.attach_dependencies)
app.on_startup.append(self.prewarm_cache)
self._startup_hooks_registered = True
async def prewarm_cache(self, app: web.Application | None = None) -> None:
"""Pre-load recipe and LoRA caches on startup."""
try:
await self.attach_dependencies(app)
if self.lora_scanner is not None:
await self.lora_scanner.get_cached_data()
hash_index = getattr(self.lora_scanner, "_hash_index", None)
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
_ = len(hash_index._hash_to_path)
if self.recipe_scanner is not None:
await self.recipe_scanner.get_cached_data(force_refresh=True)
except Exception as exc:
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
def to_route_mapping(self) -> Mapping[str, Callable]:
"""Return a mapping of handler name to coroutine for registrar binding."""
if self._handler_mapping is None:
handler_set = self._create_handler_set()
self._handler_set = handler_set
self._handler_mapping = handler_set.to_route_mapping()
return self._handler_mapping
# Internal helpers -------------------------------------------------
async def _ensure_services(self) -> None:
if self.recipe_scanner is None:
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None)
if self.civitai_client is None:
self.civitai_client = await ServiceRegistry.get_civitai_client()
def _ensure_i18n_filter(self) -> None:
if not self._i18n_registered:
self.template_env.filters["t"] = self.server_i18n.create_template_filter()
self._i18n_registered = True
def get_handler_owner(self):
"""Return the object supplying bound handler coroutines."""
if self._handler_set is None:
self._handler_set = self._create_handler_set()
return self._handler_set
def _create_handler_set(self) -> RecipeHandlerSet:
recipe_scanner_getter = lambda: self.recipe_scanner
civitai_client_getter = lambda: self.civitai_client
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
if not standalone_mode:
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
MetadataProcessor,
)
from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found]
MetadataRegistry,
)
else: # pragma: no cover - optional dependency path
get_metadata = None # type: ignore[assignment]
MetadataProcessor = None # type: ignore[assignment]
MetadataRegistry = None # type: ignore[assignment]
analysis_service = RecipeAnalysisService(
exif_utils=ExifUtils,
recipe_parser_factory=RecipeParserFactory,
downloader_factory=get_downloader,
metadata_collector=get_metadata,
metadata_processor_cls=MetadataProcessor,
metadata_registry_cls=MetadataRegistry,
standalone_mode=standalone_mode,
logger=logger,
)
persistence_service = RecipePersistenceService(
exif_utils=ExifUtils,
card_preview_width=CARD_PREVIEW_WIDTH,
logger=logger,
)
sharing_service = RecipeSharingService(logger=logger)
page_view = RecipePageView(
ensure_dependencies_ready=self.ensure_dependencies_ready,
settings_service=self.settings,
server_i18n=self.server_i18n,
template_env=self.template_env,
template_name=self.template_name,
recipe_scanner_getter=recipe_scanner_getter,
logger=logger,
)
listing = RecipeListingHandler(
ensure_dependencies_ready=self.ensure_dependencies_ready,
recipe_scanner_getter=recipe_scanner_getter,
logger=logger,
)
query = RecipeQueryHandler(
ensure_dependencies_ready=self.ensure_dependencies_ready,
recipe_scanner_getter=recipe_scanner_getter,
format_recipe_file_url=listing.format_recipe_file_url,
logger=logger,
)
management = RecipeManagementHandler(
ensure_dependencies_ready=self.ensure_dependencies_ready,
recipe_scanner_getter=recipe_scanner_getter,
logger=logger,
persistence_service=persistence_service,
analysis_service=analysis_service,
)
analysis = RecipeAnalysisHandler(
ensure_dependencies_ready=self.ensure_dependencies_ready,
recipe_scanner_getter=recipe_scanner_getter,
civitai_client_getter=civitai_client_getter,
logger=logger,
analysis_service=analysis_service,
)
sharing = RecipeSharingHandler(
ensure_dependencies_ready=self.ensure_dependencies_ready,
recipe_scanner_getter=recipe_scanner_getter,
logger=logger,
sharing_service=sharing_service,
)
return RecipeHandlerSet(
page_view=page_view,
listing=listing,
query=query,
management=management,
analysis=analysis,
sharing=sharing,
)

View File

@@ -2,10 +2,8 @@ import logging
from aiohttp import web from aiohttp import web
from .base_model_routes import BaseModelRoutes from .base_model_routes import BaseModelRoutes
from .model_route_registrar import ModelRouteRegistrar
from ..services.checkpoint_service import CheckpointService from ..services.checkpoint_service import CheckpointService
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..config import config
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -14,16 +12,19 @@ class CheckpointRoutes(BaseModelRoutes):
def __init__(self): def __init__(self):
"""Initialize Checkpoint routes with Checkpoint service""" """Initialize Checkpoint routes with Checkpoint service"""
super().__init__() # Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "checkpoints.html" self.template_name = "checkpoints.html"
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.service = CheckpointService(checkpoint_scanner) self.service = CheckpointService(checkpoint_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Attach service dependencies # Initialize parent with the service
self.attach_service(self.service) super().__init__(self.service)
def setup_routes(self, app: web.Application): def setup_routes(self, app: web.Application):
"""Setup Checkpoint routes""" """Setup Checkpoint routes"""
@@ -33,22 +34,13 @@ class CheckpointRoutes(BaseModelRoutes):
# Setup common routes with 'checkpoints' prefix (includes page route) # Setup common routes with 'checkpoints' prefix (includes page route)
super().setup_routes(app, 'checkpoints') super().setup_routes(app, 'checkpoints')
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Checkpoint-specific routes""" """Setup Checkpoint-specific routes"""
# Checkpoint-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_checkpoint)
# Checkpoint info by name # Checkpoint info by name
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info) app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_checkpoint_info)
# Checkpoint roots and Unet roots
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots)
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots)
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for Checkpoint"""
return model_type.lower() == 'checkpoint'
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "Checkpoint"
async def get_checkpoint_info(self, request: web.Request) -> web.Response: async def get_checkpoint_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific checkpoint by name""" """Get detailed information for a specific checkpoint by name"""
@@ -65,32 +57,49 @@ class CheckpointRoutes(BaseModelRoutes):
logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True) logger.error(f"Error in get_checkpoint_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500) return web.json_response({"error": str(e)}, status=500)
async def get_checkpoints_roots(self, request: web.Request) -> web.Response: async def get_civitai_versions_checkpoint(self, request: web.Request) -> web.Response:
"""Return the list of checkpoint roots from config""" """Get available versions for a Civitai checkpoint model with local availability info"""
try: try:
roots = config.checkpoints_roots model_id = request.match_info['model_id']
return web.json_response({ response = await self.civitai_client.get_model_versions(model_id)
"success": True, if not response or not response.get('modelVersions'):
"roots": roots return web.Response(status=404, text="Model not found")
})
except Exception as e:
logger.error(f"Error getting checkpoint roots: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_unet_roots(self, request: web.Request) -> web.Response: versions = response.get('modelVersions', [])
"""Return the list of unet roots from config""" model_type = response.get('type', '')
try:
roots = config.unet_roots # Check model type - should be Checkpoint
return web.json_response({ if model_type.lower() != 'checkpoint':
"success": True, return web.json_response({
"roots": roots 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
}) }, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e: except Exception as e:
logger.error(f"Error getting unet roots: {e}", exc_info=True) logger.error(f"Error fetching checkpoint model versions: {e}")
return web.json_response({ return web.Response(status=500, text=str(e))
"success": False,
"error": str(e)
}, status=500)

View File

@@ -2,7 +2,6 @@ import logging
from aiohttp import web from aiohttp import web
from .base_model_routes import BaseModelRoutes from .base_model_routes import BaseModelRoutes
from .model_route_registrar import ModelRouteRegistrar
from ..services.embedding_service import EmbeddingService from ..services.embedding_service import EmbeddingService
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
@@ -13,16 +12,19 @@ class EmbeddingRoutes(BaseModelRoutes):
def __init__(self): def __init__(self):
"""Initialize Embedding routes with Embedding service""" """Initialize Embedding routes with Embedding service"""
super().__init__() # Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "embeddings.html" self.template_name = "embeddings.html"
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
embedding_scanner = await ServiceRegistry.get_embedding_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner()
self.service = EmbeddingService(embedding_scanner) self.service = EmbeddingService(embedding_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Attach service dependencies # Initialize parent with the service
self.attach_service(self.service) super().__init__(self.service)
def setup_routes(self, app: web.Application): def setup_routes(self, app: web.Application):
"""Setup Embedding routes""" """Setup Embedding routes"""
@@ -32,18 +34,13 @@ class EmbeddingRoutes(BaseModelRoutes):
# Setup common routes with 'embeddings' prefix (includes page route) # Setup common routes with 'embeddings' prefix (includes page route)
super().setup_routes(app, 'embeddings') super().setup_routes(app, 'embeddings')
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup Embedding-specific routes""" """Setup Embedding-specific routes"""
# Embedding-specific CivitAI integration
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_embedding)
# Embedding info by name # Embedding info by name
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_embedding_info) app.router.add_get(f'/api/{prefix}/info/{{name}}', self.get_embedding_info)
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for Embedding"""
return model_type.lower() == 'textualinversion'
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "TextualInversion"
async def get_embedding_info(self, request: web.Request) -> web.Response: async def get_embedding_info(self, request: web.Request) -> web.Response:
"""Get detailed information for a specific embedding by name""" """Get detailed information for a specific embedding by name"""
@@ -59,3 +56,50 @@ class EmbeddingRoutes(BaseModelRoutes):
except Exception as e: except Exception as e:
logger.error(f"Error in get_embedding_info: {e}", exc_info=True) logger.error(f"Error in get_embedding_info: {e}", exc_info=True)
return web.json_response({"error": str(e)}, status=500) return web.json_response({"error": str(e)}, status=500)
async def get_civitai_versions_embedding(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai embedding model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be TextualInversion (Embedding)
if model_type.lower() not in ['textualinversion', 'embedding']:
return web.json_response({
'error': f"Model type mismatch. Expected TextualInversion/Embedding, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the primary model file (type="Model" and primary=true) in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model' and file.get('primary') == True), None)
# If no primary file found, try to find any model file
if not model_file:
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching embedding model versions: {e}")
return web.Response(status=500, text=str(e))

View File

@@ -1,62 +0,0 @@
"""Route registrar for example image endpoints."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Iterable, Mapping
from aiohttp import web
@dataclass(frozen=True)
class RouteDefinition:
"""Declarative configuration for a HTTP route."""
method: str
path: str
handler_name: str
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"),
RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"),
RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"),
RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"),
RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"),
RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"),
RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"),
RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"),
RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"),
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
)
class ExampleImagesRouteRegistrar:
"""Bind declarative example image routes to an aiohttp router."""
_METHOD_MAP = {
"GET": "add_get",
"POST": "add_post",
"PUT": "add_put",
"DELETE": "add_delete",
}
def __init__(self, app: web.Application) -> None:
self._app = app
def register_routes(
self,
handler_lookup: Mapping[str, Callable[[web.Request], object]],
*,
definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS,
) -> None:
"""Register each route definition using the supplied handlers."""
for definition in definitions:
handler = handler_lookup[definition.handler_name]
self._bind_route(definition.method, definition.path, handler)
def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None:
add_method_name = self._METHOD_MAP[method.upper()]
add_method = getattr(self._app.router, add_method_name)
add_method(path, handler)

View File

@@ -1,88 +1,67 @@
from __future__ import annotations
import logging import logging
from typing import Callable, Mapping from ..utils.example_images_download_manager import DownloadManager
from aiohttp import web
from .example_images_route_registrar import ExampleImagesRouteRegistrar
from .handlers.example_images_handlers import (
ExampleImagesDownloadHandler,
ExampleImagesFileHandler,
ExampleImagesHandlerSet,
ExampleImagesManagementHandler,
)
from ..services.use_cases.example_images import (
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
)
from ..utils.example_images_download_manager import (
DownloadManager,
get_default_download_manager,
)
from ..utils.example_images_file_manager import ExampleImagesFileManager
from ..utils.example_images_processor import ExampleImagesProcessor from ..utils.example_images_processor import ExampleImagesProcessor
from ..services.example_images_cleanup_service import ExampleImagesCleanupService from ..utils.example_images_file_manager import ExampleImagesFileManager
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExampleImagesRoutes: class ExampleImagesRoutes:
"""Route controller for example image endpoints.""" """Routes for example images related functionality"""
def __init__( @staticmethod
self, def setup_routes(app):
*, """Register example images routes"""
ws_manager, app.router.add_post('/api/download-example-images', ExampleImagesRoutes.download_example_images)
download_manager: DownloadManager | None = None, app.router.add_post('/api/import-example-images', ExampleImagesRoutes.import_example_images)
processor=ExampleImagesProcessor, app.router.add_get('/api/example-images-status', ExampleImagesRoutes.get_example_images_status)
file_manager=ExampleImagesFileManager, app.router.add_post('/api/pause-example-images', ExampleImagesRoutes.pause_example_images)
cleanup_service: ExampleImagesCleanupService | None = None, app.router.add_post('/api/resume-example-images', ExampleImagesRoutes.resume_example_images)
) -> None: app.router.add_post('/api/open-example-images-folder', ExampleImagesRoutes.open_example_images_folder)
if ws_manager is None: app.router.add_get('/api/example-image-files', ExampleImagesRoutes.get_example_image_files)
raise ValueError("ws_manager is required") app.router.add_get('/api/has-example-images', ExampleImagesRoutes.has_example_images)
self._download_manager = download_manager or get_default_download_manager(ws_manager) app.router.add_post('/api/delete-example-image', ExampleImagesRoutes.delete_example_image)
self._processor = processor
self._file_manager = file_manager
self._cleanup_service = cleanup_service or ExampleImagesCleanupService()
self._handler_set: ExampleImagesHandlerSet | None = None
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
@classmethod @staticmethod
def setup_routes(cls, app: web.Application, *, ws_manager) -> None: async def download_example_images(request):
"""Register routes on the given aiohttp application using default wiring.""" """Download example images for models from Civitai"""
return await DownloadManager.start_download(request)
controller = cls(ws_manager=ws_manager) @staticmethod
controller.register(app) async def get_example_images_status(request):
"""Get the current status of example images download"""
return await DownloadManager.get_status(request)
def register(self, app: web.Application) -> None: @staticmethod
"""Bind the controller's handlers to the aiohttp router.""" async def pause_example_images(request):
"""Pause the example images download"""
return await DownloadManager.pause_download(request)
registrar = ExampleImagesRouteRegistrar(app) @staticmethod
registrar.register_routes(self.to_route_mapping()) async def resume_example_images(request):
"""Resume the example images download"""
return await DownloadManager.resume_download(request)
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: @staticmethod
"""Return the registrar-compatible mapping of handler names to callables.""" async def open_example_images_folder(request):
"""Open the example images folder for a specific model"""
return await ExampleImagesFileManager.open_folder(request)
if self._handler_mapping is None: @staticmethod
handler_set = self._build_handler_set() async def get_example_image_files(request):
self._handler_set = handler_set """Get list of example image files for a specific model"""
self._handler_mapping = handler_set.to_route_mapping() return await ExampleImagesFileManager.get_files(request)
return self._handler_mapping
def _build_handler_set(self) -> ExampleImagesHandlerSet: @staticmethod
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager) async def import_example_images(request):
download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager) """Import local example images for a model"""
download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager) return await ExampleImagesProcessor.import_images(request)
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
management_handler = ExampleImagesManagementHandler( @staticmethod
import_use_case, async def has_example_images(request):
self._processor, """Check if example images folder exists and is not empty for a model"""
self._cleanup_service, return await ExampleImagesFileManager.has_images(request)
)
file_handler = ExampleImagesFileHandler(self._file_manager) @staticmethod
return ExampleImagesHandlerSet( async def delete_example_image(request):
download=download_handler, """Delete a custom example image for a model"""
management=management_handler, return await ExampleImagesProcessor.delete_custom_image(request)
files=file_handler,
)

View File

@@ -1,159 +0,0 @@
"""Handler set for example image routes."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Mapping
from aiohttp import web
from ...services.use_cases.example_images import (
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
from ...utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
DownloadNotRunningError,
ExampleImagesDownloadError,
)
from ...utils.example_images_processor import ExampleImagesImportError
class ExampleImagesDownloadHandler:
"""HTTP adapters for download-related example image endpoints."""
def __init__(
self,
download_use_case: DownloadExampleImagesUseCase,
download_manager,
) -> None:
self._download_use_case = download_use_case
self._download_manager = download_manager
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
try:
payload = await request.json()
result = await self._download_use_case.execute(payload)
return web.json_response(result)
except DownloadExampleImagesInProgressError as exc:
response = {
'success': False,
'error': str(exc),
'status': exc.progress,
}
return web.json_response(response, status=400)
except DownloadExampleImagesConfigurationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
result = await self._download_manager.get_status(request)
return web.json_response(result)
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
try:
result = await self._download_manager.pause_download(request)
return web.json_response(result)
except DownloadNotRunningError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
try:
result = await self._download_manager.resume_download(request)
return web.json_response(result)
except DownloadNotRunningError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
try:
payload = await request.json()
result = await self._download_manager.start_force_download(payload)
return web.json_response(result)
except DownloadInProgressError as exc:
response = {
'success': False,
'error': str(exc),
'status': exc.progress_snapshot,
}
return web.json_response(response, status=400)
except DownloadConfigurationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
class ExampleImagesManagementHandler:
"""HTTP adapters for import/delete endpoints."""
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor, cleanup_service) -> None:
self._import_use_case = import_use_case
self._processor = processor
self._cleanup_service = cleanup_service
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
try:
result = await self._import_use_case.execute(request)
return web.json_response(result)
except ImportExampleImagesValidationError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=400)
except ExampleImagesImportError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
return await self._processor.delete_custom_image(request)
async def cleanup_example_image_folders(self, request: web.Request) -> web.StreamResponse:
result = await self._cleanup_service.cleanup_example_image_folders()
if result.get('success') or result.get('partial_success'):
return web.json_response(result, status=200)
error_code = result.get('error_code')
status = 400 if error_code in {'path_not_configured', 'path_not_found'} else 500
return web.json_response(result, status=status)
class ExampleImagesFileHandler:
"""HTTP adapters for filesystem-centric endpoints."""
def __init__(self, file_manager) -> None:
self._file_manager = file_manager
async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse:
return await self._file_manager.open_folder(request)
async def get_example_image_files(self, request: web.Request) -> web.StreamResponse:
return await self._file_manager.get_files(request)
async def has_example_images(self, request: web.Request) -> web.StreamResponse:
return await self._file_manager.has_images(request)
@dataclass(frozen=True)
class ExampleImagesHandlerSet:
"""Aggregate of handlers exposed to the registrar."""
download: ExampleImagesDownloadHandler
management: ExampleImagesManagementHandler
files: ExampleImagesFileHandler
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
"""Flatten handler methods into the registrar mapping."""
return {
"download_example_images": self.download.download_example_images,
"get_example_images_status": self.download.get_example_images_status,
"pause_example_images": self.download.pause_example_images,
"resume_example_images": self.download.resume_example_images,
"force_download_example_images": self.download.force_download_example_images,
"import_example_images": self.management.import_example_images,
"delete_example_image": self.management.delete_example_image,
"cleanup_example_image_folders": self.management.cleanup_example_image_folders,
"open_example_images_folder": self.files.open_example_images_folder,
"get_example_image_files": self.files.get_example_image_files,
"has_example_images": self.files.has_example_images,
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,725 +0,0 @@
"""Dedicated handler objects for recipe-related routes."""
from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
from aiohttp import web
from ...config import config
from ...services.server_i18n import server_i18n as default_server_i18n
from ...services.settings_manager import SettingsManager
from ...services.recipes import (
RecipeAnalysisService,
RecipeDownloadError,
RecipeNotFoundError,
RecipePersistenceService,
RecipeSharingService,
RecipeValidationError,
)
Logger = logging.Logger
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
RecipeScannerGetter = Callable[[], Any]
CivitaiClientGetter = Callable[[], Any]
@dataclass(frozen=True)
class RecipeHandlerSet:
"""Group of handlers providing recipe route implementations."""
page_view: "RecipePageView"
listing: "RecipeListingHandler"
query: "RecipeQueryHandler"
management: "RecipeManagementHandler"
analysis: "RecipeAnalysisHandler"
sharing: "RecipeSharingHandler"
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
"""Expose handler coroutines keyed by registrar handler names."""
return {
"render_page": self.page_view.render_page,
"list_recipes": self.listing.list_recipes,
"get_recipe": self.listing.get_recipe,
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
"analyze_local_image": self.analysis.analyze_local_image,
"save_recipe": self.management.save_recipe,
"delete_recipe": self.management.delete_recipe,
"get_top_tags": self.query.get_top_tags,
"get_base_models": self.query.get_base_models,
"share_recipe": self.sharing.share_recipe,
"download_shared_recipe": self.sharing.download_shared_recipe,
"get_recipe_syntax": self.query.get_recipe_syntax,
"update_recipe": self.management.update_recipe,
"reconnect_lora": self.management.reconnect_lora,
"find_duplicates": self.query.find_duplicates,
"bulk_delete": self.management.bulk_delete,
"save_recipe_from_widget": self.management.save_recipe_from_widget,
"get_recipes_for_lora": self.query.get_recipes_for_lora,
"scan_recipes": self.query.scan_recipes,
}
class RecipePageView:
"""Render the recipe shell page."""
def __init__(
self,
*,
ensure_dependencies_ready: EnsureDependenciesCallable,
settings_service: SettingsManager,
server_i18n=default_server_i18n,
template_env,
template_name: str,
recipe_scanner_getter: RecipeScannerGetter,
logger: Logger,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._settings = settings_service
self._server_i18n = server_i18n
self._template_env = template_env
self._template_name = template_name
self._recipe_scanner_getter = recipe_scanner_getter
self._logger = logger
async def render_page(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None: # pragma: no cover - defensive guard
raise RuntimeError("Recipe scanner not available")
user_language = self._settings.get("language", "en")
self._server_i18n.set_locale(user_language)
try:
await recipe_scanner.get_cached_data(force_refresh=False)
rendered = self._template_env.get_template(self._template_name).render(
recipes=[],
is_initializing=False,
settings=self._settings,
request=request,
t=self._server_i18n.get_translation,
)
except Exception as cache_error: # pragma: no cover - logging path
self._logger.error("Error loading recipe cache data: %s", cache_error)
rendered = self._template_env.get_template(self._template_name).render(
is_initializing=True,
settings=self._settings,
request=request,
t=self._server_i18n.get_translation,
)
return web.Response(text=rendered, content_type="text/html")
except Exception as exc: # pragma: no cover - logging path
self._logger.error("Error handling recipes request: %s", exc, exc_info=True)
return web.Response(text="Error loading recipes page", status=500)
class RecipeListingHandler:
"""Provide listing and detail APIs for recipes."""
def __init__(
self,
*,
ensure_dependencies_ready: EnsureDependenciesCallable,
recipe_scanner_getter: RecipeScannerGetter,
logger: Logger,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
self._logger = logger
async def list_recipes(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
page = int(request.query.get("page", "1"))
page_size = int(request.query.get("page_size", "20"))
sort_by = request.query.get("sort_by", "date")
search = request.query.get("search")
search_options = {
"title": request.query.get("search_title", "true").lower() == "true",
"tags": request.query.get("search_tags", "true").lower() == "true",
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
}
filters: Dict[str, list[str]] = {}
base_models = request.query.get("base_models")
if base_models:
filters["base_model"] = base_models.split(",")
tags = request.query.get("tags")
if tags:
filters["tags"] = tags.split(",")
lora_hash = request.query.get("lora_hash")
result = await recipe_scanner.get_paginated_data(
page=page,
page_size=page_size,
sort_by=sort_by,
search=search,
filters=filters,
search_options=search_options,
lora_hash=lora_hash,
)
for item in result.get("items", []):
file_path = item.get("file_path")
if file_path:
item["file_url"] = self.format_recipe_file_url(file_path)
else:
item.setdefault("file_url", "/loras_static/images/no-preview.png")
item.setdefault("loras", [])
item.setdefault("base_model", "")
return web.json_response(result)
except Exception as exc:
self._logger.error("Error retrieving recipes: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def get_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
recipe_id = request.match_info["recipe_id"]
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
if not recipe:
return web.json_response({"error": "Recipe not found"}, status=404)
return web.json_response(recipe)
except Exception as exc:
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
def format_recipe_file_url(self, file_path: str) -> str:
try:
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, "/")
normalized_path = file_path.replace(os.sep, "/")
if normalized_path.startswith(recipes_dir):
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, "/")
return f"/loras_static/root1/preview/{relative_path}"
file_name = os.path.basename(file_path)
return f"/loras_static/root1/preview/recipes/{file_name}"
except Exception as exc: # pragma: no cover - logging path
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
return "/loras_static/images/no-preview.png"
class RecipeQueryHandler:
"""Provide read-only insights on recipe data."""
def __init__(
self,
*,
ensure_dependencies_ready: EnsureDependenciesCallable,
recipe_scanner_getter: RecipeScannerGetter,
format_recipe_file_url: Callable[[str], str],
logger: Logger,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
self._format_recipe_file_url = format_recipe_file_url
self._logger = logger
async def get_top_tags(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
limit = int(request.query.get("limit", "20"))
cache = await recipe_scanner.get_cached_data()
tag_counts: Dict[str, int] = {}
for recipe in getattr(cache, "raw_data", []):
for tag in recipe.get("tags", []) or []:
tag_counts[tag] = tag_counts.get(tag, 0) + 1
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
except Exception as exc:
self._logger.error("Error retrieving top tags: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_base_models(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
cache = await recipe_scanner.get_cached_data()
base_model_counts: Dict[str, int] = {}
for recipe in getattr(cache, "raw_data", []):
base_model = recipe.get("base_model")
if base_model:
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
return web.json_response({"success": True, "base_models": sorted_models})
except Exception as exc:
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
lora_hash = request.query.get("hash")
if not lora_hash:
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
return web.json_response({"success": True, "recipes": matching_recipes})
except Exception as exc:
self._logger.error("Error getting recipes for Lora: %s", exc)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def scan_recipes(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
self._logger.info("Manually triggering recipe cache rebuild")
await recipe_scanner.get_cached_data(force_refresh=True)
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
except Exception as exc:
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def find_duplicates(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
duplicate_groups = await recipe_scanner.find_all_duplicate_recipes()
response_data = []
for fingerprint, recipe_ids in duplicate_groups.items():
if len(recipe_ids) <= 1:
continue
recipes = []
for recipe_id in recipe_ids:
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
if recipe:
recipes.append(
{
"id": recipe.get("id"),
"title": recipe.get("title"),
"file_url": recipe.get("file_url")
or self._format_recipe_file_url(recipe.get("file_path", "")),
"modified": recipe.get("modified"),
"created_date": recipe.get("created_date"),
"lora_count": len(recipe.get("loras", [])),
}
)
if len(recipes) >= 2:
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
response_data.append(
{
"fingerprint": fingerprint,
"count": len(recipes),
"recipes": recipes,
}
)
response_data.sort(key=lambda entry: entry["count"], reverse=True)
return web.json_response({"success": True, "duplicate_groups": response_data})
except Exception as exc:
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
recipe_id = request.match_info["recipe_id"]
try:
syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id)
except RecipeNotFoundError:
return web.json_response({"error": "Recipe not found"}, status=404)
if not syntax_parts:
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
except Exception as exc:
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
class RecipeManagementHandler:
"""Handle create/update/delete style recipe operations."""
def __init__(
self,
*,
ensure_dependencies_ready: EnsureDependenciesCallable,
recipe_scanner_getter: RecipeScannerGetter,
logger: Logger,
persistence_service: RecipePersistenceService,
analysis_service: RecipeAnalysisService,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
self._logger = logger
self._persistence_service = persistence_service
self._analysis_service = analysis_service
async def save_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
reader = await request.multipart()
payload = await self._parse_save_payload(reader)
result = await self._persistence_service.save_recipe(
recipe_scanner=recipe_scanner,
image_bytes=payload["image_bytes"],
image_base64=payload["image_base64"],
name=payload["name"],
tags=payload["tags"],
metadata=payload["metadata"],
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def delete_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
recipe_id = request.match_info["recipe_id"]
result = await self._persistence_service.delete_recipe(
recipe_scanner=recipe_scanner, recipe_id=recipe_id
)
return web.json_response(result.payload, status=result.status)
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error deleting recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def update_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
recipe_id = request.match_info["recipe_id"]
data = await request.json()
result = await self._persistence_service.update_recipe(
recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"error": str(exc)}, status=400)
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def reconnect_lora(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
data = await request.json()
for field in ("recipe_id", "lora_index", "target_name"):
if field not in data:
raise RecipeValidationError(f"Missing required field: {field}")
result = await self._persistence_service.reconnect_lora(
recipe_scanner=recipe_scanner,
recipe_id=data["recipe_id"],
lora_index=int(data["lora_index"]),
target_name=data["target_name"],
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"error": str(exc)}, status=400)
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def bulk_delete(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
data = await request.json()
recipe_ids = data.get("recipe_ids", [])
result = await self._persistence_service.bulk_delete(
recipe_scanner=recipe_scanner, recipe_ids=recipe_ids
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=400)
except RecipeNotFoundError as exc:
return web.json_response({"success": False, "error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error performing bulk delete: %s", exc, exc_info=True)
return web.json_response({"success": False, "error": str(exc)}, status=500)
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
analysis = await self._analysis_service.analyze_widget_metadata(
recipe_scanner=recipe_scanner
)
metadata = analysis.payload.get("metadata")
image_bytes = analysis.payload.get("image_bytes")
if not metadata or image_bytes is None:
raise RecipeValidationError("Unable to extract metadata from widget")
result = await self._persistence_service.save_recipe_from_widget(
recipe_scanner=recipe_scanner,
metadata=metadata,
image_bytes=image_bytes,
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"error": str(exc)}, status=400)
except Exception as exc:
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def _parse_save_payload(self, reader) -> dict[str, Any]:
image_bytes: Optional[bytes] = None
image_base64: Optional[str] = None
name: Optional[str] = None
tags: list[str] = []
metadata: Optional[Dict[str, Any]] = None
while True:
field = await reader.next()
if field is None:
break
if field.name == "image":
image_chunks = bytearray()
while True:
chunk = await field.read_chunk()
if not chunk:
break
image_chunks.extend(chunk)
image_bytes = bytes(image_chunks)
elif field.name == "image_base64":
image_base64 = await field.text()
elif field.name == "name":
name = await field.text()
elif field.name == "tags":
tags_text = await field.text()
try:
parsed_tags = json.loads(tags_text)
tags = parsed_tags if isinstance(parsed_tags, list) else []
except Exception:
tags = []
elif field.name == "metadata":
metadata_text = await field.text()
try:
metadata = json.loads(metadata_text)
except Exception:
metadata = {}
return {
"image_bytes": image_bytes,
"image_base64": image_base64,
"name": name,
"tags": tags,
"metadata": metadata,
}
class RecipeAnalysisHandler:
"""Analyze images to extract recipe metadata."""
def __init__(
self,
*,
ensure_dependencies_ready: EnsureDependenciesCallable,
recipe_scanner_getter: RecipeScannerGetter,
civitai_client_getter: CivitaiClientGetter,
logger: Logger,
analysis_service: RecipeAnalysisService,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
self._civitai_client_getter = civitai_client_getter
self._logger = logger
self._analysis_service = analysis_service
async def analyze_uploaded_image(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
civitai_client = self._civitai_client_getter()
if recipe_scanner is None or civitai_client is None:
raise RuntimeError("Required services unavailable")
content_type = request.headers.get("Content-Type", "")
if "multipart/form-data" in content_type:
reader = await request.multipart()
field = await reader.next()
if field is None or field.name != "image":
raise RecipeValidationError("No image field found")
image_chunks = bytearray()
while True:
chunk = await field.read_chunk()
if not chunk:
break
image_chunks.extend(chunk)
result = await self._analysis_service.analyze_uploaded_image(
image_bytes=bytes(image_chunks),
recipe_scanner=recipe_scanner,
)
return web.json_response(result.payload, status=result.status)
if "application/json" in content_type:
data = await request.json()
result = await self._analysis_service.analyze_remote_image(
url=data.get("url"),
recipe_scanner=recipe_scanner,
civitai_client=civitai_client,
)
return web.json_response(result.payload, status=result.status)
raise RecipeValidationError("Unsupported content type")
except RecipeValidationError as exc:
return web.json_response({"error": str(exc), "loras": []}, status=400)
except RecipeDownloadError as exc:
return web.json_response({"error": str(exc), "loras": []}, status=400)
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc), "loras": []}, status=404)
except Exception as exc:
self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True)
return web.json_response({"error": str(exc), "loras": []}, status=500)
async def analyze_local_image(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
data = await request.json()
result = await self._analysis_service.analyze_local_image(
file_path=data.get("path"),
recipe_scanner=recipe_scanner,
)
return web.json_response(result.payload, status=result.status)
except RecipeValidationError as exc:
return web.json_response({"error": str(exc), "loras": []}, status=400)
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc), "loras": []}, status=404)
except Exception as exc:
self._logger.error("Error analyzing local image: %s", exc, exc_info=True)
return web.json_response({"error": str(exc), "loras": []}, status=500)
class RecipeSharingHandler:
"""Serve endpoints related to recipe sharing."""
def __init__(
self,
*,
ensure_dependencies_ready: EnsureDependenciesCallable,
recipe_scanner_getter: RecipeScannerGetter,
logger: Logger,
sharing_service: RecipeSharingService,
) -> None:
self._ensure_dependencies_ready = ensure_dependencies_ready
self._recipe_scanner_getter = recipe_scanner_getter
self._logger = logger
self._sharing_service = sharing_service
async def share_recipe(self, request: web.Request) -> web.Response:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
recipe_id = request.match_info["recipe_id"]
result = await self._sharing_service.share_recipe(
recipe_scanner=recipe_scanner, recipe_id=recipe_id
)
return web.json_response(result.payload, status=result.status)
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error sharing recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)
async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse:
try:
await self._ensure_dependencies_ready()
recipe_scanner = self._recipe_scanner_getter()
if recipe_scanner is None:
raise RuntimeError("Recipe scanner unavailable")
recipe_id = request.match_info["recipe_id"]
download_info = await self._sharing_service.prepare_download(
recipe_scanner=recipe_scanner, recipe_id=recipe_id
)
return web.FileResponse(
download_info.file_path,
headers={
"Content-Disposition": f'attachment; filename="{download_info.download_filename}"'
},
)
except RecipeNotFoundError as exc:
return web.json_response({"error": str(exc)}, status=404)
except Exception as exc:
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
return web.json_response({"error": str(exc)}, status=500)

View File

@@ -5,9 +5,9 @@ from typing import Dict
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from .base_model_routes import BaseModelRoutes from .base_model_routes import BaseModelRoutes
from .model_route_registrar import ModelRouteRegistrar
from ..services.lora_service import LoraService from ..services.lora_service import LoraService
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.routes_common import ModelRouteUtils
from ..utils.utils import get_lora_info from ..utils.utils import get_lora_info
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,16 +17,19 @@ class LoraRoutes(BaseModelRoutes):
def __init__(self): def __init__(self):
"""Initialize LoRA routes with LoRA service""" """Initialize LoRA routes with LoRA service"""
super().__init__() # Service will be initialized later via setup_routes
self.service = None
self.civitai_client = None
self.template_name = "loras.html" self.template_name = "loras.html"
async def initialize_services(self): async def initialize_services(self):
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
lora_scanner = await ServiceRegistry.get_lora_scanner() lora_scanner = await ServiceRegistry.get_lora_scanner()
self.service = LoraService(lora_scanner) self.service = LoraService(lora_scanner)
self.civitai_client = await ServiceRegistry.get_civitai_client()
# Attach service dependencies # Initialize parent with the service
self.attach_service(self.service) super().__init__(self.service)
def setup_routes(self, app: web.Application): def setup_routes(self, app: web.Application):
"""Setup LoRA routes""" """Setup LoRA routes"""
@@ -36,15 +39,27 @@ class LoraRoutes(BaseModelRoutes):
# Setup common routes with 'loras' prefix (includes page route) # Setup common routes with 'loras' prefix (includes page route)
super().setup_routes(app, 'loras') super().setup_routes(app, 'loras')
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str): def setup_specific_routes(self, app: web.Application, prefix: str):
"""Setup LoRA-specific routes""" """Setup LoRA-specific routes"""
# LoRA-specific query routes # LoRA-specific query routes
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts) app.router.add_get(f'/api/{prefix}/letter-counts', self.get_letter_counts)
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words) app.router.add_get(f'/api/{prefix}/get-notes', self.get_lora_notes)
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path) app.router.add_get(f'/api/{prefix}/get-trigger-words', self.get_lora_trigger_words)
app.router.add_get(f'/api/{prefix}/preview-url', self.get_lora_preview_url)
app.router.add_get(f'/api/{prefix}/civitai-url', self.get_lora_civitai_url)
app.router.add_get(f'/api/{prefix}/model-description', self.get_lora_model_description)
# LoRA-specific management routes
app.router.add_post(f'/api/{prefix}/move_model', self.move_model)
app.router.add_post(f'/api/{prefix}/move_models_bulk', self.move_models_bulk)
# CivitAI integration with LoRA-specific validation
app.router.add_get(f'/api/{prefix}/civitai/versions/{{model_id}}', self.get_civitai_versions_lora)
app.router.add_get(f'/api/{prefix}/civitai/model/version/{{modelVersionId}}', self.get_civitai_model_by_version)
app.router.add_get(f'/api/{prefix}/civitai/model/hash/{{hash}}', self.get_civitai_model_by_hash)
# ComfyUI integration # ComfyUI integration
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words) app.router.add_post(f'/api/{prefix}/get_trigger_words', self.get_trigger_words)
def _parse_specific_params(self, request: web.Request) -> Dict: def _parse_specific_params(self, request: web.Request) -> Dict:
"""Parse LoRA-specific parameters""" """Parse LoRA-specific parameters"""
@@ -70,15 +85,6 @@ class LoraRoutes(BaseModelRoutes):
return params return params
def _validate_civitai_model_type(self, model_type: str) -> bool:
"""Validate CivitAI model type for LoRA"""
from ..utils.constants import VALID_LORA_TYPES
return model_type.lower() in VALID_LORA_TYPES
def _get_expected_model_types(self) -> str:
"""Get expected model types string for error messages"""
return "LORA, LoCon, or DORA"
# LoRA-specific route handlers # LoRA-specific route handlers
async def get_letter_counts(self, request: web.Request) -> web.Response: async def get_letter_counts(self, request: web.Request) -> web.Response:
"""Get count of LoRAs for each letter of the alphabet""" """Get count of LoRAs for each letter of the alphabet"""
@@ -141,26 +147,6 @@ class LoraRoutes(BaseModelRoutes):
'error': str(e) 'error': str(e)
}, status=500) }, status=500)
async def get_lora_usage_tips_by_path(self, request: web.Request) -> web.Response:
"""Get usage tips for a LoRA by its relative path"""
try:
relative_path = request.query.get('relative_path')
if not relative_path:
return web.Response(text='Relative path is required', status=400)
usage_tips = await self.service.get_lora_usage_tips_by_relative_path(relative_path)
return web.json_response({
'success': True,
'usage_tips': usage_tips or ''
})
except Exception as e:
logger.error(f"Error getting lora usage tips by path: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_lora_preview_url(self, request: web.Request) -> web.Response: async def get_lora_preview_url(self, request: web.Request) -> web.Response:
"""Get the static preview URL for a LoRA file""" """Get the static preview URL for a LoRA file"""
try: try:
@@ -213,6 +199,258 @@ class LoraRoutes(BaseModelRoutes):
'error': str(e) 'error': str(e)
}, status=500) }, status=500)
# CivitAI integration methods
async def get_civitai_versions_lora(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai LoRA model with local availability info"""
try:
model_id = request.match_info['model_id']
response = await self.civitai_client.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.Response(status=404, text="Model not found")
versions = response.get('modelVersions', [])
model_type = response.get('type', '')
# Check model type - should be LORA, LoCon, or DORA
from ..utils.constants import VALID_LORA_TYPES
if model_type.lower() not in VALID_LORA_TYPES:
return web.json_response({
'error': f"Model type mismatch. Expected LORA or LoCon, got {model_type}"
}, status=400)
# Check local availability for each version
for version in versions:
# Find the model file (type="Model") in the files list
model_file = next((file for file in version.get('files', [])
if file.get('type') == 'Model'), None)
if model_file:
sha256 = model_file.get('hashes', {}).get('SHA256')
if sha256:
# Set existsLocally and localPath at the version level
version['existsLocally'] = self.service.has_hash(sha256)
if version['existsLocally']:
version['localPath'] = self.service.get_path_by_hash(sha256)
# Also set the model file size at the version level for easier access
version['modelSizeKB'] = model_file.get('sizeKB')
else:
# No model file found in this version
version['existsLocally'] = False
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching LoRA model versions: {e}")
return web.Response(status=500, text=str(e))
async def get_civitai_model_by_version(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by model version ID"""
try:
model_version_id = request.match_info.get('modelVersionId')
# Get model details from Civitai API
model, error_msg = await self.civitai_client.get_model_version_info(model_version_id)
if not model:
# Log warning for failed model retrieval
logger.warning(f"Failed to fetch model version {model_version_id}: {error_msg}")
# Determine status code based on error message
status_code = 404 if error_msg and "not found" in error_msg.lower() else 500
return web.json_response({
"success": False,
"error": error_msg or "Failed to fetch model information"
}, status=status_code)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def get_civitai_model_by_hash(self, request: web.Request) -> web.Response:
"""Get CivitAI model details by hash"""
try:
hash = request.match_info.get('hash')
model = await self.civitai_client.get_model_by_hash(hash)
return web.json_response(model)
except Exception as e:
logger.error(f"Error fetching model details by hash: {e}")
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
# Model management methods
async def move_model(self, request: web.Request) -> web.Response:
"""Handle model move request"""
try:
data = await request.json()
file_path = data.get('file_path') # full path of the model file
target_path = data.get('target_path') # folder path to move the model to
if not file_path or not target_path:
return web.Response(text='File path and target path are required', status=400)
# Check if source and destination are the same
import os
source_dir = os.path.dirname(file_path)
if os.path.normpath(source_dir) == os.path.normpath(target_path):
logger.info(f"Source and target directories are the same: {source_dir}")
return web.json_response({'success': True, 'message': 'Source and target directories are the same'})
# Check if target file already exists
file_name = os.path.basename(file_path)
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
if os.path.exists(target_file_path):
return web.json_response({
'success': False,
'error': f"Target file already exists: {target_file_path}"
}, status=409) # 409 Conflict
# Call scanner to handle the move operation
success = await self.service.scanner.move_model(file_path, target_path)
if success:
return web.json_response({'success': True, 'new_file_path': target_file_path})
else:
return web.Response(text='Failed to move model', status=500)
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def move_models_bulk(self, request: web.Request) -> web.Response:
"""Handle bulk model move request"""
try:
data = await request.json()
file_paths = data.get('file_paths', []) # list of full paths of the model files
target_path = data.get('target_path') # folder path to move the models to
if not file_paths or not target_path:
return web.Response(text='File paths and target path are required', status=400)
results = []
import os
for file_path in file_paths:
# Check if source and destination are the same
source_dir = os.path.dirname(file_path)
if os.path.normpath(source_dir) == os.path.normpath(target_path):
results.append({
"path": file_path,
"success": True,
"message": "Source and target directories are the same"
})
continue
# Check if target file already exists
file_name = os.path.basename(file_path)
target_file_path = os.path.join(target_path, file_name).replace(os.sep, '/')
if os.path.exists(target_file_path):
results.append({
"path": file_path,
"success": False,
"message": f"Target file already exists: {target_file_path}"
})
continue
# Try to move the model
success = await self.service.scanner.move_model(file_path, target_path)
results.append({
"path": file_path,
"success": success,
"message": "Success" if success else "Failed to move model"
})
# Count successes and failures
success_count = sum(1 for r in results if r["success"])
failure_count = len(results) - success_count
return web.json_response({
'success': True,
'message': f'Moved {success_count} of {len(file_paths)} models',
'results': results,
'success_count': success_count,
'failure_count': failure_count
})
except Exception as e:
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
return web.Response(text=str(e), status=500)
async def get_lora_model_description(self, request: web.Request) -> web.Response:
"""Get model description for a Lora model"""
try:
# Get parameters
model_id = request.query.get('model_id')
file_path = request.query.get('file_path')
if not model_id:
return web.json_response({
'success': False,
'error': 'Model ID is required'
}, status=400)
# Check if we already have the description stored in metadata
description = None
tags = []
creator = {}
if file_path:
import os
from ..utils.metadata_manager import MetadataManager
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
description = metadata.get('modelDescription')
tags = metadata.get('tags', [])
creator = metadata.get('creator', {})
# If description is not in metadata, fetch from CivitAI
if not description:
logger.info(f"Fetching model metadata for model ID: {model_id}")
model_metadata, _ = await self.civitai_client.get_model_metadata(model_id)
if model_metadata:
description = model_metadata.get('description')
tags = model_metadata.get('tags', [])
creator = model_metadata.get('creator', {})
# Save the metadata to file if we have a file path and got metadata
if file_path:
try:
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
metadata['modelDescription'] = description
metadata['tags'] = tags
# Ensure the civitai dict exists
if 'civitai' not in metadata:
metadata['civitai'] = {}
# Store creator in the civitai nested structure
metadata['civitai']['creator'] = creator
await MetadataManager.save_metadata(file_path, metadata, True)
except Exception as e:
logger.error(f"Error saving model metadata: {e}")
return web.json_response({
'success': True,
'description': description or "<p>No model description available.</p>",
'tags': tags,
'creator': creator
})
except Exception as e:
logger.error(f"Error getting model metadata: {e}")
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_trigger_words(self, request: web.Request) -> web.Response: async def get_trigger_words(self, request: web.Request) -> web.Response:
"""Get trigger words for specified LoRA models""" """Get trigger words for specified LoRA models"""
try: try:

View File

@@ -1,10 +1,9 @@
import json
import logging import logging
import os import os
import sys import sys
import threading import threading
import asyncio import asyncio
import subprocess
import re
from server import PromptServer # type: ignore from server import PromptServer # type: ignore
from aiohttp import web from aiohttp import web
from ..services.settings_manager import settings from ..services.settings_manager import settings
@@ -13,12 +12,11 @@ from ..utils.lora_metadata import extract_trained_words
from ..config import config from ..config import config
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS, NODE_TYPES, DEFAULT_NODE_COLOR
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..services.metadata_service import get_metadata_archive_manager, update_metadata_providers, get_metadata_provider import re
from ..services.websocket_manager import ws_manager
from ..services.downloader import get_downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = 'nodes' not in sys.modules
# Node registry for tracking active workflow nodes # Node registry for tracking active workflow nodes
class NodeRegistry: class NodeRegistry:
@@ -86,138 +84,76 @@ node_registry = NodeRegistry()
class MiscRoutes: class MiscRoutes:
"""Miscellaneous routes for various utility functions""" """Miscellaneous routes for various utility functions"""
@staticmethod
def is_dedicated_example_images_folder(folder_path):
"""
Check if a folder is a dedicated example images folder.
A dedicated folder should either be:
1. Empty
2. Only contain .download_progress.json file and/or folders with valid SHA256 hash names (64 hex characters)
Args:
folder_path (str): Path to the folder to check
Returns:
bool: True if the folder is dedicated, False otherwise
"""
try:
if not os.path.exists(folder_path) or not os.path.isdir(folder_path):
return False
items = os.listdir(folder_path)
# Empty folder is considered dedicated
if not items:
return True
# Check each item in the folder
for item in items:
item_path = os.path.join(folder_path, item)
# Allow .download_progress.json file
if item == '.download_progress.json' and os.path.isfile(item_path):
continue
# Allow folders with valid SHA256 hash names (64 hex characters)
if os.path.isdir(item_path):
# Check if the folder name is a valid SHA256 hash
if re.match(r'^[a-fA-F0-9]{64}$', item):
continue
# If we encounter anything else, it's not a dedicated folder
return False
return True
except Exception as e:
logger.error(f"Error checking if folder is dedicated: {e}")
return False
@staticmethod @staticmethod
def setup_routes(app): def setup_routes(app):
"""Register miscellaneous routes""" """Register miscellaneous routes"""
app.router.add_get('/api/lm/settings', MiscRoutes.get_settings) app.router.add_post('/api/settings', MiscRoutes.update_settings)
app.router.add_post('/api/lm/settings', MiscRoutes.update_settings)
app.router.add_get('/api/lm/health-check', lambda request: web.json_response({'status': 'ok'})) # Add new route for clearing cache
app.router.add_post('/api/clear-cache', MiscRoutes.clear_cache)
app.router.add_post('/api/lm/open-file-location', MiscRoutes.open_file_location) app.router.add_get('/api/health-check', lambda request: web.json_response({'status': 'ok'}))
# Usage stats routes # Usage stats routes
app.router.add_post('/api/lm/update-usage-stats', MiscRoutes.update_usage_stats) app.router.add_post('/api/update-usage-stats', MiscRoutes.update_usage_stats)
app.router.add_get('/api/lm/get-usage-stats', MiscRoutes.get_usage_stats) app.router.add_get('/api/get-usage-stats', MiscRoutes.get_usage_stats)
# Lora code update endpoint # Lora code update endpoint
app.router.add_post('/api/lm/update-lora-code', MiscRoutes.update_lora_code) app.router.add_post('/api/update-lora-code', MiscRoutes.update_lora_code)
# Add new route for getting trained words # Add new route for getting trained words
app.router.add_get('/api/lm/trained-words', MiscRoutes.get_trained_words) app.router.add_get('/api/trained-words', MiscRoutes.get_trained_words)
# Add new route for getting model example files # Add new route for getting model example files
app.router.add_get('/api/lm/model-example-files', MiscRoutes.get_model_example_files) app.router.add_get('/api/model-example-files', MiscRoutes.get_model_example_files)
# Node registry endpoints # Node registry endpoints
app.router.add_post('/api/lm/register-nodes', MiscRoutes.register_nodes) app.router.add_post('/api/register-nodes', MiscRoutes.register_nodes)
app.router.add_get('/api/lm/get-registry', MiscRoutes.get_registry) app.router.add_get('/api/get-registry', MiscRoutes.get_registry)
# Add new route for checking if a model exists in the library # Add new route for checking if a model exists in the library
app.router.add_get('/api/lm/check-model-exists', MiscRoutes.check_model_exists) app.router.add_get('/api/check-model-exists', MiscRoutes.check_model_exists)
# Add routes for metadata archive database management
app.router.add_post('/api/lm/download-metadata-archive', MiscRoutes.download_metadata_archive)
app.router.add_post('/api/lm/remove-metadata-archive', MiscRoutes.remove_metadata_archive)
app.router.add_get('/api/lm/metadata-archive-status', MiscRoutes.get_metadata_archive_status)
# Add route for checking model versions in library
app.router.add_get('/api/lm/model-versions-status', MiscRoutes.get_model_versions_status)
@staticmethod @staticmethod
async def get_settings(request): async def clear_cache(request):
"""Get application settings that should be synced to frontend""" """Clear all cache files from the cache folder"""
try: try:
# Define keys that should be synced from backend to frontend # Get the cache folder path (relative to project directory)
sync_keys = [ project_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
'civitai_api_key', cache_folder = os.path.join(project_dir, 'cache')
'default_lora_root',
'default_checkpoint_root',
'default_embedding_root',
'base_model_path_mappings',
'download_path_templates',
'enable_metadata_archive_db',
'language',
'proxy_enabled',
'proxy_type',
'proxy_host',
'proxy_port',
'proxy_username',
'proxy_password',
'example_images_path',
'optimize_example_images',
'auto_download_example_images',
'blur_mature_content',
'autoplay_on_hover',
'display_density',
'card_info_display',
'include_trigger_words',
'show_only_sfw',
'compact_mode'
]
# Build response with only the keys that should be synced # Check if cache folder exists
response_data = {} if not os.path.exists(cache_folder):
for key in sync_keys: logger.info("Cache folder does not exist, nothing to clear")
value = settings.get(key) return web.json_response({'success': True, 'message': 'No cache folder found'})
if value is not None:
response_data[key] = value # Get list of cache files before deleting for reporting
cache_files = [f for f in os.listdir(cache_folder) if os.path.isfile(os.path.join(cache_folder, f))]
deleted_files = []
# Delete each .msgpack file in the cache folder
for filename in cache_files:
if filename.endswith('.msgpack'):
file_path = os.path.join(cache_folder, filename)
try:
os.remove(file_path)
deleted_files.append(filename)
logger.info(f"Deleted cache file: {filename}")
except Exception as e:
logger.error(f"Failed to delete {filename}: {e}")
return web.json_response({
'success': False,
'error': f"Failed to delete {filename}: {str(e)}"
}, status=500)
return web.json_response({ return web.json_response({
'success': True, 'success': True,
'settings': response_data 'message': f"Successfully cleared {len(deleted_files)} cache files",
'deleted_files': deleted_files
}) })
except Exception as e: except Exception as e:
logger.error(f"Error getting settings: {e}", exc_info=True) logger.error(f"Error clearing cache files: {e}", exc_info=True)
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': str(e) 'error': str(e)
@@ -228,15 +164,10 @@ class MiscRoutes:
"""Update application settings""" """Update application settings"""
try: try:
data = await request.json() data = await request.json()
proxy_keys = {'proxy_enabled', 'proxy_host', 'proxy_port', 'proxy_username', 'proxy_password', 'proxy_type'}
proxy_changed = False
# Validate and update settings # Validate and update settings
for key, value in data.items(): for key, value in data.items():
if value == settings.get(key): # Special handling for example_images_path - verify path exists
# No change, skip
continue
# Special handling for example_images_path - verify path exists and is dedicated
if key == 'example_images_path' and value: if key == 'example_images_path' and value:
if not os.path.exists(value): if not os.path.exists(value):
return web.json_response({ return web.json_response({
@@ -244,34 +175,23 @@ class MiscRoutes:
'error': f"Path does not exist: {value}" 'error': f"Path does not exist: {value}"
}) })
# Check if folder is dedicated for example images
if not MiscRoutes.is_dedicated_example_images_folder(value):
return web.json_response({
'success': False,
'error': "Please set a dedicated folder for example images."
})
# Path changed - server restart required for new path to take effect # Path changed - server restart required for new path to take effect
old_path = settings.get('example_images_path') old_path = settings.get('example_images_path')
if old_path != value: if old_path != value:
logger.info(f"Example images path changed to {value} - server restart required") logger.info(f"Example images path changed to {value} - server restart required")
# Handle deletion for proxy credentials # Special handling for base_model_path_mappings - parse JSON string
if value == '__DELETE__' and key in ('proxy_username', 'proxy_password'): if key == 'base_model_path_mappings' and value:
settings.delete(key) try:
else: value = json.loads(value)
# Save to settings except json.JSONDecodeError:
settings.set(key, value) return web.json_response({
'success': False,
'error': f"Invalid JSON format for base_model_path_mappings: {value}"
})
if key == 'enable_metadata_archive_db': # Save to settings
await update_metadata_providers() settings.set(key, value)
if key in proxy_keys:
proxy_changed = True
if proxy_changed:
downloader = await get_downloader()
await downloader.refresh_session()
return web.json_response({'success': True}) return web.json_response({'success': True})
except Exception as e: except Exception as e:
@@ -731,13 +651,13 @@ class MiscRoutes:
exists = False exists = False
model_type = None model_type = None
if await lora_scanner.check_model_version_exists(model_version_id): if await lora_scanner.check_model_version_exists(model_id, model_version_id):
exists = True exists = True
model_type = 'lora' model_type = 'lora'
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id): elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
exists = True exists = True
model_type = 'checkpoint' model_type = 'checkpoint'
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id): elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id):
exists = True exists = True
model_type = 'embedding' model_type = 'embedding'
@@ -785,274 +705,3 @@ class MiscRoutes:
'success': False, 'success': False,
'error': str(e) 'error': str(e)
}, status=500) }, status=500)
@staticmethod
async def download_metadata_archive(request):
"""Download and extract the metadata archive database"""
try:
archive_manager = await get_metadata_archive_manager()
# Get the download_id from query parameters if provided
download_id = request.query.get('download_id')
# Progress callback to send updates via WebSocket
def progress_callback(stage, message):
data = {
'stage': stage,
'message': message,
'type': 'metadata_archive_download'
}
if download_id:
# Send to specific download WebSocket if download_id is provided
asyncio.create_task(ws_manager.broadcast_download_progress(download_id, data))
else:
# Fallback to general broadcast
asyncio.create_task(ws_manager.broadcast(data))
# Download and extract in background
success = await archive_manager.download_and_extract_database(progress_callback)
if success:
# Update settings to enable metadata archive
settings.set('enable_metadata_archive_db', True)
# Update metadata providers
await update_metadata_providers()
return web.json_response({
'success': True,
'message': 'Metadata archive database downloaded and extracted successfully'
})
else:
return web.json_response({
'success': False,
'error': 'Failed to download and extract metadata archive database'
}, status=500)
except Exception as e:
logger.error(f"Error downloading metadata archive: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def remove_metadata_archive(request):
"""Remove the metadata archive database"""
try:
archive_manager = await get_metadata_archive_manager()
success = await archive_manager.remove_database()
if success:
# Update settings to disable metadata archive
settings.set('enable_metadata_archive_db', False)
# Update metadata providers
await update_metadata_providers()
return web.json_response({
'success': True,
'message': 'Metadata archive database removed successfully'
})
else:
return web.json_response({
'success': False,
'error': 'Failed to remove metadata archive database'
}, status=500)
except Exception as e:
logger.error(f"Error removing metadata archive: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_metadata_archive_status(request):
"""Get the status of metadata archive database"""
try:
archive_manager = await get_metadata_archive_manager()
is_available = archive_manager.is_database_available()
is_enabled = settings.get('enable_metadata_archive_db', False)
db_size = 0
if is_available:
db_path = archive_manager.get_database_path()
if db_path and os.path.exists(db_path):
db_size = os.path.getsize(db_path)
return web.json_response({
'success': True,
'isAvailable': is_available,
'isEnabled': is_enabled,
'databaseSize': db_size,
'databasePath': archive_manager.get_database_path() if is_available else None
})
except Exception as e:
logger.error(f"Error getting metadata archive status: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_model_versions_status(request):
"""
Get all versions of a model from metadata provider and check their library status
Expects query parameters:
- modelId: int - Civitai model ID (required)
Returns:
- JSON with model type and versions list, each version includes 'inLibrary' flag
"""
try:
# Get the modelId from query parameters
model_id_str = request.query.get('modelId')
# Validate modelId parameter (required)
if not model_id_str:
return web.json_response({
'success': False,
'error': 'Missing required parameter: modelId'
}, status=400)
try:
# Convert modelId to integer
model_id = int(model_id_str)
except ValueError:
return web.json_response({
'success': False,
'error': 'Parameter modelId must be an integer'
}, status=400)
# Get metadata provider
metadata_provider = await get_metadata_provider()
if not metadata_provider:
return web.json_response({
'success': False,
'error': 'Metadata provider not available'
}, status=503)
# Get model versions from metadata provider
response = await metadata_provider.get_model_versions(model_id)
if not response or not response.get('modelVersions'):
return web.json_response({
'success': False,
'error': 'Model not found'
}, status=404)
versions = response.get('modelVersions', [])
model_name = response.get('name', '')
model_type = response.get('type', '').lower()
# Determine scanner based on model type
scanner = None
normalized_type = None
if model_type in ['lora', 'locon', 'dora']:
scanner = await ServiceRegistry.get_lora_scanner()
normalized_type = 'lora'
elif model_type == 'checkpoint':
scanner = await ServiceRegistry.get_checkpoint_scanner()
normalized_type = 'checkpoint'
elif model_type == 'textualinversion':
scanner = await ServiceRegistry.get_embedding_scanner()
normalized_type = 'embedding'
else:
return web.json_response({
'success': False,
'error': f'Model type "{model_type}" is not supported'
}, status=400)
if not scanner:
return web.json_response({
'success': False,
'error': f'Scanner for type "{normalized_type}" is not available'
}, status=503)
# Get local versions from scanner
local_versions = await scanner.get_model_versions_by_id(model_id)
local_version_ids = set(version['versionId'] for version in local_versions)
# Add inLibrary flag to each version
enriched_versions = []
for version in versions:
version_id = version.get('id')
enriched_version = {
'id': version_id,
'name': version.get('name', ''),
'thumbnailUrl': version.get('images')[0]['url'] if version.get('images') else None,
'inLibrary': version_id in local_version_ids
}
enriched_versions.append(enriched_version)
return web.json_response({
'success': True,
'modelId': model_id,
'modelName': model_name,
'modelType': model_type,
'versions': enriched_versions
})
except Exception as e:
logger.error(f"Failed to get model versions status: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def open_file_location(request):
"""
Open the folder containing the specified file and select the file in the file explorer.
Expects a JSON request body with:
{
"file_path": "absolute/path/to/file"
}
"""
try:
data = await request.json()
file_path = data.get('file_path')
if not file_path:
return web.json_response({
'success': False,
'error': 'Missing file_path parameter'
}, status=400)
file_path = os.path.abspath(file_path)
if not os.path.isfile(file_path):
return web.json_response({
'success': False,
'error': 'File does not exist'
}, status=404)
# Open the folder and select the file
if os.name == 'nt': # Windows
# explorer /select,"C:\path\to\file"
subprocess.Popen(['explorer', '/select,', file_path])
elif os.name == 'posix':
if sys.platform == 'darwin': # macOS
subprocess.Popen(['open', '-R', file_path])
else: # Linux (selecting file is not standard, just open folder)
folder = os.path.dirname(file_path)
subprocess.Popen(['xdg-open', folder])
return web.json_response({
'success': True,
'message': f'Opened folder and selected file: {file_path}'
})
except Exception as e:
logger.error(f"Failed to open file location: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)

View File

@@ -1,99 +0,0 @@
"""Route registrar for model endpoints."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Iterable, Mapping
from aiohttp import web
@dataclass(frozen=True)
class RouteDefinition:
"""Declarative definition for a HTTP route."""
method: str
path_template: str
handler_name: str
def build_path(self, prefix: str) -> str:
return self.path_template.replace("{prefix}", prefix)
COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/api/lm/{prefix}/list", "get_models"),
RouteDefinition("POST", "/api/lm/{prefix}/delete", "delete_model"),
RouteDefinition("POST", "/api/lm/{prefix}/exclude", "exclude_model"),
RouteDefinition("POST", "/api/lm/{prefix}/fetch-civitai", "fetch_civitai"),
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"),
RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"),
RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"),
RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"),
RouteDefinition("POST", "/api/lm/{prefix}/bulk-delete", "bulk_delete_models"),
RouteDefinition("POST", "/api/lm/{prefix}/verify-duplicates", "verify_duplicates"),
RouteDefinition("POST", "/api/lm/{prefix}/move_model", "move_model"),
RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"),
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"),
RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"),
RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"),
RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"),
RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"),
RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"),
RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"),
RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"),
RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
)
class ModelRouteRegistrar:
"""Bind declarative definitions to an aiohttp router."""
_METHOD_MAP = {
"GET": "add_get",
"POST": "add_post",
"PUT": "add_put",
"DELETE": "add_delete",
}
def __init__(self, app: web.Application) -> None:
self._app = app
def register_common_routes(
self,
prefix: str,
handler_lookup: Mapping[str, Callable[[web.Request], object]],
*,
definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS,
) -> None:
for definition in definitions:
self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name])
def add_route(self, method: str, path: str, handler: Callable) -> None:
self._bind_route(method, path, handler)
def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None:
self._bind_route(method, path_template.replace("{prefix}", prefix), handler)
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
add_method_name = self._METHOD_MAP[method.upper()]
add_method = getattr(self._app.router, add_method_name)
add_method(path, handler)

View File

@@ -1,64 +0,0 @@
"""Route registrar for recipe endpoints."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Mapping
from aiohttp import web
@dataclass(frozen=True)
class RouteDefinition:
"""Declarative definition for a recipe HTTP route."""
method: str
path: str
handler_name: str
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("GET", "/loras/recipes", "render_page"),
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
)
class RecipeRouteRegistrar:
"""Bind declarative recipe definitions to an aiohttp router."""
_METHOD_MAP = {
"GET": "add_get",
"POST": "add_post",
"PUT": "add_put",
"DELETE": "add_delete",
}
def __init__(self, app: web.Application) -> None:
self._app = app
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
for definition in ROUTE_DEFINITIONS:
handler = handler_lookup[definition.handler_name]
self._bind_route(definition.method, definition.path, handler)
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
add_method_name = self._METHOD_MAP[method.upper()]
add_method = getattr(self._app.router, add_method_name)
add_method(path, handler)

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,6 @@ from typing import Dict, List, Any
from ..config import config from ..config import config
from ..services.settings_manager import settings from ..services.settings_manager import settings
from ..services.server_i18n import server_i18n
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.usage_stats import UsageStats from ..utils.usage_stats import UsageStats
@@ -21,7 +20,6 @@ class StatsRoutes:
def __init__(self): def __init__(self):
self.lora_scanner = None self.lora_scanner = None
self.checkpoint_scanner = None self.checkpoint_scanner = None
self.embedding_scanner = None
self.usage_stats = None self.usage_stats = None
self.template_env = jinja2.Environment( self.template_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(config.templates_path), loader=jinja2.FileSystemLoader(config.templates_path),
@@ -32,14 +30,7 @@ class StatsRoutes:
"""Initialize services from ServiceRegistry""" """Initialize services from ServiceRegistry"""
self.lora_scanner = await ServiceRegistry.get_lora_scanner() self.lora_scanner = await ServiceRegistry.get_lora_scanner()
self.checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() self.checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
self.embedding_scanner = await ServiceRegistry.get_embedding_scanner() self.usage_stats = UsageStats()
# Only initialize usage stats if we have valid paths configured
try:
self.usage_stats = UsageStats()
except RuntimeError as e:
logger.warning(f"Could not initialize usage statistics: {e}")
self.usage_stats = None
async def handle_stats_page(self, request: web.Request) -> web.Response: async def handle_stats_page(self, request: web.Request) -> web.Response:
"""Handle GET /statistics request""" """Handle GET /statistics request"""
@@ -58,30 +49,13 @@ class StatsRoutes:
(hasattr(self.checkpoint_scanner, '_is_initializing') and self.checkpoint_scanner._is_initializing) (hasattr(self.checkpoint_scanner, '_is_initializing') and self.checkpoint_scanner._is_initializing)
) )
embedding_initializing = ( is_initializing = lora_initializing or checkpoint_initializing
self.embedding_scanner._cache is None or
(hasattr(self.embedding_scanner, 'is_initializing') and self.embedding_scanner.is_initializing())
)
is_initializing = lora_initializing or checkpoint_initializing or embedding_initializing
# 获取用户语言设置
user_language = settings.get('language', 'en')
# 设置服务端i18n语言
server_i18n.set_locale(user_language)
# 为模板环境添加i18n过滤器
if not hasattr(self.template_env, '_i18n_filter_added'):
self.template_env.filters['t'] = server_i18n.create_template_filter()
self.template_env._i18n_filter_added = True
template = self.template_env.get_template('statistics.html') template = self.template_env.get_template('statistics.html')
rendered = template.render( rendered = template.render(
is_initializing=is_initializing, is_initializing=is_initializing,
settings=settings, settings=settings,
request=request, request=request
t=server_i18n.get_translation,
) )
return web.Response( return web.Response(
@@ -111,29 +85,21 @@ class StatsRoutes:
checkpoint_count = len(checkpoint_cache.raw_data) checkpoint_count = len(checkpoint_cache.raw_data)
checkpoint_size = sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data) checkpoint_size = sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
# Get Embedding statistics
embedding_cache = await self.embedding_scanner.get_cached_data()
embedding_count = len(embedding_cache.raw_data)
embedding_size = sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
# Get usage statistics # Get usage statistics
usage_data = await self.usage_stats.get_stats() usage_data = await self.usage_stats.get_stats()
return web.json_response({ return web.json_response({
'success': True, 'success': True,
'data': { 'data': {
'total_models': lora_count + checkpoint_count + embedding_count, 'total_models': lora_count + checkpoint_count,
'lora_count': lora_count, 'lora_count': lora_count,
'checkpoint_count': checkpoint_count, 'checkpoint_count': checkpoint_count,
'embedding_count': embedding_count, 'total_size': lora_size + checkpoint_size,
'total_size': lora_size + checkpoint_size + embedding_size,
'lora_size': lora_size, 'lora_size': lora_size,
'checkpoint_size': checkpoint_size, 'checkpoint_size': checkpoint_size,
'embedding_size': embedding_size,
'total_generations': usage_data.get('total_executions', 0), 'total_generations': usage_data.get('total_executions', 0),
'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})), 'unused_loras': self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})),
'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})), 'unused_checkpoints': self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
'unused_embeddings': self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
} }
}) })
@@ -155,17 +121,14 @@ class StatsRoutes:
# Get model data for enrichment # Get model data for enrichment
lora_cache = await self.lora_scanner.get_cached_data() lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data() checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Create hash to model mapping # Create hash to model mapping
lora_map = {lora['sha256']: lora for lora in lora_cache.raw_data} lora_map = {lora['sha256']: lora for lora in lora_cache.raw_data}
checkpoint_map = {cp['sha256']: cp for cp in checkpoint_cache.raw_data} checkpoint_map = {cp['sha256']: cp for cp in checkpoint_cache.raw_data}
embedding_map = {emb['sha256']: emb for emb in embedding_cache.raw_data}
# Prepare top used models # Prepare top used models
top_loras = self._get_top_used_models(usage_data.get('loras', {}), lora_map, 10) top_loras = self._get_top_used_models(usage_data.get('loras', {}), lora_map, 10)
top_checkpoints = self._get_top_used_models(usage_data.get('checkpoints', {}), checkpoint_map, 10) top_checkpoints = self._get_top_used_models(usage_data.get('checkpoints', {}), checkpoint_map, 10)
top_embeddings = self._get_top_used_models(usage_data.get('embeddings', {}), embedding_map, 10)
# Prepare usage timeline (last 30 days) # Prepare usage timeline (last 30 days)
timeline = self._get_usage_timeline(usage_data, 30) timeline = self._get_usage_timeline(usage_data, 30)
@@ -175,7 +138,6 @@ class StatsRoutes:
'data': { 'data': {
'top_loras': top_loras, 'top_loras': top_loras,
'top_checkpoints': top_checkpoints, 'top_checkpoints': top_checkpoints,
'top_embeddings': top_embeddings,
'usage_timeline': timeline, 'usage_timeline': timeline,
'total_executions': usage_data.get('total_executions', 0) 'total_executions': usage_data.get('total_executions', 0)
} }
@@ -196,19 +158,16 @@ class StatsRoutes:
# Get model data # Get model data
lora_cache = await self.lora_scanner.get_cached_data() lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data() checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Count by base model # Count by base model
lora_base_models = Counter(lora.get('base_model', 'Unknown') for lora in lora_cache.raw_data) lora_base_models = Counter(lora.get('base_model', 'Unknown') for lora in lora_cache.raw_data)
checkpoint_base_models = Counter(cp.get('base_model', 'Unknown') for cp in checkpoint_cache.raw_data) checkpoint_base_models = Counter(cp.get('base_model', 'Unknown') for cp in checkpoint_cache.raw_data)
embedding_base_models = Counter(emb.get('base_model', 'Unknown') for emb in embedding_cache.raw_data)
return web.json_response({ return web.json_response({
'success': True, 'success': True,
'data': { 'data': {
'loras': dict(lora_base_models), 'loras': dict(lora_base_models),
'checkpoints': dict(checkpoint_base_models), 'checkpoints': dict(checkpoint_base_models)
'embeddings': dict(embedding_base_models)
} }
}) })
@@ -227,7 +186,6 @@ class StatsRoutes:
# Get model data # Get model data
lora_cache = await self.lora_scanner.get_cached_data() lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data() checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Count tag frequencies # Count tag frequencies
all_tags = [] all_tags = []
@@ -235,8 +193,6 @@ class StatsRoutes:
all_tags.extend(lora.get('tags', [])) all_tags.extend(lora.get('tags', []))
for cp in checkpoint_cache.raw_data: for cp in checkpoint_cache.raw_data:
all_tags.extend(cp.get('tags', [])) all_tags.extend(cp.get('tags', []))
for emb in embedding_cache.raw_data:
all_tags.extend(emb.get('tags', []))
tag_counts = Counter(all_tags) tag_counts = Counter(all_tags)
@@ -269,7 +225,6 @@ class StatsRoutes:
# Get model data # Get model data
lora_cache = await self.lora_scanner.get_cached_data() lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data() checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
# Create models with usage data # Create models with usage data
lora_storage = [] lora_storage = []
@@ -300,31 +255,15 @@ class StatsRoutes:
'base_model': cp.get('base_model', 'Unknown') 'base_model': cp.get('base_model', 'Unknown')
}) })
embedding_storage = []
for emb in embedding_cache.raw_data:
usage_count = 0
if emb['sha256'] in usage_data.get('embeddings', {}):
usage_count = usage_data['embeddings'][emb['sha256']].get('total', 0)
embedding_storage.append({
'name': emb['model_name'],
'size': emb.get('size', 0),
'usage_count': usage_count,
'folder': emb.get('folder', ''),
'base_model': emb.get('base_model', 'Unknown')
})
# Sort by size # Sort by size
lora_storage.sort(key=lambda x: x['size'], reverse=True) lora_storage.sort(key=lambda x: x['size'], reverse=True)
checkpoint_storage.sort(key=lambda x: x['size'], reverse=True) checkpoint_storage.sort(key=lambda x: x['size'], reverse=True)
embedding_storage.sort(key=lambda x: x['size'], reverse=True)
return web.json_response({ return web.json_response({
'success': True, 'success': True,
'data': { 'data': {
'loras': lora_storage[:20], # Top 20 by size 'loras': lora_storage[:20], # Top 20 by size
'checkpoints': checkpoint_storage[:20], 'checkpoints': checkpoint_storage[:20]
'embeddings': embedding_storage[:20]
} }
}) })
@@ -346,18 +285,15 @@ class StatsRoutes:
# Get model data # Get model data
lora_cache = await self.lora_scanner.get_cached_data() lora_cache = await self.lora_scanner.get_cached_data()
checkpoint_cache = await self.checkpoint_scanner.get_cached_data() checkpoint_cache = await self.checkpoint_scanner.get_cached_data()
embedding_cache = await self.embedding_scanner.get_cached_data()
insights = [] insights = []
# Calculate unused models # Calculate unused models
unused_loras = self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {})) unused_loras = self._count_unused_models(lora_cache.raw_data, usage_data.get('loras', {}))
unused_checkpoints = self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {})) unused_checkpoints = self._count_unused_models(checkpoint_cache.raw_data, usage_data.get('checkpoints', {}))
unused_embeddings = self._count_unused_models(embedding_cache.raw_data, usage_data.get('embeddings', {}))
total_loras = len(lora_cache.raw_data) total_loras = len(lora_cache.raw_data)
total_checkpoints = len(checkpoint_cache.raw_data) total_checkpoints = len(checkpoint_cache.raw_data)
total_embeddings = len(embedding_cache.raw_data)
if total_loras > 0: if total_loras > 0:
unused_lora_percent = (unused_loras / total_loras) * 100 unused_lora_percent = (unused_loras / total_loras) * 100
@@ -379,20 +315,9 @@ class StatsRoutes:
'suggestion': 'Review and consider removing checkpoints you no longer need.' 'suggestion': 'Review and consider removing checkpoints you no longer need.'
}) })
if total_embeddings > 0:
unused_embedding_percent = (unused_embeddings / total_embeddings) * 100
if unused_embedding_percent > 50:
insights.append({
'type': 'warning',
'title': 'High Number of Unused Embeddings',
'description': f'{unused_embedding_percent:.1f}% of your embeddings ({unused_embeddings}/{total_embeddings}) have never been used.',
'suggestion': 'Consider organizing or archiving unused embeddings to optimize your collection.'
})
# Storage insights # Storage insights
total_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data) + \ total_size = sum(lora.get('size', 0) for lora in lora_cache.raw_data) + \
sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data) + \ sum(cp.get('size', 0) for cp in checkpoint_cache.raw_data)
sum(emb.get('size', 0) for emb in embedding_cache.raw_data)
if total_size > 100 * 1024 * 1024 * 1024: # 100GB if total_size > 100 * 1024 * 1024 * 1024: # 100GB
insights.append({ insights.append({
@@ -465,7 +390,6 @@ class StatsRoutes:
lora_usage = 0 lora_usage = 0
checkpoint_usage = 0 checkpoint_usage = 0
embedding_usage = 0
# Count usage for this date # Count usage for this date
for model_usage in usage_data.get('loras', {}).values(): for model_usage in usage_data.get('loras', {}).values():
@@ -476,16 +400,11 @@ class StatsRoutes:
if isinstance(model_usage, dict) and 'history' in model_usage: if isinstance(model_usage, dict) and 'history' in model_usage:
checkpoint_usage += model_usage['history'].get(date_str, 0) checkpoint_usage += model_usage['history'].get(date_str, 0)
for model_usage in usage_data.get('embeddings', {}).values():
if isinstance(model_usage, dict) and 'history' in model_usage:
embedding_usage += model_usage['history'].get(date_str, 0)
timeline.append({ timeline.append({
'date': date_str, 'date': date_str,
'lora_usage': lora_usage, 'lora_usage': lora_usage,
'checkpoint_usage': checkpoint_usage, 'checkpoint_usage': checkpoint_usage,
'embedding_usage': embedding_usage, 'total_usage': lora_usage + checkpoint_usage
'total_usage': lora_usage + checkpoint_usage + embedding_usage
}) })
return list(reversed(timeline)) # Oldest to newest return list(reversed(timeline)) # Oldest to newest
@@ -507,12 +426,12 @@ class StatsRoutes:
app.router.add_get('/statistics', self.handle_stats_page) app.router.add_get('/statistics', self.handle_stats_page)
# Register API routes # Register API routes
app.router.add_get('/api/lm/stats/collection-overview', self.get_collection_overview) app.router.add_get('/api/stats/collection-overview', self.get_collection_overview)
app.router.add_get('/api/lm/stats/usage-analytics', self.get_usage_analytics) app.router.add_get('/api/stats/usage-analytics', self.get_usage_analytics)
app.router.add_get('/api/lm/stats/base-model-distribution', self.get_base_model_distribution) app.router.add_get('/api/stats/base-model-distribution', self.get_base_model_distribution)
app.router.add_get('/api/lm/stats/tag-analytics', self.get_tag_analytics) app.router.add_get('/api/stats/tag-analytics', self.get_tag_analytics)
app.router.add_get('/api/lm/stats/storage-analytics', self.get_storage_analytics) app.router.add_get('/api/stats/storage-analytics', self.get_storage_analytics)
app.router.add_get('/api/lm/stats/insights', self.get_insights) app.router.add_get('/api/stats/insights', self.get_insights)
async def _on_startup(self, app): async def _on_startup(self, app):
"""Initialize services when the app starts""" """Initialize services when the app starts"""

View File

@@ -1,13 +1,13 @@
import os import os
import subprocess
import aiohttp
import logging import logging
import toml import toml
import git import git
import zipfile from datetime import datetime
import shutil
import tempfile
from aiohttp import web from aiohttp import web
from typing import Dict, List from typing import Dict, List
from ..services.downloader import get_downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -17,9 +17,9 @@ class UpdateRoutes:
@staticmethod @staticmethod
def setup_routes(app): def setup_routes(app):
"""Register update check routes""" """Register update check routes"""
app.router.add_get('/api/lm/check-updates', UpdateRoutes.check_updates) app.router.add_get('/api/check-updates', UpdateRoutes.check_updates)
app.router.add_get('/api/lm/version-info', UpdateRoutes.get_version_info) app.router.add_get('/api/version-info', UpdateRoutes.get_version_info)
app.router.add_post('/api/lm/perform-update', UpdateRoutes.perform_update) app.router.add_post('/api/perform-update', UpdateRoutes.perform_update)
@staticmethod @staticmethod
async def check_updates(request): async def check_updates(request):
@@ -101,16 +101,18 @@ class UpdateRoutes:
@staticmethod @staticmethod
async def perform_update(request): async def perform_update(request):
""" """
Perform Git-based update to latest release tag or main branch. Perform Git-based update to latest release tag or main branch
If .git is missing, fallback to ZIP download.
""" """
try: try:
# Parse request body
body = await request.json() if request.has_body else {} body = await request.json() if request.has_body else {}
nightly = body.get('nightly', False) nightly = body.get('nightly', False)
# Get current plugin directory
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
plugin_root = os.path.dirname(os.path.dirname(current_dir)) plugin_root = os.path.dirname(os.path.dirname(current_dir))
# Backup settings.json if it exists
settings_path = os.path.join(plugin_root, 'settings.json') settings_path = os.path.join(plugin_root, 'settings.json')
settings_backup = None settings_backup = None
if os.path.exists(settings_path): if os.path.exists(settings_path):
@@ -118,14 +120,10 @@ class UpdateRoutes:
settings_backup = f.read() settings_backup = f.read()
logger.info("Backed up settings.json") logger.info("Backed up settings.json")
git_folder = os.path.join(plugin_root, '.git') # Perform Git update
if os.path.exists(git_folder): success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
# Git update
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
else:
# Fallback: Download ZIP and replace files
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
# Restore settings.json if we backed it up
if settings_backup and success: if settings_backup and success:
with open(settings_path, 'w', encoding='utf-8') as f: with open(settings_path, 'w', encoding='utf-8') as f:
f.write(settings_backup) f.write(settings_backup)
@@ -140,7 +138,7 @@ class UpdateRoutes:
else: else:
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': 'Failed to complete update' 'error': 'Failed to complete Git update'
}) })
except Exception as e: except Exception as e:
@@ -150,109 +148,6 @@ class UpdateRoutes:
'error': str(e) 'error': str(e)
}) })
@staticmethod
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
"""
Download latest release ZIP from GitHub and replace plugin files.
Skips settings.json and civitai folder. Writes extracted file list to .tracking.
"""
repo_owner = "willmiao"
repo_name = "ComfyUI-Lora-Manager"
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
try:
downloader = await get_downloader()
# Get release info
success, data = await downloader.make_request(
'GET',
github_api,
use_auth=False
)
if not success:
logger.error(f"Failed to fetch release info: {data}")
return False, ""
zip_url = data.get("zipball_url")
version = data.get("tag_name", "unknown")
# Download ZIP to temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
tmp_zip_path = tmp_zip.name
success, result = await downloader.download_file(
url=zip_url,
save_path=tmp_zip_path,
use_auth=False,
allow_resume=False
)
if not success:
logger.error(f"Failed to download ZIP: {result}")
return False, ""
zip_path = tmp_zip_path
# Skip both settings.json and civitai folder
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai'])
# Extract ZIP to temp dir
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir)
# Find extracted folder (GitHub ZIP contains a root folder)
extracted_root = next(os.scandir(tmp_dir)).path
# Copy files, skipping settings.json and civitai folder
for item in os.listdir(extracted_root):
if item == 'settings.json' or item == 'civitai':
continue
src = os.path.join(extracted_root, item)
dst = os.path.join(plugin_root, item)
if os.path.isdir(src):
if os.path.exists(dst):
shutil.rmtree(dst)
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json', 'civitai'))
else:
shutil.copy2(src, dst)
# Write .tracking file: list all files under extracted_root, relative to extracted_root
# for ComfyUI Manager to work properly
tracking_info_file = os.path.join(plugin_root, '.tracking')
tracking_files = []
for root, dirs, files in os.walk(extracted_root):
# Skip civitai folder and its contents
rel_root = os.path.relpath(root, extracted_root)
if rel_root == 'civitai' or rel_root.startswith('civitai' + os.sep):
continue
for file in files:
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
# Skip settings.json and any file under civitai
if rel_path == 'settings.json' or rel_path.startswith('civitai' + os.sep):
continue
tracking_files.append(rel_path.replace("\\", "/"))
with open(tracking_info_file, "w", encoding='utf-8') as file:
file.write('\n'.join(tracking_files))
os.remove(zip_path)
logger.info(f"Updated plugin via ZIP to {version}")
return True, version
except Exception as e:
logger.error(f"ZIP update failed: {e}", exc_info=True)
return False, ""
def _clean_plugin_folder(plugin_root, skip_files=None):
skip_files = skip_files or []
for item in os.listdir(plugin_root):
if item in skip_files:
continue
path = os.path.join(plugin_root, item)
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)
@staticmethod @staticmethod
async def _get_nightly_version() -> tuple[str, List[str]]: async def _get_nightly_version() -> tuple[str, List[str]]:
""" """
@@ -265,23 +160,23 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main" github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
try: try:
downloader = await get_downloader() async with aiohttp.ClientSession() as session:
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'}) async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
if response.status != 200:
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
return "main", []
if not success: data = await response.json()
logger.warning(f"Failed to fetch GitHub commit: {data}") commit_sha = data.get('sha', '')[:7] # Short hash
return "main", [] commit_message = data.get('commit', {}).get('message', '')
commit_sha = data.get('sha', '')[:7] # Short hash # Format as "main-{short_hash}"
commit_message = data.get('commit', {}).get('message', '') version = f"main-{commit_sha}"
# Format as "main-{short_hash}" # Use commit message as changelog
version = f"main-{commit_sha}" changelog = [commit_message] if commit_message else []
# Use commit message as changelog return version, changelog
changelog = [commit_message] if commit_message else []
return version, changelog
except Exception as e: except Exception as e:
logger.error(f"Error fetching nightly version: {e}", exc_info=True) logger.error(f"Error fetching nightly version: {e}", exc_info=True)
@@ -396,7 +291,7 @@ class UpdateRoutes:
git_info = { git_info = {
'commit_hash': 'unknown', 'commit_hash': 'unknown',
'short_hash': 'stable', 'short_hash': 'unknown',
'branch': 'unknown', 'branch': 'unknown',
'commit_date': 'unknown' 'commit_date': 'unknown'
} }
@@ -406,12 +301,49 @@ class UpdateRoutes:
if not os.path.exists(os.path.join(plugin_root, '.git')): if not os.path.exists(os.path.join(plugin_root, '.git')):
return git_info return git_info
repo = git.Repo(plugin_root) # Get current commit hash
commit = repo.head.commit result = subprocess.run(
git_info['commit_hash'] = commit.hexsha ['git', 'rev-parse', 'HEAD'],
git_info['short_hash'] = commit.hexsha[:7] cwd=plugin_root,
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached' stdout=subprocess.PIPE,
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d') stderr=subprocess.PIPE,
text=True,
check=False
)
if result.returncode == 0:
git_info['commit_hash'] = result.stdout.strip()
git_info['short_hash'] = git_info['commit_hash'][:7]
# Get current branch name
result = subprocess.run(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
cwd=plugin_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False
)
if result.returncode == 0:
git_info['branch'] = result.stdout.strip()
# Get commit date
result = subprocess.run(
['git', 'show', '-s', '--format=%ci', 'HEAD'],
cwd=plugin_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False
)
if result.returncode == 0:
commit_date = result.stdout.strip()
# Format the date nicely if possible
try:
date_obj = datetime.strptime(commit_date, '%Y-%m-%d %H:%M:%S %z')
git_info['commit_date'] = date_obj.strftime('%Y-%m-%d')
except:
git_info['commit_date'] = commit_date
except Exception as e: except Exception as e:
logger.warning(f"Error getting git info: {e}") logger.warning(f"Error getting git info: {e}")
@@ -431,22 +363,22 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
try: try:
downloader = await get_downloader() async with aiohttp.ClientSession() as session:
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'}) async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
if response.status != 200:
logger.warning(f"Failed to fetch GitHub release: {response.status}")
return "v0.0.0", []
if not success: data = await response.json()
logger.warning(f"Failed to fetch GitHub release: {data}") version = data.get('tag_name', '')
return "v0.0.0", [] if not version.startswith('v'):
version = f"v{version}"
version = data.get('tag_name', '') # Extract changelog from release notes
if not version.startswith('v'): body = data.get('body', '')
version = f"v{version}" changelog = UpdateRoutes._parse_changelog(body)
# Extract changelog from release notes return version, changelog
body = data.get('body', '')
changelog = UpdateRoutes._parse_changelog(body)
return version, changelog
except Exception as e: except Exception as e:
logger.error(f"Error fetching remote version: {e}", exc_info=True) logger.error(f"Error fetching remote version: {e}", exc_info=True)

View File

@@ -1,92 +1,101 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Type from typing import Dict, List, Optional, Type
import logging import logging
import os
from ..utils.models import BaseModelMetadata from ..utils.models import BaseModelMetadata
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider from ..utils.constants import NSFW_LEVELS
from .settings_manager import settings as default_settings from .settings_manager import settings
from ..utils.utils import fuzzy_match
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseModelService(ABC): class BaseModelService(ABC):
"""Base service class for all model types""" """Base service class for all model types"""
def __init__( def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
self, """Initialize the service
model_type: str,
scanner,
metadata_class: Type[BaseModelMetadata],
*,
cache_repository: Optional[ModelCacheRepository] = None,
filter_set: Optional[ModelFilterSet] = None,
search_strategy: Optional[SearchStrategy] = None,
settings_provider: Optional[SettingsProvider] = None,
):
"""Initialize the service.
Args: Args:
model_type: Type of model (lora, checkpoint, etc.). model_type: Type of model (lora, checkpoint, etc.)
scanner: Model scanner instance. scanner: Model scanner instance
metadata_class: Metadata class for this model type. metadata_class: Metadata class for this model type
cache_repository: Custom repository for cache access (primarily for tests).
filter_set: Filter component controlling folder/tag/favorites logic.
search_strategy: Search component for fuzzy/text matching.
settings_provider: Settings object; defaults to the global settings manager.
""" """
self.model_type = model_type self.model_type = model_type
self.scanner = scanner self.scanner = scanner
self.metadata_class = metadata_class self.metadata_class = metadata_class
self.settings = settings_provider or default_settings
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
self.filter_set = filter_set or ModelFilterSet(self.settings)
self.search_strategy = search_strategy or SearchStrategy()
async def get_paginated_data( async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
self, folder: str = None, search: str = None, fuzzy_search: bool = False,
page: int, base_models: list = None, tags: list = None,
page_size: int, search_options: dict = None, hash_filters: dict = None,
sort_by: str = 'name', favorites_only: bool = False, **kwargs) -> Dict:
folder: str = None, """Get paginated and filtered model data
search: str = None,
fuzzy_search: bool = False,
base_models: list = None,
tags: list = None,
search_options: dict = None,
hash_filters: dict = None,
favorites_only: bool = False,
**kwargs,
) -> Dict:
"""Get paginated and filtered model data"""
sort_params = self.cache_repository.parse_sort(sort_by)
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
Args:
page: Page number (1-based)
page_size: Number of items per page
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
folder: Folder filter
search: Search term
fuzzy_search: Whether to use fuzzy search
base_models: List of base models to filter by
tags: List of tags to filter by
search_options: Search options dict
hash_filters: Hash filtering options
favorites_only: Filter for favorites only
**kwargs: Additional model-specific filters
Returns:
Dict containing paginated results
"""
cache = await self.scanner.get_cached_data()
# Parse sort_by into sort_key and order
if ':' in sort_by:
sort_key, order = sort_by.split(':', 1)
sort_key = sort_key.strip()
order = order.strip().lower()
if order not in ('asc', 'desc'):
order = 'asc'
else:
sort_key = sort_by.strip()
order = 'asc'
# Get default search options if not provided
if search_options is None:
search_options = {
'filename': True,
'modelname': True,
'tags': False,
'recursive': False,
}
# Get the base data set using new sort logic
filtered_data = await cache.get_sorted_data(sort_key, order)
# Apply hash filtering if provided (highest priority)
if hash_filters: if hash_filters:
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters) filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
# Jump to pagination for hash filters
return self._paginate(filtered_data, page, page_size) return self._paginate(filtered_data, page, page_size)
# Apply common filters
filtered_data = await self._apply_common_filters( filtered_data = await self._apply_common_filters(
sorted_data, filtered_data, folder, base_models, tags, favorites_only, search_options
folder=folder,
base_models=base_models,
tags=tags,
favorites_only=favorites_only,
search_options=search_options,
) )
# Apply search filtering
if search: if search:
filtered_data = await self._apply_search_filters( filtered_data = await self._apply_search_filters(
filtered_data, filtered_data, search, fuzzy_search, search_options
search,
fuzzy_search,
search_options,
) )
# Apply model-specific filters
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
return self._paginate(filtered_data, page, page_size) return self._paginate(filtered_data, page, page_size)
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
"""Apply hash-based filtering""" """Apply hash-based filtering"""
single_hash = hash_filters.get('single_hash') single_hash = hash_filters.get('single_hash')
@@ -109,36 +118,89 @@ class BaseModelService(ABC):
return data return data
async def _apply_common_filters( async def _apply_common_filters(self, data: List[Dict], folder: str = None,
self, base_models: list = None, tags: list = None,
data: List[Dict], favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
folder: str = None,
base_models: list = None,
tags: list = None,
favorites_only: bool = False,
search_options: dict = None,
) -> List[Dict]:
"""Apply common filters that work across all model types""" """Apply common filters that work across all model types"""
normalized_options = self.search_strategy.normalize_options(search_options) # Apply SFW filtering if enabled in settings
criteria = FilterCriteria( if settings.get('show_only_sfw', False):
folder=folder, data = [
base_models=base_models, item for item in data
tags=tags, if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
favorites_only=favorites_only, ]
search_options=normalized_options,
)
return self.filter_set.apply(data, criteria)
async def _apply_search_filters( # Apply favorites filtering if enabled
self, if favorites_only:
data: List[Dict], data = [
search: str, item for item in data
fuzzy_search: bool, if item.get('favorite', False) is True
search_options: dict, ]
) -> List[Dict]:
# Apply folder filtering
if folder is not None:
if search_options and search_options.get('recursive', False):
# Recursive folder filtering - include all subfolders
data = [
item for item in data
if item['folder'].startswith(folder)
]
else:
# Exact folder filtering
data = [
item for item in data
if item['folder'] == folder
]
# Apply base model filtering
if base_models and len(base_models) > 0:
data = [
item for item in data
if item.get('base_model') in base_models
]
# Apply tag filtering
if tags and len(tags) > 0:
data = [
item for item in data
if any(tag in item.get('tags', []) for tag in tags)
]
return data
async def _apply_search_filters(self, data: List[Dict], search: str,
fuzzy_search: bool, search_options: dict) -> List[Dict]:
"""Apply search filtering""" """Apply search filtering"""
normalized_options = self.search_strategy.normalize_options(search_options) search_results = []
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
for item in data:
# Search by file name
if search_options.get('filename', True):
if fuzzy_search:
if fuzzy_match(item.get('file_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('file_name', '').lower():
search_results.append(item)
continue
# Search by model name
if search_options.get('modelname', True):
if fuzzy_search:
if fuzzy_match(item.get('model_name', ''), search):
search_results.append(item)
continue
elif search.lower() in item.get('model_name', '').lower():
search_results.append(item)
continue
# Search by tags
if search_options.get('tags', False) and 'tags' in item:
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
for tag in item['tags']):
search_results.append(item)
continue
return search_results
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
"""Apply model-specific filters - to be overridden by subclasses if needed""" """Apply model-specific filters - to be overridden by subclasses if needed"""
@@ -195,181 +257,3 @@ class BaseModelService(ABC):
def get_model_roots(self) -> List[str]: def get_model_roots(self) -> List[str]:
"""Get model root directories""" """Get model root directories"""
return self.scanner.get_model_roots() return self.scanner.get_model_roots()
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
"""Filter relevant fields from CivitAI data"""
if not data:
return {}
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
"id", "modelId", "name", "createdAt", "updatedAt",
"publishedAt", "trainedWords", "baseModel", "description",
"model", "images", "customImages", "creator"
]
return {k: data[k] for k in fields if k in data}
async def get_folder_tree(self, model_root: str) -> Dict:
"""Get hierarchical folder tree for a specific model root"""
cache = await self.scanner.get_cached_data()
# Build tree structure from folders
tree = {}
for folder in cache.folders:
# Check if this folder belongs to the specified model root
folder_belongs_to_root = False
for root in self.scanner.get_model_roots():
if root == model_root:
folder_belongs_to_root = True
break
if not folder_belongs_to_root:
continue
# Split folder path into components
parts = folder.split('/') if folder else []
current_level = tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return tree
async def get_unified_folder_tree(self) -> Dict:
"""Get unified folder tree across all model roots"""
cache = await self.scanner.get_cached_data()
# Build unified tree structure by analyzing all relative paths
unified_tree = {}
# Get all model roots for path normalization
model_roots = self.scanner.get_model_roots()
for folder in cache.folders:
if not folder: # Skip empty folders
continue
# Find which root this folder belongs to by checking the actual file paths
# This is a simplified approach - we'll use the folder as-is since it should already be relative
relative_path = folder
# Split folder path into components
parts = relative_path.split('/')
current_level = unified_tree
for part in parts:
if part not in current_level:
current_level[part] = {}
current_level = current_level[part]
return unified_tree
async def get_model_notes(self, model_name: str) -> Optional[str]:
"""Get notes for a specific model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
return model.get('notes', '')
return None
async def get_model_preview_url(self, model_name: str) -> Optional[str]:
"""Get the static preview URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
preview_url = model.get('preview_url')
if preview_url:
from ..config import config
return config.get_preview_static_url(preview_url)
return '/loras_static/images/no-preview.png'
async def get_model_civitai_url(self, model_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a model file"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model['file_name'] == model_name:
civitai_data = model.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
async def get_model_metadata(self, file_path: str) -> Optional[Dict]:
"""Get filtered CivitAI metadata for a model by file path"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model.get('file_path') == file_path:
return self.filter_civitai_data(model.get("civitai", {}))
return None
async def get_model_description(self, file_path: str) -> Optional[str]:
"""Get model description by file path"""
cache = await self.scanner.get_cached_data()
for model in cache.raw_data:
if model.get('file_path') == file_path:
return model.get('modelDescription', '')
return None
async def search_relative_paths(self, search_term: str, limit: int = 15) -> List[str]:
"""Search model relative file paths for autocomplete functionality"""
cache = await self.scanner.get_cached_data()
matching_paths = []
search_lower = search_term.lower()
# Get model roots for path calculation
model_roots = self.scanner.get_model_roots()
for model in cache.raw_data:
file_path = model.get('file_path', '')
if not file_path:
continue
# Calculate relative path from model root
relative_path = None
for root in model_roots:
# Normalize paths for comparison
normalized_root = os.path.normpath(root)
normalized_file = os.path.normpath(file_path)
if normalized_file.startswith(normalized_root):
# Remove root and leading separator to get relative path
relative_path = normalized_file[len(normalized_root):].lstrip(os.sep)
break
if relative_path and search_lower in relative_path.lower():
matching_paths.append(relative_path)
if len(matching_paths) >= limit * 2: # Get more for better sorting
break
# Sort by relevance (exact matches first, then by length)
matching_paths.sort(key=lambda x: (
not x.lower().startswith(search_lower), # Exact prefix matches first
len(x), # Then by length (shorter first)
x.lower() # Then alphabetically
))
return matching_paths[:limit]

View File

@@ -13,7 +13,7 @@ class CheckpointScanner(ModelScanner):
def __init__(self): def __init__(self):
# Define supported file extensions # Define supported file extensions
file_extensions = {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'} file_extensions = {'.safetensors', '.ckpt', '.pt', '.pth', '.sft', '.gguf'}
super().__init__( super().__init__(
model_type="checkpoint", model_type="checkpoint",
model_class=CheckpointMetadata, model_class=CheckpointMetadata,
@@ -21,14 +21,6 @@ class CheckpointScanner(ModelScanner):
hash_index=ModelHashIndex() hash_index=ModelHashIndex()
) )
def adjust_metadata(self, metadata, file_path, root_path):
if hasattr(metadata, "model_type"):
if root_path in config.checkpoints_roots:
metadata.model_type = "checkpoint"
elif root_path in config.unet_roots:
metadata.model_type = "diffusion_model"
return metadata
def get_model_roots(self) -> List[str]: def get_model_roots(self) -> List[str]:
"""Get checkpoint root directories""" """Get checkpoint root directories"""
return config.base_models_roots return config.base_models_roots

View File

@@ -1,10 +1,11 @@
import os import os
import logging import logging
from typing import Dict from typing import Dict, List, Optional
from .base_model_service import BaseModelService from .base_model_service import BaseModelService
from ..utils.models import CheckpointMetadata from ..utils.models import CheckpointMetadata
from ..config import config from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,11 +34,12 @@ class CheckpointService(BaseModelService):
"file_size": checkpoint_data.get("size", 0), "file_size": checkpoint_data.get("size", 0),
"modified": checkpoint_data.get("modified", ""), "modified": checkpoint_data.get("modified", ""),
"tags": checkpoint_data.get("tags", []), "tags": checkpoint_data.get("tags", []),
"modelDescription": checkpoint_data.get("modelDescription", ""),
"from_civitai": checkpoint_data.get("from_civitai", True), "from_civitai": checkpoint_data.get("from_civitai", True),
"notes": checkpoint_data.get("notes", ""), "notes": checkpoint_data.get("notes", ""),
"model_type": checkpoint_data.get("model_type", "checkpoint"), "model_type": checkpoint_data.get("model_type", "checkpoint"),
"favorite": checkpoint_data.get("favorite", False), "favorite": checkpoint_data.get("favorite", False),
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) "civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}))
} }
def find_duplicate_hashes(self) -> Dict: def find_duplicate_hashes(self) -> Dict:

View File

@@ -1,10 +1,11 @@
from datetime import datetime
import aiohttp
import os import os
import copy
import logging import logging
import asyncio import asyncio
from email.parser import Parser
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from urllib.parse import unquote
from .downloader import get_downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,11 +19,6 @@ class CivitaiClient:
async with cls._lock: async with cls._lock:
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
# Register this client as a metadata provider
provider_manager = await ModelMetadataProviderManager.get_instance()
provider_manager.register_provider('civitai', CivitaiModelMetadataProvider(cls._instance), True)
return cls._instance return cls._instance
def __init__(self): def __init__(self):
@@ -32,9 +28,81 @@ class CivitaiClient:
self._initialized = True self._initialized = True
self.base_url = "https://civitai.com/api/v1" self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
self._session = None
self._session_created_at = None
# Set default buffer size to 1MB for higher throughput
self.chunk_size = 1024 * 1024
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: @property
"""Download file with resumable downloads and retry mechanism async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session"""
if self._session is None:
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=8, # Increase from 3 to 8 for better parallelism
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters - increase read timeout for large files
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=120)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=trust_env,
timeout=timeout
)
self._session_created_at = datetime.now()
return self._session
async def _ensure_fresh_session(self):
"""Refresh session if it's been open too long"""
if self._session is not None:
if not hasattr(self, '_session_created_at') or \
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
await self.close()
self._session = None
return await self.session
def _parse_content_disposition(self, header: str) -> str:
"""Parse filename from content-disposition header"""
if not header:
return None
# Handle quoted filenames
if 'filename="' in header:
start = header.index('filename="') + 10
end = header.index('"', start)
return unquote(header[start:end])
# Fallback to original parsing
disposition = Parser().parsestr(f'Content-Disposition: {header}')
filename = disposition.get_param('filename')
if filename:
return unquote(filename)
return None
def _get_request_headers(self) -> dict:
"""Get request headers with optional API key"""
headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
'Content-Type': 'application/json'
}
from .settings_manager import settings
api_key = settings.get('civitai_api_key')
if (api_key):
headers['Authorization'] = f'Bearer {api_key}'
return headers
async def _download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
"""Download file with content-disposition support and progress tracking
Args: Args:
url: Download URL url: Download URL
@@ -45,228 +113,160 @@ class CivitaiClient:
Returns: Returns:
Tuple[bool, str]: (success, save_path or error message) Tuple[bool, str]: (success, save_path or error message)
""" """
downloader = await get_downloader() logger.debug(f"Resolving DNS for: {url}")
save_path = os.path.join(save_dir, default_filename) session = await self._ensure_fresh_session()
# Use unified downloader with CivitAI authentication
success, result = await downloader.download_file(
url=url,
save_path=save_path,
progress_callback=progress_callback,
use_auth=True, # Enable CivitAI authentication
allow_resume=True
)
return success, result
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
try: try:
downloader = await get_downloader() headers = self._get_request_headers()
success, result = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if success:
# Get model ID from version data
model_id = result.get('modelId')
if model_id:
# Fetch additional model metadata
success_model, data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if success_model:
# Enrich version_info with model data
result['model']['description'] = data.get("description")
result['model']['tags'] = data.get("tags", [])
# Add creator from model data # Add Range header to allow resumable downloads
result['creator'] = data.get("creator") headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
return result, None logger.debug(f"Starting download from: {url}")
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
# Handle 401 unauthorized responses
if response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
# Handle specific error cases return False, "Invalid or missing CivitAI API key, or early access restriction."
if "not found" in str(result):
return None, "Model not found"
# Other error cases # Handle other client errors that might be permission-related
logger.error(f"Failed to fetch model info for {model_hash[:10]}: {result}") if response.status == 403:
return None, str(result) logger.warning(f"Forbidden access to resource: {url} (Status 403)")
return False, "Access forbidden: You don't have permission to download this file."
# Generic error response for other status codes
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get filename from content-disposition header
content_disposition = response.headers.get('Content-Disposition')
filename = self._parse_content_disposition(content_disposition)
if not filename:
filename = default_filename
save_path = os.path.join(save_dir, filename)
# Get total file size for progress calculation
total_size = int(response.headers.get('content-length', 0))
current_size = 0
last_progress_report_time = datetime.now()
# Stream download to file with progress updates using larger buffer
with open(save_path, 'wb') as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
f.write(chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
except aiohttp.ClientError as e:
logger.error(f"Network error during download: {e}")
return False, f"Network error: {str(e)}"
except Exception as e:
logger.error(f"Download error: {e}")
return False, str(e)
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try:
session = await self._ensure_fresh_session()
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
if response.status == 200:
return await response.json()
return None
except Exception as e: except Exception as e:
logger.error(f"API Error: {str(e)}") logger.error(f"API Error: {str(e)}")
return None, str(e) return None
async def download_preview_image(self, image_url: str, save_path: str): async def download_preview_image(self, image_url: str, save_path: str):
try: try:
downloader = await get_downloader() session = await self._ensure_fresh_session()
success, content, headers = await downloader.download_to_memory( async with session.get(image_url) as response:
image_url, if response.status == 200:
use_auth=False # Preview images don't need auth content = await response.read()
) with open(save_path, 'wb') as f:
if success: f.write(content)
# Ensure directory exists return True
os.makedirs(os.path.dirname(save_path), exist_ok=True) return False
with open(save_path, 'wb') as f:
f.write(content)
return True
return False
except Exception as e: except Exception as e:
logger.error(f"Download Error: {str(e)}") print(f"Download Error: {str(e)}")
return False return False
async def get_model_versions(self, model_id: str) -> List[Dict]: async def get_model_versions(self, model_id: str) -> List[Dict]:
"""Get all versions of a model with local availability info""" """Get all versions of a model with local availability info"""
try: try:
downloader = await get_downloader() session = await self._ensure_fresh_session() # Use fresh session
success, result = await downloader.make_request( async with session.get(f"{self.base_url}/models/{model_id}") as response:
'GET', if response.status != 200:
f"{self.base_url}/models/{model_id}", return None
use_auth=True data = await response.json()
)
if success:
# Also return model type along with versions # Also return model type along with versions
return { return {
'modelVersions': result.get('modelVersions', []), 'modelVersions': data.get('modelVersions', []),
'type': result.get('type', ''), 'type': data.get('type', '')
'name': result.get('name', '')
} }
return None
except Exception as e: except Exception as e:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]: async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata """Get specific model version with additional metadata
Args: Args:
model_id: The Civitai model ID (optional if version_id is provided) model_id: The Civitai model ID
version_id: Optional specific version ID to retrieve version_id: Optional specific version ID to retrieve
Returns: Returns:
Optional[Dict]: The model version data with additional fields or None if not found Optional[Dict]: The model version data with additional fields or None if not found
""" """
try: try:
downloader = await get_downloader() session = await self._ensure_fresh_session()
# Case 1: Only version_id is provided # Step 1: Get model data to find version_id if not provided and get additional metadata
if model_id is None and version_id is not None: async with session.get(f"{self.base_url}/models/{model_id}") as response:
# First get the version info to extract model_id if response.status != 200:
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/{version_id}",
use_auth=True
)
if not success:
return None
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
# Now get the model data for additional metadata
success, model_data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if success:
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
return version
# Case 2: model_id is provided (with or without version_id)
elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata
success, data = await downloader.make_request(
'GET',
f"{self.base_url}/models/{model_id}",
use_auth=True
)
if not success:
return None return None
data = await response.json()
model_versions = data.get('modelVersions', []) model_versions = data.get('modelVersions', [])
if not model_versions:
logger.warning(f"No model versions found for model {model_id}") # Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 3: Get detailed version info using the version_id
headers = self._get_request_headers()
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
if response.status != 200:
return None return None
# Step 2: Determine the target version entry to use version = await response.json()
target_version = None
if version_id is not None:
target_version = next(
(item for item in model_versions if item.get('id') == version_id),
None
)
if target_version is None:
logger.warning(
f"Version {version_id} not found for model {model_id}, defaulting to first version"
)
if target_version is None:
target_version = model_versions[0]
target_version_id = target_version.get('id')
# Step 3: Get detailed version info using the SHA256 hash
model_hash = None
for file_info in target_version.get('files', []):
if file_info.get('type') == 'Model' and file_info.get('primary'):
model_hash = file_info.get('hashes', {}).get('SHA256')
if model_hash:
break
version = None
if model_hash:
success, version = await downloader.make_request(
'GET',
f"{self.base_url}/model-versions/by-hash/{model_hash}",
use_auth=True
)
if not success:
logger.warning(
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
)
version = None
else:
logger.warning(
f"No primary model hash found for model {model_id} version {target_version_id}"
)
if version is None:
version = copy.deepcopy(target_version)
version.pop('index', None)
version['modelId'] = model_id
version['model'] = {
'name': data.get('name'),
'type': data.get('type'),
'nsfw': data.get('nsfw'),
'poi': data.get('poi')
}
# Step 4: Enrich version_info with model data # Step 4: Enrich version_info with model data
# Add description and tags from model data # Add description and tags from model data
model_info = version.get('model') version['model']['description'] = data.get("description")
if not isinstance(model_info, dict): version['model']['tags'] = data.get("tags", [])
model_info = {}
version['model'] = model_info
model_info['description'] = data.get("description")
model_info['tags'] = data.get("tags", [])
# Add creator from model data # Add creator from model data
version['creator'] = data.get("creator") version['creator'] = data.get("creator")
return version return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e: except Exception as e:
logger.error(f"Error fetching model version: {e}") logger.error(f"Error fetching model version: {e}")
return None return None
@@ -283,34 +283,116 @@ class CivitaiClient:
- An error message if there was an error, or None on success - An error message if there was an error, or None on success
""" """
try: try:
downloader = await get_downloader() session = await self._ensure_fresh_session()
url = f"{self.base_url}/model-versions/{version_id}" url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
logger.debug(f"Resolving DNS for model version info: {url}") logger.debug(f"Resolving DNS for model version info: {url}")
success, result = await downloader.make_request( async with session.get(url, headers=headers) as response:
'GET', if response.status == 200:
url, logger.debug(f"Successfully fetched model version info for: {version_id}")
use_auth=True return await response.json(), None
)
if success: # Handle specific error cases
logger.debug(f"Successfully fetched model version info for: {version_id}") if response.status == 404:
return result, None # Try to parse the error message
try:
error_data = await response.json()
error_msg = error_data.get('error', f"Model not found (status 404)")
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
except:
return None, "Model not found (status 404)"
# Handle specific error cases # Other error cases
if "not found" in str(result): logger.error(f"Failed to fetch model info for {version_id} (status {response.status})")
error_msg = f"Model not found" return None, f"Failed to fetch model info (status {response.status})"
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
# Other error cases
logger.error(f"Failed to fetch model info for {version_id}: {result}")
return None, str(result)
except Exception as e: except Exception as e:
error_msg = f"Error fetching model version info: {e}" error_msg = f"Error fetching model version info: {e}"
logger.error(error_msg) logger.error(error_msg)
return None, error_msg return None, error_msg
async def get_model_metadata(self, model_id: str) -> Tuple[Optional[Dict], int]:
"""Fetch model metadata (description, tags, and creator info) from Civitai API
Args:
model_id: The Civitai model ID
Returns:
Tuple[Optional[Dict], int]: A tuple containing:
- A dictionary with model metadata or None if not found
- The HTTP status code from the request
"""
try:
session = await self._ensure_fresh_session()
headers = self._get_request_headers()
url = f"{self.base_url}/models/{model_id}"
async with session.get(url, headers=headers) as response:
status_code = response.status
if status_code != 200:
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
return None, status_code
data = await response.json()
# Extract relevant metadata
metadata = {
"description": data.get("description") or "No model description available",
"tags": data.get("tags", []),
"creator": {
"username": data.get("creator", {}).get("username"),
"image": data.get("creator", {}).get("image")
}
}
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
return metadata, status_code
else:
logger.warning(f"No metadata found for model {model_id}")
return None, status_code
except Exception as e:
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
return None, 0
# Keep old method for backward compatibility, delegating to the new one
async def get_model_description(self, model_id: str) -> Optional[str]:
"""Fetch the model description from Civitai API (Legacy method)"""
metadata, _ = await self.get_model_metadata(model_id)
return metadata.get("description") if metadata else None
async def close(self):
"""Close the session if it exists"""
if self._session is not None:
await self._session.close()
self._session = None
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
"""Get hash from Civitai API"""
try:
session = await self._ensure_fresh_session()
if not session:
return None
version_info = await session.get(f"{self.base_url}/model-versions/{model_version_id}")
if not version_info or not version_info.json().get('files'):
return None
# Get hash from the first file
for file_info in version_info.json().get('files', []):
if file_info.get('hashes', {}).get('SHA256'):
# Convert hash to lowercase to standardize
hash_value = file_info['hashes']['SHA256'].lower()
return hash_value
return None
except Exception as e:
logger.error(f"Error getting hash from Civitai: {e}")
return None
async def get_image_info(self, image_id: str) -> Optional[Dict]: async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API """Fetch image information from Civitai API
@@ -321,25 +403,22 @@ class CivitaiClient:
Optional[Dict]: The image data or None if not found Optional[Dict]: The image data or None if not found
""" """
try: try:
downloader = await get_downloader() session = await self._ensure_fresh_session()
headers = self._get_request_headers()
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}") logger.debug(f"Fetching image info for ID: {image_id}")
success, result = await downloader.make_request( async with session.get(url, headers=headers) as response:
'GET', if response.status == 200:
url, data = await response.json()
use_auth=True if data and "items" in data and len(data["items"]) > 0:
) logger.debug(f"Successfully fetched image info for ID: {image_id}")
return data["items"][0]
logger.warning(f"No image found with ID: {image_id}")
return None
if success: logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})")
if result and "items" in result and len(result["items"]) > 0:
logger.debug(f"Successfully fetched image info for ID: {image_id}")
return result["items"][0]
logger.warning(f"No image found with ID: {image_id}")
return None return None
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
return None
except Exception as e: except Exception as e:
error_msg = f"Error fetching image info: {e}" error_msg = f"Error fetching image info: {e}"
logger.error(error_msg) logger.error(error_msg)

View File

@@ -1,100 +0,0 @@
"""Service wrapper for coordinating download lifecycle events."""
from __future__ import annotations
import logging
from typing import Any, Awaitable, Callable, Dict, Optional
logger = logging.getLogger(__name__)
class DownloadCoordinator:
"""Manage download scheduling, cancellation and introspection."""
def __init__(
self,
*,
ws_manager,
download_manager_factory: Callable[[], Awaitable],
) -> None:
self._ws_manager = ws_manager
self._download_manager_factory = download_manager_factory
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Schedule a download using the provided payload."""
download_manager = await self._download_manager_factory()
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
payload.setdefault("download_id", download_id)
async def progress_callback(progress: Any) -> None:
await self._ws_manager.broadcast_download_progress(
download_id,
{
"status": "progress",
"progress": progress,
"download_id": download_id,
},
)
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
model_version_id = self._parse_optional_int(
payload.get("model_version_id"), "model_version_id"
)
if model_id is None and model_version_id is None:
raise ValueError(
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
)
result = await download_manager.download_from_civitai(
model_id=model_id,
model_version_id=model_version_id,
save_dir=payload.get("model_root"),
relative_path=payload.get("relative_path", ""),
use_default_paths=payload.get("use_default_paths", False),
progress_callback=progress_callback,
download_id=download_id,
source=payload.get("source"),
)
result["download_id"] = download_id
return result
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
"""Cancel an active download and emit a broadcast event."""
download_manager = await self._download_manager_factory()
result = await download_manager.cancel_download(download_id)
await self._ws_manager.broadcast_download_progress(
download_id,
{
"status": "cancelled",
"progress": 0,
"download_id": download_id,
"message": "Download cancelled by user",
},
)
return result
async def list_active_downloads(self) -> Dict[str, Any]:
"""Return the active download map from the underlying manager."""
download_manager = await self._download_manager_factory()
return await download_manager.get_active_downloads()
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
"""Parse an optional integer from user input."""
if value is None or value == "":
return None
try:
return int(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"Invalid {field}: Must be an integer") from exc

View File

@@ -10,8 +10,6 @@ from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import settings from .settings_manager import settings
from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader
# Download to temporary file first # Download to temporary file first
import tempfile import tempfile
@@ -36,11 +34,18 @@ class DownloadManager:
return return
self._initialized = True self._initialized = True
self._civitai_client = None # Will be lazily initialized
# Add download management # Add download management
self._active_downloads = OrderedDict() # download_id -> download_info self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task self._download_tasks = {} # download_id -> asyncio.Task
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
if self._civitai_client is None:
self._civitai_client = await ServiceRegistry.get_civitai_client()
return self._civitai_client
async def _get_lora_scanner(self): async def _get_lora_scanner(self):
"""Get the lora scanner from registry""" """Get the lora scanner from registry"""
return await ServiceRegistry.get_lora_scanner() return await ServiceRegistry.get_lora_scanner()
@@ -49,29 +54,24 @@ class DownloadManager:
"""Get the checkpoint scanner from registry""" """Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner() return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None, async def download_from_civitai(self, model_id: int, model_version_id: int,
save_dir: str = None, relative_path: str = '', save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False, progress_callback=None, use_default_paths: bool = False,
download_id: str = None, source: str = None) -> Dict: download_id: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control """Download model from Civitai with task tracking and concurrency control
Args: Args:
model_id: Civitai model ID (optional if model_version_id is provided) model_id: Civitai model ID
model_version_id: Civitai model version ID (optional if model_id is provided) model_version_id: Civitai model version ID
save_dir: Directory to save the model save_dir: Directory to save the model
relative_path: Relative path within save_dir relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates progress_callback: Callback function for progress updates
use_default_paths: Flag to use default paths use_default_paths: Flag to use default paths
download_id: Unique identifier for this download task download_id: Unique identifier for this download task
source: Optional source parameter to specify metadata provider
Returns: Returns:
Dict with download result Dict with download result
""" """
# Validate that at least one identifier is provided
if not model_id and not model_version_id:
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
# Use provided download_id or generate new one # Use provided download_id or generate new one
task_id = download_id or str(uuid.uuid4()) task_id = download_id or str(uuid.uuid4())
@@ -87,7 +87,7 @@ class DownloadManager:
download_task = asyncio.create_task( download_task = asyncio.create_task(
self._download_with_semaphore( self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir, task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths, source relative_path, progress_callback, use_default_paths
) )
) )
@@ -108,8 +108,7 @@ class DownloadManager:
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int, async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str, save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False, progress_callback=None, use_default_paths: bool = False):
source: str = None):
"""Execute download with semaphore to limit concurrency""" """Execute download with semaphore to limit concurrency"""
# Update status to waiting # Update status to waiting
if task_id in self._active_downloads: if task_id in self._active_downloads:
@@ -139,7 +138,7 @@ class DownloadManager:
result = await self._execute_original_download( result = await self._execute_original_download(
model_id, model_version_id, save_dir, model_id, model_version_id, save_dir,
relative_path, tracking_callback, use_default_paths, relative_path, tracking_callback, use_default_paths,
task_id, source task_id
) )
# Update status based on result # Update status based on result
@@ -174,7 +173,7 @@ class DownloadManager:
async def _execute_original_download(self, model_id, model_version_id, save_dir, async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths, relative_path, progress_callback, use_default_paths,
download_id=None, source=None): download_id=None):
"""Wrapper for original download_from_civitai implementation""" """Wrapper for original download_from_civitai implementation"""
try: try:
# Check if model version already exists in library # Check if model version already exists in library
@@ -182,29 +181,20 @@ class DownloadManager:
# Check both scanners # Check both scanners
lora_scanner = await self._get_lora_scanner() lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner() checkpoint_scanner = await self._get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Check lora scanner first # Check lora scanner first
if await lora_scanner.check_model_version_exists(model_version_id): if await lora_scanner.check_model_version_exists(model_id, model_version_id):
return {'success': False, 'error': 'Model version already exists in lora library'} return {'success': False, 'error': 'Model version already exists in lora library'}
# Check checkpoint scanner # Check checkpoint scanner
if await checkpoint_scanner.check_model_version_exists(model_version_id): if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'} return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Check embedding scanner # Get civitai client
if await embedding_scanner.check_model_version_exists(model_version_id): civitai_client = await self._get_civitai_client()
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Get metadata provider based on source parameter
if source == 'civarchive':
from .metadata_service import get_metadata_provider
metadata_provider = await get_metadata_provider('civarchive')
else:
metadata_provider = await get_default_metadata_provider()
# Get version info based on the provided identifier # Get version info based on the provided identifier
version_info = await metadata_provider.get_model_version(model_id, model_version_id) version_info = await civitai_client.get_model_version(model_id, model_version_id)
if not version_info: if not version_info:
return {'success': False, 'error': 'Failed to fetch model metadata'} return {'success': False, 'error': 'Failed to fetch model metadata'}
@@ -221,22 +211,23 @@ class DownloadManager:
# Case 2: model_version_id was None, check after getting version_info # Case 2: model_version_id was None, check after getting version_info
if model_version_id is None: if model_version_id is None:
version_model_id = version_info.get('modelId')
version_id = version_info.get('id') version_id = version_info.get('id')
if model_type == 'lora': if model_type == 'lora':
# Check lora scanner # Check lora scanner
lora_scanner = await self._get_lora_scanner() lora_scanner = await self._get_lora_scanner()
if await lora_scanner.check_model_version_exists(version_id): if await lora_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in lora library'} return {'success': False, 'error': 'Model version already exists in lora library'}
elif model_type == 'checkpoint': elif model_type == 'checkpoint':
# Check checkpoint scanner # Check checkpoint scanner
checkpoint_scanner = await self._get_checkpoint_scanner() checkpoint_scanner = await self._get_checkpoint_scanner()
if await checkpoint_scanner.check_model_version_exists(version_id): if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'} return {'success': False, 'error': 'Model version already exists in checkpoint library'}
elif model_type == 'embedding': elif model_type == 'embedding':
# Embeddings are not checked in scanners, but we can still check if it exists # Embeddings are not checked in scanners, but we can still check if it exists
embedding_scanner = await ServiceRegistry.get_embedding_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner()
if await embedding_scanner.check_model_version_exists(version_id): if await embedding_scanner.check_model_version_exists(version_model_id, version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'} return {'success': False, 'error': 'Model version already exists in embedding library'}
# Handle use_default_paths # Handle use_default_paths
@@ -259,7 +250,7 @@ class DownloadManager:
save_dir = default_path save_dir = default_path
# Calculate relative path using template # Calculate relative path using template
relative_path = self._calculate_relative_path(version_info, model_type) relative_path = self._calculate_relative_path(version_info)
# Update save directory with relative path if provided # Update save directory with relative path if provided
if relative_path: if relative_path:
@@ -275,9 +266,9 @@ class DownloadManager:
from datetime import datetime from datetime import datetime
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00')) date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
formatted_date = date_obj.strftime('%Y-%m-%d') formatted_date = date_obj.strftime('%Y-%m-%d')
early_access_msg = f"This model requires payment (until {formatted_date}). " early_access_msg = f"This model requires early access payment (until {formatted_date}). "
except: except:
early_access_msg = "This model requires payment. " early_access_msg = "This model requires early access payment. "
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai." early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}") logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
@@ -294,8 +285,6 @@ class DownloadManager:
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None) file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
if not file_info: if not file_info:
return {'success': False, 'error': 'No primary file found in metadata'} return {'success': False, 'error': 'No primary file found in metadata'}
if not file_info.get('downloadUrl'):
return {'success': False, 'error': 'No download URL found for primary file'}
# 3. Prepare download # 3. Prepare download
file_name = file_info['name'] file_name = file_info['name']
@@ -324,10 +313,6 @@ class DownloadManager:
download_id=download_id download_id=download_id
) )
# If early_access_msg exists and download failed, replace error message
if 'early_access_msg' in locals() and not result.get('success', False):
result['error'] = early_access_msg
return result return result
except Exception as e: except Exception as e:
@@ -338,18 +323,17 @@ class DownloadManager:
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."} return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str: def _calculate_relative_path(self, version_info: Dict) -> str:
"""Calculate relative path using template from settings """Calculate relative path using template from settings
Args: Args:
version_info: Version info from Civitai API version_info: Version info from Civitai API
model_type: Type of model ('lora', 'checkpoint', 'embedding')
Returns: Returns:
Relative path string Relative path string
""" """
# Get path template from settings for specific model type # Get path template from settings, default to '{base_model}/{first_tag}'
path_template = settings.get_download_path_template(model_type) path_template = settings.get('download_path_template', '{base_model}/{first_tag}')
# If template is empty, return empty path (flat structure) # If template is empty, return empty path (flat structure)
if not path_template: if not path_template:
@@ -358,13 +342,6 @@ class DownloadManager:
# Get base model name # Get base model name
base_model = version_info.get('baseModel', '') base_model = version_info.get('baseModel', '')
# Get author from creator data
creator_info = version_info.get('creator')
if creator_info and isinstance(creator_info, dict):
author = creator_info.get('username') or 'Anonymous'
else:
author = 'Anonymous'
# Apply mapping if available # Apply mapping if available
base_model_mappings = settings.get('base_model_path_mappings', {}) base_model_mappings = settings.get('base_model_path_mappings', {})
mapped_base_model = base_model_mappings.get(base_model, base_model) mapped_base_model = base_model_mappings.get(base_model, base_model)
@@ -387,49 +364,22 @@ class DownloadManager:
formatted_path = path_template formatted_path = path_template
formatted_path = formatted_path.replace('{base_model}', mapped_base_model) formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
formatted_path = formatted_path.replace('{first_tag}', first_tag) formatted_path = formatted_path.replace('{first_tag}', first_tag)
formatted_path = formatted_path.replace('{author}', author)
return formatted_path return formatted_path
async def _execute_download(self, download_url: str, save_dir: str, async def _execute_download(self, download_url: str, save_dir: str,
metadata, version_info: Dict, metadata, version_info: Dict,
relative_path: str, progress_callback=None, relative_path: str, progress_callback=None,
model_type: str = "lora", download_id: str = None) -> Dict: model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files""" """Execute the actual download process including preview images and model files"""
try: try:
# Extract original filename details civitai_client = await self._get_civitai_client()
original_filename = os.path.basename(metadata.file_path) save_path = metadata.file_path
base_name, extension = os.path.splitext(original_filename)
# Check for filename conflicts and generate unique filename if needed
# Use the hash from metadata for conflict resolution
def hash_provider():
return metadata.sha256
unique_filename = metadata.generate_unique_filename(
save_dir,
base_name,
extension,
hash_provider=hash_provider
)
# Update paths if filename changed
if unique_filename != original_filename:
logger.info(f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'")
save_path = os.path.join(save_dir, unique_filename)
# Update metadata with new file path and name
metadata.file_path = save_path.replace(os.sep, '/')
metadata.file_name = os.path.splitext(unique_filename)[0]
else:
save_path = metadata.file_path
part_path = save_path + '.part'
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# Store file paths in active_downloads for potential cleanup # Store file path in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads: if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = save_path self._active_downloads[download_id]['file_path'] = save_path
self._active_downloads[download_id]['part_path'] = part_path
# Download preview image if available # Download preview image if available
images = version_info.get('images', []) images = version_info.get('images', [])
@@ -446,14 +396,8 @@ class DownloadManager:
preview_ext = '.mp4' preview_ext = '.mp4'
preview_path = os.path.splitext(save_path)[0] + preview_ext preview_path = os.path.splitext(save_path)[0] + preview_ext
# Download video directly using downloader # Download video directly
downloader = await get_downloader() if await civitai_client.download_preview_image(images[0]['url'], preview_path):
success, result = await downloader.download_file(
images[0]['url'],
preview_path,
use_auth=False # Preview images typically don't need auth
)
if success:
metadata.preview_url = preview_path.replace(os.sep, '/') metadata.preview_url = preview_path.replace(os.sep, '/')
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0) metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
else: else:
@@ -461,16 +405,8 @@ class DownloadManager:
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
temp_path = temp_file.name temp_path = temp_file.name
# Download the original image to temp path using downloader # Download the original image to temp path
downloader = await get_downloader() if await civitai_client.download_preview_image(images[0]['url'], temp_path):
success, content, headers = await downloader.download_to_memory(
images[0]['url'],
use_auth=False
)
if success:
# Save to temp file
with open(temp_path, 'wb') as f:
f.write(content)
# Optimize and convert to WebP # Optimize and convert to WebP
preview_path = os.path.splitext(save_path)[0] + '.webp' preview_path = os.path.splitext(save_path)[0] + '.webp'
@@ -501,41 +437,26 @@ class DownloadManager:
if progress_callback: if progress_callback:
await progress_callback(3) # 3% progress after preview download await progress_callback(3) # 3% progress after preview download
# Download model file with progress tracking using downloader # Download model file with progress tracking
downloader = await get_downloader() success, result = await civitai_client._download_file(
# Determine if the download URL is from Civitai
use_auth = download_url.startswith("https://civitai.com/api/download/")
success, result = await downloader.download_file(
download_url, download_url,
save_path, # Use full path instead of separate dir and filename save_dir,
progress_callback=lambda p: self._handle_download_progress(p, progress_callback), os.path.basename(save_path),
use_auth=use_auth # Only use authentication for Civitai downloads progress_callback=lambda p: self._handle_download_progress(p, progress_callback)
) )
if not success: if not success:
# Clean up files on failure, but preserve .part file for resume # Clean up files on failure
cleanup_files = [metadata_path] for path in [save_path, metadata_path, metadata.preview_url]:
if metadata.preview_url and os.path.exists(metadata.preview_url):
cleanup_files.append(metadata.preview_url)
for path in cleanup_files:
if path and os.path.exists(path): if path and os.path.exists(path):
try: os.remove(path)
os.remove(path)
except Exception as e:
logger.warning(f"Failed to cleanup file {path}: {e}")
# Log but don't remove .part file to allow resume
if os.path.exists(part_path):
logger.info(f"Preserving partial download for resume: {part_path}")
return {'success': False, 'error': result} return {'success': False, 'error': result}
# 4. Update file information (size and modified time) # 4. Update file information (size and modified time)
metadata.update_file_info(save_path) metadata.update_file_info(save_path)
# 5. Final metadata update # 5. Final metadata update
await MetadataManager.save_metadata(save_path, metadata) await MetadataManager.save_metadata(save_path, metadata, True)
# 6. Update cache based on model type # 6. Update cache based on model type
if model_type == "checkpoint": if model_type == "checkpoint":
@@ -564,18 +485,10 @@ class DownloadManager:
except Exception as e: except Exception as e:
logger.error(f"Error in _execute_download: {e}", exc_info=True) logger.error(f"Error in _execute_download: {e}", exc_info=True)
# Clean up partial downloads except .part file # Clean up partial downloads
cleanup_files = [metadata_path] for path in [save_path, metadata_path]:
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
cleanup_files.append(metadata.preview_url)
for path in cleanup_files:
if path and os.path.exists(path): if path and os.path.exists(path):
try: os.remove(path)
os.remove(path)
except Exception as e:
logger.warning(f"Failed to cleanup file {path}: {e}")
return {'success': False, 'error': str(e)} return {'success': False, 'error': str(e)}
async def _handle_download_progress(self, file_progress: float, progress_callback): async def _handle_download_progress(self, file_progress: float, progress_callback):
@@ -617,48 +530,35 @@ class DownloadManager:
except (asyncio.CancelledError, asyncio.TimeoutError): except (asyncio.CancelledError, asyncio.TimeoutError):
pass pass
# Clean up ALL files including .part when user cancels # Clean up partial downloads
download_info = self._active_downloads.get(download_id) download_info = self._active_downloads.get(download_id)
if download_info: if download_info and 'file_path' in download_info:
# Delete the main file # Delete the partial file
if 'file_path' in download_info: file_path = download_info['file_path']
file_path = download_info['file_path'] if os.path.exists(file_path):
if os.path.exists(file_path): try:
try: os.unlink(file_path)
os.unlink(file_path) logger.debug(f"Deleted partial download: {file_path}")
logger.debug(f"Deleted cancelled download: {file_path}") except Exception as e:
except Exception as e: logger.error(f"Error deleting partial file: {e}")
logger.error(f"Error deleting file: {e}")
# Delete the .part file (only on user cancellation)
if 'part_path' in download_info:
part_path = download_info['part_path']
if os.path.exists(part_path):
try:
os.unlink(part_path)
logger.debug(f"Deleted partial download: {part_path}")
except Exception as e:
logger.error(f"Error deleting part file: {e}")
# Delete metadata file if exists # Delete metadata file if exists
if 'file_path' in download_info: metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
file_path = download_info['file_path'] if os.path.exists(metadata_path):
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' try:
if os.path.exists(metadata_path): os.unlink(metadata_path)
try: except Exception as e:
os.unlink(metadata_path) logger.error(f"Error deleting metadata file: {e}")
except Exception as e:
logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4) # Delete preview file if exists (.webp or .mp4)
for preview_ext in ['.webp', '.mp4']: for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path): if os.path.exists(preview_path):
try: try:
os.unlink(preview_path) os.unlink(preview_path)
logger.debug(f"Deleted preview file: {preview_path}") logger.debug(f"Deleted preview file: {preview_path}")
except Exception as e: except Exception as e:
logger.error(f"Error deleting preview file: {e}") logger.error(f"Error deleting preview file: {e}")
return {'success': True, 'message': 'Download cancelled successfully'} return {'success': True, 'message': 'Download cancelled successfully'}
except Exception as e: except Exception as e:

View File

@@ -1,539 +0,0 @@
"""
Unified download manager for all HTTP/HTTPS downloads in the application.
This module provides a centralized download service with:
- Singleton pattern for global session management
- Support for authenticated downloads (e.g., CivitAI API key)
- Resumable downloads with automatic retry
- Progress tracking and callbacks
- Optimized connection pooling and timeouts
- Unified error handling and logging
"""
import os
import logging
import asyncio
import aiohttp
from datetime import datetime
from typing import Optional, Dict, Tuple, Callable, Union
from ..services.settings_manager import settings
logger = logging.getLogger(__name__)
class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of Downloader"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""Initialize the downloader with optimal settings"""
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
# Session management
self._session = None
self._session_created_at = None
self._proxy_url = None # Store proxy URL for current session
# Configuration
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput
self.max_retries = 5
self.base_delay = 2.0 # Base delay for exponential backoff
self.session_timeout = 300 # 5 minutes
# Default headers
self.default_headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
@property
async def session(self) -> aiohttp.ClientSession:
"""Get or create the global aiohttp session with optimized settings"""
if self._session is None or self._should_refresh_session():
await self._create_session()
return self._session
@property
def proxy_url(self) -> Optional[str]:
"""Get the current proxy URL (initialize if needed)"""
if not hasattr(self, '_proxy_url'):
self._proxy_url = None
return self._proxy_url
def _should_refresh_session(self) -> bool:
"""Check if session should be refreshed"""
if self._session is None:
return True
if not hasattr(self, '_session_created_at') or self._session_created_at is None:
return True
# Refresh if session is older than timeout
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout:
return True
return False
async def _create_session(self):
"""Create a new aiohttp session with optimized settings"""
# Close existing session if any
if self._session is not None:
await self._session.close()
# Check for app-level proxy settings
proxy_url = None
if settings.get('proxy_enabled', False):
proxy_host = settings.get('proxy_host', '').strip()
proxy_port = settings.get('proxy_port', '').strip()
proxy_type = settings.get('proxy_type', 'http').lower()
proxy_username = settings.get('proxy_username', '').strip()
proxy_password = settings.get('proxy_password', '').strip()
if proxy_host and proxy_port:
# Build proxy URL
if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
logger.debug(f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}")
logger.debug("Proxy mode: app-level proxy is active.")
else:
logger.debug("Proxy mode: system-level proxy (trust_env) will be used if configured in environment.")
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=8, # Concurrent connections
ttl_dns_cache=300, # DNS cache timeout
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
# Configure timeout parameters
timeout = aiohttp.ClientTimeout(
total=None, # No total timeout for large downloads
connect=60, # Connection timeout
sock_read=300 # 5 minute socket read timeout
)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=proxy_url is None, # Only use system proxy if no app-level proxy is set
timeout=timeout
)
# Store proxy URL for use in requests
self._proxy_url = proxy_url
self._session_created_at = datetime.now()
logger.debug("Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s", bool(proxy_url), proxy_url is None)
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
"""Get headers with optional authentication"""
headers = self.default_headers.copy()
if use_auth:
# Add CivitAI API key if available
api_key = settings.get('civitai_api_key')
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
headers['Content-Type'] = 'application/json'
return headers
async def download_file(
self,
url: str,
save_path: str,
progress_callback: Optional[Callable[[float], None]] = None,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True
) -> Tuple[bool, str]:
"""
Download a file with resumable downloads and retry mechanism
Args:
url: Download URL
save_path: Full path where the file should be saved
progress_callback: Optional callback for progress updates (0-100)
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
custom_headers: Additional headers to include in request
allow_resume: Whether to support resumable downloads
Returns:
Tuple[bool, str]: (success, save_path or error message)
"""
retry_count = 0
part_path = save_path + '.part' if allow_resume else save_path
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
# Get existing file size for resume
resume_offset = 0
if allow_resume and os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Resuming download from offset {resume_offset} bytes")
total_size = 0
while retry_count <= self.max_retries:
try:
session = await self.session
# Debug log for proxy mode at request time
if self.proxy_url:
logger.debug(f"[download_file] Using app-level proxy: {self.proxy_url}")
else:
logger.debug("[download_file] Using system-level proxy (trust_env) if configured.")
# Add Range header for resume if we have partial data
request_headers = headers.copy()
if allow_resume and resume_offset > 0:
request_headers['Range'] = f'bytes={resume_offset}-'
# Disable compression for better chunked downloads
request_headers['Accept-Encoding'] = 'identity'
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}")
if resume_offset > 0:
logger.debug(f"Requesting range from byte {resume_offset}")
async with session.get(url, headers=request_headers, allow_redirects=True, proxy=self.proxy_url) as response:
# Handle different response codes
if response.status == 200:
# Full content response
if resume_offset > 0:
# Server doesn't support ranges, restart from beginning
logger.warning("Server doesn't support range requests, restarting download")
resume_offset = 0
if os.path.exists(part_path):
os.remove(part_path)
elif response.status == 206:
# Partial content response (resume successful)
content_range = response.headers.get('Content-Range')
if content_range:
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
range_parts = content_range.split('/')
if len(range_parts) == 2:
total_size = int(range_parts[1])
logger.info(f"Successfully resumed download from byte {resume_offset}")
elif response.status == 416:
# Range not satisfiable - file might be complete or corrupted
if allow_resume and os.path.exists(part_path):
part_size = os.path.getsize(part_path)
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
# Try to get actual file size
head_response = await session.head(url, headers=headers, proxy=self.proxy_url)
if head_response.status == 200:
actual_size = int(head_response.headers.get('content-length', 0))
if part_size == actual_size:
# File is complete, just rename it
if allow_resume:
os.rename(part_path, save_path)
if progress_callback:
await progress_callback(100)
return True, save_path
# Remove corrupted part file and restart
os.remove(part_path)
resume_offset = 0
continue
elif response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
return False, "Invalid or missing API key, or early access restriction."
elif response.status == 403:
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
return False, "Access forbidden: You don't have permission to download this file."
elif response.status == 404:
logger.warning(f"Resource not found: {url} (Status 404)")
return False, "File not found - the download link may be invalid or expired."
else:
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get total file size for progress calculation (if not set from Content-Range)
if total_size == 0:
total_size = int(response.headers.get('content-length', 0))
if response.status == 206:
# For partial content, add the offset to get total file size
total_size += resume_offset
current_size = resume_offset
last_progress_report_time = datetime.now()
# Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Stream download to file with progress updates
loop = asyncio.get_running_loop()
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Download completed successfully
# Verify file size if total_size was provided
final_size = os.path.getsize(part_path)
if total_size > 0 and final_size != total_size:
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
# Don't treat this as fatal error, continue anyway
# Atomically rename .part to final file (only if using resume)
if allow_resume and part_path != save_path:
max_rename_attempts = 5
rename_attempt = 0
rename_success = False
while rename_attempt < max_rename_attempts and not rename_success:
try:
# If the destination file exists, remove it first (Windows safe)
if os.path.exists(save_path):
os.remove(save_path)
os.rename(part_path, save_path)
rename_success = True
except PermissionError as e:
rename_attempt += 1
if rename_attempt < max_rename_attempts:
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
await asyncio.sleep(2)
else:
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
return False, f"Failed to finalize download: {str(e)}"
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
if retry_count <= self.max_retries:
# Calculate delay with exponential backoff
delay = self.base_delay * (2 ** (retry_count - 1))
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
# Update resume offset for next attempt
if allow_resume and os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Will resume from byte {resume_offset}")
# Refresh session to get new connection
await self._create_session()
continue
else:
logger.error(f"Max retries exceeded for download: {e}")
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}"
except Exception as e:
logger.error(f"Unexpected download error: {e}")
return False, str(e)
return False, f"Download failed after {self.max_retries + 1} attempts"
async def download_to_memory(
self,
url: str,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
return_headers: bool = False
) -> Tuple[bool, Union[bytes, str], Optional[Dict]]:
"""
Download a file to memory (for small files like preview images)
Args:
url: Download URL
use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request
return_headers: Whether to return response headers along with content
Returns:
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
"""
try:
session = await self.session
# Debug log for proxy mode at request time
if self.proxy_url:
logger.debug(f"[download_to_memory] Using app-level proxy: {self.proxy_url}")
else:
logger.debug("[download_to_memory] Using system-level proxy (trust_env) if configured.")
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
async with session.get(url, headers=headers, proxy=self.proxy_url) as response:
if response.status == 200:
content = await response.read()
if return_headers:
return True, content, dict(response.headers)
else:
return True, content, None
elif response.status == 401:
error_msg = "Unauthorized access - invalid or missing API key"
return False, error_msg, None
elif response.status == 403:
error_msg = "Access forbidden"
return False, error_msg, None
elif response.status == 404:
error_msg = "File not found"
return False, error_msg, None
else:
error_msg = f"Download failed with status {response.status}"
return False, error_msg, None
except Exception as e:
logger.error(f"Error downloading to memory from {url}: {e}")
return False, str(e), None
async def get_response_headers(
self,
url: str,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None
) -> Tuple[bool, Union[Dict, str]]:
"""
Get response headers without downloading the full content
Args:
url: URL to check
use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request
Returns:
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
"""
try:
session = await self.session
# Debug log for proxy mode at request time
if self.proxy_url:
logger.debug(f"[get_response_headers] Using app-level proxy: {self.proxy_url}")
else:
logger.debug("[get_response_headers] Using system-level proxy (trust_env) if configured.")
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
async with session.head(url, headers=headers, proxy=self.proxy_url) as response:
if response.status == 200:
return True, dict(response.headers)
else:
return False, f"Head request failed with status {response.status}"
except Exception as e:
logger.error(f"Error getting headers from {url}: {e}")
return False, str(e)
async def make_request(
self,
method: str,
url: str,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
**kwargs
) -> Tuple[bool, Union[Dict, str]]:
"""
Make a generic HTTP request and return JSON response
Args:
method: HTTP method (GET, POST, etc.)
url: Request URL
use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request
**kwargs: Additional arguments for aiohttp request
Returns:
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
"""
try:
session = await self.session
# Debug log for proxy mode at request time
if self.proxy_url:
logger.debug(f"[make_request] Using app-level proxy: {self.proxy_url}")
else:
logger.debug("[make_request] Using system-level proxy (trust_env) if configured.")
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
# Add proxy to kwargs if not already present
if 'proxy' not in kwargs:
kwargs['proxy'] = self.proxy_url
async with session.request(method, url, headers=headers, **kwargs) as response:
if response.status == 200:
# Try to parse as JSON, fall back to text
try:
data = await response.json()
return True, data
except:
text = await response.text()
return True, text
elif response.status == 401:
return False, "Unauthorized access - invalid or missing API key"
elif response.status == 403:
return False, "Access forbidden"
elif response.status == 404:
return False, "Resource not found"
else:
return False, f"Request failed with status {response.status}"
except Exception as e:
logger.error(f"Error making {method} request to {url}: {e}")
return False, str(e)
async def close(self):
"""Close the HTTP session"""
if self._session is not None:
await self._session.close()
self._session = None
self._session_created_at = None
self._proxy_url = None
logger.debug("Closed HTTP session")
async def refresh_session(self):
"""Force refresh the HTTP session (useful when proxy settings change)"""
await self._create_session()
logger.info("HTTP session refreshed due to settings change")
# Global instance accessor
async def get_downloader() -> Downloader:
"""Get the global downloader instance"""
return await Downloader.get_instance()

View File

@@ -1,10 +1,11 @@
import os import os
import logging import logging
from typing import Dict from typing import Dict, List, Optional
from .base_model_service import BaseModelService from .base_model_service import BaseModelService
from ..utils.models import EmbeddingMetadata from ..utils.models import EmbeddingMetadata
from ..config import config from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,11 +34,12 @@ class EmbeddingService(BaseModelService):
"file_size": embedding_data.get("size", 0), "file_size": embedding_data.get("size", 0),
"modified": embedding_data.get("modified", ""), "modified": embedding_data.get("modified", ""),
"tags": embedding_data.get("tags", []), "tags": embedding_data.get("tags", []),
"modelDescription": embedding_data.get("modelDescription", ""),
"from_civitai": embedding_data.get("from_civitai", True), "from_civitai": embedding_data.get("from_civitai", True),
"notes": embedding_data.get("notes", ""), "notes": embedding_data.get("notes", ""),
"model_type": embedding_data.get("model_type", "embedding"), "model_type": embedding_data.get("model_type", "embedding"),
"favorite": embedding_data.get("favorite", False), "favorite": embedding_data.get("favorite", False),
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) "civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}))
} }
def find_duplicate_hashes(self) -> Dict: def find_duplicate_hashes(self) -> Dict:

View File

@@ -1,246 +0,0 @@
"""Service for cleaning up example image folders."""
from __future__ import annotations
import asyncio
import logging
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List
from .service_registry import ServiceRegistry
from .settings_manager import settings
logger = logging.getLogger(__name__)
@dataclass(slots=True)
class CleanupResult:
"""Structured result returned from cleanup operations."""
success: bool
checked_folders: int
moved_empty_folders: int
moved_orphaned_folders: int
skipped_non_hash: int
move_failures: int
errors: List[str]
deleted_root: str | None
partial_success: bool
def to_dict(self) -> Dict[str, object]:
"""Convert the dataclass to a serialisable dictionary."""
data = {
"success": self.success,
"checked_folders": self.checked_folders,
"moved_empty_folders": self.moved_empty_folders,
"moved_orphaned_folders": self.moved_orphaned_folders,
"moved_total": self.moved_empty_folders + self.moved_orphaned_folders,
"skipped_non_hash": self.skipped_non_hash,
"move_failures": self.move_failures,
"errors": self.errors,
"deleted_root": self.deleted_root,
"partial_success": self.partial_success,
}
return data
class ExampleImagesCleanupService:
"""Encapsulates logic for cleaning example image folders."""
DELETED_FOLDER_NAME = "_deleted"
def __init__(self, deleted_folder_name: str | None = None) -> None:
self._deleted_folder_name = deleted_folder_name or self.DELETED_FOLDER_NAME
async def cleanup_example_image_folders(self) -> Dict[str, object]:
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
example_images_path = settings.get("example_images_path")
if not example_images_path:
logger.debug("Cleanup skipped: example images path not configured")
return {
"success": False,
"error": "Example images path is not configured.",
"error_code": "path_not_configured",
}
example_root = Path(example_images_path)
if not example_root.exists():
logger.debug("Cleanup skipped: example images path missing -> %s", example_root)
return {
"success": False,
"error": "Example images path does not exist.",
"error_code": "path_not_found",
}
try:
lora_scanner = await ServiceRegistry.get_lora_scanner()
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
except Exception as exc: # pragma: no cover - defensive guard
logger.error("Failed to acquire scanners for cleanup: %s", exc, exc_info=True)
return {
"success": False,
"error": f"Failed to load model scanners: {exc}",
"error_code": "scanner_initialization_failed",
}
deleted_bucket = example_root / self._deleted_folder_name
deleted_bucket.mkdir(exist_ok=True)
checked_folders = 0
moved_empty = 0
moved_orphaned = 0
skipped_non_hash = 0
move_failures = 0
errors: List[str] = []
for entry in os.scandir(example_root):
if not entry.is_dir(follow_symlinks=False):
continue
if entry.name == self._deleted_folder_name:
continue
checked_folders += 1
folder_path = Path(entry.path)
try:
if self._is_folder_empty(folder_path):
if await self._remove_empty_folder(folder_path):
moved_empty += 1
else:
move_failures += 1
continue
if not self._is_hash_folder(entry.name):
skipped_non_hash += 1
continue
hash_exists = (
lora_scanner.has_hash(entry.name)
or checkpoint_scanner.has_hash(entry.name)
or embedding_scanner.has_hash(entry.name)
)
if not hash_exists:
if await self._move_folder(folder_path, deleted_bucket):
moved_orphaned += 1
else:
move_failures += 1
except Exception as exc: # pragma: no cover - filesystem guard
move_failures += 1
error_message = f"{entry.name}: {exc}"
errors.append(error_message)
logger.error("Error processing example images folder %s: %s", folder_path, exc, exc_info=True)
partial_success = move_failures > 0 and (moved_empty > 0 or moved_orphaned > 0)
success = move_failures == 0 and not errors
result = CleanupResult(
success=success,
checked_folders=checked_folders,
moved_empty_folders=moved_empty,
moved_orphaned_folders=moved_orphaned,
skipped_non_hash=skipped_non_hash,
move_failures=move_failures,
errors=errors,
deleted_root=str(deleted_bucket),
partial_success=partial_success,
)
summary = result.to_dict()
if success:
logger.info(
"Example images cleanup complete: checked=%s, moved_empty=%s, moved_orphaned=%s",
checked_folders,
moved_empty,
moved_orphaned,
)
elif partial_success:
logger.warning(
"Example images cleanup partially complete: moved=%s, failures=%s",
summary["moved_total"],
move_failures,
)
else:
logger.error(
"Example images cleanup failed: move_failures=%s, errors=%s",
move_failures,
errors,
)
return summary
@staticmethod
def _is_folder_empty(folder_path: Path) -> bool:
try:
with os.scandir(folder_path) as iterator:
return not any(iterator)
except FileNotFoundError:
return True
except OSError as exc: # pragma: no cover - defensive guard
logger.debug("Failed to inspect folder %s: %s", folder_path, exc)
return False
@staticmethod
def _is_hash_folder(name: str) -> bool:
if len(name) != 64:
return False
hex_chars = set("0123456789abcdefABCDEF")
return all(char in hex_chars for char in name)
async def _remove_empty_folder(self, folder_path: Path) -> bool:
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None,
shutil.rmtree,
str(folder_path),
)
logger.debug("Removed empty example images folder %s", folder_path)
return True
except Exception as exc: # pragma: no cover - filesystem guard
logger.error("Failed to remove empty example images folder %s: %s", folder_path, exc, exc_info=True)
return False
async def _move_folder(self, folder_path: Path, deleted_bucket: Path) -> bool:
destination = self._build_destination(folder_path.name, deleted_bucket)
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(
None,
shutil.move,
str(folder_path),
str(destination),
)
logger.debug("Moved example images folder %s -> %s", folder_path, destination)
return True
except Exception as exc: # pragma: no cover - filesystem guard
logger.error(
"Failed to move example images folder %s to %s: %s",
folder_path,
destination,
exc,
exc_info=True,
)
return False
def _build_destination(self, folder_name: str, deleted_bucket: Path) -> Path:
destination = deleted_bucket / folder_name
suffix = 1
while destination.exists():
destination = deleted_bucket / f"{folder_name}_{suffix}"
suffix += 1
return destination

View File

@@ -5,6 +5,7 @@ from typing import Dict, List, Optional
from .base_model_service import BaseModelService from .base_model_service import BaseModelService
from ..utils.models import LoraMetadata from ..utils.models import LoraMetadata
from ..config import config from ..config import config
from ..utils.routes_common import ModelRouteUtils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,11 +34,12 @@ class LoraService(BaseModelService):
"file_size": lora_data.get("size", 0), "file_size": lora_data.get("size", 0),
"modified": lora_data.get("modified", ""), "modified": lora_data.get("modified", ""),
"tags": lora_data.get("tags", []), "tags": lora_data.get("tags", []),
"modelDescription": lora_data.get("modelDescription", ""),
"from_civitai": lora_data.get("from_civitai", True), "from_civitai": lora_data.get("from_civitai", True),
"usage_tips": lora_data.get("usage_tips", ""), "usage_tips": lora_data.get("usage_tips", ""),
"notes": lora_data.get("notes", ""), "notes": lora_data.get("notes", ""),
"favorite": lora_data.get("favorite", False), "favorite": lora_data.get("favorite", False),
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) "civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}))
} }
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
@@ -145,6 +147,16 @@ class LoraService(BaseModelService):
return letters return letters
async def get_lora_notes(self, lora_name: str) -> Optional[str]:
"""Get notes for a specific LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
return lora.get('notes', '')
return None
async def get_lora_trigger_words(self, lora_name: str) -> List[str]: async def get_lora_trigger_words(self, lora_name: str) -> List[str]:
"""Get trigger words for a specific LoRA file""" """Get trigger words for a specific LoRA file"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
@@ -156,22 +168,41 @@ class LoraService(BaseModelService):
return [] return []
async def get_lora_usage_tips_by_relative_path(self, relative_path: str) -> Optional[str]: async def get_lora_preview_url(self, lora_name: str) -> Optional[str]:
"""Get usage tips for a LoRA by its relative path""" """Get the static preview URL for a LoRA file"""
cache = await self.scanner.get_cached_data() cache = await self.scanner.get_cached_data()
for lora in cache.raw_data: for lora in cache.raw_data:
file_path = lora.get('file_path', '') if lora['file_name'] == lora_name:
if file_path: preview_url = lora.get('preview_url')
# Convert to forward slashes and extract relative path if preview_url:
file_path_normalized = file_path.replace('\\', '/') return config.get_preview_static_url(preview_url)
relative_path = relative_path.replace('\\', '/')
# Find the relative path part by looking for the relative_path in the full path
if file_path_normalized.endswith(relative_path) or relative_path in file_path_normalized:
return lora.get('usage_tips', '')
return None return None
async def get_lora_civitai_url(self, lora_name: str) -> Dict[str, Optional[str]]:
"""Get the Civitai URL for a LoRA file"""
cache = await self.scanner.get_cached_data()
for lora in cache.raw_data:
if lora['file_name'] == lora_name:
civitai_data = lora.get('civitai', {})
model_id = civitai_data.get('modelId')
version_id = civitai_data.get('id')
if model_id:
civitai_url = f"https://civitai.com/models/{model_id}"
if version_id:
civitai_url += f"?modelVersionId={version_id}"
return {
'civitai_url': civitai_url,
'model_id': str(model_id),
'version_id': str(version_id) if version_id else None
}
return {'civitai_url': None, 'model_id': None, 'version_id': None}
def find_duplicate_hashes(self) -> Dict: def find_duplicate_hashes(self) -> Dict:
"""Find LoRAs with duplicate SHA256 hashes""" """Find LoRAs with duplicate SHA256 hashes"""
return self.scanner._hash_index.get_duplicate_hashes() return self.scanner._hash_index.get_duplicate_hashes()

View File

@@ -1,151 +0,0 @@
import zipfile
import logging
import asyncio
from pathlib import Path
from typing import Optional
from .downloader import get_downloader
logger = logging.getLogger(__name__)
class MetadataArchiveManager:
"""Manages downloading and extracting Civitai metadata archive database"""
DOWNLOAD_URLS = [
"https://github.com/willmiao/civitai-metadata-archive-db/releases/download/db-2025-08-08/civitai.zip",
"https://huggingface.co/datasets/willmiao/civitai-metadata-archive-db/blob/main/civitai.zip"
]
def __init__(self, base_path: str):
"""Initialize with base path where files will be stored"""
self.base_path = Path(base_path)
self.civitai_folder = self.base_path / "civitai"
self.archive_path = self.base_path / "civitai.zip"
self.db_path = self.civitai_folder / "civitai.sqlite"
def is_database_available(self) -> bool:
"""Check if the SQLite database is available and valid"""
return self.db_path.exists() and self.db_path.stat().st_size > 0
def get_database_path(self) -> Optional[str]:
"""Get the path to the SQLite database if available"""
if self.is_database_available():
return str(self.db_path)
return None
async def download_and_extract_database(self, progress_callback=None) -> bool:
"""Download and extract the metadata archive database
Args:
progress_callback: Optional callback function to report progress
Returns:
bool: True if successful, False otherwise
"""
try:
# Create directories if they don't exist
self.base_path.mkdir(parents=True, exist_ok=True)
self.civitai_folder.mkdir(parents=True, exist_ok=True)
# Download the archive
if not await self._download_archive(progress_callback):
return False
# Extract the archive
if not await self._extract_archive(progress_callback):
return False
# Clean up the archive file
if self.archive_path.exists():
self.archive_path.unlink()
logger.info(f"Successfully downloaded and extracted metadata database to {self.db_path}")
return True
except Exception as e:
logger.error(f"Error downloading and extracting metadata database: {e}", exc_info=True)
return False
async def _download_archive(self, progress_callback=None) -> bool:
"""Download the zip archive from one of the available URLs"""
downloader = await get_downloader()
for url in self.DOWNLOAD_URLS:
try:
logger.info(f"Attempting to download from {url}")
if progress_callback:
progress_callback("download", f"Downloading from {url}")
# Custom progress callback to report download progress
async def download_progress(progress):
if progress_callback:
progress_callback("download", f"Downloading archive... {progress:.1f}%")
success, result = await downloader.download_file(
url=url,
save_path=str(self.archive_path),
progress_callback=download_progress,
use_auth=False, # Public download, no auth needed
allow_resume=True
)
if success:
logger.info(f"Successfully downloaded archive from {url}")
return True
else:
logger.warning(f"Failed to download from {url}: {result}")
continue
except Exception as e:
logger.warning(f"Error downloading from {url}: {e}")
continue
logger.error("Failed to download archive from any URL")
return False
async def _extract_archive(self, progress_callback=None) -> bool:
"""Extract the zip archive to the civitai folder"""
try:
if progress_callback:
progress_callback("extract", "Extracting archive...")
# Run extraction in thread pool to avoid blocking
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._extract_zip_sync)
if progress_callback:
progress_callback("extract", "Extraction completed")
return True
except Exception as e:
logger.error(f"Error extracting archive: {e}", exc_info=True)
return False
def _extract_zip_sync(self):
"""Synchronous zip extraction (runs in thread pool)"""
with zipfile.ZipFile(self.archive_path, 'r') as archive:
archive.extractall(path=self.base_path)
async def remove_database(self) -> bool:
"""Remove the metadata database and folder"""
try:
if self.civitai_folder.exists():
# Remove all files in the civitai folder
for file_path in self.civitai_folder.iterdir():
if file_path.is_file():
file_path.unlink()
# Remove the folder itself
self.civitai_folder.rmdir()
# Also remove the archive file if it exists
if self.archive_path.exists():
self.archive_path.unlink()
logger.info("Successfully removed metadata database")
return True
except Exception as e:
logger.error(f"Error removing metadata database: {e}", exc_info=True)
return False

View File

@@ -1,117 +0,0 @@
import os
import logging
from .model_metadata_provider import (
ModelMetadataProviderManager,
SQLiteModelMetadataProvider,
CivitaiModelMetadataProvider,
FallbackMetadataProvider
)
from .settings_manager import settings
from .metadata_archive_manager import MetadataArchiveManager
from .service_registry import ServiceRegistry
logger = logging.getLogger(__name__)
async def initialize_metadata_providers():
"""Initialize and configure all metadata providers based on settings"""
provider_manager = await ModelMetadataProviderManager.get_instance()
# Clear existing providers to allow reinitialization
provider_manager.providers.clear()
provider_manager.default_provider = None
# Get settings
enable_archive_db = settings.get('enable_metadata_archive_db', False)
providers = []
# Initialize archive database provider if enabled
if enable_archive_db:
try:
# Initialize archive manager
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
archive_manager = MetadataArchiveManager(base_path)
db_path = archive_manager.get_database_path()
if db_path and os.path.exists(db_path):
sqlite_provider = SQLiteModelMetadataProvider(db_path)
provider_manager.register_provider('sqlite', sqlite_provider)
providers.append(('sqlite', sqlite_provider))
logger.info(f"SQLite metadata provider registered with database: {db_path}")
else:
logger.warning("Metadata archive database is enabled but database file not found")
except Exception as e:
logger.error(f"Failed to initialize SQLite metadata provider: {e}")
# Initialize Civitai API provider (always available as fallback)
try:
civitai_client = await ServiceRegistry.get_civitai_client()
civitai_provider = CivitaiModelMetadataProvider(civitai_client)
provider_manager.register_provider('civitai_api', civitai_provider)
providers.append(('civitai_api', civitai_provider))
logger.debug("Civitai API metadata provider registered")
except Exception as e:
logger.error(f"Failed to initialize Civitai API metadata provider: {e}")
# Register CivArchive provider, but do NOT add to fallback providers
try:
from .model_metadata_provider import CivArchiveModelMetadataProvider
civarchive_provider = CivArchiveModelMetadataProvider()
provider_manager.register_provider('civarchive', civarchive_provider)
logger.debug("CivArchive metadata provider registered (not included in fallback)")
except Exception as e:
logger.error(f"Failed to initialize CivArchive metadata provider: {e}")
# Set up fallback provider based on available providers
if len(providers) > 1:
# Always use Civitai API first, then Archive DB
ordered_providers = []
ordered_providers.extend([p[1] for p in providers if p[0] == 'civitai_api'])
ordered_providers.extend([p[1] for p in providers if p[0] == 'sqlite'])
if ordered_providers:
fallback_provider = FallbackMetadataProvider(ordered_providers)
provider_manager.register_provider('fallback', fallback_provider, is_default=True)
logger.info(f"Fallback metadata provider registered with {len(ordered_providers)} providers, Civitai API first")
elif len(providers) == 1:
# Only one provider available, set it as default
provider_name, provider = providers[0]
provider_manager.register_provider(provider_name, provider, is_default=True)
logger.debug(f"Single metadata provider registered as default: {provider_name}")
else:
logger.warning("No metadata providers available - this may cause metadata lookup failures")
return provider_manager
async def update_metadata_providers():
"""Update metadata providers based on current settings"""
try:
# Get current settings
enable_archive_db = settings.get('enable_metadata_archive_db', False)
# Reinitialize all providers with new settings
provider_manager = await initialize_metadata_providers()
logger.info(f"Updated metadata providers, archive_db enabled: {enable_archive_db}")
return provider_manager
except Exception as e:
logger.error(f"Failed to update metadata providers: {e}")
return await ModelMetadataProviderManager.get_instance()
async def get_metadata_archive_manager():
"""Get metadata archive manager instance"""
base_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
return MetadataArchiveManager(base_path)
async def get_metadata_provider(provider_name: str = None):
"""Get a specific metadata provider or default provider"""
provider_manager = await ModelMetadataProviderManager.get_instance()
if provider_name:
return provider_manager._get_provider(provider_name)
return provider_manager._get_provider()
async def get_default_metadata_provider():
"""Get the default metadata provider (fallback or single provider)"""
return await get_metadata_provider()

View File

@@ -1,355 +0,0 @@
"""Services for synchronising metadata with remote providers."""
from __future__ import annotations
import json
import logging
import os
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
from ..services.settings_manager import SettingsManager
from ..utils.model_utils import determine_base_model
logger = logging.getLogger(__name__)
class MetadataProviderProtocol:
"""Subset of metadata provider interface consumed by the sync service."""
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
...
async def get_model_version(
self, model_id: int, model_version_id: Optional[int]
) -> Optional[Dict[str, Any]]:
...
class MetadataSyncService:
"""High level orchestration for metadata synchronisation flows."""
def __init__(
self,
*,
metadata_manager,
preview_service,
settings: SettingsManager,
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
) -> None:
self._metadata_manager = metadata_manager
self._preview_service = preview_service
self._settings = settings
self._get_default_provider = default_metadata_provider_factory
self._get_provider = metadata_provider_selector
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
"""Load metadata JSON from disk, returning an empty structure when missing."""
if not os.path.exists(metadata_path):
return {}
try:
with open(metadata_path, "r", encoding="utf-8") as handle:
return json.load(handle)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
return {}
async def mark_not_found_on_civitai(
self, metadata_path: str, local_metadata: Dict[str, Any]
) -> None:
"""Persist the not-found flag for a metadata payload."""
local_metadata["from_civitai"] = False
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
@staticmethod
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
"""Determine if the metadata originated from the CivitAI public API."""
if not isinstance(meta, dict):
return False
files = meta.get("files")
images = meta.get("images")
source = meta.get("source")
return bool(files) and bool(images) and source != "archive_db"
async def update_model_metadata(
self,
metadata_path: str,
local_metadata: Dict[str, Any],
civitai_metadata: Dict[str, Any],
metadata_provider: Optional[MetadataProviderProtocol] = None,
) -> Dict[str, Any]:
"""Merge remote metadata into the local record and persist the result."""
existing_civitai = local_metadata.get("civitai") or {}
if (
civitai_metadata.get("source") == "archive_db"
and self.is_civitai_api_metadata(existing_civitai)
):
logger.info(
"Skip civitai update for %s (%s)",
local_metadata.get("model_name", ""),
existing_civitai.get("name", ""),
)
else:
merged_civitai = existing_civitai.copy()
merged_civitai.update(civitai_metadata)
if civitai_metadata.get("source") == "archive_db":
model_name = civitai_metadata.get("model", {}).get("name", "")
version_name = civitai_metadata.get("name", "")
logger.info(
"Recovered metadata from archive_db for deleted model: %s (%s)",
model_name,
version_name,
)
if "trainedWords" in existing_civitai:
existing_trained = existing_civitai.get("trainedWords", [])
new_trained = civitai_metadata.get("trainedWords", [])
merged_trained = list(set(existing_trained + new_trained))
merged_civitai["trainedWords"] = merged_trained
local_metadata["civitai"] = merged_civitai
if "model" in civitai_metadata and civitai_metadata["model"]:
model_data = civitai_metadata["model"]
if model_data.get("name"):
local_metadata["model_name"] = model_data["name"]
if not local_metadata.get("modelDescription") and model_data.get("description"):
local_metadata["modelDescription"] = model_data["description"]
if not local_metadata.get("tags") and model_data.get("tags"):
local_metadata["tags"] = model_data["tags"]
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
"creator"
):
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
local_metadata["base_model"] = determine_base_model(
civitai_metadata.get("baseModel")
)
await self._preview_service.ensure_preview_for_metadata(
metadata_path, local_metadata, civitai_metadata.get("images", [])
)
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
return local_metadata
async def fetch_and_update_model(
self,
*,
sha256: str,
file_path: str,
model_data: Dict[str, Any],
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
) -> tuple[bool, Optional[str]]:
"""Fetch metadata for a model and update both disk and cache state."""
if not isinstance(model_data, dict):
error = f"Invalid model_data type: {type(model_data)}"
logger.error(error)
return False, error
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
enable_archive = self._settings.get("enable_metadata_archive_db", False)
try:
if model_data.get("civitai_deleted") is True:
if not enable_archive or model_data.get("db_checked") is True:
return (
False,
"CivitAI model is deleted and metadata archive DB is not enabled",
)
metadata_provider = await self._get_provider("sqlite")
else:
metadata_provider = await self._get_default_provider()
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
if not civitai_metadata:
if error == "Model not found":
model_data["from_civitai"] = False
model_data["civitai_deleted"] = True
model_data["db_checked"] = enable_archive
model_data["last_checked_at"] = datetime.now().timestamp()
data_to_save = model_data.copy()
data_to_save.pop("folder", None)
await self._metadata_manager.save_metadata(file_path, data_to_save)
error_msg = (
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
)
logger.error(error_msg)
return False, error_msg
model_data["from_civitai"] = True
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
model_data["db_checked"] = enable_archive
model_data["last_checked_at"] = datetime.now().timestamp()
local_metadata = model_data.copy()
local_metadata.pop("folder", None)
await self.update_model_metadata(
metadata_path,
local_metadata,
civitai_metadata,
metadata_provider,
)
update_payload = {
"model_name": local_metadata.get("model_name"),
"preview_url": local_metadata.get("preview_url"),
"civitai": local_metadata.get("civitai"),
}
model_data.update(update_payload)
await update_cache_func(file_path, file_path, local_metadata)
return True, None
except KeyError as exc:
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
logger.error(error_msg)
return False, error_msg
except Exception as exc: # pragma: no cover - error path
error_msg = f"Error fetching metadata: {exc}"
logger.error(error_msg, exc_info=True)
return False, error_msg
async def fetch_metadata_by_sha(
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
"""Fetch metadata for a SHA256 hash from the configured provider."""
provider = metadata_provider or await self._get_default_provider()
return await provider.get_model_by_hash(sha256)
async def relink_metadata(
self,
*,
file_path: str,
metadata: Dict[str, Any],
model_id: int,
model_version_id: Optional[int],
) -> Dict[str, Any]:
"""Relink a local metadata record to a specific CivitAI model version."""
provider = await self._get_default_provider()
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
if not civitai_metadata:
raise ValueError(
f"Model version not found on CivitAI for ID: {model_id}"
+ (f" with version: {model_version_id}" if model_version_id else "")
)
primary_model_file: Optional[Dict[str, Any]] = None
for file_info in civitai_metadata.get("files", []):
if file_info.get("primary", False) and file_info.get("type") == "Model":
primary_model_file = file_info
break
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
await self.update_model_metadata(
metadata_path,
metadata,
civitai_metadata,
provider,
)
return metadata
async def save_metadata_updates(
self,
*,
file_path: str,
updates: Dict[str, Any],
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
) -> Dict[str, Any]:
"""Apply metadata updates and persist to disk and cache."""
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
metadata = await metadata_loader(metadata_path)
for key, value in updates.items():
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
metadata[key].update(value)
else:
metadata[key] = value
await self._metadata_manager.save_metadata(file_path, metadata)
await update_cache(file_path, file_path, metadata)
if "model_name" in updates:
logger.debug("Metadata update touched model_name; cache resort required")
return metadata
async def verify_duplicate_hashes(
self,
*,
file_paths: Iterable[str],
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
hash_calculator: Callable[[str], Awaitable[str]],
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
) -> Dict[str, Any]:
"""Verify a collection of files share the same SHA256 hash."""
file_paths = list(file_paths)
if not file_paths:
raise ValueError("No file paths provided for verification")
results = {
"verified_as_duplicates": True,
"mismatched_files": [],
"new_hash_map": {},
}
expected_hash: Optional[str] = None
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
first_metadata = await metadata_loader(first_metadata_path)
if first_metadata and "sha256" in first_metadata:
expected_hash = first_metadata["sha256"].lower()
for path in file_paths:
if not os.path.exists(path):
continue
try:
actual_hash = await hash_calculator(path)
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
metadata = await metadata_loader(metadata_path)
stored_hash = metadata.get("sha256", "").lower()
if not expected_hash:
expected_hash = stored_hash
if actual_hash != expected_hash:
results["verified_as_duplicates"] = False
results["mismatched_files"].append(path)
results["new_hash_map"][path] = actual_hash
if actual_hash != stored_hash:
metadata["sha256"] = actual_hash
await self._metadata_manager.save_metadata(path, metadata)
await update_cache(path, path, metadata)
except Exception as exc: # pragma: no cover - defensive path
logger.error("Error verifying hash for %s: %s", path, exc)
results["mismatched_files"].append(path)
results["new_hash_map"][path] = "error_calculating_hash"
results["verified_as_duplicates"] = False
return results

View File

@@ -1,463 +0,0 @@
import asyncio
import os
import logging
from typing import List, Dict, Optional, Any, Set
from abc import ABC, abstractmethod
from ..utils.utils import calculate_relative_path_for_model, remove_empty_dirs
from ..utils.constants import AUTO_ORGANIZE_BATCH_SIZE
from ..services.settings_manager import settings
logger = logging.getLogger(__name__)
class ProgressCallback(ABC):
"""Abstract callback interface for progress reporting"""
@abstractmethod
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Called when progress is updated"""
pass
class AutoOrganizeResult:
"""Result object for auto-organize operations"""
def __init__(self):
self.total: int = 0
self.processed: int = 0
self.success_count: int = 0
self.failure_count: int = 0
self.skipped_count: int = 0
self.operation_type: str = 'unknown'
self.cleanup_counts: Dict[str, int] = {}
self.results: List[Dict[str, Any]] = []
self.results_truncated: bool = False
self.sample_results: List[Dict[str, Any]] = []
self.is_flat_structure: bool = False
def to_dict(self) -> Dict[str, Any]:
"""Convert result to dictionary"""
result = {
'success': True,
'message': f'Auto-organize {self.operation_type} completed: {self.success_count} moved, {self.skipped_count} skipped, {self.failure_count} failed out of {self.total} total',
'summary': {
'total': self.total,
'success': self.success_count,
'skipped': self.skipped_count,
'failures': self.failure_count,
'organization_type': 'flat' if self.is_flat_structure else 'structured',
'cleaned_dirs': self.cleanup_counts,
'operation_type': self.operation_type
}
}
if self.results_truncated:
result['results_truncated'] = True
result['sample_results'] = self.sample_results
else:
result['results'] = self.results
return result
class ModelFileService:
"""Service for handling model file operations and organization"""
def __init__(self, scanner, model_type: str):
"""Initialize the service
Args:
scanner: Model scanner instance
model_type: Type of model (e.g., 'lora', 'checkpoint')
"""
self.scanner = scanner
self.model_type = model_type
def get_model_roots(self) -> List[str]:
"""Get model root directories"""
return self.scanner.get_model_roots()
async def auto_organize_models(
self,
file_paths: Optional[List[str]] = None,
progress_callback: Optional[ProgressCallback] = None
) -> AutoOrganizeResult:
"""Auto-organize models based on current settings
Args:
file_paths: Optional list of specific file paths to organize.
If None, organizes all models.
progress_callback: Optional callback for progress updates
Returns:
AutoOrganizeResult object with operation results
"""
result = AutoOrganizeResult()
source_directories: Set[str] = set()
try:
# Get all models from cache
cache = await self.scanner.get_cached_data()
all_models = cache.raw_data
# Filter models if specific file paths are provided
if file_paths:
all_models = [model for model in all_models if model.get('file_path') in file_paths]
result.operation_type = 'bulk'
else:
result.operation_type = 'all'
# Get model roots for this scanner
model_roots = self.get_model_roots()
if not model_roots:
raise ValueError('No model roots configured')
# Check if flat structure is configured for this model type
path_template = settings.get_download_path_template(self.model_type)
result.is_flat_structure = not path_template
# Initialize tracking
result.total = len(all_models)
# Send initial progress
if progress_callback:
await progress_callback.on_progress({
'type': 'auto_organize_progress',
'status': 'started',
'total': result.total,
'processed': 0,
'success': 0,
'failures': 0,
'skipped': 0,
'operation_type': result.operation_type
})
# Process models in batches
await self._process_models_in_batches(
all_models,
model_roots,
result,
progress_callback,
source_directories # Pass the set to track source directories
)
# Send cleanup progress
if progress_callback:
await progress_callback.on_progress({
'type': 'auto_organize_progress',
'status': 'cleaning',
'total': result.total,
'processed': result.processed,
'success': result.success_count,
'failures': result.failure_count,
'skipped': result.skipped_count,
'message': 'Cleaning up empty directories...',
'operation_type': result.operation_type
})
# Clean up empty directories - only in affected directories for bulk operations
cleanup_paths = list(source_directories) if result.operation_type == 'bulk' else model_roots
result.cleanup_counts = await self._cleanup_empty_directories(cleanup_paths)
# Send completion message
if progress_callback:
await progress_callback.on_progress({
'type': 'auto_organize_progress',
'status': 'completed',
'total': result.total,
'processed': result.processed,
'success': result.success_count,
'failures': result.failure_count,
'skipped': result.skipped_count,
'cleanup': result.cleanup_counts,
'operation_type': result.operation_type
})
return result
except Exception as e:
logger.error(f"Error in auto_organize_models: {e}", exc_info=True)
# Send error message
if progress_callback:
await progress_callback.on_progress({
'type': 'auto_organize_progress',
'status': 'error',
'error': str(e),
'operation_type': result.operation_type
})
raise e
async def _process_models_in_batches(
self,
all_models: List[Dict[str, Any]],
model_roots: List[str],
result: AutoOrganizeResult,
progress_callback: Optional[ProgressCallback],
source_directories: Optional[Set[str]] = None
) -> None:
"""Process models in batches to avoid overwhelming the system"""
for i in range(0, result.total, AUTO_ORGANIZE_BATCH_SIZE):
batch = all_models[i:i + AUTO_ORGANIZE_BATCH_SIZE]
for model in batch:
await self._process_single_model(model, model_roots, result, source_directories)
result.processed += 1
# Send progress update after each batch
if progress_callback:
await progress_callback.on_progress({
'type': 'auto_organize_progress',
'status': 'processing',
'total': result.total,
'processed': result.processed,
'success': result.success_count,
'failures': result.failure_count,
'skipped': result.skipped_count,
'operation_type': result.operation_type
})
# Small delay between batches
await asyncio.sleep(0.1)
async def _process_single_model(
self,
model: Dict[str, Any],
model_roots: List[str],
result: AutoOrganizeResult,
source_directories: Optional[Set[str]] = None
) -> None:
"""Process a single model for organization"""
try:
file_path = model.get('file_path')
model_name = model.get('model_name', 'Unknown')
if not file_path:
self._add_result(result, model_name, False, "No file path found")
result.failure_count += 1
return
# Find which model root this file belongs to
current_root = self._find_model_root(file_path, model_roots)
if not current_root:
self._add_result(result, model_name, False,
"Model file not found in any configured root directory")
result.failure_count += 1
return
# Determine target directory
target_dir = await self._calculate_target_directory(
model, current_root, result.is_flat_structure
)
if target_dir is None:
self._add_result(result, model_name, False,
"Skipped - insufficient metadata for organization")
result.skipped_count += 1
return
current_dir = os.path.dirname(file_path)
# Skip if already in correct location
if current_dir.replace(os.sep, '/') == target_dir.replace(os.sep, '/'):
result.skipped_count += 1
return
# Check for conflicts
file_name = os.path.basename(file_path)
target_file_path = os.path.join(target_dir, file_name)
if os.path.exists(target_file_path):
self._add_result(result, model_name, False,
f"Target file already exists: {target_file_path}")
result.failure_count += 1
return
# Store the source directory for potential cleanup
if source_directories is not None:
source_directories.add(current_dir)
# Perform the move
success = await self.scanner.move_model(file_path, target_dir)
if success:
result.success_count += 1
else:
self._add_result(result, model_name, False, "Failed to move model")
result.failure_count += 1
except Exception as e:
logger.error(f"Error processing model {model.get('model_name', 'Unknown')}: {e}", exc_info=True)
self._add_result(result, model.get('model_name', 'Unknown'), False, f"Error: {str(e)}")
result.failure_count += 1
def _find_model_root(self, file_path: str, model_roots: List[str]) -> Optional[str]:
"""Find which model root the file belongs to"""
for root in model_roots:
# Normalize paths for comparison
normalized_root = os.path.normpath(root).replace(os.sep, '/')
normalized_file = os.path.normpath(file_path).replace(os.sep, '/')
if normalized_file.startswith(normalized_root):
return root
return None
async def _calculate_target_directory(
self,
model: Dict[str, Any],
current_root: str,
is_flat_structure: bool
) -> Optional[str]:
"""Calculate the target directory for a model"""
if is_flat_structure:
file_path = model.get('file_path')
current_dir = os.path.dirname(file_path)
# Check if already in root directory
if os.path.normpath(current_dir) == os.path.normpath(current_root):
return None # Signal to skip
return current_root
else:
# Calculate new relative path based on settings
new_relative_path = calculate_relative_path_for_model(model, self.model_type)
if not new_relative_path:
return None # Signal to skip
return os.path.join(current_root, new_relative_path).replace(os.sep, '/')
def _add_result(
self,
result: AutoOrganizeResult,
model_name: str,
success: bool,
message: str
) -> None:
"""Add a result entry if under the limit"""
if len(result.results) < 100: # Limit detailed results
result.results.append({
"model": model_name,
"success": success,
"message": message
})
elif len(result.results) == 100:
# Mark as truncated and save sample
result.results_truncated = True
result.sample_results = result.results[:50]
async def _cleanup_empty_directories(self, paths: List[str]) -> Dict[str, int]:
"""Clean up empty directories after organizing
Args:
paths: List of paths to check for empty directories
Returns:
Dictionary with counts of removed directories by root path
"""
cleanup_counts = {}
for path in paths:
removed = remove_empty_dirs(path)
cleanup_counts[path] = removed
return cleanup_counts
class ModelMoveService:
"""Service for handling individual model moves"""
def __init__(self, scanner):
"""Initialize the service
Args:
scanner: Model scanner instance
"""
self.scanner = scanner
async def move_model(self, file_path: str, target_path: str) -> Dict[str, Any]:
"""Move a single model file
Args:
file_path: Source file path
target_path: Target directory path
Returns:
Dictionary with move result
"""
try:
source_dir = os.path.dirname(file_path)
if os.path.normpath(source_dir) == os.path.normpath(target_path):
logger.info(f"Source and target directories are the same: {source_dir}")
return {
'success': True,
'message': 'Source and target directories are the same',
'original_file_path': file_path,
'new_file_path': file_path
}
new_file_path = await self.scanner.move_model(file_path, target_path)
if new_file_path:
return {
'success': True,
'original_file_path': file_path,
'new_file_path': new_file_path
}
else:
return {
'success': False,
'error': 'Failed to move model',
'original_file_path': file_path,
'new_file_path': None
}
except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True)
return {
'success': False,
'error': str(e),
'original_file_path': file_path,
'new_file_path': None
}
async def move_models_bulk(self, file_paths: List[str], target_path: str) -> Dict[str, Any]:
"""Move multiple model files
Args:
file_paths: List of source file paths
target_path: Target directory path
Returns:
Dictionary with bulk move results
"""
try:
results = []
for file_path in file_paths:
result = await self.move_model(file_path, target_path)
results.append({
"original_file_path": file_path,
"new_file_path": result.get('new_file_path'),
"success": result['success'],
"message": result.get('message', result.get('error', 'Unknown'))
})
success_count = sum(1 for r in results if r["success"])
failure_count = len(results) - success_count
return {
'success': True,
'message': f'Moved {success_count} of {len(file_paths)} models',
'results': results,
'success_count': success_count,
'failure_count': failure_count
}
except Exception as e:
logger.error(f"Error moving models in bulk: {e}", exc_info=True)
return {
'success': False,
'error': str(e),
'results': [],
'success_count': 0,
'failure_count': len(file_paths)
}

View File

@@ -31,34 +31,29 @@ class ModelHashIndex:
if file_path not in self._duplicate_hashes.get(sha256, []): if file_path not in self._duplicate_hashes.get(sha256, []):
self._duplicate_hashes.setdefault(sha256, []).append(file_path) self._duplicate_hashes.setdefault(sha256, []).append(file_path)
# Track duplicates by filename - FIXED LOGIC # Track duplicates by filename
if filename in self._filename_to_hash: if filename in self._filename_to_hash:
existing_hash = self._filename_to_hash[filename] old_hash = self._filename_to_hash[filename]
existing_path = self._hash_to_path.get(existing_hash) if old_hash != sha256: # Different models with the same name
old_path = self._hash_to_path.get(old_hash)
# If this is a different file with the same filename if old_path:
if existing_path and existing_path != file_path: if filename not in self._duplicate_filenames:
# Initialize duplicates tracking if needed self._duplicate_filenames[filename] = [old_path]
if filename not in self._duplicate_filenames: if file_path not in self._duplicate_filenames.get(filename, []):
self._duplicate_filenames[filename] = [existing_path] self._duplicate_filenames.setdefault(filename, []).append(file_path)
# Add current file to duplicates if not already present
if file_path not in self._duplicate_filenames[filename]:
self._duplicate_filenames[filename].append(file_path)
# Remove old path mapping if hash exists # Remove old path mapping if hash exists
if sha256 in self._hash_to_path: if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256] old_path = self._hash_to_path[sha256]
old_filename = self._get_filename_from_path(old_path) old_filename = self._get_filename_from_path(old_path)
if old_filename in self._filename_to_hash and self._filename_to_hash[old_filename] == sha256: if old_filename in self._filename_to_hash:
del self._filename_to_hash[old_filename] del self._filename_to_hash[old_filename]
# Remove old hash mapping if filename exists and points to different hash # Remove old hash mapping if filename exists
if filename in self._filename_to_hash: if filename in self._filename_to_hash:
old_hash = self._filename_to_hash[filename] old_hash = self._filename_to_hash[filename]
if old_hash != sha256 and old_hash in self._hash_to_path: if old_hash in self._hash_to_path:
# Don't delete the old hash mapping, just update filename mapping del self._hash_to_path[old_hash]
pass
# Add new mappings # Add new mappings
self._hash_to_path[sha256] = file_path self._hash_to_path[sha256] = file_path
@@ -204,6 +199,8 @@ class ModelHashIndex:
def get_hash_by_filename(self, filename: str) -> Optional[str]: def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a filename without extension""" """Get hash for a filename without extension"""
# Strip extension if present to make the function more flexible
filename = os.path.splitext(filename)[0]
return self._filename_to_hash.get(filename) return self._filename_to_hash.get(filename)
def clear(self) -> None: def clear(self) -> None:

View File

@@ -1,245 +0,0 @@
"""Service routines for model lifecycle mutations."""
from __future__ import annotations
import logging
import os
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
from ..services.service_registry import ServiceRegistry
from ..utils.constants import PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__)
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
"""Delete the primary model artefacts within ``target_dir``."""
patterns = [
f"{file_name}.safetensors",
f"{file_name}.metadata.json",
]
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{file_name}{ext}")
deleted: List[str] = []
main_file = patterns[0]
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
if os.path.exists(main_path):
os.remove(main_path)
deleted.append(main_path)
else:
logger.warning("Model file not found: %s", main_file)
for pattern in patterns[1:]:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
try:
os.remove(path)
deleted.append(pattern)
except Exception as exc: # pragma: no cover - defensive path
logger.warning("Failed to delete %s: %s", pattern, exc)
return deleted
class ModelLifecycleService:
"""Co-ordinate destructive and mutating model operations."""
def __init__(
self,
*,
scanner,
metadata_manager,
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
) -> None:
self._scanner = scanner
self._metadata_manager = metadata_manager
self._metadata_loader = metadata_loader
self._recipe_scanner_factory = (
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
)
async def delete_model(self, file_path: str) -> Dict[str, object]:
"""Delete a model file and associated artefacts."""
if not file_path:
raise ValueError("Model path is required")
target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await delete_model_artifacts(target_dir, file_name)
cache = await self._scanner.get_cached_data()
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
await cache.resort()
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
self._scanner._hash_index.remove_by_path(file_path)
return {"success": True, "deleted_files": deleted_files}
async def exclude_model(self, file_path: str) -> Dict[str, object]:
"""Mark a model as excluded and prune cache references."""
if not file_path:
raise ValueError("Model path is required")
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
metadata = await self._metadata_loader(metadata_path)
metadata["exclude"] = True
await self._metadata_manager.save_metadata(file_path, metadata)
cache = await self._scanner.get_cached_data()
model_to_remove = next(
(item for item in cache.raw_data if item["file_path"] == file_path),
None,
)
if model_to_remove:
for tag in model_to_remove.get("tags", []):
if tag in getattr(self._scanner, "_tags_count", {}):
self._scanner._tags_count[tag] = max(
0, self._scanner._tags_count[tag] - 1
)
if self._scanner._tags_count[tag] == 0:
del self._scanner._tags_count[tag]
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
self._scanner._hash_index.remove_by_path(file_path)
cache.raw_data = [
item for item in cache.raw_data if item["file_path"] != file_path
]
await cache.resort()
excluded = getattr(self._scanner, "_excluded_models", None)
if isinstance(excluded, list):
excluded.append(file_path)
message = f"Model {os.path.basename(file_path)} excluded"
return {"success": True, "message": message}
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
"""Delete a collection of models via the scanner bulk operation."""
file_paths = list(file_paths)
if not file_paths:
raise ValueError("No file paths provided for deletion")
return await self._scanner.bulk_delete_models(file_paths)
async def rename_model(
self, *, file_path: str, new_file_name: str
) -> Dict[str, object]:
"""Rename a model and its companion artefacts."""
if not file_path or not new_file_name:
raise ValueError("File path and new file name are required")
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
if any(char in new_file_name for char in invalid_chars):
raise ValueError("Invalid characters in file name")
target_dir = os.path.dirname(file_path)
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
os.sep, "/"
)
if os.path.exists(new_file_path):
raise ValueError("A file with this name already exists")
patterns = [
f"{old_file_name}.safetensors",
f"{old_file_name}.metadata.json",
f"{old_file_name}.metadata.json.bak",
]
for ext in PREVIEW_EXTENSIONS:
patterns.append(f"{old_file_name}{ext}")
existing_files: List[tuple[str, str]] = []
for pattern in patterns:
path = os.path.join(target_dir, pattern)
if os.path.exists(path):
existing_files.append((path, pattern))
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
metadata: Optional[Dict[str, object]] = None
hash_value: Optional[str] = None
if os.path.exists(metadata_path):
metadata = await self._metadata_loader(metadata_path)
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
renamed_files: List[str] = []
new_metadata_path: Optional[str] = None
new_preview: Optional[str] = None
for old_path, pattern in existing_files:
ext = self._get_multipart_ext(pattern)
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
os.sep, "/"
)
os.rename(old_path, new_path)
renamed_files.append(new_path)
if ext == ".metadata.json":
new_metadata_path = new_path
if metadata and new_metadata_path:
metadata["file_name"] = new_file_name
metadata["file_path"] = new_file_path
if metadata.get("preview_url"):
old_preview = str(metadata["preview_url"])
ext = self._get_multipart_ext(old_preview)
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
os.sep, "/"
)
metadata["preview_url"] = new_preview
await self._metadata_manager.save_metadata(new_file_path, metadata)
if metadata:
await self._scanner.update_single_model_cache(
file_path, new_file_path, metadata
)
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
recipe_scanner = await self._recipe_scanner_factory()
if recipe_scanner:
try:
await recipe_scanner.update_lora_filename_by_hash(
hash_value, new_file_name
)
except Exception as exc: # pragma: no cover - defensive logging
logger.error(
"Error updating recipe references for %s: %s",
file_path,
exc,
)
return {
"success": True,
"new_file_path": new_file_path,
"new_preview_path": new_preview,
"renamed_files": renamed_files,
"reload_required": False,
}
@staticmethod
def _get_multipart_ext(filename: str) -> str:
"""Return the extension for files with compound suffixes."""
parts = filename.split(".")
if len(parts) == 3:
return "." + ".".join(parts[-2:])
if len(parts) >= 4:
return "." + ".".join(parts[-3:])
return os.path.splitext(filename)[1]

View File

@@ -1,495 +0,0 @@
from abc import ABC, abstractmethod
import json
import logging
from typing import Optional, Dict, Tuple, Any
from .downloader import get_downloader
try:
from bs4 import BeautifulSoup
except ImportError as exc:
BeautifulSoup = None # type: ignore[assignment]
_BS4_IMPORT_ERROR = exc
else:
_BS4_IMPORT_ERROR = None
try:
import aiosqlite
except ImportError as exc:
aiosqlite = None # type: ignore[assignment]
_AIOSQLITE_IMPORT_ERROR = exc
else:
_AIOSQLITE_IMPORT_ERROR = None
def _require_beautifulsoup() -> Any:
if BeautifulSoup is None:
raise RuntimeError(
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
"Install it with 'pip install beautifulsoup4'."
) from _BS4_IMPORT_ERROR
return BeautifulSoup
def _require_aiosqlite() -> Any:
if aiosqlite is None:
raise RuntimeError(
"aiosqlite is required for SQLiteModelMetadataProvider. "
"Install it with 'pip install aiosqlite'."
) from _AIOSQLITE_IMPORT_ERROR
return aiosqlite
logger = logging.getLogger(__name__)
class ModelMetadataProvider(ABC):
"""Base abstract class for all model metadata providers"""
@abstractmethod
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value"""
pass
@abstractmethod
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model with their details"""
pass
@abstractmethod
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata"""
pass
@abstractmethod
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata"""
pass
class CivitaiModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses Civitai API for metadata"""
def __init__(self, civitai_client):
self.client = civitai_client
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_by_hash(model_hash)
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
return await self.client.get_model_versions(model_id)
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
return await self.client.get_model_version(model_id, version_id)
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
return await self.client.get_model_version_info(version_id)
class CivArchiveModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses CivArchive HTML page parsing for metadata"""
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Not supported by CivArchive provider"""
return None, "CivArchive provider does not support hash lookup"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Not supported by CivArchive provider"""
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version by parsing CivArchive HTML page"""
if model_id is None or version_id is None:
return None
try:
# Construct CivArchive URL
url = f"https://civarchive.com/models/{model_id}?modelVersionId={version_id}"
downloader = await get_downloader()
session = await downloader.session
async with session.get(url) as response:
if response.status != 200:
return None
html_content = await response.text()
# Parse HTML to extract JSON data
soup_parser = _require_beautifulsoup()
soup = soup_parser(html_content, 'html.parser')
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
if not script_tag:
return None
# Parse JSON content
json_data = json.loads(script_tag.string)
model_data = json_data.get('props', {}).get('pageProps', {}).get('model')
if not model_data or 'version' not in model_data:
return None
# Extract version data as base
version = model_data['version'].copy()
# Restructure stats
if 'downloadCount' in version and 'ratingCount' in version and 'rating' in version:
version['stats'] = {
'downloadCount': version.pop('downloadCount'),
'ratingCount': version.pop('ratingCount'),
'rating': version.pop('rating')
}
# Rename trigger to trainedWords
if 'trigger' in version:
version['trainedWords'] = version.pop('trigger')
# Transform files data to expected format
if 'files' in version:
transformed_files = []
for file_data in version['files']:
# Find first available mirror (deletedAt is null)
available_mirror = None
for mirror in file_data.get('mirrors', []):
if mirror.get('deletedAt') is None:
available_mirror = mirror
break
# Create transformed file entry
transformed_file = {
'id': file_data.get('id'),
'sizeKB': file_data.get('sizeKB'),
'name': available_mirror.get('filename', file_data.get('name')) if available_mirror else file_data.get('name'),
'type': file_data.get('type'),
'downloadUrl': available_mirror.get('url') if available_mirror else None,
'primary': True,
'mirrors': file_data.get('mirrors', [])
}
# Transform hash format
if 'sha256' in file_data:
transformed_file['hashes'] = {
'SHA256': file_data['sha256'].upper()
}
transformed_files.append(transformed_file)
version['files'] = transformed_files
# Add model information
version['model'] = {
'name': model_data.get('name'),
'type': model_data.get('type'),
'nsfw': model_data.get('is_nsfw', False),
'description': model_data.get('description'),
'tags': model_data.get('tags', [])
}
version['creator'] = {
'username': model_data.get('username'),
'image': ''
}
# Add source identifier
version['source'] = 'civarchive'
version['is_deleted'] = json_data.get('query', {}).get('is_deleted', False)
return version
except Exception as e:
logger.error(f"Error fetching CivArchive model version {model_id}/{version_id}: {e}")
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Not supported by CivArchive provider - requires both model_id and version_id"""
return None, "CivArchive provider requires both model_id and version_id"
class SQLiteModelMetadataProvider(ModelMetadataProvider):
"""Provider that uses SQLite database for metadata"""
def __init__(self, db_path: str):
self.db_path = db_path
self._aiosqlite = _require_aiosqlite()
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash value from SQLite database"""
async with self._aiosqlite.connect(self.db_path) as db:
# Look up in model_files table to get model_id and version_id
query = """
SELECT model_id, version_id
FROM model_files
WHERE sha256 = ?
LIMIT 1
"""
db.row_factory = self._aiosqlite.Row
cursor = await db.execute(query, (model_hash.upper(),))
file_row = await cursor.fetchone()
if not file_row:
return None, "Model not found"
# Get version details
model_id = file_row['model_id']
version_id = file_row['version_id']
# Build response in the same format as Civitai API
result = await self._get_version_with_model_data(db, model_id, version_id)
return result, None if result else "Error retrieving model data"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
"""Get all versions of a model from SQLite database"""
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# First check if model exists
model_query = "SELECT * FROM models WHERE id = ?"
cursor = await db.execute(model_query, (model_id,))
model_row = await cursor.fetchone()
if not model_row:
return None
model_data = json.loads(model_row['data'])
model_type = model_row['type']
model_name = model_row['name']
# Get all versions for this model
versions_query = """
SELECT id, name, base_model, data, position, published_at
FROM model_versions
WHERE model_id = ?
ORDER BY position ASC
"""
cursor = await db.execute(versions_query, (model_id,))
version_rows = await cursor.fetchall()
if not version_rows:
return {'modelVersions': [], 'type': model_type}
# Format versions similar to Civitai API
model_versions = []
for row in version_rows:
version_data = json.loads(row['data'])
# Add fields from the row to ensure we have the basic fields
version_entry = {
'id': row['id'],
'modelId': int(model_id),
'name': row['name'],
'baseModel': row['base_model'],
'model': {
'name': model_row['name'],
'type': model_type,
},
'source': 'archive_db'
}
# Update with any additional data
version_entry.update(version_data)
model_versions.append(version_entry)
return {
'modelVersions': model_versions,
'type': model_type,
'name': model_name
}
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata from SQLite database"""
if not model_id and not version_id:
return None
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# Case 1: Only version_id is provided
if model_id is None and version_id is not None:
# First get the version info to extract model_id
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
cursor = await db.execute(version_query, (version_id,))
version_row = await cursor.fetchone()
if not version_row:
return None
model_id = version_row['model_id']
# Case 2: model_id is provided but version_id is not
elif model_id is not None and version_id is None:
# Find the latest version
version_query = """
SELECT id FROM model_versions
WHERE model_id = ?
ORDER BY position ASC
LIMIT 1
"""
cursor = await db.execute(version_query, (model_id,))
version_row = await cursor.fetchone()
if not version_row:
return None
version_id = version_row['id']
# Now we have both model_id and version_id, get the full data
return await self._get_version_with_model_data(db, model_id, version_id)
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version metadata from SQLite database"""
async with self._aiosqlite.connect(self.db_path) as db:
db.row_factory = self._aiosqlite.Row
# Get version details
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
cursor = await db.execute(version_query, (version_id,))
version_row = await cursor.fetchone()
if not version_row:
return None, "Model version not found"
model_id = version_row['model_id']
# Build complete version data with model info
version_data = await self._get_version_with_model_data(db, model_id, version_id)
return version_data, None
async def _get_version_with_model_data(self, db, model_id, version_id) -> Optional[Dict]:
"""Helper to build version data with model information"""
# Get version details
version_query = "SELECT name, base_model, data FROM model_versions WHERE id = ? AND model_id = ?"
cursor = await db.execute(version_query, (version_id, model_id))
version_row = await cursor.fetchone()
if not version_row:
return None
# Get model details
model_query = "SELECT name, type, data, username FROM models WHERE id = ?"
cursor = await db.execute(model_query, (model_id,))
model_row = await cursor.fetchone()
if not model_row:
return None
# Parse JSON data
try:
version_data = json.loads(version_row['data'])
model_data = json.loads(model_row['data'])
# Build response
result = {
"id": int(version_id),
"modelId": int(model_id),
"name": version_row['name'],
"baseModel": version_row['base_model'],
"model": {
"name": model_row['name'],
"description": model_data.get("description"),
"type": model_row['type'],
"tags": model_data.get("tags", [])
},
"creator": {
"username": model_row['username'] or model_data.get("creator", {}).get("username"),
"image": model_data.get("creator", {}).get("image")
},
"source": "archive_db"
}
# Add any additional fields from version data
result.update(version_data)
return result
except json.JSONDecodeError:
return None
class FallbackMetadataProvider(ModelMetadataProvider):
"""Try providers in order, return first successful result."""
def __init__(self, providers: list):
self.providers = providers
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
try:
result, error = await provider.get_model_by_hash(model_hash)
if result:
return result, error
except Exception as e:
logger.debug(f"Provider failed for get_model_by_hash: {e}")
continue
return None, "Model not found"
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
for provider in self.providers:
try:
result = await provider.get_model_versions(model_id)
if result:
return result
except Exception as e:
logger.debug(f"Provider failed for get_model_versions: {e}")
continue
return None
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
for provider in self.providers:
try:
result = await provider.get_model_version(model_id, version_id)
if result:
return result
except Exception as e:
logger.debug(f"Provider failed for get_model_version: {e}")
continue
return None
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
for provider in self.providers:
try:
result, error = await provider.get_model_version_info(version_id)
if result:
return result, error
except Exception as e:
logger.debug(f"Provider failed for get_model_version_info: {e}")
continue
return None, "No provider could retrieve the data"
class ModelMetadataProviderManager:
"""Manager for selecting and using model metadata providers"""
_instance = None
@classmethod
async def get_instance(cls):
"""Get singleton instance of ModelMetadataProviderManager"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
self.providers = {}
self.default_provider = None
def register_provider(self, name: str, provider: ModelMetadataProvider, is_default: bool = False):
"""Register a metadata provider"""
self.providers[name] = provider
if is_default or self.default_provider is None:
self.default_provider = name
async def get_model_by_hash(self, model_hash: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
"""Find model by hash using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_by_hash(model_hash)
async def get_model_versions(self, model_id: str, provider_name: str = None) -> Optional[Dict]:
"""Get model versions using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_versions(model_id)
async def get_model_version(self, model_id: int = None, version_id: int = None, provider_name: str = None) -> Optional[Dict]:
"""Get specific model version using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_version(model_id, version_id)
async def get_model_version_info(self, version_id: str, provider_name: str = None) -> Tuple[Optional[Dict], Optional[str]]:
"""Fetch model version info using specified or default provider"""
provider = self._get_provider(provider_name)
return await provider.get_model_version_info(version_id)
def _get_provider(self, provider_name: str = None) -> ModelMetadataProvider:
"""Get provider by name or default provider"""
if provider_name and provider_name in self.providers:
return self.providers[provider_name]
if self.default_provider is None:
raise ValueError("No default provider set and no valid provider specified")
return self.providers[self.default_provider]

View File

@@ -1,196 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
from ..utils.constants import NSFW_LEVELS
from ..utils.utils import fuzzy_match as default_fuzzy_match
class SettingsProvider(Protocol):
"""Protocol describing the SettingsManager contract used by query helpers."""
def get(self, key: str, default: Any = None) -> Any:
...
@dataclass(frozen=True)
class SortParams:
"""Normalized representation of sorting instructions."""
key: str
order: str
@dataclass(frozen=True)
class FilterCriteria:
"""Container for model list filtering options."""
folder: Optional[str] = None
base_models: Optional[Sequence[str]] = None
tags: Optional[Sequence[str]] = None
favorites_only: bool = False
search_options: Optional[Dict[str, Any]] = None
class ModelCacheRepository:
"""Adapter around scanner cache access and sort normalisation."""
def __init__(self, scanner) -> None:
self._scanner = scanner
async def get_cache(self):
"""Return the underlying cache instance from the scanner."""
return await self._scanner.get_cached_data()
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
"""Fetch cached data pre-sorted according to ``params``."""
cache = await self.get_cache()
return await cache.get_sorted_data(params.key, params.order)
@staticmethod
def parse_sort(sort_by: str) -> SortParams:
"""Parse an incoming sort string into key/order primitives."""
if not sort_by:
return SortParams(key="name", order="asc")
if ":" in sort_by:
raw_key, raw_order = sort_by.split(":", 1)
sort_key = raw_key.strip().lower() or "name"
order = raw_order.strip().lower()
else:
sort_key = sort_by.strip().lower() or "name"
order = "asc"
if order not in ("asc", "desc"):
order = "asc"
return SortParams(key=sort_key, order=order)
class ModelFilterSet:
"""Applies common filtering rules to the model collection."""
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
self._settings = settings
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
"""Return items that satisfy the provided criteria."""
items = list(data)
if self._settings.get("show_only_sfw", False):
threshold = self._nsfw_levels.get("R", 0)
items = [
item for item in items
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
]
if criteria.favorites_only:
items = [item for item in items if item.get("favorite", False)]
folder = criteria.folder
options = criteria.search_options or {}
recursive = bool(options.get("recursive", True))
if folder is not None:
if recursive:
if folder:
folder_with_sep = f"{folder}/"
items = [
item for item in items
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
]
else:
items = [item for item in items if item.get("folder") == folder]
base_models = criteria.base_models or []
if base_models:
base_model_set = set(base_models)
items = [item for item in items if item.get("base_model") in base_model_set]
tags = criteria.tags or []
if tags:
tag_set = set(tags)
items = [
item for item in items
if any(tag in tag_set for tag in item.get("tags", []))
]
return items
class SearchStrategy:
"""Encapsulates text and fuzzy matching behaviour for model queries."""
DEFAULT_OPTIONS: Dict[str, Any] = {
"filename": True,
"modelname": True,
"tags": False,
"recursive": True,
"creator": False,
}
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""Merge provided options with defaults without mutating input."""
normalized = dict(self.DEFAULT_OPTIONS)
if options:
normalized.update(options)
return normalized
def apply(
self,
data: Iterable[Dict[str, Any]],
search_term: str,
options: Dict[str, Any],
fuzzy: bool = False,
) -> List[Dict[str, Any]]:
"""Return items matching the search term using the configured strategy."""
if not search_term:
return list(data)
search_lower = search_term.lower()
results: List[Dict[str, Any]] = []
for item in data:
if options.get("filename", True):
candidate = item.get("file_name", "")
if self._matches(candidate, search_term, search_lower, fuzzy):
results.append(item)
continue
if options.get("modelname", True):
candidate = item.get("model_name", "")
if self._matches(candidate, search_term, search_lower, fuzzy):
results.append(item)
continue
if options.get("tags", False):
tags = item.get("tags", []) or []
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
results.append(item)
continue
if options.get("creator", False):
creator_username = ""
civitai = item.get("civitai")
if isinstance(civitai, dict):
creator = civitai.get("creator")
if isinstance(creator, dict):
creator_username = creator.get("username", "")
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
results.append(item)
continue
return results
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
if not candidate:
return False
candidate_lower = candidate.lower()
if fuzzy:
return self._fuzzy_match(candidate, search_term)
return search_lower in candidate_lower

View File

@@ -8,12 +8,11 @@ from typing import List, Dict, Optional, Type, Set
from ..utils.models import BaseModelMetadata from ..utils.models import BaseModelMetadata
from ..config import config from ..config import config
from ..utils.file_utils import find_preview_file, get_preview_extension from ..utils.file_utils import find_preview_file
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .model_cache import ModelCache from .model_cache import ModelCache
from .model_hash_index import ModelHashIndex from .model_hash_index import ModelHashIndex
from ..utils.constants import PREVIEW_EXTENSIONS from ..utils.constants import PREVIEW_EXTENSIONS
from .model_lifecycle_service import delete_model_artifacts
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager from .websocket_manager import ws_manager
@@ -303,13 +302,6 @@ class ModelScanner:
for tag in model_data['tags']: for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Log duplicate filename warnings after building the index
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
# if duplicate_filenames:
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
# for filename, paths in duplicate_filenames.items():
# logger.warning(f" Duplicate filename '{filename}': {paths}")
# Update cache # Update cache
self._cache.raw_data = raw_data self._cache.raw_data = raw_data
loop.run_until_complete(self._cache.resort()) loop.run_until_complete(self._cache.resort())
@@ -375,13 +367,6 @@ class ModelScanner:
for tag in model_data['tags']: for tag in model_data['tags']:
self._tags_count[tag] = self._tags_count.get(tag, 0) + 1 self._tags_count[tag] = self._tags_count.get(tag, 0) + 1
# Log duplicate filename warnings after building the index
# duplicate_filenames = self._hash_index.get_duplicate_filenames()
# if duplicate_filenames:
# logger.warning(f"Found {len(duplicate_filenames)} filename(s) with duplicates during {self.model_type} cache build:")
# for filename, paths in duplicate_filenames.items():
# logger.warning(f" Duplicate filename '{filename}': {paths}")
# Update cache # Update cache
self._cache = ModelCache( self._cache = ModelCache(
raw_data=raw_data, raw_data=raw_data,
@@ -584,13 +569,12 @@ class ModelScanner:
for entry in entries: for entry in entries:
try: try:
if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions): if entry.is_file(follow_symlinks=True) and any(entry.name.endswith(ext) for ext in self.file_extensions):
# Use original path instead of real path
file_path = entry.path.replace(os.sep, "/") file_path = entry.path.replace(os.sep, "/")
result = await self._process_model_file(file_path, original_root) await self._process_single_file(file_path, original_root, models)
# Only add to models if result is not None (skip corrupted metadata)
if result:
models.append(result)
await asyncio.sleep(0) await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True): elif entry.is_dir(follow_symlinks=True):
# For directories, continue scanning with original path
await scan_recursive(entry.path, visited_paths) await scan_recursive(entry.path, visited_paths)
except Exception as e: except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}") logger.error(f"Error processing entry {entry.path}: {e}")
@@ -600,6 +584,15 @@ class ModelScanner:
await scan_recursive(root_path, set()) await scan_recursive(root_path, set())
return models return models
async def _process_single_file(self, file_path: str, root_path: str, models: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
def is_initializing(self) -> bool: def is_initializing(self) -> bool:
"""Check if the scanner is currently initializing""" """Check if the scanner is currently initializing"""
return self._is_initializing return self._is_initializing
@@ -620,18 +613,10 @@ class ModelScanner:
return os.path.dirname(rel_path).replace(os.path.sep, '/') return os.path.dirname(rel_path).replace(os.path.sep, '/')
return '' return ''
def adjust_metadata(self, metadata, file_path, root_path): # Common methods shared between scanners
"""Hook for subclasses: adjust metadata during scanning"""
return metadata
async def _process_model_file(self, file_path: str, root_path: str) -> Dict: async def _process_model_file(self, file_path: str, root_path: str) -> Dict:
"""Process a single model file and return its metadata""" """Process a single model file and return its metadata"""
metadata, should_skip = await MetadataManager.load_metadata(file_path, self.model_class) metadata = await MetadataManager.load_metadata(file_path, self.model_class)
if should_skip:
# Metadata file exists but cannot be parsed - skip this model
logger.warning(f"Skipping model {file_path} due to corrupted metadata file")
return None
if metadata is None: if metadata is None:
civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info" civitai_info_path = f"{os.path.splitext(file_path)[0]}.civitai.info"
@@ -647,7 +632,7 @@ class ModelScanner:
metadata = self.model_class.from_civitai_info(version_info, file_info, file_path) metadata = self.model_class.from_civitai_info(version_info, file_info, file_path)
metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path)) metadata.preview_url = find_preview_file(file_name, os.path.dirname(file_path))
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata, True)
logger.debug(f"Created metadata from .civitai.info for {file_path}") logger.debug(f"Created metadata from .civitai.info for {file_path}")
except Exception as e: except Exception as e:
logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}") logger.error(f"Error creating metadata from .civitai.info for {file_path}: {e}")
@@ -674,7 +659,7 @@ class ModelScanner:
metadata.modelDescription = version_info['model']['description'] metadata.modelDescription = version_info['model']['description']
# Save the updated metadata # Save the updated metadata
await MetadataManager.save_metadata(file_path, metadata) await MetadataManager.save_metadata(file_path, metadata, True)
logger.debug(f"Updated metadata with civitai info for {file_path}") logger.debug(f"Updated metadata with civitai info for {file_path}")
except Exception as e: except Exception as e:
logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}") logger.error(f"Error restoring civitai data from .civitai.info for {file_path}: {e}")
@@ -682,9 +667,6 @@ class ModelScanner:
if metadata is None: if metadata is None:
metadata = await self._create_default_metadata(file_path) metadata = await self._create_default_metadata(file_path)
# Hook: allow subclasses to adjust metadata
metadata = self.adjust_metadata(metadata, file_path, root_path)
model_data = metadata.to_dict() model_data = metadata.to_dict()
# Skip excluded models # Skip excluded models
@@ -692,20 +674,106 @@ class ModelScanner:
self._excluded_models.append(model_data['file_path']) self._excluded_models.append(model_data['file_path'])
return None return None
# Check for duplicate filename before adding to hash index await self._fetch_missing_metadata(file_path, model_data)
filename = os.path.splitext(os.path.basename(file_path))[0]
existing_hash = self._hash_index.get_hash_by_filename(filename)
if existing_hash and existing_hash != model_data.get('sha256', '').lower():
existing_path = self._hash_index.get_path(existing_hash)
if existing_path and existing_path != file_path:
logger.warning(f"Duplicate filename detected: '{filename}' - files: '{existing_path}' and '{file_path}'")
rel_path = os.path.relpath(file_path, root_path) rel_path = os.path.relpath(file_path, root_path)
folder = os.path.dirname(rel_path) folder = os.path.dirname(rel_path)
model_data['folder'] = folder.replace(os.path.sep, '/') model_data['folder'] = folder.replace(os.path.sep, '/')
return model_data return model_data
async def _fetch_missing_metadata(self, file_path: str, model_data: Dict) -> None:
"""Fetch missing description and tags from Civitai if needed"""
try:
if model_data.get('civitai_deleted', False):
logger.debug(f"Skipping metadata fetch for {file_path}: marked as deleted on Civitai")
return
needs_metadata_update = False
model_id = None
if model_data.get('civitai'):
model_id = model_data['civitai'].get('modelId')
if model_id:
model_id = str(model_id)
tags_missing = not model_data.get('tags') or len(model_data.get('tags', [])) == 0
desc_missing = not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")
# TODO: not for now, but later we should check if the creator is missing
# creator_missing = not model_data.get('civitai', {}).get('creator')
creator_missing = False
needs_metadata_update = tags_missing or desc_missing or creator_missing
if needs_metadata_update and model_id:
logger.debug(f"Fetching missing metadata for {file_path} with model ID {model_id}")
from ..services.civitai_client import CivitaiClient
client = CivitaiClient()
model_metadata, status_code = await client.get_model_metadata(model_id)
await client.close()
if status_code == 404:
logger.warning(f"Model {model_id} appears to be deleted from Civitai (404 response)")
model_data['civitai_deleted'] = True
await MetadataManager.save_metadata(file_path, model_data)
elif model_metadata:
logger.debug(f"Updating metadata for {file_path} with model ID {model_id}")
if model_metadata.get('tags') and (not model_data.get('tags') or len(model_data.get('tags', [])) == 0):
model_data['tags'] = model_metadata['tags']
if model_metadata.get('description') and (not model_data.get('modelDescription') or model_data.get('modelDescription') in (None, "")):
model_data['modelDescription'] = model_metadata['description']
model_data['civitai']['creator'] = model_metadata['creator']
await MetadataManager.save_metadata(file_path, model_data, True)
except Exception as e:
logger.error(f"Failed to update metadata from Civitai for {file_path}: {e}")
async def _scan_directory(self, root_path: str) -> List[Dict]:
"""Base implementation for directory scanning"""
models = []
original_root = root_path
async def scan_recursive(path: str, visited_paths: set):
try:
real_path = os.path.realpath(path)
if real_path in visited_paths:
logger.debug(f"Skipping already visited path: {path}")
return
visited_paths.add(real_path)
with os.scandir(path) as it:
entries = list(it)
for entry in entries:
try:
if entry.is_file(follow_symlinks=True):
ext = os.path.splitext(entry.name)[1].lower()
if ext in self.file_extensions:
file_path = entry.path.replace(os.sep, "/")
await self._process_single_file(file_path, original_root, models)
await asyncio.sleep(0)
elif entry.is_dir(follow_symlinks=True):
await scan_recursive(entry.path, visited_paths)
except Exception as e:
logger.error(f"Error processing entry {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning {path}: {e}")
await scan_recursive(root_path, set())
return models
async def _process_single_file(self, file_path: str, root_path: str, models_list: list):
"""Process a single file and add to results list"""
try:
result = await self._process_model_file(file_path, root_path)
if result:
models_list.append(result)
except Exception as e:
logger.error(f"Error processing {file_path}: {e}")
async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool: async def add_model_to_cache(self, metadata_dict: Dict, folder: str = '') -> bool:
"""Add a model to the cache """Add a model to the cache
@@ -741,16 +809,8 @@ class ModelScanner:
logger.error(f"Error adding model to cache: {e}") logger.error(f"Error adding model to cache: {e}")
return False return False
async def move_model(self, source_path: str, target_path: str) -> Optional[str]: async def move_model(self, source_path: str, target_path: str) -> bool:
"""Move a model and its associated files to a new location """Move a model and its associated files to a new location"""
Args:
source_path: Original file path
target_path: Target directory path
Returns:
Optional[str]: New file path if successful, None if failed
"""
try: try:
source_path = source_path.replace(os.sep, '/') source_path = source_path.replace(os.sep, '/')
target_path = target_path.replace(os.sep, '/') target_path = target_path.replace(os.sep, '/')
@@ -759,28 +819,14 @@ class ModelScanner:
if not file_ext or file_ext.lower() not in self.file_extensions: if not file_ext or file_ext.lower() not in self.file_extensions:
logger.error(f"Invalid file extension for model: {file_ext}") logger.error(f"Invalid file extension for model: {file_ext}")
return None return False
base_name = os.path.splitext(os.path.basename(source_path))[0] base_name = os.path.splitext(os.path.basename(source_path))[0]
source_dir = os.path.dirname(source_path) source_dir = os.path.dirname(source_path)
os.makedirs(target_path, exist_ok=True) os.makedirs(target_path, exist_ok=True)
def get_source_hash(): target_file = os.path.join(target_path, f"{base_name}{file_ext}").replace(os.sep, '/')
return self.get_hash_by_path(source_path)
# Check for filename conflicts and auto-rename if necessary
from ..utils.models import BaseModelMetadata
final_filename = BaseModelMetadata.generate_unique_filename(
target_path, base_name, file_ext, get_source_hash
)
target_file = os.path.join(target_path, final_filename).replace(os.sep, '/')
final_base_name = os.path.splitext(final_filename)[0]
# Log if filename was changed due to conflict
if final_filename != f"{base_name}{file_ext}":
logger.info(f"Renamed {base_name}{file_ext} to {final_filename} to avoid filename conflict")
real_source = os.path.realpath(source_path) real_source = os.path.realpath(source_path)
real_target = os.path.realpath(target_file) real_target = os.path.realpath(target_file)
@@ -797,17 +843,12 @@ class ModelScanner:
for file in os.listdir(source_dir): for file in os.listdir(source_dir):
if file.startswith(base_name + ".") and file != os.path.basename(source_path): if file.startswith(base_name + ".") and file != os.path.basename(source_path):
source_file_path = os.path.join(source_dir, file) source_file_path = os.path.join(source_dir, file)
# Generate new filename with the same base name as the model file
file_suffix = file[len(base_name):] # Get the part after base_name (e.g., ".metadata.json", ".preview.png")
new_associated_filename = f"{final_base_name}{file_suffix}"
target_associated_path = os.path.join(target_path, new_associated_filename)
# Store metadata file path for special handling # Store metadata file path for special handling
if file == f"{base_name}.metadata.json": if file == f"{base_name}.metadata.json":
source_metadata = source_file_path source_metadata = source_file_path
moved_metadata_path = target_associated_path moved_metadata_path = os.path.join(target_path, file)
else: else:
files_to_move.append((source_file_path, target_associated_path)) files_to_move.append((source_file_path, os.path.join(target_path, file)))
except Exception as e: except Exception as e:
logger.error(f"Error listing files in {source_dir}: {e}") logger.error(f"Error listing files in {source_dir}: {e}")
@@ -829,11 +870,11 @@ class ModelScanner:
await self.update_single_model_cache(source_path, target_file, metadata) await self.update_single_model_cache(source_path, target_file, metadata)
return target_file return True
except Exception as e: except Exception as e:
logger.error(f"Error moving model: {e}", exc_info=True) logger.error(f"Error moving model: {e}", exc_info=True)
return None return False
async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict: async def _update_metadata_paths(self, metadata_path: str, model_path: str) -> Dict:
"""Update file paths in metadata file""" """Update file paths in metadata file"""
@@ -842,15 +883,12 @@ class ModelScanner:
metadata = json.load(f) metadata = json.load(f)
metadata['file_path'] = model_path.replace(os.sep, '/') metadata['file_path'] = model_path.replace(os.sep, '/')
# Update file_name to match the new filename
metadata['file_name'] = os.path.splitext(os.path.basename(model_path))[0]
if 'preview_url' in metadata and metadata['preview_url']: if 'preview_url' in metadata and metadata['preview_url']:
preview_dir = os.path.dirname(model_path) preview_dir = os.path.dirname(model_path)
# Update preview filename to match the new base name preview_name = os.path.splitext(os.path.basename(metadata['preview_url']))[0]
new_base_name = os.path.splitext(os.path.basename(model_path))[0] preview_ext = os.path.splitext(metadata['preview_url'])[1]
preview_ext = get_preview_extension(metadata['preview_url']) new_preview_path = os.path.join(preview_dir, f"{preview_name}{preview_ext}")
new_preview_path = os.path.join(preview_dir, f"{new_base_name}{preview_ext}")
metadata['preview_url'] = new_preview_path.replace(os.sep, '/') metadata['preview_url'] = new_preview_path.replace(os.sep, '/')
await MetadataManager.save_metadata(metadata_path, metadata) await MetadataManager.save_metadata(metadata_path, metadata)
@@ -917,15 +955,7 @@ class ModelScanner:
def get_hash_by_path(self, file_path: str) -> Optional[str]: def get_hash_by_path(self, file_path: str) -> Optional[str]:
"""Get hash for a model by its file path""" """Get hash for a model by its file path"""
if self._cache is None or not self._cache.raw_data: return self._hash_index.get_hash(file_path)
return None
# Iterate through cache data to find matching file path
for model_data in self._cache.raw_data:
if model_data.get('file_path') == file_path:
return model_data.get('sha256')
return None
def get_hash_by_filename(self, filename: str) -> Optional[str]: def get_hash_by_filename(self, filename: str) -> Optional[str]:
"""Get hash for a model by its filename without path""" """Get hash for a model by its filename without path"""
@@ -1041,7 +1071,9 @@ class ModelScanner:
target_dir = os.path.dirname(file_path) target_dir = os.path.dirname(file_path)
file_name = os.path.splitext(os.path.basename(file_path))[0] file_name = os.path.splitext(os.path.basename(file_path))[0]
deleted_files = await delete_model_artifacts( # Delete all associated files for the model
from ..utils.routes_common import ModelRouteUtils
deleted_files = await ModelRouteUtils.delete_model_files(
target_dir, target_dir,
file_name file_name
) )
@@ -1162,10 +1194,11 @@ class ModelScanner:
if len(self._hash_index._duplicate_filenames[file_name]) <= 1: if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
del self._hash_index._duplicate_filenames[file_name] del self._hash_index._duplicate_filenames[file_name]
async def check_model_version_exists(self, model_version_id: int) -> bool: async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool:
"""Check if a specific model version exists in the cache """Check if a specific model version exists in the cache
Args: Args:
model_id: Civitai model ID
model_version_id: Civitai model version ID model_version_id: Civitai model version ID
Returns: Returns:
@@ -1177,7 +1210,9 @@ class ModelScanner:
return False return False
for item in cache.raw_data: for item in cache.raw_data:
if item.get('civitai') and item['civitai'].get('id') == model_version_id: if (item.get('civitai') and
item['civitai'].get('modelId') == model_id and
item['civitai'].get('id') == model_version_id):
return True return True
return False return False

View File

@@ -1,168 +0,0 @@
"""Service for processing preview assets for models."""
from __future__ import annotations
import logging
import os
from typing import Awaitable, Callable, Dict, Optional, Sequence
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
logger = logging.getLogger(__name__)
class PreviewAssetService:
"""Manage fetching and persisting preview assets."""
def __init__(
self,
*,
metadata_manager,
downloader_factory: Callable[[], Awaitable],
exif_utils,
) -> None:
self._metadata_manager = metadata_manager
self._downloader_factory = downloader_factory
self._exif_utils = exif_utils
async def ensure_preview_for_metadata(
self,
metadata_path: str,
local_metadata: Dict[str, object],
images: Sequence[Dict[str, object]] | None,
) -> None:
"""Ensure preview assets exist for the supplied metadata entry."""
if local_metadata.get("preview_url") and os.path.exists(
str(local_metadata["preview_url"])
):
return
if not images:
return
first_preview = images[0]
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
preview_dir = os.path.dirname(metadata_path)
is_video = first_preview.get("type") == "video"
if is_video:
extension = ".mp4"
preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, result = await downloader.download_file(
first_preview["url"], preview_path, use_auth=False
)
if success:
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
else:
extension = ".webp"
preview_path = os.path.join(preview_dir, base_name + extension)
downloader = await self._downloader_factory()
success, content, _headers = await downloader.download_to_memory(
first_preview["url"], use_auth=False
)
if not success:
return
try:
optimized_data, _ = self._exif_utils.optimize_image(
image_data=content,
target_width=CARD_PREVIEW_WIDTH,
format="webp",
quality=85,
preserve_metadata=False,
)
with open(preview_path, "wb") as handle:
handle.write(optimized_data)
except Exception as exc: # pragma: no cover - defensive path
logger.error("Error optimizing preview image: %s", exc)
try:
with open(preview_path, "wb") as handle:
handle.write(content)
except Exception as save_exc:
logger.error("Error saving preview image: %s", save_exc)
return
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
async def replace_preview(
self,
*,
model_path: str,
preview_data: bytes,
content_type: str,
original_filename: Optional[str],
nsfw_level: int,
update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]],
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
) -> Dict[str, object]:
"""Replace an existing preview asset for a model."""
base_name = os.path.splitext(os.path.basename(model_path))[0]
folder = os.path.dirname(model_path)
extension, optimized_data = await self._convert_preview(
preview_data, content_type, original_filename
)
for ext in PREVIEW_EXTENSIONS:
existing_preview = os.path.join(folder, base_name + ext)
if os.path.exists(existing_preview):
try:
os.remove(existing_preview)
except Exception as exc: # pragma: no cover - defensive path
logger.warning(
"Failed to delete existing preview %s: %s", existing_preview, exc
)
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/")
with open(preview_path, "wb") as handle:
handle.write(optimized_data)
metadata_path = os.path.splitext(model_path)[0] + ".metadata.json"
metadata = await metadata_loader(metadata_path)
metadata["preview_url"] = preview_path
metadata["preview_nsfw_level"] = nsfw_level
await self._metadata_manager.save_metadata(model_path, metadata)
await update_preview_in_cache(model_path, preview_path, nsfw_level)
return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level}
async def _convert_preview(
self, data: bytes, content_type: str, original_filename: Optional[str]
) -> tuple[str, bytes]:
"""Convert preview bytes to the persisted representation."""
if content_type.startswith("video/"):
extension = self._resolve_video_extension(content_type, original_filename)
return extension, data
original_ext = (original_filename or "").lower()
if original_ext.endswith(".gif") or content_type.lower() == "image/gif":
return ".gif", data
optimized_data, _ = self._exif_utils.optimize_image(
image_data=data,
target_width=CARD_PREVIEW_WIDTH,
format="webp",
quality=85,
preserve_metadata=False,
)
return ".webp", optimized_data
def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str:
"""Infer the best extension for a video preview."""
if original_filename:
extension = os.path.splitext(original_filename)[1].lower()
if extension in {".mp4", ".webm", ".mov", ".avi"}:
return extension
if "webm" in content_type:
return ".webm"
return ".mp4"

View File

@@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import Iterable, List, Dict, Optional from typing import List, Dict
from dataclasses import dataclass from dataclasses import dataclass
from operator import itemgetter from operator import itemgetter
from natsort import natsorted from natsort import natsorted
@@ -17,9 +17,18 @@ class RecipeCache:
async def resort(self, name_only: bool = False): async def resort(self, name_only: bool = False):
"""Resort all cached data views""" """Resort all cached data views"""
async with self._lock: async with self._lock:
self._resort_locked(name_only=name_only) self.sorted_by_name = natsorted(
self.raw_data,
key=lambda x: x.get('title', '').lower() # Case-insensitive sort
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('created_date', 'file_path'),
reverse=True
)
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool: async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool:
"""Update metadata for a specific recipe in all cached data """Update metadata for a specific recipe in all cached data
Args: Args:
@@ -29,96 +38,49 @@ class RecipeCache:
Returns: Returns:
bool: True if the update was successful, False if the recipe wasn't found bool: True if the update was successful, False if the recipe wasn't found
""" """
async with self._lock:
for item in self.raw_data:
if str(item.get('id')) == str(recipe_id):
item.update(metadata)
if resort:
self._resort_locked()
return True
return False # Recipe not found
async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None: # Update in raw_data
"""Add a new recipe to the cache.""" for item in self.raw_data:
if item.get('id') == recipe_id:
item.update(metadata)
break
else:
return False # Recipe not found
# Resort to reflect changes
await self.resort()
return True
async def add_recipe(self, recipe_data: Dict) -> None:
"""Add a new recipe to the cache
Args:
recipe_data: The recipe data to add
"""
async with self._lock: async with self._lock:
self.raw_data.append(recipe_data) self.raw_data.append(recipe_data)
if resort: await self.resort()
self._resort_locked()
async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]: async def remove_recipe(self, recipe_id: str) -> bool:
"""Remove a recipe from the cache by ID. """Remove a recipe from the cache by ID
Args: Args:
recipe_id: The ID of the recipe to remove recipe_id: The ID of the recipe to remove
Returns: Returns:
The removed recipe data if found, otherwise ``None``. bool: True if the recipe was found and removed, False otherwise
""" """
# Find the recipe in raw_data
recipe_index = next((i for i, recipe in enumerate(self.raw_data)
if recipe.get('id') == recipe_id), None)
async with self._lock: if recipe_index is None:
for index, recipe in enumerate(self.raw_data): return False
if str(recipe.get('id')) == str(recipe_id):
removed = self.raw_data.pop(index)
if resort:
self._resort_locked()
return removed
return None
async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]: # Remove from raw_data
"""Remove multiple recipes from the cache.""" self.raw_data.pop(recipe_index)
id_set = {str(recipe_id) for recipe_id in recipe_ids} # Resort to update sorted lists
if not id_set: await self.resort()
return []
async with self._lock: return True
removed = [item for item in self.raw_data if str(item.get('id')) in id_set]
if not removed:
return []
self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set]
if resort:
self._resort_locked()
return removed
async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool:
"""Replace cached data for a recipe."""
async with self._lock:
for index, recipe in enumerate(self.raw_data):
if str(recipe.get('id')) == str(recipe_id):
self.raw_data[index] = new_data
if resort:
self._resort_locked()
return True
return False
async def get_recipe(self, recipe_id: str) -> Optional[Dict]:
"""Return a shallow copy of a cached recipe."""
async with self._lock:
for recipe in self.raw_data:
if str(recipe.get('id')) == str(recipe_id):
return dict(recipe)
return None
async def snapshot(self) -> List[Dict]:
"""Return a copy of all cached recipes."""
async with self._lock:
return [dict(item) for item in self.raw_data]
def _resort_locked(self, *, name_only: bool = False) -> None:
"""Sort cached views. Caller must hold ``_lock``."""
self.sorted_by_name = natsorted(
self.raw_data,
key=lambda x: x.get('title', '').lower()
)
if not name_only:
self.sorted_by_date = sorted(
self.raw_data,
key=itemgetter('created_date', 'file_path'),
reverse=True
)

View File

@@ -3,14 +3,12 @@ import logging
import asyncio import asyncio
import json import json
import time import time
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from typing import List, Dict, Optional, Any, Tuple
from ..config import config from ..config import config
from .recipe_cache import RecipeCache from .recipe_cache import RecipeCache
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .lora_scanner import LoraScanner from .lora_scanner import LoraScanner
from .metadata_service import get_default_metadata_provider from ..utils.utils import fuzzy_match
from .recipes.errors import RecipeNotFoundError
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
from natsort import natsorted from natsort import natsorted
import sys import sys
@@ -47,8 +45,6 @@ class RecipeScanner:
self._initialization_lock = asyncio.Lock() self._initialization_lock = asyncio.Lock()
self._initialization_task: Optional[asyncio.Task] = None self._initialization_task: Optional[asyncio.Task] = None
self._is_initializing = False self._is_initializing = False
self._mutation_lock = asyncio.Lock()
self._resort_tasks: Set[asyncio.Task] = set()
if lora_scanner: if lora_scanner:
self._lora_scanner = lora_scanner self._lora_scanner = lora_scanner
self._initialized = True self._initialized = True
@@ -194,22 +190,6 @@ class RecipeScanner:
# Clean up the event loop # Clean up the event loop
loop.close() loop.close()
def _schedule_resort(self, *, name_only: bool = False) -> None:
"""Schedule a background resort of the recipe cache."""
if not self._cache:
return
async def _resort_wrapper() -> None:
try:
await self._cache.resort(name_only=name_only)
except Exception as exc: # pragma: no cover - defensive logging
logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True)
task = asyncio.create_task(_resort_wrapper())
self._resort_tasks.add(task)
task.add_done_callback(lambda finished: self._resort_tasks.discard(finished))
@property @property
def recipes_dir(self) -> str: def recipes_dir(self) -> str:
"""Get path to recipes directory""" """Get path to recipes directory"""
@@ -275,44 +255,6 @@ class RecipeScanner:
# Return the cache (may be empty or partially initialized) # Return the cache (may be empty or partially initialized)
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
async def refresh_cache(self, force: bool = False) -> RecipeCache:
"""Public helper to refresh or return the recipe cache."""
return await self.get_cached_data(force_refresh=force)
async def add_recipe(self, recipe_data: Dict[str, Any]) -> None:
"""Add a recipe to the in-memory cache."""
if not recipe_data:
return
cache = await self.get_cached_data()
await cache.add_recipe(recipe_data, resort=False)
self._schedule_resort()
async def remove_recipe(self, recipe_id: str) -> bool:
"""Remove a recipe from the cache by ID."""
if not recipe_id:
return False
cache = await self.get_cached_data()
removed = await cache.remove_recipe(recipe_id, resort=False)
if removed is None:
return False
self._schedule_resort()
return True
async def bulk_remove(self, recipe_ids: Iterable[str]) -> int:
"""Remove multiple recipes from the cache."""
cache = await self.get_cached_data()
removed = await cache.bulk_remove(recipe_ids, resort=False)
if removed:
self._schedule_resort()
return len(removed)
async def scan_all_recipes(self) -> List[Dict]: async def scan_all_recipes(self) -> List[Dict]:
"""Scan all recipe JSON files and return metadata""" """Scan all recipe JSON files and return metadata"""
recipes = [] recipes = []
@@ -383,6 +325,7 @@ class RecipeScanner:
# Calculate and update fingerprint if missing # Calculate and update fingerprint if missing
if 'loras' in recipe_data and 'fingerprint' not in recipe_data: if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
from ..utils.utils import calculate_recipe_fingerprint
fingerprint = calculate_recipe_fingerprint(recipe_data['loras']) fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
recipe_data['fingerprint'] = fingerprint recipe_data['fingerprint'] = fingerprint
@@ -488,13 +431,13 @@ class RecipeScanner:
async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]: async def _get_hash_from_civitai(self, model_version_id: str) -> Optional[str]:
"""Get hash from Civitai API""" """Get hash from Civitai API"""
try: try:
# Get metadata provider instead of civitai client directly # Get CivitaiClient from ServiceRegistry
metadata_provider = await get_default_metadata_provider() civitai_client = await self._get_civitai_client()
if not metadata_provider: if not civitai_client:
logger.error("Failed to get metadata provider") logger.error("Failed to get CivitaiClient from ServiceRegistry")
return None return None
version_info, error_msg = await metadata_provider.get_model_version_info(model_version_id) version_info, error_msg = await civitai_client.get_model_version_info(model_version_id)
if not version_info: if not version_info:
if error_msg and "model not found" in error_msg.lower(): if error_msg and "model not found" in error_msg.lower():
@@ -553,33 +496,6 @@ class RecipeScanner:
logger.error(f"Error getting base model for lora: {e}") logger.error(f"Error getting base model for lora: {e}")
return None return None
def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]:
"""Populate convenience fields for a LoRA entry."""
if not lora or not self._lora_scanner:
return lora
hash_value = (lora.get('hash') or '').lower()
if not hash_value:
return lora
try:
lora['inLibrary'] = self._lora_scanner.has_hash(hash_value)
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value)
lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value)
except Exception as exc: # pragma: no cover - defensive logging
logger.debug("Error enriching lora entry %s: %s", hash_value, exc)
return lora
async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]:
"""Lookup a local LoRA model by name."""
if not self._lora_scanner or not name:
return None
return await self._lora_scanner.get_model_info_by_name(name)
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True): async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
"""Get paginated and filtered recipe data """Get paginated and filtered recipe data
@@ -685,7 +601,11 @@ class RecipeScanner:
# Add inLibrary information for each lora # Add inLibrary information for each lora
for item in paginated_items: for item in paginated_items:
if 'loras' in item: if 'loras' in item:
item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']] for lora in item['loras']:
if 'hash' in lora and lora['hash']:
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
result = { result = {
'items': paginated_items, 'items': paginated_items,
@@ -732,7 +652,12 @@ class RecipeScanner:
# Add lora metadata # Add lora metadata
if 'loras' in formatted_recipe: if 'loras' in formatted_recipe:
formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']] for lora in formatted_recipe['loras']:
if 'hash' in lora and lora['hash']:
lora_hash = lora['hash'].lower()
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
return formatted_recipe return formatted_recipe
@@ -794,8 +719,7 @@ class RecipeScanner:
# Update the cache if it exists # Update the cache if it exists
if self._cache is not None: if self._cache is not None:
await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False) await self._cache.update_recipe_metadata(recipe_id, metadata)
self._schedule_resort()
# If the recipe has an image, update its EXIF metadata # If the recipe has an image, update its EXIF metadata
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
@@ -809,138 +733,6 @@ class RecipeScanner:
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True) logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
return False return False
async def update_lora_entry(
self,
recipe_id: str,
lora_index: int,
*,
target_name: str,
target_lora: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Update a specific LoRA entry within a recipe.
Returns the updated recipe data and the refreshed LoRA metadata.
"""
if target_name is None:
raise ValueError("target_name must be provided")
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
raise RecipeNotFoundError("Recipe not found")
async with self._mutation_lock:
with open(recipe_json_path, 'r', encoding='utf-8') as file_obj:
recipe_data = json.load(file_obj)
loras = recipe_data.get('loras', [])
if lora_index >= len(loras):
raise RecipeNotFoundError("LoRA index out of range in recipe")
lora_entry = loras[lora_index]
lora_entry['isDeleted'] = False
lora_entry['exclude'] = False
lora_entry['file_name'] = target_name
if target_lora is not None:
sha_value = target_lora.get('sha256') or target_lora.get('sha')
if sha_value:
lora_entry['hash'] = sha_value.lower()
civitai_info = target_lora.get('civitai') or {}
if civitai_info:
lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '')
lora_entry['modelVersionName'] = civitai_info.get('name', '')
lora_entry['modelVersionId'] = civitai_info.get('id')
recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', []))
recipe_data['modified'] = time.time()
with open(recipe_json_path, 'w', encoding='utf-8') as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
cache = await self.get_cached_data()
replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False)
if not replaced:
await cache.add_recipe(recipe_data, resort=False)
self._schedule_resort()
updated_lora = dict(lora_entry)
if target_lora is not None:
preview_url = target_lora.get('preview_url')
if preview_url:
updated_lora['preview_url'] = config.get_preview_static_url(preview_url)
if target_lora.get('file_path'):
updated_lora['localPath'] = target_lora['file_path']
updated_lora = self._enrich_lora_entry(updated_lora)
return recipe_data, updated_lora
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
"""Return recipes that reference a given LoRA hash."""
if not lora_hash:
return []
normalized_hash = lora_hash.lower()
cache = await self.get_cached_data()
matching_recipes: List[Dict[str, Any]] = []
for recipe in cache.raw_data:
loras = recipe.get('loras', [])
if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras):
recipe_copy = {**recipe}
recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras]
recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path'))
matching_recipes.append(recipe_copy)
return matching_recipes
async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]:
"""Build LoRA syntax tokens for a recipe."""
cache = await self.get_cached_data()
recipe = await cache.get_recipe(recipe_id)
if recipe is None:
raise RecipeNotFoundError("Recipe not found")
loras = recipe.get('loras', [])
if not loras:
return []
lora_cache = None
if self._lora_scanner is not None:
lora_cache = await self._lora_scanner.get_cached_data()
syntax_parts: List[str] = []
for lora in loras:
if lora.get('isDeleted', False):
continue
file_name = None
hash_value = (lora.get('hash') or '').lower()
if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'):
file_path = self._lora_scanner._hash_index.get_path(hash_value)
if file_path:
file_name = os.path.splitext(os.path.basename(file_path))[0]
if not file_name and lora.get('modelVersionId') and lora_cache is not None:
for cached_lora in getattr(lora_cache, 'raw_data', []):
civitai_info = cached_lora.get('civitai')
if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'):
cached_path = cached_lora.get('path') or cached_lora.get('file_path')
if cached_path:
file_name = os.path.splitext(os.path.basename(cached_path))[0]
break
if not file_name:
file_name = lora.get('file_name', 'unknown-lora')
strength = lora.get('strength', 1.0)
syntax_parts.append(f"<lora:{file_name}:{strength}>")
return syntax_parts
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]: async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
"""Update file_name in all recipes that contain a LoRA with the specified hash. """Update file_name in all recipes that contain a LoRA with the specified hash.

View File

@@ -1,23 +0,0 @@
"""Recipe service layer implementations."""
from .analysis_service import RecipeAnalysisService
from .persistence_service import RecipePersistenceService
from .sharing_service import RecipeSharingService
from .errors import (
RecipeServiceError,
RecipeValidationError,
RecipeNotFoundError,
RecipeDownloadError,
RecipeConflictError,
)
__all__ = [
"RecipeAnalysisService",
"RecipePersistenceService",
"RecipeSharingService",
"RecipeServiceError",
"RecipeValidationError",
"RecipeNotFoundError",
"RecipeDownloadError",
"RecipeConflictError",
]

View File

@@ -1,289 +0,0 @@
"""Services responsible for recipe metadata analysis."""
from __future__ import annotations
import base64
import io
import os
import re
import tempfile
from dataclasses import dataclass
from typing import Any, Callable, Optional
import numpy as np
from PIL import Image
from ...utils.utils import calculate_recipe_fingerprint
from .errors import (
RecipeDownloadError,
RecipeNotFoundError,
RecipeServiceError,
RecipeValidationError,
)
@dataclass(frozen=True)
class AnalysisResult:
"""Return payload from analysis operations."""
payload: dict[str, Any]
status: int = 200
class RecipeAnalysisService:
"""Extract recipe metadata from various image sources."""
def __init__(
self,
*,
exif_utils,
recipe_parser_factory,
downloader_factory: Callable[[], Any],
metadata_collector: Optional[Callable[[], Any]] = None,
metadata_processor_cls: Optional[type] = None,
metadata_registry_cls: Optional[type] = None,
standalone_mode: bool = False,
logger,
) -> None:
self._exif_utils = exif_utils
self._recipe_parser_factory = recipe_parser_factory
self._downloader_factory = downloader_factory
self._metadata_collector = metadata_collector
self._metadata_processor_cls = metadata_processor_cls
self._metadata_registry_cls = metadata_registry_cls
self._standalone_mode = standalone_mode
self._logger = logger
async def analyze_uploaded_image(
self,
*,
image_bytes: bytes | None,
recipe_scanner,
) -> AnalysisResult:
"""Analyze an uploaded image payload."""
if not image_bytes:
raise RecipeValidationError("No image data provided")
temp_path = self._write_temp_file(image_bytes)
try:
metadata = self._exif_utils.extract_image_metadata(temp_path)
if not metadata:
return AnalysisResult({"error": "No metadata found in this image", "loras": []})
return await self._parse_metadata(
metadata,
recipe_scanner=recipe_scanner,
image_path=None,
include_image_base64=False,
)
finally:
self._safe_cleanup(temp_path)
async def analyze_remote_image(
self,
*,
url: str | None,
recipe_scanner,
civitai_client,
) -> AnalysisResult:
"""Analyze an image accessible via URL, including Civitai integration."""
if not url:
raise RecipeValidationError("No URL provided")
if civitai_client is None:
raise RecipeServiceError("Civitai client unavailable")
temp_path = self._create_temp_path()
metadata: Optional[dict[str, Any]] = None
try:
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url)
if civitai_match:
image_info = await civitai_client.get_image_info(civitai_match.group(1))
if not image_info:
raise RecipeDownloadError("Failed to fetch image information from Civitai")
image_url = image_info.get("url")
if not image_url:
raise RecipeDownloadError("No image URL found in Civitai response")
await self._download_image(image_url, temp_path)
metadata = image_info.get("meta") if "meta" in image_info else None
else:
await self._download_image(url, temp_path)
if metadata is None:
metadata = self._exif_utils.extract_image_metadata(temp_path)
if not metadata:
return self._metadata_not_found_response(temp_path)
return await self._parse_metadata(
metadata,
recipe_scanner=recipe_scanner,
image_path=temp_path,
include_image_base64=True,
)
finally:
self._safe_cleanup(temp_path)
async def analyze_local_image(
self,
*,
file_path: str | None,
recipe_scanner,
) -> AnalysisResult:
"""Analyze a file already present on disk."""
if not file_path:
raise RecipeValidationError("No file path provided")
normalized_path = os.path.normpath(file_path.strip('"').strip("'"))
if not os.path.isfile(normalized_path):
raise RecipeNotFoundError("File not found")
metadata = self._exif_utils.extract_image_metadata(normalized_path)
if not metadata:
return self._metadata_not_found_response(normalized_path)
return await self._parse_metadata(
metadata,
recipe_scanner=recipe_scanner,
image_path=normalized_path,
include_image_base64=True,
)
async def analyze_widget_metadata(self, *, recipe_scanner) -> AnalysisResult:
"""Analyse the most recent generation metadata for widget saves."""
if self._metadata_collector is None or self._metadata_processor_cls is None:
raise RecipeValidationError("Metadata collection not available")
raw_metadata = self._metadata_collector()
metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata)
if not metadata_dict:
raise RecipeValidationError("No generation metadata found")
latest_image = None
if not self._standalone_mode and self._metadata_registry_cls is not None:
metadata_registry = self._metadata_registry_cls()
latest_image = metadata_registry.get_first_decoded_image()
if latest_image is None:
raise RecipeValidationError(
"No recent images found to use for recipe. Try generating an image first."
)
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
if image_bytes is None:
raise RecipeValidationError("Cannot handle this data shape from metadata registry")
return AnalysisResult(
{
"metadata": metadata_dict,
"image_bytes": image_bytes,
}
)
# Internal helpers -------------------------------------------------
async def _parse_metadata(
self,
metadata: dict[str, Any],
*,
recipe_scanner,
image_path: Optional[str],
include_image_base64: bool,
) -> AnalysisResult:
parser = self._recipe_parser_factory.create_parser(metadata)
if parser is None:
payload = {"error": "No parser found for this image", "loras": []}
if include_image_base64 and image_path:
payload["image_base64"] = self._encode_file(image_path)
return AnalysisResult(payload)
result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner)
if include_image_base64 and image_path:
result["image_base64"] = self._encode_file(image_path)
if "error" in result and not result.get("loras"):
return AnalysisResult(result)
fingerprint = calculate_recipe_fingerprint(result.get("loras", []))
result["fingerprint"] = fingerprint
matching_recipes: list[str] = []
if fingerprint:
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
result["matching_recipes"] = matching_recipes
return AnalysisResult(result)
async def _download_image(self, url: str, temp_path: str) -> None:
downloader = await self._downloader_factory()
success, result = await downloader.download_file(url, temp_path, use_auth=False)
if not success:
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []}
if os.path.exists(path):
payload["image_base64"] = self._encode_file(path)
return AnalysisResult(payload)
def _write_temp_file(self, data: bytes) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
temp_file.write(data)
return temp_file.name
def _create_temp_path(self) -> str:
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
return temp_file.name
def _safe_cleanup(self, path: Optional[str]) -> None:
if path and os.path.exists(path):
try:
os.unlink(path)
except Exception as exc: # pragma: no cover - defensive logging
self._logger.error("Error deleting temporary file: %s", exc)
def _encode_file(self, path: str) -> str:
with open(path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def _convert_tensor_to_png_bytes(self, latest_image: Any) -> Optional[bytes]:
try:
if isinstance(latest_image, tuple):
tensor_image = latest_image[0] if latest_image else None
if tensor_image is None:
return None
else:
tensor_image = latest_image
if hasattr(tensor_image, "shape"):
self._logger.debug(
"Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None)
)
import torch # type: ignore[import-not-found]
if isinstance(tensor_image, torch.Tensor):
image_np = tensor_image.cpu().numpy()
else:
image_np = np.array(tensor_image)
while len(image_np.shape) > 3:
image_np = image_np[0]
if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0:
image_np = (image_np * 255).astype(np.uint8)
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
pil_image = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
pil_image.save(img_byte_arr, format="PNG")
return img_byte_arr.getvalue()
except Exception as exc: # pragma: no cover - defensive logging path
self._logger.error("Error processing image data: %s", exc, exc_info=True)
return None
return None

View File

@@ -1,22 +0,0 @@
"""Shared exceptions for recipe services."""
from __future__ import annotations
class RecipeServiceError(Exception):
"""Base exception for recipe service failures."""
class RecipeValidationError(RecipeServiceError):
"""Raised when a request payload fails validation."""
class RecipeNotFoundError(RecipeServiceError):
"""Raised when a recipe resource cannot be located."""
class RecipeDownloadError(RecipeServiceError):
"""Raised when remote recipe assets cannot be downloaded."""
class RecipeConflictError(RecipeServiceError):
"""Raised when a conflicting recipe state is detected."""

View File

@@ -1,400 +0,0 @@
"""Services encapsulating recipe persistence workflows."""
from __future__ import annotations
import base64
import json
import os
import re
import time
import uuid
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Optional
from ...config import config
from ...utils.utils import calculate_recipe_fingerprint
from .errors import RecipeNotFoundError, RecipeValidationError
@dataclass(frozen=True)
class PersistenceResult:
"""Return payload from persistence operations."""
payload: dict[str, Any]
status: int = 200
class RecipePersistenceService:
"""Coordinate recipe persistence tasks across storage and caches."""
def __init__(
self,
*,
exif_utils,
card_preview_width: int,
logger,
) -> None:
self._exif_utils = exif_utils
self._card_preview_width = card_preview_width
self._logger = logger
async def save_recipe(
self,
*,
recipe_scanner,
image_bytes: bytes | None,
image_base64: str | None,
name: str | None,
tags: Iterable[str],
metadata: Optional[dict[str, Any]],
) -> PersistenceResult:
"""Persist a user uploaded recipe."""
missing_fields = []
if not name:
missing_fields.append("name")
if metadata is None:
missing_fields.append("metadata")
if missing_fields:
raise RecipeValidationError(
f"Missing required fields: {', '.join(missing_fields)}"
)
resolved_image_bytes = self._resolve_image_bytes(image_bytes, image_base64)
recipes_dir = recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)
recipe_id = str(uuid.uuid4())
optimized_image, extension = self._exif_utils.optimize_image(
image_data=resolved_image_bytes,
target_width=self._card_preview_width,
format="webp",
quality=85,
preserve_metadata=True,
)
image_filename = f"{recipe_id}{extension}"
image_path = os.path.join(recipes_dir, image_filename)
with open(image_path, "wb") as file_obj:
file_obj.write(optimized_image)
current_time = time.time()
loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])]
gen_params = metadata.get("gen_params", {})
if not gen_params and "raw_metadata" in metadata:
raw_metadata = metadata.get("raw_metadata", {})
gen_params = {
"prompt": raw_metadata.get("prompt", ""),
"negative_prompt": raw_metadata.get("negative_prompt", ""),
"checkpoint": raw_metadata.get("checkpoint", {}),
"steps": raw_metadata.get("steps", ""),
"sampler": raw_metadata.get("sampler", ""),
"cfg_scale": raw_metadata.get("cfg_scale", ""),
"seed": raw_metadata.get("seed", ""),
"size": raw_metadata.get("size", ""),
"clip_skip": raw_metadata.get("clip_skip", ""),
}
fingerprint = calculate_recipe_fingerprint(loras_data)
recipe_data: Dict[str, Any] = {
"id": recipe_id,
"file_path": image_path,
"title": name,
"modified": current_time,
"created_date": current_time,
"base_model": metadata.get("base_model", ""),
"loras": loras_data,
"gen_params": gen_params,
"fingerprint": fingerprint,
}
tags_list = list(tags)
if tags_list:
recipe_data["tags"] = tags_list
if metadata.get("source_path"):
recipe_data["source_path"] = metadata.get("source_path")
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
await recipe_scanner.add_recipe(recipe_data)
return PersistenceResult(
{
"success": True,
"recipe_id": recipe_id,
"image_path": image_path,
"json_path": json_path,
"matching_recipes": matching_recipes,
}
)
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
"""Delete an existing recipe."""
recipes_dir = recipe_scanner.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
raise RecipeNotFoundError("Recipes directory not found")
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
raise RecipeNotFoundError("Recipe not found")
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
recipe_data = json.load(file_obj)
image_path = recipe_data.get("file_path")
os.remove(recipe_json_path)
if image_path and os.path.exists(image_path):
os.remove(image_path)
await recipe_scanner.remove_recipe(recipe_id)
return PersistenceResult({"success": True, "message": "Recipe deleted successfully"})
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
"""Update persisted metadata for a recipe."""
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")):
raise RecipeValidationError(
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
)
success = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
if not success:
raise RecipeNotFoundError("Recipe not found or update failed")
return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates})
async def reconnect_lora(
self,
*,
recipe_scanner,
recipe_id: str,
lora_index: int,
target_name: str,
) -> PersistenceResult:
"""Reconnect a LoRA entry within an existing recipe."""
recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_path):
raise RecipeNotFoundError("Recipe not found")
target_lora = await recipe_scanner.get_local_lora(target_name)
if not target_lora:
raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}")
recipe_data, updated_lora = await recipe_scanner.update_lora_entry(
recipe_id,
lora_index,
target_name=target_name,
target_lora=target_lora,
)
image_path = recipe_data.get("file_path")
if image_path and os.path.exists(image_path):
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
matching_recipes = []
if "fingerprint" in recipe_data:
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"])
if recipe_id in matching_recipes:
matching_recipes.remove(recipe_id)
return PersistenceResult(
{
"success": True,
"recipe_id": recipe_id,
"updated_lora": updated_lora,
"matching_recipes": matching_recipes,
}
)
async def bulk_delete(
self,
*,
recipe_scanner,
recipe_ids: Iterable[str],
) -> PersistenceResult:
"""Delete multiple recipes in a single request."""
recipe_ids = list(recipe_ids)
if not recipe_ids:
raise RecipeValidationError("No recipe IDs provided")
recipes_dir = recipe_scanner.recipes_dir
if not recipes_dir or not os.path.exists(recipes_dir):
raise RecipeNotFoundError("Recipes directory not found")
deleted_recipes: list[str] = []
failed_recipes: list[dict[str, Any]] = []
for recipe_id in recipe_ids:
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
if not os.path.exists(recipe_json_path):
failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"})
continue
try:
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
recipe_data = json.load(file_obj)
image_path = recipe_data.get("file_path")
os.remove(recipe_json_path)
if image_path and os.path.exists(image_path):
os.remove(image_path)
deleted_recipes.append(recipe_id)
except Exception as exc:
failed_recipes.append({"id": recipe_id, "reason": str(exc)})
if deleted_recipes:
await recipe_scanner.bulk_remove(deleted_recipes)
return PersistenceResult(
{
"success": True,
"deleted": deleted_recipes,
"failed": failed_recipes,
"total_deleted": len(deleted_recipes),
"total_failed": len(failed_recipes),
}
)
async def save_recipe_from_widget(
self,
*,
recipe_scanner,
metadata: dict[str, Any],
image_bytes: bytes,
) -> PersistenceResult:
"""Save a recipe constructed from widget metadata."""
if not metadata:
raise RecipeValidationError("No generation metadata found")
recipes_dir = recipe_scanner.recipes_dir
os.makedirs(recipes_dir, exist_ok=True)
recipe_id = str(uuid.uuid4())
image_filename = f"{recipe_id}.png"
image_path = os.path.join(recipes_dir, image_filename)
with open(image_path, "wb") as file_obj:
file_obj.write(image_bytes)
lora_stack = metadata.get("loras", "")
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack)
if not lora_matches:
raise RecipeValidationError("No LoRAs found in the generation metadata")
loras_data = []
base_model_counts: Dict[str, int] = {}
for name, strength in lora_matches:
lora_info = await recipe_scanner.get_local_lora(name)
lora_data = {
"file_name": name,
"strength": float(strength),
"hash": (lora_info.get("sha256") or "").lower() if lora_info else "",
"modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0,
"modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "",
"modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "",
"isDeleted": False,
"exclude": False,
}
loras_data.append(lora_data)
if lora_info and "base_model" in lora_info:
base_model = lora_info["base_model"]
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
recipe_name = self._derive_recipe_name(lora_matches)
most_common_base_model = (
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
)
recipe_data = {
"id": recipe_id,
"file_path": image_path,
"title": recipe_name,
"modified": time.time(),
"created_date": time.time(),
"base_model": most_common_base_model,
"loras": loras_data,
"checkpoint": metadata.get("checkpoint", ""),
"gen_params": {
key: value
for key, value in metadata.items()
if key not in ["checkpoint", "loras"]
},
"loras_stack": lora_stack,
}
json_filename = f"{recipe_id}.recipe.json"
json_path = os.path.join(recipes_dir, json_filename)
with open(json_path, "w", encoding="utf-8") as file_obj:
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
await recipe_scanner.add_recipe(recipe_data)
return PersistenceResult(
{
"success": True,
"recipe_id": recipe_id,
"image_path": image_path,
"json_path": json_path,
"recipe_name": recipe_name,
}
)
# Helper methods ---------------------------------------------------
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
if image_bytes is not None:
return image_bytes
if image_base64:
try:
payload = image_base64.split(",", 1)[1] if "," in image_base64 else image_base64
return base64.b64decode(payload)
except Exception as exc: # pragma: no cover - validation guard
raise RecipeValidationError(f"Invalid base64 image data: {exc}") from exc
raise RecipeValidationError("No image data provided")
def _normalise_lora_entry(self, lora: dict[str, Any]) -> dict[str, Any]:
return {
"file_name": lora.get("file_name", "")
or (
os.path.splitext(os.path.basename(lora.get("localPath", "")))[0]
if lora.get("localPath")
else ""
),
"hash": (lora.get("hash") or "").lower(),
"strength": float(lora.get("weight", 1.0)),
"modelVersionId": lora.get("id", 0),
"modelName": lora.get("name", ""),
"modelVersionName": lora.get("version", ""),
"isDeleted": lora.get("isDeleted", False),
"exclude": lora.get("exclude", False),
}
async def _find_matching_recipes(
self,
recipe_scanner,
fingerprint: str | None,
*,
exclude_id: Optional[str] = None,
) -> list[str]:
if not fingerprint:
return []
matches = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
if exclude_id and exclude_id in matches:
matches.remove(exclude_id)
return matches
def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str:
recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]]
recipe_name = "_".join(recipe_name_parts)
return recipe_name or "recipe"

View File

@@ -1,105 +0,0 @@
"""Services handling recipe sharing and downloads."""
from __future__ import annotations
import os
import shutil
import tempfile
import time
from dataclasses import dataclass
from typing import Any, Dict
from .errors import RecipeNotFoundError
@dataclass(frozen=True)
class SharingResult:
"""Return payload for share operations."""
payload: dict[str, Any]
status: int = 200
@dataclass(frozen=True)
class DownloadInfo:
"""Information required to stream a shared recipe file."""
file_path: str
download_filename: str
class RecipeSharingService:
"""Prepare temporary recipe downloads with TTL cleanup."""
def __init__(self, *, ttl_seconds: int = 300, logger) -> None:
self._ttl_seconds = ttl_seconds
self._logger = logger
self._shared_recipes: Dict[str, Dict[str, Any]] = {}
async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult:
"""Prepare a temporary downloadable copy of a recipe image."""
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
if not recipe:
raise RecipeNotFoundError("Recipe not found")
image_path = recipe.get("file_path")
if not image_path or not os.path.exists(image_path):
raise RecipeNotFoundError("Recipe image not found")
ext = os.path.splitext(image_path)[1]
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file:
temp_path = temp_file.name
shutil.copy2(image_path, temp_path)
timestamp = int(time.time())
self._shared_recipes[recipe_id] = {
"path": temp_path,
"timestamp": timestamp,
"expires": time.time() + self._ttl_seconds,
}
self._cleanup_shared_recipes()
safe_title = recipe.get("title", "").replace(" ", "_").lower()
filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}"
url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}"
return SharingResult({"success": True, "download_url": url_path, "filename": filename})
async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> DownloadInfo:
"""Return file path and filename for a prepared shared recipe."""
shared_info = self._shared_recipes.get(recipe_id)
if not shared_info or time.time() > shared_info.get("expires", 0):
self._cleanup_entry(recipe_id)
raise RecipeNotFoundError("Shared recipe not found or expired")
file_path = shared_info["path"]
if not os.path.exists(file_path):
self._cleanup_entry(recipe_id)
raise RecipeNotFoundError("Shared recipe file not found")
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
filename_base = (
f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id
)
ext = os.path.splitext(file_path)[1]
download_filename = f"{filename_base}{ext}"
return DownloadInfo(file_path=file_path, download_filename=download_filename)
def _cleanup_shared_recipes(self) -> None:
for recipe_id in list(self._shared_recipes.keys()):
shared = self._shared_recipes.get(recipe_id)
if not shared:
continue
if time.time() > shared.get("expires", 0):
self._cleanup_entry(recipe_id)
def _cleanup_entry(self, recipe_id: str) -> None:
shared_info = self._shared_recipes.pop(recipe_id, None)
if not shared_info:
return
file_path = shared_info.get("path")
if file_path and os.path.exists(file_path):
try:
os.unlink(file_path)
except Exception as exc: # pragma: no cover - defensive logging
self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc)

View File

@@ -1,114 +0,0 @@
import os
import json
import logging
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
class ServerI18nManager:
"""Server-side internationalization manager for template rendering"""
def __init__(self):
self.translations = {}
self.current_locale = 'en'
self._load_translations()
def _load_translations(self):
"""Load all translation files from the locales directory"""
i18n_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
'locales'
)
if not os.path.exists(i18n_path):
logger.warning(f"I18n directory not found: {i18n_path}")
return
# Load all available locale files
for filename in os.listdir(i18n_path):
if filename.endswith('.json'):
locale_code = filename[:-5] # Remove .json extension
try:
self._load_locale_file(i18n_path, filename, locale_code)
except Exception as e:
logger.error(f"Error loading locale file {filename}: {e}")
def _load_locale_file(self, path: str, filename: str, locale_code: str):
"""Load a single locale JSON file"""
file_path = os.path.join(path, filename)
try:
with open(file_path, 'r', encoding='utf-8') as f:
translations = json.load(f)
self.translations[locale_code] = translations
logger.debug(f"Loaded translations for {locale_code} from {filename}")
except Exception as e:
logger.error(f"Error parsing locale file {filename}: {e}")
def set_locale(self, locale: str):
"""Set the current locale"""
if locale in self.translations:
self.current_locale = locale
else:
logger.warning(f"Locale {locale} not found, using 'en'")
self.current_locale = 'en'
def get_translation(self, key: str, params: Dict[str, Any] = None, **kwargs) -> str:
"""Get translation for a key with optional parameters (supports both dict and keyword args)"""
# Merge kwargs into params for convenience
if params is None:
params = {}
if kwargs:
params = {**params, **kwargs}
if self.current_locale not in self.translations:
return key
# Navigate through nested object using dot notation
keys = key.split('.')
value = self.translations[self.current_locale]
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
# Fallback to English if current locale doesn't have the key
if self.current_locale != 'en' and 'en' in self.translations:
en_value = self.translations['en']
for k in keys:
if isinstance(en_value, dict) and k in en_value:
en_value = en_value[k]
else:
return key
value = en_value
else:
return key
break
if not isinstance(value, str):
return key
# Replace parameters if provided
if params:
for param_key, param_value in params.items():
placeholder = f"{{{param_key}}}"
double_placeholder = f"{{{{{param_key}}}}}"
value = value.replace(placeholder, str(param_value))
value = value.replace(double_placeholder, str(param_value))
return value
def get_available_locales(self) -> list:
"""Get list of available locales"""
return list(self.translations.keys())
def create_template_filter(self):
"""Create a Jinja2 filter function for templates"""
def t_filter(key: str, **params) -> str:
return self.get_translation(key, params)
return t_filter
# Create global instance
server_i18n = ServerI18nManager()

View File

@@ -5,43 +5,10 @@ from typing import Any, Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_SETTINGS: Dict[str, Any] = {
"civitai_api_key": "",
"language": "en",
"show_only_sfw": False,
"enable_metadata_archive_db": False,
"proxy_enabled": False,
"proxy_host": "",
"proxy_port": "",
"proxy_username": "",
"proxy_password": "",
"proxy_type": "http",
"default_lora_root": "",
"default_checkpoint_root": "",
"default_embedding_root": "",
"base_model_path_mappings": {},
"download_path_templates": {},
"example_images_path": "",
"optimize_example_images": True,
"auto_download_example_images": False,
"blur_mature_content": True,
"autoplay_on_hover": False,
"display_density": "default",
"card_info_display": "always",
"include_trigger_words": False,
"compact_mode": False,
}
class SettingsManager: class SettingsManager:
def __init__(self): def __init__(self):
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'settings.json') self.settings_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'settings.json')
self.settings = self._load_settings() self.settings = self._load_settings()
self._migrate_setting_keys()
self._ensure_default_settings()
self._migrate_download_path_template()
self._auto_set_default_roots()
self._check_environment_variables() self._check_environment_variables()
def _load_settings(self) -> Dict[str, Any]: def _load_settings(self) -> Dict[str, Any]:
@@ -54,84 +21,6 @@ class SettingsManager:
logger.error(f"Error loading settings: {e}") logger.error(f"Error loading settings: {e}")
return self._get_default_settings() return self._get_default_settings()
def _ensure_default_settings(self) -> None:
"""Ensure all default settings keys exist"""
updated = False
for key, value in self._get_default_settings().items():
if key not in self.settings:
if isinstance(value, dict):
self.settings[key] = value.copy()
else:
self.settings[key] = value
updated = True
if updated:
self._save_settings()
def _migrate_setting_keys(self) -> None:
"""Migrate legacy camelCase setting keys to snake_case"""
key_migrations = {
'optimizeExampleImages': 'optimize_example_images',
'autoDownloadExampleImages': 'auto_download_example_images',
'blurMatureContent': 'blur_mature_content',
'autoplayOnHover': 'autoplay_on_hover',
'displayDensity': 'display_density',
'cardInfoDisplay': 'card_info_display',
'includeTriggerWords': 'include_trigger_words',
'compactMode': 'compact_mode',
}
updated = False
for old_key, new_key in key_migrations.items():
if old_key in self.settings:
if new_key not in self.settings:
self.settings[new_key] = self.settings[old_key]
del self.settings[old_key]
updated = True
if updated:
logger.info("Migrated legacy setting keys to snake_case")
self._save_settings()
def _migrate_download_path_template(self):
"""Migrate old download_path_template to new download_path_templates"""
old_template = self.settings.get('download_path_template')
templates = self.settings.get('download_path_templates')
# If old template exists and new templates don't exist, migrate
if old_template is not None and not templates:
logger.info("Migrating download_path_template to download_path_templates")
self.settings['download_path_templates'] = {
'lora': old_template,
'checkpoint': old_template,
'embedding': old_template
}
# Remove old setting
del self.settings['download_path_template']
self._save_settings()
logger.info("Migration completed")
def _auto_set_default_roots(self):
"""Auto set default root paths if only one folder is present and default is empty."""
folder_paths = self.settings.get('folder_paths', {})
updated = False
# loras
loras = folder_paths.get('loras', [])
if isinstance(loras, list) and len(loras) == 1 and not self.settings.get('default_lora_root'):
self.settings['default_lora_root'] = loras[0]
updated = True
# checkpoints
checkpoints = folder_paths.get('checkpoints', [])
if isinstance(checkpoints, list) and len(checkpoints) == 1 and not self.settings.get('default_checkpoint_root'):
self.settings['default_checkpoint_root'] = checkpoints[0]
updated = True
# embeddings
embeddings = folder_paths.get('embeddings', [])
if isinstance(embeddings, list) and len(embeddings) == 1 and not self.settings.get('default_embedding_root'):
self.settings['default_embedding_root'] = embeddings[0]
updated = True
if updated:
self._save_settings()
def _check_environment_variables(self) -> None: def _check_environment_variables(self) -> None:
"""Check for environment variables and update settings if needed""" """Check for environment variables and update settings if needed"""
env_api_key = os.environ.get('CIVITAI_API_KEY') env_api_key = os.environ.get('CIVITAI_API_KEY')
@@ -147,11 +36,10 @@ class SettingsManager:
def _get_default_settings(self) -> Dict[str, Any]: def _get_default_settings(self) -> Dict[str, Any]:
"""Return default settings""" """Return default settings"""
defaults = DEFAULT_SETTINGS.copy() return {
# Ensure nested dicts are independent copies "civitai_api_key": "",
defaults['base_model_path_mappings'] = {} "show_only_sfw": False
defaults['download_path_templates'] = {} }
return defaults
def get(self, key: str, default: Any = None) -> Any: def get(self, key: str, default: Any = None) -> Any:
"""Get setting value""" """Get setting value"""
@@ -162,13 +50,6 @@ class SettingsManager:
self.settings[key] = value self.settings[key] = value
self._save_settings() self._save_settings()
def delete(self, key: str) -> None:
"""Delete setting key and save"""
if key in self.settings:
del self.settings[key]
self._save_settings()
logger.info(f"Deleted setting: {key}")
def _save_settings(self) -> None: def _save_settings(self) -> None:
"""Save settings to file""" """Save settings to file"""
try: try:
@@ -177,53 +58,4 @@ class SettingsManager:
except Exception as e: except Exception as e:
logger.error(f"Error saving settings: {e}") logger.error(f"Error saving settings: {e}")
def get_download_path_template(self, model_type: str) -> str:
"""Get download path template for specific model type
Args:
model_type: The type of model ('lora', 'checkpoint', 'embedding')
Returns:
Template string for the model type, defaults to '{base_model}/{first_tag}'
"""
templates = self.settings.get('download_path_templates', {})
# Handle edge case where templates might be stored as JSON string
if isinstance(templates, str):
try:
# Try to parse JSON string
parsed_templates = json.loads(templates)
if isinstance(parsed_templates, dict):
# Update settings with parsed dictionary
self.settings['download_path_templates'] = parsed_templates
self._save_settings()
templates = parsed_templates
logger.info("Successfully parsed download_path_templates from JSON string")
else:
raise ValueError("Parsed JSON is not a dictionary")
except (json.JSONDecodeError, ValueError) as e:
# If parsing fails, set default values
logger.warning(f"Failed to parse download_path_templates JSON string: {e}. Setting default values.")
default_template = '{base_model}/{first_tag}'
templates = {
'lora': default_template,
'checkpoint': default_template,
'embedding': default_template
}
self.settings['download_path_templates'] = templates
self._save_settings()
# Ensure templates is a dictionary
if not isinstance(templates, dict):
default_template = '{base_model}/{first_tag}'
templates = {
'lora': default_template,
'checkpoint': default_template,
'embedding': default_template
}
self.settings['download_path_templates'] = templates
self._save_settings()
return templates.get(model_type, '{base_model}/{first_tag}')
settings = SettingsManager() settings = SettingsManager()

View File

@@ -1,47 +0,0 @@
"""Service for updating tag collections on metadata records."""
from __future__ import annotations
import os
from typing import Awaitable, Callable, Dict, List, Sequence
class TagUpdateService:
"""Encapsulate tag manipulation for models."""
def __init__(self, *, metadata_manager) -> None:
self._metadata_manager = metadata_manager
async def add_tags(
self,
*,
file_path: str,
new_tags: Sequence[str],
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
) -> List[str]:
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
base, _ = os.path.splitext(file_path)
metadata_path = f"{base}.metadata.json"
metadata = await metadata_loader(metadata_path)
existing_tags = list(metadata.get("tags", []))
existing_lower = [tag.lower() for tag in existing_tags]
tags_added: List[str] = []
for tag in new_tags:
if isinstance(tag, str) and tag.strip():
normalized = tag.strip()
if normalized.lower() not in existing_lower:
existing_tags.append(normalized)
existing_lower.append(normalized.lower())
tags_added.append(normalized)
metadata["tags"] = existing_tags
await self._metadata_manager.save_metadata(file_path, metadata)
await update_cache(file_path, file_path, metadata)
return existing_tags

View File

@@ -1,37 +0,0 @@
"""Application-level orchestration services for model routes."""
from .auto_organize_use_case import (
AutoOrganizeInProgressError,
AutoOrganizeUseCase,
)
from .bulk_metadata_refresh_use_case import (
BulkMetadataRefreshUseCase,
MetadataRefreshProgressReporter,
)
from .download_model_use_case import (
DownloadModelEarlyAccessError,
DownloadModelUseCase,
DownloadModelValidationError,
)
from .example_images import (
DownloadExampleImagesConfigurationError,
DownloadExampleImagesInProgressError,
DownloadExampleImagesUseCase,
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
__all__ = [
"AutoOrganizeInProgressError",
"AutoOrganizeUseCase",
"BulkMetadataRefreshUseCase",
"MetadataRefreshProgressReporter",
"DownloadModelEarlyAccessError",
"DownloadModelUseCase",
"DownloadModelValidationError",
"DownloadExampleImagesConfigurationError",
"DownloadExampleImagesInProgressError",
"DownloadExampleImagesUseCase",
"ImportExampleImagesUseCase",
"ImportExampleImagesValidationError",
]

View File

@@ -1,56 +0,0 @@
"""Auto-organize use case orchestrating concurrency and progress handling."""
from __future__ import annotations
import asyncio
from typing import Optional, Protocol, Sequence
from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback
class AutoOrganizeLockProvider(Protocol):
"""Minimal protocol for objects exposing auto-organize locking primitives."""
def is_auto_organize_running(self) -> bool:
"""Return ``True`` when an auto-organize operation is in-flight."""
async def get_auto_organize_lock(self) -> asyncio.Lock:
"""Return the asyncio lock guarding auto-organize operations."""
class AutoOrganizeInProgressError(RuntimeError):
"""Raised when an auto-organize run is already active."""
class AutoOrganizeUseCase:
"""Coordinate auto-organize execution behind a shared lock."""
def __init__(
self,
*,
file_service: ModelFileService,
lock_provider: AutoOrganizeLockProvider,
) -> None:
self._file_service = file_service
self._lock_provider = lock_provider
async def execute(
self,
*,
file_paths: Optional[Sequence[str]] = None,
progress_callback: Optional[ProgressCallback] = None,
) -> AutoOrganizeResult:
"""Run the auto-organize routine guarded by a shared lock."""
if self._lock_provider.is_auto_organize_running():
raise AutoOrganizeInProgressError("Auto-organize is already running")
lock = await self._lock_provider.get_auto_organize_lock()
if lock.locked():
raise AutoOrganizeInProgressError("Auto-organize is already running")
async with lock:
return await self._file_service.auto_organize_models(
file_paths=list(file_paths) if file_paths is not None else None,
progress_callback=progress_callback,
)

View File

@@ -1,122 +0,0 @@
"""Use case encapsulating the bulk metadata refresh orchestration."""
from __future__ import annotations
import logging
from typing import Any, Dict, Optional, Protocol, Sequence
from ..metadata_sync_service import MetadataSyncService
class MetadataRefreshProgressReporter(Protocol):
"""Protocol for progress reporters used during metadata refresh."""
async def on_progress(self, payload: Dict[str, Any]) -> None:
"""Handle a metadata refresh progress update."""
class BulkMetadataRefreshUseCase:
"""Coordinate bulk metadata refreshes with progress emission."""
def __init__(
self,
*,
service,
metadata_sync: MetadataSyncService,
settings_service,
logger: Optional[logging.Logger] = None,
) -> None:
self._service = service
self._metadata_sync = metadata_sync
self._settings = settings_service
self._logger = logger or logging.getLogger(__name__)
async def execute(
self,
*,
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
) -> Dict[str, Any]:
"""Refresh metadata for all qualifying models."""
cache = await self._service.scanner.get_cached_data()
total_models = len(cache.raw_data)
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
to_process: Sequence[Dict[str, Any]] = [
model
for model in cache.raw_data
if model.get("sha256")
and (not model.get("civitai") or not model["civitai"].get("id"))
and (
(enable_metadata_archive_db and not model.get("db_checked", False))
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
)
]
total_to_process = len(to_process)
processed = 0
success = 0
needs_resort = False
async def emit(status: str, **extra: Any) -> None:
if progress_callback is None:
return
payload = {"status": status, "total": total_to_process, "processed": processed, "success": success}
payload.update(extra)
await progress_callback.on_progress(payload)
await emit("started")
for model in to_process:
try:
original_name = model.get("model_name")
result, _ = await self._metadata_sync.fetch_and_update_model(
sha256=model["sha256"],
file_path=model["file_path"],
model_data=model,
update_cache_func=self._service.scanner.update_single_model_cache,
)
if result:
success += 1
if original_name != model.get("model_name"):
needs_resort = True
processed += 1
await emit(
"processing",
processed=processed,
success=success,
current_name=model.get("model_name", "Unknown"),
)
except Exception as exc: # pragma: no cover - logging path
processed += 1
self._logger.error(
"Error fetching CivitAI data for %s: %s",
model.get("file_path"),
exc,
)
if needs_resort:
await cache.resort()
await emit("completed", processed=processed, success=success)
message = (
"Successfully updated "
f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})"
)
return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models}
async def execute_with_error_handling(
self,
*,
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
) -> Dict[str, Any]:
"""Wrapper providing progress notification on unexpected failures."""
try:
return await self.execute(progress_callback=progress_callback)
except Exception as exc:
if progress_callback is not None:
await progress_callback.on_progress({"status": "error", "error": str(exc)})
raise

View File

@@ -1,37 +0,0 @@
"""Use case for scheduling model downloads with consistent error handling."""
from __future__ import annotations
from typing import Any, Dict
from ..download_coordinator import DownloadCoordinator
class DownloadModelValidationError(ValueError):
"""Raised when incoming payload validation fails."""
class DownloadModelEarlyAccessError(RuntimeError):
"""Raised when the download is gated behind Civitai early access."""
class DownloadModelUseCase:
"""Coordinate download scheduling through the coordinator service."""
def __init__(self, *, download_coordinator: DownloadCoordinator) -> None:
self._download_coordinator = download_coordinator
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Schedule a download and normalize error conditions."""
try:
return await self._download_coordinator.schedule_download(payload)
except ValueError as exc:
raise DownloadModelValidationError(str(exc)) from exc
except Exception as exc: # pragma: no cover - defensive logging path
message = str(exc)
if "401" in message:
raise DownloadModelEarlyAccessError(
"Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
) from exc
raise

View File

@@ -1,19 +0,0 @@
"""Example image specific use case exports."""
from .download_example_images_use_case import (
DownloadExampleImagesUseCase,
DownloadExampleImagesInProgressError,
DownloadExampleImagesConfigurationError,
)
from .import_example_images_use_case import (
ImportExampleImagesUseCase,
ImportExampleImagesValidationError,
)
__all__ = [
"DownloadExampleImagesUseCase",
"DownloadExampleImagesInProgressError",
"DownloadExampleImagesConfigurationError",
"ImportExampleImagesUseCase",
"ImportExampleImagesValidationError",
]

View File

@@ -1,42 +0,0 @@
"""Use case coordinating example image downloads."""
from __future__ import annotations
from typing import Any, Dict
from ....utils.example_images_download_manager import (
DownloadConfigurationError,
DownloadInProgressError,
ExampleImagesDownloadError,
)
class DownloadExampleImagesInProgressError(RuntimeError):
"""Raised when a download is already running."""
def __init__(self, progress: Dict[str, Any]) -> None:
super().__init__("Download already in progress")
self.progress = progress
class DownloadExampleImagesConfigurationError(ValueError):
"""Raised when settings prevent downloads from starting."""
class DownloadExampleImagesUseCase:
"""Validate payloads and trigger the download manager."""
def __init__(self, *, download_manager) -> None:
self._download_manager = download_manager
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Start a download and translate manager errors."""
try:
return await self._download_manager.start_download(payload)
except DownloadInProgressError as exc:
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
except DownloadConfigurationError as exc:
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
except ExampleImagesDownloadError:
raise

View File

@@ -1,86 +0,0 @@
"""Use case for importing example images."""
from __future__ import annotations
import os
import tempfile
from contextlib import suppress
from typing import Any, Dict, List
from aiohttp import web
from ....utils.example_images_processor import (
ExampleImagesImportError,
ExampleImagesProcessor,
ExampleImagesValidationError,
)
class ImportExampleImagesValidationError(ValueError):
"""Raised when request validation fails."""
class ImportExampleImagesUseCase:
"""Parse upload payloads and delegate to the processor service."""
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
self._processor = processor
async def execute(self, request: web.Request) -> Dict[str, Any]:
model_hash: str | None = None
files_to_import: List[str] = []
temp_files: List[str] = []
try:
if request.content_type and "multipart/form-data" in request.content_type:
reader = await request.multipart()
first_field = await reader.next()
if first_field and first_field.name == "model_hash":
model_hash = await first_field.text()
else:
# Support clients that send files first and hash later
if first_field is not None:
await self._collect_upload_file(first_field, files_to_import, temp_files)
async for field in reader:
if field.name == "model_hash" and not model_hash:
model_hash = await field.text()
elif field.name == "files":
await self._collect_upload_file(field, files_to_import, temp_files)
else:
data = await request.json()
model_hash = data.get("model_hash")
files_to_import = list(data.get("file_paths", []))
result = await self._processor.import_images(model_hash, files_to_import)
return result
except ExampleImagesValidationError as exc:
raise ImportExampleImagesValidationError(str(exc)) from exc
except ExampleImagesImportError:
raise
finally:
for path in temp_files:
with suppress(Exception):
os.remove(path)
async def _collect_upload_file(
self,
field: Any,
files_to_import: List[str],
temp_files: List[str],
) -> None:
"""Persist an uploaded file to disk and add it to the import list."""
filename = field.filename or "upload"
file_ext = os.path.splitext(filename)[1].lower()
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
temp_files.append(tmp_file.name)
while True:
chunk = await field.read_chunk()
if not chunk:
break
tmp_file.write(chunk)
files_to_import.append(tmp_file.name)

View File

@@ -16,9 +16,6 @@ class WebSocketManager:
self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients
# Add progress tracking dictionary # Add progress tracking dictionary
self._download_progress: Dict[str, Dict] = {} self._download_progress: Dict[str, Dict] = {}
# Add auto-organize progress tracking
self._auto_organize_progress: Optional[Dict] = None
self._auto_organize_lock = asyncio.Lock()
async def handle_connection(self, request: web.Request) -> web.WebSocketResponse: async def handle_connection(self, request: web.Request) -> web.WebSocketResponse:
"""Handle new WebSocket connection""" """Handle new WebSocket connection"""
@@ -137,33 +134,6 @@ class WebSocketManager:
except Exception as e: except Exception as e:
logger.error(f"Error sending download progress: {e}") logger.error(f"Error sending download progress: {e}")
async def broadcast_auto_organize_progress(self, data: Dict):
"""Broadcast auto-organize progress to connected clients"""
# Store progress data in memory
self._auto_organize_progress = data
# Broadcast via WebSocket
await self.broadcast(data)
def get_auto_organize_progress(self) -> Optional[Dict]:
"""Get current auto-organize progress"""
return self._auto_organize_progress
def cleanup_auto_organize_progress(self):
"""Clear auto-organize progress data"""
self._auto_organize_progress = None
def is_auto_organize_running(self) -> bool:
"""Check if auto-organize is currently running"""
if not self._auto_organize_progress:
return False
status = self._auto_organize_progress.get('status')
return status in ['started', 'processing', 'cleaning']
async def get_auto_organize_lock(self):
"""Get the auto-organize lock"""
return self._auto_organize_lock
def get_download_progress(self, download_id: str) -> Optional[Dict]: def get_download_progress(self, download_id: str) -> Optional[Dict]:
"""Get progress information for a specific download""" """Get progress information for a specific download"""
return self._download_progress.get(download_id) return self._download_progress.get(download_id)

View File

@@ -1,29 +0,0 @@
"""Progress callback implementations backed by the shared WebSocket manager."""
from typing import Any, Dict, Protocol
from .model_file_service import ProgressCallback
from .websocket_manager import ws_manager
class ProgressReporter(Protocol):
"""Protocol representing an async progress callback."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Handle a progress update payload."""
class WebSocketProgressCallback(ProgressCallback):
"""WebSocket implementation of progress callback."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send progress data via WebSocket."""
await ws_manager.broadcast_auto_organize_progress(progress_data)
class WebSocketBroadcastCallback:
"""Generic WebSocket progress callback broadcasting to all clients."""
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
"""Send the provided payload to all connected clients."""
await ws_manager.broadcast(progress_data)

View File

@@ -48,13 +48,9 @@ SUPPORTED_MEDIA_EXTENSIONS = {
# Valid Lora types # Valid Lora types
VALID_LORA_TYPES = ['lora', 'locon', 'dora'] VALID_LORA_TYPES = ['lora', 'locon', 'dora']
# Auto-organize settings
AUTO_ORGANIZE_BATCH_SIZE = 50 # Process models in batches to avoid overwhelming the system
# Civitai model tags in priority order for subfolder organization # Civitai model tags in priority order for subfolder organization
CIVITAI_MODEL_TAGS = [ CIVITAI_MODEL_TAGS = [
'character', 'concept', 'clothing', 'character', 'style', 'concept', 'clothing', 'base model',
'realistic', 'anime', 'toon', 'furry', 'style',
'poses', 'background', 'tool', 'vehicle', 'buildings', 'poses', 'background', 'tool', 'vehicle', 'buildings',
'objects', 'assets', 'animal', 'action' 'objects', 'assets', 'animal', 'action'
] ]

View File

@@ -1,217 +1,208 @@
from __future__ import annotations
import logging import logging
import os import os
import asyncio import asyncio
import json import json
import time import time
from typing import Any, Dict import aiohttp
from aiohttp import web
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater from .example_images_metadata import MetadataUpdater
from ..services.downloader import get_downloader
from ..services.settings_manager import settings
class ExampleImagesDownloadError(RuntimeError):
"""Base error for example image download operations."""
class DownloadInProgressError(ExampleImagesDownloadError):
"""Raised when a download is already running."""
def __init__(self, progress_snapshot: dict) -> None:
super().__init__("Download already in progress")
self.progress_snapshot = progress_snapshot
class DownloadNotRunningError(ExampleImagesDownloadError):
"""Raised when pause/resume is requested without an active download."""
def __init__(self, message: str = "No download in progress") -> None:
super().__init__(message)
class DownloadConfigurationError(ExampleImagesDownloadError):
"""Raised when configuration prevents starting a download."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Download status tracking
class _DownloadProgress(dict): download_task = None
"""Mutable mapping maintaining download progress with set-aware serialisation.""" is_downloading = False
download_progress = {
def __init__(self) -> None: 'total': 0,
super().__init__() 'completed': 0,
self.reset() 'current_model': '',
'status': 'idle', # idle, running, paused, completed, error
def reset(self) -> None: 'errors': [],
"""Reset the progress dictionary to its initial state.""" 'last_error': None,
'start_time': None,
self.update( 'end_time': None,
total=0, 'processed_models': set(), # Track models that have been processed
completed=0, 'refreshed_models': set() # Track models that had metadata refreshed
current_model='', }
status='idle',
errors=[],
last_error=None,
start_time=None,
end_time=None,
processed_models=set(),
refreshed_models=set(),
failed_models=set(),
)
def snapshot(self) -> dict:
"""Return a JSON-serialisable snapshot of the current progress."""
snapshot = dict(self)
snapshot['processed_models'] = list(self['processed_models'])
snapshot['refreshed_models'] = list(self['refreshed_models'])
snapshot['failed_models'] = list(self['failed_models'])
return snapshot
class DownloadManager: class DownloadManager:
"""Manages downloading example images for models.""" """Manages downloading example images for models"""
def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None: @staticmethod
self._download_task: asyncio.Task | None = None async def start_download(request):
self._is_downloading = False """
self._progress = _DownloadProgress() Start downloading example images for models
self._ws_manager = ws_manager
self._state_lock = state_lock or asyncio.Lock()
async def start_download(self, options: dict): Expects a JSON body with:
"""Start downloading example images for models.""" {
"output_dir": "path/to/output", # Base directory to save example images
"optimize": true, # Whether to optimize images (default: true)
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
"delay": 1.0 # Delay between downloads to avoid rate limiting (default: 1.0)
}
"""
global download_task, is_downloading, download_progress
async with self._state_lock: if is_downloading:
if self._is_downloading: # Create a copy for JSON serialization
raise DownloadInProgressError(self._progress.snapshot()) response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
try: return web.json_response({
data = options or {} 'success': False,
auto_mode = data.get('auto_mode', False) 'error': 'Download already in progress',
optimize = data.get('optimize', True) 'status': response_progress
model_types = data.get('model_types', ['lora', 'checkpoint']) }, status=400)
delay = float(data.get('delay', 0.2))
output_dir = settings.get('example_images_path') try:
# Parse the request body
data = await request.json()
output_dir = data.get('output_dir')
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
if not output_dir: if not output_dir:
error_msg = 'Example images path not configured in settings' return web.json_response({
if auto_mode: 'success': False,
logger.debug(error_msg) 'error': 'Missing output_dir parameter'
return { }, status=400)
'success': True,
'message': 'Example images path not configured, skipping auto download'
}
raise DownloadConfigurationError(error_msg)
os.makedirs(output_dir, exist_ok=True) # Create the output directory
os.makedirs(output_dir, exist_ok=True)
self._progress.reset() # Initialize progress tracking
self._progress['status'] = 'running' download_progress['total'] = 0
self._progress['start_time'] = time.time() download_progress['completed'] = 0
self._progress['end_time'] = None download_progress['current_model'] = ''
download_progress['status'] = 'running'
download_progress['errors'] = []
download_progress['last_error'] = None
download_progress['start_time'] = time.time()
download_progress['end_time'] = None
progress_file = os.path.join(output_dir, '.download_progress.json') # Get the processed models list from a file if it exists
if os.path.exists(progress_file): progress_file = os.path.join(output_dir, '.download_progress.json')
try: if os.path.exists(progress_file):
with open(progress_file, 'r', encoding='utf-8') as f: try:
saved_progress = json.load(f) with open(progress_file, 'r', encoding='utf-8') as f:
self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) saved_progress = json.load(f)
self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) download_progress['processed_models'] = set(saved_progress.get('processed_models', []))
logger.debug( logger.info(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed")
"Loaded previous progress, %s models already processed, %s models marked as failed", except Exception as e:
len(self._progress['processed_models']), logger.error(f"Failed to load progress file: {e}")
len(self._progress['failed_models']), download_progress['processed_models'] = set()
) else:
except Exception as e: download_progress['processed_models'] = set()
logger.error(f"Failed to load progress file: {e}")
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
else:
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
self._is_downloading = True # Start the download task
self._download_task = asyncio.create_task( is_downloading = True
self._download_all_example_images( download_task = asyncio.create_task(
output_dir, DownloadManager._download_all_example_images(
optimize, output_dir,
model_types, optimize,
delay model_types,
) delay
) )
)
snapshot = self._progress.snapshot() # Create a copy for JSON serialization
except Exception as e: response_progress = download_progress.copy()
self._is_downloading = False response_progress['processed_models'] = list(download_progress['processed_models'])
self._download_task = None response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
await self._broadcast_progress(status='running') return web.json_response({
'success': True,
'message': 'Download started',
'status': response_progress
})
return { except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def get_status(request):
"""Get the current status of example images download"""
global download_progress
# Create a copy of the progress dict with the set converted to a list for JSON serialization
response_progress = download_progress.copy()
response_progress['processed_models'] = list(download_progress['processed_models'])
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
return web.json_response({
'success': True, 'success': True,
'message': 'Download started', 'is_downloading': is_downloading,
'status': snapshot 'status': response_progress
} })
async def get_status(self, request): @staticmethod
"""Get the current status of example images download.""" async def pause_download(request):
"""Pause the example images download"""
global download_progress
return { if not is_downloading:
'success': True, return web.json_response({
'is_downloading': self._is_downloading, 'success': False,
'status': self._progress.snapshot(), 'error': 'No download in progress'
} }, status=400)
async def pause_download(self, request): download_progress['status'] = 'paused'
"""Pause the example images download."""
async with self._state_lock: return web.json_response({
if not self._is_downloading:
raise DownloadNotRunningError()
self._progress['status'] = 'paused'
await self._broadcast_progress(status='paused')
return {
'success': True, 'success': True,
'message': 'Download paused' 'message': 'Download paused'
} })
async def resume_download(self, request): @staticmethod
"""Resume the example images download.""" async def resume_download(request):
"""Resume the example images download"""
global download_progress
async with self._state_lock: if not is_downloading:
if not self._is_downloading: return web.json_response({
raise DownloadNotRunningError() 'success': False,
'error': 'No download in progress'
}, status=400)
if self._progress['status'] == 'paused': if download_progress['status'] == 'paused':
self._progress['status'] = 'running' download_progress['status'] = 'running'
else:
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
await self._broadcast_progress(status='running') return web.json_response({
'success': True,
'message': 'Download resumed'
})
else:
return web.json_response({
'success': False,
'error': f"Download is in '{download_progress['status']}' state, cannot resume"
}, status=400)
return { @staticmethod
'success': True, async def _download_all_example_images(output_dir, optimize, model_types, delay):
'message': 'Download resumed' """Download example images for all models"""
} global is_downloading, download_progress
async def _download_all_example_images(self, output_dir, optimize, model_types, delay): # Create independent download session
"""Download example images for all models.""" connector = aiohttp.TCPConnector(
ssl=True,
downloader = await get_downloader() limit=3,
force_close=False,
enable_cleanup_closed=True
)
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
independent_session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=timeout
)
try: try:
# Get scanners # Get scanners
@@ -238,67 +229,65 @@ class DownloadManager:
all_models.append((scanner_type, model, scanner)) all_models.append((scanner_type, model, scanner))
# Update total count # Update total count
self._progress['total'] = len(all_models) download_progress['total'] = len(all_models)
logger.debug(f"Found {self._progress['total']} models to process") logger.info(f"Found {download_progress['total']} models to process")
await self._broadcast_progress(status='running')
# Process each model # Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models): for i, (scanner_type, model, scanner) in enumerate(all_models):
# Main logic for processing model is here, but actual operations are delegated to other classes # Main logic for processing model is here, but actual operations are delegated to other classes
was_remote_download = await self._process_model( was_remote_download = await DownloadManager._process_model(
scanner_type, model, scanner, scanner_type, model, scanner,
output_dir, optimize, downloader output_dir, optimize, independent_session
) )
# Update progress # Update progress
self._progress['completed'] += 1 download_progress['completed'] += 1
await self._broadcast_progress(status='running')
# Only add delay after remote download of models, and not after processing the last model # Only add delay after remote download of models, and not after processing the last model
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': if was_remote_download and i < len(all_models) - 1 and download_progress['status'] == 'running':
await asyncio.sleep(delay) await asyncio.sleep(delay)
# Mark as completed # Mark as completed
self._progress['status'] = 'completed' download_progress['status'] = 'completed'
self._progress['end_time'] = time.time() download_progress['end_time'] = time.time()
logger.debug( logger.info(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
"Example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
await self._broadcast_progress(status='completed')
except Exception as e: except Exception as e:
error_msg = f"Error during example images download: {str(e)}" error_msg = f"Error during example images download: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
self._progress['errors'].append(error_msg) download_progress['errors'].append(error_msg)
self._progress['last_error'] = error_msg download_progress['last_error'] = error_msg
self._progress['status'] = 'error' download_progress['status'] = 'error'
self._progress['end_time'] = time.time() download_progress['end_time'] = time.time()
await self._broadcast_progress(status='error', extra={'error': error_msg})
finally: finally:
# Close the independent session
try:
await independent_session.close()
except Exception as e:
logger.error(f"Error closing download session: {e}")
# Save final progress to file # Save final progress to file
try: try:
self._save_progress(output_dir) DownloadManager._save_progress(output_dir)
except Exception as e: except Exception as e:
logger.error(f"Failed to save progress file: {e}") logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading # Set download status to not downloading
async with self._state_lock: is_downloading = False
self._is_downloading = False
self._download_task = None
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): @staticmethod
"""Process a single model download.""" async def _process_model(scanner_type, model, scanner, output_dir, optimize, independent_session):
"""Process a single model download"""
global download_progress
# Check if download is paused # Check if download is paused
while self._progress['status'] == 'paused': while download_progress['status'] == 'paused':
await asyncio.sleep(1) await asyncio.sleep(1)
# Check if download should continue # Check if download should continue
if self._progress['status'] != 'running': if download_progress['status'] != 'running':
logger.info(f"Download stopped: {self._progress['status']}") logger.info(f"Download stopped: {download_progress['status']}")
return False # Return False to indicate no remote download happened return False # Return False to indicate no remote download happened
model_hash = model.get('sha256', '').lower() model_hash = model.get('sha256', '').lower()
@@ -308,16 +297,10 @@ class DownloadManager:
try: try:
# Update current model info # Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Skip if already in failed models
if model_hash in self._progress['failed_models']:
logger.debug(f"Skipping known failed model: {model_name}")
return False
# Skip if already processed AND directory exists with files # Skip if already processed AND directory exists with files
if model_hash in self._progress['processed_models']: if model_hash in download_progress['processed_models']:
model_dir = os.path.join(output_dir, model_hash) model_dir = os.path.join(output_dir, model_hash)
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir)) has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
if has_files: if has_files:
@@ -325,8 +308,6 @@ class DownloadManager:
return False return False
else: else:
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
# Remove from processed models since we need to reprocess
self._progress['processed_models'].discard(model_hash)
# Create model directory # Create model directory
model_dir = os.path.join(output_dir, model_hash) model_dir = os.path.join(output_dir, model_hash)
@@ -342,7 +323,7 @@ class DownloadManager:
await MetadataUpdater.update_metadata_from_local_examples( await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir model_hash, model, scanner_type, scanner, model_dir
) )
self._progress['processed_models'].add(model_hash) download_progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened return False # Return False to indicate no remote download happened
# If no local images, try to download from remote # If no local images, try to download from remote
@@ -350,13 +331,13 @@ class DownloadManager:
images = model.get('civitai', {}).get('images', []) images = model.get('civitai', {}).get('images', [])
success, is_stale = await ExampleImagesProcessor.download_model_images( success, is_stale = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, images, model_dir, optimize, downloader model_hash, model_name, images, model_dir, optimize, independent_session
) )
# If metadata is stale, try to refresh it # If metadata is stale, try to refresh it
if is_stale and model_hash not in self._progress['refreshed_models']: if is_stale and model_hash not in download_progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata( await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner, self._progress model_hash, model_name, scanner_type, scanner
) )
# Get the updated model data # Get the updated model data
@@ -368,41 +349,32 @@ class DownloadManager:
# Retry download with updated metadata # Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', []) updated_images = updated_model.get('civitai', {}).get('images', [])
success, _ = await ExampleImagesProcessor.download_model_images( success, _ = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, updated_images, model_dir, optimize, downloader model_hash, model_name, updated_images, model_dir, optimize, independent_session
) )
self._progress['refreshed_models'].add(model_hash) # Only mark as processed if all images were downloaded successfully
# Mark as processed if successful, or as failed if unsuccessful after refresh
if success: if success:
self._progress['processed_models'].add(model_hash) download_progress['processed_models'].add(model_hash)
else:
# If we refreshed metadata and still failed, mark as permanently failed
if model_hash in self._progress['refreshed_models']:
self._progress['failed_models'].add(model_hash)
logger.info(f"Marking model {model_name} as failed after metadata refresh")
return True # Return True to indicate a remote download happened return True # Return True to indicate a remote download happened
else:
# No civitai data or images available, mark as failed to avoid future attempts
self._progress['failed_models'].add(model_hash)
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
# Save progress periodically # Save progress periodically
if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1:
self._save_progress(output_dir) DownloadManager._save_progress(output_dir)
return False # Default return if no conditions met return False # Default return if no conditions met
except Exception as e: except Exception as e:
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
self._progress['errors'].append(error_msg) download_progress['errors'].append(error_msg)
self._progress['last_error'] = error_msg download_progress['last_error'] = error_msg
return False # Return False on exception return False # Return False on exception
def _save_progress(self, output_dir): @staticmethod
"""Save download progress to file.""" def _save_progress(output_dir):
"""Save download progress to file"""
global download_progress
try: try:
progress_file = os.path.join(output_dir, '.download_progress.json') progress_file = os.path.join(output_dir, '.download_progress.json')
@@ -417,11 +389,10 @@ class DownloadManager:
# Create new progress data # Create new progress data
progress_data = { progress_data = {
'processed_models': list(self._progress['processed_models']), 'processed_models': list(download_progress['processed_models']),
'refreshed_models': list(self._progress['refreshed_models']), 'refreshed_models': list(download_progress['refreshed_models']),
'failed_models': list(self._progress['failed_models']), 'completed': download_progress['completed'],
'completed': self._progress['completed'], 'total': download_progress['total'],
'total': self._progress['total'],
'last_update': time.time() 'last_update': time.time()
} }
@@ -435,343 +406,3 @@ class DownloadManager:
json.dump(progress_data, f, indent=2) json.dump(progress_data, f, indent=2)
except Exception as e: except Exception as e:
logger.error(f"Failed to save progress file: {e}") logger.error(f"Failed to save progress file: {e}")
async def start_force_download(self, options: dict):
"""Force download example images for specific models."""
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
data = options or {}
model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2))
if not model_hashes:
raise DownloadConfigurationError('Missing model_hashes parameter')
output_dir = settings.get('example_images_path')
if not output_dir:
raise DownloadConfigurationError('Example images path not configured in settings')
os.makedirs(output_dir, exist_ok=True)
self._progress.reset()
self._progress['total'] = len(model_hashes)
self._progress['status'] = 'running'
self._progress['start_time'] = time.time()
self._progress['end_time'] = None
self._is_downloading = True
await self._broadcast_progress(status='running')
try:
result = await self._download_specific_models_example_images_sync(
model_hashes,
output_dir,
optimize,
model_types,
delay
)
async with self._state_lock:
self._is_downloading = False
return {
'success': True,
'message': 'Force download completed',
'result': result
}
except Exception as e:
async with self._state_lock:
self._is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
await self._broadcast_progress(status='error', extra={'error': str(e)})
raise ExampleImagesDownloadError(str(e)) from e
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
"""Download example images for specific models only - synchronous version."""
downloader = await get_downloader()
try:
# Get scanners
scanners = []
if 'lora' in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(('lora', lora_scanner))
if 'checkpoint' in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(('checkpoint', checkpoint_scanner))
if 'embedding' in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(('embedding', embedding_scanner))
# Find the specified models
models_to_process = []
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
if model.get('sha256') in model_hashes:
models_to_process.append((scanner_type, model, scanner))
# Update total count based on found models
self._progress['total'] = len(models_to_process)
logger.debug(f"Found {self._progress['total']} models to process")
# Send initial progress via WebSocket
await self._broadcast_progress(status='running')
# Process each model
success_count = 0
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
# Force process this model regardless of previous status
was_successful = await self._process_specific_model(
scanner_type, model, scanner,
output_dir, optimize, downloader
)
if was_successful:
success_count += 1
# Update progress
self._progress['completed'] += 1
# Send progress update via WebSocket
await self._broadcast_progress(status='running')
# Only add delay after remote download, and not after processing the last model
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
await asyncio.sleep(delay)
# Mark as completed
self._progress['status'] = 'completed'
self._progress['end_time'] = time.time()
logger.debug(
"Forced example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
# Send final progress via WebSocket
await self._broadcast_progress(status='completed')
return {
'total': self._progress['total'],
'processed': self._progress['completed'],
'successful': success_count,
'errors': self._progress['errors']
}
except Exception as e:
error_msg = f"Error during forced example images download: {str(e)}"
logger.error(error_msg, exc_info=True)
self._progress['errors'].append(error_msg)
self._progress['last_error'] = error_msg
self._progress['status'] = 'error'
self._progress['end_time'] = time.time()
# Send error status via WebSocket
await self._broadcast_progress(status='error', extra={'error': error_msg})
raise
finally:
# No need to close any sessions since we use the global downloader
pass
async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
"""Process a specific model for forced download, ignoring previous download status."""
# Check if download is paused
while self._progress['status'] == 'paused':
await asyncio.sleep(1)
# Check if download should continue
if self._progress['status'] != 'running':
logger.info(f"Download stopped: {self._progress['status']}")
return False
model_hash = model.get('sha256', '').lower()
model_name = model.get('model_name', 'Unknown')
model_file_path = model.get('file_path', '')
model_file_name = model.get('file_name', '')
try:
# Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
os.makedirs(model_dir, exist_ok=True)
# First check for local example images - local processing doesn't need delay
local_images_processed = await ExampleImagesProcessor.process_local_examples(
model_file_path, model_file_name, model_name, model_dir, optimize
)
# If we processed local images, update metadata
if local_images_processed:
await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir
)
self._progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened
# If no local images, try to download from remote
elif model.get('civitai') and model.get('civitai', {}).get('images'):
images = model.get('civitai', {}).get('images', [])
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, images, model_dir, optimize, downloader
)
# If metadata is stale, try to refresh it
if is_stale and model_hash not in self._progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner, self._progress
)
# Get the updated model data
updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner
)
if updated_model and updated_model.get('civitai', {}).get('images'):
# Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', [])
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, updated_images, model_dir, optimize, downloader
)
# Combine failed images from both attempts
failed_images.extend(additional_failed_images)
self._progress['refreshed_models'].add(model_hash)
# For forced downloads, remove failed images from metadata
if failed_images:
# Create a copy of images excluding failed ones
await self._remove_failed_images_from_metadata(
model_hash, model_name, failed_images, scanner
)
# Mark as processed
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
self._progress['processed_models'].add(model_hash)
return True # Return True to indicate a remote download happened
else:
logger.debug(f"No civitai images available for model {model_name}")
return False
except Exception as e:
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
logger.error(error_msg, exc_info=True)
self._progress['errors'].append(error_msg)
self._progress['last_error'] = error_msg
return False # Return False on exception
async def _remove_failed_images_from_metadata(self, model_hash, model_name, failed_images, scanner):
"""Remove failed images from model metadata"""
try:
# Get current model data
model_data = await MetadataUpdater.get_updated_model(model_hash, scanner)
if not model_data:
logger.warning(f"Could not find model data for {model_name} to remove failed images")
return
if not model_data.get('civitai', {}).get('images'):
logger.warning(f"No images in metadata for {model_name}")
return
# Get current images
current_images = model_data['civitai']['images']
# Filter out failed images
updated_images = [img for img in current_images if img.get('url') not in failed_images]
# If images were removed, update metadata
if len(updated_images) < len(current_images):
removed_count = len(current_images) - len(updated_images)
logger.info(f"Removing {removed_count} failed images from metadata for {model_name}")
# Update the images list
model_data['civitai']['images'] = updated_images
# Save metadata to file
file_path = model_data.get('file_path')
if file_path:
# Create a copy of model data without 'folder' field
model_copy = model_data.copy()
model_copy.pop('folder', None)
# Write metadata to file
await MetadataManager.save_metadata(file_path, model_copy)
logger.info(f"Saved updated metadata for {model_name} after removing failed images")
# Update the scanner cache
await scanner.update_single_model_cache(file_path, file_path, model_data)
except Exception as e:
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
async def _broadcast_progress(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> None:
payload = self._build_progress_payload(status=status, extra=extra)
try:
await self._ws_manager.broadcast(payload)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to broadcast example image progress: %s", exc)
def _build_progress_payload(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': status or self._progress['status'],
'current_model': self._progress['current_model'],
}
if self._progress['errors']:
payload['errors'] = list(self._progress['errors'])
if self._progress['last_error']:
payload['last_error'] = self._progress['last_error']
if extra:
payload.update(extra)
return payload
_default_download_manager: DownloadManager | None = None
def get_default_download_manager(ws_manager) -> DownloadManager:
"""Return the singleton download manager used by default routes."""
global _default_download_manager
if (
_default_download_manager is None
or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager
):
_default_download_manager = DownloadManager(ws_manager=ws_manager)
return _default_download_manager

View File

@@ -1,5 +1,6 @@
import logging import logging
import os import os
import re
import sys import sys
import subprocess import subprocess
from aiohttp import web from aiohttp import web
@@ -42,14 +43,6 @@ class ExampleImagesFileManager:
# Construct folder path for this model # Construct folder path for this model
model_folder = os.path.join(example_images_path, model_hash) model_folder = os.path.join(example_images_path, model_hash)
model_folder = os.path.abspath(model_folder) # Get absolute path
# Path validation: ensure model_folder is under example_images_path
if not model_folder.startswith(os.path.abspath(example_images_path)):
return web.json_response({
'success': False,
'error': 'Invalid model folder path'
}, status=400)
# Check if folder exists # Check if folder exists
if not os.path.exists(model_folder): if not os.path.exists(model_folder):

View File

@@ -1,39 +1,19 @@
import logging import logging
import os import os
import re import re
from ..utils.metadata_manager import MetadataManager
from ..recipes.constants import GEN_PARAM_KEYS from ..utils.routes_common import ModelRouteUtils
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
from ..services.metadata_sync_service import MetadataSyncService
from ..services.preview_asset_service import PreviewAssetService
from ..services.settings_manager import settings
from ..services.downloader import get_downloader
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.metadata_manager import MetadataManager from ..recipes.constants import GEN_PARAM_KEYS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_preview_service = PreviewAssetService(
metadata_manager=MetadataManager,
downloader_factory=get_downloader,
exif_utils=ExifUtils,
)
_metadata_sync_service = MetadataSyncService(
metadata_manager=MetadataManager,
preview_service=_preview_service,
settings=settings,
default_metadata_provider_factory=get_default_metadata_provider,
metadata_provider_selector=get_metadata_provider,
)
class MetadataUpdater: class MetadataUpdater:
"""Handles updating model metadata related to example images""" """Handles updating model metadata related to example images"""
@staticmethod @staticmethod
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None): async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner):
"""Refresh model metadata from CivitAI """Refresh model metadata from CivitAI
Args: Args:
@@ -45,6 +25,8 @@ class MetadataUpdater:
Returns: Returns:
bool: True if metadata was successfully refreshed, False otherwise bool: True if metadata was successfully refreshed, False otherwise
""" """
from ..utils.example_images_download_manager import download_progress
try: try:
# Find the model in the scanner cache # Find the model in the scanner cache
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
@@ -65,32 +47,31 @@ class MetadataUpdater:
return False return False
# Track that we're refreshing this model # Track that we're refreshing this model
if progress is not None: download_progress['refreshed_models'].add(model_hash)
progress['refreshed_models'].add(model_hash)
# Use ModelRouteUtils to refresh metadata
async def update_cache_func(old_path, new_path, metadata): async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata) return await scanner.update_single_model_cache(old_path, new_path, metadata)
success, error = await _metadata_sync_service.fetch_and_update_model( success = await ModelRouteUtils.fetch_and_update_model(
sha256=model_hash, model_hash,
file_path=file_path, file_path,
model_data=model_data, model_data,
update_cache_func=update_cache_func, update_cache_func
) )
if success: if success:
logger.info(f"Successfully refreshed metadata for {model_name}") logger.info(f"Successfully refreshed metadata for {model_name}")
return True return True
else: else:
logger.warning(f"Failed to refresh metadata for {model_name}, {error}") logger.warning(f"Failed to refresh metadata for {model_name}")
return False return False
except Exception as e: except Exception as e:
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}" error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
if progress is not None: download_progress['errors'].append(error_msg)
progress['errors'].append(error_msg) download_progress['last_error'] = error_msg
progress['last_error'] = error_msg
return False return False
@staticmethod @staticmethod

Some files were not shown because too many files have changed in this diff Show More