From b060dc99fcdb62bc97c2ba258973afe5232d8ba2 Mon Sep 17 00:00:00 2001 From: Will Miao Date: Tue, 2 Jun 2026 20:38:47 +0800 Subject: [PATCH] feat(download): add skip-download endpoint that cancels in-memory tracking while preserving partial files on disk --- py/routes/handlers/model_handlers.py | 16 ++++++ py/routes/model_route_registrar.py | 1 + py/services/download_coordinator.py | 17 ++++++ py/services/download_manager.py | 83 ++++++++++++++++++++++++++++ 4 files changed, 117 insertions(+) diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 3724a01f..d38ad34b 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -1472,6 +1472,21 @@ class ModelDownloadHandler: ) return web.Response(status=500, text=str(exc)) + async def skip_download_get(self, request: web.Request) -> web.Response: + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response( + {"success": False, "error": "Download ID is required"}, status=400 + ) + result = await self._download_coordinator.skip_download(download_id) + return web.json_response(result) + except Exception as exc: + self._logger.error( + "Error skipping download via GET: %s", exc, exc_info=True + ) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def cancel_download_get(self, request: web.Request) -> web.Response: try: download_id = request.query.get("download_id") @@ -2566,6 +2581,7 @@ class ModelHandlerSet: "download_model": self.download.download_model, "download_model_get": self.download.download_model_get, "cancel_download_get": self.download.cancel_download_get, + "skip_download_get": self.download.skip_download_get, "pause_download_get": self.download.pause_download_get, "resume_download_get": self.download.resume_download_get, "get_download_progress": self.download.get_download_progress, diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index cd0f208a..527e1f75 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -101,6 +101,7 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/download-model", "download_model"), RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"), RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"), + RouteDefinition("GET", "/api/lm/skip-download", "skip_download_get"), RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"), RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"), RouteDefinition( diff --git a/py/services/download_coordinator.py b/py/services/download_coordinator.py index 035f288c..ddfc859b 100644 --- a/py/services/download_coordinator.py +++ b/py/services/download_coordinator.py @@ -110,6 +110,23 @@ class DownloadCoordinator: return result + async def skip_download(self, download_id: str) -> Dict[str, Any]: + """Skip a download while preserving all partial files on disk.""" + download_manager = await self._download_manager_factory() + result = await download_manager.skip_download(download_id) + + await self._ws_manager.broadcast_download_progress( + download_id, + { + "status": "skipped", + "progress": 0, + "download_id": download_id, + "message": "Download skipped by user (partial files preserved)", + }, + ) + + return result + async def pause_download(self, download_id: str) -> Dict[str, Any]: """Pause an active download and notify listeners.""" diff --git a/py/services/download_manager.py b/py/services/download_manager.py index a9f7bb3a..fe79121f 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -2404,6 +2404,89 @@ class DownloadManager: self._download_tasks.pop(download_id, None) await self._aria2_state_store.remove(download_id) + async def skip_download(self, download_id: str) -> Dict: + """Skip a download while preserving all partial files on disk. + + Removes all in-memory tracking (asyncio task, semaphore, active/pause + state) but keeps partial files (.part / .aria2) on disk so that a + subsequent download-model-get request for the same save path can + auto-resume from the preserved partial download. + + Args: + download_id: The unique identifier of the download task + + Returns: + Dict: Status of the skip operation + """ + await self._restore_persisted_downloads() + + if download_id not in self._download_tasks and download_id not in self._active_downloads: + return {"success": False, "error": "Download task not found"} + + download_info = self._active_downloads.get(download_id) + task = self._download_tasks.get(download_id) + active_statuses = {"queued", "waiting", "downloading", "paused", "cancelling"} + if task is None and ( + not isinstance(download_info, dict) + or download_info.get("status") not in active_statuses + ): + return {"success": False, "error": "Download task not found"} + + backend = ( + self._active_downloads.get(download_id, {}).get("transfer_backend") + or "python" + ) + + try: + # For aria2: pause the transfer rather than force-removing it, so + # the .aria2 control file stays on disk for future resume + if backend == "aria2": + try: + aria2_downloader = await get_aria2_downloader() + pause_result = await aria2_downloader.pause_download(download_id) + if not pause_result.get("success"): + logger.warning( + "Failed to pause aria2 transfer for %s during skip: %s", + download_id, + pause_result.get("error"), + ) + except Exception as exc: + logger.warning( + "Failed to pause aria2 transfer for %s during skip: %s", + download_id, + exc, + ) + + # Cancel the asyncio task so the semaphore slot is released + if task is not None: + task.cancel() + + # Resume pause event so the task can exit cleanly + pause_control = self._pause_events.get(download_id) + if pause_control is not None: + pause_control.resume() + + # Wait briefly for task to acknowledge cancellation + if task is not None: + try: + await asyncio.wait_for(asyncio.shield(task), timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + logger.info(f"Download skipped for task {download_id} (partial files preserved)") + return {"success": True, "message": "Download skipped successfully"} + except Exception as e: + logger.error(f"Error skipping download: {e}", exc_info=True) + return {"success": False, "error": str(e)} + finally: + # Clean up local in-memory tracking only - NO file deletion + self._pause_events.pop(download_id, None) + self._download_tasks.pop(download_id, None) + if download_id in self._active_downloads: + del self._active_downloads[download_id] + # Preserve aria2 state store entry so the partial download + # info survives restarts and can be resumed later + async def pause_download(self, download_id: str) -> Dict: """Pause an active download without losing progress."""