mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-26 15:38:52 -03:00
Merge pull request #564 from willmiao/codex/design-apis-for-pause-and-resume-download
test: add coverage for download pause and resume controls
This commit is contained in:
@@ -758,6 +758,30 @@ class ModelDownloadHandler:
|
|||||||
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
|
self._logger.error("Error cancelling download via GET: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def pause_download_get(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
download_id = request.query.get("download_id")
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
|
||||||
|
result = await self._download_coordinator.pause_download(download_id)
|
||||||
|
status = 200 if result.get("success") else 400
|
||||||
|
return web.json_response(result, status=status)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error pausing download via GET: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def resume_download_get(self, request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
download_id = request.query.get("download_id")
|
||||||
|
if not download_id:
|
||||||
|
return web.json_response({"success": False, "error": "Download ID is required"}, status=400)
|
||||||
|
result = await self._download_coordinator.resume_download(download_id)
|
||||||
|
status = 200 if result.get("success") else 400
|
||||||
|
return web.json_response(result, status=status)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error resuming download via GET: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def get_download_progress(self, request: web.Request) -> web.Response:
|
async def get_download_progress(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
download_id = request.match_info.get("download_id")
|
download_id = request.match_info.get("download_id")
|
||||||
@@ -766,15 +790,23 @@ class ModelDownloadHandler:
|
|||||||
progress_data = self._ws_manager.get_download_progress(download_id)
|
progress_data = self._ws_manager.get_download_progress(download_id)
|
||||||
if progress_data is None:
|
if progress_data is None:
|
||||||
return web.json_response({"success": False, "error": "Download ID not found"}, status=404)
|
return web.json_response({"success": False, "error": "Download ID not found"}, status=404)
|
||||||
return web.json_response(
|
response_payload = {
|
||||||
{
|
|
||||||
"success": True,
|
"success": True,
|
||||||
"progress": progress_data.get("progress", 0),
|
"progress": progress_data.get("progress", 0),
|
||||||
"bytes_downloaded": progress_data.get("bytes_downloaded"),
|
"bytes_downloaded": progress_data.get("bytes_downloaded"),
|
||||||
"total_bytes": progress_data.get("total_bytes"),
|
"total_bytes": progress_data.get("total_bytes"),
|
||||||
"bytes_per_second": progress_data.get("bytes_per_second", 0.0),
|
"bytes_per_second": progress_data.get("bytes_per_second", 0.0),
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
status = progress_data.get("status")
|
||||||
|
if status and status != "progress":
|
||||||
|
response_payload["status"] = status
|
||||||
|
if "message" in progress_data:
|
||||||
|
response_payload["message"] = progress_data["message"]
|
||||||
|
elif status is None and "message" in progress_data:
|
||||||
|
response_payload["message"] = progress_data["message"]
|
||||||
|
|
||||||
|
return web.json_response(response_payload)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self._logger.error("Error getting download progress: %s", exc, exc_info=True)
|
self._logger.error("Error getting download progress: %s", exc, exc_info=True)
|
||||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
@@ -1025,6 +1057,8 @@ class ModelHandlerSet:
|
|||||||
"download_model": self.download.download_model,
|
"download_model": self.download.download_model,
|
||||||
"download_model_get": self.download.download_model_get,
|
"download_model_get": self.download.download_model_get,
|
||||||
"cancel_download_get": self.download.cancel_download_get,
|
"cancel_download_get": self.download.cancel_download_get,
|
||||||
|
"pause_download_get": self.download.pause_download_get,
|
||||||
|
"resume_download_get": self.download.resume_download_get,
|
||||||
"get_download_progress": self.download.get_download_progress,
|
"get_download_progress": self.download.get_download_progress,
|
||||||
"get_civitai_versions": self.civitai.get_civitai_versions,
|
"get_civitai_versions": self.civitai.get_civitai_versions,
|
||||||
"get_civitai_model_by_version": self.civitai.get_civitai_model_by_version,
|
"get_civitai_model_by_version": self.civitai.get_civitai_model_by_version,
|
||||||
|
|||||||
@@ -58,6 +58,8 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||||
|
RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"),
|
||||||
|
RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"),
|
||||||
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
|
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
|
||||||
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -109,6 +109,56 @@ class DownloadCoordinator:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
async def pause_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
"""Pause an active download and notify listeners."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
result = await download_manager.pause_download(download_id)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
cached_progress = self._ws_manager.get_download_progress(download_id) or {}
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"status": "paused",
|
||||||
|
"progress": cached_progress.get("progress", 0),
|
||||||
|
"download_id": download_id,
|
||||||
|
"message": "Download paused by user",
|
||||||
|
}
|
||||||
|
|
||||||
|
for field in ("bytes_downloaded", "total_bytes", "bytes_per_second"):
|
||||||
|
if field in cached_progress:
|
||||||
|
payload[field] = cached_progress[field]
|
||||||
|
|
||||||
|
payload["bytes_per_second"] = 0.0
|
||||||
|
|
||||||
|
await self._ws_manager.broadcast_download_progress(download_id, payload)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def resume_download(self, download_id: str) -> Dict[str, Any]:
|
||||||
|
"""Resume a paused download and notify listeners."""
|
||||||
|
|
||||||
|
download_manager = await self._download_manager_factory()
|
||||||
|
result = await download_manager.resume_download(download_id)
|
||||||
|
|
||||||
|
if result.get("success"):
|
||||||
|
cached_progress = self._ws_manager.get_download_progress(download_id) or {}
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"status": "downloading",
|
||||||
|
"progress": cached_progress.get("progress", 0),
|
||||||
|
"download_id": download_id,
|
||||||
|
"message": "Download resumed by user",
|
||||||
|
}
|
||||||
|
|
||||||
|
for field in ("bytes_downloaded", "total_bytes"):
|
||||||
|
if field in cached_progress:
|
||||||
|
payload[field] = cached_progress[field]
|
||||||
|
|
||||||
|
payload["bytes_per_second"] = cached_progress.get("bytes_per_second", 0.0)
|
||||||
|
|
||||||
|
await self._ws_manager.broadcast_download_progress(download_id, payload)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def list_active_downloads(self) -> Dict[str, Any]:
|
async def list_active_downloads(self) -> Dict[str, Any]:
|
||||||
"""Return the active download map from the underlying manager."""
|
"""Return the active download map from the underlying manager."""
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class DownloadManager:
|
|||||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||||
self._download_tasks = {} # download_id -> asyncio.Task
|
self._download_tasks = {} # download_id -> asyncio.Task
|
||||||
|
self._pause_events: Dict[str, asyncio.Event] = {}
|
||||||
|
|
||||||
async def _get_lora_scanner(self):
|
async def _get_lora_scanner(self):
|
||||||
"""Get the lora scanner from registry"""
|
"""Get the lora scanner from registry"""
|
||||||
@@ -89,6 +90,10 @@ class DownloadManager:
|
|||||||
'bytes_per_second': 0.0,
|
'bytes_per_second': 0.0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pause_event = asyncio.Event()
|
||||||
|
pause_event.set()
|
||||||
|
self._pause_events[task_id] = pause_event
|
||||||
|
|
||||||
# Create tracking task
|
# Create tracking task
|
||||||
download_task = asyncio.create_task(
|
download_task = asyncio.create_task(
|
||||||
self._download_with_semaphore(
|
self._download_with_semaphore(
|
||||||
@@ -111,6 +116,7 @@ class DownloadManager:
|
|||||||
# Clean up task reference
|
# Clean up task reference
|
||||||
if task_id in self._download_tasks:
|
if task_id in self._download_tasks:
|
||||||
del self._download_tasks[task_id]
|
del self._download_tasks[task_id]
|
||||||
|
self._pause_events.pop(task_id, None)
|
||||||
|
|
||||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||||
save_dir: str, relative_path: str,
|
save_dir: str, relative_path: str,
|
||||||
@@ -140,6 +146,13 @@ class DownloadManager:
|
|||||||
# Acquire semaphore to limit concurrent downloads
|
# Acquire semaphore to limit concurrent downloads
|
||||||
try:
|
try:
|
||||||
async with self._download_semaphore:
|
async with self._download_semaphore:
|
||||||
|
pause_event = self._pause_events.get(task_id)
|
||||||
|
if pause_event is not None and not pause_event.is_set():
|
||||||
|
if task_id in self._active_downloads:
|
||||||
|
self._active_downloads[task_id]['status'] = 'paused'
|
||||||
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||||
|
await pause_event.wait()
|
||||||
|
|
||||||
# Update status to downloading
|
# Update status to downloading
|
||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
self._active_downloads[task_id]['status'] = 'downloading'
|
self._active_downloads[task_id]['status'] = 'downloading'
|
||||||
@@ -189,9 +202,17 @@ class DownloadManager:
|
|||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
del self._active_downloads[task_id]
|
del self._active_downloads[task_id]
|
||||||
|
|
||||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
async def _execute_original_download(
|
||||||
relative_path, progress_callback, use_default_paths,
|
self,
|
||||||
download_id=None, source=None):
|
model_id,
|
||||||
|
model_version_id,
|
||||||
|
save_dir,
|
||||||
|
relative_path,
|
||||||
|
progress_callback,
|
||||||
|
use_default_paths,
|
||||||
|
download_id=None,
|
||||||
|
source=None,
|
||||||
|
):
|
||||||
"""Wrapper for original download_from_civitai implementation"""
|
"""Wrapper for original download_from_civitai implementation"""
|
||||||
try:
|
try:
|
||||||
# Check if model version already exists in library
|
# Check if model version already exists in library
|
||||||
@@ -345,7 +366,7 @@ class DownloadManager:
|
|||||||
relative_path=relative_path,
|
relative_path=relative_path,
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
download_id=download_id
|
download_id=download_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If early_access_msg exists and download failed, replace error message
|
# If early_access_msg exists and download failed, replace error message
|
||||||
@@ -410,10 +431,17 @@ class DownloadManager:
|
|||||||
|
|
||||||
return formatted_path
|
return formatted_path
|
||||||
|
|
||||||
async def _execute_download(self, download_urls: List[str], save_dir: str,
|
async def _execute_download(
|
||||||
metadata, version_info: Dict,
|
self,
|
||||||
relative_path: str, progress_callback=None,
|
download_urls: List[str],
|
||||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
save_dir: str,
|
||||||
|
metadata,
|
||||||
|
version_info: Dict,
|
||||||
|
relative_path: str,
|
||||||
|
progress_callback=None,
|
||||||
|
model_type: str = "lora",
|
||||||
|
download_id: str = None,
|
||||||
|
) -> Dict:
|
||||||
"""Execute the actual download process including preview images and model files"""
|
"""Execute the actual download process including preview images and model files"""
|
||||||
try:
|
try:
|
||||||
# Extract original filename details
|
# Extract original filename details
|
||||||
@@ -445,6 +473,8 @@ class DownloadManager:
|
|||||||
part_path = save_path + '.part'
|
part_path = save_path + '.part'
|
||||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||||
|
|
||||||
|
pause_event = self._pause_events.get(download_id) if download_id else None
|
||||||
|
|
||||||
# Store file paths in active_downloads for potential cleanup
|
# Store file paths in active_downloads for potential cleanup
|
||||||
if download_id and download_id in self._active_downloads:
|
if download_id and download_id in self._active_downloads:
|
||||||
self._active_downloads[download_id]['file_path'] = save_path
|
self._active_downloads[download_id]['file_path'] = save_path
|
||||||
@@ -558,15 +588,22 @@ class DownloadManager:
|
|||||||
last_error = None
|
last_error = None
|
||||||
for download_url in download_urls:
|
for download_url in download_urls:
|
||||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||||
success, result = await downloader.download_file(
|
download_kwargs = {
|
||||||
download_url,
|
"progress_callback": lambda progress, snapshot=None: self._handle_download_progress(
|
||||||
save_path, # Use full path instead of separate dir and filename
|
|
||||||
progress_callback=lambda progress, snapshot=None: self._handle_download_progress(
|
|
||||||
progress,
|
progress,
|
||||||
progress_callback,
|
progress_callback,
|
||||||
snapshot,
|
snapshot,
|
||||||
),
|
),
|
||||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
||||||
|
}
|
||||||
|
|
||||||
|
if pause_event is not None:
|
||||||
|
download_kwargs["pause_event"] = pause_event
|
||||||
|
|
||||||
|
success, result = await downloader.download_file(
|
||||||
|
download_url,
|
||||||
|
save_path, # Use full path instead of separate dir and filename
|
||||||
|
**download_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
@@ -690,6 +727,10 @@ class DownloadManager:
|
|||||||
task = self._download_tasks[download_id]
|
task = self._download_tasks[download_id]
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
pause_event = self._pause_events.get(download_id)
|
||||||
|
if pause_event is not None:
|
||||||
|
pause_event.set()
|
||||||
|
|
||||||
# Update status in active downloads
|
# Update status in active downloads
|
||||||
if download_id in self._active_downloads:
|
if download_id in self._active_downloads:
|
||||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||||
@@ -756,6 +797,52 @@ class DownloadManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||||
return {'success': False, 'error': str(e)}
|
return {'success': False, 'error': str(e)}
|
||||||
|
finally:
|
||||||
|
self._pause_events.pop(download_id, None)
|
||||||
|
|
||||||
|
async def pause_download(self, download_id: str) -> Dict:
|
||||||
|
"""Pause an active download without losing progress."""
|
||||||
|
|
||||||
|
if download_id not in self._download_tasks:
|
||||||
|
return {'success': False, 'error': 'Download task not found'}
|
||||||
|
|
||||||
|
pause_event = self._pause_events.get(download_id)
|
||||||
|
if pause_event is None:
|
||||||
|
pause_event = asyncio.Event()
|
||||||
|
pause_event.set()
|
||||||
|
self._pause_events[download_id] = pause_event
|
||||||
|
|
||||||
|
if not pause_event.is_set():
|
||||||
|
return {'success': False, 'error': 'Download is already paused'}
|
||||||
|
|
||||||
|
pause_event.clear()
|
||||||
|
|
||||||
|
download_info = self._active_downloads.get(download_id)
|
||||||
|
if download_info is not None:
|
||||||
|
download_info['status'] = 'paused'
|
||||||
|
download_info['bytes_per_second'] = 0.0
|
||||||
|
|
||||||
|
return {'success': True, 'message': 'Download paused successfully'}
|
||||||
|
|
||||||
|
async def resume_download(self, download_id: str) -> Dict:
|
||||||
|
"""Resume a previously paused download."""
|
||||||
|
|
||||||
|
pause_event = self._pause_events.get(download_id)
|
||||||
|
if pause_event is None:
|
||||||
|
return {'success': False, 'error': 'Download task not found'}
|
||||||
|
|
||||||
|
if pause_event.is_set():
|
||||||
|
return {'success': False, 'error': 'Download is not paused'}
|
||||||
|
|
||||||
|
pause_event.set()
|
||||||
|
|
||||||
|
download_info = self._active_downloads.get(download_id)
|
||||||
|
if download_info is not None:
|
||||||
|
if download_info.get('status') == 'paused':
|
||||||
|
download_info['status'] = 'downloading'
|
||||||
|
download_info.setdefault('bytes_per_second', 0.0)
|
||||||
|
|
||||||
|
return {'success': True, 'message': 'Download resumed successfully'}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _coerce_progress_value(progress) -> float:
|
def _coerce_progress_value(progress) -> float:
|
||||||
|
|||||||
@@ -175,7 +175,8 @@ class Downloader:
|
|||||||
progress_callback: Optional[Callable[..., Awaitable[None]]] = None,
|
progress_callback: Optional[Callable[..., Awaitable[None]]] = None,
|
||||||
use_auth: bool = False,
|
use_auth: bool = False,
|
||||||
custom_headers: Optional[Dict[str, str]] = None,
|
custom_headers: Optional[Dict[str, str]] = None,
|
||||||
allow_resume: bool = True
|
allow_resume: bool = True,
|
||||||
|
pause_event: Optional[asyncio.Event] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Download a file with resumable downloads and retry mechanism
|
Download a file with resumable downloads and retry mechanism
|
||||||
@@ -187,6 +188,7 @@ class Downloader:
|
|||||||
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
||||||
custom_headers: Additional headers to include in request
|
custom_headers: Additional headers to include in request
|
||||||
allow_resume: Whether to support resumable downloads
|
allow_resume: Whether to support resumable downloads
|
||||||
|
pause_event: Optional event that, when cleared, will pause streaming until set again
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: (success, save_path or error message)
|
Tuple[bool, str]: (success, save_path or error message)
|
||||||
@@ -309,6 +311,8 @@ class Downloader:
|
|||||||
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||||
with open(part_path, mode) as f:
|
with open(part_path, mode) as f:
|
||||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||||
|
if pause_event is not None and not pause_event.is_set():
|
||||||
|
await pause_event.wait()
|
||||||
if chunk:
|
if chunk:
|
||||||
# Run blocking file write in executor
|
# Run blocking file write in executor
|
||||||
await loop.run_in_executor(None, f.write, chunk)
|
await loop.run_in_executor(None, f.write, chunk)
|
||||||
|
|||||||
@@ -164,6 +164,11 @@ class WebSocketManager:
|
|||||||
if field in data:
|
if field in data:
|
||||||
progress_entry[field] = data[field]
|
progress_entry[field] = data[field]
|
||||||
|
|
||||||
|
if 'status' in data:
|
||||||
|
progress_entry['status'] = data['status']
|
||||||
|
if 'message' in data:
|
||||||
|
progress_entry['message'] = data['message']
|
||||||
|
|
||||||
self._download_progress[download_id] = progress_entry
|
self._download_progress[download_id] = progress_entry
|
||||||
|
|
||||||
if download_id not in self._download_websockets:
|
if download_id not in self._download_websockets:
|
||||||
|
|||||||
117
tests/services/test_download_coordinator.py
Normal file
117
tests/services/test_download_coordinator.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
from py.services.download_coordinator import DownloadCoordinator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StubWebSocketManager:
|
||||||
|
progress: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||||
|
broadcasts: List[Tuple[str, Dict[str, Any]]] = field(default_factory=list)
|
||||||
|
|
||||||
|
def generate_download_id(self) -> str:
|
||||||
|
return "generated"
|
||||||
|
|
||||||
|
def get_download_progress(self, download_id: str) -> Dict[str, Any] | None:
|
||||||
|
return self.progress.get(download_id)
|
||||||
|
|
||||||
|
async def broadcast_download_progress(self, download_id: str, payload: Dict[str, Any]) -> None:
|
||||||
|
self.broadcasts.append((download_id, payload))
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pause_download_broadcasts_cached_state():
|
||||||
|
ws_manager = StubWebSocketManager(
|
||||||
|
progress={
|
||||||
|
"dl": {
|
||||||
|
"progress": 45,
|
||||||
|
"bytes_downloaded": 1024,
|
||||||
|
"total_bytes": 2048,
|
||||||
|
"bytes_per_second": 256.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
download_manager = AsyncMock()
|
||||||
|
download_manager.pause_download = AsyncMock(return_value={"success": True})
|
||||||
|
|
||||||
|
async def factory():
|
||||||
|
return download_manager
|
||||||
|
|
||||||
|
coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory)
|
||||||
|
|
||||||
|
result = await coordinator.pause_download("dl")
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert ws_manager.broadcasts == [
|
||||||
|
(
|
||||||
|
"dl",
|
||||||
|
{
|
||||||
|
"status": "paused",
|
||||||
|
"progress": 45,
|
||||||
|
"download_id": "dl",
|
||||||
|
"message": "Download paused by user",
|
||||||
|
"bytes_downloaded": 1024,
|
||||||
|
"total_bytes": 2048,
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resume_download_broadcasts_cached_state():
|
||||||
|
ws_manager = StubWebSocketManager(
|
||||||
|
progress={
|
||||||
|
"dl": {
|
||||||
|
"progress": 75,
|
||||||
|
"bytes_downloaded": 2048,
|
||||||
|
"total_bytes": 4096,
|
||||||
|
"bytes_per_second": 512.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
download_manager = AsyncMock()
|
||||||
|
download_manager.resume_download = AsyncMock(return_value={"success": True})
|
||||||
|
|
||||||
|
async def factory():
|
||||||
|
return download_manager
|
||||||
|
|
||||||
|
coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory)
|
||||||
|
|
||||||
|
result = await coordinator.resume_download("dl")
|
||||||
|
|
||||||
|
assert result == {"success": True}
|
||||||
|
assert ws_manager.broadcasts == [
|
||||||
|
(
|
||||||
|
"dl",
|
||||||
|
{
|
||||||
|
"status": "downloading",
|
||||||
|
"progress": 75,
|
||||||
|
"download_id": "dl",
|
||||||
|
"message": "Download resumed by user",
|
||||||
|
"bytes_downloaded": 2048,
|
||||||
|
"total_bytes": 4096,
|
||||||
|
"bytes_per_second": 512.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pause_download_does_not_broadcast_on_failure():
|
||||||
|
ws_manager = StubWebSocketManager()
|
||||||
|
|
||||||
|
download_manager = AsyncMock()
|
||||||
|
download_manager.pause_download = AsyncMock(return_value={"success": False, "error": "nope"})
|
||||||
|
|
||||||
|
async def factory():
|
||||||
|
return download_manager
|
||||||
|
|
||||||
|
coordinator = DownloadCoordinator(ws_manager=ws_manager, download_manager_factory=factory)
|
||||||
|
|
||||||
|
result = await coordinator.pause_download("dl")
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "nope"}
|
||||||
|
assert ws_manager.broadcasts == []
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -398,6 +399,67 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path):
|
|||||||
assert dummy_scanner.calls # ensure cache updated
|
assert dummy_scanner.calls # ensure cache updated
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pause_download_updates_state():
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
manager._download_tasks[download_id] = object()
|
||||||
|
pause_event = asyncio.Event()
|
||||||
|
pause_event.set()
|
||||||
|
manager._pause_events[download_id] = pause_event
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"status": "downloading",
|
||||||
|
"bytes_per_second": 42.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await manager.pause_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": True, "message": "Download paused successfully"}
|
||||||
|
assert download_id in manager._pause_events
|
||||||
|
assert manager._pause_events[download_id].is_set() is False
|
||||||
|
assert manager._active_downloads[download_id]["status"] == "paused"
|
||||||
|
assert manager._active_downloads[download_id]["bytes_per_second"] == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pause_download_rejects_unknown_task():
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
result = await manager.pause_download("missing")
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "Download task not found"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resume_download_sets_event_and_status():
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
pause_event = asyncio.Event()
|
||||||
|
manager._pause_events[download_id] = pause_event
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"status": "paused",
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await manager.resume_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": True, "message": "Download resumed successfully"}
|
||||||
|
assert manager._pause_events[download_id].is_set() is True
|
||||||
|
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resume_download_rejects_when_not_paused():
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
pause_event = asyncio.Event()
|
||||||
|
pause_event.set()
|
||||||
|
manager._pause_events[download_id] = pause_event
|
||||||
|
|
||||||
|
result = await manager.resume_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": False, "error": "Download is not paused"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
|
async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_path):
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
|
|||||||
Reference in New Issue
Block a user