diff --git a/py/services/downloader.py b/py/services/downloader.py index 3b07cb1f..55b6b40e 100644 --- a/py/services/downloader.py +++ b/py/services/downloader.py @@ -138,7 +138,7 @@ class Downloader: self.chunk_size = ( 16 * 1024 * 1024 ) # 16MB chunks to balance I/O reduction and memory usage - self.max_retries = 5 + self.max_retries = self._resolve_max_retries() self.base_delay = 2.0 # Base delay for exponential backoff self.session_timeout = 300 # 5 minutes self.stall_timeout = self._resolve_stall_timeout() @@ -192,6 +192,18 @@ class Downloader: return max(30.0, timeout_value) + def _resolve_max_retries(self) -> int: + """Determine max retry count from environment while preserving defaults.""" + default_retries = 5 + raw_value = os.environ.get("COMFYUI_DOWNLOAD_MAX_RETRIES") + + try: + retries = int(raw_value) + except (TypeError, ValueError): + retries = default_retries + + return max(0, retries) + def _should_refresh_session(self) -> bool: """Check if session should be refreshed""" if self._session is None: @@ -334,6 +346,7 @@ class Downloader: logger.info(f"Resuming download from offset {resume_offset} bytes") total_size = 0 + range_redirect_retry_urls: set[str] = set() while retry_count <= self.max_retries: try: @@ -372,6 +385,23 @@ class Downloader: if response.status == 200: # Full content response if resume_offset > 0: + redirected_url = str(response.url) + if ( + allow_resume + and response.history + and redirected_url + and redirected_url != url + and redirected_url not in range_redirect_retry_urls + ): + range_redirect_retry_urls.add(redirected_url) + logger.info( + "Range request was not honored after redirect; retrying final URL directly: %s", + redirected_url, + ) + url = redirected_url + response.release() + continue + # Server doesn't support ranges, restart from beginning logger.warning( "Server doesn't support range requests, restarting download" @@ -571,37 +601,53 @@ class Downloader: expected_size = total_size if total_size > 0 else None integrity_error: Optional[str] = None + resumable_incomplete = False if final_size <= 0: integrity_error = "Downloaded file is empty" elif expected_size is not None and final_size != expected_size: integrity_error = f"File size mismatch. Expected: {expected_size}, Got: {final_size}" + resumable_incomplete = ( + allow_resume + and part_path != save_path + and final_size > 0 + and final_size < expected_size + ) if integrity_error is not None: - logger.error( + log_fn = logger.warning if resumable_incomplete else logger.error + log_fn( "Download integrity check failed for %s: %s", save_path, integrity_error, ) - # Remove the corrupted payload so future attempts start fresh - if os.path.exists(part_path): - try: - os.remove(part_path) - except OSError as remove_error: - logger.warning( - "Failed to delete corrupted download %s: %s", - part_path, - remove_error, - ) - if part_path != save_path and os.path.exists(save_path): - try: - os.remove(save_path) - except OSError as remove_error: - logger.warning( - "Failed to delete target file %s after integrity error: %s", - save_path, - remove_error, - ) + if resumable_incomplete: + logger.info( + "Preserving incomplete download for resume: %s (%s/%s bytes)", + part_path, + final_size, + expected_size, + ) + else: + # Remove corrupted payloads that cannot be safely resumed. + if os.path.exists(part_path): + try: + os.remove(part_path) + except OSError as remove_error: + logger.warning( + "Failed to delete corrupted download %s: %s", + part_path, + remove_error, + ) + if part_path != save_path and os.path.exists(save_path): + try: + os.remove(save_path) + except OSError as remove_error: + logger.warning( + "Failed to delete target file %s after integrity error: %s", + save_path, + remove_error, + ) retry_count += 1 if retry_count <= self.max_retries: @@ -611,8 +657,16 @@ class Downloader: delay, ) await asyncio.sleep(delay) - resume_offset = 0 - total_size = 0 + if resumable_incomplete and os.path.exists(part_path): + resume_offset = os.path.getsize(part_path) + total_size = expected_size or 0 + logger.info( + "Will resume incomplete download from byte %s", + resume_offset, + ) + else: + resume_offset = 0 + total_size = 0 await self._create_session() continue diff --git a/tests/services/test_downloader.py b/tests/services/test_downloader.py index 156ab276..84bd858b 100644 --- a/tests/services/test_downloader.py +++ b/tests/services/test_downloader.py @@ -30,10 +30,21 @@ class FakeStream: class FakeResponse: - def __init__(self, status, headers, chunks): + def __init__( + self, + status, + headers, + chunks, + *, + url="https://example.com/file", + history=None, + ): self.status = status self.headers = headers self.content = FakeStream(chunks) + self.url = url + self.history = history or [] + self.released = False async def __aenter__(self): return self @@ -41,14 +52,25 @@ class FakeResponse: async def __aexit__(self, exc_type, exc, tb): return False + def release(self): + self.released = True + class FakeSession: def __init__(self, responses): self._responses = list(responses) self._get_calls = 0 + self.requests = [] def get(self, url, headers=None, allow_redirects=True, proxy=None): # noqa: D401 - signature mirrors aiohttp - del url, headers, allow_redirects, proxy + self.requests.append( + { + "url": url, + "headers": headers or {}, + "allow_redirects": allow_redirects, + "proxy": proxy, + } + ) response_factory = self._responses[self._get_calls] self._get_calls += 1 return response_factory() @@ -75,7 +97,7 @@ def _build_downloader(responses, *, max_retries=0): @pytest.mark.asyncio -async def test_download_file_fails_when_size_mismatch(tmp_path): +async def test_download_file_preserves_incomplete_part_when_size_mismatch(tmp_path): target_path = tmp_path / "model" / "file.bin" target_path.parent.mkdir() @@ -94,7 +116,7 @@ async def test_download_file_fails_when_size_mismatch(tmp_path): assert success is False assert "mismatch" in message.lower() assert not target_path.exists() - assert not Path(str(target_path) + ".part").exists() + assert Path(str(target_path) + ".part").read_bytes() == b"abc" @pytest.mark.asyncio @@ -136,7 +158,9 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path): downloader = _build_downloader(responses) - success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) + success, result_path = await downloader.download_file( + "https://example.com/file", str(target_path) + ) assert success is True assert Path(result_path).read_bytes() == payload @@ -166,9 +190,77 @@ async def test_download_file_recovers_from_stall(tmp_path): downloader = _build_downloader(responses, max_retries=1) downloader.stall_timeout = 0.05 - success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) + success, result_path = await downloader.download_file( + "https://example.com/file", str(target_path) + ) assert success is True assert Path(result_path).read_bytes() == payload assert downloader._session._get_calls == 2 assert not Path(str(target_path) + ".part").exists() + + +@pytest.mark.asyncio +async def test_download_file_resumes_after_incomplete_integrity_check(tmp_path): + target_path = tmp_path / "model" / "file.bin" + target_path.parent.mkdir() + + responses = [ + lambda: FakeResponse( + status=200, + headers={"content-length": "6"}, + chunks=[b"abc"], + ), + lambda: FakeResponse( + status=206, + headers={"content-length": "3", "Content-Range": "bytes 3-5/6"}, + chunks=[b"def"], + ), + ] + + downloader = _build_downloader(responses, max_retries=1) + + success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) + + assert success is True + assert Path(result_path).read_bytes() == b"abcdef" + assert downloader._session._get_calls == 2 + assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-" + assert not Path(str(target_path) + ".part").exists() + + +@pytest.mark.asyncio +async def test_download_file_retries_redirected_url_when_range_not_honored(tmp_path): + target_path = tmp_path / "model" / "file.bin" + target_path.parent.mkdir() + Path(str(target_path) + ".part").write_bytes(b"abc") + + redirected_url = "https://download.example.com/file.bin" + first_response = FakeResponse( + status=200, + headers={"content-length": "6"}, + chunks=[], + url=redirected_url, + history=[object()], + ) + + responses = [ + lambda: first_response, + lambda: FakeResponse( + status=206, + headers={"content-length": "3", "Content-Range": "bytes 3-5/6"}, + chunks=[b"def"], + url=redirected_url, + ), + ] + + downloader = _build_downloader(responses, max_retries=0) + + success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) + + assert success is True + assert Path(result_path).read_bytes() == b"abcdef" + assert first_response.released is True + assert downloader._session.requests[0]["headers"]["Range"] == "bytes=3-" + assert downloader._session.requests[1]["url"] == redirected_url + assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-"