mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat(routes): extract orchestration use cases
This commit is contained in:
@@ -14,10 +14,19 @@ import jinja2
|
||||
from ...config import config
|
||||
from ...services.download_coordinator import DownloadCoordinator
|
||||
from ...services.metadata_sync_service import MetadataSyncService
|
||||
from ...services.model_file_service import ModelFileService, ModelMoveService
|
||||
from ...services.model_file_service import ModelMoveService
|
||||
from ...services.preview_asset_service import PreviewAssetService
|
||||
from ...services.settings_manager import SettingsManager
|
||||
from ...services.tag_update_service import TagUpdateService
|
||||
from ...services.use_cases import (
|
||||
AutoOrganizeInProgressError,
|
||||
AutoOrganizeUseCase,
|
||||
BulkMetadataRefreshUseCase,
|
||||
DownloadModelEarlyAccessError,
|
||||
DownloadModelUseCase,
|
||||
DownloadModelValidationError,
|
||||
MetadataRefreshProgressReporter,
|
||||
)
|
||||
from ...services.websocket_manager import WebSocketManager
|
||||
from ...services.websocket_progress_callback import WebSocketProgressCallback
|
||||
from ...utils.file_utils import calculate_sha256
|
||||
@@ -600,33 +609,29 @@ class ModelDownloadHandler:
|
||||
*,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
download_use_case: DownloadModelUseCase,
|
||||
download_coordinator: DownloadCoordinator,
|
||||
) -> None:
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
self._download_use_case = download_use_case
|
||||
self._download_coordinator = download_coordinator
|
||||
|
||||
async def download_model(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
payload = await request.json()
|
||||
result = await self._download_coordinator.schedule_download(payload)
|
||||
result = await self._download_use_case.execute(payload)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
except DownloadModelValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except DownloadModelEarlyAccessError as exc:
|
||||
self._logger.warning("Early access error: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
||||
except Exception as exc:
|
||||
error_message = str(exc)
|
||||
if "401" in error_message:
|
||||
self._logger.warning("Early access error (401): %s", error_message)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Early Access Restriction: This model requires purchase. Please buy early access on Civitai.com.",
|
||||
},
|
||||
status=401,
|
||||
)
|
||||
self._logger.error("Error downloading model: %s", error_message)
|
||||
self._logger.error("Error downloading model: %s", error_message, exc_info=True)
|
||||
return web.json_response({"success": False, "error": error_message}, status=500)
|
||||
|
||||
async def download_model_get(self, request: web.Request) -> web.Response:
|
||||
@@ -653,12 +658,15 @@ class ModelDownloadHandler:
|
||||
future.set_result(data)
|
||||
|
||||
mock_request = type("MockRequest", (), {"json": lambda self=None: future})()
|
||||
result = await self._download_coordinator.schedule_download(data)
|
||||
result = await self._download_use_case.execute(data)
|
||||
if not result.get("success", False):
|
||||
return web.json_response(result, status=500)
|
||||
return web.json_response(result)
|
||||
except ValueError as exc:
|
||||
except DownloadModelValidationError as exc:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except DownloadModelEarlyAccessError as exc:
|
||||
self._logger.warning("Early access error: %s", exc)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=401)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error downloading model via GET: %s", exc, exc_info=True)
|
||||
return web.Response(status=500, text=str(exc))
|
||||
@@ -703,6 +711,8 @@ class ModelCivitaiHandler:
|
||||
expected_model_types: Callable[[], str],
|
||||
find_model_file: Callable[[Iterable[Mapping[str, object]]], Optional[Mapping[str, object]]],
|
||||
metadata_sync: MetadataSyncService,
|
||||
metadata_refresh_use_case: BulkMetadataRefreshUseCase,
|
||||
metadata_progress_callback: MetadataRefreshProgressReporter,
|
||||
) -> None:
|
||||
self._service = service
|
||||
self._settings = settings_service
|
||||
@@ -713,75 +723,16 @@ class ModelCivitaiHandler:
|
||||
self._expected_model_types = expected_model_types
|
||||
self._find_model_file = find_model_file
|
||||
self._metadata_sync = metadata_sync
|
||||
self._metadata_refresh_use_case = metadata_refresh_use_case
|
||||
self._metadata_progress_callback = metadata_progress_callback
|
||||
|
||||
async def fetch_all_civitai(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
cache = await self._service.scanner.get_cached_data()
|
||||
total = len(cache.raw_data)
|
||||
processed = 0
|
||||
success = 0
|
||||
needs_resort = False
|
||||
|
||||
enable_metadata_archive_db = self._settings.get("enable_metadata_archive_db", False)
|
||||
to_process = [
|
||||
model
|
||||
for model in cache.raw_data
|
||||
if model.get("sha256")
|
||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||
and (
|
||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||
)
|
||||
]
|
||||
total_to_process = len(to_process)
|
||||
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "started",
|
||||
"total": total_to_process,
|
||||
"processed": 0,
|
||||
"success": 0,
|
||||
})
|
||||
|
||||
for model in to_process:
|
||||
try:
|
||||
original_name = model.get("model_name")
|
||||
result, error = await self._metadata_sync.fetch_and_update_model(
|
||||
sha256=model["sha256"],
|
||||
file_path=model["file_path"],
|
||||
model_data=model,
|
||||
update_cache_func=self._service.scanner.update_single_model_cache,
|
||||
)
|
||||
if result:
|
||||
success += 1
|
||||
if original_name != model.get("model_name"):
|
||||
needs_resort = True
|
||||
processed += 1
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "processing",
|
||||
"total": total_to_process,
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
"current_name": model.get("model_name", "Unknown"),
|
||||
})
|
||||
except Exception as exc: # pragma: no cover - logging path
|
||||
self._logger.error("Error fetching CivitAI data for %s: %s", model["file_path"], exc)
|
||||
|
||||
if needs_resort:
|
||||
await cache.resort()
|
||||
|
||||
await self._ws_manager.broadcast({
|
||||
"status": "completed",
|
||||
"total": total_to_process,
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
})
|
||||
|
||||
return web.json_response({
|
||||
"success": True,
|
||||
"message": f"Successfully updated {success} of {processed} processed {self._service.model_type}s (total: {total})",
|
||||
})
|
||||
result = await self._metadata_refresh_use_case.execute_with_error_handling(
|
||||
progress_callback=self._metadata_progress_callback
|
||||
)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
await self._ws_manager.broadcast({"status": "error", "error": str(exc)})
|
||||
self._logger.error("Error in fetch_all_civitai for %ss: %s", self._service.model_type, exc)
|
||||
return web.Response(text=str(exc), status=500)
|
||||
|
||||
@@ -887,31 +838,18 @@ class ModelAutoOrganizeHandler:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
file_service: ModelFileService,
|
||||
use_case: AutoOrganizeUseCase,
|
||||
progress_callback: WebSocketProgressCallback,
|
||||
ws_manager: WebSocketManager,
|
||||
logger: logging.Logger,
|
||||
) -> None:
|
||||
self._file_service = file_service
|
||||
self._use_case = use_case
|
||||
self._progress_callback = progress_callback
|
||||
self._ws_manager = ws_manager
|
||||
self._logger = logger
|
||||
|
||||
async def auto_organize_models(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
if self._ws_manager.is_auto_organize_running():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
|
||||
auto_organize_lock = await self._ws_manager.get_auto_organize_lock()
|
||||
if auto_organize_lock.locked():
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
|
||||
file_paths = None
|
||||
if request.method == "POST":
|
||||
try:
|
||||
@@ -920,17 +858,24 @@ class ModelAutoOrganizeHandler:
|
||||
except Exception: # pragma: no cover - permissive path
|
||||
pass
|
||||
|
||||
async with auto_organize_lock:
|
||||
result = await self._file_service.auto_organize_models(
|
||||
file_paths=file_paths,
|
||||
progress_callback=self._progress_callback,
|
||||
)
|
||||
return web.json_response(result.to_dict())
|
||||
result = await self._use_case.execute(
|
||||
file_paths=file_paths,
|
||||
progress_callback=self._progress_callback,
|
||||
)
|
||||
return web.json_response(result.to_dict())
|
||||
except AutoOrganizeInProgressError:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Auto-organize is already running. Please wait for it to complete."},
|
||||
status=409,
|
||||
)
|
||||
except Exception as exc:
|
||||
self._logger.error("Error in auto_organize_models: %s", exc, exc_info=True)
|
||||
await self._ws_manager.broadcast_auto_organize_progress(
|
||||
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
|
||||
)
|
||||
try:
|
||||
await self._progress_callback.on_progress(
|
||||
{"type": "auto_organize_progress", "status": "error", "error": str(exc)}
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive reporting
|
||||
pass
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def get_auto_organize_progress(self, request: web.Request) -> web.Response:
|
||||
|
||||
Reference in New Issue
Block a user