fix: preserve resumable downloads across retries

This commit is contained in:
Shuangrui CHEN
2026-04-17 03:35:41 +08:00
parent 89fd2b43d6
commit fa049a28c8
2 changed files with 175 additions and 29 deletions

View File

@@ -138,7 +138,7 @@ class Downloader:
self.chunk_size = ( self.chunk_size = (
16 * 1024 * 1024 16 * 1024 * 1024
) # 16MB chunks to balance I/O reduction and memory usage ) # 16MB chunks to balance I/O reduction and memory usage
self.max_retries = 5 self.max_retries = self._resolve_max_retries()
self.base_delay = 2.0 # Base delay for exponential backoff self.base_delay = 2.0 # Base delay for exponential backoff
self.session_timeout = 300 # 5 minutes self.session_timeout = 300 # 5 minutes
self.stall_timeout = self._resolve_stall_timeout() self.stall_timeout = self._resolve_stall_timeout()
@@ -192,6 +192,18 @@ class Downloader:
return max(30.0, timeout_value) return max(30.0, timeout_value)
def _resolve_max_retries(self) -> int:
"""Determine max retry count from environment while preserving defaults."""
default_retries = 5
raw_value = os.environ.get("COMFYUI_DOWNLOAD_MAX_RETRIES")
try:
retries = int(raw_value)
except (TypeError, ValueError):
retries = default_retries
return max(0, retries)
def _should_refresh_session(self) -> bool: def _should_refresh_session(self) -> bool:
"""Check if session should be refreshed""" """Check if session should be refreshed"""
if self._session is None: if self._session is None:
@@ -334,6 +346,7 @@ class Downloader:
logger.info(f"Resuming download from offset {resume_offset} bytes") logger.info(f"Resuming download from offset {resume_offset} bytes")
total_size = 0 total_size = 0
range_redirect_retry_urls: set[str] = set()
while retry_count <= self.max_retries: while retry_count <= self.max_retries:
try: try:
@@ -372,6 +385,23 @@ class Downloader:
if response.status == 200: if response.status == 200:
# Full content response # Full content response
if resume_offset > 0: if resume_offset > 0:
redirected_url = str(response.url)
if (
allow_resume
and response.history
and redirected_url
and redirected_url != url
and redirected_url not in range_redirect_retry_urls
):
range_redirect_retry_urls.add(redirected_url)
logger.info(
"Range request was not honored after redirect; retrying final URL directly: %s",
redirected_url,
)
url = redirected_url
response.release()
continue
# Server doesn't support ranges, restart from beginning # Server doesn't support ranges, restart from beginning
logger.warning( logger.warning(
"Server doesn't support range requests, restarting download" "Server doesn't support range requests, restarting download"
@@ -571,37 +601,53 @@ class Downloader:
expected_size = total_size if total_size > 0 else None expected_size = total_size if total_size > 0 else None
integrity_error: Optional[str] = None integrity_error: Optional[str] = None
resumable_incomplete = False
if final_size <= 0: if final_size <= 0:
integrity_error = "Downloaded file is empty" integrity_error = "Downloaded file is empty"
elif expected_size is not None and final_size != expected_size: elif expected_size is not None and final_size != expected_size:
integrity_error = f"File size mismatch. Expected: {expected_size}, Got: {final_size}" integrity_error = f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
resumable_incomplete = (
allow_resume
and part_path != save_path
and final_size > 0
and final_size < expected_size
)
if integrity_error is not None: if integrity_error is not None:
logger.error( log_fn = logger.warning if resumable_incomplete else logger.error
log_fn(
"Download integrity check failed for %s: %s", "Download integrity check failed for %s: %s",
save_path, save_path,
integrity_error, integrity_error,
) )
# Remove the corrupted payload so future attempts start fresh if resumable_incomplete:
if os.path.exists(part_path): logger.info(
try: "Preserving incomplete download for resume: %s (%s/%s bytes)",
os.remove(part_path) part_path,
except OSError as remove_error: final_size,
logger.warning( expected_size,
"Failed to delete corrupted download %s: %s", )
part_path, else:
remove_error, # Remove corrupted payloads that cannot be safely resumed.
) if os.path.exists(part_path):
if part_path != save_path and os.path.exists(save_path): try:
try: os.remove(part_path)
os.remove(save_path) except OSError as remove_error:
except OSError as remove_error: logger.warning(
logger.warning( "Failed to delete corrupted download %s: %s",
"Failed to delete target file %s after integrity error: %s", part_path,
save_path, remove_error,
remove_error, )
) if part_path != save_path and os.path.exists(save_path):
try:
os.remove(save_path)
except OSError as remove_error:
logger.warning(
"Failed to delete target file %s after integrity error: %s",
save_path,
remove_error,
)
retry_count += 1 retry_count += 1
if retry_count <= self.max_retries: if retry_count <= self.max_retries:
@@ -611,8 +657,16 @@ class Downloader:
delay, delay,
) )
await asyncio.sleep(delay) await asyncio.sleep(delay)
resume_offset = 0 if resumable_incomplete and os.path.exists(part_path):
total_size = 0 resume_offset = os.path.getsize(part_path)
total_size = expected_size or 0
logger.info(
"Will resume incomplete download from byte %s",
resume_offset,
)
else:
resume_offset = 0
total_size = 0
await self._create_session() await self._create_session()
continue continue

View File

@@ -30,10 +30,21 @@ class FakeStream:
class FakeResponse: class FakeResponse:
def __init__(self, status, headers, chunks): def __init__(
self,
status,
headers,
chunks,
*,
url="https://example.com/file",
history=None,
):
self.status = status self.status = status
self.headers = headers self.headers = headers
self.content = FakeStream(chunks) self.content = FakeStream(chunks)
self.url = url
self.history = history or []
self.released = False
async def __aenter__(self): async def __aenter__(self):
return self return self
@@ -41,14 +52,25 @@ class FakeResponse:
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
return False return False
def release(self):
self.released = True
class FakeSession: class FakeSession:
def __init__(self, responses): def __init__(self, responses):
self._responses = list(responses) self._responses = list(responses)
self._get_calls = 0 self._get_calls = 0
self.requests = []
def get(self, url, headers=None, allow_redirects=True, proxy=None): # noqa: D401 - signature mirrors aiohttp def get(self, url, headers=None, allow_redirects=True, proxy=None): # noqa: D401 - signature mirrors aiohttp
del url, headers, allow_redirects, proxy self.requests.append(
{
"url": url,
"headers": headers or {},
"allow_redirects": allow_redirects,
"proxy": proxy,
}
)
response_factory = self._responses[self._get_calls] response_factory = self._responses[self._get_calls]
self._get_calls += 1 self._get_calls += 1
return response_factory() return response_factory()
@@ -75,7 +97,7 @@ def _build_downloader(responses, *, max_retries=0):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_download_file_fails_when_size_mismatch(tmp_path): async def test_download_file_preserves_incomplete_part_when_size_mismatch(tmp_path):
target_path = tmp_path / "model" / "file.bin" target_path = tmp_path / "model" / "file.bin"
target_path.parent.mkdir() target_path.parent.mkdir()
@@ -94,7 +116,7 @@ async def test_download_file_fails_when_size_mismatch(tmp_path):
assert success is False assert success is False
assert "mismatch" in message.lower() assert "mismatch" in message.lower()
assert not target_path.exists() assert not target_path.exists()
assert not Path(str(target_path) + ".part").exists() assert Path(str(target_path) + ".part").read_bytes() == b"abc"
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -136,7 +158,9 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path):
downloader = _build_downloader(responses) downloader = _build_downloader(responses)
success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) success, result_path = await downloader.download_file(
"https://example.com/file", str(target_path)
)
assert success is True assert success is True
assert Path(result_path).read_bytes() == payload assert Path(result_path).read_bytes() == payload
@@ -166,9 +190,77 @@ async def test_download_file_recovers_from_stall(tmp_path):
downloader = _build_downloader(responses, max_retries=1) downloader = _build_downloader(responses, max_retries=1)
downloader.stall_timeout = 0.05 downloader.stall_timeout = 0.05
success, result_path = await downloader.download_file("https://example.com/file", str(target_path)) success, result_path = await downloader.download_file(
"https://example.com/file", str(target_path)
)
assert success is True assert success is True
assert Path(result_path).read_bytes() == payload assert Path(result_path).read_bytes() == payload
assert downloader._session._get_calls == 2 assert downloader._session._get_calls == 2
assert not Path(str(target_path) + ".part").exists() assert not Path(str(target_path) + ".part").exists()
@pytest.mark.asyncio
async def test_download_file_resumes_after_incomplete_integrity_check(tmp_path):
target_path = tmp_path / "model" / "file.bin"
target_path.parent.mkdir()
responses = [
lambda: FakeResponse(
status=200,
headers={"content-length": "6"},
chunks=[b"abc"],
),
lambda: FakeResponse(
status=206,
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
chunks=[b"def"],
),
]
downloader = _build_downloader(responses, max_retries=1)
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
assert success is True
assert Path(result_path).read_bytes() == b"abcdef"
assert downloader._session._get_calls == 2
assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-"
assert not Path(str(target_path) + ".part").exists()
@pytest.mark.asyncio
async def test_download_file_retries_redirected_url_when_range_not_honored(tmp_path):
target_path = tmp_path / "model" / "file.bin"
target_path.parent.mkdir()
Path(str(target_path) + ".part").write_bytes(b"abc")
redirected_url = "https://download.example.com/file.bin"
first_response = FakeResponse(
status=200,
headers={"content-length": "6"},
chunks=[],
url=redirected_url,
history=[object()],
)
responses = [
lambda: first_response,
lambda: FakeResponse(
status=206,
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
chunks=[b"def"],
url=redirected_url,
),
]
downloader = _build_downloader(responses, max_retries=0)
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
assert success is True
assert Path(result_path).read_bytes() == b"abcdef"
assert first_response.released is True
assert downloader._session.requests[0]["headers"]["Range"] == "bytes=3-"
assert downloader._session.requests[1]["url"] == redirected_url
assert downloader._session.requests[1]["headers"]["Range"] == "bytes=3-"