mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
12c88835f2 | ||
|
|
6f4453aaf3 | ||
|
|
4b4b8fe3c1 | ||
|
|
49e7c2e9f5 | ||
|
|
4653c273e3 | ||
|
|
ae145de2f2 | ||
|
|
dde7cf71c6 | ||
|
|
219cd242db | ||
|
|
e5b712c082 | ||
|
|
4d2c60d59b | ||
|
|
1d2c1b114b | ||
|
|
2bde936d05 | ||
|
|
cd3e32bf4b | ||
|
|
454536d631 | ||
|
|
656f1755fd | ||
|
|
8aa76ce5c1 | ||
|
|
49fa37f00d | ||
|
|
9f83548cf3 | ||
|
|
6054d95e85 | ||
|
|
8c9bb35824 | ||
|
|
3eacf9558a | ||
|
|
fee37172b4 | ||
|
|
e128c80eb1 | ||
|
|
5cc735ed57 | ||
|
|
43fcce6361 | ||
|
|
49b7126278 | ||
|
|
679cfb5c69 | ||
|
|
50616bc680 | ||
|
|
aaad270822 | ||
|
|
bd10280736 | ||
|
|
d477050239 | ||
|
|
85f79cd8d1 | ||
|
|
613cd81152 | ||
|
|
e0aba6c49a | ||
|
|
d78bcf2494 | ||
|
|
f7cffd2eba | ||
|
|
0d0b91aa80 | ||
|
|
42872e6d2d | ||
|
|
b91f06405d | ||
|
|
dac4c688d6 | ||
|
|
097a68ad18 | ||
|
|
4a98710db0 | ||
|
|
d033a374dd | ||
|
|
6aa23fe36a | ||
|
|
3220cfb79c | ||
|
|
b92e7aa446 | ||
|
|
c3b9c73541 | ||
|
|
81c6672880 | ||
|
|
08baf884d3 | ||
|
|
1c4096f3d5 | ||
|
|
66a3f3f59a | ||
|
|
624df1328b | ||
|
|
c063854b51 | ||
|
|
8cf99dd928 | ||
|
|
c07e885725 | ||
|
|
21772feadd | ||
|
|
2d00cfdd31 | ||
|
|
49e03d658b | ||
|
|
fec85bcc08 | ||
|
|
0e93a6bcb0 | ||
|
|
7e20f738fb | ||
|
|
24090e6077 | ||
|
|
1022b07f64 | ||
|
|
4faf912c6f | ||
|
|
56e4b24b07 | ||
|
|
12295d2fdc | ||
|
|
6261f7d18d |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -6,3 +6,5 @@ py/run_test.py
|
|||||||
.vscode/
|
.vscode/
|
||||||
cache/
|
cache/
|
||||||
civitai/
|
civitai/
|
||||||
|
node_modules/
|
||||||
|
coverage/
|
||||||
|
|||||||
19
AGENTS.md
Normal file
19
AGENTS.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Repository Guidelines
|
||||||
|
|
||||||
|
## Project Structure & Module Organization
|
||||||
|
ComfyUI LoRA Manager pairs a Python backend with lightweight browser scripts. Backend modules live in `py/`, organized by responsibility: HTTP entry points under `routes/`, feature logic in `services/`, reusable helpers within `utils/`, and custom nodes in `nodes/`. Front-end widgets that extend the ComfyUI interface sit in `web/comfyui/`, while static images and templates are in `static/` and `templates/`. Shared localization files are stored in `locales/`, with workflow examples under `example_workflows/`. Tests currently reside alongside the source (`test_i18n.py`) until a dedicated `tests/` folder is introduced.
|
||||||
|
|
||||||
|
## Build, Test, and Development Commands
|
||||||
|
Install dependencies with `pip install -r requirements.txt` from the repo root. Launch the standalone server for iterative work via `python standalone.py --port 8188`; ComfyUI users can also load the extension directly through ComfyUI's custom node manager. Run backend checks with `python -m pytest test_i18n.py`, and target new test files explicitly (e.g. `python -m pytest tests/test_recipes.py` once added). Use `python scripts/sync_translation_keys.py` to reconcile locale keys after updating UI strings.
|
||||||
|
|
||||||
|
## Coding Style & Naming Conventions
|
||||||
|
Follow PEP 8 with four-space indentation and descriptive snake_case module/function names, mirroring files such as `py/services/settings_manager.py`. Classes remain PascalCase, constants UPPER_SNAKE_CASE, and loggers retrieved via `logging.getLogger(__name__)`. Prefer explicit type hints for new public APIs and docstrings that clarify side effects. JavaScript in `web/comfyui/` is modern ES modules; keep imports relative, favor camelCase functions, and mirror existing file suffixes like `_widget.js` for UI components.
|
||||||
|
|
||||||
|
## Testing Guidelines
|
||||||
|
Extend pytest coverage by co-locating tests near the code under test or in `tests/` with names like `test_<feature>.py`. When introducing new routes or services, add regression cases that mock ComfyUI dependencies (see the standalone mocking helpers in `standalone.py`). Prioritize deterministic fixtures for filesystem interactions and ensure translations include coverage when adding new locale keys. Always run `python -m pytest` before submitting work.
|
||||||
|
|
||||||
|
## Commit & Pull Request Guidelines
|
||||||
|
Commits follow the conventional pattern seen in `git log` (`feat(scope):`, `fix(scope):`, `chore(scope):`). Keep messages imperative and scoped to a single change. Pull requests should summarize the problem, detail the solution, list manual test evidence, and link any GitHub issues. Include UI screenshots or GIFs when front-end behavior changes, and call out migration steps (e.g., settings updates) in the PR description.
|
||||||
|
|
||||||
|
## Configuration & Localization Tips
|
||||||
|
Sample configuration defaults live in `settings.json.example`; copy it to `settings.json` and adjust model directories before running the standalone server. Whenever you add UI text, update `locales/<lang>.json` and run the translation sync script. Store reference assets in `civitai/` or `docs/` rather than mixing them with production templates, keeping the runtime folders (`static/`, `templates/`) deploy-ready.
|
||||||
39
__init__.py
39
__init__.py
@@ -1,13 +1,32 @@
|
|||||||
from .py.lora_manager import LoraManager
|
try: # pragma: no cover - import fallback for pytest collection
|
||||||
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader
|
from .py.lora_manager import LoraManager
|
||||||
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
from .py.nodes.lora_loader import LoraManagerLoader, LoraManagerTextLoader
|
||||||
from .py.nodes.lora_stacker import LoraStacker
|
from .py.nodes.trigger_word_toggle import TriggerWordToggle
|
||||||
from .py.nodes.save_image import SaveImage
|
from .py.nodes.lora_stacker import LoraStacker
|
||||||
from .py.nodes.debug_metadata import DebugMetadata
|
from .py.nodes.save_image import SaveImage
|
||||||
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
from .py.nodes.debug_metadata import DebugMetadata
|
||||||
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
from .py.nodes.wanvideo_lora_select import WanVideoLoraSelect
|
||||||
# Import metadata collector to install hooks on startup
|
from .py.nodes.wanvideo_lora_select_from_text import WanVideoLoraSelectFromText
|
||||||
from .py.metadata_collector import init as init_metadata_collector
|
from .py.metadata_collector import init as init_metadata_collector
|
||||||
|
except ImportError: # pragma: no cover - allows running under pytest without package install
|
||||||
|
import importlib
|
||||||
|
import pathlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
package_root = pathlib.Path(__file__).resolve().parent
|
||||||
|
if str(package_root) not in sys.path:
|
||||||
|
sys.path.append(str(package_root))
|
||||||
|
|
||||||
|
LoraManager = importlib.import_module("py.lora_manager").LoraManager
|
||||||
|
LoraManagerLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerLoader
|
||||||
|
LoraManagerTextLoader = importlib.import_module("py.nodes.lora_loader").LoraManagerTextLoader
|
||||||
|
TriggerWordToggle = importlib.import_module("py.nodes.trigger_word_toggle").TriggerWordToggle
|
||||||
|
LoraStacker = importlib.import_module("py.nodes.lora_stacker").LoraStacker
|
||||||
|
SaveImage = importlib.import_module("py.nodes.save_image").SaveImage
|
||||||
|
DebugMetadata = importlib.import_module("py.nodes.debug_metadata").DebugMetadata
|
||||||
|
WanVideoLoraSelect = importlib.import_module("py.nodes.wanvideo_lora_select").WanVideoLoraSelect
|
||||||
|
WanVideoLoraSelectFromText = importlib.import_module("py.nodes.wanvideo_lora_select_from_text").WanVideoLoraSelectFromText
|
||||||
|
init_metadata_collector = importlib.import_module("py.metadata_collector").init
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
LoraManagerLoader.NAME: LoraManagerLoader,
|
LoraManagerLoader.NAME: LoraManagerLoader,
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
The **LoRA Manager Civitai Extension** is a Browser extension designed to work seamlessly with [LoRA Manager](https://github.com/willmiao/ComfyUI-Lora-Manager) to significantly enhance your browsing experience on [Civitai](https://civitai.com). With this extension, you can:
|
The **LoRA Manager Civitai Extension** is a Browser extension designed to work seamlessly with [LoRA Manager](https://github.com/willmiao/ComfyUI-Lora-Manager) to significantly enhance your browsing experience on [Civitai](https://civitai.com).
|
||||||
|
It also supports browsing on [CivArchive](https://civarchive.com/) (formerly CivitaiArchive).
|
||||||
|
|
||||||
|
With this extension, you can:
|
||||||
|
|
||||||
✅ Instantly see which models are already present in your local library
|
✅ Instantly see which models are already present in your local library
|
||||||
✅ Download new models with a single click
|
✅ Download new models with a single click
|
||||||
@@ -8,6 +11,7 @@ The **LoRA Manager Civitai Extension** is a Browser extension designed to work s
|
|||||||
✅ Keep your downloaded models automatically organized according to your custom settings
|
✅ Keep your downloaded models automatically organized according to your custom settings
|
||||||
|
|
||||||

|

|
||||||
|

|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
93
docs/architecture/example_images_routes.md
Normal file
93
docs/architecture/example_images_routes.md
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Example image route architecture
|
||||||
|
|
||||||
|
The example image routing stack mirrors the layered model route stack described in
|
||||||
|
[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup,
|
||||||
|
handler orchestration, and long-running workflows now live in clearly separated modules so
|
||||||
|
we can extend download/import behaviour without touching the entire feature surface.
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
subgraph HTTP
|
||||||
|
A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller]
|
||||||
|
end
|
||||||
|
subgraph Application
|
||||||
|
B --> C[ExampleImagesHandlerSet]
|
||||||
|
C --> D1[Handlers]
|
||||||
|
D1 --> E1[Use cases]
|
||||||
|
E1 --> F1[Download manager / processor / file manager]
|
||||||
|
end
|
||||||
|
subgraph Side Effects
|
||||||
|
F1 --> G1[Filesystem]
|
||||||
|
F1 --> G2[Model metadata]
|
||||||
|
F1 --> G3[WebSocket progress]
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
## Layer responsibilities
|
||||||
|
|
||||||
|
| Layer | Module(s) | Responsibility |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. |
|
||||||
|
| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. |
|
||||||
|
| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. |
|
||||||
|
| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. |
|
||||||
|
| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. |
|
||||||
|
|
||||||
|
## Handler responsibilities & invariants
|
||||||
|
|
||||||
|
`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}`
|
||||||
|
mapping consumed by the registrar. The table below outlines how each handler collaborates
|
||||||
|
with the use cases and utilities.
|
||||||
|
|
||||||
|
| Handler | Key endpoints | Collaborators | Contracts |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. |
|
||||||
|
| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. |
|
||||||
|
| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. |
|
||||||
|
|
||||||
|
## Use case boundaries
|
||||||
|
|
||||||
|
| Use case | Entry point | Dependencies | Guarantees |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. |
|
||||||
|
| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. |
|
||||||
|
|
||||||
|
## Maintaining critical invariants
|
||||||
|
|
||||||
|
* **Shared progress snapshots** - The download handler returns the same snapshot built by
|
||||||
|
`DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket
|
||||||
|
progress events.
|
||||||
|
* **Safe filesystem access** - All folder/file actions flow through
|
||||||
|
`ExampleImagesFileManager`, which validates the configured example image root and ensures
|
||||||
|
responses never leak absolute paths outside the allowed directory.
|
||||||
|
* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`,
|
||||||
|
which updates model metadata via `MetadataManager` and notifies the relevant scanners so
|
||||||
|
cache state stays in sync.
|
||||||
|
|
||||||
|
## Migration notes
|
||||||
|
|
||||||
|
The refactor brings the example image stack in line with the model/recipe stacks:
|
||||||
|
|
||||||
|
1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects
|
||||||
|
should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring
|
||||||
|
handler callables.
|
||||||
|
2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like
|
||||||
|
`BaseModelRoutes`. If you previously instantiated handlers directly, inject custom
|
||||||
|
collaborators via the controller constructor (`download_manager`, `processor`,
|
||||||
|
`file_manager`) to keep test seams predictable.
|
||||||
|
3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching
|
||||||
|
`DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers
|
||||||
|
expect those abstractions to surface validation/concurrency errors, and bypassing them
|
||||||
|
will skip the HTTP-friendly error mapping.
|
||||||
|
|
||||||
|
## Extending the stack
|
||||||
|
|
||||||
|
1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`.
|
||||||
|
2. Expose the coroutine on an existing handler class (or create a new handler and extend
|
||||||
|
`ExampleImagesHandlerSet`).
|
||||||
|
3. Wire additional services or factories inside `_build_handler_set` on
|
||||||
|
`ExampleImagesRoutes`, mirroring how the model stack introduces new use cases.
|
||||||
|
|
||||||
|
`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause
|
||||||
|
flows, and import validations. Use it as a template when introducing new handler
|
||||||
|
collaborators or error mappings.
|
||||||
100
docs/architecture/model_routes.md
Normal file
100
docs/architecture/model_routes.md
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# Base model route architecture
|
||||||
|
|
||||||
|
The model routing stack now splits HTTP wiring, orchestration logic, and
|
||||||
|
business rules into discrete layers. The goal is to make it obvious where a
|
||||||
|
new collaborator should live and which contract it must honour. The diagram
|
||||||
|
below captures the end-to-end flow for a typical request:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
subgraph HTTP
|
||||||
|
A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy]
|
||||||
|
end
|
||||||
|
subgraph Application
|
||||||
|
B --> C[ModelHandlerSet]
|
||||||
|
C --> D1[Handlers]
|
||||||
|
D1 --> E1[Use cases]
|
||||||
|
E1 --> F1[Services / scanners]
|
||||||
|
end
|
||||||
|
subgraph Side Effects
|
||||||
|
F1 --> G1[Cache & metadata]
|
||||||
|
F1 --> G2[Filesystem]
|
||||||
|
F1 --> G3[WebSocket state]
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
Every box maps to a concrete module:
|
||||||
|
|
||||||
|
| Layer | Module(s) | Responsibility |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. |
|
||||||
|
| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. |
|
||||||
|
| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). |
|
||||||
|
| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. |
|
||||||
|
| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. |
|
||||||
|
|
||||||
|
## Handler responsibilities & contracts
|
||||||
|
|
||||||
|
`ModelHandlerSet` flattens the handler objects into the exact callables used by
|
||||||
|
the registrar. The table below highlights the separation of concerns within
|
||||||
|
the set and the invariants that must hold after each handler returns.
|
||||||
|
|
||||||
|
| Handler | Key endpoints | Collaborators | Contracts |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
|
||||||
|
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
|
||||||
|
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
|
||||||
|
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
|
||||||
|
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
|
||||||
|
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
|
||||||
|
| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. |
|
||||||
|
| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. |
|
||||||
|
|
||||||
|
## Use case boundaries
|
||||||
|
|
||||||
|
Each use case exposes a narrow asynchronous API that hides the underlying
|
||||||
|
services. Their error mapping is essential for predictable HTTP responses.
|
||||||
|
|
||||||
|
| Use case | Entry point | Dependencies | Guarantees |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. |
|
||||||
|
| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. |
|
||||||
|
| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. |
|
||||||
|
|
||||||
|
## Maintaining legacy contracts
|
||||||
|
|
||||||
|
The refactor preserves the invariants called out in the previous architecture
|
||||||
|
notes. The most critical ones are reiterated here to emphasise the
|
||||||
|
collaboration points:
|
||||||
|
|
||||||
|
1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are
|
||||||
|
channelled through `ModelManagementHandler`. The handler delegates to
|
||||||
|
`ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is
|
||||||
|
mutated in-place before the handler returns. The accompanying tests assert
|
||||||
|
that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after
|
||||||
|
each mutation.
|
||||||
|
2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new
|
||||||
|
asset, `MetadataSyncService` persists the JSON metadata, and
|
||||||
|
`scanner.update_preview_in_cache` mirrors the change. The handler returns
|
||||||
|
the static URL produced by `config.get_preview_static_url`, keeping browser
|
||||||
|
clients in lockstep with disk state.
|
||||||
|
3. **Download progress** – `DownloadCoordinator.schedule_download` generates the
|
||||||
|
download identifier, registers a WebSocket progress callback, and caches the
|
||||||
|
latest numeric progress via `WebSocketManager`. Both `download_model`
|
||||||
|
responses and `/download-progress/{id}` polling read from the same cache to
|
||||||
|
guarantee consistent progress reporting across transports.
|
||||||
|
|
||||||
|
## Extending the stack
|
||||||
|
|
||||||
|
To add a new shared route:
|
||||||
|
|
||||||
|
1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name.
|
||||||
|
2. Implement the corresponding coroutine on one of the handlers inside
|
||||||
|
`ModelHandlerSet` (or introduce a new handler class when the concern does not
|
||||||
|
fit existing ones).
|
||||||
|
3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by
|
||||||
|
wiring services or use cases through the constructor parameters.
|
||||||
|
|
||||||
|
Model-specific routes should continue to be registered inside the subclass
|
||||||
|
implementation of `setup_specific_routes`, reusing the shared registrar where
|
||||||
|
possible.
|
||||||
89
docs/architecture/recipe_routes.md
Normal file
89
docs/architecture/recipe_routes.md
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# Recipe route architecture
|
||||||
|
|
||||||
|
The recipe routing stack now mirrors the modular model route design. HTTP
|
||||||
|
bindings, controller wiring, handler orchestration, and business rules live in
|
||||||
|
separate layers so new behaviours can be added without re-threading the entire
|
||||||
|
feature. The diagram below outlines the flow for a typical request:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
subgraph HTTP
|
||||||
|
A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller]
|
||||||
|
end
|
||||||
|
subgraph Application
|
||||||
|
B --> C[RecipeHandlerSet]
|
||||||
|
C --> D1[Handlers]
|
||||||
|
D1 --> E1[Use cases]
|
||||||
|
E1 --> F1[Services / scanners]
|
||||||
|
end
|
||||||
|
subgraph Side Effects
|
||||||
|
F1 --> G1[Cache & fingerprint index]
|
||||||
|
F1 --> G2[Metadata files]
|
||||||
|
F1 --> G3[Temporary shares]
|
||||||
|
end
|
||||||
|
```
|
||||||
|
|
||||||
|
## Layer responsibilities
|
||||||
|
|
||||||
|
| Layer | Module(s) | Responsibility |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. |
|
||||||
|
| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. |
|
||||||
|
| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. |
|
||||||
|
| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. |
|
||||||
|
|
||||||
|
## Handler responsibilities & invariants
|
||||||
|
|
||||||
|
`RecipeHandlerSet` flattens purpose-built handler objects into the callables the
|
||||||
|
registrar binds. Each handler is responsible for a narrow concern and enforces a
|
||||||
|
set of invariants before returning:
|
||||||
|
|
||||||
|
| Handler | Key endpoints | Collaborators | Contracts |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. |
|
||||||
|
| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. |
|
||||||
|
| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. |
|
||||||
|
| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. |
|
||||||
|
| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. |
|
||||||
|
| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. |
|
||||||
|
|
||||||
|
## Use case boundaries
|
||||||
|
|
||||||
|
The dedicated services encapsulate long-running work so handlers stay thin.
|
||||||
|
|
||||||
|
| Use case | Entry point | Dependencies | Guarantees |
|
||||||
|
| --- | --- | --- | --- |
|
||||||
|
| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. |
|
||||||
|
| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. |
|
||||||
|
| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. |
|
||||||
|
|
||||||
|
## Maintaining critical invariants
|
||||||
|
|
||||||
|
* **Cache updates** – Mutations (`save`, `delete`, `bulk_delete`, `update`) call
|
||||||
|
back into the recipe scanner to mutate the in-memory cache and fingerprint
|
||||||
|
index before returning a response. Tests assert that these methods are invoked
|
||||||
|
even when stubbing persistence.
|
||||||
|
* **Fingerprint management** – `RecipePersistenceService` recomputes
|
||||||
|
fingerprints whenever LoRA metadata changes and duplicate lookups use those
|
||||||
|
fingerprints to group recipes. Handlers bubble the resulting IDs so clients
|
||||||
|
can merge duplicates without an extra fetch.
|
||||||
|
* **Metadata synchronisation** – Saving or reconnecting a recipe updates the
|
||||||
|
JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the
|
||||||
|
scanner to resort its cache. Sharing relies on this metadata to generate
|
||||||
|
filenames and ensure downloads stay in sync with on-disk state.
|
||||||
|
|
||||||
|
## Extending the stack
|
||||||
|
|
||||||
|
1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name.
|
||||||
|
2. Implement the coroutine on an existing handler or introduce a new handler
|
||||||
|
class inside `py/routes/handlers/recipe_handlers.py` when the concern does
|
||||||
|
not fit existing ones.
|
||||||
|
3. Wire additional collaborators inside
|
||||||
|
`BaseRecipeRoutes._create_handler_set` (inject new services or factories) and
|
||||||
|
expose helper getters on the handler owner if the handler needs to share
|
||||||
|
utilities.
|
||||||
|
|
||||||
|
Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing,
|
||||||
|
mutation, analysis-error, and sharing paths end-to-end, ensuring the controller
|
||||||
|
and handler wiring remains valid as new capabilities are added.
|
||||||
|
|
||||||
23
docs/frontend-testing-roadmap.md
Normal file
23
docs/frontend-testing-roadmap.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Frontend Automation Testing Roadmap
|
||||||
|
|
||||||
|
This roadmap tracks the planned rollout of automated testing for the ComfyUI LoRA Manager frontend. Each phase builds on the infrastructure introduced in this change set and records progress so future contributors can quickly identify the next tasks.
|
||||||
|
|
||||||
|
## Phase Overview
|
||||||
|
|
||||||
|
| Phase | Goal | Primary Focus | Status | Notes |
|
||||||
|
| --- | --- | --- | --- | --- |
|
||||||
|
| Phase 0 | Establish baseline tooling | Add Node test runner, jsdom environment, and seed smoke tests | ✅ Complete | Vitest + jsdom configured, example state tests committed |
|
||||||
|
| Phase 1 | Cover state management logic | Unit test selectors, derived data helpers, and storage utilities under `static/js/state` and `static/js/utils` | ✅ Complete | Storage helpers and state selectors now exercised via deterministic suites |
|
||||||
|
| Phase 2 | Test AppCore orchestration | Simulate page bootstrapping, infinite scroll hooks, and manager registration using JSDOM DOM fixtures | 🟡 In Progress | AppCore initialization specs landed; expand to additional page wiring and scroll hooks |
|
||||||
|
| Phase 3 | Validate page-specific managers | Add focused suites for `loras`, `checkpoints`, `embeddings`, and `recipes` managers covering filtering, sorting, and bulk actions | ⚪ Not Started | Consider shared helpers for mocking API modules and storage |
|
||||||
|
| Phase 4 | Interaction-level regression tests | Exercise template fragments, modals, and menus to ensure UI wiring remains intact | ⚪ Not Started | Evaluate Playwright component testing or happy-path DOM snapshots |
|
||||||
|
| Phase 5 | Continuous integration & coverage | Integrate frontend tests into CI workflow and track coverage metrics | ⚪ Not Started | Align reporting directories with backend coverage for unified reporting |
|
||||||
|
|
||||||
|
## Next Steps Checklist
|
||||||
|
|
||||||
|
- [x] Expand unit tests for `storageHelpers` covering migrations and namespace behavior.
|
||||||
|
- [ ] Document DOM fixture strategy for reproducing template structures in tests.
|
||||||
|
- [x] Prototype AppCore initialization test that verifies manager bootstrapping with stubbed dependencies.
|
||||||
|
- [ ] Evaluate integrating coverage reporting once test surface grows (> 20 specs).
|
||||||
|
|
||||||
|
Maintaining this roadmap alongside code changes will make it easier to append new automated test tasks and update their progress.
|
||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "Keine Remote-Beispielbilder für dieses Modell auf Civitai verfügbar"
|
"noRemoteImagesAvailable": "Keine Remote-Beispielbilder für dieses Modell auf Civitai verfügbar"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "Beispielbilder herunterladen",
|
||||||
|
"missingPath": "Bitte legen Sie einen Speicherort fest, bevor Sie Beispielbilder herunterladen.",
|
||||||
|
"unavailable": "Beispielbild-Downloads sind noch nicht verfügbar. Versuchen Sie es erneut, nachdem die Seite vollständig geladen ist."
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "Beispielbild-Ordner bereinigen",
|
||||||
|
"success": "{count} Ordner wurden in den Papierkorb verschoben",
|
||||||
|
"none": "Keine Beispielbild-Ordner mussten bereinigt werden",
|
||||||
|
"partial": "Bereinigung abgeschlossen, {failures} Ordner übersprungen",
|
||||||
|
"error": "Fehler beim Bereinigen der Beispielbild-Ordner: {message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA Manager",
|
"appTitle": "LoRA Manager",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "Stärke Min",
|
"strengthMin": "Stärke Min",
|
||||||
"strengthMax": "Stärke Max",
|
"strengthMax": "Stärke Max",
|
||||||
"strength": "Stärke",
|
"strength": "Stärke",
|
||||||
|
"clipStrength": "Clip-Stärke",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "Wert",
|
"valuePlaceholder": "Wert",
|
||||||
"add": "Hinzufügen"
|
"add": "Hinzufügen"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "No remote example images available for this model on Civitai"
|
"noRemoteImagesAvailable": "No remote example images available for this model on Civitai"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "Download example images",
|
||||||
|
"missingPath": "Set a download location before downloading example images.",
|
||||||
|
"unavailable": "Example image downloads aren't available yet. Try again after the page finishes loading."
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "Clean up example image folders",
|
||||||
|
"success": "Moved {count} folder(s) to the deleted folder",
|
||||||
|
"none": "No example image folders needed cleanup",
|
||||||
|
"partial": "Cleanup completed with {failures} folder(s) skipped",
|
||||||
|
"error": "Failed to clean example image folders: {message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA Manager",
|
"appTitle": "LoRA Manager",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "Strength Min",
|
"strengthMin": "Strength Min",
|
||||||
"strengthMax": "Strength Max",
|
"strengthMax": "Strength Max",
|
||||||
"strength": "Strength",
|
"strength": "Strength",
|
||||||
|
"clipStrength": "Clip Strength",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "Value",
|
"valuePlaceholder": "Value",
|
||||||
"add": "Add"
|
"add": "Add"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "No hay imágenes de ejemplo remotas disponibles para este modelo en Civitai"
|
"noRemoteImagesAvailable": "No hay imágenes de ejemplo remotas disponibles para este modelo en Civitai"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "Descargar imágenes de ejemplo",
|
||||||
|
"missingPath": "Establece una ubicación de descarga antes de descargar imágenes de ejemplo.",
|
||||||
|
"unavailable": "Las descargas de imágenes de ejemplo aún no están disponibles. Intenta de nuevo después de que la página termine de cargar."
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "Limpiar carpetas de imágenes de ejemplo",
|
||||||
|
"success": "Se movieron {count} carpeta(s) a la carpeta de eliminados",
|
||||||
|
"none": "No hay carpetas de imágenes de ejemplo que necesiten limpieza",
|
||||||
|
"partial": "Limpieza completada con {failures} carpeta(s) omitidas",
|
||||||
|
"error": "No se pudieron limpiar las carpetas de imágenes de ejemplo: {message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA Manager",
|
"appTitle": "LoRA Manager",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "Fuerza mínima",
|
"strengthMin": "Fuerza mínima",
|
||||||
"strengthMax": "Fuerza máxima",
|
"strengthMax": "Fuerza máxima",
|
||||||
"strength": "Fuerza",
|
"strength": "Fuerza",
|
||||||
|
"clipStrength": "Fuerza de Clip",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "Valor",
|
"valuePlaceholder": "Valor",
|
||||||
"add": "Añadir"
|
"add": "Añadir"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "Aucune image d'exemple distante disponible pour ce modèle sur Civitai"
|
"noRemoteImagesAvailable": "Aucune image d'exemple distante disponible pour ce modèle sur Civitai"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "Télécharger les images d'exemple",
|
||||||
|
"missingPath": "Définissez un emplacement de téléchargement avant de télécharger les images d'exemple.",
|
||||||
|
"unavailable": "Le téléchargement des images d'exemple n'est pas encore disponible. Réessayez après le chargement complet de la page."
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "Nettoyer les dossiers d'images d'exemple",
|
||||||
|
"success": "{count} dossier(s) déplacé(s) vers le dossier supprimé",
|
||||||
|
"none": "Aucun dossier d'images d'exemple à nettoyer",
|
||||||
|
"partial": "Nettoyage terminé avec {failures} dossier(s) ignoré(s)",
|
||||||
|
"error": "Échec du nettoyage des dossiers d'images d'exemple : {message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA Manager",
|
"appTitle": "LoRA Manager",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "Force Min",
|
"strengthMin": "Force Min",
|
||||||
"strengthMax": "Force Max",
|
"strengthMax": "Force Max",
|
||||||
"strength": "Force",
|
"strength": "Force",
|
||||||
|
"clipStrength": "Force Clip",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "Valeur",
|
"valuePlaceholder": "Valeur",
|
||||||
"add": "Ajouter"
|
"add": "Ajouter"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "このモデルのCivitaiでのリモート例画像は利用できません"
|
"noRemoteImagesAvailable": "このモデルのCivitaiでのリモート例画像は利用できません"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "例画像をダウンロード",
|
||||||
|
"missingPath": "例画像をダウンロードする前にダウンロード場所を設定してください。",
|
||||||
|
"unavailable": "例画像のダウンロードはまだ利用できません。ページの読み込みが完了してから再度お試しください。"
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "例画像フォルダをクリーンアップ",
|
||||||
|
"success": "{count} 個のフォルダを削除フォルダに移動しました",
|
||||||
|
"none": "クリーンアップが必要な例画像フォルダはありません",
|
||||||
|
"partial": "クリーンアップが完了しましたが、{failures} 個のフォルダはスキップされました",
|
||||||
|
"error": "例画像フォルダのクリーンアップに失敗しました:{message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA Manager",
|
"appTitle": "LoRA Manager",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "強度最小",
|
"strengthMin": "強度最小",
|
||||||
"strengthMax": "強度最大",
|
"strengthMax": "強度最大",
|
||||||
"strength": "強度",
|
"strength": "強度",
|
||||||
|
"clipStrength": "クリップ強度",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "値",
|
"valuePlaceholder": "値",
|
||||||
"add": "追加"
|
"add": "追加"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "Civitai에서 이 모델의 원격 예시 이미지를 사용할 수 없습니다"
|
"noRemoteImagesAvailable": "Civitai에서 이 모델의 원격 예시 이미지를 사용할 수 없습니다"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "예시 이미지 다운로드",
|
||||||
|
"missingPath": "예시 이미지를 다운로드하기 전에 다운로드 위치를 설정하세요.",
|
||||||
|
"unavailable": "예시 이미지 다운로드는 아직 사용할 수 없습니다. 페이지 로딩이 완료된 후 다시 시도하세요."
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "예시 이미지 폴더 정리",
|
||||||
|
"success": "{count}개의 폴더가 삭제 폴더로 이동되었습니다",
|
||||||
|
"none": "정리가 필요한 예시 이미지 폴더가 없습니다",
|
||||||
|
"partial": "정리가 완료되었으나 {failures}개의 폴더가 건너뛰어졌습니다",
|
||||||
|
"error": "예시 이미지 폴더 정리에 실패했습니다: {message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA Manager",
|
"appTitle": "LoRA Manager",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "최소 강도",
|
"strengthMin": "최소 강도",
|
||||||
"strengthMax": "최대 강도",
|
"strengthMax": "최대 강도",
|
||||||
"strength": "강도",
|
"strength": "강도",
|
||||||
|
"clipStrength": "클립 강도",
|
||||||
"clipSkip": "클립 스킵",
|
"clipSkip": "클립 스킵",
|
||||||
"valuePlaceholder": "값",
|
"valuePlaceholder": "값",
|
||||||
"add": "추가"
|
"add": "추가"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "Нет удаленных примеров изображений для этой модели на Civitai"
|
"noRemoteImagesAvailable": "Нет удаленных примеров изображений для этой модели на Civitai"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "Загрузить примеры изображений",
|
||||||
|
"missingPath": "Укажите место загрузки перед загрузкой примеров изображений.",
|
||||||
|
"unavailable": "Загрузка примеров изображений пока недоступна. Попробуйте снова после полной загрузки страницы."
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "Очистить папки с примерами изображений",
|
||||||
|
"success": "Перемещено {count} папок в папку удалённых",
|
||||||
|
"none": "Нет папок с примерами изображений, требующих очистки",
|
||||||
|
"partial": "Очистка завершена, пропущено {failures} папок",
|
||||||
|
"error": "Не удалось очистить папки с примерами изображений: {message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA Manager",
|
"appTitle": "LoRA Manager",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "Мин. сила",
|
"strengthMin": "Мин. сила",
|
||||||
"strengthMax": "Макс. сила",
|
"strengthMax": "Макс. сила",
|
||||||
"strength": "Сила",
|
"strength": "Сила",
|
||||||
|
"clipStrength": "Сила клипа",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "Значение",
|
"valuePlaceholder": "Значение",
|
||||||
"add": "Добавить"
|
"add": "Добавить"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "此模型在 Civitai 上没有远程示例图片"
|
"noRemoteImagesAvailable": "此模型在 Civitai 上没有远程示例图片"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "下载示例图片",
|
||||||
|
"missingPath": "请先设置下载位置后再下载示例图片。",
|
||||||
|
"unavailable": "示例图片下载当前不可用。请在页面加载完成后重试。"
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "清理示例图片文件夹",
|
||||||
|
"success": "已将 {count} 个文件夹移动到已删除文件夹",
|
||||||
|
"none": "没有需要清理的示例图片文件夹",
|
||||||
|
"partial": "清理完成,有 {failures} 个文件夹跳过",
|
||||||
|
"error": "清理示例图片文件夹失败:{message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA 管理器",
|
"appTitle": "LoRA 管理器",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "最小强度",
|
"strengthMin": "最小强度",
|
||||||
"strengthMax": "最大强度",
|
"strengthMax": "最大强度",
|
||||||
"strength": "强度",
|
"strength": "强度",
|
||||||
|
"clipStrength": "Clip 强度",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "数值",
|
"valuePlaceholder": "数值",
|
||||||
"add": "添加"
|
"add": "添加"
|
||||||
|
|||||||
@@ -122,6 +122,20 @@
|
|||||||
"noRemoteImagesAvailable": "此模型在 Civitai 上無遠端範例圖片"
|
"noRemoteImagesAvailable": "此模型在 Civitai 上無遠端範例圖片"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"globalContextMenu": {
|
||||||
|
"downloadExampleImages": {
|
||||||
|
"label": "下載範例圖片",
|
||||||
|
"missingPath": "請先設定下載位置再下載範例圖片。",
|
||||||
|
"unavailable": "範例圖片下載目前尚不可用。請在頁面載入完成後再試一次。"
|
||||||
|
},
|
||||||
|
"cleanupExampleImages": {
|
||||||
|
"label": "清理範例圖片資料夾",
|
||||||
|
"success": "已將 {count} 個資料夾移至已刪除資料夾",
|
||||||
|
"none": "沒有需要清理的範例圖片資料夾",
|
||||||
|
"partial": "清理完成,有 {failures} 個資料夾略過",
|
||||||
|
"error": "清理範例圖片資料夾失敗:{message}"
|
||||||
|
}
|
||||||
|
},
|
||||||
"header": {
|
"header": {
|
||||||
"appTitle": "LoRA 管理器",
|
"appTitle": "LoRA 管理器",
|
||||||
"navigation": {
|
"navigation": {
|
||||||
@@ -727,6 +741,7 @@
|
|||||||
"strengthMin": "最小強度",
|
"strengthMin": "最小強度",
|
||||||
"strengthMax": "最大強度",
|
"strengthMax": "最大強度",
|
||||||
"strength": "強度",
|
"strength": "強度",
|
||||||
|
"clipStrength": "Clip 強度",
|
||||||
"clipSkip": "Clip Skip",
|
"clipSkip": "Clip Skip",
|
||||||
"valuePlaceholder": "數值",
|
"valuePlaceholder": "數值",
|
||||||
"add": "新增"
|
"add": "新增"
|
||||||
|
|||||||
2572
package-lock.json
generated
Normal file
2572
package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
14
package.json
Normal file
14
package.json
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"name": "comfyui-lora-manager-frontend",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"private": true,
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"test": "vitest run",
|
||||||
|
"test:watch": "vitest"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"jsdom": "^24.0.0",
|
||||||
|
"vitest": "^1.6.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
"""Project namespace package."""
|
||||||
|
|
||||||
|
# pytest's internal compatibility layer still imports ``py.path.local`` from the
|
||||||
|
# historical ``py`` dependency. Because this project reuses the ``py`` package
|
||||||
|
# name, we expose a minimal shim so ``py.path.local`` resolves to ``pathlib.Path``
|
||||||
|
# during test runs without pulling in the external dependency.
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
path = SimpleNamespace(local=Path)
|
||||||
|
|
||||||
|
__all__ = ["path"]
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ import platform
|
|||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from typing import List
|
from typing import List
|
||||||
import logging
|
import logging
|
||||||
import sys
|
|
||||||
import json
|
import json
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Use an environment variable to control standalone mode
|
||||||
standalone_mode = 'nodes' not in sys.modules
|
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from .services.service_registry import ServiceRegistry
|
|||||||
from .services.settings_manager import settings
|
from .services.settings_manager import settings
|
||||||
from .utils.example_images_migration import ExampleImagesMigration
|
from .utils.example_images_migration import ExampleImagesMigration
|
||||||
from .services.websocket_manager import ws_manager
|
from .services.websocket_manager import ws_manager
|
||||||
|
from .services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -166,7 +167,7 @@ class LoraManager:
|
|||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
MiscRoutes.setup_routes(app)
|
MiscRoutes.setup_routes(app)
|
||||||
ExampleImagesRoutes.setup_routes(app)
|
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||||
|
|
||||||
# Setup WebSocket routes that are shared across all model types
|
# Setup WebSocket routes that are shared across all model types
|
||||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
@@ -240,7 +241,6 @@ class LoraManager:
|
|||||||
# Run post-initialization tasks
|
# Run post-initialization tasks
|
||||||
post_tasks = [
|
post_tasks = [
|
||||||
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
|
asyncio.create_task(cls._cleanup_backup_files(), name='cleanup_bak_files'),
|
||||||
asyncio.create_task(cls._cleanup_example_images_folders(), name='cleanup_example_images'),
|
|
||||||
# Add more post-initialization tasks here as needed
|
# Add more post-initialization tasks here as needed
|
||||||
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
# asyncio.create_task(cls._another_post_task(), name='another_task'),
|
||||||
]
|
]
|
||||||
@@ -352,116 +352,37 @@ class LoraManager:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _cleanup_example_images_folders(cls):
|
async def _cleanup_example_images_folders(cls):
|
||||||
"""Clean up invalid or empty folders in example images directory"""
|
"""Invoke the example images cleanup service for manual execution."""
|
||||||
try:
|
try:
|
||||||
example_images_path = settings.get('example_images_path')
|
service = ExampleImagesCleanupService()
|
||||||
if not example_images_path or not os.path.exists(example_images_path):
|
result = await service.cleanup_example_image_folders()
|
||||||
logger.debug("Example images path not configured or doesn't exist, skipping cleanup")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug(f"Starting cleanup of example images folders in: {example_images_path}")
|
|
||||||
|
|
||||||
# Get all scanner instances to check hash validity
|
|
||||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
|
||||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
|
||||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
|
||||||
|
|
||||||
total_folders_checked = 0
|
|
||||||
empty_folders_removed = 0
|
|
||||||
orphaned_folders_removed = 0
|
|
||||||
|
|
||||||
# Scan the example images directory
|
|
||||||
try:
|
|
||||||
with os.scandir(example_images_path) as it:
|
|
||||||
for entry in it:
|
|
||||||
if not entry.is_dir(follow_symlinks=False):
|
|
||||||
continue
|
|
||||||
|
|
||||||
folder_name = entry.name
|
|
||||||
folder_path = entry.path
|
|
||||||
total_folders_checked += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check if folder is empty
|
|
||||||
is_empty = cls._is_folder_empty(folder_path)
|
|
||||||
if is_empty:
|
|
||||||
logger.debug(f"Removing empty example images folder: {folder_name}")
|
|
||||||
await cls._remove_folder_safely(folder_path)
|
|
||||||
empty_folders_removed += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if folder name is a valid SHA256 hash (64 hex characters)
|
|
||||||
if len(folder_name) != 64 or not all(c in '0123456789abcdefABCDEF' for c in folder_name):
|
|
||||||
# Skip non-hash folders to avoid deleting other content
|
|
||||||
logger.debug(f"Skipping non-hash folder: {folder_name}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if hash exists in any of the scanners
|
|
||||||
hash_exists = (
|
|
||||||
lora_scanner.has_hash(folder_name) or
|
|
||||||
checkpoint_scanner.has_hash(folder_name) or
|
|
||||||
embedding_scanner.has_hash(folder_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hash_exists:
|
|
||||||
logger.debug(f"Removing example images folder for deleted model: {folder_name}")
|
|
||||||
await cls._remove_folder_safely(folder_path)
|
|
||||||
orphaned_folders_removed += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
if result.get('success'):
|
||||||
logger.error(f"Error processing example images folder {folder_name}: {e}")
|
logger.debug(
|
||||||
|
"Manual example images cleanup completed: moved=%s",
|
||||||
# Yield control periodically
|
result.get('moved_total'),
|
||||||
await asyncio.sleep(0.01)
|
)
|
||||||
|
elif result.get('partial_success'):
|
||||||
except Exception as e:
|
logger.warning(
|
||||||
logger.error(f"Error scanning example images directory: {e}")
|
"Manual example images cleanup partially succeeded: moved=%s failures=%s",
|
||||||
return
|
result.get('moved_total'),
|
||||||
|
result.get('move_failures'),
|
||||||
# Log final cleanup report
|
)
|
||||||
total_removed = empty_folders_removed + orphaned_folders_removed
|
|
||||||
if total_removed > 0:
|
|
||||||
logger.info(f"Example images cleanup completed: checked {total_folders_checked} folders, "
|
|
||||||
f"removed {empty_folders_removed} empty folders and {orphaned_folders_removed} "
|
|
||||||
f"folders for deleted models (total: {total_removed} removed)")
|
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Example images cleanup completed: checked {total_folders_checked} folders, "
|
logger.debug(
|
||||||
f"no cleanup needed")
|
"Manual example images cleanup skipped or failed: %s",
|
||||||
|
result.get('error', 'no changes'),
|
||||||
except Exception as e:
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e: # pragma: no cover - defensive guard
|
||||||
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
logger.error(f"Error during example images cleanup: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
@classmethod
|
'success': False,
|
||||||
def _is_folder_empty(cls, folder_path: str) -> bool:
|
'error': str(e),
|
||||||
"""Check if a folder is empty
|
'error_code': 'unexpected_error',
|
||||||
|
}
|
||||||
Args:
|
|
||||||
folder_path: Path to the folder to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if folder is empty, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with os.scandir(folder_path) as it:
|
|
||||||
return not any(it)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Error checking if folder is empty {folder_path}: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def _remove_folder_safely(cls, folder_path: str):
|
|
||||||
"""Safely remove a folder and all its contents
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder_path: Path to the folder to remove
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import shutil
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
await loop.run_in_executor(None, shutil.rmtree, folder_path)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to remove folder {folder_path}: {e}")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _cleanup(cls, app):
|
async def _cleanup(cls, app):
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import importlib
|
|
||||||
import sys
|
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Check if running in standalone mode
|
||||||
standalone_mode = 'nodes' not in sys.modules
|
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
|
||||||
if not standalone_mode:
|
if not standalone_mode:
|
||||||
from .metadata_hook import MetadataHook
|
from .metadata_hook import MetadataHook
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
import os
|
||||||
from .constants import IMAGES
|
from .constants import IMAGES
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Check if running in standalone mode
|
||||||
standalone_mode = 'nodes' not in sys.modules
|
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
|
||||||
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER
|
from .constants import MODELS, PROMPTS, SAMPLING, LORAS, SIZE, IS_SAMPLER
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
217
py/routes/base_recipe_routes.py
Normal file
217
py/routes/base_recipe_routes.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""Base infrastructure shared across recipe routes."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Callable, Mapping
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from ..config import config
|
||||||
|
from ..recipes import RecipeParserFactory
|
||||||
|
from ..services.downloader import get_downloader
|
||||||
|
from ..services.recipes import (
|
||||||
|
RecipeAnalysisService,
|
||||||
|
RecipePersistenceService,
|
||||||
|
RecipeSharingService,
|
||||||
|
)
|
||||||
|
from ..services.server_i18n import server_i18n
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||||
|
from ..utils.exif_utils import ExifUtils
|
||||||
|
from .handlers.recipe_handlers import (
|
||||||
|
RecipeAnalysisHandler,
|
||||||
|
RecipeHandlerSet,
|
||||||
|
RecipeListingHandler,
|
||||||
|
RecipeManagementHandler,
|
||||||
|
RecipePageView,
|
||||||
|
RecipeQueryHandler,
|
||||||
|
RecipeSharingHandler,
|
||||||
|
)
|
||||||
|
from .recipe_route_registrar import ROUTE_DEFINITIONS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRecipeRoutes:
|
||||||
|
"""Common dependency and startup wiring for recipe routes."""
|
||||||
|
|
||||||
|
_HANDLER_NAMES: tuple[str, ...] = tuple(
|
||||||
|
definition.handler_name for definition in ROUTE_DEFINITIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
template_name: str = "recipes.html"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.recipe_scanner = None
|
||||||
|
self.lora_scanner = None
|
||||||
|
self.civitai_client = None
|
||||||
|
self.settings = settings
|
||||||
|
self.server_i18n = server_i18n
|
||||||
|
self.template_env = jinja2.Environment(
|
||||||
|
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||||
|
autoescape=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._i18n_registered = False
|
||||||
|
self._startup_hooks_registered = False
|
||||||
|
self._handler_set: RecipeHandlerSet | None = None
|
||||||
|
self._handler_mapping: dict[str, Callable] | None = None
|
||||||
|
|
||||||
|
async def attach_dependencies(self, app: web.Application | None = None) -> None:
|
||||||
|
"""Resolve shared services from the registry."""
|
||||||
|
|
||||||
|
await self._ensure_services()
|
||||||
|
self._ensure_i18n_filter()
|
||||||
|
|
||||||
|
async def ensure_dependencies_ready(self) -> None:
|
||||||
|
"""Ensure dependencies are available for request handlers."""
|
||||||
|
|
||||||
|
if self.recipe_scanner is None or self.civitai_client is None:
|
||||||
|
await self.attach_dependencies()
|
||||||
|
|
||||||
|
def register_startup_hooks(self, app: web.Application) -> None:
|
||||||
|
"""Register startup hooks once for dependency wiring."""
|
||||||
|
|
||||||
|
if self._startup_hooks_registered:
|
||||||
|
return
|
||||||
|
|
||||||
|
app.on_startup.append(self.attach_dependencies)
|
||||||
|
app.on_startup.append(self.prewarm_cache)
|
||||||
|
self._startup_hooks_registered = True
|
||||||
|
|
||||||
|
async def prewarm_cache(self, app: web.Application | None = None) -> None:
|
||||||
|
"""Pre-load recipe and LoRA caches on startup."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.attach_dependencies(app)
|
||||||
|
|
||||||
|
if self.lora_scanner is not None:
|
||||||
|
await self.lora_scanner.get_cached_data()
|
||||||
|
hash_index = getattr(self.lora_scanner, "_hash_index", None)
|
||||||
|
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
|
||||||
|
_ = len(hash_index._hash_to_path)
|
||||||
|
|
||||||
|
if self.recipe_scanner is not None:
|
||||||
|
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
|
||||||
|
|
||||||
|
def to_route_mapping(self) -> Mapping[str, Callable]:
|
||||||
|
"""Return a mapping of handler name to coroutine for registrar binding."""
|
||||||
|
|
||||||
|
if self._handler_mapping is None:
|
||||||
|
handler_set = self._create_handler_set()
|
||||||
|
self._handler_set = handler_set
|
||||||
|
self._handler_mapping = handler_set.to_route_mapping()
|
||||||
|
return self._handler_mapping
|
||||||
|
|
||||||
|
# Internal helpers -------------------------------------------------
|
||||||
|
|
||||||
|
async def _ensure_services(self) -> None:
|
||||||
|
if self.recipe_scanner is None:
|
||||||
|
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||||
|
self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None)
|
||||||
|
|
||||||
|
if self.civitai_client is None:
|
||||||
|
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||||
|
|
||||||
|
def _ensure_i18n_filter(self) -> None:
|
||||||
|
if not self._i18n_registered:
|
||||||
|
self.template_env.filters["t"] = self.server_i18n.create_template_filter()
|
||||||
|
self._i18n_registered = True
|
||||||
|
|
||||||
|
def get_handler_owner(self):
|
||||||
|
"""Return the object supplying bound handler coroutines."""
|
||||||
|
|
||||||
|
if self._handler_set is None:
|
||||||
|
self._handler_set = self._create_handler_set()
|
||||||
|
return self._handler_set
|
||||||
|
|
||||||
|
def _create_handler_set(self) -> RecipeHandlerSet:
|
||||||
|
recipe_scanner_getter = lambda: self.recipe_scanner
|
||||||
|
civitai_client_getter = lambda: self.civitai_client
|
||||||
|
|
||||||
|
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
if not standalone_mode:
|
||||||
|
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
|
||||||
|
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
|
||||||
|
MetadataProcessor,
|
||||||
|
)
|
||||||
|
from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found]
|
||||||
|
MetadataRegistry,
|
||||||
|
)
|
||||||
|
else: # pragma: no cover - optional dependency path
|
||||||
|
get_metadata = None # type: ignore[assignment]
|
||||||
|
MetadataProcessor = None # type: ignore[assignment]
|
||||||
|
MetadataRegistry = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
analysis_service = RecipeAnalysisService(
|
||||||
|
exif_utils=ExifUtils,
|
||||||
|
recipe_parser_factory=RecipeParserFactory,
|
||||||
|
downloader_factory=get_downloader,
|
||||||
|
metadata_collector=get_metadata,
|
||||||
|
metadata_processor_cls=MetadataProcessor,
|
||||||
|
metadata_registry_cls=MetadataRegistry,
|
||||||
|
standalone_mode=standalone_mode,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
persistence_service = RecipePersistenceService(
|
||||||
|
exif_utils=ExifUtils,
|
||||||
|
card_preview_width=CARD_PREVIEW_WIDTH,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
sharing_service = RecipeSharingService(logger=logger)
|
||||||
|
|
||||||
|
page_view = RecipePageView(
|
||||||
|
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||||
|
settings_service=self.settings,
|
||||||
|
server_i18n=self.server_i18n,
|
||||||
|
template_env=self.template_env,
|
||||||
|
template_name=self.template_name,
|
||||||
|
recipe_scanner_getter=recipe_scanner_getter,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
listing = RecipeListingHandler(
|
||||||
|
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||||
|
recipe_scanner_getter=recipe_scanner_getter,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
query = RecipeQueryHandler(
|
||||||
|
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||||
|
recipe_scanner_getter=recipe_scanner_getter,
|
||||||
|
format_recipe_file_url=listing.format_recipe_file_url,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
management = RecipeManagementHandler(
|
||||||
|
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||||
|
recipe_scanner_getter=recipe_scanner_getter,
|
||||||
|
logger=logger,
|
||||||
|
persistence_service=persistence_service,
|
||||||
|
analysis_service=analysis_service,
|
||||||
|
)
|
||||||
|
analysis = RecipeAnalysisHandler(
|
||||||
|
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||||
|
recipe_scanner_getter=recipe_scanner_getter,
|
||||||
|
civitai_client_getter=civitai_client_getter,
|
||||||
|
logger=logger,
|
||||||
|
analysis_service=analysis_service,
|
||||||
|
)
|
||||||
|
sharing = RecipeSharingHandler(
|
||||||
|
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||||
|
recipe_scanner_getter=recipe_scanner_getter,
|
||||||
|
logger=logger,
|
||||||
|
sharing_service=sharing_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RecipeHandlerSet(
|
||||||
|
page_view=page_view,
|
||||||
|
listing=listing,
|
||||||
|
query=query,
|
||||||
|
management=management,
|
||||||
|
analysis=analysis,
|
||||||
|
sharing=sharing,
|
||||||
|
)
|
||||||
|
|
||||||
@@ -2,9 +2,9 @@ import logging
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
|
from .model_route_registrar import ModelRouteRegistrar
|
||||||
from ..services.checkpoint_service import CheckpointService
|
from ..services.checkpoint_service import CheckpointService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.metadata_service import get_default_metadata_provider
|
|
||||||
from ..config import config
|
from ..config import config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -14,8 +14,7 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize Checkpoint routes with Checkpoint service"""
|
"""Initialize Checkpoint routes with Checkpoint service"""
|
||||||
# Service will be initialized later via setup_routes
|
super().__init__()
|
||||||
self.service = None
|
|
||||||
self.template_name = "checkpoints.html"
|
self.template_name = "checkpoints.html"
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
@@ -23,8 +22,8 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
self.service = CheckpointService(checkpoint_scanner)
|
self.service = CheckpointService(checkpoint_scanner)
|
||||||
|
|
||||||
# Initialize parent with the service
|
# Attach service dependencies
|
||||||
super().__init__(self.service)
|
self.attach_service(self.service)
|
||||||
|
|
||||||
def setup_routes(self, app: web.Application):
|
def setup_routes(self, app: web.Application):
|
||||||
"""Setup Checkpoint routes"""
|
"""Setup Checkpoint routes"""
|
||||||
@@ -34,14 +33,14 @@ class CheckpointRoutes(BaseModelRoutes):
|
|||||||
# Setup common routes with 'checkpoints' prefix (includes page route)
|
# Setup common routes with 'checkpoints' prefix (includes page route)
|
||||||
super().setup_routes(app, 'checkpoints')
|
super().setup_routes(app, 'checkpoints')
|
||||||
|
|
||||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||||
"""Setup Checkpoint-specific routes"""
|
"""Setup Checkpoint-specific routes"""
|
||||||
# Checkpoint info by name
|
# Checkpoint info by name
|
||||||
app.router.add_get(f'/api/lm/{prefix}/info/{{name}}', self.get_checkpoint_info)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_checkpoint_info)
|
||||||
|
|
||||||
# Checkpoint roots and Unet roots
|
# Checkpoint roots and Unet roots
|
||||||
app.router.add_get(f'/api/lm/{prefix}/checkpoints_roots', self.get_checkpoints_roots)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/checkpoints_roots', prefix, self.get_checkpoints_roots)
|
||||||
app.router.add_get(f'/api/lm/{prefix}/unet_roots', self.get_unet_roots)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/unet_roots', prefix, self.get_unet_roots)
|
||||||
|
|
||||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||||
"""Validate CivitAI model type for Checkpoint"""
|
"""Validate CivitAI model type for Checkpoint"""
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ import logging
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
|
from .model_route_registrar import ModelRouteRegistrar
|
||||||
from ..services.embedding_service import EmbeddingService
|
from ..services.embedding_service import EmbeddingService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.metadata_service import get_default_metadata_provider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -13,8 +13,7 @@ class EmbeddingRoutes(BaseModelRoutes):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize Embedding routes with Embedding service"""
|
"""Initialize Embedding routes with Embedding service"""
|
||||||
# Service will be initialized later via setup_routes
|
super().__init__()
|
||||||
self.service = None
|
|
||||||
self.template_name = "embeddings.html"
|
self.template_name = "embeddings.html"
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
@@ -22,8 +21,8 @@ class EmbeddingRoutes(BaseModelRoutes):
|
|||||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
self.service = EmbeddingService(embedding_scanner)
|
self.service = EmbeddingService(embedding_scanner)
|
||||||
|
|
||||||
# Initialize parent with the service
|
# Attach service dependencies
|
||||||
super().__init__(self.service)
|
self.attach_service(self.service)
|
||||||
|
|
||||||
def setup_routes(self, app: web.Application):
|
def setup_routes(self, app: web.Application):
|
||||||
"""Setup Embedding routes"""
|
"""Setup Embedding routes"""
|
||||||
@@ -33,10 +32,10 @@ class EmbeddingRoutes(BaseModelRoutes):
|
|||||||
# Setup common routes with 'embeddings' prefix (includes page route)
|
# Setup common routes with 'embeddings' prefix (includes page route)
|
||||||
super().setup_routes(app, 'embeddings')
|
super().setup_routes(app, 'embeddings')
|
||||||
|
|
||||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||||
"""Setup Embedding-specific routes"""
|
"""Setup Embedding-specific routes"""
|
||||||
# Embedding info by name
|
# Embedding info by name
|
||||||
app.router.add_get(f'/api/lm/{prefix}/info/{{name}}', self.get_embedding_info)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/info/{name}', prefix, self.get_embedding_info)
|
||||||
|
|
||||||
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
def _validate_civitai_model_type(self, model_type: str) -> bool:
|
||||||
"""Validate CivitAI model type for Embedding"""
|
"""Validate CivitAI model type for Embedding"""
|
||||||
|
|||||||
62
py/routes/example_images_route_registrar.py
Normal file
62
py/routes/example_images_route_registrar.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Route registrar for example image endpoints."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Iterable, Mapping
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RouteDefinition:
|
||||||
|
"""Declarative configuration for a HTTP route."""
|
||||||
|
|
||||||
|
method: str
|
||||||
|
path: str
|
||||||
|
handler_name: str
|
||||||
|
|
||||||
|
|
||||||
|
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||||
|
RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"),
|
||||||
|
RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"),
|
||||||
|
RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"),
|
||||||
|
RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"),
|
||||||
|
RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"),
|
||||||
|
RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"),
|
||||||
|
RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"),
|
||||||
|
RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"),
|
||||||
|
RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"),
|
||||||
|
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||||
|
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagesRouteRegistrar:
|
||||||
|
"""Bind declarative example image routes to an aiohttp router."""
|
||||||
|
|
||||||
|
_METHOD_MAP = {
|
||||||
|
"GET": "add_get",
|
||||||
|
"POST": "add_post",
|
||||||
|
"PUT": "add_put",
|
||||||
|
"DELETE": "add_delete",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, app: web.Application) -> None:
|
||||||
|
self._app = app
|
||||||
|
|
||||||
|
def register_routes(
|
||||||
|
self,
|
||||||
|
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||||
|
*,
|
||||||
|
definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS,
|
||||||
|
) -> None:
|
||||||
|
"""Register each route definition using the supplied handlers."""
|
||||||
|
|
||||||
|
for definition in definitions:
|
||||||
|
handler = handler_lookup[definition.handler_name]
|
||||||
|
self._bind_route(definition.method, definition.path, handler)
|
||||||
|
|
||||||
|
def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None:
|
||||||
|
add_method_name = self._METHOD_MAP[method.upper()]
|
||||||
|
add_method = getattr(self._app.router, add_method_name)
|
||||||
|
add_method(path, handler)
|
||||||
@@ -1,74 +1,88 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from ..utils.example_images_download_manager import DownloadManager
|
from typing import Callable, Mapping
|
||||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from .example_images_route_registrar import ExampleImagesRouteRegistrar
|
||||||
|
from .handlers.example_images_handlers import (
|
||||||
|
ExampleImagesDownloadHandler,
|
||||||
|
ExampleImagesFileHandler,
|
||||||
|
ExampleImagesHandlerSet,
|
||||||
|
ExampleImagesManagementHandler,
|
||||||
|
)
|
||||||
|
from ..services.use_cases.example_images import (
|
||||||
|
DownloadExampleImagesUseCase,
|
||||||
|
ImportExampleImagesUseCase,
|
||||||
|
)
|
||||||
|
from ..utils.example_images_download_manager import (
|
||||||
|
DownloadManager,
|
||||||
|
get_default_download_manager,
|
||||||
|
)
|
||||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||||
from ..services.websocket_manager import ws_manager
|
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||||
|
from ..services.example_images_cleanup_service import ExampleImagesCleanupService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ExampleImagesRoutes:
|
class ExampleImagesRoutes:
|
||||||
"""Routes for example images related functionality"""
|
"""Route controller for example image endpoints."""
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def setup_routes(app):
|
|
||||||
"""Register example images routes"""
|
|
||||||
app.router.add_post('/api/lm/download-example-images', ExampleImagesRoutes.download_example_images)
|
|
||||||
app.router.add_post('/api/lm/import-example-images', ExampleImagesRoutes.import_example_images)
|
|
||||||
app.router.add_get('/api/lm/example-images-status', ExampleImagesRoutes.get_example_images_status)
|
|
||||||
app.router.add_post('/api/lm/pause-example-images', ExampleImagesRoutes.pause_example_images)
|
|
||||||
app.router.add_post('/api/lm/resume-example-images', ExampleImagesRoutes.resume_example_images)
|
|
||||||
app.router.add_post('/api/lm/open-example-images-folder', ExampleImagesRoutes.open_example_images_folder)
|
|
||||||
app.router.add_get('/api/lm/example-image-files', ExampleImagesRoutes.get_example_image_files)
|
|
||||||
app.router.add_get('/api/lm/has-example-images', ExampleImagesRoutes.has_example_images)
|
|
||||||
app.router.add_post('/api/lm/delete-example-image', ExampleImagesRoutes.delete_example_image)
|
|
||||||
app.router.add_post('/api/lm/force-download-example-images', ExampleImagesRoutes.force_download_example_images)
|
|
||||||
|
|
||||||
@staticmethod
|
def __init__(
|
||||||
async def download_example_images(request):
|
self,
|
||||||
"""Download example images for models from Civitai"""
|
*,
|
||||||
return await DownloadManager.start_download(request)
|
ws_manager,
|
||||||
|
download_manager: DownloadManager | None = None,
|
||||||
|
processor=ExampleImagesProcessor,
|
||||||
|
file_manager=ExampleImagesFileManager,
|
||||||
|
cleanup_service: ExampleImagesCleanupService | None = None,
|
||||||
|
) -> None:
|
||||||
|
if ws_manager is None:
|
||||||
|
raise ValueError("ws_manager is required")
|
||||||
|
self._download_manager = download_manager or get_default_download_manager(ws_manager)
|
||||||
|
self._processor = processor
|
||||||
|
self._file_manager = file_manager
|
||||||
|
self._cleanup_service = cleanup_service or ExampleImagesCleanupService()
|
||||||
|
self._handler_set: ExampleImagesHandlerSet | None = None
|
||||||
|
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||||
|
|
||||||
@staticmethod
|
@classmethod
|
||||||
async def get_example_images_status(request):
|
def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
|
||||||
"""Get the current status of example images download"""
|
"""Register routes on the given aiohttp application using default wiring."""
|
||||||
return await DownloadManager.get_status(request)
|
|
||||||
|
|
||||||
@staticmethod
|
controller = cls(ws_manager=ws_manager)
|
||||||
async def pause_example_images(request):
|
controller.register(app)
|
||||||
"""Pause the example images download"""
|
|
||||||
return await DownloadManager.pause_download(request)
|
|
||||||
|
|
||||||
@staticmethod
|
def register(self, app: web.Application) -> None:
|
||||||
async def resume_example_images(request):
|
"""Bind the controller's handlers to the aiohttp router."""
|
||||||
"""Resume the example images download"""
|
|
||||||
return await DownloadManager.resume_download(request)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def open_example_images_folder(request):
|
|
||||||
"""Open the example images folder for a specific model"""
|
|
||||||
return await ExampleImagesFileManager.open_folder(request)
|
|
||||||
|
|
||||||
@staticmethod
|
registrar = ExampleImagesRouteRegistrar(app)
|
||||||
async def get_example_image_files(request):
|
registrar.register_routes(self.to_route_mapping())
|
||||||
"""Get list of example image files for a specific model"""
|
|
||||||
return await ExampleImagesFileManager.get_files(request)
|
|
||||||
|
|
||||||
@staticmethod
|
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||||
async def import_example_images(request):
|
"""Return the registrar-compatible mapping of handler names to callables."""
|
||||||
"""Import local example images for a model"""
|
|
||||||
return await ExampleImagesProcessor.import_images(request)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def has_example_images(request):
|
|
||||||
"""Check if example images folder exists and is not empty for a model"""
|
|
||||||
return await ExampleImagesFileManager.has_images(request)
|
|
||||||
|
|
||||||
@staticmethod
|
if self._handler_mapping is None:
|
||||||
async def delete_example_image(request):
|
handler_set = self._build_handler_set()
|
||||||
"""Delete a custom example image for a model"""
|
self._handler_set = handler_set
|
||||||
return await ExampleImagesProcessor.delete_custom_image(request)
|
self._handler_mapping = handler_set.to_route_mapping()
|
||||||
|
return self._handler_mapping
|
||||||
|
|
||||||
@staticmethod
|
def _build_handler_set(self) -> ExampleImagesHandlerSet:
|
||||||
async def force_download_example_images(request):
|
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
|
||||||
"""Force download example images for specific models"""
|
download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager)
|
||||||
return await DownloadManager.start_force_download(request)
|
download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager)
|
||||||
|
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
|
||||||
|
management_handler = ExampleImagesManagementHandler(
|
||||||
|
import_use_case,
|
||||||
|
self._processor,
|
||||||
|
self._cleanup_service,
|
||||||
|
)
|
||||||
|
file_handler = ExampleImagesFileHandler(self._file_manager)
|
||||||
|
return ExampleImagesHandlerSet(
|
||||||
|
download=download_handler,
|
||||||
|
management=management_handler,
|
||||||
|
files=file_handler,
|
||||||
|
)
|
||||||
|
|||||||
159
py/routes/handlers/example_images_handlers.py
Normal file
159
py/routes/handlers/example_images_handlers.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""Handler set for example image routes."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Mapping
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from ...services.use_cases.example_images import (
|
||||||
|
DownloadExampleImagesConfigurationError,
|
||||||
|
DownloadExampleImagesInProgressError,
|
||||||
|
DownloadExampleImagesUseCase,
|
||||||
|
ImportExampleImagesUseCase,
|
||||||
|
ImportExampleImagesValidationError,
|
||||||
|
)
|
||||||
|
from ...utils.example_images_download_manager import (
|
||||||
|
DownloadConfigurationError,
|
||||||
|
DownloadInProgressError,
|
||||||
|
DownloadNotRunningError,
|
||||||
|
ExampleImagesDownloadError,
|
||||||
|
)
|
||||||
|
from ...utils.example_images_processor import ExampleImagesImportError
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagesDownloadHandler:
|
||||||
|
"""HTTP adapters for download-related example image endpoints."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
download_use_case: DownloadExampleImagesUseCase,
|
||||||
|
download_manager,
|
||||||
|
) -> None:
|
||||||
|
self._download_use_case = download_use_case
|
||||||
|
self._download_manager = download_manager
|
||||||
|
|
||||||
|
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
result = await self._download_use_case.execute(payload)
|
||||||
|
return web.json_response(result)
|
||||||
|
except DownloadExampleImagesInProgressError as exc:
|
||||||
|
response = {
|
||||||
|
'success': False,
|
||||||
|
'error': str(exc),
|
||||||
|
'status': exc.progress,
|
||||||
|
}
|
||||||
|
return web.json_response(response, status=400)
|
||||||
|
except DownloadExampleImagesConfigurationError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||||
|
except ExampleImagesDownloadError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
result = await self._download_manager.get_status(request)
|
||||||
|
return web.json_response(result)
|
||||||
|
|
||||||
|
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
try:
|
||||||
|
result = await self._download_manager.pause_download(request)
|
||||||
|
return web.json_response(result)
|
||||||
|
except DownloadNotRunningError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||||
|
|
||||||
|
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
try:
|
||||||
|
result = await self._download_manager.resume_download(request)
|
||||||
|
return web.json_response(result)
|
||||||
|
except DownloadNotRunningError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||||
|
|
||||||
|
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
result = await self._download_manager.start_force_download(payload)
|
||||||
|
return web.json_response(result)
|
||||||
|
except DownloadInProgressError as exc:
|
||||||
|
response = {
|
||||||
|
'success': False,
|
||||||
|
'error': str(exc),
|
||||||
|
'status': exc.progress_snapshot,
|
||||||
|
}
|
||||||
|
return web.json_response(response, status=400)
|
||||||
|
except DownloadConfigurationError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||||
|
except ExampleImagesDownloadError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagesManagementHandler:
|
||||||
|
"""HTTP adapters for import/delete endpoints."""
|
||||||
|
|
||||||
|
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor, cleanup_service) -> None:
|
||||||
|
self._import_use_case = import_use_case
|
||||||
|
self._processor = processor
|
||||||
|
self._cleanup_service = cleanup_service
|
||||||
|
|
||||||
|
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
try:
|
||||||
|
result = await self._import_use_case.execute(request)
|
||||||
|
return web.json_response(result)
|
||||||
|
except ImportExampleImagesValidationError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||||
|
except ExampleImagesImportError as exc:
|
||||||
|
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
return await self._processor.delete_custom_image(request)
|
||||||
|
|
||||||
|
async def cleanup_example_image_folders(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
result = await self._cleanup_service.cleanup_example_image_folders()
|
||||||
|
|
||||||
|
if result.get('success') or result.get('partial_success'):
|
||||||
|
return web.json_response(result, status=200)
|
||||||
|
|
||||||
|
error_code = result.get('error_code')
|
||||||
|
status = 400 if error_code in {'path_not_configured', 'path_not_found'} else 500
|
||||||
|
return web.json_response(result, status=status)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagesFileHandler:
|
||||||
|
"""HTTP adapters for filesystem-centric endpoints."""
|
||||||
|
|
||||||
|
def __init__(self, file_manager) -> None:
|
||||||
|
self._file_manager = file_manager
|
||||||
|
|
||||||
|
async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
return await self._file_manager.open_folder(request)
|
||||||
|
|
||||||
|
async def get_example_image_files(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
return await self._file_manager.get_files(request)
|
||||||
|
|
||||||
|
async def has_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
return await self._file_manager.has_images(request)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ExampleImagesHandlerSet:
|
||||||
|
"""Aggregate of handlers exposed to the registrar."""
|
||||||
|
|
||||||
|
download: ExampleImagesDownloadHandler
|
||||||
|
management: ExampleImagesManagementHandler
|
||||||
|
files: ExampleImagesFileHandler
|
||||||
|
|
||||||
|
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||||
|
"""Flatten handler methods into the registrar mapping."""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"download_example_images": self.download.download_example_images,
|
||||||
|
"get_example_images_status": self.download.get_example_images_status,
|
||||||
|
"pause_example_images": self.download.pause_example_images,
|
||||||
|
"resume_example_images": self.download.resume_example_images,
|
||||||
|
"force_download_example_images": self.download.force_download_example_images,
|
||||||
|
"import_example_images": self.management.import_example_images,
|
||||||
|
"delete_example_image": self.management.delete_example_image,
|
||||||
|
"cleanup_example_image_folders": self.management.cleanup_example_image_folders,
|
||||||
|
"open_example_images_folder": self.files.open_example_images_folder,
|
||||||
|
"get_example_image_files": self.files.get_example_image_files,
|
||||||
|
"has_example_images": self.files.has_example_images,
|
||||||
|
}
|
||||||
1020
py/routes/handlers/model_handlers.py
Normal file
1020
py/routes/handlers/model_handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
725
py/routes/handlers/recipe_handlers.py
Normal file
725
py/routes/handlers/recipe_handlers.py
Normal file
@@ -0,0 +1,725 @@
|
|||||||
|
"""Dedicated handler objects for recipe-related routes."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from ...config import config
|
||||||
|
from ...services.server_i18n import server_i18n as default_server_i18n
|
||||||
|
from ...services.settings_manager import SettingsManager
|
||||||
|
from ...services.recipes import (
|
||||||
|
RecipeAnalysisService,
|
||||||
|
RecipeDownloadError,
|
||||||
|
RecipeNotFoundError,
|
||||||
|
RecipePersistenceService,
|
||||||
|
RecipeSharingService,
|
||||||
|
RecipeValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
Logger = logging.Logger
|
||||||
|
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||||
|
RecipeScannerGetter = Callable[[], Any]
|
||||||
|
CivitaiClientGetter = Callable[[], Any]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RecipeHandlerSet:
|
||||||
|
"""Group of handlers providing recipe route implementations."""
|
||||||
|
|
||||||
|
page_view: "RecipePageView"
|
||||||
|
listing: "RecipeListingHandler"
|
||||||
|
query: "RecipeQueryHandler"
|
||||||
|
management: "RecipeManagementHandler"
|
||||||
|
analysis: "RecipeAnalysisHandler"
|
||||||
|
sharing: "RecipeSharingHandler"
|
||||||
|
|
||||||
|
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||||
|
"""Expose handler coroutines keyed by registrar handler names."""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"render_page": self.page_view.render_page,
|
||||||
|
"list_recipes": self.listing.list_recipes,
|
||||||
|
"get_recipe": self.listing.get_recipe,
|
||||||
|
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
||||||
|
"analyze_local_image": self.analysis.analyze_local_image,
|
||||||
|
"save_recipe": self.management.save_recipe,
|
||||||
|
"delete_recipe": self.management.delete_recipe,
|
||||||
|
"get_top_tags": self.query.get_top_tags,
|
||||||
|
"get_base_models": self.query.get_base_models,
|
||||||
|
"share_recipe": self.sharing.share_recipe,
|
||||||
|
"download_shared_recipe": self.sharing.download_shared_recipe,
|
||||||
|
"get_recipe_syntax": self.query.get_recipe_syntax,
|
||||||
|
"update_recipe": self.management.update_recipe,
|
||||||
|
"reconnect_lora": self.management.reconnect_lora,
|
||||||
|
"find_duplicates": self.query.find_duplicates,
|
||||||
|
"bulk_delete": self.management.bulk_delete,
|
||||||
|
"save_recipe_from_widget": self.management.save_recipe_from_widget,
|
||||||
|
"get_recipes_for_lora": self.query.get_recipes_for_lora,
|
||||||
|
"scan_recipes": self.query.scan_recipes,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RecipePageView:
|
||||||
|
"""Render the recipe shell page."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||||
|
settings_service: SettingsManager,
|
||||||
|
server_i18n=default_server_i18n,
|
||||||
|
template_env,
|
||||||
|
template_name: str,
|
||||||
|
recipe_scanner_getter: RecipeScannerGetter,
|
||||||
|
logger: Logger,
|
||||||
|
) -> None:
|
||||||
|
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||||
|
self._settings = settings_service
|
||||||
|
self._server_i18n = server_i18n
|
||||||
|
self._template_env = template_env
|
||||||
|
self._template_name = template_name
|
||||||
|
self._recipe_scanner_getter = recipe_scanner_getter
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
async def render_page(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None: # pragma: no cover - defensive guard
|
||||||
|
raise RuntimeError("Recipe scanner not available")
|
||||||
|
|
||||||
|
user_language = self._settings.get("language", "en")
|
||||||
|
self._server_i18n.set_locale(user_language)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await recipe_scanner.get_cached_data(force_refresh=False)
|
||||||
|
rendered = self._template_env.get_template(self._template_name).render(
|
||||||
|
recipes=[],
|
||||||
|
is_initializing=False,
|
||||||
|
settings=self._settings,
|
||||||
|
request=request,
|
||||||
|
t=self._server_i18n.get_translation,
|
||||||
|
)
|
||||||
|
except Exception as cache_error: # pragma: no cover - logging path
|
||||||
|
self._logger.error("Error loading recipe cache data: %s", cache_error)
|
||||||
|
rendered = self._template_env.get_template(self._template_name).render(
|
||||||
|
is_initializing=True,
|
||||||
|
settings=self._settings,
|
||||||
|
request=request,
|
||||||
|
t=self._server_i18n.get_translation,
|
||||||
|
)
|
||||||
|
return web.Response(text=rendered, content_type="text/html")
|
||||||
|
except Exception as exc: # pragma: no cover - logging path
|
||||||
|
self._logger.error("Error handling recipes request: %s", exc, exc_info=True)
|
||||||
|
return web.Response(text="Error loading recipes page", status=500)
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeListingHandler:
|
||||||
|
"""Provide listing and detail APIs for recipes."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||||
|
recipe_scanner_getter: RecipeScannerGetter,
|
||||||
|
logger: Logger,
|
||||||
|
) -> None:
|
||||||
|
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||||
|
self._recipe_scanner_getter = recipe_scanner_getter
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
async def list_recipes(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
page = int(request.query.get("page", "1"))
|
||||||
|
page_size = int(request.query.get("page_size", "20"))
|
||||||
|
sort_by = request.query.get("sort_by", "date")
|
||||||
|
search = request.query.get("search")
|
||||||
|
|
||||||
|
search_options = {
|
||||||
|
"title": request.query.get("search_title", "true").lower() == "true",
|
||||||
|
"tags": request.query.get("search_tags", "true").lower() == "true",
|
||||||
|
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
|
||||||
|
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
filters: Dict[str, list[str]] = {}
|
||||||
|
base_models = request.query.get("base_models")
|
||||||
|
if base_models:
|
||||||
|
filters["base_model"] = base_models.split(",")
|
||||||
|
|
||||||
|
tags = request.query.get("tags")
|
||||||
|
if tags:
|
||||||
|
filters["tags"] = tags.split(",")
|
||||||
|
|
||||||
|
lora_hash = request.query.get("lora_hash")
|
||||||
|
|
||||||
|
result = await recipe_scanner.get_paginated_data(
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
sort_by=sort_by,
|
||||||
|
search=search,
|
||||||
|
filters=filters,
|
||||||
|
search_options=search_options,
|
||||||
|
lora_hash=lora_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
for item in result.get("items", []):
|
||||||
|
file_path = item.get("file_path")
|
||||||
|
if file_path:
|
||||||
|
item["file_url"] = self.format_recipe_file_url(file_path)
|
||||||
|
else:
|
||||||
|
item.setdefault("file_url", "/loras_static/images/no-preview.png")
|
||||||
|
item.setdefault("loras", [])
|
||||||
|
item.setdefault("base_model", "")
|
||||||
|
|
||||||
|
return web.json_response(result)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error retrieving recipes: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_recipe(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
recipe_id = request.match_info["recipe_id"]
|
||||||
|
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||||
|
|
||||||
|
if not recipe:
|
||||||
|
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||||
|
return web.json_response(recipe)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
def format_recipe_file_url(self, file_path: str) -> str:
|
||||||
|
try:
|
||||||
|
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, "/")
|
||||||
|
normalized_path = file_path.replace(os.sep, "/")
|
||||||
|
if normalized_path.startswith(recipes_dir):
|
||||||
|
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, "/")
|
||||||
|
return f"/loras_static/root1/preview/{relative_path}"
|
||||||
|
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||||
|
except Exception as exc: # pragma: no cover - logging path
|
||||||
|
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
|
||||||
|
return "/loras_static/images/no-preview.png"
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeQueryHandler:
|
||||||
|
"""Provide read-only insights on recipe data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||||
|
recipe_scanner_getter: RecipeScannerGetter,
|
||||||
|
format_recipe_file_url: Callable[[str], str],
|
||||||
|
logger: Logger,
|
||||||
|
) -> None:
|
||||||
|
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||||
|
self._recipe_scanner_getter = recipe_scanner_getter
|
||||||
|
self._format_recipe_file_url = format_recipe_file_url
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
limit = int(request.query.get("limit", "20"))
|
||||||
|
cache = await recipe_scanner.get_cached_data()
|
||||||
|
|
||||||
|
tag_counts: Dict[str, int] = {}
|
||||||
|
for recipe in getattr(cache, "raw_data", []):
|
||||||
|
for tag in recipe.get("tags", []) or []:
|
||||||
|
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||||
|
|
||||||
|
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
|
||||||
|
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
|
||||||
|
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error retrieving top tags: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
cache = await recipe_scanner.get_cached_data()
|
||||||
|
|
||||||
|
base_model_counts: Dict[str, int] = {}
|
||||||
|
for recipe in getattr(cache, "raw_data", []):
|
||||||
|
base_model = recipe.get("base_model")
|
||||||
|
if base_model:
|
||||||
|
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||||
|
|
||||||
|
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
|
||||||
|
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||||
|
return web.json_response({"success": True, "base_models": sorted_models})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
lora_hash = request.query.get("hash")
|
||||||
|
if not lora_hash:
|
||||||
|
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
||||||
|
|
||||||
|
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
|
||||||
|
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error getting recipes for Lora: %s", exc)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def scan_recipes(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
self._logger.info("Manually triggering recipe cache rebuild")
|
||||||
|
await recipe_scanner.get_cached_data(force_refresh=True)
|
||||||
|
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def find_duplicates(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
duplicate_groups = await recipe_scanner.find_all_duplicate_recipes()
|
||||||
|
response_data = []
|
||||||
|
|
||||||
|
for fingerprint, recipe_ids in duplicate_groups.items():
|
||||||
|
if len(recipe_ids) <= 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
recipes = []
|
||||||
|
for recipe_id in recipe_ids:
|
||||||
|
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||||
|
if recipe:
|
||||||
|
recipes.append(
|
||||||
|
{
|
||||||
|
"id": recipe.get("id"),
|
||||||
|
"title": recipe.get("title"),
|
||||||
|
"file_url": recipe.get("file_url")
|
||||||
|
or self._format_recipe_file_url(recipe.get("file_path", "")),
|
||||||
|
"modified": recipe.get("modified"),
|
||||||
|
"created_date": recipe.get("created_date"),
|
||||||
|
"lora_count": len(recipe.get("loras", [])),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(recipes) >= 2:
|
||||||
|
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
|
||||||
|
response_data.append(
|
||||||
|
{
|
||||||
|
"fingerprint": fingerprint,
|
||||||
|
"count": len(recipes),
|
||||||
|
"recipes": recipes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data.sort(key=lambda entry: entry["count"], reverse=True)
|
||||||
|
return web.json_response({"success": True, "duplicate_groups": response_data})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
recipe_id = request.match_info["recipe_id"]
|
||||||
|
try:
|
||||||
|
syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id)
|
||||||
|
except RecipeNotFoundError:
|
||||||
|
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||||
|
|
||||||
|
if not syntax_parts:
|
||||||
|
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
||||||
|
|
||||||
|
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeManagementHandler:
|
||||||
|
"""Handle create/update/delete style recipe operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||||
|
recipe_scanner_getter: RecipeScannerGetter,
|
||||||
|
logger: Logger,
|
||||||
|
persistence_service: RecipePersistenceService,
|
||||||
|
analysis_service: RecipeAnalysisService,
|
||||||
|
) -> None:
|
||||||
|
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||||
|
self._recipe_scanner_getter = recipe_scanner_getter
|
||||||
|
self._logger = logger
|
||||||
|
self._persistence_service = persistence_service
|
||||||
|
self._analysis_service = analysis_service
|
||||||
|
|
||||||
|
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
reader = await request.multipart()
|
||||||
|
payload = await self._parse_save_payload(reader)
|
||||||
|
|
||||||
|
result = await self._persistence_service.save_recipe(
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
image_bytes=payload["image_bytes"],
|
||||||
|
image_base64=payload["image_base64"],
|
||||||
|
name=payload["name"],
|
||||||
|
tags=payload["tags"],
|
||||||
|
metadata=payload["metadata"],
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
recipe_id = request.match_info["recipe_id"]
|
||||||
|
result = await self._persistence_service.delete_recipe(
|
||||||
|
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error deleting recipe: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def update_recipe(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
recipe_id = request.match_info["recipe_id"]
|
||||||
|
data = await request.json()
|
||||||
|
result = await self._persistence_service.update_recipe(
|
||||||
|
recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=400)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def reconnect_lora(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
for field in ("recipe_id", "lora_index", "target_name"):
|
||||||
|
if field not in data:
|
||||||
|
raise RecipeValidationError(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
result = await self._persistence_service.reconnect_lora(
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
recipe_id=data["recipe_id"],
|
||||||
|
lora_index=int(data["lora_index"]),
|
||||||
|
target_name=data["target_name"],
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=400)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def bulk_delete(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
recipe_ids = data.get("recipe_ids", [])
|
||||||
|
result = await self._persistence_service.bulk_delete(
|
||||||
|
recipe_scanner=recipe_scanner, recipe_ids=recipe_ids
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error performing bulk delete: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
analysis = await self._analysis_service.analyze_widget_metadata(
|
||||||
|
recipe_scanner=recipe_scanner
|
||||||
|
)
|
||||||
|
metadata = analysis.payload.get("metadata")
|
||||||
|
image_bytes = analysis.payload.get("image_bytes")
|
||||||
|
if not metadata or image_bytes is None:
|
||||||
|
raise RecipeValidationError("Unable to extract metadata from widget")
|
||||||
|
|
||||||
|
result = await self._persistence_service.save_recipe_from_widget(
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
metadata=metadata,
|
||||||
|
image_bytes=image_bytes,
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=400)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def _parse_save_payload(self, reader) -> dict[str, Any]:
|
||||||
|
image_bytes: Optional[bytes] = None
|
||||||
|
image_base64: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
tags: list[str] = []
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
while True:
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None:
|
||||||
|
break
|
||||||
|
if field.name == "image":
|
||||||
|
image_chunks = bytearray()
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk()
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
image_chunks.extend(chunk)
|
||||||
|
image_bytes = bytes(image_chunks)
|
||||||
|
elif field.name == "image_base64":
|
||||||
|
image_base64 = await field.text()
|
||||||
|
elif field.name == "name":
|
||||||
|
name = await field.text()
|
||||||
|
elif field.name == "tags":
|
||||||
|
tags_text = await field.text()
|
||||||
|
try:
|
||||||
|
parsed_tags = json.loads(tags_text)
|
||||||
|
tags = parsed_tags if isinstance(parsed_tags, list) else []
|
||||||
|
except Exception:
|
||||||
|
tags = []
|
||||||
|
elif field.name == "metadata":
|
||||||
|
metadata_text = await field.text()
|
||||||
|
try:
|
||||||
|
metadata = json.loads(metadata_text)
|
||||||
|
except Exception:
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image_bytes": image_bytes,
|
||||||
|
"image_base64": image_base64,
|
||||||
|
"name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeAnalysisHandler:
|
||||||
|
"""Analyze images to extract recipe metadata."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||||
|
recipe_scanner_getter: RecipeScannerGetter,
|
||||||
|
civitai_client_getter: CivitaiClientGetter,
|
||||||
|
logger: Logger,
|
||||||
|
analysis_service: RecipeAnalysisService,
|
||||||
|
) -> None:
|
||||||
|
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||||
|
self._recipe_scanner_getter = recipe_scanner_getter
|
||||||
|
self._civitai_client_getter = civitai_client_getter
|
||||||
|
self._logger = logger
|
||||||
|
self._analysis_service = analysis_service
|
||||||
|
|
||||||
|
async def analyze_uploaded_image(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
civitai_client = self._civitai_client_getter()
|
||||||
|
if recipe_scanner is None or civitai_client is None:
|
||||||
|
raise RuntimeError("Required services unavailable")
|
||||||
|
|
||||||
|
content_type = request.headers.get("Content-Type", "")
|
||||||
|
if "multipart/form-data" in content_type:
|
||||||
|
reader = await request.multipart()
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None or field.name != "image":
|
||||||
|
raise RecipeValidationError("No image field found")
|
||||||
|
image_chunks = bytearray()
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk()
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
image_chunks.extend(chunk)
|
||||||
|
result = await self._analysis_service.analyze_uploaded_image(
|
||||||
|
image_bytes=bytes(image_chunks),
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
|
||||||
|
if "application/json" in content_type:
|
||||||
|
data = await request.json()
|
||||||
|
result = await self._analysis_service.analyze_remote_image(
|
||||||
|
url=data.get("url"),
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
civitai_client=civitai_client,
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
|
||||||
|
raise RecipeValidationError("Unsupported content type")
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||||
|
except RecipeDownloadError as exc:
|
||||||
|
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||||
|
|
||||||
|
async def analyze_local_image(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
result = await self._analysis_service.analyze_local_image(
|
||||||
|
file_path=data.get("path"),
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeValidationError as exc:
|
||||||
|
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error analyzing local image: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeSharingHandler:
|
||||||
|
"""Serve endpoints related to recipe sharing."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||||
|
recipe_scanner_getter: RecipeScannerGetter,
|
||||||
|
logger: Logger,
|
||||||
|
sharing_service: RecipeSharingService,
|
||||||
|
) -> None:
|
||||||
|
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||||
|
self._recipe_scanner_getter = recipe_scanner_getter
|
||||||
|
self._logger = logger
|
||||||
|
self._sharing_service = sharing_service
|
||||||
|
|
||||||
|
async def share_recipe(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
recipe_id = request.match_info["recipe_id"]
|
||||||
|
result = await self._sharing_service.share_recipe(
|
||||||
|
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||||
|
)
|
||||||
|
return web.json_response(result.payload, status=result.status)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error sharing recipe: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
try:
|
||||||
|
await self._ensure_dependencies_ready()
|
||||||
|
recipe_scanner = self._recipe_scanner_getter()
|
||||||
|
if recipe_scanner is None:
|
||||||
|
raise RuntimeError("Recipe scanner unavailable")
|
||||||
|
|
||||||
|
recipe_id = request.match_info["recipe_id"]
|
||||||
|
download_info = await self._sharing_service.prepare_download(
|
||||||
|
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||||
|
)
|
||||||
|
return web.FileResponse(
|
||||||
|
download_info.file_path,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="{download_info.download_filename}"'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except RecipeNotFoundError as exc:
|
||||||
|
return web.json_response({"error": str(exc)}, status=404)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"error": str(exc)}, status=500)
|
||||||
@@ -5,9 +5,9 @@ from typing import Dict
|
|||||||
from server import PromptServer # type: ignore
|
from server import PromptServer # type: ignore
|
||||||
|
|
||||||
from .base_model_routes import BaseModelRoutes
|
from .base_model_routes import BaseModelRoutes
|
||||||
|
from .model_route_registrar import ModelRouteRegistrar
|
||||||
from ..services.lora_service import LoraService
|
from ..services.lora_service import LoraService
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.metadata_service import get_default_metadata_provider
|
|
||||||
from ..utils.utils import get_lora_info
|
from ..utils.utils import get_lora_info
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -17,8 +17,7 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize LoRA routes with LoRA service"""
|
"""Initialize LoRA routes with LoRA service"""
|
||||||
# Service will be initialized later via setup_routes
|
super().__init__()
|
||||||
self.service = None
|
|
||||||
self.template_name = "loras.html"
|
self.template_name = "loras.html"
|
||||||
|
|
||||||
async def initialize_services(self):
|
async def initialize_services(self):
|
||||||
@@ -26,26 +25,26 @@ class LoraRoutes(BaseModelRoutes):
|
|||||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
self.service = LoraService(lora_scanner)
|
self.service = LoraService(lora_scanner)
|
||||||
|
|
||||||
# Initialize parent with the service
|
# Attach service dependencies
|
||||||
super().__init__(self.service)
|
self.attach_service(self.service)
|
||||||
|
|
||||||
def setup_routes(self, app: web.Application):
|
def setup_routes(self, app: web.Application):
|
||||||
"""Setup LoRA routes"""
|
"""Setup LoRA routes"""
|
||||||
# Schedule service initialization on app startup
|
# Schedule service initialization on app startup
|
||||||
app.on_startup.append(lambda _: self.initialize_services())
|
app.on_startup.append(lambda _: self.initialize_services())
|
||||||
|
|
||||||
# Setup common routes with 'loras' prefix (includes page route)
|
# Setup common routes with 'loras' prefix (includes page route)
|
||||||
super().setup_routes(app, 'loras')
|
super().setup_routes(app, 'loras')
|
||||||
|
|
||||||
def setup_specific_routes(self, app: web.Application, prefix: str):
|
def setup_specific_routes(self, registrar: ModelRouteRegistrar, prefix: str):
|
||||||
"""Setup LoRA-specific routes"""
|
"""Setup LoRA-specific routes"""
|
||||||
# LoRA-specific query routes
|
# LoRA-specific query routes
|
||||||
app.router.add_get(f'/api/lm/{prefix}/letter-counts', self.get_letter_counts)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/letter-counts', prefix, self.get_letter_counts)
|
||||||
app.router.add_get(f'/api/lm/{prefix}/get-trigger-words', self.get_lora_trigger_words)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/get-trigger-words', prefix, self.get_lora_trigger_words)
|
||||||
app.router.add_get(f'/api/lm/{prefix}/usage-tips-by-path', self.get_lora_usage_tips_by_path)
|
registrar.add_prefixed_route('GET', '/api/lm/{prefix}/usage-tips-by-path', prefix, self.get_lora_usage_tips_by_path)
|
||||||
|
|
||||||
# ComfyUI integration
|
# ComfyUI integration
|
||||||
app.router.add_post(f'/api/lm/{prefix}/get_trigger_words', self.get_trigger_words)
|
registrar.add_prefixed_route('POST', '/api/lm/{prefix}/get_trigger_words', prefix, self.get_trigger_words)
|
||||||
|
|
||||||
def _parse_specific_params(self, request: web.Request) -> Dict:
|
def _parse_specific_params(self, request: web.Request) -> Dict:
|
||||||
"""Parse LoRA-specific parameters"""
|
"""Parse LoRA-specific parameters"""
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..services.websocket_manager import ws_manager
|
|||||||
from ..services.downloader import get_downloader
|
from ..services.downloader import get_downloader
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
standalone_mode = 'nodes' not in sys.modules
|
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
|
||||||
# Node registry for tracking active workflow nodes
|
# Node registry for tracking active workflow nodes
|
||||||
class NodeRegistry:
|
class NodeRegistry:
|
||||||
@@ -179,7 +179,7 @@ class MiscRoutes:
|
|||||||
# Define keys that should be synced from backend to frontend
|
# Define keys that should be synced from backend to frontend
|
||||||
sync_keys = [
|
sync_keys = [
|
||||||
'civitai_api_key',
|
'civitai_api_key',
|
||||||
'default_lora_root',
|
'default_lora_root',
|
||||||
'default_checkpoint_root',
|
'default_checkpoint_root',
|
||||||
'default_embedding_root',
|
'default_embedding_root',
|
||||||
'base_model_path_mappings',
|
'base_model_path_mappings',
|
||||||
@@ -193,8 +193,15 @@ class MiscRoutes:
|
|||||||
'proxy_username',
|
'proxy_username',
|
||||||
'proxy_password',
|
'proxy_password',
|
||||||
'example_images_path',
|
'example_images_path',
|
||||||
'optimizeExampleImages',
|
'optimize_example_images',
|
||||||
'autoDownloadExampleImages'
|
'auto_download_example_images',
|
||||||
|
'blur_mature_content',
|
||||||
|
'autoplay_on_hover',
|
||||||
|
'display_density',
|
||||||
|
'card_info_display',
|
||||||
|
'include_trigger_words',
|
||||||
|
'show_only_sfw',
|
||||||
|
'compact_mode'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Build response with only the keys that should be synced
|
# Build response with only the keys that should be synced
|
||||||
|
|||||||
99
py/routes/model_route_registrar.py
Normal file
99
py/routes/model_route_registrar.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""Route registrar for model endpoints."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Iterable, Mapping
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RouteDefinition:
|
||||||
|
"""Declarative definition for a HTTP route."""
|
||||||
|
|
||||||
|
method: str
|
||||||
|
path_template: str
|
||||||
|
handler_name: str
|
||||||
|
|
||||||
|
def build_path(self, prefix: str) -> str:
|
||||||
|
return self.path_template.replace("{prefix}", prefix)
|
||||||
|
|
||||||
|
|
||||||
|
COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/list", "get_models"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/delete", "delete_model"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/exclude", "exclude_model"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/fetch-civitai", "fetch_civitai"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/bulk-delete", "bulk_delete_models"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/verify-duplicates", "verify_duplicates"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/move_model", "move_model"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||||
|
RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/scan", "scan_models"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
||||||
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
||||||
|
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||||
|
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||||
|
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||||
|
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
|
||||||
|
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRouteRegistrar:
|
||||||
|
"""Bind declarative definitions to an aiohttp router."""
|
||||||
|
|
||||||
|
_METHOD_MAP = {
|
||||||
|
"GET": "add_get",
|
||||||
|
"POST": "add_post",
|
||||||
|
"PUT": "add_put",
|
||||||
|
"DELETE": "add_delete",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, app: web.Application) -> None:
|
||||||
|
self._app = app
|
||||||
|
|
||||||
|
def register_common_routes(
|
||||||
|
self,
|
||||||
|
prefix: str,
|
||||||
|
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||||
|
*,
|
||||||
|
definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS,
|
||||||
|
) -> None:
|
||||||
|
for definition in definitions:
|
||||||
|
self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name])
|
||||||
|
|
||||||
|
def add_route(self, method: str, path: str, handler: Callable) -> None:
|
||||||
|
self._bind_route(method, path, handler)
|
||||||
|
|
||||||
|
def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None:
|
||||||
|
self._bind_route(method, path_template.replace("{prefix}", prefix), handler)
|
||||||
|
|
||||||
|
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||||
|
add_method_name = self._METHOD_MAP[method.upper()]
|
||||||
|
add_method = getattr(self._app.router, add_method_name)
|
||||||
|
add_method(path, handler)
|
||||||
|
|
||||||
64
py/routes/recipe_route_registrar.py
Normal file
64
py/routes/recipe_route_registrar.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""Route registrar for recipe endpoints."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Callable, Mapping
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class RouteDefinition:
|
||||||
|
"""Declarative definition for a recipe HTTP route."""
|
||||||
|
|
||||||
|
method: str
|
||||||
|
path: str
|
||||||
|
handler_name: str
|
||||||
|
|
||||||
|
|
||||||
|
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||||
|
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||||
|
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||||
|
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||||
|
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||||
|
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||||
|
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||||
|
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||||
|
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||||
|
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||||
|
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeRouteRegistrar:
|
||||||
|
"""Bind declarative recipe definitions to an aiohttp router."""
|
||||||
|
|
||||||
|
_METHOD_MAP = {
|
||||||
|
"GET": "add_get",
|
||||||
|
"POST": "add_post",
|
||||||
|
"PUT": "add_put",
|
||||||
|
"DELETE": "add_delete",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, app: web.Application) -> None:
|
||||||
|
self._app = app
|
||||||
|
|
||||||
|
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
|
||||||
|
for definition in ROUTE_DEFINITIONS:
|
||||||
|
handler = handler_lookup[definition.handler_name]
|
||||||
|
self._bind_route(definition.method, definition.path, handler)
|
||||||
|
|
||||||
|
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||||
|
add_method_name = self._METHOD_MAP[method.upper()]
|
||||||
|
add_method = getattr(self._app.router, add_method_name)
|
||||||
|
add_method(path, handler)
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -4,99 +4,88 @@ import logging
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from ..utils.models import BaseModelMetadata
|
from ..utils.models import BaseModelMetadata
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from .settings_manager import settings as default_settings
|
||||||
from .settings_manager import settings
|
|
||||||
from ..utils.utils import fuzzy_match
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BaseModelService(ABC):
|
class BaseModelService(ABC):
|
||||||
"""Base service class for all model types"""
|
"""Base service class for all model types"""
|
||||||
|
|
||||||
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
def __init__(
|
||||||
"""Initialize the service
|
self,
|
||||||
|
model_type: str,
|
||||||
|
scanner,
|
||||||
|
metadata_class: Type[BaseModelMetadata],
|
||||||
|
*,
|
||||||
|
cache_repository: Optional[ModelCacheRepository] = None,
|
||||||
|
filter_set: Optional[ModelFilterSet] = None,
|
||||||
|
search_strategy: Optional[SearchStrategy] = None,
|
||||||
|
settings_provider: Optional[SettingsProvider] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the service.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_type: Type of model (lora, checkpoint, etc.)
|
model_type: Type of model (lora, checkpoint, etc.).
|
||||||
scanner: Model scanner instance
|
scanner: Model scanner instance.
|
||||||
metadata_class: Metadata class for this model type
|
metadata_class: Metadata class for this model type.
|
||||||
|
cache_repository: Custom repository for cache access (primarily for tests).
|
||||||
|
filter_set: Filter component controlling folder/tag/favorites logic.
|
||||||
|
search_strategy: Search component for fuzzy/text matching.
|
||||||
|
settings_provider: Settings object; defaults to the global settings manager.
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.scanner = scanner
|
self.scanner = scanner
|
||||||
self.metadata_class = metadata_class
|
self.metadata_class = metadata_class
|
||||||
|
self.settings = settings_provider or default_settings
|
||||||
|
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||||
|
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||||
|
self.search_strategy = search_strategy or SearchStrategy()
|
||||||
|
|
||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
async def get_paginated_data(
|
||||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
self,
|
||||||
base_models: list = None, tags: list = None,
|
page: int,
|
||||||
search_options: dict = None, hash_filters: dict = None,
|
page_size: int,
|
||||||
favorites_only: bool = False, **kwargs) -> Dict:
|
sort_by: str = 'name',
|
||||||
"""Get paginated and filtered model data
|
folder: str = None,
|
||||||
|
search: str = None,
|
||||||
Args:
|
fuzzy_search: bool = False,
|
||||||
page: Page number (1-based)
|
base_models: list = None,
|
||||||
page_size: Number of items per page
|
tags: list = None,
|
||||||
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
|
search_options: dict = None,
|
||||||
folder: Folder filter
|
hash_filters: dict = None,
|
||||||
search: Search term
|
favorites_only: bool = False,
|
||||||
fuzzy_search: Whether to use fuzzy search
|
**kwargs,
|
||||||
base_models: List of base models to filter by
|
) -> Dict:
|
||||||
tags: List of tags to filter by
|
"""Get paginated and filtered model data"""
|
||||||
search_options: Search options dict
|
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||||
hash_filters: Hash filtering options
|
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||||
favorites_only: Filter for favorites only
|
|
||||||
**kwargs: Additional model-specific filters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing paginated results
|
|
||||||
"""
|
|
||||||
cache = await self.scanner.get_cached_data()
|
|
||||||
|
|
||||||
# Parse sort_by into sort_key and order
|
|
||||||
if ':' in sort_by:
|
|
||||||
sort_key, order = sort_by.split(':', 1)
|
|
||||||
sort_key = sort_key.strip()
|
|
||||||
order = order.strip().lower()
|
|
||||||
if order not in ('asc', 'desc'):
|
|
||||||
order = 'asc'
|
|
||||||
else:
|
|
||||||
sort_key = sort_by.strip()
|
|
||||||
order = 'asc'
|
|
||||||
|
|
||||||
# Get default search options if not provided
|
|
||||||
if search_options is None:
|
|
||||||
search_options = {
|
|
||||||
'filename': True,
|
|
||||||
'modelname': True,
|
|
||||||
'tags': False,
|
|
||||||
'recursive': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Get the base data set using new sort logic
|
|
||||||
filtered_data = await cache.get_sorted_data(sort_key, order)
|
|
||||||
|
|
||||||
# Apply hash filtering if provided (highest priority)
|
|
||||||
if hash_filters:
|
if hash_filters:
|
||||||
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||||
|
|
||||||
# Jump to pagination for hash filters
|
|
||||||
return self._paginate(filtered_data, page, page_size)
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
# Apply common filters
|
|
||||||
filtered_data = await self._apply_common_filters(
|
filtered_data = await self._apply_common_filters(
|
||||||
filtered_data, folder, base_models, tags, favorites_only, search_options
|
sorted_data,
|
||||||
|
folder=folder,
|
||||||
|
base_models=base_models,
|
||||||
|
tags=tags,
|
||||||
|
favorites_only=favorites_only,
|
||||||
|
search_options=search_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply search filtering
|
|
||||||
if search:
|
if search:
|
||||||
filtered_data = await self._apply_search_filters(
|
filtered_data = await self._apply_search_filters(
|
||||||
filtered_data, search, fuzzy_search, search_options
|
filtered_data,
|
||||||
|
search,
|
||||||
|
fuzzy_search,
|
||||||
|
search_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply model-specific filters
|
|
||||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||||
|
|
||||||
return self._paginate(filtered_data, page, page_size)
|
return self._paginate(filtered_data, page, page_size)
|
||||||
|
|
||||||
|
|
||||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||||
"""Apply hash-based filtering"""
|
"""Apply hash-based filtering"""
|
||||||
@@ -120,113 +109,36 @@ class BaseModelService(ABC):
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
async def _apply_common_filters(
|
||||||
base_models: list = None, tags: list = None,
|
self,
|
||||||
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
data: List[Dict],
|
||||||
|
folder: str = None,
|
||||||
|
base_models: list = None,
|
||||||
|
tags: list = None,
|
||||||
|
favorites_only: bool = False,
|
||||||
|
search_options: dict = None,
|
||||||
|
) -> List[Dict]:
|
||||||
"""Apply common filters that work across all model types"""
|
"""Apply common filters that work across all model types"""
|
||||||
# Apply SFW filtering if enabled in settings
|
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||||
if settings.get('show_only_sfw', False):
|
criteria = FilterCriteria(
|
||||||
data = [
|
folder=folder,
|
||||||
item for item in data
|
base_models=base_models,
|
||||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
tags=tags,
|
||||||
]
|
favorites_only=favorites_only,
|
||||||
|
search_options=normalized_options,
|
||||||
# Apply favorites filtering if enabled
|
)
|
||||||
if favorites_only:
|
return self.filter_set.apply(data, criteria)
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if item.get('favorite', False) is True
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply folder filtering
|
|
||||||
if folder is not None:
|
|
||||||
if search_options and search_options.get('recursive', True):
|
|
||||||
# Recursive folder filtering - include all subfolders
|
|
||||||
# Ensure we match exact folder or its subfolders by checking path boundaries
|
|
||||||
if folder == "":
|
|
||||||
# Empty folder means root - include all items
|
|
||||||
pass # Don't filter anything
|
|
||||||
else:
|
|
||||||
# Add trailing slash to ensure we match folder boundaries correctly
|
|
||||||
folder_with_separator = folder + "/"
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if (item['folder'] == folder or
|
|
||||||
item['folder'].startswith(folder_with_separator))
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Exact folder filtering
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if item['folder'] == folder
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply base model filtering
|
|
||||||
if base_models and len(base_models) > 0:
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if item.get('base_model') in base_models
|
|
||||||
]
|
|
||||||
|
|
||||||
# Apply tag filtering
|
|
||||||
if tags and len(tags) > 0:
|
|
||||||
data = [
|
|
||||||
item for item in data
|
|
||||||
if any(tag in item.get('tags', []) for tag in tags)
|
|
||||||
]
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
async def _apply_search_filters(self, data: List[Dict], search: str,
|
async def _apply_search_filters(
|
||||||
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
self,
|
||||||
|
data: List[Dict],
|
||||||
|
search: str,
|
||||||
|
fuzzy_search: bool,
|
||||||
|
search_options: dict,
|
||||||
|
) -> List[Dict]:
|
||||||
"""Apply search filtering"""
|
"""Apply search filtering"""
|
||||||
search_results = []
|
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||||
|
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
|
||||||
for item in data:
|
|
||||||
# Search by file name
|
|
||||||
if search_options.get('filename', True):
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(item.get('file_name', ''), search):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
elif search.lower() in item.get('file_name', '').lower():
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by model name
|
|
||||||
if search_options.get('modelname', True):
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(item.get('model_name', ''), search):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
elif search.lower() in item.get('model_name', '').lower():
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by tags
|
|
||||||
if search_options.get('tags', False) and 'tags' in item:
|
|
||||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
|
||||||
for tag in item['tags']):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Search by creator
|
|
||||||
civitai = item.get('civitai')
|
|
||||||
creator_username = ''
|
|
||||||
if civitai and isinstance(civitai, dict):
|
|
||||||
creator = civitai.get('creator')
|
|
||||||
if creator and isinstance(creator, dict):
|
|
||||||
creator_username = creator.get('username', '')
|
|
||||||
if search_options.get('creator', False) and creator_username:
|
|
||||||
if fuzzy_search:
|
|
||||||
if fuzzy_match(creator_username, search):
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
elif search.lower() in creator_username.lower():
|
|
||||||
search_results.append(item)
|
|
||||||
continue
|
|
||||||
|
|
||||||
return search_results
|
|
||||||
|
|
||||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||||
@@ -284,6 +196,18 @@ class BaseModelService(ABC):
|
|||||||
"""Get model root directories"""
|
"""Get model root directories"""
|
||||||
return self.scanner.get_model_roots()
|
return self.scanner.get_model_roots()
|
||||||
|
|
||||||
|
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
||||||
|
"""Filter relevant fields from CivitAI data"""
|
||||||
|
if not data:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||||
|
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||||
|
"publishedAt", "trainedWords", "baseModel", "description",
|
||||||
|
"model", "images", "customImages", "creator"
|
||||||
|
]
|
||||||
|
return {k: data[k] for k in fields if k in data}
|
||||||
|
|
||||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||||
"""Get hierarchical folder tree for a specific model root"""
|
"""Get hierarchical folder tree for a specific model root"""
|
||||||
cache = await self.scanner.get_cached_data()
|
cache = await self.scanner.get_cached_data()
|
||||||
@@ -394,7 +318,7 @@ class BaseModelService(ABC):
|
|||||||
|
|
||||||
for model in cache.raw_data:
|
for model in cache.raw_data:
|
||||||
if model.get('file_path') == file_path:
|
if model.get('file_path') == file_path:
|
||||||
return ModelRouteUtils.filter_civitai_data(model.get("civitai", {}))
|
return self.filter_civitai_data(model.get("civitai", {}))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict
|
||||||
|
|
||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from ..utils.models import CheckpointMetadata
|
from ..utils.models import CheckpointMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ class CheckpointService(BaseModelService):
|
|||||||
"notes": checkpoint_data.get("notes", ""),
|
"notes": checkpoint_data.get("notes", ""),
|
||||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||||
"favorite": checkpoint_data.get("favorite", False),
|
"favorite": checkpoint_data.get("favorite", False),
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_duplicate_hashes(self) -> Dict:
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Dict, Tuple, List
|
from typing import Optional, Dict, Tuple, List
|
||||||
@@ -189,31 +190,76 @@ class CivitaiClient:
|
|||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model_versions = data.get('modelVersions', [])
|
model_versions = data.get('modelVersions', [])
|
||||||
|
if not model_versions:
|
||||||
# Step 2: Determine the version_id to use
|
logger.warning(f"No model versions found for model {model_id}")
|
||||||
target_version_id = version_id
|
|
||||||
if target_version_id is None:
|
|
||||||
target_version_id = model_versions[0].get('id')
|
|
||||||
|
|
||||||
# Step 3: Get detailed version info using the version_id
|
|
||||||
success, version = await downloader.make_request(
|
|
||||||
'GET',
|
|
||||||
f"{self.base_url}/model-versions/{target_version_id}",
|
|
||||||
use_auth=True
|
|
||||||
)
|
|
||||||
if not success:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Step 2: Determine the target version entry to use
|
||||||
|
target_version = None
|
||||||
|
if version_id is not None:
|
||||||
|
target_version = next(
|
||||||
|
(item for item in model_versions if item.get('id') == version_id),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if target_version is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Version {version_id} not found for model {model_id}, defaulting to first version"
|
||||||
|
)
|
||||||
|
if target_version is None:
|
||||||
|
target_version = model_versions[0]
|
||||||
|
|
||||||
|
target_version_id = target_version.get('id')
|
||||||
|
|
||||||
|
# Step 3: Get detailed version info using the SHA256 hash
|
||||||
|
model_hash = None
|
||||||
|
for file_info in target_version.get('files', []):
|
||||||
|
if file_info.get('type') == 'Model' and file_info.get('primary'):
|
||||||
|
model_hash = file_info.get('hashes', {}).get('SHA256')
|
||||||
|
if model_hash:
|
||||||
|
break
|
||||||
|
|
||||||
|
version = None
|
||||||
|
if model_hash:
|
||||||
|
success, version = await downloader.make_request(
|
||||||
|
'GET',
|
||||||
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to fetch version by hash for model {model_id} version {target_version_id}: {version}"
|
||||||
|
)
|
||||||
|
version = None
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"No primary model hash found for model {model_id} version {target_version_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if version is None:
|
||||||
|
version = copy.deepcopy(target_version)
|
||||||
|
version.pop('index', None)
|
||||||
|
version['modelId'] = model_id
|
||||||
|
version['model'] = {
|
||||||
|
'name': data.get('name'),
|
||||||
|
'type': data.get('type'),
|
||||||
|
'nsfw': data.get('nsfw'),
|
||||||
|
'poi': data.get('poi')
|
||||||
|
}
|
||||||
|
|
||||||
# Step 4: Enrich version_info with model data
|
# Step 4: Enrich version_info with model data
|
||||||
# Add description and tags from model data
|
# Add description and tags from model data
|
||||||
version['model']['description'] = data.get("description")
|
model_info = version.get('model')
|
||||||
version['model']['tags'] = data.get("tags", [])
|
if not isinstance(model_info, dict):
|
||||||
|
model_info = {}
|
||||||
|
version['model'] = model_info
|
||||||
|
model_info['description'] = data.get("description")
|
||||||
|
model_info['tags'] = data.get("tags", [])
|
||||||
|
|
||||||
# Add creator from model data
|
# Add creator from model data
|
||||||
version['creator'] = data.get("creator")
|
version['creator'] = data.get("creator")
|
||||||
|
|
||||||
return version
|
return version
|
||||||
|
|
||||||
# Case 3: Neither model_id nor version_id provided
|
# Case 3: Neither model_id nor version_id provided
|
||||||
|
|||||||
100
py/services/download_coordinator.py
Normal file
100
py/services/download_coordinator.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""Service wrapper for coordinating download lifecycle events."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadCoordinator:
|
||||||
|
"""Manage download scheduling, cancellation and introspection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ws_manager,
|
||||||
|
download_manager_factory: Callable[[], Awaitable],
|
||||||
|
) -> None:
|
||||||
|
self._ws_manager = ws_manager
|
||||||
|
self._download_manager_factory = download_manager_factory
|
||||||
|
|
||||||
|
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Schedule a download using the provided payload."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
|
||||||
|
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||||
|
payload.setdefault("download_id", download_id)
|
||||||
|
|
||||||
|
async def progress_callback(progress: Any) -> None:
|
||||||
|
await self._ws_manager.broadcast_download_progress(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"status": "progress",
|
||||||
|
"progress": progress,
|
||||||
|
"download_id": download_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||||
|
model_version_id = self._parse_optional_int(
|
||||||
|
payload.get("model_version_id"), "model_version_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_id is None and model_version_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await download_manager.download_from_civitai(
|
||||||
|
model_id=model_id,
|
||||||
|
model_version_id=model_version_id,
|
||||||
|
save_dir=payload.get("model_root"),
|
||||||
|
relative_path=payload.get("relative_path", ""),
|
||||||
|
use_default_paths=payload.get("use_default_paths", False),
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
download_id=download_id,
|
||||||
|
source=payload.get("source"),
|
||||||
|
)
|
||||||
|
|
||||||
|
result["download_id"] = download_id
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
"""Cancel an active download and emit a broadcast event."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
result = await download_manager.cancel_download(download_id)
|
||||||
|
|
||||||
|
await self._ws_manager.broadcast_download_progress(
|
||||||
|
download_id,
|
||||||
|
{
|
||||||
|
"status": "cancelled",
|
||||||
|
"progress": 0,
|
||||||
|
"download_id": download_id,
|
||||||
|
"message": "Download cancelled by user",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||||
|
"""Return the active download map from the underlying manager."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
return await download_manager.get_active_downloads()
|
||||||
|
|
||||||
|
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
|
||||||
|
"""Parse an optional integer from user input."""
|
||||||
|
|
||||||
|
if value is None or value == "":
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise ValueError(f"Invalid {field}: Must be an integer") from exc
|
||||||
|
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict
|
||||||
|
|
||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from ..utils.models import EmbeddingMetadata
|
from ..utils.models import EmbeddingMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ class EmbeddingService(BaseModelService):
|
|||||||
"notes": embedding_data.get("notes", ""),
|
"notes": embedding_data.get("notes", ""),
|
||||||
"model_type": embedding_data.get("model_type", "embedding"),
|
"model_type": embedding_data.get("model_type", "embedding"),
|
||||||
"favorite": embedding_data.get("favorite", False),
|
"favorite": embedding_data.get("favorite", False),
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
def find_duplicate_hashes(self) -> Dict:
|
def find_duplicate_hashes(self) -> Dict:
|
||||||
|
|||||||
246
py/services/example_images_cleanup_service.py
Normal file
246
py/services/example_images_cleanup_service.py
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
"""Service for cleaning up example image folders."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from .service_registry import ServiceRegistry
|
||||||
|
from .settings_manager import settings
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class CleanupResult:
|
||||||
|
"""Structured result returned from cleanup operations."""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
checked_folders: int
|
||||||
|
moved_empty_folders: int
|
||||||
|
moved_orphaned_folders: int
|
||||||
|
skipped_non_hash: int
|
||||||
|
move_failures: int
|
||||||
|
errors: List[str]
|
||||||
|
deleted_root: str | None
|
||||||
|
partial_success: bool
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, object]:
|
||||||
|
"""Convert the dataclass to a serialisable dictionary."""
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"success": self.success,
|
||||||
|
"checked_folders": self.checked_folders,
|
||||||
|
"moved_empty_folders": self.moved_empty_folders,
|
||||||
|
"moved_orphaned_folders": self.moved_orphaned_folders,
|
||||||
|
"moved_total": self.moved_empty_folders + self.moved_orphaned_folders,
|
||||||
|
"skipped_non_hash": self.skipped_non_hash,
|
||||||
|
"move_failures": self.move_failures,
|
||||||
|
"errors": self.errors,
|
||||||
|
"deleted_root": self.deleted_root,
|
||||||
|
"partial_success": self.partial_success,
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagesCleanupService:
|
||||||
|
"""Encapsulates logic for cleaning example image folders."""
|
||||||
|
|
||||||
|
DELETED_FOLDER_NAME = "_deleted"
|
||||||
|
|
||||||
|
def __init__(self, deleted_folder_name: str | None = None) -> None:
|
||||||
|
self._deleted_folder_name = deleted_folder_name or self.DELETED_FOLDER_NAME
|
||||||
|
|
||||||
|
async def cleanup_example_image_folders(self) -> Dict[str, object]:
|
||||||
|
"""Clean empty or orphaned example image folders by moving them under a deleted bucket."""
|
||||||
|
|
||||||
|
example_images_path = settings.get("example_images_path")
|
||||||
|
if not example_images_path:
|
||||||
|
logger.debug("Cleanup skipped: example images path not configured")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Example images path is not configured.",
|
||||||
|
"error_code": "path_not_configured",
|
||||||
|
}
|
||||||
|
|
||||||
|
example_root = Path(example_images_path)
|
||||||
|
if not example_root.exists():
|
||||||
|
logger.debug("Cleanup skipped: example images path missing -> %s", example_root)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Example images path does not exist.",
|
||||||
|
"error_code": "path_not_found",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.error("Failed to acquire scanners for cleanup: %s", exc, exc_info=True)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Failed to load model scanners: {exc}",
|
||||||
|
"error_code": "scanner_initialization_failed",
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_bucket = example_root / self._deleted_folder_name
|
||||||
|
deleted_bucket.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
checked_folders = 0
|
||||||
|
moved_empty = 0
|
||||||
|
moved_orphaned = 0
|
||||||
|
skipped_non_hash = 0
|
||||||
|
move_failures = 0
|
||||||
|
errors: List[str] = []
|
||||||
|
|
||||||
|
for entry in os.scandir(example_root):
|
||||||
|
if not entry.is_dir(follow_symlinks=False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if entry.name == self._deleted_folder_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
checked_folders += 1
|
||||||
|
folder_path = Path(entry.path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if self._is_folder_empty(folder_path):
|
||||||
|
if await self._remove_empty_folder(folder_path):
|
||||||
|
moved_empty += 1
|
||||||
|
else:
|
||||||
|
move_failures += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not self._is_hash_folder(entry.name):
|
||||||
|
skipped_non_hash += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
hash_exists = (
|
||||||
|
lora_scanner.has_hash(entry.name)
|
||||||
|
or checkpoint_scanner.has_hash(entry.name)
|
||||||
|
or embedding_scanner.has_hash(entry.name)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not hash_exists:
|
||||||
|
if await self._move_folder(folder_path, deleted_bucket):
|
||||||
|
moved_orphaned += 1
|
||||||
|
else:
|
||||||
|
move_failures += 1
|
||||||
|
|
||||||
|
except Exception as exc: # pragma: no cover - filesystem guard
|
||||||
|
move_failures += 1
|
||||||
|
error_message = f"{entry.name}: {exc}"
|
||||||
|
errors.append(error_message)
|
||||||
|
logger.error("Error processing example images folder %s: %s", folder_path, exc, exc_info=True)
|
||||||
|
|
||||||
|
partial_success = move_failures > 0 and (moved_empty > 0 or moved_orphaned > 0)
|
||||||
|
success = move_failures == 0 and not errors
|
||||||
|
|
||||||
|
result = CleanupResult(
|
||||||
|
success=success,
|
||||||
|
checked_folders=checked_folders,
|
||||||
|
moved_empty_folders=moved_empty,
|
||||||
|
moved_orphaned_folders=moved_orphaned,
|
||||||
|
skipped_non_hash=skipped_non_hash,
|
||||||
|
move_failures=move_failures,
|
||||||
|
errors=errors,
|
||||||
|
deleted_root=str(deleted_bucket),
|
||||||
|
partial_success=partial_success,
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = result.to_dict()
|
||||||
|
if success:
|
||||||
|
logger.info(
|
||||||
|
"Example images cleanup complete: checked=%s, moved_empty=%s, moved_orphaned=%s",
|
||||||
|
checked_folders,
|
||||||
|
moved_empty,
|
||||||
|
moved_orphaned,
|
||||||
|
)
|
||||||
|
elif partial_success:
|
||||||
|
logger.warning(
|
||||||
|
"Example images cleanup partially complete: moved=%s, failures=%s",
|
||||||
|
summary["moved_total"],
|
||||||
|
move_failures,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Example images cleanup failed: move_failures=%s, errors=%s",
|
||||||
|
move_failures,
|
||||||
|
errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
return summary
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_folder_empty(folder_path: Path) -> bool:
|
||||||
|
try:
|
||||||
|
with os.scandir(folder_path) as iterator:
|
||||||
|
return not any(iterator)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return True
|
||||||
|
except OSError as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.debug("Failed to inspect folder %s: %s", folder_path, exc)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_hash_folder(name: str) -> bool:
|
||||||
|
if len(name) != 64:
|
||||||
|
return False
|
||||||
|
hex_chars = set("0123456789abcdefABCDEF")
|
||||||
|
return all(char in hex_chars for char in name)
|
||||||
|
|
||||||
|
async def _remove_empty_folder(self, folder_path: Path) -> bool:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
shutil.rmtree,
|
||||||
|
str(folder_path),
|
||||||
|
)
|
||||||
|
logger.debug("Removed empty example images folder %s", folder_path)
|
||||||
|
return True
|
||||||
|
except Exception as exc: # pragma: no cover - filesystem guard
|
||||||
|
logger.error("Failed to remove empty example images folder %s: %s", folder_path, exc, exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _move_folder(self, folder_path: Path, deleted_bucket: Path) -> bool:
|
||||||
|
destination = self._build_destination(folder_path.name, deleted_bucket)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
shutil.move,
|
||||||
|
str(folder_path),
|
||||||
|
str(destination),
|
||||||
|
)
|
||||||
|
logger.debug("Moved example images folder %s -> %s", folder_path, destination)
|
||||||
|
return True
|
||||||
|
except Exception as exc: # pragma: no cover - filesystem guard
|
||||||
|
logger.error(
|
||||||
|
"Failed to move example images folder %s to %s: %s",
|
||||||
|
folder_path,
|
||||||
|
destination,
|
||||||
|
exc,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _build_destination(self, folder_name: str, deleted_bucket: Path) -> Path:
|
||||||
|
destination = deleted_bucket / folder_name
|
||||||
|
suffix = 1
|
||||||
|
|
||||||
|
while destination.exists():
|
||||||
|
destination = deleted_bucket / f"{folder_name}_{suffix}"
|
||||||
|
suffix += 1
|
||||||
|
|
||||||
|
return destination
|
||||||
@@ -5,7 +5,6 @@ from typing import Dict, List, Optional
|
|||||||
from .base_model_service import BaseModelService
|
from .base_model_service import BaseModelService
|
||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ class LoraService(BaseModelService):
|
|||||||
"usage_tips": lora_data.get("usage_tips", ""),
|
"usage_tips": lora_data.get("usage_tips", ""),
|
||||||
"notes": lora_data.get("notes", ""),
|
"notes": lora_data.get("notes", ""),
|
||||||
"favorite": lora_data.get("favorite", False),
|
"favorite": lora_data.get("favorite", False),
|
||||||
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||||
|
|||||||
355
py/services/metadata_sync_service.py
Normal file
355
py/services/metadata_sync_service.py
Normal file
@@ -0,0 +1,355 @@
|
|||||||
|
"""Services for synchronising metadata with remote providers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
from ..services.settings_manager import SettingsManager
|
||||||
|
from ..utils.model_utils import determine_base_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataProviderProtocol:
|
||||||
|
"""Subset of metadata provider interface consumed by the sync service."""
|
||||||
|
|
||||||
|
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
async def get_model_version(
|
||||||
|
self, model_id: int, model_version_id: Optional[int]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataSyncService:
|
||||||
|
"""High level orchestration for metadata synchronisation flows."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata_manager,
|
||||||
|
preview_service,
|
||||||
|
settings: SettingsManager,
|
||||||
|
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
|
||||||
|
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
|
||||||
|
) -> None:
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
self._preview_service = preview_service
|
||||||
|
self._settings = settings
|
||||||
|
self._get_default_provider = default_metadata_provider_factory
|
||||||
|
self._get_provider = metadata_provider_selector
|
||||||
|
|
||||||
|
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
|
||||||
|
"""Load metadata JSON from disk, returning an empty structure when missing."""
|
||||||
|
|
||||||
|
if not os.path.exists(metadata_path):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||||
|
return json.load(handle)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def mark_not_found_on_civitai(
|
||||||
|
self, metadata_path: str, local_metadata: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
|
"""Persist the not-found flag for a metadata payload."""
|
||||||
|
|
||||||
|
local_metadata["from_civitai"] = False
|
||||||
|
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
|
||||||
|
"""Determine if the metadata originated from the CivitAI public API."""
|
||||||
|
|
||||||
|
if not isinstance(meta, dict):
|
||||||
|
return False
|
||||||
|
files = meta.get("files")
|
||||||
|
images = meta.get("images")
|
||||||
|
source = meta.get("source")
|
||||||
|
return bool(files) and bool(images) and source != "archive_db"
|
||||||
|
|
||||||
|
async def update_model_metadata(
|
||||||
|
self,
|
||||||
|
metadata_path: str,
|
||||||
|
local_metadata: Dict[str, Any],
|
||||||
|
civitai_metadata: Dict[str, Any],
|
||||||
|
metadata_provider: Optional[MetadataProviderProtocol] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge remote metadata into the local record and persist the result."""
|
||||||
|
|
||||||
|
existing_civitai = local_metadata.get("civitai") or {}
|
||||||
|
|
||||||
|
if (
|
||||||
|
civitai_metadata.get("source") == "archive_db"
|
||||||
|
and self.is_civitai_api_metadata(existing_civitai)
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"Skip civitai update for %s (%s)",
|
||||||
|
local_metadata.get("model_name", ""),
|
||||||
|
existing_civitai.get("name", ""),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged_civitai = existing_civitai.copy()
|
||||||
|
merged_civitai.update(civitai_metadata)
|
||||||
|
|
||||||
|
if civitai_metadata.get("source") == "archive_db":
|
||||||
|
model_name = civitai_metadata.get("model", {}).get("name", "")
|
||||||
|
version_name = civitai_metadata.get("name", "")
|
||||||
|
logger.info(
|
||||||
|
"Recovered metadata from archive_db for deleted model: %s (%s)",
|
||||||
|
model_name,
|
||||||
|
version_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
if "trainedWords" in existing_civitai:
|
||||||
|
existing_trained = existing_civitai.get("trainedWords", [])
|
||||||
|
new_trained = civitai_metadata.get("trainedWords", [])
|
||||||
|
merged_trained = list(set(existing_trained + new_trained))
|
||||||
|
merged_civitai["trainedWords"] = merged_trained
|
||||||
|
|
||||||
|
local_metadata["civitai"] = merged_civitai
|
||||||
|
|
||||||
|
if "model" in civitai_metadata and civitai_metadata["model"]:
|
||||||
|
model_data = civitai_metadata["model"]
|
||||||
|
|
||||||
|
if model_data.get("name"):
|
||||||
|
local_metadata["model_name"] = model_data["name"]
|
||||||
|
|
||||||
|
if not local_metadata.get("modelDescription") and model_data.get("description"):
|
||||||
|
local_metadata["modelDescription"] = model_data["description"]
|
||||||
|
|
||||||
|
if not local_metadata.get("tags") and model_data.get("tags"):
|
||||||
|
local_metadata["tags"] = model_data["tags"]
|
||||||
|
|
||||||
|
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
|
||||||
|
"creator"
|
||||||
|
):
|
||||||
|
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||||
|
|
||||||
|
local_metadata["base_model"] = determine_base_model(
|
||||||
|
civitai_metadata.get("baseModel")
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._preview_service.ensure_preview_for_metadata(
|
||||||
|
metadata_path, local_metadata, civitai_metadata.get("images", [])
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||||
|
return local_metadata
|
||||||
|
|
||||||
|
async def fetch_and_update_model(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sha256: str,
|
||||||
|
file_path: str,
|
||||||
|
model_data: Dict[str, Any],
|
||||||
|
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||||
|
) -> tuple[bool, Optional[str]]:
|
||||||
|
"""Fetch metadata for a model and update both disk and cache state."""
|
||||||
|
|
||||||
|
if not isinstance(model_data, dict):
|
||||||
|
error = f"Invalid model_data type: {type(model_data)}"
|
||||||
|
logger.error(error)
|
||||||
|
return False, error
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
enable_archive = self._settings.get("enable_metadata_archive_db", False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model_data.get("civitai_deleted") is True:
|
||||||
|
if not enable_archive or model_data.get("db_checked") is True:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"CivitAI model is deleted and metadata archive DB is not enabled",
|
||||||
|
)
|
||||||
|
metadata_provider = await self._get_provider("sqlite")
|
||||||
|
else:
|
||||||
|
metadata_provider = await self._get_default_provider()
|
||||||
|
|
||||||
|
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||||
|
if not civitai_metadata:
|
||||||
|
if error == "Model not found":
|
||||||
|
model_data["from_civitai"] = False
|
||||||
|
model_data["civitai_deleted"] = True
|
||||||
|
model_data["db_checked"] = enable_archive
|
||||||
|
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||||
|
|
||||||
|
data_to_save = model_data.copy()
|
||||||
|
data_to_save.pop("folder", None)
|
||||||
|
await self._metadata_manager.save_metadata(file_path, data_to_save)
|
||||||
|
|
||||||
|
error_msg = (
|
||||||
|
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
|
||||||
|
)
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
model_data["from_civitai"] = True
|
||||||
|
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
|
||||||
|
model_data["db_checked"] = enable_archive
|
||||||
|
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||||
|
|
||||||
|
local_metadata = model_data.copy()
|
||||||
|
local_metadata.pop("folder", None)
|
||||||
|
|
||||||
|
await self.update_model_metadata(
|
||||||
|
metadata_path,
|
||||||
|
local_metadata,
|
||||||
|
civitai_metadata,
|
||||||
|
metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_payload = {
|
||||||
|
"model_name": local_metadata.get("model_name"),
|
||||||
|
"preview_url": local_metadata.get("preview_url"),
|
||||||
|
"civitai": local_metadata.get("civitai"),
|
||||||
|
}
|
||||||
|
model_data.update(update_payload)
|
||||||
|
|
||||||
|
await update_cache_func(file_path, file_path, local_metadata)
|
||||||
|
return True, None
|
||||||
|
except KeyError as exc:
|
||||||
|
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
except Exception as exc: # pragma: no cover - error path
|
||||||
|
error_msg = f"Error fetching metadata: {exc}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
async def fetch_metadata_by_sha(
|
||||||
|
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
|
||||||
|
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||||
|
"""Fetch metadata for a SHA256 hash from the configured provider."""
|
||||||
|
|
||||||
|
provider = metadata_provider or await self._get_default_provider()
|
||||||
|
return await provider.get_model_by_hash(sha256)
|
||||||
|
|
||||||
|
async def relink_metadata(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
metadata: Dict[str, Any],
|
||||||
|
model_id: int,
|
||||||
|
model_version_id: Optional[int],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Relink a local metadata record to a specific CivitAI model version."""
|
||||||
|
|
||||||
|
provider = await self._get_default_provider()
|
||||||
|
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
|
||||||
|
if not civitai_metadata:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model version not found on CivitAI for ID: {model_id}"
|
||||||
|
+ (f" with version: {model_version_id}" if model_version_id else "")
|
||||||
|
)
|
||||||
|
|
||||||
|
primary_model_file: Optional[Dict[str, Any]] = None
|
||||||
|
for file_info in civitai_metadata.get("files", []):
|
||||||
|
if file_info.get("primary", False) and file_info.get("type") == "Model":
|
||||||
|
primary_model_file = file_info
|
||||||
|
break
|
||||||
|
|
||||||
|
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
|
||||||
|
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
await self.update_model_metadata(
|
||||||
|
metadata_path,
|
||||||
|
metadata,
|
||||||
|
civitai_metadata,
|
||||||
|
provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
async def save_metadata_updates(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
updates: Dict[str, Any],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||||
|
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Apply metadata updates and persist to disk and cache."""
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
|
||||||
|
for key, value in updates.items():
|
||||||
|
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
|
||||||
|
metadata[key].update(value)
|
||||||
|
else:
|
||||||
|
metadata[key] = value
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||||
|
await update_cache(file_path, file_path, metadata)
|
||||||
|
|
||||||
|
if "model_name" in updates:
|
||||||
|
logger.debug("Metadata update touched model_name; cache resort required")
|
||||||
|
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
async def verify_duplicate_hashes(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_paths: Iterable[str],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||||
|
hash_calculator: Callable[[str], Awaitable[str]],
|
||||||
|
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Verify a collection of files share the same SHA256 hash."""
|
||||||
|
|
||||||
|
file_paths = list(file_paths)
|
||||||
|
if not file_paths:
|
||||||
|
raise ValueError("No file paths provided for verification")
|
||||||
|
|
||||||
|
results = {
|
||||||
|
"verified_as_duplicates": True,
|
||||||
|
"mismatched_files": [],
|
||||||
|
"new_hash_map": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
expected_hash: Optional[str] = None
|
||||||
|
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
|
||||||
|
first_metadata = await metadata_loader(first_metadata_path)
|
||||||
|
if first_metadata and "sha256" in first_metadata:
|
||||||
|
expected_hash = first_metadata["sha256"].lower()
|
||||||
|
|
||||||
|
for path in file_paths:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
actual_hash = await hash_calculator(path)
|
||||||
|
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
stored_hash = metadata.get("sha256", "").lower()
|
||||||
|
|
||||||
|
if not expected_hash:
|
||||||
|
expected_hash = stored_hash
|
||||||
|
|
||||||
|
if actual_hash != expected_hash:
|
||||||
|
results["verified_as_duplicates"] = False
|
||||||
|
results["mismatched_files"].append(path)
|
||||||
|
results["new_hash_map"][path] = actual_hash
|
||||||
|
|
||||||
|
if actual_hash != stored_hash:
|
||||||
|
metadata["sha256"] = actual_hash
|
||||||
|
await self._metadata_manager.save_metadata(path, metadata)
|
||||||
|
await update_cache(path, path, metadata)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.error("Error verifying hash for %s: %s", path, exc)
|
||||||
|
results["mismatched_files"].append(path)
|
||||||
|
results["new_hash_map"][path] = "error_calculating_hash"
|
||||||
|
results["verified_as_duplicates"] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
245
py/services/model_lifecycle_service.py
Normal file
245
py/services/model_lifecycle_service.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Service routines for model lifecycle mutations."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||||
|
"""Delete the primary model artefacts within ``target_dir``."""
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
f"{file_name}.safetensors",
|
||||||
|
f"{file_name}.metadata.json",
|
||||||
|
]
|
||||||
|
for ext in PREVIEW_EXTENSIONS:
|
||||||
|
patterns.append(f"{file_name}{ext}")
|
||||||
|
|
||||||
|
deleted: List[str] = []
|
||||||
|
main_file = patterns[0]
|
||||||
|
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
|
||||||
|
|
||||||
|
if os.path.exists(main_path):
|
||||||
|
os.remove(main_path)
|
||||||
|
deleted.append(main_path)
|
||||||
|
else:
|
||||||
|
logger.warning("Model file not found: %s", main_file)
|
||||||
|
|
||||||
|
for pattern in patterns[1:]:
|
||||||
|
path = os.path.join(target_dir, pattern)
|
||||||
|
if os.path.exists(path):
|
||||||
|
try:
|
||||||
|
os.remove(path)
|
||||||
|
deleted.append(pattern)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.warning("Failed to delete %s: %s", pattern, exc)
|
||||||
|
|
||||||
|
return deleted
|
||||||
|
|
||||||
|
|
||||||
|
class ModelLifecycleService:
|
||||||
|
"""Co-ordinate destructive and mutating model operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
scanner,
|
||||||
|
metadata_manager,
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
|
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._scanner = scanner
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
self._metadata_loader = metadata_loader
|
||||||
|
self._recipe_scanner_factory = (
|
||||||
|
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||||
|
"""Delete a model file and associated artefacts."""
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
raise ValueError("Model path is required")
|
||||||
|
|
||||||
|
target_dir = os.path.dirname(file_path)
|
||||||
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
|
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||||
|
|
||||||
|
cache = await self._scanner.get_cached_data()
|
||||||
|
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||||
|
self._scanner._hash_index.remove_by_path(file_path)
|
||||||
|
|
||||||
|
return {"success": True, "deleted_files": deleted_files}
|
||||||
|
|
||||||
|
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||||
|
"""Mark a model as excluded and prune cache references."""
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
raise ValueError("Model path is required")
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||||
|
metadata = await self._metadata_loader(metadata_path)
|
||||||
|
metadata["exclude"] = True
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||||
|
|
||||||
|
cache = await self._scanner.get_cached_data()
|
||||||
|
model_to_remove = next(
|
||||||
|
(item for item in cache.raw_data if item["file_path"] == file_path),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_to_remove:
|
||||||
|
for tag in model_to_remove.get("tags", []):
|
||||||
|
if tag in getattr(self._scanner, "_tags_count", {}):
|
||||||
|
self._scanner._tags_count[tag] = max(
|
||||||
|
0, self._scanner._tags_count[tag] - 1
|
||||||
|
)
|
||||||
|
if self._scanner._tags_count[tag] == 0:
|
||||||
|
del self._scanner._tags_count[tag]
|
||||||
|
|
||||||
|
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||||
|
self._scanner._hash_index.remove_by_path(file_path)
|
||||||
|
|
||||||
|
cache.raw_data = [
|
||||||
|
item for item in cache.raw_data if item["file_path"] != file_path
|
||||||
|
]
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
excluded = getattr(self._scanner, "_excluded_models", None)
|
||||||
|
if isinstance(excluded, list):
|
||||||
|
excluded.append(file_path)
|
||||||
|
|
||||||
|
message = f"Model {os.path.basename(file_path)} excluded"
|
||||||
|
return {"success": True, "message": message}
|
||||||
|
|
||||||
|
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
||||||
|
"""Delete a collection of models via the scanner bulk operation."""
|
||||||
|
|
||||||
|
file_paths = list(file_paths)
|
||||||
|
if not file_paths:
|
||||||
|
raise ValueError("No file paths provided for deletion")
|
||||||
|
|
||||||
|
return await self._scanner.bulk_delete_models(file_paths)
|
||||||
|
|
||||||
|
async def rename_model(
|
||||||
|
self, *, file_path: str, new_file_name: str
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
"""Rename a model and its companion artefacts."""
|
||||||
|
|
||||||
|
if not file_path or not new_file_name:
|
||||||
|
raise ValueError("File path and new file name are required")
|
||||||
|
|
||||||
|
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
|
||||||
|
if any(char in new_file_name for char in invalid_chars):
|
||||||
|
raise ValueError("Invalid characters in file name")
|
||||||
|
|
||||||
|
target_dir = os.path.dirname(file_path)
|
||||||
|
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.exists(new_file_path):
|
||||||
|
raise ValueError("A file with this name already exists")
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
f"{old_file_name}.safetensors",
|
||||||
|
f"{old_file_name}.metadata.json",
|
||||||
|
f"{old_file_name}.metadata.json.bak",
|
||||||
|
]
|
||||||
|
for ext in PREVIEW_EXTENSIONS:
|
||||||
|
patterns.append(f"{old_file_name}{ext}")
|
||||||
|
|
||||||
|
existing_files: List[tuple[str, str]] = []
|
||||||
|
for pattern in patterns:
|
||||||
|
path = os.path.join(target_dir, pattern)
|
||||||
|
if os.path.exists(path):
|
||||||
|
existing_files.append((path, pattern))
|
||||||
|
|
||||||
|
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
||||||
|
metadata: Optional[Dict[str, object]] = None
|
||||||
|
hash_value: Optional[str] = None
|
||||||
|
|
||||||
|
if os.path.exists(metadata_path):
|
||||||
|
metadata = await self._metadata_loader(metadata_path)
|
||||||
|
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
|
||||||
|
|
||||||
|
renamed_files: List[str] = []
|
||||||
|
new_metadata_path: Optional[str] = None
|
||||||
|
new_preview: Optional[str] = None
|
||||||
|
|
||||||
|
for old_path, pattern in existing_files:
|
||||||
|
ext = self._get_multipart_ext(pattern)
|
||||||
|
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
|
os.rename(old_path, new_path)
|
||||||
|
renamed_files.append(new_path)
|
||||||
|
|
||||||
|
if ext == ".metadata.json":
|
||||||
|
new_metadata_path = new_path
|
||||||
|
|
||||||
|
if metadata and new_metadata_path:
|
||||||
|
metadata["file_name"] = new_file_name
|
||||||
|
metadata["file_path"] = new_file_path
|
||||||
|
|
||||||
|
if metadata.get("preview_url"):
|
||||||
|
old_preview = str(metadata["preview_url"])
|
||||||
|
ext = self._get_multipart_ext(old_preview)
|
||||||
|
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
|
metadata["preview_url"] = new_preview
|
||||||
|
|
||||||
|
await self._metadata_manager.save_metadata(new_file_path, metadata)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
await self._scanner.update_single_model_cache(
|
||||||
|
file_path, new_file_path, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
|
||||||
|
recipe_scanner = await self._recipe_scanner_factory()
|
||||||
|
if recipe_scanner:
|
||||||
|
try:
|
||||||
|
await recipe_scanner.update_lora_filename_by_hash(
|
||||||
|
hash_value, new_file_name
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error(
|
||||||
|
"Error updating recipe references for %s: %s",
|
||||||
|
file_path,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"new_file_path": new_file_path,
|
||||||
|
"new_preview_path": new_preview,
|
||||||
|
"renamed_files": renamed_files,
|
||||||
|
"reload_required": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_multipart_ext(filename: str) -> str:
|
||||||
|
"""Return the extension for files with compound suffixes."""
|
||||||
|
|
||||||
|
parts = filename.split(".")
|
||||||
|
if len(parts) == 3:
|
||||||
|
return "." + ".".join(parts[-2:])
|
||||||
|
if len(parts) >= 4:
|
||||||
|
return "." + ".".join(parts[-3:])
|
||||||
|
return os.path.splitext(filename)[1]
|
||||||
|
|
||||||
@@ -1,11 +1,41 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import json
|
import json
|
||||||
import aiosqlite
|
|
||||||
import logging
|
import logging
|
||||||
from bs4 import BeautifulSoup
|
from typing import Optional, Dict, Tuple, Any
|
||||||
from typing import Optional, Dict, Tuple
|
|
||||||
from .downloader import get_downloader
|
from .downloader import get_downloader
|
||||||
|
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
except ImportError as exc:
|
||||||
|
BeautifulSoup = None # type: ignore[assignment]
|
||||||
|
_BS4_IMPORT_ERROR = exc
|
||||||
|
else:
|
||||||
|
_BS4_IMPORT_ERROR = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiosqlite
|
||||||
|
except ImportError as exc:
|
||||||
|
aiosqlite = None # type: ignore[assignment]
|
||||||
|
_AIOSQLITE_IMPORT_ERROR = exc
|
||||||
|
else:
|
||||||
|
_AIOSQLITE_IMPORT_ERROR = None
|
||||||
|
|
||||||
|
def _require_beautifulsoup() -> Any:
|
||||||
|
if BeautifulSoup is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"BeautifulSoup (bs4) is required for CivArchiveModelMetadataProvider. "
|
||||||
|
"Install it with 'pip install beautifulsoup4'."
|
||||||
|
) from _BS4_IMPORT_ERROR
|
||||||
|
return BeautifulSoup
|
||||||
|
|
||||||
|
def _require_aiosqlite() -> Any:
|
||||||
|
if aiosqlite is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"aiosqlite is required for SQLiteModelMetadataProvider. "
|
||||||
|
"Install it with 'pip install aiosqlite'."
|
||||||
|
) from _AIOSQLITE_IMPORT_ERROR
|
||||||
|
return aiosqlite
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class ModelMetadataProvider(ABC):
|
class ModelMetadataProvider(ABC):
|
||||||
@@ -78,7 +108,8 @@ class CivArchiveModelMetadataProvider(ModelMetadataProvider):
|
|||||||
html_content = await response.text()
|
html_content = await response.text()
|
||||||
|
|
||||||
# Parse HTML to extract JSON data
|
# Parse HTML to extract JSON data
|
||||||
soup = BeautifulSoup(html_content, 'html.parser')
|
soup_parser = _require_beautifulsoup()
|
||||||
|
soup = soup_parser(html_content, 'html.parser')
|
||||||
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
script_tag = soup.find('script', {'id': '__NEXT_DATA__', 'type': 'application/json'})
|
||||||
|
|
||||||
if not script_tag:
|
if not script_tag:
|
||||||
@@ -171,10 +202,11 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
|
|
||||||
def __init__(self, db_path: str):
|
def __init__(self, db_path: str):
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
self._aiosqlite = _require_aiosqlite()
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_by_hash(self, model_hash: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Find model by hash value from SQLite database"""
|
"""Find model by hash value from SQLite database"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
# Look up in model_files table to get model_id and version_id
|
# Look up in model_files table to get model_id and version_id
|
||||||
query = """
|
query = """
|
||||||
SELECT model_id, version_id
|
SELECT model_id, version_id
|
||||||
@@ -182,7 +214,7 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
WHERE sha256 = ?
|
WHERE sha256 = ?
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
"""
|
"""
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
cursor = await db.execute(query, (model_hash.upper(),))
|
cursor = await db.execute(query, (model_hash.upper(),))
|
||||||
file_row = await cursor.fetchone()
|
file_row = await cursor.fetchone()
|
||||||
|
|
||||||
@@ -199,8 +231,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
async def get_model_versions(self, model_id: str) -> Optional[Dict]:
|
||||||
"""Get all versions of a model from SQLite database"""
|
"""Get all versions of a model from SQLite database"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# First check if model exists
|
# First check if model exists
|
||||||
model_query = "SELECT * FROM models WHERE id = ?"
|
model_query = "SELECT * FROM models WHERE id = ?"
|
||||||
@@ -258,8 +290,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
if not model_id and not version_id:
|
if not model_id and not version_id:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# Case 1: Only version_id is provided
|
# Case 1: Only version_id is provided
|
||||||
if model_id is None and version_id is not None:
|
if model_id is None and version_id is not None:
|
||||||
@@ -295,8 +327,8 @@ class SQLiteModelMetadataProvider(ModelMetadataProvider):
|
|||||||
|
|
||||||
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
async def get_model_version_info(self, version_id: str) -> Tuple[Optional[Dict], Optional[str]]:
|
||||||
"""Fetch model version metadata from SQLite database"""
|
"""Fetch model version metadata from SQLite database"""
|
||||||
async with aiosqlite.connect(self.db_path) as db:
|
async with self._aiosqlite.connect(self.db_path) as db:
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = self._aiosqlite.Row
|
||||||
|
|
||||||
# Get version details
|
# Get version details
|
||||||
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
version_query = "SELECT model_id FROM model_versions WHERE id = ?"
|
||||||
|
|||||||
196
py/services/model_query.py
Normal file
196
py/services/model_query.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
|
||||||
|
|
||||||
|
from ..utils.constants import NSFW_LEVELS
|
||||||
|
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsProvider(Protocol):
|
||||||
|
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SortParams:
|
||||||
|
"""Normalized representation of sorting instructions."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
order: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FilterCriteria:
|
||||||
|
"""Container for model list filtering options."""
|
||||||
|
|
||||||
|
folder: Optional[str] = None
|
||||||
|
base_models: Optional[Sequence[str]] = None
|
||||||
|
tags: Optional[Sequence[str]] = None
|
||||||
|
favorites_only: bool = False
|
||||||
|
search_options: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCacheRepository:
|
||||||
|
"""Adapter around scanner cache access and sort normalisation."""
|
||||||
|
|
||||||
|
def __init__(self, scanner) -> None:
|
||||||
|
self._scanner = scanner
|
||||||
|
|
||||||
|
async def get_cache(self):
|
||||||
|
"""Return the underlying cache instance from the scanner."""
|
||||||
|
return await self._scanner.get_cached_data()
|
||||||
|
|
||||||
|
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
|
||||||
|
"""Fetch cached data pre-sorted according to ``params``."""
|
||||||
|
cache = await self.get_cache()
|
||||||
|
return await cache.get_sorted_data(params.key, params.order)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_sort(sort_by: str) -> SortParams:
|
||||||
|
"""Parse an incoming sort string into key/order primitives."""
|
||||||
|
if not sort_by:
|
||||||
|
return SortParams(key="name", order="asc")
|
||||||
|
|
||||||
|
if ":" in sort_by:
|
||||||
|
raw_key, raw_order = sort_by.split(":", 1)
|
||||||
|
sort_key = raw_key.strip().lower() or "name"
|
||||||
|
order = raw_order.strip().lower()
|
||||||
|
else:
|
||||||
|
sort_key = sort_by.strip().lower() or "name"
|
||||||
|
order = "asc"
|
||||||
|
|
||||||
|
if order not in ("asc", "desc"):
|
||||||
|
order = "asc"
|
||||||
|
|
||||||
|
return SortParams(key=sort_key, order=order)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFilterSet:
|
||||||
|
"""Applies common filtering rules to the model collection."""
|
||||||
|
|
||||||
|
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
|
||||||
|
self._settings = settings
|
||||||
|
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
|
||||||
|
|
||||||
|
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
||||||
|
"""Return items that satisfy the provided criteria."""
|
||||||
|
items = list(data)
|
||||||
|
|
||||||
|
if self._settings.get("show_only_sfw", False):
|
||||||
|
threshold = self._nsfw_levels.get("R", 0)
|
||||||
|
items = [
|
||||||
|
item for item in items
|
||||||
|
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
if criteria.favorites_only:
|
||||||
|
items = [item for item in items if item.get("favorite", False)]
|
||||||
|
|
||||||
|
folder = criteria.folder
|
||||||
|
options = criteria.search_options or {}
|
||||||
|
recursive = bool(options.get("recursive", True))
|
||||||
|
if folder is not None:
|
||||||
|
if recursive:
|
||||||
|
if folder:
|
||||||
|
folder_with_sep = f"{folder}/"
|
||||||
|
items = [
|
||||||
|
item for item in items
|
||||||
|
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
items = [item for item in items if item.get("folder") == folder]
|
||||||
|
|
||||||
|
base_models = criteria.base_models or []
|
||||||
|
if base_models:
|
||||||
|
base_model_set = set(base_models)
|
||||||
|
items = [item for item in items if item.get("base_model") in base_model_set]
|
||||||
|
|
||||||
|
tags = criteria.tags or []
|
||||||
|
if tags:
|
||||||
|
tag_set = set(tags)
|
||||||
|
items = [
|
||||||
|
item for item in items
|
||||||
|
if any(tag in tag_set for tag in item.get("tags", []))
|
||||||
|
]
|
||||||
|
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
|
class SearchStrategy:
|
||||||
|
"""Encapsulates text and fuzzy matching behaviour for model queries."""
|
||||||
|
|
||||||
|
DEFAULT_OPTIONS: Dict[str, Any] = {
|
||||||
|
"filename": True,
|
||||||
|
"modelname": True,
|
||||||
|
"tags": False,
|
||||||
|
"recursive": True,
|
||||||
|
"creator": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
|
||||||
|
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
|
||||||
|
|
||||||
|
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
|
"""Merge provided options with defaults without mutating input."""
|
||||||
|
normalized = dict(self.DEFAULT_OPTIONS)
|
||||||
|
if options:
|
||||||
|
normalized.update(options)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
data: Iterable[Dict[str, Any]],
|
||||||
|
search_term: str,
|
||||||
|
options: Dict[str, Any],
|
||||||
|
fuzzy: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Return items matching the search term using the configured strategy."""
|
||||||
|
if not search_term:
|
||||||
|
return list(data)
|
||||||
|
|
||||||
|
search_lower = search_term.lower()
|
||||||
|
results: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for item in data:
|
||||||
|
if options.get("filename", True):
|
||||||
|
candidate = item.get("file_name", "")
|
||||||
|
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if options.get("modelname", True):
|
||||||
|
candidate = item.get("model_name", "")
|
||||||
|
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if options.get("tags", False):
|
||||||
|
tags = item.get("tags", []) or []
|
||||||
|
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if options.get("creator", False):
|
||||||
|
creator_username = ""
|
||||||
|
civitai = item.get("civitai")
|
||||||
|
if isinstance(civitai, dict):
|
||||||
|
creator = civitai.get("creator")
|
||||||
|
if isinstance(creator, dict):
|
||||||
|
creator_username = creator.get("username", "")
|
||||||
|
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
|
||||||
|
results.append(item)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
|
||||||
|
if not candidate:
|
||||||
|
return False
|
||||||
|
|
||||||
|
candidate_lower = candidate.lower()
|
||||||
|
if fuzzy:
|
||||||
|
return self._fuzzy_match(candidate, search_term)
|
||||||
|
return search_lower in candidate_lower
|
||||||
@@ -13,6 +13,7 @@ from ..utils.metadata_manager import MetadataManager
|
|||||||
from .model_cache import ModelCache
|
from .model_cache import ModelCache
|
||||||
from .model_hash_index import ModelHashIndex
|
from .model_hash_index import ModelHashIndex
|
||||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||||
|
from .model_lifecycle_service import delete_model_artifacts
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
|
|
||||||
@@ -1040,10 +1041,8 @@ class ModelScanner:
|
|||||||
target_dir = os.path.dirname(file_path)
|
target_dir = os.path.dirname(file_path)
|
||||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
# Delete all associated files for the model
|
deleted_files = await delete_model_artifacts(
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
target_dir,
|
||||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
|
||||||
target_dir,
|
|
||||||
file_name
|
file_name
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
168
py/services/preview_asset_service.py
Normal file
168
py/services/preview_asset_service.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Service for processing preview assets for models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||||
|
|
||||||
|
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PreviewAssetService:
|
||||||
|
"""Manage fetching and persisting preview assets."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata_manager,
|
||||||
|
downloader_factory: Callable[[], Awaitable],
|
||||||
|
exif_utils,
|
||||||
|
) -> None:
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
self._downloader_factory = downloader_factory
|
||||||
|
self._exif_utils = exif_utils
|
||||||
|
|
||||||
|
async def ensure_preview_for_metadata(
|
||||||
|
self,
|
||||||
|
metadata_path: str,
|
||||||
|
local_metadata: Dict[str, object],
|
||||||
|
images: Sequence[Dict[str, object]] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure preview assets exist for the supplied metadata entry."""
|
||||||
|
|
||||||
|
if local_metadata.get("preview_url") and os.path.exists(
|
||||||
|
str(local_metadata["preview_url"])
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
if not images:
|
||||||
|
return
|
||||||
|
|
||||||
|
first_preview = images[0]
|
||||||
|
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||||
|
preview_dir = os.path.dirname(metadata_path)
|
||||||
|
is_video = first_preview.get("type") == "video"
|
||||||
|
|
||||||
|
if is_video:
|
||||||
|
extension = ".mp4"
|
||||||
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
success, result = await downloader.download_file(
|
||||||
|
first_preview["url"], preview_path, use_auth=False
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
else:
|
||||||
|
extension = ".webp"
|
||||||
|
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
success, content, _headers = await downloader.download_to_memory(
|
||||||
|
first_preview["url"], use_auth=False
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
optimized_data, _ = self._exif_utils.optimize_image(
|
||||||
|
image_data=content,
|
||||||
|
target_width=CARD_PREVIEW_WIDTH,
|
||||||
|
format="webp",
|
||||||
|
quality=85,
|
||||||
|
preserve_metadata=False,
|
||||||
|
)
|
||||||
|
with open(preview_path, "wb") as handle:
|
||||||
|
handle.write(optimized_data)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.error("Error optimizing preview image: %s", exc)
|
||||||
|
try:
|
||||||
|
with open(preview_path, "wb") as handle:
|
||||||
|
handle.write(content)
|
||||||
|
except Exception as save_exc:
|
||||||
|
logger.error("Error saving preview image: %s", save_exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||||
|
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||||
|
|
||||||
|
async def replace_preview(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
model_path: str,
|
||||||
|
preview_data: bytes,
|
||||||
|
content_type: str,
|
||||||
|
original_filename: Optional[str],
|
||||||
|
nsfw_level: int,
|
||||||
|
update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
"""Replace an existing preview asset for a model."""
|
||||||
|
|
||||||
|
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||||
|
folder = os.path.dirname(model_path)
|
||||||
|
|
||||||
|
extension, optimized_data = await self._convert_preview(
|
||||||
|
preview_data, content_type, original_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
for ext in PREVIEW_EXTENSIONS:
|
||||||
|
existing_preview = os.path.join(folder, base_name + ext)
|
||||||
|
if os.path.exists(existing_preview):
|
||||||
|
try:
|
||||||
|
os.remove(existing_preview)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive path
|
||||||
|
logger.warning(
|
||||||
|
"Failed to delete existing preview %s: %s", existing_preview, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/")
|
||||||
|
with open(preview_path, "wb") as handle:
|
||||||
|
handle.write(optimized_data)
|
||||||
|
|
||||||
|
metadata_path = os.path.splitext(model_path)[0] + ".metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
metadata["preview_url"] = preview_path
|
||||||
|
metadata["preview_nsfw_level"] = nsfw_level
|
||||||
|
await self._metadata_manager.save_metadata(model_path, metadata)
|
||||||
|
|
||||||
|
await update_preview_in_cache(model_path, preview_path, nsfw_level)
|
||||||
|
|
||||||
|
return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level}
|
||||||
|
|
||||||
|
async def _convert_preview(
|
||||||
|
self, data: bytes, content_type: str, original_filename: Optional[str]
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
"""Convert preview bytes to the persisted representation."""
|
||||||
|
|
||||||
|
if content_type.startswith("video/"):
|
||||||
|
extension = self._resolve_video_extension(content_type, original_filename)
|
||||||
|
return extension, data
|
||||||
|
|
||||||
|
original_ext = (original_filename or "").lower()
|
||||||
|
if original_ext.endswith(".gif") or content_type.lower() == "image/gif":
|
||||||
|
return ".gif", data
|
||||||
|
|
||||||
|
optimized_data, _ = self._exif_utils.optimize_image(
|
||||||
|
image_data=data,
|
||||||
|
target_width=CARD_PREVIEW_WIDTH,
|
||||||
|
format="webp",
|
||||||
|
quality=85,
|
||||||
|
preserve_metadata=False,
|
||||||
|
)
|
||||||
|
return ".webp", optimized_data
|
||||||
|
|
||||||
|
def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str:
|
||||||
|
"""Infer the best extension for a video preview."""
|
||||||
|
|
||||||
|
if original_filename:
|
||||||
|
extension = os.path.splitext(original_filename)[1].lower()
|
||||||
|
if extension in {".mp4", ".webm", ".mov", ".avi"}:
|
||||||
|
return extension
|
||||||
|
|
||||||
|
if "webm" in content_type:
|
||||||
|
return ".webm"
|
||||||
|
return ".mp4"
|
||||||
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict
|
from typing import Iterable, List, Dict, Optional
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
@@ -10,77 +10,115 @@ class RecipeCache:
|
|||||||
raw_data: List[Dict]
|
raw_data: List[Dict]
|
||||||
sorted_by_name: List[Dict]
|
sorted_by_name: List[Dict]
|
||||||
sorted_by_date: List[Dict]
|
sorted_by_date: List[Dict]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
async def resort(self, name_only: bool = False):
|
async def resort(self, name_only: bool = False):
|
||||||
"""Resort all cached data views"""
|
"""Resort all cached data views"""
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self.sorted_by_name = natsorted(
|
self._resort_locked(name_only=name_only)
|
||||||
self.raw_data,
|
|
||||||
key=lambda x: x.get('title', '').lower() # Case-insensitive sort
|
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool:
|
||||||
)
|
|
||||||
if not name_only:
|
|
||||||
self.sorted_by_date = sorted(
|
|
||||||
self.raw_data,
|
|
||||||
key=itemgetter('created_date', 'file_path'),
|
|
||||||
reverse=True
|
|
||||||
)
|
|
||||||
|
|
||||||
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool:
|
|
||||||
"""Update metadata for a specific recipe in all cached data
|
"""Update metadata for a specific recipe in all cached data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
recipe_id: The ID of the recipe to update
|
recipe_id: The ID of the recipe to update
|
||||||
metadata: The new metadata
|
metadata: The new metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the update was successful, False if the recipe wasn't found
|
bool: True if the update was successful, False if the recipe wasn't found
|
||||||
"""
|
"""
|
||||||
|
async with self._lock:
|
||||||
|
for item in self.raw_data:
|
||||||
|
if str(item.get('id')) == str(recipe_id):
|
||||||
|
item.update(metadata)
|
||||||
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
return True
|
||||||
|
return False # Recipe not found
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None:
|
||||||
|
"""Add a new recipe to the cache."""
|
||||||
|
|
||||||
# Update in raw_data
|
|
||||||
for item in self.raw_data:
|
|
||||||
if item.get('id') == recipe_id:
|
|
||||||
item.update(metadata)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
return False # Recipe not found
|
|
||||||
|
|
||||||
# Resort to reflect changes
|
|
||||||
await self.resort()
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def add_recipe(self, recipe_data: Dict) -> None:
|
|
||||||
"""Add a new recipe to the cache
|
|
||||||
|
|
||||||
Args:
|
|
||||||
recipe_data: The recipe data to add
|
|
||||||
"""
|
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self.raw_data.append(recipe_data)
|
self.raw_data.append(recipe_data)
|
||||||
await self.resort()
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
|
||||||
|
async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]:
|
||||||
|
"""Remove a recipe from the cache by ID.
|
||||||
|
|
||||||
async def remove_recipe(self, recipe_id: str) -> bool:
|
|
||||||
"""Remove a recipe from the cache by ID
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
recipe_id: The ID of the recipe to remove
|
recipe_id: The ID of the recipe to remove
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the recipe was found and removed, False otherwise
|
The removed recipe data if found, otherwise ``None``.
|
||||||
"""
|
"""
|
||||||
# Find the recipe in raw_data
|
|
||||||
recipe_index = next((i for i, recipe in enumerate(self.raw_data)
|
async with self._lock:
|
||||||
if recipe.get('id') == recipe_id), None)
|
for index, recipe in enumerate(self.raw_data):
|
||||||
|
if str(recipe.get('id')) == str(recipe_id):
|
||||||
if recipe_index is None:
|
removed = self.raw_data.pop(index)
|
||||||
return False
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
# Remove from raw_data
|
return removed
|
||||||
self.raw_data.pop(recipe_index)
|
return None
|
||||||
|
|
||||||
# Resort to update sorted lists
|
async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]:
|
||||||
await self.resort()
|
"""Remove multiple recipes from the cache."""
|
||||||
|
|
||||||
return True
|
id_set = {str(recipe_id) for recipe_id in recipe_ids}
|
||||||
|
if not id_set:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
removed = [item for item in self.raw_data if str(item.get('id')) in id_set]
|
||||||
|
if not removed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set]
|
||||||
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
return removed
|
||||||
|
|
||||||
|
async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool:
|
||||||
|
"""Replace cached data for a recipe."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for index, recipe in enumerate(self.raw_data):
|
||||||
|
if str(recipe.get('id')) == str(recipe_id):
|
||||||
|
self.raw_data[index] = new_data
|
||||||
|
if resort:
|
||||||
|
self._resort_locked()
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_recipe(self, recipe_id: str) -> Optional[Dict]:
|
||||||
|
"""Return a shallow copy of a cached recipe."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
for recipe in self.raw_data:
|
||||||
|
if str(recipe.get('id')) == str(recipe_id):
|
||||||
|
return dict(recipe)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def snapshot(self) -> List[Dict]:
|
||||||
|
"""Return a copy of all cached recipes."""
|
||||||
|
|
||||||
|
async with self._lock:
|
||||||
|
return [dict(item) for item in self.raw_data]
|
||||||
|
|
||||||
|
def _resort_locked(self, *, name_only: bool = False) -> None:
|
||||||
|
"""Sort cached views. Caller must hold ``_lock``."""
|
||||||
|
|
||||||
|
self.sorted_by_name = natsorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=lambda x: x.get('title', '').lower()
|
||||||
|
)
|
||||||
|
if not name_only:
|
||||||
|
self.sorted_by_date = sorted(
|
||||||
|
self.raw_data,
|
||||||
|
key=itemgetter('created_date', 'file_path'),
|
||||||
|
reverse=True
|
||||||
|
)
|
||||||
@@ -3,13 +3,14 @@ import logging
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Optional, Any, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .recipe_cache import RecipeCache
|
from .recipe_cache import RecipeCache
|
||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .lora_scanner import LoraScanner
|
from .lora_scanner import LoraScanner
|
||||||
from .metadata_service import get_default_metadata_provider
|
from .metadata_service import get_default_metadata_provider
|
||||||
from ..utils.utils import fuzzy_match
|
from .recipes.errors import RecipeNotFoundError
|
||||||
|
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
|
||||||
from natsort import natsorted
|
from natsort import natsorted
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -46,6 +47,8 @@ class RecipeScanner:
|
|||||||
self._initialization_lock = asyncio.Lock()
|
self._initialization_lock = asyncio.Lock()
|
||||||
self._initialization_task: Optional[asyncio.Task] = None
|
self._initialization_task: Optional[asyncio.Task] = None
|
||||||
self._is_initializing = False
|
self._is_initializing = False
|
||||||
|
self._mutation_lock = asyncio.Lock()
|
||||||
|
self._resort_tasks: Set[asyncio.Task] = set()
|
||||||
if lora_scanner:
|
if lora_scanner:
|
||||||
self._lora_scanner = lora_scanner
|
self._lora_scanner = lora_scanner
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
@@ -191,6 +194,22 @@ class RecipeScanner:
|
|||||||
# Clean up the event loop
|
# Clean up the event loop
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
def _schedule_resort(self, *, name_only: bool = False) -> None:
|
||||||
|
"""Schedule a background resort of the recipe cache."""
|
||||||
|
|
||||||
|
if not self._cache:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _resort_wrapper() -> None:
|
||||||
|
try:
|
||||||
|
await self._cache.resort(name_only=name_only)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True)
|
||||||
|
|
||||||
|
task = asyncio.create_task(_resort_wrapper())
|
||||||
|
self._resort_tasks.add(task)
|
||||||
|
task.add_done_callback(lambda finished: self._resort_tasks.discard(finished))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recipes_dir(self) -> str:
|
def recipes_dir(self) -> str:
|
||||||
"""Get path to recipes directory"""
|
"""Get path to recipes directory"""
|
||||||
@@ -255,7 +274,45 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Return the cache (may be empty or partially initialized)
|
# Return the cache (may be empty or partially initialized)
|
||||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||||
|
|
||||||
|
async def refresh_cache(self, force: bool = False) -> RecipeCache:
|
||||||
|
"""Public helper to refresh or return the recipe cache."""
|
||||||
|
|
||||||
|
return await self.get_cached_data(force_refresh=force)
|
||||||
|
|
||||||
|
async def add_recipe(self, recipe_data: Dict[str, Any]) -> None:
|
||||||
|
"""Add a recipe to the in-memory cache."""
|
||||||
|
|
||||||
|
if not recipe_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
await cache.add_recipe(recipe_data, resort=False)
|
||||||
|
self._schedule_resort()
|
||||||
|
|
||||||
|
async def remove_recipe(self, recipe_id: str) -> bool:
|
||||||
|
"""Remove a recipe from the cache by ID."""
|
||||||
|
|
||||||
|
if not recipe_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
removed = await cache.remove_recipe(recipe_id, resort=False)
|
||||||
|
if removed is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._schedule_resort()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def bulk_remove(self, recipe_ids: Iterable[str]) -> int:
|
||||||
|
"""Remove multiple recipes from the cache."""
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
removed = await cache.bulk_remove(recipe_ids, resort=False)
|
||||||
|
if removed:
|
||||||
|
self._schedule_resort()
|
||||||
|
return len(removed)
|
||||||
|
|
||||||
async def scan_all_recipes(self) -> List[Dict]:
|
async def scan_all_recipes(self) -> List[Dict]:
|
||||||
"""Scan all recipe JSON files and return metadata"""
|
"""Scan all recipe JSON files and return metadata"""
|
||||||
recipes = []
|
recipes = []
|
||||||
@@ -326,7 +383,6 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Calculate and update fingerprint if missing
|
# Calculate and update fingerprint if missing
|
||||||
if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
|
if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
|
||||||
from ..utils.utils import calculate_recipe_fingerprint
|
|
||||||
fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
|
fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
|
||||||
recipe_data['fingerprint'] = fingerprint
|
recipe_data['fingerprint'] = fingerprint
|
||||||
|
|
||||||
@@ -497,9 +553,36 @@ class RecipeScanner:
|
|||||||
logger.error(f"Error getting base model for lora: {e}")
|
logger.error(f"Error getting base model for lora: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Populate convenience fields for a LoRA entry."""
|
||||||
|
|
||||||
|
if not lora or not self._lora_scanner:
|
||||||
|
return lora
|
||||||
|
|
||||||
|
hash_value = (lora.get('hash') or '').lower()
|
||||||
|
if not hash_value:
|
||||||
|
return lora
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora['inLibrary'] = self._lora_scanner.has_hash(hash_value)
|
||||||
|
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value)
|
||||||
|
lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.debug("Error enriching lora entry %s: %s", hash_value, exc)
|
||||||
|
|
||||||
|
return lora
|
||||||
|
|
||||||
|
async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Lookup a local LoRA model by name."""
|
||||||
|
|
||||||
|
if not self._lora_scanner or not name:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return await self._lora_scanner.get_model_info_by_name(name)
|
||||||
|
|
||||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
|
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
|
||||||
"""Get paginated and filtered recipe data
|
"""Get paginated and filtered recipe data
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
page: Current page number (1-based)
|
page: Current page number (1-based)
|
||||||
page_size: Number of items per page
|
page_size: Number of items per page
|
||||||
@@ -598,16 +681,12 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Get paginated items
|
# Get paginated items
|
||||||
paginated_items = filtered_data[start_idx:end_idx]
|
paginated_items = filtered_data[start_idx:end_idx]
|
||||||
|
|
||||||
# Add inLibrary information for each lora
|
# Add inLibrary information for each lora
|
||||||
for item in paginated_items:
|
for item in paginated_items:
|
||||||
if 'loras' in item:
|
if 'loras' in item:
|
||||||
for lora in item['loras']:
|
item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']]
|
||||||
if 'hash' in lora and lora['hash']:
|
|
||||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
|
|
||||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
|
||||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
'items': paginated_items,
|
'items': paginated_items,
|
||||||
'total': total_items,
|
'total': total_items,
|
||||||
@@ -653,13 +732,8 @@ class RecipeScanner:
|
|||||||
|
|
||||||
# Add lora metadata
|
# Add lora metadata
|
||||||
if 'loras' in formatted_recipe:
|
if 'loras' in formatted_recipe:
|
||||||
for lora in formatted_recipe['loras']:
|
formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']]
|
||||||
if 'hash' in lora and lora['hash']:
|
|
||||||
lora_hash = lora['hash'].lower()
|
|
||||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
|
|
||||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
|
||||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
|
|
||||||
|
|
||||||
return formatted_recipe
|
return formatted_recipe
|
||||||
|
|
||||||
def _format_file_url(self, file_path: str) -> str:
|
def _format_file_url(self, file_path: str) -> str:
|
||||||
@@ -717,26 +791,159 @@ class RecipeScanner:
|
|||||||
# Save updated recipe
|
# Save updated recipe
|
||||||
with open(recipe_json_path, 'w', encoding='utf-8') as f:
|
with open(recipe_json_path, 'w', encoding='utf-8') as f:
|
||||||
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
|
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
# Update the cache if it exists
|
# Update the cache if it exists
|
||||||
if self._cache is not None:
|
if self._cache is not None:
|
||||||
await self._cache.update_recipe_metadata(recipe_id, metadata)
|
await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False)
|
||||||
|
self._schedule_resort()
|
||||||
|
|
||||||
# If the recipe has an image, update its EXIF metadata
|
# If the recipe has an image, update its EXIF metadata
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
image_path = recipe_data.get('file_path')
|
image_path = recipe_data.get('file_path')
|
||||||
if image_path and os.path.exists(image_path):
|
if image_path and os.path.exists(image_path):
|
||||||
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
import logging
|
||||||
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
|
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def update_lora_entry(
|
||||||
|
self,
|
||||||
|
recipe_id: str,
|
||||||
|
lora_index: int,
|
||||||
|
*,
|
||||||
|
target_name: str,
|
||||||
|
target_lora: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||||
|
"""Update a specific LoRA entry within a recipe.
|
||||||
|
|
||||||
|
Returns the updated recipe data and the refreshed LoRA metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if target_name is None:
|
||||||
|
raise ValueError("target_name must be provided")
|
||||||
|
|
||||||
|
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||||
|
if not os.path.exists(recipe_json_path):
|
||||||
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
|
async with self._mutation_lock:
|
||||||
|
with open(recipe_json_path, 'r', encoding='utf-8') as file_obj:
|
||||||
|
recipe_data = json.load(file_obj)
|
||||||
|
|
||||||
|
loras = recipe_data.get('loras', [])
|
||||||
|
if lora_index >= len(loras):
|
||||||
|
raise RecipeNotFoundError("LoRA index out of range in recipe")
|
||||||
|
|
||||||
|
lora_entry = loras[lora_index]
|
||||||
|
lora_entry['isDeleted'] = False
|
||||||
|
lora_entry['exclude'] = False
|
||||||
|
lora_entry['file_name'] = target_name
|
||||||
|
|
||||||
|
if target_lora is not None:
|
||||||
|
sha_value = target_lora.get('sha256') or target_lora.get('sha')
|
||||||
|
if sha_value:
|
||||||
|
lora_entry['hash'] = sha_value.lower()
|
||||||
|
|
||||||
|
civitai_info = target_lora.get('civitai') or {}
|
||||||
|
if civitai_info:
|
||||||
|
lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '')
|
||||||
|
lora_entry['modelVersionName'] = civitai_info.get('name', '')
|
||||||
|
lora_entry['modelVersionId'] = civitai_info.get('id')
|
||||||
|
|
||||||
|
recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', []))
|
||||||
|
recipe_data['modified'] = time.time()
|
||||||
|
|
||||||
|
with open(recipe_json_path, 'w', encoding='utf-8') as file_obj:
|
||||||
|
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False)
|
||||||
|
if not replaced:
|
||||||
|
await cache.add_recipe(recipe_data, resort=False)
|
||||||
|
self._schedule_resort()
|
||||||
|
|
||||||
|
updated_lora = dict(lora_entry)
|
||||||
|
if target_lora is not None:
|
||||||
|
preview_url = target_lora.get('preview_url')
|
||||||
|
if preview_url:
|
||||||
|
updated_lora['preview_url'] = config.get_preview_static_url(preview_url)
|
||||||
|
if target_lora.get('file_path'):
|
||||||
|
updated_lora['localPath'] = target_lora['file_path']
|
||||||
|
|
||||||
|
updated_lora = self._enrich_lora_entry(updated_lora)
|
||||||
|
return recipe_data, updated_lora
|
||||||
|
|
||||||
|
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Return recipes that reference a given LoRA hash."""
|
||||||
|
|
||||||
|
if not lora_hash:
|
||||||
|
return []
|
||||||
|
|
||||||
|
normalized_hash = lora_hash.lower()
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
matching_recipes: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
for recipe in cache.raw_data:
|
||||||
|
loras = recipe.get('loras', [])
|
||||||
|
if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras):
|
||||||
|
recipe_copy = {**recipe}
|
||||||
|
recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras]
|
||||||
|
recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path'))
|
||||||
|
matching_recipes.append(recipe_copy)
|
||||||
|
|
||||||
|
return matching_recipes
|
||||||
|
|
||||||
|
async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]:
|
||||||
|
"""Build LoRA syntax tokens for a recipe."""
|
||||||
|
|
||||||
|
cache = await self.get_cached_data()
|
||||||
|
recipe = await cache.get_recipe(recipe_id)
|
||||||
|
if recipe is None:
|
||||||
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
|
loras = recipe.get('loras', [])
|
||||||
|
if not loras:
|
||||||
|
return []
|
||||||
|
|
||||||
|
lora_cache = None
|
||||||
|
if self._lora_scanner is not None:
|
||||||
|
lora_cache = await self._lora_scanner.get_cached_data()
|
||||||
|
|
||||||
|
syntax_parts: List[str] = []
|
||||||
|
for lora in loras:
|
||||||
|
if lora.get('isDeleted', False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_name = None
|
||||||
|
hash_value = (lora.get('hash') or '').lower()
|
||||||
|
if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'):
|
||||||
|
file_path = self._lora_scanner._hash_index.get_path(hash_value)
|
||||||
|
if file_path:
|
||||||
|
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
|
if not file_name and lora.get('modelVersionId') and lora_cache is not None:
|
||||||
|
for cached_lora in getattr(lora_cache, 'raw_data', []):
|
||||||
|
civitai_info = cached_lora.get('civitai')
|
||||||
|
if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'):
|
||||||
|
cached_path = cached_lora.get('path') or cached_lora.get('file_path')
|
||||||
|
if cached_path:
|
||||||
|
file_name = os.path.splitext(os.path.basename(cached_path))[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
if not file_name:
|
||||||
|
file_name = lora.get('file_name', 'unknown-lora')
|
||||||
|
|
||||||
|
strength = lora.get('strength', 1.0)
|
||||||
|
syntax_parts.append(f"<lora:{file_name}:{strength}>")
|
||||||
|
|
||||||
|
return syntax_parts
|
||||||
|
|
||||||
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
|
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
|
||||||
"""Update file_name in all recipes that contain a LoRA with the specified hash.
|
"""Update file_name in all recipes that contain a LoRA with the specified hash.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hash_value: The SHA256 hash value of the LoRA
|
hash_value: The SHA256 hash value of the LoRA
|
||||||
new_file_name: The new file_name to set
|
new_file_name: The new file_name to set
|
||||||
|
|||||||
23
py/services/recipes/__init__.py
Normal file
23
py/services/recipes/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Recipe service layer implementations."""
|
||||||
|
|
||||||
|
from .analysis_service import RecipeAnalysisService
|
||||||
|
from .persistence_service import RecipePersistenceService
|
||||||
|
from .sharing_service import RecipeSharingService
|
||||||
|
from .errors import (
|
||||||
|
RecipeServiceError,
|
||||||
|
RecipeValidationError,
|
||||||
|
RecipeNotFoundError,
|
||||||
|
RecipeDownloadError,
|
||||||
|
RecipeConflictError,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RecipeAnalysisService",
|
||||||
|
"RecipePersistenceService",
|
||||||
|
"RecipeSharingService",
|
||||||
|
"RecipeServiceError",
|
||||||
|
"RecipeValidationError",
|
||||||
|
"RecipeNotFoundError",
|
||||||
|
"RecipeDownloadError",
|
||||||
|
"RecipeConflictError",
|
||||||
|
]
|
||||||
289
py/services/recipes/analysis_service.py
Normal file
289
py/services/recipes/analysis_service.py
Normal file
@@ -0,0 +1,289 @@
|
|||||||
|
"""Services responsible for recipe metadata analysis."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import tempfile
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from ...utils.utils import calculate_recipe_fingerprint
|
||||||
|
from .errors import (
|
||||||
|
RecipeDownloadError,
|
||||||
|
RecipeNotFoundError,
|
||||||
|
RecipeServiceError,
|
||||||
|
RecipeValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AnalysisResult:
|
||||||
|
"""Return payload from analysis operations."""
|
||||||
|
|
||||||
|
payload: dict[str, Any]
|
||||||
|
status: int = 200
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeAnalysisService:
|
||||||
|
"""Extract recipe metadata from various image sources."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
exif_utils,
|
||||||
|
recipe_parser_factory,
|
||||||
|
downloader_factory: Callable[[], Any],
|
||||||
|
metadata_collector: Optional[Callable[[], Any]] = None,
|
||||||
|
metadata_processor_cls: Optional[type] = None,
|
||||||
|
metadata_registry_cls: Optional[type] = None,
|
||||||
|
standalone_mode: bool = False,
|
||||||
|
logger,
|
||||||
|
) -> None:
|
||||||
|
self._exif_utils = exif_utils
|
||||||
|
self._recipe_parser_factory = recipe_parser_factory
|
||||||
|
self._downloader_factory = downloader_factory
|
||||||
|
self._metadata_collector = metadata_collector
|
||||||
|
self._metadata_processor_cls = metadata_processor_cls
|
||||||
|
self._metadata_registry_cls = metadata_registry_cls
|
||||||
|
self._standalone_mode = standalone_mode
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
async def analyze_uploaded_image(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
image_bytes: bytes | None,
|
||||||
|
recipe_scanner,
|
||||||
|
) -> AnalysisResult:
|
||||||
|
"""Analyze an uploaded image payload."""
|
||||||
|
|
||||||
|
if not image_bytes:
|
||||||
|
raise RecipeValidationError("No image data provided")
|
||||||
|
|
||||||
|
temp_path = self._write_temp_file(image_bytes)
|
||||||
|
try:
|
||||||
|
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||||
|
if not metadata:
|
||||||
|
return AnalysisResult({"error": "No metadata found in this image", "loras": []})
|
||||||
|
|
||||||
|
return await self._parse_metadata(
|
||||||
|
metadata,
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
image_path=None,
|
||||||
|
include_image_base64=False,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._safe_cleanup(temp_path)
|
||||||
|
|
||||||
|
async def analyze_remote_image(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
url: str | None,
|
||||||
|
recipe_scanner,
|
||||||
|
civitai_client,
|
||||||
|
) -> AnalysisResult:
|
||||||
|
"""Analyze an image accessible via URL, including Civitai integration."""
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
raise RecipeValidationError("No URL provided")
|
||||||
|
|
||||||
|
if civitai_client is None:
|
||||||
|
raise RecipeServiceError("Civitai client unavailable")
|
||||||
|
|
||||||
|
temp_path = self._create_temp_path()
|
||||||
|
metadata: Optional[dict[str, Any]] = None
|
||||||
|
try:
|
||||||
|
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url)
|
||||||
|
if civitai_match:
|
||||||
|
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||||
|
if not image_info:
|
||||||
|
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||||
|
image_url = image_info.get("url")
|
||||||
|
if not image_url:
|
||||||
|
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||||
|
await self._download_image(image_url, temp_path)
|
||||||
|
metadata = image_info.get("meta") if "meta" in image_info else None
|
||||||
|
else:
|
||||||
|
await self._download_image(url, temp_path)
|
||||||
|
|
||||||
|
if metadata is None:
|
||||||
|
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||||
|
|
||||||
|
if not metadata:
|
||||||
|
return self._metadata_not_found_response(temp_path)
|
||||||
|
|
||||||
|
return await self._parse_metadata(
|
||||||
|
metadata,
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
image_path=temp_path,
|
||||||
|
include_image_base64=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._safe_cleanup(temp_path)
|
||||||
|
|
||||||
|
async def analyze_local_image(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str | None,
|
||||||
|
recipe_scanner,
|
||||||
|
) -> AnalysisResult:
|
||||||
|
"""Analyze a file already present on disk."""
|
||||||
|
|
||||||
|
if not file_path:
|
||||||
|
raise RecipeValidationError("No file path provided")
|
||||||
|
|
||||||
|
normalized_path = os.path.normpath(file_path.strip('"').strip("'"))
|
||||||
|
if not os.path.isfile(normalized_path):
|
||||||
|
raise RecipeNotFoundError("File not found")
|
||||||
|
|
||||||
|
metadata = self._exif_utils.extract_image_metadata(normalized_path)
|
||||||
|
if not metadata:
|
||||||
|
return self._metadata_not_found_response(normalized_path)
|
||||||
|
|
||||||
|
return await self._parse_metadata(
|
||||||
|
metadata,
|
||||||
|
recipe_scanner=recipe_scanner,
|
||||||
|
image_path=normalized_path,
|
||||||
|
include_image_base64=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def analyze_widget_metadata(self, *, recipe_scanner) -> AnalysisResult:
|
||||||
|
"""Analyse the most recent generation metadata for widget saves."""
|
||||||
|
|
||||||
|
if self._metadata_collector is None or self._metadata_processor_cls is None:
|
||||||
|
raise RecipeValidationError("Metadata collection not available")
|
||||||
|
|
||||||
|
raw_metadata = self._metadata_collector()
|
||||||
|
metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata)
|
||||||
|
if not metadata_dict:
|
||||||
|
raise RecipeValidationError("No generation metadata found")
|
||||||
|
|
||||||
|
latest_image = None
|
||||||
|
if not self._standalone_mode and self._metadata_registry_cls is not None:
|
||||||
|
metadata_registry = self._metadata_registry_cls()
|
||||||
|
latest_image = metadata_registry.get_first_decoded_image()
|
||||||
|
|
||||||
|
if latest_image is None:
|
||||||
|
raise RecipeValidationError(
|
||||||
|
"No recent images found to use for recipe. Try generating an image first."
|
||||||
|
)
|
||||||
|
|
||||||
|
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
|
||||||
|
if image_bytes is None:
|
||||||
|
raise RecipeValidationError("Cannot handle this data shape from metadata registry")
|
||||||
|
|
||||||
|
return AnalysisResult(
|
||||||
|
{
|
||||||
|
"metadata": metadata_dict,
|
||||||
|
"image_bytes": image_bytes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Internal helpers -------------------------------------------------
|
||||||
|
|
||||||
|
async def _parse_metadata(
|
||||||
|
self,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
*,
|
||||||
|
recipe_scanner,
|
||||||
|
image_path: Optional[str],
|
||||||
|
include_image_base64: bool,
|
||||||
|
) -> AnalysisResult:
|
||||||
|
parser = self._recipe_parser_factory.create_parser(metadata)
|
||||||
|
if parser is None:
|
||||||
|
payload = {"error": "No parser found for this image", "loras": []}
|
||||||
|
if include_image_base64 and image_path:
|
||||||
|
payload["image_base64"] = self._encode_file(image_path)
|
||||||
|
return AnalysisResult(payload)
|
||||||
|
|
||||||
|
result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner)
|
||||||
|
|
||||||
|
if include_image_base64 and image_path:
|
||||||
|
result["image_base64"] = self._encode_file(image_path)
|
||||||
|
|
||||||
|
if "error" in result and not result.get("loras"):
|
||||||
|
return AnalysisResult(result)
|
||||||
|
|
||||||
|
fingerprint = calculate_recipe_fingerprint(result.get("loras", []))
|
||||||
|
result["fingerprint"] = fingerprint
|
||||||
|
|
||||||
|
matching_recipes: list[str] = []
|
||||||
|
if fingerprint:
|
||||||
|
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
||||||
|
result["matching_recipes"] = matching_recipes
|
||||||
|
|
||||||
|
return AnalysisResult(result)
|
||||||
|
|
||||||
|
async def _download_image(self, url: str, temp_path: str) -> None:
|
||||||
|
downloader = await self._downloader_factory()
|
||||||
|
success, result = await downloader.download_file(url, temp_path, use_auth=False)
|
||||||
|
if not success:
|
||||||
|
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
|
||||||
|
|
||||||
|
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
|
||||||
|
payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []}
|
||||||
|
if os.path.exists(path):
|
||||||
|
payload["image_base64"] = self._encode_file(path)
|
||||||
|
return AnalysisResult(payload)
|
||||||
|
|
||||||
|
def _write_temp_file(self, data: bytes) -> str:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||||
|
temp_file.write(data)
|
||||||
|
return temp_file.name
|
||||||
|
|
||||||
|
def _create_temp_path(self) -> str:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||||
|
return temp_file.name
|
||||||
|
|
||||||
|
def _safe_cleanup(self, path: Optional[str]) -> None:
|
||||||
|
if path and os.path.exists(path):
|
||||||
|
try:
|
||||||
|
os.unlink(path)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._logger.error("Error deleting temporary file: %s", exc)
|
||||||
|
|
||||||
|
def _encode_file(self, path: str) -> str:
|
||||||
|
with open(path, "rb") as image_file:
|
||||||
|
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
|
||||||
|
def _convert_tensor_to_png_bytes(self, latest_image: Any) -> Optional[bytes]:
|
||||||
|
try:
|
||||||
|
if isinstance(latest_image, tuple):
|
||||||
|
tensor_image = latest_image[0] if latest_image else None
|
||||||
|
if tensor_image is None:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
tensor_image = latest_image
|
||||||
|
|
||||||
|
if hasattr(tensor_image, "shape"):
|
||||||
|
self._logger.debug(
|
||||||
|
"Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
import torch # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
if isinstance(tensor_image, torch.Tensor):
|
||||||
|
image_np = tensor_image.cpu().numpy()
|
||||||
|
else:
|
||||||
|
image_np = np.array(tensor_image)
|
||||||
|
|
||||||
|
while len(image_np.shape) > 3:
|
||||||
|
image_np = image_np[0]
|
||||||
|
|
||||||
|
if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0:
|
||||||
|
image_np = (image_np * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
|
||||||
|
pil_image = Image.fromarray(image_np)
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
pil_image.save(img_byte_arr, format="PNG")
|
||||||
|
return img_byte_arr.getvalue()
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging path
|
||||||
|
self._logger.error("Error processing image data: %s", exc, exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
return None
|
||||||
22
py/services/recipes/errors.py
Normal file
22
py/services/recipes/errors.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Shared exceptions for recipe services."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeServiceError(Exception):
|
||||||
|
"""Base exception for recipe service failures."""
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeValidationError(RecipeServiceError):
|
||||||
|
"""Raised when a request payload fails validation."""
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeNotFoundError(RecipeServiceError):
|
||||||
|
"""Raised when a recipe resource cannot be located."""
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeDownloadError(RecipeServiceError):
|
||||||
|
"""Raised when remote recipe assets cannot be downloaded."""
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeConflictError(RecipeServiceError):
|
||||||
|
"""Raised when a conflicting recipe state is detected."""
|
||||||
400
py/services/recipes/persistence_service.py
Normal file
400
py/services/recipes/persistence_service.py
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
"""Services encapsulating recipe persistence workflows."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
from ...config import config
|
||||||
|
from ...utils.utils import calculate_recipe_fingerprint
|
||||||
|
from .errors import RecipeNotFoundError, RecipeValidationError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PersistenceResult:
|
||||||
|
"""Return payload from persistence operations."""
|
||||||
|
|
||||||
|
payload: dict[str, Any]
|
||||||
|
status: int = 200
|
||||||
|
|
||||||
|
|
||||||
|
class RecipePersistenceService:
|
||||||
|
"""Coordinate recipe persistence tasks across storage and caches."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
exif_utils,
|
||||||
|
card_preview_width: int,
|
||||||
|
logger,
|
||||||
|
) -> None:
|
||||||
|
self._exif_utils = exif_utils
|
||||||
|
self._card_preview_width = card_preview_width
|
||||||
|
self._logger = logger
|
||||||
|
|
||||||
|
async def save_recipe(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
recipe_scanner,
|
||||||
|
image_bytes: bytes | None,
|
||||||
|
image_base64: str | None,
|
||||||
|
name: str | None,
|
||||||
|
tags: Iterable[str],
|
||||||
|
metadata: Optional[dict[str, Any]],
|
||||||
|
) -> PersistenceResult:
|
||||||
|
"""Persist a user uploaded recipe."""
|
||||||
|
|
||||||
|
missing_fields = []
|
||||||
|
if not name:
|
||||||
|
missing_fields.append("name")
|
||||||
|
if metadata is None:
|
||||||
|
missing_fields.append("metadata")
|
||||||
|
if missing_fields:
|
||||||
|
raise RecipeValidationError(
|
||||||
|
f"Missing required fields: {', '.join(missing_fields)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_image_bytes = self._resolve_image_bytes(image_bytes, image_base64)
|
||||||
|
recipes_dir = recipe_scanner.recipes_dir
|
||||||
|
os.makedirs(recipes_dir, exist_ok=True)
|
||||||
|
|
||||||
|
recipe_id = str(uuid.uuid4())
|
||||||
|
optimized_image, extension = self._exif_utils.optimize_image(
|
||||||
|
image_data=resolved_image_bytes,
|
||||||
|
target_width=self._card_preview_width,
|
||||||
|
format="webp",
|
||||||
|
quality=85,
|
||||||
|
preserve_metadata=True,
|
||||||
|
)
|
||||||
|
image_filename = f"{recipe_id}{extension}"
|
||||||
|
image_path = os.path.join(recipes_dir, image_filename)
|
||||||
|
with open(image_path, "wb") as file_obj:
|
||||||
|
file_obj.write(optimized_image)
|
||||||
|
|
||||||
|
current_time = time.time()
|
||||||
|
loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])]
|
||||||
|
|
||||||
|
gen_params = metadata.get("gen_params", {})
|
||||||
|
if not gen_params and "raw_metadata" in metadata:
|
||||||
|
raw_metadata = metadata.get("raw_metadata", {})
|
||||||
|
gen_params = {
|
||||||
|
"prompt": raw_metadata.get("prompt", ""),
|
||||||
|
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
||||||
|
"checkpoint": raw_metadata.get("checkpoint", {}),
|
||||||
|
"steps": raw_metadata.get("steps", ""),
|
||||||
|
"sampler": raw_metadata.get("sampler", ""),
|
||||||
|
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
||||||
|
"seed": raw_metadata.get("seed", ""),
|
||||||
|
"size": raw_metadata.get("size", ""),
|
||||||
|
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||||
|
recipe_data: Dict[str, Any] = {
|
||||||
|
"id": recipe_id,
|
||||||
|
"file_path": image_path,
|
||||||
|
"title": name,
|
||||||
|
"modified": current_time,
|
||||||
|
"created_date": current_time,
|
||||||
|
"base_model": metadata.get("base_model", ""),
|
||||||
|
"loras": loras_data,
|
||||||
|
"gen_params": gen_params,
|
||||||
|
"fingerprint": fingerprint,
|
||||||
|
}
|
||||||
|
|
||||||
|
tags_list = list(tags)
|
||||||
|
if tags_list:
|
||||||
|
recipe_data["tags"] = tags_list
|
||||||
|
|
||||||
|
if metadata.get("source_path"):
|
||||||
|
recipe_data["source_path"] = metadata.get("source_path")
|
||||||
|
|
||||||
|
json_filename = f"{recipe_id}.recipe.json"
|
||||||
|
json_path = os.path.join(recipes_dir, json_filename)
|
||||||
|
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||||
|
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||||
|
|
||||||
|
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
|
||||||
|
await recipe_scanner.add_recipe(recipe_data)
|
||||||
|
|
||||||
|
return PersistenceResult(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"recipe_id": recipe_id,
|
||||||
|
"image_path": image_path,
|
||||||
|
"json_path": json_path,
|
||||||
|
"matching_recipes": matching_recipes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
|
||||||
|
"""Delete an existing recipe."""
|
||||||
|
|
||||||
|
recipes_dir = recipe_scanner.recipes_dir
|
||||||
|
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||||
|
raise RecipeNotFoundError("Recipes directory not found")
|
||||||
|
|
||||||
|
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||||
|
if not os.path.exists(recipe_json_path):
|
||||||
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
|
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
|
||||||
|
recipe_data = json.load(file_obj)
|
||||||
|
|
||||||
|
image_path = recipe_data.get("file_path")
|
||||||
|
os.remove(recipe_json_path)
|
||||||
|
if image_path and os.path.exists(image_path):
|
||||||
|
os.remove(image_path)
|
||||||
|
|
||||||
|
await recipe_scanner.remove_recipe(recipe_id)
|
||||||
|
return PersistenceResult({"success": True, "message": "Recipe deleted successfully"})
|
||||||
|
|
||||||
|
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
|
||||||
|
"""Update persisted metadata for a recipe."""
|
||||||
|
|
||||||
|
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")):
|
||||||
|
raise RecipeValidationError(
|
||||||
|
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
|
||||||
|
)
|
||||||
|
|
||||||
|
success = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
|
||||||
|
if not success:
|
||||||
|
raise RecipeNotFoundError("Recipe not found or update failed")
|
||||||
|
|
||||||
|
return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates})
|
||||||
|
|
||||||
|
async def reconnect_lora(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
recipe_scanner,
|
||||||
|
recipe_id: str,
|
||||||
|
lora_index: int,
|
||||||
|
target_name: str,
|
||||||
|
) -> PersistenceResult:
|
||||||
|
"""Reconnect a LoRA entry within an existing recipe."""
|
||||||
|
|
||||||
|
recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json")
|
||||||
|
if not os.path.exists(recipe_path):
|
||||||
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
|
target_lora = await recipe_scanner.get_local_lora(target_name)
|
||||||
|
if not target_lora:
|
||||||
|
raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}")
|
||||||
|
|
||||||
|
recipe_data, updated_lora = await recipe_scanner.update_lora_entry(
|
||||||
|
recipe_id,
|
||||||
|
lora_index,
|
||||||
|
target_name=target_name,
|
||||||
|
target_lora=target_lora,
|
||||||
|
)
|
||||||
|
|
||||||
|
image_path = recipe_data.get("file_path")
|
||||||
|
if image_path and os.path.exists(image_path):
|
||||||
|
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||||
|
|
||||||
|
matching_recipes = []
|
||||||
|
if "fingerprint" in recipe_data:
|
||||||
|
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"])
|
||||||
|
if recipe_id in matching_recipes:
|
||||||
|
matching_recipes.remove(recipe_id)
|
||||||
|
|
||||||
|
return PersistenceResult(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"recipe_id": recipe_id,
|
||||||
|
"updated_lora": updated_lora,
|
||||||
|
"matching_recipes": matching_recipes,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def bulk_delete(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
recipe_scanner,
|
||||||
|
recipe_ids: Iterable[str],
|
||||||
|
) -> PersistenceResult:
|
||||||
|
"""Delete multiple recipes in a single request."""
|
||||||
|
|
||||||
|
recipe_ids = list(recipe_ids)
|
||||||
|
if not recipe_ids:
|
||||||
|
raise RecipeValidationError("No recipe IDs provided")
|
||||||
|
|
||||||
|
recipes_dir = recipe_scanner.recipes_dir
|
||||||
|
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||||
|
raise RecipeNotFoundError("Recipes directory not found")
|
||||||
|
|
||||||
|
deleted_recipes: list[str] = []
|
||||||
|
failed_recipes: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for recipe_id in recipe_ids:
|
||||||
|
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||||
|
if not os.path.exists(recipe_json_path):
|
||||||
|
failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"})
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
|
||||||
|
recipe_data = json.load(file_obj)
|
||||||
|
image_path = recipe_data.get("file_path")
|
||||||
|
os.remove(recipe_json_path)
|
||||||
|
if image_path and os.path.exists(image_path):
|
||||||
|
os.remove(image_path)
|
||||||
|
deleted_recipes.append(recipe_id)
|
||||||
|
except Exception as exc:
|
||||||
|
failed_recipes.append({"id": recipe_id, "reason": str(exc)})
|
||||||
|
|
||||||
|
if deleted_recipes:
|
||||||
|
await recipe_scanner.bulk_remove(deleted_recipes)
|
||||||
|
|
||||||
|
return PersistenceResult(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"deleted": deleted_recipes,
|
||||||
|
"failed": failed_recipes,
|
||||||
|
"total_deleted": len(deleted_recipes),
|
||||||
|
"total_failed": len(failed_recipes),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def save_recipe_from_widget(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
recipe_scanner,
|
||||||
|
metadata: dict[str, Any],
|
||||||
|
image_bytes: bytes,
|
||||||
|
) -> PersistenceResult:
|
||||||
|
"""Save a recipe constructed from widget metadata."""
|
||||||
|
|
||||||
|
if not metadata:
|
||||||
|
raise RecipeValidationError("No generation metadata found")
|
||||||
|
|
||||||
|
recipes_dir = recipe_scanner.recipes_dir
|
||||||
|
os.makedirs(recipes_dir, exist_ok=True)
|
||||||
|
|
||||||
|
recipe_id = str(uuid.uuid4())
|
||||||
|
image_filename = f"{recipe_id}.png"
|
||||||
|
image_path = os.path.join(recipes_dir, image_filename)
|
||||||
|
with open(image_path, "wb") as file_obj:
|
||||||
|
file_obj.write(image_bytes)
|
||||||
|
|
||||||
|
lora_stack = metadata.get("loras", "")
|
||||||
|
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack)
|
||||||
|
if not lora_matches:
|
||||||
|
raise RecipeValidationError("No LoRAs found in the generation metadata")
|
||||||
|
|
||||||
|
loras_data = []
|
||||||
|
base_model_counts: Dict[str, int] = {}
|
||||||
|
|
||||||
|
for name, strength in lora_matches:
|
||||||
|
lora_info = await recipe_scanner.get_local_lora(name)
|
||||||
|
lora_data = {
|
||||||
|
"file_name": name,
|
||||||
|
"strength": float(strength),
|
||||||
|
"hash": (lora_info.get("sha256") or "").lower() if lora_info else "",
|
||||||
|
"modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0,
|
||||||
|
"modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "",
|
||||||
|
"modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "",
|
||||||
|
"isDeleted": False,
|
||||||
|
"exclude": False,
|
||||||
|
}
|
||||||
|
loras_data.append(lora_data)
|
||||||
|
|
||||||
|
if lora_info and "base_model" in lora_info:
|
||||||
|
base_model = lora_info["base_model"]
|
||||||
|
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||||
|
|
||||||
|
recipe_name = self._derive_recipe_name(lora_matches)
|
||||||
|
most_common_base_model = (
|
||||||
|
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
recipe_data = {
|
||||||
|
"id": recipe_id,
|
||||||
|
"file_path": image_path,
|
||||||
|
"title": recipe_name,
|
||||||
|
"modified": time.time(),
|
||||||
|
"created_date": time.time(),
|
||||||
|
"base_model": most_common_base_model,
|
||||||
|
"loras": loras_data,
|
||||||
|
"checkpoint": metadata.get("checkpoint", ""),
|
||||||
|
"gen_params": {
|
||||||
|
key: value
|
||||||
|
for key, value in metadata.items()
|
||||||
|
if key not in ["checkpoint", "loras"]
|
||||||
|
},
|
||||||
|
"loras_stack": lora_stack,
|
||||||
|
}
|
||||||
|
|
||||||
|
json_filename = f"{recipe_id}.recipe.json"
|
||||||
|
json_path = os.path.join(recipes_dir, json_filename)
|
||||||
|
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||||
|
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||||
|
await recipe_scanner.add_recipe(recipe_data)
|
||||||
|
|
||||||
|
return PersistenceResult(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"recipe_id": recipe_id,
|
||||||
|
"image_path": image_path,
|
||||||
|
"json_path": json_path,
|
||||||
|
"recipe_name": recipe_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Helper methods ---------------------------------------------------
|
||||||
|
|
||||||
|
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
|
||||||
|
if image_bytes is not None:
|
||||||
|
return image_bytes
|
||||||
|
if image_base64:
|
||||||
|
try:
|
||||||
|
payload = image_base64.split(",", 1)[1] if "," in image_base64 else image_base64
|
||||||
|
return base64.b64decode(payload)
|
||||||
|
except Exception as exc: # pragma: no cover - validation guard
|
||||||
|
raise RecipeValidationError(f"Invalid base64 image data: {exc}") from exc
|
||||||
|
raise RecipeValidationError("No image data provided")
|
||||||
|
|
||||||
|
def _normalise_lora_entry(self, lora: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"file_name": lora.get("file_name", "")
|
||||||
|
or (
|
||||||
|
os.path.splitext(os.path.basename(lora.get("localPath", "")))[0]
|
||||||
|
if lora.get("localPath")
|
||||||
|
else ""
|
||||||
|
),
|
||||||
|
"hash": (lora.get("hash") or "").lower(),
|
||||||
|
"strength": float(lora.get("weight", 1.0)),
|
||||||
|
"modelVersionId": lora.get("id", 0),
|
||||||
|
"modelName": lora.get("name", ""),
|
||||||
|
"modelVersionName": lora.get("version", ""),
|
||||||
|
"isDeleted": lora.get("isDeleted", False),
|
||||||
|
"exclude": lora.get("exclude", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _find_matching_recipes(
|
||||||
|
self,
|
||||||
|
recipe_scanner,
|
||||||
|
fingerprint: str | None,
|
||||||
|
*,
|
||||||
|
exclude_id: Optional[str] = None,
|
||||||
|
) -> list[str]:
|
||||||
|
if not fingerprint:
|
||||||
|
return []
|
||||||
|
matches = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
||||||
|
if exclude_id and exclude_id in matches:
|
||||||
|
matches.remove(exclude_id)
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str:
|
||||||
|
recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]]
|
||||||
|
recipe_name = "_".join(recipe_name_parts)
|
||||||
|
return recipe_name or "recipe"
|
||||||
105
py/services/recipes/sharing_service.py
Normal file
105
py/services/recipes/sharing_service.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Services handling recipe sharing and downloads."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from .errors import RecipeNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SharingResult:
|
||||||
|
"""Return payload for share operations."""
|
||||||
|
|
||||||
|
payload: dict[str, Any]
|
||||||
|
status: int = 200
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DownloadInfo:
|
||||||
|
"""Information required to stream a shared recipe file."""
|
||||||
|
|
||||||
|
file_path: str
|
||||||
|
download_filename: str
|
||||||
|
|
||||||
|
|
||||||
|
class RecipeSharingService:
|
||||||
|
"""Prepare temporary recipe downloads with TTL cleanup."""
|
||||||
|
|
||||||
|
def __init__(self, *, ttl_seconds: int = 300, logger) -> None:
|
||||||
|
self._ttl_seconds = ttl_seconds
|
||||||
|
self._logger = logger
|
||||||
|
self._shared_recipes: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult:
|
||||||
|
"""Prepare a temporary downloadable copy of a recipe image."""
|
||||||
|
|
||||||
|
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||||
|
if not recipe:
|
||||||
|
raise RecipeNotFoundError("Recipe not found")
|
||||||
|
|
||||||
|
image_path = recipe.get("file_path")
|
||||||
|
if not image_path or not os.path.exists(image_path):
|
||||||
|
raise RecipeNotFoundError("Recipe image not found")
|
||||||
|
|
||||||
|
ext = os.path.splitext(image_path)[1]
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file:
|
||||||
|
temp_path = temp_file.name
|
||||||
|
|
||||||
|
shutil.copy2(image_path, temp_path)
|
||||||
|
timestamp = int(time.time())
|
||||||
|
self._shared_recipes[recipe_id] = {
|
||||||
|
"path": temp_path,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"expires": time.time() + self._ttl_seconds,
|
||||||
|
}
|
||||||
|
self._cleanup_shared_recipes()
|
||||||
|
|
||||||
|
safe_title = recipe.get("title", "").replace(" ", "_").lower()
|
||||||
|
filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}"
|
||||||
|
url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}"
|
||||||
|
return SharingResult({"success": True, "download_url": url_path, "filename": filename})
|
||||||
|
|
||||||
|
async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> DownloadInfo:
|
||||||
|
"""Return file path and filename for a prepared shared recipe."""
|
||||||
|
|
||||||
|
shared_info = self._shared_recipes.get(recipe_id)
|
||||||
|
if not shared_info or time.time() > shared_info.get("expires", 0):
|
||||||
|
self._cleanup_entry(recipe_id)
|
||||||
|
raise RecipeNotFoundError("Shared recipe not found or expired")
|
||||||
|
|
||||||
|
file_path = shared_info["path"]
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
self._cleanup_entry(recipe_id)
|
||||||
|
raise RecipeNotFoundError("Shared recipe file not found")
|
||||||
|
|
||||||
|
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||||
|
filename_base = (
|
||||||
|
f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id
|
||||||
|
)
|
||||||
|
ext = os.path.splitext(file_path)[1]
|
||||||
|
download_filename = f"{filename_base}{ext}"
|
||||||
|
return DownloadInfo(file_path=file_path, download_filename=download_filename)
|
||||||
|
|
||||||
|
def _cleanup_shared_recipes(self) -> None:
|
||||||
|
for recipe_id in list(self._shared_recipes.keys()):
|
||||||
|
shared = self._shared_recipes.get(recipe_id)
|
||||||
|
if not shared:
|
||||||
|
continue
|
||||||
|
if time.time() > shared.get("expires", 0):
|
||||||
|
self._cleanup_entry(recipe_id)
|
||||||
|
|
||||||
|
def _cleanup_entry(self, recipe_id: str) -> None:
|
||||||
|
shared_info = self._shared_recipes.pop(recipe_id, None)
|
||||||
|
if not shared_info:
|
||||||
|
return
|
||||||
|
file_path = shared_info.get("path")
|
||||||
|
if file_path and os.path.exists(file_path):
|
||||||
|
try:
|
||||||
|
os.unlink(file_path)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc)
|
||||||
@@ -5,10 +5,41 @@ from typing import Any, Dict
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SETTINGS: Dict[str, Any] = {
|
||||||
|
"civitai_api_key": "",
|
||||||
|
"language": "en",
|
||||||
|
"show_only_sfw": False,
|
||||||
|
"enable_metadata_archive_db": False,
|
||||||
|
"proxy_enabled": False,
|
||||||
|
"proxy_host": "",
|
||||||
|
"proxy_port": "",
|
||||||
|
"proxy_username": "",
|
||||||
|
"proxy_password": "",
|
||||||
|
"proxy_type": "http",
|
||||||
|
"default_lora_root": "",
|
||||||
|
"default_checkpoint_root": "",
|
||||||
|
"default_embedding_root": "",
|
||||||
|
"base_model_path_mappings": {},
|
||||||
|
"download_path_templates": {},
|
||||||
|
"example_images_path": "",
|
||||||
|
"optimize_example_images": True,
|
||||||
|
"auto_download_example_images": False,
|
||||||
|
"blur_mature_content": True,
|
||||||
|
"autoplay_on_hover": False,
|
||||||
|
"display_density": "default",
|
||||||
|
"card_info_display": "always",
|
||||||
|
"include_trigger_words": False,
|
||||||
|
"compact_mode": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class SettingsManager:
|
class SettingsManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'settings.json')
|
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'settings.json')
|
||||||
self.settings = self._load_settings()
|
self.settings = self._load_settings()
|
||||||
|
self._migrate_setting_keys()
|
||||||
|
self._ensure_default_settings()
|
||||||
self._migrate_download_path_template()
|
self._migrate_download_path_template()
|
||||||
self._auto_set_default_roots()
|
self._auto_set_default_roots()
|
||||||
self._check_environment_variables()
|
self._check_environment_variables()
|
||||||
@@ -23,11 +54,49 @@ class SettingsManager:
|
|||||||
logger.error(f"Error loading settings: {e}")
|
logger.error(f"Error loading settings: {e}")
|
||||||
return self._get_default_settings()
|
return self._get_default_settings()
|
||||||
|
|
||||||
|
def _ensure_default_settings(self) -> None:
|
||||||
|
"""Ensure all default settings keys exist"""
|
||||||
|
updated = False
|
||||||
|
for key, value in self._get_default_settings().items():
|
||||||
|
if key not in self.settings:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
self.settings[key] = value.copy()
|
||||||
|
else:
|
||||||
|
self.settings[key] = value
|
||||||
|
updated = True
|
||||||
|
if updated:
|
||||||
|
self._save_settings()
|
||||||
|
|
||||||
|
def _migrate_setting_keys(self) -> None:
|
||||||
|
"""Migrate legacy camelCase setting keys to snake_case"""
|
||||||
|
key_migrations = {
|
||||||
|
'optimizeExampleImages': 'optimize_example_images',
|
||||||
|
'autoDownloadExampleImages': 'auto_download_example_images',
|
||||||
|
'blurMatureContent': 'blur_mature_content',
|
||||||
|
'autoplayOnHover': 'autoplay_on_hover',
|
||||||
|
'displayDensity': 'display_density',
|
||||||
|
'cardInfoDisplay': 'card_info_display',
|
||||||
|
'includeTriggerWords': 'include_trigger_words',
|
||||||
|
'compactMode': 'compact_mode',
|
||||||
|
}
|
||||||
|
|
||||||
|
updated = False
|
||||||
|
for old_key, new_key in key_migrations.items():
|
||||||
|
if old_key in self.settings:
|
||||||
|
if new_key not in self.settings:
|
||||||
|
self.settings[new_key] = self.settings[old_key]
|
||||||
|
del self.settings[old_key]
|
||||||
|
updated = True
|
||||||
|
|
||||||
|
if updated:
|
||||||
|
logger.info("Migrated legacy setting keys to snake_case")
|
||||||
|
self._save_settings()
|
||||||
|
|
||||||
def _migrate_download_path_template(self):
|
def _migrate_download_path_template(self):
|
||||||
"""Migrate old download_path_template to new download_path_templates"""
|
"""Migrate old download_path_template to new download_path_templates"""
|
||||||
old_template = self.settings.get('download_path_template')
|
old_template = self.settings.get('download_path_template')
|
||||||
templates = self.settings.get('download_path_templates')
|
templates = self.settings.get('download_path_templates')
|
||||||
|
|
||||||
# If old template exists and new templates don't exist, migrate
|
# If old template exists and new templates don't exist, migrate
|
||||||
if old_template is not None and not templates:
|
if old_template is not None and not templates:
|
||||||
logger.info("Migrating download_path_template to download_path_templates")
|
logger.info("Migrating download_path_template to download_path_templates")
|
||||||
@@ -78,18 +147,11 @@ class SettingsManager:
|
|||||||
|
|
||||||
def _get_default_settings(self) -> Dict[str, Any]:
|
def _get_default_settings(self) -> Dict[str, Any]:
|
||||||
"""Return default settings"""
|
"""Return default settings"""
|
||||||
return {
|
defaults = DEFAULT_SETTINGS.copy()
|
||||||
"civitai_api_key": "",
|
# Ensure nested dicts are independent copies
|
||||||
"language": "en",
|
defaults['base_model_path_mappings'] = {}
|
||||||
"show_only_sfw": False, # Show only SFW content
|
defaults['download_path_templates'] = {}
|
||||||
"enable_metadata_archive_db": False, # Enable metadata archive database
|
return defaults
|
||||||
"proxy_enabled": False, # Enable app-level proxy
|
|
||||||
"proxy_host": "", # Proxy host
|
|
||||||
"proxy_port": "", # Proxy port
|
|
||||||
"proxy_username": "", # Proxy username (optional)
|
|
||||||
"proxy_password": "", # Proxy password (optional)
|
|
||||||
"proxy_type": "http" # Proxy type: http, https, socks4, socks5
|
|
||||||
}
|
|
||||||
|
|
||||||
def get(self, key: str, default: Any = None) -> Any:
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
"""Get setting value"""
|
"""Get setting value"""
|
||||||
|
|||||||
47
py/services/tag_update_service.py
Normal file
47
py/services/tag_update_service.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
"""Service for updating tag collections on metadata records."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Awaitable, Callable, Dict, List, Sequence
|
||||||
|
|
||||||
|
|
||||||
|
class TagUpdateService:
|
||||||
|
"""Encapsulate tag manipulation for models."""
|
||||||
|
|
||||||
|
def __init__(self, *, metadata_manager) -> None:
|
||||||
|
self._metadata_manager = metadata_manager
|
||||||
|
|
||||||
|
async def add_tags(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
new_tags: Sequence[str],
|
||||||
|
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||||
|
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
|
||||||
|
|
||||||
|
base, _ = os.path.splitext(file_path)
|
||||||
|
metadata_path = f"{base}.metadata.json"
|
||||||
|
metadata = await metadata_loader(metadata_path)
|
||||||
|
|
||||||
|
existing_tags = list(metadata.get("tags", []))
|
||||||
|
existing_lower = [tag.lower() for tag in existing_tags]
|
||||||
|
|
||||||
|
tags_added: List[str] = []
|
||||||
|
for tag in new_tags:
|
||||||
|
if isinstance(tag, str) and tag.strip():
|
||||||
|
normalized = tag.strip()
|
||||||
|
if normalized.lower() not in existing_lower:
|
||||||
|
existing_tags.append(normalized)
|
||||||
|
existing_lower.append(normalized.lower())
|
||||||
|
tags_added.append(normalized)
|
||||||
|
|
||||||
|
metadata["tags"] = existing_tags
|
||||||
|
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||||
|
await update_cache(file_path, file_path, metadata)
|
||||||
|
|
||||||
|
return existing_tags
|
||||||
|
|
||||||
37
py/services/use_cases/__init__.py
Normal file
37
py/services/use_cases/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Application-level orchestration services for model routes."""
|
||||||
|
|
||||||
|
from .auto_organize_use_case import (
|
||||||
|
AutoOrganizeInProgressError,
|
||||||
|
AutoOrganizeUseCase,
|
||||||
|
)
|
||||||
|
from .bulk_metadata_refresh_use_case import (
|
||||||
|
BulkMetadataRefreshUseCase,
|
||||||
|
MetadataRefreshProgressReporter,
|
||||||
|
)
|
||||||
|
from .download_model_use_case import (
|
||||||
|
DownloadModelEarlyAccessError,
|
||||||
|
DownloadModelUseCase,
|
||||||
|
DownloadModelValidationError,
|
||||||
|
)
|
||||||
|
from .example_images import (
|
||||||
|
DownloadExampleImagesConfigurationError,
|
||||||
|
DownloadExampleImagesInProgressError,
|
||||||
|
DownloadExampleImagesUseCase,
|
||||||
|
ImportExampleImagesUseCase,
|
||||||
|
ImportExampleImagesValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AutoOrganizeInProgressError",
|
||||||
|
"AutoOrganizeUseCase",
|
||||||
|
"BulkMetadataRefreshUseCase",
|
||||||
|
"MetadataRefreshProgressReporter",
|
||||||
|
"DownloadModelEarlyAccessError",
|
||||||
|
"DownloadModelUseCase",
|
||||||
|
"DownloadModelValidationError",
|
||||||
|
"DownloadExampleImagesConfigurationError",
|
||||||
|
"DownloadExampleImagesInProgressError",
|
||||||
|
"DownloadExampleImagesUseCase",
|
||||||
|
"ImportExampleImagesUseCase",
|
||||||
|
"ImportExampleImagesValidationError",
|
||||||
|
]
|
||||||
56
py/services/use_cases/auto_organize_use_case.py
Normal file
56
py/services/use_cases/auto_organize_use_case.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Auto-organize use case orchestrating concurrency and progress handling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Protocol, Sequence
|
||||||
|
|
||||||
|
from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback
|
||||||
|
|
||||||
|
|
||||||
|
class AutoOrganizeLockProvider(Protocol):
|
||||||
|
"""Minimal protocol for objects exposing auto-organize locking primitives."""
|
||||||
|
|
||||||
|
def is_auto_organize_running(self) -> bool:
|
||||||
|
"""Return ``True`` when an auto-organize operation is in-flight."""
|
||||||
|
|
||||||
|
async def get_auto_organize_lock(self) -> asyncio.Lock:
|
||||||
|
"""Return the asyncio lock guarding auto-organize operations."""
|
||||||
|
|
||||||
|
|
||||||
|
class AutoOrganizeInProgressError(RuntimeError):
|
||||||
|
"""Raised when an auto-organize run is already active."""
|
||||||
|
|
||||||
|
|
||||||
|
class AutoOrganizeUseCase:
|
||||||
|
"""Coordinate auto-organize execution behind a shared lock."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_service: ModelFileService,
|
||||||
|
lock_provider: AutoOrganizeLockProvider,
|
||||||
|
) -> None:
|
||||||
|
self._file_service = file_service
|
||||||
|
self._lock_provider = lock_provider
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
file_paths: Optional[Sequence[str]] = None,
|
||||||
|
progress_callback: Optional[ProgressCallback] = None,
|
||||||
|
) -> AutoOrganizeResult:
|
||||||
|
"""Run the auto-organize routine guarded by a shared lock."""
|
||||||
|
|
||||||
|
if self._lock_provider.is_auto_organize_running():
|
||||||
|
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||||
|
|
||||||
|
lock = await self._lock_provider.get_auto_organize_lock()
|
||||||
|
if lock.locked():
|
||||||
|
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||||
|
|
||||||
|
async with lock:
|
||||||
|
return await self._file_service.auto_organize_models(
|
||||||
|
file_paths=list(file_paths) if file_paths is not None else None,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Use case encapsulating the bulk metadata refresh orchestration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Protocol, Sequence
|
||||||
|
|
||||||
|
from ..metadata_sync_service import MetadataSyncService
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataRefreshProgressReporter(Protocol):
|
||||||
|
"""Protocol for progress reporters used during metadata refresh."""
|
||||||
|
|
||||||
|
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||||
|
"""Handle a metadata refresh progress update."""
|
||||||
|
|
||||||
|
|
||||||
|
class BulkMetadataRefreshUseCase:
|
||||||
|
"""Coordinate bulk metadata refreshes with progress emission."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
service,
|
||||||
|
metadata_sync: MetadataSyncService,
|
||||||
|
settings_service,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
) -> None:
|
||||||
|
self._service = service
|
||||||
|
self._metadata_sync = metadata_sync
|
||||||
|
self._settings = settings_service
|
||||||
|
self._logger = logger or logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Refresh metadata for all qualifying models."""
|
||||||
|
|
||||||
|
cache = await self._service.scanner.get_cached_data()
|
||||||
|
total_models = len(cache.raw_data)
|
||||||
|
|
||||||
|
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
|
||||||
|
to_process: Sequence[Dict[str, Any]] = [
|
||||||
|
model
|
||||||
|
for model in cache.raw_data
|
||||||
|
if model.get("sha256")
|
||||||
|
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||||
|
and (
|
||||||
|
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||||
|
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
total_to_process = len(to_process)
|
||||||
|
processed = 0
|
||||||
|
success = 0
|
||||||
|
needs_resort = False
|
||||||
|
|
||||||
|
async def emit(status: str, **extra: Any) -> None:
|
||||||
|
if progress_callback is None:
|
||||||
|
return
|
||||||
|
payload = {"status": status, "total": total_to_process, "processed": processed, "success": success}
|
||||||
|
payload.update(extra)
|
||||||
|
await progress_callback.on_progress(payload)
|
||||||
|
|
||||||
|
await emit("started")
|
||||||
|
|
||||||
|
for model in to_process:
|
||||||
|
try:
|
||||||
|
original_name = model.get("model_name")
|
||||||
|
result, _ = await self._metadata_sync.fetch_and_update_model(
|
||||||
|
sha256=model["sha256"],
|
||||||
|
file_path=model["file_path"],
|
||||||
|
model_data=model,
|
||||||
|
update_cache_func=self._service.scanner.update_single_model_cache,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
success += 1
|
||||||
|
if original_name != model.get("model_name"):
|
||||||
|
needs_resort = True
|
||||||
|
processed += 1
|
||||||
|
await emit(
|
||||||
|
"processing",
|
||||||
|
processed=processed,
|
||||||
|
success=success,
|
||||||
|
current_name=model.get("model_name", "Unknown"),
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover - logging path
|
||||||
|
processed += 1
|
||||||
|
self._logger.error(
|
||||||
|
"Error fetching CivitAI data for %s: %s",
|
||||||
|
model.get("file_path"),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
if needs_resort:
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
await emit("completed", processed=processed, success=success)
|
||||||
|
|
||||||
|
message = (
|
||||||
|
"Successfully updated "
|
||||||
|
f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models}
|
||||||
|
|
||||||
|
async def execute_with_error_handling(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Wrapper providing progress notification on unexpected failures."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self.execute(progress_callback=progress_callback)
|
||||||
|
except Exception as exc:
|
||||||
|
if progress_callback is not None:
|
||||||
|
await progress_callback.on_progress({"status": "error", "error": str(exc)})
|
||||||
|
raise
|
||||||
37
py/services/use_cases/download_model_use_case.py
Normal file
37
py/services/use_cases/download_model_use_case.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Use case for scheduling model downloads with consistent error handling."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from ..download_coordinator import DownloadCoordinator
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadModelValidationError(ValueError):
|
||||||
|
"""Raised when incoming payload validation fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadModelEarlyAccessError(RuntimeError):
|
||||||
|
"""Raised when the download is gated behind Civitai early access."""
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadModelUseCase:
|
||||||
|
"""Coordinate download scheduling through the coordinator service."""
|
||||||
|
|
||||||
|
def __init__(self, *, download_coordinator: DownloadCoordinator) -> None:
|
||||||
|
self._download_coordinator = download_coordinator
|
||||||
|
|
||||||
|
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Schedule a download and normalize error conditions."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._download_coordinator.schedule_download(payload)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise DownloadModelValidationError(str(exc)) from exc
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging path
|
||||||
|
message = str(exc)
|
||||||
|
if "401" in message:
|
||||||
|
raise DownloadModelEarlyAccessError(
|
||||||
|
"Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||||
|
) from exc
|
||||||
|
raise
|
||||||
19
py/services/use_cases/example_images/__init__.py
Normal file
19
py/services/use_cases/example_images/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""Example image specific use case exports."""
|
||||||
|
|
||||||
|
from .download_example_images_use_case import (
|
||||||
|
DownloadExampleImagesUseCase,
|
||||||
|
DownloadExampleImagesInProgressError,
|
||||||
|
DownloadExampleImagesConfigurationError,
|
||||||
|
)
|
||||||
|
from .import_example_images_use_case import (
|
||||||
|
ImportExampleImagesUseCase,
|
||||||
|
ImportExampleImagesValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DownloadExampleImagesUseCase",
|
||||||
|
"DownloadExampleImagesInProgressError",
|
||||||
|
"DownloadExampleImagesConfigurationError",
|
||||||
|
"ImportExampleImagesUseCase",
|
||||||
|
"ImportExampleImagesValidationError",
|
||||||
|
]
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
"""Use case coordinating example image downloads."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from ....utils.example_images_download_manager import (
|
||||||
|
DownloadConfigurationError,
|
||||||
|
DownloadInProgressError,
|
||||||
|
ExampleImagesDownloadError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadExampleImagesInProgressError(RuntimeError):
|
||||||
|
"""Raised when a download is already running."""
|
||||||
|
|
||||||
|
def __init__(self, progress: Dict[str, Any]) -> None:
|
||||||
|
super().__init__("Download already in progress")
|
||||||
|
self.progress = progress
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadExampleImagesConfigurationError(ValueError):
|
||||||
|
"""Raised when settings prevent downloads from starting."""
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadExampleImagesUseCase:
|
||||||
|
"""Validate payloads and trigger the download manager."""
|
||||||
|
|
||||||
|
def __init__(self, *, download_manager) -> None:
|
||||||
|
self._download_manager = download_manager
|
||||||
|
|
||||||
|
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Start a download and translate manager errors."""
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._download_manager.start_download(payload)
|
||||||
|
except DownloadInProgressError as exc:
|
||||||
|
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
|
||||||
|
except DownloadConfigurationError as exc:
|
||||||
|
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
|
||||||
|
except ExampleImagesDownloadError:
|
||||||
|
raise
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
"""Use case for importing example images."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from contextlib import suppress
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from ....utils.example_images_processor import (
|
||||||
|
ExampleImagesImportError,
|
||||||
|
ExampleImagesProcessor,
|
||||||
|
ExampleImagesValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImportExampleImagesValidationError(ValueError):
|
||||||
|
"""Raised when request validation fails."""
|
||||||
|
|
||||||
|
|
||||||
|
class ImportExampleImagesUseCase:
|
||||||
|
"""Parse upload payloads and delegate to the processor service."""
|
||||||
|
|
||||||
|
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
|
||||||
|
self._processor = processor
|
||||||
|
|
||||||
|
async def execute(self, request: web.Request) -> Dict[str, Any]:
|
||||||
|
model_hash: str | None = None
|
||||||
|
files_to_import: List[str] = []
|
||||||
|
temp_files: List[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
if request.content_type and "multipart/form-data" in request.content_type:
|
||||||
|
reader = await request.multipart()
|
||||||
|
|
||||||
|
first_field = await reader.next()
|
||||||
|
if first_field and first_field.name == "model_hash":
|
||||||
|
model_hash = await first_field.text()
|
||||||
|
else:
|
||||||
|
# Support clients that send files first and hash later
|
||||||
|
if first_field is not None:
|
||||||
|
await self._collect_upload_file(first_field, files_to_import, temp_files)
|
||||||
|
|
||||||
|
async for field in reader:
|
||||||
|
if field.name == "model_hash" and not model_hash:
|
||||||
|
model_hash = await field.text()
|
||||||
|
elif field.name == "files":
|
||||||
|
await self._collect_upload_file(field, files_to_import, temp_files)
|
||||||
|
else:
|
||||||
|
data = await request.json()
|
||||||
|
model_hash = data.get("model_hash")
|
||||||
|
files_to_import = list(data.get("file_paths", []))
|
||||||
|
|
||||||
|
result = await self._processor.import_images(model_hash, files_to_import)
|
||||||
|
return result
|
||||||
|
except ExampleImagesValidationError as exc:
|
||||||
|
raise ImportExampleImagesValidationError(str(exc)) from exc
|
||||||
|
except ExampleImagesImportError:
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
for path in temp_files:
|
||||||
|
with suppress(Exception):
|
||||||
|
os.remove(path)
|
||||||
|
|
||||||
|
async def _collect_upload_file(
|
||||||
|
self,
|
||||||
|
field: Any,
|
||||||
|
files_to_import: List[str],
|
||||||
|
temp_files: List[str],
|
||||||
|
) -> None:
|
||||||
|
"""Persist an uploaded file to disk and add it to the import list."""
|
||||||
|
|
||||||
|
filename = field.filename or "upload"
|
||||||
|
file_ext = os.path.splitext(filename)[1].lower()
|
||||||
|
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||||
|
temp_files.append(tmp_file.name)
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk()
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
tmp_file.write(chunk)
|
||||||
|
|
||||||
|
files_to_import.append(tmp_file.name)
|
||||||
@@ -1,11 +1,29 @@
|
|||||||
from typing import Dict, Any
|
"""Progress callback implementations backed by the shared WebSocket manager."""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Protocol
|
||||||
|
|
||||||
from .model_file_service import ProgressCallback
|
from .model_file_service import ProgressCallback
|
||||||
from .websocket_manager import ws_manager
|
from .websocket_manager import ws_manager
|
||||||
|
|
||||||
|
|
||||||
class WebSocketProgressCallback(ProgressCallback):
|
class ProgressReporter(Protocol):
|
||||||
"""WebSocket implementation of progress callback"""
|
"""Protocol representing an async progress callback."""
|
||||||
|
|
||||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||||
"""Send progress data via WebSocket"""
|
"""Handle a progress update payload."""
|
||||||
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
|
||||||
|
|
||||||
|
class WebSocketProgressCallback(ProgressCallback):
|
||||||
|
"""WebSocket implementation of progress callback."""
|
||||||
|
|
||||||
|
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||||
|
"""Send progress data via WebSocket."""
|
||||||
|
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketBroadcastCallback:
|
||||||
|
"""Generic WebSocket progress callback broadcasting to all clients."""
|
||||||
|
|
||||||
|
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||||
|
"""Send the provided payload to all connected clients."""
|
||||||
|
await ws_manager.broadcast(progress_data)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,39 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from ..utils.metadata_manager import MetadataManager
|
|
||||||
from ..utils.routes_common import ModelRouteUtils
|
from ..recipes.constants import GEN_PARAM_KEYS
|
||||||
|
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||||
|
from ..services.metadata_sync_service import MetadataSyncService
|
||||||
|
from ..services.preview_asset_service import PreviewAssetService
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
from ..services.downloader import get_downloader
|
||||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||||
from ..utils.exif_utils import ExifUtils
|
from ..utils.exif_utils import ExifUtils
|
||||||
from ..recipes.constants import GEN_PARAM_KEYS
|
from ..utils.metadata_manager import MetadataManager
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_preview_service = PreviewAssetService(
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
downloader_factory=get_downloader,
|
||||||
|
exif_utils=ExifUtils,
|
||||||
|
)
|
||||||
|
|
||||||
|
_metadata_sync_service = MetadataSyncService(
|
||||||
|
metadata_manager=MetadataManager,
|
||||||
|
preview_service=_preview_service,
|
||||||
|
settings=settings,
|
||||||
|
default_metadata_provider_factory=get_default_metadata_provider,
|
||||||
|
metadata_provider_selector=get_metadata_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataUpdater:
|
class MetadataUpdater:
|
||||||
"""Handles updating model metadata related to example images"""
|
"""Handles updating model metadata related to example images"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner):
|
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None):
|
||||||
"""Refresh model metadata from CivitAI
|
"""Refresh model metadata from CivitAI
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -25,8 +45,6 @@ class MetadataUpdater:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if metadata was successfully refreshed, False otherwise
|
bool: True if metadata was successfully refreshed, False otherwise
|
||||||
"""
|
"""
|
||||||
from ..utils.example_images_download_manager import download_progress
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Find the model in the scanner cache
|
# Find the model in the scanner cache
|
||||||
cache = await scanner.get_cached_data()
|
cache = await scanner.get_cached_data()
|
||||||
@@ -47,17 +65,17 @@ class MetadataUpdater:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Track that we're refreshing this model
|
# Track that we're refreshing this model
|
||||||
download_progress['refreshed_models'].add(model_hash)
|
if progress is not None:
|
||||||
|
progress['refreshed_models'].add(model_hash)
|
||||||
|
|
||||||
# Use ModelRouteUtils to refresh metadata
|
|
||||||
async def update_cache_func(old_path, new_path, metadata):
|
async def update_cache_func(old_path, new_path, metadata):
|
||||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||||
|
|
||||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
success, error = await _metadata_sync_service.fetch_and_update_model(
|
||||||
model_hash,
|
sha256=model_hash,
|
||||||
file_path,
|
file_path=file_path,
|
||||||
model_data,
|
model_data=model_data,
|
||||||
update_cache_func
|
update_cache_func=update_cache_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -66,12 +84,13 @@ class MetadataUpdater:
|
|||||||
else:
|
else:
|
||||||
logger.warning(f"Failed to refresh metadata for {model_name}, {error}")
|
logger.warning(f"Failed to refresh metadata for {model_name}, {error}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
|
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
download_progress['errors'].append(error_msg)
|
if progress is not None:
|
||||||
download_progress['last_error'] = error_msg
|
progress['errors'].append(error_msg)
|
||||||
|
progress['last_error'] = error_msg
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import tempfile
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagesImportError(RuntimeError):
|
||||||
|
"""Base error for example image import operations."""
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleImagesValidationError(ExampleImagesImportError):
|
||||||
|
"""Raised when input validation fails."""
|
||||||
|
|
||||||
class ExampleImagesProcessor:
|
class ExampleImagesProcessor:
|
||||||
"""Processes and manipulates example images"""
|
"""Processes and manipulates example images"""
|
||||||
|
|
||||||
@@ -299,90 +306,29 @@ class ExampleImagesProcessor:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def import_images(request):
|
async def import_images(model_hash: str, files_to_import: list[str]):
|
||||||
"""
|
"""Import local example images for a model."""
|
||||||
Import local example images
|
|
||||||
|
if not model_hash:
|
||||||
Accepts:
|
raise ExampleImagesValidationError('Missing model_hash parameter')
|
||||||
- multipart/form-data form with model_hash and files fields
|
|
||||||
or
|
if not files_to_import:
|
||||||
- JSON request with model_hash and file_paths
|
raise ExampleImagesValidationError('No files provided to import')
|
||||||
|
|
||||||
Returns:
|
|
||||||
- Success status and list of imported files
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
model_hash = None
|
|
||||||
files_to_import = []
|
|
||||||
temp_files_to_cleanup = []
|
|
||||||
|
|
||||||
# Check if it's a multipart form-data request (direct file upload)
|
|
||||||
if request.content_type and 'multipart/form-data' in request.content_type:
|
|
||||||
reader = await request.multipart()
|
|
||||||
|
|
||||||
# First get model_hash
|
|
||||||
field = await reader.next()
|
|
||||||
if field.name == 'model_hash':
|
|
||||||
model_hash = await field.text()
|
|
||||||
|
|
||||||
# Then process all files
|
|
||||||
while True:
|
|
||||||
field = await reader.next()
|
|
||||||
if field is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
if field.name == 'files':
|
|
||||||
# Create a temporary file with appropriate suffix for type detection
|
|
||||||
file_name = field.filename
|
|
||||||
file_ext = os.path.splitext(file_name)[1].lower()
|
|
||||||
|
|
||||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
|
||||||
temp_path = tmp_file.name
|
|
||||||
temp_files_to_cleanup.append(temp_path) # Track for cleanup
|
|
||||||
|
|
||||||
# Write chunks to the temporary file
|
|
||||||
while True:
|
|
||||||
chunk = await field.read_chunk()
|
|
||||||
if not chunk:
|
|
||||||
break
|
|
||||||
tmp_file.write(chunk)
|
|
||||||
|
|
||||||
# Add to the list of files to process
|
|
||||||
files_to_import.append(temp_path)
|
|
||||||
else:
|
|
||||||
# Parse JSON request (legacy method using file paths)
|
|
||||||
data = await request.json()
|
|
||||||
model_hash = data.get('model_hash')
|
|
||||||
files_to_import = data.get('file_paths', [])
|
|
||||||
|
|
||||||
if not model_hash:
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'Missing model_hash parameter'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
if not files_to_import:
|
|
||||||
return web.json_response({
|
|
||||||
'success': False,
|
|
||||||
'error': 'No files provided to import'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Get example images path
|
# Get example images path
|
||||||
example_images_path = settings.get('example_images_path')
|
example_images_path = settings.get('example_images_path')
|
||||||
if not example_images_path:
|
if not example_images_path:
|
||||||
return web.json_response({
|
raise ExampleImagesValidationError('No example images path configured')
|
||||||
'success': False,
|
|
||||||
'error': 'No example images path configured'
|
|
||||||
}, status=400)
|
|
||||||
|
|
||||||
# Find the model and get current metadata
|
# Find the model and get current metadata
|
||||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
|
|
||||||
model_data = None
|
model_data = None
|
||||||
scanner = None
|
scanner = None
|
||||||
|
|
||||||
# Check both scanners to find the model
|
# Check both scanners to find the model
|
||||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||||
cache = await scan_obj.get_cached_data()
|
cache = await scan_obj.get_cached_data()
|
||||||
@@ -393,21 +339,20 @@ class ExampleImagesProcessor:
|
|||||||
break
|
break
|
||||||
if model_data:
|
if model_data:
|
||||||
break
|
break
|
||||||
|
|
||||||
if not model_data:
|
if not model_data:
|
||||||
return web.json_response({
|
raise ExampleImagesImportError(
|
||||||
'success': False,
|
f"Model with hash {model_hash} not found in cache"
|
||||||
'error': f"Model with hash {model_hash} not found in cache"
|
)
|
||||||
}, status=404)
|
|
||||||
|
|
||||||
# Create model folder
|
# Create model folder
|
||||||
model_folder = os.path.join(example_images_path, model_hash)
|
model_folder = os.path.join(example_images_path, model_hash)
|
||||||
os.makedirs(model_folder, exist_ok=True)
|
os.makedirs(model_folder, exist_ok=True)
|
||||||
|
|
||||||
imported_files = []
|
imported_files = []
|
||||||
errors = []
|
errors = []
|
||||||
newly_imported_paths = []
|
newly_imported_paths = []
|
||||||
|
|
||||||
# Process each file path
|
# Process each file path
|
||||||
for file_path in files_to_import:
|
for file_path in files_to_import:
|
||||||
try:
|
try:
|
||||||
@@ -415,26 +360,26 @@ class ExampleImagesProcessor:
|
|||||||
if not os.path.isfile(file_path):
|
if not os.path.isfile(file_path):
|
||||||
errors.append(f"File not found: {file_path}")
|
errors.append(f"File not found: {file_path}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check if file type is supported
|
# Check if file type is supported
|
||||||
file_ext = os.path.splitext(file_path)[1].lower()
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
|
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
|
||||||
errors.append(f"Unsupported file type: {file_path}")
|
errors.append(f"Unsupported file type: {file_path}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Generate new filename using short ID instead of UUID
|
# Generate new filename using short ID instead of UUID
|
||||||
short_id = ExampleImagesProcessor.generate_short_id()
|
short_id = ExampleImagesProcessor.generate_short_id()
|
||||||
new_filename = f"custom_{short_id}{file_ext}"
|
new_filename = f"custom_{short_id}{file_ext}"
|
||||||
|
|
||||||
dest_path = os.path.join(model_folder, new_filename)
|
dest_path = os.path.join(model_folder, new_filename)
|
||||||
|
|
||||||
# Copy the file
|
# Copy the file
|
||||||
import shutil
|
import shutil
|
||||||
shutil.copy2(file_path, dest_path)
|
shutil.copy2(file_path, dest_path)
|
||||||
# Store both the dest_path and the short_id
|
# Store both the dest_path and the short_id
|
||||||
newly_imported_paths.append((dest_path, short_id))
|
newly_imported_paths.append((dest_path, short_id))
|
||||||
|
|
||||||
# Add to imported files list
|
# Add to imported files list
|
||||||
imported_files.append({
|
imported_files.append({
|
||||||
'name': new_filename,
|
'name': new_filename,
|
||||||
@@ -444,39 +389,31 @@ class ExampleImagesProcessor:
|
|||||||
})
|
})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.append(f"Error importing {file_path}: {str(e)}")
|
errors.append(f"Error importing {file_path}: {str(e)}")
|
||||||
|
|
||||||
# Update metadata with new example images
|
# Update metadata with new example images
|
||||||
regular_images, custom_images = await MetadataUpdater.update_metadata_after_import(
|
regular_images, custom_images = await MetadataUpdater.update_metadata_after_import(
|
||||||
model_hash,
|
model_hash,
|
||||||
model_data,
|
model_data,
|
||||||
scanner,
|
scanner,
|
||||||
newly_imported_paths
|
newly_imported_paths
|
||||||
)
|
)
|
||||||
|
|
||||||
return web.json_response({
|
return {
|
||||||
'success': len(imported_files) > 0,
|
'success': len(imported_files) > 0,
|
||||||
'message': f'Successfully imported {len(imported_files)} files' +
|
'message': f'Successfully imported {len(imported_files)} files' +
|
||||||
(f' with {len(errors)} errors' if errors else ''),
|
(f' with {len(errors)} errors' if errors else ''),
|
||||||
'files': imported_files,
|
'files': imported_files,
|
||||||
'errors': errors,
|
'errors': errors,
|
||||||
'regular_images': regular_images,
|
'regular_images': regular_images,
|
||||||
'custom_images': custom_images,
|
'custom_images': custom_images,
|
||||||
"model_file_path": model_data.get('file_path', ''),
|
"model_file_path": model_data.get('file_path', ''),
|
||||||
})
|
}
|
||||||
|
|
||||||
|
except ExampleImagesImportError:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to import example images: {e}", exc_info=True)
|
logger.error(f"Failed to import example images: {e}", exc_info=True)
|
||||||
return web.json_response({
|
raise ExampleImagesImportError(str(e)) from e
|
||||||
'success': False,
|
|
||||||
'error': str(e)
|
|
||||||
}, status=500)
|
|
||||||
finally:
|
|
||||||
# Clean up temporary files
|
|
||||||
for temp_file in temp_files_to_cleanup:
|
|
||||||
try:
|
|
||||||
os.remove(temp_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def delete_custom_image(request):
|
async def delete_custom_image(request):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
@@ -12,7 +11,7 @@ from ..config import config
|
|||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
# Check if running in standalone mode
|
# Check if running in standalone mode
|
||||||
standalone_mode = 'nodes' not in sys.modules
|
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||||
|
|
||||||
if not standalone_mode:
|
if not standalone_mode:
|
||||||
from ..metadata_collector.metadata_registry import MetadataRegistry
|
from ..metadata_collector.metadata_registry import MetadataRegistry
|
||||||
|
|||||||
11
pytest.ini
Normal file
11
pytest.ini
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
[pytest]
|
||||||
|
addopts = -v --import-mode=importlib
|
||||||
|
testpaths = tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
# Register async marker for coroutine-style tests
|
||||||
|
markers =
|
||||||
|
asyncio: execute test within asyncio event loop
|
||||||
|
# Skip problematic directories to avoid import conflicts
|
||||||
|
norecursedirs = .git .tox dist build *.egg __pycache__ py
|
||||||
110
refs/civitai_api_model_by_modelId.json
Normal file
110
refs/civitai_api_model_by_modelId.json
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
{
|
||||||
|
"id": 1231067,
|
||||||
|
"name": "Vivid Impressions Storybook Style",
|
||||||
|
"description": "<h3 id=\"if-you'd-like-to-support-me-feel-free-to-visit-my-ko-fi-page.-please-share-your-images-using-the-"+add-post"-button-below.-it-supports-the-creators.-thanks!-nnfwkvfly\">If you'd like to support me, feel free to visit my <a target=\"_blank\" rel=\"ugc\" href=\"https://ko-fi.com/pixelpawsai\">Ko-Fi</a> page. ❤️<br /><br />Please share your images using the \"<span style=\"color:rgb(250, 82, 82)\">+add post</span>\" button below. It supports the creators. Thanks! 💕</h3><h3 id=\"if-you-like-my-lora-please-like-comment-or-donate-some-buzz.-much-appreciated!-vyeqok3go\">If you like my LoRA, please<span style=\"color:rgb(230, 73, 128)\"> </span><span style=\"color:rgb(250, 82, 82)\">like</span>, <span style=\"color:rgb(250, 82, 82)\">comment</span>, or <span style=\"color:#fa5252\">donate some Buzz</span>. Much appreciated! ❤️</h3><h3 id=\"-lo912t8rj\"></h3><h3 id=\"trigger-word:-ppstorybook-wlggllim2\"><strong><span style=\"color:rgb(253, 126, 20)\">Trigger word: </span></strong>ppstorybook</h3><h3 id=\"strength:-0.8-experiment-as-you-like-luvhks6za\"><strong><span style=\"color:rgb(253, 126, 20)\">Strength: </span></strong>0.8, experiment as you like</h3>",
|
||||||
|
"allowNoCredit": true,
|
||||||
|
"allowCommercialUse": [
|
||||||
|
"Image",
|
||||||
|
"RentCivit",
|
||||||
|
"Rent",
|
||||||
|
"Sell"
|
||||||
|
],
|
||||||
|
"allowDerivatives": true,
|
||||||
|
"allowDifferentLicense": true,
|
||||||
|
"type": "LORA",
|
||||||
|
"minor": false,
|
||||||
|
"sfwOnly": false,
|
||||||
|
"poi": false,
|
||||||
|
"nsfw": false,
|
||||||
|
"nsfwLevel": 1,
|
||||||
|
"availability": "Public",
|
||||||
|
"cosmetic": null,
|
||||||
|
"supportsGeneration": true,
|
||||||
|
"stats": {
|
||||||
|
"downloadCount": 2183,
|
||||||
|
"favoriteCount": 0,
|
||||||
|
"thumbsUpCount": 416,
|
||||||
|
"thumbsDownCount": 0,
|
||||||
|
"commentCount": 12,
|
||||||
|
"ratingCount": 0,
|
||||||
|
"rating": 0,
|
||||||
|
"tippedAmountCount": 360
|
||||||
|
},
|
||||||
|
"creator": {
|
||||||
|
"username": "PixelPawsAI",
|
||||||
|
"image": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/f3a1aa7c-0159-4dd8-884a-1e7ceb350f96/width=96/PixelPawsAI.jpeg"
|
||||||
|
},
|
||||||
|
"tags": [
|
||||||
|
"style",
|
||||||
|
"illustration",
|
||||||
|
"storybook"
|
||||||
|
],
|
||||||
|
"modelVersions": [
|
||||||
|
{
|
||||||
|
"id": 1387174,
|
||||||
|
"index": 0,
|
||||||
|
"name": "v1.0",
|
||||||
|
"baseModel": "Flux.1 D",
|
||||||
|
"baseModelType": "Standard",
|
||||||
|
"createdAt": "2025-02-08T11:15:47.197Z",
|
||||||
|
"publishedAt": "2025-02-08T11:29:04.487Z",
|
||||||
|
"status": "Published",
|
||||||
|
"availability": "Public",
|
||||||
|
"nsfwLevel": 1,
|
||||||
|
"trainedWords": [
|
||||||
|
"ppstorybook"
|
||||||
|
],
|
||||||
|
"covered": true,
|
||||||
|
"stats": {
|
||||||
|
"downloadCount": 2183,
|
||||||
|
"ratingCount": 0,
|
||||||
|
"rating": 0,
|
||||||
|
"thumbsUpCount": 416,
|
||||||
|
"thumbsDownCount": 0
|
||||||
|
},
|
||||||
|
"files": [
|
||||||
|
{
|
||||||
|
"id": 1289799,
|
||||||
|
"sizeKB": 18829.1484375,
|
||||||
|
"name": "pp-storybook_rank2_bf16.safetensors",
|
||||||
|
"type": "Model",
|
||||||
|
"pickleScanResult": "Success",
|
||||||
|
"pickleScanMessage": "No Pickle imports",
|
||||||
|
"virusScanResult": "Success",
|
||||||
|
"virusScanMessage": null,
|
||||||
|
"scannedAt": "2025-02-08T11:21:04.247Z",
|
||||||
|
"metadata": {
|
||||||
|
"format": "SafeTensor"
|
||||||
|
},
|
||||||
|
"hashes": {
|
||||||
|
"AutoV1": "F414C813",
|
||||||
|
"AutoV2": "9753338AB6",
|
||||||
|
"SHA256": "9753338AB693CA82BF89ED77A5D1912879E40051463EC6E330FB9866CE798668",
|
||||||
|
"CRC32": "A65AE7B3",
|
||||||
|
"BLAKE3": "A5F8AB95AC2486345E4ACCAE541FF19D97ED53EFB0A7CC9226636975A0437591",
|
||||||
|
"AutoV3": "34A22376739D"
|
||||||
|
},
|
||||||
|
"downloadUrl": "https://civitai.com/api/download/models/1387174",
|
||||||
|
"primary": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"images": [
|
||||||
|
{
|
||||||
|
"url": "https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/42b875cf-c62b-41fa-a349-383b7f074351/original=true/56547310.jpeg",
|
||||||
|
"nsfwLevel": 1,
|
||||||
|
"width": 832,
|
||||||
|
"height": 1216,
|
||||||
|
"hash": "U5IiO6s-4Vn+0~EO^5xa00VsL#IU_O?E7yWC",
|
||||||
|
"type": "image",
|
||||||
|
"minor": false,
|
||||||
|
"poi": false,
|
||||||
|
"hasMeta": true,
|
||||||
|
"hasPositivePrompt": true,
|
||||||
|
"onSite": false,
|
||||||
|
"remixOfId": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"downloadUrl": "https://civitai.com/api/download/models/1387174"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -4,6 +4,9 @@ import sys
|
|||||||
import json
|
import json
|
||||||
from py.middleware.cache_middleware import cache_control
|
from py.middleware.cache_middleware import cache_control
|
||||||
|
|
||||||
|
# Set environment variable to indicate standalone mode
|
||||||
|
os.environ["COMFYUI_LORA_MANAGER_STANDALONE"] = "1"
|
||||||
|
|
||||||
# Create mock modules for py/nodes directory - add this before any other imports
|
# Create mock modules for py/nodes directory - add this before any other imports
|
||||||
def mock_nodes_directory():
|
def mock_nodes_directory():
|
||||||
"""Create mock modules for all Python files in the py/nodes directory"""
|
"""Create mock modules for all Python files in the py/nodes directory"""
|
||||||
@@ -418,7 +421,7 @@ class StandaloneLoraManager(LoraManager):
|
|||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
MiscRoutes.setup_routes(app)
|
MiscRoutes.setup_routes(app)
|
||||||
ExampleImagesRoutes.setup_routes(app)
|
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||||
|
|
||||||
# Setup WebSocket routes that are shared across all model types
|
# Setup WebSocket routes that are shared across all model types
|
||||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
|||||||
@@ -945,7 +945,7 @@ export class BaseModelApiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Determine optimize setting
|
// Determine optimize setting
|
||||||
const optimize = state.global?.settings?.optimizeExampleImages ?? true;
|
const optimize = state.global?.settings?.optimize_example_images ?? true;
|
||||||
|
|
||||||
// Make the API request to start the download process
|
// Make the API request to start the download process
|
||||||
const response = await fetch(DOWNLOAD_ENDPOINTS.exampleImages, {
|
const response = await fetch(DOWNLOAD_ENDPOINTS.exampleImages, {
|
||||||
|
|||||||
104
static/js/components/ContextMenu/GlobalContextMenu.js
Normal file
104
static/js/components/ContextMenu/GlobalContextMenu.js
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||||
|
import { showToast } from '../../utils/uiHelpers.js';
|
||||||
|
import { state } from '../../state/index.js';
|
||||||
|
|
||||||
|
export class GlobalContextMenu extends BaseContextMenu {
|
||||||
|
constructor() {
|
||||||
|
super('globalContextMenu');
|
||||||
|
this._cleanupInProgress = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
showMenu(x, y, origin = null) {
|
||||||
|
const contextOrigin = origin || { type: 'global' };
|
||||||
|
super.showMenu(x, y, contextOrigin);
|
||||||
|
}
|
||||||
|
|
||||||
|
handleMenuAction(action, menuItem) {
|
||||||
|
switch (action) {
|
||||||
|
case 'cleanup-example-images-folders':
|
||||||
|
this.cleanupExampleImagesFolders(menuItem).catch((error) => {
|
||||||
|
console.error('Failed to trigger example images cleanup:', error);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
case 'download-example-images':
|
||||||
|
this.downloadExampleImages(menuItem).catch((error) => {
|
||||||
|
console.error('Failed to trigger example images download:', error);
|
||||||
|
});
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
console.warn(`Unhandled global context menu action: ${action}`);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async downloadExampleImages(menuItem) {
|
||||||
|
const exampleImagesManager = window.exampleImagesManager;
|
||||||
|
|
||||||
|
if (!exampleImagesManager) {
|
||||||
|
showToast('globalContextMenu.downloadExampleImages.unavailable', {}, 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const downloadPath = state?.global?.settings?.example_images_path;
|
||||||
|
if (!downloadPath) {
|
||||||
|
showToast('globalContextMenu.downloadExampleImages.missingPath', {}, 'warning');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
menuItem?.classList.add('disabled');
|
||||||
|
|
||||||
|
try {
|
||||||
|
await exampleImagesManager.handleDownloadButton();
|
||||||
|
} finally {
|
||||||
|
menuItem?.classList.remove('disabled');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async cleanupExampleImagesFolders(menuItem) {
|
||||||
|
if (this._cleanupInProgress) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this._cleanupInProgress = true;
|
||||||
|
menuItem?.classList.add('disabled');
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await fetch('/api/lm/cleanup-example-image-folders', {
|
||||||
|
method: 'POST',
|
||||||
|
});
|
||||||
|
|
||||||
|
let payload;
|
||||||
|
try {
|
||||||
|
payload = await response.json();
|
||||||
|
} catch (parseError) {
|
||||||
|
payload = { error: 'Unexpected response format.' };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (response.ok && (payload.success || payload.partial_success)) {
|
||||||
|
const movedTotal = payload.moved_total || 0;
|
||||||
|
|
||||||
|
if (movedTotal > 0) {
|
||||||
|
showToast('globalContextMenu.cleanupExampleImages.success', { count: movedTotal }, 'success');
|
||||||
|
} else {
|
||||||
|
showToast('globalContextMenu.cleanupExampleImages.none', {}, 'info');
|
||||||
|
}
|
||||||
|
|
||||||
|
if (payload.partial_success) {
|
||||||
|
showToast(
|
||||||
|
'globalContextMenu.cleanupExampleImages.partial',
|
||||||
|
{ failures: payload.move_failures ?? 0 },
|
||||||
|
'warning',
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const message = payload?.error || 'Unknown error';
|
||||||
|
showToast('globalContextMenu.cleanupExampleImages.error', { message }, 'error');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
showToast('globalContextMenu.cleanupExampleImages.error', { message: error.message || 'Unknown error' }, 'error');
|
||||||
|
} finally {
|
||||||
|
this._cleanupInProgress = false;
|
||||||
|
menuItem?.classList.remove('disabled');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { BaseContextMenu } from './BaseContextMenu.js';
|
import { BaseContextMenu } from './BaseContextMenu.js';
|
||||||
import { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
import { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
||||||
import { getModelApiClient, resetAndReload } from '../../api/modelApiFactory.js';
|
import { getModelApiClient, resetAndReload } from '../../api/modelApiFactory.js';
|
||||||
import { copyLoraSyntax, sendLoraToWorkflow } from '../../utils/uiHelpers.js';
|
import { copyLoraSyntax, sendLoraToWorkflow, buildLoraSyntax } from '../../utils/uiHelpers.js';
|
||||||
import { showExcludeModal, showDeleteModal } from '../../utils/modalUtils.js';
|
import { showExcludeModal, showDeleteModal } from '../../utils/modalUtils.js';
|
||||||
import { moveManager } from '../../managers/MoveManager.js';
|
import { moveManager } from '../../managers/MoveManager.js';
|
||||||
|
|
||||||
@@ -70,9 +70,8 @@ export class LoraContextMenu extends BaseContextMenu {
|
|||||||
sendLoraToWorkflow(replaceMode) {
|
sendLoraToWorkflow(replaceMode) {
|
||||||
const card = this.currentCard;
|
const card = this.currentCard;
|
||||||
const usageTips = JSON.parse(card.dataset.usage_tips || '{}');
|
const usageTips = JSON.parse(card.dataset.usage_tips || '{}');
|
||||||
const strength = usageTips.strength || 1;
|
const loraSyntax = buildLoraSyntax(card.dataset.file_name, usageTips);
|
||||||
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
|
||||||
|
|
||||||
sendLoraToWorkflow(loraSyntax, replaceMode, 'lora');
|
sendLoraToWorkflow(loraSyntax, replaceMode, 'lora');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,12 +2,14 @@ export { LoraContextMenu } from './LoraContextMenu.js';
|
|||||||
export { RecipeContextMenu } from './RecipeContextMenu.js';
|
export { RecipeContextMenu } from './RecipeContextMenu.js';
|
||||||
export { CheckpointContextMenu } from './CheckpointContextMenu.js';
|
export { CheckpointContextMenu } from './CheckpointContextMenu.js';
|
||||||
export { EmbeddingContextMenu } from './EmbeddingContextMenu.js';
|
export { EmbeddingContextMenu } from './EmbeddingContextMenu.js';
|
||||||
|
export { GlobalContextMenu } from './GlobalContextMenu.js';
|
||||||
export { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
export { ModelContextMenuMixin } from './ModelContextMenuMixin.js';
|
||||||
|
|
||||||
import { LoraContextMenu } from './LoraContextMenu.js';
|
import { LoraContextMenu } from './LoraContextMenu.js';
|
||||||
import { RecipeContextMenu } from './RecipeContextMenu.js';
|
import { RecipeContextMenu } from './RecipeContextMenu.js';
|
||||||
import { CheckpointContextMenu } from './CheckpointContextMenu.js';
|
import { CheckpointContextMenu } from './CheckpointContextMenu.js';
|
||||||
import { EmbeddingContextMenu } from './EmbeddingContextMenu.js';
|
import { EmbeddingContextMenu } from './EmbeddingContextMenu.js';
|
||||||
|
import { GlobalContextMenu } from './GlobalContextMenu.js';
|
||||||
|
|
||||||
// Factory method to create page-specific context menu instances
|
// Factory method to create page-specific context menu instances
|
||||||
export function createPageContextMenu(pageType) {
|
export function createPageContextMenu(pageType) {
|
||||||
@@ -23,4 +25,8 @@ export function createPageContextMenu(pageType) {
|
|||||||
default:
|
default:
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function createGlobalContextMenu() {
|
||||||
|
return new GlobalContextMenu();
|
||||||
}
|
}
|
||||||
@@ -46,7 +46,7 @@ class RecipeCard {
|
|||||||
|
|
||||||
// NSFW blur logic - similar to LoraCard
|
// NSFW blur logic - similar to LoraCard
|
||||||
const nsfwLevel = this.recipe.preview_nsfw_level !== undefined ? this.recipe.preview_nsfw_level : 0;
|
const nsfwLevel = this.recipe.preview_nsfw_level !== undefined ? this.recipe.preview_nsfw_level : 0;
|
||||||
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
|
const shouldBlur = state.settings.blur_mature_content && nsfwLevel > NSFW_LEVELS.PG13;
|
||||||
|
|
||||||
if (shouldBlur) {
|
if (shouldBlur) {
|
||||||
card.classList.add('nsfw-content');
|
card.classList.add('nsfw-content');
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { showToast, openCivitai, copyToClipboard, copyLoraSyntax, sendLoraToWorkflow, openExampleImagesFolder } from '../../utils/uiHelpers.js';
|
import { showToast, openCivitai, copyToClipboard, copyLoraSyntax, sendLoraToWorkflow, openExampleImagesFolder, buildLoraSyntax } from '../../utils/uiHelpers.js';
|
||||||
import { state, getCurrentPageState } from '../../state/index.js';
|
import { state, getCurrentPageState } from '../../state/index.js';
|
||||||
import { showModelModal } from './ModelModal.js';
|
import { showModelModal } from './ModelModal.js';
|
||||||
import { toggleShowcase } from './showcase/ShowcaseView.js';
|
import { toggleShowcase } from './showcase/ShowcaseView.js';
|
||||||
@@ -155,8 +155,7 @@ async function toggleFavorite(card) {
|
|||||||
function handleSendToWorkflow(card, replaceMode, modelType) {
|
function handleSendToWorkflow(card, replaceMode, modelType) {
|
||||||
if (modelType === MODEL_TYPES.LORA) {
|
if (modelType === MODEL_TYPES.LORA) {
|
||||||
const usageTips = JSON.parse(card.dataset.usage_tips || '{}');
|
const usageTips = JSON.parse(card.dataset.usage_tips || '{}');
|
||||||
const strength = usageTips.strength || 1;
|
const loraSyntax = buildLoraSyntax(card.dataset.file_name, usageTips);
|
||||||
const loraSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
|
||||||
sendLoraToWorkflow(loraSyntax, replaceMode, 'lora');
|
sendLoraToWorkflow(loraSyntax, replaceMode, 'lora');
|
||||||
} else {
|
} else {
|
||||||
// Checkpoint send functionality - to be implemented
|
// Checkpoint send functionality - to be implemented
|
||||||
@@ -406,7 +405,7 @@ export function createModelCard(model, modelType) {
|
|||||||
card.dataset.nsfwLevel = nsfwLevel;
|
card.dataset.nsfwLevel = nsfwLevel;
|
||||||
|
|
||||||
// Determine if the preview should be blurred based on NSFW level and user settings
|
// Determine if the preview should be blurred based on NSFW level and user settings
|
||||||
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
|
const shouldBlur = state.settings.blur_mature_content && nsfwLevel > NSFW_LEVELS.PG13;
|
||||||
if (shouldBlur) {
|
if (shouldBlur) {
|
||||||
card.classList.add('nsfw-content');
|
card.classList.add('nsfw-content');
|
||||||
}
|
}
|
||||||
@@ -434,7 +433,7 @@ export function createModelCard(model, modelType) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if autoplayOnHover is enabled for video previews
|
// Check if autoplayOnHover is enabled for video previews
|
||||||
const autoplayOnHover = state.global?.settings?.autoplayOnHover || false;
|
const autoplayOnHover = state.global?.settings?.autoplay_on_hover || false;
|
||||||
const isVideo = previewUrl.endsWith('.mp4');
|
const isVideo = previewUrl.endsWith('.mp4');
|
||||||
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
|
const videoAttrs = autoplayOnHover ? 'controls muted loop' : 'controls autoplay muted loop';
|
||||||
|
|
||||||
|
|||||||
@@ -271,6 +271,7 @@ function renderLoraSpecificContent(lora, escapedWords) {
|
|||||||
<option value="strength_min">${translate('modals.model.usageTips.strengthMin', {}, 'Strength Min')}</option>
|
<option value="strength_min">${translate('modals.model.usageTips.strengthMin', {}, 'Strength Min')}</option>
|
||||||
<option value="strength_max">${translate('modals.model.usageTips.strengthMax', {}, 'Strength Max')}</option>
|
<option value="strength_max">${translate('modals.model.usageTips.strengthMax', {}, 'Strength Max')}</option>
|
||||||
<option value="strength">${translate('modals.model.usageTips.strength', {}, 'Strength')}</option>
|
<option value="strength">${translate('modals.model.usageTips.strength', {}, 'Strength')}</option>
|
||||||
|
<option value="clip_strength">${translate('modals.model.usageTips.clipStrength', {}, 'Clip Strength')}</option>
|
||||||
<option value="clip_skip">${translate('modals.model.usageTips.clipSkip', {}, 'Clip Skip')}</option>
|
<option value="clip_skip">${translate('modals.model.usageTips.clipSkip', {}, 'Clip Skip')}</option>
|
||||||
</select>
|
</select>
|
||||||
<input type="number" id="preset-value" step="0.01" placeholder="${translate('modals.model.usageTips.valuePlaceholder', {}, 'Value')}" style="display:none;">
|
<input type="number" id="preset-value" step="0.01" placeholder="${translate('modals.model.usageTips.valuePlaceholder', {}, 'Value')}" style="display:none;">
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ function renderMediaItem(img, index, exampleFiles) {
|
|||||||
|
|
||||||
// Check if media should be blurred
|
// Check if media should be blurred
|
||||||
const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0;
|
const nsfwLevel = img.nsfwLevel !== undefined ? img.nsfwLevel : 0;
|
||||||
const shouldBlur = state.settings.blurMatureContent && nsfwLevel > NSFW_LEVELS.PG13;
|
const shouldBlur = state.settings.blur_mature_content && nsfwLevel > NSFW_LEVELS.PG13;
|
||||||
|
|
||||||
// Determine NSFW warning text based on level
|
// Determine NSFW warning text based on level
|
||||||
let nsfwText = "Mature Content";
|
let nsfwText = "Mature Content";
|
||||||
|
|||||||
@@ -12,11 +12,10 @@ import { helpManager } from './managers/HelpManager.js';
|
|||||||
import { bannerService } from './managers/BannerService.js';
|
import { bannerService } from './managers/BannerService.js';
|
||||||
import { initTheme, initBackToTop } from './utils/uiHelpers.js';
|
import { initTheme, initBackToTop } from './utils/uiHelpers.js';
|
||||||
import { initializeInfiniteScroll } from './utils/infiniteScroll.js';
|
import { initializeInfiniteScroll } from './utils/infiniteScroll.js';
|
||||||
import { migrateStorageItems } from './utils/storageHelpers.js';
|
|
||||||
import { i18n } from './i18n/index.js';
|
import { i18n } from './i18n/index.js';
|
||||||
import { onboardingManager } from './managers/OnboardingManager.js';
|
import { onboardingManager } from './managers/OnboardingManager.js';
|
||||||
import { BulkContextMenu } from './components/ContextMenu/BulkContextMenu.js';
|
import { BulkContextMenu } from './components/ContextMenu/BulkContextMenu.js';
|
||||||
import { createPageContextMenu } from './components/ContextMenu/index.js';
|
import { createPageContextMenu, createGlobalContextMenu } from './components/ContextMenu/index.js';
|
||||||
import { initializeEventManagement } from './utils/eventManagementInit.js';
|
import { initializeEventManagement } from './utils/eventManagementInit.js';
|
||||||
|
|
||||||
// Core application class
|
// Core application class
|
||||||
@@ -74,7 +73,7 @@ export class AppCore {
|
|||||||
// Initialize the help manager
|
// Initialize the help manager
|
||||||
helpManager.initialize();
|
helpManager.initialize();
|
||||||
|
|
||||||
const cardInfoDisplay = state.global.settings.cardInfoDisplay || 'always';
|
const cardInfoDisplay = state.global.settings.card_info_display || 'always';
|
||||||
document.body.classList.toggle('hover-reveal', cardInfoDisplay === 'hover');
|
document.body.classList.toggle('hover-reveal', cardInfoDisplay === 'hover');
|
||||||
|
|
||||||
initializeEventManagement();
|
initializeEventManagement();
|
||||||
@@ -116,13 +115,12 @@ export class AppCore {
|
|||||||
initializeContextMenus(pageType) {
|
initializeContextMenus(pageType) {
|
||||||
// Create page-specific context menu
|
// Create page-specific context menu
|
||||||
window.pageContextMenu = createPageContextMenu(pageType);
|
window.pageContextMenu = createPageContextMenu(pageType);
|
||||||
|
|
||||||
|
if (!window.globalContextMenuInstance) {
|
||||||
|
window.globalContextMenuInstance = createGlobalContextMenu();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
document.addEventListener('DOMContentLoaded', () => {
|
|
||||||
// Migrate localStorage items to use the namespace prefix
|
|
||||||
migrateStorageItems();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Create and export a singleton instance
|
// Create and export a singleton instance
|
||||||
export const appCore = new AppCore();
|
export const appCore = new AppCore();
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import { state } from '../state/index.js';
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Internationalization (i18n) system for LoRA Manager
|
* Internationalization (i18n) system for LoRA Manager
|
||||||
* Uses user-selected language from settings with fallback to English
|
* Uses user-selected language from settings with fallback to English
|
||||||
@@ -123,26 +125,12 @@ class I18nManager {
|
|||||||
* @returns {string} Language code
|
* @returns {string} Language code
|
||||||
*/
|
*/
|
||||||
getLanguageFromSettings() {
|
getLanguageFromSettings() {
|
||||||
// Check localStorage for user-selected language
|
const language = state?.global?.settings?.language;
|
||||||
const STORAGE_PREFIX = 'lora_manager_';
|
|
||||||
let userLanguage = null;
|
if (language && this.availableLocales[language]) {
|
||||||
|
return language;
|
||||||
try {
|
|
||||||
const settings = localStorage.getItem(STORAGE_PREFIX + 'settings');
|
|
||||||
if (settings) {
|
|
||||||
const parsedSettings = JSON.parse(settings);
|
|
||||||
userLanguage = parsedSettings.language;
|
|
||||||
}
|
|
||||||
} catch (e) {
|
|
||||||
console.warn('Failed to parse settings from localStorage:', e);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If user has selected a language, use it
|
|
||||||
if (userLanguage && this.availableLocales[userLanguage]) {
|
|
||||||
return userLanguage;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to English
|
|
||||||
return 'en';
|
return 'en';
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -165,18 +153,10 @@ class I18nManager {
|
|||||||
this.readyPromise = this.initializeWithLocale(languageCode);
|
this.readyPromise = this.initializeWithLocale(languageCode);
|
||||||
await this.readyPromise;
|
await this.readyPromise;
|
||||||
|
|
||||||
// Save to localStorage
|
if (state?.global?.settings) {
|
||||||
const STORAGE_PREFIX = 'lora_manager_';
|
state.global.settings.language = languageCode;
|
||||||
const currentSettings = localStorage.getItem(STORAGE_PREFIX + 'settings');
|
|
||||||
let settings = {};
|
|
||||||
|
|
||||||
if (currentSettings) {
|
|
||||||
settings = JSON.parse(currentSettings);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
settings.language = languageCode;
|
|
||||||
localStorage.setItem(STORAGE_PREFIX + 'settings', JSON.stringify(settings));
|
|
||||||
|
|
||||||
console.log(`Language changed to: ${languageCode}`);
|
console.log(`Language changed to: ${languageCode}`);
|
||||||
|
|
||||||
// Dispatch event to notify components of language change
|
// Dispatch event to notify components of language change
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { state, getCurrentPageState } from '../state/index.js';
|
import { state, getCurrentPageState } from '../state/index.js';
|
||||||
import { showToast, copyToClipboard, sendLoraToWorkflow } from '../utils/uiHelpers.js';
|
import { showToast, copyToClipboard, sendLoraToWorkflow, buildLoraSyntax } from '../utils/uiHelpers.js';
|
||||||
import { updateCardsForBulkMode } from '../components/shared/ModelCard.js';
|
import { updateCardsForBulkMode } from '../components/shared/ModelCard.js';
|
||||||
import { modalManager } from './ModalManager.js';
|
import { modalManager } from './ModalManager.js';
|
||||||
import { getModelApiClient, resetAndReload } from '../api/modelApiFactory.js';
|
import { getModelApiClient, resetAndReload } from '../api/modelApiFactory.js';
|
||||||
@@ -321,8 +321,7 @@ export class BulkManager {
|
|||||||
|
|
||||||
if (metadata) {
|
if (metadata) {
|
||||||
const usageTips = JSON.parse(metadata.usageTips || '{}');
|
const usageTips = JSON.parse(metadata.usageTips || '{}');
|
||||||
const strength = usageTips.strength || 1;
|
loraSyntaxes.push(buildLoraSyntax(metadata.fileName, usageTips));
|
||||||
loraSyntaxes.push(`<lora:${metadata.fileName}:${strength}>`);
|
|
||||||
} else {
|
} else {
|
||||||
missingLoras.push(filepath);
|
missingLoras.push(filepath);
|
||||||
}
|
}
|
||||||
@@ -361,8 +360,7 @@ export class BulkManager {
|
|||||||
|
|
||||||
if (metadata) {
|
if (metadata) {
|
||||||
const usageTips = JSON.parse(metadata.usageTips || '{}');
|
const usageTips = JSON.parse(metadata.usageTips || '{}');
|
||||||
const strength = usageTips.strength || 1;
|
loraSyntaxes.push(buildLoraSyntax(metadata.fileName, usageTips));
|
||||||
loraSyntaxes.push(`<lora:${metadata.fileName}:${strength}>`);
|
|
||||||
} else {
|
} else {
|
||||||
missingLoras.push(filepath);
|
missingLoras.push(filepath);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ export class ExampleImagesManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setup auto download if enabled
|
// Setup auto download if enabled
|
||||||
if (state.global.settings.autoDownloadExampleImages) {
|
if (state.global.settings.auto_download_example_images) {
|
||||||
this.setupAutoDownload();
|
this.setupAutoDownload();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,7 +106,7 @@ export class ExampleImagesManager {
|
|||||||
showToast('toast.exampleImages.pathUpdateFailed', { message: error.message }, 'error');
|
showToast('toast.exampleImages.pathUpdateFailed', { message: error.message }, 'error');
|
||||||
}
|
}
|
||||||
// Setup or clear auto download based on path availability
|
// Setup or clear auto download based on path availability
|
||||||
if (state.global.settings.autoDownloadExampleImages) {
|
if (state.global.settings.auto_download_example_images) {
|
||||||
if (hasPath) {
|
if (hasPath) {
|
||||||
this.setupAutoDownload();
|
this.setupAutoDownload();
|
||||||
} else {
|
} else {
|
||||||
@@ -225,7 +225,7 @@ export class ExampleImagesManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const optimize = state.global.settings.optimizeExampleImages;
|
const optimize = state.global.settings.optimize_example_images;
|
||||||
|
|
||||||
const response = await fetch('/api/lm/download-example-images', {
|
const response = await fetch('/api/lm/download-example-images', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
@@ -677,7 +677,7 @@ export class ExampleImagesManager {
|
|||||||
|
|
||||||
canAutoDownload() {
|
canAutoDownload() {
|
||||||
// Check if auto download is enabled
|
// Check if auto download is enabled
|
||||||
if (!state.global.settings.autoDownloadExampleImages) {
|
if (!state.global.settings.auto_download_example_images) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -713,7 +713,7 @@ export class ExampleImagesManager {
|
|||||||
try {
|
try {
|
||||||
console.log('Performing auto download check...');
|
console.log('Performing auto download check...');
|
||||||
|
|
||||||
const optimize = state.global.settings.optimizeExampleImages;
|
const optimize = state.global.settings.optimize_example_images;
|
||||||
|
|
||||||
const response = await fetch('/api/lm/download-example-images', {
|
const response = await fetch('/api/lm/download-example-images', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
|||||||
@@ -182,9 +182,6 @@ export class OnboardingManager {
|
|||||||
// Update state
|
// Update state
|
||||||
state.global.settings.language = languageCode;
|
state.global.settings.language = languageCode;
|
||||||
|
|
||||||
// Save to localStorage
|
|
||||||
setStorageItem('settings', state.global.settings);
|
|
||||||
|
|
||||||
// Save to backend
|
// Save to backend
|
||||||
const response = await fetch('/api/lm/settings', {
|
const response = await fetch('/api/lm/settings', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import { modalManager } from './ModalManager.js';
|
import { modalManager } from './ModalManager.js';
|
||||||
import { showToast } from '../utils/uiHelpers.js';
|
import { showToast } from '../utils/uiHelpers.js';
|
||||||
import { state } from '../state/index.js';
|
import { state, createDefaultSettings } from '../state/index.js';
|
||||||
import { resetAndReload } from '../api/modelApiFactory.js';
|
import { resetAndReload } from '../api/modelApiFactory.js';
|
||||||
import { setStorageItem, getStorageItem } from '../utils/storageHelpers.js';
|
|
||||||
import { DOWNLOAD_PATH_TEMPLATES, MAPPABLE_BASE_MODELS, PATH_TEMPLATE_PLACEHOLDERS, DEFAULT_PATH_TEMPLATES } from '../utils/constants.js';
|
import { DOWNLOAD_PATH_TEMPLATES, MAPPABLE_BASE_MODELS, PATH_TEMPLATE_PLACEHOLDERS, DEFAULT_PATH_TEMPLATES } from '../utils/constants.js';
|
||||||
import { translate } from '../utils/i18nHelpers.js';
|
import { translate } from '../utils/i18nHelpers.js';
|
||||||
|
import { i18n } from '../i18n/index.js';
|
||||||
|
|
||||||
export class SettingsManager {
|
export class SettingsManager {
|
||||||
constructor() {
|
constructor() {
|
||||||
@@ -14,7 +14,9 @@ export class SettingsManager {
|
|||||||
|
|
||||||
// Add initialization to sync with modal state
|
// Add initialization to sync with modal state
|
||||||
this.currentPage = document.body.dataset.page || 'loras';
|
this.currentPage = document.body.dataset.page || 'loras';
|
||||||
|
|
||||||
|
this.backendSettingKeys = new Set(Object.keys(createDefaultSettings()));
|
||||||
|
|
||||||
// Start initialization but don't await here to avoid blocking constructor
|
// Start initialization but don't await here to avoid blocking constructor
|
||||||
this.initializationPromise = this.initializeSettings();
|
this.initializationPromise = this.initializeSettings();
|
||||||
|
|
||||||
@@ -29,177 +31,91 @@ export class SettingsManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async initializeSettings() {
|
async initializeSettings() {
|
||||||
// Load frontend-only settings from localStorage
|
// Reset to defaults before syncing
|
||||||
this.loadFrontendSettingsFromStorage();
|
state.global.settings = createDefaultSettings();
|
||||||
|
|
||||||
// Sync settings from backend to frontend
|
// Sync settings from backend to frontend
|
||||||
await this.syncSettingsFromBackend();
|
await this.syncSettingsFromBackend();
|
||||||
}
|
}
|
||||||
|
|
||||||
loadFrontendSettingsFromStorage() {
|
|
||||||
// Get saved settings from localStorage
|
|
||||||
const savedSettings = getStorageItem('settings');
|
|
||||||
|
|
||||||
// Frontend-only settings that should be stored in localStorage
|
|
||||||
const frontendOnlyKeys = [
|
|
||||||
'blurMatureContent',
|
|
||||||
'autoplayOnHover',
|
|
||||||
'displayDensity',
|
|
||||||
'cardInfoDisplay',
|
|
||||||
'includeTriggerWords'
|
|
||||||
];
|
|
||||||
|
|
||||||
// Apply saved frontend settings to state if available
|
|
||||||
if (savedSettings) {
|
|
||||||
const frontendSettings = {};
|
|
||||||
frontendOnlyKeys.forEach(key => {
|
|
||||||
if (savedSettings[key] !== undefined) {
|
|
||||||
frontendSettings[key] = savedSettings[key];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
state.global.settings = { ...state.global.settings, ...frontendSettings };
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize default values for frontend settings if they don't exist
|
|
||||||
if (state.global.settings.blurMatureContent === undefined) {
|
|
||||||
state.global.settings.blurMatureContent = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.global.settings.show_only_sfw === undefined) {
|
|
||||||
state.global.settings.show_only_sfw = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.global.settings.autoplayOnHover === undefined) {
|
|
||||||
state.global.settings.autoplayOnHover = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.global.settings.cardInfoDisplay === undefined) {
|
|
||||||
state.global.settings.cardInfoDisplay = 'always';
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.global.settings.displayDensity === undefined) {
|
|
||||||
// Migrate legacy compactMode if it exists
|
|
||||||
if (state.global.settings.compactMode === true) {
|
|
||||||
state.global.settings.displayDensity = 'compact';
|
|
||||||
} else {
|
|
||||||
state.global.settings.displayDensity = 'default';
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (state.global.settings.includeTriggerWords === undefined) {
|
|
||||||
state.global.settings.includeTriggerWords = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save updated frontend settings to localStorage
|
|
||||||
this.saveFrontendSettingsToStorage();
|
|
||||||
}
|
|
||||||
|
|
||||||
async syncSettingsFromBackend() {
|
async syncSettingsFromBackend() {
|
||||||
try {
|
try {
|
||||||
const response = await fetch('/api/lm/settings');
|
const response = await fetch('/api/lm/settings');
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`HTTP error! status: ${response.status}`);
|
throw new Error(`HTTP error! status: ${response.status}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
const data = await response.json();
|
const data = await response.json();
|
||||||
if (data.success && data.settings) {
|
if (data.success && data.settings) {
|
||||||
// Merge backend settings with current state
|
state.global.settings = this.mergeSettingsWithDefaults(data.settings);
|
||||||
state.global.settings = { ...state.global.settings, ...data.settings };
|
|
||||||
|
|
||||||
// Set defaults for backend settings if they're null/undefined
|
|
||||||
this.setBackendSettingDefaults();
|
|
||||||
|
|
||||||
console.log('Settings synced from backend');
|
console.log('Settings synced from backend');
|
||||||
} else {
|
} else {
|
||||||
console.error('Failed to sync settings from backend:', data.error);
|
console.error('Failed to sync settings from backend:', data.error);
|
||||||
|
state.global.settings = this.mergeSettingsWithDefaults();
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Failed to sync settings from backend:', error);
|
console.error('Failed to sync settings from backend:', error);
|
||||||
// Set defaults if backend sync fails
|
state.global.settings = this.mergeSettingsWithDefaults();
|
||||||
this.setBackendSettingDefaults();
|
}
|
||||||
|
|
||||||
|
await this.applyLanguageSetting();
|
||||||
|
this.applyFrontendSettings();
|
||||||
|
}
|
||||||
|
|
||||||
|
async applyLanguageSetting() {
|
||||||
|
const desiredLanguage = state?.global?.settings?.language;
|
||||||
|
|
||||||
|
if (!desiredLanguage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (i18n.getCurrentLocale() !== desiredLanguage) {
|
||||||
|
await i18n.setLanguage(desiredLanguage);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn('Failed to apply language from settings:', error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
setBackendSettingDefaults() {
|
mergeSettingsWithDefaults(backendSettings = {}) {
|
||||||
// Set defaults for backend settings
|
const defaults = createDefaultSettings();
|
||||||
const backendDefaults = {
|
const merged = { ...defaults, ...backendSettings };
|
||||||
civitai_api_key: '',
|
|
||||||
default_lora_root: '',
|
|
||||||
default_checkpoint_root: '',
|
|
||||||
default_embedding_root: '',
|
|
||||||
base_model_path_mappings: {},
|
|
||||||
download_path_templates: { ...DEFAULT_PATH_TEMPLATES },
|
|
||||||
enable_metadata_archive_db: false,
|
|
||||||
language: 'en',
|
|
||||||
show_only_sfw: false,
|
|
||||||
proxy_enabled: false,
|
|
||||||
proxy_type: 'http',
|
|
||||||
proxy_host: '',
|
|
||||||
proxy_port: '',
|
|
||||||
proxy_username: '',
|
|
||||||
proxy_password: '',
|
|
||||||
example_images_path: '',
|
|
||||||
optimizeExampleImages: true,
|
|
||||||
autoDownloadExampleImages: false
|
|
||||||
};
|
|
||||||
|
|
||||||
Object.keys(backendDefaults).forEach(key => {
|
const baseMappings = backendSettings?.base_model_path_mappings;
|
||||||
if (state.global.settings[key] === undefined || state.global.settings[key] === null) {
|
if (baseMappings && typeof baseMappings === 'object' && !Array.isArray(baseMappings)) {
|
||||||
state.global.settings[key] = backendDefaults[key];
|
merged.base_model_path_mappings = baseMappings;
|
||||||
|
} else {
|
||||||
|
merged.base_model_path_mappings = defaults.base_model_path_mappings;
|
||||||
|
}
|
||||||
|
|
||||||
|
let templates = backendSettings?.download_path_templates;
|
||||||
|
if (typeof templates === 'string') {
|
||||||
|
try {
|
||||||
|
const parsed = JSON.parse(templates);
|
||||||
|
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
|
||||||
|
templates = parsed;
|
||||||
|
}
|
||||||
|
} catch (parseError) {
|
||||||
|
console.warn('Failed to parse download_path_templates string from backend, using defaults');
|
||||||
|
templates = null;
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
|
||||||
// Ensure all model types have templates
|
if (!templates || typeof templates !== 'object' || Array.isArray(templates)) {
|
||||||
Object.keys(DEFAULT_PATH_TEMPLATES).forEach(modelType => {
|
templates = {};
|
||||||
if (!state.global.settings.download_path_templates[modelType]) {
|
}
|
||||||
state.global.settings.download_path_templates[modelType] = DEFAULT_PATH_TEMPLATES[modelType];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
saveFrontendSettingsToStorage() {
|
merged.download_path_templates = { ...DEFAULT_PATH_TEMPLATES, ...templates };
|
||||||
// Save only frontend-specific settings to localStorage
|
|
||||||
const frontendOnlyKeys = [
|
|
||||||
'blurMatureContent',
|
|
||||||
'autoplayOnHover',
|
|
||||||
'displayDensity',
|
|
||||||
'cardInfoDisplay',
|
|
||||||
'includeTriggerWords'
|
|
||||||
];
|
|
||||||
|
|
||||||
const frontendSettings = {};
|
Object.keys(merged).forEach(key => this.backendSettingKeys.add(key));
|
||||||
frontendOnlyKeys.forEach(key => {
|
|
||||||
if (state.global.settings[key] !== undefined) {
|
|
||||||
frontendSettings[key] = state.global.settings[key];
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
setStorageItem('settings', frontendSettings);
|
return merged;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper method to determine if a setting should be saved to backend
|
// Helper method to determine if a setting should be saved to backend
|
||||||
isBackendSetting(settingKey) {
|
isBackendSetting(settingKey) {
|
||||||
const backendKeys = [
|
return this.backendSettingKeys.has(settingKey);
|
||||||
'civitai_api_key',
|
|
||||||
'default_lora_root',
|
|
||||||
'default_checkpoint_root',
|
|
||||||
'default_embedding_root',
|
|
||||||
'base_model_path_mappings',
|
|
||||||
'download_path_templates',
|
|
||||||
'enable_metadata_archive_db',
|
|
||||||
'language',
|
|
||||||
'show_only_sfw',
|
|
||||||
'proxy_enabled',
|
|
||||||
'proxy_type',
|
|
||||||
'proxy_host',
|
|
||||||
'proxy_port',
|
|
||||||
'proxy_username',
|
|
||||||
'proxy_password',
|
|
||||||
'example_images_path',
|
|
||||||
'optimizeExampleImages',
|
|
||||||
'autoDownloadExampleImages'
|
|
||||||
];
|
|
||||||
return backendKeys.includes(settingKey);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper method to save setting based on whether it's frontend or backend
|
// Helper method to save setting based on whether it's frontend or backend
|
||||||
@@ -207,36 +123,35 @@ export class SettingsManager {
|
|||||||
// Update state
|
// Update state
|
||||||
state.global.settings[settingKey] = value;
|
state.global.settings[settingKey] = value;
|
||||||
|
|
||||||
if (this.isBackendSetting(settingKey)) {
|
if (!this.isBackendSetting(settingKey)) {
|
||||||
// Save to backend
|
return;
|
||||||
try {
|
}
|
||||||
const payload = {};
|
|
||||||
payload[settingKey] = value;
|
|
||||||
|
|
||||||
const response = await fetch('/api/lm/settings', {
|
// Save to backend
|
||||||
method: 'POST',
|
try {
|
||||||
headers: {
|
const payload = {};
|
||||||
'Content-Type': 'application/json',
|
payload[settingKey] = value;
|
||||||
},
|
|
||||||
body: JSON.stringify(payload)
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
const response = await fetch('/api/lm/settings', {
|
||||||
throw new Error('Failed to save setting to backend');
|
method: 'POST',
|
||||||
}
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify(payload)
|
||||||
|
});
|
||||||
|
|
||||||
// Parse response and check for success
|
if (!response.ok) {
|
||||||
const data = await response.json();
|
throw new Error('Failed to save setting to backend');
|
||||||
if (data.success === false) {
|
|
||||||
throw new Error(data.error || 'Failed to save setting to backend');
|
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
console.error(`Failed to save backend setting ${settingKey}:`, error);
|
|
||||||
throw error;
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Save frontend settings to localStorage
|
// Parse response and check for success
|
||||||
this.saveFrontendSettingsToStorage();
|
const data = await response.json();
|
||||||
|
if (data.success === false) {
|
||||||
|
throw new Error(data.error || 'Failed to save setting to backend');
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to save backend setting ${settingKey}:`, error);
|
||||||
|
throw error;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -298,43 +213,42 @@ export class SettingsManager {
|
|||||||
// Set frontend settings from state
|
// Set frontend settings from state
|
||||||
const blurMatureContentCheckbox = document.getElementById('blurMatureContent');
|
const blurMatureContentCheckbox = document.getElementById('blurMatureContent');
|
||||||
if (blurMatureContentCheckbox) {
|
if (blurMatureContentCheckbox) {
|
||||||
blurMatureContentCheckbox.checked = state.global.settings.blurMatureContent;
|
blurMatureContentCheckbox.checked = state.global.settings.blur_mature_content ?? true;
|
||||||
}
|
}
|
||||||
|
|
||||||
const showOnlySFWCheckbox = document.getElementById('showOnlySFW');
|
const showOnlySFWCheckbox = document.getElementById('showOnlySFW');
|
||||||
if (showOnlySFWCheckbox) {
|
if (showOnlySFWCheckbox) {
|
||||||
// Sync with state (backend will set this via template)
|
showOnlySFWCheckbox.checked = state.global.settings.show_only_sfw ?? false;
|
||||||
state.global.settings.show_only_sfw = showOnlySFWCheckbox.checked;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set video autoplay on hover setting
|
// Set video autoplay on hover setting
|
||||||
const autoplayOnHoverCheckbox = document.getElementById('autoplayOnHover');
|
const autoplayOnHoverCheckbox = document.getElementById('autoplayOnHover');
|
||||||
if (autoplayOnHoverCheckbox) {
|
if (autoplayOnHoverCheckbox) {
|
||||||
autoplayOnHoverCheckbox.checked = state.global.settings.autoplayOnHover || false;
|
autoplayOnHoverCheckbox.checked = state.global.settings.autoplay_on_hover || false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set display density setting
|
// Set display density setting
|
||||||
const displayDensitySelect = document.getElementById('displayDensity');
|
const displayDensitySelect = document.getElementById('displayDensity');
|
||||||
if (displayDensitySelect) {
|
if (displayDensitySelect) {
|
||||||
displayDensitySelect.value = state.global.settings.displayDensity || 'default';
|
displayDensitySelect.value = state.global.settings.display_density || 'default';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set card info display setting
|
// Set card info display setting
|
||||||
const cardInfoDisplaySelect = document.getElementById('cardInfoDisplay');
|
const cardInfoDisplaySelect = document.getElementById('cardInfoDisplay');
|
||||||
if (cardInfoDisplaySelect) {
|
if (cardInfoDisplaySelect) {
|
||||||
cardInfoDisplaySelect.value = state.global.settings.cardInfoDisplay || 'always';
|
cardInfoDisplaySelect.value = state.global.settings.card_info_display || 'always';
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set optimize example images setting
|
// Set optimize example images setting
|
||||||
const optimizeExampleImagesCheckbox = document.getElementById('optimizeExampleImages');
|
const optimizeExampleImagesCheckbox = document.getElementById('optimizeExampleImages');
|
||||||
if (optimizeExampleImagesCheckbox) {
|
if (optimizeExampleImagesCheckbox) {
|
||||||
optimizeExampleImagesCheckbox.checked = state.global.settings.optimizeExampleImages || false;
|
optimizeExampleImagesCheckbox.checked = state.global.settings.optimize_example_images ?? true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set auto download example images setting
|
// Set auto download example images setting
|
||||||
const autoDownloadExampleImagesCheckbox = document.getElementById('autoDownloadExampleImages');
|
const autoDownloadExampleImagesCheckbox = document.getElementById('autoDownloadExampleImages');
|
||||||
if (autoDownloadExampleImagesCheckbox) {
|
if (autoDownloadExampleImagesCheckbox) {
|
||||||
autoDownloadExampleImagesCheckbox.checked = state.global.settings.autoDownloadExampleImages || false;
|
autoDownloadExampleImagesCheckbox.checked = state.global.settings.auto_download_example_images || false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load download path templates
|
// Load download path templates
|
||||||
@@ -343,7 +257,7 @@ export class SettingsManager {
|
|||||||
// Set include trigger words setting
|
// Set include trigger words setting
|
||||||
const includeTriggerWordsCheckbox = document.getElementById('includeTriggerWords');
|
const includeTriggerWordsCheckbox = document.getElementById('includeTriggerWords');
|
||||||
if (includeTriggerWordsCheckbox) {
|
if (includeTriggerWordsCheckbox) {
|
||||||
includeTriggerWordsCheckbox.checked = state.global.settings.includeTriggerWords || false;
|
includeTriggerWordsCheckbox.checked = state.global.settings.include_trigger_words || false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load metadata archive settings
|
// Load metadata archive settings
|
||||||
@@ -883,38 +797,17 @@ export class SettingsManager {
|
|||||||
if (!element) return;
|
if (!element) return;
|
||||||
|
|
||||||
const value = element.checked;
|
const value = element.checked;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Update frontend state with mapped keys
|
await this.saveSetting(settingKey, value);
|
||||||
if (settingKey === 'blur_mature_content') {
|
|
||||||
await this.saveSetting('blurMatureContent', value);
|
if (settingKey === 'proxy_enabled') {
|
||||||
} else if (settingKey === 'show_only_sfw') {
|
|
||||||
await this.saveSetting('show_only_sfw', value);
|
|
||||||
} else if (settingKey === 'autoplay_on_hover') {
|
|
||||||
await this.saveSetting('autoplayOnHover', value);
|
|
||||||
} else if (settingKey === 'optimize_example_images') {
|
|
||||||
await this.saveSetting('optimizeExampleImages', value);
|
|
||||||
} else if (settingKey === 'auto_download_example_images') {
|
|
||||||
await this.saveSetting('autoDownloadExampleImages', value);
|
|
||||||
} else if (settingKey === 'compact_mode') {
|
|
||||||
await this.saveSetting('compactMode', value);
|
|
||||||
} else if (settingKey === 'include_trigger_words') {
|
|
||||||
await this.saveSetting('includeTriggerWords', value);
|
|
||||||
} else if (settingKey === 'enable_metadata_archive_db') {
|
|
||||||
await this.saveSetting('enable_metadata_archive_db', value);
|
|
||||||
} else if (settingKey === 'proxy_enabled') {
|
|
||||||
await this.saveSetting('proxy_enabled', value);
|
|
||||||
|
|
||||||
// Toggle visibility of proxy settings group
|
|
||||||
const proxySettingsGroup = document.getElementById('proxySettingsGroup');
|
const proxySettingsGroup = document.getElementById('proxySettingsGroup');
|
||||||
if (proxySettingsGroup) {
|
if (proxySettingsGroup) {
|
||||||
proxySettingsGroup.style.display = value ? 'block' : 'none';
|
proxySettingsGroup.style.display = value ? 'block' : 'none';
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// For any other settings that might be added in the future
|
|
||||||
await this.saveSetting(settingKey, value);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Refresh metadata archive status when enable setting changes
|
// Refresh metadata archive status when enable setting changes
|
||||||
if (settingKey === 'enable_metadata_archive_db') {
|
if (settingKey === 'enable_metadata_archive_db') {
|
||||||
await this.updateMetadataArchiveStatus();
|
await this.updateMetadataArchiveStatus();
|
||||||
@@ -941,16 +834,11 @@ export class SettingsManager {
|
|||||||
// Recalculate layout when compact mode changes
|
// Recalculate layout when compact mode changes
|
||||||
if (settingKey === 'compact_mode' && state.virtualScroller) {
|
if (settingKey === 'compact_mode' && state.virtualScroller) {
|
||||||
state.virtualScroller.calculateLayout();
|
state.virtualScroller.calculateLayout();
|
||||||
showToast('toast.settings.compactModeToggled', {
|
showToast('toast.settings.compactModeToggled', {
|
||||||
state: value ? 'toast.settings.compactModeEnabled' : 'toast.settings.compactModeDisabled'
|
state: value ? 'toast.settings.compactModeEnabled' : 'toast.settings.compactModeDisabled'
|
||||||
}, 'success');
|
}, 'success');
|
||||||
}
|
}
|
||||||
|
|
||||||
// Special handling for metadata archive settings
|
|
||||||
if (settingKey === 'enable_metadata_archive_db') {
|
|
||||||
await this.updateMetadataArchiveStatus();
|
|
||||||
}
|
|
||||||
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
showToast('toast.settings.settingSaveFailed', { message: error.message }, 'error');
|
showToast('toast.settings.settingSaveFailed', { message: error.message }, 'error');
|
||||||
}
|
}
|
||||||
@@ -964,23 +852,8 @@ export class SettingsManager {
|
|||||||
|
|
||||||
try {
|
try {
|
||||||
// Update frontend state with mapped keys
|
// Update frontend state with mapped keys
|
||||||
if (settingKey === 'default_lora_root') {
|
await this.saveSetting(settingKey, value);
|
||||||
await this.saveSetting('default_lora_root', value);
|
|
||||||
} else if (settingKey === 'default_checkpoint_root') {
|
|
||||||
await this.saveSetting('default_checkpoint_root', value);
|
|
||||||
} else if (settingKey === 'default_embedding_root') {
|
|
||||||
await this.saveSetting('default_embedding_root', value);
|
|
||||||
} else if (settingKey === 'display_density') {
|
|
||||||
await this.saveSetting('displayDensity', value);
|
|
||||||
} else if (settingKey === 'card_info_display') {
|
|
||||||
await this.saveSetting('cardInfoDisplay', value);
|
|
||||||
} else if (settingKey === 'proxy_type') {
|
|
||||||
await this.saveSetting('proxy_type', value);
|
|
||||||
} else {
|
|
||||||
// For any other settings that might be added in the future
|
|
||||||
await this.saveSetting(settingKey, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply frontend settings immediately
|
// Apply frontend settings immediately
|
||||||
this.applyFrontendSettings();
|
this.applyFrontendSettings();
|
||||||
|
|
||||||
@@ -1296,13 +1169,13 @@ export class SettingsManager {
|
|||||||
async saveLanguageSetting() {
|
async saveLanguageSetting() {
|
||||||
const element = document.getElementById('languageSelect');
|
const element = document.getElementById('languageSelect');
|
||||||
if (!element) return;
|
if (!element) return;
|
||||||
|
|
||||||
const selectedLanguage = element.value;
|
const selectedLanguage = element.value;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Use the universal save method for language (frontend-only setting)
|
// Use the universal save method for language (frontend-only setting)
|
||||||
await this.saveSetting('language', selectedLanguage);
|
await this.saveSetting('language', selectedLanguage);
|
||||||
|
|
||||||
// Reload the page to apply the new language
|
// Reload the page to apply the new language
|
||||||
window.location.reload();
|
window.location.reload();
|
||||||
|
|
||||||
@@ -1347,7 +1220,7 @@ export class SettingsManager {
|
|||||||
|
|
||||||
applyFrontendSettings() {
|
applyFrontendSettings() {
|
||||||
// Apply autoplay setting to existing videos in card previews
|
// Apply autoplay setting to existing videos in card previews
|
||||||
const autoplayOnHover = state.global.settings.autoplayOnHover;
|
const autoplayOnHover = state.global.settings.autoplay_on_hover;
|
||||||
document.querySelectorAll('.card-preview video').forEach(video => {
|
document.querySelectorAll('.card-preview video').forEach(video => {
|
||||||
// Remove previous event listeners by cloning and replacing the element
|
// Remove previous event listeners by cloning and replacing the element
|
||||||
const videoParent = video.parentElement;
|
const videoParent = video.parentElement;
|
||||||
@@ -1377,17 +1250,17 @@ export class SettingsManager {
|
|||||||
// Apply display density class to grid
|
// Apply display density class to grid
|
||||||
const grid = document.querySelector('.card-grid');
|
const grid = document.querySelector('.card-grid');
|
||||||
if (grid) {
|
if (grid) {
|
||||||
const density = state.global.settings.displayDensity || 'default';
|
const density = state.global.settings.display_density || 'default';
|
||||||
|
|
||||||
// Remove all density classes first
|
// Remove all density classes first
|
||||||
grid.classList.remove('default-density', 'medium-density', 'compact-density');
|
grid.classList.remove('default-density', 'medium-density', 'compact-density');
|
||||||
|
|
||||||
// Add the appropriate density class
|
// Add the appropriate density class
|
||||||
grid.classList.add(`${density}-density`);
|
grid.classList.add(`${density}-density`);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply card info display setting
|
// Apply card info display setting
|
||||||
const cardInfoDisplay = state.global.settings.cardInfoDisplay || 'always';
|
const cardInfoDisplay = state.global.settings.card_info_display || 'always';
|
||||||
document.body.classList.toggle('hover-reveal', cardInfoDisplay === 'hover');
|
document.body.classList.toggle('hover-reveal', cardInfoDisplay === 'hover');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,43 @@
|
|||||||
// Create the new hierarchical state structure
|
// Create the new hierarchical state structure
|
||||||
import { getStorageItem, getMapFromStorage } from '../utils/storageHelpers.js';
|
import { getStorageItem, getMapFromStorage } from '../utils/storageHelpers.js';
|
||||||
import { MODEL_TYPES } from '../api/apiConfig.js';
|
import { MODEL_TYPES } from '../api/apiConfig.js';
|
||||||
|
import { DEFAULT_PATH_TEMPLATES } from '../utils/constants.js';
|
||||||
|
|
||||||
// Load only frontend settings from localStorage with defaults
|
const DEFAULT_SETTINGS_BASE = Object.freeze({
|
||||||
// Backend settings will be loaded by SettingsManager from the backend
|
civitai_api_key: '',
|
||||||
const savedSettings = getStorageItem('settings', {
|
language: 'en',
|
||||||
blurMatureContent: true,
|
|
||||||
show_only_sfw: false,
|
show_only_sfw: false,
|
||||||
cardInfoDisplay: 'always',
|
enable_metadata_archive_db: false,
|
||||||
autoplayOnHover: false,
|
proxy_enabled: false,
|
||||||
displayDensity: 'default',
|
proxy_type: 'http',
|
||||||
optimizeExampleImages: true,
|
proxy_host: '',
|
||||||
autoDownloadExampleImages: true,
|
proxy_port: '',
|
||||||
includeTriggerWords: false
|
proxy_username: '',
|
||||||
|
proxy_password: '',
|
||||||
|
default_lora_root: '',
|
||||||
|
default_checkpoint_root: '',
|
||||||
|
default_embedding_root: '',
|
||||||
|
base_model_path_mappings: {},
|
||||||
|
download_path_templates: {},
|
||||||
|
example_images_path: '',
|
||||||
|
optimize_example_images: true,
|
||||||
|
auto_download_example_images: false,
|
||||||
|
blur_mature_content: true,
|
||||||
|
autoplay_on_hover: false,
|
||||||
|
display_density: 'default',
|
||||||
|
card_info_display: 'always',
|
||||||
|
include_trigger_words: false,
|
||||||
|
compact_mode: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
|
export function createDefaultSettings() {
|
||||||
|
return {
|
||||||
|
...DEFAULT_SETTINGS_BASE,
|
||||||
|
base_model_path_mappings: {},
|
||||||
|
download_path_templates: { ...DEFAULT_PATH_TEMPLATES },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Load preview versions from localStorage for each model type
|
// Load preview versions from localStorage for each model type
|
||||||
const loraPreviewVersions = getMapFromStorage('loras_preview_versions');
|
const loraPreviewVersions = getMapFromStorage('loras_preview_versions');
|
||||||
const checkpointPreviewVersions = getMapFromStorage('checkpoints_preview_versions');
|
const checkpointPreviewVersions = getMapFromStorage('checkpoints_preview_versions');
|
||||||
@@ -23,7 +46,7 @@ const embeddingPreviewVersions = getMapFromStorage('embeddings_preview_versions'
|
|||||||
export const state = {
|
export const state = {
|
||||||
// Global state
|
// Global state
|
||||||
global: {
|
global: {
|
||||||
settings: savedSettings,
|
settings: createDefaultSettings(),
|
||||||
loadingManager: null,
|
loadingManager: null,
|
||||||
observer: null,
|
observer: null,
|
||||||
},
|
},
|
||||||
|
|||||||
53
static/js/state/index.test.js
Normal file
53
static/js/state/index.test.js
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import { describe, it, expect, beforeEach } from 'vitest';
|
||||||
|
import { createDefaultSettings, getCurrentPageState, initPageState, setCurrentPageType, state } from './index.js';
|
||||||
|
import { MODEL_TYPES } from '../api/apiConfig.js';
|
||||||
|
import { DEFAULT_PATH_TEMPLATES } from '../utils/constants.js';
|
||||||
|
|
||||||
|
describe('state module', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
// Reset to default page before each assertion
|
||||||
|
state.currentPageType = MODEL_TYPES.LORA;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('creates default settings with immutable template copies', () => {
|
||||||
|
const defaultSettings = createDefaultSettings();
|
||||||
|
|
||||||
|
expect(defaultSettings).toMatchObject({
|
||||||
|
civitai_api_key: '',
|
||||||
|
language: 'en',
|
||||||
|
blur_mature_content: true
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(defaultSettings.download_path_templates).toEqual(DEFAULT_PATH_TEMPLATES);
|
||||||
|
|
||||||
|
// ensure nested objects are new references so tests can safely mutate
|
||||||
|
expect(defaultSettings.download_path_templates).not.toBe(DEFAULT_PATH_TEMPLATES);
|
||||||
|
expect(defaultSettings.base_model_path_mappings).toEqual({});
|
||||||
|
expect(Object.isFrozen(defaultSettings)).toBe(false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('switches current page type when valid', () => {
|
||||||
|
const didSwitch = setCurrentPageType(MODEL_TYPES.CHECKPOINT);
|
||||||
|
|
||||||
|
expect(didSwitch).toBe(true);
|
||||||
|
expect(state.currentPageType).toBe(MODEL_TYPES.CHECKPOINT);
|
||||||
|
expect(getCurrentPageState()).toBe(state.pages[MODEL_TYPES.CHECKPOINT]);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('rejects switching to an unknown page type', () => {
|
||||||
|
state.currentPageType = MODEL_TYPES.LORA;
|
||||||
|
|
||||||
|
const didSwitch = setCurrentPageType('invalid-page');
|
||||||
|
|
||||||
|
expect(didSwitch).toBe(false);
|
||||||
|
expect(state.currentPageType).toBe(MODEL_TYPES.LORA);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('initializes and returns state for a known page', () => {
|
||||||
|
const pageState = initPageState(MODEL_TYPES.EMBEDDING);
|
||||||
|
|
||||||
|
expect(pageState).toBeDefined();
|
||||||
|
expect(pageState).toBe(state.pages[MODEL_TYPES.EMBEDDING]);
|
||||||
|
expect(state.currentPageType).toBe(MODEL_TYPES.EMBEDDING);
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -102,7 +102,7 @@ export class VirtualScroller {
|
|||||||
const availableContentWidth = containerWidth - paddingLeft - paddingRight;
|
const availableContentWidth = containerWidth - paddingLeft - paddingRight;
|
||||||
|
|
||||||
// Get display density setting
|
// Get display density setting
|
||||||
const displayDensity = state.global.settings?.displayDensity || 'default';
|
const displayDensity = state.global.settings?.display_density || 'default';
|
||||||
|
|
||||||
// Set exact column counts and grid widths to match CSS container widths
|
// Set exact column counts and grid widths to match CSS container widths
|
||||||
let maxColumns, maxGridWidth;
|
let maxColumns, maxGridWidth;
|
||||||
|
|||||||
@@ -86,29 +86,39 @@ function setupPageUnloadCleanup() {
|
|||||||
function registerContextMenuEvents() {
|
function registerContextMenuEvents() {
|
||||||
eventManager.addHandler('contextmenu', 'contextMenu-coordination', (e) => {
|
eventManager.addHandler('contextmenu', 'contextMenu-coordination', (e) => {
|
||||||
const card = e.target.closest('.model-card');
|
const card = e.target.closest('.model-card');
|
||||||
if (!card) {
|
const pageContent = e.target.closest('.page-content');
|
||||||
// Hide all menus if not right-clicking on a card
|
|
||||||
window.pageContextMenu?.hideMenu();
|
if (!pageContent) {
|
||||||
window.bulkManager?.bulkContextMenu?.hideMenu();
|
window.globalContextMenuInstance?.hideMenu();
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
e.preventDefault();
|
if (card) {
|
||||||
|
e.preventDefault();
|
||||||
// Hide all menus first
|
|
||||||
window.pageContextMenu?.hideMenu();
|
// Hide all menus first
|
||||||
window.bulkManager?.bulkContextMenu?.hideMenu();
|
window.pageContextMenu?.hideMenu();
|
||||||
|
window.bulkManager?.bulkContextMenu?.hideMenu();
|
||||||
// Determine which menu to show based on bulk mode and selection state
|
window.globalContextMenuInstance?.hideMenu();
|
||||||
if (state.bulkMode && card.classList.contains('selected')) {
|
|
||||||
// Show bulk menu for selected cards in bulk mode
|
// Determine which menu to show based on bulk mode and selection state
|
||||||
window.bulkManager?.bulkContextMenu?.showMenu(e.clientX, e.clientY, card);
|
if (state.bulkMode && card.classList.contains('selected')) {
|
||||||
} else if (!state.bulkMode) {
|
// Show bulk menu for selected cards in bulk mode
|
||||||
// Show regular menu when not in bulk mode
|
window.bulkManager?.bulkContextMenu?.showMenu(e.clientX, e.clientY, card);
|
||||||
window.pageContextMenu?.showMenu(e.clientX, e.clientY, card);
|
} else if (!state.bulkMode) {
|
||||||
|
// Show regular menu when not in bulk mode
|
||||||
|
window.pageContextMenu?.showMenu(e.clientX, e.clientY, card);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
e.preventDefault();
|
||||||
|
|
||||||
|
window.pageContextMenu?.hideMenu();
|
||||||
|
window.bulkManager?.bulkContextMenu?.hideMenu();
|
||||||
|
window.globalContextMenuInstance?.hideMenu();
|
||||||
|
|
||||||
|
window.globalContextMenuInstance?.showMenu(e.clientX, e.clientY, null);
|
||||||
}
|
}
|
||||||
// Don't show any menu for unselected cards in bulk mode
|
|
||||||
|
|
||||||
return true; // Stop propagation
|
return true; // Stop propagation
|
||||||
}, {
|
}, {
|
||||||
priority: 200, // Higher priority than bulk manager events
|
priority: 200, // Higher priority than bulk manager events
|
||||||
@@ -125,6 +135,7 @@ function registerGlobalClickHandlers() {
|
|||||||
if (!e.target.closest('.context-menu')) {
|
if (!e.target.closest('.context-menu')) {
|
||||||
window.pageContextMenu?.hideMenu();
|
window.pageContextMenu?.hideMenu();
|
||||||
window.bulkManager?.bulkContextMenu?.hideMenu();
|
window.bulkManager?.bulkContextMenu?.hideMenu();
|
||||||
|
window.globalContextMenuInstance?.hideMenu();
|
||||||
}
|
}
|
||||||
return false; // Allow other handlers to process
|
return false; // Allow other handlers to process
|
||||||
}, {
|
}, {
|
||||||
|
|||||||
@@ -116,64 +116,6 @@ export function removeSessionItem(key) {
|
|||||||
sessionStorage.removeItem(STORAGE_PREFIX + key);
|
sessionStorage.removeItem(STORAGE_PREFIX + key);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Migrate all existing localStorage items to use the prefix
|
|
||||||
* This should be called once during application initialization
|
|
||||||
*/
|
|
||||||
export function migrateStorageItems() {
|
|
||||||
// Check if migration has already been performed
|
|
||||||
if (localStorage.getItem(STORAGE_PREFIX + 'migration_completed')) {
|
|
||||||
console.log('Lora Manager: Storage migration already completed');
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
// List of known keys used in the application
|
|
||||||
const knownKeys = [
|
|
||||||
'nsfwBlurLevel',
|
|
||||||
'theme',
|
|
||||||
'activeFolder',
|
|
||||||
'folderTagsCollapsed',
|
|
||||||
'settings',
|
|
||||||
'loras_filters',
|
|
||||||
'recipes_filters',
|
|
||||||
'checkpoints_filters',
|
|
||||||
'loras_search_prefs',
|
|
||||||
'recipes_search_prefs',
|
|
||||||
'checkpoints_search_prefs',
|
|
||||||
'show_update_notifications',
|
|
||||||
'last_update_check',
|
|
||||||
'dismissed_banners'
|
|
||||||
];
|
|
||||||
|
|
||||||
// Migrate each known key
|
|
||||||
knownKeys.forEach(key => {
|
|
||||||
const prefixedKey = STORAGE_PREFIX + key;
|
|
||||||
|
|
||||||
// Only migrate if the prefixed key doesn't already exist
|
|
||||||
if (localStorage.getItem(prefixedKey) === null) {
|
|
||||||
const value = localStorage.getItem(key);
|
|
||||||
if (value !== null) {
|
|
||||||
try {
|
|
||||||
// Try to parse as JSON first
|
|
||||||
const parsedValue = JSON.parse(value);
|
|
||||||
setStorageItem(key, parsedValue);
|
|
||||||
} catch (e) {
|
|
||||||
// If not JSON, store as is
|
|
||||||
setStorageItem(key, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
// We can optionally remove the old key after migration
|
|
||||||
localStorage.removeItem(key);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Mark migration as completed
|
|
||||||
localStorage.setItem(STORAGE_PREFIX + 'migration_completed', 'true');
|
|
||||||
|
|
||||||
console.log('Lora Manager: Storage migration completed');
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Save a Map to localStorage
|
* Save a Map to localStorage
|
||||||
* @param {string} key - The localStorage key
|
* @param {string} key - The localStorage key
|
||||||
|
|||||||
111
static/js/utils/storageHelpers.test.js
Normal file
111
static/js/utils/storageHelpers.test.js
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
import { describe, it, expect, beforeEach, afterEach, vi } from 'vitest';
|
||||||
|
import * as storageHelpers from './storageHelpers.js';
|
||||||
|
|
||||||
|
const {
|
||||||
|
getStorageItem,
|
||||||
|
setStorageItem,
|
||||||
|
removeStorageItem,
|
||||||
|
getSessionItem,
|
||||||
|
setSessionItem,
|
||||||
|
removeSessionItem,
|
||||||
|
} = storageHelpers;
|
||||||
|
|
||||||
|
const createFakeStorage = () => {
|
||||||
|
const store = new Map();
|
||||||
|
return {
|
||||||
|
getItem: vi.fn((key) => (store.has(key) ? store.get(key) : null)),
|
||||||
|
setItem: vi.fn((key, value) => {
|
||||||
|
store.set(key, value);
|
||||||
|
}),
|
||||||
|
removeItem: vi.fn((key) => {
|
||||||
|
store.delete(key);
|
||||||
|
}),
|
||||||
|
clear: vi.fn(() => {
|
||||||
|
store.clear();
|
||||||
|
}),
|
||||||
|
key: vi.fn((index) => Array.from(store.keys())[index] ?? null),
|
||||||
|
get length() {
|
||||||
|
return store.size;
|
||||||
|
},
|
||||||
|
_store: store
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
let localStorageMock;
|
||||||
|
let sessionStorageMock;
|
||||||
|
let consoleLogMock;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
localStorageMock = createFakeStorage();
|
||||||
|
sessionStorageMock = createFakeStorage();
|
||||||
|
vi.stubGlobal('localStorage', localStorageMock);
|
||||||
|
vi.stubGlobal('sessionStorage', sessionStorageMock);
|
||||||
|
consoleLogMock = vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.unstubAllGlobals();
|
||||||
|
vi.restoreAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('storageHelpers namespace utilities', () => {
|
||||||
|
it('returns parsed JSON for prefixed localStorage items', () => {
|
||||||
|
localStorage.setItem('lora_manager_preferences', JSON.stringify({ theme: 'dark' }));
|
||||||
|
|
||||||
|
const result = getStorageItem('preferences');
|
||||||
|
|
||||||
|
expect(result).toEqual({ theme: 'dark' });
|
||||||
|
expect(localStorage.getItem).toHaveBeenCalledWith('lora_manager_preferences');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('falls back to legacy keys and migrates them to the namespace', () => {
|
||||||
|
localStorage.setItem('legacy_key', 'value');
|
||||||
|
|
||||||
|
const value = getStorageItem('legacy_key');
|
||||||
|
|
||||||
|
expect(value).toBe('value');
|
||||||
|
expect(localStorage.getItem('lora_manager_legacy_key')).toBe('value');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('serializes objects when setting prefixed localStorage values', () => {
|
||||||
|
const data = { ids: [1, 2, 3] };
|
||||||
|
|
||||||
|
setStorageItem('data', data);
|
||||||
|
|
||||||
|
expect(localStorage.setItem).toHaveBeenCalledWith('lora_manager_data', JSON.stringify(data));
|
||||||
|
expect(localStorage.getItem('lora_manager_data')).toEqual(JSON.stringify(data));
|
||||||
|
});
|
||||||
|
|
||||||
|
it('removes both prefixed and legacy localStorage entries', () => {
|
||||||
|
localStorage.setItem('lora_manager_temp', '123');
|
||||||
|
localStorage.setItem('temp', '456');
|
||||||
|
|
||||||
|
removeStorageItem('temp');
|
||||||
|
|
||||||
|
expect(localStorage.getItem('lora_manager_temp')).toBeNull();
|
||||||
|
expect(localStorage.getItem('temp')).toBeNull();
|
||||||
|
});
|
||||||
|
|
||||||
|
it('returns parsed JSON for session storage items', () => {
|
||||||
|
sessionStorage.setItem('lora_manager_session', JSON.stringify({ page: 'loras' }));
|
||||||
|
|
||||||
|
const session = getSessionItem('session');
|
||||||
|
|
||||||
|
expect(session).toEqual({ page: 'loras' });
|
||||||
|
});
|
||||||
|
|
||||||
|
it('stores primitives in session storage directly', () => {
|
||||||
|
setSessionItem('token', 'abc123');
|
||||||
|
|
||||||
|
expect(sessionStorage.setItem).toHaveBeenCalledWith('lora_manager_token', 'abc123');
|
||||||
|
expect(sessionStorage.getItem('lora_manager_token')).toBe('abc123');
|
||||||
|
});
|
||||||
|
|
||||||
|
it('removes session storage entries by namespace', () => {
|
||||||
|
sessionStorage.setItem('lora_manager_flag', '1');
|
||||||
|
|
||||||
|
removeSessionItem('flag');
|
||||||
|
|
||||||
|
expect(sessionStorage.getItem('lora_manager_flag')).toBeNull();
|
||||||
|
});
|
||||||
|
});
|
||||||
@@ -295,13 +295,51 @@ export function getNSFWLevelName(level) {
|
|||||||
return 'Unknown';
|
return 'Unknown';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function parseUsageTipNumber(value) {
|
||||||
|
if (typeof value === 'number' && Number.isFinite(value)) {
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof value === 'string') {
|
||||||
|
const parsed = parseFloat(value);
|
||||||
|
if (Number.isFinite(parsed)) {
|
||||||
|
return parsed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getLoraStrengthsFromUsageTips(usageTips = {}) {
|
||||||
|
const parsedStrength = parseUsageTipNumber(usageTips.strength);
|
||||||
|
const clipStrengthSource = usageTips.clip_strength ?? usageTips.clipStrength;
|
||||||
|
const parsedClipStrength = parseUsageTipNumber(clipStrengthSource);
|
||||||
|
|
||||||
|
return {
|
||||||
|
strength: parsedStrength !== null ? parsedStrength : 1,
|
||||||
|
hasStrength: parsedStrength !== null,
|
||||||
|
clipStrength: parsedClipStrength,
|
||||||
|
hasClipStrength: parsedClipStrength !== null,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function buildLoraSyntax(fileName, usageTips = {}) {
|
||||||
|
const { strength, hasStrength, clipStrength, hasClipStrength } = getLoraStrengthsFromUsageTips(usageTips);
|
||||||
|
|
||||||
|
if (hasClipStrength) {
|
||||||
|
const modelStrength = hasStrength ? strength : 1;
|
||||||
|
return `<lora:${fileName}:${modelStrength}:${clipStrength}>`;
|
||||||
|
}
|
||||||
|
|
||||||
|
return `<lora:${fileName}:${strength}>`;
|
||||||
|
}
|
||||||
|
|
||||||
export function copyLoraSyntax(card) {
|
export function copyLoraSyntax(card) {
|
||||||
const usageTips = JSON.parse(card.dataset.usage_tips || "{}");
|
const usageTips = JSON.parse(card.dataset.usage_tips || "{}");
|
||||||
const strength = usageTips.strength || 1;
|
const baseSyntax = buildLoraSyntax(card.dataset.file_name, usageTips);
|
||||||
const baseSyntax = `<lora:${card.dataset.file_name}:${strength}>`;
|
|
||||||
|
|
||||||
// Check if trigger words should be included
|
// Check if trigger words should be included
|
||||||
const includeTriggerWords = state.global.settings.includeTriggerWords;
|
const includeTriggerWords = state.global.settings.include_trigger_words;
|
||||||
|
|
||||||
if (!includeTriggerWords) {
|
if (!includeTriggerWords) {
|
||||||
const message = translate('uiHelpers.lora.syntaxCopied', {}, 'LoRA syntax copied to clipboard');
|
const message = translate('uiHelpers.lora.syntaxCopied', {}, 'LoRA syntax copied to clipboard');
|
||||||
|
|||||||
@@ -84,6 +84,15 @@
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div id="globalContextMenu" class="context-menu">
|
||||||
|
<div class="context-menu-item" data-action="download-example-images">
|
||||||
|
<i class="fas fa-download"></i> <span>{{ t('globalContextMenu.downloadExampleImages.label') }}</span>
|
||||||
|
</div>
|
||||||
|
<div class="context-menu-item" data-action="cleanup-example-images-folders">
|
||||||
|
<i class="fas fa-trash-restore"></i> <span>{{ t('globalContextMenu.cleanupExampleImages.label') }}</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div id="nsfwLevelSelector" class="nsfw-level-selector">
|
<div id="nsfwLevelSelector" class="nsfw-level-selector">
|
||||||
<div class="nsfw-level-header">
|
<div class="nsfw-level-header">
|
||||||
<h3>{{ t('modals.contentRating.title') }}</h3>
|
<h3>{{ t('modals.contentRating.title') }}</h3>
|
||||||
@@ -103,4 +112,4 @@
|
|||||||
|
|
||||||
<div id="nodeSelector" class="node-selector">
|
<div id="nodeSelector" class="node-selector">
|
||||||
<!-- Dynamic node list will be populated here -->
|
<!-- Dynamic node list will be populated here -->
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
222
tests/conftest.py
Normal file
222
tests/conftest.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
import types
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
from unittest import mock
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Mock ComfyUI modules before any imports from the main project
|
||||||
|
server_mock = types.SimpleNamespace()
|
||||||
|
server_mock.PromptServer = mock.MagicMock()
|
||||||
|
sys.modules['server'] = server_mock
|
||||||
|
|
||||||
|
folder_paths_mock = types.SimpleNamespace()
|
||||||
|
folder_paths_mock.get_folder_paths = mock.MagicMock(return_value=[])
|
||||||
|
folder_paths_mock.folder_names_and_paths = {}
|
||||||
|
sys.modules['folder_paths'] = folder_paths_mock
|
||||||
|
|
||||||
|
# Mock other ComfyUI modules that might be imported
|
||||||
|
comfy_mock = types.SimpleNamespace()
|
||||||
|
comfy_mock.utils = types.SimpleNamespace()
|
||||||
|
comfy_mock.model_management = types.SimpleNamespace()
|
||||||
|
comfy_mock.comfy_types = types.SimpleNamespace()
|
||||||
|
comfy_mock.comfy_types.IO = mock.MagicMock()
|
||||||
|
sys.modules['comfy'] = comfy_mock
|
||||||
|
sys.modules['comfy.utils'] = comfy_mock.utils
|
||||||
|
sys.modules['comfy.model_management'] = comfy_mock.model_management
|
||||||
|
sys.modules['comfy.comfy_types'] = comfy_mock.comfy_types
|
||||||
|
|
||||||
|
execution_mock = types.SimpleNamespace()
|
||||||
|
execution_mock.PromptExecutor = mock.MagicMock()
|
||||||
|
sys.modules['execution'] = execution_mock
|
||||||
|
|
||||||
|
# Mock ComfyUI nodes module
|
||||||
|
nodes_mock = types.SimpleNamespace()
|
||||||
|
nodes_mock.LoraLoader = mock.MagicMock()
|
||||||
|
nodes_mock.SaveImage = mock.MagicMock()
|
||||||
|
nodes_mock.NODE_CLASS_MAPPINGS = {}
|
||||||
|
sys.modules['nodes'] = nodes_mock
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_pyfunc_call(pyfuncitem):
|
||||||
|
"""Allow bare async tests to run without pytest.mark.asyncio."""
|
||||||
|
test_function = pyfuncitem.function
|
||||||
|
if inspect.iscoroutinefunction(test_function):
|
||||||
|
func = pyfuncitem.obj
|
||||||
|
signature = inspect.signature(func)
|
||||||
|
accepted_kwargs: Dict[str, Any] = {}
|
||||||
|
for name, parameter in signature.parameters.items():
|
||||||
|
if parameter.kind is inspect.Parameter.VAR_POSITIONAL:
|
||||||
|
continue
|
||||||
|
if parameter.kind is inspect.Parameter.VAR_KEYWORD:
|
||||||
|
accepted_kwargs = dict(pyfuncitem.funcargs)
|
||||||
|
break
|
||||||
|
if name in pyfuncitem.funcargs:
|
||||||
|
accepted_kwargs[name] = pyfuncitem.funcargs[name]
|
||||||
|
|
||||||
|
original_policy = asyncio.get_event_loop_policy()
|
||||||
|
policy = pyfuncitem.funcargs.get("event_loop_policy")
|
||||||
|
if policy is not None and policy is not original_policy:
|
||||||
|
asyncio.set_event_loop_policy(policy)
|
||||||
|
try:
|
||||||
|
asyncio.run(func(**accepted_kwargs))
|
||||||
|
finally:
|
||||||
|
if policy is not None and policy is not original_policy:
|
||||||
|
asyncio.set_event_loop_policy(original_policy)
|
||||||
|
return True
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockHashIndex:
|
||||||
|
"""Minimal hash index stub mirroring the scanner contract."""
|
||||||
|
|
||||||
|
removed_paths: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def remove_by_path(self, path: str) -> None:
|
||||||
|
self.removed_paths.append(path)
|
||||||
|
|
||||||
|
|
||||||
|
class MockCache:
|
||||||
|
"""Cache object with the attributes."""
|
||||||
|
|
||||||
|
def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None):
|
||||||
|
self.raw_data: List[Dict[str, Any]] = list(items or [])
|
||||||
|
self.resort_calls = 0
|
||||||
|
|
||||||
|
async def resort(self) -> None:
|
||||||
|
self.resort_calls += 1
|
||||||
|
# expects the coroutine interface but does not
|
||||||
|
# rely on the return value.
|
||||||
|
|
||||||
|
|
||||||
|
class MockScanner:
|
||||||
|
"""Scanner double that exposes the attributes used by route utilities."""
|
||||||
|
|
||||||
|
def __init__(self, cache: Optional[MockCache] = None, hash_index: Optional[MockHashIndex] = None):
|
||||||
|
self._cache = cache or MockCache()
|
||||||
|
self._hash_index = hash_index or MockHashIndex()
|
||||||
|
self._tags_count: Dict[str, int] = {}
|
||||||
|
self._excluded_models: List[str] = []
|
||||||
|
self.updated_models: List[Dict[str, Any]] = []
|
||||||
|
self.preview_updates: List[Dict[str, Any]] = []
|
||||||
|
self.bulk_deleted: List[Sequence[str]] = []
|
||||||
|
|
||||||
|
async def get_cached_data(self, force_refresh: bool = False):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
async def update_single_model_cache(self, original_path: str, new_path: str, metadata: Dict[str, Any]) -> bool:
|
||||||
|
self.updated_models.append({
|
||||||
|
"original_path": original_path,
|
||||||
|
"new_path": new_path,
|
||||||
|
"metadata": metadata,
|
||||||
|
})
|
||||||
|
for item in self._cache.raw_data:
|
||||||
|
if item.get("file_path") == original_path:
|
||||||
|
item.update(metadata)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def update_preview_in_cache(self, model_path: str, preview_path: str, nsfw_level: int) -> bool:
|
||||||
|
self.preview_updates.append({
|
||||||
|
"model_path": model_path,
|
||||||
|
"preview_path": preview_path,
|
||||||
|
"nsfw_level": nsfw_level,
|
||||||
|
})
|
||||||
|
for item in self._cache.raw_data:
|
||||||
|
if item.get("file_path") == model_path:
|
||||||
|
item["preview_url"] = preview_path
|
||||||
|
item["preview_nsfw_level"] = nsfw_level
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def bulk_delete_models(self, file_paths: Sequence[str]) -> Dict[str, Any]:
|
||||||
|
self.bulk_deleted.append(tuple(file_paths))
|
||||||
|
self._cache.raw_data = [item for item in self._cache.raw_data if item.get("file_path") not in file_paths]
|
||||||
|
await self._cache.resort()
|
||||||
|
for path in file_paths:
|
||||||
|
self._hash_index.remove_by_path(path)
|
||||||
|
return {"success": True, "deleted": list(file_paths)}
|
||||||
|
|
||||||
|
|
||||||
|
class MockModelService:
|
||||||
|
"""Service stub consumed by the shared routes."""
|
||||||
|
|
||||||
|
def __init__(self, scanner: MockScanner):
|
||||||
|
self.scanner = scanner
|
||||||
|
self.model_type = "test-model"
|
||||||
|
self.paginated_items: List[Dict[str, Any]] = []
|
||||||
|
self.formatted: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
||||||
|
items = [dict(item) for item in self.paginated_items]
|
||||||
|
total = len(items)
|
||||||
|
page = params.get("page", 1)
|
||||||
|
page_size = params.get("page_size", 20)
|
||||||
|
return {
|
||||||
|
"items": items,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"total_pages": max(1, (total + page_size - 1) // page_size),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def format_response(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
formatted = {**item, "formatted": True}
|
||||||
|
self.formatted.append(formatted)
|
||||||
|
return formatted
|
||||||
|
|
||||||
|
# Convenience helpers used by assorted routes. They are no-ops for the
|
||||||
|
# smoke tests but document the expected surface area of the real services.
|
||||||
|
def get_model_roots(self) -> List[str]:
|
||||||
|
return ["."]
|
||||||
|
|
||||||
|
async def scan_models(self, *_, **__): # pragma: no cover - behaviour exercised via mocks
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_notes(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_model_preview_url(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def get_model_civitai_url(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return {"civitai_url": ""}
|
||||||
|
|
||||||
|
async def get_model_metadata(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def get_model_description(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def get_relative_paths(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return []
|
||||||
|
|
||||||
|
def has_hash(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_path_by_hash(self, *_args, **_kwargs): # pragma: no cover
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_hash_index() -> MockHashIndex:
|
||||||
|
return MockHashIndex()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_cache() -> MockCache:
|
||||||
|
return MockCache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockScanner:
|
||||||
|
return MockScanner(cache=mock_cache, hash_index=mock_hash_index)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_service(mock_scanner: MockScanner) -> MockModelService:
|
||||||
|
return MockModelService(scanner=mock_scanner)
|
||||||
|
|
||||||
|
|
||||||
16
tests/frontend/setup.js
Normal file
16
tests/frontend/setup.js
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
import { afterEach, beforeEach } from 'vitest';
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Ensure storage is clean before each test to avoid cross-test pollution
|
||||||
|
localStorage.clear();
|
||||||
|
sessionStorage.clear();
|
||||||
|
|
||||||
|
// Reset DOM state for modules that rely on body attributes
|
||||||
|
document.body.innerHTML = '';
|
||||||
|
document.body.dataset.page = '';
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
// Clean any dynamically attached globals by tests
|
||||||
|
delete document.body.dataset.page;
|
||||||
|
});
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user