fix(aria2): drain stderr pipe to prevent aria2 freeze, retry RPC status on transient failure

Root cause: aria2c subprocess stderr pipe (64 KB buffer) was never
drained. When enough error/warning output accumulated, aria2's write()
blocked, freezing the entire process including its RPC handler. The
tellStatus call then timed out after 30s with asyncio.TimeoutError(),
producing the empty error message in 'Failed to query aria2 download
status: '.

Fixes:
- Drain stderr in a background task so pipe never fills up
- Retry get_status() RPC calls up to 3 times on transient failure
- In the failure path, preserve .safetensors when .aria2 is absent
  (the download was likely complete on disk)
This commit is contained in:
Will Miao
2026-06-26 08:25:05 +08:00
parent 0ac10dfd42
commit 3a2941d751
3 changed files with 152 additions and 2 deletions

View File

@@ -84,6 +84,7 @@ class Aria2Downloader:
self._transfers: Dict[str, Aria2Transfer] = {} self._transfers: Dict[str, Aria2Transfer] = {}
self._poll_interval = 0.5 self._poll_interval = 0.5
self._state_store = Aria2TransferStateStore() self._state_store = Aria2TransferStateStore()
self._stderr_reader_task: Optional[asyncio.Task] = None
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
@@ -115,7 +116,7 @@ class Aria2Downloader:
try: try:
while True: while True:
status = await self.get_status(download_id) status = await self._get_status_with_retry(download_id)
if status is None: if status is None:
return False, "aria2 download not found" return False, "aria2 download not found"
@@ -136,6 +137,35 @@ class Aria2Downloader:
finally: finally:
self._transfers.pop(download_id, None) self._transfers.pop(download_id, None)
async def _get_status_with_retry(
self, download_id: str, *, max_retries: int = 3, retry_delay: float = 1.0
) -> Optional[Dict[str, Any]]:
"""Call get_status with retry for transient RPC failures.
Only retries on :exc:`Aria2Error` (RPC-level failure). Returns
``None`` immediately when the download_id is not tracked (a missing
transfer is not a transient condition, so retrying is pointless).
A single failed RPC call should not immediately fail the download,
because aria2 may be temporarily busy (e.g. finalizing multiple
concurrent downloads) and a retry will often succeed.
"""
last_exc: Optional[Exception] = None
for attempt in range(max_retries):
try:
return await self.get_status(download_id)
except Aria2Error as exc:
last_exc = exc
if attempt < max_retries - 1:
logger.warning(
"aria2 get_status transient failure (attempt %d/%d) for %s: %s",
attempt + 1, max_retries, download_id, exc,
)
await asyncio.sleep(retry_delay)
raise Aria2Error(
f"Failed to query aria2 download status after {max_retries} attempts: {last_exc}"
) from last_exc
async def _schedule_download( async def _schedule_download(
self, self,
url: str, url: str,
@@ -312,6 +342,16 @@ class Aria2Downloader:
async def close(self) -> None: async def close(self) -> None:
"""Shut down the RPC process and session.""" """Shut down the RPC process and session."""
# Cancel the background stderr reader first so it stops reading
# from the pipe before the subprocess is terminated.
if self._stderr_reader_task is not None:
self._stderr_reader_task.cancel()
try:
await asyncio.wait_for(self._stderr_reader_task, timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
self._stderr_reader_task = None
if self._rpc_session is not None: if self._rpc_session is not None:
await self._rpc_session.close() await self._rpc_session.close()
self._rpc_session = None self._rpc_session = None
@@ -331,6 +371,23 @@ class Aria2Downloader:
process.kill() process.kill()
await process.wait() await process.wait()
async def _drain_stderr(self) -> None:
"""Continuously drain aria2's stderr pipe so it never blocks.
When the 64 KB pipe buffer fills up, aria2's ``write()`` to stderr
blocks, which freezes the entire ``aria2c`` process — including its
RPC handler. This background task reads lines from stderr as they
arrive and forwards them to Python's logger.
"""
try:
assert self._process is not None and self._process.stderr is not None
async for line in self._process.stderr:
text = line.decode("utf-8", errors="replace").rstrip()
if text:
logger.debug("aria2 stderr: %s", text)
except Exception:
pass
async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None: async def _dispatch_progress(self, callback, snapshot: DownloadProgress) -> None:
try: try:
result = callback(snapshot, snapshot) result = callback(snapshot, snapshot)
@@ -463,6 +520,14 @@ class Aria2Downloader:
stderr=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
) )
# Drain aria2's stderr in a background task so the pipe buffer
# never fills up. If the pipe blocks, aria2 itself freezes and
# cannot respond to RPC — this was the root cause of the
# "Failed to query aria2 download status" timeout bug.
self._stderr_reader_task = asyncio.create_task(
self._drain_stderr()
)
await self._wait_until_ready() await self._wait_until_ready()
def _resolve_executable(self) -> str: def _resolve_executable(self) -> str:

View File

@@ -2029,7 +2029,21 @@ class DownloadManager:
break break
last_error = result last_error = result
if os.path.exists(save_path): # For aria2: if the .aria2 control file is missing, aria2 considers
# the download complete. A transient RPC failure may have made us
# think the download failed even though the file is fully on disk.
# Keep the file so a retry can find it already complete.
if (
transfer_backend == "aria2"
and os.path.exists(save_path)
and not os.path.exists(f"{save_path}.aria2")
):
logger.warning(
"aria2 download reported failure but .aria2 file is absent "
"for %s — the file is likely complete. Preserving it for retry.",
save_path,
)
elif os.path.exists(save_path):
try: try:
os.remove(save_path) os.remove(save_path)
except Exception as e: except Exception as e:

View File

@@ -352,3 +352,74 @@ async def test_resolve_authenticated_redirect_url_returns_location(monkeypatch):
) )
assert result == "https://signed.example.com/file.safetensors" assert result == "https://signed.example.com/file.safetensors"
@pytest.mark.asyncio
async def test_get_status_with_retry_passes_through_success(monkeypatch):
"""A successful first call returns immediately, no retries."""
downloader = Aria2Downloader()
call_count = 0
async def fake_get_status(_id):
nonlocal call_count
call_count += 1
return {"status": "active", "completedLength": "50", "totalLength": "100"}
monkeypatch.setattr(downloader, "get_status", fake_get_status)
result = await downloader._get_status_with_retry("dummy")
assert result is not None
assert result["status"] == "active"
assert call_count == 1
@pytest.mark.asyncio
async def test_get_status_with_retry_succeeds_after_transient_failure(monkeypatch):
"""A transient Aria2Error on the first call is retried and succeeds."""
downloader = Aria2Downloader()
call_count = 0
async def fake_get_status(_id):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Aria2Error("timeout")
return {"status": "complete", "completedLength": "100", "totalLength": "100"}
monkeypatch.setattr(downloader, "get_status", fake_get_status)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
result = await downloader._get_status_with_retry("dummy")
assert result is not None
assert result["status"] == "complete"
assert call_count == 2
@pytest.mark.asyncio
async def test_get_status_with_retry_raises_after_all_retries_exhausted(monkeypatch):
"""All retry attempts fail → Aria2Error with a descriptive message."""
downloader = Aria2Downloader()
async def fake_get_status(_id):
raise Aria2Error("connection reset")
monkeypatch.setattr(downloader, "get_status", fake_get_status)
monkeypatch.setattr("py.services.aria2_downloader.asyncio.sleep", AsyncMock())
with pytest.raises(Aria2Error) as exc_info:
await downloader._get_status_with_retry("dummy")
msg = str(exc_info.value)
assert "after 3 attempts" in msg
assert "connection reset" in msg
@pytest.mark.asyncio
async def test_get_status_with_retry_returns_none_when_not_tracked(monkeypatch):
"""No transfer in _transfers → get_status returns None → no retry needed."""
downloader = Aria2Downloader()
# get_status returns None when the download_id has no transfer;
# _get_status_with_retry should propagate that without raising.
result = await downloader._get_status_with_retry("nonexistent")
assert result is None