From dde2b2a9605a05fa01ab66775380db9ea80e3765 Mon Sep 17 00:00:00 2001 From: pixelpaws Date: Thu, 23 Oct 2025 11:55:39 +0800 Subject: [PATCH] fix(downloader): enforce file size integrity checks --- py/services/downloader.py | 64 +++++++++++++-- tests/services/test_downloader.py | 125 ++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 7 deletions(-) create mode 100644 tests/services/test_downloader.py diff --git a/py/services/downloader.py b/py/services/downloader.py index aadfdc89..0f0b7045 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -354,12 +354,60 @@ class Downloader: last_progress_report_time = now # Download completed successfully - # Verify file size if total_size was provided - final_size = os.path.getsize(part_path) - if total_size > 0 and final_size != total_size: - logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}") - # Don't treat this as fatal error, continue anyway - + # Verify file size integrity before finalizing + final_size = os.path.getsize(part_path) if os.path.exists(part_path) else 0 + expected_size = total_size if total_size > 0 else None + + integrity_error: Optional[str] = None + if final_size <= 0: + integrity_error = "Downloaded file is empty" + elif expected_size is not None and final_size != expected_size: + integrity_error = ( + f"File size mismatch. Expected: {expected_size}, Got: {final_size}" + ) + + if integrity_error is not None: + logger.error( + "Download integrity check failed for %s: %s", + save_path, + integrity_error, + ) + + # Remove the corrupted payload so future attempts start fresh + if os.path.exists(part_path): + try: + os.remove(part_path) + except OSError as remove_error: + logger.warning( + "Failed to delete corrupted download %s: %s", + part_path, + remove_error, + ) + if part_path != save_path and os.path.exists(save_path): + try: + os.remove(save_path) + except OSError as remove_error: + logger.warning( + "Failed to delete target file %s after integrity error: %s", + save_path, + remove_error, + ) + + retry_count += 1 + if retry_count <= self.max_retries: + delay = self.base_delay * (2 ** (retry_count - 1)) + logger.info( + "Retrying download in %s seconds due to integrity check failure", + delay, + ) + await asyncio.sleep(delay) + resume_offset = 0 + total_size = 0 + await self._create_session() + continue + + return False, integrity_error + # Atomically rename .part to final file (only if using resume) if allow_resume and part_path != save_path: max_rename_attempts = 5 @@ -382,7 +430,9 @@ class Downloader: else: logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}") return False, f"Failed to finalize download: {str(e)}" - + + final_size = os.path.getsize(save_path) + # Ensure 100% progress is reported if progress_callback: final_snapshot = DownloadProgress( diff --git a/tests/services/test_downloader.py b/tests/services/test_downloader.py new file mode 100644 index 00000000..61469957 --- /dev/null +++ b/tests/services/test_downloader.py @@ -0,0 +1,125 @@ +import asyncio +from datetime import datetime +from pathlib import Path + +import pytest + +from py.services.downloader import Downloader + + +class FakeStream: + def __init__(self, chunks): + self._chunks = list(chunks) + + async def iter_chunked(self, _chunk_size): + for chunk in self._chunks: + await asyncio.sleep(0) + yield chunk + + +class FakeResponse: + def __init__(self, status, headers, chunks): + self.status = status + self.headers = headers + self.content = FakeStream(chunks) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + +class FakeSession: + def __init__(self, responses): + self._responses = list(responses) + self._get_calls = 0 + + def get(self, url, headers=None, allow_redirects=True, proxy=None): # noqa: D401 - signature mirrors aiohttp + del url, headers, allow_redirects, proxy + response_factory = self._responses[self._get_calls] + self._get_calls += 1 + return response_factory() + + async def close(self): + return None + + +def _build_downloader(responses, *, max_retries=0): + downloader = Downloader() + downloader.max_retries = max_retries + downloader.base_delay = 0 + fake_session = FakeSession(responses) + downloader._session = fake_session + downloader._session_created_at = datetime.now() + downloader._proxy_url = None + return downloader + + +@pytest.mark.asyncio +async def test_download_file_fails_when_size_mismatch(tmp_path): + target_path = tmp_path / "model" / "file.bin" + target_path.parent.mkdir() + + responses = [ + lambda: FakeResponse( + status=200, + headers={"content-length": "10"}, + chunks=[b"abc"], + ) + ] + + downloader = _build_downloader(responses) + + success, message = await downloader.download_file("https://example.com/file", str(target_path)) + + assert success is False + assert "mismatch" in message.lower() + assert not target_path.exists() + assert not Path(str(target_path) + ".part").exists() + + +@pytest.mark.asyncio +async def test_download_file_fails_when_zero_bytes(tmp_path): + target_path = tmp_path / "model" / "file.bin" + target_path.parent.mkdir() + + responses = [ + lambda: FakeResponse( + status=200, + headers={"content-length": "0"}, + chunks=[], + ) + ] + + downloader = _build_downloader(responses) + + success, message = await downloader.download_file("https://example.com/file", str(target_path)) + + assert success is False + assert "empty" in message.lower() + assert not target_path.exists() + assert not Path(str(target_path) + ".part").exists() + + +@pytest.mark.asyncio +async def test_download_file_succeeds_when_sizes_match(tmp_path): + target_path = tmp_path / "model" / "file.bin" + target_path.parent.mkdir() + + payload = b"abcdef" + responses = [ + lambda: FakeResponse( + status=200, + headers={"content-length": str(len(payload))}, + chunks=[payload], + ) + ] + + downloader = _build_downloader(responses) + + success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) + + assert success is True + assert Path(result_path).read_bytes() == payload + assert not Path(str(target_path) + ".part").exists()