diff --git a/py/services/download_manager.py b/py/services/download_manager.py index a0b7b6fe..3dffbc01 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -15,7 +15,7 @@ from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry from .settings_manager import get_settings_manager from .metadata_service import get_default_metadata_provider -from .downloader import get_downloader, DownloadProgress +from .downloader import get_downloader, DownloadProgress, DownloadStreamControl # Download to temporary file first import tempfile @@ -44,7 +44,7 @@ class DownloadManager: self._active_downloads = OrderedDict() # download_id -> download_info self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_tasks = {} # download_id -> asyncio.Task - self._pause_events: Dict[str, asyncio.Event] = {} + self._pause_events: Dict[str, DownloadStreamControl] = {} async def _get_lora_scanner(self): """Get the lora scanner from registry""" @@ -89,11 +89,11 @@ class DownloadManager: 'bytes_downloaded': 0, 'total_bytes': None, 'bytes_per_second': 0.0, + 'last_progress_timestamp': None, } - pause_event = asyncio.Event() - pause_event.set() - self._pause_events[task_id] = pause_event + pause_control = DownloadStreamControl() + self._pause_events[task_id] = pause_control # Create tracking task download_task = asyncio.create_task( @@ -140,19 +140,23 @@ class DownloadManager: info['bytes_downloaded'] = snapshot.bytes_downloaded info['total_bytes'] = snapshot.total_bytes info['bytes_per_second'] = snapshot.bytes_per_second + pause_control = self._pause_events.get(task_id) + if isinstance(pause_control, DownloadStreamControl): + pause_control.mark_progress(snapshot.timestamp) + info['last_progress_timestamp'] = pause_control.last_progress_timestamp if original_callback: await self._dispatch_progress(original_callback, snapshot, progress_value) - + # Acquire semaphore to limit concurrent downloads try: async with self._download_semaphore: - pause_event = self._pause_events.get(task_id) - if pause_event is not None and not pause_event.is_set(): + pause_control = self._pause_events.get(task_id) + if pause_control is not None and pause_control.is_paused(): if task_id in self._active_downloads: self._active_downloads[task_id]['status'] = 'paused' self._active_downloads[task_id]['bytes_per_second'] = 0.0 - await pause_event.wait() + await pause_control.wait() # Update status to downloading if task_id in self._active_downloads: @@ -478,8 +482,8 @@ class DownloadManager: part_path = save_path + '.part' metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' - pause_event = self._pause_events.get(download_id) if download_id else None - + pause_control = self._pause_events.get(download_id) if download_id else None + # Store file paths in active_downloads for potential cleanup if download_id and download_id in self._active_downloads: self._active_downloads[download_id]['file_path'] = save_path @@ -590,6 +594,8 @@ class DownloadManager: # Download model file with progress tracking using downloader downloader = await get_downloader() + if pause_control is not None: + pause_control.update_stall_timeout(downloader.stall_timeout) last_error = None for download_url in download_urls: use_auth = download_url.startswith("https://civitai.com/api/download/") @@ -602,8 +608,8 @@ class DownloadManager: "use_auth": use_auth, # Only use authentication for Civitai downloads } - if pause_event is not None: - download_kwargs["pause_event"] = pause_event + if pause_control is not None: + download_kwargs["pause_event"] = pause_control success, result = await downloader.download_file( download_url, @@ -756,9 +762,9 @@ class DownloadManager: task = self._download_tasks[download_id] task.cancel() - pause_event = self._pause_events.get(download_id) - if pause_event is not None: - pause_event.set() + pause_control = self._pause_events.get(download_id) + if pause_control is not None: + pause_control.resume() # Update status in active downloads if download_id in self._active_downloads: @@ -835,16 +841,14 @@ class DownloadManager: if download_id not in self._download_tasks: return {'success': False, 'error': 'Download task not found'} - pause_event = self._pause_events.get(download_id) - if pause_event is None: - pause_event = asyncio.Event() - pause_event.set() - self._pause_events[download_id] = pause_event + pause_control = self._pause_events.get(download_id) + if pause_control is None: + return {'success': False, 'error': 'Download task not found'} - if not pause_event.is_set(): + if pause_control.is_paused(): return {'success': False, 'error': 'Download is already paused'} - pause_event.clear() + pause_control.pause() download_info = self._active_downloads.get(download_id) if download_info is not None: @@ -856,16 +860,28 @@ class DownloadManager: async def resume_download(self, download_id: str) -> Dict: """Resume a previously paused download.""" - pause_event = self._pause_events.get(download_id) - if pause_event is None: + pause_control = self._pause_events.get(download_id) + if pause_control is None: return {'success': False, 'error': 'Download task not found'} - if pause_event.is_set(): + if pause_control.is_set(): return {'success': False, 'error': 'Download is not paused'} - pause_event.set() - download_info = self._active_downloads.get(download_id) + force_reconnect = False + if pause_control is not None: + elapsed = pause_control.time_since_last_progress() + threshold = max(30.0, pause_control.stall_timeout / 2.0) + if elapsed is not None and elapsed >= threshold: + force_reconnect = True + logger.info( + "Forcing reconnect for download %s after %.1f seconds without progress", + download_id, + elapsed, + ) + + pause_control.resume(force_reconnect=force_reconnect) + if download_info is not None: if download_info.get('status') == 'paused': download_info['status'] = 'downloading' diff --git a/py/services/downloader.py b/py/services/downloader.py index 0f0b7045..8a7a80fc 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -36,6 +36,73 @@ class DownloadProgress: timestamp: float +class DownloadStreamControl: + """Synchronize pause/resume requests and reconnect hints for a download.""" + + def __init__(self, *, stall_timeout: Optional[float] = None) -> None: + self._event = asyncio.Event() + self._event.set() + self._reconnect_requested = False + self.last_progress_timestamp: Optional[float] = None + self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0 + + def is_set(self) -> bool: + return self._event.is_set() + + def is_paused(self) -> bool: + return not self._event.is_set() + + def set(self) -> None: + self._event.set() + + def clear(self) -> None: + self._event.clear() + + async def wait(self) -> None: + await self._event.wait() + + def pause(self) -> None: + self.clear() + + def resume(self, *, force_reconnect: bool = False) -> None: + if force_reconnect: + self._reconnect_requested = True + self.set() + + def request_reconnect(self) -> None: + self._reconnect_requested = True + self.set() + + def has_reconnect_request(self) -> bool: + return self._reconnect_requested + + def consume_reconnect_request(self) -> bool: + reconnect = self._reconnect_requested + self._reconnect_requested = False + return reconnect + + def mark_progress(self, timestamp: Optional[float] = None) -> None: + self.last_progress_timestamp = timestamp or datetime.now().timestamp() + self._reconnect_requested = False + + def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]: + if self.last_progress_timestamp is None: + return None + reference = now if now is not None else datetime.now().timestamp() + return max(0.0, reference - self.last_progress_timestamp) + + def update_stall_timeout(self, stall_timeout: float) -> None: + self.stall_timeout = float(stall_timeout) + + +class DownloadRestartRequested(Exception): + """Raised when a caller explicitly requests a fresh HTTP stream.""" + + +class DownloadStalledError(Exception): + """Raised when download progress stalls beyond the configured timeout.""" + + class Downloader: """Unified downloader for all HTTP/HTTPS downloads in the application.""" @@ -67,6 +134,7 @@ class Downloader: self.max_retries = 5 self.base_delay = 2.0 # Base delay for exponential backoff self.session_timeout = 300 # 5 minutes + self.stall_timeout = self._resolve_stall_timeout() # Default headers self.default_headers = { @@ -82,14 +150,38 @@ class Downloader: if self._session is None or self._should_refresh_session(): await self._create_session() return self._session - + @property def proxy_url(self) -> Optional[str]: """Get the current proxy URL (initialize if needed)""" if not hasattr(self, '_proxy_url'): self._proxy_url = None return self._proxy_url - + + def _resolve_stall_timeout(self) -> float: + """Determine the stall timeout from settings or environment.""" + default_timeout = 120.0 + settings_timeout = None + + try: + settings_manager = get_settings_manager() + settings_timeout = settings_manager.get('download_stall_timeout_seconds') + except Exception as exc: # pragma: no cover - defensive guard + logger.debug("Failed to read stall timeout from settings: %s", exc) + + raw_value = ( + settings_timeout + if settings_timeout not in (None, "") + else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT') + ) + + try: + timeout_value = float(raw_value) + except (TypeError, ValueError): + timeout_value = default_timeout + + return max(30.0, timeout_value) + def _should_refresh_session(self) -> bool: """Check if session should be refreshed""" if self._session is None: @@ -181,7 +273,7 @@ class Downloader: use_auth: bool = False, custom_headers: Optional[Dict[str, str]] = None, allow_resume: bool = True, - pause_event: Optional[asyncio.Event] = None, + pause_event: Optional[DownloadStreamControl] = None, ) -> Tuple[bool, str]: """ Download a file with resumable downloads and retry mechanism @@ -193,7 +285,7 @@ class Downloader: use_auth: Whether to include authentication headers (e.g., CivitAI API key) custom_headers: Additional headers to include in request allow_resume: Whether to support resumable downloads - pause_event: Optional event that, when cleared, will pause streaming until set again + pause_event: Optional stream control used to pause/resume and request reconnects Returns: Tuple[bool, str]: (success, save_path or error message) @@ -307,51 +399,88 @@ class Downloader: last_progress_report_time = datetime.now() progress_samples: deque[tuple[datetime, int]] = deque() progress_samples.append((last_progress_report_time, current_size)) - + # Ensure directory exists os.makedirs(os.path.dirname(save_path), exist_ok=True) - + # Stream download to file with progress updates loop = asyncio.get_running_loop() mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb' + control = pause_event + + if control is not None: + control.update_stall_timeout(self.stall_timeout) + with open(part_path, mode) as f: - async for chunk in response.content.iter_chunked(self.chunk_size): - if pause_event is not None and not pause_event.is_set(): - await pause_event.wait() - if chunk: - # Run blocking file write in executor - await loop.run_in_executor(None, f.write, chunk) - current_size += len(chunk) + while True: + active_stall_timeout = control.stall_timeout if control else self.stall_timeout - # Limit progress update frequency to reduce overhead - now = datetime.now() - time_diff = (now - last_progress_report_time).total_seconds() + if control is not None: + if control.is_paused(): + await control.wait() + resume_time = datetime.now() + last_progress_report_time = resume_time + if control.consume_reconnect_request(): + raise DownloadRestartRequested( + "Reconnect requested after resume" + ) + elif control.consume_reconnect_request(): + raise DownloadRestartRequested("Reconnect requested") - if progress_callback and time_diff >= 1.0: - progress_samples.append((now, current_size)) - cutoff = now - timedelta(seconds=5) - while progress_samples and progress_samples[0][0] < cutoff: - progress_samples.popleft() + try: + chunk = await asyncio.wait_for( + response.content.read(self.chunk_size), + timeout=active_stall_timeout, + ) + except asyncio.TimeoutError as exc: + logger.warning( + "Download stalled for %.1f seconds without progress from %s", + active_stall_timeout, + url, + ) + raise DownloadStalledError( + f"No data received for {active_stall_timeout:.1f} seconds" + ) from exc - percent = (current_size / total_size) * 100 if total_size else 0.0 - bytes_per_second = 0.0 - if len(progress_samples) >= 2: - first_time, first_bytes = progress_samples[0] - last_time, last_bytes = progress_samples[-1] - elapsed = (last_time - first_time).total_seconds() - if elapsed > 0: - bytes_per_second = (last_bytes - first_bytes) / elapsed + if not chunk: + break - progress_snapshot = DownloadProgress( - percent_complete=percent, - bytes_downloaded=current_size, - total_bytes=total_size or None, - bytes_per_second=bytes_per_second, - timestamp=now.timestamp(), - ) + # Run blocking file write in executor + await loop.run_in_executor(None, f.write, chunk) + current_size += len(chunk) - await self._dispatch_progress_callback(progress_callback, progress_snapshot) - last_progress_report_time = now + now = datetime.now() + if control is not None: + control.mark_progress(timestamp=now.timestamp()) + + # Limit progress update frequency to reduce overhead + time_diff = (now - last_progress_report_time).total_seconds() + + if progress_callback and time_diff >= 1.0: + progress_samples.append((now, current_size)) + cutoff = now - timedelta(seconds=5) + while progress_samples and progress_samples[0][0] < cutoff: + progress_samples.popleft() + + percent = (current_size / total_size) * 100 if total_size else 0.0 + bytes_per_second = 0.0 + if len(progress_samples) >= 2: + first_time, first_bytes = progress_samples[0] + last_time, last_bytes = progress_samples[-1] + elapsed = (last_time - first_time).total_seconds() + if elapsed > 0: + bytes_per_second = (last_bytes - first_bytes) / elapsed + + progress_snapshot = DownloadProgress( + percent_complete=percent, + bytes_downloaded=current_size, + total_bytes=total_size or None, + bytes_per_second=bytes_per_second, + timestamp=now.timestamp(), + ) + + await self._dispatch_progress_callback(progress_callback, progress_snapshot) + last_progress_report_time = now # Download completed successfully # Verify file size integrity before finalizing @@ -447,11 +576,17 @@ class Downloader: return True, save_path - except (aiohttp.ClientError, aiohttp.ClientPayloadError, - aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e: + except ( + aiohttp.ClientError, + aiohttp.ClientPayloadError, + aiohttp.ServerDisconnectedError, + asyncio.TimeoutError, + DownloadStalledError, + DownloadRestartRequested, + ) as e: retry_count += 1 logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}") - + if retry_count <= self.max_retries: # Calculate delay with exponential backoff delay = self.base_delay * (2 ** (retry_count - 1)) diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index 488b9a61..7ecd163e 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -1,5 +1,6 @@ import asyncio import os +from datetime import datetime from pathlib import Path from typing import Optional from types import SimpleNamespace @@ -8,6 +9,7 @@ from unittest.mock import AsyncMock import pytest from py.services.download_manager import DownloadManager +from py.services.downloader import DownloadStreamControl from py.services import download_manager from py.services.service_registry import ServiceRegistry from py.services.settings_manager import SettingsManager, get_settings_manager @@ -528,9 +530,8 @@ async def test_pause_download_updates_state(): download_id = "dl" manager._download_tasks[download_id] = object() - pause_event = asyncio.Event() - pause_event.set() - manager._pause_events[download_id] = pause_event + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control manager._active_downloads[download_id] = { "status": "downloading", "bytes_per_second": 42.0, @@ -557,8 +558,10 @@ async def test_resume_download_sets_event_and_status(): manager = DownloadManager() download_id = "dl" - pause_event = asyncio.Event() - manager._pause_events[download_id] = pause_event + pause_control = DownloadStreamControl() + pause_control.pause() + pause_control.mark_progress() + manager._pause_events[download_id] = pause_control manager._active_downloads[download_id] = { "status": "paused", "bytes_per_second": 0.0, @@ -571,13 +574,32 @@ async def test_resume_download_sets_event_and_status(): assert manager._active_downloads[download_id]["status"] == "downloading" +async def test_resume_download_requests_reconnect_for_stalled_stream(): + manager = DownloadManager() + + download_id = "dl" + pause_control = DownloadStreamControl(stall_timeout=40) + pause_control.pause() + pause_control.last_progress_timestamp = (datetime.now().timestamp() - 120) + manager._pause_events[download_id] = pause_control + manager._active_downloads[download_id] = { + "status": "paused", + "bytes_per_second": 0.0, + } + + result = await manager.resume_download(download_id) + + assert result == {"success": True, "message": "Download resumed successfully"} + assert pause_control.is_set() is True + assert pause_control.has_reconnect_request() is True + + async def test_resume_download_rejects_when_not_paused(): manager = DownloadManager() download_id = "dl" - pause_event = asyncio.Event() - pause_event.set() - manager._pause_events[download_id] = pause_event + pause_control = DownloadStreamControl() + manager._pause_events[download_id] = pause_control result = await manager.resume_download(download_id) diff --git a/tests/services/test_downloader.py b/tests/services/test_downloader.py index 61469957..156ab276 100644 --- a/tests/services/test_downloader.py +++ b/tests/services/test_downloader.py @@ -1,6 +1,7 @@ import asyncio from datetime import datetime from pathlib import Path +from typing import Sequence import pytest @@ -8,13 +9,24 @@ from py.services.downloader import Downloader class FakeStream: - def __init__(self, chunks): + def __init__(self, chunks: Sequence[Sequence] | Sequence[bytes]): self._chunks = list(chunks) - async def iter_chunked(self, _chunk_size): - for chunk in self._chunks: + async def read(self, _chunk_size: int) -> bytes: + if not self._chunks: await asyncio.sleep(0) - yield chunk + return b"" + + item = self._chunks.pop(0) + delay = 0.0 + payload = item + + if isinstance(item, tuple): + payload = item[0] + delay = item[1] + + await asyncio.sleep(delay) + return payload class FakeResponse: @@ -53,6 +65,12 @@ def _build_downloader(responses, *, max_retries=0): downloader._session = fake_session downloader._session_created_at = datetime.now() downloader._proxy_url = None + async def _noop_create_session(): + downloader._session = fake_session + downloader._session_created_at = datetime.now() + downloader._proxy_url = None + + downloader._create_session = _noop_create_session # type: ignore[assignment] return downloader @@ -123,3 +141,34 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path): assert success is True assert Path(result_path).read_bytes() == payload assert not Path(str(target_path) + ".part").exists() + + +@pytest.mark.asyncio +async def test_download_file_recovers_from_stall(tmp_path): + target_path = tmp_path / "model" / "file.bin" + target_path.parent.mkdir() + + payload = b"abcdef" + + responses = [ + lambda: FakeResponse( + status=200, + headers={"content-length": str(len(payload))}, + chunks=[(b"abc", 0.0), (b"def", 0.1)], + ), + lambda: FakeResponse( + status=206, + headers={"content-length": "3", "Content-Range": "bytes 3-5/6"}, + chunks=[(b"def", 0.0)], + ), + ] + + downloader = _build_downloader(responses, max_retries=1) + downloader.stall_timeout = 0.05 + + success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) + + assert success is True + assert Path(result_path).read_bytes() == payload + assert downloader._session._get_calls == 2 + assert not Path(str(target_path) + ".part").exists()