mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(routes): extract orchestration use cases
This commit is contained in:
25
py/services/use_cases/__init__.py
Normal file
25
py/services/use_cases/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Application-level orchestration services for model routes."""
|
||||
|
||||
from .auto_organize_use_case import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
)
|
||||
from .bulk_metadata_refresh_use_case import (
|
||||
BulkMetadataRefreshUseCase,
|
||||
MetadataRefreshProgressReporter,
|
||||
)
|
||||
from .download_model_use_case import (
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutoOrganizeInProgressError",
|
||||
"AutoOrganizeUseCase",
|
||||
"BulkMetadataRefreshUseCase",
|
||||
"MetadataRefreshProgressReporter",
|
||||
"DownloadModelEarlyAccessError",
|
||||
"DownloadModelUseCase",
|
||||
"DownloadModelValidationError",
|
||||
]
|
||||
56
py/services/use_cases/auto_organize_use_case.py
Normal file
56
py/services/use_cases/auto_organize_use_case.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Auto-organize use case orchestrating concurrency and progress handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Protocol, Sequence
|
||||
|
||||
from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback
|
||||
|
||||
|
||||
class AutoOrganizeLockProvider(Protocol):
|
||||
"""Minimal protocol for objects exposing auto-organize locking primitives."""
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
"""Return ``True`` when an auto-organize operation is in-flight."""
|
||||
|
||||
async def get_auto_organize_lock(self) -> asyncio.Lock:
|
||||
"""Return the asyncio lock guarding auto-organize operations."""
|
||||
|
||||
|
||||
class AutoOrganizeInProgressError(RuntimeError):
|
||||
"""Raised when an auto-organize run is already active."""
|
||||
|
||||
|
||||
class AutoOrganizeUseCase:
|
||||
"""Coordinate auto-organize execution behind a shared lock."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_service: ModelFileService,
|
||||
lock_provider: AutoOrganizeLockProvider,
|
||||
) -> None:
|
||||
self._file_service = file_service
|
||||
self._lock_provider = lock_provider
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
file_paths: Optional[Sequence[str]] = None,
|
||||
progress_callback: Optional[ProgressCallback] = None,
|
||||
) -> AutoOrganizeResult:
|
||||
"""Run the auto-organize routine guarded by a shared lock."""
|
||||
|
||||
if self._lock_provider.is_auto_organize_running():
|
||||
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||
|
||||
lock = await self._lock_provider.get_auto_organize_lock()
|
||||
if lock.locked():
|
||||
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||
|
||||
async with lock:
|
||||
return await self._file_service.auto_organize_models(
|
||||
file_paths=list(file_paths) if file_paths is not None else None,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Use case encapsulating the bulk metadata refresh orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Protocol, Sequence
|
||||
|
||||
from ..metadata_sync_service import MetadataSyncService
|
||||
|
||||
|
||||
class MetadataRefreshProgressReporter(Protocol):
|
||||
"""Protocol for progress reporters used during metadata refresh."""
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
"""Handle a metadata refresh progress update."""
|
||||
|
||||
|
||||
class BulkMetadataRefreshUseCase:
|
||||
"""Coordinate bulk metadata refreshes with progress emission."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
metadata_sync: MetadataSyncService,
|
||||
settings_service,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._metadata_sync = metadata_sync
|
||||
self._settings = settings_service
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Refresh metadata for all qualifying models."""
|
||||
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
total_models = len(cache.raw_data)
|
||||
|
||||
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
|
||||
to_process: Sequence[Dict[str, Any]] = [
|
||||
model
|
||||
for model in cache.raw_data
|
||||
if model.get("sha256")
|
||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||
and (
|
||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||
)
|
||||
]
|
||||
|
||||
total_to_process = len(to_process)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
async def emit(status: str, **extra: Any) -> None:
|
||||
if progress_callback is None:
|
||||
return
|
||||
payload = {"status": status, "total": total_to_process, "processed": processed, "success": success}
|
||||
payload.update(extra)
|
||||
await progress_callback.on_progress(payload)
|
||||
|
||||
await emit("started")
|
||||
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, _ = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
update_cache_func=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
if original_name != model.get("model_name"):
|
||||
needs_resort = True
|
||||
processed += 1
|
||||
await emit(
|
||||
"processing",
|
||||
processed=processed,
|
||||
success=success,
|
||||
current_name=model.get("model_name", "Unknown"),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
processed += 1
|
||||
self._logger.error(
|
||||
"Error fetching CivitAI data for %s: %s",
|
||||
model.get("file_path"),
|
||||
exc,
|
||||
)
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
await emit("completed", processed=processed, success=success)
|
||||
|
||||
message = (
|
||||
"Successfully updated "
|
||||
f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})"
|
||||
)
|
||||
|
||||
return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models}
|
||||
|
||||
async def execute_with_error_handling(
|
||||
self,
|
||||
*,
|
||||
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Wrapper providing progress notification on unexpected failures."""
|
||||
|
||||
try:
|
||||
return await self.execute(progress_callback=progress_callback)
|
||||
except Exception as exc:
|
||||
if progress_callback is not None:
|
||||
await progress_callback.on_progress({"status": "error", "error": str(exc)})
|
||||
raise
|
||||
37
py/services/use_cases/download_model_use_case.py
Normal file
37
py/services/use_cases/download_model_use_case.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Use case for scheduling model downloads with consistent error handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..download_coordinator import DownloadCoordinator
|
||||
|
||||
|
||||
class DownloadModelValidationError(ValueError):
|
||||
"""Raised when incoming payload validation fails."""
|
||||
|
||||
|
||||
class DownloadModelEarlyAccessError(RuntimeError):
|
||||
"""Raised when the download is gated behind Civitai early access."""
|
||||
|
||||
|
||||
class DownloadModelUseCase:
|
||||
"""Coordinate download scheduling through the coordinator service."""
|
||||
|
||||
def __init__(self, *, download_coordinator: DownloadCoordinator) -> None:
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download and normalize error conditions."""
|
||||
|
||||
try:
|
||||
return await self._download_coordinator.schedule_download(payload)
|
||||
except ValueError as exc:
|
||||
raise DownloadModelValidationError(str(exc)) from exc
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
message = str(exc)
|
||||
if "401" in message:
|
||||
raise DownloadModelEarlyAccessError(
|
||||
"Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
) from exc
|
||||
raise
|
||||
@@ -1,11 +1,29 @@
|
||||
from typing import Dict, Any
|
||||
"""Progress callback implementations backed by the shared WebSocket manager."""
|
||||
|
||||
from typing import Any, Dict, Protocol
|
||||
|
||||
from .model_file_service import ProgressCallback
|
||||
from .websocket_manager import ws_manager
|
||||
|
||||
|
||||
class WebSocketProgressCallback(ProgressCallback):
|
||||
"""WebSocket implementation of progress callback"""
|
||||
|
||||
class ProgressReporter(Protocol):
|
||||
"""Protocol representing an async progress callback."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send progress data via WebSocket"""
|
||||
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
||||
"""Handle a progress update payload."""
|
||||
|
||||
|
||||
class WebSocketProgressCallback(ProgressCallback):
|
||||
"""WebSocket implementation of progress callback."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send progress data via WebSocket."""
|
||||
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
||||
|
||||
|
||||
class WebSocketBroadcastCallback:
|
||||
"""Generic WebSocket progress callback broadcasting to all clients."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send the provided payload to all connected clients."""
|
||||
await ws_manager.broadcast(progress_data)
|
||||
|
||||
Reference in New Issue
Block a user