From 055c1ca0d493a42cd2f39de1735046de58b30e9a Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Mon, 13 Oct 2025 21:30:23 +0800 Subject: [PATCH] test(downloads): cover pause and resume flows --- py/routes/handlers/model_handlers.py | 52 +++++++-- py/routes/model_route_registrar.py | 2 + py/services/download_coordinator.py | 50 ++++++++ py/services/download_manager.py | 121 +++++++++++++++++--- py/services/downloader.py | 6 +- py/services/websocket_manager.py | 5 + tests/services/test_download_coordinator.py | 117 +++++++++++++++++++ tests/services/test_download_manager.py | 62 ++++++++++ 8 files changed, 388 insertions(+), 27 deletions(-) create mode 100644 tests/services/test_download_coordinator.py diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index 6f7f7747..16089b66 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -758,6 +758,30 @@ class ModelDownloadHandler: self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) + async def pause_download_get(self, request: web.Request) -> web.Response: + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response({"success": False, "error": "Download ID is required"}, status=400) + result = await self._download_coordinator.pause_download(download_id) + status = 200 if result.get("success") else 400 + return web.json_response(result, status=status) + except Exception as exc: + self._logger.error("Error pausing download via GET: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + + async def resume_download_get(self, request: web.Request) -> web.Response: + try: + download_id = request.query.get("download_id") + if not download_id: + return web.json_response({"success": False, "error": "Download ID is required"}, status=400) + result = await self._download_coordinator.resume_download(download_id) + status = 200 if result.get("success") else 400 + return web.json_response(result, status=status) + except Exception as exc: + self._logger.error("Error resuming download via GET: %s", exc, exc_info=True) + return web.json_response({"success": False, "error": str(exc)}, status=500) + async def get_download_progress(self, request: web.Request) -> web.Response: try: download_id = request.match_info.get("download_id") @@ -766,15 +790,23 @@ class ModelDownloadHandler: progress_data = self._ws_manager.get_download_progress(download_id) if progress_data is None: return web.json_response({"success": False, "error": "Download ID not found"}, status=404) - return web.json_response( - { - "success": True, - "progress": progress_data.get("progress", 0), - "bytes_downloaded": progress_data.get("bytes_downloaded"), - "total_bytes": progress_data.get("total_bytes"), - "bytes_per_second": progress_data.get("bytes_per_second", 0.0), - } - ) + response_payload = { + "success": True, + "progress": progress_data.get("progress", 0), + "bytes_downloaded": progress_data.get("bytes_downloaded"), + "total_bytes": progress_data.get("total_bytes"), + "bytes_per_second": progress_data.get("bytes_per_second", 0.0), + } + + status = progress_data.get("status") + if status and status != "progress": + response_payload["status"] = status + if "message" in progress_data: + response_payload["message"] = progress_data["message"] + elif status is None and "message" in progress_data: + response_payload["message"] = progress_data["message"] + + return web.json_response(response_payload) except Exception as exc: self._logger.error("Error getting download progress: %s", exc, exc_info=True) return web.json_response({"success": False, "error": str(exc)}, status=500) @@ -1025,6 +1057,8 @@ class ModelHandlerSet: "download_model": self.download.download_model, "download_model_get": self.download.download_model_get, "cancel_download_get": self.download.cancel_download_get, + "pause_download_get": self.download.pause_download_get, + "resume_download_get": self.download.resume_download_get, "get_download_progress": self.download.get_download_progress, "get_civitai_versions": self.civitai.get_civitai_versions, "get_civitai_model_by_version": self.civitai.get_civitai_model_by_version, diff --git a/py/routes/model_route_registrar.py b/py/routes/model_route_registrar.py index 96f65fc5..105e5f09 100644 --- a/py/routes/model_route_registrar.py +++ b/py/routes/model_route_registrar.py @@ -58,6 +58,8 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = ( RouteDefinition("POST", "/api/lm/download-model", "download_model"), RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"), RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"), + RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"), + RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"), RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"), RouteDefinition("GET", "/{prefix}", "handle_models_page"), ) diff --git a/py/services/download_coordinator.py b/py/services/download_coordinator.py index 51700793..8a008412 100644 --- a/py/services/download_coordinator.py +++ b/py/services/download_coordinator.py @@ -109,6 +109,56 @@ class DownloadCoordinator: return result + async def pause_download(self, download_id: str) -> Dict[str, Any]: + """Pause an active download and notify listeners.""" + + download_manager = await self._download_manager_factory() + result = await download_manager.pause_download(download_id) + + if result.get("success"): + cached_progress = self._ws_manager.get_download_progress(download_id) or {} + payload: Dict[str, Any] = { + "status": "paused", + "progress": cached_progress.get("progress", 0), + "download_id": download_id, + "message": "Download paused by user", + } + + for field in ("bytes_downloaded", "total_bytes", "bytes_per_second"): + if field in cached_progress: + payload[field] = cached_progress[field] + + payload["bytes_per_second"] = 0.0 + + await self._ws_manager.broadcast_download_progress(download_id, payload) + + return result + + async def resume_download(self, download_id: str) -> Dict[str, Any]: + """Resume a paused download and notify listeners.""" + + download_manager = await self._download_manager_factory() + result = await download_manager.resume_download(download_id) + + if result.get("success"): + cached_progress = self._ws_manager.get_download_progress(download_id) or {} + payload: Dict[str, Any] = { + "status": "downloading", + "progress": cached_progress.get("progress", 0), + "download_id": download_id, + "message": "Download resumed by user", + } + + for field in ("bytes_downloaded", "total_bytes"): + if field in cached_progress: + payload[field] = cached_progress[field] + + payload["bytes_per_second"] = cached_progress.get("bytes_per_second", 0.0) + + await self._ws_manager.broadcast_download_progress(download_id, payload) + + return result + async def list_active_downloads(self) -> Dict[str, Any]: """Return the active download map from the underlying manager.""" diff --git a/py/services/download_manager.py b/py/services/download_manager.py index d15642b7..b8042414 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -43,6 +43,7 @@ class DownloadManager: self._active_downloads = OrderedDict() # download_id -> download_info self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_tasks = {} # download_id -> asyncio.Task + self._pause_events: Dict[str, asyncio.Event] = {} async def _get_lora_scanner(self): """Get the lora scanner from registry""" @@ -88,11 +89,15 @@ class DownloadManager: 'total_bytes': None, 'bytes_per_second': 0.0, } + + pause_event = asyncio.Event() + pause_event.set() + self._pause_events[task_id] = pause_event # Create tracking task download_task = asyncio.create_task( self._download_with_semaphore( - task_id, model_id, model_version_id, save_dir, + task_id, model_id, model_version_id, save_dir, relative_path, progress_callback, use_default_paths, source ) ) @@ -111,9 +116,10 @@ class DownloadManager: # Clean up task reference if task_id in self._download_tasks: del self._download_tasks[task_id] + self._pause_events.pop(task_id, None) async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int, - save_dir: str, relative_path: str, + save_dir: str, relative_path: str, progress_callback=None, use_default_paths: bool = False, source: str = None): """Execute download with semaphore to limit concurrency""" @@ -140,6 +146,13 @@ class DownloadManager: # Acquire semaphore to limit concurrent downloads try: async with self._download_semaphore: + pause_event = self._pause_events.get(task_id) + if pause_event is not None and not pause_event.is_set(): + if task_id in self._active_downloads: + self._active_downloads[task_id]['status'] = 'paused' + self._active_downloads[task_id]['bytes_per_second'] = 0.0 + await pause_event.wait() + # Update status to downloading if task_id in self._active_downloads: self._active_downloads[task_id]['status'] = 'downloading' @@ -189,9 +202,17 @@ class DownloadManager: if task_id in self._active_downloads: del self._active_downloads[task_id] - async def _execute_original_download(self, model_id, model_version_id, save_dir, - relative_path, progress_callback, use_default_paths, - download_id=None, source=None): + async def _execute_original_download( + self, + model_id, + model_version_id, + save_dir, + relative_path, + progress_callback, + use_default_paths, + download_id=None, + source=None, + ): """Wrapper for original download_from_civitai implementation""" try: # Check if model version already exists in library @@ -345,7 +366,7 @@ class DownloadManager: relative_path=relative_path, progress_callback=progress_callback, model_type=model_type, - download_id=download_id + download_id=download_id, ) # If early_access_msg exists and download failed, replace error message @@ -410,10 +431,17 @@ class DownloadManager: return formatted_path - async def _execute_download(self, download_urls: List[str], save_dir: str, - metadata, version_info: Dict, - relative_path: str, progress_callback=None, - model_type: str = "lora", download_id: str = None) -> Dict: + async def _execute_download( + self, + download_urls: List[str], + save_dir: str, + metadata, + version_info: Dict, + relative_path: str, + progress_callback=None, + model_type: str = "lora", + download_id: str = None, + ) -> Dict: """Execute the actual download process including preview images and model files""" try: # Extract original filename details @@ -444,6 +472,8 @@ class DownloadManager: part_path = save_path + '.part' metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' + + pause_event = self._pause_events.get(download_id) if download_id else None # Store file paths in active_downloads for potential cleanup if download_id and download_id in self._active_downloads: @@ -558,15 +588,22 @@ class DownloadManager: last_error = None for download_url in download_urls: use_auth = download_url.startswith("https://civitai.com/api/download/") - success, result = await downloader.download_file( - download_url, - save_path, # Use full path instead of separate dir and filename - progress_callback=lambda progress, snapshot=None: self._handle_download_progress( + download_kwargs = { + "progress_callback": lambda progress, snapshot=None: self._handle_download_progress( progress, progress_callback, snapshot, ), - use_auth=use_auth # Only use authentication for Civitai downloads + "use_auth": use_auth, # Only use authentication for Civitai downloads + } + + if pause_event is not None: + download_kwargs["pause_event"] = pause_event + + success, result = await downloader.download_file( + download_url, + save_path, # Use full path instead of separate dir and filename + **download_kwargs, ) if success: @@ -675,7 +712,7 @@ class DownloadManager: async def cancel_download(self, download_id: str) -> Dict: """Cancel an active download by download_id - + Args: download_id: The unique identifier of the download task @@ -689,7 +726,11 @@ class DownloadManager: # Get the task and cancel it task = self._download_tasks[download_id] task.cancel() - + + pause_event = self._pause_events.get(download_id) + if pause_event is not None: + pause_event.set() + # Update status in active downloads if download_id in self._active_downloads: self._active_downloads[download_id]['status'] = 'cancelling' @@ -756,6 +797,52 @@ class DownloadManager: except Exception as e: logger.error(f"Error cancelling download: {e}", exc_info=True) return {'success': False, 'error': str(e)} + finally: + self._pause_events.pop(download_id, None) + + async def pause_download(self, download_id: str) -> Dict: + """Pause an active download without losing progress.""" + + if download_id not in self._download_tasks: + return {'success': False, 'error': 'Download task not found'} + + pause_event = self._pause_events.get(download_id) + if pause_event is None: + pause_event = asyncio.Event() + pause_event.set() + self._pause_events[download_id] = pause_event + + if not pause_event.is_set(): + return {'success': False, 'error': 'Download is already paused'} + + pause_event.clear() + + download_info = self._active_downloads.get(download_id) + if download_info is not None: + download_info['status'] = 'paused' + download_info['bytes_per_second'] = 0.0 + + return {'success': True, 'message': 'Download paused successfully'} + + async def resume_download(self, download_id: str) -> Dict: + """Resume a previously paused download.""" + + pause_event = self._pause_events.get(download_id) + if pause_event is None: + return {'success': False, 'error': 'Download task not found'} + + if pause_event.is_set(): + return {'success': False, 'error': 'Download is not paused'} + + pause_event.set() + + download_info = self._active_downloads.get(download_id) + if download_info is not None: + if download_info.get('status') == 'paused': + download_info['status'] = 'downloading' + download_info.setdefault('bytes_per_second', 0.0) + + return {'success': True, 'message': 'Download resumed successfully'} @staticmethod def _coerce_progress_value(progress) -> float: diff --git a/py/services/downloader.py b/py/services/downloader.py index 5c2a1f0c..dafef78e 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -175,7 +175,8 @@ class Downloader: progress_callback: Optional[Callable[..., Awaitable[None]]] = None, use_auth: bool = False, custom_headers: Optional[Dict[str, str]] = None, - allow_resume: bool = True + allow_resume: bool = True, + pause_event: Optional[asyncio.Event] = None, ) -> Tuple[bool, str]: """ Download a file with resumable downloads and retry mechanism @@ -187,6 +188,7 @@ class Downloader: use_auth: Whether to include authentication headers (e.g., CivitAI API key) custom_headers: Additional headers to include in request allow_resume: Whether to support resumable downloads + pause_event: Optional event that, when cleared, will pause streaming until set again Returns: Tuple[bool, str]: (success, save_path or error message) @@ -309,6 +311,8 @@ class Downloader: mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb' with open(part_path, mode) as f: async for chunk in response.content.iter_chunked(self.chunk_size): + if pause_event is not None and not pause_event.is_set(): + await pause_event.wait() if chunk: # Run blocking file write in executor await loop.run_in_executor(None, f.write, chunk) diff --git a/py/services/websocket_manager.py b/py/services/websocket_manager.py index b98dc5ba..4c786853 100644 --- a/py/services/websocket_manager.py +++ b/py/services/websocket_manager.py @@ -164,6 +164,11 @@ class WebSocketManager: if field in data: progress_entry[field] = data[field] + if 'status' in data: + progress_entry['status'] = data['status'] + if 'message' in data: + progress_entry['message'] = data['message'] + self._download_progress[download_id] = progress_entry if download_id not in self._download_websockets: diff --git a/tests/services/test_download_coordinator.py b/tests/services/test_download_coordinator.py new file mode 100644 index 00000000..617eb0d1 --- /dev/null +++ b/tests/services/test_download_coordinator.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Tuple +from unittest.mock import AsyncMock + +from py.services.download_coordinator import DownloadCoordinator + + +@dataclass +class StubWebSocketManager: + progress: Dict[str, Dict[str, Any]] = field(default_factory=dict) + broadcasts: List[Tuple[str, Dict[str, Any]]] = field(default_factory=list) + + def generate_download_id(self) -> str: + return "generated" + + def get_download_progress(self, download_id: str) -> Dict[str, Any] | None: + return self.progress.get(download_id) + + async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None: + self.broadcasts.append((download_id, payload)) + + +async def test_pause_download_broadcasts_cached_state(): + ws_manager = StubWebSocketManager( + progress={ + "dl": { + "progress": 45, + "bytes_downloaded": 1024, + "total_bytes": 2048, + "bytes_per_second": 256.0, + } + } + ) + + download_manager = AsyncMock() + download_manager.pause_download = AsyncMock(return_value={"success": True}) + + async def factory(): + return download_manager + + coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory) + + result = await coordinator.pause_download("dl") + + assert result == {"success": True} + assert ws_manager.broadcasts == [ + ( + "dl", + { + "status": "paused", + "progress": 45, + "download_id": "dl", + "message": "Download paused by user", + "bytes_downloaded": 1024, + "total_bytes": 2048, + "bytes_per_second": 0.0, + }, + ) + ] + + +async def test_resume_download_broadcasts_cached_state(): + ws_manager = StubWebSocketManager( + progress={ + "dl": { + "progress": 75, + "bytes_downloaded": 2048, + "total_bytes": 4096, + "bytes_per_second": 512.0, + } + } + ) + + download_manager = AsyncMock() + download_manager.resume_download = AsyncMock(return_value={"success": True}) + + async def factory(): + return download_manager + + coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory) + + result = await coordinator.resume_download("dl") + + assert result == {"success": True} + assert ws_manager.broadcasts == [ + ( + "dl", + { + "status": "downloading", + "progress": 75, + "download_id": "dl", + "message": "Download resumed by user", + "bytes_downloaded": 2048, + "total_bytes": 4096, + "bytes_per_second": 512.0, + }, + ) + ] + + +async def test_pause_download_does_not_broadcast_on_failure(): + ws_manager = StubWebSocketManager() + + download_manager = AsyncMock() + download_manager.pause_download = AsyncMock(return_value={"success": False, "error": "nope"}) + + async def factory(): + return download_manager + + coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory) + + result = await coordinator.pause_download("dl") + + assert result == {"success": False, "error": "nope"} + assert ws_manager.broadcasts == [] diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index fad366f0..eedca8d7 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -1,3 +1,4 @@ +import asyncio import os from pathlib import Path from types import SimpleNamespace @@ -398,6 +399,67 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path): assert dummy_scanner.calls # ensure cache updated +async def test_pause_download_updates_state(): + manager = DownloadManager() + + download_id = "dl" + manager._download_tasks[download_id] = object() + pause_event = asyncio.Event() + pause_event.set() + manager._pause_events[download_id] = pause_event + manager._active_downloads[download_id] = { + "status": "downloading", + "bytes_per_second": 42.0, + } + + result = await manager.pause_download(download_id) + + assert result == {"success": True, "message": "Download paused successfully"} + assert download_id in manager._pause_events + assert manager._pause_events[download_id].is_set() is False + assert manager._active_downloads[download_id]["status"] == "paused" + assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0 + + +async def test_pause_download_rejects_unknown_task(): + manager = DownloadManager() + + result = await manager.pause_download("missing") + + assert result == {"success": False, "error": "Download task not found"} + + +async def test_resume_download_sets_event_and_status(): + manager = DownloadManager() + + download_id = "dl" + pause_event = asyncio.Event() + manager._pause_events[download_id] = pause_event + manager._active_downloads[download_id] = { + "status": "paused", + "bytes_per_second": 0.0, + } + + result = await manager.resume_download(download_id) + + assert result == {"success": True, "message": "Download resumed successfully"} + assert manager._pause_events[download_id].is_set() is True + assert manager._active_downloads[download_id]["status"] == "downloading" + + +async def test_resume_download_rejects_when_not_paused(): + manager = DownloadManager() + + download_id = "dl" + pause_event = asyncio.Event() + pause_event.set() + manager._pause_events[download_id] = pause_event + + result = await manager.resume_download(download_id) + + assert result == {"success": False, "error": "Download is not paused"} + + @pytest.mark.asyncio async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path): manager = DownloadManager()