diff --git a/docs/architecture/example_images_routes.md b/docs/architecture/example_images_routes.md new file mode 100644 index 00000000..128530f6 --- /dev/null +++ b/docs/architecture/example_images_routes.md @@ -0,0 +1,93 @@ +# Example image route architecture + +The example image routing stack mirrors the layered model route stack described in +[`docs/architecture/model_routes.md`](model_routes.md). HTTP wiring, controller setup, +handler orchestration, and long-running workflows now live in clearly separated modules so +we can extend download/import behaviour without touching the entire feature surface. + +```mermaid +graph TD + subgraph HTTP + A[ExampleImagesRouteRegistrar] -->|binds| B[ExampleImagesRoutes controller] + end + subgraph Application + B --> C[ExampleImagesHandlerSet] + C --> D1[Handlers] + D1 --> E1[Use cases] + E1 --> F1[Download manager / processor / file manager] + end + subgraph Side Effects + F1 --> G1[Filesystem] + F1 --> G2[Model metadata] + F1 --> G3[WebSocket progress] + end +``` + +## Layer responsibilities + +| Layer | Module(s) | Responsibility | +| --- | --- | --- | +| Registrar | `py/routes/example_images_route_registrar.py` | Declarative catalogue of every example image endpoint plus helpers that bind them to an `aiohttp` router. Keeps HTTP concerns symmetrical with the model registrar. | +| Controller | `py/routes/example_images_routes.py` | Lazily constructs `ExampleImagesHandlerSet`, injects defaults for the download manager, processor, and file manager, and exposes the registrar-ready mapping just like `BaseModelRoutes`. | +| Handler set | `py/routes/handlers/example_images_handlers.py` | Groups HTTP adapters by concern (downloads, imports/deletes, filesystem access). Each handler translates domain errors into HTTP responses and defers to a use case or utility service. | +| Use cases | `py/services/use_cases/example_images/*.py` | Encapsulate orchestration for downloads and imports. They validate input, translate concurrency/configuration errors, and keep handler logic declarative. | +| Supporting services | `py/utils/example_images_download_manager.py`, `py/utils/example_images_processor.py`, `py/utils/example_images_file_manager.py` | Execute long-running work: pull assets from Civitai, persist uploads, clean metadata, expose filesystem actions with guardrails, and broadcast progress snapshots. | + +## Handler responsibilities & invariants + +`ExampleImagesHandlerSet` flattens the handler objects into the `{"handler_name": coroutine}` +mapping consumed by the registrar. The table below outlines how each handler collaborates +with the use cases and utilities. + +| Handler | Key endpoints | Collaborators | Contracts | +| --- | --- | --- | --- | +| `ExampleImagesDownloadHandler` | `/api/lm/download-example-images`, `/api/lm/example-images-status`, `/api/lm/pause-example-images`, `/api/lm/resume-example-images`, `/api/lm/force-download-example-images` | `DownloadExampleImagesUseCase`, `DownloadManager` | Delegates payload validation and concurrency checks to the use case; progress/status endpoints expose the same snapshot used for WebSocket broadcasts; pause/resume surface `DownloadNotRunningError` as HTTP 400 instead of 500. | +| `ExampleImagesManagementHandler` | `/api/lm/import-example-images`, `/api/lm/delete-example-image` | `ImportExampleImagesUseCase`, `ExampleImagesProcessor` | Multipart uploads are streamed to disk via the use case; validation failures return HTTP 400 with no filesystem side effects; deletion funnels through the processor to prune metadata and cached images consistently. | +| `ExampleImagesFileHandler` | `/api/lm/open-example-images-folder`, `/api/lm/example-image-files`, `/api/lm/has-example-images` | `ExampleImagesFileManager` | Centralises filesystem access, enforcing settings-based root paths and returning HTTP 400/404 for missing configuration or folders; responses always include `success`/`has_images` booleans for UI consumption. | + +## Use case boundaries + +| Use case | Entry point | Dependencies | Guarantees | +| --- | --- | --- | --- | +| `DownloadExampleImagesUseCase` | `execute(payload)` | `DownloadManager.start_download`, download configuration errors | Raises `DownloadExampleImagesInProgressError` when the manager reports an active job, rewraps configuration errors into `DownloadExampleImagesConfigurationError`, and lets `ExampleImagesDownloadError` bubble as 500s so handlers do not duplicate logging. | +| `ImportExampleImagesUseCase` | `execute(request)` | `ExampleImagesProcessor.import_images`, temporary file helpers | Supports multipart or JSON payloads, normalises file paths into a single list, cleans up temp files even on failure, and maps validation issues to `ImportExampleImagesValidationError` for HTTP 400 responses. | + +## Maintaining critical invariants + +* **Shared progress snapshots** - The download handler returns the same snapshot built by + `DownloadManager`, guaranteeing parity between HTTP polling endpoints and WebSocket + progress events. +* **Safe filesystem access** - All folder/file actions flow through + `ExampleImagesFileManager`, which validates the configured example image root and ensures + responses never leak absolute paths outside the allowed directory. +* **Metadata hygiene** - Import/delete operations run through `ExampleImagesProcessor`, + which updates model metadata via `MetadataManager` and notifies the relevant scanners so + cache state stays in sync. + +## Migration notes + +The refactor brings the example image stack in line with the model/recipe stacks: + +1. `ExampleImagesRouteRegistrar` now owns the declarative route list. Downstream projects + should rely on `ExampleImagesRoutes.to_route_mapping()` instead of manually wiring + handler callables. +2. `ExampleImagesRoutes` caches its `ExampleImagesHandlerSet` just like + `BaseModelRoutes`. If you previously instantiated handlers directly, inject custom + collaborators via the controller constructor (`download_manager`, `processor`, + `file_manager`) to keep test seams predictable. +3. Tests that mocked `ExampleImagesRoutes.setup_routes` should switch to patching + `DownloadExampleImagesUseCase`/`ImportExampleImagesUseCase` at import time. The handlers + expect those abstractions to surface validation/concurrency errors, and bypassing them + will skip the HTTP-friendly error mapping. + +## Extending the stack + +1. Add the endpoint to `ROUTE_DEFINITIONS` with a unique `handler_name`. +2. Expose the coroutine on an existing handler class (or create a new handler and extend + `ExampleImagesHandlerSet`). +3. Wire additional services or factories inside `_build_handler_set` on + `ExampleImagesRoutes`, mirroring how the model stack introduces new use cases. + +`tests/routes/test_example_images_routes.py` exercises registrar binding, download pause +flows, and import validations. Use it as a template when introducing new handler +collaborators or error mappings. diff --git a/docs/architecture/model_routes.md b/docs/architecture/model_routes.md index 00329299..a9fbf967 100644 --- a/docs/architecture/model_routes.md +++ b/docs/architecture/model_routes.md @@ -1,127 +1,100 @@ # Base model route architecture -The `BaseModelRoutes` controller centralizes HTTP endpoints that every model type -(LoRAs, checkpoints, embeddings, etc.) share. Each handler either forwards the -request to the injected service, delegates to a utility in -`ModelRouteUtils`, or orchestrates long‑running operations via helper services -such as the download or WebSocket managers. The table below lists every handler -exposed in `py/routes/base_model_routes.py`, the collaborators it leans on, and -any cache or WebSocket side effects implemented in -`py/utils/routes_common.py`. +The model routing stack now splits HTTP wiring, orchestration logic, and +business rules into discrete layers. The goal is to make it obvious where a +new collaborator should live and which contract it must honour. The diagram +below captures the end-to-end flow for a typical request: -## Contents +```mermaid +graph TD + subgraph HTTP + A[ModelRouteRegistrar] -->|binds| B[BaseModelRoutes handler proxy] + end + subgraph Application + B --> C[ModelHandlerSet] + C --> D1[Handlers] + D1 --> E1[Use cases] + E1 --> F1[Services / scanners] + end + subgraph Side Effects + F1 --> G1[Cache & metadata] + F1 --> G2[Filesystem] + F1 --> G3[WebSocket state] + end +``` -- [Handler catalogue](#handler-catalogue) -- [Dependency map and contracts](#dependency-map-and-contracts) - - [Cache and metadata mutations](#cache-and-metadata-mutations) - - [Download and WebSocket flows](#download-and-websocket-flows) - - [Read-only queries](#read-only-queries) - - [Template rendering and initialization](#template-rendering-and-initialization) +Every box maps to a concrete module: -## Handler catalogue +| Layer | Module(s) | Responsibility | +| --- | --- | --- | +| Registrar | `py/routes/model_route_registrar.py` | Declarative list of routes shared by every model type and helper methods for binding them to an `aiohttp` application. | +| Route controller | `py/routes/base_model_routes.py` | Constructs the handler graph, injects shared services, exposes proxies that surface `503 Service not ready` when the model service has not been attached. | +| Handler set | `py/routes/handlers/model_handlers.py` | Thin HTTP adapters grouped by concern (page rendering, listings, mutations, queries, downloads, CivitAI integration, move operations, auto-organize). | +| Use cases | `py/services/use_cases/*.py` | Encapsulate long-running flows (`DownloadModelUseCase`, `BulkMetadataRefreshUseCase`, `AutoOrganizeUseCase`). They normalise validation errors and concurrency constraints before returning control to the handlers. | +| Services | `py/services/*.py` | Existing services and scanners that mutate caches, write metadata, move files, and broadcast WebSocket updates. | -The routes exposed by `BaseModelRoutes` combine HTTP wiring with a handful of -shared helper classes. Services surface filesystem and metadata operations, -`ModelRouteUtils` bundles cache-sensitive mutations, and `ws_manager` -coordinates fan-out to browser clients. The tables below expand the existing -catalogue into explicit dependency maps and invariants so refactors can reason -about the expectations each collaborator must uphold. +## Handler responsibilities & contracts -## Dependency map and contracts +`ModelHandlerSet` flattens the handler objects into the exact callables used by +the registrar. The table below highlights the separation of concerns within +the set and the invariants that must hold after each handler returns. -### Cache and metadata mutations - -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | +| Handler | Key endpoints | Collaborators | Contracts | | --- | --- | --- | --- | -| `/api/lm/{prefix}/delete` | `ModelRouteUtils.handle_delete_model()` | Removes files from disk, prunes `scanner._cache.raw_data`, awaits `scanner._cache.resort()`, calls `scanner._hash_index.remove_by_path()`. | Cache and hash index must no longer reference the deleted path; resort must complete before responding to keep pagination deterministic. | -| `/api/lm/{prefix}/exclude` | `ModelRouteUtils.handle_exclude_model()` | Mutates metadata records, `scanner._cache.raw_data`, `scanner._hash_index`, `scanner._tags_count`, and `scanner._excluded_models`. | Excluded models remain discoverable via exclusion list while being hidden from listings; tag counts stay balanced after removal. | -| `/api/lm/{prefix}/fetch-civitai` | `ModelRouteUtils.fetch_and_update_model()` | Reads `scanner._cache.raw_data`, writes metadata JSON through `MetadataManager`, syncs cache via `scanner.update_single_model_cache`. | Requires a cached SHA256 hash; cache entries must reflect merged metadata before formatted response is returned. | -| `/api/lm/{prefix}/fetch-all-civitai` | `ModelRouteUtils.fetch_and_update_model()`, `ws_manager.broadcast()` | Iterates over cache, updates metadata files and cache records, optionally awaits `scanner._cache.resort()`. | Progress broadcasts follow started → processing → completed; if any model name changes, cache resort must run once before completion broadcast. | -| `/api/lm/{prefix}/relink-civitai` | `ModelRouteUtils.handle_relink_civitai()` | Updates metadata on disk and resynchronizes the cache entry. | The new association must propagate to `scanner.update_single_model_cache` so duplicate resolution and listings reflect the change immediately. | -| `/api/lm/{prefix}/replace-preview` | `ModelRouteUtils.handle_replace_preview()` | Writes optimized preview file, persists metadata via `MetadataManager`, updates cache with `scanner.update_preview_in_cache()`. | Preview path stored in metadata and cache must match the normalized file system path; NSFW level integer is synchronized across metadata and cache. | -| `/api/lm/{prefix}/save-metadata` | `ModelRouteUtils.handle_save_metadata()` | Writes metadata JSON and ensures cache entry mirrors the latest content. | Metadata persistence must be atomic—cache data should match on-disk metadata before response emits success. | -| `/api/lm/{prefix}/add-tags` | `ModelRouteUtils.handle_add_tags()` | Updates metadata tags, increments `scanner._tags_count`, and patches cached item. | Tag frequency map remains in sync with cache and metadata after increments. | -| `/api/lm/{prefix}/rename` | `ModelRouteUtils.handle_rename_model()` | Renames files, metadata, previews; updates cache indices and hash mappings. | File moves succeed or rollback as a unit so cache state never points to a missing file; hash index entries track the new path. | -| `/api/lm/{prefix}/bulk-delete` | `ModelRouteUtils.handle_bulk_delete_models()` | Delegates to `scanner.bulk_delete_models()` to delete files, trim cache, resort, and drop hash index entries. | Every requested path is removed from cache and index; resort happens once after bulk deletion. | -| `/api/lm/{prefix}/verify-duplicates` | `ModelRouteUtils.handle_verify_duplicates()` | Recomputes hashes, updates metadata and cached entries if discrepancies found. | Hash metadata stored in cache must mirror recomputed values to guarantee future duplicate checks operate on current data. | -| `/api/lm/{prefix}/scan` | `service.scan_models()` | Rescans filesystem, rebuilding scanner cache. | Scanner replaces its cache atomically so subsequent requests observe a consistent snapshot. | -| `/api/lm/{prefix}/move_model` | `ModelMoveService.move_model()` | Moves files/directories and notifies scanner via service layer conventions. | Move operations respect filesystem invariants (target path exists, metadata follows file) and emit success/failure without leaving partial moves. | -| `/api/lm/{prefix}/move_models_bulk` | `ModelMoveService.move_models_bulk()` | Batch move behavior as above. | Aggregated result enumerates successes/failures while preserving per-model atomicity. | -| `/api/lm/{prefix}/auto-organize` (GET/POST) | `ModelFileService.auto_organize_models()`, `ws_manager.get_auto_organize_lock()`, `WebSocketProgressCallback` | Writes organized files, updates metadata, and streams progress snapshots. | Only one auto-organize job may run; lock must guard reentrancy and WebSocket updates must include latest progress payload consumed by polling route. | +| `ModelPageView` | `/{prefix}` | `SettingsManager`, `server_i18n`, Jinja environment, `service.scanner` | Template is rendered with `is_initializing` flag when caches are cold; i18n filter is registered exactly once per environment instance. | +| `ModelListingHandler` | `/api/lm/{prefix}/list` | `service.get_paginated_data`, `service.format_response` | Listings respect pagination query parameters and cap `page_size` at 100; every item is formatted before response. | +| `ModelManagementHandler` | Mutations (delete, exclude, metadata, preview, tags, rename, bulk delete, duplicate verification) | `ModelLifecycleService`, `MetadataSyncService`, `PreviewAssetService`, `TagUpdateService`, scanner cache/index | Cache state mirrors filesystem changes: deletes prune cache & hash index, preview replacements synchronise metadata and cache NSFW levels, metadata saves trigger cache resort when names change. | +| `ModelQueryHandler` | Read-only queries (top tags, folders, duplicates, metadata, URLs) | Service query helpers & scanner cache | Outputs always wrapped in `{"success": True}` when no error; duplicate/filename grouping omits empty entries; invalid parameters (e.g. missing `model_root`) return HTTP 400. | +| `ModelDownloadHandler` | `/api/lm/download-model`, `/download-model-get`, `/download-progress/{id}`, `/cancel-download-get` | `DownloadModelUseCase`, `DownloadCoordinator`, `WebSocketManager` | Payload validation errors become HTTP 400 without mutating download progress cache; early-access failures surface as HTTP 401; successful downloads cache progress snapshots that back both WebSocket broadcasts and polling endpoints. | +| `ModelCivitaiHandler` | CivitAI metadata routes | `MetadataSyncService`, metadata provider factory, `BulkMetadataRefreshUseCase` | `fetch_all_civitai` streams progress via `WebSocketBroadcastCallback`; version lookups validate model type before returning; local availability fields derive from hash lookups without mutating cache state. | +| `ModelMoveHandler` | `move_model`, `move_models_bulk` | `ModelMoveService` | Moves execute atomically per request; bulk operations aggregate success/failure per file set. | +| `ModelAutoOrganizeHandler` | `/api/lm/{prefix}/auto-organize` (GET/POST), `/auto-organize-progress` | `AutoOrganizeUseCase`, `WebSocketProgressCallback`, `WebSocketManager` | Enforces single-flight execution using the shared lock; progress broadcasts remain available to polling clients until explicitly cleared; conflicts return HTTP 409 with a descriptive error. | -### Download and WebSocket flows +## Use case boundaries -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | +Each use case exposes a narrow asynchronous API that hides the underlying +services. Their error mapping is essential for predictable HTTP responses. + +| Use case | Entry point | Dependencies | Guarantees | | --- | --- | --- | --- | -| `/api/lm/download-model` (POST) & `/api/lm/download-model-get` (GET) | `ModelRouteUtils.handle_download_model()`, `ServiceRegistry.get_download_manager()` | Schedules downloads, registers `ws_manager.broadcast_download_progress()` callback that stores progress in `ws_manager._download_progress`. | Download IDs remain stable across POST/GET helpers; every progress callback persists a timestamped entry so `/download-progress` and WebSocket clients share consistent snapshots. | -| `/api/lm/cancel-download-get` | `ModelRouteUtils.handle_cancel_download()` | Signals download manager, prunes `ws_manager._download_progress`, and emits cancellation broadcast. | Cancel requests must tolerate missing IDs gracefully while ensuring cached progress is removed once cancellation succeeds. | -| `/api/lm/download-progress/{download_id}` | `ws_manager.get_download_progress()` | Reads cached progress dictionary. | Returns `404` when progress is absent; successful payload surfaces the numeric `progress` stored during broadcasts. | -| `/api/lm/{prefix}/fetch-all-civitai` | `ws_manager.broadcast()` | Broadcast loop described above. | Broadcast cadence cannot skip completion/error messages so clients know when to clear UI spinners. | -| `/api/lm/{prefix}/auto-organize-progress` | `ws_manager.get_auto_organize_progress()` | Reads cached progress snapshot. | Route returns cached payload verbatim; absence yields `404` to signal idle state. | +| `DownloadModelUseCase` | `execute(payload)` | `DownloadCoordinator.schedule_download` | Translates `ValueError` into `DownloadModelValidationError` for HTTP 400, recognises early-access errors (`"401"` in message) and surfaces them as `DownloadModelEarlyAccessError`, forwards success dictionaries untouched. | +| `AutoOrganizeUseCase` | `execute(file_paths, progress_callback)` | `ModelFileService.auto_organize_models`, `WebSocketManager` lock | Guarded by `ws_manager` lock + status checks; raises `AutoOrganizeInProgressError` before invoking the file service when another run is already active. | +| `BulkMetadataRefreshUseCase` | `execute_with_error_handling(progress_callback)` | `MetadataSyncService`, `SettingsManager`, `WebSocketBroadcastCallback` | Iterates through cached models, applies metadata sync, emits progress snapshots that handlers broadcast unchanged. | -### Read-only queries +## Maintaining legacy contracts -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | -| --- | --- | --- | --- | -| `/api/lm/{prefix}/list` | `service.get_paginated_data()`, `service.format_response()` | Reads service-managed pagination data. | Formatting must be applied to every item before response; pagination metadata echoes service result. | -| `/api/lm/{prefix}/top-tags` | `service.get_top_tags()` | Reads aggregated tag counts. | Limit parameter bounded to `[1, 100]`; response always wraps tags in `{success: True}` envelope. | -| `/api/lm/{prefix}/base-models` | `service.get_base_models()` | Reads service data. | Same limit handling as tags. | -| `/api/lm/{prefix}/roots` | `service.get_model_roots()` | Reads configured roots. | Always returns `{success: True, roots: [...]}`. | -| `/api/lm/{prefix}/folders` | `service.scanner.get_cached_data()` | Reads folder summaries from cache. | Cache access must tolerate initialization phases by surfacing errors via HTTP 500. | -| `/api/lm/{prefix}/folder-tree` | `service.get_folder_tree()` | Reads derived tree for requested root. | Rejects missing `model_root` with HTTP 400. | -| `/api/lm/{prefix}/unified-folder-tree` | `service.get_unified_folder_tree()` | Aggregated folder tree. | Returns `{success: True, tree: ...}` or 500 on error. | -| `/api/lm/{prefix}/find-duplicates` | `service.find_duplicate_hashes()`, `service.scanner.get_cached_data()`, `service.get_path_by_hash()` | Reads cache and hash index to format duplicates. | Only returns groups with more than one resolved model. | -| `/api/lm/{prefix}/find-filename-conflicts` | `service.find_duplicate_filenames()`, `service.scanner.get_cached_data()`, `service.scanner.get_hash_by_filename()` | Similar read-only assembly. | Includes resolved main index entry when available; empty `models` groups are omitted. | -| `/api/lm/{prefix}/get-notes` | `service.get_model_notes()` | Reads persisted notes. | Missing notes produce HTTP 404 with explicit error message. | -| `/api/lm/{prefix}/preview-url` | `service.get_model_preview_url()` | Resolves static URL. | Successful responses wrap URL in `{success: True}`; missing preview yields 404 error payload. | -| `/api/lm/{prefix}/civitai-url` | `service.get_model_civitai_url()` | Returns remote permalink info. | Response envelope matches preview pattern. | -| `/api/lm/{prefix}/metadata` | `service.get_model_metadata()` | Reads metadata JSON. | Responds with raw metadata dict or 500 on failure. | -| `/api/lm/{prefix}/model-description` | `service.get_model_description()` | Returns formatted description string. | Always JSON with success boolean. | -| `/api/lm/{prefix}/relative-paths` | `service.get_relative_paths()` | Resolves filesystem suggestions. | Maintains read-only contract. | -| `/api/lm/{prefix}/civitai/versions/{model_id}` | `get_default_metadata_provider()`, `service.has_hash()`, `service.get_path_by_hash()` | Reads remote API, cross-references cache. | Versions payload includes `existsLocally`/`localPath` only when hashes match local indices. | -| `/api/lm/{prefix}/civitai/model/version/{modelVersionId}` | `get_default_metadata_provider()` | Remote metadata lookup. | Errors propagate as JSON with `{success: False}` payload. | -| `/api/lm/{prefix}/civitai/model/hash/{hash}` | `get_default_metadata_provider()` | Remote metadata lookup. | Missing hashes return 404 with `{success: False}`. | +The refactor preserves the invariants called out in the previous architecture +notes. The most critical ones are reiterated here to emphasise the +collaboration points: -### Template rendering and initialization +1. **Cache mutations** – Delete, exclude, rename, and bulk delete operations are + channelled through `ModelManagementHandler`. The handler delegates to + `ModelLifecycleService` or `MetadataSyncService`, and the scanner cache is + mutated in-place before the handler returns. The accompanying tests assert + that `scanner._cache.raw_data` and `scanner._hash_index` stay in sync after + each mutation. +2. **Preview updates** – `PreviewAssetService.replace_preview` writes the new + asset, `MetadataSyncService` persists the JSON metadata, and + `scanner.update_preview_in_cache` mirrors the change. The handler returns + the static URL produced by `config.get_preview_static_url`, keeping browser + clients in lockstep with disk state. +3. **Download progress** – `DownloadCoordinator.schedule_download` generates the + download identifier, registers a WebSocket progress callback, and caches the + latest numeric progress via `WebSocketManager`. Both `download_model` + responses and `/download-progress/{id}` polling read from the same cache to + guarantee consistent progress reporting across transports. -| Endpoint(s) | Delegate(s) | State touched | Invariants / contracts | -| --- | --- | --- | --- | -| `/{prefix}` | `handle_models_page` | Reads configuration via `settings`, sets locale with `server_i18n`, pulls cached folders through `service.scanner.get_cached_data()`, renders Jinja template. | Template rendering must tolerate scanner initialization by flagging `is_initializing`; i18n filter is attached exactly once per environment to avoid duplicate registration errors. | +## Extending the stack -### Contract sequences +To add a new shared route: -The following high-level sequences show how the collaborating services work -together for the most stateful operations: +1. Declare it in `COMMON_ROUTE_DEFINITIONS` using a unique handler name. +2. Implement the corresponding coroutine on one of the handlers inside + `ModelHandlerSet` (or introduce a new handler class when the concern does not + fit existing ones). +3. Inject additional dependencies in `BaseModelRoutes._create_handler_set` by + wiring services or use cases through the constructor parameters. -``` -delete_model request - → BaseModelRoutes.delete_model - → ModelRouteUtils.handle_delete_model - → filesystem delete + metadata cleanup - → scanner._cache.raw_data prune - → await scanner._cache.resort() - → scanner._hash_index.remove_by_path() -``` - -``` -replace_preview request - → BaseModelRoutes.replace_preview - → ModelRouteUtils.handle_replace_preview - → ExifUtils.optimize_image / config.get_preview_static_url - → MetadataManager.save_metadata - → scanner.update_preview_in_cache(model_path, preview_path, nsfw_level) -``` - -``` -download_model request - → BaseModelRoutes.download_model - → ModelRouteUtils.handle_download_model - → ServiceRegistry.get_download_manager().download_from_civitai(..., progress_callback) - → ws_manager.broadcast_download_progress(download_id, data) - → ws_manager._download_progress[download_id] updated with timestamp - → /api/lm/download-progress/{id} polls ws_manager.get_download_progress -``` - -These contracts complement the tables above: if any collaborator changes its -behavior, the invariants called out here must continue to hold for the routes -to remain predictable. +Model-specific routes should continue to be registered inside the subclass +implementation of `setup_specific_routes`, reusing the shared registrar where +possible. diff --git a/docs/architecture/recipe_routes.md b/docs/architecture/recipe_routes.md new file mode 100644 index 00000000..0bdb7c90 --- /dev/null +++ b/docs/architecture/recipe_routes.md @@ -0,0 +1,89 @@ +# Recipe route architecture + +The recipe routing stack now mirrors the modular model route design. HTTP +bindings, controller wiring, handler orchestration, and business rules live in +separate layers so new behaviours can be added without re-threading the entire +feature. The diagram below outlines the flow for a typical request: + +```mermaid +graph TD + subgraph HTTP + A[RecipeRouteRegistrar] -->|binds| B[RecipeRoutes controller] + end + subgraph Application + B --> C[RecipeHandlerSet] + C --> D1[Handlers] + D1 --> E1[Use cases] + E1 --> F1[Services / scanners] + end + subgraph Side Effects + F1 --> G1[Cache & fingerprint index] + F1 --> G2[Metadata files] + F1 --> G3[Temporary shares] + end +``` + +## Layer responsibilities + +| Layer | Module(s) | Responsibility | +| --- | --- | --- | +| Registrar | `py/routes/recipe_route_registrar.py` | Declarative list of every recipe endpoint and helper methods that bind them to an `aiohttp` application. | +| Controller | `py/routes/base_recipe_routes.py`, `py/routes/recipe_routes.py` | Lazily resolves scanners/clients from the service registry, wires shared templates/i18n, instantiates `RecipeHandlerSet`, and exposes a `{handler_name: coroutine}` mapping for the registrar. | +| Handler set | `py/routes/handlers/recipe_handlers.py` | Thin HTTP adapters grouped by concern (page view, listings, queries, mutations, sharing). They normalise responses and translate service exceptions into HTTP status codes. | +| Services & scanners | `py/services/recipes/*.py`, `py/services/recipe_scanner.py`, `py/services/service_registry.py` | Concrete business logic: metadata parsing, persistence, sharing, fingerprint/index maintenance, and cache refresh. | + +## Handler responsibilities & invariants + +`RecipeHandlerSet` flattens purpose-built handler objects into the callables the +registrar binds. Each handler is responsible for a narrow concern and enforces a +set of invariants before returning: + +| Handler | Key endpoints | Collaborators | Contracts | +| --- | --- | --- | --- | +| `RecipePageView` | `/loras/recipes` | `SettingsManager`, `server_i18n`, Jinja environment, recipe scanner getter | Template rendered with `is_initializing` flag when caches are still warming; i18n filter registered exactly once per environment instance. | +| `RecipeListingHandler` | `/api/lm/recipes`, `/api/lm/recipe/{id}` | `recipe_scanner.get_paginated_data`, `recipe_scanner.get_recipe_by_id` | Listings respect pagination and search filters; every item receives a `file_url` fallback even when metadata is incomplete; missing recipes become HTTP 404. | +| `RecipeQueryHandler` | Tag/base-model stats, syntax, LoRA lookups | Recipe scanner cache, `format_recipe_file_url` helper | Cache snapshots are reused without forcing refresh; duplicate lookups collapse groups by fingerprint; syntax lookups return helpful errors when LoRAs are absent. | +| `RecipeManagementHandler` | Save, update, reconnect, bulk delete, widget ingest | `RecipePersistenceService`, `RecipeAnalysisService`, recipe scanner | Persistence results propagate HTTP status codes; fingerprint/index updates flow through the scanner before returning; validation errors surface as HTTP 400 without touching disk. | +| `RecipeAnalysisHandler` | Uploaded/local/remote analysis | `RecipeAnalysisService`, `civitai_client`, recipe scanner | Unsupported content types map to HTTP 400; download errors (`RecipeDownloadError`) are not retried; every response includes a `loras` array for client compatibility. | +| `RecipeSharingHandler` | Share + download | `RecipeSharingService`, recipe scanner | Share responses provide a stable download URL and filename; expired shares surface as HTTP 404; downloads stream via `web.FileResponse` with attachment headers. | + +## Use case boundaries + +The dedicated services encapsulate long-running work so handlers stay thin. + +| Use case | Entry point | Dependencies | Guarantees | +| --- | --- | --- | --- | +| `RecipeAnalysisService` | `analyze_uploaded_image`, `analyze_remote_image`, `analyze_local_image`, `analyze_widget_metadata` | `ExifUtils`, `RecipeParserFactory`, downloader factory, optional metadata collector/processor | Normalises missing/invalid payloads into `RecipeValidationError`; generates consistent fingerprint data to keep duplicate detection stable; temporary files are cleaned up after every analysis path. | +| `RecipePersistenceService` | `save_recipe`, `delete_recipe`, `update_recipe`, `reconnect_lora`, `bulk_delete`, `save_recipe_from_widget` | `ExifUtils`, recipe scanner, card preview sizing constants | Writes images/JSON metadata atomically; updates scanner caches and hash indices before returning; recalculates fingerprints whenever LoRA assignments change. | +| `RecipeSharingService` | `share_recipe`, `prepare_download` | `tempfile`, recipe scanner | Copies originals to TTL-managed temp files; metadata lookups re-use the scanner; expired shares trigger cleanup and `RecipeNotFoundError`. | + +## Maintaining critical invariants + +* **Cache updates** – Mutations (`save`, `delete`, `bulk_delete`, `update`) call + back into the recipe scanner to mutate the in-memory cache and fingerprint + index before returning a response. Tests assert that these methods are invoked + even when stubbing persistence. +* **Fingerprint management** – `RecipePersistenceService` recomputes + fingerprints whenever LoRA metadata changes and duplicate lookups use those + fingerprints to group recipes. Handlers bubble the resulting IDs so clients + can merge duplicates without an extra fetch. +* **Metadata synchronisation** – Saving or reconnecting a recipe updates the + JSON sidecar, refreshes embedded metadata via `ExifUtils`, and instructs the + scanner to resort its cache. Sharing relies on this metadata to generate + filenames and ensure downloads stay in sync with on-disk state. + +## Extending the stack + +1. Declare the new endpoint in `ROUTE_DEFINITIONS` with a unique handler name. +2. Implement the coroutine on an existing handler or introduce a new handler + class inside `py/routes/handlers/recipe_handlers.py` when the concern does + not fit existing ones. +3. Wire additional collaborators inside + `BaseRecipeRoutes._create_handler_set` (inject new services or factories) and + expose helper getters on the handler owner if the handler needs to share + utilities. + +Integration tests in `tests/routes/test_recipe_routes.py` exercise the listing, +mutation, analysis-error, and sharing paths end-to-end, ensuring the controller +and handler wiring remains valid as new capabilities are added. + diff --git a/py/lora_manager.py b/py/lora_manager.py index 1a99d508..ed37f27d 100644 --- a/py/lora_manager.py +++ b/py/lora_manager.py @@ -166,7 +166,7 @@ class LoraManager: RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) - ExampleImagesRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/py/routes/base_model_routes.py b/py/routes/base_model_routes.py index 458a5e87..84b9f43f 100644 --- a/py/routes/base_model_routes.py +++ b/py/routes/base_model_routes.py @@ -8,13 +8,29 @@ import jinja2 from aiohttp import web from ..config import config -from ..services.metadata_service import get_default_metadata_provider +from ..services.download_coordinator import DownloadCoordinator +from ..services.downloader import get_downloader +from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider +from ..services.metadata_sync_service import MetadataSyncService from ..services.model_file_service import ModelFileService, ModelMoveService -from ..services.settings_manager import settings as default_settings -from ..services.websocket_manager import ws_manager as default_ws_manager -from ..services.websocket_progress_callback import WebSocketProgressCallback +from ..services.model_lifecycle_service import ModelLifecycleService +from ..services.preview_asset_service import PreviewAssetService from ..services.server_i18n import server_i18n as default_server_i18n -from ..utils.routes_common import ModelRouteUtils +from ..services.service_registry import ServiceRegistry +from ..services.settings_manager import settings as default_settings +from ..services.tag_update_service import TagUpdateService +from ..services.websocket_manager import ws_manager as default_ws_manager +from ..services.use_cases import ( + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelUseCase, +) +from ..services.websocket_progress_callback import ( + WebSocketBroadcastCallback, + WebSocketProgressCallback, +) +from ..utils.exif_utils import ExifUtils +from ..utils.metadata_manager import MetadataManager from .model_route_registrar import COMMON_ROUTE_DEFINITIONS, ModelRouteRegistrar from .handlers.model_handlers import ( ModelAutoOrganizeHandler, @@ -59,11 +75,31 @@ class BaseModelRoutes(ABC): self.model_file_service: ModelFileService | None = None self.model_move_service: ModelMoveService | None = None + self.model_lifecycle_service: ModelLifecycleService | None = None self.websocket_progress_callback = WebSocketProgressCallback() + self.metadata_progress_callback = WebSocketBroadcastCallback() self._handler_set: ModelHandlerSet | None = None self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None + self._preview_service = PreviewAssetService( + metadata_manager=MetadataManager, + downloader_factory=get_downloader, + exif_utils=ExifUtils, + ) + self._metadata_sync_service = MetadataSyncService( + metadata_manager=MetadataManager, + preview_service=self._preview_service, + settings=settings_service, + default_metadata_provider_factory=metadata_provider_factory, + metadata_provider_selector=get_metadata_provider, + ) + self._tag_update_service = TagUpdateService(metadata_manager=MetadataManager) + self._download_coordinator = DownloadCoordinator( + ws_manager=self._ws_manager, + download_manager_factory=ServiceRegistry.get_download_manager, + ) + if service is not None: self.attach_service(service) @@ -73,6 +109,12 @@ class BaseModelRoutes(ABC): self.model_type = service.model_type self.model_file_service = ModelFileService(service.scanner, service.model_type) self.model_move_service = ModelMoveService(service.scanner) + self.model_lifecycle_service = ModelLifecycleService( + scanner=service.scanner, + metadata_manager=MetadataManager, + metadata_loader=self._metadata_sync_service.load_local_metadata, + recipe_scanner_factory=ServiceRegistry.get_recipe_scanner, + ) self._handler_set = None self._handler_mapping = None @@ -98,9 +140,28 @@ class BaseModelRoutes(ABC): parse_specific_params=self._parse_specific_params, logger=logger, ) - management = ModelManagementHandler(service=service, logger=logger) + management = ModelManagementHandler( + service=service, + logger=logger, + metadata_sync=self._metadata_sync_service, + preview_service=self._preview_service, + tag_update_service=self._tag_update_service, + lifecycle_service=self._ensure_lifecycle_service(), + ) query = ModelQueryHandler(service=service, logger=logger) - download = ModelDownloadHandler(ws_manager=self._ws_manager, logger=logger) + download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator) + download = ModelDownloadHandler( + ws_manager=self._ws_manager, + logger=logger, + download_use_case=download_use_case, + download_coordinator=self._download_coordinator, + ) + metadata_refresh_use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=self._metadata_sync_service, + settings_service=self._settings, + logger=logger, + ) civitai = ModelCivitaiHandler( service=service, settings_service=self._settings, @@ -110,10 +171,17 @@ class BaseModelRoutes(ABC): validate_model_type=self._validate_civitai_model_type, expected_model_types=self._get_expected_model_types, find_model_file=self._find_model_file, + metadata_sync=self._metadata_sync_service, + metadata_refresh_use_case=metadata_refresh_use_case, + metadata_progress_callback=self.metadata_progress_callback, ) move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger) - auto_organize = ModelAutoOrganizeHandler( + auto_organize_use_case = AutoOrganizeUseCase( file_service=self._ensure_file_service(), + lock_provider=self._ws_manager, + ) + auto_organize = ModelAutoOrganizeHandler( + use_case=auto_organize_use_case, progress_callback=self.websocket_progress_callback, ws_manager=self._ws_manager, logger=logger, @@ -167,10 +235,6 @@ class BaseModelRoutes(ABC): """Expose handlers for subclasses or tests.""" return self._ensure_handler_mapping()[name] - @property - def utils(self) -> ModelRouteUtils: # pragma: no cover - compatibility shim - return ModelRouteUtils - def _ensure_service(self): if self.service is None: raise RuntimeError("Model service has not been attached") @@ -188,6 +252,17 @@ class BaseModelRoutes(ABC): self.model_move_service = ModelMoveService(service.scanner) return self.model_move_service + def _ensure_lifecycle_service(self) -> ModelLifecycleService: + if self.model_lifecycle_service is None: + service = self._ensure_service() + self.model_lifecycle_service = ModelLifecycleService( + scanner=service.scanner, + metadata_manager=MetadataManager, + metadata_loader=self._metadata_sync_service.load_local_metadata, + recipe_scanner_factory=ServiceRegistry.get_recipe_scanner, + ) + return self.model_lifecycle_service + def _make_handler_proxy(self, name: str) -> Callable[[web.Request], web.StreamResponse]: async def proxy(request: web.Request) -> web.StreamResponse: try: diff --git a/py/routes/base_recipe_routes.py b/py/routes/base_recipe_routes.py new file mode 100644 index 00000000..4447bb7b --- /dev/null +++ b/py/routes/base_recipe_routes.py @@ -0,0 +1,217 @@ +"""Base infrastructure shared across recipe routes.""" +from __future__ import annotations + +import logging +import os +from typing import Callable, Mapping + +import jinja2 +from aiohttp import web + +from ..config import config +from ..recipes import RecipeParserFactory +from ..services.downloader import get_downloader +from ..services.recipes import ( + RecipeAnalysisService, + RecipePersistenceService, + RecipeSharingService, +) +from ..services.server_i18n import server_i18n +from ..services.service_registry import ServiceRegistry +from ..services.settings_manager import settings +from ..utils.constants import CARD_PREVIEW_WIDTH +from ..utils.exif_utils import ExifUtils +from .handlers.recipe_handlers import ( + RecipeAnalysisHandler, + RecipeHandlerSet, + RecipeListingHandler, + RecipeManagementHandler, + RecipePageView, + RecipeQueryHandler, + RecipeSharingHandler, +) +from .recipe_route_registrar import ROUTE_DEFINITIONS + +logger = logging.getLogger(__name__) + + +class BaseRecipeRoutes: + """Common dependency and startup wiring for recipe routes.""" + + _HANDLER_NAMES: tuple[str, ...] = tuple( + definition.handler_name for definition in ROUTE_DEFINITIONS + ) + + template_name: str = "recipes.html" + + def __init__(self) -> None: + self.recipe_scanner = None + self.lora_scanner = None + self.civitai_client = None + self.settings = settings + self.server_i18n = server_i18n + self.template_env = jinja2.Environment( + loader=jinja2.FileSystemLoader(config.templates_path), + autoescape=True, + ) + + self._i18n_registered = False + self._startup_hooks_registered = False + self._handler_set: RecipeHandlerSet | None = None + self._handler_mapping: dict[str, Callable] | None = None + + async def attach_dependencies(self, app: web.Application | None = None) -> None: + """Resolve shared services from the registry.""" + + await self._ensure_services() + self._ensure_i18n_filter() + + async def ensure_dependencies_ready(self) -> None: + """Ensure dependencies are available for request handlers.""" + + if self.recipe_scanner is None or self.civitai_client is None: + await self.attach_dependencies() + + def register_startup_hooks(self, app: web.Application) -> None: + """Register startup hooks once for dependency wiring.""" + + if self._startup_hooks_registered: + return + + app.on_startup.append(self.attach_dependencies) + app.on_startup.append(self.prewarm_cache) + self._startup_hooks_registered = True + + async def prewarm_cache(self, app: web.Application | None = None) -> None: + """Pre-load recipe and LoRA caches on startup.""" + + try: + await self.attach_dependencies(app) + + if self.lora_scanner is not None: + await self.lora_scanner.get_cached_data() + hash_index = getattr(self.lora_scanner, "_hash_index", None) + if hash_index is not None and hasattr(hash_index, "_hash_to_path"): + _ = len(hash_index._hash_to_path) + + if self.recipe_scanner is not None: + await self.recipe_scanner.get_cached_data(force_refresh=True) + except Exception as exc: + logger.error("Error pre-warming recipe cache: %s", exc, exc_info=True) + + def to_route_mapping(self) -> Mapping[str, Callable]: + """Return a mapping of handler name to coroutine for registrar binding.""" + + if self._handler_mapping is None: + handler_set = self._create_handler_set() + self._handler_set = handler_set + self._handler_mapping = handler_set.to_route_mapping() + return self._handler_mapping + + # Internal helpers ------------------------------------------------- + + async def _ensure_services(self) -> None: + if self.recipe_scanner is None: + self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() + self.lora_scanner = getattr(self.recipe_scanner, "_lora_scanner", None) + + if self.civitai_client is None: + self.civitai_client = await ServiceRegistry.get_civitai_client() + + def _ensure_i18n_filter(self) -> None: + if not self._i18n_registered: + self.template_env.filters["t"] = self.server_i18n.create_template_filter() + self._i18n_registered = True + + def get_handler_owner(self): + """Return the object supplying bound handler coroutines.""" + + if self._handler_set is None: + self._handler_set = self._create_handler_set() + return self._handler_set + + def _create_handler_set(self) -> RecipeHandlerSet: + recipe_scanner_getter = lambda: self.recipe_scanner + civitai_client_getter = lambda: self.civitai_client + + standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" + if not standalone_mode: + from ..metadata_collector import get_metadata # type: ignore[import-not-found] + from ..metadata_collector.metadata_processor import ( # type: ignore[import-not-found] + MetadataProcessor, + ) + from ..metadata_collector.metadata_registry import ( # type: ignore[import-not-found] + MetadataRegistry, + ) + else: # pragma: no cover - optional dependency path + get_metadata = None # type: ignore[assignment] + MetadataProcessor = None # type: ignore[assignment] + MetadataRegistry = None # type: ignore[assignment] + + analysis_service = RecipeAnalysisService( + exif_utils=ExifUtils, + recipe_parser_factory=RecipeParserFactory, + downloader_factory=get_downloader, + metadata_collector=get_metadata, + metadata_processor_cls=MetadataProcessor, + metadata_registry_cls=MetadataRegistry, + standalone_mode=standalone_mode, + logger=logger, + ) + persistence_service = RecipePersistenceService( + exif_utils=ExifUtils, + card_preview_width=CARD_PREVIEW_WIDTH, + logger=logger, + ) + sharing_service = RecipeSharingService(logger=logger) + + page_view = RecipePageView( + ensure_dependencies_ready=self.ensure_dependencies_ready, + settings_service=self.settings, + server_i18n=self.server_i18n, + template_env=self.template_env, + template_name=self.template_name, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + ) + listing = RecipeListingHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + ) + query = RecipeQueryHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + format_recipe_file_url=listing.format_recipe_file_url, + logger=logger, + ) + management = RecipeManagementHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + persistence_service=persistence_service, + analysis_service=analysis_service, + ) + analysis = RecipeAnalysisHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + civitai_client_getter=civitai_client_getter, + logger=logger, + analysis_service=analysis_service, + ) + sharing = RecipeSharingHandler( + ensure_dependencies_ready=self.ensure_dependencies_ready, + recipe_scanner_getter=recipe_scanner_getter, + logger=logger, + sharing_service=sharing_service, + ) + + return RecipeHandlerSet( + page_view=page_view, + listing=listing, + query=query, + management=management, + analysis=analysis, + sharing=sharing, + ) + diff --git a/py/routes/example_images_route_registrar.py b/py/routes/example_images_route_registrar.py new file mode 100644 index 00000000..d0f1fab0 --- /dev/null +++ b/py/routes/example_images_route_registrar.py @@ -0,0 +1,61 @@ +"""Route registrar for example image endpoints.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Iterable, Mapping + +from aiohttp import web + + +@dataclass(frozen=True) +class RouteDefinition: + """Declarative configuration for a HTTP route.""" + + method: str + path: str + handler_name: str + + +ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( + RouteDefinition("POST", "/api/lm/download-example-images", "download_example_images"), + RouteDefinition("POST", "/api/lm/import-example-images", "import_example_images"), + RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"), + RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"), + RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"), + RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"), + RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"), + RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"), + RouteDefinition("POST", "/api/lm/delete-example-image", "delete_example_image"), + RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"), +) + + +class ExampleImagesRouteRegistrar: + """Bind declarative example image routes to an aiohttp router.""" + + _METHOD_MAP = { + "GET": "add_get", + "POST": "add_post", + "PUT": "add_put", + "DELETE": "add_delete", + } + + def __init__(self, app: web.Application) -> None: + self._app = app + + def register_routes( + self, + handler_lookup: Mapping[str, Callable[[web.Request], object]], + *, + definitions: Iterable[RouteDefinition] = ROUTE_DEFINITIONS, + ) -> None: + """Register each route definition using the supplied handlers.""" + + for definition in definitions: + handler = handler_lookup[definition.handler_name] + self._bind_route(definition.method, definition.path, handler) + + def _bind_route(self, method: str, path: str, handler: Callable[[web.Request], object]) -> None: + add_method_name = self._METHOD_MAP[method.upper()] + add_method = getattr(self._app.router, add_method_name) + add_method(path, handler) diff --git a/py/routes/example_images_routes.py b/py/routes/example_images_routes.py index 07cb0e71..5073410d 100644 --- a/py/routes/example_images_routes.py +++ b/py/routes/example_images_routes.py @@ -1,74 +1,81 @@ +from __future__ import annotations + import logging -from ..utils.example_images_download_manager import DownloadManager -from ..utils.example_images_processor import ExampleImagesProcessor +from typing import Callable, Mapping + +from aiohttp import web + +from .example_images_route_registrar import ExampleImagesRouteRegistrar +from .handlers.example_images_handlers import ( + ExampleImagesDownloadHandler, + ExampleImagesFileHandler, + ExampleImagesHandlerSet, + ExampleImagesManagementHandler, +) +from ..services.use_cases.example_images import ( + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, +) +from ..utils.example_images_download_manager import ( + DownloadManager, + get_default_download_manager, +) from ..utils.example_images_file_manager import ExampleImagesFileManager -from ..services.websocket_manager import ws_manager +from ..utils.example_images_processor import ExampleImagesProcessor logger = logging.getLogger(__name__) + class ExampleImagesRoutes: - """Routes for example images related functionality""" - - @staticmethod - def setup_routes(app): - """Register example images routes""" - app.router.add_post('/api/lm/download-example-images', ExampleImagesRoutes.download_example_images) - app.router.add_post('/api/lm/import-example-images', ExampleImagesRoutes.import_example_images) - app.router.add_get('/api/lm/example-images-status', ExampleImagesRoutes.get_example_images_status) - app.router.add_post('/api/lm/pause-example-images', ExampleImagesRoutes.pause_example_images) - app.router.add_post('/api/lm/resume-example-images', ExampleImagesRoutes.resume_example_images) - app.router.add_post('/api/lm/open-example-images-folder', ExampleImagesRoutes.open_example_images_folder) - app.router.add_get('/api/lm/example-image-files', ExampleImagesRoutes.get_example_image_files) - app.router.add_get('/api/lm/has-example-images', ExampleImagesRoutes.has_example_images) - app.router.add_post('/api/lm/delete-example-image', ExampleImagesRoutes.delete_example_image) - app.router.add_post('/api/lm/force-download-example-images', ExampleImagesRoutes.force_download_example_images) + """Route controller for example image endpoints.""" - @staticmethod - async def download_example_images(request): - """Download example images for models from Civitai""" - return await DownloadManager.start_download(request) + def __init__( + self, + *, + ws_manager, + download_manager: DownloadManager | None = None, + processor=ExampleImagesProcessor, + file_manager=ExampleImagesFileManager, + ) -> None: + if ws_manager is None: + raise ValueError("ws_manager is required") + self._download_manager = download_manager or get_default_download_manager(ws_manager) + self._processor = processor + self._file_manager = file_manager + self._handler_set: ExampleImagesHandlerSet | None = None + self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None - @staticmethod - async def get_example_images_status(request): - """Get the current status of example images download""" - return await DownloadManager.get_status(request) + @classmethod + def setup_routes(cls, app: web.Application, *, ws_manager) -> None: + """Register routes on the given aiohttp application using default wiring.""" - @staticmethod - async def pause_example_images(request): - """Pause the example images download""" - return await DownloadManager.pause_download(request) + controller = cls(ws_manager=ws_manager) + controller.register(app) - @staticmethod - async def resume_example_images(request): - """Resume the example images download""" - return await DownloadManager.resume_download(request) - - @staticmethod - async def open_example_images_folder(request): - """Open the example images folder for a specific model""" - return await ExampleImagesFileManager.open_folder(request) + def register(self, app: web.Application) -> None: + """Bind the controller's handlers to the aiohttp router.""" - @staticmethod - async def get_example_image_files(request): - """Get list of example image files for a specific model""" - return await ExampleImagesFileManager.get_files(request) + registrar = ExampleImagesRouteRegistrar(app) + registrar.register_routes(self.to_route_mapping()) - @staticmethod - async def import_example_images(request): - """Import local example images for a model""" - return await ExampleImagesProcessor.import_images(request) - - @staticmethod - async def has_example_images(request): - """Check if example images folder exists and is not empty for a model""" - return await ExampleImagesFileManager.has_images(request) + def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: + """Return the registrar-compatible mapping of handler names to callables.""" - @staticmethod - async def delete_example_image(request): - """Delete a custom example image for a model""" - return await ExampleImagesProcessor.delete_custom_image(request) + if self._handler_mapping is None: + handler_set = self._build_handler_set() + self._handler_set = handler_set + self._handler_mapping = handler_set.to_route_mapping() + return self._handler_mapping - @staticmethod - async def force_download_example_images(request): - """Force download example images for specific models""" - return await DownloadManager.start_force_download(request) \ No newline at end of file + def _build_handler_set(self) -> ExampleImagesHandlerSet: + logger.debug("Building ExampleImagesHandlerSet with %s, %s, %s", self._download_manager, self._processor, self._file_manager) + download_use_case = DownloadExampleImagesUseCase(download_manager=self._download_manager) + download_handler = ExampleImagesDownloadHandler(download_use_case, self._download_manager) + import_use_case = ImportExampleImagesUseCase(processor=self._processor) + management_handler = ExampleImagesManagementHandler(import_use_case, self._processor) + file_handler = ExampleImagesFileHandler(self._file_manager) + return ExampleImagesHandlerSet( + download=download_handler, + management=management_handler, + files=file_handler, + ) diff --git a/py/routes/handlers/example_images_handlers.py b/py/routes/handlers/example_images_handlers.py new file mode 100644 index 00000000..fd39de04 --- /dev/null +++ b/py/routes/handlers/example_images_handlers.py @@ -0,0 +1,147 @@ +"""Handler set for example image routes.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Mapping + +from aiohttp import web + +from ...services.use_cases.example_images import ( + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) +from ...utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + DownloadNotRunningError, + ExampleImagesDownloadError, +) +from ...utils.example_images_processor import ExampleImagesImportError + + +class ExampleImagesDownloadHandler: + """HTTP adapters for download-related example image endpoints.""" + + def __init__( + self, + download_use_case: DownloadExampleImagesUseCase, + download_manager, + ) -> None: + self._download_use_case = download_use_case + self._download_manager = download_manager + + async def download_example_images(self, request: web.Request) -> web.StreamResponse: + try: + payload = await request.json() + result = await self._download_use_case.execute(payload) + return web.json_response(result) + except DownloadExampleImagesInProgressError as exc: + response = { + 'success': False, + 'error': str(exc), + 'status': exc.progress, + } + return web.json_response(response, status=400) + except DownloadExampleImagesConfigurationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesDownloadError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) + + async def get_example_images_status(self, request: web.Request) -> web.StreamResponse: + result = await self._download_manager.get_status(request) + return web.json_response(result) + + async def pause_example_images(self, request: web.Request) -> web.StreamResponse: + try: + result = await self._download_manager.pause_download(request) + return web.json_response(result) + except DownloadNotRunningError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + + async def resume_example_images(self, request: web.Request) -> web.StreamResponse: + try: + result = await self._download_manager.resume_download(request) + return web.json_response(result) + except DownloadNotRunningError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + + async def force_download_example_images(self, request: web.Request) -> web.StreamResponse: + try: + payload = await request.json() + result = await self._download_manager.start_force_download(payload) + return web.json_response(result) + except DownloadInProgressError as exc: + response = { + 'success': False, + 'error': str(exc), + 'status': exc.progress_snapshot, + } + return web.json_response(response, status=400) + except DownloadConfigurationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesDownloadError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) + + +class ExampleImagesManagementHandler: + """HTTP adapters for import/delete endpoints.""" + + def __init__(self, import_use_case: ImportExampleImagesUseCase, processor) -> None: + self._import_use_case = import_use_case + self._processor = processor + + async def import_example_images(self, request: web.Request) -> web.StreamResponse: + try: + result = await self._import_use_case.execute(request) + return web.json_response(result) + except ImportExampleImagesValidationError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=400) + except ExampleImagesImportError as exc: + return web.json_response({'success': False, 'error': str(exc)}, status=500) + + async def delete_example_image(self, request: web.Request) -> web.StreamResponse: + return await self._processor.delete_custom_image(request) + + +class ExampleImagesFileHandler: + """HTTP adapters for filesystem-centric endpoints.""" + + def __init__(self, file_manager) -> None: + self._file_manager = file_manager + + async def open_example_images_folder(self, request: web.Request) -> web.StreamResponse: + return await self._file_manager.open_folder(request) + + async def get_example_image_files(self, request: web.Request) -> web.StreamResponse: + return await self._file_manager.get_files(request) + + async def has_example_images(self, request: web.Request) -> web.StreamResponse: + return await self._file_manager.has_images(request) + + +@dataclass(frozen=True) +class ExampleImagesHandlerSet: + """Aggregate of handlers exposed to the registrar.""" + + download: ExampleImagesDownloadHandler + management: ExampleImagesManagementHandler + files: ExampleImagesFileHandler + + def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], web.StreamResponse]]: + """Flatten handler methods into the registrar mapping.""" + + return { + "download_example_images": self.download.download_example_images, + "get_example_images_status": self.download.get_example_images_status, + "pause_example_images": self.download.pause_example_images, + "resume_example_images": self.download.resume_example_images, + "force_download_example_images": self.download.force_download_example_images, + "import_example_images": self.management.import_example_images, + "delete_example_image": self.management.delete_example_image, + "open_example_images_folder": self.files.open_example_images_folder, + "get_example_image_files": self.files.get_example_image_files, + "has_example_images": self.files.has_example_images, + } diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 66a7123a..a6fe4091 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -4,17 +4,32 @@ from __future__ import annotations import asyncio import json import logging +import os from dataclasses import dataclass from typing import Awaitable, Callable, Dict, Iterable, Mapping, Optional from aiohttp import web import jinja2 -from ...services.model_file_service import ModelFileService, ModelMoveService -from ...services.websocket_progress_callback import WebSocketProgressCallback -from ...services.websocket_manager import WebSocketManager +from ...config import config +from ...services.download_coordinator import DownloadCoordinator +from ...services.metadata_sync_service import MetadataSyncService +from ...services.model_file_service import ModelMoveService +from ...services.preview_asset_service import PreviewAssetService from ...services.settings_manager import SettingsManager -from ...utils.routes_common import ModelRouteUtils +from ...services.tag_update_service import TagUpdateService +from ...services.use_cases import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, + MetadataRefreshProgressReporter, +) +from ...services.websocket_manager import WebSocketManager +from ...services.websocket_progress_callback import WebSocketProgressCallback +from ...utils.file_utils import calculate_sha256 class ModelPageView: @@ -168,15 +183,52 @@ class ModelListingHandler: class ModelManagementHandler: """Handle mutation operations on models.""" - def __init__(self, *, service, logger: logging.Logger) -> None: + def __init__( + self, + *, + service, + logger: logging.Logger, + metadata_sync: MetadataSyncService, + preview_service: PreviewAssetService, + tag_update_service: TagUpdateService, + lifecycle_service, + ) -> None: self._service = service self._logger = logger + self._metadata_sync = metadata_sync + self._preview_service = preview_service + self._tag_update_service = tag_update_service + self._lifecycle_service = lifecycle_service async def delete_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_delete_model(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.Response(text="Model path is required", status=400) + + result = await self._lifecycle_service.delete_model(file_path) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error deleting model: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def exclude_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_exclude_model(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.Response(text="Model path is required", status=400) + + result = await self._lifecycle_service.exclude_model(file_path) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error excluding model: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def fetch_civitai(self, request: web.Request) -> web.Response: try: @@ -192,7 +244,7 @@ class ModelManagementHandler: if not model_data.get("sha256"): return web.json_response({"success": False, "error": "No SHA256 hash found"}, status=400) - success, error = await ModelRouteUtils.fetch_and_update_model( + success, error = await self._metadata_sync.fetch_and_update_model( sha256=model_data["sha256"], file_path=file_path, model_data=model_data, @@ -208,25 +260,221 @@ class ModelManagementHandler: return web.json_response({"success": False, "error": str(exc)}, status=500) async def relink_civitai(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_relink_civitai(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + model_id = data.get("model_id") + model_version_id = data.get("model_version_id") + + if not file_path or model_id is None: + return web.json_response( + {"success": False, "error": "Both file_path and model_id are required"}, + status=400, + ) + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + local_metadata = await self._metadata_sync.load_local_metadata(metadata_path) + + updated_metadata = await self._metadata_sync.relink_metadata( + file_path=file_path, + metadata=local_metadata, + model_id=int(model_id), + model_version_id=int(model_version_id) if model_version_id else None, + ) + + await self._service.scanner.update_single_model_cache( + file_path, file_path, updated_metadata + ) + + message = ( + f"Model successfully re-linked to Civitai model {model_id}" + + (f" version {model_version_id}" if model_version_id else "") + ) + return web.json_response( + {"success": True, "message": message, "hash": updated_metadata.get("sha256", "")} + ) + except Exception as exc: + self._logger.error("Error re-linking to CivitAI: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) async def replace_preview(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_replace_preview(request, self._service.scanner) + try: + reader = await request.multipart() + + field = await reader.next() + if field is None or field.name != "preview_file": + raise ValueError("Expected 'preview_file' field") + content_type = field.headers.get("Content-Type", "image/png") + content_disposition = field.headers.get("Content-Disposition", "") + + original_filename = None + import re + + match = re.search(r'filename="(.*?)"', content_disposition) + if match: + original_filename = match.group(1) + + preview_data = await field.read() + + field = await reader.next() + if field is None or field.name != "model_path": + raise ValueError("Expected 'model_path' field") + model_path = (await field.read()).decode() + + nsfw_level = 0 + field = await reader.next() + if field and field.name == "nsfw_level": + try: + nsfw_level = int((await field.read()).decode()) + except (ValueError, TypeError): + self._logger.warning("Invalid NSFW level format, using default 0") + + result = await self._preview_service.replace_preview( + model_path=model_path, + preview_data=preview_data, + content_type=content_type, + original_filename=original_filename, + nsfw_level=nsfw_level, + update_preview_in_cache=self._service.scanner.update_preview_in_cache, + metadata_loader=self._metadata_sync.load_local_metadata, + ) + + return web.json_response( + { + "success": True, + "preview_url": config.get_preview_static_url(result["preview_path"]), + "preview_nsfw_level": result["preview_nsfw_level"], + } + ) + except Exception as exc: + self._logger.error("Error replacing preview: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def save_metadata(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_save_metadata(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + if not file_path: + return web.Response(text="File path is required", status=400) + + metadata_updates = {k: v for k, v in data.items() if k != "file_path"} + + await self._metadata_sync.save_metadata_updates( + file_path=file_path, + updates=metadata_updates, + metadata_loader=self._metadata_sync.load_local_metadata, + update_cache=self._service.scanner.update_single_model_cache, + ) + + if "model_name" in metadata_updates: + cache = await self._service.scanner.get_cached_data() + await cache.resort() + + return web.json_response({"success": True}) + except Exception as exc: + self._logger.error("Error saving metadata: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def add_tags(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_add_tags(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + new_tags = data.get("tags", []) + + if not file_path: + return web.Response(text="File path is required", status=400) + + if not isinstance(new_tags, list): + return web.Response(text="Tags must be a list", status=400) + + tags = await self._tag_update_service.add_tags( + file_path=file_path, + new_tags=new_tags, + metadata_loader=self._metadata_sync.load_local_metadata, + update_cache=self._service.scanner.update_single_model_cache, + ) + + return web.json_response({"success": True, "tags": tags}) + except Exception as exc: + self._logger.error("Error adding tags: %s", exc, exc_info=True) + return web.Response(text=str(exc), status=500) async def rename_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_rename_model(request, self._service.scanner) + try: + data = await request.json() + file_path = data.get("file_path") + new_file_name = data.get("new_file_name") + + if not file_path or not new_file_name: + return web.json_response( + { + "success": False, + "error": "File path and new file name are required", + }, + status=400, + ) + + result = await self._lifecycle_service.rename_model( + file_path=file_path, new_file_name=new_file_name + ) + + return web.json_response( + { + **result, + "new_preview_path": config.get_preview_static_url( + result.get("new_preview_path") + ), + } + ) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error renaming model: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) async def bulk_delete_models(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_bulk_delete_models(request, self._service.scanner) + try: + data = await request.json() + file_paths = data.get("file_paths", []) + if not file_paths: + return web.json_response( + { + "success": False, + "error": "No file paths provided for deletion", + }, + status=400, + ) + + result = await self._lifecycle_service.bulk_delete_models(file_paths) + return web.json_response(result) + except ValueError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error in bulk delete: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) async def verify_duplicates(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_verify_duplicates(request, self._service.scanner) + try: + data = await request.json() + file_paths = data.get("file_paths", []) + + if not file_paths: + return web.json_response( + {"success": False, "error": "No file paths provided for verification"}, + status=400, + ) + + results = await self._metadata_sync.verify_duplicate_hashes( + file_paths=file_paths, + metadata_loader=self._metadata_sync.load_local_metadata, + hash_calculator=calculate_sha256, + update_cache=self._service.scanner.update_single_model_cache, + ) + + return web.json_response({"success": True, **results}) + except Exception as exc: + self._logger.error("Error verifying duplicate models: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) class ModelQueryHandler: @@ -429,12 +677,35 @@ class ModelQueryHandler: class ModelDownloadHandler: """Coordinate downloads and progress reporting.""" - def __init__(self, *, ws_manager: WebSocketManager, logger: logging.Logger) -> None: + def __init__( + self, + *, + ws_manager: WebSocketManager, + logger: logging.Logger, + download_use_case: DownloadModelUseCase, + download_coordinator: DownloadCoordinator, + ) -> None: self._ws_manager = ws_manager self._logger = logger + self._download_use_case = download_use_case + self._download_coordinator = download_coordinator async def download_model(self, request: web.Request) -> web.Response: - return await ModelRouteUtils.handle_download_model(request) + try: + payload = await request.json() + result = await self._download_use_case.execute(payload) + if not result.get("success", False): + return web.json_response(result, status=500) + return web.json_response(result) + except DownloadModelValidationError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except DownloadModelEarlyAccessError as exc: + self._logger.warning("Early access error: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=401) + except Exception as exc: + error_message = str(exc) + self._logger.error("Error downloading model: %s", error_message, exc_info=True) + return web.json_response({"success": False, "error": error_message}, status=500) async def download_model_get(self, request: web.Request) -> web.Response: try: @@ -460,7 +731,15 @@ class ModelDownloadHandler: future.set_result(data) mock_request = type("MockRequest", (), {"json": lambda self=None: future})() - return await ModelRouteUtils.handle_download_model(mock_request) + result = await self._download_use_case.execute(data) + if not result.get("success", False): + return web.json_response(result, status=500) + return web.json_response(result) + except DownloadModelValidationError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except DownloadModelEarlyAccessError as exc: + self._logger.warning("Early access error: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=401) except Exception as exc: self._logger.error("Error downloading model via GET: %s", exc, exc_info=True) return web.Response(status=500, text=str(exc)) @@ -470,8 +749,8 @@ class ModelDownloadHandler: download_id = request.query.get("download_id") if not download_id: return web.json_response({"success": False, "error": "Download ID is required"}, status=400) - mock_request = type("MockRequest", (), {"match_info": {"download_id": download_id}})() - return await ModelRouteUtils.handle_cancel_download(mock_request) + result = await self._download_coordinator.cancel_download(download_id) + return web.json_response(result) except Exception as exc: self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -504,6 +783,9 @@ class ModelCivitaiHandler: validate_model_type: Callable[[str], bool], expected_model_types: Callable[[], str], find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]], + metadata_sync: MetadataSyncService, + metadata_refresh_use_case: BulkMetadataRefreshUseCase, + metadata_progress_callback: MetadataRefreshProgressReporter, ) -> None: self._service = service self._settings = settings_service @@ -513,75 +795,17 @@ class ModelCivitaiHandler: self._validate_model_type = validate_model_type self._expected_model_types = expected_model_types self._find_model_file = find_model_file + self._metadata_sync = metadata_sync + self._metadata_refresh_use_case = metadata_refresh_use_case + self._metadata_progress_callback = metadata_progress_callback async def fetch_all_civitai(self, request: web.Request) -> web.Response: try: - cache = await self._service.scanner.get_cached_data() - total = len(cache.raw_data) - processed = 0 - success = 0 - needs_resort = False - - enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False) - to_process = [ - model - for model in cache.raw_data - if model.get("sha256") - and (not model.get("civitai") or not model["civitai"].get("id")) - and ( - (enable_metadata_archive_db and not model.get("db_checked", False)) - or (not enable_metadata_archive_db and model.get("from_civitai") is True) - ) - ] - total_to_process = len(to_process) - - await self._ws_manager.broadcast({ - "status": "started", - "total": total_to_process, - "processed": 0, - "success": 0, - }) - - for model in to_process: - try: - original_name = model.get("model_name") - result, error = await ModelRouteUtils.fetch_and_update_model( - sha256=model["sha256"], - file_path=model["file_path"], - model_data=model, - update_cache_func=self._service.scanner.update_single_model_cache, - ) - if result: - success += 1 - if original_name != model.get("model_name"): - needs_resort = True - processed += 1 - await self._ws_manager.broadcast({ - "status": "processing", - "total": total_to_process, - "processed": processed, - "success": success, - "current_name": model.get("model_name", "Unknown"), - }) - except Exception as exc: # pragma: no cover - logging path - self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc) - - if needs_resort: - await cache.resort() - - await self._ws_manager.broadcast({ - "status": "completed", - "total": total_to_process, - "processed": processed, - "success": success, - }) - - return web.json_response({ - "success": True, - "message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})", - }) + result = await self._metadata_refresh_use_case.execute_with_error_handling( + progress_callback=self._metadata_progress_callback + ) + return web.json_response(result) except Exception as exc: - await self._ws_manager.broadcast({"status": "error", "error": str(exc)}) self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc) return web.Response(text=str(exc), status=500) @@ -687,31 +911,18 @@ class ModelAutoOrganizeHandler: def __init__( self, *, - file_service: ModelFileService, + use_case: AutoOrganizeUseCase, progress_callback: WebSocketProgressCallback, ws_manager: WebSocketManager, logger: logging.Logger, ) -> None: - self._file_service = file_service + self._use_case = use_case self._progress_callback = progress_callback self._ws_manager = ws_manager self._logger = logger async def auto_organize_models(self, request: web.Request) -> web.Response: try: - if self._ws_manager.is_auto_organize_running(): - return web.json_response( - {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, - status=409, - ) - - auto_organize_lock = await self._ws_manager.get_auto_organize_lock() - if auto_organize_lock.locked(): - return web.json_response( - {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, - status=409, - ) - file_paths = None if request.method == "POST": try: @@ -720,17 +931,24 @@ class ModelAutoOrganizeHandler: except Exception: # pragma: no cover - permissive path pass - async with auto_organize_lock: - result = await self._file_service.auto_organize_models( - file_paths=file_paths, - progress_callback=self._progress_callback, - ) - return web.json_response(result.to_dict()) + result = await self._use_case.execute( + file_paths=file_paths, + progress_callback=self._progress_callback, + ) + return web.json_response(result.to_dict()) + except AutoOrganizeInProgressError: + return web.json_response( + {"success": False, "error": "Auto-organize is already running. Please wait for it to complete."}, + status=409, + ) except Exception as exc: self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True) - await self._ws_manager.broadcast_auto_organize_progress( - {"type": "auto_organize_progress", "status": "error", "error": str(exc)} - ) + try: + await self._progress_callback.on_progress( + {"type": "auto_organize_progress", "status": "error", "error": str(exc)} + ) + except Exception: # pragma: no cover - defensive reporting + pass return web.json_response({"success": False, "error": str(exc)}, status=500) async def get_auto_organize_progress(self, request: web.Request) -> web.Response: diff --git a/py/routes/handlers/recipe_handlers.py b/py/routes/handlers/recipe_handlers.py new file mode 100644 index 00000000..aa912477 --- /dev/null +++ b/py/routes/handlers/recipe_handlers.py @@ -0,0 +1,725 @@ +"""Dedicated handler objects for recipe-related routes.""" +from __future__ import annotations + +import json +import logging +import os +from dataclasses import dataclass +from typing import Any, Awaitable, Callable, Dict, Mapping, Optional + +from aiohttp import web + +from ...config import config +from ...services.server_i18n import server_i18n as default_server_i18n +from ...services.settings_manager import SettingsManager +from ...services.recipes import ( + RecipeAnalysisService, + RecipeDownloadError, + RecipeNotFoundError, + RecipePersistenceService, + RecipeSharingService, + RecipeValidationError, +) + +Logger = logging.Logger +EnsureDependenciesCallable = Callable[[], Awaitable[None]] +RecipeScannerGetter = Callable[[], Any] +CivitaiClientGetter = Callable[[], Any] + + +@dataclass(frozen=True) +class RecipeHandlerSet: + """Group of handlers providing recipe route implementations.""" + + page_view: "RecipePageView" + listing: "RecipeListingHandler" + query: "RecipeQueryHandler" + management: "RecipeManagementHandler" + analysis: "RecipeAnalysisHandler" + sharing: "RecipeSharingHandler" + + def to_route_mapping(self) -> Mapping[str, Callable[[web.Request], Awaitable[web.StreamResponse]]]: + """Expose handler coroutines keyed by registrar handler names.""" + + return { + "render_page": self.page_view.render_page, + "list_recipes": self.listing.list_recipes, + "get_recipe": self.listing.get_recipe, + "analyze_uploaded_image": self.analysis.analyze_uploaded_image, + "analyze_local_image": self.analysis.analyze_local_image, + "save_recipe": self.management.save_recipe, + "delete_recipe": self.management.delete_recipe, + "get_top_tags": self.query.get_top_tags, + "get_base_models": self.query.get_base_models, + "share_recipe": self.sharing.share_recipe, + "download_shared_recipe": self.sharing.download_shared_recipe, + "get_recipe_syntax": self.query.get_recipe_syntax, + "update_recipe": self.management.update_recipe, + "reconnect_lora": self.management.reconnect_lora, + "find_duplicates": self.query.find_duplicates, + "bulk_delete": self.management.bulk_delete, + "save_recipe_from_widget": self.management.save_recipe_from_widget, + "get_recipes_for_lora": self.query.get_recipes_for_lora, + "scan_recipes": self.query.scan_recipes, + } + + +class RecipePageView: + """Render the recipe shell page.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + settings_service: SettingsManager, + server_i18n=default_server_i18n, + template_env, + template_name: str, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._settings = settings_service + self._server_i18n = server_i18n + self._template_env = template_env + self._template_name = template_name + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + + async def render_page(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: # pragma: no cover - defensive guard + raise RuntimeError("Recipe scanner not available") + + user_language = self._settings.get("language", "en") + self._server_i18n.set_locale(user_language) + + try: + await recipe_scanner.get_cached_data(force_refresh=False) + rendered = self._template_env.get_template(self._template_name).render( + recipes=[], + is_initializing=False, + settings=self._settings, + request=request, + t=self._server_i18n.get_translation, + ) + except Exception as cache_error: # pragma: no cover - logging path + self._logger.error("Error loading recipe cache data: %s", cache_error) + rendered = self._template_env.get_template(self._template_name).render( + is_initializing=True, + settings=self._settings, + request=request, + t=self._server_i18n.get_translation, + ) + return web.Response(text=rendered, content_type="text/html") + except Exception as exc: # pragma: no cover - logging path + self._logger.error("Error handling recipes request: %s", exc, exc_info=True) + return web.Response(text="Error loading recipes page", status=500) + + +class RecipeListingHandler: + """Provide listing and detail APIs for recipes.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + + async def list_recipes(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + page = int(request.query.get("page", "1")) + page_size = int(request.query.get("page_size", "20")) + sort_by = request.query.get("sort_by", "date") + search = request.query.get("search") + + search_options = { + "title": request.query.get("search_title", "true").lower() == "true", + "tags": request.query.get("search_tags", "true").lower() == "true", + "lora_name": request.query.get("search_lora_name", "true").lower() == "true", + "lora_model": request.query.get("search_lora_model", "true").lower() == "true", + } + + filters: Dict[str, list[str]] = {} + base_models = request.query.get("base_models") + if base_models: + filters["base_model"] = base_models.split(",") + + tags = request.query.get("tags") + if tags: + filters["tags"] = tags.split(",") + + lora_hash = request.query.get("lora_hash") + + result = await recipe_scanner.get_paginated_data( + page=page, + page_size=page_size, + sort_by=sort_by, + search=search, + filters=filters, + search_options=search_options, + lora_hash=lora_hash, + ) + + for item in result.get("items", []): + file_path = item.get("file_path") + if file_path: + item["file_url"] = self.format_recipe_file_url(file_path) + else: + item.setdefault("file_url", "/loras_static/images/no-preview.png") + item.setdefault("loras", []) + item.setdefault("base_model", "") + + return web.json_response(result) + except Exception as exc: + self._logger.error("Error retrieving recipes: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def get_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) + + if not recipe: + return web.json_response({"error": "Recipe not found"}, status=404) + return web.json_response(recipe) + except Exception as exc: + self._logger.error("Error retrieving recipe details: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + def format_recipe_file_url(self, file_path: str) -> str: + try: + recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, "/") + normalized_path = file_path.replace(os.sep, "/") + if normalized_path.startswith(recipes_dir): + relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, "/") + return f"/loras_static/root1/preview/{relative_path}" + + file_name = os.path.basename(file_path) + return f"/loras_static/root1/preview/recipes/{file_name}" + except Exception as exc: # pragma: no cover - logging path + self._logger.error("Error formatting recipe file URL: %s", exc, exc_info=True) + return "/loras_static/images/no-preview.png" + + +class RecipeQueryHandler: + """Provide read-only insights on recipe data.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + format_recipe_file_url: Callable[[str], str], + logger: Logger, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._format_recipe_file_url = format_recipe_file_url + self._logger = logger + + async def get_top_tags(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + limit = int(request.query.get("limit", "20")) + cache = await recipe_scanner.get_cached_data() + + tag_counts: Dict[str, int] = {} + for recipe in getattr(cache, "raw_data", []): + for tag in recipe.get("tags", []) or []: + tag_counts[tag] = tag_counts.get(tag, 0) + 1 + + sorted_tags = [{"tag": tag, "count": count} for tag, count in tag_counts.items()] + sorted_tags.sort(key=lambda entry: entry["count"], reverse=True) + return web.json_response({"success": True, "tags": sorted_tags[:limit]}) + except Exception as exc: + self._logger.error("Error retrieving top tags: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_base_models(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + cache = await recipe_scanner.get_cached_data() + + base_model_counts: Dict[str, int] = {} + for recipe in getattr(cache, "raw_data", []): + base_model = recipe.get("base_model") + if base_model: + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + sorted_models = [{"name": model, "count": count} for model, count in base_model_counts.items()] + sorted_models.sort(key=lambda entry: entry["count"], reverse=True) + return web.json_response({"success": True, "base_models": sorted_models}) + except Exception as exc: + self._logger.error("Error retrieving base models: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_recipes_for_lora(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + lora_hash = request.query.get("hash") + if not lora_hash: + return web.json_response({"success": False, "error": "Lora hash is required"}, status=400) + + matching_recipes = await recipe_scanner.get_recipes_for_lora(lora_hash) + return web.json_response({"success": True, "recipes": matching_recipes}) + except Exception as exc: + self._logger.error("Error getting recipes for Lora: %s", exc) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def scan_recipes(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + self._logger.info("Manually triggering recipe cache rebuild") + await recipe_scanner.get_cached_data(force_refresh=True) + return web.json_response({"success": True, "message": "Recipe cache refreshed successfully"}) + except Exception as exc: + self._logger.error("Error refreshing recipe cache: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def find_duplicates(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + duplicate_groups = await recipe_scanner.find_all_duplicate_recipes() + response_data = [] + + for fingerprint, recipe_ids in duplicate_groups.items(): + if len(recipe_ids) <= 1: + continue + + recipes = [] + for recipe_id in recipe_ids: + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) + if recipe: + recipes.append( + { + "id": recipe.get("id"), + "title": recipe.get("title"), + "file_url": recipe.get("file_url") + or self._format_recipe_file_url(recipe.get("file_path", "")), + "modified": recipe.get("modified"), + "created_date": recipe.get("created_date"), + "lora_count": len(recipe.get("loras", [])), + } + ) + + if len(recipes) >= 2: + recipes.sort(key=lambda entry: entry.get("modified", 0), reverse=True) + response_data.append( + { + "fingerprint": fingerprint, + "count": len(recipes), + "recipes": recipes, + } + ) + + response_data.sort(key=lambda entry: entry["count"], reverse=True) + return web.json_response({"success": True, "duplicate_groups": response_data}) + except Exception as exc: + self._logger.error("Error finding duplicate recipes: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def get_recipe_syntax(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + try: + syntax_parts = await recipe_scanner.get_recipe_syntax_tokens(recipe_id) + except RecipeNotFoundError: + return web.json_response({"error": "Recipe not found"}, status=404) + + if not syntax_parts: + return web.json_response({"error": "No LoRAs found in this recipe"}, status=400) + + return web.json_response({"success": True, "syntax": " ".join(syntax_parts)}) + except Exception as exc: + self._logger.error("Error generating recipe syntax: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + +class RecipeManagementHandler: + """Handle create/update/delete style recipe operations.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + persistence_service: RecipePersistenceService, + analysis_service: RecipeAnalysisService, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + self._persistence_service = persistence_service + self._analysis_service = analysis_service + + async def save_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + reader = await request.multipart() + payload = await self._parse_save_payload(reader) + + result = await self._persistence_service.save_recipe( + recipe_scanner=recipe_scanner, + image_bytes=payload["image_bytes"], + image_base64=payload["image_base64"], + name=payload["name"], + tags=payload["tags"], + metadata=payload["metadata"], + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error saving recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def delete_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + result = await self._persistence_service.delete_recipe( + recipe_scanner=recipe_scanner, recipe_id=recipe_id + ) + return web.json_response(result.payload, status=result.status) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) + except Exception as exc: + self._logger.error("Error deleting recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def update_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + data = await request.json() + result = await self._persistence_service.update_recipe( + recipe_scanner=recipe_scanner, recipe_id=recipe_id, updates=data + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) + except Exception as exc: + self._logger.error("Error updating recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def reconnect_lora(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + data = await request.json() + for field in ("recipe_id", "lora_index", "target_name"): + if field not in data: + raise RecipeValidationError(f"Missing required field: {field}") + + result = await self._persistence_service.reconnect_lora( + recipe_scanner=recipe_scanner, + recipe_id=data["recipe_id"], + lora_index=int(data["lora_index"]), + target_name=data["target_name"], + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) + except Exception as exc: + self._logger.error("Error reconnecting LoRA: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def bulk_delete(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + data = await request.json() + recipe_ids = data.get("recipe_ids", []) + result = await self._persistence_service.bulk_delete( + recipe_scanner=recipe_scanner, recipe_ids=recipe_ids + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"success": False, "error": str(exc)}, status=404) + except Exception as exc: + self._logger.error("Error performing bulk delete: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def save_recipe_from_widget(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + analysis = await self._analysis_service.analyze_widget_metadata( + recipe_scanner=recipe_scanner + ) + metadata = analysis.payload.get("metadata") + image_bytes = analysis.payload.get("image_bytes") + if not metadata or image_bytes is None: + raise RecipeValidationError("Unable to extract metadata from widget") + + result = await self._persistence_service.save_recipe_from_widget( + recipe_scanner=recipe_scanner, + metadata=metadata, + image_bytes=image_bytes, + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc)}, status=400) + except Exception as exc: + self._logger.error("Error saving recipe from widget: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def _parse_save_payload(self, reader) -> dict[str, Any]: + image_bytes: Optional[bytes] = None + image_base64: Optional[str] = None + name: Optional[str] = None + tags: list[str] = [] + metadata: Optional[Dict[str, Any]] = None + + while True: + field = await reader.next() + if field is None: + break + if field.name == "image": + image_chunks = bytearray() + while True: + chunk = await field.read_chunk() + if not chunk: + break + image_chunks.extend(chunk) + image_bytes = bytes(image_chunks) + elif field.name == "image_base64": + image_base64 = await field.text() + elif field.name == "name": + name = await field.text() + elif field.name == "tags": + tags_text = await field.text() + try: + parsed_tags = json.loads(tags_text) + tags = parsed_tags if isinstance(parsed_tags, list) else [] + except Exception: + tags = [] + elif field.name == "metadata": + metadata_text = await field.text() + try: + metadata = json.loads(metadata_text) + except Exception: + metadata = {} + + return { + "image_bytes": image_bytes, + "image_base64": image_base64, + "name": name, + "tags": tags, + "metadata": metadata, + } + + +class RecipeAnalysisHandler: + """Analyze images to extract recipe metadata.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + civitai_client_getter: CivitaiClientGetter, + logger: Logger, + analysis_service: RecipeAnalysisService, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._civitai_client_getter = civitai_client_getter + self._logger = logger + self._analysis_service = analysis_service + + async def analyze_uploaded_image(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + civitai_client = self._civitai_client_getter() + if recipe_scanner is None or civitai_client is None: + raise RuntimeError("Required services unavailable") + + content_type = request.headers.get("Content-Type", "") + if "multipart/form-data" in content_type: + reader = await request.multipart() + field = await reader.next() + if field is None or field.name != "image": + raise RecipeValidationError("No image field found") + image_chunks = bytearray() + while True: + chunk = await field.read_chunk() + if not chunk: + break + image_chunks.extend(chunk) + result = await self._analysis_service.analyze_uploaded_image( + image_bytes=bytes(image_chunks), + recipe_scanner=recipe_scanner, + ) + return web.json_response(result.payload, status=result.status) + + if "application/json" in content_type: + data = await request.json() + result = await self._analysis_service.analyze_remote_image( + url=data.get("url"), + recipe_scanner=recipe_scanner, + civitai_client=civitai_client, + ) + return web.json_response(result.payload, status=result.status) + + raise RecipeValidationError("Unsupported content type") + except RecipeValidationError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=400) + except RecipeDownloadError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=404) + except Exception as exc: + self._logger.error("Error analyzing recipe image: %s", exc, exc_info=True) + return web.json_response({"error": str(exc), "loras": []}, status=500) + + async def analyze_local_image(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + data = await request.json() + result = await self._analysis_service.analyze_local_image( + file_path=data.get("path"), + recipe_scanner=recipe_scanner, + ) + return web.json_response(result.payload, status=result.status) + except RecipeValidationError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=400) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc), "loras": []}, status=404) + except Exception as exc: + self._logger.error("Error analyzing local image: %s", exc, exc_info=True) + return web.json_response({"error": str(exc), "loras": []}, status=500) + + +class RecipeSharingHandler: + """Serve endpoints related to recipe sharing.""" + + def __init__( + self, + *, + ensure_dependencies_ready: EnsureDependenciesCallable, + recipe_scanner_getter: RecipeScannerGetter, + logger: Logger, + sharing_service: RecipeSharingService, + ) -> None: + self._ensure_dependencies_ready = ensure_dependencies_ready + self._recipe_scanner_getter = recipe_scanner_getter + self._logger = logger + self._sharing_service = sharing_service + + async def share_recipe(self, request: web.Request) -> web.Response: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + result = await self._sharing_service.share_recipe( + recipe_scanner=recipe_scanner, recipe_id=recipe_id + ) + return web.json_response(result.payload, status=result.status) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) + except Exception as exc: + self._logger.error("Error sharing recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) + + async def download_shared_recipe(self, request: web.Request) -> web.StreamResponse: + try: + await self._ensure_dependencies_ready() + recipe_scanner = self._recipe_scanner_getter() + if recipe_scanner is None: + raise RuntimeError("Recipe scanner unavailable") + + recipe_id = request.match_info["recipe_id"] + download_info = await self._sharing_service.prepare_download( + recipe_scanner=recipe_scanner, recipe_id=recipe_id + ) + return web.FileResponse( + download_info.file_path, + headers={ + "Content-Disposition": f'attachment; filename="{download_info.download_filename}"' + }, + ) + except RecipeNotFoundError as exc: + return web.json_response({"error": str(exc)}, status=404) + except Exception as exc: + self._logger.error("Error downloading shared recipe: %s", exc, exc_info=True) + return web.json_response({"error": str(exc)}, status=500) diff --git a/py/routes/recipe_route_registrar.py b/py/routes/recipe_route_registrar.py new file mode 100644 index 00000000..471edf19 --- /dev/null +++ b/py/routes/recipe_route_registrar.py @@ -0,0 +1,64 @@ +"""Route registrar for recipe endpoints.""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Mapping + +from aiohttp import web + + +@dataclass(frozen=True) +class RouteDefinition: + """Declarative definition for a recipe HTTP route.""" + + method: str + path: str + handler_name: str + + +ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( + RouteDefinition("GET", "/loras/recipes", "render_page"), + RouteDefinition("GET", "/api/lm/recipes", "list_recipes"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}", "get_recipe"), + RouteDefinition("POST", "/api/lm/recipes/analyze-image", "analyze_uploaded_image"), + RouteDefinition("POST", "/api/lm/recipes/analyze-local-image", "analyze_local_image"), + RouteDefinition("POST", "/api/lm/recipes/save", "save_recipe"), + RouteDefinition("DELETE", "/api/lm/recipe/{recipe_id}", "delete_recipe"), + RouteDefinition("GET", "/api/lm/recipes/top-tags", "get_top_tags"), + RouteDefinition("GET", "/api/lm/recipes/base-models", "get_base_models"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share", "share_recipe"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/share/download", "download_shared_recipe"), + RouteDefinition("GET", "/api/lm/recipe/{recipe_id}/syntax", "get_recipe_syntax"), + RouteDefinition("PUT", "/api/lm/recipe/{recipe_id}/update", "update_recipe"), + RouteDefinition("POST", "/api/lm/recipe/lora/reconnect", "reconnect_lora"), + RouteDefinition("GET", "/api/lm/recipes/find-duplicates", "find_duplicates"), + RouteDefinition("POST", "/api/lm/recipes/bulk-delete", "bulk_delete"), + RouteDefinition("POST", "/api/lm/recipes/save-from-widget", "save_recipe_from_widget"), + RouteDefinition("GET", "/api/lm/recipes/for-lora", "get_recipes_for_lora"), + RouteDefinition("GET", "/api/lm/recipes/scan", "scan_recipes"), +) + + +class RecipeRouteRegistrar: + """Bind declarative recipe definitions to an aiohttp router.""" + + _METHOD_MAP = { + "GET": "add_get", + "POST": "add_post", + "PUT": "add_put", + "DELETE": "add_delete", + } + + def __init__(self, app: web.Application) -> None: + self._app = app + + def register_routes(self, handler_lookup: Mapping[str, Callable[[web.Request], object]]) -> None: + for definition in ROUTE_DEFINITIONS: + handler = handler_lookup[definition.handler_name] + self._bind_route(definition.method, definition.path, handler) + + def _bind_route(self, method: str, path: str, handler: Callable) -> None: + add_method_name = self._METHOD_MAP[method.upper()] + add_method = getattr(self._app.router, add_method_name) + add_method(path, handler) + diff --git a/py/routes/recipe_routes.py b/py/routes/recipe_routes.py index 21214d99..2c233d01 100644 --- a/py/routes/recipe_routes.py +++ b/py/routes/recipe_routes.py @@ -1,1652 +1,21 @@ -import os -import time -import base64 -import jinja2 -import numpy as np -from PIL import Image -import io -import logging +"""Concrete recipe route configuration.""" + from aiohttp import web -from typing import Dict -import tempfile -import json -import asyncio -import sys -from ..utils.exif_utils import ExifUtils -from ..recipes import RecipeParserFactory -from ..utils.constants import CARD_PREVIEW_WIDTH -from ..services.settings_manager import settings -from ..services.server_i18n import server_i18n -from ..config import config +from .base_recipe_routes import BaseRecipeRoutes +from .recipe_route_registrar import RecipeRouteRegistrar -# Check if running in standalone mode -standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" -from ..services.service_registry import ServiceRegistry # Add ServiceRegistry import -from ..services.downloader import get_downloader +class RecipeRoutes(BaseRecipeRoutes): + """API route handlers for Recipe management.""" -# Only import MetadataRegistry in non-standalone mode -if not standalone_mode: - # Import metadata_collector functions and classes conditionally - from ..metadata_collector import get_metadata # Add MetadataCollector import - from ..metadata_collector.metadata_processor import MetadataProcessor # Add MetadataProcessor import - from ..metadata_collector.metadata_registry import MetadataRegistry - -logger = logging.getLogger(__name__) - -class RecipeRoutes: - """API route handlers for Recipe management""" - - def __init__(self): - # Initialize service references as None, will be set during async init - self.recipe_scanner = None - self.civitai_client = None - self.template_env = jinja2.Environment( - loader=jinja2.FileSystemLoader(config.templates_path), - autoescape=True - ) - - # Pre-warm the cache - self._init_cache_task = None - - async def init_services(self): - """Initialize services from ServiceRegistry""" - self.recipe_scanner = await ServiceRegistry.get_recipe_scanner() - self.civitai_client = await ServiceRegistry.get_civitai_client() + template_name = "recipes.html" @classmethod def setup_routes(cls, app: web.Application): - """Register API routes""" + """Register API routes using the declarative registrar.""" + routes = cls() - app.router.add_get('/loras/recipes', routes.handle_recipes_page) - - app.router.add_get('/api/lm/recipes', routes.get_recipes) - app.router.add_get('/api/lm/recipe/{recipe_id}', routes.get_recipe_detail) - app.router.add_post('/api/lm/recipes/analyze-image', routes.analyze_recipe_image) - app.router.add_post('/api/lm/recipes/analyze-local-image', routes.analyze_local_image) - app.router.add_post('/api/lm/recipes/save', routes.save_recipe) - app.router.add_delete('/api/lm/recipe/{recipe_id}', routes.delete_recipe) - - # Add new filter-related endpoints - app.router.add_get('/api/lm/recipes/top-tags', routes.get_top_tags) - app.router.add_get('/api/lm/recipes/base-models', routes.get_base_models) - - # Add new sharing endpoints - app.router.add_get('/api/lm/recipe/{recipe_id}/share', routes.share_recipe) - app.router.add_get('/api/lm/recipe/{recipe_id}/share/download', routes.download_shared_recipe) - - # Add new endpoint for getting recipe syntax - app.router.add_get('/api/lm/recipe/{recipe_id}/syntax', routes.get_recipe_syntax) - - # Add new endpoint for updating recipe metadata (name, tags and source_path) - app.router.add_put('/api/lm/recipe/{recipe_id}/update', routes.update_recipe) - - # Add new endpoint for reconnecting deleted LoRAs - app.router.add_post('/api/lm/recipe/lora/reconnect', routes.reconnect_lora) - - # Add new endpoint for finding duplicate recipes - app.router.add_get('/api/lm/recipes/find-duplicates', routes.find_duplicates) - - # Add new endpoint for bulk deletion of recipes - app.router.add_post('/api/lm/recipes/bulk-delete', routes.bulk_delete) - - # Start cache initialization - app.on_startup.append(routes._init_cache) - - app.router.add_post('/api/lm/recipes/save-from-widget', routes.save_recipe_from_widget) - - # Add route to get recipes for a specific Lora - app.router.add_get('/api/lm/recipes/for-lora', routes.get_recipes_for_lora) - - # Add new endpoint for scanning and rebuilding the recipe cache - app.router.add_get('/api/lm/recipes/scan', routes.scan_recipes) - - async def _init_cache(self, app): - """Initialize cache on startup""" - try: - # Initialize services first - await self.init_services() - - # Now that services are initialized, get the lora scanner - lora_scanner = self.recipe_scanner._lora_scanner - - # Get lora cache to ensure it's initialized - lora_cache = await lora_scanner.get_cached_data() - - # Verify hash index is built - if hasattr(lora_scanner, '_hash_index'): - hash_index_size = len(lora_scanner._hash_index._hash_to_path) if hasattr(lora_scanner._hash_index, '_hash_to_path') else 0 - - # Now that lora scanner is initialized, initialize recipe cache - await self.recipe_scanner.get_cached_data(force_refresh=True) - except Exception as e: - logger.error(f"Error pre-warming recipe cache: {e}", exc_info=True) - - async def handle_recipes_page(self, request: web.Request) -> web.Response: - """Handle GET /loras/recipes request""" - try: - # Ensure services are initialized - await self.init_services() - - # 获取用户语言设置 - user_language = settings.get('language', 'en') - - # 设置服务端i18n语言 - server_i18n.set_locale(user_language) - - # 为模板环境添加i18n过滤器 - if not hasattr(self.template_env, '_i18n_filter_added'): - self.template_env.filters['t'] = server_i18n.create_template_filter() - self.template_env._i18n_filter_added = True - - # Skip initialization check and directly try to get cached data - try: - # Recipe scanner will initialize cache if needed - await self.recipe_scanner.get_cached_data(force_refresh=False) - template = self.template_env.get_template('recipes.html') - rendered = template.render( - recipes=[], # Frontend will load recipes via API - is_initializing=False, - settings=settings, - request=request, - # 添加服务端翻译函数 - t=server_i18n.get_translation, - ) - except Exception as cache_error: - logger.error(f"Error loading recipe cache data: {cache_error}") - # Still keep error handling - show initializing page on error - template = self.template_env.get_template('recipes.html') - rendered = template.render( - is_initializing=True, - settings=settings, - request=request, - # 添加服务端翻译函数 - t=server_i18n.get_translation, - ) - logger.info("Recipe cache error, returning initialization page") - - return web.Response( - text=rendered, - content_type='text/html' - ) - - except Exception as e: - logger.error(f"Error handling recipes request: {e}", exc_info=True) - return web.Response( - text="Error loading recipes page", - status=500 - ) - - async def get_recipes(self, request: web.Request) -> web.Response: - """API endpoint for getting paginated recipes""" - try: - # Ensure services are initialized - await self.init_services() - - # Get query parameters with defaults - page = int(request.query.get('page', '1')) - page_size = int(request.query.get('page_size', '20')) - sort_by = request.query.get('sort_by', 'date') - search = request.query.get('search', None) - - # Get search options (renamed for better clarity) - search_title = request.query.get('search_title', 'true').lower() == 'true' - search_tags = request.query.get('search_tags', 'true').lower() == 'true' - search_lora_name = request.query.get('search_lora_name', 'true').lower() == 'true' - search_lora_model = request.query.get('search_lora_model', 'true').lower() == 'true' - - # Get filter parameters - base_models = request.query.get('base_models', None) - tags = request.query.get('tags', None) - - # New parameter: get LoRA hash filter - lora_hash = request.query.get('lora_hash', None) - - # Parse filter parameters - filters = {} - if base_models: - filters['base_model'] = base_models.split(',') - if tags: - filters['tags'] = tags.split(',') - - # Add search options to filters - search_options = { - 'title': search_title, - 'tags': search_tags, - 'lora_name': search_lora_name, - 'lora_model': search_lora_model - } - - # Get paginated data with the new lora_hash parameter - result = await self.recipe_scanner.get_paginated_data( - page=page, - page_size=page_size, - sort_by=sort_by, - search=search, - filters=filters, - search_options=search_options, - lora_hash=lora_hash - ) - - # Format the response data with static URLs for file paths - for item in result['items']: - # Always ensure file_url is set - if 'file_path' in item: - item['file_url'] = self._format_recipe_file_url(item['file_path']) - else: - item['file_url'] = '/loras_static/images/no-preview.png' - - # 确保 loras 数组存在 - if 'loras' not in item: - item['loras'] = [] - - # 确保有 base_model 字段 - if 'base_model' not in item: - item['base_model'] = "" - - return web.json_response(result) - except Exception as e: - logger.error(f"Error retrieving recipes: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_recipe_detail(self, request: web.Request) -> web.Response: - """Get detailed information about a specific recipe""" - try: - # Ensure services are initialized - await self.init_services() - - recipe_id = request.match_info['recipe_id'] - - # Use the new get_recipe_by_id method from recipe_scanner - recipe = await self.recipe_scanner.get_recipe_by_id(recipe_id) - - if not recipe: - return web.json_response({"error": "Recipe not found"}, status=404) - - return web.json_response(recipe) - except Exception as e: - logger.error(f"Error retrieving recipe details: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _format_recipe_file_url(self, file_path: str) -> str: - """Format file path for recipe image as a URL""" - try: - # Return the file URL directly for the first lora root's preview - recipes_dir = os.path.join(config.loras_roots[0], "recipes").replace(os.sep, '/') - if file_path.replace(os.sep, '/').startswith(recipes_dir): - relative_path = os.path.relpath(file_path, config.loras_roots[0]).replace(os.sep, '/') - return f"/loras_static/root1/preview/{relative_path}" - - # If not in recipes dir, try to create a valid URL from the file path - file_name = os.path.basename(file_path) - return f"/loras_static/root1/preview/recipes/{file_name}" - except Exception as e: - logger.error(f"Error formatting recipe file URL: {e}", exc_info=True) - return '/loras_static/images/no-preview.png' # Return default image on error - - def _format_recipe_data(self, recipe: Dict) -> Dict: - """Format recipe data for API response""" - formatted = {**recipe} # Copy all fields - - # Format file paths to URLs - if 'file_path' in formatted: - formatted['file_url'] = self._format_recipe_file_url(formatted['file_path']) - - # Format dates for display - for date_field in ['created_date', 'modified']: - if date_field in formatted: - formatted[f"{date_field}_formatted"] = self._format_timestamp(formatted[date_field]) - - return formatted - - def _format_timestamp(self, timestamp: float) -> str: - """Format timestamp for display""" - from datetime import datetime - return datetime.fromtimestamp(timestamp).strftime('%Y-%m-%d %H:%M:%S') - - async def analyze_recipe_image(self, request: web.Request) -> web.Response: - """Analyze an uploaded image or URL for recipe metadata""" - temp_path = None - try: - # Ensure services are initialized - await self.init_services() - - # Check if request contains multipart data (image) or JSON data (url) - content_type = request.headers.get('Content-Type', '') - - is_url_mode = False - metadata = None # Initialize metadata variable - - if 'multipart/form-data' in content_type: - # Handle image upload - reader = await request.multipart() - field = await reader.next() - - if field.name != 'image': - return web.json_response({ - "error": "No image field found", - "loras": [] - }, status=400) - - # Create a temporary file to store the uploaded image - with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file: - while True: - chunk = await field.read_chunk() - if not chunk: - break - temp_file.write(chunk) - temp_path = temp_file.name - - elif 'application/json' in content_type: - # Handle URL input - data = await request.json() - url = data.get('url') - is_url_mode = True - - if not url: - return web.json_response({ - "error": "No URL provided", - "loras": [] - }, status=400) - - # Check if this is a Civitai image URL - import re - civitai_image_match = re.match(r'https://civitai\.com/images/(\d+)', url) - - if civitai_image_match: - # Extract image ID and fetch image info using get_image_info - image_id = civitai_image_match.group(1) - image_info = await self.civitai_client.get_image_info(image_id) - - if not image_info: - return web.json_response({ - "error": "Failed to fetch image information from Civitai", - "loras": [] - }, status=400) - - # Get image URL from response - image_url = image_info.get('url') - if not image_url: - return web.json_response({ - "error": "No image URL found in Civitai response", - "loras": [] - }, status=400) - - # Download image using unified downloader - downloader = await get_downloader() - # Create a temporary file to save the downloaded image - with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file: - temp_path = temp_file.name - - success, result = await downloader.download_file( - image_url, - temp_path, - use_auth=False # Image downloads typically don't need auth - ) - - if not success: - return web.json_response({ - "error": f"Failed to download image from URL: {result}", - "loras": [] - }, status=400) - - # Use meta field from image_info as metadata - if 'meta' in image_info: - metadata = image_info['meta'] - - # If metadata wasn't obtained from Civitai API, extract it from the image - if metadata is None: - # Extract metadata from the image using ExifUtils - metadata = ExifUtils.extract_image_metadata(temp_path) - - # If no metadata found, return a more specific error - if not metadata: - result = { - "error": "No metadata found in this image", - "loras": [] # Return empty loras array to prevent client-side errors - } - - # For URL mode, include the image data as base64 - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response(result, status=200) - - # Use the parser factory to get the appropriate parser - parser = RecipeParserFactory.create_parser(metadata) - - if parser is None: - result = { - "error": "No parser found for this image", - "loras": [] # Return empty loras array to prevent client-side errors - } - - # For URL mode, include the image data as base64 - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response(result, status=200) - - # Parse the metadata - result = await parser.parse_metadata( - metadata, - recipe_scanner=self.recipe_scanner - ) - - # For URL mode, include the image data as base64 - if is_url_mode and temp_path: - with open(temp_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - # Check for errors - if "error" in result and not result.get("loras"): - return web.json_response(result, status=200) - - # Calculate fingerprint from parsed loras - from ..utils.utils import calculate_recipe_fingerprint - fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) - - # Add fingerprint to result - result["fingerprint"] = fingerprint - - # Find matching recipes with the same fingerprint - matching_recipes = [] - if fingerprint: - matching_recipes = await self.recipe_scanner.find_recipes_by_fingerprint(fingerprint) - - # Add matching recipes to result - result["matching_recipes"] = matching_recipes - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error analyzing recipe image: {e}", exc_info=True) - return web.json_response({ - "error": str(e), - "loras": [] # Return empty loras array to prevent client-side errors - }, status=500) - finally: - # Clean up the temporary file in the finally block - if temp_path and os.path.exists(temp_path): - try: - os.unlink(temp_path) - except Exception as e: - logger.error(f"Error deleting temporary file: {e}") - - async def analyze_local_image(self, request: web.Request) -> web.Response: - """Analyze a local image file for recipe metadata""" - try: - # Ensure services are initialized - await self.init_services() - - # Get JSON data from request - data = await request.json() - file_path = data.get('path') - - if not file_path: - return web.json_response({ - 'error': 'No file path provided', - 'loras': [] - }, status=400) - - # Normalize file path for cross-platform compatibility - file_path = os.path.normpath(file_path.strip('"').strip("'")) - - # Validate that the file exists - if not os.path.isfile(file_path): - return web.json_response({ - 'error': 'File not found', - 'loras': [] - }, status=404) - - # Extract metadata from the image using ExifUtils - metadata = ExifUtils.extract_image_metadata(file_path) - - # If no metadata found, return error - if not metadata: - # Get base64 image data - with open(file_path, "rb") as image_file: - image_base64 = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response({ - "error": "No metadata found in this image", - "loras": [], # Return empty loras array to prevent client-side errors - "image_base64": image_base64 - }, status=200) - - # Use the parser factory to get the appropriate parser - parser = RecipeParserFactory.create_parser(metadata) - - if parser is None: - # Get base64 image data - with open(file_path, "rb") as image_file: - image_base64 = base64.b64encode(image_file.read()).decode('utf-8') - - return web.json_response({ - "error": "No parser found for this image", - "loras": [], # Return empty loras array to prevent client-side errors - "image_base64": image_base64 - }, status=200) - - # Parse the metadata - result = await parser.parse_metadata( - metadata, - recipe_scanner=self.recipe_scanner - ) - - # Add base64 image data to result - with open(file_path, "rb") as image_file: - result["image_base64"] = base64.b64encode(image_file.read()).decode('utf-8') - - # Check for errors - if "error" in result and not result.get("loras"): - return web.json_response(result, status=200) - - # Calculate fingerprint from parsed loras - from ..utils.utils import calculate_recipe_fingerprint - fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) - - # Add fingerprint to result - result["fingerprint"] = fingerprint - - # Find matching recipes with the same fingerprint - matching_recipes = [] - if fingerprint: - matching_recipes = await self.recipe_scanner.find_recipes_by_fingerprint(fingerprint) - - # Add matching recipes to result - result["matching_recipes"] = matching_recipes - - return web.json_response(result) - - except Exception as e: - logger.error(f"Error analyzing local image: {e}", exc_info=True) - return web.json_response({ - 'error': str(e), - 'loras': [] # Return empty loras array to prevent client-side errors - }, status=500) - - async def save_recipe(self, request: web.Request) -> web.Response: - """Save a recipe to the recipes folder""" - try: - # Ensure services are initialized - await self.init_services() - - reader = await request.multipart() - - # Process form data - image = None - image_base64 = None - image_url = None - name = None - tags = [] - metadata = None - - while True: - field = await reader.next() - if field is None: - break - - if field.name == 'image': - # Read image data - image_data = b'' - while True: - chunk = await field.read_chunk() - if not chunk: - break - image_data += chunk - image = image_data - - elif field.name == 'image_base64': - # Get base64 image data - image_base64 = await field.text() - - elif field.name == 'image_url': - # Get image URL - image_url = await field.text() - - elif field.name == 'name': - name = await field.text() - - elif field.name == 'tags': - tags_text = await field.text() - try: - tags = json.loads(tags_text) - except: - tags = [] - - elif field.name == 'metadata': - metadata_text = await field.text() - try: - metadata = json.loads(metadata_text) - except: - metadata = {} - - missing_fields = [] - if not name: - missing_fields.append("name") - if not metadata: - missing_fields.append("metadata") - if missing_fields: - return web.json_response({"error": f"Missing required fields: {', '.join(missing_fields)}"}, status=400) - - # Handle different image sources - if not image: - if image_base64: - # Convert base64 to binary - try: - # Remove potential data URL prefix - if ',' in image_base64: - image_base64 = image_base64.split(',', 1)[1] - image = base64.b64decode(image_base64) - except Exception as e: - return web.json_response({"error": f"Invalid base64 image data: {str(e)}"}, status=400) - else: - return web.json_response({"error": "No image data provided"}, status=400) - - # Create recipes directory if it doesn't exist - recipes_dir = self.recipe_scanner.recipes_dir - os.makedirs(recipes_dir, exist_ok=True) - - # Generate UUID for the recipe - import uuid - recipe_id = str(uuid.uuid4()) - - # Optimize the image (resize and convert to WebP) - optimized_image, extension = ExifUtils.optimize_image( - image_data=image, - target_width=CARD_PREVIEW_WIDTH, - format='webp', - quality=85, - preserve_metadata=True - ) - - # Save the optimized image - image_filename = f"{recipe_id}{extension}" - image_path = os.path.join(recipes_dir, image_filename) - with open(image_path, 'wb') as f: - f.write(optimized_image) - - # Create the recipe data structure - current_time = time.time() - - # Format loras data according to the recipe.json format - loras_data = [] - for lora in metadata.get("loras", []): - # Modified: Always include deleted LoRAs in the recipe metadata - # Even if they're marked to be excluded, we still keep their identifying information - # The exclude flag will only be used to determine if they should be included in recipe syntax - - # Convert frontend lora format to recipe format - lora_entry = { - "file_name": lora.get("file_name", "") or os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] if lora.get("localPath") else "", - "hash": lora.get("hash", "").lower() if lora.get("hash") else "", - "strength": float(lora.get("weight", 1.0)), - "modelVersionId": lora.get("id", 0), - "modelName": lora.get("name", ""), - "modelVersionName": lora.get("version", ""), - "isDeleted": lora.get("isDeleted", False), # Preserve deletion status in saved recipe - "exclude": lora.get("exclude", False) # Add exclude flag to the recipe - } - loras_data.append(lora_entry) - - # Format gen_params according to the recipe.json format - gen_params = metadata.get("gen_params", {}) - if not gen_params and "raw_metadata" in metadata: - # Extract from raw metadata if available - raw_metadata = metadata.get("raw_metadata", {}) - gen_params = { - "prompt": raw_metadata.get("prompt", ""), - "negative_prompt": raw_metadata.get("negative_prompt", ""), - "checkpoint": raw_metadata.get("checkpoint", {}), - "steps": raw_metadata.get("steps", ""), - "sampler": raw_metadata.get("sampler", ""), - "cfg_scale": raw_metadata.get("cfg_scale", ""), - "seed": raw_metadata.get("seed", ""), - "size": raw_metadata.get("size", ""), - "clip_skip": raw_metadata.get("clip_skip", "") - } - - # Calculate recipe fingerprint - from ..utils.utils import calculate_recipe_fingerprint - fingerprint = calculate_recipe_fingerprint(loras_data) - - # Create the recipe data structure - recipe_data = { - "id": recipe_id, - "file_path": image_path, - "title": name, - "modified": current_time, - "created_date": current_time, - "base_model": metadata.get("base_model", ""), - "loras": loras_data, - "gen_params": gen_params, - "fingerprint": fingerprint - } - - # Add tags if provided - if tags: - recipe_data["tags"] = tags - - # Add source_path if provided in metadata - if metadata.get("source_path"): - recipe_data["source_path"] = metadata.get("source_path") - - # Save the recipe JSON - json_filename = f"{recipe_id}.recipe.json" - json_path = os.path.join(recipes_dir, json_filename) - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - - # Add recipe metadata to the image - ExifUtils.append_recipe_metadata(image_path, recipe_data) - - # Check for duplicates - matching_recipes = [] - if fingerprint: - matching_recipes = await self.recipe_scanner.find_recipes_by_fingerprint(fingerprint) - # Remove current recipe from matches - if recipe_id in matching_recipes: - matching_recipes.remove(recipe_id) - - # Simplified cache update approach - # Instead of trying to update the cache directly, just set it to None - # to force a refresh on the next get_cached_data call - if self.recipe_scanner._cache is not None: - # Add the recipe to the raw data if the cache exists - # This is a simple direct update without locks or timeouts - self.recipe_scanner._cache.raw_data.append(recipe_data) - # Schedule a background task to resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Added recipe {recipe_id} to cache") - - return web.json_response({ - 'success': True, - 'recipe_id': recipe_id, - 'image_path': image_path, - 'json_path': json_path, - 'matching_recipes': matching_recipes - }) - - except Exception as e: - logger.error(f"Error saving recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def delete_recipe(self, request: web.Request) -> web.Response: - """Delete a recipe by ID""" - try: - # Ensure services are initialized - await self.init_services() - - recipe_id = request.match_info['recipe_id'] - - # Get recipes directory - recipes_dir = self.recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - return web.json_response({"error": "Recipes directory not found"}, status=404) - - # Find recipe JSON file - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_json_path): - return web.json_response({"error": "Recipe not found"}, status=404) - - # Load recipe data to get image path - with open(recipe_json_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - # Get image path - image_path = recipe_data.get('file_path') - - # Delete recipe JSON file - os.remove(recipe_json_path) - logger.info(f"Deleted recipe JSON file: {recipe_json_path}") - - # Delete recipe image if it exists - if image_path and os.path.exists(image_path): - os.remove(image_path) - logger.info(f"Deleted recipe image: {image_path}") - - # Simplified cache update approach - if self.recipe_scanner._cache is not None: - # Remove the recipe from raw_data if it exists - self.recipe_scanner._cache.raw_data = [ - r for r in self.recipe_scanner._cache.raw_data - if str(r.get('id', '')) != recipe_id - ] - # Schedule a background task to resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Removed recipe {recipe_id} from cache") - - return web.json_response({"success": True, "message": "Recipe deleted successfully"}) - except Exception as e: - logger.error(f"Error deleting recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_top_tags(self, request: web.Request) -> web.Response: - """Get top tags used in recipes""" - try: - # Ensure services are initialized - await self.init_services() - - # Get limit parameter with default - limit = int(request.query.get('limit', '20')) - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Count tag occurrences - tag_counts = {} - for recipe in cache.raw_data: - if 'tags' in recipe and recipe['tags']: - for tag in recipe['tags']: - tag_counts[tag] = tag_counts.get(tag, 0) + 1 - - # Sort tags by count and limit results - sorted_tags = [{'tag': tag, 'count': count} for tag, count in tag_counts.items()] - sorted_tags.sort(key=lambda x: x['count'], reverse=True) - top_tags = sorted_tags[:limit] - - return web.json_response({ - 'success': True, - 'tags': top_tags - }) - except Exception as e: - logger.error(f"Error retrieving top tags: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def get_base_models(self, request: web.Request) -> web.Response: - """Get base models used in recipes""" - try: - # Ensure services are initialized - await self.init_services() - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Count base model occurrences - base_model_counts = {} - for recipe in cache.raw_data: - if 'base_model' in recipe and recipe['base_model']: - base_model = recipe['base_model'] - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - - # Sort base models by count - sorted_models = [{'name': model, 'count': count} for model, count in base_model_counts.items()] - sorted_models.sort(key=lambda x: x['count'], reverse=True) - - return web.json_response({ - 'success': True, - 'base_models': sorted_models - }) - except Exception as e: - logger.error(f"Error retrieving base models: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e)} - , status=500) - - async def share_recipe(self, request: web.Request) -> web.Response: - """Process a recipe image for sharing by adding metadata to EXIF""" - try: - # Ensure services are initialized - await self.init_services() - - recipe_id = request.match_info['recipe_id'] - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Find the specific recipe - recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None) - - if not recipe: - return web.json_response({"error": "Recipe not found"}, status=404) - - # Get the image path - image_path = recipe.get('file_path') - if not image_path or not os.path.exists(image_path): - return web.json_response({"error": "Recipe image not found"}, status=404) - - # Create a temporary copy of the image to modify - import tempfile - import shutil - - # Create temp file with same extension - ext = os.path.splitext(image_path)[1] - with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file: - temp_path = temp_file.name - - # Copy the original image to temp file - shutil.copy2(image_path, temp_path) - processed_path = temp_path - - # Create a URL for the processed image - # Use a timestamp to prevent caching - timestamp = int(time.time()) - url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}" - - # Store the temp path in a dictionary to serve later - if not hasattr(self, '_shared_recipes'): - self._shared_recipes = {} - - self._shared_recipes[recipe_id] = { - 'path': processed_path, - 'timestamp': timestamp, - 'expires': time.time() + 300 # Expire after 5 minutes - } - - # Clean up old entries - self._cleanup_shared_recipes() - - return web.json_response({ - 'success': True, - 'download_url': url_path, - 'filename': f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}{ext}" - }) - except Exception as e: - logger.error(f"Error sharing recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def download_shared_recipe(self, request: web.Request) -> web.Response: - """Serve a processed recipe image for download""" - try: - # Ensure services are initialized - await self.init_services() - - recipe_id = request.match_info['recipe_id'] - - # Check if we have this shared recipe - if not hasattr(self, '_shared_recipes') or recipe_id not in self._shared_recipes: - return web.json_response({"error": "Shared recipe not found or expired"}, status=404) - - shared_info = self._shared_recipes[recipe_id] - file_path = shared_info['path'] - - if not os.path.exists(file_path): - return web.json_response({"error": "Shared recipe file not found"}, status=404) - - # Get recipe to determine filename - cache = await self.recipe_scanner.get_cached_data() - recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None) - - # Set filename for download - filename = f"recipe_{recipe.get('title', '').replace(' ', '_').lower() if recipe else recipe_id}" - ext = os.path.splitext(file_path)[1] - download_filename = f"{filename}{ext}" - - # Serve the file - return web.FileResponse( - file_path, - headers={ - 'Content-Disposition': f'attachment; filename="{download_filename}"' - } - ) - except Exception as e: - logger.error(f"Error downloading shared recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - def _cleanup_shared_recipes(self): - """Clean up expired shared recipes""" - if not hasattr(self, '_shared_recipes'): - return - - current_time = time.time() - expired_ids = [rid for rid, info in self._shared_recipes.items() - if current_time > info.get('expires', 0)] - - for rid in expired_ids: - try: - # Delete the temporary file - file_path = self._shared_recipes[rid]['path'] - if os.path.exists(file_path): - os.unlink(file_path) - - # Remove from dictionary - del self._shared_recipes[rid] - except Exception as e: - logger.error(f"Error cleaning up shared recipe {rid}: {e}") - - async def save_recipe_from_widget(self, request: web.Request) -> web.Response: - """Save a recipe from the LoRAs widget""" - try: - # Ensure services are initialized - await self.init_services() - - # Get metadata using the metadata collector instead of workflow parsing - raw_metadata = get_metadata() - metadata_dict = MetadataProcessor.to_dict(raw_metadata) - - # Check if we have valid metadata - if not metadata_dict: - return web.json_response({"error": "No generation metadata found"}, status=400) - - # Get the most recent image from metadata registry instead of temp directory - if not standalone_mode: - metadata_registry = MetadataRegistry() - latest_image = metadata_registry.get_first_decoded_image() - else: - latest_image = None - - if latest_image is None: - return web.json_response({"error": "No recent images found to use for recipe. Try generating an image first."}, status=400) - - # Convert the image data to bytes - handle tuple and tensor cases - logger.debug(f"Image type: {type(latest_image)}") - - try: - # Handle the tuple case first - if isinstance(latest_image, tuple): - # Extract the tensor from the tuple - if len(latest_image) > 0: - tensor_image = latest_image[0] - else: - return web.json_response({"error": "Empty image tuple received"}, status=400) - else: - tensor_image = latest_image - - # Get the shape info for debugging - if hasattr(tensor_image, 'shape'): - shape_info = tensor_image.shape - logger.debug(f"Tensor shape: {shape_info}, dtype: {tensor_image.dtype}") - - import torch - - # Convert tensor to numpy array - if isinstance(tensor_image, torch.Tensor): - image_np = tensor_image.cpu().numpy() - else: - image_np = np.array(tensor_image) - - # Handle different tensor shapes - # Case: (1, 1, H, W, 3) or (1, H, W, 3) - batch or multi-batch - if len(image_np.shape) > 3: - # Remove batch dimensions until we get to (H, W, 3) - while len(image_np.shape) > 3: - image_np = image_np[0] - - # If values are in [0, 1] range, convert to [0, 255] - if image_np.dtype == np.float32 or image_np.dtype == np.float64: - if image_np.max() <= 1.0: - image_np = (image_np * 255).astype(np.uint8) - - # Ensure image is in the right format (HWC with RGB channels) - if len(image_np.shape) == 3 and image_np.shape[2] == 3: - pil_image = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - pil_image.save(img_byte_arr, format='PNG') - image = img_byte_arr.getvalue() - else: - return web.json_response({"error": f"Cannot handle this data shape: {image_np.shape}, {image_np.dtype}"}, status=400) - except Exception as e: - logger.error(f"Error processing image data: {str(e)}", exc_info=True) - return web.json_response({"error": f"Error processing image: {str(e)}"}, status=400) - - # Get the lora stack from the metadata - lora_stack = metadata_dict.get("loras", "") - - # Parse the lora stack format: " ..." - import re - lora_matches = re.findall(r']+)>', lora_stack) - - # Check if any loras were found - if not lora_matches: - return web.json_response({"error": "No LoRAs found in the generation metadata"}, status=400) - - # Generate recipe name from the first 3 loras (or less if fewer are available) - loras_for_name = lora_matches[:3] # Take at most 3 loras for the name - - recipe_name_parts = [] - for lora_name, lora_strength in loras_for_name: - # Get the basename without path or extension - basename = os.path.basename(lora_name) - basename = os.path.splitext(basename)[0] - recipe_name_parts.append(f"{basename}:{lora_strength}") - - recipe_name = " ".join(recipe_name_parts) - - # Create recipes directory if it doesn't exist - recipes_dir = self.recipe_scanner.recipes_dir - os.makedirs(recipes_dir, exist_ok=True) - - # Generate UUID for the recipe - import uuid - recipe_id = str(uuid.uuid4()) - - # Optimize the image (resize and convert to WebP) - optimized_image, extension = ExifUtils.optimize_image( - image_data=image, - target_width=CARD_PREVIEW_WIDTH, - format='webp', - quality=85, - preserve_metadata=True - ) - - # Save the optimized image - image_filename = f"{recipe_id}{extension}" - image_path = os.path.join(recipes_dir, image_filename) - with open(image_path, 'wb') as f: - f.write(optimized_image) - - # Format loras data from the lora stack - loras_data = [] - - for lora_name, lora_strength in lora_matches: - try: - # Get lora info from scanner - lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora_name) - - # Create lora entry - lora_entry = { - "file_name": lora_name, - "hash": lora_info.get("sha256", "").lower() if lora_info else "", - "strength": float(lora_strength), - "modelVersionId": lora_info.get("civitai", {}).get("id", 0) if lora_info else 0, - "modelName": lora_info.get("civitai", {}).get("model", {}).get("name", "") if lora_info else lora_name, - "modelVersionName": lora_info.get("civitai", {}).get("name", "") if lora_info else "", - "isDeleted": False - } - loras_data.append(lora_entry) - except Exception as e: - logger.warning(f"Error processing LoRA {lora_name}: {e}") - - # Get base model from lora scanner for the available loras - base_model_counts = {} - for lora in loras_data: - lora_info = await self.recipe_scanner._lora_scanner.get_model_info_by_name(lora.get("file_name", "")) - if lora_info and "base_model" in lora_info: - base_model = lora_info["base_model"] - base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 - - # Get most common base model - most_common_base_model = "" - if base_model_counts: - most_common_base_model = max(base_model_counts.items(), key=lambda x: x[1])[0] - - # Create the recipe data structure - recipe_data = { - "id": recipe_id, - "file_path": image_path, - "title": recipe_name, # Use generated recipe name - "modified": time.time(), - "created_date": time.time(), - "base_model": most_common_base_model, - "loras": loras_data, - "checkpoint": metadata_dict.get("checkpoint", ""), - "gen_params": {key: value for key, value in metadata_dict.items() - if key not in ['checkpoint', 'loras']}, - "loras_stack": lora_stack # Include the original lora stack string - } - - # Save the recipe JSON - json_filename = f"{recipe_id}.recipe.json" - json_path = os.path.join(recipes_dir, json_filename) - with open(json_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - - # Add recipe metadata to the image - ExifUtils.append_recipe_metadata(image_path, recipe_data) - - # Update cache - if self.recipe_scanner._cache is not None: - # Add the recipe to the raw data if the cache exists - self.recipe_scanner._cache.raw_data.append(recipe_data) - # Schedule a background task to resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Added recipe {recipe_id} to cache") - - return web.json_response({ - 'success': True, - 'recipe_id': recipe_id, - 'image_path': image_path, - 'json_path': json_path, - 'recipe_name': recipe_name # Include the generated recipe name in the response - }) - - except Exception as e: - logger.error(f"Error saving recipe from widget: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_recipe_syntax(self, request: web.Request) -> web.Response: - """Generate recipe syntax for LoRAs in the recipe, looking up proper file names using hash_index""" - try: - # Ensure services are initialized - await self.init_services() - - recipe_id = request.match_info['recipe_id'] - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Find the specific recipe - recipe = next((r for r in cache.raw_data if str(r.get('id', '')) == recipe_id), None) - - if not recipe: - return web.json_response({"error": "Recipe not found"}, status=404) - - # Get the loras from the recipe - loras = recipe.get('loras', []) - - if not loras: - return web.json_response({"error": "No LoRAs found in this recipe"}, status=400) - - # Generate recipe syntax for all LoRAs that: - # 1. Are in the library (not deleted) OR - # 2. Are deleted but not marked for exclusion - lora_syntax_parts = [] - - # Access the hash_index from lora_scanner - hash_index = self.recipe_scanner._lora_scanner._hash_index - - for lora in loras: - # Skip loras that are deleted AND marked for exclusion - if lora.get("isDeleted", False): - continue - - if not self.recipe_scanner._lora_scanner.has_hash(lora.get("hash", "")): - continue - - # Get the strength - strength = lora.get("strength", 1.0) - - # Try to find the actual file name for this lora - file_name = None - hash_value = lora.get("hash", "").lower() - - if hash_value and hasattr(hash_index, "_hash_to_path"): - # Look up the file path from the hash - file_path = hash_index._hash_to_path.get(hash_value) - - if file_path: - # Extract the file name without extension from the path - file_name = os.path.splitext(os.path.basename(file_path))[0] - - # If hash lookup failed, fall back to modelVersionId lookup - if not file_name and lora.get("modelVersionId"): - # Search for files with matching modelVersionId - all_loras = await self.recipe_scanner._lora_scanner.get_cached_data() - for cached_lora in all_loras.raw_data: - if not cached_lora.get("civitai"): - continue - if cached_lora.get("civitai", {}).get("id") == lora.get("modelVersionId"): - file_name = os.path.splitext(os.path.basename(cached_lora["path"]))[0] - break - - # If all lookups failed, use the file_name from the recipe - if not file_name: - file_name = lora.get("file_name", "unknown-lora") - - # Add to syntax parts - lora_syntax_parts.append(f"") - - # Join the LoRA syntax parts - lora_syntax = " ".join(lora_syntax_parts) - - return web.json_response({ - 'success': True, - 'syntax': lora_syntax - }) - except Exception as e: - logger.error(f"Error generating recipe syntax: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def update_recipe(self, request: web.Request) -> web.Response: - """Update recipe metadata (name and tags)""" - try: - # Ensure services are initialized - await self.init_services() - - recipe_id = request.match_info['recipe_id'] - data = await request.json() - - # Validate required fields - if 'title' not in data and 'tags' not in data and 'source_path' not in data and 'preview_nsfw_level' not in data: - return web.json_response({ - "error": "At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)" - }, status=400) - - # Use the recipe scanner's update method - success = await self.recipe_scanner.update_recipe_metadata(recipe_id, data) - - if not success: - return web.json_response({"error": "Recipe not found or update failed"}, status=404) - - return web.json_response({ - "success": True, - "recipe_id": recipe_id, - "updates": data - }) - except Exception as e: - logger.error(f"Error updating recipe: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def reconnect_lora(self, request: web.Request) -> web.Response: - """Reconnect a deleted LoRA in a recipe to a local LoRA file""" - try: - # Ensure services are initialized - await self.init_services() - - # Parse request data - data = await request.json() - - # Validate required fields - required_fields = ['recipe_id', 'lora_index', 'target_name'] - for field in required_fields: - if field not in data: - return web.json_response({ - "error": f"Missing required field: {field}" - }, status=400) - - recipe_id = data['recipe_id'] - lora_index = int(data['lora_index']) - target_name = data['target_name'] - - # Get recipe scanner - scanner = self.recipe_scanner - lora_scanner = scanner._lora_scanner - - # Check if recipe exists - recipe_path = os.path.join(scanner.recipes_dir, f"{recipe_id}.recipe.json") - if not os.path.exists(recipe_path): - return web.json_response({"error": "Recipe not found"}, status=404) - - # Find target LoRA by name - target_lora = await lora_scanner.get_model_info_by_name(target_name) - if not target_lora: - return web.json_response({"error": f"Local LoRA not found with name: {target_name}"}, status=404) - - # Load recipe data - with open(recipe_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - lora = recipe_data.get("loras", [])[lora_index] if lora_index < len(recipe_data.get('loras', [])) else None - - if lora is None: - return web.json_response({"error": "LoRA index out of range in recipe"}, status=404) - - # Update LoRA data - lora['isDeleted'] = False - lora['exclude'] = False - lora['file_name'] = target_name - - # Update with information from the target LoRA - if 'sha256' in target_lora: - lora['hash'] = target_lora['sha256'].lower() - if target_lora.get("civitai"): - lora['modelName'] = target_lora['civitai']['model']['name'] - lora['modelVersionName'] = target_lora['civitai']['name'] - lora['modelVersionId'] = target_lora['civitai']['id'] - - updated_lora = dict(lora) # Make a copy for response - - # Recalculate recipe fingerprint after updating LoRA - from ..utils.utils import calculate_recipe_fingerprint - recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', [])) - - # Save updated recipe - with open(recipe_path, 'w', encoding='utf-8') as f: - json.dump(recipe_data, f, indent=4, ensure_ascii=False) - - updated_lora['inLibrary'] = True - updated_lora['preview_url'] = config.get_preview_static_url(target_lora['preview_url']) - updated_lora['localPath'] = target_lora['file_path'] - - # Update in cache if it exists - if scanner._cache is not None: - for cache_item in scanner._cache.raw_data: - if cache_item.get('id') == recipe_id: - # Replace loras array with updated version - cache_item['loras'] = recipe_data['loras'] - # Update fingerprint in cache - cache_item['fingerprint'] = recipe_data['fingerprint'] - - # Resort the cache - asyncio.create_task(scanner._cache.resort()) - break - - # Update EXIF metadata if image exists - image_path = recipe_data.get('file_path') - if image_path and os.path.exists(image_path): - from ..utils.exif_utils import ExifUtils - ExifUtils.append_recipe_metadata(image_path, recipe_data) - - # Find other recipes with the same fingerprint - matching_recipes = [] - if 'fingerprint' in recipe_data: - matching_recipes = await scanner.find_recipes_by_fingerprint(recipe_data['fingerprint']) - # Remove current recipe from matches - if recipe_id in matching_recipes: - matching_recipes.remove(recipe_id) - - return web.json_response({ - "success": True, - "recipe_id": recipe_id, - "updated_lora": updated_lora, - "matching_recipes": matching_recipes - }) - - except Exception as e: - logger.error(f"Error reconnecting LoRA: {e}", exc_info=True) - return web.json_response({"error": str(e)}, status=500) - - async def get_recipes_for_lora(self, request: web.Request) -> web.Response: - """Get recipes that use a specific Lora""" - try: - # Ensure services are initialized - await self.init_services() - - lora_hash = request.query.get('hash') - - # Hash is required - if not lora_hash: - return web.json_response({'success': False, 'error': 'Lora hash is required'}, status=400) - - # Log the search parameters - logger.debug(f"Getting recipes for Lora by hash: {lora_hash}") - - # Get all recipes from cache - cache = await self.recipe_scanner.get_cached_data() - - # Filter recipes that use this Lora by hash - matching_recipes = [] - for recipe in cache.raw_data: - # Check if any of the recipe's loras match this hash - loras = recipe.get('loras', []) - for lora in loras: - if lora.get('hash', '').lower() == lora_hash.lower(): - matching_recipes.append(recipe) - break # No need to check other loras in this recipe - - # Process the recipes similar to get_paginated_data to ensure all needed data is available - for recipe in matching_recipes: - # Add inLibrary information for each lora - if 'loras' in recipe: - for lora in recipe['loras']: - if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self.recipe_scanner._lora_scanner.has_hash(lora['hash'].lower()) - lora['preview_url'] = self.recipe_scanner._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self.recipe_scanner._lora_scanner.get_path_by_hash(lora['hash'].lower()) - - # Ensure file_url is set (needed by frontend) - if 'file_path' in recipe: - recipe['file_url'] = self._format_recipe_file_url(recipe['file_path']) - else: - recipe['file_url'] = '/loras_static/images/no-preview.png' - - return web.json_response({'success': True, 'recipes': matching_recipes}) - except Exception as e: - logger.error(f"Error getting recipes for Lora: {str(e)}") - return web.json_response({'success': False, 'error': str(e)}, status=500) - - async def scan_recipes(self, request: web.Request) -> web.Response: - """API endpoint for scanning and rebuilding the recipe cache""" - try: - # Ensure services are initialized - await self.init_services() - - # Force refresh the recipe cache - logger.info("Manually triggering recipe cache rebuild") - await self.recipe_scanner.get_cached_data(force_refresh=True) - - return web.json_response({ - 'success': True, - 'message': 'Recipe cache refreshed successfully' - }) - except Exception as e: - logger.error(f"Error refreshing recipe cache: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def find_duplicates(self, request: web.Request) -> web.Response: - """Find all duplicate recipes based on fingerprints""" - try: - # Ensure services are initialized - await self.init_services() - - # Get all duplicate recipes - duplicate_groups = await self.recipe_scanner.find_all_duplicate_recipes() - - # Create response data with additional recipe information - response_data = [] - - for fingerprint, recipe_ids in duplicate_groups.items(): - # Skip groups with only one recipe (not duplicates) - if len(recipe_ids) <= 1: - continue - - # Get recipe details for each recipe in the group - recipes = [] - for recipe_id in recipe_ids: - recipe = await self.recipe_scanner.get_recipe_by_id(recipe_id) - if recipe: - # Add only needed fields to keep response size manageable - recipes.append({ - 'id': recipe.get('id'), - 'title': recipe.get('title'), - 'file_url': recipe.get('file_url') or self._format_recipe_file_url(recipe.get('file_path', '')), - 'modified': recipe.get('modified'), - 'created_date': recipe.get('created_date'), - 'lora_count': len(recipe.get('loras', [])), - }) - - # Only include groups with at least 2 valid recipes - if len(recipes) >= 2: - # Sort recipes by modified date (newest first) - recipes.sort(key=lambda x: x.get('modified', 0), reverse=True) - - response_data.append({ - 'fingerprint': fingerprint, - 'count': len(recipes), - 'recipes': recipes - }) - - # Sort groups by count (highest first) - response_data.sort(key=lambda x: x['count'], reverse=True) - - return web.json_response({ - 'success': True, - 'duplicate_groups': response_data - }) - - except Exception as e: - logger.error(f"Error finding duplicate recipes: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - async def bulk_delete(self, request: web.Request) -> web.Response: - """Delete multiple recipes by ID""" - try: - # Ensure services are initialized - await self.init_services() - - # Parse request data - data = await request.json() - recipe_ids = data.get('recipe_ids', []) - - if not recipe_ids: - return web.json_response({ - 'success': False, - 'error': 'No recipe IDs provided' - }, status=400) - - # Get recipes directory - recipes_dir = self.recipe_scanner.recipes_dir - if not recipes_dir or not os.path.exists(recipes_dir): - return web.json_response({ - 'success': False, - 'error': 'Recipes directory not found' - }, status=404) - - # Track deleted and failed recipes - deleted_recipes = [] - failed_recipes = [] - - # Process each recipe ID - for recipe_id in recipe_ids: - # Find recipe JSON file - recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") - - if not os.path.exists(recipe_json_path): - failed_recipes.append({ - 'id': recipe_id, - 'reason': 'Recipe not found' - }) - continue - - try: - # Load recipe data to get image path - with open(recipe_json_path, 'r', encoding='utf-8') as f: - recipe_data = json.load(f) - - # Get image path - image_path = recipe_data.get('file_path') - - # Delete recipe JSON file - os.remove(recipe_json_path) - - # Delete recipe image if it exists - if image_path and os.path.exists(image_path): - os.remove(image_path) - - deleted_recipes.append(recipe_id) - - except Exception as e: - failed_recipes.append({ - 'id': recipe_id, - 'reason': str(e) - }) - - # Update cache if any recipes were deleted - if deleted_recipes and self.recipe_scanner._cache is not None: - # Remove deleted recipes from raw_data - self.recipe_scanner._cache.raw_data = [ - r for r in self.recipe_scanner._cache.raw_data - if r.get('id') not in deleted_recipes - ] - # Resort the cache - asyncio.create_task(self.recipe_scanner._cache.resort()) - logger.info(f"Removed {len(deleted_recipes)} recipes from cache") - - return web.json_response({ - 'success': True, - 'deleted': deleted_recipes, - 'failed': failed_recipes, - 'total_deleted': len(deleted_recipes), - 'total_failed': len(failed_recipes) - }) - - except Exception as e: - logger.error(f"Error performing bulk delete: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + registrar = RecipeRouteRegistrar(app) + registrar.register_routes(routes.to_route_mapping()) + routes.register_startup_hooks(app) diff --git a/py/services/base_model_service.py b/py/services/base_model_service.py index ed1fc930..2c2c0ad8 100644 --- a/py/services/base_model_service.py +++ b/py/services/base_model_service.py @@ -4,99 +4,88 @@ import logging import os from ..utils.models import BaseModelMetadata -from ..utils.routes_common import ModelRouteUtils -from ..utils.constants import NSFW_LEVELS -from .settings_manager import settings -from ..utils.utils import fuzzy_match +from .model_query import FilterCriteria, ModelCacheRepository, ModelFilterSet, SearchStrategy, SettingsProvider +from .settings_manager import settings as default_settings logger = logging.getLogger(__name__) class BaseModelService(ABC): """Base service class for all model types""" - def __init__(self, model_type: str, scanner, metadata_class: Type[BaseModelMetadata]): - """Initialize the service - + def __init__( + self, + model_type: str, + scanner, + metadata_class: Type[BaseModelMetadata], + *, + cache_repository: Optional[ModelCacheRepository] = None, + filter_set: Optional[ModelFilterSet] = None, + search_strategy: Optional[SearchStrategy] = None, + settings_provider: Optional[SettingsProvider] = None, + ): + """Initialize the service. + Args: - model_type: Type of model (lora, checkpoint, etc.) - scanner: Model scanner instance - metadata_class: Metadata class for this model type + model_type: Type of model (lora, checkpoint, etc.). + scanner: Model scanner instance. + metadata_class: Metadata class for this model type. + cache_repository: Custom repository for cache access (primarily for tests). + filter_set: Filter component controlling folder/tag/favorites logic. + search_strategy: Search component for fuzzy/text matching. + settings_provider: Settings object; defaults to the global settings manager. """ self.model_type = model_type self.scanner = scanner self.metadata_class = metadata_class + self.settings = settings_provider or default_settings + self.cache_repository = cache_repository or ModelCacheRepository(scanner) + self.filter_set = filter_set or ModelFilterSet(self.settings) + self.search_strategy = search_strategy or SearchStrategy() - async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'name', - folder: str = None, search: str = None, fuzzy_search: bool = False, - base_models: list = None, tags: list = None, - search_options: dict = None, hash_filters: dict = None, - favorites_only: bool = False, **kwargs) -> Dict: - """Get paginated and filtered model data - - Args: - page: Page number (1-based) - page_size: Number of items per page - sort_by: Sort criteria, e.g. 'name', 'name:asc', 'name:desc', 'date', 'date:asc', 'date:desc' - folder: Folder filter - search: Search term - fuzzy_search: Whether to use fuzzy search - base_models: List of base models to filter by - tags: List of tags to filter by - search_options: Search options dict - hash_filters: Hash filtering options - favorites_only: Filter for favorites only - **kwargs: Additional model-specific filters - - Returns: - Dict containing paginated results - """ - cache = await self.scanner.get_cached_data() + async def get_paginated_data( + self, + page: int, + page_size: int, + sort_by: str = 'name', + folder: str = None, + search: str = None, + fuzzy_search: bool = False, + base_models: list = None, + tags: list = None, + search_options: dict = None, + hash_filters: dict = None, + favorites_only: bool = False, + **kwargs, + ) -> Dict: + """Get paginated and filtered model data""" + sort_params = self.cache_repository.parse_sort(sort_by) + sorted_data = await self.cache_repository.fetch_sorted(sort_params) - # Parse sort_by into sort_key and order - if ':' in sort_by: - sort_key, order = sort_by.split(':', 1) - sort_key = sort_key.strip() - order = order.strip().lower() - if order not in ('asc', 'desc'): - order = 'asc' - else: - sort_key = sort_by.strip() - order = 'asc' - - # Get default search options if not provided - if search_options is None: - search_options = { - 'filename': True, - 'modelname': True, - 'tags': False, - 'recursive': True, - } - - # Get the base data set using new sort logic - filtered_data = await cache.get_sorted_data(sort_key, order) - - # Apply hash filtering if provided (highest priority) if hash_filters: - filtered_data = await self._apply_hash_filters(filtered_data, hash_filters) - - # Jump to pagination for hash filters + filtered_data = await self._apply_hash_filters(sorted_data, hash_filters) return self._paginate(filtered_data, page, page_size) - - # Apply common filters + filtered_data = await self._apply_common_filters( - filtered_data, folder, base_models, tags, favorites_only, search_options + sorted_data, + folder=folder, + base_models=base_models, + tags=tags, + favorites_only=favorites_only, + search_options=search_options, ) - - # Apply search filtering + if search: filtered_data = await self._apply_search_filters( - filtered_data, search, fuzzy_search, search_options + filtered_data, + search, + fuzzy_search, + search_options, ) - - # Apply model-specific filters + filtered_data = await self._apply_specific_filters(filtered_data, **kwargs) - + return self._paginate(filtered_data, page, page_size) + async def _apply_hash_filters(self, data: List[Dict], hash_filters: Dict) -> List[Dict]: """Apply hash-based filtering""" @@ -120,113 +109,36 @@ class BaseModelService(ABC): return data - async def _apply_common_filters(self, data: List[Dict], folder: str = None, - base_models: list = None, tags: list = None, - favorites_only: bool = False, search_options: dict = None) -> List[Dict]: + async def _apply_common_filters( + self, + data: List[Dict], + folder: str = None, + base_models: list = None, + tags: list = None, + favorites_only: bool = False, + search_options: dict = None, + ) -> List[Dict]: """Apply common filters that work across all model types""" - # Apply SFW filtering if enabled in settings - if settings.get('show_only_sfw', False): - data = [ - item for item in data - if not item.get('preview_nsfw_level') or item.get('preview_nsfw_level') < NSFW_LEVELS['R'] - ] - - # Apply favorites filtering if enabled - if favorites_only: - data = [ - item for item in data - if item.get('favorite', False) is True - ] - - # Apply folder filtering - if folder is not None: - if search_options and search_options.get('recursive', True): - # Recursive folder filtering - include all subfolders - # Ensure we match exact folder or its subfolders by checking path boundaries - if folder == "": - # Empty folder means root - include all items - pass # Don't filter anything - else: - # Add trailing slash to ensure we match folder boundaries correctly - folder_with_separator = folder + "/" - data = [ - item for item in data - if (item['folder'] == folder or - item['folder'].startswith(folder_with_separator)) - ] - else: - # Exact folder filtering - data = [ - item for item in data - if item['folder'] == folder - ] - - # Apply base model filtering - if base_models and len(base_models) > 0: - data = [ - item for item in data - if item.get('base_model') in base_models - ] - - # Apply tag filtering - if tags and len(tags) > 0: - data = [ - item for item in data - if any(tag in item.get('tags', []) for tag in tags) - ] - - return data + normalized_options = self.search_strategy.normalize_options(search_options) + criteria = FilterCriteria( + folder=folder, + base_models=base_models, + tags=tags, + favorites_only=favorites_only, + search_options=normalized_options, + ) + return self.filter_set.apply(data, criteria) - async def _apply_search_filters(self, data: List[Dict], search: str, - fuzzy_search: bool, search_options: dict) -> List[Dict]: + async def _apply_search_filters( + self, + data: List[Dict], + search: str, + fuzzy_search: bool, + search_options: dict, + ) -> List[Dict]: """Apply search filtering""" - search_results = [] - - for item in data: - # Search by file name - if search_options.get('filename', True): - if fuzzy_search: - if fuzzy_match(item.get('file_name', ''), search): - search_results.append(item) - continue - elif search.lower() in item.get('file_name', '').lower(): - search_results.append(item) - continue - - # Search by model name - if search_options.get('modelname', True): - if fuzzy_search: - if fuzzy_match(item.get('model_name', ''), search): - search_results.append(item) - continue - elif search.lower() in item.get('model_name', '').lower(): - search_results.append(item) - continue - - # Search by tags - if search_options.get('tags', False) and 'tags' in item: - if any((fuzzy_match(tag, search) if fuzzy_search else search.lower() in tag.lower()) - for tag in item['tags']): - search_results.append(item) - continue - - # Search by creator - civitai = item.get('civitai') - creator_username = '' - if civitai and isinstance(civitai, dict): - creator = civitai.get('creator') - if creator and isinstance(creator, dict): - creator_username = creator.get('username', '') - if search_options.get('creator', False) and creator_username: - if fuzzy_search: - if fuzzy_match(creator_username, search): - search_results.append(item) - continue - elif search.lower() in creator_username.lower(): - search_results.append(item) - continue - - return search_results + normalized_options = self.search_strategy.normalize_options(search_options) + return self.search_strategy.apply(data, search, normalized_options, fuzzy_search) async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: """Apply model-specific filters - to be overridden by subclasses if needed""" @@ -284,6 +196,18 @@ class BaseModelService(ABC): """Get model root directories""" return self.scanner.get_model_roots() + def filter_civitai_data(self, data: Dict, minimal: bool = False) -> Dict: + """Filter relevant fields from CivitAI data""" + if not data: + return {} + + fields = ["id", "modelId", "name", "trainedWords"] if minimal else [ + "id", "modelId", "name", "createdAt", "updatedAt", + "publishedAt", "trainedWords", "baseModel", "description", + "model", "images", "customImages", "creator" + ] + return {k: data[k] for k in fields if k in data} + async def get_folder_tree(self, model_root: str) -> Dict: """Get hierarchical folder tree for a specific model root""" cache = await self.scanner.get_cached_data() @@ -394,7 +318,7 @@ class BaseModelService(ABC): for model in cache.raw_data: if model.get('file_path') == file_path: - return ModelRouteUtils.filter_civitai_data(model.get("civitai", {})) + return self.filter_civitai_data(model.get("civitai", {})) return None diff --git a/py/services/checkpoint_service.py b/py/services/checkpoint_service.py index ef3dc4a8..2f7b8a96 100644 --- a/py/services/checkpoint_service.py +++ b/py/services/checkpoint_service.py @@ -1,11 +1,10 @@ import os import logging -from typing import Dict, List, Optional +from typing import Dict from .base_model_service import BaseModelService from ..utils.models import CheckpointMetadata from ..config import config -from ..utils.routes_common import ModelRouteUtils logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class CheckpointService(BaseModelService): "notes": checkpoint_data.get("notes", ""), "model_type": checkpoint_data.get("model_type", "checkpoint"), "favorite": checkpoint_data.get("favorite", False), - "civitai": ModelRouteUtils.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data(checkpoint_data.get("civitai", {}), minimal=True) } def find_duplicate_hashes(self) -> Dict: diff --git a/py/services/download_coordinator.py b/py/services/download_coordinator.py new file mode 100644 index 00000000..4cf866e5 --- /dev/null +++ b/py/services/download_coordinator.py @@ -0,0 +1,100 @@ +"""Service wrapper for coordinating download lifecycle events.""" + +from __future__ import annotations + +import logging +from typing import Any, Awaitable, Callable, Dict, Optional + + +logger = logging.getLogger(__name__) + + +class DownloadCoordinator: + """Manage download scheduling, cancellation and introspection.""" + + def __init__( + self, + *, + ws_manager, + download_manager_factory: Callable[[], Awaitable], + ) -> None: + self._ws_manager = ws_manager + self._download_manager_factory = download_manager_factory + + async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Schedule a download using the provided payload.""" + + download_manager = await self._download_manager_factory() + + download_id = payload.get("download_id") or self._ws_manager.generate_download_id() + payload.setdefault("download_id", download_id) + + async def progress_callback(progress: Any) -> None: + await self._ws_manager.broadcast_download_progress( + download_id, + { + "status": "progress", + "progress": progress, + "download_id": download_id, + }, + ) + + model_id = self._parse_optional_int(payload.get("model_id"), "model_id") + model_version_id = self._parse_optional_int( + payload.get("model_version_id"), "model_version_id" + ) + + if model_id is None and model_version_id is None: + raise ValueError( + "Missing required parameter: Please provide either 'model_id' or 'model_version_id'" + ) + + result = await download_manager.download_from_civitai( + model_id=model_id, + model_version_id=model_version_id, + save_dir=payload.get("model_root"), + relative_path=payload.get("relative_path", ""), + use_default_paths=payload.get("use_default_paths", False), + progress_callback=progress_callback, + download_id=download_id, + source=payload.get("source"), + ) + + result["download_id"] = download_id + return result + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + """Cancel an active download and emit a broadcast event.""" + + download_manager = await self._download_manager_factory() + result = await download_manager.cancel_download(download_id) + + await self._ws_manager.broadcast_download_progress( + download_id, + { + "status": "cancelled", + "progress": 0, + "download_id": download_id, + "message": "Download cancelled by user", + }, + ) + + return result + + async def list_active_downloads(self) -> Dict[str, Any]: + """Return the active download map from the underlying manager.""" + + download_manager = await self._download_manager_factory() + return await download_manager.get_active_downloads() + + def _parse_optional_int(self, value: Any, field: str) -> Optional[int]: + """Parse an optional integer from user input.""" + + if value is None or value == "": + return None + + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid {field}: Must be an integer") from exc + diff --git a/py/services/embedding_service.py b/py/services/embedding_service.py index bab067d9..46396fc5 100644 --- a/py/services/embedding_service.py +++ b/py/services/embedding_service.py @@ -1,11 +1,10 @@ import os import logging -from typing import Dict, List, Optional +from typing import Dict from .base_model_service import BaseModelService from ..utils.models import EmbeddingMetadata from ..config import config -from ..utils.routes_common import ModelRouteUtils logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class EmbeddingService(BaseModelService): "notes": embedding_data.get("notes", ""), "model_type": embedding_data.get("model_type", "embedding"), "favorite": embedding_data.get("favorite", False), - "civitai": ModelRouteUtils.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data(embedding_data.get("civitai", {}), minimal=True) } def find_duplicate_hashes(self) -> Dict: diff --git a/py/services/lora_service.py b/py/services/lora_service.py index d1e522a3..551c4d3c 100644 --- a/py/services/lora_service.py +++ b/py/services/lora_service.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional from .base_model_service import BaseModelService from ..utils.models import LoraMetadata from ..config import config -from ..utils.routes_common import ModelRouteUtils logger = logging.getLogger(__name__) @@ -38,7 +37,7 @@ class LoraService(BaseModelService): "usage_tips": lora_data.get("usage_tips", ""), "notes": lora_data.get("notes", ""), "favorite": lora_data.get("favorite", False), - "civitai": ModelRouteUtils.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) + "civitai": self.filter_civitai_data(lora_data.get("civitai", {}), minimal=True) } async def _apply_specific_filters(self, data: List[Dict], **kwargs) -> List[Dict]: diff --git a/py/services/metadata_sync_service.py b/py/services/metadata_sync_service.py new file mode 100644 index 00000000..aaf2f248 --- /dev/null +++ b/py/services/metadata_sync_service.py @@ -0,0 +1,355 @@ +"""Services for synchronising metadata with remote providers.""" + +from __future__ import annotations + +import json +import logging +import os +from datetime import datetime +from typing import Any, Awaitable, Callable, Dict, Iterable, Optional + +from ..services.settings_manager import SettingsManager +from ..utils.model_utils import determine_base_model + +logger = logging.getLogger(__name__) + + +class MetadataProviderProtocol: + """Subset of metadata provider interface consumed by the sync service.""" + + async def get_model_by_hash(self, sha256: str) -> tuple[Optional[Dict[str, Any]], Optional[str]]: + ... + + async def get_model_version( + self, model_id: int, model_version_id: Optional[int] + ) -> Optional[Dict[str, Any]]: + ... + + +class MetadataSyncService: + """High level orchestration for metadata synchronisation flows.""" + + def __init__( + self, + *, + metadata_manager, + preview_service, + settings: SettingsManager, + default_metadata_provider_factory: Callable[[], Awaitable[MetadataProviderProtocol]], + metadata_provider_selector: Callable[[str], Awaitable[MetadataProviderProtocol]], + ) -> None: + self._metadata_manager = metadata_manager + self._preview_service = preview_service + self._settings = settings + self._get_default_provider = default_metadata_provider_factory + self._get_provider = metadata_provider_selector + + async def load_local_metadata(self, metadata_path: str) -> Dict[str, Any]: + """Load metadata JSON from disk, returning an empty structure when missing.""" + + if not os.path.exists(metadata_path): + return {} + + try: + with open(metadata_path, "r", encoding="utf-8") as handle: + return json.load(handle) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Error loading metadata from %s: %s", metadata_path, exc) + return {} + + async def mark_not_found_on_civitai( + self, metadata_path: str, local_metadata: Dict[str, Any] + ) -> None: + """Persist the not-found flag for a metadata payload.""" + + local_metadata["from_civitai"] = False + await self._metadata_manager.save_metadata(metadata_path, local_metadata) + + @staticmethod + def is_civitai_api_metadata(meta: Dict[str, Any]) -> bool: + """Determine if the metadata originated from the CivitAI public API.""" + + if not isinstance(meta, dict): + return False + files = meta.get("files") + images = meta.get("images") + source = meta.get("source") + return bool(files) and bool(images) and source != "archive_db" + + async def update_model_metadata( + self, + metadata_path: str, + local_metadata: Dict[str, Any], + civitai_metadata: Dict[str, Any], + metadata_provider: Optional[MetadataProviderProtocol] = None, + ) -> Dict[str, Any]: + """Merge remote metadata into the local record and persist the result.""" + + existing_civitai = local_metadata.get("civitai") or {} + + if ( + civitai_metadata.get("source") == "archive_db" + and self.is_civitai_api_metadata(existing_civitai) + ): + logger.info( + "Skip civitai update for %s (%s)", + local_metadata.get("model_name", ""), + existing_civitai.get("name", ""), + ) + else: + merged_civitai = existing_civitai.copy() + merged_civitai.update(civitai_metadata) + + if civitai_metadata.get("source") == "archive_db": + model_name = civitai_metadata.get("model", {}).get("name", "") + version_name = civitai_metadata.get("name", "") + logger.info( + "Recovered metadata from archive_db for deleted model: %s (%s)", + model_name, + version_name, + ) + + if "trainedWords" in existing_civitai: + existing_trained = existing_civitai.get("trainedWords", []) + new_trained = civitai_metadata.get("trainedWords", []) + merged_trained = list(set(existing_trained + new_trained)) + merged_civitai["trainedWords"] = merged_trained + + local_metadata["civitai"] = merged_civitai + + if "model" in civitai_metadata and civitai_metadata["model"]: + model_data = civitai_metadata["model"] + + if model_data.get("name"): + local_metadata["model_name"] = model_data["name"] + + if not local_metadata.get("modelDescription") and model_data.get("description"): + local_metadata["modelDescription"] = model_data["description"] + + if not local_metadata.get("tags") and model_data.get("tags"): + local_metadata["tags"] = model_data["tags"] + + if model_data.get("creator") and not local_metadata.get("civitai", {}).get( + "creator" + ): + local_metadata.setdefault("civitai", {})["creator"] = model_data["creator"] + + local_metadata["base_model"] = determine_base_model( + civitai_metadata.get("baseModel") + ) + + await self._preview_service.ensure_preview_for_metadata( + metadata_path, local_metadata, civitai_metadata.get("images", []) + ) + + await self._metadata_manager.save_metadata(metadata_path, local_metadata) + return local_metadata + + async def fetch_and_update_model( + self, + *, + sha256: str, + file_path: str, + model_data: Dict[str, Any], + update_cache_func: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> tuple[bool, Optional[str]]: + """Fetch metadata for a model and update both disk and cache state.""" + + if not isinstance(model_data, dict): + error = f"Invalid model_data type: {type(model_data)}" + logger.error(error) + return False, error + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + enable_archive = self._settings.get("enable_metadata_archive_db", False) + + try: + if model_data.get("civitai_deleted") is True: + if not enable_archive or model_data.get("db_checked") is True: + return ( + False, + "CivitAI model is deleted and metadata archive DB is not enabled", + ) + metadata_provider = await self._get_provider("sqlite") + else: + metadata_provider = await self._get_default_provider() + + civitai_metadata, error = await metadata_provider.get_model_by_hash(sha256) + if not civitai_metadata: + if error == "Model not found": + model_data["from_civitai"] = False + model_data["civitai_deleted"] = True + model_data["db_checked"] = enable_archive + model_data["last_checked_at"] = datetime.now().timestamp() + + data_to_save = model_data.copy() + data_to_save.pop("folder", None) + await self._metadata_manager.save_metadata(file_path, data_to_save) + + error_msg = ( + f"Error fetching metadata: {error} (model_name={model_data.get('model_name', '')})" + ) + logger.error(error_msg) + return False, error_msg + + model_data["from_civitai"] = True + model_data["civitai_deleted"] = civitai_metadata.get("source") == "archive_db" + model_data["db_checked"] = enable_archive + model_data["last_checked_at"] = datetime.now().timestamp() + + local_metadata = model_data.copy() + local_metadata.pop("folder", None) + + await self.update_model_metadata( + metadata_path, + local_metadata, + civitai_metadata, + metadata_provider, + ) + + update_payload = { + "model_name": local_metadata.get("model_name"), + "preview_url": local_metadata.get("preview_url"), + "civitai": local_metadata.get("civitai"), + } + model_data.update(update_payload) + + await update_cache_func(file_path, file_path, local_metadata) + return True, None + except KeyError as exc: + error_msg = f"Error fetching metadata - Missing key: {exc} in model_data={model_data}" + logger.error(error_msg) + return False, error_msg + except Exception as exc: # pragma: no cover - error path + error_msg = f"Error fetching metadata: {exc}" + logger.error(error_msg, exc_info=True) + return False, error_msg + + async def fetch_metadata_by_sha( + self, sha256: str, metadata_provider: Optional[MetadataProviderProtocol] = None + ) -> tuple[Optional[Dict[str, Any]], Optional[str]]: + """Fetch metadata for a SHA256 hash from the configured provider.""" + + provider = metadata_provider or await self._get_default_provider() + return await provider.get_model_by_hash(sha256) + + async def relink_metadata( + self, + *, + file_path: str, + metadata: Dict[str, Any], + model_id: int, + model_version_id: Optional[int], + ) -> Dict[str, Any]: + """Relink a local metadata record to a specific CivitAI model version.""" + + provider = await self._get_default_provider() + civitai_metadata = await provider.get_model_version(model_id, model_version_id) + if not civitai_metadata: + raise ValueError( + f"Model version not found on CivitAI for ID: {model_id}" + + (f" with version: {model_version_id}" if model_version_id else "") + ) + + primary_model_file: Optional[Dict[str, Any]] = None + for file_info in civitai_metadata.get("files", []): + if file_info.get("primary", False) and file_info.get("type") == "Model": + primary_model_file = file_info + break + + if primary_model_file and primary_model_file.get("hashes", {}).get("SHA256"): + metadata["sha256"] = primary_model_file["hashes"]["SHA256"].lower() + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + await self.update_model_metadata( + metadata_path, + metadata, + civitai_metadata, + provider, + ) + + return metadata + + async def save_metadata_updates( + self, + *, + file_path: str, + updates: Dict[str, Any], + metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]], + update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> Dict[str, Any]: + """Apply metadata updates and persist to disk and cache.""" + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + + for key, value in updates.items(): + if isinstance(value, dict) and isinstance(metadata.get(key), dict): + metadata[key].update(value) + else: + metadata[key] = value + + await self._metadata_manager.save_metadata(file_path, metadata) + await update_cache(file_path, file_path, metadata) + + if "model_name" in updates: + logger.debug("Metadata update touched model_name; cache resort required") + + return metadata + + async def verify_duplicate_hashes( + self, + *, + file_paths: Iterable[str], + metadata_loader: Callable[[str], Awaitable[Dict[str, Any]]], + hash_calculator: Callable[[str], Awaitable[str]], + update_cache: Callable[[str, str, Dict[str, Any]], Awaitable[bool]], + ) -> Dict[str, Any]: + """Verify a collection of files share the same SHA256 hash.""" + + file_paths = list(file_paths) + if not file_paths: + raise ValueError("No file paths provided for verification") + + results = { + "verified_as_duplicates": True, + "mismatched_files": [], + "new_hash_map": {}, + } + + expected_hash: Optional[str] = None + first_metadata_path = os.path.splitext(file_paths[0])[0] + ".metadata.json" + first_metadata = await metadata_loader(first_metadata_path) + if first_metadata and "sha256" in first_metadata: + expected_hash = first_metadata["sha256"].lower() + + for path in file_paths: + if not os.path.exists(path): + continue + + try: + actual_hash = await hash_calculator(path) + metadata_path = os.path.splitext(path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + stored_hash = metadata.get("sha256", "").lower() + + if not expected_hash: + expected_hash = stored_hash + + if actual_hash != expected_hash: + results["verified_as_duplicates"] = False + results["mismatched_files"].append(path) + results["new_hash_map"][path] = actual_hash + + if actual_hash != stored_hash: + metadata["sha256"] = actual_hash + await self._metadata_manager.save_metadata(path, metadata) + await update_cache(path, path, metadata) + except Exception as exc: # pragma: no cover - defensive path + logger.error("Error verifying hash for %s: %s", path, exc) + results["mismatched_files"].append(path) + results["new_hash_map"][path] = "error_calculating_hash" + results["verified_as_duplicates"] = False + + return results + diff --git a/py/services/model_lifecycle_service.py b/py/services/model_lifecycle_service.py new file mode 100644 index 00000000..9aa87b04 --- /dev/null +++ b/py/services/model_lifecycle_service.py @@ -0,0 +1,245 @@ +"""Service routines for model lifecycle mutations.""" + +from __future__ import annotations + +import logging +import os +from typing import Awaitable, Callable, Dict, Iterable, List, Optional + +from ..services.service_registry import ServiceRegistry +from ..utils.constants import PREVIEW_EXTENSIONS + +logger = logging.getLogger(__name__) + + +async def delete_model_artifacts(target_dir: str, file_name: str) -> List[str]: + """Delete the primary model artefacts within ``target_dir``.""" + + patterns = [ + f"{file_name}.safetensors", + f"{file_name}.metadata.json", + ] + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{file_name}{ext}") + + deleted: List[str] = [] + main_file = patterns[0] + main_path = os.path.join(target_dir, main_file).replace(os.sep, "/") + + if os.path.exists(main_path): + os.remove(main_path) + deleted.append(main_path) + else: + logger.warning("Model file not found: %s", main_file) + + for pattern in patterns[1:]: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + try: + os.remove(path) + deleted.append(pattern) + except Exception as exc: # pragma: no cover - defensive path + logger.warning("Failed to delete %s: %s", pattern, exc) + + return deleted + + +class ModelLifecycleService: + """Co-ordinate destructive and mutating model operations.""" + + def __init__( + self, + *, + scanner, + metadata_manager, + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + recipe_scanner_factory: Callable[[], Awaitable] | None = None, + ) -> None: + self._scanner = scanner + self._metadata_manager = metadata_manager + self._metadata_loader = metadata_loader + self._recipe_scanner_factory = ( + recipe_scanner_factory or ServiceRegistry.get_recipe_scanner + ) + + async def delete_model(self, file_path: str) -> Dict[str, object]: + """Delete a model file and associated artefacts.""" + + if not file_path: + raise ValueError("Model path is required") + + target_dir = os.path.dirname(file_path) + file_name = os.path.splitext(os.path.basename(file_path))[0] + + deleted_files = await delete_model_artifacts(target_dir, file_name) + + cache = await self._scanner.get_cached_data() + cache.raw_data = [item for item in cache.raw_data if item["file_path"] != file_path] + await cache.resort() + + if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index: + self._scanner._hash_index.remove_by_path(file_path) + + return {"success": True, "deleted_files": deleted_files} + + async def exclude_model(self, file_path: str) -> Dict[str, object]: + """Mark a model as excluded and prune cache references.""" + + if not file_path: + raise ValueError("Model path is required") + + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" + metadata = await self._metadata_loader(metadata_path) + metadata["exclude"] = True + + await self._metadata_manager.save_metadata(file_path, metadata) + + cache = await self._scanner.get_cached_data() + model_to_remove = next( + (item for item in cache.raw_data if item["file_path"] == file_path), + None, + ) + + if model_to_remove: + for tag in model_to_remove.get("tags", []): + if tag in getattr(self._scanner, "_tags_count", {}): + self._scanner._tags_count[tag] = max( + 0, self._scanner._tags_count[tag] - 1 + ) + if self._scanner._tags_count[tag] == 0: + del self._scanner._tags_count[tag] + + if hasattr(self._scanner, "_hash_index") and self._scanner._hash_index: + self._scanner._hash_index.remove_by_path(file_path) + + cache.raw_data = [ + item for item in cache.raw_data if item["file_path"] != file_path + ] + await cache.resort() + + excluded = getattr(self._scanner, "_excluded_models", None) + if isinstance(excluded, list): + excluded.append(file_path) + + message = f"Model {os.path.basename(file_path)} excluded" + return {"success": True, "message": message} + + async def bulk_delete_models(self, file_paths: Iterable[str]) -> Dict[str, object]: + """Delete a collection of models via the scanner bulk operation.""" + + file_paths = list(file_paths) + if not file_paths: + raise ValueError("No file paths provided for deletion") + + return await self._scanner.bulk_delete_models(file_paths) + + async def rename_model( + self, *, file_path: str, new_file_name: str + ) -> Dict[str, object]: + """Rename a model and its companion artefacts.""" + + if not file_path or not new_file_name: + raise ValueError("File path and new file name are required") + + invalid_chars = {"/", "\\", ":", "*", "?", '"', "<", ">", "|"} + if any(char in new_file_name for char in invalid_chars): + raise ValueError("Invalid characters in file name") + + target_dir = os.path.dirname(file_path) + old_file_name = os.path.splitext(os.path.basename(file_path))[0] + new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace( + os.sep, "/" + ) + + if os.path.exists(new_file_path): + raise ValueError("A file with this name already exists") + + patterns = [ + f"{old_file_name}.safetensors", + f"{old_file_name}.metadata.json", + f"{old_file_name}.metadata.json.bak", + ] + for ext in PREVIEW_EXTENSIONS: + patterns.append(f"{old_file_name}{ext}") + + existing_files: List[tuple[str, str]] = [] + for pattern in patterns: + path = os.path.join(target_dir, pattern) + if os.path.exists(path): + existing_files.append((path, pattern)) + + metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") + metadata: Optional[Dict[str, object]] = None + hash_value: Optional[str] = None + + if os.path.exists(metadata_path): + metadata = await self._metadata_loader(metadata_path) + hash_value = metadata.get("sha256") if isinstance(metadata, dict) else None + + renamed_files: List[str] = [] + new_metadata_path: Optional[str] = None + new_preview: Optional[str] = None + + for old_path, pattern in existing_files: + ext = self._get_multipart_ext(pattern) + new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace( + os.sep, "/" + ) + os.rename(old_path, new_path) + renamed_files.append(new_path) + + if ext == ".metadata.json": + new_metadata_path = new_path + + if metadata and new_metadata_path: + metadata["file_name"] = new_file_name + metadata["file_path"] = new_file_path + + if metadata.get("preview_url"): + old_preview = str(metadata["preview_url"]) + ext = self._get_multipart_ext(old_preview) + new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace( + os.sep, "/" + ) + metadata["preview_url"] = new_preview + + await self._metadata_manager.save_metadata(new_file_path, metadata) + + if metadata: + await self._scanner.update_single_model_cache( + file_path, new_file_path, metadata + ) + + if hash_value and getattr(self._scanner, "model_type", "") == "lora": + recipe_scanner = await self._recipe_scanner_factory() + if recipe_scanner: + try: + await recipe_scanner.update_lora_filename_by_hash( + hash_value, new_file_name + ) + except Exception as exc: # pragma: no cover - defensive logging + logger.error( + "Error updating recipe references for %s: %s", + file_path, + exc, + ) + + return { + "success": True, + "new_file_path": new_file_path, + "new_preview_path": new_preview, + "renamed_files": renamed_files, + "reload_required": False, + } + + @staticmethod + def _get_multipart_ext(filename: str) -> str: + """Return the extension for files with compound suffixes.""" + + parts = filename.split(".") + if len(parts) == 3: + return "." + ".".join(parts[-2:]) + if len(parts) >= 4: + return "." + ".".join(parts[-3:]) + return os.path.splitext(filename)[1] + diff --git a/py/services/model_query.py b/py/services/model_query.py new file mode 100644 index 00000000..df7bb67a --- /dev/null +++ b/py/services/model_query.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Protocol, Callable + +from ..utils.constants import NSFW_LEVELS +from ..utils.utils import fuzzy_match as default_fuzzy_match + + +class SettingsProvider(Protocol): + """Protocol describing the SettingsManager contract used by query helpers.""" + + def get(self, key: str, default: Any = None) -> Any: + ... + + +@dataclass(frozen=True) +class SortParams: + """Normalized representation of sorting instructions.""" + + key: str + order: str + + +@dataclass(frozen=True) +class FilterCriteria: + """Container for model list filtering options.""" + + folder: Optional[str] = None + base_models: Optional[Sequence[str]] = None + tags: Optional[Sequence[str]] = None + favorites_only: bool = False + search_options: Optional[Dict[str, Any]] = None + + +class ModelCacheRepository: + """Adapter around scanner cache access and sort normalisation.""" + + def __init__(self, scanner) -> None: + self._scanner = scanner + + async def get_cache(self): + """Return the underlying cache instance from the scanner.""" + return await self._scanner.get_cached_data() + + async def fetch_sorted(self, params: SortParams) -> List[Dict[str, Any]]: + """Fetch cached data pre-sorted according to ``params``.""" + cache = await self.get_cache() + return await cache.get_sorted_data(params.key, params.order) + + @staticmethod + def parse_sort(sort_by: str) -> SortParams: + """Parse an incoming sort string into key/order primitives.""" + if not sort_by: + return SortParams(key="name", order="asc") + + if ":" in sort_by: + raw_key, raw_order = sort_by.split(":", 1) + sort_key = raw_key.strip().lower() or "name" + order = raw_order.strip().lower() + else: + sort_key = sort_by.strip().lower() or "name" + order = "asc" + + if order not in ("asc", "desc"): + order = "asc" + + return SortParams(key=sort_key, order=order) + + +class ModelFilterSet: + """Applies common filtering rules to the model collection.""" + + def __init__(self, settings: SettingsProvider, nsfw_levels: Optional[Dict[str, int]] = None) -> None: + self._settings = settings + self._nsfw_levels = nsfw_levels or NSFW_LEVELS + + def apply(self, data: Iterable[Dict[str, Any]], criteria: FilterCriteria) -> List[Dict[str, Any]]: + """Return items that satisfy the provided criteria.""" + items = list(data) + + if self._settings.get("show_only_sfw", False): + threshold = self._nsfw_levels.get("R", 0) + items = [ + item for item in items + if not item.get("preview_nsfw_level") or item.get("preview_nsfw_level") < threshold + ] + + if criteria.favorites_only: + items = [item for item in items if item.get("favorite", False)] + + folder = criteria.folder + options = criteria.search_options or {} + recursive = bool(options.get("recursive", True)) + if folder is not None: + if recursive: + if folder: + folder_with_sep = f"{folder}/" + items = [ + item for item in items + if item.get("folder") == folder or item.get("folder", "").startswith(folder_with_sep) + ] + else: + items = [item for item in items if item.get("folder") == folder] + + base_models = criteria.base_models or [] + if base_models: + base_model_set = set(base_models) + items = [item for item in items if item.get("base_model") in base_model_set] + + tags = criteria.tags or [] + if tags: + tag_set = set(tags) + items = [ + item for item in items + if any(tag in tag_set for tag in item.get("tags", [])) + ] + + return items + + +class SearchStrategy: + """Encapsulates text and fuzzy matching behaviour for model queries.""" + + DEFAULT_OPTIONS: Dict[str, Any] = { + "filename": True, + "modelname": True, + "tags": False, + "recursive": True, + "creator": False, + } + + def __init__(self, fuzzy_matcher: Optional[Callable[[str, str], bool]] = None) -> None: + self._fuzzy_match = fuzzy_matcher or default_fuzzy_match + + def normalize_options(self, options: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """Merge provided options with defaults without mutating input.""" + normalized = dict(self.DEFAULT_OPTIONS) + if options: + normalized.update(options) + return normalized + + def apply( + self, + data: Iterable[Dict[str, Any]], + search_term: str, + options: Dict[str, Any], + fuzzy: bool = False, + ) -> List[Dict[str, Any]]: + """Return items matching the search term using the configured strategy.""" + if not search_term: + return list(data) + + search_lower = search_term.lower() + results: List[Dict[str, Any]] = [] + + for item in data: + if options.get("filename", True): + candidate = item.get("file_name", "") + if self._matches(candidate, search_term, search_lower, fuzzy): + results.append(item) + continue + + if options.get("modelname", True): + candidate = item.get("model_name", "") + if self._matches(candidate, search_term, search_lower, fuzzy): + results.append(item) + continue + + if options.get("tags", False): + tags = item.get("tags", []) or [] + if any(self._matches(tag, search_term, search_lower, fuzzy) for tag in tags): + results.append(item) + continue + + if options.get("creator", False): + creator_username = "" + civitai = item.get("civitai") + if isinstance(civitai, dict): + creator = civitai.get("creator") + if isinstance(creator, dict): + creator_username = creator.get("username", "") + if creator_username and self._matches(creator_username, search_term, search_lower, fuzzy): + results.append(item) + continue + + return results + + def _matches(self, candidate: str, search_term: str, search_lower: str, fuzzy: bool) -> bool: + if not candidate: + return False + + candidate_lower = candidate.lower() + if fuzzy: + return self._fuzzy_match(candidate, search_term) + return search_lower in candidate_lower diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index f0ae3177..51aa4507 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -13,6 +13,7 @@ from ..utils.metadata_manager import MetadataManager from .model_cache import ModelCache from .model_hash_index import ModelHashIndex from ..utils.constants import PREVIEW_EXTENSIONS +from .model_lifecycle_service import delete_model_artifacts from .service_registry import ServiceRegistry from .websocket_manager import ws_manager @@ -1040,10 +1041,8 @@ class ModelScanner: target_dir = os.path.dirname(file_path) file_name = os.path.splitext(os.path.basename(file_path))[0] - # Delete all associated files for the model - from ..utils.routes_common import ModelRouteUtils - deleted_files = await ModelRouteUtils.delete_model_files( - target_dir, + deleted_files = await delete_model_artifacts( + target_dir, file_name ) diff --git a/py/services/preview_asset_service.py b/py/services/preview_asset_service.py new file mode 100644 index 00000000..42baadac --- /dev/null +++ b/py/services/preview_asset_service.py @@ -0,0 +1,168 @@ +"""Service for processing preview assets for models.""" + +from __future__ import annotations + +import logging +import os +from typing import Awaitable, Callable, Dict, Optional, Sequence + +from ..utils.constants import CARD_PREVIEW_WIDTH, PREVIEW_EXTENSIONS + +logger = logging.getLogger(__name__) + + +class PreviewAssetService: + """Manage fetching and persisting preview assets.""" + + def __init__( + self, + *, + metadata_manager, + downloader_factory: Callable[[], Awaitable], + exif_utils, + ) -> None: + self._metadata_manager = metadata_manager + self._downloader_factory = downloader_factory + self._exif_utils = exif_utils + + async def ensure_preview_for_metadata( + self, + metadata_path: str, + local_metadata: Dict[str, object], + images: Sequence[Dict[str, object]] | None, + ) -> None: + """Ensure preview assets exist for the supplied metadata entry.""" + + if local_metadata.get("preview_url") and os.path.exists( + str(local_metadata["preview_url"]) + ): + return + + if not images: + return + + first_preview = images[0] + base_name = os.path.splitext(os.path.splitext(os.path.basename(metadata_path))[0])[0] + preview_dir = os.path.dirname(metadata_path) + is_video = first_preview.get("type") == "video" + + if is_video: + extension = ".mp4" + preview_path = os.path.join(preview_dir, base_name + extension) + downloader = await self._downloader_factory() + success, result = await downloader.download_file( + first_preview["url"], preview_path, use_auth=False + ) + if success: + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + else: + extension = ".webp" + preview_path = os.path.join(preview_dir, base_name + extension) + downloader = await self._downloader_factory() + success, content, _headers = await downloader.download_to_memory( + first_preview["url"], use_auth=False + ) + if not success: + return + + try: + optimized_data, _ = self._exif_utils.optimize_image( + image_data=content, + target_width=CARD_PREVIEW_WIDTH, + format="webp", + quality=85, + preserve_metadata=False, + ) + with open(preview_path, "wb") as handle: + handle.write(optimized_data) + except Exception as exc: # pragma: no cover - defensive path + logger.error("Error optimizing preview image: %s", exc) + try: + with open(preview_path, "wb") as handle: + handle.write(content) + except Exception as save_exc: + logger.error("Error saving preview image: %s", save_exc) + return + + local_metadata["preview_url"] = preview_path.replace(os.sep, "/") + local_metadata["preview_nsfw_level"] = first_preview.get("nsfwLevel", 0) + + async def replace_preview( + self, + *, + model_path: str, + preview_data: bytes, + content_type: str, + original_filename: Optional[str], + nsfw_level: int, + update_preview_in_cache: Callable[[str, str, int], Awaitable[bool]], + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + ) -> Dict[str, object]: + """Replace an existing preview asset for a model.""" + + base_name = os.path.splitext(os.path.basename(model_path))[0] + folder = os.path.dirname(model_path) + + extension, optimized_data = await self._convert_preview( + preview_data, content_type, original_filename + ) + + for ext in PREVIEW_EXTENSIONS: + existing_preview = os.path.join(folder, base_name + ext) + if os.path.exists(existing_preview): + try: + os.remove(existing_preview) + except Exception as exc: # pragma: no cover - defensive path + logger.warning( + "Failed to delete existing preview %s: %s", existing_preview, exc + ) + + preview_path = os.path.join(folder, base_name + extension).replace(os.sep, "/") + with open(preview_path, "wb") as handle: + handle.write(optimized_data) + + metadata_path = os.path.splitext(model_path)[0] + ".metadata.json" + metadata = await metadata_loader(metadata_path) + metadata["preview_url"] = preview_path + metadata["preview_nsfw_level"] = nsfw_level + await self._metadata_manager.save_metadata(model_path, metadata) + + await update_preview_in_cache(model_path, preview_path, nsfw_level) + + return {"preview_path": preview_path, "preview_nsfw_level": nsfw_level} + + async def _convert_preview( + self, data: bytes, content_type: str, original_filename: Optional[str] + ) -> tuple[str, bytes]: + """Convert preview bytes to the persisted representation.""" + + if content_type.startswith("video/"): + extension = self._resolve_video_extension(content_type, original_filename) + return extension, data + + original_ext = (original_filename or "").lower() + if original_ext.endswith(".gif") or content_type.lower() == "image/gif": + return ".gif", data + + optimized_data, _ = self._exif_utils.optimize_image( + image_data=data, + target_width=CARD_PREVIEW_WIDTH, + format="webp", + quality=85, + preserve_metadata=False, + ) + return ".webp", optimized_data + + def _resolve_video_extension(self, content_type: str, original_filename: Optional[str]) -> str: + """Infer the best extension for a video preview.""" + + if original_filename: + extension = os.path.splitext(original_filename)[1].lower() + if extension in {".mp4", ".webm", ".mov", ".avi"}: + return extension + + if "webm" in content_type: + return ".webm" + return ".mp4" + diff --git a/py/services/recipe_cache.py b/py/services/recipe_cache.py index b1f52246..ac28b3aa 100644 --- a/py/services/recipe_cache.py +++ b/py/services/recipe_cache.py @@ -1,5 +1,5 @@ import asyncio -from typing import List, Dict +from typing import Iterable, List, Dict, Optional from dataclasses import dataclass from operator import itemgetter from natsort import natsorted @@ -10,77 +10,115 @@ class RecipeCache: raw_data: List[Dict] sorted_by_name: List[Dict] sorted_by_date: List[Dict] - + def __post_init__(self): self._lock = asyncio.Lock() async def resort(self, name_only: bool = False): """Resort all cached data views""" async with self._lock: - self.sorted_by_name = natsorted( - self.raw_data, - key=lambda x: x.get('title', '').lower() # Case-insensitive sort - ) - if not name_only: - self.sorted_by_date = sorted( - self.raw_data, - key=itemgetter('created_date', 'file_path'), - reverse=True - ) - - async def update_recipe_metadata(self, recipe_id: str, metadata: Dict) -> bool: + self._resort_locked(name_only=name_only) + + async def update_recipe_metadata(self, recipe_id: str, metadata: Dict, *, resort: bool = True) -> bool: """Update metadata for a specific recipe in all cached data - + Args: recipe_id: The ID of the recipe to update metadata: The new metadata - + Returns: bool: True if the update was successful, False if the recipe wasn't found """ + async with self._lock: + for item in self.raw_data: + if str(item.get('id')) == str(recipe_id): + item.update(metadata) + if resort: + self._resort_locked() + return True + return False # Recipe not found + + async def add_recipe(self, recipe_data: Dict, *, resort: bool = False) -> None: + """Add a new recipe to the cache.""" - # Update in raw_data - for item in self.raw_data: - if item.get('id') == recipe_id: - item.update(metadata) - break - else: - return False # Recipe not found - - # Resort to reflect changes - await self.resort() - return True - - async def add_recipe(self, recipe_data: Dict) -> None: - """Add a new recipe to the cache - - Args: - recipe_data: The recipe data to add - """ async with self._lock: self.raw_data.append(recipe_data) - await self.resort() + if resort: + self._resort_locked() + + async def remove_recipe(self, recipe_id: str, *, resort: bool = False) -> Optional[Dict]: + """Remove a recipe from the cache by ID. - async def remove_recipe(self, recipe_id: str) -> bool: - """Remove a recipe from the cache by ID - Args: recipe_id: The ID of the recipe to remove - + Returns: - bool: True if the recipe was found and removed, False otherwise + The removed recipe data if found, otherwise ``None``. """ - # Find the recipe in raw_data - recipe_index = next((i for i, recipe in enumerate(self.raw_data) - if recipe.get('id') == recipe_id), None) - - if recipe_index is None: - return False - - # Remove from raw_data - self.raw_data.pop(recipe_index) - - # Resort to update sorted lists - await self.resort() - - return True \ No newline at end of file + + async with self._lock: + for index, recipe in enumerate(self.raw_data): + if str(recipe.get('id')) == str(recipe_id): + removed = self.raw_data.pop(index) + if resort: + self._resort_locked() + return removed + return None + + async def bulk_remove(self, recipe_ids: Iterable[str], *, resort: bool = False) -> List[Dict]: + """Remove multiple recipes from the cache.""" + + id_set = {str(recipe_id) for recipe_id in recipe_ids} + if not id_set: + return [] + + async with self._lock: + removed = [item for item in self.raw_data if str(item.get('id')) in id_set] + if not removed: + return [] + + self.raw_data = [item for item in self.raw_data if str(item.get('id')) not in id_set] + if resort: + self._resort_locked() + return removed + + async def replace_recipe(self, recipe_id: str, new_data: Dict, *, resort: bool = False) -> bool: + """Replace cached data for a recipe.""" + + async with self._lock: + for index, recipe in enumerate(self.raw_data): + if str(recipe.get('id')) == str(recipe_id): + self.raw_data[index] = new_data + if resort: + self._resort_locked() + return True + return False + + async def get_recipe(self, recipe_id: str) -> Optional[Dict]: + """Return a shallow copy of a cached recipe.""" + + async with self._lock: + for recipe in self.raw_data: + if str(recipe.get('id')) == str(recipe_id): + return dict(recipe) + return None + + async def snapshot(self) -> List[Dict]: + """Return a copy of all cached recipes.""" + + async with self._lock: + return [dict(item) for item in self.raw_data] + + def _resort_locked(self, *, name_only: bool = False) -> None: + """Sort cached views. Caller must hold ``_lock``.""" + + self.sorted_by_name = natsorted( + self.raw_data, + key=lambda x: x.get('title', '').lower() + ) + if not name_only: + self.sorted_by_date = sorted( + self.raw_data, + key=itemgetter('created_date', 'file_path'), + reverse=True + ) \ No newline at end of file diff --git a/py/services/recipe_scanner.py b/py/services/recipe_scanner.py index ca5a20ac..9a82b237 100644 --- a/py/services/recipe_scanner.py +++ b/py/services/recipe_scanner.py @@ -3,13 +3,14 @@ import logging import asyncio import json import time -from typing import List, Dict, Optional, Any, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from ..config import config from .recipe_cache import RecipeCache from .service_registry import ServiceRegistry from .lora_scanner import LoraScanner from .metadata_service import get_default_metadata_provider -from ..utils.utils import fuzzy_match +from .recipes.errors import RecipeNotFoundError +from ..utils.utils import calculate_recipe_fingerprint, fuzzy_match from natsort import natsorted import sys @@ -46,6 +47,8 @@ class RecipeScanner: self._initialization_lock = asyncio.Lock() self._initialization_task: Optional[asyncio.Task] = None self._is_initializing = False + self._mutation_lock = asyncio.Lock() + self._resort_tasks: Set[asyncio.Task] = set() if lora_scanner: self._lora_scanner = lora_scanner self._initialized = True @@ -191,6 +194,22 @@ class RecipeScanner: # Clean up the event loop loop.close() + def _schedule_resort(self, *, name_only: bool = False) -> None: + """Schedule a background resort of the recipe cache.""" + + if not self._cache: + return + + async def _resort_wrapper() -> None: + try: + await self._cache.resort(name_only=name_only) + except Exception as exc: # pragma: no cover - defensive logging + logger.error("Recipe Scanner: error resorting cache: %s", exc, exc_info=True) + + task = asyncio.create_task(_resort_wrapper()) + self._resort_tasks.add(task) + task.add_done_callback(lambda finished: self._resort_tasks.discard(finished)) + @property def recipes_dir(self) -> str: """Get path to recipes directory""" @@ -255,7 +274,45 @@ class RecipeScanner: # Return the cache (may be empty or partially initialized) return self._cache or RecipeCache(raw_data=[], sorted_by_name=[], sorted_by_date=[]) - + + async def refresh_cache(self, force: bool = False) -> RecipeCache: + """Public helper to refresh or return the recipe cache.""" + + return await self.get_cached_data(force_refresh=force) + + async def add_recipe(self, recipe_data: Dict[str, Any]) -> None: + """Add a recipe to the in-memory cache.""" + + if not recipe_data: + return + + cache = await self.get_cached_data() + await cache.add_recipe(recipe_data, resort=False) + self._schedule_resort() + + async def remove_recipe(self, recipe_id: str) -> bool: + """Remove a recipe from the cache by ID.""" + + if not recipe_id: + return False + + cache = await self.get_cached_data() + removed = await cache.remove_recipe(recipe_id, resort=False) + if removed is None: + return False + + self._schedule_resort() + return True + + async def bulk_remove(self, recipe_ids: Iterable[str]) -> int: + """Remove multiple recipes from the cache.""" + + cache = await self.get_cached_data() + removed = await cache.bulk_remove(recipe_ids, resort=False) + if removed: + self._schedule_resort() + return len(removed) + async def scan_all_recipes(self) -> List[Dict]: """Scan all recipe JSON files and return metadata""" recipes = [] @@ -326,7 +383,6 @@ class RecipeScanner: # Calculate and update fingerprint if missing if 'loras' in recipe_data and 'fingerprint' not in recipe_data: - from ..utils.utils import calculate_recipe_fingerprint fingerprint = calculate_recipe_fingerprint(recipe_data['loras']) recipe_data['fingerprint'] = fingerprint @@ -497,9 +553,36 @@ class RecipeScanner: logger.error(f"Error getting base model for lora: {e}") return None + def _enrich_lora_entry(self, lora: Dict[str, Any]) -> Dict[str, Any]: + """Populate convenience fields for a LoRA entry.""" + + if not lora or not self._lora_scanner: + return lora + + hash_value = (lora.get('hash') or '').lower() + if not hash_value: + return lora + + try: + lora['inLibrary'] = self._lora_scanner.has_hash(hash_value) + lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(hash_value) + lora['localPath'] = self._lora_scanner.get_path_by_hash(hash_value) + except Exception as exc: # pragma: no cover - defensive logging + logger.debug("Error enriching lora entry %s: %s", hash_value, exc) + + return lora + + async def get_local_lora(self, name: str) -> Optional[Dict[str, Any]]: + """Lookup a local LoRA model by name.""" + + if not self._lora_scanner or not name: + return None + + return await self._lora_scanner.get_model_info_by_name(name) + async def get_paginated_data(self, page: int, page_size: int, sort_by: str = 'date', search: str = None, filters: dict = None, search_options: dict = None, lora_hash: str = None, bypass_filters: bool = True): """Get paginated and filtered recipe data - + Args: page: Current page number (1-based) page_size: Number of items per page @@ -598,16 +681,12 @@ class RecipeScanner: # Get paginated items paginated_items = filtered_data[start_idx:end_idx] - + # Add inLibrary information for each lora for item in paginated_items: if 'loras' in item: - for lora in item['loras']: - if 'hash' in lora and lora['hash']: - lora['inLibrary'] = self._lora_scanner.has_hash(lora['hash'].lower()) - lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora['hash'].lower()) - lora['localPath'] = self._lora_scanner.get_path_by_hash(lora['hash'].lower()) - + item['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in item['loras']] + result = { 'items': paginated_items, 'total': total_items, @@ -653,13 +732,8 @@ class RecipeScanner: # Add lora metadata if 'loras' in formatted_recipe: - for lora in formatted_recipe['loras']: - if 'hash' in lora and lora['hash']: - lora_hash = lora['hash'].lower() - lora['inLibrary'] = self._lora_scanner.has_hash(lora_hash) - lora['preview_url'] = self._lora_scanner.get_preview_url_by_hash(lora_hash) - lora['localPath'] = self._lora_scanner.get_path_by_hash(lora_hash) - + formatted_recipe['loras'] = [self._enrich_lora_entry(dict(lora)) for lora in formatted_recipe['loras']] + return formatted_recipe def _format_file_url(self, file_path: str) -> str: @@ -717,26 +791,159 @@ class RecipeScanner: # Save updated recipe with open(recipe_json_path, 'w', encoding='utf-8') as f: json.dump(recipe_data, f, indent=4, ensure_ascii=False) - + # Update the cache if it exists if self._cache is not None: - await self._cache.update_recipe_metadata(recipe_id, metadata) - + await self._cache.update_recipe_metadata(recipe_id, metadata, resort=False) + self._schedule_resort() + # If the recipe has an image, update its EXIF metadata from ..utils.exif_utils import ExifUtils image_path = recipe_data.get('file_path') if image_path and os.path.exists(image_path): ExifUtils.append_recipe_metadata(image_path, recipe_data) - + return True except Exception as e: import logging logging.getLogger(__name__).error(f"Error updating recipe metadata: {e}", exc_info=True) return False + async def update_lora_entry( + self, + recipe_id: str, + lora_index: int, + *, + target_name: str, + target_lora: Optional[Dict[str, Any]] = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Update a specific LoRA entry within a recipe. + + Returns the updated recipe data and the refreshed LoRA metadata. + """ + + if target_name is None: + raise ValueError("target_name must be provided") + + recipe_json_path = os.path.join(self.recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + raise RecipeNotFoundError("Recipe not found") + + async with self._mutation_lock: + with open(recipe_json_path, 'r', encoding='utf-8') as file_obj: + recipe_data = json.load(file_obj) + + loras = recipe_data.get('loras', []) + if lora_index >= len(loras): + raise RecipeNotFoundError("LoRA index out of range in recipe") + + lora_entry = loras[lora_index] + lora_entry['isDeleted'] = False + lora_entry['exclude'] = False + lora_entry['file_name'] = target_name + + if target_lora is not None: + sha_value = target_lora.get('sha256') or target_lora.get('sha') + if sha_value: + lora_entry['hash'] = sha_value.lower() + + civitai_info = target_lora.get('civitai') or {} + if civitai_info: + lora_entry['modelName'] = civitai_info.get('model', {}).get('name', '') + lora_entry['modelVersionName'] = civitai_info.get('name', '') + lora_entry['modelVersionId'] = civitai_info.get('id') + + recipe_data['fingerprint'] = calculate_recipe_fingerprint(recipe_data.get('loras', [])) + recipe_data['modified'] = time.time() + + with open(recipe_json_path, 'w', encoding='utf-8') as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + cache = await self.get_cached_data() + replaced = await cache.replace_recipe(recipe_id, recipe_data, resort=False) + if not replaced: + await cache.add_recipe(recipe_data, resort=False) + self._schedule_resort() + + updated_lora = dict(lora_entry) + if target_lora is not None: + preview_url = target_lora.get('preview_url') + if preview_url: + updated_lora['preview_url'] = config.get_preview_static_url(preview_url) + if target_lora.get('file_path'): + updated_lora['localPath'] = target_lora['file_path'] + + updated_lora = self._enrich_lora_entry(updated_lora) + return recipe_data, updated_lora + + async def get_recipes_for_lora(self, lora_hash: str) -> List[Dict[str, Any]]: + """Return recipes that reference a given LoRA hash.""" + + if not lora_hash: + return [] + + normalized_hash = lora_hash.lower() + cache = await self.get_cached_data() + matching_recipes: List[Dict[str, Any]] = [] + + for recipe in cache.raw_data: + loras = recipe.get('loras', []) + if any((entry.get('hash') or '').lower() == normalized_hash for entry in loras): + recipe_copy = {**recipe} + recipe_copy['loras'] = [self._enrich_lora_entry(dict(entry)) for entry in loras] + recipe_copy['file_url'] = self._format_file_url(recipe.get('file_path')) + matching_recipes.append(recipe_copy) + + return matching_recipes + + async def get_recipe_syntax_tokens(self, recipe_id: str) -> List[str]: + """Build LoRA syntax tokens for a recipe.""" + + cache = await self.get_cached_data() + recipe = await cache.get_recipe(recipe_id) + if recipe is None: + raise RecipeNotFoundError("Recipe not found") + + loras = recipe.get('loras', []) + if not loras: + return [] + + lora_cache = None + if self._lora_scanner is not None: + lora_cache = await self._lora_scanner.get_cached_data() + + syntax_parts: List[str] = [] + for lora in loras: + if lora.get('isDeleted', False): + continue + + file_name = None + hash_value = (lora.get('hash') or '').lower() + if hash_value and self._lora_scanner is not None and hasattr(self._lora_scanner, '_hash_index'): + file_path = self._lora_scanner._hash_index.get_path(hash_value) + if file_path: + file_name = os.path.splitext(os.path.basename(file_path))[0] + + if not file_name and lora.get('modelVersionId') and lora_cache is not None: + for cached_lora in getattr(lora_cache, 'raw_data', []): + civitai_info = cached_lora.get('civitai') + if civitai_info and civitai_info.get('id') == lora.get('modelVersionId'): + cached_path = cached_lora.get('path') or cached_lora.get('file_path') + if cached_path: + file_name = os.path.splitext(os.path.basename(cached_path))[0] + break + + if not file_name: + file_name = lora.get('file_name', 'unknown-lora') + + strength = lora.get('strength', 1.0) + syntax_parts.append(f"") + + return syntax_parts + async def update_lora_filename_by_hash(self, hash_value: str, new_file_name: str) -> Tuple[int, int]: """Update file_name in all recipes that contain a LoRA with the specified hash. - + Args: hash_value: The SHA256 hash value of the LoRA new_file_name: The new file_name to set diff --git a/py/services/recipes/__init__.py b/py/services/recipes/__init__.py new file mode 100644 index 00000000..8009b7c3 --- /dev/null +++ b/py/services/recipes/__init__.py @@ -0,0 +1,23 @@ +"""Recipe service layer implementations.""" + +from .analysis_service import RecipeAnalysisService +from .persistence_service import RecipePersistenceService +from .sharing_service import RecipeSharingService +from .errors import ( + RecipeServiceError, + RecipeValidationError, + RecipeNotFoundError, + RecipeDownloadError, + RecipeConflictError, +) + +__all__ = [ + "RecipeAnalysisService", + "RecipePersistenceService", + "RecipeSharingService", + "RecipeServiceError", + "RecipeValidationError", + "RecipeNotFoundError", + "RecipeDownloadError", + "RecipeConflictError", +] diff --git a/py/services/recipes/analysis_service.py b/py/services/recipes/analysis_service.py new file mode 100644 index 00000000..77d80e34 --- /dev/null +++ b/py/services/recipes/analysis_service.py @@ -0,0 +1,289 @@ +"""Services responsible for recipe metadata analysis.""" +from __future__ import annotations + +import base64 +import io +import os +import re +import tempfile +from dataclasses import dataclass +from typing import Any, Callable, Optional + +import numpy as np +from PIL import Image + +from ...utils.utils import calculate_recipe_fingerprint +from .errors import ( + RecipeDownloadError, + RecipeNotFoundError, + RecipeServiceError, + RecipeValidationError, +) + + +@dataclass(frozen=True) +class AnalysisResult: + """Return payload from analysis operations.""" + + payload: dict[str, Any] + status: int = 200 + + +class RecipeAnalysisService: + """Extract recipe metadata from various image sources.""" + + def __init__( + self, + *, + exif_utils, + recipe_parser_factory, + downloader_factory: Callable[[], Any], + metadata_collector: Optional[Callable[[], Any]] = None, + metadata_processor_cls: Optional[type] = None, + metadata_registry_cls: Optional[type] = None, + standalone_mode: bool = False, + logger, + ) -> None: + self._exif_utils = exif_utils + self._recipe_parser_factory = recipe_parser_factory + self._downloader_factory = downloader_factory + self._metadata_collector = metadata_collector + self._metadata_processor_cls = metadata_processor_cls + self._metadata_registry_cls = metadata_registry_cls + self._standalone_mode = standalone_mode + self._logger = logger + + async def analyze_uploaded_image( + self, + *, + image_bytes: bytes | None, + recipe_scanner, + ) -> AnalysisResult: + """Analyze an uploaded image payload.""" + + if not image_bytes: + raise RecipeValidationError("No image data provided") + + temp_path = self._write_temp_file(image_bytes) + try: + metadata = self._exif_utils.extract_image_metadata(temp_path) + if not metadata: + return AnalysisResult({"error": "No metadata found in this image", "loras": []}) + + return await self._parse_metadata( + metadata, + recipe_scanner=recipe_scanner, + image_path=None, + include_image_base64=False, + ) + finally: + self._safe_cleanup(temp_path) + + async def analyze_remote_image( + self, + *, + url: str | None, + recipe_scanner, + civitai_client, + ) -> AnalysisResult: + """Analyze an image accessible via URL, including Civitai integration.""" + + if not url: + raise RecipeValidationError("No URL provided") + + if civitai_client is None: + raise RecipeServiceError("Civitai client unavailable") + + temp_path = self._create_temp_path() + metadata: Optional[dict[str, Any]] = None + try: + civitai_match = re.match(r"https://civitai\.com/images/(\d+)", url) + if civitai_match: + image_info = await civitai_client.get_image_info(civitai_match.group(1)) + if not image_info: + raise RecipeDownloadError("Failed to fetch image information from Civitai") + image_url = image_info.get("url") + if not image_url: + raise RecipeDownloadError("No image URL found in Civitai response") + await self._download_image(image_url, temp_path) + metadata = image_info.get("meta") if "meta" in image_info else None + else: + await self._download_image(url, temp_path) + + if metadata is None: + metadata = self._exif_utils.extract_image_metadata(temp_path) + + if not metadata: + return self._metadata_not_found_response(temp_path) + + return await self._parse_metadata( + metadata, + recipe_scanner=recipe_scanner, + image_path=temp_path, + include_image_base64=True, + ) + finally: + self._safe_cleanup(temp_path) + + async def analyze_local_image( + self, + *, + file_path: str | None, + recipe_scanner, + ) -> AnalysisResult: + """Analyze a file already present on disk.""" + + if not file_path: + raise RecipeValidationError("No file path provided") + + normalized_path = os.path.normpath(file_path.strip('"').strip("'")) + if not os.path.isfile(normalized_path): + raise RecipeNotFoundError("File not found") + + metadata = self._exif_utils.extract_image_metadata(normalized_path) + if not metadata: + return self._metadata_not_found_response(normalized_path) + + return await self._parse_metadata( + metadata, + recipe_scanner=recipe_scanner, + image_path=normalized_path, + include_image_base64=True, + ) + + async def analyze_widget_metadata(self, *, recipe_scanner) -> AnalysisResult: + """Analyse the most recent generation metadata for widget saves.""" + + if self._metadata_collector is None or self._metadata_processor_cls is None: + raise RecipeValidationError("Metadata collection not available") + + raw_metadata = self._metadata_collector() + metadata_dict = self._metadata_processor_cls.to_dict(raw_metadata) + if not metadata_dict: + raise RecipeValidationError("No generation metadata found") + + latest_image = None + if not self._standalone_mode and self._metadata_registry_cls is not None: + metadata_registry = self._metadata_registry_cls() + latest_image = metadata_registry.get_first_decoded_image() + + if latest_image is None: + raise RecipeValidationError( + "No recent images found to use for recipe. Try generating an image first." + ) + + image_bytes = self._convert_tensor_to_png_bytes(latest_image) + if image_bytes is None: + raise RecipeValidationError("Cannot handle this data shape from metadata registry") + + return AnalysisResult( + { + "metadata": metadata_dict, + "image_bytes": image_bytes, + } + ) + + # Internal helpers ------------------------------------------------- + + async def _parse_metadata( + self, + metadata: dict[str, Any], + *, + recipe_scanner, + image_path: Optional[str], + include_image_base64: bool, + ) -> AnalysisResult: + parser = self._recipe_parser_factory.create_parser(metadata) + if parser is None: + payload = {"error": "No parser found for this image", "loras": []} + if include_image_base64 and image_path: + payload["image_base64"] = self._encode_file(image_path) + return AnalysisResult(payload) + + result = await parser.parse_metadata(metadata, recipe_scanner=recipe_scanner) + + if include_image_base64 and image_path: + result["image_base64"] = self._encode_file(image_path) + + if "error" in result and not result.get("loras"): + return AnalysisResult(result) + + fingerprint = calculate_recipe_fingerprint(result.get("loras", [])) + result["fingerprint"] = fingerprint + + matching_recipes: list[str] = [] + if fingerprint: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) + result["matching_recipes"] = matching_recipes + + return AnalysisResult(result) + + async def _download_image(self, url: str, temp_path: str) -> None: + downloader = await self._downloader_factory() + success, result = await downloader.download_file(url, temp_path, use_auth=False) + if not success: + raise RecipeDownloadError(f"Failed to download image from URL: {result}") + + def _metadata_not_found_response(self, path: str) -> AnalysisResult: + payload: dict[str, Any] = {"error": "No metadata found in this image", "loras": []} + if os.path.exists(path): + payload["image_base64"] = self._encode_file(path) + return AnalysisResult(payload) + + def _write_temp_file(self, data: bytes) -> str: + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + temp_file.write(data) + return temp_file.name + + def _create_temp_path(self) -> str: + with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: + return temp_file.name + + def _safe_cleanup(self, path: Optional[str]) -> None: + if path and os.path.exists(path): + try: + os.unlink(path) + except Exception as exc: # pragma: no cover - defensive logging + self._logger.error("Error deleting temporary file: %s", exc) + + def _encode_file(self, path: str) -> str: + with open(path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + def _convert_tensor_to_png_bytes(self, latest_image: Any) -> Optional[bytes]: + try: + if isinstance(latest_image, tuple): + tensor_image = latest_image[0] if latest_image else None + if tensor_image is None: + return None + else: + tensor_image = latest_image + + if hasattr(tensor_image, "shape"): + self._logger.debug( + "Tensor shape: %s, dtype: %s", tensor_image.shape, getattr(tensor_image, "dtype", None) + ) + + import torch # type: ignore[import-not-found] + + if isinstance(tensor_image, torch.Tensor): + image_np = tensor_image.cpu().numpy() + else: + image_np = np.array(tensor_image) + + while len(image_np.shape) > 3: + image_np = image_np[0] + + if image_np.dtype in (np.float32, np.float64) and image_np.max() <= 1.0: + image_np = (image_np * 255).astype(np.uint8) + + if len(image_np.shape) == 3 and image_np.shape[2] == 3: + pil_image = Image.fromarray(image_np) + img_byte_arr = io.BytesIO() + pil_image.save(img_byte_arr, format="PNG") + return img_byte_arr.getvalue() + except Exception as exc: # pragma: no cover - defensive logging path + self._logger.error("Error processing image data: %s", exc, exc_info=True) + return None + + return None diff --git a/py/services/recipes/errors.py b/py/services/recipes/errors.py new file mode 100644 index 00000000..9e5d9720 --- /dev/null +++ b/py/services/recipes/errors.py @@ -0,0 +1,22 @@ +"""Shared exceptions for recipe services.""" +from __future__ import annotations + + +class RecipeServiceError(Exception): + """Base exception for recipe service failures.""" + + +class RecipeValidationError(RecipeServiceError): + """Raised when a request payload fails validation.""" + + +class RecipeNotFoundError(RecipeServiceError): + """Raised when a recipe resource cannot be located.""" + + +class RecipeDownloadError(RecipeServiceError): + """Raised when remote recipe assets cannot be downloaded.""" + + +class RecipeConflictError(RecipeServiceError): + """Raised when a conflicting recipe state is detected.""" diff --git a/py/services/recipes/persistence_service.py b/py/services/recipes/persistence_service.py new file mode 100644 index 00000000..078ac906 --- /dev/null +++ b/py/services/recipes/persistence_service.py @@ -0,0 +1,400 @@ +"""Services encapsulating recipe persistence workflows.""" +from __future__ import annotations + +import base64 +import json +import os +import re +import time +import uuid +from dataclasses import dataclass +from typing import Any, Dict, Iterable, Optional + +from ...config import config +from ...utils.utils import calculate_recipe_fingerprint +from .errors import RecipeNotFoundError, RecipeValidationError + + +@dataclass(frozen=True) +class PersistenceResult: + """Return payload from persistence operations.""" + + payload: dict[str, Any] + status: int = 200 + + +class RecipePersistenceService: + """Coordinate recipe persistence tasks across storage and caches.""" + + def __init__( + self, + *, + exif_utils, + card_preview_width: int, + logger, + ) -> None: + self._exif_utils = exif_utils + self._card_preview_width = card_preview_width + self._logger = logger + + async def save_recipe( + self, + *, + recipe_scanner, + image_bytes: bytes | None, + image_base64: str | None, + name: str | None, + tags: Iterable[str], + metadata: Optional[dict[str, Any]], + ) -> PersistenceResult: + """Persist a user uploaded recipe.""" + + missing_fields = [] + if not name: + missing_fields.append("name") + if metadata is None: + missing_fields.append("metadata") + if missing_fields: + raise RecipeValidationError( + f"Missing required fields: {', '.join(missing_fields)}" + ) + + resolved_image_bytes = self._resolve_image_bytes(image_bytes, image_base64) + recipes_dir = recipe_scanner.recipes_dir + os.makedirs(recipes_dir, exist_ok=True) + + recipe_id = str(uuid.uuid4()) + optimized_image, extension = self._exif_utils.optimize_image( + image_data=resolved_image_bytes, + target_width=self._card_preview_width, + format="webp", + quality=85, + preserve_metadata=True, + ) + image_filename = f"{recipe_id}{extension}" + image_path = os.path.join(recipes_dir, image_filename) + with open(image_path, "wb") as file_obj: + file_obj.write(optimized_image) + + current_time = time.time() + loras_data = [self._normalise_lora_entry(lora) for lora in metadata.get("loras", [])] + + gen_params = metadata.get("gen_params", {}) + if not gen_params and "raw_metadata" in metadata: + raw_metadata = metadata.get("raw_metadata", {}) + gen_params = { + "prompt": raw_metadata.get("prompt", ""), + "negative_prompt": raw_metadata.get("negative_prompt", ""), + "checkpoint": raw_metadata.get("checkpoint", {}), + "steps": raw_metadata.get("steps", ""), + "sampler": raw_metadata.get("sampler", ""), + "cfg_scale": raw_metadata.get("cfg_scale", ""), + "seed": raw_metadata.get("seed", ""), + "size": raw_metadata.get("size", ""), + "clip_skip": raw_metadata.get("clip_skip", ""), + } + + fingerprint = calculate_recipe_fingerprint(loras_data) + recipe_data: Dict[str, Any] = { + "id": recipe_id, + "file_path": image_path, + "title": name, + "modified": current_time, + "created_date": current_time, + "base_model": metadata.get("base_model", ""), + "loras": loras_data, + "gen_params": gen_params, + "fingerprint": fingerprint, + } + + tags_list = list(tags) + if tags_list: + recipe_data["tags"] = tags_list + + if metadata.get("source_path"): + recipe_data["source_path"] = metadata.get("source_path") + + json_filename = f"{recipe_id}.recipe.json" + json_path = os.path.join(recipes_dir, json_filename) + with open(json_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + + matching_recipes = await self._find_matching_recipes(recipe_scanner, fingerprint, exclude_id=recipe_id) + await recipe_scanner.add_recipe(recipe_data) + + return PersistenceResult( + { + "success": True, + "recipe_id": recipe_id, + "image_path": image_path, + "json_path": json_path, + "matching_recipes": matching_recipes, + } + ) + + async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> PersistenceResult: + """Delete an existing recipe.""" + + recipes_dir = recipe_scanner.recipes_dir + if not recipes_dir or not os.path.exists(recipes_dir): + raise RecipeNotFoundError("Recipes directory not found") + + recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + raise RecipeNotFoundError("Recipe not found") + + with open(recipe_json_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + + image_path = recipe_data.get("file_path") + os.remove(recipe_json_path) + if image_path and os.path.exists(image_path): + os.remove(image_path) + + await recipe_scanner.remove_recipe(recipe_id) + return PersistenceResult({"success": True, "message": "Recipe deleted successfully"}) + + async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: dict[str, Any]) -> PersistenceResult: + """Update persisted metadata for a recipe.""" + + if not any(key in updates for key in ("title", "tags", "source_path", "preview_nsfw_level")): + raise RecipeValidationError( + "At least one field to update must be provided (title or tags or source_path or preview_nsfw_level)" + ) + + success = await recipe_scanner.update_recipe_metadata(recipe_id, updates) + if not success: + raise RecipeNotFoundError("Recipe not found or update failed") + + return PersistenceResult({"success": True, "recipe_id": recipe_id, "updates": updates}) + + async def reconnect_lora( + self, + *, + recipe_scanner, + recipe_id: str, + lora_index: int, + target_name: str, + ) -> PersistenceResult: + """Reconnect a LoRA entry within an existing recipe.""" + + recipe_path = os.path.join(recipe_scanner.recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_path): + raise RecipeNotFoundError("Recipe not found") + + target_lora = await recipe_scanner.get_local_lora(target_name) + if not target_lora: + raise RecipeNotFoundError(f"Local LoRA not found with name: {target_name}") + + recipe_data, updated_lora = await recipe_scanner.update_lora_entry( + recipe_id, + lora_index, + target_name=target_name, + target_lora=target_lora, + ) + + image_path = recipe_data.get("file_path") + if image_path and os.path.exists(image_path): + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + + matching_recipes = [] + if "fingerprint" in recipe_data: + matching_recipes = await recipe_scanner.find_recipes_by_fingerprint(recipe_data["fingerprint"]) + if recipe_id in matching_recipes: + matching_recipes.remove(recipe_id) + + return PersistenceResult( + { + "success": True, + "recipe_id": recipe_id, + "updated_lora": updated_lora, + "matching_recipes": matching_recipes, + } + ) + + async def bulk_delete( + self, + *, + recipe_scanner, + recipe_ids: Iterable[str], + ) -> PersistenceResult: + """Delete multiple recipes in a single request.""" + + recipe_ids = list(recipe_ids) + if not recipe_ids: + raise RecipeValidationError("No recipe IDs provided") + + recipes_dir = recipe_scanner.recipes_dir + if not recipes_dir or not os.path.exists(recipes_dir): + raise RecipeNotFoundError("Recipes directory not found") + + deleted_recipes: list[str] = [] + failed_recipes: list[dict[str, Any]] = [] + + for recipe_id in recipe_ids: + recipe_json_path = os.path.join(recipes_dir, f"{recipe_id}.recipe.json") + if not os.path.exists(recipe_json_path): + failed_recipes.append({"id": recipe_id, "reason": "Recipe not found"}) + continue + + try: + with open(recipe_json_path, "r", encoding="utf-8") as file_obj: + recipe_data = json.load(file_obj) + image_path = recipe_data.get("file_path") + os.remove(recipe_json_path) + if image_path and os.path.exists(image_path): + os.remove(image_path) + deleted_recipes.append(recipe_id) + except Exception as exc: + failed_recipes.append({"id": recipe_id, "reason": str(exc)}) + + if deleted_recipes: + await recipe_scanner.bulk_remove(deleted_recipes) + + return PersistenceResult( + { + "success": True, + "deleted": deleted_recipes, + "failed": failed_recipes, + "total_deleted": len(deleted_recipes), + "total_failed": len(failed_recipes), + } + ) + + async def save_recipe_from_widget( + self, + *, + recipe_scanner, + metadata: dict[str, Any], + image_bytes: bytes, + ) -> PersistenceResult: + """Save a recipe constructed from widget metadata.""" + + if not metadata: + raise RecipeValidationError("No generation metadata found") + + recipes_dir = recipe_scanner.recipes_dir + os.makedirs(recipes_dir, exist_ok=True) + + recipe_id = str(uuid.uuid4()) + image_filename = f"{recipe_id}.png" + image_path = os.path.join(recipes_dir, image_filename) + with open(image_path, "wb") as file_obj: + file_obj.write(image_bytes) + + lora_stack = metadata.get("loras", "") + lora_matches = re.findall(r"]+)>", lora_stack) + if not lora_matches: + raise RecipeValidationError("No LoRAs found in the generation metadata") + + loras_data = [] + base_model_counts: Dict[str, int] = {} + + for name, strength in lora_matches: + lora_info = await recipe_scanner.get_local_lora(name) + lora_data = { + "file_name": name, + "strength": float(strength), + "hash": (lora_info.get("sha256") or "").lower() if lora_info else "", + "modelVersionId": lora_info.get("civitai", {}).get("id") if lora_info else 0, + "modelName": lora_info.get("civitai", {}).get("model", {}).get("name") if lora_info else "", + "modelVersionName": lora_info.get("civitai", {}).get("name") if lora_info else "", + "isDeleted": False, + "exclude": False, + } + loras_data.append(lora_data) + + if lora_info and "base_model" in lora_info: + base_model = lora_info["base_model"] + base_model_counts[base_model] = base_model_counts.get(base_model, 0) + 1 + + recipe_name = self._derive_recipe_name(lora_matches) + most_common_base_model = ( + max(base_model_counts.items(), key=lambda item: item[1])[0] if base_model_counts else "" + ) + + recipe_data = { + "id": recipe_id, + "file_path": image_path, + "title": recipe_name, + "modified": time.time(), + "created_date": time.time(), + "base_model": most_common_base_model, + "loras": loras_data, + "checkpoint": metadata.get("checkpoint", ""), + "gen_params": { + key: value + for key, value in metadata.items() + if key not in ["checkpoint", "loras"] + }, + "loras_stack": lora_stack, + } + + json_filename = f"{recipe_id}.recipe.json" + json_path = os.path.join(recipes_dir, json_filename) + with open(json_path, "w", encoding="utf-8") as file_obj: + json.dump(recipe_data, file_obj, indent=4, ensure_ascii=False) + + self._exif_utils.append_recipe_metadata(image_path, recipe_data) + await recipe_scanner.add_recipe(recipe_data) + + return PersistenceResult( + { + "success": True, + "recipe_id": recipe_id, + "image_path": image_path, + "json_path": json_path, + "recipe_name": recipe_name, + } + ) + + # Helper methods --------------------------------------------------- + + def _resolve_image_bytes(self, image_bytes: bytes | None, image_base64: str | None) -> bytes: + if image_bytes is not None: + return image_bytes + if image_base64: + try: + payload = image_base64.split(",", 1)[1] if "," in image_base64 else image_base64 + return base64.b64decode(payload) + except Exception as exc: # pragma: no cover - validation guard + raise RecipeValidationError(f"Invalid base64 image data: {exc}") from exc + raise RecipeValidationError("No image data provided") + + def _normalise_lora_entry(self, lora: dict[str, Any]) -> dict[str, Any]: + return { + "file_name": lora.get("file_name", "") + or ( + os.path.splitext(os.path.basename(lora.get("localPath", "")))[0] + if lora.get("localPath") + else "" + ), + "hash": (lora.get("hash") or "").lower(), + "strength": float(lora.get("weight", 1.0)), + "modelVersionId": lora.get("id", 0), + "modelName": lora.get("name", ""), + "modelVersionName": lora.get("version", ""), + "isDeleted": lora.get("isDeleted", False), + "exclude": lora.get("exclude", False), + } + + async def _find_matching_recipes( + self, + recipe_scanner, + fingerprint: str | None, + *, + exclude_id: Optional[str] = None, + ) -> list[str]: + if not fingerprint: + return [] + matches = await recipe_scanner.find_recipes_by_fingerprint(fingerprint) + if exclude_id and exclude_id in matches: + matches.remove(exclude_id) + return matches + + def _derive_recipe_name(self, lora_matches: list[tuple[str, str]]) -> str: + recipe_name_parts = [f"{name.strip()}-{float(strength):.2f}" for name, strength in lora_matches[:3]] + recipe_name = "_".join(recipe_name_parts) + return recipe_name or "recipe" diff --git a/py/services/recipes/sharing_service.py b/py/services/recipes/sharing_service.py new file mode 100644 index 00000000..47ab9718 --- /dev/null +++ b/py/services/recipes/sharing_service.py @@ -0,0 +1,105 @@ +"""Services handling recipe sharing and downloads.""" +from __future__ import annotations + +import os +import shutil +import tempfile +import time +from dataclasses import dataclass +from typing import Any, Dict + +from .errors import RecipeNotFoundError + + +@dataclass(frozen=True) +class SharingResult: + """Return payload for share operations.""" + + payload: dict[str, Any] + status: int = 200 + + +@dataclass(frozen=True) +class DownloadInfo: + """Information required to stream a shared recipe file.""" + + file_path: str + download_filename: str + + +class RecipeSharingService: + """Prepare temporary recipe downloads with TTL cleanup.""" + + def __init__(self, *, ttl_seconds: int = 300, logger) -> None: + self._ttl_seconds = ttl_seconds + self._logger = logger + self._shared_recipes: Dict[str, Dict[str, Any]] = {} + + async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SharingResult: + """Prepare a temporary downloadable copy of a recipe image.""" + + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) + if not recipe: + raise RecipeNotFoundError("Recipe not found") + + image_path = recipe.get("file_path") + if not image_path or not os.path.exists(image_path): + raise RecipeNotFoundError("Recipe image not found") + + ext = os.path.splitext(image_path)[1] + with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as temp_file: + temp_path = temp_file.name + + shutil.copy2(image_path, temp_path) + timestamp = int(time.time()) + self._shared_recipes[recipe_id] = { + "path": temp_path, + "timestamp": timestamp, + "expires": time.time() + self._ttl_seconds, + } + self._cleanup_shared_recipes() + + safe_title = recipe.get("title", "").replace(" ", "_").lower() + filename = f"recipe_{safe_title}{ext}" if safe_title else f"recipe_{recipe_id}{ext}" + url_path = f"/api/recipe/{recipe_id}/share/download?t={timestamp}" + return SharingResult({"success": True, "download_url": url_path, "filename": filename}) + + async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> DownloadInfo: + """Return file path and filename for a prepared shared recipe.""" + + shared_info = self._shared_recipes.get(recipe_id) + if not shared_info or time.time() > shared_info.get("expires", 0): + self._cleanup_entry(recipe_id) + raise RecipeNotFoundError("Shared recipe not found or expired") + + file_path = shared_info["path"] + if not os.path.exists(file_path): + self._cleanup_entry(recipe_id) + raise RecipeNotFoundError("Shared recipe file not found") + + recipe = await recipe_scanner.get_recipe_by_id(recipe_id) + filename_base = ( + f"recipe_{recipe.get('title', '').replace(' ', '_').lower()}" if recipe else recipe_id + ) + ext = os.path.splitext(file_path)[1] + download_filename = f"{filename_base}{ext}" + return DownloadInfo(file_path=file_path, download_filename=download_filename) + + def _cleanup_shared_recipes(self) -> None: + for recipe_id in list(self._shared_recipes.keys()): + shared = self._shared_recipes.get(recipe_id) + if not shared: + continue + if time.time() > shared.get("expires", 0): + self._cleanup_entry(recipe_id) + + def _cleanup_entry(self, recipe_id: str) -> None: + shared_info = self._shared_recipes.pop(recipe_id, None) + if not shared_info: + return + file_path = shared_info.get("path") + if file_path and os.path.exists(file_path): + try: + os.unlink(file_path) + except Exception as exc: # pragma: no cover - defensive logging + self._logger.error("Error cleaning up shared recipe %s: %s", recipe_id, exc) diff --git a/py/services/tag_update_service.py b/py/services/tag_update_service.py new file mode 100644 index 00000000..d560e7d6 --- /dev/null +++ b/py/services/tag_update_service.py @@ -0,0 +1,47 @@ +"""Service for updating tag collections on metadata records.""" + +from __future__ import annotations + +import os + +from typing import Awaitable, Callable, Dict, List, Sequence + + +class TagUpdateService: + """Encapsulate tag manipulation for models.""" + + def __init__(self, *, metadata_manager) -> None: + self._metadata_manager = metadata_manager + + async def add_tags( + self, + *, + file_path: str, + new_tags: Sequence[str], + metadata_loader: Callable[[str], Awaitable[Dict[str, object]]], + update_cache: Callable[[str, str, Dict[str, object]], Awaitable[bool]], + ) -> List[str]: + """Add tags to a metadata entry while keeping case-insensitive uniqueness.""" + + base, _ = os.path.splitext(file_path) + metadata_path = f"{base}.metadata.json" + metadata = await metadata_loader(metadata_path) + + existing_tags = list(metadata.get("tags", [])) + existing_lower = [tag.lower() for tag in existing_tags] + + tags_added: List[str] = [] + for tag in new_tags: + if isinstance(tag, str) and tag.strip(): + normalized = tag.strip() + if normalized.lower() not in existing_lower: + existing_tags.append(normalized) + existing_lower.append(normalized.lower()) + tags_added.append(normalized) + + metadata["tags"] = existing_tags + await self._metadata_manager.save_metadata(file_path, metadata) + await update_cache(file_path, file_path, metadata) + + return existing_tags + diff --git a/py/services/use_cases/__init__.py b/py/services/use_cases/__init__.py new file mode 100644 index 00000000..8a43318c --- /dev/null +++ b/py/services/use_cases/__init__.py @@ -0,0 +1,37 @@ +"""Application-level orchestration services for model routes.""" + +from .auto_organize_use_case import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, +) +from .bulk_metadata_refresh_use_case import ( + BulkMetadataRefreshUseCase, + MetadataRefreshProgressReporter, +) +from .download_model_use_case import ( + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, +) +from .example_images import ( + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) + +__all__ = [ + "AutoOrganizeInProgressError", + "AutoOrganizeUseCase", + "BulkMetadataRefreshUseCase", + "MetadataRefreshProgressReporter", + "DownloadModelEarlyAccessError", + "DownloadModelUseCase", + "DownloadModelValidationError", + "DownloadExampleImagesConfigurationError", + "DownloadExampleImagesInProgressError", + "DownloadExampleImagesUseCase", + "ImportExampleImagesUseCase", + "ImportExampleImagesValidationError", +] diff --git a/py/services/use_cases/auto_organize_use_case.py b/py/services/use_cases/auto_organize_use_case.py new file mode 100644 index 00000000..0914739f --- /dev/null +++ b/py/services/use_cases/auto_organize_use_case.py @@ -0,0 +1,56 @@ +"""Auto-organize use case orchestrating concurrency and progress handling.""" + +from __future__ import annotations + +import asyncio +from typing import Optional, Protocol, Sequence + +from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback + + +class AutoOrganizeLockProvider(Protocol): + """Minimal protocol for objects exposing auto-organize locking primitives.""" + + def is_auto_organize_running(self) -> bool: + """Return ``True`` when an auto-organize operation is in-flight.""" + + async def get_auto_organize_lock(self) -> asyncio.Lock: + """Return the asyncio lock guarding auto-organize operations.""" + + +class AutoOrganizeInProgressError(RuntimeError): + """Raised when an auto-organize run is already active.""" + + +class AutoOrganizeUseCase: + """Coordinate auto-organize execution behind a shared lock.""" + + def __init__( + self, + *, + file_service: ModelFileService, + lock_provider: AutoOrganizeLockProvider, + ) -> None: + self._file_service = file_service + self._lock_provider = lock_provider + + async def execute( + self, + *, + file_paths: Optional[Sequence[str]] = None, + progress_callback: Optional[ProgressCallback] = None, + ) -> AutoOrganizeResult: + """Run the auto-organize routine guarded by a shared lock.""" + + if self._lock_provider.is_auto_organize_running(): + raise AutoOrganizeInProgressError("Auto-organize is already running") + + lock = await self._lock_provider.get_auto_organize_lock() + if lock.locked(): + raise AutoOrganizeInProgressError("Auto-organize is already running") + + async with lock: + return await self._file_service.auto_organize_models( + file_paths=list(file_paths) if file_paths is not None else None, + progress_callback=progress_callback, + ) diff --git a/py/services/use_cases/bulk_metadata_refresh_use_case.py b/py/services/use_cases/bulk_metadata_refresh_use_case.py new file mode 100644 index 00000000..6a809955 --- /dev/null +++ b/py/services/use_cases/bulk_metadata_refresh_use_case.py @@ -0,0 +1,122 @@ +"""Use case encapsulating the bulk metadata refresh orchestration.""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, Optional, Protocol, Sequence + +from ..metadata_sync_service import MetadataSyncService + + +class MetadataRefreshProgressReporter(Protocol): + """Protocol for progress reporters used during metadata refresh.""" + + async def on_progress(self, payload: Dict[str, Any]) -> None: + """Handle a metadata refresh progress update.""" + + +class BulkMetadataRefreshUseCase: + """Coordinate bulk metadata refreshes with progress emission.""" + + def __init__( + self, + *, + service, + metadata_sync: MetadataSyncService, + settings_service, + logger: Optional[logging.Logger] = None, + ) -> None: + self._service = service + self._metadata_sync = metadata_sync + self._settings = settings_service + self._logger = logger or logging.getLogger(__name__) + + async def execute( + self, + *, + progress_callback: Optional[MetadataRefreshProgressReporter] = None, + ) -> Dict[str, Any]: + """Refresh metadata for all qualifying models.""" + + cache = await self._service.scanner.get_cached_data() + total_models = len(cache.raw_data) + + enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False) + to_process: Sequence[Dict[str, Any]] = [ + model + for model in cache.raw_data + if model.get("sha256") + and (not model.get("civitai") or not model["civitai"].get("id")) + and ( + (enable_metadata_archive_db and not model.get("db_checked", False)) + or (not enable_metadata_archive_db and model.get("from_civitai") is True) + ) + ] + + total_to_process = len(to_process) + processed = 0 + success = 0 + needs_resort = False + + async def emit(status: str, **extra: Any) -> None: + if progress_callback is None: + return + payload = {"status": status, "total": total_to_process, "processed": processed, "success": success} + payload.update(extra) + await progress_callback.on_progress(payload) + + await emit("started") + + for model in to_process: + try: + original_name = model.get("model_name") + result, _ = await self._metadata_sync.fetch_and_update_model( + sha256=model["sha256"], + file_path=model["file_path"], + model_data=model, + update_cache_func=self._service.scanner.update_single_model_cache, + ) + if result: + success += 1 + if original_name != model.get("model_name"): + needs_resort = True + processed += 1 + await emit( + "processing", + processed=processed, + success=success, + current_name=model.get("model_name", "Unknown"), + ) + except Exception as exc: # pragma: no cover - logging path + processed += 1 + self._logger.error( + "Error fetching CivitAI data for %s: %s", + model.get("file_path"), + exc, + ) + + if needs_resort: + await cache.resort() + + await emit("completed", processed=processed, success=success) + + message = ( + "Successfully updated " + f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})" + ) + + return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models} + + async def execute_with_error_handling( + self, + *, + progress_callback: Optional[MetadataRefreshProgressReporter] = None, + ) -> Dict[str, Any]: + """Wrapper providing progress notification on unexpected failures.""" + + try: + return await self.execute(progress_callback=progress_callback) + except Exception as exc: + if progress_callback is not None: + await progress_callback.on_progress({"status": "error", "error": str(exc)}) + raise diff --git a/py/services/use_cases/download_model_use_case.py b/py/services/use_cases/download_model_use_case.py new file mode 100644 index 00000000..5aa25bda --- /dev/null +++ b/py/services/use_cases/download_model_use_case.py @@ -0,0 +1,37 @@ +"""Use case for scheduling model downloads with consistent error handling.""" + +from __future__ import annotations + +from typing import Any, Dict + +from ..download_coordinator import DownloadCoordinator + + +class DownloadModelValidationError(ValueError): + """Raised when incoming payload validation fails.""" + + +class DownloadModelEarlyAccessError(RuntimeError): + """Raised when the download is gated behind Civitai early access.""" + + +class DownloadModelUseCase: + """Coordinate download scheduling through the coordinator service.""" + + def __init__(self, *, download_coordinator: DownloadCoordinator) -> None: + self._download_coordinator = download_coordinator + + async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Schedule a download and normalize error conditions.""" + + try: + return await self._download_coordinator.schedule_download(payload) + except ValueError as exc: + raise DownloadModelValidationError(str(exc)) from exc + except Exception as exc: # pragma: no cover - defensive logging path + message = str(exc) + if "401" in message: + raise DownloadModelEarlyAccessError( + "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com." + ) from exc + raise diff --git a/py/services/use_cases/example_images/__init__.py b/py/services/use_cases/example_images/__init__.py new file mode 100644 index 00000000..820de618 --- /dev/null +++ b/py/services/use_cases/example_images/__init__.py @@ -0,0 +1,19 @@ +"""Example image specific use case exports.""" + +from .download_example_images_use_case import ( + DownloadExampleImagesUseCase, + DownloadExampleImagesInProgressError, + DownloadExampleImagesConfigurationError, +) +from .import_example_images_use_case import ( + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) + +__all__ = [ + "DownloadExampleImagesUseCase", + "DownloadExampleImagesInProgressError", + "DownloadExampleImagesConfigurationError", + "ImportExampleImagesUseCase", + "ImportExampleImagesValidationError", +] diff --git a/py/services/use_cases/example_images/download_example_images_use_case.py b/py/services/use_cases/example_images/download_example_images_use_case.py new file mode 100644 index 00000000..e9a51e13 --- /dev/null +++ b/py/services/use_cases/example_images/download_example_images_use_case.py @@ -0,0 +1,42 @@ +"""Use case coordinating example image downloads.""" + +from __future__ import annotations + +from typing import Any, Dict + +from ....utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + ExampleImagesDownloadError, +) + + +class DownloadExampleImagesInProgressError(RuntimeError): + """Raised when a download is already running.""" + + def __init__(self, progress: Dict[str, Any]) -> None: + super().__init__("Download already in progress") + self.progress = progress + + +class DownloadExampleImagesConfigurationError(ValueError): + """Raised when settings prevent downloads from starting.""" + + +class DownloadExampleImagesUseCase: + """Validate payloads and trigger the download manager.""" + + def __init__(self, *, download_manager) -> None: + self._download_manager = download_manager + + async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """Start a download and translate manager errors.""" + + try: + return await self._download_manager.start_download(payload) + except DownloadInProgressError as exc: + raise DownloadExampleImagesInProgressError(exc.progress_snapshot) from exc + except DownloadConfigurationError as exc: + raise DownloadExampleImagesConfigurationError(str(exc)) from exc + except ExampleImagesDownloadError: + raise diff --git a/py/services/use_cases/example_images/import_example_images_use_case.py b/py/services/use_cases/example_images/import_example_images_use_case.py new file mode 100644 index 00000000..547b2f4e --- /dev/null +++ b/py/services/use_cases/example_images/import_example_images_use_case.py @@ -0,0 +1,86 @@ +"""Use case for importing example images.""" + +from __future__ import annotations + +import os +import tempfile +from contextlib import suppress +from typing import Any, Dict, List + +from aiohttp import web + +from ....utils.example_images_processor import ( + ExampleImagesImportError, + ExampleImagesProcessor, + ExampleImagesValidationError, +) + + +class ImportExampleImagesValidationError(ValueError): + """Raised when request validation fails.""" + + +class ImportExampleImagesUseCase: + """Parse upload payloads and delegate to the processor service.""" + + def __init__(self, *, processor: ExampleImagesProcessor) -> None: + self._processor = processor + + async def execute(self, request: web.Request) -> Dict[str, Any]: + model_hash: str | None = None + files_to_import: List[str] = [] + temp_files: List[str] = [] + + try: + if request.content_type and "multipart/form-data" in request.content_type: + reader = await request.multipart() + + first_field = await reader.next() + if first_field and first_field.name == "model_hash": + model_hash = await first_field.text() + else: + # Support clients that send files first and hash later + if first_field is not None: + await self._collect_upload_file(first_field, files_to_import, temp_files) + + async for field in reader: + if field.name == "model_hash" and not model_hash: + model_hash = await field.text() + elif field.name == "files": + await self._collect_upload_file(field, files_to_import, temp_files) + else: + data = await request.json() + model_hash = data.get("model_hash") + files_to_import = list(data.get("file_paths", [])) + + result = await self._processor.import_images(model_hash, files_to_import) + return result + except ExampleImagesValidationError as exc: + raise ImportExampleImagesValidationError(str(exc)) from exc + except ExampleImagesImportError: + raise + finally: + for path in temp_files: + with suppress(Exception): + os.remove(path) + + async def _collect_upload_file( + self, + field: Any, + files_to_import: List[str], + temp_files: List[str], + ) -> None: + """Persist an uploaded file to disk and add it to the import list.""" + + filename = field.filename or "upload" + file_ext = os.path.splitext(filename)[1].lower() + + with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file: + temp_files.append(tmp_file.name) + while True: + chunk = await field.read_chunk() + if not chunk: + break + tmp_file.write(chunk) + + files_to_import.append(tmp_file.name) diff --git a/py/services/websocket_progress_callback.py b/py/services/websocket_progress_callback.py index 1a390f30..21423044 100644 --- a/py/services/websocket_progress_callback.py +++ b/py/services/websocket_progress_callback.py @@ -1,11 +1,29 @@ -from typing import Dict, Any +"""Progress callback implementations backed by the shared WebSocket manager.""" + +from typing import Any, Dict, Protocol + from .model_file_service import ProgressCallback from .websocket_manager import ws_manager -class WebSocketProgressCallback(ProgressCallback): - """WebSocket implementation of progress callback""" - +class ProgressReporter(Protocol): + """Protocol representing an async progress callback.""" + async def on_progress(self, progress_data: Dict[str, Any]) -> None: - """Send progress data via WebSocket""" - await ws_manager.broadcast_auto_organize_progress(progress_data) \ No newline at end of file + """Handle a progress update payload.""" + + +class WebSocketProgressCallback(ProgressCallback): + """WebSocket implementation of progress callback.""" + + async def on_progress(self, progress_data: Dict[str, Any]) -> None: + """Send progress data via WebSocket.""" + await ws_manager.broadcast_auto_organize_progress(progress_data) + + +class WebSocketBroadcastCallback: + """Generic WebSocket progress callback broadcasting to all clients.""" + + async def on_progress(self, progress_data: Dict[str, Any]) -> None: + """Send the provided payload to all connected clients.""" + await ws_manager.broadcast(progress_data) diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index 842192f2..9ddf03a4 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -1,218 +1,216 @@ +from __future__ import annotations + import logging import os import asyncio import json import time -from aiohttp import web +from typing import Any, Dict + from ..services.service_registry import ServiceRegistry from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater -from ..services.websocket_manager import ws_manager # Add this import at the top from ..services.downloader import get_downloader from ..services.settings_manager import settings + +class ExampleImagesDownloadError(RuntimeError): + """Base error for example image download operations.""" + + +class DownloadInProgressError(ExampleImagesDownloadError): + """Raised when a download is already running.""" + + def __init__(self, progress_snapshot: dict) -> None: + super().__init__("Download already in progress") + self.progress_snapshot = progress_snapshot + + +class DownloadNotRunningError(ExampleImagesDownloadError): + """Raised when pause/resume is requested without an active download.""" + + def __init__(self, message: str = "No download in progress") -> None: + super().__init__(message) + + +class DownloadConfigurationError(ExampleImagesDownloadError): + """Raised when configuration prevents starting a download.""" + + logger = logging.getLogger(__name__) -# Download status tracking -download_task = None -is_downloading = False -download_progress = { - 'total': 0, - 'completed': 0, - 'current_model': '', - 'status': 'idle', # idle, running, paused, completed, error - 'errors': [], - 'last_error': None, - 'start_time': None, - 'end_time': None, - 'processed_models': set(), # Track models that have been processed - 'refreshed_models': set(), # Track models that had metadata refreshed - 'failed_models': set() # Track models that failed to download after metadata refresh -} + +class _DownloadProgress(dict): + """Mutable mapping maintaining download progress with set-aware serialisation.""" + + def __init__(self) -> None: + super().__init__() + self.reset() + + def reset(self) -> None: + """Reset the progress dictionary to its initial state.""" + + self.update( + total=0, + completed=0, + current_model='', + status='idle', + errors=[], + last_error=None, + start_time=None, + end_time=None, + processed_models=set(), + refreshed_models=set(), + failed_models=set(), + ) + + def snapshot(self) -> dict: + """Return a JSON-serialisable snapshot of the current progress.""" + + snapshot = dict(self) + snapshot['processed_models'] = list(self['processed_models']) + snapshot['refreshed_models'] = list(self['refreshed_models']) + snapshot['failed_models'] = list(self['failed_models']) + return snapshot class DownloadManager: - """Manages downloading example images for models""" - - @staticmethod - async def start_download(request): - """ - Start downloading example images for models - - Expects a JSON body with: - { - "optimize": true, # Whether to optimize images (default: true) - "model_types": ["lora", "checkpoint"], # Model types to process (default: both) - "delay": 1.0, # Delay between downloads to avoid rate limiting (default: 1.0) - "auto_mode": false # Flag to indicate automatic download (default: false) - } - """ - global download_task, is_downloading, download_progress - - if is_downloading: - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ - 'success': False, - 'error': 'Download already in progress', - 'status': response_progress - }, status=400) - - try: - # Parse the request body - data = await request.json() - auto_mode = data.get('auto_mode', False) - optimize = data.get('optimize', True) - model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds - - # Get output directory from settings - output_dir = settings.get('example_images_path') + """Manages downloading example images for models.""" - if not output_dir: - error_msg = 'Example images path not configured in settings' - if auto_mode: - # For auto mode, just log and return success to avoid showing error toasts - logger.debug(error_msg) - return web.json_response({ - 'success': True, - 'message': 'Example images path not configured, skipping auto download' - }) + def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None: + self._download_task: asyncio.Task | None = None + self._is_downloading = False + self._progress = _DownloadProgress() + self._ws_manager = ws_manager + self._state_lock = state_lock or asyncio.Lock() + + async def start_download(self, options: dict): + """Start downloading example images for models.""" + + async with self._state_lock: + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) + + try: + data = options or {} + auto_mode = data.get('auto_mode', False) + optimize = data.get('optimize', True) + model_types = data.get('model_types', ['lora', 'checkpoint']) + delay = float(data.get('delay', 0.2)) + + output_dir = settings.get('example_images_path') + + if not output_dir: + error_msg = 'Example images path not configured in settings' + if auto_mode: + logger.debug(error_msg) + return { + 'success': True, + 'message': 'Example images path not configured, skipping auto download' + } + raise DownloadConfigurationError(error_msg) + + os.makedirs(output_dir, exist_ok=True) + + self._progress.reset() + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None + + progress_file = os.path.join(output_dir, '.download_progress.json') + if os.path.exists(progress_file): + try: + with open(progress_file, 'r', encoding='utf-8') as f: + saved_progress = json.load(f) + self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) + self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) + logger.debug( + "Loaded previous progress, %s models already processed, %s models marked as failed", + len(self._progress['processed_models']), + len(self._progress['failed_models']), + ) + except Exception as e: + logger.error(f"Failed to load progress file: {e}") + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() else: - return web.json_response({ - 'success': False, - 'error': error_msg - }, status=400) - - # Create the output directory - os.makedirs(output_dir, exist_ok=True) - - # Initialize progress tracking - download_progress['total'] = 0 - download_progress['completed'] = 0 - download_progress['current_model'] = '' - download_progress['status'] = 'running' - download_progress['errors'] = [] - download_progress['last_error'] = None - download_progress['start_time'] = time.time() - download_progress['end_time'] = None - - # Get the processed models list from a file if it exists - progress_file = os.path.join(output_dir, '.download_progress.json') - if os.path.exists(progress_file): - try: - with open(progress_file, 'r', encoding='utf-8') as f: - saved_progress = json.load(f) - download_progress['processed_models'] = set(saved_progress.get('processed_models', [])) - download_progress['failed_models'] = set(saved_progress.get('failed_models', [])) - logger.debug(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed, {len(download_progress['failed_models'])} models marked as failed") - except Exception as e: - logger.error(f"Failed to load progress file: {e}") - download_progress['processed_models'] = set() - download_progress['failed_models'] = set() - else: - download_progress['processed_models'] = set() - download_progress['failed_models'] = set() - - # Start the download task - is_downloading = True - download_task = asyncio.create_task( - DownloadManager._download_all_example_images( - output_dir, - optimize, - model_types, - delay - ) - ) - - # Create a copy for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ - 'success': True, - 'message': 'Download started', - 'status': response_progress - }) - - except Exception as e: - logger.error(f"Failed to start example images download: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - - @staticmethod - async def get_status(request): - """Get the current status of example images download""" - global download_progress - - # Create a copy of the progress dict with the set converted to a list for JSON serialization - response_progress = download_progress.copy() - response_progress['processed_models'] = list(download_progress['processed_models']) - response_progress['refreshed_models'] = list(download_progress['refreshed_models']) - response_progress['failed_models'] = list(download_progress['failed_models']) - - return web.json_response({ - 'success': True, - 'is_downloading': is_downloading, - 'status': response_progress - }) + self._progress['processed_models'] = set() + self._progress['failed_models'] = set() - @staticmethod - async def pause_download(request): - """Pause the example images download""" - global download_progress - - if not is_downloading: - return web.json_response({ - 'success': False, - 'error': 'No download in progress' - }, status=400) - - download_progress['status'] = 'paused' - - return web.json_response({ + self._is_downloading = True + self._download_task = asyncio.create_task( + self._download_all_example_images( + output_dir, + optimize, + model_types, + delay + ) + ) + + snapshot = self._progress.snapshot() + except Exception as e: + self._is_downloading = False + self._download_task = None + logger.error(f"Failed to start example images download: {e}", exc_info=True) + raise ExampleImagesDownloadError(str(e)) from e + + await self._broadcast_progress(status='running') + + return { + 'success': True, + 'message': 'Download started', + 'status': snapshot + } + + async def get_status(self, request): + """Get the current status of example images download.""" + + return { + 'success': True, + 'is_downloading': self._is_downloading, + 'status': self._progress.snapshot(), + } + + async def pause_download(self, request): + """Pause the example images download.""" + + async with self._state_lock: + if not self._is_downloading: + raise DownloadNotRunningError() + + self._progress['status'] = 'paused' + + await self._broadcast_progress(status='paused') + + return { 'success': True, 'message': 'Download paused' - }) + } - @staticmethod - async def resume_download(request): - """Resume the example images download""" - global download_progress - - if not is_downloading: - return web.json_response({ - 'success': False, - 'error': 'No download in progress' - }, status=400) - - if download_progress['status'] == 'paused': - download_progress['status'] = 'running' - - return web.json_response({ - 'success': True, - 'message': 'Download resumed' - }) - else: - return web.json_response({ - 'success': False, - 'error': f"Download is in '{download_progress['status']}' state, cannot resume" - }, status=400) + async def resume_download(self, request): + """Resume the example images download.""" + + async with self._state_lock: + if not self._is_downloading: + raise DownloadNotRunningError() + + if self._progress['status'] == 'paused': + self._progress['status'] = 'running' + else: + raise DownloadNotRunningError( + f"Download is in '{self._progress['status']}' state, cannot resume" + ) + + await self._broadcast_progress(status='running') + + return { + 'success': True, + 'message': 'Download resumed' + } - @staticmethod - async def _download_all_example_images(output_dir, optimize, model_types, delay): - """Download example images for all models""" - global is_downloading, download_progress - - # Get unified downloader + async def _download_all_example_images(self, output_dir, optimize, model_types, delay): + """Download example images for all models.""" + downloader = await get_downloader() try: @@ -240,59 +238,67 @@ class DownloadManager: all_models.append((scanner_type, model, scanner)) # Update total count - download_progress['total'] = len(all_models) - logger.debug(f"Found {download_progress['total']} models to process") + self._progress['total'] = len(all_models) + logger.debug(f"Found {self._progress['total']} models to process") + await self._broadcast_progress(status='running') # Process each model for i, (scanner_type, model, scanner) in enumerate(all_models): # Main logic for processing model is here, but actual operations are delegated to other classes - was_remote_download = await DownloadManager._process_model( - scanner_type, model, scanner, + was_remote_download = await self._process_model( + scanner_type, model, scanner, output_dir, optimize, downloader ) # Update progress - download_progress['completed'] += 1 + self._progress['completed'] += 1 + await self._broadcast_progress(status='running') # Only add delay after remote download of models, and not after processing the last model - if was_remote_download and i < len(all_models) - 1 and download_progress['status'] == 'running': + if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed - download_progress['status'] = 'completed' - download_progress['end_time'] = time.time() - logger.debug(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") - + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() + logger.debug( + "Example images download completed: %s/%s models processed", + self._progress['completed'], + self._progress['total'], + ) + await self._broadcast_progress(status='completed') + except Exception as e: error_msg = f"Error during example images download: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - download_progress['status'] = 'error' - download_progress['end_time'] = time.time() - + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() + await self._broadcast_progress(status='error', extra={'error': error_msg}) + finally: # Save final progress to file try: - DownloadManager._save_progress(output_dir) + self._save_progress(output_dir) except Exception as e: logger.error(f"Failed to save progress file: {e}") - + # Set download status to not downloading - is_downloading = False + async with self._state_lock: + self._is_downloading = False + self._download_task = None - @staticmethod - async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader): - """Process a single model download""" - global download_progress - + async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + """Process a single model download.""" + # Check if download is paused - while download_progress['status'] == 'paused': + while self._progress['status'] == 'paused': await asyncio.sleep(1) - + # Check if download should continue - if download_progress['status'] != 'running': - logger.info(f"Download stopped: {download_progress['status']}") + if self._progress['status'] != 'running': + logger.info(f"Download stopped: {self._progress['status']}") return False # Return False to indicate no remote download happened model_hash = model.get('sha256', '').lower() @@ -302,15 +308,16 @@ class DownloadManager: try: # Update current model info - download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') # Skip if already in failed models - if model_hash in download_progress['failed_models']: + if model_hash in self._progress['failed_models']: logger.debug(f"Skipping known failed model: {model_name}") return False # Skip if already processed AND directory exists with files - if model_hash in download_progress['processed_models']: + if model_hash in self._progress['processed_models']: model_dir = os.path.join(output_dir, model_hash) has_files = os.path.exists(model_dir) and any(os.listdir(model_dir)) if has_files: @@ -319,7 +326,7 @@ class DownloadManager: else: logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") # Remove from processed models since we need to reprocess - download_progress['processed_models'].discard(model_hash) + self._progress['processed_models'].discard(model_hash) # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -335,7 +342,7 @@ class DownloadManager: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened # If no local images, try to download from remote @@ -347,57 +354,55 @@ class DownloadManager: ) # If metadata is stale, try to refresh it - if is_stale and model_hash not in download_progress['refreshed_models']: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( - model_hash, model_name, scanner_type, scanner + model_hash, model_name, scanner_type, scanner, self._progress ) - + # Get the updated model data updated_model = await MetadataUpdater.get_updated_model( model_hash, scanner ) - + if updated_model and updated_model.get('civitai', {}).get('images'): # Retry download with updated metadata updated_images = updated_model.get('civitai', {}).get('images', []) success, _ = await ExampleImagesProcessor.download_model_images( model_hash, model_name, updated_images, model_dir, optimize, downloader ) - - download_progress['refreshed_models'].add(model_hash) + + self._progress['refreshed_models'].add(model_hash) # Mark as processed if successful, or as failed if unsuccessful after refresh if success: - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) else: # If we refreshed metadata and still failed, mark as permanently failed - if model_hash in download_progress['refreshed_models']: - download_progress['failed_models'].add(model_hash) + if model_hash in self._progress['refreshed_models']: + self._progress['failed_models'].add(model_hash) logger.info(f"Marking model {model_name} as failed after metadata refresh") return True # Return True to indicate a remote download happened else: # No civitai data or images available, mark as failed to avoid future attempts - download_progress['failed_models'].add(model_hash) + self._progress['failed_models'].add(model_hash) logger.debug(f"No civitai images available for model {model_name}, marking as failed") # Save progress periodically - if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1: - DownloadManager._save_progress(output_dir) + if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1: + self._save_progress(output_dir) return False # Default return if no conditions met except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg return False # Return False on exception - @staticmethod - def _save_progress(output_dir): - """Save download progress to file""" - global download_progress + def _save_progress(self, output_dir): + """Save download progress to file.""" try: progress_file = os.path.join(output_dir, '.download_progress.json') @@ -412,11 +417,11 @@ class DownloadManager: # Create new progress data progress_data = { - 'processed_models': list(download_progress['processed_models']), - 'refreshed_models': list(download_progress['refreshed_models']), - 'failed_models': list(download_progress['failed_models']), - 'completed': download_progress['completed'], - 'total': download_progress['total'], + 'processed_models': list(self._progress['processed_models']), + 'refreshed_models': list(self._progress['refreshed_models']), + 'failed_models': list(self._progress['failed_models']), + 'completed': self._progress['completed'], + 'total': self._progress['total'], 'last_update': time.time() } @@ -431,102 +436,67 @@ class DownloadManager: except Exception as e: logger.error(f"Failed to save progress file: {e}") - @staticmethod - async def start_force_download(request): - """ - Force download example images for specific models - - Expects a JSON body with: - { - "model_hashes": ["hash1", "hash2", ...], # List of model hashes to download - "optimize": true, # Whether to optimize images (default: true) - "model_types": ["lora", "checkpoint"], # Model types to process (default: both) - "delay": 1.0 # Delay between downloads (default: 1.0) - } - """ - global download_task, is_downloading, download_progress + async def start_force_download(self, options: dict): + """Force download example images for specific models.""" - if is_downloading: - return web.json_response({ - 'success': False, - 'error': 'Download already in progress' - }, status=400) + async with self._state_lock: + if self._is_downloading: + raise DownloadInProgressError(self._progress.snapshot()) - try: - # Parse the request body - data = await request.json() + data = options or {} model_hashes = data.get('model_hashes', []) optimize = data.get('optimize', True) model_types = data.get('model_types', ['lora', 'checkpoint']) - delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds - + delay = float(data.get('delay', 0.2)) + if not model_hashes: - return web.json_response({ - 'success': False, - 'error': 'Missing model_hashes parameter' - }, status=400) - - # Get output directory from settings + raise DownloadConfigurationError('Missing model_hashes parameter') + output_dir = settings.get('example_images_path') - + if not output_dir: - return web.json_response({ - 'success': False, - 'error': 'Example images path not configured in settings' - }, status=400) - - # Create the output directory + raise DownloadConfigurationError('Example images path not configured in settings') + os.makedirs(output_dir, exist_ok=True) - - # Initialize progress tracking - download_progress['total'] = len(model_hashes) - download_progress['completed'] = 0 - download_progress['current_model'] = '' - download_progress['status'] = 'running' - download_progress['errors'] = [] - download_progress['last_error'] = None - download_progress['start_time'] = time.time() - download_progress['end_time'] = None - download_progress['processed_models'] = set() - download_progress['refreshed_models'] = set() - download_progress['failed_models'] = set() - # Set download status to downloading - is_downloading = True + self._progress.reset() + self._progress['total'] = len(model_hashes) + self._progress['status'] = 'running' + self._progress['start_time'] = time.time() + self._progress['end_time'] = None - # Execute the download function directly instead of creating a background task - result = await DownloadManager._download_specific_models_example_images_sync( + self._is_downloading = True + + await self._broadcast_progress(status='running') + + try: + result = await self._download_specific_models_example_images_sync( model_hashes, - output_dir, - optimize, + output_dir, + optimize, model_types, delay ) - # Set download status to not downloading - is_downloading = False + async with self._state_lock: + self._is_downloading = False - return web.json_response({ + return { 'success': True, 'message': 'Force download completed', 'result': result - }) + } except Exception as e: - # Set download status to not downloading - is_downloading = False + async with self._state_lock: + self._is_downloading = False logger.error(f"Failed during forced example images download: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) + await self._broadcast_progress(status='error', extra={'error': str(e)}) + raise ExampleImagesDownloadError(str(e)) from e - @staticmethod - async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): - """Download example images for specific models only - synchronous version""" - global download_progress - - # Get unified downloader + async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): + """Download example images for specific models only - synchronous version.""" + downloader = await get_downloader() try: @@ -554,24 +524,18 @@ class DownloadManager: models_to_process.append((scanner_type, model, scanner)) # Update total count based on found models - download_progress['total'] = len(models_to_process) - logger.debug(f"Found {download_progress['total']} models to process") - + self._progress['total'] = len(models_to_process) + logger.debug(f"Found {self._progress['total']} models to process") + # Send initial progress via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': 0, - 'total': download_progress['total'], - 'status': 'running', - 'current_model': '' - }) + await self._broadcast_progress(status='running') # Process each model success_count = 0 for i, (scanner_type, model, scanner) in enumerate(models_to_process): # Force process this model regardless of previous status - was_successful = await DownloadManager._process_specific_model( - scanner_type, model, scanner, + was_successful = await self._process_specific_model( + scanner_type, model, scanner, output_dir, optimize, downloader ) @@ -579,59 +543,44 @@ class DownloadManager: success_count += 1 # Update progress - download_progress['completed'] += 1 - + self._progress['completed'] += 1 + # Send progress update via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], - 'status': 'running', - 'current_model': download_progress['current_model'] - }) + await self._broadcast_progress(status='running') # Only add delay after remote download, and not after processing the last model - if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running': + if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running': await asyncio.sleep(delay) # Mark as completed - download_progress['status'] = 'completed' - download_progress['end_time'] = time.time() - logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") - + self._progress['status'] = 'completed' + self._progress['end_time'] = time.time() + logger.debug( + "Forced example images download completed: %s/%s models processed", + self._progress['completed'], + self._progress['total'], + ) + # Send final progress via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], - 'status': 'completed', - 'current_model': '' - }) + await self._broadcast_progress(status='completed') return { - 'total': download_progress['total'], - 'processed': download_progress['completed'], + 'total': self._progress['total'], + 'processed': self._progress['completed'], 'successful': success_count, - 'errors': download_progress['errors'] + 'errors': self._progress['errors'] } except Exception as e: error_msg = f"Error during forced example images download: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg - download_progress['status'] = 'error' - download_progress['end_time'] = time.time() - + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg + self._progress['status'] = 'error' + self._progress['end_time'] = time.time() + # Send error status via WebSocket - await ws_manager.broadcast({ - 'type': 'example_images_progress', - 'processed': download_progress['completed'], - 'total': download_progress['total'], - 'status': 'error', - 'error': error_msg, - 'current_model': '' - }) + await self._broadcast_progress(status='error', extra={'error': error_msg}) raise @@ -639,18 +588,16 @@ class DownloadManager: # No need to close any sessions since we use the global downloader pass - @staticmethod - async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader): - """Process a specific model for forced download, ignoring previous download status""" - global download_progress - + async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): + """Process a specific model for forced download, ignoring previous download status.""" + # Check if download is paused - while download_progress['status'] == 'paused': + while self._progress['status'] == 'paused': await asyncio.sleep(1) # Check if download should continue - if download_progress['status'] != 'running': - logger.info(f"Download stopped: {download_progress['status']}") + if self._progress['status'] != 'running': + logger.info(f"Download stopped: {self._progress['status']}") return False model_hash = model.get('sha256', '').lower() @@ -660,7 +607,8 @@ class DownloadManager: try: # Update current model info - download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" + self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" + await self._broadcast_progress(status='running') # Create model directory model_dir = os.path.join(output_dir, model_hash) @@ -676,7 +624,7 @@ class DownloadManager: await MetadataUpdater.update_metadata_from_local_examples( model_hash, model, scanner_type, scanner, model_dir ) - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return False # Return False to indicate no remote download happened # If no local images, try to download from remote @@ -688,9 +636,9 @@ class DownloadManager: ) # If metadata is stale, try to refresh it - if is_stale and model_hash not in download_progress['refreshed_models']: + if is_stale and model_hash not in self._progress['refreshed_models']: await MetadataUpdater.refresh_model_metadata( - model_hash, model_name, scanner_type, scanner + model_hash, model_name, scanner_type, scanner, self._progress ) # Get the updated model data @@ -708,18 +656,18 @@ class DownloadManager: # Combine failed images from both attempts failed_images.extend(additional_failed_images) - download_progress['refreshed_models'].add(model_hash) + self._progress['refreshed_models'].add(model_hash) # For forced downloads, remove failed images from metadata if failed_images: # Create a copy of images excluding failed ones - await DownloadManager._remove_failed_images_from_metadata( + await self._remove_failed_images_from_metadata( model_hash, model_name, failed_images, scanner ) # Mark as processed if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones - download_progress['processed_models'].add(model_hash) + self._progress['processed_models'].add(model_hash) return True # Return True to indicate a remote download happened else: @@ -729,12 +677,11 @@ class DownloadManager: except Exception as e: error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + self._progress['errors'].append(error_msg) + self._progress['last_error'] = error_msg return False # Return False on exception - @staticmethod - async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner): + async def _remove_failed_images_from_metadata(self, model_hash, model_name, failed_images, scanner): """Remove failed images from model metadata""" try: # Get current model data @@ -776,4 +723,55 @@ class DownloadManager: await scanner.update_single_model_cache(file_path, file_path, model_data) except Exception as e: - logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) \ No newline at end of file + logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) + + async def _broadcast_progress( + self, + *, + status: str | None = None, + extra: Dict[str, Any] | None = None, + ) -> None: + payload = self._build_progress_payload(status=status, extra=extra) + try: + await self._ws_manager.broadcast(payload) + except Exception as exc: # pragma: no cover - defensive logging + logger.warning("Failed to broadcast example image progress: %s", exc) + + def _build_progress_payload( + self, + *, + status: str | None = None, + extra: Dict[str, Any] | None = None, + ) -> Dict[str, Any]: + payload: Dict[str, Any] = { + 'type': 'example_images_progress', + 'processed': self._progress['completed'], + 'total': self._progress['total'], + 'status': status or self._progress['status'], + 'current_model': self._progress['current_model'], + } + + if self._progress['errors']: + payload['errors'] = list(self._progress['errors']) + if self._progress['last_error']: + payload['last_error'] = self._progress['last_error'] + + if extra: + payload.update(extra) + + return payload + + +_default_download_manager: DownloadManager | None = None + + +def get_default_download_manager(ws_manager) -> DownloadManager: + """Return the singleton download manager used by default routes.""" + + global _default_download_manager + if ( + _default_download_manager is None + or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager + ): + _default_download_manager = DownloadManager(ws_manager=ws_manager) + return _default_download_manager diff --git a/py/utils/example_images_metadata.py b/py/utils/example_images_metadata.py index 71566bff..780eb43b 100644 --- a/py/utils/example_images_metadata.py +++ b/py/utils/example_images_metadata.py @@ -1,19 +1,39 @@ import logging import os import re -from ..utils.metadata_manager import MetadataManager -from ..utils.routes_common import ModelRouteUtils + +from ..recipes.constants import GEN_PARAM_KEYS +from ..services.metadata_service import get_default_metadata_provider, get_metadata_provider +from ..services.metadata_sync_service import MetadataSyncService +from ..services.preview_asset_service import PreviewAssetService +from ..services.settings_manager import settings +from ..services.downloader import get_downloader from ..utils.constants import SUPPORTED_MEDIA_EXTENSIONS from ..utils.exif_utils import ExifUtils -from ..recipes.constants import GEN_PARAM_KEYS +from ..utils.metadata_manager import MetadataManager logger = logging.getLogger(__name__) +_preview_service = PreviewAssetService( + metadata_manager=MetadataManager, + downloader_factory=get_downloader, + exif_utils=ExifUtils, +) + +_metadata_sync_service = MetadataSyncService( + metadata_manager=MetadataManager, + preview_service=_preview_service, + settings=settings, + default_metadata_provider_factory=get_default_metadata_provider, + metadata_provider_selector=get_metadata_provider, +) + + class MetadataUpdater: """Handles updating model metadata related to example images""" @staticmethod - async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner): + async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None): """Refresh model metadata from CivitAI Args: @@ -25,8 +45,6 @@ class MetadataUpdater: Returns: bool: True if metadata was successfully refreshed, False otherwise """ - from ..utils.example_images_download_manager import download_progress - try: # Find the model in the scanner cache cache = await scanner.get_cached_data() @@ -47,17 +65,17 @@ class MetadataUpdater: return False # Track that we're refreshing this model - download_progress['refreshed_models'].add(model_hash) + if progress is not None: + progress['refreshed_models'].add(model_hash) - # Use ModelRouteUtils to refresh metadata async def update_cache_func(old_path, new_path, metadata): return await scanner.update_single_model_cache(old_path, new_path, metadata) - success, error = await ModelRouteUtils.fetch_and_update_model( - model_hash, - file_path, - model_data, - update_cache_func + success, error = await _metadata_sync_service.fetch_and_update_model( + sha256=model_hash, + file_path=file_path, + model_data=model_data, + update_cache_func=update_cache_func, ) if success: @@ -66,12 +84,13 @@ class MetadataUpdater: else: logger.warning(f"Failed to refresh metadata for {model_name}, {error}") return False - + except Exception as e: error_msg = f"Error refreshing metadata for {model_name}: {str(e)}" logger.error(error_msg, exc_info=True) - download_progress['errors'].append(error_msg) - download_progress['last_error'] = error_msg + if progress is not None: + progress['errors'].append(error_msg) + progress['last_error'] = error_msg return False @staticmethod diff --git a/py/utils/example_images_processor.py b/py/utils/example_images_processor.py index f1cfd2bf..7f108ef9 100644 --- a/py/utils/example_images_processor.py +++ b/py/utils/example_images_processor.py @@ -1,7 +1,6 @@ import logging import os import re -import tempfile import random import string from aiohttp import web @@ -13,6 +12,14 @@ from ..utils.metadata_manager import MetadataManager logger = logging.getLogger(__name__) + +class ExampleImagesImportError(RuntimeError): + """Base error for example image import operations.""" + + +class ExampleImagesValidationError(ExampleImagesImportError): + """Raised when input validation fails.""" + class ExampleImagesProcessor: """Processes and manipulates example images""" @@ -299,90 +306,29 @@ class ExampleImagesProcessor: return False @staticmethod - async def import_images(request): - """ - Import local example images - - Accepts: - - multipart/form-data form with model_hash and files fields - or - - JSON request with model_hash and file_paths - - Returns: - - Success status and list of imported files - """ + async def import_images(model_hash: str, files_to_import: list[str]): + """Import local example images for a model.""" + + if not model_hash: + raise ExampleImagesValidationError('Missing model_hash parameter') + + if not files_to_import: + raise ExampleImagesValidationError('No files provided to import') + try: - model_hash = None - files_to_import = [] - temp_files_to_cleanup = [] - - # Check if it's a multipart form-data request (direct file upload) - if request.content_type and 'multipart/form-data' in request.content_type: - reader = await request.multipart() - - # First get model_hash - field = await reader.next() - if field.name == 'model_hash': - model_hash = await field.text() - - # Then process all files - while True: - field = await reader.next() - if field is None: - break - - if field.name == 'files': - # Create a temporary file with appropriate suffix for type detection - file_name = field.filename - file_ext = os.path.splitext(file_name)[1].lower() - - with tempfile.NamedTemporaryFile(suffix=file_ext, delete=False) as tmp_file: - temp_path = tmp_file.name - temp_files_to_cleanup.append(temp_path) # Track for cleanup - - # Write chunks to the temporary file - while True: - chunk = await field.read_chunk() - if not chunk: - break - tmp_file.write(chunk) - - # Add to the list of files to process - files_to_import.append(temp_path) - else: - # Parse JSON request (legacy method using file paths) - data = await request.json() - model_hash = data.get('model_hash') - files_to_import = data.get('file_paths', []) - - if not model_hash: - return web.json_response({ - 'success': False, - 'error': 'Missing model_hash parameter' - }, status=400) - - if not files_to_import: - return web.json_response({ - 'success': False, - 'error': 'No files provided to import' - }, status=400) - # Get example images path example_images_path = settings.get('example_images_path') if not example_images_path: - return web.json_response({ - 'success': False, - 'error': 'No example images path configured' - }, status=400) - + raise ExampleImagesValidationError('No example images path configured') + # Find the model and get current metadata lora_scanner = await ServiceRegistry.get_lora_scanner() checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner() - + model_data = None scanner = None - + # Check both scanners to find the model for scan_obj in [lora_scanner, checkpoint_scanner, embedding_scanner]: cache = await scan_obj.get_cached_data() @@ -393,21 +339,20 @@ class ExampleImagesProcessor: break if model_data: break - + if not model_data: - return web.json_response({ - 'success': False, - 'error': f"Model with hash {model_hash} not found in cache" - }, status=404) - + raise ExampleImagesImportError( + f"Model with hash {model_hash} not found in cache" + ) + # Create model folder model_folder = os.path.join(example_images_path, model_hash) os.makedirs(model_folder, exist_ok=True) - + imported_files = [] errors = [] newly_imported_paths = [] - + # Process each file path for file_path in files_to_import: try: @@ -415,26 +360,26 @@ class ExampleImagesProcessor: if not os.path.isfile(file_path): errors.append(f"File not found: {file_path}") continue - + # Check if file type is supported file_ext = os.path.splitext(file_path)[1].lower() - if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or + if not (file_ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or file_ext in SUPPORTED_MEDIA_EXTENSIONS['videos']): errors.append(f"Unsupported file type: {file_path}") continue - + # Generate new filename using short ID instead of UUID short_id = ExampleImagesProcessor.generate_short_id() new_filename = f"custom_{short_id}{file_ext}" - + dest_path = os.path.join(model_folder, new_filename) - + # Copy the file import shutil shutil.copy2(file_path, dest_path) # Store both the dest_path and the short_id newly_imported_paths.append((dest_path, short_id)) - + # Add to imported files list imported_files.append({ 'name': new_filename, @@ -444,39 +389,31 @@ class ExampleImagesProcessor: }) except Exception as e: errors.append(f"Error importing {file_path}: {str(e)}") - + # Update metadata with new example images regular_images, custom_images = await MetadataUpdater.update_metadata_after_import( - model_hash, + model_hash, model_data, scanner, newly_imported_paths ) - - return web.json_response({ + + return { 'success': len(imported_files) > 0, - 'message': f'Successfully imported {len(imported_files)} files' + + 'message': f'Successfully imported {len(imported_files)} files' + (f' with {len(errors)} errors' if errors else ''), 'files': imported_files, 'errors': errors, 'regular_images': regular_images, 'custom_images': custom_images, "model_file_path": model_data.get('file_path', ''), - }) - + } + + except ExampleImagesImportError: + raise except Exception as e: logger.error(f"Failed to import example images: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - finally: - # Clean up temporary files - for temp_file in temp_files_to_cleanup: - try: - os.remove(temp_file) - except Exception as e: - logger.error(f"Failed to remove temporary file {temp_file}: {e}") + raise ExampleImagesImportError(str(e)) from e @staticmethod async def delete_custom_image(request): diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index b5f6af30..642bfcad 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -1,7 +1,7 @@ import os import json import logging -from typing import Dict, List, Callable, Awaitable +from typing import Dict, Callable, Awaitable from aiohttp import web from datetime import datetime @@ -18,7 +18,7 @@ from ..services.settings_manager import settings logger = logging.getLogger(__name__) - +# TODO: retire this class class ModelRouteUtils: """Shared utilities for model routes (LoRAs, Checkpoints, etc.)""" @@ -284,104 +284,6 @@ class ModelRouteUtils: ] return {k: data[k] for k in fields if k in data} - @staticmethod - async def delete_model_files(target_dir: str, file_name: str) -> List[str]: - """Delete model and associated files - - Args: - target_dir: Directory containing the model files - file_name: Base name of the model file without extension - - Returns: - List of deleted file paths - """ - patterns = [ - f"{file_name}.safetensors", # Required - f"{file_name}.metadata.json", - ] - - # Add all preview file extensions - for ext in PREVIEW_EXTENSIONS: - patterns.append(f"{file_name}{ext}") - - deleted = [] - main_file = patterns[0] - main_path = os.path.join(target_dir, main_file).replace(os.sep, '/') - - if os.path.exists(main_path): - # Delete file - os.remove(main_path) - deleted.append(main_path) - else: - logger.warning(f"Model file not found: {main_file}") - - # Delete optional files - for pattern in patterns[1:]: - path = os.path.join(target_dir, pattern) - if os.path.exists(path): - try: - os.remove(path) - deleted.append(pattern) - except Exception as e: - logger.warning(f"Failed to delete {pattern}: {e}") - - return deleted - - @staticmethod - def get_multipart_ext(filename): - """Get extension that may have multiple parts like .metadata.json or .metadata.json.bak""" - parts = filename.split(".") - if len(parts) == 3: # If contains 2-part extension - return "." + ".".join(parts[-2:]) # Take the last two parts, like ".metadata.json" - elif len(parts) >= 4: # If contains 3-part or more extensions - return "." + ".".join(parts[-3:]) # Take the last three parts, like ".metadata.json.bak" - return os.path.splitext(filename)[1] # Otherwise take the regular extension, like ".safetensors" - - # New common endpoint handlers - - @staticmethod - async def handle_delete_model(request: web.Request, scanner) -> web.Response: - """Handle model deletion request - - Args: - request: The aiohttp request - scanner: The model scanner instance with cache management methods - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='Model path is required', status=400) - - target_dir = os.path.dirname(file_path) - file_name = os.path.splitext(os.path.basename(file_path))[0] - - deleted_files = await ModelRouteUtils.delete_model_files( - target_dir, - file_name - ) - - # Remove from cache - cache = await scanner.get_cached_data() - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] - await cache.resort() - - # Update hash index if available - if hasattr(scanner, '_hash_index') and scanner._hash_index: - scanner._hash_index.remove_by_path(file_path) - - return web.json_response({ - 'success': True, - 'deleted_files': deleted_files - }) - - except Exception as e: - logger.error(f"Error deleting model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - @staticmethod async def handle_fetch_civitai(request: web.Request, scanner) -> web.Response: """Handle CivitAI metadata fetch request @@ -544,64 +446,6 @@ class ModelRouteUtils: logger.error(f"Error replacing preview: {e}", exc_info=True) return web.Response(text=str(e), status=500) - @staticmethod - async def handle_exclude_model(request: web.Request, scanner) -> web.Response: - """Handle model exclusion request - - Args: - request: The aiohttp request - scanner: The model scanner instance with cache management methods - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_path = data.get('file_path') - if not file_path: - return web.Response(text='Model path is required', status=400) - - # Update metadata to mark as excluded - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - metadata['exclude'] = True - - # Save updated metadata - await MetadataManager.save_metadata(file_path, metadata) - - # Update cache - cache = await scanner.get_cached_data() - - # Find and remove model from cache - model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) - if model_to_remove: - # Update tags count - for tag in model_to_remove.get('tags', []): - if tag in scanner._tags_count: - scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1) - if scanner._tags_count[tag] == 0: - del scanner._tags_count[tag] - - # Remove from hash index if available - if hasattr(scanner, '_hash_index') and scanner._hash_index: - scanner._hash_index.remove_by_path(file_path) - - # Remove from cache data - cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] - await cache.resort() - - # Add to excluded models list - scanner._excluded_models.append(file_path) - - return web.json_response({ - 'success': True, - 'message': f"Model {os.path.basename(file_path)} excluded" - }) - - except Exception as e: - logger.error(f"Error excluding model: {e}", exc_info=True) - return web.Response(text=str(e), status=500) - @staticmethod async def handle_download_model(request: web.Request) -> web.Response: """Handle model download request""" @@ -755,44 +599,6 @@ class ModelRouteUtils: 'error': str(e) }, status=500) - @staticmethod - async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response: - """Handle bulk deletion of models - - Args: - request: The aiohttp request - scanner: The model scanner instance with cache management methods - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_paths = data.get('file_paths', []) - - if not file_paths: - return web.json_response({ - 'success': False, - 'error': 'No file paths provided for deletion' - }, status=400) - - # Use the scanner's bulk delete method to handle all cache and file operations - result = await scanner.bulk_delete_models(file_paths) - - return web.json_response({ - 'success': result.get('success', False), - 'total_deleted': result.get('total_deleted', 0), - 'total_attempted': result.get('total_attempted', len(file_paths)), - 'results': result.get('results', []) - }) - - except Exception as e: - logger.error(f"Error in bulk delete: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - @staticmethod async def handle_relink_civitai(request: web.Request, scanner) -> web.Response: """Handle CivitAI metadata re-linking request by model ID and/or version ID @@ -948,137 +754,6 @@ class ModelRouteUtils: 'error': str(e) }, status=500) - @staticmethod - async def handle_rename_model(request: web.Request, scanner) -> web.Response: - """Handle renaming a model file and its associated files - - Args: - request: The aiohttp request - scanner: The model scanner instance - - Returns: - web.Response: The HTTP response - """ - try: - data = await request.json() - file_path = data.get('file_path') - new_file_name = data.get('new_file_name') - - if not file_path or not new_file_name: - return web.json_response({ - 'success': False, - 'error': 'File path and new file name are required' - }, status=400) - - # Validate the new file name (no path separators or invalid characters) - invalid_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|'] - if any(char in new_file_name for char in invalid_chars): - return web.json_response({ - 'success': False, - 'error': 'Invalid characters in file name' - }, status=400) - - # Get the directory and current file name - target_dir = os.path.dirname(file_path) - old_file_name = os.path.splitext(os.path.basename(file_path))[0] - - # Check if the target file already exists - new_file_path = os.path.join(target_dir, f"{new_file_name}.safetensors").replace(os.sep, '/') - if os.path.exists(new_file_path): - return web.json_response({ - 'success': False, - 'error': 'A file with this name already exists' - }, status=400) - - # Define the patterns for associated files - patterns = [ - f"{old_file_name}.safetensors", # Required - f"{old_file_name}.metadata.json", - f"{old_file_name}.metadata.json.bak", - ] - - # Add all preview file extensions - for ext in PREVIEW_EXTENSIONS: - patterns.append(f"{old_file_name}{ext}") - - # Find all matching files - existing_files = [] - for pattern in patterns: - path = os.path.join(target_dir, pattern) - if os.path.exists(path): - existing_files.append((path, pattern)) - - # Get the hash from the main file to update hash index - hash_value = None - metadata = None - metadata_path = os.path.join(target_dir, f"{old_file_name}.metadata.json") - - if os.path.exists(metadata_path): - metadata = await ModelRouteUtils.load_local_metadata(metadata_path) - hash_value = metadata.get('sha256') - logger.info(f"hash_value: {hash_value}, metadata_path: {metadata_path}, metadata: {metadata}") - # Rename all files - renamed_files = [] - new_metadata_path = None - new_preview = None - - for old_path, pattern in existing_files: - # Get the file extension like .safetensors or .metadata.json - ext = ModelRouteUtils.get_multipart_ext(pattern) - - # Create the new path - new_path = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') - - # Rename the file - os.rename(old_path, new_path) - renamed_files.append(new_path) - - # Keep track of metadata path for later update - if ext == '.metadata.json': - new_metadata_path = new_path - - # Update the metadata file with new file name and paths - if new_metadata_path and metadata: - # Update file_name, file_path and preview_url in metadata - metadata['file_name'] = new_file_name - metadata['file_path'] = new_file_path - - # Update preview_url if it exists - if 'preview_url' in metadata and metadata['preview_url']: - old_preview = metadata['preview_url'] - ext = ModelRouteUtils.get_multipart_ext(old_preview) - new_preview = os.path.join(target_dir, f"{new_file_name}{ext}").replace(os.sep, '/') - metadata['preview_url'] = new_preview - - # Save updated metadata - await MetadataManager.save_metadata(new_file_path, metadata) - - # Update the scanner cache - if metadata: - await scanner.update_single_model_cache(file_path, new_file_path, metadata) - - # Update recipe files and cache if hash is available and recipe_scanner exists - if hash_value and hasattr(scanner, 'update_lora_filename_by_hash'): - recipe_scanner = await ServiceRegistry.get_recipe_scanner() - if recipe_scanner: - recipes_updated, cache_updated = await recipe_scanner.update_lora_filename_by_hash(hash_value, new_file_name) - logger.info(f"Updated {recipes_updated} recipe files and {cache_updated} cache entries for renamed model") - - return web.json_response({ - 'success': True, - 'new_file_path': new_file_path, - 'new_preview_path': config.get_preview_static_url(new_preview), - 'renamed_files': renamed_files, - 'reload_required': False - }) - - except Exception as e: - logger.error(f"Error renaming model: {e}", exc_info=True) - return web.json_response({ - 'success': False, - 'error': str(e) - }, status=500) - @staticmethod async def handle_save_metadata(request: web.Request, scanner) -> web.Response: """Handle saving metadata updates diff --git a/pytest.ini b/pytest.ini index 44f4dc04..6f82885c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,5 +4,8 @@ testpaths = tests python_files = test_*.py python_classes = Test* python_functions = test_* +# Register async marker for coroutine-style tests +markers = + asyncio: execute test within asyncio event loop # Skip problematic directories to avoid import conflicts norecursedirs = .git .tox dist build *.egg __pycache__ py \ No newline at end of file diff --git a/run_tests.py b/run_tests.py deleted file mode 100644 index af5f96ff..00000000 --- a/run_tests.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -""" -Test runner script for ComfyUI-Lora-Manager. - -This script runs pytest from the tests directory to avoid import issues -with the root __init__.py file. -""" -import subprocess -import sys -import os -from pathlib import Path - -# Set environment variable to indicate standalone mode -# HF_HUB_DISABLE_TELEMETRY is from ComfyUI main.py -standalone_mode = os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" - -def main(): - """Run pytest from the tests directory to avoid import issues.""" - # Get the script directory - script_dir = Path(__file__).parent.absolute() - tests_dir = script_dir / "tests" - - if not tests_dir.exists(): - print(f"Error: Tests directory not found at {tests_dir}") - return 1 - - # Change to tests directory - original_cwd = os.getcwd() - os.chdir(tests_dir) - - try: - # Build pytest command - cmd = [ - sys.executable, "-m", "pytest", - "-v", - "--rootdir=.", - ] + sys.argv[1:] # Pass any additional arguments - - print(f"Running: {' '.join(cmd)}") - print(f"Working directory: {tests_dir}") - - # Run pytest - result = subprocess.run(cmd, cwd=tests_dir) - return result.returncode - finally: - # Restore original working directory - os.chdir(original_cwd) - -if __name__ == "__main__": - sys.exit(main()) diff --git a/standalone.py b/standalone.py index a6259851..95c45ca7 100644 --- a/standalone.py +++ b/standalone.py @@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager): RecipeRoutes.setup_routes(app) UpdateRoutes.setup_routes(app) MiscRoutes.setup_routes(app) - ExampleImagesRoutes.setup_routes(app) + ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager) # Setup WebSocket routes that are shared across all model types app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) diff --git a/tests/conftest.py b/tests/conftest.py index dfe99691..58263c8a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import types from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence +import asyncio +import inspect from unittest import mock import sys @@ -39,6 +41,35 @@ nodes_mock.NODE_CLASS_MAPPINGS = {} sys.modules['nodes'] = nodes_mock +def pytest_pyfunc_call(pyfuncitem): + """Allow bare async tests to run without pytest.mark.asyncio.""" + test_function = pyfuncitem.function + if inspect.iscoroutinefunction(test_function): + func = pyfuncitem.obj + signature = inspect.signature(func) + accepted_kwargs: Dict[str, Any] = {} + for name, parameter in signature.parameters.items(): + if parameter.kind is inspect.Parameter.VAR_POSITIONAL: + continue + if parameter.kind is inspect.Parameter.VAR_KEYWORD: + accepted_kwargs = dict(pyfuncitem.funcargs) + break + if name in pyfuncitem.funcargs: + accepted_kwargs[name] = pyfuncitem.funcargs[name] + + original_policy = asyncio.get_event_loop_policy() + policy = pyfuncitem.funcargs.get("event_loop_policy") + if policy is not None and policy is not original_policy: + asyncio.set_event_loop_policy(policy) + try: + asyncio.run(func(**accepted_kwargs)) + finally: + if policy is not None and policy is not original_policy: + asyncio.set_event_loop_policy(original_policy) + return True + return None + + @dataclass class MockHashIndex: """Minimal hash index stub mirroring the scanner contract.""" @@ -50,7 +81,7 @@ class MockHashIndex: class MockCache: - """Cache object with the attributes consumed by ``ModelRouteUtils``.""" + """Cache object with the attributes.""" def __init__(self, items: Optional[Sequence[Dict[str, Any]]] = None): self.raw_data: List[Dict[str, Any]] = list(items or []) @@ -58,7 +89,7 @@ class MockCache: async def resort(self) -> None: self.resort_calls += 1 - # ``ModelRouteUtils`` expects the coroutine interface but does not + # expects the coroutine interface but does not # rely on the return value. @@ -187,3 +218,5 @@ def mock_scanner(mock_cache: MockCache, mock_hash_index: MockHashIndex) -> MockS @pytest.fixture def mock_service(mock_scanner: MockScanner) -> MockModelService: return MockModelService(scanner=mock_scanner) + + diff --git a/tests/routes/test_base_model_routes_smoke.py b/tests/routes/test_base_model_routes_smoke.py index 2b9ed805..136bd0a8 100644 --- a/tests/routes/test_base_model_routes_smoke.py +++ b/tests/routes/test_base_model_routes_smoke.py @@ -28,6 +28,7 @@ spec.loader.exec_module(py_local) sys.modules.setdefault("py_local", py_local) from py_local.routes.base_model_routes import BaseModelRoutes +from py_local.services.model_file_service import AutoOrganizeResult from py_local.services.service_registry import ServiceRegistry from py_local.services.websocket_manager import ws_manager from py_local.utils.routes_common import ExifUtils @@ -66,12 +67,24 @@ def download_manager_stub(): class FakeDownloadManager: def __init__(self): self.calls = [] + self.error = None + self.cancelled = [] + self.active_downloads = {} async def download_from_civitai(self, **kwargs): self.calls.append(kwargs) + if self.error is not None: + raise self.error await kwargs["progress_callback"](42) return {"success": True, "path": "/tmp/model.safetensors"} + async def cancel_download(self, download_id): + self.cancelled.append(download_id) + return {"success": True, "download_id": download_id} + + async def get_active_downloads(self): + return self.active_downloads + stub = FakeDownloadManager() previous = ServiceRegistry._services.get("download_manager") asyncio.run(ServiceRegistry.register_service("download_manager", stub)) @@ -103,6 +116,21 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner): asyncio.run(scenario()) +def test_routes_return_service_not_ready_when_unattached(): + async def scenario(): + client = await create_test_client(None) + try: + response = await client.get("/api/lm/test-models/list") + payload = await response.json() + + assert response.status == 503 + assert payload == {"success": False, "error": "Service not ready"} + finally: + await client.close() + + asyncio.run(scenario()) + + def test_delete_model_updates_cache_and_hash_index(mock_service, mock_scanner, tmp_path: Path): model_path = tmp_path / "sample.safetensors" model_path.write_bytes(b"model") @@ -222,6 +250,69 @@ def test_download_model_invokes_download_manager( asyncio.run(scenario()) +def test_download_model_requires_identifier(mock_service, download_manager_stub): + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post( + "/api/lm/download-model", + json={"model_root": "/tmp"}, + ) + payload = await response.json() + + assert response.status == 400 + assert payload["success"] is False + assert "Missing required" in payload["error"] + finally: + await client.close() + + asyncio.run(scenario()) + + +def test_download_model_maps_validation_errors(mock_service, download_manager_stub): + download_manager_stub.error = ValueError("Invalid relative path") + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post( + "/api/lm/download-model", + json={"model_version_id": 123}, + ) + payload = await response.json() + + assert response.status == 400 + assert payload == {"success": False, "error": "Invalid relative path"} + assert ws_manager._download_progress == {} + finally: + await client.close() + + asyncio.run(scenario()) + + +def test_download_model_maps_early_access_errors(mock_service, download_manager_stub): + download_manager_stub.error = RuntimeError("401 early access") + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post( + "/api/lm/download-model", + json={"model_id": 4}, + ) + payload = await response.json() + + assert response.status == 401 + assert payload == { + "success": False, + "error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.", + } + finally: + await client.close() + + asyncio.run(scenario()) + + def test_auto_organize_progress_returns_latest_snapshot(mock_service): async def scenario(): client = await create_test_client(mock_service) @@ -235,5 +326,65 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service): assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}} finally: await client.close() + + asyncio.run(scenario()) + + +def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch): + async def fake_auto_organize(self, file_paths=None, progress_callback=None): + result = AutoOrganizeResult() + result.total = 1 + result.processed = 1 + result.success_count = 1 + result.skipped_count = 0 + result.failure_count = 0 + result.operation_type = "bulk" + if progress_callback is not None: + await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"}) + await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"}) + return result + + monkeypatch.setattr( + py_local.services.model_file_service.ModelFileService, + "auto_organize_models", + fake_auto_organize, + ) + + async def scenario(): + client = await create_test_client(mock_service) + try: + response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []}) + payload = await response.json() + + assert response.status == 200 + assert payload["success"] is True + + progress = ws_manager.get_auto_organize_progress() + assert progress is not None + assert progress["status"] == "completed" + finally: + await client.close() + + asyncio.run(scenario()) + + +def test_auto_organize_conflict_when_running(mock_service): + async def scenario(): + client = await create_test_client(mock_service) + try: + await ws_manager.broadcast_auto_organize_progress( + {"type": "auto_organize_progress", "status": "started"} + ) + + response = await client.post("/api/lm/test-models/auto-organize") + payload = await response.json() + + assert response.status == 409 + assert payload == { + "success": False, + "error": "Auto-organize is already running. Please wait for it to complete.", + } + finally: + await client.close() asyncio.run(scenario()) diff --git a/tests/routes/test_example_images_routes.py b/tests/routes/test_example_images_routes.py new file mode 100644 index 00000000..9a316499 --- /dev/null +++ b/tests/routes/test_example_images_routes.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +import json +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer +import pytest + +from py.routes.example_images_route_registrar import ROUTE_DEFINITIONS +from py.routes.example_images_routes import ExampleImagesRoutes +from py.routes.handlers.example_images_handlers import ( + ExampleImagesDownloadHandler, + ExampleImagesFileHandler, + ExampleImagesHandlerSet, + ExampleImagesManagementHandler, +) + + +@dataclass +class ExampleImagesHarness: + """Container exposing the aiohttp client and stubbed collaborators.""" + + client: TestClient + download_manager: "StubDownloadManager" + processor: "StubExampleImagesProcessor" + file_manager: "StubExampleImagesFileManager" + controller: ExampleImagesRoutes + + +class StubDownloadManager: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def start_download(self, payload: Any) -> dict: + self.calls.append(("start_download", payload)) + return {"operation": "start_download", "payload": payload} + + async def get_status(self, request: web.Request) -> dict: + self.calls.append(("get_status", dict(request.query))) + return {"operation": "get_status"} + + async def pause_download(self, request: web.Request) -> dict: + self.calls.append(("pause_download", None)) + return {"operation": "pause_download"} + + async def resume_download(self, request: web.Request) -> dict: + self.calls.append(("resume_download", None)) + return {"operation": "resume_download"} + + async def start_force_download(self, payload: Any) -> dict: + self.calls.append(("start_force_download", payload)) + return {"operation": "start_force_download", "payload": payload} + + +class StubExampleImagesProcessor: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def import_images(self, model_hash: str, files: List[str]) -> dict: + payload = {"model_hash": model_hash, "file_paths": files} + self.calls.append(("import_images", payload)) + return {"operation": "import_images", "payload": payload} + + async def delete_custom_image(self, request: web.Request) -> web.StreamResponse: + payload = await request.json() + self.calls.append(("delete_custom_image", payload)) + return web.json_response({"operation": "delete_custom_image", "payload": payload}) + + +class StubExampleImagesFileManager: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def open_folder(self, request: web.Request) -> web.StreamResponse: + payload = await request.json() + self.calls.append(("open_folder", payload)) + return web.json_response({"operation": "open_folder", "payload": payload}) + + async def get_files(self, request: web.Request) -> web.StreamResponse: + self.calls.append(("get_files", dict(request.query))) + return web.json_response({"operation": "get_files", "query": dict(request.query)}) + + async def has_images(self, request: web.Request) -> web.StreamResponse: + self.calls.append(("has_images", dict(request.query))) + return web.json_response({"operation": "has_images", "query": dict(request.query)}) + + +class StubWebSocketManager: + def __init__(self) -> None: + self.broadcast_calls: List[Dict[str, Any]] = [] + + async def broadcast(self, payload: Dict[str, Any]) -> None: + self.broadcast_calls.append(payload) + + +@asynccontextmanager +async def example_images_app() -> ExampleImagesHarness: + """Yield an ExampleImagesRoutes app wired with stubbed collaborators.""" + + download_manager = StubDownloadManager() + processor = StubExampleImagesProcessor() + file_manager = StubExampleImagesFileManager() + ws_manager = StubWebSocketManager() + + controller = ExampleImagesRoutes( + ws_manager=ws_manager, + download_manager=download_manager, + processor=processor, + file_manager=file_manager, + ) + + app = web.Application() + controller.register(app) + + server = TestServer(app) + client = TestClient(server) + await client.start_server() + + try: + yield ExampleImagesHarness( + client=client, + download_manager=download_manager, + processor=processor, + file_manager=file_manager, + controller=controller, + ) + finally: + await client.close() + + +async def test_setup_routes_registers_all_definitions(): + async with example_images_app() as harness: + registered = { + (route.method, route.resource.canonical) + for route in harness.client.app.router.routes() + if route.resource.canonical + } + + expected = {(definition.method, definition.path) for definition in ROUTE_DEFINITIONS} + + assert expected <= registered + + +@pytest.mark.parametrize( + "endpoint, payload", + [ + ("/api/lm/download-example-images", {"model_types": ["lora"], "optimize": False}), + ("/api/lm/force-download-example-images", {"model_hashes": ["abc123"]}), + ], +) +async def test_download_routes_delegate_to_manager(endpoint, payload): + async with example_images_app() as harness: + response = await harness.client.post(endpoint, json=payload) + body = await response.json() + + assert response.status == 200 + assert body["payload"] == payload + assert body["operation"].startswith("start") + + expected_call = body["operation"], payload + assert expected_call in harness.download_manager.calls + + +async def test_status_route_returns_manager_payload(): + async with example_images_app() as harness: + response = await harness.client.get( + "/api/lm/example-images-status", params={"detail": "true"} + ) + body = await response.json() + + assert response.status == 200 + assert body == {"operation": "get_status"} + assert harness.download_manager.calls == [("get_status", {"detail": "true"})] + + +async def test_pause_and_resume_routes_delegate(): + async with example_images_app() as harness: + pause_response = await harness.client.post("/api/lm/pause-example-images") + resume_response = await harness.client.post("/api/lm/resume-example-images") + + assert pause_response.status == 200 + assert await pause_response.json() == {"operation": "pause_download"} + assert resume_response.status == 200 + assert await resume_response.json() == {"operation": "resume_download"} + + assert harness.download_manager.calls[-2:] == [ + ("pause_download", None), + ("resume_download", None), + ] + + +async def test_import_route_delegates_to_processor(): + payload = {"model_hash": "abc123", "file_paths": ["/path/image.png"]} + async with example_images_app() as harness: + response = await harness.client.post( + "/api/lm/import-example-images", json=payload + ) + body = await response.json() + + assert response.status == 200 + assert body == {"operation": "import_images", "payload": payload} + expected_call = ("import_images", payload) + assert expected_call in harness.processor.calls + + +async def test_delete_route_delegates_to_processor(): + payload = {"model_hash": "abc123", "short_id": "xyz"} + async with example_images_app() as harness: + response = await harness.client.post( + "/api/lm/delete-example-image", json=payload + ) + body = await response.json() + + assert response.status == 200 + assert body == {"operation": "delete_custom_image", "payload": payload} + assert harness.processor.calls == [("delete_custom_image", payload)] + + +async def test_file_routes_delegate_to_file_manager(): + open_payload = {"model_hash": "abc123"} + files_params = {"model_hash": "def456"} + + async with example_images_app() as harness: + open_response = await harness.client.post( + "/api/lm/open-example-images-folder", json=open_payload + ) + files_response = await harness.client.get( + "/api/lm/example-image-files", params=files_params + ) + has_response = await harness.client.get( + "/api/lm/has-example-images", params=files_params + ) + + assert open_response.status == 200 + assert files_response.status == 200 + assert has_response.status == 200 + + assert await open_response.json() == {"operation": "open_folder", "payload": open_payload} + assert await files_response.json() == { + "operation": "get_files", + "query": files_params, + } + assert await has_response.json() == { + "operation": "has_images", + "query": files_params, + } + + assert harness.file_manager.calls == [ + ("open_folder", open_payload), + ("get_files", files_params), + ("has_images", files_params), + ] + + +@pytest.mark.asyncio +async def test_download_handler_methods_delegate() -> None: + class Recorder: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def get_status(self, request) -> dict: + self.calls.append(("get_status", request)) + return {"status": "ok"} + + async def pause_download(self, request) -> dict: + self.calls.append(("pause_download", request)) + return {"status": "paused"} + + async def resume_download(self, request) -> dict: + self.calls.append(("resume_download", request)) + return {"status": "running"} + + async def start_force_download(self, payload) -> dict: + self.calls.append(("start_force_download", payload)) + return {"status": "force", "payload": payload} + + class StubDownloadUseCase: + def __init__(self) -> None: + self.payloads: List[Any] = [] + + async def execute(self, payload: dict) -> dict: + self.payloads.append(payload) + return {"status": "started", "payload": payload} + + class DummyRequest: + def __init__(self, payload: dict) -> None: + self._payload = payload + self.query = {} + + async def json(self) -> dict: + return self._payload + + recorder = Recorder() + use_case = StubDownloadUseCase() + handler = ExampleImagesDownloadHandler(use_case, recorder) + request = DummyRequest({"foo": "bar"}) + + download_response = await handler.download_example_images(request) + assert json.loads(download_response.text) == {"status": "started", "payload": {"foo": "bar"}} + status_response = await handler.get_example_images_status(request) + assert json.loads(status_response.text) == {"status": "ok"} + pause_response = await handler.pause_example_images(request) + assert json.loads(pause_response.text) == {"status": "paused"} + resume_response = await handler.resume_example_images(request) + assert json.loads(resume_response.text) == {"status": "running"} + force_response = await handler.force_download_example_images(request) + assert json.loads(force_response.text) == {"status": "force", "payload": {"foo": "bar"}} + + assert use_case.payloads == [{"foo": "bar"}] + assert recorder.calls == [ + ("get_status", request), + ("pause_download", request), + ("resume_download", request), + ("start_force_download", {"foo": "bar"}), + ] + + +@pytest.mark.asyncio +async def test_management_handler_methods_delegate() -> None: + class StubImportUseCase: + def __init__(self) -> None: + self.requests: List[Any] = [] + + async def execute(self, request: Any) -> dict: + self.requests.append(request) + return {"status": "imported"} + + class Recorder: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def delete_custom_image(self, request) -> str: + self.calls.append(("delete_custom_image", request)) + return "delete" + + recorder = Recorder() + use_case = StubImportUseCase() + handler = ExampleImagesManagementHandler(use_case, recorder) + request = object() + + import_response = await handler.import_example_images(request) + assert json.loads(import_response.text) == {"status": "imported"} + assert await handler.delete_example_image(request) == "delete" + assert use_case.requests == [request] + assert recorder.calls == [("delete_custom_image", request)] + + +@pytest.mark.asyncio +async def test_file_handler_methods_delegate() -> None: + class Recorder: + def __init__(self) -> None: + self.calls: List[Tuple[str, Any]] = [] + + async def open_folder(self, request) -> str: + self.calls.append(("open_folder", request)) + return "open" + + async def get_files(self, request) -> str: + self.calls.append(("get_files", request)) + return "files" + + async def has_images(self, request) -> str: + self.calls.append(("has_images", request)) + return "has" + + recorder = Recorder() + handler = ExampleImagesFileHandler(recorder) + request = object() + + assert await handler.open_example_images_folder(request) == "open" + assert await handler.get_example_image_files(request) == "files" + assert await handler.has_example_images(request) == "has" + assert recorder.calls == [ + ("open_folder", request), + ("get_files", request), + ("has_images", request), + ] + + +def test_handler_set_route_mapping_includes_all_handlers() -> None: + class DummyUseCase: + async def execute(self, payload): + return payload + + class DummyManager: + async def get_status(self, request): + return {} + + async def pause_download(self, request): + return {} + + async def resume_download(self, request): + return {} + + async def start_force_download(self, payload): + return payload + + class DummyProcessor: + async def delete_custom_image(self, request): + return {} + + download = ExampleImagesDownloadHandler(DummyUseCase(), DummyManager()) + management = ExampleImagesManagementHandler(DummyUseCase(), DummyProcessor()) + files = ExampleImagesFileHandler(object()) + handler_set = ExampleImagesHandlerSet( + download=download, + management=management, + files=files, + ) + + mapping = handler_set.to_route_mapping() + + expected_keys = { + "download_example_images", + "get_example_images_status", + "pause_example_images", + "resume_example_images", + "force_download_example_images", + "import_example_images", + "delete_example_image", + "open_example_images_folder", + "get_example_image_files", + "has_example_images", + } + + assert mapping.keys() == expected_keys + for key in expected_keys: + assert callable(mapping[key]) diff --git a/tests/routes/test_recipe_route_scaffolding.py b/tests/routes/test_recipe_route_scaffolding.py new file mode 100644 index 00000000..59765d36 --- /dev/null +++ b/tests/routes/test_recipe_route_scaffolding.py @@ -0,0 +1,236 @@ +"""Smoke tests for the recipe routing scaffolding. + +The cases keep the registrar/controller contract aligned with +``docs/architecture/recipe_routes.md`` so future refactors can focus on handler +logic. +""" + +from __future__ import annotations + +import asyncio +import importlib.util +import sys +import types +from collections import Counter +from pathlib import Path +from typing import Any, Awaitable, Callable, Dict + +import pytest +from aiohttp import web + + +REPO_ROOT = Path(__file__).resolve().parents[2] +PY_PACKAGE_PATH = REPO_ROOT / "py" + +spec = importlib.util.spec_from_file_location( + "py_local", + PY_PACKAGE_PATH / "__init__.py", + submodule_search_locations=[str(PY_PACKAGE_PATH)], +) +py_local = importlib.util.module_from_spec(spec) +assert spec.loader is not None +spec.loader.exec_module(py_local) +sys.modules.setdefault("py_local", py_local) + +base_routes_module = importlib.import_module("py_local.routes.base_recipe_routes") +recipe_routes_module = importlib.import_module("py_local.routes.recipe_routes") +registrar_module = importlib.import_module("py_local.routes.recipe_route_registrar") + + +@pytest.fixture(autouse=True) +def reset_service_registry(monkeypatch: pytest.MonkeyPatch): + """Ensure each test starts from a clean registry state.""" + + services_module = importlib.import_module("py_local.services.service_registry") + registry = services_module.ServiceRegistry + previous_services = dict(registry._services) + previous_locks = dict(registry._locks) + registry._services.clear() + registry._locks.clear() + try: + yield + finally: + registry._services = previous_services + registry._locks = previous_locks + + +def _make_stub_scanner(): + class _StubScanner: + def __init__(self): + self._cache = types.SimpleNamespace() + + async def _lora_get_cached_data(): # pragma: no cover - smoke hook + return None + + self._lora_scanner = types.SimpleNamespace( + get_cached_data=_lora_get_cached_data, + _hash_index=types.SimpleNamespace(_hash_to_path={}), + ) + + async def get_cached_data(self, force_refresh: bool = False): + return self._cache + + return _StubScanner() + + +def test_attach_dependencies_resolves_services_once(monkeypatch: pytest.MonkeyPatch): + base_module = base_routes_module + services_module = importlib.import_module("py_local.services.service_registry") + registry = services_module.ServiceRegistry + server_i18n = importlib.import_module("py_local.services.server_i18n").server_i18n + + scanner = _make_stub_scanner() + civitai_client = object() + filter_calls = Counter() + + async def fake_get_recipe_scanner(): + return scanner + + async def fake_get_civitai_client(): + return civitai_client + + def fake_create_filter(): + filter_calls["create_filter"] += 1 + return object() + + monkeypatch.setattr(registry, "get_recipe_scanner", fake_get_recipe_scanner) + monkeypatch.setattr(registry, "get_civitai_client", fake_get_civitai_client) + monkeypatch.setattr(server_i18n, "create_template_filter", fake_create_filter) + + async def scenario(): + routes = base_module.BaseRecipeRoutes() + + await routes.attach_dependencies() + await routes.attach_dependencies() # idempotent + + assert routes.recipe_scanner is scanner + assert routes.lora_scanner is scanner._lora_scanner + assert routes.civitai_client is civitai_client + assert routes.template_env.filters["t"] is not None + assert filter_calls["create_filter"] == 1 + + asyncio.run(scenario()) + + +def test_register_startup_hooks_appends_once(): + routes = base_routes_module.BaseRecipeRoutes() + + app = web.Application() + routes.register_startup_hooks(app) + routes.register_startup_hooks(app) + + startup_bound_to_routes = [ + callback for callback in app.on_startup if getattr(callback, "__self__", None) is routes + ] + + assert routes.attach_dependencies in startup_bound_to_routes + assert routes.prewarm_cache in startup_bound_to_routes + assert len(startup_bound_to_routes) == 2 + + +def test_to_route_mapping_uses_handler_set(): + class DummyHandlerSet: + def __init__(self): + self.calls = 0 + + def to_route_mapping(self): + self.calls += 1 + + async def render_page(request): # pragma: no cover - simple coroutine + return web.Response(text="ok") + + return {"render_page": render_page} + + class DummyRoutes(base_routes_module.BaseRecipeRoutes): + def __init__(self): + super().__init__() + self.created = 0 + + def _create_handler_set(self): # noqa: D401 - simple override for test + self.created += 1 + return DummyHandlerSet() + + routes = DummyRoutes() + mapping = routes.to_route_mapping() + + assert set(mapping.keys()) == {"render_page"} + assert asyncio.iscoroutinefunction(mapping["render_page"]) + # Cached mapping reused on subsequent calls + assert routes.to_route_mapping() is mapping + # Handler set cached for get_handler_owner callers + assert isinstance(routes.get_handler_owner(), DummyHandlerSet) + assert routes.created == 1 + + +def test_recipe_route_registrar_binds_every_route(): + class FakeRouter: + def __init__(self): + self.calls: list[tuple[str, str, Callable[..., Awaitable[Any]]]] = [] + + def add_get(self, path, handler): + self.calls.append(("GET", path, handler)) + + def add_post(self, path, handler): + self.calls.append(("POST", path, handler)) + + def add_put(self, path, handler): + self.calls.append(("PUT", path, handler)) + + def add_delete(self, path, handler): + self.calls.append(("DELETE", path, handler)) + + class FakeApp: + def __init__(self): + self.router = FakeRouter() + + app = FakeApp() + registrar = registrar_module.RecipeRouteRegistrar(app) + + handler_mapping = { + definition.handler_name: object() + for definition in registrar_module.ROUTE_DEFINITIONS + } + + registrar.register_routes(handler_mapping) + + assert { + (method, path) + for method, path, _ in app.router.calls + } == {(d.method, d.path) for d in registrar_module.ROUTE_DEFINITIONS} + + +def test_recipe_routes_setup_routes_uses_registrar(monkeypatch: pytest.MonkeyPatch): + registered_mappings: list[Dict[str, Callable[..., Awaitable[Any]]]] = [] + + class DummyRegistrar: + def __init__(self, app): + self.app = app + + def register_routes(self, mapping): + registered_mappings.append(mapping) + + monkeypatch.setattr(recipe_routes_module, "RecipeRouteRegistrar", DummyRegistrar) + + expected_mapping = {name: object() for name in ("render_page", "list_recipes")} + + def fake_to_route_mapping(self): + return expected_mapping + + monkeypatch.setattr(base_routes_module.BaseRecipeRoutes, "to_route_mapping", fake_to_route_mapping) + monkeypatch.setattr( + base_routes_module.BaseRecipeRoutes, + "_HANDLER_NAMES", + tuple(expected_mapping.keys()), + ) + + app = web.Application() + recipe_routes_module.RecipeRoutes.setup_routes(app) + + assert registered_mappings == [expected_mapping] + recipe_callbacks = { + cb + for cb in app.on_startup + if isinstance(getattr(cb, "__self__", None), recipe_routes_module.RecipeRoutes) + } + assert {type(cb.__self__) for cb in recipe_callbacks} == {recipe_routes_module.RecipeRoutes} + assert {cb.__name__ for cb in recipe_callbacks} == {"attach_dependencies", "prewarm_cache"} diff --git a/tests/routes/test_recipe_routes.py b/tests/routes/test_recipe_routes.py new file mode 100644 index 00000000..467cb5b5 --- /dev/null +++ b/tests/routes/test_recipe_routes.py @@ -0,0 +1,330 @@ +"""Integration smoke tests for the recipe route stack.""" +from __future__ import annotations + +import json +from contextlib import asynccontextmanager +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any, AsyncIterator, Dict, List, Optional + +from aiohttp import FormData, web +from aiohttp.test_utils import TestClient, TestServer + +from py.config import config +from py.routes import base_recipe_routes +from py.routes.recipe_routes import RecipeRoutes +from py.services.recipes import RecipeValidationError +from py.services.service_registry import ServiceRegistry + + +@dataclass +class RecipeRouteHarness: + """Container exposing the aiohttp client and stubbed collaborators.""" + + client: TestClient + scanner: "StubRecipeScanner" + analysis: "StubAnalysisService" + persistence: "StubPersistenceService" + sharing: "StubSharingService" + tmp_dir: Path + + +class StubRecipeScanner: + """Minimal scanner double with the surface used by the handlers.""" + + def __init__(self, base_dir: Path) -> None: + self.recipes_dir = str(base_dir / "recipes") + self.listing_items: List[Dict[str, Any]] = [] + self.cached_raw: List[Dict[str, Any]] = [] + self.recipes: Dict[str, Dict[str, Any]] = {} + self.removed: List[str] = [] + + async def _noop_get_cached_data(force_refresh: bool = False) -> None: # noqa: ARG001 - signature mirrors real scanner + return None + + self._lora_scanner = SimpleNamespace( # mimic BaseRecipeRoutes expectations + get_cached_data=_noop_get_cached_data, + _hash_index=SimpleNamespace(_hash_to_path={}), + ) + + async def get_cached_data(self, force_refresh: bool = False) -> SimpleNamespace: # noqa: ARG002 - flag unused by stub + return SimpleNamespace(raw_data=list(self.cached_raw)) + + async def get_paginated_data(self, **params: Any) -> Dict[str, Any]: + items = [dict(item) for item in self.listing_items] + page = int(params.get("page", 1)) + page_size = int(params.get("page_size", 20)) + return { + "items": items, + "total": len(items), + "page": page, + "page_size": page_size, + "total_pages": max(1, (len(items) + page_size - 1) // max(page_size, 1)), + } + + async def get_recipe_by_id(self, recipe_id: str) -> Optional[Dict[str, Any]]: + return self.recipes.get(recipe_id) + + async def remove_recipe(self, recipe_id: str) -> None: + self.removed.append(recipe_id) + self.recipes.pop(recipe_id, None) + + +class StubAnalysisService: + """Captures calls made by analysis routes while returning canned responses.""" + + instances: List["StubAnalysisService"] = [] + + def __init__(self, **_: Any) -> None: + self.raise_for_uploaded: Optional[Exception] = None + self.raise_for_remote: Optional[Exception] = None + self.raise_for_local: Optional[Exception] = None + self.upload_calls: List[bytes] = [] + self.remote_calls: List[Optional[str]] = [] + self.local_calls: List[Optional[str]] = [] + self.result = SimpleNamespace(payload={"loras": []}, status=200) + StubAnalysisService.instances.append(self) + + async def analyze_uploaded_image(self, *, image_bytes: bytes | None, recipe_scanner) -> SimpleNamespace: # noqa: D401 - mirrors real signature + if self.raise_for_uploaded: + raise self.raise_for_uploaded + self.upload_calls.append(image_bytes or b"") + return self.result + + async def analyze_remote_image(self, *, url: Optional[str], recipe_scanner, civitai_client) -> SimpleNamespace: # noqa: D401 + if self.raise_for_remote: + raise self.raise_for_remote + self.remote_calls.append(url) + return self.result + + async def analyze_local_image(self, *, file_path: Optional[str], recipe_scanner) -> SimpleNamespace: # noqa: D401 + if self.raise_for_local: + raise self.raise_for_local + self.local_calls.append(file_path) + return self.result + + async def analyze_widget_metadata(self, *, recipe_scanner) -> SimpleNamespace: + return SimpleNamespace(payload={"metadata": {}, "image_bytes": b""}, status=200) + + +class StubPersistenceService: + """Stub for persistence operations to avoid filesystem writes.""" + + instances: List["StubPersistenceService"] = [] + + def __init__(self, **_: Any) -> None: + self.save_calls: List[Dict[str, Any]] = [] + self.delete_calls: List[str] = [] + self.save_result = SimpleNamespace(payload={"success": True, "recipe_id": "stub-id"}, status=200) + self.delete_result = SimpleNamespace(payload={"success": True}, status=200) + StubPersistenceService.instances.append(self) + + async def save_recipe(self, *, recipe_scanner, image_bytes, image_base64, name, tags, metadata) -> SimpleNamespace: # noqa: D401 + self.save_calls.append( + { + "recipe_scanner": recipe_scanner, + "image_bytes": image_bytes, + "image_base64": image_base64, + "name": name, + "tags": list(tags), + "metadata": metadata, + } + ) + return self.save_result + + async def delete_recipe(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace: + self.delete_calls.append(recipe_id) + await recipe_scanner.remove_recipe(recipe_id) + return self.delete_result + + async def update_recipe(self, *, recipe_scanner, recipe_id: str, updates: Dict[str, Any]) -> SimpleNamespace: # pragma: no cover - unused by smoke tests + return SimpleNamespace(payload={"success": True, "recipe_id": recipe_id, "updates": updates}, status=200) + + async def reconnect_lora(self, *, recipe_scanner, recipe_id: str, lora_index: int, target_name: str) -> SimpleNamespace: # pragma: no cover + return SimpleNamespace(payload={"success": True}, status=200) + + async def bulk_delete(self, *, recipe_scanner, recipe_ids: List[str]) -> SimpleNamespace: # pragma: no cover + return SimpleNamespace(payload={"success": True, "deleted": recipe_ids}, status=200) + + async def save_recipe_from_widget(self, *, recipe_scanner, metadata: Dict[str, Any], image_bytes: bytes) -> SimpleNamespace: # pragma: no cover + return SimpleNamespace(payload={"success": True}, status=200) + + +class StubSharingService: + """Share service stub recording requests and returning canned responses.""" + + instances: List["StubSharingService"] = [] + + def __init__(self, *, ttl_seconds: int = 300, logger) -> None: # noqa: ARG002 - ttl unused in stub + self.share_calls: List[str] = [] + self.download_calls: List[str] = [] + self.share_result = SimpleNamespace( + payload={"success": True, "download_url": "/share/stub", "filename": "recipe.png"}, + status=200, + ) + self.download_info = SimpleNamespace(file_path="", download_filename="") + StubSharingService.instances.append(self) + + async def share_recipe(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace: + self.share_calls.append(recipe_id) + return self.share_result + + async def prepare_download(self, *, recipe_scanner, recipe_id: str) -> SimpleNamespace: + self.download_calls.append(recipe_id) + return self.download_info + + +@asynccontextmanager +async def recipe_harness(monkeypatch, tmp_path: Path) -> AsyncIterator[RecipeRouteHarness]: + """Context manager that yields a fully wired recipe route harness.""" + + StubAnalysisService.instances.clear() + StubPersistenceService.instances.clear() + StubSharingService.instances.clear() + + scanner = StubRecipeScanner(tmp_path) + + async def fake_get_recipe_scanner(): + return scanner + + async def fake_get_civitai_client(): + return object() + + monkeypatch.setattr(ServiceRegistry, "get_recipe_scanner", fake_get_recipe_scanner) + monkeypatch.setattr(ServiceRegistry, "get_civitai_client", fake_get_civitai_client) + monkeypatch.setattr(base_recipe_routes, "RecipeAnalysisService", StubAnalysisService) + monkeypatch.setattr(base_recipe_routes, "RecipePersistenceService", StubPersistenceService) + monkeypatch.setattr(base_recipe_routes, "RecipeSharingService", StubSharingService) + monkeypatch.setattr(config, "loras_roots", [str(tmp_path)], raising=False) + + app = web.Application() + RecipeRoutes.setup_routes(app) + + server = TestServer(app) + client = TestClient(server) + await client.start_server() + + harness = RecipeRouteHarness( + client=client, + scanner=scanner, + analysis=StubAnalysisService.instances[-1], + persistence=StubPersistenceService.instances[-1], + sharing=StubSharingService.instances[-1], + tmp_dir=tmp_path, + ) + + try: + yield harness + finally: + await client.close() + StubAnalysisService.instances.clear() + StubPersistenceService.instances.clear() + StubSharingService.instances.clear() + + +async def test_list_recipes_provides_file_urls(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + recipe_path = harness.tmp_dir / "recipes" / "demo.png" + harness.scanner.listing_items = [ + { + "id": "recipe-1", + "file_path": str(recipe_path), + "title": "Demo", + "loras": [], + } + ] + harness.scanner.cached_raw = list(harness.scanner.listing_items) + + response = await harness.client.get("/api/lm/recipes") + payload = await response.json() + + assert response.status == 200 + assert payload["items"][0]["file_url"].endswith("demo.png") + assert payload["items"][0]["loras"] == [] + + +async def test_save_and_delete_recipe_round_trip(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + form = FormData() + form.add_field("image", b"stub", filename="sample.png", content_type="image/png") + form.add_field("name", "Test Recipe") + form.add_field("tags", json.dumps(["tag-a"])) + form.add_field("metadata", json.dumps({"loras": []})) + form.add_field("image_base64", "aW1hZ2U=") + + harness.persistence.save_result = SimpleNamespace( + payload={"success": True, "recipe_id": "saved-id"}, + status=201, + ) + + save_response = await harness.client.post("/api/lm/recipes/save", data=form) + save_payload = await save_response.json() + + assert save_response.status == 201 + assert save_payload["recipe_id"] == "saved-id" + assert harness.persistence.save_calls[-1]["name"] == "Test Recipe" + + harness.persistence.delete_result = SimpleNamespace(payload={"success": True}, status=200) + + delete_response = await harness.client.delete("/api/lm/recipe/saved-id") + delete_payload = await delete_response.json() + + assert delete_response.status == 200 + assert delete_payload["success"] is True + assert harness.persistence.delete_calls == ["saved-id"] + + +async def test_analyze_uploaded_image_error_path(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + harness.analysis.raise_for_uploaded = RecipeValidationError("No image data provided") + + form = FormData() + form.add_field("image", b"", filename="empty.png", content_type="image/png") + + response = await harness.client.post("/api/lm/recipes/analyze-image", data=form) + payload = await response.json() + + assert response.status == 400 + assert payload["error"] == "No image data provided" + assert payload["loras"] == [] + + +async def test_share_and_download_recipe(monkeypatch, tmp_path: Path) -> None: + async with recipe_harness(monkeypatch, tmp_path) as harness: + recipe_id = "share-me" + download_path = harness.tmp_dir / "recipes" / "share.png" + download_path.parent.mkdir(parents=True, exist_ok=True) + download_path.write_bytes(b"stub") + + harness.scanner.recipes[recipe_id] = { + "id": recipe_id, + "title": "Shared", + "file_path": str(download_path), + } + + harness.sharing.share_result = SimpleNamespace( + payload={"success": True, "download_url": "/api/share", "filename": "share.png"}, + status=200, + ) + harness.sharing.download_info = SimpleNamespace( + file_path=str(download_path), + download_filename="share.png", + ) + + share_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share") + share_payload = await share_response.json() + + assert share_response.status == 200 + assert share_payload["filename"] == "share.png" + assert harness.sharing.share_calls == [recipe_id] + + download_response = await harness.client.get(f"/api/lm/recipe/{recipe_id}/share/download") + body = await download_response.read() + + assert download_response.status == 200 + assert download_response.headers["Content-Disposition"] == 'attachment; filename="share.png"' + assert body == b"stub" + + download_path.unlink(missing_ok=True) + diff --git a/tests/services/test_base_model_service.py b/tests/services/test_base_model_service.py new file mode 100644 index 00000000..fc28a54e --- /dev/null +++ b/tests/services/test_base_model_service.py @@ -0,0 +1,296 @@ +import pytest + +import importlib +import importlib.util +import sys +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def import_from(module_name: str): + existing = sys.modules.get("py") + if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"): + sys.modules.pop("py", None) + spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py") + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) # type: ignore[union-attr] + module.__path__ = [str(ROOT / "py")] + sys.modules["py"] = module + return importlib.import_module(module_name) + + +BaseModelService = import_from("py.services.base_model_service").BaseModelService +model_query_module = import_from("py.services.model_query") +ModelCacheRepository = model_query_module.ModelCacheRepository +ModelFilterSet = model_query_module.ModelFilterSet +SearchStrategy = model_query_module.SearchStrategy +SortParams = model_query_module.SortParams +BaseModelMetadata = import_from("py.utils.models").BaseModelMetadata + + +class StubSettings: + def __init__(self, values): + self._values = dict(values) + + def get(self, key, default=None): + return self._values.get(key, default) + + +class DummyService(BaseModelService): + async def format_response(self, model_data): + return model_data + + +class StubRepository: + def __init__(self, data): + self._data = list(data) + self.parse_sort_calls = [] + self.fetch_sorted_calls = [] + + def parse_sort(self, sort_by): + params = ModelCacheRepository.parse_sort(sort_by) + self.parse_sort_calls.append(sort_by) + return params + + async def fetch_sorted(self, params): + self.fetch_sorted_calls.append(params) + return list(self._data) + + +class StubFilterSet: + def __init__(self, result): + self.result = list(result) + self.calls = [] + + def apply(self, data, criteria): + self.calls.append((list(data), criteria)) + return list(self.result) + + +class StubSearchStrategy: + def __init__(self, search_result): + self.search_result = list(search_result) + self.normalize_calls = [] + self.apply_calls = [] + + def normalize_options(self, options): + self.normalize_calls.append(options) + normalized = {"recursive": True} + if options: + normalized.update(options) + return normalized + + def apply(self, data, search_term, options, fuzzy): + self.apply_calls.append((list(data), search_term, options, fuzzy)) + return list(self.search_result) + + +@pytest.mark.asyncio +async def test_get_paginated_data_uses_injected_collaborators(): + data = [ + {"model_name": "Alpha", "folder": "root"}, + {"model_name": "Beta", "folder": "root"}, + ] + repository = StubRepository(data) + filter_set = StubFilterSet([{"model_name": "Filtered"}]) + search_strategy = StubSearchStrategy([{"model_name": "SearchResult"}]) + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=1, + page_size=5, + sort_by="name:desc", + folder="root", + search="query", + fuzzy_search=True, + base_models=["base"], + tags=["tag"], + search_options={"recursive": False}, + favorites_only=True, + ) + + assert repository.parse_sort_calls == ["name:desc"] + assert repository.fetch_sorted_calls and isinstance(repository.fetch_sorted_calls[0], SortParams) + sort_params = repository.fetch_sorted_calls[0] + assert sort_params.key == "name" and sort_params.order == "desc" + + assert filter_set.calls, "FilterSet should be invoked" + call_data, criteria = filter_set.calls[0] + assert call_data == data + assert criteria.folder == "root" + assert criteria.base_models == ["base"] + assert criteria.tags == ["tag"] + assert criteria.favorites_only is True + assert criteria.search_options.get("recursive") is False + + assert search_strategy.normalize_calls == [{"recursive": False}, {"recursive": False}] + assert search_strategy.apply_calls == [([{"model_name": "Filtered"}], "query", {"recursive": False}, True)] + + assert response["items"] == search_strategy.search_result + assert response["total"] == len(search_strategy.search_result) + assert response["page"] == 1 + assert response["page_size"] == 5 + + +class FakeCache: + def __init__(self, items): + self.items = list(items) + + async def get_sorted_data(self, sort_key, order): + if sort_key == "name": + data = sorted(self.items, key=lambda x: x["model_name"].lower()) + if order == "desc": + data.reverse() + else: + data = list(self.items) + return data + + +class FakeScanner: + def __init__(self, cache): + self._cache = cache + + async def get_cached_data(self, *_, **__): + return self._cache + + +@pytest.mark.asyncio +async def test_get_paginated_data_filters_and_searches_combination(): + items = [ + { + "model_name": "Alpha", + "file_name": "alpha.safetensors", + "folder": "root/sub", + "tags": ["tag1"], + "base_model": "v1", + "favorite": True, + "preview_nsfw_level": 0, + }, + { + "model_name": "Beta", + "file_name": "beta.safetensors", + "folder": "root", + "tags": ["tag2"], + "base_model": "v2", + "favorite": False, + "preview_nsfw_level": 999, + }, + { + "model_name": "Gamma", + "file_name": "gamma.safetensors", + "folder": "root/sub2", + "tags": ["tag1", "tag3"], + "base_model": "v1", + "favorite": True, + "preview_nsfw_level": 0, + "civitai": {"creator": {"username": "artist"}}, + }, + ] + + cache = FakeCache(items) + scanner = FakeScanner(cache) + settings = StubSettings({"show_only_sfw": True}) + + service = DummyService( + model_type="stub", + scanner=scanner, + metadata_class=BaseModelMetadata, + cache_repository=ModelCacheRepository(scanner), + filter_set=ModelFilterSet(settings), + search_strategy=SearchStrategy(), + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=1, + page_size=1, + sort_by="name:asc", + folder="root", + search="artist", + base_models=["v1"], + tags=["tag1"], + search_options={"creator": True, "tags": True}, + favorites_only=True, + ) + + assert response["items"] == [items[2]] + assert response["total"] == 1 + assert response["page"] == 1 + assert response["page_size"] == 1 + assert response["total_pages"] == 1 + + +class PassThroughFilterSet: + def __init__(self): + self.calls = [] + + def apply(self, data, criteria): + self.calls.append(criteria) + return list(data) + + +class NoSearchStrategy: + def __init__(self): + self.normalize_calls = [] + self.apply_called = False + + def normalize_options(self, options): + self.normalize_calls.append(options) + return {"recursive": True} + + def apply(self, *args, **kwargs): + self.apply_called = True + pytest.fail("Search should not be invoked when no search term is provided") + + +@pytest.mark.asyncio +async def test_get_paginated_data_paginates_without_search(): + items = [ + {"model_name": name, "folder": "root"} + for name in ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"] + ] + + repository = StubRepository(items) + filter_set = PassThroughFilterSet() + search_strategy = NoSearchStrategy() + settings = StubSettings({}) + + service = DummyService( + model_type="stub", + scanner=object(), + metadata_class=BaseModelMetadata, + cache_repository=repository, + filter_set=filter_set, + search_strategy=search_strategy, + settings_provider=settings, + ) + + response = await service.get_paginated_data( + page=2, + page_size=2, + sort_by="name:asc", + ) + + assert repository.parse_sort_calls == ["name:asc"] + assert len(repository.fetch_sorted_calls) == 1 + assert filter_set.calls and filter_set.calls[0].favorites_only is False + assert search_strategy.apply_called is False + assert response["items"] == items[2:4] + assert response["total"] == len(items) + assert response["page"] == 2 + assert response["page_size"] == 2 + assert response["total_pages"] == 3 diff --git a/tests/services/test_example_images_download_manager_async.py b/tests/services/test_example_images_download_manager_async.py new file mode 100644 index 00000000..7eef56fb --- /dev/null +++ b/tests/services/test_example_images_download_manager_async.py @@ -0,0 +1,228 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest + +from py.services.settings_manager import settings +from py.utils import example_images_download_manager as download_module + + +class RecordingWebSocketManager: + """Collects broadcast payloads for assertions.""" + + def __init__(self) -> None: + self.payloads: list[dict] = [] + + async def broadcast(self, payload: dict) -> None: + self.payloads.append(payload) + + +class StubScanner: + """Scanner double returning predetermined cache contents.""" + + def __init__(self, models: list[dict]) -> None: + self._cache = SimpleNamespace(raw_data=models) + + async def get_cached_data(self): + return self._cache + + +def _patch_scanner(monkeypatch: pytest.MonkeyPatch, scanner: StubScanner) -> None: + async def _get_lora_scanner(cls): + return scanner + + monkeypatch.setattr( + download_module.ServiceRegistry, + "get_lora_scanner", + classmethod(_get_lora_scanner), + ) + + +@pytest.mark.usefixtures("tmp_path") +async def test_start_download_rejects_parallel_runs(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + + model = { + "sha256": "abc123", + "model_name": "Example", + "file_path": str(tmp_path / "example.safetensors"), + "file_name": "example.safetensors", + } + _patch_scanner(monkeypatch, StubScanner([model])) + + started = asyncio.Event() + release = asyncio.Event() + + async def fake_process_local_examples(*_args, **_kwargs): + started.set() + await release.wait() + return True + + async def fake_update_metadata(*_args, **_kwargs): + return True + + async def fake_get_downloader(): + return object() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.MetadataUpdater, + "update_metadata_from_local_examples", + staticmethod(fake_update_metadata), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + try: + result = await manager.start_download({"model_types": ["lora"], "delay": 0}) + assert result["success"] is True + + await asyncio.wait_for(started.wait(), timeout=1) + + with pytest.raises(download_module.DownloadInProgressError) as exc: + await manager.start_download({"model_types": ["lora"], "delay": 0}) + + snapshot = exc.value.progress_snapshot + assert snapshot["status"] == "running" + assert snapshot["current_model"] == "Example (abc123)" + + statuses = [payload["status"] for payload in ws_manager.payloads] + assert "running" in statuses + + finally: + release.set() + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + +@pytest.mark.usefixtures("tmp_path") +async def test_pause_resume_blocks_processing(monkeypatch: pytest.MonkeyPatch, tmp_path): + ws_manager = RecordingWebSocketManager() + manager = download_module.DownloadManager(ws_manager=ws_manager) + + monkeypatch.setitem(settings.settings, "example_images_path", str(tmp_path)) + + models = [ + { + "sha256": "hash-one", + "model_name": "Model One", + "file_path": str(tmp_path / "model-one.safetensors"), + "file_name": "model-one.safetensors", + "civitai": {"images": [{"url": "https://example.com/one.png"}]}, + }, + { + "sha256": "hash-two", + "model_name": "Model Two", + "file_path": str(tmp_path / "model-two.safetensors"), + "file_name": "model-two.safetensors", + "civitai": {"images": [{"url": "https://example.com/two.png"}]}, + }, + ] + _patch_scanner(monkeypatch, StubScanner(models)) + + async def fake_process_local_examples(*_args, **_kwargs): + return False + + async def fake_update_metadata(*_args, **_kwargs): + return True + + first_call_started = asyncio.Event() + first_release = asyncio.Event() + second_call_started = asyncio.Event() + call_order: list[str] = [] + + async def fake_download_model_images(model_hash, *_args, **_kwargs): + call_order.append(model_hash) + if len(call_order) == 1: + first_call_started.set() + await first_release.wait() + else: + second_call_started.set() + return True, False + + async def fake_get_downloader(): + class _Downloader: + async def download_to_memory(self, *_a, **_kw): + return True, b"", {} + + return _Downloader() + + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "process_local_examples", + staticmethod(fake_process_local_examples), + ) + monkeypatch.setattr( + download_module.MetadataUpdater, + "update_metadata_from_local_examples", + staticmethod(fake_update_metadata), + ) + monkeypatch.setattr( + download_module.ExampleImagesProcessor, + "download_model_images", + staticmethod(fake_download_model_images), + ) + monkeypatch.setattr(download_module, "get_downloader", fake_get_downloader) + + original_sleep = download_module.asyncio.sleep + pause_gate = asyncio.Event() + resume_gate = asyncio.Event() + + async def fake_sleep(delay: float): + if delay == 1: + pause_gate.set() + await resume_gate.wait() + else: + await original_sleep(delay) + + monkeypatch.setattr(download_module.asyncio, "sleep", fake_sleep) + + try: + await manager.start_download({"model_types": ["lora"], "delay": 0}) + + await asyncio.wait_for(first_call_started.wait(), timeout=1) + + await manager.pause_download({}) + + first_release.set() + + await asyncio.wait_for(pause_gate.wait(), timeout=1) + assert manager._progress["status"] == "paused" + assert not second_call_started.is_set() + + statuses = [payload["status"] for payload in ws_manager.payloads] + paused_index = statuses.index("paused") + + await asyncio.sleep(0) + assert not second_call_started.is_set() + + await manager.resume_download({}) + resume_gate.set() + + await asyncio.wait_for(second_call_started.wait(), timeout=1) + + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + + statuses_after = [payload["status"] for payload in ws_manager.payloads] + running_after = next( + i for i, status in enumerate(statuses_after[paused_index + 1 :], start=paused_index + 1) if status == "running" + ) + assert running_after > paused_index + assert "completed" in statuses_after[running_after:] + assert call_order == ["hash-one", "hash-two"] + + finally: + first_release.set() + resume_gate.set() + if manager._download_task is not None: + await asyncio.wait_for(manager._download_task, timeout=1) + monkeypatch.setattr(download_module.asyncio, "sleep", original_sleep) diff --git a/tests/services/test_recipe_scanner.py b/tests/services/test_recipe_scanner.py new file mode 100644 index 00000000..63c18f25 --- /dev/null +++ b/tests/services/test_recipe_scanner.py @@ -0,0 +1,185 @@ +import asyncio +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from py.config import config +from py.services.recipe_scanner import RecipeScanner +from py.utils.utils import calculate_recipe_fingerprint + + +class StubHashIndex: + def __init__(self) -> None: + self._hash_to_path: dict[str, str] = {} + + def get_path(self, hash_value: str) -> str | None: + return self._hash_to_path.get(hash_value) + + +class StubLoraScanner: + def __init__(self) -> None: + self._hash_index = StubHashIndex() + self._hash_meta: dict[str, dict[str, str]] = {} + self._models_by_name: dict[str, dict] = {} + self._cache = SimpleNamespace(raw_data=[]) + + async def get_cached_data(self): + return self._cache + + def has_hash(self, hash_value: str) -> bool: + return hash_value.lower() in self._hash_meta + + def get_preview_url_by_hash(self, hash_value: str) -> str: + meta = self._hash_meta.get(hash_value.lower()) + return meta.get("preview_url", "") if meta else "" + + def get_path_by_hash(self, hash_value: str) -> str | None: + meta = self._hash_meta.get(hash_value.lower()) + return meta.get("path") if meta else None + + async def get_model_info_by_name(self, name: str): + return self._models_by_name.get(name) + + def register_model(self, name: str, info: dict) -> None: + self._models_by_name[name] = info + hash_value = (info.get("sha256") or "").lower() + if hash_value: + self._hash_meta[hash_value] = { + "path": info.get("file_path", ""), + "preview_url": info.get("preview_url", ""), + } + self._hash_index._hash_to_path[hash_value] = info.get("file_path", "") + self._cache.raw_data.append({ + "sha256": info.get("sha256", ""), + "path": info.get("file_path", ""), + "civitai": info.get("civitai", {}), + }) + + +@pytest.fixture +def recipe_scanner(tmp_path: Path, monkeypatch): + RecipeScanner._instance = None + monkeypatch.setattr(config, "loras_roots", [str(tmp_path)]) + stub = StubLoraScanner() + scanner = RecipeScanner(lora_scanner=stub) + asyncio.run(scanner.refresh_cache(force=True)) + yield scanner, stub + RecipeScanner._instance = None + + +async def test_add_recipe_during_concurrent_reads(recipe_scanner): + scanner, _ = recipe_scanner + + initial_recipe = { + "id": "one", + "file_path": "path/a.png", + "title": "First", + "modified": 1.0, + "created_date": 1.0, + "loras": [], + } + await scanner.add_recipe(initial_recipe) + + new_recipe = { + "id": "two", + "file_path": "path/b.png", + "title": "Second", + "modified": 2.0, + "created_date": 2.0, + "loras": [], + } + + async def reader_task(): + for _ in range(5): + cache = await scanner.get_cached_data() + _ = [item["id"] for item in cache.raw_data] + await asyncio.sleep(0) + + await asyncio.gather(reader_task(), reader_task(), scanner.add_recipe(new_recipe)) + await asyncio.sleep(0) + cache = await scanner.get_cached_data() + + assert {item["id"] for item in cache.raw_data} == {"one", "two"} + assert len(cache.sorted_by_name) == len(cache.raw_data) + + +async def test_remove_recipe_during_reads(recipe_scanner): + scanner, _ = recipe_scanner + + recipe_ids = ["alpha", "beta", "gamma"] + for index, recipe_id in enumerate(recipe_ids): + await scanner.add_recipe({ + "id": recipe_id, + "file_path": f"path/{recipe_id}.png", + "title": recipe_id, + "modified": float(index), + "created_date": float(index), + "loras": [], + }) + + async def reader_task(): + for _ in range(5): + cache = await scanner.get_cached_data() + _ = list(cache.sorted_by_date) + await asyncio.sleep(0) + + await asyncio.gather(reader_task(), scanner.remove_recipe("beta")) + await asyncio.sleep(0) + cache = await scanner.get_cached_data() + + assert {item["id"] for item in cache.raw_data} == {"alpha", "gamma"} + + +async def test_update_lora_entry_updates_cache_and_file(tmp_path: Path, recipe_scanner): + scanner, stub = recipe_scanner + recipes_dir = Path(config.loras_roots[0]) / "recipes" + recipes_dir.mkdir(parents=True, exist_ok=True) + + recipe_id = "recipe-1" + recipe_path = recipes_dir / f"{recipe_id}.recipe.json" + recipe_data = { + "id": recipe_id, + "file_path": str(tmp_path / "image.png"), + "title": "Original", + "modified": 0.0, + "created_date": 0.0, + "loras": [ + {"file_name": "old", "strength": 1.0, "hash": "", "isDeleted": True, "exclude": True}, + ], + } + recipe_path.write_text(json.dumps(recipe_data)) + + await scanner.add_recipe(dict(recipe_data)) + + target_hash = "abc123" + target_info = { + "sha256": target_hash, + "file_path": str(tmp_path / "loras" / "target.safetensors"), + "preview_url": "preview.png", + "civitai": {"id": 42, "name": "v1", "model": {"name": "Target"}}, + } + stub.register_model("target", target_info) + + updated_recipe, updated_lora = await scanner.update_lora_entry( + recipe_id, + 0, + target_name="target", + target_lora=target_info, + ) + + assert updated_lora["inLibrary"] is True + assert updated_lora["localPath"] == target_info["file_path"] + assert updated_lora["hash"] == target_hash + + with recipe_path.open("r", encoding="utf-8") as file_obj: + persisted = json.load(file_obj) + + expected_fingerprint = calculate_recipe_fingerprint(persisted["loras"]) + assert persisted["fingerprint"] == expected_fingerprint + + cache = await scanner.get_cached_data() + cached_recipe = next(item for item in cache.raw_data if item["id"] == recipe_id) + assert cached_recipe["loras"][0]["hash"] == target_hash + assert cached_recipe["fingerprint"] == expected_fingerprint diff --git a/tests/services/test_recipe_services.py b/tests/services/test_recipe_services.py new file mode 100644 index 00000000..81a15424 --- /dev/null +++ b/tests/services/test_recipe_services.py @@ -0,0 +1,150 @@ +import logging +import os +from types import SimpleNamespace + +import pytest + +from py.services.recipes.analysis_service import RecipeAnalysisService +from py.services.recipes.errors import RecipeDownloadError, RecipeNotFoundError +from py.services.recipes.persistence_service import RecipePersistenceService + + +class DummyExifUtils: + def optimize_image(self, image_data, target_width, format, quality, preserve_metadata): + return image_data, ".webp" + + def append_recipe_metadata(self, image_path, recipe_data): + self.appended = (image_path, recipe_data) + + def extract_image_metadata(self, path): + return {} + + +@pytest.mark.asyncio +async def test_analyze_remote_image_download_failure_cleans_temp(tmp_path, monkeypatch): + exif_utils = DummyExifUtils() + + class DummyFactory: + def create_parser(self, metadata): + return None + + async def downloader_factory(): + class Downloader: + async def download_file(self, url, path, use_auth=False): + return False, "failure" + + return Downloader() + + service = RecipeAnalysisService( + exif_utils=exif_utils, + recipe_parser_factory=DummyFactory(), + downloader_factory=downloader_factory, + metadata_collector=None, + metadata_processor_cls=None, + metadata_registry_cls=None, + standalone_mode=False, + logger=logging.getLogger("test"), + ) + + temp_path = tmp_path / "temp.jpg" + + def create_temp_path(): + temp_path.write_bytes(b"") + return str(temp_path) + + monkeypatch.setattr(service, "_create_temp_path", create_temp_path) + + with pytest.raises(RecipeDownloadError): + await service.analyze_remote_image( + url="https://example.com/image.jpg", + recipe_scanner=SimpleNamespace(), + civitai_client=SimpleNamespace(), + ) + + assert not temp_path.exists(), "temporary file should be cleaned after failure" + + +@pytest.mark.asyncio +async def test_analyze_local_image_missing_file(tmp_path): + async def downloader_factory(): + return SimpleNamespace() + + service = RecipeAnalysisService( + exif_utils=DummyExifUtils(), + recipe_parser_factory=SimpleNamespace(create_parser=lambda metadata: None), + downloader_factory=downloader_factory, + metadata_collector=None, + metadata_processor_cls=None, + metadata_registry_cls=None, + standalone_mode=False, + logger=logging.getLogger("test"), + ) + + with pytest.raises(RecipeNotFoundError): + await service.analyze_local_image( + file_path=str(tmp_path / "missing.png"), + recipe_scanner=SimpleNamespace(), + ) + + +@pytest.mark.asyncio +async def test_save_recipe_reports_duplicates(tmp_path): + exif_utils = DummyExifUtils() + + class DummyCache: + def __init__(self): + self.raw_data = [] + + async def resort(self): + pass + + class DummyScanner: + def __init__(self, root): + self.recipes_dir = str(root) + self._cache = DummyCache() + self.last_fingerprint = None + + async def find_recipes_by_fingerprint(self, fingerprint): + self.last_fingerprint = fingerprint + return ["existing"] + + async def add_recipe(self, recipe_data): + self._cache.raw_data.append(recipe_data) + await self._cache.resort() + + scanner = DummyScanner(tmp_path) + service = RecipePersistenceService( + exif_utils=exif_utils, + card_preview_width=512, + logger=logging.getLogger("test"), + ) + + metadata = { + "base_model": "sd", + "loras": [ + { + "file_name": "sample", + "hash": "abc123", + "weight": 0.5, + "id": 1, + "name": "Sample", + "version": "v1", + "isDeleted": False, + "exclude": False, + } + ], + } + + result = await service.save_recipe( + recipe_scanner=scanner, + image_bytes=b"image-bytes", + image_base64=None, + name="My Recipe", + tags=["tag"], + metadata=metadata, + ) + + assert result.payload["matching_recipes"] == ["existing"] + assert scanner.last_fingerprint is not None + assert os.path.exists(result.payload["json_path"]) + assert scanner._cache.raw_data diff --git a/tests/services/test_route_support_services.py b/tests/services/test_route_support_services.py new file mode 100644 index 00000000..978438c3 --- /dev/null +++ b/tests/services/test_route_support_services.py @@ -0,0 +1,273 @@ +import asyncio +import json +import os +import sys +from pathlib import Path +from typing import Any, Dict, List + +ROOT = Path(__file__).resolve().parents[2] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +import importlib +import importlib.util + +import pytest + + +def import_from(module_name: str): + existing = sys.modules.get("py") + if existing is None or getattr(existing, "__file__", "") != str(ROOT / "py/__init__.py"): + sys.modules.pop("py", None) + spec = importlib.util.spec_from_file_location("py", ROOT / "py/__init__.py") + module = importlib.util.module_from_spec(spec) + assert spec and spec.loader + spec.loader.exec_module(module) # type: ignore[union-attr] + module.__path__ = [str(ROOT / "py")] + sys.modules["py"] = module + return importlib.import_module(module_name) + + +DownloadCoordinator = import_from("py.services.download_coordinator").DownloadCoordinator +MetadataSyncService = import_from("py.services.metadata_sync_service").MetadataSyncService +PreviewAssetService = import_from("py.services.preview_asset_service").PreviewAssetService +TagUpdateService = import_from("py.services.tag_update_service").TagUpdateService + + +class DummySettings: + def __init__(self, values: Dict[str, Any] | None = None) -> None: + self._values = values or {} + + def get(self, key: str, default: Any = None) -> Any: + return self._values.get(key, default) + + +class RecordingMetadataManager: + def __init__(self) -> None: + self.saved: List[tuple[str, Dict[str, Any]]] = [] + + async def save_metadata(self, path: str, metadata: Dict[str, Any]) -> bool: + self.saved.append((path, json.loads(json.dumps(metadata)))) + metadata_path = path if path.endswith(".metadata.json") else f"{os.path.splitext(path)[0]}.metadata.json" + Path(metadata_path).write_text(json.dumps(metadata)) + return True + + +class RecordingPreviewService: + def __init__(self) -> None: + self.calls: List[tuple[str, List[Dict[str, Any]]]] = [] + + async def ensure_preview_for_metadata( + self, metadata_path: str, local_metadata: Dict[str, Any], images + ) -> None: + self.calls.append((metadata_path, list(images or []))) + local_metadata["preview_url"] = "preview.webp" + local_metadata["preview_nsfw_level"] = 1 + + +class DummyProvider: + def __init__(self, payload: Dict[str, Any]) -> None: + self.payload = payload + + async def get_model_by_hash(self, sha256: str): + return self.payload, None + + async def get_model_version(self, model_id: int, model_version_id: int | None): + return self.payload + + +class FakeExifUtils: + @staticmethod + def optimize_image(**kwargs): + return kwargs["image_data"], {} + + +def test_metadata_sync_merges_remote_fields(tmp_path: Path) -> None: + manager = RecordingMetadataManager() + preview = RecordingPreviewService() + provider = DummyProvider({ + "baseModel": "SD15", + "model": {"name": "Merged", "description": "desc", "tags": ["tag"], "creator": {"username": "user"}}, + "trainedWords": ["word"], + "images": [{"url": "http://example", "nsfwLevel": 2, "type": "image"}], + }) + + service = MetadataSyncService( + metadata_manager=manager, + preview_service=preview, + settings=DummySettings(), + default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider), + metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider), + ) + + metadata_path = str(tmp_path / "model.metadata.json") + local_metadata = {"civitai": {"trainedWords": ["existing"]}} + + updated = asyncio.run(service.update_model_metadata(metadata_path, local_metadata, provider.payload)) + + assert updated["model_name"] == "Merged" + assert updated["modelDescription"] == "desc" + assert set(updated["civitai"]["trainedWords"]) == {"existing", "word"} + assert manager.saved + assert preview.calls + + +def test_metadata_sync_fetch_and_update_updates_cache(tmp_path: Path) -> None: + manager = RecordingMetadataManager() + preview = RecordingPreviewService() + provider = DummyProvider({ + "baseModel": "SDXL", + "model": {"name": "Updated"}, + "images": [], + }) + + update_cache_calls: List[Dict[str, Any]] = [] + + async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool: + update_cache_calls.append({"original": original, "metadata": metadata}) + return True + + service = MetadataSyncService( + metadata_manager=manager, + preview_service=preview, + settings=DummySettings(), + default_metadata_provider_factory=lambda: asyncio.sleep(0, result=provider), + metadata_provider_selector=lambda _name=None: asyncio.sleep(0, result=provider), + ) + + model_data = {"sha256": "abc", "file_path": str(tmp_path / "model.safetensors")} + success, error = asyncio.run( + service.fetch_and_update_model( + sha256="abc", + file_path=str(tmp_path / "model.safetensors"), + model_data=model_data, + update_cache_func=update_cache, + ) + ) + + assert success is True + assert error is None + assert update_cache_calls + assert manager.saved + + +def test_preview_asset_service_replace_preview(tmp_path: Path) -> None: + metadata_path = tmp_path / "sample.metadata.json" + metadata_path.write_text(json.dumps({})) + + async def metadata_loader(path: str) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + + manager = RecordingMetadataManager() + + service = PreviewAssetService( + metadata_manager=manager, + downloader_factory=lambda: asyncio.sleep(0, result=None), + exif_utils=FakeExifUtils(), + ) + + preview_calls: List[Dict[str, Any]] = [] + + async def update_preview(model_path: str, preview_path: str, nsfw: int) -> bool: + preview_calls.append({"model_path": model_path, "preview_path": preview_path, "nsfw": nsfw}) + return True + + model_path = str(tmp_path / "sample.safetensors") + Path(model_path).write_bytes(b"model") + + result = asyncio.run( + service.replace_preview( + model_path=model_path, + preview_data=b"image-bytes", + content_type="image/png", + original_filename="preview.png", + nsfw_level=2, + update_preview_in_cache=update_preview, + metadata_loader=metadata_loader, + ) + ) + + assert result["preview_nsfw_level"] == 2 + assert preview_calls + saved_metadata = json.loads(metadata_path.read_text()) + assert saved_metadata["preview_nsfw_level"] == 2 + + +def test_download_coordinator_emits_progress() -> None: + class WSStub: + def __init__(self) -> None: + self.progress_events: List[Dict[str, Any]] = [] + self.counter = 0 + + def generate_download_id(self) -> str: + self.counter += 1 + return f"dl-{self.counter}" + + async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None: + self.progress_events.append({"id": download_id, **payload}) + + class DownloadManagerStub: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def download_from_civitai(self, **kwargs) -> Dict[str, Any]: + self.calls.append(kwargs) + await kwargs["progress_callback"](10) + return {"success": True} + + async def cancel_download(self, download_id: str) -> Dict[str, Any]: + return {"success": True, "download_id": download_id} + + async def get_active_downloads(self) -> Dict[str, Any]: + return {"active": []} + + ws_stub = WSStub() + manager_stub = DownloadManagerStub() + + coordinator = DownloadCoordinator( + ws_manager=ws_stub, + download_manager_factory=lambda: asyncio.sleep(0, result=manager_stub), + ) + + result = asyncio.run(coordinator.schedule_download({"model_id": 1})) + + assert result["success"] is True + assert manager_stub.calls + assert ws_stub.progress_events + + cancel_result = asyncio.run(coordinator.cancel_download(result["download_id"])) + assert cancel_result["success"] is True + + active = asyncio.run(coordinator.list_active_downloads()) + assert active == {"active": []} + + +def test_tag_update_service_adds_unique_tags(tmp_path: Path) -> None: + metadata_path = tmp_path / "model.metadata.json" + metadata_path.write_text(json.dumps({"tags": ["Existing"]})) + + async def loader(path: str) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + + manager = RecordingMetadataManager() + + service = TagUpdateService(metadata_manager=manager) + + cache_updates: List[Dict[str, Any]] = [] + + async def update_cache(original: str, new: str, metadata: Dict[str, Any]) -> bool: + cache_updates.append(metadata) + return True + + tags = asyncio.run( + service.add_tags( + file_path=str(tmp_path / "model.safetensors"), + new_tags=["New", "existing"], + metadata_loader=loader, + update_cache=update_cache, + ) + ) + + assert tags == ["Existing", "New"] + assert manager.saved + assert cache_updates diff --git a/tests/services/test_use_cases.py b/tests/services/test_use_cases.py new file mode 100644 index 00000000..cfd0f10c --- /dev/null +++ b/tests/services/test_use_cases.py @@ -0,0 +1,317 @@ +import asyncio +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import pytest + +from py_local.services.model_file_service import AutoOrganizeResult +from py_local.services.use_cases import ( + AutoOrganizeInProgressError, + AutoOrganizeUseCase, + BulkMetadataRefreshUseCase, + DownloadExampleImagesConfigurationError, + DownloadExampleImagesInProgressError, + DownloadExampleImagesUseCase, + DownloadModelEarlyAccessError, + DownloadModelUseCase, + DownloadModelValidationError, + ImportExampleImagesUseCase, + ImportExampleImagesValidationError, +) +from py_local.utils.example_images_download_manager import ( + DownloadConfigurationError, + DownloadInProgressError, + ExampleImagesDownloadError, +) +from py_local.utils.example_images_processor import ( + ExampleImagesImportError, + ExampleImagesValidationError, +) +from tests.conftest import MockModelService, MockScanner + + +class StubLockProvider: + def __init__(self) -> None: + self._lock = asyncio.Lock() + self.running = False + + def is_auto_organize_running(self) -> bool: + return self.running + + async def get_auto_organize_lock(self) -> asyncio.Lock: + return self._lock + + +class StubFileService: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def auto_organize_models( + self, + *, + file_paths: Optional[List[str]] = None, + progress_callback=None, + ) -> AutoOrganizeResult: + result = AutoOrganizeResult() + result.total = len(file_paths or []) + self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback}) + return result + + +class StubMetadataSync: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + + async def fetch_and_update_model(self, **kwargs: Any): + self.calls.append(kwargs) + model_data = kwargs["model_data"] + model_data["model_name"] = model_data.get("model_name", "model") + "-updated" + return True, None + + +@dataclass +class StubSettings: + enable_metadata_archive_db: bool = False + + def get(self, key: str, default: Any = None) -> Any: + if key == "enable_metadata_archive_db": + return self.enable_metadata_archive_db + return default + + +class ProgressCollector: + def __init__(self) -> None: + self.events: List[Dict[str, Any]] = [] + + async def on_progress(self, payload: Dict[str, Any]) -> None: + self.events.append(payload) + + +class StubDownloadCoordinator: + def __init__(self, *, error: Optional[str] = None) -> None: + self.error = error + self.payloads: List[Dict[str, Any]] = [] + + async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + self.payloads.append(payload) + if self.error == "validation": + raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'") + if self.error == "401": + raise RuntimeError("401 Unauthorized") + return {"success": True, "download_id": "abc123"} + + +class StubExampleImagesDownloadManager: + def __init__(self) -> None: + self.payloads: List[Dict[str, Any]] = [] + self.error: Optional[str] = None + self.progress_snapshot = {"status": "running"} + + async def start_download(self, payload: Dict[str, Any]) -> Dict[str, Any]: + self.payloads.append(payload) + if self.error == "in_progress": + raise DownloadInProgressError(self.progress_snapshot) + if self.error == "configuration": + raise DownloadConfigurationError("path missing") + if self.error == "generic": + raise ExampleImagesDownloadError("boom") + return {"success": True, "message": "ok"} + + +class StubExampleImagesProcessor: + def __init__(self) -> None: + self.calls: List[Dict[str, Any]] = [] + self.error: Optional[str] = None + self.response: Dict[str, Any] = {"success": True} + + async def import_images(self, model_hash: str, files: List[str]) -> Dict[str, Any]: + self.calls.append({"model_hash": model_hash, "files": files}) + if self.error == "validation": + raise ExampleImagesValidationError("missing") + if self.error == "generic": + raise ExampleImagesImportError("boom") + return self.response + + +async def test_auto_organize_use_case_executes_with_lock() -> None: + file_service = StubFileService() + lock_provider = StubLockProvider() + use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider) + + result = await use_case.execute(file_paths=["model1"], progress_callback=None) + + assert isinstance(result, AutoOrganizeResult) + assert file_service.calls[0]["file_paths"] == ["model1"] + + +async def test_auto_organize_use_case_rejects_when_running() -> None: + file_service = StubFileService() + lock_provider = StubLockProvider() + lock_provider.running = True + use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider) + + with pytest.raises(AutoOrganizeInProgressError): + await use_case.execute(file_paths=None, progress_callback=None) + + +async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None: + scanner = MockScanner() + scanner._cache.raw_data = [ + { + "file_path": "model1.safetensors", + "sha256": "hash", + "from_civitai": True, + "model_name": "Demo", + } + ] + service = MockModelService(scanner) + metadata_sync = StubMetadataSync() + settings = StubSettings() + progress = ProgressCollector() + + use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=metadata_sync, + settings_service=settings, + logger=logging.getLogger("test"), + ) + + result = await use_case.execute_with_error_handling(progress_callback=progress) + + assert result["success"] is True + assert progress.events[0]["status"] == "started" + assert progress.events[-1]["status"] == "completed" + assert metadata_sync.calls + assert scanner._cache.resort_calls == 1 + + +async def test_bulk_metadata_refresh_reports_errors() -> None: + class FailingScanner(MockScanner): + async def get_cached_data(self, force_refresh: bool = False): + raise RuntimeError("boom") + + service = MockModelService(FailingScanner()) + metadata_sync = StubMetadataSync() + settings = StubSettings() + progress = ProgressCollector() + + use_case = BulkMetadataRefreshUseCase( + service=service, + metadata_sync=metadata_sync, + settings_service=settings, + logger=logging.getLogger("test"), + ) + + with pytest.raises(RuntimeError): + await use_case.execute_with_error_handling(progress_callback=progress) + + assert progress.events + assert progress.events[-1]["status"] == "error" + assert progress.events[-1]["error"] == "boom" + + +async def test_download_model_use_case_raises_validation_error() -> None: + coordinator = StubDownloadCoordinator(error="validation") + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + with pytest.raises(DownloadModelValidationError): + await use_case.execute({}) + + +async def test_download_model_use_case_raises_early_access() -> None: + coordinator = StubDownloadCoordinator(error="401") + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + with pytest.raises(DownloadModelEarlyAccessError): + await use_case.execute({"model_id": 1}) + + +async def test_download_model_use_case_returns_result() -> None: + coordinator = StubDownloadCoordinator() + use_case = DownloadModelUseCase(download_coordinator=coordinator) + + result = await use_case.execute({"model_id": 1}) + + assert result["success"] is True + assert result["download_id"] == "abc123" + + +async def test_download_example_images_use_case_triggers_manager() -> None: + manager = StubExampleImagesDownloadManager() + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + payload = {"optimize": True} + result = await use_case.execute(payload) + + assert manager.payloads == [payload] + assert result == {"success": True, "message": "ok"} + + +async def test_download_example_images_use_case_maps_in_progress() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "in_progress" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(DownloadExampleImagesInProgressError) as exc: + await use_case.execute({}) + + assert exc.value.progress == manager.progress_snapshot + + +async def test_download_example_images_use_case_maps_configuration() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "configuration" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(DownloadExampleImagesConfigurationError): + await use_case.execute({}) + + +async def test_download_example_images_use_case_propagates_generic_error() -> None: + manager = StubExampleImagesDownloadManager() + manager.error = "generic" + use_case = DownloadExampleImagesUseCase(download_manager=manager) + + with pytest.raises(ExampleImagesDownloadError): + await use_case.execute({}) + + +class DummyJsonRequest: + def __init__(self, payload: Dict[str, Any]) -> None: + self._payload = payload + self.content_type = "application/json" + + async def json(self) -> Dict[str, Any]: + return self._payload + + +async def test_import_example_images_use_case_delegates() -> None: + processor = StubExampleImagesProcessor() + use_case = ImportExampleImagesUseCase(processor=processor) + + request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]}) + result = await use_case.execute(request) + + assert processor.calls == [{"model_hash": "abc", "files": ["/tmp/file"]}] + assert result == {"success": True} + + +async def test_import_example_images_use_case_maps_validation_error() -> None: + processor = StubExampleImagesProcessor() + processor.error = "validation" + use_case = ImportExampleImagesUseCase(processor=processor) + request = DummyJsonRequest({"model_hash": None, "file_paths": []}) + + with pytest.raises(ImportExampleImagesValidationError): + await use_case.execute(request) + + +async def test_import_example_images_use_case_propagates_generic_error() -> None: + processor = StubExampleImagesProcessor() + processor.error = "generic" + use_case = ImportExampleImagesUseCase(processor=processor) + request = DummyJsonRequest({"model_hash": "abc", "file_paths": ["/tmp/file"]}) + + with pytest.raises(ExampleImagesImportError): + await use_case.execute(request)