test(downloads): cover pause and resume flows

This commit is contained in:
pixelpaws
2025-10-13 21:30:23 +08:00
parent 27370df93a
commit 055c1ca0d4
8 changed files with 388 additions and 27 deletions

View File

@@ -109,6 +109,56 @@ class DownloadCoordinator:
return result
async def pause_download(self, download_id: str) -> Dict[str, Any]:
"""Pause an active download and notify listeners."""
download_manager = await self._download_manager_factory()
result = await download_manager.pause_download(download_id)
if result.get("success"):
cached_progress = self._ws_manager.get_download_progress(download_id) or {}
payload: Dict[str, Any] = {
"status": "paused",
"progress": cached_progress.get("progress", 0),
"download_id": download_id,
"message": "Download paused by user",
}
for field in ("bytes_downloaded", "total_bytes", "bytes_per_second"):
if field in cached_progress:
payload[field] = cached_progress[field]
payload["bytes_per_second"] = 0.0
await self._ws_manager.broadcast_download_progress(download_id, payload)
return result
async def resume_download(self, download_id: str) -> Dict[str, Any]:
"""Resume a paused download and notify listeners."""
download_manager = await self._download_manager_factory()
result = await download_manager.resume_download(download_id)
if result.get("success"):
cached_progress = self._ws_manager.get_download_progress(download_id) or {}
payload: Dict[str, Any] = {
"status": "downloading",
"progress": cached_progress.get("progress", 0),
"download_id": download_id,
"message": "Download resumed by user",
}
for field in ("bytes_downloaded", "total_bytes"):
if field in cached_progress:
payload[field] = cached_progress[field]
payload["bytes_per_second"] = cached_progress.get("bytes_per_second", 0.0)
await self._ws_manager.broadcast_download_progress(download_id, payload)
return result
async def list_active_downloads(self) -> Dict[str, Any]:
"""Return the active download map from the underlying manager."""

View File

@@ -43,6 +43,7 @@ class DownloadManager:
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
self._pause_events: Dict[str, asyncio.Event] = {}
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
@@ -88,11 +89,15 @@ class DownloadManager:
'total_bytes': None,
'bytes_per_second': 0.0,
}
pause_event = asyncio.Event()
pause_event.set()
self._pause_events[task_id] = pause_event
# Create tracking task
download_task = asyncio.create_task(
self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir,
task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths, source
)
)
@@ -111,9 +116,10 @@ class DownloadManager:
# Clean up task reference
if task_id in self._download_tasks:
del self._download_tasks[task_id]
self._pause_events.pop(task_id, None)
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str,
save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False,
source: str = None):
"""Execute download with semaphore to limit concurrency"""
@@ -140,6 +146,13 @@ class DownloadManager:
# Acquire semaphore to limit concurrent downloads
try:
async with self._download_semaphore:
pause_event = self._pause_events.get(task_id)
if pause_event is not None and not pause_event.is_set():
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'paused'
self._active_downloads[task_id]['bytes_per_second'] = 0.0
await pause_event.wait()
# Update status to downloading
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'downloading'
@@ -189,9 +202,17 @@ class DownloadManager:
if task_id in self._active_downloads:
del self._active_downloads[task_id]
async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths,
download_id=None, source=None):
async def _execute_original_download(
self,
model_id,
model_version_id,
save_dir,
relative_path,
progress_callback,
use_default_paths,
download_id=None,
source=None,
):
"""Wrapper for original download_from_civitai implementation"""
try:
# Check if model version already exists in library
@@ -345,7 +366,7 @@ class DownloadManager:
relative_path=relative_path,
progress_callback=progress_callback,
model_type=model_type,
download_id=download_id
download_id=download_id,
)
# If early_access_msg exists and download failed, replace error message
@@ -410,10 +431,17 @@ class DownloadManager:
return formatted_path
async def _execute_download(self, download_urls: List[str], save_dir: str,
metadata, version_info: Dict,
relative_path: str, progress_callback=None,
model_type: str = "lora", download_id: str = None) -> Dict:
async def _execute_download(
self,
download_urls: List[str],
save_dir: str,
metadata,
version_info: Dict,
relative_path: str,
progress_callback=None,
model_type: str = "lora",
download_id: str = None,
) -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
# Extract original filename details
@@ -444,6 +472,8 @@ class DownloadManager:
part_path = save_path + '.part'
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
pause_event = self._pause_events.get(download_id) if download_id else None
# Store file paths in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads:
@@ -558,15 +588,22 @@ class DownloadManager:
last_error = None
for download_url in download_urls:
use_auth = download_url.startswith("https://civitai.com/api/download/")
success, result = await downloader.download_file(
download_url,
save_path, # Use full path instead of separate dir and filename
progress_callback=lambda progress, snapshot=None: self._handle_download_progress(
download_kwargs = {
"progress_callback": lambda progress, snapshot=None: self._handle_download_progress(
progress,
progress_callback,
snapshot,
),
use_auth=use_auth # Only use authentication for Civitai downloads
"use_auth": use_auth, # Only use authentication for Civitai downloads
}
if pause_event is not None:
download_kwargs["pause_event"] = pause_event
success, result = await downloader.download_file(
download_url,
save_path, # Use full path instead of separate dir and filename
**download_kwargs,
)
if success:
@@ -675,7 +712,7 @@ class DownloadManager:
async def cancel_download(self, download_id: str) -> Dict:
"""Cancel an active download by download_id
Args:
download_id: The unique identifier of the download task
@@ -689,7 +726,11 @@ class DownloadManager:
# Get the task and cancel it
task = self._download_tasks[download_id]
task.cancel()
pause_event = self._pause_events.get(download_id)
if pause_event is not None:
pause_event.set()
# Update status in active downloads
if download_id in self._active_downloads:
self._active_downloads[download_id]['status'] = 'cancelling'
@@ -756,6 +797,52 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
finally:
self._pause_events.pop(download_id, None)
async def pause_download(self, download_id: str) -> Dict:
"""Pause an active download without losing progress."""
if download_id not in self._download_tasks:
return {'success': False, 'error': 'Download task not found'}
pause_event = self._pause_events.get(download_id)
if pause_event is None:
pause_event = asyncio.Event()
pause_event.set()
self._pause_events[download_id] = pause_event
if not pause_event.is_set():
return {'success': False, 'error': 'Download is already paused'}
pause_event.clear()
download_info = self._active_downloads.get(download_id)
if download_info is not None:
download_info['status'] = 'paused'
download_info['bytes_per_second'] = 0.0
return {'success': True, 'message': 'Download paused successfully'}
async def resume_download(self, download_id: str) -> Dict:
"""Resume a previously paused download."""
pause_event = self._pause_events.get(download_id)
if pause_event is None:
return {'success': False, 'error': 'Download task not found'}
if pause_event.is_set():
return {'success': False, 'error': 'Download is not paused'}
pause_event.set()
download_info = self._active_downloads.get(download_id)
if download_info is not None:
if download_info.get('status') == 'paused':
download_info['status'] = 'downloading'
download_info.setdefault('bytes_per_second', 0.0)
return {'success': True, 'message': 'Download resumed successfully'}
@staticmethod
def _coerce_progress_value(progress) -> float:

View File

@@ -175,7 +175,8 @@ class Downloader:
progress_callback: Optional[Callable[..., Awaitable[None]]] = None,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True
allow_resume: bool = True,
pause_event: Optional[asyncio.Event] = None,
) -> Tuple[bool, str]:
"""
Download a file with resumable downloads and retry mechanism
@@ -187,6 +188,7 @@ class Downloader:
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
custom_headers: Additional headers to include in request
allow_resume: Whether to support resumable downloads
pause_event: Optional event that, when cleared, will pause streaming until set again
Returns:
Tuple[bool, str]: (success, save_path or error message)
@@ -309,6 +311,8 @@ class Downloader:
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if pause_event is not None and not pause_event.is_set():
await pause_event.wait()
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)

View File

@@ -164,6 +164,11 @@ class WebSocketManager:
if field in data:
progress_entry[field] = data[field]
if 'status' in data:
progress_entry['status'] = data['status']
if 'message' in data:
progress_entry['message'] = data['message']
self._download_progress[download_id] = progress_entry
if download_id not in self._download_websockets: