diff --git a/py/services/aria2_downloader.py b/py/services/aria2_downloader.py index 3b4cddd9..617cdb9d 100644 --- a/py/services/aria2_downloader.py +++ b/py/services/aria2_downloader.py @@ -84,6 +84,7 @@ class Aria2Downloader: self._transfers: Dict[str, Aria2Transfer] = {} self._poll_interval = 0.5 self._state_store = Aria2TransferStateStore() + self._stderr_reader_task: Optional[asyncio.Task] = None @property def is_running(self) -> bool: @@ -115,7 +116,7 @@ class Aria2Downloader: try: while True: - status = await self.get_status(download_id) + status = await self._get_status_with_retry(download_id) if status is None: return False, "aria2 download not found" @@ -136,6 +137,35 @@ class Aria2Downloader: finally: self._transfers.pop(download_id, None) + async def _get_status_with_retry( + self, download_id: str, *, max_retries: int = 3, retry_delay: float = 1.0 + ) -> Optional[Dict[str, Any]]: + """Call get_status with retry for transient RPC failures. + + Only retries on :exc:`Aria2Error` (RPC-level failure). Returns + ``None`` immediately when the download_id is not tracked (a missing + transfer is not a transient condition, so retrying is pointless). + + A single failed RPC call should not immediately fail the download, + because aria2 may be temporarily busy (e.g. finalizing multiple + concurrent downloads) and a retry will often succeed. + """ + last_exc: Optional[Exception] = None + for attempt in range(max_retries): + try: + return await self.get_status(download_id) + except Aria2Error as exc: + last_exc = exc + if attempt < max_retries - 1: + logger.warning( + "aria2 get_status transient failure (attempt %d/%d) for %s: %s", + attempt + 1, max_retries, download_id, exc, + ) + await asyncio.sleep(retry_delay) + raise Aria2Error( + f"Failed to query aria2 download status after {max_retries} attempts: {last_exc}" + ) from last_exc + async def _schedule_download( self, url: str, @@ -312,6 +342,16 @@ class Aria2Downloader: async def close(self) -> None: """Shut down the RPC process and session.""" + # Cancel the background stderr reader first so it stops reading + # from the pipe before the subprocess is terminated. + if self._stderr_reader_task is not None: + self._stderr_reader_task.cancel() + try: + await asyncio.wait_for(self._stderr_reader_task, timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + self._stderr_reader_task = None + if self._rpc_session is not None: await self._rpc_session.close() self._rpc_session = None @@ -331,6 +371,23 @@ class Aria2Downloader: process.kill() await process.wait() + async def _drain_stderr(self) -> None: + """Continuously drain aria2's stderr pipe so it never blocks. + + When the 64 KB pipe buffer fills up, aria2's ``write()`` to stderr + blocks, which freezes the entire ``aria2c`` process — including its + RPC handler. This background task reads lines from stderr as they + arrive and forwards them to Python's logger. + """ + try: + assert self._process is not None and self._process.stderr is not None + async for line in self._process.stderr: + text = line.decode("utf-8", errors="replace").rstrip() + if text: + logger.debug("aria2 stderr: %s", text) + except Exception: + pass + async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None: try: result = callback(snapshot, snapshot) @@ -463,6 +520,14 @@ class Aria2Downloader: stderr=asyncio.subprocess.PIPE, ) + # Drain aria2's stderr in a background task so the pipe buffer + # never fills up. If the pipe blocks, aria2 itself freezes and + # cannot respond to RPC — this was the root cause of the + # "Failed to query aria2 download status" timeout bug. + self._stderr_reader_task = asyncio.create_task( + self._drain_stderr() + ) + await self._wait_until_ready() def _resolve_executable(self) -> str: diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 3e395610..adcde582 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -2029,7 +2029,21 @@ class DownloadManager: break last_error = result - if os.path.exists(save_path): + # For aria2: if the .aria2 control file is missing, aria2 considers + # the download complete. A transient RPC failure may have made us + # think the download failed even though the file is fully on disk. + # Keep the file so a retry can find it already complete. + if ( + transfer_backend == "aria2" + and os.path.exists(save_path) + and not os.path.exists(f"{save_path}.aria2") + ): + logger.warning( + "aria2 download reported failure but .aria2 file is absent " + "for %s — the file is likely complete. Preserving it for retry.", + save_path, + ) + elif os.path.exists(save_path): try: os.remove(save_path) except Exception as e: diff --git a/tests/services/test_aria2_downloader.py b/tests/services/test_aria2_downloader.py index 268aa91e..6c300fb8 100644 --- a/tests/services/test_aria2_downloader.py +++ b/tests/services/test_aria2_downloader.py @@ -352,3 +352,74 @@ async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch): ) assert result == "https://signed.example.com/file.safetensors" + + +@pytest.mark.asyncio +async def test_get_status_with_retry_passes_through_success(monkeypatch): + """A successful first call returns immediately, no retries.""" + downloader = Aria2Downloader() + call_count = 0 + + async def fake_get_status(_id): + nonlocal call_count + call_count += 1 + return {"status": "active", "completedLength": "50", "totalLength": "100"} + + monkeypatch.setattr(downloader, "get_status", fake_get_status) + + result = await downloader._get_status_with_retry("dummy") + assert result is not None + assert result["status"] == "active" + assert call_count == 1 + + +@pytest.mark.asyncio +async def test_get_status_with_retry_succeeds_after_transient_failure(monkeypatch): + """A transient Aria2Error on the first call is retried and succeeds.""" + downloader = Aria2Downloader() + call_count = 0 + + async def fake_get_status(_id): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise Aria2Error("timeout") + return {"status": "complete", "completedLength": "100", "totalLength": "100"} + + monkeypatch.setattr(downloader, "get_status", fake_get_status) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + result = await downloader._get_status_with_retry("dummy") + assert result is not None + assert result["status"] == "complete" + assert call_count == 2 + + +@pytest.mark.asyncio +async def test_get_status_with_retry_raises_after_all_retries_exhausted(monkeypatch): + """All retry attempts fail → Aria2Error with a descriptive message.""" + downloader = Aria2Downloader() + + async def fake_get_status(_id): + raise Aria2Error("connection reset") + + monkeypatch.setattr(downloader, "get_status", fake_get_status) + monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock()) + + with pytest.raises(Aria2Error) as exc_info: + await downloader._get_status_with_retry("dummy") + + msg = str(exc_info.value) + assert "after 3 attempts" in msg + assert "connection reset" in msg + + +@pytest.mark.asyncio +async def test_get_status_with_retry_returns_none_when_not_tracked(monkeypatch): + """No transfer in _transfers → get_status returns None → no retry needed.""" + downloader = Aria2Downloader() + + # get_status returns None when the download_id has no transfer; + # _get_status_with_retry should propagate that without raising. + result = await downloader._get_status_with_retry("nonexistent") + assert result is None