mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
93
docs/architecture/example_images_routes.md
Normal file
93
docs/architecture/example_images_routes.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# Example image route architecture
|
||||
|
||||
The example image routing stack mirrors the layered model route stack described in
|
||||
[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup,
|
||||
handler orchestration, and long-running workflows now live in clearly separated modules so
|
||||
we can extend download/import behaviour without touching the entire feature surface.
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ExampleImagesHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Download manager / processor / file manager]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Filesystem]
|
||||
F1 --> G2[Model metadata]
|
||||
F1 --> G3[WebSocket progress]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. |
|
||||
| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. |
|
||||
| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. |
|
||||
| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. |
|
||||
| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}`
|
||||
mapping consumed by the registrar. The table below outlines how each handler collaborates
|
||||
with the use cases and utilities.
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. |
|
||||
| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. |
|
||||
| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. |
|
||||
| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Shared progress snapshots** - The download handler returns the same snapshot built by
|
||||
`DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket
|
||||
progress events.
|
||||
* **Safe filesystem access** - All folder/file actions flow through
|
||||
`ExampleImagesFileManager`, which validates the configured example image root and ensures
|
||||
responses never leak absolute paths outside the allowed directory.
|
||||
* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`,
|
||||
which updates model metadata via `MetadataManager` and notifies the relevant scanners so
|
||||
cache state stays in sync.
|
||||
|
||||
## Migration notes
|
||||
|
||||
The refactor brings the example image stack in line with the model/recipe stacks:
|
||||
|
||||
1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects
|
||||
should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring
|
||||
handler callables.
|
||||
2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like
|
||||
`BaseModelRoutes`. If you previously instantiated handlers directly, inject custom
|
||||
collaborators via the controller constructor (`download_manager`, `processor`,
|
||||
`file_manager`) to keep test seams predictable.
|
||||
3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching
|
||||
`DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers
|
||||
expect those abstractions to surface validation/concurrency errors, and bypassing them
|
||||
will skip the HTTP-friendly error mapping.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`.
|
||||
2. Expose the coroutine on an existing handler class (or create a new handler and extend
|
||||
`ExampleImagesHandlerSet`).
|
||||
3. Wire additional services or factories inside `_build_handler_set` on
|
||||
`ExampleImagesRoutes`, mirroring how the model stack introduces new use cases.
|
||||
|
||||
`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause
|
||||
flows, and import validations. Use it as a template when introducing new handler
|
||||
collaborators or error mappings.
|
||||
@@ -1,127 +1,100 @@
|
||||
# Base model route architecture
|
||||
|
||||
The `BaseModelRoutes` controller centralizes HTTP endpoints that every model type
|
||||
(LoRAs, checkpoints, embeddings, etc.) share. Each handler either forwards the
|
||||
request to the injected service, delegates to a utility in
|
||||
`ModelRouteUtils`, or orchestrates long‑running operations via helper services
|
||||
such as the download or WebSocket managers. The table below lists every handler
|
||||
exposed in `py/routes/base_model_routes.py`, the collaborators it leans on, and
|
||||
any cache or WebSocket side effects implemented in
|
||||
`py/utils/routes_common.py`.
|
||||
The model routing stack now splits HTTP wiring, orchestration logic, and
|
||||
business rules into discrete layers. The goal is to make it obvious where a
|
||||
new collaborator should live and which contract it must honour. The diagram
|
||||
below captures the end-to-end flow for a typical request:
|
||||
|
||||
## Contents
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[ModelHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & metadata]
|
||||
F1 --> G2[Filesystem]
|
||||
F1 --> G3[WebSocket state]
|
||||
end
|
||||
```
|
||||
|
||||
- [Handler catalogue](#handler-catalogue)
|
||||
- [Dependency map and contracts](#dependency-map-and-contracts)
|
||||
- [Cache and metadata mutations](#cache-and-metadata-mutations)
|
||||
- [Download and WebSocket flows](#download-and-websocket-flows)
|
||||
- [Read-only queries](#read-only-queries)
|
||||
- [Template rendering and initialization](#template-rendering-and-initialization)
|
||||
Every box maps to a concrete module:
|
||||
|
||||
## Handler catalogue
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. |
|
||||
| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. |
|
||||
| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). |
|
||||
| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. |
|
||||
| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. |
|
||||
|
||||
The routes exposed by `BaseModelRoutes` combine HTTP wiring with a handful of
|
||||
shared helper classes. Services surface filesystem and metadata operations,
|
||||
`ModelRouteUtils` bundles cache-sensitive mutations, and `ws_manager`
|
||||
coordinates fan-out to browser clients. The tables below expand the existing
|
||||
catalogue into explicit dependency maps and invariants so refactors can reason
|
||||
about the expectations each collaborator must uphold.
|
||||
## Handler responsibilities & contracts
|
||||
|
||||
## Dependency map and contracts
|
||||
`ModelHandlerSet` flattens the handler objects into the exact callables used by
|
||||
the registrar. The table below highlights the separation of concerns within
|
||||
the set and the invariants that must hold after each handler returns.
|
||||
|
||||
### Cache and metadata mutations
|
||||
|
||||
| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts |
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `/api/lm/{prefix}/delete` | `ModelRouteUtils.handle_delete_model()` | Removes files from disk, prunes `scanner._cache.raw_data`, awaits `scanner._cache.resort()`, calls `scanner._hash_index.remove_by_path()`. | Cache and hash index must no longer reference the deleted path; resort must complete before responding to keep pagination deterministic. |
|
||||
| `/api/lm/{prefix}/exclude` | `ModelRouteUtils.handle_exclude_model()` | Mutates metadata records, `scanner._cache.raw_data`, `scanner._hash_index`, `scanner._tags_count`, and `scanner._excluded_models`. | Excluded models remain discoverable via exclusion list while being hidden from listings; tag counts stay balanced after removal. |
|
||||
| `/api/lm/{prefix}/fetch-civitai` | `ModelRouteUtils.fetch_and_update_model()` | Reads `scanner._cache.raw_data`, writes metadata JSON through `MetadataManager`, syncs cache via `scanner.update_single_model_cache`. | Requires a cached SHA256 hash; cache entries must reflect merged metadata before formatted response is returned. |
|
||||
| `/api/lm/{prefix}/fetch-all-civitai` | `ModelRouteUtils.fetch_and_update_model()`, `ws_manager.broadcast()` | Iterates over cache, updates metadata files and cache records, optionally awaits `scanner._cache.resort()`. | Progress broadcasts follow started → processing → completed; if any model name changes, cache resort must run once before completion broadcast. |
|
||||
| `/api/lm/{prefix}/relink-civitai` | `ModelRouteUtils.handle_relink_civitai()` | Updates metadata on disk and resynchronizes the cache entry. | The new association must propagate to `scanner.update_single_model_cache` so duplicate resolution and listings reflect the change immediately. |
|
||||
| `/api/lm/{prefix}/replace-preview` | `ModelRouteUtils.handle_replace_preview()` | Writes optimized preview file, persists metadata via `MetadataManager`, updates cache with `scanner.update_preview_in_cache()`. | Preview path stored in metadata and cache must match the normalized file system path; NSFW level integer is synchronized across metadata and cache. |
|
||||
| `/api/lm/{prefix}/save-metadata` | `ModelRouteUtils.handle_save_metadata()` | Writes metadata JSON and ensures cache entry mirrors the latest content. | Metadata persistence must be atomic—cache data should match on-disk metadata before response emits success. |
|
||||
| `/api/lm/{prefix}/add-tags` | `ModelRouteUtils.handle_add_tags()` | Updates metadata tags, increments `scanner._tags_count`, and patches cached item. | Tag frequency map remains in sync with cache and metadata after increments. |
|
||||
| `/api/lm/{prefix}/rename` | `ModelRouteUtils.handle_rename_model()` | Renames files, metadata, previews; updates cache indices and hash mappings. | File moves succeed or rollback as a unit so cache state never points to a missing file; hash index entries track the new path. |
|
||||
| `/api/lm/{prefix}/bulk-delete` | `ModelRouteUtils.handle_bulk_delete_models()` | Delegates to `scanner.bulk_delete_models()` to delete files, trim cache, resort, and drop hash index entries. | Every requested path is removed from cache and index; resort happens once after bulk deletion. |
|
||||
| `/api/lm/{prefix}/verify-duplicates` | `ModelRouteUtils.handle_verify_duplicates()` | Recomputes hashes, updates metadata and cached entries if discrepancies found. | Hash metadata stored in cache must mirror recomputed values to guarantee future duplicate checks operate on current data. |
|
||||
| `/api/lm/{prefix}/scan` | `service.scan_models()` | Rescans filesystem, rebuilding scanner cache. | Scanner replaces its cache atomically so subsequent requests observe a consistent snapshot. |
|
||||
| `/api/lm/{prefix}/move_model` | `ModelMoveService.move_model()` | Moves files/directories and notifies scanner via service layer conventions. | Move operations respect filesystem invariants (target path exists, metadata follows file) and emit success/failure without leaving partial moves. |
|
||||
| `/api/lm/{prefix}/move_models_bulk` | `ModelMoveService.move_models_bulk()` | Batch move behavior as above. | Aggregated result enumerates successes/failures while preserving per-model atomicity. |
|
||||
| `/api/lm/{prefix}/auto-organize` (GET/POST) | `ModelFileService.auto_organize_models()`, `ws_manager.get_auto_organize_lock()`, `WebSocketProgressCallback` | Writes organized files, updates metadata, and streams progress snapshots. | Only one auto-organize job may run; lock must guard reentrancy and WebSocket updates must include latest progress payload consumed by polling route. |
|
||||
| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. |
|
||||
| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. |
|
||||
| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. |
|
||||
| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. |
|
||||
| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. |
|
||||
| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. |
|
||||
| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. |
|
||||
| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. |
|
||||
|
||||
### Download and WebSocket flows
|
||||
## Use case boundaries
|
||||
|
||||
| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts |
|
||||
Each use case exposes a narrow asynchronous API that hides the underlying
|
||||
services. Their error mapping is essential for predictable HTTP responses.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `/api/lm/download-model` (POST) & `/api/lm/download-model-get` (GET) | `ModelRouteUtils.handle_download_model()`, `ServiceRegistry.get_download_manager()` | Schedules downloads, registers `ws_manager.broadcast_download_progress()` callback that stores progress in `ws_manager._download_progress`. | Download IDs remain stable across POST/GET helpers; every progress callback persists a timestamped entry so `/download-progress` and WebSocket clients share consistent snapshots. |
|
||||
| `/api/lm/cancel-download-get` | `ModelRouteUtils.handle_cancel_download()` | Signals download manager, prunes `ws_manager._download_progress`, and emits cancellation broadcast. | Cancel requests must tolerate missing IDs gracefully while ensuring cached progress is removed once cancellation succeeds. |
|
||||
| `/api/lm/download-progress/{download_id}` | `ws_manager.get_download_progress()` | Reads cached progress dictionary. | Returns `404` when progress is absent; successful payload surfaces the numeric `progress` stored during broadcasts. |
|
||||
| `/api/lm/{prefix}/fetch-all-civitai` | `ws_manager.broadcast()` | Broadcast loop described above. | Broadcast cadence cannot skip completion/error messages so clients know when to clear UI spinners. |
|
||||
| `/api/lm/{prefix}/auto-organize-progress` | `ws_manager.get_auto_organize_progress()` | Reads cached progress snapshot. | Route returns cached payload verbatim; absence yields `404` to signal idle state. |
|
||||
| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. |
|
||||
| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. |
|
||||
| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. |
|
||||
|
||||
### Read-only queries
|
||||
## Maintaining legacy contracts
|
||||
|
||||
| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `/api/lm/{prefix}/list` | `service.get_paginated_data()`, `service.format_response()` | Reads service-managed pagination data. | Formatting must be applied to every item before response; pagination metadata echoes service result. |
|
||||
| `/api/lm/{prefix}/top-tags` | `service.get_top_tags()` | Reads aggregated tag counts. | Limit parameter bounded to `[1, 100]`; response always wraps tags in `{success: True}` envelope. |
|
||||
| `/api/lm/{prefix}/base-models` | `service.get_base_models()` | Reads service data. | Same limit handling as tags. |
|
||||
| `/api/lm/{prefix}/roots` | `service.get_model_roots()` | Reads configured roots. | Always returns `{success: True, roots: [...]}`. |
|
||||
| `/api/lm/{prefix}/folders` | `service.scanner.get_cached_data()` | Reads folder summaries from cache. | Cache access must tolerate initialization phases by surfacing errors via HTTP 500. |
|
||||
| `/api/lm/{prefix}/folder-tree` | `service.get_folder_tree()` | Reads derived tree for requested root. | Rejects missing `model_root` with HTTP 400. |
|
||||
| `/api/lm/{prefix}/unified-folder-tree` | `service.get_unified_folder_tree()` | Aggregated folder tree. | Returns `{success: True, tree: ...}` or 500 on error. |
|
||||
| `/api/lm/{prefix}/find-duplicates` | `service.find_duplicate_hashes()`, `service.scanner.get_cached_data()`, `service.get_path_by_hash()` | Reads cache and hash index to format duplicates. | Only returns groups with more than one resolved model. |
|
||||
| `/api/lm/{prefix}/find-filename-conflicts` | `service.find_duplicate_filenames()`, `service.scanner.get_cached_data()`, `service.scanner.get_hash_by_filename()` | Similar read-only assembly. | Includes resolved main index entry when available; empty `models` groups are omitted. |
|
||||
| `/api/lm/{prefix}/get-notes` | `service.get_model_notes()` | Reads persisted notes. | Missing notes produce HTTP 404 with explicit error message. |
|
||||
| `/api/lm/{prefix}/preview-url` | `service.get_model_preview_url()` | Resolves static URL. | Successful responses wrap URL in `{success: True}`; missing preview yields 404 error payload. |
|
||||
| `/api/lm/{prefix}/civitai-url` | `service.get_model_civitai_url()` | Returns remote permalink info. | Response envelope matches preview pattern. |
|
||||
| `/api/lm/{prefix}/metadata` | `service.get_model_metadata()` | Reads metadata JSON. | Responds with raw metadata dict or 500 on failure. |
|
||||
| `/api/lm/{prefix}/model-description` | `service.get_model_description()` | Returns formatted description string. | Always JSON with success boolean. |
|
||||
| `/api/lm/{prefix}/relative-paths` | `service.get_relative_paths()` | Resolves filesystem suggestions. | Maintains read-only contract. |
|
||||
| `/api/lm/{prefix}/civitai/versions/{model_id}` | `get_default_metadata_provider()`, `service.has_hash()`, `service.get_path_by_hash()` | Reads remote API, cross-references cache. | Versions payload includes `existsLocally`/`localPath` only when hashes match local indices. |
|
||||
| `/api/lm/{prefix}/civitai/model/version/{modelVersionId}` | `get_default_metadata_provider()` | Remote metadata lookup. | Errors propagate as JSON with `{success: False}` payload. |
|
||||
| `/api/lm/{prefix}/civitai/model/hash/{hash}` | `get_default_metadata_provider()` | Remote metadata lookup. | Missing hashes return 404 with `{success: False}`. |
|
||||
The refactor preserves the invariants called out in the previous architecture
|
||||
notes. The most critical ones are reiterated here to emphasise the
|
||||
collaboration points:
|
||||
|
||||
### Template rendering and initialization
|
||||
1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are
|
||||
channelled through `ModelManagementHandler`. The handler delegates to
|
||||
`ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is
|
||||
mutated in-place before the handler returns. The accompanying tests assert
|
||||
that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after
|
||||
each mutation.
|
||||
2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new
|
||||
asset, `MetadataSyncService` persists the JSON metadata, and
|
||||
`scanner.update_preview_in_cache` mirrors the change. The handler returns
|
||||
the static URL produced by `config.get_preview_static_url`, keeping browser
|
||||
clients in lockstep with disk state.
|
||||
3. **Download progress** – `DownloadCoordinator.schedule_download` generates the
|
||||
download identifier, registers a WebSocket progress callback, and caches the
|
||||
latest numeric progress via `WebSocketManager`. Both `download_model`
|
||||
responses and `/download-progress/{id}` polling read from the same cache to
|
||||
guarantee consistent progress reporting across transports.
|
||||
|
||||
| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `/{prefix}` | `handle_models_page` | Reads configuration via `settings`, sets locale with `server_i18n`, pulls cached folders through `service.scanner.get_cached_data()`, renders Jinja template. | Template rendering must tolerate scanner initialization by flagging `is_initializing`; i18n filter is attached exactly once per environment to avoid duplicate registration errors. |
|
||||
## Extending the stack
|
||||
|
||||
### Contract sequences
|
||||
To add a new shared route:
|
||||
|
||||
The following high-level sequences show how the collaborating services work
|
||||
together for the most stateful operations:
|
||||
1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name.
|
||||
2. Implement the corresponding coroutine on one of the handlers inside
|
||||
`ModelHandlerSet` (or introduce a new handler class when the concern does not
|
||||
fit existing ones).
|
||||
3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by
|
||||
wiring services or use cases through the constructor parameters.
|
||||
|
||||
```
|
||||
delete_model request
|
||||
→ BaseModelRoutes.delete_model
|
||||
→ ModelRouteUtils.handle_delete_model
|
||||
→ filesystem delete + metadata cleanup
|
||||
→ scanner._cache.raw_data prune
|
||||
→ await scanner._cache.resort()
|
||||
→ scanner._hash_index.remove_by_path()
|
||||
```
|
||||
|
||||
```
|
||||
replace_preview request
|
||||
→ BaseModelRoutes.replace_preview
|
||||
→ ModelRouteUtils.handle_replace_preview
|
||||
→ ExifUtils.optimize_image / config.get_preview_static_url
|
||||
→ MetadataManager.save_metadata
|
||||
→ scanner.update_preview_in_cache(model_path, preview_path, nsfw_level)
|
||||
```
|
||||
|
||||
```
|
||||
download_model request
|
||||
→ BaseModelRoutes.download_model
|
||||
→ ModelRouteUtils.handle_download_model
|
||||
→ ServiceRegistry.get_download_manager().download_from_civitai(..., progress_callback)
|
||||
→ ws_manager.broadcast_download_progress(download_id, data)
|
||||
→ ws_manager._download_progress[download_id] updated with timestamp
|
||||
→ /api/lm/download-progress/{id} polls ws_manager.get_download_progress
|
||||
```
|
||||
|
||||
These contracts complement the tables above: if any collaborator changes its
|
||||
behavior, the invariants called out here must continue to hold for the routes
|
||||
to remain predictable.
|
||||
Model-specific routes should continue to be registered inside the subclass
|
||||
implementation of `setup_specific_routes`, reusing the shared registrar where
|
||||
possible.
|
||||
|
||||
89
docs/architecture/recipe_routes.md
Normal file
89
docs/architecture/recipe_routes.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# Recipe route architecture
|
||||
|
||||
The recipe routing stack now mirrors the modular model route design. HTTP
|
||||
bindings, controller wiring, handler orchestration, and business rules live in
|
||||
separate layers so new behaviours can be added without re-threading the entire
|
||||
feature. The diagram below outlines the flow for a typical request:
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph HTTP
|
||||
A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller]
|
||||
end
|
||||
subgraph Application
|
||||
B --> C[RecipeHandlerSet]
|
||||
C --> D1[Handlers]
|
||||
D1 --> E1[Use cases]
|
||||
E1 --> F1[Services / scanners]
|
||||
end
|
||||
subgraph Side Effects
|
||||
F1 --> G1[Cache & fingerprint index]
|
||||
F1 --> G2[Metadata files]
|
||||
F1 --> G3[Temporary shares]
|
||||
end
|
||||
```
|
||||
|
||||
## Layer responsibilities
|
||||
|
||||
| Layer | Module(s) | Responsibility |
|
||||
| --- | --- | --- |
|
||||
| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. |
|
||||
| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. |
|
||||
| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. |
|
||||
| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. |
|
||||
|
||||
## Handler responsibilities & invariants
|
||||
|
||||
`RecipeHandlerSet` flattens purpose-built handler objects into the callables the
|
||||
registrar binds. Each handler is responsible for a narrow concern and enforces a
|
||||
set of invariants before returning:
|
||||
|
||||
| Handler | Key endpoints | Collaborators | Contracts |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. |
|
||||
| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. |
|
||||
| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. |
|
||||
| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. |
|
||||
| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. |
|
||||
| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. |
|
||||
|
||||
## Use case boundaries
|
||||
|
||||
The dedicated services encapsulate long-running work so handlers stay thin.
|
||||
|
||||
| Use case | Entry point | Dependencies | Guarantees |
|
||||
| --- | --- | --- | --- |
|
||||
| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. |
|
||||
| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. |
|
||||
| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. |
|
||||
|
||||
## Maintaining critical invariants
|
||||
|
||||
* **Cache updates** – Mutations (`save`, `delete`, `bulk_delete`, `update`) call
|
||||
back into the recipe scanner to mutate the in-memory cache and fingerprint
|
||||
index before returning a response. Tests assert that these methods are invoked
|
||||
even when stubbing persistence.
|
||||
* **Fingerprint management** – `RecipePersistenceService` recomputes
|
||||
fingerprints whenever LoRA metadata changes and duplicate lookups use those
|
||||
fingerprints to group recipes. Handlers bubble the resulting IDs so clients
|
||||
can merge duplicates without an extra fetch.
|
||||
* **Metadata synchronisation** – Saving or reconnecting a recipe updates the
|
||||
JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the
|
||||
scanner to resort its cache. Sharing relies on this metadata to generate
|
||||
filenames and ensure downloads stay in sync with on-disk state.
|
||||
|
||||
## Extending the stack
|
||||
|
||||
1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name.
|
||||
2. Implement the coroutine on an existing handler or introduce a new handler
|
||||
class inside `py/routes/handlers/recipe_handlers.py` when the concern does
|
||||
not fit existing ones.
|
||||
3. Wire additional collaborators inside
|
||||
`BaseRecipeRoutes._create_handler_set` (inject new services or factories) and
|
||||
expose helper getters on the handler owner if the handler needs to share
|
||||
utilities.
|
||||
|
||||
Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing,
|
||||
mutation, analysis-error, and sharing paths end-to-end, ensuring the controller
|
||||
and handler wiring remains valid as new capabilities are added.
|
||||
|
||||
@@ -166,7 +166,7 @@ class LoraManager:
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
|
||||
@@ -8,13 +8,29 @@ import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..services.metadata_service import get_default_metadata_provider
|
||||
from ..services.download_coordinator import DownloadCoordinator
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ..services.model_lifecycle_service import ModelLifecycleService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.server_i18n import server_i18n as default_server_i18n
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.use_cases import (
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelUseCase,
|
||||
)
|
||||
from ..services.websocket_progress_callback import (
|
||||
WebSocketBroadcastCallback,
|
||||
WebSocketProgressCallback,
|
||||
)
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar
|
||||
from .handlers.model_handlers import (
|
||||
ModelAutoOrganizeHandler,
|
||||
@@ -59,11 +75,31 @@ class BaseModelRoutes(ABC):
|
||||
|
||||
self.model_file_service: ModelFileService | None = None
|
||||
self.model_move_service: ModelMoveService | None = None
|
||||
self.model_lifecycle_service: ModelLifecycleService | None = None
|
||||
self.websocket_progress_callback = WebSocketProgressCallback()
|
||||
self.metadata_progress_callback = WebSocketBroadcastCallback()
|
||||
|
||||
self._handler_set: ModelHandlerSet | None = None
|
||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
self._preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
self._metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=self._preview_service,
|
||||
settings=settings_service,
|
||||
default_metadata_provider_factory=metadata_provider_factory,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager)
|
||||
self._download_coordinator = DownloadCoordinator(
|
||||
ws_manager=self._ws_manager,
|
||||
download_manager_factory=ServiceRegistry.get_download_manager,
|
||||
)
|
||||
|
||||
if service is not None:
|
||||
self.attach_service(service)
|
||||
|
||||
@@ -73,6 +109,12 @@ class BaseModelRoutes(ABC):
|
||||
self.model_type = service.model_type
|
||||
self.model_file_service = ModelFileService(service.scanner, service.model_type)
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
)
|
||||
self._handler_set = None
|
||||
self._handler_mapping = None
|
||||
|
||||
@@ -98,9 +140,28 @@ class BaseModelRoutes(ABC):
|
||||
parse_specific_params=self._parse_specific_params,
|
||||
logger=logger,
|
||||
)
|
||||
management = ModelManagementHandler(service=service, logger=logger)
|
||||
management = ModelManagementHandler(
|
||||
service=service,
|
||||
logger=logger,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
preview_service=self._preview_service,
|
||||
tag_update_service=self._tag_update_service,
|
||||
lifecycle_service=self._ensure_lifecycle_service(),
|
||||
)
|
||||
query = ModelQueryHandler(service=service, logger=logger)
|
||||
download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger)
|
||||
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
|
||||
download = ModelDownloadHandler(
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
download_use_case=download_use_case,
|
||||
download_coordinator=self._download_coordinator,
|
||||
)
|
||||
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
settings_service=self._settings,
|
||||
logger=logger,
|
||||
)
|
||||
civitai = ModelCivitaiHandler(
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
@@ -110,10 +171,17 @@ class BaseModelRoutes(ABC):
|
||||
validate_model_type=self._validate_civitai_model_type,
|
||||
expected_model_types=self._get_expected_model_types,
|
||||
find_model_file=self._find_model_file,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
metadata_refresh_use_case=metadata_refresh_use_case,
|
||||
metadata_progress_callback=self.metadata_progress_callback,
|
||||
)
|
||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
auto_organize_use_case = AutoOrganizeUseCase(
|
||||
file_service=self._ensure_file_service(),
|
||||
lock_provider=self._ws_manager,
|
||||
)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
use_case=auto_organize_use_case,
|
||||
progress_callback=self.websocket_progress_callback,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
@@ -167,10 +235,6 @@ class BaseModelRoutes(ABC):
|
||||
"""Expose handlers for subclasses or tests."""
|
||||
return self._ensure_handler_mapping()[name]
|
||||
|
||||
@property
|
||||
def utils(self) -> ModelRouteUtils: # pragma: no cover - compatibility shim
|
||||
return ModelRouteUtils
|
||||
|
||||
def _ensure_service(self):
|
||||
if self.service is None:
|
||||
raise RuntimeError("Model service has not been attached")
|
||||
@@ -188,6 +252,17 @@ class BaseModelRoutes(ABC):
|
||||
self.model_move_service = ModelMoveService(service.scanner)
|
||||
return self.model_move_service
|
||||
|
||||
def _ensure_lifecycle_service(self) -> ModelLifecycleService:
|
||||
if self.model_lifecycle_service is None:
|
||||
service = self._ensure_service()
|
||||
self.model_lifecycle_service = ModelLifecycleService(
|
||||
scanner=service.scanner,
|
||||
metadata_manager=MetadataManager,
|
||||
metadata_loader=self._metadata_sync_service.load_local_metadata,
|
||||
recipe_scanner_factory=ServiceRegistry.get_recipe_scanner,
|
||||
)
|
||||
return self.model_lifecycle_service
|
||||
|
||||
def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]:
|
||||
async def proxy(request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
|
||||
217
py/routes/base_recipe_routes.py
Normal file
217
py/routes/base_recipe_routes.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Base infrastructure shared across recipe routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Mapping
|
||||
|
||||
import jinja2
|
||||
from aiohttp import web
|
||||
|
||||
from ..config import config
|
||||
from ..recipes import RecipeParserFactory
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
)
|
||||
from ..services.server_i18n import server_i18n
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from .handlers.recipe_handlers import (
|
||||
RecipeAnalysisHandler,
|
||||
RecipeHandlerSet,
|
||||
RecipeListingHandler,
|
||||
RecipeManagementHandler,
|
||||
RecipePageView,
|
||||
RecipeQueryHandler,
|
||||
RecipeSharingHandler,
|
||||
)
|
||||
from .recipe_route_registrar import ROUTE_DEFINITIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseRecipeRoutes:
|
||||
"""Common dependency and startup wiring for recipe routes."""
|
||||
|
||||
_HANDLER_NAMES: tuple[str, ...] = tuple(
|
||||
definition.handler_name for definition in ROUTE_DEFINITIONS
|
||||
)
|
||||
|
||||
template_name: str = "recipes.html"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.recipe_scanner = None
|
||||
self.lora_scanner = None
|
||||
self.civitai_client = None
|
||||
self.settings = settings
|
||||
self.server_i18n = server_i18n
|
||||
self.template_env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(config.templates_path),
|
||||
autoescape=True,
|
||||
)
|
||||
|
||||
self._i18n_registered = False
|
||||
self._startup_hooks_registered = False
|
||||
self._handler_set: RecipeHandlerSet | None = None
|
||||
self._handler_mapping: dict[str, Callable] | None = None
|
||||
|
||||
async def attach_dependencies(self, app: web.Application | None = None) -> None:
|
||||
"""Resolve shared services from the registry."""
|
||||
|
||||
await self._ensure_services()
|
||||
self._ensure_i18n_filter()
|
||||
|
||||
async def ensure_dependencies_ready(self) -> None:
|
||||
"""Ensure dependencies are available for request handlers."""
|
||||
|
||||
if self.recipe_scanner is None or self.civitai_client is None:
|
||||
await self.attach_dependencies()
|
||||
|
||||
def register_startup_hooks(self, app: web.Application) -> None:
|
||||
"""Register startup hooks once for dependency wiring."""
|
||||
|
||||
if self._startup_hooks_registered:
|
||||
return
|
||||
|
||||
app.on_startup.append(self.attach_dependencies)
|
||||
app.on_startup.append(self.prewarm_cache)
|
||||
self._startup_hooks_registered = True
|
||||
|
||||
async def prewarm_cache(self, app: web.Application | None = None) -> None:
|
||||
"""Pre-load recipe and LoRA caches on startup."""
|
||||
|
||||
try:
|
||||
await self.attach_dependencies(app)
|
||||
|
||||
if self.lora_scanner is not None:
|
||||
await self.lora_scanner.get_cached_data()
|
||||
hash_index = getattr(self.lora_scanner, "_hash_index", None)
|
||||
if hash_index is not None and hasattr(hash_index, "_hash_to_path"):
|
||||
_ = len(hash_index._hash_to_path)
|
||||
|
||||
if self.recipe_scanner is not None:
|
||||
await self.recipe_scanner.get_cached_data(force_refresh=True)
|
||||
except Exception as exc:
|
||||
logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True)
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable]:
|
||||
"""Return a mapping of handler name to coroutine for registrar binding."""
|
||||
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._create_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
async def _ensure_services(self) -> None:
|
||||
if self.recipe_scanner is None:
|
||||
self.recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None)
|
||||
|
||||
if self.civitai_client is None:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
def _ensure_i18n_filter(self) -> None:
|
||||
if not self._i18n_registered:
|
||||
self.template_env.filters["t"] = self.server_i18n.create_template_filter()
|
||||
self._i18n_registered = True
|
||||
|
||||
def get_handler_owner(self):
|
||||
"""Return the object supplying bound handler coroutines."""
|
||||
|
||||
if self._handler_set is None:
|
||||
self._handler_set = self._create_handler_set()
|
||||
return self._handler_set
|
||||
|
||||
def _create_handler_set(self) -> RecipeHandlerSet:
|
||||
recipe_scanner_getter = lambda: self.recipe_scanner
|
||||
civitai_client_getter = lambda: self.civitai_client
|
||||
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
if not standalone_mode:
|
||||
from ..metadata_collector import get_metadata # type: ignore[import-not-found]
|
||||
from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found]
|
||||
MetadataProcessor,
|
||||
)
|
||||
from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found]
|
||||
MetadataRegistry,
|
||||
)
|
||||
else: # pragma: no cover - optional dependency path
|
||||
get_metadata = None # type: ignore[assignment]
|
||||
MetadataProcessor = None # type: ignore[assignment]
|
||||
MetadataRegistry = None # type: ignore[assignment]
|
||||
|
||||
analysis_service = RecipeAnalysisService(
|
||||
exif_utils=ExifUtils,
|
||||
recipe_parser_factory=RecipeParserFactory,
|
||||
downloader_factory=get_downloader,
|
||||
metadata_collector=get_metadata,
|
||||
metadata_processor_cls=MetadataProcessor,
|
||||
metadata_registry_cls=MetadataRegistry,
|
||||
standalone_mode=standalone_mode,
|
||||
logger=logger,
|
||||
)
|
||||
persistence_service = RecipePersistenceService(
|
||||
exif_utils=ExifUtils,
|
||||
card_preview_width=CARD_PREVIEW_WIDTH,
|
||||
logger=logger,
|
||||
)
|
||||
sharing_service = RecipeSharingService(logger=logger)
|
||||
|
||||
page_view = RecipePageView(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
settings_service=self.settings,
|
||||
server_i18n=self.server_i18n,
|
||||
template_env=self.template_env,
|
||||
template_name=self.template_name,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
listing = RecipeListingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
)
|
||||
query = RecipeQueryHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
format_recipe_file_url=listing.format_recipe_file_url,
|
||||
logger=logger,
|
||||
)
|
||||
management = RecipeManagementHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
persistence_service=persistence_service,
|
||||
analysis_service=analysis_service,
|
||||
)
|
||||
analysis = RecipeAnalysisHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
civitai_client_getter=civitai_client_getter,
|
||||
logger=logger,
|
||||
analysis_service=analysis_service,
|
||||
)
|
||||
sharing = RecipeSharingHandler(
|
||||
ensure_dependencies_ready=self.ensure_dependencies_ready,
|
||||
recipe_scanner_getter=recipe_scanner_getter,
|
||||
logger=logger,
|
||||
sharing_service=sharing_service,
|
||||
)
|
||||
|
||||
return RecipeHandlerSet(
|
||||
page_view=page_view,
|
||||
listing=listing,
|
||||
query=query,
|
||||
management=management,
|
||||
analysis=analysis,
|
||||
sharing=sharing,
|
||||
)
|
||||
|
||||
61
py/routes/example_images_route_registrar.py
Normal file
61
py/routes/example_images_route_registrar.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Route registrar for example image endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative configuration for a HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"),
|
||||
RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"),
|
||||
RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"),
|
||||
RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"),
|
||||
RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"),
|
||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||
)
|
||||
|
||||
|
||||
class ExampleImagesRouteRegistrar:
|
||||
"""Bind declarative example image routes to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(
|
||||
self,
|
||||
handler_lookup: Mapping[str, Callable[[web.Request], object]],
|
||||
*,
|
||||
definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS,
|
||||
) -> None:
|
||||
"""Register each route definition using the supplied handlers."""
|
||||
|
||||
for definition in definitions:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
@@ -1,74 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from ..utils.example_images_download_manager import DownloadManager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from .example_images_route_registrar import ExampleImagesRouteRegistrar
|
||||
from .handlers.example_images_handlers import (
|
||||
ExampleImagesDownloadHandler,
|
||||
ExampleImagesFileHandler,
|
||||
ExampleImagesHandlerSet,
|
||||
ExampleImagesManagementHandler,
|
||||
)
|
||||
from ..services.use_cases.example_images import (
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
)
|
||||
from ..utils.example_images_download_manager import (
|
||||
DownloadManager,
|
||||
get_default_download_manager,
|
||||
)
|
||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||
from ..services.websocket_manager import ws_manager
|
||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleImagesRoutes:
|
||||
"""Routes for example images related functionality"""
|
||||
|
||||
@staticmethod
|
||||
def setup_routes(app):
|
||||
"""Register example images routes"""
|
||||
app.router.add_post('/api/lm/download-example-images', ExampleImagesRoutes.download_example_images)
|
||||
app.router.add_post('/api/lm/import-example-images', ExampleImagesRoutes.import_example_images)
|
||||
app.router.add_get('/api/lm/example-images-status', ExampleImagesRoutes.get_example_images_status)
|
||||
app.router.add_post('/api/lm/pause-example-images', ExampleImagesRoutes.pause_example_images)
|
||||
app.router.add_post('/api/lm/resume-example-images', ExampleImagesRoutes.resume_example_images)
|
||||
app.router.add_post('/api/lm/open-example-images-folder', ExampleImagesRoutes.open_example_images_folder)
|
||||
app.router.add_get('/api/lm/example-image-files', ExampleImagesRoutes.get_example_image_files)
|
||||
app.router.add_get('/api/lm/has-example-images', ExampleImagesRoutes.has_example_images)
|
||||
app.router.add_post('/api/lm/delete-example-image', ExampleImagesRoutes.delete_example_image)
|
||||
app.router.add_post('/api/lm/force-download-example-images', ExampleImagesRoutes.force_download_example_images)
|
||||
"""Route controller for example image endpoints."""
|
||||
|
||||
@staticmethod
|
||||
async def download_example_images(request):
|
||||
"""Download example images for models from Civitai"""
|
||||
return await DownloadManager.start_download(request)
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager: DownloadManager | None = None,
|
||||
processor=ExampleImagesProcessor,
|
||||
file_manager=ExampleImagesFileManager,
|
||||
) -> None:
|
||||
if ws_manager is None:
|
||||
raise ValueError("ws_manager is required")
|
||||
self._download_manager = download_manager or get_default_download_manager(ws_manager)
|
||||
self._processor = processor
|
||||
self._file_manager = file_manager
|
||||
self._handler_set: ExampleImagesHandlerSet | None = None
|
||||
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
|
||||
@staticmethod
|
||||
async def get_example_images_status(request):
|
||||
"""Get the current status of example images download"""
|
||||
return await DownloadManager.get_status(request)
|
||||
@classmethod
|
||||
def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
|
||||
"""Register routes on the given aiohttp application using default wiring."""
|
||||
|
||||
@staticmethod
|
||||
async def pause_example_images(request):
|
||||
"""Pause the example images download"""
|
||||
return await DownloadManager.pause_download(request)
|
||||
controller = cls(ws_manager=ws_manager)
|
||||
controller.register(app)
|
||||
|
||||
@staticmethod
|
||||
async def resume_example_images(request):
|
||||
"""Resume the example images download"""
|
||||
return await DownloadManager.resume_download(request)
|
||||
|
||||
@staticmethod
|
||||
async def open_example_images_folder(request):
|
||||
"""Open the example images folder for a specific model"""
|
||||
return await ExampleImagesFileManager.open_folder(request)
|
||||
def register(self, app: web.Application) -> None:
|
||||
"""Bind the controller's handlers to the aiohttp router."""
|
||||
|
||||
@staticmethod
|
||||
async def get_example_image_files(request):
|
||||
"""Get list of example image files for a specific model"""
|
||||
return await ExampleImagesFileManager.get_files(request)
|
||||
registrar = ExampleImagesRouteRegistrar(app)
|
||||
registrar.register_routes(self.to_route_mapping())
|
||||
|
||||
@staticmethod
|
||||
async def import_example_images(request):
|
||||
"""Import local example images for a model"""
|
||||
return await ExampleImagesProcessor.import_images(request)
|
||||
|
||||
@staticmethod
|
||||
async def has_example_images(request):
|
||||
"""Check if example images folder exists and is not empty for a model"""
|
||||
return await ExampleImagesFileManager.has_images(request)
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
"""Return the registrar-compatible mapping of handler names to callables."""
|
||||
|
||||
@staticmethod
|
||||
async def delete_example_image(request):
|
||||
"""Delete a custom example image for a model"""
|
||||
return await ExampleImagesProcessor.delete_custom_image(request)
|
||||
if self._handler_mapping is None:
|
||||
handler_set = self._build_handler_set()
|
||||
self._handler_set = handler_set
|
||||
self._handler_mapping = handler_set.to_route_mapping()
|
||||
return self._handler_mapping
|
||||
|
||||
@staticmethod
|
||||
async def force_download_example_images(request):
|
||||
"""Force download example images for specific models"""
|
||||
return await DownloadManager.start_force_download(request)
|
||||
def _build_handler_set(self) -> ExampleImagesHandlerSet:
|
||||
logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager)
|
||||
download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager)
|
||||
download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager)
|
||||
import_use_case = ImportExampleImagesUseCase(processor=self._processor)
|
||||
management_handler = ExampleImagesManagementHandler(import_use_case, self._processor)
|
||||
file_handler = ExampleImagesFileHandler(self._file_manager)
|
||||
return ExampleImagesHandlerSet(
|
||||
download=download_handler,
|
||||
management=management_handler,
|
||||
files=file_handler,
|
||||
)
|
||||
|
||||
147
py/routes/handlers/example_images_handlers.py
Normal file
147
py/routes/handlers/example_images_handlers.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Handler set for example image routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...services.use_cases.example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from ...utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
DownloadNotRunningError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from ...utils.example_images_processor import ExampleImagesImportError
|
||||
|
||||
|
||||
class ExampleImagesDownloadHandler:
|
||||
"""HTTP adapters for download-related example image endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
download_use_case: DownloadExampleImagesUseCase,
|
||||
download_manager,
|
||||
) -> None:
|
||||
self._download_use_case = download_use_case
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_use_case.execute(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadExampleImagesInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadExampleImagesConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def get_example_images_status(self, request: web.Request) -> web.StreamResponse:
|
||||
result = await self._download_manager.get_status(request)
|
||||
return web.json_response(result)
|
||||
|
||||
async def pause_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.pause_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def resume_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.resume_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_manager.start_force_download(payload)
|
||||
return web.json_response(result)
|
||||
except DownloadInProgressError as exc:
|
||||
response = {
|
||||
'success': False,
|
||||
'error': str(exc),
|
||||
'status': exc.progress_snapshot,
|
||||
}
|
||||
return web.json_response(response, status=400)
|
||||
except DownloadConfigurationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
|
||||
class ExampleImagesManagementHandler:
|
||||
"""HTTP adapters for import/delete endpoints."""
|
||||
|
||||
def __init__(self, import_use_case: ImportExampleImagesUseCase, processor) -> None:
|
||||
self._import_use_case = import_use_case
|
||||
self._processor = processor
|
||||
|
||||
async def import_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._import_use_case.execute(request)
|
||||
return web.json_response(result)
|
||||
except ImportExampleImagesValidationError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
except ExampleImagesImportError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def delete_example_image(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._processor.delete_custom_image(request)
|
||||
|
||||
|
||||
class ExampleImagesFileHandler:
|
||||
"""HTTP adapters for filesystem-centric endpoints."""
|
||||
|
||||
def __init__(self, file_manager) -> None:
|
||||
self._file_manager = file_manager
|
||||
|
||||
async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.open_folder(request)
|
||||
|
||||
async def get_example_image_files(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.get_files(request)
|
||||
|
||||
async def has_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
return await self._file_manager.has_images(request)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExampleImagesHandlerSet:
|
||||
"""Aggregate of handlers exposed to the registrar."""
|
||||
|
||||
download: ExampleImagesDownloadHandler
|
||||
management: ExampleImagesManagementHandler
|
||||
files: ExampleImagesFileHandler
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]:
|
||||
"""Flatten handler methods into the registrar mapping."""
|
||||
|
||||
return {
|
||||
"download_example_images": self.download.download_example_images,
|
||||
"get_example_images_status": self.download.get_example_images_status,
|
||||
"pause_example_images": self.download.pause_example_images,
|
||||
"resume_example_images": self.download.resume_example_images,
|
||||
"force_download_example_images": self.download.force_download_example_images,
|
||||
"import_example_images": self.management.import_example_images,
|
||||
"delete_example_image": self.management.delete_example_image,
|
||||
"open_example_images_folder": self.files.open_example_images_folder,
|
||||
"get_example_image_files": self.files.get_example_image_files,
|
||||
"has_example_images": self.files.has_example_images,
|
||||
}
|
||||
@@ -4,17 +4,32 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
import jinja2
|
||||
|
||||
from ...services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelMoveService
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...utils.routes_common import ModelRouteUtils
|
||||
from ...services.tag_update_service import TagUpdateService
|
||||
from ...services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
MetadataRefreshProgressReporter,
|
||||
)
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
|
||||
|
||||
class ModelPageView:
|
||||
@@ -168,15 +183,52 @@ class ModelListingHandler:
|
||||
class ModelManagementHandler:
|
||||
"""Handle mutation operations on models."""
|
||||
|
||||
def __init__(self, *, service, logger: logging.Logger) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
logger: logging.Logger,
|
||||
metadata_sync: MetadataSyncService,
|
||||
preview_service: PreviewAssetService,
|
||||
tag_update_service: TagUpdateService,
|
||||
lifecycle_service,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._logger = logger
|
||||
self._metadata_sync = metadata_sync
|
||||
self._preview_service = preview_service
|
||||
self._tag_update_service = tag_update_service
|
||||
self._lifecycle_service = lifecycle_service
|
||||
|
||||
async def delete_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_delete_model(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
return web.Response(text="Model path is required", status=400)
|
||||
|
||||
result = await self._lifecycle_service.delete_model(file_path)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error deleting model: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_exclude_model(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
return web.Response(text="Model path is required", status=400)
|
||||
|
||||
result = await self._lifecycle_service.exclude_model(file_path)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error excluding model: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -192,7 +244,7 @@ class ModelManagementHandler:
|
||||
if not model_data.get("sha256"):
|
||||
return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400)
|
||||
|
||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
success, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model_data["sha256"],
|
||||
file_path=file_path,
|
||||
model_data=model_data,
|
||||
@@ -208,25 +260,221 @@ class ModelManagementHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def relink_civitai(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
model_id = data.get("model_id")
|
||||
model_version_id = data.get("model_version_id")
|
||||
|
||||
if not file_path or model_id is None:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Both file_path and model_id are required"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
local_metadata = await self._metadata_sync.load_local_metadata(metadata_path)
|
||||
|
||||
updated_metadata = await self._metadata_sync.relink_metadata(
|
||||
file_path=file_path,
|
||||
metadata=local_metadata,
|
||||
model_id=int(model_id),
|
||||
model_version_id=int(model_version_id) if model_version_id else None,
|
||||
)
|
||||
|
||||
await self._service.scanner.update_single_model_cache(
|
||||
file_path, file_path, updated_metadata
|
||||
)
|
||||
|
||||
message = (
|
||||
f"Model successfully re-linked to Civitai model {model_id}"
|
||||
+ (f" version {model_version_id}" if model_version_id else "")
|
||||
)
|
||||
return web.json_response(
|
||||
{"success": True, "message": message, "hash": updated_metadata.get("sha256", "")}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def replace_preview(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner)
|
||||
try:
|
||||
reader = await request.multipart()
|
||||
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "preview_file":
|
||||
raise ValueError("Expected 'preview_file' field")
|
||||
content_type = field.headers.get("Content-Type", "image/png")
|
||||
content_disposition = field.headers.get("Content-Disposition", "")
|
||||
|
||||
original_filename = None
|
||||
import re
|
||||
|
||||
match = re.search(r'filename="(.*?)"', content_disposition)
|
||||
if match:
|
||||
original_filename = match.group(1)
|
||||
|
||||
preview_data = await field.read()
|
||||
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "model_path":
|
||||
raise ValueError("Expected 'model_path' field")
|
||||
model_path = (await field.read()).decode()
|
||||
|
||||
nsfw_level = 0
|
||||
field = await reader.next()
|
||||
if field and field.name == "nsfw_level":
|
||||
try:
|
||||
nsfw_level = int((await field.read()).decode())
|
||||
except (ValueError, TypeError):
|
||||
self._logger.warning("Invalid NSFW level format, using default 0")
|
||||
|
||||
result = await self._preview_service.replace_preview(
|
||||
model_path=model_path,
|
||||
preview_data=preview_data,
|
||||
content_type=content_type,
|
||||
original_filename=original_filename,
|
||||
nsfw_level=nsfw_level,
|
||||
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"preview_url": config.get_preview_static_url(result["preview_path"]),
|
||||
"preview_nsfw_level": result["preview_nsfw_level"],
|
||||
}
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
if not file_path:
|
||||
return web.Response(text="File path is required", status=400)
|
||||
|
||||
metadata_updates = {k: v for k, v in data.items() if k != "file_path"}
|
||||
|
||||
await self._metadata_sync.save_metadata_updates(
|
||||
file_path=file_path,
|
||||
updates=metadata_updates,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
if "model_name" in metadata_updates:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
await cache.resort()
|
||||
|
||||
return web.json_response({"success": True})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving metadata: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def add_tags(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_add_tags(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
new_tags = data.get("tags", [])
|
||||
|
||||
if not file_path:
|
||||
return web.Response(text="File path is required", status=400)
|
||||
|
||||
if not isinstance(new_tags, list):
|
||||
return web.Response(text="Tags must be a list", status=400)
|
||||
|
||||
tags = await self._tag_update_service.add_tags(
|
||||
file_path=file_path,
|
||||
new_tags=new_tags,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, "tags": tags})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error adding tags: %s", exc, exc_info=True)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
async def rename_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_rename_model(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get("file_path")
|
||||
new_file_name = data.get("new_file_name")
|
||||
|
||||
if not file_path or not new_file_name:
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "File path and new file name are required",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = await self._lifecycle_service.rename_model(
|
||||
file_path=file_path, new_file_name=new_file_name
|
||||
)
|
||||
|
||||
return web.json_response(
|
||||
{
|
||||
**result,
|
||||
"new_preview_path": config.get_preview_static_url(
|
||||
result.get("new_preview_path")
|
||||
),
|
||||
}
|
||||
)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error renaming model: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def bulk_delete_models(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get("file_paths", [])
|
||||
if not file_paths:
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "No file paths provided for deletion",
|
||||
},
|
||||
status=400,
|
||||
)
|
||||
|
||||
result = await self._lifecycle_service.bulk_delete_models(file_paths)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error in bulk delete: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def verify_duplicates(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner)
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get("file_paths", [])
|
||||
|
||||
if not file_paths:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "No file paths provided for verification"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
results = await self._metadata_sync.verify_duplicate_hashes(
|
||||
file_paths=file_paths,
|
||||
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||
hash_calculator=calculate_sha256,
|
||||
update_cache=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
|
||||
return web.json_response({"success": True, **results})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class ModelQueryHandler:
|
||||
@@ -429,12 +677,35 @@ class ModelQueryHandler:
|
||||
class ModelDownloadHandler:
|
||||
"""Coordinate downloads and progress reporting."""
|
||||
|
||||
def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
download_use_case: DownloadModelUseCase,
|
||||
download_coordinator: DownloadCoordinator,
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
self._download_use_case = download_use_case
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def download_model(self, request: web.Request) -> web.Response:
|
||||
return await ModelRouteUtils.handle_download_model(request)
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_use_case.execute(payload)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except DownloadModelValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except DownloadModelEarlyAccessError as exc:
|
||||
self._logger.warning("Early access error: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
||||
except Exception as exc:
|
||||
error_message = str(exc)
|
||||
self._logger.error("Error downloading model: %s", error_message, exc_info=True)
|
||||
return web.json_response({"success": False, "error": error_message}, status=500)
|
||||
|
||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
@@ -460,7 +731,15 @@ class ModelDownloadHandler:
|
||||
future.set_result(data)
|
||||
|
||||
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
||||
return await ModelRouteUtils.handle_download_model(mock_request)
|
||||
result = await self._download_use_case.execute(data)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except DownloadModelValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except DownloadModelEarlyAccessError as exc:
|
||||
self._logger.warning("Early access error: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
@@ -470,8 +749,8 @@ class ModelDownloadHandler:
|
||||
download_id = request.query.get("download_id")
|
||||
if not download_id:
|
||||
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
|
||||
mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})()
|
||||
return await ModelRouteUtils.handle_cancel_download(mock_request)
|
||||
result = await self._download_coordinator.cancel_download(download_id)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
@@ -504,6 +783,9 @@ class ModelCivitaiHandler:
|
||||
validate_model_type: Callable[[str], bool],
|
||||
expected_model_types: Callable[[], str],
|
||||
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
|
||||
metadata_sync: MetadataSyncService,
|
||||
metadata_refresh_use_case: BulkMetadataRefreshUseCase,
|
||||
metadata_progress_callback: MetadataRefreshProgressReporter,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._settings = settings_service
|
||||
@@ -513,75 +795,17 @@ class ModelCivitaiHandler:
|
||||
self._validate_model_type = validate_model_type
|
||||
self._expected_model_types = expected_model_types
|
||||
self._find_model_file = find_model_file
|
||||
self._metadata_sync = metadata_sync
|
||||
self._metadata_refresh_use_case = metadata_refresh_use_case
|
||||
self._metadata_progress_callback = metadata_progress_callback
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
|
||||
to_process = [
|
||||
model
|
||||
for model in cache.raw_data
|
||||
if model.get("sha256")
|
||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||
and (
|
||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||
)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "started",
|
||||
"total": total_to_process,
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
})
|
||||
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
update_cache_func=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
if original_name != model.get("model_name"):
|
||||
needs_resort = True
|
||||
processed += 1
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "processing",
|
||||
"total": total_to_process,
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
"current_name": model.get("model_name", "Unknown"),
|
||||
})
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc)
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "completed",
|
||||
"total": total_to_process,
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})",
|
||||
})
|
||||
result = await self._metadata_refresh_use_case.execute_with_error_handling(
|
||||
progress_callback=self._metadata_progress_callback
|
||||
)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
await self._ws_manager.broadcast({"status": "error", "error": str(exc)})
|
||||
self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
@@ -687,31 +911,18 @@ class ModelAutoOrganizeHandler:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_service: ModelFileService,
|
||||
use_case: AutoOrganizeUseCase,
|
||||
progress_callback: WebSocketProgressCallback,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._file_service = file_service
|
||||
self._use_case = use_case
|
||||
self._progress_callback = progress_callback
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
|
||||
async def auto_organize_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
if self._ws_manager.is_auto_organize_running():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
|
||||
auto_organize_lock = await self._ws_manager.get_auto_organize_lock()
|
||||
if auto_organize_lock.locked():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
|
||||
file_paths = None
|
||||
if request.method == "POST":
|
||||
try:
|
||||
@@ -720,17 +931,24 @@ class ModelAutoOrganizeHandler:
|
||||
except Exception: # pragma: no cover - permissive path
|
||||
pass
|
||||
|
||||
async with auto_organize_lock:
|
||||
result = await self._file_service.auto_organize_models(
|
||||
file_paths=file_paths,
|
||||
progress_callback=self._progress_callback,
|
||||
)
|
||||
return web.json_response(result.to_dict())
|
||||
result = await self._use_case.execute(
|
||||
file_paths=file_paths,
|
||||
progress_callback=self._progress_callback,
|
||||
)
|
||||
return web.json_response(result.to_dict())
|
||||
except AutoOrganizeInProgressError:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True)
|
||||
await self._ws_manager.broadcast_auto_organize_progress(
|
||||
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
|
||||
)
|
||||
try:
|
||||
await self._progress_callback.on_progress(
|
||||
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive reporting
|
||||
pass
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_auto_organize_progress(self, request: web.Request) -> web.Response:
|
||||
|
||||
725
py/routes/handlers/recipe_handlers.py
Normal file
725
py/routes/handlers/recipe_handlers.py
Normal file
@@ -0,0 +1,725 @@
|
||||
"""Dedicated handler objects for recipe-related routes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ...config import config
|
||||
from ...services.server_i18n import server_i18n as default_server_i18n
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.recipes import (
|
||||
RecipeAnalysisService,
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
RecipePersistenceService,
|
||||
RecipeSharingService,
|
||||
RecipeValidationError,
|
||||
)
|
||||
|
||||
Logger = logging.Logger
|
||||
EnsureDependenciesCallable = Callable[[], Awaitable[None]]
|
||||
RecipeScannerGetter = Callable[[], Any]
|
||||
CivitaiClientGetter = Callable[[], Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RecipeHandlerSet:
|
||||
"""Group of handlers providing recipe route implementations."""
|
||||
|
||||
page_view: "RecipePageView"
|
||||
listing: "RecipeListingHandler"
|
||||
query: "RecipeQueryHandler"
|
||||
management: "RecipeManagementHandler"
|
||||
analysis: "RecipeAnalysisHandler"
|
||||
sharing: "RecipeSharingHandler"
|
||||
|
||||
def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]:
|
||||
"""Expose handler coroutines keyed by registrar handler names."""
|
||||
|
||||
return {
|
||||
"render_page": self.page_view.render_page,
|
||||
"list_recipes": self.listing.list_recipes,
|
||||
"get_recipe": self.listing.get_recipe,
|
||||
"analyze_uploaded_image": self.analysis.analyze_uploaded_image,
|
||||
"analyze_local_image": self.analysis.analyze_local_image,
|
||||
"save_recipe": self.management.save_recipe,
|
||||
"delete_recipe": self.management.delete_recipe,
|
||||
"get_top_tags": self.query.get_top_tags,
|
||||
"get_base_models": self.query.get_base_models,
|
||||
"share_recipe": self.sharing.share_recipe,
|
||||
"download_shared_recipe": self.sharing.download_shared_recipe,
|
||||
"get_recipe_syntax": self.query.get_recipe_syntax,
|
||||
"update_recipe": self.management.update_recipe,
|
||||
"reconnect_lora": self.management.reconnect_lora,
|
||||
"find_duplicates": self.query.find_duplicates,
|
||||
"bulk_delete": self.management.bulk_delete,
|
||||
"save_recipe_from_widget": self.management.save_recipe_from_widget,
|
||||
"get_recipes_for_lora": self.query.get_recipes_for_lora,
|
||||
"scan_recipes": self.query.scan_recipes,
|
||||
}
|
||||
|
||||
|
||||
class RecipePageView:
|
||||
"""Render the recipe shell page."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
settings_service: SettingsManager,
|
||||
server_i18n=default_server_i18n,
|
||||
template_env,
|
||||
template_name: str,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._settings = settings_service
|
||||
self._server_i18n = server_i18n
|
||||
self._template_env = template_env
|
||||
self._template_name = template_name
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def render_page(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None: # pragma: no cover - defensive guard
|
||||
raise RuntimeError("Recipe scanner not available")
|
||||
|
||||
user_language = self._settings.get("language", "en")
|
||||
self._server_i18n.set_locale(user_language)
|
||||
|
||||
try:
|
||||
await recipe_scanner.get_cached_data(force_refresh=False)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
recipes=[],
|
||||
is_initializing=False,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
except Exception as cache_error: # pragma: no cover - logging path
|
||||
self._logger.error("Error loading recipe cache data: %s", cache_error)
|
||||
rendered = self._template_env.get_template(self._template_name).render(
|
||||
is_initializing=True,
|
||||
settings=self._settings,
|
||||
request=request,
|
||||
t=self._server_i18n.get_translation,
|
||||
)
|
||||
return web.Response(text=rendered, content_type="text/html")
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error handling recipes request: %s", exc, exc_info=True)
|
||||
return web.Response(text="Error loading recipes page", status=500)
|
||||
|
||||
|
||||
class RecipeListingHandler:
|
||||
"""Provide listing and detail APIs for recipes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
|
||||
async def list_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
page = int(request.query.get("page", "1"))
|
||||
page_size = int(request.query.get("page_size", "20"))
|
||||
sort_by = request.query.get("sort_by", "date")
|
||||
search = request.query.get("search")
|
||||
|
||||
search_options = {
|
||||
"title": request.query.get("search_title", "true").lower() == "true",
|
||||
"tags": request.query.get("search_tags", "true").lower() == "true",
|
||||
"lora_name": request.query.get("search_lora_name", "true").lower() == "true",
|
||||
"lora_model": request.query.get("search_lora_model", "true").lower() == "true",
|
||||
}
|
||||
|
||||
filters: Dict[str, list[str]] = {}
|
||||
base_models = request.query.get("base_models")
|
||||
if base_models:
|
||||
filters["base_model"] = base_models.split(",")
|
||||
|
||||
tags = request.query.get("tags")
|
||||
if tags:
|
||||
filters["tags"] = tags.split(",")
|
||||
|
||||
lora_hash = request.query.get("lora_hash")
|
||||
|
||||
result = await recipe_scanner.get_paginated_data(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort_by=sort_by,
|
||||
search=search,
|
||||
filters=filters,
|
||||
search_options=search_options,
|
||||
lora_hash=lora_hash,
|
||||
)
|
||||
|
||||
for item in result.get("items", []):
|
||||
file_path = item.get("file_path")
|
||||
if file_path:
|
||||
item["file_url"] = self.format_recipe_file_url(file_path)
|
||||
else:
|
||||
item.setdefault("file_url", "/loras_static/images/no-preview.png")
|
||||
item.setdefault("loras", [])
|
||||
item.setdefault("base_model", "")
|
||||
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
|
||||
if not recipe:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
return web.json_response(recipe)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
def format_recipe_file_url(self, file_path: str) -> str:
|
||||
try:
|
||||
recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, "/")
|
||||
normalized_path = file_path.replace(os.sep, "/")
|
||||
if normalized_path.startswith(recipes_dir):
|
||||
relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, "/")
|
||||
return f"/loras_static/root1/preview/{relative_path}"
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
return f"/loras_static/root1/preview/recipes/{file_name}"
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True)
|
||||
return "/loras_static/images/no-preview.png"
|
||||
|
||||
|
||||
class RecipeQueryHandler:
|
||||
"""Provide read-only insights on recipe data."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
format_recipe_file_url: Callable[[str], str],
|
||||
logger: Logger,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._format_recipe_file_url = format_recipe_file_url
|
||||
self._logger = logger
|
||||
|
||||
async def get_top_tags(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
limit = int(request.query.get("limit", "20"))
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
tag_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
for tag in recipe.get("tags", []) or []:
|
||||
tag_counts[tag] = tag_counts.get(tag, 0) + 1
|
||||
|
||||
sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()]
|
||||
sorted_tags.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "tags": sorted_tags[:limit]})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving top tags: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_base_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
cache = await recipe_scanner.get_cached_data()
|
||||
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
for recipe in getattr(cache, "raw_data", []):
|
||||
base_model = recipe.get("base_model")
|
||||
if base_model:
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()]
|
||||
sorted_models.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "base_models": sorted_models})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error retrieving base models: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipes_for_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
lora_hash = request.query.get("hash")
|
||||
if not lora_hash:
|
||||
return web.json_response({"success": False, "error": "Lora hash is required"}, status=400)
|
||||
|
||||
matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash)
|
||||
return web.json_response({"success": True, "recipes": matching_recipes})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error getting recipes for Lora: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def scan_recipes(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
self._logger.info("Manually triggering recipe cache rebuild")
|
||||
await recipe_scanner.get_cached_data(force_refresh=True)
|
||||
return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def find_duplicates(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
duplicate_groups = await recipe_scanner.find_all_duplicate_recipes()
|
||||
response_data = []
|
||||
|
||||
for fingerprint, recipe_ids in duplicate_groups.items():
|
||||
if len(recipe_ids) <= 1:
|
||||
continue
|
||||
|
||||
recipes = []
|
||||
for recipe_id in recipe_ids:
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
if recipe:
|
||||
recipes.append(
|
||||
{
|
||||
"id": recipe.get("id"),
|
||||
"title": recipe.get("title"),
|
||||
"file_url": recipe.get("file_url")
|
||||
or self._format_recipe_file_url(recipe.get("file_path", "")),
|
||||
"modified": recipe.get("modified"),
|
||||
"created_date": recipe.get("created_date"),
|
||||
"lora_count": len(recipe.get("loras", [])),
|
||||
}
|
||||
)
|
||||
|
||||
if len(recipes) >= 2:
|
||||
recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True)
|
||||
response_data.append(
|
||||
{
|
||||
"fingerprint": fingerprint,
|
||||
"count": len(recipes),
|
||||
"recipes": recipes,
|
||||
}
|
||||
)
|
||||
|
||||
response_data.sort(key=lambda entry: entry["count"], reverse=True)
|
||||
return web.json_response({"success": True, "duplicate_groups": response_data})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_recipe_syntax(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
try:
|
||||
syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id)
|
||||
except RecipeNotFoundError:
|
||||
return web.json_response({"error": "Recipe not found"}, status=404)
|
||||
|
||||
if not syntax_parts:
|
||||
return web.json_response({"error": "No LoRAs found in this recipe"}, status=400)
|
||||
|
||||
return web.json_response({"success": True, "syntax": " ".join(syntax_parts)})
|
||||
except Exception as exc:
|
||||
self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class RecipeManagementHandler:
|
||||
"""Handle create/update/delete style recipe operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
persistence_service: RecipePersistenceService,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._persistence_service = persistence_service
|
||||
self._analysis_service = analysis_service
|
||||
|
||||
async def save_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
reader = await request.multipart()
|
||||
payload = await self._parse_save_payload(reader)
|
||||
|
||||
result = await self._persistence_service.save_recipe(
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_bytes=payload["image_bytes"],
|
||||
image_base64=payload["image_base64"],
|
||||
name=payload["name"],
|
||||
tags=payload["tags"],
|
||||
metadata=payload["metadata"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def delete_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._persistence_service.delete_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error deleting recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def update_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
data = await request.json()
|
||||
result = await self._persistence_service.update_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error updating recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def reconnect_lora(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
for field in ("recipe_id", "lora_index", "target_name"):
|
||||
if field not in data:
|
||||
raise RecipeValidationError(f"Missing required field: {field}")
|
||||
|
||||
result = await self._persistence_service.reconnect_lora(
|
||||
recipe_scanner=recipe_scanner,
|
||||
recipe_id=data["recipe_id"],
|
||||
lora_index=int(data["lora_index"]),
|
||||
target_name=data["target_name"],
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def bulk_delete(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
recipe_ids = data.get("recipe_ids", [])
|
||||
result = await self._persistence_service.bulk_delete(
|
||||
recipe_scanner=recipe_scanner, recipe_ids=recipe_ids
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error performing bulk delete: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def save_recipe_from_widget(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
analysis = await self._analysis_service.analyze_widget_metadata(
|
||||
recipe_scanner=recipe_scanner
|
||||
)
|
||||
metadata = analysis.payload.get("metadata")
|
||||
image_bytes = analysis.payload.get("image_bytes")
|
||||
if not metadata or image_bytes is None:
|
||||
raise RecipeValidationError("Unable to extract metadata from widget")
|
||||
|
||||
result = await self._persistence_service.save_recipe_from_widget(
|
||||
recipe_scanner=recipe_scanner,
|
||||
metadata=metadata,
|
||||
image_bytes=image_bytes,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=400)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def _parse_save_payload(self, reader) -> dict[str, Any]:
|
||||
image_bytes: Optional[bytes] = None
|
||||
image_base64: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
tags: list[str] = []
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
if field.name == "image":
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
image_bytes = bytes(image_chunks)
|
||||
elif field.name == "image_base64":
|
||||
image_base64 = await field.text()
|
||||
elif field.name == "name":
|
||||
name = await field.text()
|
||||
elif field.name == "tags":
|
||||
tags_text = await field.text()
|
||||
try:
|
||||
parsed_tags = json.loads(tags_text)
|
||||
tags = parsed_tags if isinstance(parsed_tags, list) else []
|
||||
except Exception:
|
||||
tags = []
|
||||
elif field.name == "metadata":
|
||||
metadata_text = await field.text()
|
||||
try:
|
||||
metadata = json.loads(metadata_text)
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
return {
|
||||
"image_bytes": image_bytes,
|
||||
"image_base64": image_base64,
|
||||
"name": name,
|
||||
"tags": tags,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
class RecipeAnalysisHandler:
|
||||
"""Analyze images to extract recipe metadata."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
civitai_client_getter: CivitaiClientGetter,
|
||||
logger: Logger,
|
||||
analysis_service: RecipeAnalysisService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._civitai_client_getter = civitai_client_getter
|
||||
self._logger = logger
|
||||
self._analysis_service = analysis_service
|
||||
|
||||
async def analyze_uploaded_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
civitai_client = self._civitai_client_getter()
|
||||
if recipe_scanner is None or civitai_client is None:
|
||||
raise RuntimeError("Required services unavailable")
|
||||
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
if "multipart/form-data" in content_type:
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
if field is None or field.name != "image":
|
||||
raise RecipeValidationError("No image field found")
|
||||
image_chunks = bytearray()
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
image_chunks.extend(chunk)
|
||||
result = await self._analysis_service.analyze_uploaded_image(
|
||||
image_bytes=bytes(image_chunks),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
if "application/json" in content_type:
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_remote_image(
|
||||
url=data.get("url"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
civitai_client=civitai_client,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
|
||||
raise RecipeValidationError("Unsupported content type")
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeDownloadError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
async def analyze_local_image(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
data = await request.json()
|
||||
result = await self._analysis_service.analyze_local_image(
|
||||
file_path=data.get("path"),
|
||||
recipe_scanner=recipe_scanner,
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeValidationError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=400)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error analyzing local image: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc), "loras": []}, status=500)
|
||||
|
||||
|
||||
class RecipeSharingHandler:
|
||||
"""Serve endpoints related to recipe sharing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ensure_dependencies_ready: EnsureDependenciesCallable,
|
||||
recipe_scanner_getter: RecipeScannerGetter,
|
||||
logger: Logger,
|
||||
sharing_service: RecipeSharingService,
|
||||
) -> None:
|
||||
self._ensure_dependencies_ready = ensure_dependencies_ready
|
||||
self._recipe_scanner_getter = recipe_scanner_getter
|
||||
self._logger = logger
|
||||
self._sharing_service = sharing_service
|
||||
|
||||
async def share_recipe(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
result = await self._sharing_service.share_recipe(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.json_response(result.payload, status=result.status)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error sharing recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
|
||||
async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
await self._ensure_dependencies_ready()
|
||||
recipe_scanner = self._recipe_scanner_getter()
|
||||
if recipe_scanner is None:
|
||||
raise RuntimeError("Recipe scanner unavailable")
|
||||
|
||||
recipe_id = request.match_info["recipe_id"]
|
||||
download_info = await self._sharing_service.prepare_download(
|
||||
recipe_scanner=recipe_scanner, recipe_id=recipe_id
|
||||
)
|
||||
return web.FileResponse(
|
||||
download_info.file_path,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{download_info.download_filename}"'
|
||||
},
|
||||
)
|
||||
except RecipeNotFoundError as exc:
|
||||
return web.json_response({"error": str(exc)}, status=404)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True)
|
||||
return web.json_response({"error": str(exc)}, status=500)
|
||||
64
py/routes/recipe_route_registrar.py
Normal file
64
py/routes/recipe_route_registrar.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Route registrar for recipe endpoints."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Mapping
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RouteDefinition:
|
||||
"""Declarative definition for a recipe HTTP route."""
|
||||
|
||||
method: str
|
||||
path: str
|
||||
handler_name: str
|
||||
|
||||
|
||||
ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/loras/recipes", "render_page"),
|
||||
RouteDefinition("GET", "/api/lm/recipes", "list_recipes"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"),
|
||||
RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"),
|
||||
RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"),
|
||||
RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"),
|
||||
RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"),
|
||||
RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"),
|
||||
RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"),
|
||||
)
|
||||
|
||||
|
||||
class RecipeRouteRegistrar:
|
||||
"""Bind declarative recipe definitions to an aiohttp router."""
|
||||
|
||||
_METHOD_MAP = {
|
||||
"GET": "add_get",
|
||||
"POST": "add_post",
|
||||
"PUT": "add_put",
|
||||
"DELETE": "add_delete",
|
||||
}
|
||||
|
||||
def __init__(self, app: web.Application) -> None:
|
||||
self._app = app
|
||||
|
||||
def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None:
|
||||
for definition in ROUTE_DEFINITIONS:
|
||||
handler = handler_lookup[definition.handler_name]
|
||||
self._bind_route(definition.method, definition.path, handler)
|
||||
|
||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||
add_method_name = self._METHOD_MAP[method.upper()]
|
||||
add_method = getattr(self._app.router, add_method_name)
|
||||
add_method(path, handler)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,99 +4,88 @@ import logging
|
||||
import os
|
||||
|
||||
from ..utils.models import BaseModelMetadata
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from .settings_manager import settings
|
||||
from ..utils.utils import fuzzy_match
|
||||
from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider
|
||||
from .settings_manager import settings as default_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseModelService(ABC):
|
||||
"""Base service class for all model types"""
|
||||
|
||||
def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]):
|
||||
"""Initialize the service
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_type: str,
|
||||
scanner,
|
||||
metadata_class: Type[BaseModelMetadata],
|
||||
*,
|
||||
cache_repository: Optional[ModelCacheRepository] = None,
|
||||
filter_set: Optional[ModelFilterSet] = None,
|
||||
search_strategy: Optional[SearchStrategy] = None,
|
||||
settings_provider: Optional[SettingsProvider] = None,
|
||||
):
|
||||
"""Initialize the service.
|
||||
|
||||
Args:
|
||||
model_type: Type of model (lora, checkpoint, etc.)
|
||||
scanner: Model scanner instance
|
||||
metadata_class: Metadata class for this model type
|
||||
model_type: Type of model (lora, checkpoint, etc.).
|
||||
scanner: Model scanner instance.
|
||||
metadata_class: Metadata class for this model type.
|
||||
cache_repository: Custom repository for cache access (primarily for tests).
|
||||
filter_set: Filter component controlling folder/tag/favorites logic.
|
||||
search_strategy: Search component for fuzzy/text matching.
|
||||
settings_provider: Settings object; defaults to the global settings manager.
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.scanner = scanner
|
||||
self.metadata_class = metadata_class
|
||||
self.settings = settings_provider or default_settings
|
||||
self.cache_repository = cache_repository or ModelCacheRepository(scanner)
|
||||
self.filter_set = filter_set or ModelFilterSet(self.settings)
|
||||
self.search_strategy = search_strategy or SearchStrategy()
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name',
|
||||
folder: str = None, search: str = None, fuzzy_search: bool = False,
|
||||
base_models: list = None, tags: list = None,
|
||||
search_options: dict = None, hash_filters: dict = None,
|
||||
favorites_only: bool = False, **kwargs) -> Dict:
|
||||
"""Get paginated and filtered model data
|
||||
|
||||
Args:
|
||||
page: Page number (1-based)
|
||||
page_size: Number of items per page
|
||||
sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc'
|
||||
folder: Folder filter
|
||||
search: Search term
|
||||
fuzzy_search: Whether to use fuzzy search
|
||||
base_models: List of base models to filter by
|
||||
tags: List of tags to filter by
|
||||
search_options: Search options dict
|
||||
hash_filters: Hash filtering options
|
||||
favorites_only: Filter for favorites only
|
||||
**kwargs: Additional model-specific filters
|
||||
|
||||
Returns:
|
||||
Dict containing paginated results
|
||||
"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
async def get_paginated_data(
|
||||
self,
|
||||
page: int,
|
||||
page_size: int,
|
||||
sort_by: str = 'name',
|
||||
folder: str = None,
|
||||
search: str = None,
|
||||
fuzzy_search: bool = False,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
search_options: dict = None,
|
||||
hash_filters: dict = None,
|
||||
favorites_only: bool = False,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""Get paginated and filtered model data"""
|
||||
sort_params = self.cache_repository.parse_sort(sort_by)
|
||||
sorted_data = await self.cache_repository.fetch_sorted(sort_params)
|
||||
|
||||
# Parse sort_by into sort_key and order
|
||||
if ':' in sort_by:
|
||||
sort_key, order = sort_by.split(':', 1)
|
||||
sort_key = sort_key.strip()
|
||||
order = order.strip().lower()
|
||||
if order not in ('asc', 'desc'):
|
||||
order = 'asc'
|
||||
else:
|
||||
sort_key = sort_by.strip()
|
||||
order = 'asc'
|
||||
|
||||
# Get default search options if not provided
|
||||
if search_options is None:
|
||||
search_options = {
|
||||
'filename': True,
|
||||
'modelname': True,
|
||||
'tags': False,
|
||||
'recursive': True,
|
||||
}
|
||||
|
||||
# Get the base data set using new sort logic
|
||||
filtered_data = await cache.get_sorted_data(sort_key, order)
|
||||
|
||||
# Apply hash filtering if provided (highest priority)
|
||||
if hash_filters:
|
||||
filtered_data = await self._apply_hash_filters(filtered_data, hash_filters)
|
||||
|
||||
# Jump to pagination for hash filters
|
||||
filtered_data = await self._apply_hash_filters(sorted_data, hash_filters)
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
# Apply common filters
|
||||
|
||||
filtered_data = await self._apply_common_filters(
|
||||
filtered_data, folder, base_models, tags, favorites_only, search_options
|
||||
sorted_data,
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=search_options,
|
||||
)
|
||||
|
||||
# Apply search filtering
|
||||
|
||||
if search:
|
||||
filtered_data = await self._apply_search_filters(
|
||||
filtered_data, search, fuzzy_search, search_options
|
||||
filtered_data,
|
||||
search,
|
||||
fuzzy_search,
|
||||
search_options,
|
||||
)
|
||||
|
||||
# Apply model-specific filters
|
||||
|
||||
filtered_data = await self._apply_specific_filters(filtered_data, **kwargs)
|
||||
|
||||
|
||||
return self._paginate(filtered_data, page, page_size)
|
||||
|
||||
|
||||
async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]:
|
||||
"""Apply hash-based filtering"""
|
||||
@@ -120,113 +109,36 @@ class BaseModelService(ABC):
|
||||
|
||||
return data
|
||||
|
||||
async def _apply_common_filters(self, data: List[Dict], folder: str = None,
|
||||
base_models: list = None, tags: list = None,
|
||||
favorites_only: bool = False, search_options: dict = None) -> List[Dict]:
|
||||
async def _apply_common_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
folder: str = None,
|
||||
base_models: list = None,
|
||||
tags: list = None,
|
||||
favorites_only: bool = False,
|
||||
search_options: dict = None,
|
||||
) -> List[Dict]:
|
||||
"""Apply common filters that work across all model types"""
|
||||
# Apply SFW filtering if enabled in settings
|
||||
if settings.get('show_only_sfw', False):
|
||||
data = [
|
||||
item for item in data
|
||||
if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R']
|
||||
]
|
||||
|
||||
# Apply favorites filtering if enabled
|
||||
if favorites_only:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('favorite', False) is True
|
||||
]
|
||||
|
||||
# Apply folder filtering
|
||||
if folder is not None:
|
||||
if search_options and search_options.get('recursive', True):
|
||||
# Recursive folder filtering - include all subfolders
|
||||
# Ensure we match exact folder or its subfolders by checking path boundaries
|
||||
if folder == "":
|
||||
# Empty folder means root - include all items
|
||||
pass # Don't filter anything
|
||||
else:
|
||||
# Add trailing slash to ensure we match folder boundaries correctly
|
||||
folder_with_separator = folder + "/"
|
||||
data = [
|
||||
item for item in data
|
||||
if (item['folder'] == folder or
|
||||
item['folder'].startswith(folder_with_separator))
|
||||
]
|
||||
else:
|
||||
# Exact folder filtering
|
||||
data = [
|
||||
item for item in data
|
||||
if item['folder'] == folder
|
||||
]
|
||||
|
||||
# Apply base model filtering
|
||||
if base_models and len(base_models) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if item.get('base_model') in base_models
|
||||
]
|
||||
|
||||
# Apply tag filtering
|
||||
if tags and len(tags) > 0:
|
||||
data = [
|
||||
item for item in data
|
||||
if any(tag in item.get('tags', []) for tag in tags)
|
||||
]
|
||||
|
||||
return data
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
criteria = FilterCriteria(
|
||||
folder=folder,
|
||||
base_models=base_models,
|
||||
tags=tags,
|
||||
favorites_only=favorites_only,
|
||||
search_options=normalized_options,
|
||||
)
|
||||
return self.filter_set.apply(data, criteria)
|
||||
|
||||
async def _apply_search_filters(self, data: List[Dict], search: str,
|
||||
fuzzy_search: bool, search_options: dict) -> List[Dict]:
|
||||
async def _apply_search_filters(
|
||||
self,
|
||||
data: List[Dict],
|
||||
search: str,
|
||||
fuzzy_search: bool,
|
||||
search_options: dict,
|
||||
) -> List[Dict]:
|
||||
"""Apply search filtering"""
|
||||
search_results = []
|
||||
|
||||
for item in data:
|
||||
# Search by file name
|
||||
if search_options.get('filename', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('file_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('file_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by model name
|
||||
if search_options.get('modelname', True):
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(item.get('model_name', ''), search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in item.get('model_name', '').lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by tags
|
||||
if search_options.get('tags', False) and 'tags' in item:
|
||||
if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower())
|
||||
for tag in item['tags']):
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
# Search by creator
|
||||
civitai = item.get('civitai')
|
||||
creator_username = ''
|
||||
if civitai and isinstance(civitai, dict):
|
||||
creator = civitai.get('creator')
|
||||
if creator and isinstance(creator, dict):
|
||||
creator_username = creator.get('username', '')
|
||||
if search_options.get('creator', False) and creator_username:
|
||||
if fuzzy_search:
|
||||
if fuzzy_match(creator_username, search):
|
||||
search_results.append(item)
|
||||
continue
|
||||
elif search.lower() in creator_username.lower():
|
||||
search_results.append(item)
|
||||
continue
|
||||
|
||||
return search_results
|
||||
normalized_options = self.search_strategy.normalize_options(search_options)
|
||||
return self.search_strategy.apply(data, search, normalized_options, fuzzy_search)
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
"""Apply model-specific filters - to be overridden by subclasses if needed"""
|
||||
@@ -284,6 +196,18 @@ class BaseModelService(ABC):
|
||||
"""Get model root directories"""
|
||||
return self.scanner.get_model_roots()
|
||||
|
||||
def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict:
|
||||
"""Filter relevant fields from CivitAI data"""
|
||||
if not data:
|
||||
return {}
|
||||
|
||||
fields = ["id", "modelId", "name", "trainedWords"] if minimal else [
|
||||
"id", "modelId", "name", "createdAt", "updatedAt",
|
||||
"publishedAt", "trainedWords", "baseModel", "description",
|
||||
"model", "images", "customImages", "creator"
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
async def get_folder_tree(self, model_root: str) -> Dict:
|
||||
"""Get hierarchical folder tree for a specific model root"""
|
||||
cache = await self.scanner.get_cached_data()
|
||||
@@ -394,7 +318,7 @@ class BaseModelService(ABC):
|
||||
|
||||
for model in cache.raw_data:
|
||||
if model.get('file_path') == file_path:
|
||||
return ModelRouteUtils.filter_civitai_data(model.get("civitai", {}))
|
||||
return self.filter_civitai_data(model.get("civitai", {}))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import CheckpointMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,7 +37,7 @@ class CheckpointService(BaseModelService):
|
||||
"notes": checkpoint_data.get("notes", ""),
|
||||
"model_type": checkpoint_data.get("model_type", "checkpoint"),
|
||||
"favorite": checkpoint_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
"civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
|
||||
100
py/services/download_coordinator.py
Normal file
100
py/services/download_coordinator.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Service wrapper for coordinating download lifecycle events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Optional
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadCoordinator:
|
||||
"""Manage download scheduling, cancellation and introspection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
ws_manager,
|
||||
download_manager_factory: Callable[[], Awaitable],
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._download_manager_factory = download_manager_factory
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download using the provided payload."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
|
||||
download_id = payload.get("download_id") or self._ws_manager.generate_download_id()
|
||||
payload.setdefault("download_id", download_id)
|
||||
|
||||
async def progress_callback(progress: Any) -> None:
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "progress",
|
||||
"progress": progress,
|
||||
"download_id": download_id,
|
||||
},
|
||||
)
|
||||
|
||||
model_id = self._parse_optional_int(payload.get("model_id"), "model_id")
|
||||
model_version_id = self._parse_optional_int(
|
||||
payload.get("model_version_id"), "model_version_id"
|
||||
)
|
||||
|
||||
if model_id is None and model_version_id is None:
|
||||
raise ValueError(
|
||||
"Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||
)
|
||||
|
||||
result = await download_manager.download_from_civitai(
|
||||
model_id=model_id,
|
||||
model_version_id=model_version_id,
|
||||
save_dir=payload.get("model_root"),
|
||||
relative_path=payload.get("relative_path", ""),
|
||||
use_default_paths=payload.get("use_default_paths", False),
|
||||
progress_callback=progress_callback,
|
||||
download_id=download_id,
|
||||
source=payload.get("source"),
|
||||
)
|
||||
|
||||
result["download_id"] = download_id
|
||||
return result
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
"""Cancel an active download and emit a broadcast event."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
result = await download_manager.cancel_download(download_id)
|
||||
|
||||
await self._ws_manager.broadcast_download_progress(
|
||||
download_id,
|
||||
{
|
||||
"status": "cancelled",
|
||||
"progress": 0,
|
||||
"download_id": download_id,
|
||||
"message": "Download cancelled by user",
|
||||
},
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||
"""Return the active download map from the underlying manager."""
|
||||
|
||||
download_manager = await self._download_manager_factory()
|
||||
return await download_manager.get_active_downloads()
|
||||
|
||||
def _parse_optional_int(self, value: Any, field: str) -> Optional[int]:
|
||||
"""Parse an optional integer from user input."""
|
||||
|
||||
if value is None or value == "":
|
||||
return None
|
||||
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Invalid {field}: Must be an integer") from exc
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict
|
||||
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import EmbeddingMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,7 +37,7 @@ class EmbeddingService(BaseModelService):
|
||||
"notes": embedding_data.get("notes", ""),
|
||||
"model_type": embedding_data.get("model_type", "embedding"),
|
||||
"favorite": embedding_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||
"civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
def find_duplicate_hashes(self) -> Dict:
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Dict, List, Optional
|
||||
from .base_model_service import BaseModelService
|
||||
from ..utils.models import LoraMetadata
|
||||
from ..config import config
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -38,7 +37,7 @@ class LoraService(BaseModelService):
|
||||
"usage_tips": lora_data.get("usage_tips", ""),
|
||||
"notes": lora_data.get("notes", ""),
|
||||
"favorite": lora_data.get("favorite", False),
|
||||
"civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||
"civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True)
|
||||
}
|
||||
|
||||
async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]:
|
||||
|
||||
355
py/services/metadata_sync_service.py
Normal file
355
py/services/metadata_sync_service.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Services for synchronising metadata with remote providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, Iterable, Optional
|
||||
|
||||
from ..services.settings_manager import SettingsManager
|
||||
from ..utils.model_utils import determine_base_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataProviderProtocol:
|
||||
"""Subset of metadata provider interface consumed by the sync service."""
|
||||
|
||||
async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
...
|
||||
|
||||
async def get_model_version(
|
||||
self, model_id: int, model_version_id: Optional[int]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
|
||||
class MetadataSyncService:
|
||||
"""High level orchestration for metadata synchronisation flows."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_manager,
|
||||
preview_service,
|
||||
settings: SettingsManager,
|
||||
default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]],
|
||||
metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]],
|
||||
) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
self._preview_service = preview_service
|
||||
self._settings = settings
|
||||
self._get_default_provider = default_metadata_provider_factory
|
||||
self._get_provider = metadata_provider_selector
|
||||
|
||||
async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]:
|
||||
"""Load metadata JSON from disk, returning an empty structure when missing."""
|
||||
|
||||
if not os.path.exists(metadata_path):
|
||||
return {}
|
||||
|
||||
try:
|
||||
with open(metadata_path, "r", encoding="utf-8") as handle:
|
||||
return json.load(handle)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error loading metadata from %s: %s", metadata_path, exc)
|
||||
return {}
|
||||
|
||||
async def mark_not_found_on_civitai(
|
||||
self, metadata_path: str, local_metadata: Dict[str, Any]
|
||||
) -> None:
|
||||
"""Persist the not-found flag for a metadata payload."""
|
||||
|
||||
local_metadata["from_civitai"] = False
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
|
||||
@staticmethod
|
||||
def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool:
|
||||
"""Determine if the metadata originated from the CivitAI public API."""
|
||||
|
||||
if not isinstance(meta, dict):
|
||||
return False
|
||||
files = meta.get("files")
|
||||
images = meta.get("images")
|
||||
source = meta.get("source")
|
||||
return bool(files) and bool(images) and source != "archive_db"
|
||||
|
||||
async def update_model_metadata(
|
||||
self,
|
||||
metadata_path: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
civitai_metadata: Dict[str, Any],
|
||||
metadata_provider: Optional[MetadataProviderProtocol] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge remote metadata into the local record and persist the result."""
|
||||
|
||||
existing_civitai = local_metadata.get("civitai") or {}
|
||||
|
||||
if (
|
||||
civitai_metadata.get("source") == "archive_db"
|
||||
and self.is_civitai_api_metadata(existing_civitai)
|
||||
):
|
||||
logger.info(
|
||||
"Skip civitai update for %s (%s)",
|
||||
local_metadata.get("model_name", ""),
|
||||
existing_civitai.get("name", ""),
|
||||
)
|
||||
else:
|
||||
merged_civitai = existing_civitai.copy()
|
||||
merged_civitai.update(civitai_metadata)
|
||||
|
||||
if civitai_metadata.get("source") == "archive_db":
|
||||
model_name = civitai_metadata.get("model", {}).get("name", "")
|
||||
version_name = civitai_metadata.get("name", "")
|
||||
logger.info(
|
||||
"Recovered metadata from archive_db for deleted model: %s (%s)",
|
||||
model_name,
|
||||
version_name,
|
||||
)
|
||||
|
||||
if "trainedWords" in existing_civitai:
|
||||
existing_trained = existing_civitai.get("trainedWords", [])
|
||||
new_trained = civitai_metadata.get("trainedWords", [])
|
||||
merged_trained = list(set(existing_trained + new_trained))
|
||||
merged_civitai["trainedWords"] = merged_trained
|
||||
|
||||
local_metadata["civitai"] = merged_civitai
|
||||
|
||||
if "model" in civitai_metadata and civitai_metadata["model"]:
|
||||
model_data = civitai_metadata["model"]
|
||||
|
||||
if model_data.get("name"):
|
||||
local_metadata["model_name"] = model_data["name"]
|
||||
|
||||
if not local_metadata.get("modelDescription") and model_data.get("description"):
|
||||
local_metadata["modelDescription"] = model_data["description"]
|
||||
|
||||
if not local_metadata.get("tags") and model_data.get("tags"):
|
||||
local_metadata["tags"] = model_data["tags"]
|
||||
|
||||
if model_data.get("creator") and not local_metadata.get("civitai", {}).get(
|
||||
"creator"
|
||||
):
|
||||
local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"]
|
||||
|
||||
local_metadata["base_model"] = determine_base_model(
|
||||
civitai_metadata.get("baseModel")
|
||||
)
|
||||
|
||||
await self._preview_service.ensure_preview_for_metadata(
|
||||
metadata_path, local_metadata, civitai_metadata.get("images", [])
|
||||
)
|
||||
|
||||
await self._metadata_manager.save_metadata(metadata_path, local_metadata)
|
||||
return local_metadata
|
||||
|
||||
async def fetch_and_update_model(
|
||||
self,
|
||||
*,
|
||||
sha256: str,
|
||||
file_path: str,
|
||||
model_data: Dict[str, Any],
|
||||
update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Fetch metadata for a model and update both disk and cache state."""
|
||||
|
||||
if not isinstance(model_data, dict):
|
||||
error = f"Invalid model_data type: {type(model_data)}"
|
||||
logger.error(error)
|
||||
return False, error
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
enable_archive = self._settings.get("enable_metadata_archive_db", False)
|
||||
|
||||
try:
|
||||
if model_data.get("civitai_deleted") is True:
|
||||
if not enable_archive or model_data.get("db_checked") is True:
|
||||
return (
|
||||
False,
|
||||
"CivitAI model is deleted and metadata archive DB is not enabled",
|
||||
)
|
||||
metadata_provider = await self._get_provider("sqlite")
|
||||
else:
|
||||
metadata_provider = await self._get_default_provider()
|
||||
|
||||
civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256)
|
||||
if not civitai_metadata:
|
||||
if error == "Model not found":
|
||||
model_data["from_civitai"] = False
|
||||
model_data["civitai_deleted"] = True
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
data_to_save = model_data.copy()
|
||||
data_to_save.pop("folder", None)
|
||||
await self._metadata_manager.save_metadata(file_path, data_to_save)
|
||||
|
||||
error_msg = (
|
||||
f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
model_data["from_civitai"] = True
|
||||
model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db"
|
||||
model_data["db_checked"] = enable_archive
|
||||
model_data["last_checked_at"] = datetime.now().timestamp()
|
||||
|
||||
local_metadata = model_data.copy()
|
||||
local_metadata.pop("folder", None)
|
||||
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
local_metadata,
|
||||
civitai_metadata,
|
||||
metadata_provider,
|
||||
)
|
||||
|
||||
update_payload = {
|
||||
"model_name": local_metadata.get("model_name"),
|
||||
"preview_url": local_metadata.get("preview_url"),
|
||||
"civitai": local_metadata.get("civitai"),
|
||||
}
|
||||
model_data.update(update_payload)
|
||||
|
||||
await update_cache_func(file_path, file_path, local_metadata)
|
||||
return True, None
|
||||
except KeyError as exc:
|
||||
error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
except Exception as exc: # pragma: no cover - error path
|
||||
error_msg = f"Error fetching metadata: {exc}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, error_msg
|
||||
|
||||
async def fetch_metadata_by_sha(
|
||||
self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None
|
||||
) -> tuple[Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""Fetch metadata for a SHA256 hash from the configured provider."""
|
||||
|
||||
provider = metadata_provider or await self._get_default_provider()
|
||||
return await provider.get_model_by_hash(sha256)
|
||||
|
||||
async def relink_metadata(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
metadata: Dict[str, Any],
|
||||
model_id: int,
|
||||
model_version_id: Optional[int],
|
||||
) -> Dict[str, Any]:
|
||||
"""Relink a local metadata record to a specific CivitAI model version."""
|
||||
|
||||
provider = await self._get_default_provider()
|
||||
civitai_metadata = await provider.get_model_version(model_id, model_version_id)
|
||||
if not civitai_metadata:
|
||||
raise ValueError(
|
||||
f"Model version not found on CivitAI for ID: {model_id}"
|
||||
+ (f" with version: {model_version_id}" if model_version_id else "")
|
||||
)
|
||||
|
||||
primary_model_file: Optional[Dict[str, Any]] = None
|
||||
for file_info in civitai_metadata.get("files", []):
|
||||
if file_info.get("primary", False) and file_info.get("type") == "Model":
|
||||
primary_model_file = file_info
|
||||
break
|
||||
|
||||
if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"):
|
||||
metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower()
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
await self.update_model_metadata(
|
||||
metadata_path,
|
||||
metadata,
|
||||
civitai_metadata,
|
||||
provider,
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
async def save_metadata_updates(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
updates: Dict[str, Any],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Apply metadata updates and persist to disk and cache."""
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
|
||||
for key, value in updates.items():
|
||||
if isinstance(value, dict) and isinstance(metadata.get(key), dict):
|
||||
metadata[key].update(value)
|
||||
else:
|
||||
metadata[key] = value
|
||||
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
await update_cache(file_path, file_path, metadata)
|
||||
|
||||
if "model_name" in updates:
|
||||
logger.debug("Metadata update touched model_name; cache resort required")
|
||||
|
||||
return metadata
|
||||
|
||||
async def verify_duplicate_hashes(
|
||||
self,
|
||||
*,
|
||||
file_paths: Iterable[str],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]],
|
||||
hash_calculator: Callable[[str], Awaitable[str]],
|
||||
update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Verify a collection of files share the same SHA256 hash."""
|
||||
|
||||
file_paths = list(file_paths)
|
||||
if not file_paths:
|
||||
raise ValueError("No file paths provided for verification")
|
||||
|
||||
results = {
|
||||
"verified_as_duplicates": True,
|
||||
"mismatched_files": [],
|
||||
"new_hash_map": {},
|
||||
}
|
||||
|
||||
expected_hash: Optional[str] = None
|
||||
first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json"
|
||||
first_metadata = await metadata_loader(first_metadata_path)
|
||||
if first_metadata and "sha256" in first_metadata:
|
||||
expected_hash = first_metadata["sha256"].lower()
|
||||
|
||||
for path in file_paths:
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
|
||||
try:
|
||||
actual_hash = await hash_calculator(path)
|
||||
metadata_path = os.path.splitext(path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
stored_hash = metadata.get("sha256", "").lower()
|
||||
|
||||
if not expected_hash:
|
||||
expected_hash = stored_hash
|
||||
|
||||
if actual_hash != expected_hash:
|
||||
results["verified_as_duplicates"] = False
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = actual_hash
|
||||
|
||||
if actual_hash != stored_hash:
|
||||
metadata["sha256"] = actual_hash
|
||||
await self._metadata_manager.save_metadata(path, metadata)
|
||||
await update_cache(path, path, metadata)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error("Error verifying hash for %s: %s", path, exc)
|
||||
results["mismatched_files"].append(path)
|
||||
results["new_hash_map"][path] = "error_calculating_hash"
|
||||
results["verified_as_duplicates"] = False
|
||||
|
||||
return results
|
||||
|
||||
245
py/services/model_lifecycle_service.py
Normal file
245
py/services/model_lifecycle_service.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Service routines for model lifecycle mutations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete the primary model artefacts within ``target_dir``."""
|
||||
|
||||
patterns = [
|
||||
f"{file_name}.safetensors",
|
||||
f"{file_name}.metadata.json",
|
||||
]
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{file_name}{ext}")
|
||||
|
||||
deleted: List[str] = []
|
||||
main_file = patterns[0]
|
||||
main_path = os.path.join(target_dir, main_file).replace(os.sep, "/")
|
||||
|
||||
if os.path.exists(main_path):
|
||||
os.remove(main_path)
|
||||
deleted.append(main_path)
|
||||
else:
|
||||
logger.warning("Model file not found: %s", main_file)
|
||||
|
||||
for pattern in patterns[1:]:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted.append(pattern)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.warning("Failed to delete %s: %s", pattern, exc)
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
class ModelLifecycleService:
|
||||
"""Co-ordinate destructive and mutating model operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
scanner,
|
||||
metadata_manager,
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
recipe_scanner_factory: Callable[[], Awaitable] | None = None,
|
||||
) -> None:
|
||||
self._scanner = scanner
|
||||
self._metadata_manager = metadata_manager
|
||||
self._metadata_loader = metadata_loader
|
||||
self._recipe_scanner_factory = (
|
||||
recipe_scanner_factory or ServiceRegistry.get_recipe_scanner
|
||||
)
|
||||
|
||||
async def delete_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Delete a model file and associated artefacts."""
|
||||
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await delete_model_artifacts(target_dir, file_name)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path]
|
||||
await cache.resort()
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
return {"success": True, "deleted_files": deleted_files}
|
||||
|
||||
async def exclude_model(self, file_path: str) -> Dict[str, object]:
|
||||
"""Mark a model as excluded and prune cache references."""
|
||||
|
||||
if not file_path:
|
||||
raise ValueError("Model path is required")
|
||||
|
||||
metadata_path = os.path.splitext(file_path)[0] + ".metadata.json"
|
||||
metadata = await self._metadata_loader(metadata_path)
|
||||
metadata["exclude"] = True
|
||||
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
|
||||
cache = await self._scanner.get_cached_data()
|
||||
model_to_remove = next(
|
||||
(item for item in cache.raw_data if item["file_path"] == file_path),
|
||||
None,
|
||||
)
|
||||
|
||||
if model_to_remove:
|
||||
for tag in model_to_remove.get("tags", []):
|
||||
if tag in getattr(self._scanner, "_tags_count", {}):
|
||||
self._scanner._tags_count[tag] = max(
|
||||
0, self._scanner._tags_count[tag] - 1
|
||||
)
|
||||
if self._scanner._tags_count[tag] == 0:
|
||||
del self._scanner._tags_count[tag]
|
||||
|
||||
if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index:
|
||||
self._scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
cache.raw_data = [
|
||||
item for item in cache.raw_data if item["file_path"] != file_path
|
||||
]
|
||||
await cache.resort()
|
||||
|
||||
excluded = getattr(self._scanner, "_excluded_models", None)
|
||||
if isinstance(excluded, list):
|
||||
excluded.append(file_path)
|
||||
|
||||
message = f"Model {os.path.basename(file_path)} excluded"
|
||||
return {"success": True, "message": message}
|
||||
|
||||
async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]:
|
||||
"""Delete a collection of models via the scanner bulk operation."""
|
||||
|
||||
file_paths = list(file_paths)
|
||||
if not file_paths:
|
||||
raise ValueError("No file paths provided for deletion")
|
||||
|
||||
return await self._scanner.bulk_delete_models(file_paths)
|
||||
|
||||
async def rename_model(
|
||||
self, *, file_path: str, new_file_name: str
|
||||
) -> Dict[str, object]:
|
||||
"""Rename a model and its companion artefacts."""
|
||||
|
||||
if not file_path or not new_file_name:
|
||||
raise ValueError("File path and new file name are required")
|
||||
|
||||
invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"}
|
||||
if any(char in new_file_name for char in invalid_chars):
|
||||
raise ValueError("Invalid characters in file name")
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
|
||||
if os.path.exists(new_file_path):
|
||||
raise ValueError("A file with this name already exists")
|
||||
|
||||
patterns = [
|
||||
f"{old_file_name}.safetensors",
|
||||
f"{old_file_name}.metadata.json",
|
||||
f"{old_file_name}.metadata.json.bak",
|
||||
]
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{old_file_name}{ext}")
|
||||
|
||||
existing_files: List[tuple[str, str]] = []
|
||||
for pattern in patterns:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
existing_files.append((path, pattern))
|
||||
|
||||
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
||||
metadata: Optional[Dict[str, object]] = None
|
||||
hash_value: Optional[str] = None
|
||||
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = await self._metadata_loader(metadata_path)
|
||||
hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None
|
||||
|
||||
renamed_files: List[str] = []
|
||||
new_metadata_path: Optional[str] = None
|
||||
new_preview: Optional[str] = None
|
||||
|
||||
for old_path, pattern in existing_files:
|
||||
ext = self._get_multipart_ext(pattern)
|
||||
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
os.rename(old_path, new_path)
|
||||
renamed_files.append(new_path)
|
||||
|
||||
if ext == ".metadata.json":
|
||||
new_metadata_path = new_path
|
||||
|
||||
if metadata and new_metadata_path:
|
||||
metadata["file_name"] = new_file_name
|
||||
metadata["file_path"] = new_file_path
|
||||
|
||||
if metadata.get("preview_url"):
|
||||
old_preview = str(metadata["preview_url"])
|
||||
ext = self._get_multipart_ext(old_preview)
|
||||
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(
|
||||
os.sep, "/"
|
||||
)
|
||||
metadata["preview_url"] = new_preview
|
||||
|
||||
await self._metadata_manager.save_metadata(new_file_path, metadata)
|
||||
|
||||
if metadata:
|
||||
await self._scanner.update_single_model_cache(
|
||||
file_path, new_file_path, metadata
|
||||
)
|
||||
|
||||
if hash_value and getattr(self._scanner, "model_type", "") == "lora":
|
||||
recipe_scanner = await self._recipe_scanner_factory()
|
||||
if recipe_scanner:
|
||||
try:
|
||||
await recipe_scanner.update_lora_filename_by_hash(
|
||||
hash_value, new_file_name
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error(
|
||||
"Error updating recipe references for %s: %s",
|
||||
file_path,
|
||||
exc,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"new_file_path": new_file_path,
|
||||
"new_preview_path": new_preview,
|
||||
"renamed_files": renamed_files,
|
||||
"reload_required": False,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_multipart_ext(filename: str) -> str:
|
||||
"""Return the extension for files with compound suffixes."""
|
||||
|
||||
parts = filename.split(".")
|
||||
if len(parts) == 3:
|
||||
return "." + ".".join(parts[-2:])
|
||||
if len(parts) >= 4:
|
||||
return "." + ".".join(parts[-3:])
|
||||
return os.path.splitext(filename)[1]
|
||||
|
||||
196
py/services/model_query.py
Normal file
196
py/services/model_query.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable
|
||||
|
||||
from ..utils.constants import NSFW_LEVELS
|
||||
from ..utils.utils import fuzzy_match as default_fuzzy_match
|
||||
|
||||
|
||||
class SettingsProvider(Protocol):
|
||||
"""Protocol describing the SettingsManager contract used by query helpers."""
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SortParams:
|
||||
"""Normalized representation of sorting instructions."""
|
||||
|
||||
key: str
|
||||
order: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FilterCriteria:
|
||||
"""Container for model list filtering options."""
|
||||
|
||||
folder: Optional[str] = None
|
||||
base_models: Optional[Sequence[str]] = None
|
||||
tags: Optional[Sequence[str]] = None
|
||||
favorites_only: bool = False
|
||||
search_options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModelCacheRepository:
|
||||
"""Adapter around scanner cache access and sort normalisation."""
|
||||
|
||||
def __init__(self, scanner) -> None:
|
||||
self._scanner = scanner
|
||||
|
||||
async def get_cache(self):
|
||||
"""Return the underlying cache instance from the scanner."""
|
||||
return await self._scanner.get_cached_data()
|
||||
|
||||
async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]:
|
||||
"""Fetch cached data pre-sorted according to ``params``."""
|
||||
cache = await self.get_cache()
|
||||
return await cache.get_sorted_data(params.key, params.order)
|
||||
|
||||
@staticmethod
|
||||
def parse_sort(sort_by: str) -> SortParams:
|
||||
"""Parse an incoming sort string into key/order primitives."""
|
||||
if not sort_by:
|
||||
return SortParams(key="name", order="asc")
|
||||
|
||||
if ":" in sort_by:
|
||||
raw_key, raw_order = sort_by.split(":", 1)
|
||||
sort_key = raw_key.strip().lower() or "name"
|
||||
order = raw_order.strip().lower()
|
||||
else:
|
||||
sort_key = sort_by.strip().lower() or "name"
|
||||
order = "asc"
|
||||
|
||||
if order not in ("asc", "desc"):
|
||||
order = "asc"
|
||||
|
||||
return SortParams(key=sort_key, order=order)
|
||||
|
||||
|
||||
class ModelFilterSet:
|
||||
"""Applies common filtering rules to the model collection."""
|
||||
|
||||
def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None:
|
||||
self._settings = settings
|
||||
self._nsfw_levels = nsfw_levels or NSFW_LEVELS
|
||||
|
||||
def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]:
|
||||
"""Return items that satisfy the provided criteria."""
|
||||
items = list(data)
|
||||
|
||||
if self._settings.get("show_only_sfw", False):
|
||||
threshold = self._nsfw_levels.get("R", 0)
|
||||
items = [
|
||||
item for item in items
|
||||
if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold
|
||||
]
|
||||
|
||||
if criteria.favorites_only:
|
||||
items = [item for item in items if item.get("favorite", False)]
|
||||
|
||||
folder = criteria.folder
|
||||
options = criteria.search_options or {}
|
||||
recursive = bool(options.get("recursive", True))
|
||||
if folder is not None:
|
||||
if recursive:
|
||||
if folder:
|
||||
folder_with_sep = f"{folder}/"
|
||||
items = [
|
||||
item for item in items
|
||||
if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep)
|
||||
]
|
||||
else:
|
||||
items = [item for item in items if item.get("folder") == folder]
|
||||
|
||||
base_models = criteria.base_models or []
|
||||
if base_models:
|
||||
base_model_set = set(base_models)
|
||||
items = [item for item in items if item.get("base_model") in base_model_set]
|
||||
|
||||
tags = criteria.tags or []
|
||||
if tags:
|
||||
tag_set = set(tags)
|
||||
items = [
|
||||
item for item in items
|
||||
if any(tag in tag_set for tag in item.get("tags", []))
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
class SearchStrategy:
|
||||
"""Encapsulates text and fuzzy matching behaviour for model queries."""
|
||||
|
||||
DEFAULT_OPTIONS: Dict[str, Any] = {
|
||||
"filename": True,
|
||||
"modelname": True,
|
||||
"tags": False,
|
||||
"recursive": True,
|
||||
"creator": False,
|
||||
}
|
||||
|
||||
def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None:
|
||||
self._fuzzy_match = fuzzy_matcher or default_fuzzy_match
|
||||
|
||||
def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Merge provided options with defaults without mutating input."""
|
||||
normalized = dict(self.DEFAULT_OPTIONS)
|
||||
if options:
|
||||
normalized.update(options)
|
||||
return normalized
|
||||
|
||||
def apply(
|
||||
self,
|
||||
data: Iterable[Dict[str, Any]],
|
||||
search_term: str,
|
||||
options: Dict[str, Any],
|
||||
fuzzy: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return items matching the search term using the configured strategy."""
|
||||
if not search_term:
|
||||
return list(data)
|
||||
|
||||
search_lower = search_term.lower()
|
||||
results: List[Dict[str, Any]] = []
|
||||
|
||||
for item in data:
|
||||
if options.get("filename", True):
|
||||
candidate = item.get("file_name", "")
|
||||
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("modelname", True):
|
||||
candidate = item.get("model_name", "")
|
||||
if self._matches(candidate, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("tags", False):
|
||||
tags = item.get("tags", []) or []
|
||||
if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
if options.get("creator", False):
|
||||
creator_username = ""
|
||||
civitai = item.get("civitai")
|
||||
if isinstance(civitai, dict):
|
||||
creator = civitai.get("creator")
|
||||
if isinstance(creator, dict):
|
||||
creator_username = creator.get("username", "")
|
||||
if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy):
|
||||
results.append(item)
|
||||
continue
|
||||
|
||||
return results
|
||||
|
||||
def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool:
|
||||
if not candidate:
|
||||
return False
|
||||
|
||||
candidate_lower = candidate.lower()
|
||||
if fuzzy:
|
||||
return self._fuzzy_match(candidate, search_term)
|
||||
return search_lower in candidate_lower
|
||||
@@ -13,6 +13,7 @@ from ..utils.metadata_manager import MetadataManager
|
||||
from .model_cache import ModelCache
|
||||
from .model_hash_index import ModelHashIndex
|
||||
from ..utils.constants import PREVIEW_EXTENSIONS
|
||||
from .model_lifecycle_service import delete_model_artifacts
|
||||
from .service_registry import ServiceRegistry
|
||||
from .websocket_manager import ws_manager
|
||||
|
||||
@@ -1040,10 +1041,8 @@ class ModelScanner:
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
# Delete all associated files for the model
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
||||
target_dir,
|
||||
deleted_files = await delete_model_artifacts(
|
||||
target_dir,
|
||||
file_name
|
||||
)
|
||||
|
||||
|
||||
168
py/services/preview_asset_service.py
Normal file
168
py/services/preview_asset_service.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Service for processing preview assets for models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Awaitable, Callable, Dict, Optional, Sequence
|
||||
|
||||
from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreviewAssetService:
|
||||
"""Manage fetching and persisting preview assets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
metadata_manager,
|
||||
downloader_factory: Callable[[], Awaitable],
|
||||
exif_utils,
|
||||
) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
self._downloader_factory = downloader_factory
|
||||
self._exif_utils = exif_utils
|
||||
|
||||
async def ensure_preview_for_metadata(
|
||||
self,
|
||||
metadata_path: str,
|
||||
local_metadata: Dict[str, object],
|
||||
images: Sequence[Dict[str, object]] | None,
|
||||
) -> None:
|
||||
"""Ensure preview assets exist for the supplied metadata entry."""
|
||||
|
||||
if local_metadata.get("preview_url") and os.path.exists(
|
||||
str(local_metadata["preview_url"])
|
||||
):
|
||||
return
|
||||
|
||||
if not images:
|
||||
return
|
||||
|
||||
first_preview = images[0]
|
||||
base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0]
|
||||
preview_dir = os.path.dirname(metadata_path)
|
||||
is_video = first_preview.get("type") == "video"
|
||||
|
||||
if is_video:
|
||||
extension = ".mp4"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, result = await downloader.download_file(
|
||||
first_preview["url"], preview_path, use_auth=False
|
||||
)
|
||||
if success:
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
else:
|
||||
extension = ".webp"
|
||||
preview_path = os.path.join(preview_dir, base_name + extension)
|
||||
downloader = await self._downloader_factory()
|
||||
success, content, _headers = await downloader.download_to_memory(
|
||||
first_preview["url"], use_auth=False
|
||||
)
|
||||
if not success:
|
||||
return
|
||||
|
||||
try:
|
||||
optimized_data, _ = self._exif_utils.optimize_image(
|
||||
image_data=content,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(optimized_data)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.error("Error optimizing preview image: %s", exc)
|
||||
try:
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(content)
|
||||
except Exception as save_exc:
|
||||
logger.error("Error saving preview image: %s", save_exc)
|
||||
return
|
||||
|
||||
local_metadata["preview_url"] = preview_path.replace(os.sep, "/")
|
||||
local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0)
|
||||
|
||||
async def replace_preview(
|
||||
self,
|
||||
*,
|
||||
model_path: str,
|
||||
preview_data: bytes,
|
||||
content_type: str,
|
||||
original_filename: Optional[str],
|
||||
nsfw_level: int,
|
||||
update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
) -> Dict[str, object]:
|
||||
"""Replace an existing preview asset for a model."""
|
||||
|
||||
base_name = os.path.splitext(os.path.basename(model_path))[0]
|
||||
folder = os.path.dirname(model_path)
|
||||
|
||||
extension, optimized_data = await self._convert_preview(
|
||||
preview_data, content_type, original_filename
|
||||
)
|
||||
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
existing_preview = os.path.join(folder, base_name + ext)
|
||||
if os.path.exists(existing_preview):
|
||||
try:
|
||||
os.remove(existing_preview)
|
||||
except Exception as exc: # pragma: no cover - defensive path
|
||||
logger.warning(
|
||||
"Failed to delete existing preview %s: %s", existing_preview, exc
|
||||
)
|
||||
|
||||
preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/")
|
||||
with open(preview_path, "wb") as handle:
|
||||
handle.write(optimized_data)
|
||||
|
||||
metadata_path = os.path.splitext(model_path)[0] + ".metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
metadata["preview_url"] = preview_path
|
||||
metadata["preview_nsfw_level"] = nsfw_level
|
||||
await self._metadata_manager.save_metadata(model_path, metadata)
|
||||
|
||||
await update_preview_in_cache(model_path, preview_path, nsfw_level)
|
||||
|
||||
return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level}
|
||||
|
||||
async def _convert_preview(
|
||||
self, data: bytes, content_type: str, original_filename: Optional[str]
|
||||
) -> tuple[str, bytes]:
|
||||
"""Convert preview bytes to the persisted representation."""
|
||||
|
||||
if content_type.startswith("video/"):
|
||||
extension = self._resolve_video_extension(content_type, original_filename)
|
||||
return extension, data
|
||||
|
||||
original_ext = (original_filename or "").lower()
|
||||
if original_ext.endswith(".gif") or content_type.lower() == "image/gif":
|
||||
return ".gif", data
|
||||
|
||||
optimized_data, _ = self._exif_utils.optimize_image(
|
||||
image_data=data,
|
||||
target_width=CARD_PREVIEW_WIDTH,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=False,
|
||||
)
|
||||
return ".webp", optimized_data
|
||||
|
||||
def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str:
|
||||
"""Infer the best extension for a video preview."""
|
||||
|
||||
if original_filename:
|
||||
extension = os.path.splitext(original_filename)[1].lower()
|
||||
if extension in {".mp4", ".webm", ".mov", ".avi"}:
|
||||
return extension
|
||||
|
||||
if "webm" in content_type:
|
||||
return ".webm"
|
||||
return ".mp4"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import Iterable, List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
@@ -10,77 +10,115 @@ class RecipeCache:
|
||||
raw_data: List[Dict]
|
||||
sorted_by_name: List[Dict]
|
||||
sorted_by_date: List[Dict]
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def resort(self, name_only: bool = False):
|
||||
"""Resort all cached data views"""
|
||||
async with self._lock:
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x.get('title', '').lower() # Case-insensitive sort
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('created_date', 'file_path'),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool:
|
||||
self._resort_locked(name_only=name_only)
|
||||
|
||||
async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool:
|
||||
"""Update metadata for a specific recipe in all cached data
|
||||
|
||||
|
||||
Args:
|
||||
recipe_id: The ID of the recipe to update
|
||||
metadata: The new metadata
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the update was successful, False if the recipe wasn't found
|
||||
"""
|
||||
async with self._lock:
|
||||
for item in self.raw_data:
|
||||
if str(item.get('id')) == str(recipe_id):
|
||||
item.update(metadata)
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return True
|
||||
return False # Recipe not found
|
||||
|
||||
async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None:
|
||||
"""Add a new recipe to the cache."""
|
||||
|
||||
# Update in raw_data
|
||||
for item in self.raw_data:
|
||||
if item.get('id') == recipe_id:
|
||||
item.update(metadata)
|
||||
break
|
||||
else:
|
||||
return False # Recipe not found
|
||||
|
||||
# Resort to reflect changes
|
||||
await self.resort()
|
||||
return True
|
||||
|
||||
async def add_recipe(self, recipe_data: Dict) -> None:
|
||||
"""Add a new recipe to the cache
|
||||
|
||||
Args:
|
||||
recipe_data: The recipe data to add
|
||||
"""
|
||||
async with self._lock:
|
||||
self.raw_data.append(recipe_data)
|
||||
await self.resort()
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
|
||||
async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]:
|
||||
"""Remove a recipe from the cache by ID.
|
||||
|
||||
async def remove_recipe(self, recipe_id: str) -> bool:
|
||||
"""Remove a recipe from the cache by ID
|
||||
|
||||
Args:
|
||||
recipe_id: The ID of the recipe to remove
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True if the recipe was found and removed, False otherwise
|
||||
The removed recipe data if found, otherwise ``None``.
|
||||
"""
|
||||
# Find the recipe in raw_data
|
||||
recipe_index = next((i for i, recipe in enumerate(self.raw_data)
|
||||
if recipe.get('id') == recipe_id), None)
|
||||
|
||||
if recipe_index is None:
|
||||
return False
|
||||
|
||||
# Remove from raw_data
|
||||
self.raw_data.pop(recipe_index)
|
||||
|
||||
# Resort to update sorted lists
|
||||
await self.resort()
|
||||
|
||||
return True
|
||||
|
||||
async with self._lock:
|
||||
for index, recipe in enumerate(self.raw_data):
|
||||
if str(recipe.get('id')) == str(recipe_id):
|
||||
removed = self.raw_data.pop(index)
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return removed
|
||||
return None
|
||||
|
||||
async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]:
|
||||
"""Remove multiple recipes from the cache."""
|
||||
|
||||
id_set = {str(recipe_id) for recipe_id in recipe_ids}
|
||||
if not id_set:
|
||||
return []
|
||||
|
||||
async with self._lock:
|
||||
removed = [item for item in self.raw_data if str(item.get('id')) in id_set]
|
||||
if not removed:
|
||||
return []
|
||||
|
||||
self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set]
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return removed
|
||||
|
||||
async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool:
|
||||
"""Replace cached data for a recipe."""
|
||||
|
||||
async with self._lock:
|
||||
for index, recipe in enumerate(self.raw_data):
|
||||
if str(recipe.get('id')) == str(recipe_id):
|
||||
self.raw_data[index] = new_data
|
||||
if resort:
|
||||
self._resort_locked()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_recipe(self, recipe_id: str) -> Optional[Dict]:
|
||||
"""Return a shallow copy of a cached recipe."""
|
||||
|
||||
async with self._lock:
|
||||
for recipe in self.raw_data:
|
||||
if str(recipe.get('id')) == str(recipe_id):
|
||||
return dict(recipe)
|
||||
return None
|
||||
|
||||
async def snapshot(self) -> List[Dict]:
|
||||
"""Return a copy of all cached recipes."""
|
||||
|
||||
async with self._lock:
|
||||
return [dict(item) for item in self.raw_data]
|
||||
|
||||
def _resort_locked(self, *, name_only: bool = False) -> None:
|
||||
"""Sort cached views. Caller must hold ``_lock``."""
|
||||
|
||||
self.sorted_by_name = natsorted(
|
||||
self.raw_data,
|
||||
key=lambda x: x.get('title', '').lower()
|
||||
)
|
||||
if not name_only:
|
||||
self.sorted_by_date = sorted(
|
||||
self.raw_data,
|
||||
key=itemgetter('created_date', 'file_path'),
|
||||
reverse=True
|
||||
)
|
||||
@@ -3,13 +3,14 @@ import logging
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from ..config import config
|
||||
from .recipe_cache import RecipeCache
|
||||
from .service_registry import ServiceRegistry
|
||||
from .lora_scanner import LoraScanner
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
from ..utils.utils import fuzzy_match
|
||||
from .recipes.errors import RecipeNotFoundError
|
||||
from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match
|
||||
from natsort import natsorted
|
||||
import sys
|
||||
|
||||
@@ -46,6 +47,8 @@ class RecipeScanner:
|
||||
self._initialization_lock = asyncio.Lock()
|
||||
self._initialization_task: Optional[asyncio.Task] = None
|
||||
self._is_initializing = False
|
||||
self._mutation_lock = asyncio.Lock()
|
||||
self._resort_tasks: Set[asyncio.Task] = set()
|
||||
if lora_scanner:
|
||||
self._lora_scanner = lora_scanner
|
||||
self._initialized = True
|
||||
@@ -191,6 +194,22 @@ class RecipeScanner:
|
||||
# Clean up the event loop
|
||||
loop.close()
|
||||
|
||||
def _schedule_resort(self, *, name_only: bool = False) -> None:
|
||||
"""Schedule a background resort of the recipe cache."""
|
||||
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
async def _resort_wrapper() -> None:
|
||||
try:
|
||||
await self._cache.resort(name_only=name_only)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True)
|
||||
|
||||
task = asyncio.create_task(_resort_wrapper())
|
||||
self._resort_tasks.add(task)
|
||||
task.add_done_callback(lambda finished: self._resort_tasks.discard(finished))
|
||||
|
||||
@property
|
||||
def recipes_dir(self) -> str:
|
||||
"""Get path to recipes directory"""
|
||||
@@ -255,7 +274,45 @@ class RecipeScanner:
|
||||
|
||||
# Return the cache (may be empty or partially initialized)
|
||||
return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[])
|
||||
|
||||
|
||||
async def refresh_cache(self, force: bool = False) -> RecipeCache:
|
||||
"""Public helper to refresh or return the recipe cache."""
|
||||
|
||||
return await self.get_cached_data(force_refresh=force)
|
||||
|
||||
async def add_recipe(self, recipe_data: Dict[str, Any]) -> None:
|
||||
"""Add a recipe to the in-memory cache."""
|
||||
|
||||
if not recipe_data:
|
||||
return
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
await cache.add_recipe(recipe_data, resort=False)
|
||||
self._schedule_resort()
|
||||
|
||||
async def remove_recipe(self, recipe_id: str) -> bool:
|
||||
"""Remove a recipe from the cache by ID."""
|
||||
|
||||
if not recipe_id:
|
||||
return False
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
removed = await cache.remove_recipe(recipe_id, resort=False)
|
||||
if removed is None:
|
||||
return False
|
||||
|
||||
self._schedule_resort()
|
||||
return True
|
||||
|
||||
async def bulk_remove(self, recipe_ids: Iterable[str]) -> int:
|
||||
"""Remove multiple recipes from the cache."""
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
removed = await cache.bulk_remove(recipe_ids, resort=False)
|
||||
if removed:
|
||||
self._schedule_resort()
|
||||
return len(removed)
|
||||
|
||||
async def scan_all_recipes(self) -> List[Dict]:
|
||||
"""Scan all recipe JSON files and return metadata"""
|
||||
recipes = []
|
||||
@@ -326,7 +383,6 @@ class RecipeScanner:
|
||||
|
||||
# Calculate and update fingerprint if missing
|
||||
if 'loras' in recipe_data and 'fingerprint' not in recipe_data:
|
||||
from ..utils.utils import calculate_recipe_fingerprint
|
||||
fingerprint = calculate_recipe_fingerprint(recipe_data['loras'])
|
||||
recipe_data['fingerprint'] = fingerprint
|
||||
|
||||
@@ -497,9 +553,36 @@ class RecipeScanner:
|
||||
logger.error(f"Error getting base model for lora: {e}")
|
||||
return None
|
||||
|
||||
def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Populate convenience fields for a LoRA entry."""
|
||||
|
||||
if not lora or not self._lora_scanner:
|
||||
return lora
|
||||
|
||||
hash_value = (lora.get('hash') or '').lower()
|
||||
if not hash_value:
|
||||
return lora
|
||||
|
||||
try:
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(hash_value)
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value)
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.debug("Error enriching lora entry %s: %s", hash_value, exc)
|
||||
|
||||
return lora
|
||||
|
||||
async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Lookup a local LoRA model by name."""
|
||||
|
||||
if not self._lora_scanner or not name:
|
||||
return None
|
||||
|
||||
return await self._lora_scanner.get_model_info_by_name(name)
|
||||
|
||||
async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True):
|
||||
"""Get paginated and filtered recipe data
|
||||
|
||||
|
||||
Args:
|
||||
page: Current page number (1-based)
|
||||
page_size: Number of items per page
|
||||
@@ -598,16 +681,12 @@ class RecipeScanner:
|
||||
|
||||
# Get paginated items
|
||||
paginated_items = filtered_data[start_idx:end_idx]
|
||||
|
||||
|
||||
# Add inLibrary information for each lora
|
||||
for item in paginated_items:
|
||||
if 'loras' in item:
|
||||
for lora in item['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower())
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower())
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower())
|
||||
|
||||
item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']]
|
||||
|
||||
result = {
|
||||
'items': paginated_items,
|
||||
'total': total_items,
|
||||
@@ -653,13 +732,8 @@ class RecipeScanner:
|
||||
|
||||
# Add lora metadata
|
||||
if 'loras' in formatted_recipe:
|
||||
for lora in formatted_recipe['loras']:
|
||||
if 'hash' in lora and lora['hash']:
|
||||
lora_hash = lora['hash'].lower()
|
||||
lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash)
|
||||
lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash)
|
||||
lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash)
|
||||
|
||||
formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']]
|
||||
|
||||
return formatted_recipe
|
||||
|
||||
def _format_file_url(self, file_path: str) -> str:
|
||||
@@ -717,26 +791,159 @@ class RecipeScanner:
|
||||
# Save updated recipe
|
||||
with open(recipe_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(recipe_data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
# Update the cache if it exists
|
||||
if self._cache is not None:
|
||||
await self._cache.update_recipe_metadata(recipe_id, metadata)
|
||||
|
||||
await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False)
|
||||
self._schedule_resort()
|
||||
|
||||
# If the recipe has an image, update its EXIF metadata
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
image_path = recipe_data.get('file_path')
|
||||
if image_path and os.path.exists(image_path):
|
||||
ExifUtils.append_recipe_metadata(image_path, recipe_data)
|
||||
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_lora_entry(
|
||||
self,
|
||||
recipe_id: str,
|
||||
lora_index: int,
|
||||
*,
|
||||
target_name: str,
|
||||
target_lora: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""Update a specific LoRA entry within a recipe.
|
||||
|
||||
Returns the updated recipe data and the refreshed LoRA metadata.
|
||||
"""
|
||||
|
||||
if target_name is None:
|
||||
raise ValueError("target_name must be provided")
|
||||
|
||||
recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
async with self._mutation_lock:
|
||||
with open(recipe_json_path, 'r', encoding='utf-8') as file_obj:
|
||||
recipe_data = json.load(file_obj)
|
||||
|
||||
loras = recipe_data.get('loras', [])
|
||||
if lora_index >= len(loras):
|
||||
raise RecipeNotFoundError("LoRA index out of range in recipe")
|
||||
|
||||
lora_entry = loras[lora_index]
|
||||
lora_entry['isDeleted'] = False
|
||||
lora_entry['exclude'] = False
|
||||
lora_entry['file_name'] = target_name
|
||||
|
||||
if target_lora is not None:
|
||||
sha_value = target_lora.get('sha256') or target_lora.get('sha')
|
||||
if sha_value:
|
||||
lora_entry['hash'] = sha_value.lower()
|
||||
|
||||
civitai_info = target_lora.get('civitai') or {}
|
||||
if civitai_info:
|
||||
lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '')
|
||||
lora_entry['modelVersionName'] = civitai_info.get('name', '')
|
||||
lora_entry['modelVersionId'] = civitai_info.get('id')
|
||||
|
||||
recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', []))
|
||||
recipe_data['modified'] = time.time()
|
||||
|
||||
with open(recipe_json_path, 'w', encoding='utf-8') as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False)
|
||||
if not replaced:
|
||||
await cache.add_recipe(recipe_data, resort=False)
|
||||
self._schedule_resort()
|
||||
|
||||
updated_lora = dict(lora_entry)
|
||||
if target_lora is not None:
|
||||
preview_url = target_lora.get('preview_url')
|
||||
if preview_url:
|
||||
updated_lora['preview_url'] = config.get_preview_static_url(preview_url)
|
||||
if target_lora.get('file_path'):
|
||||
updated_lora['localPath'] = target_lora['file_path']
|
||||
|
||||
updated_lora = self._enrich_lora_entry(updated_lora)
|
||||
return recipe_data, updated_lora
|
||||
|
||||
async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]:
|
||||
"""Return recipes that reference a given LoRA hash."""
|
||||
|
||||
if not lora_hash:
|
||||
return []
|
||||
|
||||
normalized_hash = lora_hash.lower()
|
||||
cache = await self.get_cached_data()
|
||||
matching_recipes: List[Dict[str, Any]] = []
|
||||
|
||||
for recipe in cache.raw_data:
|
||||
loras = recipe.get('loras', [])
|
||||
if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras):
|
||||
recipe_copy = {**recipe}
|
||||
recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras]
|
||||
recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path'))
|
||||
matching_recipes.append(recipe_copy)
|
||||
|
||||
return matching_recipes
|
||||
|
||||
async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]:
|
||||
"""Build LoRA syntax tokens for a recipe."""
|
||||
|
||||
cache = await self.get_cached_data()
|
||||
recipe = await cache.get_recipe(recipe_id)
|
||||
if recipe is None:
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
loras = recipe.get('loras', [])
|
||||
if not loras:
|
||||
return []
|
||||
|
||||
lora_cache = None
|
||||
if self._lora_scanner is not None:
|
||||
lora_cache = await self._lora_scanner.get_cached_data()
|
||||
|
||||
syntax_parts: List[str] = []
|
||||
for lora in loras:
|
||||
if lora.get('isDeleted', False):
|
||||
continue
|
||||
|
||||
file_name = None
|
||||
hash_value = (lora.get('hash') or '').lower()
|
||||
if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'):
|
||||
file_path = self._lora_scanner._hash_index.get_path(hash_value)
|
||||
if file_path:
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
if not file_name and lora.get('modelVersionId') and lora_cache is not None:
|
||||
for cached_lora in getattr(lora_cache, 'raw_data', []):
|
||||
civitai_info = cached_lora.get('civitai')
|
||||
if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'):
|
||||
cached_path = cached_lora.get('path') or cached_lora.get('file_path')
|
||||
if cached_path:
|
||||
file_name = os.path.splitext(os.path.basename(cached_path))[0]
|
||||
break
|
||||
|
||||
if not file_name:
|
||||
file_name = lora.get('file_name', 'unknown-lora')
|
||||
|
||||
strength = lora.get('strength', 1.0)
|
||||
syntax_parts.append(f"<lora:{file_name}:{strength}>")
|
||||
|
||||
return syntax_parts
|
||||
|
||||
async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]:
|
||||
"""Update file_name in all recipes that contain a LoRA with the specified hash.
|
||||
|
||||
|
||||
Args:
|
||||
hash_value: The SHA256 hash value of the LoRA
|
||||
new_file_name: The new file_name to set
|
||||
|
||||
23
py/services/recipes/__init__.py
Normal file
23
py/services/recipes/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Recipe service layer implementations."""
|
||||
|
||||
from .analysis_service import RecipeAnalysisService
|
||||
from .persistence_service import RecipePersistenceService
|
||||
from .sharing_service import RecipeSharingService
|
||||
from .errors import (
|
||||
RecipeServiceError,
|
||||
RecipeValidationError,
|
||||
RecipeNotFoundError,
|
||||
RecipeDownloadError,
|
||||
RecipeConflictError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RecipeAnalysisService",
|
||||
"RecipePersistenceService",
|
||||
"RecipeSharingService",
|
||||
"RecipeServiceError",
|
||||
"RecipeValidationError",
|
||||
"RecipeNotFoundError",
|
||||
"RecipeDownloadError",
|
||||
"RecipeConflictError",
|
||||
]
|
||||
289
py/services/recipes/analysis_service.py
Normal file
289
py/services/recipes/analysis_service.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""Services responsible for recipe metadata analysis."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ...utils.utils import calculate_recipe_fingerprint
|
||||
from .errors import (
|
||||
RecipeDownloadError,
|
||||
RecipeNotFoundError,
|
||||
RecipeServiceError,
|
||||
RecipeValidationError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnalysisResult:
|
||||
"""Return payload from analysis operations."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
class RecipeAnalysisService:
|
||||
"""Extract recipe metadata from various image sources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exif_utils,
|
||||
recipe_parser_factory,
|
||||
downloader_factory: Callable[[], Any],
|
||||
metadata_collector: Optional[Callable[[], Any]] = None,
|
||||
metadata_processor_cls: Optional[type] = None,
|
||||
metadata_registry_cls: Optional[type] = None,
|
||||
standalone_mode: bool = False,
|
||||
logger,
|
||||
) -> None:
|
||||
self._exif_utils = exif_utils
|
||||
self._recipe_parser_factory = recipe_parser_factory
|
||||
self._downloader_factory = downloader_factory
|
||||
self._metadata_collector = metadata_collector
|
||||
self._metadata_processor_cls = metadata_processor_cls
|
||||
self._metadata_registry_cls = metadata_registry_cls
|
||||
self._standalone_mode = standalone_mode
|
||||
self._logger = logger
|
||||
|
||||
async def analyze_uploaded_image(
|
||||
self,
|
||||
*,
|
||||
image_bytes: bytes | None,
|
||||
recipe_scanner,
|
||||
) -> AnalysisResult:
|
||||
"""Analyze an uploaded image payload."""
|
||||
|
||||
if not image_bytes:
|
||||
raise RecipeValidationError("No image data provided")
|
||||
|
||||
temp_path = self._write_temp_file(image_bytes)
|
||||
try:
|
||||
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||
if not metadata:
|
||||
return AnalysisResult({"error": "No metadata found in this image", "loras": []})
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_path=None,
|
||||
include_image_base64=False,
|
||||
)
|
||||
finally:
|
||||
self._safe_cleanup(temp_path)
|
||||
|
||||
async def analyze_remote_image(
|
||||
self,
|
||||
*,
|
||||
url: str | None,
|
||||
recipe_scanner,
|
||||
civitai_client,
|
||||
) -> AnalysisResult:
|
||||
"""Analyze an image accessible via URL, including Civitai integration."""
|
||||
|
||||
if not url:
|
||||
raise RecipeValidationError("No URL provided")
|
||||
|
||||
if civitai_client is None:
|
||||
raise RecipeServiceError("Civitai client unavailable")
|
||||
|
||||
temp_path = self._create_temp_path()
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
try:
|
||||
civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url)
|
||||
if civitai_match:
|
||||
image_info = await civitai_client.get_image_info(civitai_match.group(1))
|
||||
if not image_info:
|
||||
raise RecipeDownloadError("Failed to fetch image information from Civitai")
|
||||
image_url = image_info.get("url")
|
||||
if not image_url:
|
||||
raise RecipeDownloadError("No image URL found in Civitai response")
|
||||
await self._download_image(image_url, temp_path)
|
||||
metadata = image_info.get("meta") if "meta" in image_info else None
|
||||
else:
|
||||
await self._download_image(url, temp_path)
|
||||
|
||||
if metadata is None:
|
||||
metadata = self._exif_utils.extract_image_metadata(temp_path)
|
||||
|
||||
if not metadata:
|
||||
return self._metadata_not_found_response(temp_path)
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_path=temp_path,
|
||||
include_image_base64=True,
|
||||
)
|
||||
finally:
|
||||
self._safe_cleanup(temp_path)
|
||||
|
||||
async def analyze_local_image(
|
||||
self,
|
||||
*,
|
||||
file_path: str | None,
|
||||
recipe_scanner,
|
||||
) -> AnalysisResult:
|
||||
"""Analyze a file already present on disk."""
|
||||
|
||||
if not file_path:
|
||||
raise RecipeValidationError("No file path provided")
|
||||
|
||||
normalized_path = os.path.normpath(file_path.strip('"').strip("'"))
|
||||
if not os.path.isfile(normalized_path):
|
||||
raise RecipeNotFoundError("File not found")
|
||||
|
||||
metadata = self._exif_utils.extract_image_metadata(normalized_path)
|
||||
if not metadata:
|
||||
return self._metadata_not_found_response(normalized_path)
|
||||
|
||||
return await self._parse_metadata(
|
||||
metadata,
|
||||
recipe_scanner=recipe_scanner,
|
||||
image_path=normalized_path,
|
||||
include_image_base64=True,
|
||||
)
|
||||
|
||||
async def analyze_widget_metadata(self, *, recipe_scanner) -> AnalysisResult:
|
||||
"""Analyse the most recent generation metadata for widget saves."""
|
||||
|
||||
if self._metadata_collector is None or self._metadata_processor_cls is None:
|
||||
raise RecipeValidationError("Metadata collection not available")
|
||||
|
||||
raw_metadata = self._metadata_collector()
|
||||
metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata)
|
||||
if not metadata_dict:
|
||||
raise RecipeValidationError("No generation metadata found")
|
||||
|
||||
latest_image = None
|
||||
if not self._standalone_mode and self._metadata_registry_cls is not None:
|
||||
metadata_registry = self._metadata_registry_cls()
|
||||
latest_image = metadata_registry.get_first_decoded_image()
|
||||
|
||||
if latest_image is None:
|
||||
raise RecipeValidationError(
|
||||
"No recent images found to use for recipe. Try generating an image first."
|
||||
)
|
||||
|
||||
image_bytes = self._convert_tensor_to_png_bytes(latest_image)
|
||||
if image_bytes is None:
|
||||
raise RecipeValidationError("Cannot handle this data shape from metadata registry")
|
||||
|
||||
return AnalysisResult(
|
||||
{
|
||||
"metadata": metadata_dict,
|
||||
"image_bytes": image_bytes,
|
||||
}
|
||||
)
|
||||
|
||||
# Internal helpers -------------------------------------------------
|
||||
|
||||
async def _parse_metadata(
|
||||
self,
|
||||
metadata: dict[str, Any],
|
||||
*,
|
||||
recipe_scanner,
|
||||
image_path: Optional[str],
|
||||
include_image_base64: bool,
|
||||
) -> AnalysisResult:
|
||||
parser = self._recipe_parser_factory.create_parser(metadata)
|
||||
if parser is None:
|
||||
payload = {"error": "No parser found for this image", "loras": []}
|
||||
if include_image_base64 and image_path:
|
||||
payload["image_base64"] = self._encode_file(image_path)
|
||||
return AnalysisResult(payload)
|
||||
|
||||
result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner)
|
||||
|
||||
if include_image_base64 and image_path:
|
||||
result["image_base64"] = self._encode_file(image_path)
|
||||
|
||||
if "error" in result and not result.get("loras"):
|
||||
return AnalysisResult(result)
|
||||
|
||||
fingerprint = calculate_recipe_fingerprint(result.get("loras", []))
|
||||
result["fingerprint"] = fingerprint
|
||||
|
||||
matching_recipes: list[str] = []
|
||||
if fingerprint:
|
||||
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
||||
result["matching_recipes"] = matching_recipes
|
||||
|
||||
return AnalysisResult(result)
|
||||
|
||||
async def _download_image(self, url: str, temp_path: str) -> None:
|
||||
downloader = await self._downloader_factory()
|
||||
success, result = await downloader.download_file(url, temp_path, use_auth=False)
|
||||
if not success:
|
||||
raise RecipeDownloadError(f"Failed to download image from URL: {result}")
|
||||
|
||||
def _metadata_not_found_response(self, path: str) -> AnalysisResult:
|
||||
payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []}
|
||||
if os.path.exists(path):
|
||||
payload["image_base64"] = self._encode_file(path)
|
||||
return AnalysisResult(payload)
|
||||
|
||||
def _write_temp_file(self, data: bytes) -> str:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
temp_file.write(data)
|
||||
return temp_file.name
|
||||
|
||||
def _create_temp_path(self) -> str:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
return temp_file.name
|
||||
|
||||
def _safe_cleanup(self, path: Optional[str]) -> None:
|
||||
if path and os.path.exists(path):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error("Error deleting temporary file: %s", exc)
|
||||
|
||||
def _encode_file(self, path: str) -> str:
|
||||
with open(path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
|
||||
def _convert_tensor_to_png_bytes(self, latest_image: Any) -> Optional[bytes]:
|
||||
try:
|
||||
if isinstance(latest_image, tuple):
|
||||
tensor_image = latest_image[0] if latest_image else None
|
||||
if tensor_image is None:
|
||||
return None
|
||||
else:
|
||||
tensor_image = latest_image
|
||||
|
||||
if hasattr(tensor_image, "shape"):
|
||||
self._logger.debug(
|
||||
"Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None)
|
||||
)
|
||||
|
||||
import torch # type: ignore[import-not-found]
|
||||
|
||||
if isinstance(tensor_image, torch.Tensor):
|
||||
image_np = tensor_image.cpu().numpy()
|
||||
else:
|
||||
image_np = np.array(tensor_image)
|
||||
|
||||
while len(image_np.shape) > 3:
|
||||
image_np = image_np[0]
|
||||
|
||||
if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0:
|
||||
image_np = (image_np * 255).astype(np.uint8)
|
||||
|
||||
if len(image_np.shape) == 3 and image_np.shape[2] == 3:
|
||||
pil_image = Image.fromarray(image_np)
|
||||
img_byte_arr = io.BytesIO()
|
||||
pil_image.save(img_byte_arr, format="PNG")
|
||||
return img_byte_arr.getvalue()
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
self._logger.error("Error processing image data: %s", exc, exc_info=True)
|
||||
return None
|
||||
|
||||
return None
|
||||
22
py/services/recipes/errors.py
Normal file
22
py/services/recipes/errors.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Shared exceptions for recipe services."""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class RecipeServiceError(Exception):
|
||||
"""Base exception for recipe service failures."""
|
||||
|
||||
|
||||
class RecipeValidationError(RecipeServiceError):
|
||||
"""Raised when a request payload fails validation."""
|
||||
|
||||
|
||||
class RecipeNotFoundError(RecipeServiceError):
|
||||
"""Raised when a recipe resource cannot be located."""
|
||||
|
||||
|
||||
class RecipeDownloadError(RecipeServiceError):
|
||||
"""Raised when remote recipe assets cannot be downloaded."""
|
||||
|
||||
|
||||
class RecipeConflictError(RecipeServiceError):
|
||||
"""Raised when a conflicting recipe state is detected."""
|
||||
400
py/services/recipes/persistence_service.py
Normal file
400
py/services/recipes/persistence_service.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""Services encapsulating recipe persistence workflows."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from ...config import config
|
||||
from ...utils.utils import calculate_recipe_fingerprint
|
||||
from .errors import RecipeNotFoundError, RecipeValidationError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PersistenceResult:
|
||||
"""Return payload from persistence operations."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
class RecipePersistenceService:
|
||||
"""Coordinate recipe persistence tasks across storage and caches."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
exif_utils,
|
||||
card_preview_width: int,
|
||||
logger,
|
||||
) -> None:
|
||||
self._exif_utils = exif_utils
|
||||
self._card_preview_width = card_preview_width
|
||||
self._logger = logger
|
||||
|
||||
async def save_recipe(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
image_bytes: bytes | None,
|
||||
image_base64: str | None,
|
||||
name: str | None,
|
||||
tags: Iterable[str],
|
||||
metadata: Optional[dict[str, Any]],
|
||||
) -> PersistenceResult:
|
||||
"""Persist a user uploaded recipe."""
|
||||
|
||||
missing_fields = []
|
||||
if not name:
|
||||
missing_fields.append("name")
|
||||
if metadata is None:
|
||||
missing_fields.append("metadata")
|
||||
if missing_fields:
|
||||
raise RecipeValidationError(
|
||||
f"Missing required fields: {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
resolved_image_bytes = self._resolve_image_bytes(image_bytes, image_base64)
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
os.makedirs(recipes_dir, exist_ok=True)
|
||||
|
||||
recipe_id = str(uuid.uuid4())
|
||||
optimized_image, extension = self._exif_utils.optimize_image(
|
||||
image_data=resolved_image_bytes,
|
||||
target_width=self._card_preview_width,
|
||||
format="webp",
|
||||
quality=85,
|
||||
preserve_metadata=True,
|
||||
)
|
||||
image_filename = f"{recipe_id}{extension}"
|
||||
image_path = os.path.join(recipes_dir, image_filename)
|
||||
with open(image_path, "wb") as file_obj:
|
||||
file_obj.write(optimized_image)
|
||||
|
||||
current_time = time.time()
|
||||
loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])]
|
||||
|
||||
gen_params = metadata.get("gen_params", {})
|
||||
if not gen_params and "raw_metadata" in metadata:
|
||||
raw_metadata = metadata.get("raw_metadata", {})
|
||||
gen_params = {
|
||||
"prompt": raw_metadata.get("prompt", ""),
|
||||
"negative_prompt": raw_metadata.get("negative_prompt", ""),
|
||||
"checkpoint": raw_metadata.get("checkpoint", {}),
|
||||
"steps": raw_metadata.get("steps", ""),
|
||||
"sampler": raw_metadata.get("sampler", ""),
|
||||
"cfg_scale": raw_metadata.get("cfg_scale", ""),
|
||||
"seed": raw_metadata.get("seed", ""),
|
||||
"size": raw_metadata.get("size", ""),
|
||||
"clip_skip": raw_metadata.get("clip_skip", ""),
|
||||
}
|
||||
|
||||
fingerprint = calculate_recipe_fingerprint(loras_data)
|
||||
recipe_data: Dict[str, Any] = {
|
||||
"id": recipe_id,
|
||||
"file_path": image_path,
|
||||
"title": name,
|
||||
"modified": current_time,
|
||||
"created_date": current_time,
|
||||
"base_model": metadata.get("base_model", ""),
|
||||
"loras": loras_data,
|
||||
"gen_params": gen_params,
|
||||
"fingerprint": fingerprint,
|
||||
}
|
||||
|
||||
tags_list = list(tags)
|
||||
if tags_list:
|
||||
recipe_data["tags"] = tags_list
|
||||
|
||||
if metadata.get("source_path"):
|
||||
recipe_data["source_path"] = metadata.get("source_path")
|
||||
|
||||
json_filename = f"{recipe_id}.recipe.json"
|
||||
json_path = os.path.join(recipes_dir, json_filename)
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||
|
||||
matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id)
|
||||
await recipe_scanner.add_recipe(recipe_data)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"image_path": image_path,
|
||||
"json_path": json_path,
|
||||
"matching_recipes": matching_recipes,
|
||||
}
|
||||
)
|
||||
|
||||
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult:
|
||||
"""Delete an existing recipe."""
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||
raise RecipeNotFoundError("Recipes directory not found")
|
||||
|
||||
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
|
||||
recipe_data = json.load(file_obj)
|
||||
|
||||
image_path = recipe_data.get("file_path")
|
||||
os.remove(recipe_json_path)
|
||||
if image_path and os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
|
||||
await recipe_scanner.remove_recipe(recipe_id)
|
||||
return PersistenceResult({"success": True, "message": "Recipe deleted successfully"})
|
||||
|
||||
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult:
|
||||
"""Update persisted metadata for a recipe."""
|
||||
|
||||
if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")):
|
||||
raise RecipeValidationError(
|
||||
"At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)"
|
||||
)
|
||||
|
||||
success = await recipe_scanner.update_recipe_metadata(recipe_id, updates)
|
||||
if not success:
|
||||
raise RecipeNotFoundError("Recipe not found or update failed")
|
||||
|
||||
return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates})
|
||||
|
||||
async def reconnect_lora(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
recipe_id: str,
|
||||
lora_index: int,
|
||||
target_name: str,
|
||||
) -> PersistenceResult:
|
||||
"""Reconnect a LoRA entry within an existing recipe."""
|
||||
|
||||
recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_path):
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
target_lora = await recipe_scanner.get_local_lora(target_name)
|
||||
if not target_lora:
|
||||
raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}")
|
||||
|
||||
recipe_data, updated_lora = await recipe_scanner.update_lora_entry(
|
||||
recipe_id,
|
||||
lora_index,
|
||||
target_name=target_name,
|
||||
target_lora=target_lora,
|
||||
)
|
||||
|
||||
image_path = recipe_data.get("file_path")
|
||||
if image_path and os.path.exists(image_path):
|
||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||
|
||||
matching_recipes = []
|
||||
if "fingerprint" in recipe_data:
|
||||
matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"])
|
||||
if recipe_id in matching_recipes:
|
||||
matching_recipes.remove(recipe_id)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"updated_lora": updated_lora,
|
||||
"matching_recipes": matching_recipes,
|
||||
}
|
||||
)
|
||||
|
||||
async def bulk_delete(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
recipe_ids: Iterable[str],
|
||||
) -> PersistenceResult:
|
||||
"""Delete multiple recipes in a single request."""
|
||||
|
||||
recipe_ids = list(recipe_ids)
|
||||
if not recipe_ids:
|
||||
raise RecipeValidationError("No recipe IDs provided")
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
if not recipes_dir or not os.path.exists(recipes_dir):
|
||||
raise RecipeNotFoundError("Recipes directory not found")
|
||||
|
||||
deleted_recipes: list[str] = []
|
||||
failed_recipes: list[dict[str, Any]] = []
|
||||
|
||||
for recipe_id in recipe_ids:
|
||||
recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json")
|
||||
if not os.path.exists(recipe_json_path):
|
||||
failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"})
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(recipe_json_path, "r", encoding="utf-8") as file_obj:
|
||||
recipe_data = json.load(file_obj)
|
||||
image_path = recipe_data.get("file_path")
|
||||
os.remove(recipe_json_path)
|
||||
if image_path and os.path.exists(image_path):
|
||||
os.remove(image_path)
|
||||
deleted_recipes.append(recipe_id)
|
||||
except Exception as exc:
|
||||
failed_recipes.append({"id": recipe_id, "reason": str(exc)})
|
||||
|
||||
if deleted_recipes:
|
||||
await recipe_scanner.bulk_remove(deleted_recipes)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"deleted": deleted_recipes,
|
||||
"failed": failed_recipes,
|
||||
"total_deleted": len(deleted_recipes),
|
||||
"total_failed": len(failed_recipes),
|
||||
}
|
||||
)
|
||||
|
||||
async def save_recipe_from_widget(
|
||||
self,
|
||||
*,
|
||||
recipe_scanner,
|
||||
metadata: dict[str, Any],
|
||||
image_bytes: bytes,
|
||||
) -> PersistenceResult:
|
||||
"""Save a recipe constructed from widget metadata."""
|
||||
|
||||
if not metadata:
|
||||
raise RecipeValidationError("No generation metadata found")
|
||||
|
||||
recipes_dir = recipe_scanner.recipes_dir
|
||||
os.makedirs(recipes_dir, exist_ok=True)
|
||||
|
||||
recipe_id = str(uuid.uuid4())
|
||||
image_filename = f"{recipe_id}.png"
|
||||
image_path = os.path.join(recipes_dir, image_filename)
|
||||
with open(image_path, "wb") as file_obj:
|
||||
file_obj.write(image_bytes)
|
||||
|
||||
lora_stack = metadata.get("loras", "")
|
||||
lora_matches = re.findall(r"<lora:([^:]+):([^>]+)>", lora_stack)
|
||||
if not lora_matches:
|
||||
raise RecipeValidationError("No LoRAs found in the generation metadata")
|
||||
|
||||
loras_data = []
|
||||
base_model_counts: Dict[str, int] = {}
|
||||
|
||||
for name, strength in lora_matches:
|
||||
lora_info = await recipe_scanner.get_local_lora(name)
|
||||
lora_data = {
|
||||
"file_name": name,
|
||||
"strength": float(strength),
|
||||
"hash": (lora_info.get("sha256") or "").lower() if lora_info else "",
|
||||
"modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0,
|
||||
"modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "",
|
||||
"modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "",
|
||||
"isDeleted": False,
|
||||
"exclude": False,
|
||||
}
|
||||
loras_data.append(lora_data)
|
||||
|
||||
if lora_info and "base_model" in lora_info:
|
||||
base_model = lora_info["base_model"]
|
||||
base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1
|
||||
|
||||
recipe_name = self._derive_recipe_name(lora_matches)
|
||||
most_common_base_model = (
|
||||
max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else ""
|
||||
)
|
||||
|
||||
recipe_data = {
|
||||
"id": recipe_id,
|
||||
"file_path": image_path,
|
||||
"title": recipe_name,
|
||||
"modified": time.time(),
|
||||
"created_date": time.time(),
|
||||
"base_model": most_common_base_model,
|
||||
"loras": loras_data,
|
||||
"checkpoint": metadata.get("checkpoint", ""),
|
||||
"gen_params": {
|
||||
key: value
|
||||
for key, value in metadata.items()
|
||||
if key not in ["checkpoint", "loras"]
|
||||
},
|
||||
"loras_stack": lora_stack,
|
||||
}
|
||||
|
||||
json_filename = f"{recipe_id}.recipe.json"
|
||||
json_path = os.path.join(recipes_dir, json_filename)
|
||||
with open(json_path, "w", encoding="utf-8") as file_obj:
|
||||
json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False)
|
||||
|
||||
self._exif_utils.append_recipe_metadata(image_path, recipe_data)
|
||||
await recipe_scanner.add_recipe(recipe_data)
|
||||
|
||||
return PersistenceResult(
|
||||
{
|
||||
"success": True,
|
||||
"recipe_id": recipe_id,
|
||||
"image_path": image_path,
|
||||
"json_path": json_path,
|
||||
"recipe_name": recipe_name,
|
||||
}
|
||||
)
|
||||
|
||||
# Helper methods ---------------------------------------------------
|
||||
|
||||
def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes:
|
||||
if image_bytes is not None:
|
||||
return image_bytes
|
||||
if image_base64:
|
||||
try:
|
||||
payload = image_base64.split(",", 1)[1] if "," in image_base64 else image_base64
|
||||
return base64.b64decode(payload)
|
||||
except Exception as exc: # pragma: no cover - validation guard
|
||||
raise RecipeValidationError(f"Invalid base64 image data: {exc}") from exc
|
||||
raise RecipeValidationError("No image data provided")
|
||||
|
||||
def _normalise_lora_entry(self, lora: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"file_name": lora.get("file_name", "")
|
||||
or (
|
||||
os.path.splitext(os.path.basename(lora.get("localPath", "")))[0]
|
||||
if lora.get("localPath")
|
||||
else ""
|
||||
),
|
||||
"hash": (lora.get("hash") or "").lower(),
|
||||
"strength": float(lora.get("weight", 1.0)),
|
||||
"modelVersionId": lora.get("id", 0),
|
||||
"modelName": lora.get("name", ""),
|
||||
"modelVersionName": lora.get("version", ""),
|
||||
"isDeleted": lora.get("isDeleted", False),
|
||||
"exclude": lora.get("exclude", False),
|
||||
}
|
||||
|
||||
async def _find_matching_recipes(
|
||||
self,
|
||||
recipe_scanner,
|
||||
fingerprint: str | None,
|
||||
*,
|
||||
exclude_id: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
if not fingerprint:
|
||||
return []
|
||||
matches = await recipe_scanner.find_recipes_by_fingerprint(fingerprint)
|
||||
if exclude_id and exclude_id in matches:
|
||||
matches.remove(exclude_id)
|
||||
return matches
|
||||
|
||||
def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str:
|
||||
recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]]
|
||||
recipe_name = "_".join(recipe_name_parts)
|
||||
return recipe_name or "recipe"
|
||||
105
py/services/recipes/sharing_service.py
Normal file
105
py/services/recipes/sharing_service.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Services handling recipe sharing and downloads."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
||||
from .errors import RecipeNotFoundError
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SharingResult:
|
||||
"""Return payload for share operations."""
|
||||
|
||||
payload: dict[str, Any]
|
||||
status: int = 200
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DownloadInfo:
|
||||
"""Information required to stream a shared recipe file."""
|
||||
|
||||
file_path: str
|
||||
download_filename: str
|
||||
|
||||
|
||||
class RecipeSharingService:
|
||||
"""Prepare temporary recipe downloads with TTL cleanup."""
|
||||
|
||||
def __init__(self, *, ttl_seconds: int = 300, logger) -> None:
|
||||
self._ttl_seconds = ttl_seconds
|
||||
self._logger = logger
|
||||
self._shared_recipes: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult:
|
||||
"""Prepare a temporary downloadable copy of a recipe image."""
|
||||
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
if not recipe:
|
||||
raise RecipeNotFoundError("Recipe not found")
|
||||
|
||||
image_path = recipe.get("file_path")
|
||||
if not image_path or not os.path.exists(image_path):
|
||||
raise RecipeNotFoundError("Recipe image not found")
|
||||
|
||||
ext = os.path.splitext(image_path)[1]
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file:
|
||||
temp_path = temp_file.name
|
||||
|
||||
shutil.copy2(image_path, temp_path)
|
||||
timestamp = int(time.time())
|
||||
self._shared_recipes[recipe_id] = {
|
||||
"path": temp_path,
|
||||
"timestamp": timestamp,
|
||||
"expires": time.time() + self._ttl_seconds,
|
||||
}
|
||||
self._cleanup_shared_recipes()
|
||||
|
||||
safe_title = recipe.get("title", "").replace(" ", "_").lower()
|
||||
filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}"
|
||||
url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}"
|
||||
return SharingResult({"success": True, "download_url": url_path, "filename": filename})
|
||||
|
||||
async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> DownloadInfo:
|
||||
"""Return file path and filename for a prepared shared recipe."""
|
||||
|
||||
shared_info = self._shared_recipes.get(recipe_id)
|
||||
if not shared_info or time.time() > shared_info.get("expires", 0):
|
||||
self._cleanup_entry(recipe_id)
|
||||
raise RecipeNotFoundError("Shared recipe not found or expired")
|
||||
|
||||
file_path = shared_info["path"]
|
||||
if not os.path.exists(file_path):
|
||||
self._cleanup_entry(recipe_id)
|
||||
raise RecipeNotFoundError("Shared recipe file not found")
|
||||
|
||||
recipe = await recipe_scanner.get_recipe_by_id(recipe_id)
|
||||
filename_base = (
|
||||
f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id
|
||||
)
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
download_filename = f"{filename_base}{ext}"
|
||||
return DownloadInfo(file_path=file_path, download_filename=download_filename)
|
||||
|
||||
def _cleanup_shared_recipes(self) -> None:
|
||||
for recipe_id in list(self._shared_recipes.keys()):
|
||||
shared = self._shared_recipes.get(recipe_id)
|
||||
if not shared:
|
||||
continue
|
||||
if time.time() > shared.get("expires", 0):
|
||||
self._cleanup_entry(recipe_id)
|
||||
|
||||
def _cleanup_entry(self, recipe_id: str) -> None:
|
||||
shared_info = self._shared_recipes.pop(recipe_id, None)
|
||||
if not shared_info:
|
||||
return
|
||||
file_path = shared_info.get("path")
|
||||
if file_path and os.path.exists(file_path):
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc)
|
||||
47
py/services/tag_update_service.py
Normal file
47
py/services/tag_update_service.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Service for updating tag collections on metadata records."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from typing import Awaitable, Callable, Dict, List, Sequence
|
||||
|
||||
|
||||
class TagUpdateService:
|
||||
"""Encapsulate tag manipulation for models."""
|
||||
|
||||
def __init__(self, *, metadata_manager) -> None:
|
||||
self._metadata_manager = metadata_manager
|
||||
|
||||
async def add_tags(
|
||||
self,
|
||||
*,
|
||||
file_path: str,
|
||||
new_tags: Sequence[str],
|
||||
metadata_loader: Callable[[str], Awaitable[Dict[str, object]]],
|
||||
update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]],
|
||||
) -> List[str]:
|
||||
"""Add tags to a metadata entry while keeping case-insensitive uniqueness."""
|
||||
|
||||
base, _ = os.path.splitext(file_path)
|
||||
metadata_path = f"{base}.metadata.json"
|
||||
metadata = await metadata_loader(metadata_path)
|
||||
|
||||
existing_tags = list(metadata.get("tags", []))
|
||||
existing_lower = [tag.lower() for tag in existing_tags]
|
||||
|
||||
tags_added: List[str] = []
|
||||
for tag in new_tags:
|
||||
if isinstance(tag, str) and tag.strip():
|
||||
normalized = tag.strip()
|
||||
if normalized.lower() not in existing_lower:
|
||||
existing_tags.append(normalized)
|
||||
existing_lower.append(normalized.lower())
|
||||
tags_added.append(normalized)
|
||||
|
||||
metadata["tags"] = existing_tags
|
||||
await self._metadata_manager.save_metadata(file_path, metadata)
|
||||
await update_cache(file_path, file_path, metadata)
|
||||
|
||||
return existing_tags
|
||||
|
||||
37
py/services/use_cases/__init__.py
Normal file
37
py/services/use_cases/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Application-level orchestration services for model routes."""
|
||||
|
||||
from .auto_organize_use_case import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
)
|
||||
from .bulk_metadata_refresh_use_case import (
|
||||
BulkMetadataRefreshUseCase,
|
||||
MetadataRefreshProgressReporter,
|
||||
)
|
||||
from .download_model_use_case import (
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
)
|
||||
from .example_images import (
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutoOrganizeInProgressError",
|
||||
"AutoOrganizeUseCase",
|
||||
"BulkMetadataRefreshUseCase",
|
||||
"MetadataRefreshProgressReporter",
|
||||
"DownloadModelEarlyAccessError",
|
||||
"DownloadModelUseCase",
|
||||
"DownloadModelValidationError",
|
||||
"DownloadExampleImagesConfigurationError",
|
||||
"DownloadExampleImagesInProgressError",
|
||||
"DownloadExampleImagesUseCase",
|
||||
"ImportExampleImagesUseCase",
|
||||
"ImportExampleImagesValidationError",
|
||||
]
|
||||
56
py/services/use_cases/auto_organize_use_case.py
Normal file
56
py/services/use_cases/auto_organize_use_case.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Auto-organize use case orchestrating concurrency and progress handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Protocol, Sequence
|
||||
|
||||
from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback
|
||||
|
||||
|
||||
class AutoOrganizeLockProvider(Protocol):
|
||||
"""Minimal protocol for objects exposing auto-organize locking primitives."""
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
"""Return ``True`` when an auto-organize operation is in-flight."""
|
||||
|
||||
async def get_auto_organize_lock(self) -> asyncio.Lock:
|
||||
"""Return the asyncio lock guarding auto-organize operations."""
|
||||
|
||||
|
||||
class AutoOrganizeInProgressError(RuntimeError):
|
||||
"""Raised when an auto-organize run is already active."""
|
||||
|
||||
|
||||
class AutoOrganizeUseCase:
|
||||
"""Coordinate auto-organize execution behind a shared lock."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_service: ModelFileService,
|
||||
lock_provider: AutoOrganizeLockProvider,
|
||||
) -> None:
|
||||
self._file_service = file_service
|
||||
self._lock_provider = lock_provider
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
file_paths: Optional[Sequence[str]] = None,
|
||||
progress_callback: Optional[ProgressCallback] = None,
|
||||
) -> AutoOrganizeResult:
|
||||
"""Run the auto-organize routine guarded by a shared lock."""
|
||||
|
||||
if self._lock_provider.is_auto_organize_running():
|
||||
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||
|
||||
lock = await self._lock_provider.get_auto_organize_lock()
|
||||
if lock.locked():
|
||||
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||
|
||||
async with lock:
|
||||
return await self._file_service.auto_organize_models(
|
||||
file_paths=list(file_paths) if file_paths is not None else None,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Use case encapsulating the bulk metadata refresh orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Protocol, Sequence
|
||||
|
||||
from ..metadata_sync_service import MetadataSyncService
|
||||
|
||||
|
||||
class MetadataRefreshProgressReporter(Protocol):
|
||||
"""Protocol for progress reporters used during metadata refresh."""
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
"""Handle a metadata refresh progress update."""
|
||||
|
||||
|
||||
class BulkMetadataRefreshUseCase:
|
||||
"""Coordinate bulk metadata refreshes with progress emission."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
metadata_sync: MetadataSyncService,
|
||||
settings_service,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._metadata_sync = metadata_sync
|
||||
self._settings = settings_service
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Refresh metadata for all qualifying models."""
|
||||
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
total_models = len(cache.raw_data)
|
||||
|
||||
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
|
||||
to_process: Sequence[Dict[str, Any]] = [
|
||||
model
|
||||
for model in cache.raw_data
|
||||
if model.get("sha256")
|
||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||
and (
|
||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||
)
|
||||
]
|
||||
|
||||
total_to_process = len(to_process)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
async def emit(status: str, **extra: Any) -> None:
|
||||
if progress_callback is None:
|
||||
return
|
||||
payload = {"status": status, "total": total_to_process, "processed": processed, "success": success}
|
||||
payload.update(extra)
|
||||
await progress_callback.on_progress(payload)
|
||||
|
||||
await emit("started")
|
||||
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, _ = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
update_cache_func=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
if original_name != model.get("model_name"):
|
||||
needs_resort = True
|
||||
processed += 1
|
||||
await emit(
|
||||
"processing",
|
||||
processed=processed,
|
||||
success=success,
|
||||
current_name=model.get("model_name", "Unknown"),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
processed += 1
|
||||
self._logger.error(
|
||||
"Error fetching CivitAI data for %s: %s",
|
||||
model.get("file_path"),
|
||||
exc,
|
||||
)
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
await emit("completed", processed=processed, success=success)
|
||||
|
||||
message = (
|
||||
"Successfully updated "
|
||||
f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})"
|
||||
)
|
||||
|
||||
return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models}
|
||||
|
||||
async def execute_with_error_handling(
|
||||
self,
|
||||
*,
|
||||
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Wrapper providing progress notification on unexpected failures."""
|
||||
|
||||
try:
|
||||
return await self.execute(progress_callback=progress_callback)
|
||||
except Exception as exc:
|
||||
if progress_callback is not None:
|
||||
await progress_callback.on_progress({"status": "error", "error": str(exc)})
|
||||
raise
|
||||
37
py/services/use_cases/download_model_use_case.py
Normal file
37
py/services/use_cases/download_model_use_case.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Use case for scheduling model downloads with consistent error handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..download_coordinator import DownloadCoordinator
|
||||
|
||||
|
||||
class DownloadModelValidationError(ValueError):
|
||||
"""Raised when incoming payload validation fails."""
|
||||
|
||||
|
||||
class DownloadModelEarlyAccessError(RuntimeError):
|
||||
"""Raised when the download is gated behind Civitai early access."""
|
||||
|
||||
|
||||
class DownloadModelUseCase:
|
||||
"""Coordinate download scheduling through the coordinator service."""
|
||||
|
||||
def __init__(self, *, download_coordinator: DownloadCoordinator) -> None:
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download and normalize error conditions."""
|
||||
|
||||
try:
|
||||
return await self._download_coordinator.schedule_download(payload)
|
||||
except ValueError as exc:
|
||||
raise DownloadModelValidationError(str(exc)) from exc
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
message = str(exc)
|
||||
if "401" in message:
|
||||
raise DownloadModelEarlyAccessError(
|
||||
"Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
) from exc
|
||||
raise
|
||||
19
py/services/use_cases/example_images/__init__.py
Normal file
19
py/services/use_cases/example_images/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Example image specific use case exports."""
|
||||
|
||||
from .download_example_images_use_case import (
|
||||
DownloadExampleImagesUseCase,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesConfigurationError,
|
||||
)
|
||||
from .import_example_images_use_case import (
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DownloadExampleImagesUseCase",
|
||||
"DownloadExampleImagesInProgressError",
|
||||
"DownloadExampleImagesConfigurationError",
|
||||
"ImportExampleImagesUseCase",
|
||||
"ImportExampleImagesValidationError",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Use case coordinating example image downloads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ....utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
|
||||
|
||||
class DownloadExampleImagesInProgressError(RuntimeError):
|
||||
"""Raised when a download is already running."""
|
||||
|
||||
def __init__(self, progress: Dict[str, Any]) -> None:
|
||||
super().__init__("Download already in progress")
|
||||
self.progress = progress
|
||||
|
||||
|
||||
class DownloadExampleImagesConfigurationError(ValueError):
|
||||
"""Raised when settings prevent downloads from starting."""
|
||||
|
||||
|
||||
class DownloadExampleImagesUseCase:
|
||||
"""Validate payloads and trigger the download manager."""
|
||||
|
||||
def __init__(self, *, download_manager) -> None:
|
||||
self._download_manager = download_manager
|
||||
|
||||
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Start a download and translate manager errors."""
|
||||
|
||||
try:
|
||||
return await self._download_manager.start_download(payload)
|
||||
except DownloadInProgressError as exc:
|
||||
raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc
|
||||
except DownloadConfigurationError as exc:
|
||||
raise DownloadExampleImagesConfigurationError(str(exc)) from exc
|
||||
except ExampleImagesDownloadError:
|
||||
raise
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Use case for importing example images."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import suppress
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from ....utils.example_images_processor import (
|
||||
ExampleImagesImportError,
|
||||
ExampleImagesProcessor,
|
||||
ExampleImagesValidationError,
|
||||
)
|
||||
|
||||
|
||||
class ImportExampleImagesValidationError(ValueError):
|
||||
"""Raised when request validation fails."""
|
||||
|
||||
|
||||
class ImportExampleImagesUseCase:
|
||||
"""Parse upload payloads and delegate to the processor service."""
|
||||
|
||||
def __init__(self, *, processor: ExampleImagesProcessor) -> None:
|
||||
self._processor = processor
|
||||
|
||||
async def execute(self, request: web.Request) -> Dict[str, Any]:
|
||||
model_hash: str | None = None
|
||||
files_to_import: List[str] = []
|
||||
temp_files: List[str] = []
|
||||
|
||||
try:
|
||||
if request.content_type and "multipart/form-data" in request.content_type:
|
||||
reader = await request.multipart()
|
||||
|
||||
first_field = await reader.next()
|
||||
if first_field and first_field.name == "model_hash":
|
||||
model_hash = await first_field.text()
|
||||
else:
|
||||
# Support clients that send files first and hash later
|
||||
if first_field is not None:
|
||||
await self._collect_upload_file(first_field, files_to_import, temp_files)
|
||||
|
||||
async for field in reader:
|
||||
if field.name == "model_hash" and not model_hash:
|
||||
model_hash = await field.text()
|
||||
elif field.name == "files":
|
||||
await self._collect_upload_file(field, files_to_import, temp_files)
|
||||
else:
|
||||
data = await request.json()
|
||||
model_hash = data.get("model_hash")
|
||||
files_to_import = list(data.get("file_paths", []))
|
||||
|
||||
result = await self._processor.import_images(model_hash, files_to_import)
|
||||
return result
|
||||
except ExampleImagesValidationError as exc:
|
||||
raise ImportExampleImagesValidationError(str(exc)) from exc
|
||||
except ExampleImagesImportError:
|
||||
raise
|
||||
finally:
|
||||
for path in temp_files:
|
||||
with suppress(Exception):
|
||||
os.remove(path)
|
||||
|
||||
async def _collect_upload_file(
|
||||
self,
|
||||
field: Any,
|
||||
files_to_import: List[str],
|
||||
temp_files: List[str],
|
||||
) -> None:
|
||||
"""Persist an uploaded file to disk and add it to the import list."""
|
||||
|
||||
filename = field.filename or "upload"
|
||||
file_ext = os.path.splitext(filename)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||
temp_files.append(tmp_file.name)
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
tmp_file.write(chunk)
|
||||
|
||||
files_to_import.append(tmp_file.name)
|
||||
@@ -1,11 +1,29 @@
|
||||
from typing import Dict, Any
|
||||
"""Progress callback implementations backed by the shared WebSocket manager."""
|
||||
|
||||
from typing import Any, Dict, Protocol
|
||||
|
||||
from .model_file_service import ProgressCallback
|
||||
from .websocket_manager import ws_manager
|
||||
|
||||
|
||||
class WebSocketProgressCallback(ProgressCallback):
|
||||
"""WebSocket implementation of progress callback"""
|
||||
|
||||
class ProgressReporter(Protocol):
|
||||
"""Protocol representing an async progress callback."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send progress data via WebSocket"""
|
||||
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
||||
"""Handle a progress update payload."""
|
||||
|
||||
|
||||
class WebSocketProgressCallback(ProgressCallback):
|
||||
"""WebSocket implementation of progress callback."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send progress data via WebSocket."""
|
||||
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
||||
|
||||
|
||||
class WebSocketBroadcastCallback:
|
||||
"""Generic WebSocket progress callback broadcasting to all clients."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send the provided payload to all connected clients."""
|
||||
await ws_manager.broadcast(progress_data)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,19 +1,39 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
|
||||
from ..recipes.constants import GEN_PARAM_KEYS
|
||||
from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider
|
||||
from ..services.metadata_sync_service import MetadataSyncService
|
||||
from ..services.preview_asset_service import PreviewAssetService
|
||||
from ..services.settings_manager import settings
|
||||
from ..services.downloader import get_downloader
|
||||
from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..recipes.constants import GEN_PARAM_KEYS
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_preview_service = PreviewAssetService(
|
||||
metadata_manager=MetadataManager,
|
||||
downloader_factory=get_downloader,
|
||||
exif_utils=ExifUtils,
|
||||
)
|
||||
|
||||
_metadata_sync_service = MetadataSyncService(
|
||||
metadata_manager=MetadataManager,
|
||||
preview_service=_preview_service,
|
||||
settings=settings,
|
||||
default_metadata_provider_factory=get_default_metadata_provider,
|
||||
metadata_provider_selector=get_metadata_provider,
|
||||
)
|
||||
|
||||
|
||||
class MetadataUpdater:
|
||||
"""Handles updating model metadata related to example images"""
|
||||
|
||||
@staticmethod
|
||||
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner):
|
||||
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None):
|
||||
"""Refresh model metadata from CivitAI
|
||||
|
||||
Args:
|
||||
@@ -25,8 +45,6 @@ class MetadataUpdater:
|
||||
Returns:
|
||||
bool: True if metadata was successfully refreshed, False otherwise
|
||||
"""
|
||||
from ..utils.example_images_download_manager import download_progress
|
||||
|
||||
try:
|
||||
# Find the model in the scanner cache
|
||||
cache = await scanner.get_cached_data()
|
||||
@@ -47,17 +65,17 @@ class MetadataUpdater:
|
||||
return False
|
||||
|
||||
# Track that we're refreshing this model
|
||||
download_progress['refreshed_models'].add(model_hash)
|
||||
if progress is not None:
|
||||
progress['refreshed_models'].add(model_hash)
|
||||
|
||||
# Use ModelRouteUtils to refresh metadata
|
||||
async def update_cache_func(old_path, new_path, metadata):
|
||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||
|
||||
success, error = await ModelRouteUtils.fetch_and_update_model(
|
||||
model_hash,
|
||||
file_path,
|
||||
model_data,
|
||||
update_cache_func
|
||||
success, error = await _metadata_sync_service.fetch_and_update_model(
|
||||
sha256=model_hash,
|
||||
file_path=file_path,
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache_func,
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -66,12 +84,13 @@ class MetadataUpdater:
|
||||
else:
|
||||
logger.warning(f"Failed to refresh metadata for {model_name}, {error}")
|
||||
return False
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
download_progress['errors'].append(error_msg)
|
||||
download_progress['last_error'] = error_msg
|
||||
if progress is not None:
|
||||
progress['errors'].append(error_msg)
|
||||
progress['last_error'] = error_msg
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import random
|
||||
import string
|
||||
from aiohttp import web
|
||||
@@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExampleImagesImportError(RuntimeError):
|
||||
"""Base error for example image import operations."""
|
||||
|
||||
|
||||
class ExampleImagesValidationError(ExampleImagesImportError):
|
||||
"""Raised when input validation fails."""
|
||||
|
||||
class ExampleImagesProcessor:
|
||||
"""Processes and manipulates example images"""
|
||||
|
||||
@@ -299,90 +306,29 @@ class ExampleImagesProcessor:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def import_images(request):
|
||||
"""
|
||||
Import local example images
|
||||
|
||||
Accepts:
|
||||
- multipart/form-data form with model_hash and files fields
|
||||
or
|
||||
- JSON request with model_hash and file_paths
|
||||
|
||||
Returns:
|
||||
- Success status and list of imported files
|
||||
"""
|
||||
async def import_images(model_hash: str, files_to_import: list[str]):
|
||||
"""Import local example images for a model."""
|
||||
|
||||
if not model_hash:
|
||||
raise ExampleImagesValidationError('Missing model_hash parameter')
|
||||
|
||||
if not files_to_import:
|
||||
raise ExampleImagesValidationError('No files provided to import')
|
||||
|
||||
try:
|
||||
model_hash = None
|
||||
files_to_import = []
|
||||
temp_files_to_cleanup = []
|
||||
|
||||
# Check if it's a multipart form-data request (direct file upload)
|
||||
if request.content_type and 'multipart/form-data' in request.content_type:
|
||||
reader = await request.multipart()
|
||||
|
||||
# First get model_hash
|
||||
field = await reader.next()
|
||||
if field.name == 'model_hash':
|
||||
model_hash = await field.text()
|
||||
|
||||
# Then process all files
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
if field.name == 'files':
|
||||
# Create a temporary file with appropriate suffix for type detection
|
||||
file_name = field.filename
|
||||
file_ext = os.path.splitext(file_name)[1].lower()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file:
|
||||
temp_path = tmp_file.name
|
||||
temp_files_to_cleanup.append(temp_path) # Track for cleanup
|
||||
|
||||
# Write chunks to the temporary file
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
tmp_file.write(chunk)
|
||||
|
||||
# Add to the list of files to process
|
||||
files_to_import.append(temp_path)
|
||||
else:
|
||||
# Parse JSON request (legacy method using file paths)
|
||||
data = await request.json()
|
||||
model_hash = data.get('model_hash')
|
||||
files_to_import = data.get('file_paths', [])
|
||||
|
||||
if not model_hash:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hash parameter'
|
||||
}, status=400)
|
||||
|
||||
if not files_to_import:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No files provided to import'
|
||||
}, status=400)
|
||||
|
||||
# Get example images path
|
||||
example_images_path = settings.get('example_images_path')
|
||||
if not example_images_path:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No example images path configured'
|
||||
}, status=400)
|
||||
|
||||
raise ExampleImagesValidationError('No example images path configured')
|
||||
|
||||
# Find the model and get current metadata
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
|
||||
|
||||
model_data = None
|
||||
scanner = None
|
||||
|
||||
|
||||
# Check both scanners to find the model
|
||||
for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]:
|
||||
cache = await scan_obj.get_cached_data()
|
||||
@@ -393,21 +339,20 @@ class ExampleImagesProcessor:
|
||||
break
|
||||
if model_data:
|
||||
break
|
||||
|
||||
|
||||
if not model_data:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Model with hash {model_hash} not found in cache"
|
||||
}, status=404)
|
||||
|
||||
raise ExampleImagesImportError(
|
||||
f"Model with hash {model_hash} not found in cache"
|
||||
)
|
||||
|
||||
# Create model folder
|
||||
model_folder = os.path.join(example_images_path, model_hash)
|
||||
os.makedirs(model_folder, exist_ok=True)
|
||||
|
||||
|
||||
imported_files = []
|
||||
errors = []
|
||||
newly_imported_paths = []
|
||||
|
||||
|
||||
# Process each file path
|
||||
for file_path in files_to_import:
|
||||
try:
|
||||
@@ -415,26 +360,26 @@ class ExampleImagesProcessor:
|
||||
if not os.path.isfile(file_path):
|
||||
errors.append(f"File not found: {file_path}")
|
||||
continue
|
||||
|
||||
|
||||
# Check if file type is supported
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or
|
||||
file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']):
|
||||
errors.append(f"Unsupported file type: {file_path}")
|
||||
continue
|
||||
|
||||
|
||||
# Generate new filename using short ID instead of UUID
|
||||
short_id = ExampleImagesProcessor.generate_short_id()
|
||||
new_filename = f"custom_{short_id}{file_ext}"
|
||||
|
||||
|
||||
dest_path = os.path.join(model_folder, new_filename)
|
||||
|
||||
|
||||
# Copy the file
|
||||
import shutil
|
||||
shutil.copy2(file_path, dest_path)
|
||||
# Store both the dest_path and the short_id
|
||||
newly_imported_paths.append((dest_path, short_id))
|
||||
|
||||
|
||||
# Add to imported files list
|
||||
imported_files.append({
|
||||
'name': new_filename,
|
||||
@@ -444,39 +389,31 @@ class ExampleImagesProcessor:
|
||||
})
|
||||
except Exception as e:
|
||||
errors.append(f"Error importing {file_path}: {str(e)}")
|
||||
|
||||
|
||||
# Update metadata with new example images
|
||||
regular_images, custom_images = await MetadataUpdater.update_metadata_after_import(
|
||||
model_hash,
|
||||
model_hash,
|
||||
model_data,
|
||||
scanner,
|
||||
newly_imported_paths
|
||||
)
|
||||
|
||||
return web.json_response({
|
||||
|
||||
return {
|
||||
'success': len(imported_files) > 0,
|
||||
'message': f'Successfully imported {len(imported_files)} files' +
|
||||
'message': f'Successfully imported {len(imported_files)} files' +
|
||||
(f' with {len(errors)} errors' if errors else ''),
|
||||
'files': imported_files,
|
||||
'errors': errors,
|
||||
'regular_images': regular_images,
|
||||
'custom_images': custom_images,
|
||||
"model_file_path": model_data.get('file_path', ''),
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
except ExampleImagesImportError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to import example images: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
for temp_file in temp_files_to_cleanup:
|
||||
try:
|
||||
os.remove(temp_file)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temporary file {temp_file}: {e}")
|
||||
raise ExampleImagesImportError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
async def delete_custom_image(request):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Callable, Awaitable
|
||||
from typing import Dict, Callable, Awaitable
|
||||
from aiohttp import web
|
||||
from datetime import datetime
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..services.settings_manager import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: retire this class
|
||||
class ModelRouteUtils:
|
||||
"""Shared utilities for model routes (LoRAs, Checkpoints, etc.)"""
|
||||
|
||||
@@ -284,104 +284,6 @@ class ModelRouteUtils:
|
||||
]
|
||||
return {k: data[k] for k in fields if k in data}
|
||||
|
||||
@staticmethod
|
||||
async def delete_model_files(target_dir: str, file_name: str) -> List[str]:
|
||||
"""Delete model and associated files
|
||||
|
||||
Args:
|
||||
target_dir: Directory containing the model files
|
||||
file_name: Base name of the model file without extension
|
||||
|
||||
Returns:
|
||||
List of deleted file paths
|
||||
"""
|
||||
patterns = [
|
||||
f"{file_name}.safetensors", # Required
|
||||
f"{file_name}.metadata.json",
|
||||
]
|
||||
|
||||
# Add all preview file extensions
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{file_name}{ext}")
|
||||
|
||||
deleted = []
|
||||
main_file = patterns[0]
|
||||
main_path = os.path.join(target_dir, main_file).replace(os.sep, '/')
|
||||
|
||||
if os.path.exists(main_path):
|
||||
# Delete file
|
||||
os.remove(main_path)
|
||||
deleted.append(main_path)
|
||||
else:
|
||||
logger.warning(f"Model file not found: {main_file}")
|
||||
|
||||
# Delete optional files
|
||||
for pattern in patterns[1:]:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted.append(pattern)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete {pattern}: {e}")
|
||||
|
||||
return deleted
|
||||
|
||||
@staticmethod
|
||||
def get_multipart_ext(filename):
|
||||
"""Get extension that may have multiple parts like .metadata.json or .metadata.json.bak"""
|
||||
parts = filename.split(".")
|
||||
if len(parts) == 3: # If contains 2-part extension
|
||||
return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json"
|
||||
elif len(parts) >= 4: # If contains 3-part or more extensions
|
||||
return "." + ".".join(parts[-3:]) # Take the last three parts, like ".metadata.json.bak"
|
||||
return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors"
|
||||
|
||||
# New common endpoint handlers
|
||||
|
||||
@staticmethod
|
||||
async def handle_delete_model(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle model deletion request
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance with cache management methods
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='Model path is required', status=400)
|
||||
|
||||
target_dir = os.path.dirname(file_path)
|
||||
file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
deleted_files = await ModelRouteUtils.delete_model_files(
|
||||
target_dir,
|
||||
file_name
|
||||
)
|
||||
|
||||
# Remove from cache
|
||||
cache = await scanner.get_cached_data()
|
||||
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
|
||||
await cache.resort()
|
||||
|
||||
# Update hash index if available
|
||||
if hasattr(scanner, '_hash_index') and scanner._hash_index:
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'deleted_files': deleted_files
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle CivitAI metadata fetch request
|
||||
@@ -544,64 +446,6 @@ class ModelRouteUtils:
|
||||
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_exclude_model(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle model exclusion request
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance with cache management methods
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
if not file_path:
|
||||
return web.Response(text='Model path is required', status=400)
|
||||
|
||||
# Update metadata to mark as excluded
|
||||
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
metadata['exclude'] = True
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(file_path, metadata)
|
||||
|
||||
# Update cache
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
# Find and remove model from cache
|
||||
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||
if model_to_remove:
|
||||
# Update tags count
|
||||
for tag in model_to_remove.get('tags', []):
|
||||
if tag in scanner._tags_count:
|
||||
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||
if scanner._tags_count[tag] == 0:
|
||||
del scanner._tags_count[tag]
|
||||
|
||||
# Remove from hash index if available
|
||||
if hasattr(scanner, '_hash_index') and scanner._hash_index:
|
||||
scanner._hash_index.remove_by_path(file_path)
|
||||
|
||||
# Remove from cache data
|
||||
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
|
||||
await cache.resort()
|
||||
|
||||
# Add to excluded models list
|
||||
scanner._excluded_models.append(file_path)
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'message': f"Model {os.path.basename(file_path)} excluded"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error excluding model: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_download_model(request: web.Request) -> web.Response:
|
||||
"""Handle model download request"""
|
||||
@@ -755,44 +599,6 @@ class ModelRouteUtils:
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle bulk deletion of models
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance with cache management methods
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_paths = data.get('file_paths', [])
|
||||
|
||||
if not file_paths:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No file paths provided for deletion'
|
||||
}, status=400)
|
||||
|
||||
# Use the scanner's bulk delete method to handle all cache and file operations
|
||||
result = await scanner.bulk_delete_models(file_paths)
|
||||
|
||||
return web.json_response({
|
||||
'success': result.get('success', False),
|
||||
'total_deleted': result.get('total_deleted', 0),
|
||||
'total_attempted': result.get('total_attempted', len(file_paths)),
|
||||
'results': result.get('results', [])
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bulk delete: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_relink_civitai(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle CivitAI metadata re-linking request by model ID and/or version ID
|
||||
@@ -948,137 +754,6 @@ class ModelRouteUtils:
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_rename_model(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle renaming a model file and its associated files
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
scanner: The model scanner instance
|
||||
|
||||
Returns:
|
||||
web.Response: The HTTP response
|
||||
"""
|
||||
try:
|
||||
data = await request.json()
|
||||
file_path = data.get('file_path')
|
||||
new_file_name = data.get('new_file_name')
|
||||
|
||||
if not file_path or not new_file_name:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'File path and new file name are required'
|
||||
}, status=400)
|
||||
|
||||
# Validate the new file name (no path separators or invalid characters)
|
||||
invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
|
||||
if any(char in new_file_name for char in invalid_chars):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Invalid characters in file name'
|
||||
}, status=400)
|
||||
|
||||
# Get the directory and current file name
|
||||
target_dir = os.path.dirname(file_path)
|
||||
old_file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
|
||||
# Check if the target file already exists
|
||||
new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(os.sep, '/')
|
||||
if os.path.exists(new_file_path):
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'A file with this name already exists'
|
||||
}, status=400)
|
||||
|
||||
# Define the patterns for associated files
|
||||
patterns = [
|
||||
f"{old_file_name}.safetensors", # Required
|
||||
f"{old_file_name}.metadata.json",
|
||||
f"{old_file_name}.metadata.json.bak",
|
||||
]
|
||||
|
||||
# Add all preview file extensions
|
||||
for ext in PREVIEW_EXTENSIONS:
|
||||
patterns.append(f"{old_file_name}{ext}")
|
||||
|
||||
# Find all matching files
|
||||
existing_files = []
|
||||
for pattern in patterns:
|
||||
path = os.path.join(target_dir, pattern)
|
||||
if os.path.exists(path):
|
||||
existing_files.append((path, pattern))
|
||||
|
||||
# Get the hash from the main file to update hash index
|
||||
hash_value = None
|
||||
metadata = None
|
||||
metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json")
|
||||
|
||||
if os.path.exists(metadata_path):
|
||||
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||
hash_value = metadata.get('sha256')
|
||||
logger.info(f"hash_value: {hash_value}, metadata_path: {metadata_path}, metadata: {metadata}")
|
||||
# Rename all files
|
||||
renamed_files = []
|
||||
new_metadata_path = None
|
||||
new_preview = None
|
||||
|
||||
for old_path, pattern in existing_files:
|
||||
# Get the file extension like .safetensors or .metadata.json
|
||||
ext = ModelRouteUtils.get_multipart_ext(pattern)
|
||||
|
||||
# Create the new path
|
||||
new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
|
||||
|
||||
# Rename the file
|
||||
os.rename(old_path, new_path)
|
||||
renamed_files.append(new_path)
|
||||
|
||||
# Keep track of metadata path for later update
|
||||
if ext == '.metadata.json':
|
||||
new_metadata_path = new_path
|
||||
|
||||
# Update the metadata file with new file name and paths
|
||||
if new_metadata_path and metadata:
|
||||
# Update file_name, file_path and preview_url in metadata
|
||||
metadata['file_name'] = new_file_name
|
||||
metadata['file_path'] = new_file_path
|
||||
|
||||
# Update preview_url if it exists
|
||||
if 'preview_url' in metadata and metadata['preview_url']:
|
||||
old_preview = metadata['preview_url']
|
||||
ext = ModelRouteUtils.get_multipart_ext(old_preview)
|
||||
new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/')
|
||||
metadata['preview_url'] = new_preview
|
||||
|
||||
# Save updated metadata
|
||||
await MetadataManager.save_metadata(new_file_path, metadata)
|
||||
|
||||
# Update the scanner cache
|
||||
if metadata:
|
||||
await scanner.update_single_model_cache(file_path, new_file_path, metadata)
|
||||
|
||||
# Update recipe files and cache if hash is available and recipe_scanner exists
|
||||
if hash_value and hasattr(scanner, 'update_lora_filename_by_hash'):
|
||||
recipe_scanner = await ServiceRegistry.get_recipe_scanner()
|
||||
if recipe_scanner:
|
||||
recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name)
|
||||
logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed model")
|
||||
|
||||
return web.json_response({
|
||||
'success': True,
|
||||
'new_file_path': new_file_path,
|
||||
'new_preview_path': config.get_preview_static_url(new_preview),
|
||||
'renamed_files': renamed_files,
|
||||
'reload_required': False
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error renaming model: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
|
||||
@staticmethod
|
||||
async def handle_save_metadata(request: web.Request, scanner) -> web.Response:
|
||||
"""Handle saving metadata updates
|
||||
|
||||
@@ -4,5 +4,8 @@ testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
# Register async marker for coroutine-style tests
|
||||
markers =
|
||||
asyncio: execute test within asyncio event loop
|
||||
# Skip problematic directories to avoid import conflicts
|
||||
norecursedirs = .git .tox dist build *.egg __pycache__ py
|
||||
50
run_tests.py
50
run_tests.py
@@ -1,50 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test runner script for ComfyUI-Lora-Manager.
|
||||
|
||||
This script runs pytest from the tests directory to avoid import issues
|
||||
with the root __init__.py file.
|
||||
"""
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Set environment variable to indicate standalone mode
|
||||
# HF_HUB_DISABLE_TELEMETRY is from ComfyUI main.py
|
||||
standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
|
||||
def main():
|
||||
"""Run pytest from the tests directory to avoid import issues."""
|
||||
# Get the script directory
|
||||
script_dir = Path(__file__).parent.absolute()
|
||||
tests_dir = script_dir / "tests"
|
||||
|
||||
if not tests_dir.exists():
|
||||
print(f"Error: Tests directory not found at {tests_dir}")
|
||||
return 1
|
||||
|
||||
# Change to tests directory
|
||||
original_cwd = os.getcwd()
|
||||
os.chdir(tests_dir)
|
||||
|
||||
try:
|
||||
# Build pytest command
|
||||
cmd = [
|
||||
sys.executable, "-m", "pytest",
|
||||
"-v",
|
||||
"--rootdir=.",
|
||||
] + sys.argv[1:] # Pass any additional arguments
|
||||
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
print(f"Working directory: {tests_dir}")
|
||||
|
||||
# Run pytest
|
||||
result = subprocess.run(cmd, cwd=tests_dir)
|
||||
return result.returncode
|
||||
finally:
|
||||
# Restore original working directory
|
||||
os.chdir(original_cwd)
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager):
|
||||
RecipeRoutes.setup_routes(app)
|
||||
UpdateRoutes.setup_routes(app)
|
||||
MiscRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app)
|
||||
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||
|
||||
# Setup WebSocket routes that are shared across all model types
|
||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import types
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
import asyncio
|
||||
import inspect
|
||||
from unittest import mock
|
||||
import sys
|
||||
|
||||
@@ -39,6 +41,35 @@ nodes_mock.NODE_CLASS_MAPPINGS = {}
|
||||
sys.modules['nodes'] = nodes_mock
|
||||
|
||||
|
||||
def pytest_pyfunc_call(pyfuncitem):
|
||||
"""Allow bare async tests to run without pytest.mark.asyncio."""
|
||||
test_function = pyfuncitem.function
|
||||
if inspect.iscoroutinefunction(test_function):
|
||||
func = pyfuncitem.obj
|
||||
signature = inspect.signature(func)
|
||||
accepted_kwargs: Dict[str, Any] = {}
|
||||
for name, parameter in signature.parameters.items():
|
||||
if parameter.kind is inspect.Parameter.VAR_POSITIONAL:
|
||||
continue
|
||||
if parameter.kind is inspect.Parameter.VAR_KEYWORD:
|
||||
accepted_kwargs = dict(pyfuncitem.funcargs)
|
||||
break
|
||||
if name in pyfuncitem.funcargs:
|
||||
accepted_kwargs[name] = pyfuncitem.funcargs[name]
|
||||
|
||||
original_policy = asyncio.get_event_loop_policy()
|
||||
policy = pyfuncitem.funcargs.get("event_loop_policy")
|
||||
if policy is not None and policy is not original_policy:
|
||||
asyncio.set_event_loop_policy(policy)
|
||||
try:
|
||||
asyncio.run(func(**accepted_kwargs))
|
||||
finally:
|
||||
if policy is not None and policy is not original_policy:
|
||||
asyncio.set_event_loop_policy(original_policy)
|
||||
return True
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockHashIndex:
|
||||
"""Minimal hash index stub mirroring the scanner contract."""
|
||||
@@ -50,7 +81,7 @@ class MockHashIndex:
|
||||
|
||||
|
||||
class MockCache:
|
||||
"""Cache object with the attributes consumed by ``ModelRouteUtils``."""
|
||||
"""Cache object with the attributes."""
|
||||
|
||||
def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None):
|
||||
self.raw_data: List[Dict[str, Any]] = list(items or [])
|
||||
@@ -58,7 +89,7 @@ class MockCache:
|
||||
|
||||
async def resort(self) -> None:
|
||||
self.resort_calls += 1
|
||||
# ``ModelRouteUtils`` expects the coroutine interface but does not
|
||||
# expects the coroutine interface but does not
|
||||
# rely on the return value.
|
||||
|
||||
|
||||
@@ -187,3 +218,5 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS
|
||||
@pytest.fixture
|
||||
def mock_service(mock_scanner: MockScanner) -> MockModelService:
|
||||
return MockModelService(scanner=mock_scanner)
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ spec.loader.exec_module(py_local)
|
||||
sys.modules.setdefault("py_local", py_local)
|
||||
|
||||
from py_local.routes.base_model_routes import BaseModelRoutes
|
||||
from py_local.services.model_file_service import AutoOrganizeResult
|
||||
from py_local.services.service_registry import ServiceRegistry
|
||||
from py_local.services.websocket_manager import ws_manager
|
||||
from py_local.utils.routes_common import ExifUtils
|
||||
@@ -66,12 +67,24 @@ def download_manager_stub():
|
||||
class FakeDownloadManager:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
self.error = None
|
||||
self.cancelled = []
|
||||
self.active_downloads = {}
|
||||
|
||||
async def download_from_civitai(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
if self.error is not None:
|
||||
raise self.error
|
||||
await kwargs["progress_callback"](42)
|
||||
return {"success": True, "path": "/tmp/model.safetensors"}
|
||||
|
||||
async def cancel_download(self, download_id):
|
||||
self.cancelled.append(download_id)
|
||||
return {"success": True, "download_id": download_id}
|
||||
|
||||
async def get_active_downloads(self):
|
||||
return self.active_downloads
|
||||
|
||||
stub = FakeDownloadManager()
|
||||
previous = ServiceRegistry._services.get("download_manager")
|
||||
asyncio.run(ServiceRegistry.register_service("download_manager", stub))
|
||||
@@ -103,6 +116,21 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_routes_return_service_not_ready_when_unattached():
|
||||
async def scenario():
|
||||
client = await create_test_client(None)
|
||||
try:
|
||||
response = await client.get("/api/lm/test-models/list")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 503
|
||||
assert payload == {"success": False, "error": "Service not ready"}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_delete_model_updates_cache_and_hash_index(mock_service, mock_scanner, tmp_path: Path):
|
||||
model_path = tmp_path / "sample.safetensors"
|
||||
model_path.write_bytes(b"model")
|
||||
@@ -222,6 +250,69 @@ def test_download_model_invokes_download_manager(
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_download_model_requires_identifier(mock_service, download_manager_stub):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/download-model",
|
||||
json={"model_root": "/tmp"},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "Missing required" in payload["error"]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_download_model_maps_validation_errors(mock_service, download_manager_stub):
|
||||
download_manager_stub.error = ValueError("Invalid relative path")
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/download-model",
|
||||
json={"model_version_id": 123},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 400
|
||||
assert payload == {"success": False, "error": "Invalid relative path"}
|
||||
assert ws_manager._download_progress == {}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_download_model_maps_early_access_errors(mock_service, download_manager_stub):
|
||||
download_manager_stub.error = RuntimeError("401 early access")
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/download-model",
|
||||
json={"model_id": 4},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 401
|
||||
assert payload == {
|
||||
"success": False,
|
||||
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
|
||||
}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
@@ -235,5 +326,65 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
||||
assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch):
|
||||
async def fake_auto_organize(self, file_paths=None, progress_callback=None):
|
||||
result = AutoOrganizeResult()
|
||||
result.total = 1
|
||||
result.processed = 1
|
||||
result.success_count = 1
|
||||
result.skipped_count = 0
|
||||
result.failure_count = 0
|
||||
result.operation_type = "bulk"
|
||||
if progress_callback is not None:
|
||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"})
|
||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"})
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(
|
||||
py_local.services.model_file_service.ModelFileService,
|
||||
"auto_organize_models",
|
||||
fake_auto_organize,
|
||||
)
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []})
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
|
||||
progress = ws_manager.get_auto_organize_progress()
|
||||
assert progress is not None
|
||||
assert progress["status"] == "completed"
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_conflict_when_running(mock_service):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
await ws_manager.broadcast_auto_organize_progress(
|
||||
{"type": "auto_organize_progress", "status": "started"}
|
||||
)
|
||||
|
||||
response = await client.post("/api/lm/test-models/auto-organize")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 409
|
||||
assert payload == {
|
||||
"success": False,
|
||||
"error": "Auto-organize is already running. Please wait for it to complete.",
|
||||
}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
431
tests/routes/test_example_images_routes.py
Normal file
431
tests/routes/test_example_images_routes.py
Normal file
@@ -0,0 +1,431 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
import pytest
|
||||
|
||||
from py.routes.example_images_route_registrar import ROUTE_DEFINITIONS
|
||||
from py.routes.example_images_routes import ExampleImagesRoutes
|
||||
from py.routes.handlers.example_images_handlers import (
|
||||
ExampleImagesDownloadHandler,
|
||||
ExampleImagesFileHandler,
|
||||
ExampleImagesHandlerSet,
|
||||
ExampleImagesManagementHandler,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExampleImagesHarness:
|
||||
"""Container exposing the aiohttp client and stubbed collaborators."""
|
||||
|
||||
client: TestClient
|
||||
download_manager: "StubDownloadManager"
|
||||
processor: "StubExampleImagesProcessor"
|
||||
file_manager: "StubExampleImagesFileManager"
|
||||
controller: ExampleImagesRoutes
|
||||
|
||||
|
||||
class StubDownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def start_download(self, payload: Any) -> dict:
|
||||
self.calls.append(("start_download", payload))
|
||||
return {"operation": "start_download", "payload": payload}
|
||||
|
||||
async def get_status(self, request: web.Request) -> dict:
|
||||
self.calls.append(("get_status", dict(request.query)))
|
||||
return {"operation": "get_status"}
|
||||
|
||||
async def pause_download(self, request: web.Request) -> dict:
|
||||
self.calls.append(("pause_download", None))
|
||||
return {"operation": "pause_download"}
|
||||
|
||||
async def resume_download(self, request: web.Request) -> dict:
|
||||
self.calls.append(("resume_download", None))
|
||||
return {"operation": "resume_download"}
|
||||
|
||||
async def start_force_download(self, payload: Any) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return {"operation": "start_force_download", "payload": payload}
|
||||
|
||||
|
||||
class StubExampleImagesProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def import_images(self, model_hash: str, files: List[str]) -> dict:
|
||||
payload = {"model_hash": model_hash, "file_paths": files}
|
||||
self.calls.append(("import_images", payload))
|
||||
return {"operation": "import_images", "payload": payload}
|
||||
|
||||
async def delete_custom_image(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
self.calls.append(("delete_custom_image", payload))
|
||||
return web.json_response({"operation": "delete_custom_image", "payload": payload})
|
||||
|
||||
|
||||
class StubExampleImagesFileManager:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def open_folder(self, request: web.Request) -> web.StreamResponse:
|
||||
payload = await request.json()
|
||||
self.calls.append(("open_folder", payload))
|
||||
return web.json_response({"operation": "open_folder", "payload": payload})
|
||||
|
||||
async def get_files(self, request: web.Request) -> web.StreamResponse:
|
||||
self.calls.append(("get_files", dict(request.query)))
|
||||
return web.json_response({"operation": "get_files", "query": dict(request.query)})
|
||||
|
||||
async def has_images(self, request: web.Request) -> web.StreamResponse:
|
||||
self.calls.append(("has_images", dict(request.query)))
|
||||
return web.json_response({"operation": "has_images", "query": dict(request.query)})
|
||||
|
||||
|
||||
class StubWebSocketManager:
|
||||
def __init__(self) -> None:
|
||||
self.broadcast_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def broadcast(self, payload: Dict[str, Any]) -> None:
|
||||
self.broadcast_calls.append(payload)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def example_images_app() -> ExampleImagesHarness:
|
||||
"""Yield an ExampleImagesRoutes app wired with stubbed collaborators."""
|
||||
|
||||
download_manager = StubDownloadManager()
|
||||
processor = StubExampleImagesProcessor()
|
||||
file_manager = StubExampleImagesFileManager()
|
||||
ws_manager = StubWebSocketManager()
|
||||
|
||||
controller = ExampleImagesRoutes(
|
||||
ws_manager=ws_manager,
|
||||
download_manager=download_manager,
|
||||
processor=processor,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
controller.register(app)
|
||||
|
||||
server = TestServer(app)
|
||||
client = TestClient(server)
|
||||
await client.start_server()
|
||||
|
||||
try:
|
||||
yield ExampleImagesHarness(
|
||||
client=client,
|
||||
download_manager=download_manager,
|
||||
processor=processor,
|
||||
file_manager=file_manager,
|
||||
controller=controller,
|
||||
)
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
|
||||
async def test_setup_routes_registers_all_definitions():
|
||||
async with example_images_app() as harness:
|
||||
registered = {
|
||||
(route.method, route.resource.canonical)
|
||||
for route in harness.client.app.router.routes()
|
||||
if route.resource.canonical
|
||||
}
|
||||
|
||||
expected = {(definition.method, definition.path) for definition in ROUTE_DEFINITIONS}
|
||||
|
||||
assert expected <= registered
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint, payload",
|
||||
[
|
||||
("/api/lm/download-example-images", {"model_types": ["lora"], "optimize": False}),
|
||||
("/api/lm/force-download-example-images", {"model_hashes": ["abc123"]}),
|
||||
],
|
||||
)
|
||||
async def test_download_routes_delegate_to_manager(endpoint, payload):
|
||||
async with example_images_app() as harness:
|
||||
response = await harness.client.post(endpoint, json=payload)
|
||||
body = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert body["payload"] == payload
|
||||
assert body["operation"].startswith("start")
|
||||
|
||||
expected_call = body["operation"], payload
|
||||
assert expected_call in harness.download_manager.calls
|
||||
|
||||
|
||||
async def test_status_route_returns_manager_payload():
|
||||
async with example_images_app() as harness:
|
||||
response = await harness.client.get(
|
||||
"/api/lm/example-images-status", params={"detail": "true"}
|
||||
)
|
||||
body = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert body == {"operation": "get_status"}
|
||||
assert harness.download_manager.calls == [("get_status", {"detail": "true"})]
|
||||
|
||||
|
||||
async def test_pause_and_resume_routes_delegate():
|
||||
async with example_images_app() as harness:
|
||||
pause_response = await harness.client.post("/api/lm/pause-example-images")
|
||||
resume_response = await harness.client.post("/api/lm/resume-example-images")
|
||||
|
||||
assert pause_response.status == 200
|
||||
assert await pause_response.json() == {"operation": "pause_download"}
|
||||
assert resume_response.status == 200
|
||||
assert await resume_response.json() == {"operation": "resume_download"}
|
||||
|
||||
assert harness.download_manager.calls[-2:] == [
|
||||
("pause_download", None),
|
||||
("resume_download", None),
|
||||
]
|
||||
|
||||
|
||||
async def test_import_route_delegates_to_processor():
|
||||
payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]}
|
||||
async with example_images_app() as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/import-example-images", json=payload
|
||||
)
|
||||
body = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert body == {"operation": "import_images", "payload": payload}
|
||||
expected_call = ("import_images", payload)
|
||||
assert expected_call in harness.processor.calls
|
||||
|
||||
|
||||
async def test_delete_route_delegates_to_processor():
|
||||
payload = {"model_hash": "abc123", "short_id": "xyz"}
|
||||
async with example_images_app() as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/delete-example-image", json=payload
|
||||
)
|
||||
body = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert body == {"operation": "delete_custom_image", "payload": payload}
|
||||
assert harness.processor.calls == [("delete_custom_image", payload)]
|
||||
|
||||
|
||||
async def test_file_routes_delegate_to_file_manager():
|
||||
open_payload = {"model_hash": "abc123"}
|
||||
files_params = {"model_hash": "def456"}
|
||||
|
||||
async with example_images_app() as harness:
|
||||
open_response = await harness.client.post(
|
||||
"/api/lm/open-example-images-folder", json=open_payload
|
||||
)
|
||||
files_response = await harness.client.get(
|
||||
"/api/lm/example-image-files", params=files_params
|
||||
)
|
||||
has_response = await harness.client.get(
|
||||
"/api/lm/has-example-images", params=files_params
|
||||
)
|
||||
|
||||
assert open_response.status == 200
|
||||
assert files_response.status == 200
|
||||
assert has_response.status == 200
|
||||
|
||||
assert await open_response.json() == {"operation": "open_folder", "payload": open_payload}
|
||||
assert await files_response.json() == {
|
||||
"operation": "get_files",
|
||||
"query": files_params,
|
||||
}
|
||||
assert await has_response.json() == {
|
||||
"operation": "has_images",
|
||||
"query": files_params,
|
||||
}
|
||||
|
||||
assert harness.file_manager.calls == [
|
||||
("open_folder", open_payload),
|
||||
("get_files", files_params),
|
||||
("has_images", files_params),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_handler_methods_delegate() -> None:
|
||||
class Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def get_status(self, request) -> dict:
|
||||
self.calls.append(("get_status", request))
|
||||
return {"status": "ok"}
|
||||
|
||||
async def pause_download(self, request) -> dict:
|
||||
self.calls.append(("pause_download", request))
|
||||
return {"status": "paused"}
|
||||
|
||||
async def resume_download(self, request) -> dict:
|
||||
self.calls.append(("resume_download", request))
|
||||
return {"status": "running"}
|
||||
|
||||
async def start_force_download(self, payload) -> dict:
|
||||
self.calls.append(("start_force_download", payload))
|
||||
return {"status": "force", "payload": payload}
|
||||
|
||||
class StubDownloadUseCase:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[Any] = []
|
||||
|
||||
async def execute(self, payload: dict) -> dict:
|
||||
self.payloads.append(payload)
|
||||
return {"status": "started", "payload": payload}
|
||||
|
||||
class DummyRequest:
|
||||
def __init__(self, payload: dict) -> None:
|
||||
self._payload = payload
|
||||
self.query = {}
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._payload
|
||||
|
||||
recorder = Recorder()
|
||||
use_case = StubDownloadUseCase()
|
||||
handler = ExampleImagesDownloadHandler(use_case, recorder)
|
||||
request = DummyRequest({"foo": "bar"})
|
||||
|
||||
download_response = await handler.download_example_images(request)
|
||||
assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}}
|
||||
status_response = await handler.get_example_images_status(request)
|
||||
assert json.loads(status_response.text) == {"status": "ok"}
|
||||
pause_response = await handler.pause_example_images(request)
|
||||
assert json.loads(pause_response.text) == {"status": "paused"}
|
||||
resume_response = await handler.resume_example_images(request)
|
||||
assert json.loads(resume_response.text) == {"status": "running"}
|
||||
force_response = await handler.force_download_example_images(request)
|
||||
assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}}
|
||||
|
||||
assert use_case.payloads == [{"foo": "bar"}]
|
||||
assert recorder.calls == [
|
||||
("get_status", request),
|
||||
("pause_download", request),
|
||||
("resume_download", request),
|
||||
("start_force_download", {"foo": "bar"}),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_management_handler_methods_delegate() -> None:
|
||||
class StubImportUseCase:
|
||||
def __init__(self) -> None:
|
||||
self.requests: List[Any] = []
|
||||
|
||||
async def execute(self, request: Any) -> dict:
|
||||
self.requests.append(request)
|
||||
return {"status": "imported"}
|
||||
|
||||
class Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def delete_custom_image(self, request) -> str:
|
||||
self.calls.append(("delete_custom_image", request))
|
||||
return "delete"
|
||||
|
||||
recorder = Recorder()
|
||||
use_case = StubImportUseCase()
|
||||
handler = ExampleImagesManagementHandler(use_case, recorder)
|
||||
request = object()
|
||||
|
||||
import_response = await handler.import_example_images(request)
|
||||
assert json.loads(import_response.text) == {"status": "imported"}
|
||||
assert await handler.delete_example_image(request) == "delete"
|
||||
assert use_case.requests == [request]
|
||||
assert recorder.calls == [("delete_custom_image", request)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_handler_methods_delegate() -> None:
|
||||
class Recorder:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Tuple[str, Any]] = []
|
||||
|
||||
async def open_folder(self, request) -> str:
|
||||
self.calls.append(("open_folder", request))
|
||||
return "open"
|
||||
|
||||
async def get_files(self, request) -> str:
|
||||
self.calls.append(("get_files", request))
|
||||
return "files"
|
||||
|
||||
async def has_images(self, request) -> str:
|
||||
self.calls.append(("has_images", request))
|
||||
return "has"
|
||||
|
||||
recorder = Recorder()
|
||||
handler = ExampleImagesFileHandler(recorder)
|
||||
request = object()
|
||||
|
||||
assert await handler.open_example_images_folder(request) == "open"
|
||||
assert await handler.get_example_image_files(request) == "files"
|
||||
assert await handler.has_example_images(request) == "has"
|
||||
assert recorder.calls == [
|
||||
("open_folder", request),
|
||||
("get_files", request),
|
||||
("has_images", request),
|
||||
]
|
||||
|
||||
|
||||
def test_handler_set_route_mapping_includes_all_handlers() -> None:
|
||||
class DummyUseCase:
|
||||
async def execute(self, payload):
|
||||
return payload
|
||||
|
||||
class DummyManager:
|
||||
async def get_status(self, request):
|
||||
return {}
|
||||
|
||||
async def pause_download(self, request):
|
||||
return {}
|
||||
|
||||
async def resume_download(self, request):
|
||||
return {}
|
||||
|
||||
async def start_force_download(self, payload):
|
||||
return payload
|
||||
|
||||
class DummyProcessor:
|
||||
async def delete_custom_image(self, request):
|
||||
return {}
|
||||
|
||||
download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager())
|
||||
management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor())
|
||||
files = ExampleImagesFileHandler(object())
|
||||
handler_set = ExampleImagesHandlerSet(
|
||||
download=download,
|
||||
management=management,
|
||||
files=files,
|
||||
)
|
||||
|
||||
mapping = handler_set.to_route_mapping()
|
||||
|
||||
expected_keys = {
|
||||
"download_example_images",
|
||||
"get_example_images_status",
|
||||
"pause_example_images",
|
||||
"resume_example_images",
|
||||
"force_download_example_images",
|
||||
"import_example_images",
|
||||
"delete_example_image",
|
||||
"open_example_images_folder",
|
||||
"get_example_image_files",
|
||||
"has_example_images",
|
||||
}
|
||||
|
||||
assert mapping.keys() == expected_keys
|
||||
for key in expected_keys:
|
||||
assert callable(mapping[key])
|
||||
236
tests/routes/test_recipe_route_scaffolding.py
Normal file
236
tests/routes/test_recipe_route_scaffolding.py
Normal file
@@ -0,0 +1,236 @@
|
||||
"""Smoke tests for the recipe routing scaffolding.
|
||||
|
||||
The cases keep the registrar/controller contract aligned with
|
||||
``docs/architecture/recipe_routes.md`` so future refactors can focus on handler
|
||||
logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict
|
||||
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
PY_PACKAGE_PATH = REPO_ROOT / "py"
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"py_local",
|
||||
PY_PACKAGE_PATH / "__init__.py",
|
||||
submodule_search_locations=[str(PY_PACKAGE_PATH)],
|
||||
)
|
||||
py_local = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
spec.loader.exec_module(py_local)
|
||||
sys.modules.setdefault("py_local", py_local)
|
||||
|
||||
base_routes_module = importlib.import_module("py_local.routes.base_recipe_routes")
|
||||
recipe_routes_module = importlib.import_module("py_local.routes.recipe_routes")
|
||||
registrar_module = importlib.import_module("py_local.routes.recipe_route_registrar")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_service_registry(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Ensure each test starts from a clean registry state."""
|
||||
|
||||
services_module = importlib.import_module("py_local.services.service_registry")
|
||||
registry = services_module.ServiceRegistry
|
||||
previous_services = dict(registry._services)
|
||||
previous_locks = dict(registry._locks)
|
||||
registry._services.clear()
|
||||
registry._locks.clear()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
registry._services = previous_services
|
||||
registry._locks = previous_locks
|
||||
|
||||
|
||||
def _make_stub_scanner():
|
||||
class _StubScanner:
|
||||
def __init__(self):
|
||||
self._cache = types.SimpleNamespace()
|
||||
|
||||
async def _lora_get_cached_data(): # pragma: no cover - smoke hook
|
||||
return None
|
||||
|
||||
self._lora_scanner = types.SimpleNamespace(
|
||||
get_cached_data=_lora_get_cached_data,
|
||||
_hash_index=types.SimpleNamespace(_hash_to_path={}),
|
||||
)
|
||||
|
||||
async def get_cached_data(self, force_refresh: bool = False):
|
||||
return self._cache
|
||||
|
||||
return _StubScanner()
|
||||
|
||||
|
||||
def test_attach_dependencies_resolves_services_once(monkeypatch: pytest.MonkeyPatch):
|
||||
base_module = base_routes_module
|
||||
services_module = importlib.import_module("py_local.services.service_registry")
|
||||
registry = services_module.ServiceRegistry
|
||||
server_i18n = importlib.import_module("py_local.services.server_i18n").server_i18n
|
||||
|
||||
scanner = _make_stub_scanner()
|
||||
civitai_client = object()
|
||||
filter_calls = Counter()
|
||||
|
||||
async def fake_get_recipe_scanner():
|
||||
return scanner
|
||||
|
||||
async def fake_get_civitai_client():
|
||||
return civitai_client
|
||||
|
||||
def fake_create_filter():
|
||||
filter_calls["create_filter"] += 1
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(registry, "get_recipe_scanner", fake_get_recipe_scanner)
|
||||
monkeypatch.setattr(registry, "get_civitai_client", fake_get_civitai_client)
|
||||
monkeypatch.setattr(server_i18n, "create_template_filter", fake_create_filter)
|
||||
|
||||
async def scenario():
|
||||
routes = base_module.BaseRecipeRoutes()
|
||||
|
||||
await routes.attach_dependencies()
|
||||
await routes.attach_dependencies() # idempotent
|
||||
|
||||
assert routes.recipe_scanner is scanner
|
||||
assert routes.lora_scanner is scanner._lora_scanner
|
||||
assert routes.civitai_client is civitai_client
|
||||
assert routes.template_env.filters["t"] is not None
|
||||
assert filter_calls["create_filter"] == 1
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_register_startup_hooks_appends_once():
|
||||
routes = base_routes_module.BaseRecipeRoutes()
|
||||
|
||||
app = web.Application()
|
||||
routes.register_startup_hooks(app)
|
||||
routes.register_startup_hooks(app)
|
||||
|
||||
startup_bound_to_routes = [
|
||||
callback for callback in app.on_startup if getattr(callback, "__self__", None) is routes
|
||||
]
|
||||
|
||||
assert routes.attach_dependencies in startup_bound_to_routes
|
||||
assert routes.prewarm_cache in startup_bound_to_routes
|
||||
assert len(startup_bound_to_routes) == 2
|
||||
|
||||
|
||||
def test_to_route_mapping_uses_handler_set():
|
||||
class DummyHandlerSet:
|
||||
def __init__(self):
|
||||
self.calls = 0
|
||||
|
||||
def to_route_mapping(self):
|
||||
self.calls += 1
|
||||
|
||||
async def render_page(request): # pragma: no cover - simple coroutine
|
||||
return web.Response(text="ok")
|
||||
|
||||
return {"render_page": render_page}
|
||||
|
||||
class DummyRoutes(base_routes_module.BaseRecipeRoutes):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.created = 0
|
||||
|
||||
def _create_handler_set(self): # noqa: D401 - simple override for test
|
||||
self.created += 1
|
||||
return DummyHandlerSet()
|
||||
|
||||
routes = DummyRoutes()
|
||||
mapping = routes.to_route_mapping()
|
||||
|
||||
assert set(mapping.keys()) == {"render_page"}
|
||||
assert asyncio.iscoroutinefunction(mapping["render_page"])
|
||||
# Cached mapping reused on subsequent calls
|
||||
assert routes.to_route_mapping() is mapping
|
||||
# Handler set cached for get_handler_owner callers
|
||||
assert isinstance(routes.get_handler_owner(), DummyHandlerSet)
|
||||
assert routes.created == 1
|
||||
|
||||
|
||||
def test_recipe_route_registrar_binds_every_route():
|
||||
class FakeRouter:
|
||||
def __init__(self):
|
||||
self.calls: list[tuple[str, str, Callable[..., Awaitable[Any]]]] = []
|
||||
|
||||
def add_get(self, path, handler):
|
||||
self.calls.append(("GET", path, handler))
|
||||
|
||||
def add_post(self, path, handler):
|
||||
self.calls.append(("POST", path, handler))
|
||||
|
||||
def add_put(self, path, handler):
|
||||
self.calls.append(("PUT", path, handler))
|
||||
|
||||
def add_delete(self, path, handler):
|
||||
self.calls.append(("DELETE", path, handler))
|
||||
|
||||
class FakeApp:
|
||||
def __init__(self):
|
||||
self.router = FakeRouter()
|
||||
|
||||
app = FakeApp()
|
||||
registrar = registrar_module.RecipeRouteRegistrar(app)
|
||||
|
||||
handler_mapping = {
|
||||
definition.handler_name: object()
|
||||
for definition in registrar_module.ROUTE_DEFINITIONS
|
||||
}
|
||||
|
||||
registrar.register_routes(handler_mapping)
|
||||
|
||||
assert {
|
||||
(method, path)
|
||||
for method, path, _ in app.router.calls
|
||||
} == {(d.method, d.path) for d in registrar_module.ROUTE_DEFINITIONS}
|
||||
|
||||
|
||||
def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPatch):
|
||||
registered_mappings: list[Dict[str, Callable[..., Awaitable[Any]]]] = []
|
||||
|
||||
class DummyRegistrar:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
def register_routes(self, mapping):
|
||||
registered_mappings.append(mapping)
|
||||
|
||||
monkeypatch.setattr(recipe_routes_module, "RecipeRouteRegistrar", DummyRegistrar)
|
||||
|
||||
expected_mapping = {name: object() for name in ("render_page", "list_recipes")}
|
||||
|
||||
def fake_to_route_mapping(self):
|
||||
return expected_mapping
|
||||
|
||||
monkeypatch.setattr(base_routes_module.BaseRecipeRoutes, "to_route_mapping", fake_to_route_mapping)
|
||||
monkeypatch.setattr(
|
||||
base_routes_module.BaseRecipeRoutes,
|
||||
"_HANDLER_NAMES",
|
||||
tuple(expected_mapping.keys()),
|
||||
)
|
||||
|
||||
app = web.Application()
|
||||
recipe_routes_module.RecipeRoutes.setup_routes(app)
|
||||
|
||||
assert registered_mappings == [expected_mapping]
|
||||
recipe_callbacks = {
|
||||
cb
|
||||
for cb in app.on_startup
|
||||
if isinstance(getattr(cb, "__self__", None), recipe_routes_module.RecipeRoutes)
|
||||
}
|
||||
assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes_module.RecipeRoutes}
|
||||
assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies", "prewarm_cache"}
|
||||
330
tests/routes/test_recipe_routes.py
Normal file
330
tests/routes/test_recipe_routes.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""Integration smoke tests for the recipe route stack."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
from aiohttp import FormData, web
|
||||
from aiohttp.test_utils import TestClient, TestServer
|
||||
|
||||
from py.config import config
|
||||
from py.routes import base_recipe_routes
|
||||
from py.routes.recipe_routes import RecipeRoutes
|
||||
from py.services.recipes import RecipeValidationError
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecipeRouteHarness:
|
||||
"""Container exposing the aiohttp client and stubbed collaborators."""
|
||||
|
||||
client: TestClient
|
||||
scanner: "StubRecipeScanner"
|
||||
analysis: "StubAnalysisService"
|
||||
persistence: "StubPersistenceService"
|
||||
sharing: "StubSharingService"
|
||||
tmp_dir: Path
|
||||
|
||||
|
||||
class StubRecipeScanner:
|
||||
"""Minimal scanner double with the surface used by the handlers."""
|
||||
|
||||
def __init__(self, base_dir: Path) -> None:
|
||||
self.recipes_dir = str(base_dir / "recipes")
|
||||
self.listing_items: List[Dict[str, Any]] = []
|
||||
self.cached_raw: List[Dict[str, Any]] = []
|
||||
self.recipes: Dict[str, Dict[str, Any]] = {}
|
||||
self.removed: List[str] = []
|
||||
|
||||
async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner
|
||||
return None
|
||||
|
||||
self._lora_scanner = SimpleNamespace( # mimic BaseRecipeRoutes expectations
|
||||
get_cached_data=_noop_get_cached_data,
|
||||
_hash_index=SimpleNamespace(_hash_to_path={}),
|
||||
)
|
||||
|
||||
async def get_cached_data(self, force_refresh: bool = False) -> SimpleNamespace: # noqa: ARG002 - flag unused by stub
|
||||
return SimpleNamespace(raw_data=list(self.cached_raw))
|
||||
|
||||
async def get_paginated_data(self, **params: Any) -> Dict[str, Any]:
|
||||
items = [dict(item) for item in self.listing_items]
|
||||
page = int(params.get("page", 1))
|
||||
page_size = int(params.get("page_size", 20))
|
||||
return {
|
||||
"items": items,
|
||||
"total": len(items),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": max(1, (len(items) + page_size - 1) // max(page_size, 1)),
|
||||
}
|
||||
|
||||
async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self.recipes.get(recipe_id)
|
||||
|
||||
async def remove_recipe(self, recipe_id: str) -> None:
|
||||
self.removed.append(recipe_id)
|
||||
self.recipes.pop(recipe_id, None)
|
||||
|
||||
|
||||
class StubAnalysisService:
|
||||
"""Captures calls made by analysis routes while returning canned responses."""
|
||||
|
||||
instances: List["StubAnalysisService"] = []
|
||||
|
||||
def __init__(self, **_: Any) -> None:
|
||||
self.raise_for_uploaded: Optional[Exception] = None
|
||||
self.raise_for_remote: Optional[Exception] = None
|
||||
self.raise_for_local: Optional[Exception] = None
|
||||
self.upload_calls: List[bytes] = []
|
||||
self.remote_calls: List[Optional[str]] = []
|
||||
self.local_calls: List[Optional[str]] = []
|
||||
self.result = SimpleNamespace(payload={"loras": []}, status=200)
|
||||
StubAnalysisService.instances.append(self)
|
||||
|
||||
async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature
|
||||
if self.raise_for_uploaded:
|
||||
raise self.raise_for_uploaded
|
||||
self.upload_calls.append(image_bytes or b"")
|
||||
return self.result
|
||||
|
||||
async def analyze_remote_image(self, *, url: Optional[str], recipe_scanner, civitai_client) -> SimpleNamespace: # noqa: D401
|
||||
if self.raise_for_remote:
|
||||
raise self.raise_for_remote
|
||||
self.remote_calls.append(url)
|
||||
return self.result
|
||||
|
||||
async def analyze_local_image(self, *, file_path: Optional[str], recipe_scanner) -> SimpleNamespace: # noqa: D401
|
||||
if self.raise_for_local:
|
||||
raise self.raise_for_local
|
||||
self.local_calls.append(file_path)
|
||||
return self.result
|
||||
|
||||
async def analyze_widget_metadata(self, *, recipe_scanner) -> SimpleNamespace:
|
||||
return SimpleNamespace(payload={"metadata": {}, "image_bytes": b""}, status=200)
|
||||
|
||||
|
||||
class StubPersistenceService:
|
||||
"""Stub for persistence operations to avoid filesystem writes."""
|
||||
|
||||
instances: List["StubPersistenceService"] = []
|
||||
|
||||
def __init__(self, **_: Any) -> None:
|
||||
self.save_calls: List[Dict[str, Any]] = []
|
||||
self.delete_calls: List[str] = []
|
||||
self.save_result = SimpleNamespace(payload={"success": True, "recipe_id": "stub-id"}, status=200)
|
||||
self.delete_result = SimpleNamespace(payload={"success": True}, status=200)
|
||||
StubPersistenceService.instances.append(self)
|
||||
|
||||
async def save_recipe(self, *, recipe_scanner, image_bytes, image_base64, name, tags, metadata) -> SimpleNamespace: # noqa: D401
|
||||
self.save_calls.append(
|
||||
{
|
||||
"recipe_scanner": recipe_scanner,
|
||||
"image_bytes": image_bytes,
|
||||
"image_base64": image_base64,
|
||||
"name": name,
|
||||
"tags": list(tags),
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return self.save_result
|
||||
|
||||
async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace:
|
||||
self.delete_calls.append(recipe_id)
|
||||
await recipe_scanner.remove_recipe(recipe_id)
|
||||
return self.delete_result
|
||||
|
||||
async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]) -> SimpleNamespace: # pragma: no cover - unused by smoke tests
|
||||
return SimpleNamespace(payload={"success": True, "recipe_id": recipe_id, "updates": updates}, status=200)
|
||||
|
||||
async def reconnect_lora(self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str) -> SimpleNamespace: # pragma: no cover
|
||||
return SimpleNamespace(payload={"success": True}, status=200)
|
||||
|
||||
async def bulk_delete(self, *, recipe_scanner, recipe_ids: List[str]) -> SimpleNamespace: # pragma: no cover
|
||||
return SimpleNamespace(payload={"success": True, "deleted": recipe_ids}, status=200)
|
||||
|
||||
async def save_recipe_from_widget(self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes) -> SimpleNamespace: # pragma: no cover
|
||||
return SimpleNamespace(payload={"success": True}, status=200)
|
||||
|
||||
|
||||
class StubSharingService:
|
||||
"""Share service stub recording requests and returning canned responses."""
|
||||
|
||||
instances: List["StubSharingService"] = []
|
||||
|
||||
def __init__(self, *, ttl_seconds: int = 300, logger) -> None: # noqa: ARG002 - ttl unused in stub
|
||||
self.share_calls: List[str] = []
|
||||
self.download_calls: List[str] = []
|
||||
self.share_result = SimpleNamespace(
|
||||
payload={"success": True, "download_url": "/share/stub", "filename": "recipe.png"},
|
||||
status=200,
|
||||
)
|
||||
self.download_info = SimpleNamespace(file_path="", download_filename="")
|
||||
StubSharingService.instances.append(self)
|
||||
|
||||
async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace:
|
||||
self.share_calls.append(recipe_id)
|
||||
return self.share_result
|
||||
|
||||
async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace:
|
||||
self.download_calls.append(recipe_id)
|
||||
return self.download_info
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]:
|
||||
"""Context manager that yields a fully wired recipe route harness."""
|
||||
|
||||
StubAnalysisService.instances.clear()
|
||||
StubPersistenceService.instances.clear()
|
||||
StubSharingService.instances.clear()
|
||||
|
||||
scanner = StubRecipeScanner(tmp_path)
|
||||
|
||||
async def fake_get_recipe_scanner():
|
||||
return scanner
|
||||
|
||||
async def fake_get_civitai_client():
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner)
|
||||
monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService)
|
||||
monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService)
|
||||
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False)
|
||||
|
||||
app = web.Application()
|
||||
RecipeRoutes.setup_routes(app)
|
||||
|
||||
server = TestServer(app)
|
||||
client = TestClient(server)
|
||||
await client.start_server()
|
||||
|
||||
harness = RecipeRouteHarness(
|
||||
client=client,
|
||||
scanner=scanner,
|
||||
analysis=StubAnalysisService.instances[-1],
|
||||
persistence=StubPersistenceService.instances[-1],
|
||||
sharing=StubSharingService.instances[-1],
|
||||
tmp_dir=tmp_path,
|
||||
)
|
||||
|
||||
try:
|
||||
yield harness
|
||||
finally:
|
||||
await client.close()
|
||||
StubAnalysisService.instances.clear()
|
||||
StubPersistenceService.instances.clear()
|
||||
StubSharingService.instances.clear()
|
||||
|
||||
|
||||
async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
recipe_path = harness.tmp_dir / "recipes" / "demo.png"
|
||||
harness.scanner.listing_items = [
|
||||
{
|
||||
"id": "recipe-1",
|
||||
"file_path": str(recipe_path),
|
||||
"title": "Demo",
|
||||
"loras": [],
|
||||
}
|
||||
]
|
||||
harness.scanner.cached_raw = list(harness.scanner.listing_items)
|
||||
|
||||
response = await harness.client.get("/api/lm/recipes")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["items"][0]["file_url"].endswith("demo.png")
|
||||
assert payload["items"][0]["loras"] == []
|
||||
|
||||
|
||||
async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
form = FormData()
|
||||
form.add_field("image", b"stub", filename="sample.png", content_type="image/png")
|
||||
form.add_field("name", "Test Recipe")
|
||||
form.add_field("tags", json.dumps(["tag-a"]))
|
||||
form.add_field("metadata", json.dumps({"loras": []}))
|
||||
form.add_field("image_base64", "aW1hZ2U=")
|
||||
|
||||
harness.persistence.save_result = SimpleNamespace(
|
||||
payload={"success": True, "recipe_id": "saved-id"},
|
||||
status=201,
|
||||
)
|
||||
|
||||
save_response = await harness.client.post("/api/lm/recipes/save", data=form)
|
||||
save_payload = await save_response.json()
|
||||
|
||||
assert save_response.status == 201
|
||||
assert save_payload["recipe_id"] == "saved-id"
|
||||
assert harness.persistence.save_calls[-1]["name"] == "Test Recipe"
|
||||
|
||||
harness.persistence.delete_result = SimpleNamespace(payload={"success": True}, status=200)
|
||||
|
||||
delete_response = await harness.client.delete("/api/lm/recipe/saved-id")
|
||||
delete_payload = await delete_response.json()
|
||||
|
||||
assert delete_response.status == 200
|
||||
assert delete_payload["success"] is True
|
||||
assert harness.persistence.delete_calls == ["saved-id"]
|
||||
|
||||
|
||||
async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided")
|
||||
|
||||
form = FormData()
|
||||
form.add_field("image", b"", filename="empty.png", content_type="image/png")
|
||||
|
||||
response = await harness.client.post("/api/lm/recipes/analyze-image", data=form)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["error"] == "No image data provided"
|
||||
assert payload["loras"] == []
|
||||
|
||||
|
||||
async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None:
|
||||
async with recipe_harness(monkeypatch, tmp_path) as harness:
|
||||
recipe_id = "share-me"
|
||||
download_path = harness.tmp_dir / "recipes" / "share.png"
|
||||
download_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
download_path.write_bytes(b"stub")
|
||||
|
||||
harness.scanner.recipes[recipe_id] = {
|
||||
"id": recipe_id,
|
||||
"title": "Shared",
|
||||
"file_path": str(download_path),
|
||||
}
|
||||
|
||||
harness.sharing.share_result = SimpleNamespace(
|
||||
payload={"success": True, "download_url": "/api/share", "filename": "share.png"},
|
||||
status=200,
|
||||
)
|
||||
harness.sharing.download_info = SimpleNamespace(
|
||||
file_path=str(download_path),
|
||||
download_filename="share.png",
|
||||
)
|
||||
|
||||
share_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share")
|
||||
share_payload = await share_response.json()
|
||||
|
||||
assert share_response.status == 200
|
||||
assert share_payload["filename"] == "share.png"
|
||||
assert harness.sharing.share_calls == [recipe_id]
|
||||
|
||||
download_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share/download")
|
||||
body = await download_response.read()
|
||||
|
||||
assert download_response.status == 200
|
||||
assert download_response.headers["Content-Disposition"] == 'attachment; filename="share.png"'
|
||||
assert body == b"stub"
|
||||
|
||||
download_path.unlink(missing_ok=True)
|
||||
|
||||
296
tests/services/test_base_model_service.py
Normal file
296
tests/services/test_base_model_service.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import pytest
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
|
||||
def import_from(module_name: str):
|
||||
existing = sys.modules.get("py")
|
||||
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||
sys.modules.pop("py", None)
|
||||
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
module.__path__ = [str(ROOT / "py")]
|
||||
sys.modules["py"] = module
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
BaseModelService = import_from("py.services.base_model_service").BaseModelService
|
||||
model_query_module = import_from("py.services.model_query")
|
||||
ModelCacheRepository = model_query_module.ModelCacheRepository
|
||||
ModelFilterSet = model_query_module.ModelFilterSet
|
||||
SearchStrategy = model_query_module.SearchStrategy
|
||||
SortParams = model_query_module.SortParams
|
||||
BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata
|
||||
|
||||
|
||||
class StubSettings:
|
||||
def __init__(self, values):
|
||||
self._values = dict(values)
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
class DummyService(BaseModelService):
|
||||
async def format_response(self, model_data):
|
||||
return model_data
|
||||
|
||||
|
||||
class StubRepository:
|
||||
def __init__(self, data):
|
||||
self._data = list(data)
|
||||
self.parse_sort_calls = []
|
||||
self.fetch_sorted_calls = []
|
||||
|
||||
def parse_sort(self, sort_by):
|
||||
params = ModelCacheRepository.parse_sort(sort_by)
|
||||
self.parse_sort_calls.append(sort_by)
|
||||
return params
|
||||
|
||||
async def fetch_sorted(self, params):
|
||||
self.fetch_sorted_calls.append(params)
|
||||
return list(self._data)
|
||||
|
||||
|
||||
class StubFilterSet:
|
||||
def __init__(self, result):
|
||||
self.result = list(result)
|
||||
self.calls = []
|
||||
|
||||
def apply(self, data, criteria):
|
||||
self.calls.append((list(data), criteria))
|
||||
return list(self.result)
|
||||
|
||||
|
||||
class StubSearchStrategy:
|
||||
def __init__(self, search_result):
|
||||
self.search_result = list(search_result)
|
||||
self.normalize_calls = []
|
||||
self.apply_calls = []
|
||||
|
||||
def normalize_options(self, options):
|
||||
self.normalize_calls.append(options)
|
||||
normalized = {"recursive": True}
|
||||
if options:
|
||||
normalized.update(options)
|
||||
return normalized
|
||||
|
||||
def apply(self, data, search_term, options, fuzzy):
|
||||
self.apply_calls.append((list(data), search_term, options, fuzzy))
|
||||
return list(self.search_result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_uses_injected_collaborators():
|
||||
data = [
|
||||
{"model_name": "Alpha", "folder": "root"},
|
||||
{"model_name": "Beta", "folder": "root"},
|
||||
]
|
||||
repository = StubRepository(data)
|
||||
filter_set = StubFilterSet([{"model_name": "Filtered"}])
|
||||
search_strategy = StubSearchStrategy([{"model_name": "SearchResult"}])
|
||||
settings = StubSettings({})
|
||||
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=object(),
|
||||
metadata_class=BaseModelMetadata,
|
||||
cache_repository=repository,
|
||||
filter_set=filter_set,
|
||||
search_strategy=search_strategy,
|
||||
settings_provider=settings,
|
||||
)
|
||||
|
||||
response = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=5,
|
||||
sort_by="name:desc",
|
||||
folder="root",
|
||||
search="query",
|
||||
fuzzy_search=True,
|
||||
base_models=["base"],
|
||||
tags=["tag"],
|
||||
search_options={"recursive": False},
|
||||
favorites_only=True,
|
||||
)
|
||||
|
||||
assert repository.parse_sort_calls == ["name:desc"]
|
||||
assert repository.fetch_sorted_calls and isinstance(repository.fetch_sorted_calls[0], SortParams)
|
||||
sort_params = repository.fetch_sorted_calls[0]
|
||||
assert sort_params.key == "name" and sort_params.order == "desc"
|
||||
|
||||
assert filter_set.calls, "FilterSet should be invoked"
|
||||
call_data, criteria = filter_set.calls[0]
|
||||
assert call_data == data
|
||||
assert criteria.folder == "root"
|
||||
assert criteria.base_models == ["base"]
|
||||
assert criteria.tags == ["tag"]
|
||||
assert criteria.favorites_only is True
|
||||
assert criteria.search_options.get("recursive") is False
|
||||
|
||||
assert search_strategy.normalize_calls == [{"recursive": False}, {"recursive": False}]
|
||||
assert search_strategy.apply_calls == [([{"model_name": "Filtered"}], "query", {"recursive": False}, True)]
|
||||
|
||||
assert response["items"] == search_strategy.search_result
|
||||
assert response["total"] == len(search_strategy.search_result)
|
||||
assert response["page"] == 1
|
||||
assert response["page_size"] == 5
|
||||
|
||||
|
||||
class FakeCache:
|
||||
def __init__(self, items):
|
||||
self.items = list(items)
|
||||
|
||||
async def get_sorted_data(self, sort_key, order):
|
||||
if sort_key == "name":
|
||||
data = sorted(self.items, key=lambda x: x["model_name"].lower())
|
||||
if order == "desc":
|
||||
data.reverse()
|
||||
else:
|
||||
data = list(self.items)
|
||||
return data
|
||||
|
||||
|
||||
class FakeScanner:
|
||||
def __init__(self, cache):
|
||||
self._cache = cache
|
||||
|
||||
async def get_cached_data(self, *_, **__):
|
||||
return self._cache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_filters_and_searches_combination():
|
||||
items = [
|
||||
{
|
||||
"model_name": "Alpha",
|
||||
"file_name": "alpha.safetensors",
|
||||
"folder": "root/sub",
|
||||
"tags": ["tag1"],
|
||||
"base_model": "v1",
|
||||
"favorite": True,
|
||||
"preview_nsfw_level": 0,
|
||||
},
|
||||
{
|
||||
"model_name": "Beta",
|
||||
"file_name": "beta.safetensors",
|
||||
"folder": "root",
|
||||
"tags": ["tag2"],
|
||||
"base_model": "v2",
|
||||
"favorite": False,
|
||||
"preview_nsfw_level": 999,
|
||||
},
|
||||
{
|
||||
"model_name": "Gamma",
|
||||
"file_name": "gamma.safetensors",
|
||||
"folder": "root/sub2",
|
||||
"tags": ["tag1", "tag3"],
|
||||
"base_model": "v1",
|
||||
"favorite": True,
|
||||
"preview_nsfw_level": 0,
|
||||
"civitai": {"creator": {"username": "artist"}},
|
||||
},
|
||||
]
|
||||
|
||||
cache = FakeCache(items)
|
||||
scanner = FakeScanner(cache)
|
||||
settings = StubSettings({"show_only_sfw": True})
|
||||
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=scanner,
|
||||
metadata_class=BaseModelMetadata,
|
||||
cache_repository=ModelCacheRepository(scanner),
|
||||
filter_set=ModelFilterSet(settings),
|
||||
search_strategy=SearchStrategy(),
|
||||
settings_provider=settings,
|
||||
)
|
||||
|
||||
response = await service.get_paginated_data(
|
||||
page=1,
|
||||
page_size=1,
|
||||
sort_by="name:asc",
|
||||
folder="root",
|
||||
search="artist",
|
||||
base_models=["v1"],
|
||||
tags=["tag1"],
|
||||
search_options={"creator": True, "tags": True},
|
||||
favorites_only=True,
|
||||
)
|
||||
|
||||
assert response["items"] == [items[2]]
|
||||
assert response["total"] == 1
|
||||
assert response["page"] == 1
|
||||
assert response["page_size"] == 1
|
||||
assert response["total_pages"] == 1
|
||||
|
||||
|
||||
class PassThroughFilterSet:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def apply(self, data, criteria):
|
||||
self.calls.append(criteria)
|
||||
return list(data)
|
||||
|
||||
|
||||
class NoSearchStrategy:
|
||||
def __init__(self):
|
||||
self.normalize_calls = []
|
||||
self.apply_called = False
|
||||
|
||||
def normalize_options(self, options):
|
||||
self.normalize_calls.append(options)
|
||||
return {"recursive": True}
|
||||
|
||||
def apply(self, *args, **kwargs):
|
||||
self.apply_called = True
|
||||
pytest.fail("Search should not be invoked when no search term is provided")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_paginated_data_paginates_without_search():
|
||||
items = [
|
||||
{"model_name": name, "folder": "root"}
|
||||
for name in ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"]
|
||||
]
|
||||
|
||||
repository = StubRepository(items)
|
||||
filter_set = PassThroughFilterSet()
|
||||
search_strategy = NoSearchStrategy()
|
||||
settings = StubSettings({})
|
||||
|
||||
service = DummyService(
|
||||
model_type="stub",
|
||||
scanner=object(),
|
||||
metadata_class=BaseModelMetadata,
|
||||
cache_repository=repository,
|
||||
filter_set=filter_set,
|
||||
search_strategy=search_strategy,
|
||||
settings_provider=settings,
|
||||
)
|
||||
|
||||
response = await service.get_paginated_data(
|
||||
page=2,
|
||||
page_size=2,
|
||||
sort_by="name:asc",
|
||||
)
|
||||
|
||||
assert repository.parse_sort_calls == ["name:asc"]
|
||||
assert len(repository.fetch_sorted_calls) == 1
|
||||
assert filter_set.calls and filter_set.calls[0].favorites_only is False
|
||||
assert search_strategy.apply_called is False
|
||||
assert response["items"] == items[2:4]
|
||||
assert response["total"] == len(items)
|
||||
assert response["page"] == 2
|
||||
assert response["page_size"] == 2
|
||||
assert response["total_pages"] == 3
|
||||
228
tests/services/test_example_images_download_manager_async.py
Normal file
228
tests/services/test_example_images_download_manager_async.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import settings
|
||||
from py.utils import example_images_download_manager as download_module
|
||||
|
||||
|
||||
class RecordingWebSocketManager:
|
||||
"""Collects broadcast payloads for assertions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.payloads: list[dict] = []
|
||||
|
||||
async def broadcast(self, payload: dict) -> None:
|
||||
self.payloads.append(payload)
|
||||
|
||||
|
||||
class StubScanner:
|
||||
"""Scanner double returning predetermined cache contents."""
|
||||
|
||||
def __init__(self, models: list[dict]) -> None:
|
||||
self._cache = SimpleNamespace(raw_data=models)
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> None:
|
||||
async def _get_lora_scanner(cls):
|
||||
return scanner
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ServiceRegistry,
|
||||
"get_lora_scanner",
|
||||
classmethod(_get_lora_scanner),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
model = {
|
||||
"sha256": "abc123",
|
||||
"model_name": "Example",
|
||||
"file_path": str(tmp_path / "example.safetensors"),
|
||||
"file_name": "example.safetensors",
|
||||
}
|
||||
_patch_scanner(monkeypatch, StubScanner([model]))
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def fake_process_local_examples(*_args, **_kwargs):
|
||||
started.set()
|
||||
await release.wait()
|
||||
return True
|
||||
|
||||
async def fake_update_metadata(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
async def fake_get_downloader():
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"process_local_examples",
|
||||
staticmethod(fake_process_local_examples),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.MetadataUpdater,
|
||||
"update_metadata_from_local_examples",
|
||||
staticmethod(fake_update_metadata),
|
||||
)
|
||||
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||
|
||||
try:
|
||||
result = await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
assert result["success"] is True
|
||||
|
||||
await asyncio.wait_for(started.wait(), timeout=1)
|
||||
|
||||
with pytest.raises(download_module.DownloadInProgressError) as exc:
|
||||
await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
|
||||
snapshot = exc.value.progress_snapshot
|
||||
assert snapshot["status"] == "running"
|
||||
assert snapshot["current_model"] == "Example (abc123)"
|
||||
|
||||
statuses = [payload["status"] for payload in ws_manager.payloads]
|
||||
assert "running" in statuses
|
||||
|
||||
finally:
|
||||
release.set()
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [
|
||||
{
|
||||
"sha256": "hash-one",
|
||||
"model_name": "Model One",
|
||||
"file_path": str(tmp_path / "model-one.safetensors"),
|
||||
"file_name": "model-one.safetensors",
|
||||
"civitai": {"images": [{"url": "https://example.com/one.png"}]},
|
||||
},
|
||||
{
|
||||
"sha256": "hash-two",
|
||||
"model_name": "Model Two",
|
||||
"file_path": str(tmp_path / "model-two.safetensors"),
|
||||
"file_name": "model-two.safetensors",
|
||||
"civitai": {"images": [{"url": "https://example.com/two.png"}]},
|
||||
},
|
||||
]
|
||||
_patch_scanner(monkeypatch, StubScanner(models))
|
||||
|
||||
async def fake_process_local_examples(*_args, **_kwargs):
|
||||
return False
|
||||
|
||||
async def fake_update_metadata(*_args, **_kwargs):
|
||||
return True
|
||||
|
||||
first_call_started = asyncio.Event()
|
||||
first_release = asyncio.Event()
|
||||
second_call_started = asyncio.Event()
|
||||
call_order: list[str] = []
|
||||
|
||||
async def fake_download_model_images(model_hash, *_args, **_kwargs):
|
||||
call_order.append(model_hash)
|
||||
if len(call_order) == 1:
|
||||
first_call_started.set()
|
||||
await first_release.wait()
|
||||
else:
|
||||
second_call_started.set()
|
||||
return True, False
|
||||
|
||||
async def fake_get_downloader():
|
||||
class _Downloader:
|
||||
async def download_to_memory(self, *_a, **_kw):
|
||||
return True, b"", {}
|
||||
|
||||
return _Downloader()
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"process_local_examples",
|
||||
staticmethod(fake_process_local_examples),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.MetadataUpdater,
|
||||
"update_metadata_from_local_examples",
|
||||
staticmethod(fake_update_metadata),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ExampleImagesProcessor,
|
||||
"download_model_images",
|
||||
staticmethod(fake_download_model_images),
|
||||
)
|
||||
monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader)
|
||||
|
||||
original_sleep = download_module.asyncio.sleep
|
||||
pause_gate = asyncio.Event()
|
||||
resume_gate = asyncio.Event()
|
||||
|
||||
async def fake_sleep(delay: float):
|
||||
if delay == 1:
|
||||
pause_gate.set()
|
||||
await resume_gate.wait()
|
||||
else:
|
||||
await original_sleep(delay)
|
||||
|
||||
monkeypatch.setattr(download_module.asyncio, "sleep", fake_sleep)
|
||||
|
||||
try:
|
||||
await manager.start_download({"model_types": ["lora"], "delay": 0})
|
||||
|
||||
await asyncio.wait_for(first_call_started.wait(), timeout=1)
|
||||
|
||||
await manager.pause_download({})
|
||||
|
||||
first_release.set()
|
||||
|
||||
await asyncio.wait_for(pause_gate.wait(), timeout=1)
|
||||
assert manager._progress["status"] == "paused"
|
||||
assert not second_call_started.is_set()
|
||||
|
||||
statuses = [payload["status"] for payload in ws_manager.payloads]
|
||||
paused_index = statuses.index("paused")
|
||||
|
||||
await asyncio.sleep(0)
|
||||
assert not second_call_started.is_set()
|
||||
|
||||
await manager.resume_download({})
|
||||
resume_gate.set()
|
||||
|
||||
await asyncio.wait_for(second_call_started.wait(), timeout=1)
|
||||
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
|
||||
statuses_after = [payload["status"] for payload in ws_manager.payloads]
|
||||
running_after = next(
|
||||
i for i, status in enumerate(statuses_after[paused_index + 1 :], start=paused_index + 1) if status == "running"
|
||||
)
|
||||
assert running_after > paused_index
|
||||
assert "completed" in statuses_after[running_after:]
|
||||
assert call_order == ["hash-one", "hash-two"]
|
||||
|
||||
finally:
|
||||
first_release.set()
|
||||
resume_gate.set()
|
||||
if manager._download_task is not None:
|
||||
await asyncio.wait_for(manager._download_task, timeout=1)
|
||||
monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep)
|
||||
185
tests/services/test_recipe_scanner.py
Normal file
185
tests/services/test_recipe_scanner.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.config import config
|
||||
from py.services.recipe_scanner import RecipeScanner
|
||||
from py.utils.utils import calculate_recipe_fingerprint
|
||||
|
||||
|
||||
class StubHashIndex:
|
||||
def __init__(self) -> None:
|
||||
self._hash_to_path: dict[str, str] = {}
|
||||
|
||||
def get_path(self, hash_value: str) -> str | None:
|
||||
return self._hash_to_path.get(hash_value)
|
||||
|
||||
|
||||
class StubLoraScanner:
|
||||
def __init__(self) -> None:
|
||||
self._hash_index = StubHashIndex()
|
||||
self._hash_meta: dict[str, dict[str, str]] = {}
|
||||
self._models_by_name: dict[str, dict] = {}
|
||||
self._cache = SimpleNamespace(raw_data=[])
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self._cache
|
||||
|
||||
def has_hash(self, hash_value: str) -> bool:
|
||||
return hash_value.lower() in self._hash_meta
|
||||
|
||||
def get_preview_url_by_hash(self, hash_value: str) -> str:
|
||||
meta = self._hash_meta.get(hash_value.lower())
|
||||
return meta.get("preview_url", "") if meta else ""
|
||||
|
||||
def get_path_by_hash(self, hash_value: str) -> str | None:
|
||||
meta = self._hash_meta.get(hash_value.lower())
|
||||
return meta.get("path") if meta else None
|
||||
|
||||
async def get_model_info_by_name(self, name: str):
|
||||
return self._models_by_name.get(name)
|
||||
|
||||
def register_model(self, name: str, info: dict) -> None:
|
||||
self._models_by_name[name] = info
|
||||
hash_value = (info.get("sha256") or "").lower()
|
||||
if hash_value:
|
||||
self._hash_meta[hash_value] = {
|
||||
"path": info.get("file_path", ""),
|
||||
"preview_url": info.get("preview_url", ""),
|
||||
}
|
||||
self._hash_index._hash_to_path[hash_value] = info.get("file_path", "")
|
||||
self._cache.raw_data.append({
|
||||
"sha256": info.get("sha256", ""),
|
||||
"path": info.get("file_path", ""),
|
||||
"civitai": info.get("civitai", {}),
|
||||
})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recipe_scanner(tmp_path: Path, monkeypatch):
|
||||
RecipeScanner._instance = None
|
||||
monkeypatch.setattr(config, "loras_roots", [str(tmp_path)])
|
||||
stub = StubLoraScanner()
|
||||
scanner = RecipeScanner(lora_scanner=stub)
|
||||
asyncio.run(scanner.refresh_cache(force=True))
|
||||
yield scanner, stub
|
||||
RecipeScanner._instance = None
|
||||
|
||||
|
||||
async def test_add_recipe_during_concurrent_reads(recipe_scanner):
|
||||
scanner, _ = recipe_scanner
|
||||
|
||||
initial_recipe = {
|
||||
"id": "one",
|
||||
"file_path": "path/a.png",
|
||||
"title": "First",
|
||||
"modified": 1.0,
|
||||
"created_date": 1.0,
|
||||
"loras": [],
|
||||
}
|
||||
await scanner.add_recipe(initial_recipe)
|
||||
|
||||
new_recipe = {
|
||||
"id": "two",
|
||||
"file_path": "path/b.png",
|
||||
"title": "Second",
|
||||
"modified": 2.0,
|
||||
"created_date": 2.0,
|
||||
"loras": [],
|
||||
}
|
||||
|
||||
async def reader_task():
|
||||
for _ in range(5):
|
||||
cache = await scanner.get_cached_data()
|
||||
_ = [item["id"] for item in cache.raw_data]
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await asyncio.gather(reader_task(), reader_task(), scanner.add_recipe(new_recipe))
|
||||
await asyncio.sleep(0)
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
assert {item["id"] for item in cache.raw_data} == {"one", "two"}
|
||||
assert len(cache.sorted_by_name) == len(cache.raw_data)
|
||||
|
||||
|
||||
async def test_remove_recipe_during_reads(recipe_scanner):
|
||||
scanner, _ = recipe_scanner
|
||||
|
||||
recipe_ids = ["alpha", "beta", "gamma"]
|
||||
for index, recipe_id in enumerate(recipe_ids):
|
||||
await scanner.add_recipe({
|
||||
"id": recipe_id,
|
||||
"file_path": f"path/{recipe_id}.png",
|
||||
"title": recipe_id,
|
||||
"modified": float(index),
|
||||
"created_date": float(index),
|
||||
"loras": [],
|
||||
})
|
||||
|
||||
async def reader_task():
|
||||
for _ in range(5):
|
||||
cache = await scanner.get_cached_data()
|
||||
_ = list(cache.sorted_by_date)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
await asyncio.gather(reader_task(), scanner.remove_recipe("beta"))
|
||||
await asyncio.sleep(0)
|
||||
cache = await scanner.get_cached_data()
|
||||
|
||||
assert {item["id"] for item in cache.raw_data} == {"alpha", "gamma"}
|
||||
|
||||
|
||||
async def test_update_lora_entry_updates_cache_and_file(tmp_path: Path, recipe_scanner):
|
||||
scanner, stub = recipe_scanner
|
||||
recipes_dir = Path(config.loras_roots[0]) / "recipes"
|
||||
recipes_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
recipe_id = "recipe-1"
|
||||
recipe_path = recipes_dir / f"{recipe_id}.recipe.json"
|
||||
recipe_data = {
|
||||
"id": recipe_id,
|
||||
"file_path": str(tmp_path / "image.png"),
|
||||
"title": "Original",
|
||||
"modified": 0.0,
|
||||
"created_date": 0.0,
|
||||
"loras": [
|
||||
{"file_name": "old", "strength": 1.0, "hash": "", "isDeleted": True, "exclude": True},
|
||||
],
|
||||
}
|
||||
recipe_path.write_text(json.dumps(recipe_data))
|
||||
|
||||
await scanner.add_recipe(dict(recipe_data))
|
||||
|
||||
target_hash = "abc123"
|
||||
target_info = {
|
||||
"sha256": target_hash,
|
||||
"file_path": str(tmp_path / "loras" / "target.safetensors"),
|
||||
"preview_url": "preview.png",
|
||||
"civitai": {"id": 42, "name": "v1", "model": {"name": "Target"}},
|
||||
}
|
||||
stub.register_model("target", target_info)
|
||||
|
||||
updated_recipe, updated_lora = await scanner.update_lora_entry(
|
||||
recipe_id,
|
||||
0,
|
||||
target_name="target",
|
||||
target_lora=target_info,
|
||||
)
|
||||
|
||||
assert updated_lora["inLibrary"] is True
|
||||
assert updated_lora["localPath"] == target_info["file_path"]
|
||||
assert updated_lora["hash"] == target_hash
|
||||
|
||||
with recipe_path.open("r", encoding="utf-8") as file_obj:
|
||||
persisted = json.load(file_obj)
|
||||
|
||||
expected_fingerprint = calculate_recipe_fingerprint(persisted["loras"])
|
||||
assert persisted["fingerprint"] == expected_fingerprint
|
||||
|
||||
cache = await scanner.get_cached_data()
|
||||
cached_recipe = next(item for item in cache.raw_data if item["id"] == recipe_id)
|
||||
assert cached_recipe["loras"][0]["hash"] == target_hash
|
||||
assert cached_recipe["fingerprint"] == expected_fingerprint
|
||||
150
tests/services/test_recipe_services.py
Normal file
150
tests/services/test_recipe_services.py
Normal file
@@ -0,0 +1,150 @@
|
||||
import logging
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.recipes.analysis_service import RecipeAnalysisService
|
||||
from py.services.recipes.errors import RecipeDownloadError, RecipeNotFoundError
|
||||
from py.services.recipes.persistence_service import RecipePersistenceService
|
||||
|
||||
|
||||
class DummyExifUtils:
|
||||
def optimize_image(self, image_data, target_width, format, quality, preserve_metadata):
|
||||
return image_data, ".webp"
|
||||
|
||||
def append_recipe_metadata(self, image_path, recipe_data):
|
||||
self.appended = (image_path, recipe_data)
|
||||
|
||||
def extract_image_metadata(self, path):
|
||||
return {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_remote_image_download_failure_cleans_temp(tmp_path, monkeypatch):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyFactory:
|
||||
def create_parser(self, metadata):
|
||||
return None
|
||||
|
||||
async def downloader_factory():
|
||||
class Downloader:
|
||||
async def download_file(self, url, path, use_auth=False):
|
||||
return False, "failure"
|
||||
|
||||
return Downloader()
|
||||
|
||||
service = RecipeAnalysisService(
|
||||
exif_utils=exif_utils,
|
||||
recipe_parser_factory=DummyFactory(),
|
||||
downloader_factory=downloader_factory,
|
||||
metadata_collector=None,
|
||||
metadata_processor_cls=None,
|
||||
metadata_registry_cls=None,
|
||||
standalone_mode=False,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
temp_path = tmp_path / "temp.jpg"
|
||||
|
||||
def create_temp_path():
|
||||
temp_path.write_bytes(b"")
|
||||
return str(temp_path)
|
||||
|
||||
monkeypatch.setattr(service, "_create_temp_path", create_temp_path)
|
||||
|
||||
with pytest.raises(RecipeDownloadError):
|
||||
await service.analyze_remote_image(
|
||||
url="https://example.com/image.jpg",
|
||||
recipe_scanner=SimpleNamespace(),
|
||||
civitai_client=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert not temp_path.exists(), "temporary file should be cleaned after failure"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_local_image_missing_file(tmp_path):
|
||||
async def downloader_factory():
|
||||
return SimpleNamespace()
|
||||
|
||||
service = RecipeAnalysisService(
|
||||
exif_utils=DummyExifUtils(),
|
||||
recipe_parser_factory=SimpleNamespace(create_parser=lambda metadata: None),
|
||||
downloader_factory=downloader_factory,
|
||||
metadata_collector=None,
|
||||
metadata_processor_cls=None,
|
||||
metadata_registry_cls=None,
|
||||
standalone_mode=False,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
with pytest.raises(RecipeNotFoundError):
|
||||
await service.analyze_local_image(
|
||||
file_path=str(tmp_path / "missing.png"),
|
||||
recipe_scanner=SimpleNamespace(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_recipe_reports_duplicates(tmp_path):
|
||||
exif_utils = DummyExifUtils()
|
||||
|
||||
class DummyCache:
|
||||
def __init__(self):
|
||||
self.raw_data = []
|
||||
|
||||
async def resort(self):
|
||||
pass
|
||||
|
||||
class DummyScanner:
|
||||
def __init__(self, root):
|
||||
self.recipes_dir = str(root)
|
||||
self._cache = DummyCache()
|
||||
self.last_fingerprint = None
|
||||
|
||||
async def find_recipes_by_fingerprint(self, fingerprint):
|
||||
self.last_fingerprint = fingerprint
|
||||
return ["existing"]
|
||||
|
||||
async def add_recipe(self, recipe_data):
|
||||
self._cache.raw_data.append(recipe_data)
|
||||
await self._cache.resort()
|
||||
|
||||
scanner = DummyScanner(tmp_path)
|
||||
service = RecipePersistenceService(
|
||||
exif_utils=exif_utils,
|
||||
card_preview_width=512,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"base_model": "sd",
|
||||
"loras": [
|
||||
{
|
||||
"file_name": "sample",
|
||||
"hash": "abc123",
|
||||
"weight": 0.5,
|
||||
"id": 1,
|
||||
"name": "Sample",
|
||||
"version": "v1",
|
||||
"isDeleted": False,
|
||||
"exclude": False,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
result = await service.save_recipe(
|
||||
recipe_scanner=scanner,
|
||||
image_bytes=b"image-bytes",
|
||||
image_base64=None,
|
||||
name="My Recipe",
|
||||
tags=["tag"],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result.payload["matching_recipes"] == ["existing"]
|
||||
assert scanner.last_fingerprint is not None
|
||||
assert os.path.exists(result.payload["json_path"])
|
||||
assert scanner._cache.raw_data
|
||||
273
tests/services/test_route_support_services.py
Normal file
273
tests/services/test_route_support_services.py
Normal file
@@ -0,0 +1,273 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(ROOT))
|
||||
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def import_from(module_name: str):
|
||||
existing = sys.modules.get("py")
|
||||
if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"):
|
||||
sys.modules.pop("py", None)
|
||||
spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec and spec.loader
|
||||
spec.loader.exec_module(module) # type: ignore[union-attr]
|
||||
module.__path__ = [str(ROOT / "py")]
|
||||
sys.modules["py"] = module
|
||||
return importlib.import_module(module_name)
|
||||
|
||||
|
||||
DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator
|
||||
MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService
|
||||
PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService
|
||||
TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService
|
||||
|
||||
|
||||
class DummySettings:
|
||||
def __init__(self, values: Dict[str, Any] | None = None) -> None:
|
||||
self._values = values or {}
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._values.get(key, default)
|
||||
|
||||
|
||||
class RecordingMetadataManager:
|
||||
def __init__(self) -> None:
|
||||
self.saved: List[tuple[str, Dict[str, Any]]] = []
|
||||
|
||||
async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool:
|
||||
self.saved.append((path, json.loads(json.dumps(metadata))))
|
||||
metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json"
|
||||
Path(metadata_path).write_text(json.dumps(metadata))
|
||||
return True
|
||||
|
||||
|
||||
class RecordingPreviewService:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[tuple[str, List[Dict[str, Any]]]] = []
|
||||
|
||||
async def ensure_preview_for_metadata(
|
||||
self, metadata_path: str, local_metadata: Dict[str, Any], images
|
||||
) -> None:
|
||||
self.calls.append((metadata_path, list(images or [])))
|
||||
local_metadata["preview_url"] = "preview.webp"
|
||||
local_metadata["preview_nsfw_level"] = 1
|
||||
|
||||
|
||||
class DummyProvider:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self.payload = payload
|
||||
|
||||
async def get_model_by_hash(self, sha256: str):
|
||||
return self.payload, None
|
||||
|
||||
async def get_model_version(self, model_id: int, model_version_id: int | None):
|
||||
return self.payload
|
||||
|
||||
|
||||
class FakeExifUtils:
|
||||
@staticmethod
|
||||
def optimize_image(**kwargs):
|
||||
return kwargs["image_data"], {}
|
||||
|
||||
|
||||
def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None:
|
||||
manager = RecordingMetadataManager()
|
||||
preview = RecordingPreviewService()
|
||||
provider = DummyProvider({
|
||||
"baseModel": "SD15",
|
||||
"model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}},
|
||||
"trainedWords": ["word"],
|
||||
"images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}],
|
||||
})
|
||||
|
||||
service = MetadataSyncService(
|
||||
metadata_manager=manager,
|
||||
preview_service=preview,
|
||||
settings=DummySettings(),
|
||||
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||
)
|
||||
|
||||
metadata_path = str(tmp_path / "model.metadata.json")
|
||||
local_metadata = {"civitai": {"trainedWords": ["existing"]}}
|
||||
|
||||
updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload))
|
||||
|
||||
assert updated["model_name"] == "Merged"
|
||||
assert updated["modelDescription"] == "desc"
|
||||
assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"}
|
||||
assert manager.saved
|
||||
assert preview.calls
|
||||
|
||||
|
||||
def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None:
|
||||
manager = RecordingMetadataManager()
|
||||
preview = RecordingPreviewService()
|
||||
provider = DummyProvider({
|
||||
"baseModel": "SDXL",
|
||||
"model": {"name": "Updated"},
|
||||
"images": [],
|
||||
})
|
||||
|
||||
update_cache_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||
update_cache_calls.append({"original": original, "metadata": metadata})
|
||||
return True
|
||||
|
||||
service = MetadataSyncService(
|
||||
metadata_manager=manager,
|
||||
preview_service=preview,
|
||||
settings=DummySettings(),
|
||||
default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider),
|
||||
metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider),
|
||||
)
|
||||
|
||||
model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")}
|
||||
success, error = asyncio.run(
|
||||
service.fetch_and_update_model(
|
||||
sha256="abc",
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
model_data=model_data,
|
||||
update_cache_func=update_cache,
|
||||
)
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert error is None
|
||||
assert update_cache_calls
|
||||
assert manager.saved
|
||||
|
||||
|
||||
def test_preview_asset_service_replace_preview(tmp_path: Path) -> None:
|
||||
metadata_path = tmp_path / "sample.metadata.json"
|
||||
metadata_path.write_text(json.dumps({}))
|
||||
|
||||
async def metadata_loader(path: str) -> Dict[str, Any]:
|
||||
return json.loads(Path(path).read_text())
|
||||
|
||||
manager = RecordingMetadataManager()
|
||||
|
||||
service = PreviewAssetService(
|
||||
metadata_manager=manager,
|
||||
downloader_factory=lambda: asyncio.sleep(0, result=None),
|
||||
exif_utils=FakeExifUtils(),
|
||||
)
|
||||
|
||||
preview_calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool:
|
||||
preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw})
|
||||
return True
|
||||
|
||||
model_path = str(tmp_path / "sample.safetensors")
|
||||
Path(model_path).write_bytes(b"model")
|
||||
|
||||
result = asyncio.run(
|
||||
service.replace_preview(
|
||||
model_path=model_path,
|
||||
preview_data=b"image-bytes",
|
||||
content_type="image/png",
|
||||
original_filename="preview.png",
|
||||
nsfw_level=2,
|
||||
update_preview_in_cache=update_preview,
|
||||
metadata_loader=metadata_loader,
|
||||
)
|
||||
)
|
||||
|
||||
assert result["preview_nsfw_level"] == 2
|
||||
assert preview_calls
|
||||
saved_metadata = json.loads(metadata_path.read_text())
|
||||
assert saved_metadata["preview_nsfw_level"] == 2
|
||||
|
||||
|
||||
def test_download_coordinator_emits_progress() -> None:
|
||||
class WSStub:
|
||||
def __init__(self) -> None:
|
||||
self.progress_events: List[Dict[str, Any]] = []
|
||||
self.counter = 0
|
||||
|
||||
def generate_download_id(self) -> str:
|
||||
self.counter += 1
|
||||
return f"dl-{self.counter}"
|
||||
|
||||
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
||||
self.progress_events.append({"id": download_id, **payload})
|
||||
|
||||
class DownloadManagerStub:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def download_from_civitai(self, **kwargs) -> Dict[str, Any]:
|
||||
self.calls.append(kwargs)
|
||||
await kwargs["progress_callback"](10)
|
||||
return {"success": True}
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict[str, Any]:
|
||||
return {"success": True, "download_id": download_id}
|
||||
|
||||
async def get_active_downloads(self) -> Dict[str, Any]:
|
||||
return {"active": []}
|
||||
|
||||
ws_stub = WSStub()
|
||||
manager_stub = DownloadManagerStub()
|
||||
|
||||
coordinator = DownloadCoordinator(
|
||||
ws_manager=ws_stub,
|
||||
download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub),
|
||||
)
|
||||
|
||||
result = asyncio.run(coordinator.schedule_download({"model_id": 1}))
|
||||
|
||||
assert result["success"] is True
|
||||
assert manager_stub.calls
|
||||
assert ws_stub.progress_events
|
||||
|
||||
cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"]))
|
||||
assert cancel_result["success"] is True
|
||||
|
||||
active = asyncio.run(coordinator.list_active_downloads())
|
||||
assert active == {"active": []}
|
||||
|
||||
|
||||
def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None:
|
||||
metadata_path = tmp_path / "model.metadata.json"
|
||||
metadata_path.write_text(json.dumps({"tags": ["Existing"]}))
|
||||
|
||||
async def loader(path: str) -> Dict[str, Any]:
|
||||
return json.loads(Path(path).read_text())
|
||||
|
||||
manager = RecordingMetadataManager()
|
||||
|
||||
service = TagUpdateService(metadata_manager=manager)
|
||||
|
||||
cache_updates: List[Dict[str, Any]] = []
|
||||
|
||||
async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool:
|
||||
cache_updates.append(metadata)
|
||||
return True
|
||||
|
||||
tags = asyncio.run(
|
||||
service.add_tags(
|
||||
file_path=str(tmp_path / "model.safetensors"),
|
||||
new_tags=["New", "existing"],
|
||||
metadata_loader=loader,
|
||||
update_cache=update_cache,
|
||||
)
|
||||
)
|
||||
|
||||
assert tags == ["Existing", "New"]
|
||||
assert manager.saved
|
||||
assert cache_updates
|
||||
317
tests/services/test_use_cases.py
Normal file
317
tests/services/test_use_cases.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from py_local.services.model_file_service import AutoOrganizeResult
|
||||
from py_local.services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadExampleImagesConfigurationError,
|
||||
DownloadExampleImagesInProgressError,
|
||||
DownloadExampleImagesUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
ImportExampleImagesUseCase,
|
||||
ImportExampleImagesValidationError,
|
||||
)
|
||||
from py_local.utils.example_images_download_manager import (
|
||||
DownloadConfigurationError,
|
||||
DownloadInProgressError,
|
||||
ExampleImagesDownloadError,
|
||||
)
|
||||
from py_local.utils.example_images_processor import (
|
||||
ExampleImagesImportError,
|
||||
ExampleImagesValidationError,
|
||||
)
|
||||
from tests.conftest import MockModelService, MockScanner
|
||||
|
||||
|
||||
class StubLockProvider:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self.running = False
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
return self.running
|
||||
|
||||
async def get_auto_organize_lock(self) -> asyncio.Lock:
|
||||
return self._lock
|
||||
|
||||
|
||||
class StubFileService:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def auto_organize_models(
|
||||
self,
|
||||
*,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
progress_callback=None,
|
||||
) -> AutoOrganizeResult:
|
||||
result = AutoOrganizeResult()
|
||||
result.total = len(file_paths or [])
|
||||
self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback})
|
||||
return result
|
||||
|
||||
|
||||
class StubMetadataSync:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fetch_and_update_model(self, **kwargs: Any):
|
||||
self.calls.append(kwargs)
|
||||
model_data = kwargs["model_data"]
|
||||
model_data["model_name"] = model_data.get("model_name", "model") + "-updated"
|
||||
return True, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubSettings:
|
||||
enable_metadata_archive_db: bool = False
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
if key == "enable_metadata_archive_db":
|
||||
return self.enable_metadata_archive_db
|
||||
return default
|
||||
|
||||
|
||||
class ProgressCollector:
|
||||
def __init__(self) -> None:
|
||||
self.events: List[Dict[str, Any]] = []
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
self.events.append(payload)
|
||||
|
||||
|
||||
class StubDownloadCoordinator:
|
||||
def __init__(self, *, error: Optional[str] = None) -> None:
|
||||
self.error = error
|
||||
self.payloads: List[Dict[str, Any]] = []
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
self.payloads.append(payload)
|
||||
if self.error == "validation":
|
||||
raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'")
|
||||
if self.error == "401":
|
||||
raise RuntimeError("401 Unauthorized")
|
||||
return {"success": True, "download_id": "abc123"}
|
||||
|
||||
|
||||
class StubExampleImagesDownloadManager:
|
||||
def __init__(self) -> None:
|
||||
self.payloads: List[Dict[str, Any]] = []
|
||||
self.error: Optional[str] = None
|
||||
self.progress_snapshot = {"status": "running"}
|
||||
|
||||
async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
self.payloads.append(payload)
|
||||
if self.error == "in_progress":
|
||||
raise DownloadInProgressError(self.progress_snapshot)
|
||||
if self.error == "configuration":
|
||||
raise DownloadConfigurationError("path missing")
|
||||
if self.error == "generic":
|
||||
raise ExampleImagesDownloadError("boom")
|
||||
return {"success": True, "message": "ok"}
|
||||
|
||||
|
||||
class StubExampleImagesProcessor:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
self.error: Optional[str] = None
|
||||
self.response: Dict[str, Any] = {"success": True}
|
||||
|
||||
async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]:
|
||||
self.calls.append({"model_hash": model_hash, "files": files})
|
||||
if self.error == "validation":
|
||||
raise ExampleImagesValidationError("missing")
|
||||
if self.error == "generic":
|
||||
raise ExampleImagesImportError("boom")
|
||||
return self.response
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_executes_with_lock() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
|
||||
|
||||
result = await use_case.execute(file_paths=["model1"], progress_callback=None)
|
||||
|
||||
assert isinstance(result, AutoOrganizeResult)
|
||||
assert file_service.calls[0]["file_paths"] == ["model1"]
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_rejects_when_running() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
lock_provider.running = True
|
||||
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
|
||||
|
||||
with pytest.raises(AutoOrganizeInProgressError):
|
||||
await use_case.execute(file_paths=None, progress_callback=None)
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
||||
scanner = MockScanner()
|
||||
scanner._cache.raw_data = [
|
||||
{
|
||||
"file_path": "model1.safetensors",
|
||||
"sha256": "hash",
|
||||
"from_civitai": True,
|
||||
"model_name": "Demo",
|
||||
}
|
||||
]
|
||||
service = MockModelService(scanner)
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings()
|
||||
progress = ProgressCollector()
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert result["success"] is True
|
||||
assert progress.events[0]["status"] == "started"
|
||||
assert progress.events[-1]["status"] == "completed"
|
||||
assert metadata_sync.calls
|
||||
assert scanner._cache.resort_calls == 1
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_reports_errors() -> None:
|
||||
class FailingScanner(MockScanner):
|
||||
async def get_cached_data(self, force_refresh: bool = False):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
service = MockModelService(FailingScanner())
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings()
|
||||
progress = ProgressCollector()
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert progress.events
|
||||
assert progress.events[-1]["status"] == "error"
|
||||
assert progress.events[-1]["error"] == "boom"
|
||||
|
||||
|
||||
async def test_download_model_use_case_raises_validation_error() -> None:
|
||||
coordinator = StubDownloadCoordinator(error="validation")
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
with pytest.raises(DownloadModelValidationError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
async def test_download_model_use_case_raises_early_access() -> None:
|
||||
coordinator = StubDownloadCoordinator(error="401")
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
with pytest.raises(DownloadModelEarlyAccessError):
|
||||
await use_case.execute({"model_id": 1})
|
||||
|
||||
|
||||
async def test_download_model_use_case_returns_result() -> None:
|
||||
coordinator = StubDownloadCoordinator()
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
result = await use_case.execute({"model_id": 1})
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["download_id"] == "abc123"
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_triggers_manager() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
payload = {"optimize": True}
|
||||
result = await use_case.execute(payload)
|
||||
|
||||
assert manager.payloads == [payload]
|
||||
assert result == {"success": True, "message": "ok"}
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_maps_in_progress() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "in_progress"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(DownloadExampleImagesInProgressError) as exc:
|
||||
await use_case.execute({})
|
||||
|
||||
assert exc.value.progress == manager.progress_snapshot
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_maps_configuration() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "configuration"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(DownloadExampleImagesConfigurationError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
async def test_download_example_images_use_case_propagates_generic_error() -> None:
|
||||
manager = StubExampleImagesDownloadManager()
|
||||
manager.error = "generic"
|
||||
use_case = DownloadExampleImagesUseCase(download_manager=manager)
|
||||
|
||||
with pytest.raises(ExampleImagesDownloadError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
class DummyJsonRequest:
|
||||
def __init__(self, payload: Dict[str, Any]) -> None:
|
||||
self._payload = payload
|
||||
self.content_type = "application/json"
|
||||
|
||||
async def json(self) -> Dict[str, Any]:
|
||||
return self._payload
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_delegates() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
|
||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||
result = await use_case.execute(request)
|
||||
|
||||
assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}]
|
||||
assert result == {"success": True}
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_maps_validation_error() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
processor.error = "validation"
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
request = DummyJsonRequest({"model_hash": None, "file_paths": []})
|
||||
|
||||
with pytest.raises(ImportExampleImagesValidationError):
|
||||
await use_case.execute(request)
|
||||
|
||||
|
||||
async def test_import_example_images_use_case_propagates_generic_error() -> None:
|
||||
processor = StubExampleImagesProcessor()
|
||||
processor.error = "generic"
|
||||
use_case = ImportExampleImagesUseCase(processor=processor)
|
||||
request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]})
|
||||
|
||||
with pytest.raises(ExampleImagesImportError):
|
||||
await use_case.execute(request)
|
||||
Reference in New Issue
Block a user