mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat(routes): extract orchestration use cases
This commit is contained in:
@@ -19,7 +19,15 @@ from ..services.service_registry import ServiceRegistry
|
||||
from ..services.settings_manager import settings as default_settings
|
||||
from ..services.tag_update_service import TagUpdateService
|
||||
from ..services.websocket_manager import ws_manager as default_ws_manager
|
||||
from ..services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ..services.use_cases import (
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelUseCase,
|
||||
)
|
||||
from ..services.websocket_progress_callback import (
|
||||
WebSocketBroadcastCallback,
|
||||
WebSocketProgressCallback,
|
||||
)
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from ..utils.routes_common import ModelRouteUtils
|
||||
@@ -68,6 +76,7 @@ class BaseModelRoutes(ABC):
|
||||
self.model_file_service: ModelFileService | None = None
|
||||
self.model_move_service: ModelMoveService | None = None
|
||||
self.websocket_progress_callback = WebSocketProgressCallback()
|
||||
self.metadata_progress_callback = WebSocketBroadcastCallback()
|
||||
|
||||
self._handler_set: ModelHandlerSet | None = None
|
||||
self._handler_mapping: Dict[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||
@@ -132,11 +141,19 @@ class BaseModelRoutes(ABC):
|
||||
tag_update_service=self._tag_update_service,
|
||||
)
|
||||
query = ModelQueryHandler(service=service, logger=logger)
|
||||
download_use_case = DownloadModelUseCase(download_coordinator=self._download_coordinator)
|
||||
download = ModelDownloadHandler(
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
download_use_case=download_use_case,
|
||||
download_coordinator=self._download_coordinator,
|
||||
)
|
||||
metadata_refresh_use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
settings_service=self._settings,
|
||||
logger=logger,
|
||||
)
|
||||
civitai = ModelCivitaiHandler(
|
||||
service=service,
|
||||
settings_service=self._settings,
|
||||
@@ -147,10 +164,16 @@ class BaseModelRoutes(ABC):
|
||||
expected_model_types=self._get_expected_model_types,
|
||||
find_model_file=self._find_model_file,
|
||||
metadata_sync=self._metadata_sync_service,
|
||||
metadata_refresh_use_case=metadata_refresh_use_case,
|
||||
metadata_progress_callback=self.metadata_progress_callback,
|
||||
)
|
||||
move = ModelMoveHandler(move_service=self._ensure_move_service(), logger=logger)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
auto_organize_use_case = AutoOrganizeUseCase(
|
||||
file_service=self._ensure_file_service(),
|
||||
lock_provider=self._ws_manager,
|
||||
)
|
||||
auto_organize = ModelAutoOrganizeHandler(
|
||||
use_case=auto_organize_use_case,
|
||||
progress_callback=self.websocket_progress_callback,
|
||||
ws_manager=self._ws_manager,
|
||||
logger=logger,
|
||||
|
||||
@@ -14,10 +14,19 @@ import jinja2
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ...services.model_file_service import ModelMoveService
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.tag_update_service import TagUpdateService
|
||||
from ...services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
MetadataRefreshProgressReporter,
|
||||
)
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
@@ -600,33 +609,29 @@ class ModelDownloadHandler:
|
||||
*,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
download_use_case: DownloadModelUseCase,
|
||||
download_coordinator: DownloadCoordinator,
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
self._download_use_case = download_use_case
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def download_model(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_coordinator.schedule_download(payload)
|
||||
result = await self._download_use_case.execute(payload)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
except DownloadModelValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except DownloadModelEarlyAccessError as exc:
|
||||
self._logger.warning("Early access error: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
||||
except Exception as exc:
|
||||
error_message = str(exc)
|
||||
if "401" in error_message:
|
||||
self._logger.warning("Early access error (401): %s", error_message)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
self._logger.error("Error downloading model: %s", error_message)
|
||||
self._logger.error("Error downloading model: %s", error_message, exc_info=True)
|
||||
return web.json_response({"success": False, "error": error_message}, status=500)
|
||||
|
||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||
@@ -653,12 +658,15 @@ class ModelDownloadHandler:
|
||||
future.set_result(data)
|
||||
|
||||
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
||||
result = await self._download_coordinator.schedule_download(data)
|
||||
result = await self._download_use_case.execute(data)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
except DownloadModelValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except DownloadModelEarlyAccessError as exc:
|
||||
self._logger.warning("Early access error: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
@@ -703,6 +711,8 @@ class ModelCivitaiHandler:
|
||||
expected_model_types: Callable[[], str],
|
||||
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
|
||||
metadata_sync: MetadataSyncService,
|
||||
metadata_refresh_use_case: BulkMetadataRefreshUseCase,
|
||||
metadata_progress_callback: MetadataRefreshProgressReporter,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._settings = settings_service
|
||||
@@ -713,75 +723,16 @@ class ModelCivitaiHandler:
|
||||
self._expected_model_types = expected_model_types
|
||||
self._find_model_file = find_model_file
|
||||
self._metadata_sync = metadata_sync
|
||||
self._metadata_refresh_use_case = metadata_refresh_use_case
|
||||
self._metadata_progress_callback = metadata_progress_callback
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
|
||||
to_process = [
|
||||
model
|
||||
for model in cache.raw_data
|
||||
if model.get("sha256")
|
||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||
and (
|
||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||
)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "started",
|
||||
"total": total_to_process,
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
})
|
||||
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
update_cache_func=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
if original_name != model.get("model_name"):
|
||||
needs_resort = True
|
||||
processed += 1
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "processing",
|
||||
"total": total_to_process,
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
"current_name": model.get("model_name", "Unknown"),
|
||||
})
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc)
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "completed",
|
||||
"total": total_to_process,
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})",
|
||||
})
|
||||
result = await self._metadata_refresh_use_case.execute_with_error_handling(
|
||||
progress_callback=self._metadata_progress_callback
|
||||
)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
await self._ws_manager.broadcast({"status": "error", "error": str(exc)})
|
||||
self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
@@ -887,31 +838,18 @@ class ModelAutoOrganizeHandler:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_service: ModelFileService,
|
||||
use_case: AutoOrganizeUseCase,
|
||||
progress_callback: WebSocketProgressCallback,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._file_service = file_service
|
||||
self._use_case = use_case
|
||||
self._progress_callback = progress_callback
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
|
||||
async def auto_organize_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
if self._ws_manager.is_auto_organize_running():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
|
||||
auto_organize_lock = await self._ws_manager.get_auto_organize_lock()
|
||||
if auto_organize_lock.locked():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
|
||||
file_paths = None
|
||||
if request.method == "POST":
|
||||
try:
|
||||
@@ -920,17 +858,24 @@ class ModelAutoOrganizeHandler:
|
||||
except Exception: # pragma: no cover - permissive path
|
||||
pass
|
||||
|
||||
async with auto_organize_lock:
|
||||
result = await self._file_service.auto_organize_models(
|
||||
file_paths=file_paths,
|
||||
progress_callback=self._progress_callback,
|
||||
)
|
||||
return web.json_response(result.to_dict())
|
||||
result = await self._use_case.execute(
|
||||
file_paths=file_paths,
|
||||
progress_callback=self._progress_callback,
|
||||
)
|
||||
return web.json_response(result.to_dict())
|
||||
except AutoOrganizeInProgressError:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True)
|
||||
await self._ws_manager.broadcast_auto_organize_progress(
|
||||
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
|
||||
)
|
||||
try:
|
||||
await self._progress_callback.on_progress(
|
||||
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive reporting
|
||||
pass
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_auto_organize_progress(self, request: web.Request) -> web.Response:
|
||||
|
||||
25
py/services/use_cases/__init__.py
Normal file
25
py/services/use_cases/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Application-level orchestration services for model routes."""
|
||||
|
||||
from .auto_organize_use_case import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
)
|
||||
from .bulk_metadata_refresh_use_case import (
|
||||
BulkMetadataRefreshUseCase,
|
||||
MetadataRefreshProgressReporter,
|
||||
)
|
||||
from .download_model_use_case import (
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AutoOrganizeInProgressError",
|
||||
"AutoOrganizeUseCase",
|
||||
"BulkMetadataRefreshUseCase",
|
||||
"MetadataRefreshProgressReporter",
|
||||
"DownloadModelEarlyAccessError",
|
||||
"DownloadModelUseCase",
|
||||
"DownloadModelValidationError",
|
||||
]
|
||||
56
py/services/use_cases/auto_organize_use_case.py
Normal file
56
py/services/use_cases/auto_organize_use_case.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Auto-organize use case orchestrating concurrency and progress handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Protocol, Sequence
|
||||
|
||||
from ..model_file_service import AutoOrganizeResult, ModelFileService, ProgressCallback
|
||||
|
||||
|
||||
class AutoOrganizeLockProvider(Protocol):
|
||||
"""Minimal protocol for objects exposing auto-organize locking primitives."""
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
"""Return ``True`` when an auto-organize operation is in-flight."""
|
||||
|
||||
async def get_auto_organize_lock(self) -> asyncio.Lock:
|
||||
"""Return the asyncio lock guarding auto-organize operations."""
|
||||
|
||||
|
||||
class AutoOrganizeInProgressError(RuntimeError):
|
||||
"""Raised when an auto-organize run is already active."""
|
||||
|
||||
|
||||
class AutoOrganizeUseCase:
|
||||
"""Coordinate auto-organize execution behind a shared lock."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_service: ModelFileService,
|
||||
lock_provider: AutoOrganizeLockProvider,
|
||||
) -> None:
|
||||
self._file_service = file_service
|
||||
self._lock_provider = lock_provider
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
file_paths: Optional[Sequence[str]] = None,
|
||||
progress_callback: Optional[ProgressCallback] = None,
|
||||
) -> AutoOrganizeResult:
|
||||
"""Run the auto-organize routine guarded by a shared lock."""
|
||||
|
||||
if self._lock_provider.is_auto_organize_running():
|
||||
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||
|
||||
lock = await self._lock_provider.get_auto_organize_lock()
|
||||
if lock.locked():
|
||||
raise AutoOrganizeInProgressError("Auto-organize is already running")
|
||||
|
||||
async with lock:
|
||||
return await self._file_service.auto_organize_models(
|
||||
file_paths=list(file_paths) if file_paths is not None else None,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
122
py/services/use_cases/bulk_metadata_refresh_use_case.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Use case encapsulating the bulk metadata refresh orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Protocol, Sequence
|
||||
|
||||
from ..metadata_sync_service import MetadataSyncService
|
||||
|
||||
|
||||
class MetadataRefreshProgressReporter(Protocol):
|
||||
"""Protocol for progress reporters used during metadata refresh."""
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
"""Handle a metadata refresh progress update."""
|
||||
|
||||
|
||||
class BulkMetadataRefreshUseCase:
|
||||
"""Coordinate bulk metadata refreshes with progress emission."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service,
|
||||
metadata_sync: MetadataSyncService,
|
||||
settings_service,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._metadata_sync = metadata_sync
|
||||
self._settings = settings_service
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
*,
|
||||
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Refresh metadata for all qualifying models."""
|
||||
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
total_models = len(cache.raw_data)
|
||||
|
||||
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
|
||||
to_process: Sequence[Dict[str, Any]] = [
|
||||
model
|
||||
for model in cache.raw_data
|
||||
if model.get("sha256")
|
||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||
and (
|
||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||
)
|
||||
]
|
||||
|
||||
total_to_process = len(to_process)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
async def emit(status: str, **extra: Any) -> None:
|
||||
if progress_callback is None:
|
||||
return
|
||||
payload = {"status": status, "total": total_to_process, "processed": processed, "success": success}
|
||||
payload.update(extra)
|
||||
await progress_callback.on_progress(payload)
|
||||
|
||||
await emit("started")
|
||||
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, _ = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
update_cache_func=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
if original_name != model.get("model_name"):
|
||||
needs_resort = True
|
||||
processed += 1
|
||||
await emit(
|
||||
"processing",
|
||||
processed=processed,
|
||||
success=success,
|
||||
current_name=model.get("model_name", "Unknown"),
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
processed += 1
|
||||
self._logger.error(
|
||||
"Error fetching CivitAI data for %s: %s",
|
||||
model.get("file_path"),
|
||||
exc,
|
||||
)
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
await emit("completed", processed=processed, success=success)
|
||||
|
||||
message = (
|
||||
"Successfully updated "
|
||||
f"{success} of {processed} processed {self._service.model_type}s (total: {total_models})"
|
||||
)
|
||||
|
||||
return {"success": True, "message": message, "processed": processed, "updated": success, "total": total_models}
|
||||
|
||||
async def execute_with_error_handling(
|
||||
self,
|
||||
*,
|
||||
progress_callback: Optional[MetadataRefreshProgressReporter] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Wrapper providing progress notification on unexpected failures."""
|
||||
|
||||
try:
|
||||
return await self.execute(progress_callback=progress_callback)
|
||||
except Exception as exc:
|
||||
if progress_callback is not None:
|
||||
await progress_callback.on_progress({"status": "error", "error": str(exc)})
|
||||
raise
|
||||
37
py/services/use_cases/download_model_use_case.py
Normal file
37
py/services/use_cases/download_model_use_case.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Use case for scheduling model downloads with consistent error handling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from ..download_coordinator import DownloadCoordinator
|
||||
|
||||
|
||||
class DownloadModelValidationError(ValueError):
|
||||
"""Raised when incoming payload validation fails."""
|
||||
|
||||
|
||||
class DownloadModelEarlyAccessError(RuntimeError):
|
||||
"""Raised when the download is gated behind Civitai early access."""
|
||||
|
||||
|
||||
class DownloadModelUseCase:
|
||||
"""Coordinate download scheduling through the coordinator service."""
|
||||
|
||||
def __init__(self, *, download_coordinator: DownloadCoordinator) -> None:
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def execute(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Schedule a download and normalize error conditions."""
|
||||
|
||||
try:
|
||||
return await self._download_coordinator.schedule_download(payload)
|
||||
except ValueError as exc:
|
||||
raise DownloadModelValidationError(str(exc)) from exc
|
||||
except Exception as exc: # pragma: no cover - defensive logging path
|
||||
message = str(exc)
|
||||
if "401" in message:
|
||||
raise DownloadModelEarlyAccessError(
|
||||
"Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com."
|
||||
) from exc
|
||||
raise
|
||||
@@ -1,11 +1,29 @@
|
||||
from typing import Dict, Any
|
||||
"""Progress callback implementations backed by the shared WebSocket manager."""
|
||||
|
||||
from typing import Any, Dict, Protocol
|
||||
|
||||
from .model_file_service import ProgressCallback
|
||||
from .websocket_manager import ws_manager
|
||||
|
||||
|
||||
class WebSocketProgressCallback(ProgressCallback):
|
||||
"""WebSocket implementation of progress callback"""
|
||||
|
||||
class ProgressReporter(Protocol):
|
||||
"""Protocol representing an async progress callback."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send progress data via WebSocket"""
|
||||
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
||||
"""Handle a progress update payload."""
|
||||
|
||||
|
||||
class WebSocketProgressCallback(ProgressCallback):
|
||||
"""WebSocket implementation of progress callback."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send progress data via WebSocket."""
|
||||
await ws_manager.broadcast_auto_organize_progress(progress_data)
|
||||
|
||||
|
||||
class WebSocketBroadcastCallback:
|
||||
"""Generic WebSocket progress callback broadcasting to all clients."""
|
||||
|
||||
async def on_progress(self, progress_data: Dict[str, Any]) -> None:
|
||||
"""Send the provided payload to all connected clients."""
|
||||
await ws_manager.broadcast(progress_data)
|
||||
|
||||
@@ -28,6 +28,7 @@ spec.loader.exec_module(py_local)
|
||||
sys.modules.setdefault("py_local", py_local)
|
||||
|
||||
from py_local.routes.base_model_routes import BaseModelRoutes
|
||||
from py_local.services.model_file_service import AutoOrganizeResult
|
||||
from py_local.services.service_registry import ServiceRegistry
|
||||
from py_local.services.websocket_manager import ws_manager
|
||||
from py_local.utils.routes_common import ExifUtils
|
||||
@@ -222,6 +223,25 @@ def test_download_model_invokes_download_manager(
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_download_model_requires_identifier(mock_service, download_manager_stub):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post(
|
||||
"/api/lm/download-model",
|
||||
json={"model_root": "/tmp"},
|
||||
)
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 400
|
||||
assert payload["success"] is False
|
||||
assert "Missing required" in payload["error"]
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
@@ -235,5 +255,65 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
||||
assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch):
|
||||
async def fake_auto_organize(self, file_paths=None, progress_callback=None):
|
||||
result = AutoOrganizeResult()
|
||||
result.total = 1
|
||||
result.processed = 1
|
||||
result.success_count = 1
|
||||
result.skipped_count = 0
|
||||
result.failure_count = 0
|
||||
result.operation_type = "bulk"
|
||||
if progress_callback is not None:
|
||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"})
|
||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"})
|
||||
return result
|
||||
|
||||
monkeypatch.setattr(
|
||||
py_local.services.model_file_service.ModelFileService,
|
||||
"auto_organize_models",
|
||||
fake_auto_organize,
|
||||
)
|
||||
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []})
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 200
|
||||
assert payload["success"] is True
|
||||
|
||||
progress = ws_manager.get_auto_organize_progress()
|
||||
assert progress is not None
|
||||
assert progress["status"] == "completed"
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
|
||||
def test_auto_organize_conflict_when_running(mock_service):
|
||||
async def scenario():
|
||||
client = await create_test_client(mock_service)
|
||||
try:
|
||||
await ws_manager.broadcast_auto_organize_progress(
|
||||
{"type": "auto_organize_progress", "status": "started"}
|
||||
)
|
||||
|
||||
response = await client.post("/api/lm/test-models/auto-organize")
|
||||
payload = await response.json()
|
||||
|
||||
assert response.status == 409
|
||||
assert payload == {
|
||||
"success": False,
|
||||
"error": "Auto-organize is already running. Please wait for it to complete.",
|
||||
}
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
asyncio.run(scenario())
|
||||
|
||||
191
tests/services/test_use_cases.py
Normal file
191
tests/services/test_use_cases.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from py_local.services.model_file_service import AutoOrganizeResult
|
||||
from py_local.services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
)
|
||||
from tests.conftest import MockModelService, MockScanner
|
||||
|
||||
|
||||
class StubLockProvider:
|
||||
def __init__(self) -> None:
|
||||
self._lock = asyncio.Lock()
|
||||
self.running = False
|
||||
|
||||
def is_auto_organize_running(self) -> bool:
|
||||
return self.running
|
||||
|
||||
async def get_auto_organize_lock(self) -> asyncio.Lock:
|
||||
return self._lock
|
||||
|
||||
|
||||
class StubFileService:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def auto_organize_models(
|
||||
self,
|
||||
*,
|
||||
file_paths: Optional[List[str]] = None,
|
||||
progress_callback=None,
|
||||
) -> AutoOrganizeResult:
|
||||
result = AutoOrganizeResult()
|
||||
result.total = len(file_paths or [])
|
||||
self.calls.append({"file_paths": file_paths, "progress_callback": progress_callback})
|
||||
return result
|
||||
|
||||
|
||||
class StubMetadataSync:
|
||||
def __init__(self) -> None:
|
||||
self.calls: List[Dict[str, Any]] = []
|
||||
|
||||
async def fetch_and_update_model(self, **kwargs: Any):
|
||||
self.calls.append(kwargs)
|
||||
model_data = kwargs["model_data"]
|
||||
model_data["model_name"] = model_data.get("model_name", "model") + "-updated"
|
||||
return True, None
|
||||
|
||||
|
||||
@dataclass
|
||||
class StubSettings:
|
||||
enable_metadata_archive_db: bool = False
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
if key == "enable_metadata_archive_db":
|
||||
return self.enable_metadata_archive_db
|
||||
return default
|
||||
|
||||
|
||||
class ProgressCollector:
|
||||
def __init__(self) -> None:
|
||||
self.events: List[Dict[str, Any]] = []
|
||||
|
||||
async def on_progress(self, payload: Dict[str, Any]) -> None:
|
||||
self.events.append(payload)
|
||||
|
||||
|
||||
class StubDownloadCoordinator:
|
||||
def __init__(self, *, error: Optional[str] = None) -> None:
|
||||
self.error = error
|
||||
self.payloads: List[Dict[str, Any]] = []
|
||||
|
||||
async def schedule_download(self, payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
self.payloads.append(payload)
|
||||
if self.error == "validation":
|
||||
raise ValueError("Missing required parameter: Please provide either 'model_id' or 'model_version_id'")
|
||||
if self.error == "401":
|
||||
raise RuntimeError("401 Unauthorized")
|
||||
return {"success": True, "download_id": "abc123"}
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_executes_with_lock() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
|
||||
|
||||
result = await use_case.execute(file_paths=["model1"], progress_callback=None)
|
||||
|
||||
assert isinstance(result, AutoOrganizeResult)
|
||||
assert file_service.calls[0]["file_paths"] == ["model1"]
|
||||
|
||||
|
||||
async def test_auto_organize_use_case_rejects_when_running() -> None:
|
||||
file_service = StubFileService()
|
||||
lock_provider = StubLockProvider()
|
||||
lock_provider.running = True
|
||||
use_case = AutoOrganizeUseCase(file_service=file_service, lock_provider=lock_provider)
|
||||
|
||||
with pytest.raises(AutoOrganizeInProgressError):
|
||||
await use_case.execute(file_paths=None, progress_callback=None)
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_emits_progress_and_updates_cache() -> None:
|
||||
scanner = MockScanner()
|
||||
scanner._cache.raw_data = [
|
||||
{
|
||||
"file_path": "model1.safetensors",
|
||||
"sha256": "hash",
|
||||
"from_civitai": True,
|
||||
"model_name": "Demo",
|
||||
}
|
||||
]
|
||||
service = MockModelService(scanner)
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings()
|
||||
progress = ProgressCollector()
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert result["success"] is True
|
||||
assert progress.events[0]["status"] == "started"
|
||||
assert progress.events[-1]["status"] == "completed"
|
||||
assert metadata_sync.calls
|
||||
assert scanner._cache.resort_calls == 1
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_reports_errors() -> None:
|
||||
class FailingScanner(MockScanner):
|
||||
async def get_cached_data(self, force_refresh: bool = False):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
service = MockModelService(FailingScanner())
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings()
|
||||
progress = ProgressCollector()
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert progress.events
|
||||
assert progress.events[-1]["status"] == "error"
|
||||
assert progress.events[-1]["error"] == "boom"
|
||||
|
||||
|
||||
async def test_download_model_use_case_raises_validation_error() -> None:
|
||||
coordinator = StubDownloadCoordinator(error="validation")
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
with pytest.raises(DownloadModelValidationError):
|
||||
await use_case.execute({})
|
||||
|
||||
|
||||
async def test_download_model_use_case_raises_early_access() -> None:
|
||||
coordinator = StubDownloadCoordinator(error="401")
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
with pytest.raises(DownloadModelEarlyAccessError):
|
||||
await use_case.execute({"model_id": 1})
|
||||
|
||||
|
||||
async def test_download_model_use_case_returns_result() -> None:
|
||||
coordinator = StubDownloadCoordinator()
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
result = await use_case.execute({"model_id": 1})
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["download_id"] == "abc123"
|
||||
Reference in New Issue
Block a user