fix(download): recover stalled transfers automatically

This commit is contained in:
pixelpaws
2025-10-23 17:25:38 +08:00
parent 2eae8a7729
commit faa26651dd
4 changed files with 303 additions and 81 deletions

View File

@@ -36,6 +36,73 @@ class DownloadProgress:
timestamp: float
class DownloadStreamControl:
"""Synchronize pause/resume requests and reconnect hints for a download."""
def __init__(self, *, stall_timeout: Optional[float] = None) -> None:
self._event = asyncio.Event()
self._event.set()
self._reconnect_requested = False
self.last_progress_timestamp: Optional[float] = None
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0
def is_set(self) -> bool:
return self._event.is_set()
def is_paused(self) -> bool:
return not self._event.is_set()
def set(self) -> None:
self._event.set()
def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
await self._event.wait()
def pause(self) -> None:
self.clear()
def resume(self, *, force_reconnect: bool = False) -> None:
if force_reconnect:
self._reconnect_requested = True
self.set()
def request_reconnect(self) -> None:
self._reconnect_requested = True
self.set()
def has_reconnect_request(self) -> bool:
return self._reconnect_requested
def consume_reconnect_request(self) -> bool:
reconnect = self._reconnect_requested
self._reconnect_requested = False
return reconnect
def mark_progress(self, timestamp: Optional[float] = None) -> None:
self.last_progress_timestamp = timestamp or datetime.now().timestamp()
self._reconnect_requested = False
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]:
if self.last_progress_timestamp is None:
return None
reference = now if now is not None else datetime.now().timestamp()
return max(0.0, reference - self.last_progress_timestamp)
def update_stall_timeout(self, stall_timeout: float) -> None:
self.stall_timeout = float(stall_timeout)
class DownloadRestartRequested(Exception):
"""Raised when a caller explicitly requests a fresh HTTP stream."""
class DownloadStalledError(Exception):
"""Raised when download progress stalls beyond the configured timeout."""
class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
@@ -67,6 +134,7 @@ class Downloader:
self.max_retries = 5
self.base_delay = 2.0 # Base delay for exponential backoff
self.session_timeout = 300 # 5 minutes
self.stall_timeout = self._resolve_stall_timeout()
# Default headers
self.default_headers = {
@@ -82,14 +150,38 @@ class Downloader:
if self._session is None or self._should_refresh_session():
await self._create_session()
return self._session
@property
def proxy_url(self) -> Optional[str]:
"""Get the current proxy URL (initialize if needed)"""
if not hasattr(self, '_proxy_url'):
self._proxy_url = None
return self._proxy_url
def _resolve_stall_timeout(self) -> float:
"""Determine the stall timeout from settings or environment."""
default_timeout = 120.0
settings_timeout = None
try:
settings_manager = get_settings_manager()
settings_timeout = settings_manager.get('download_stall_timeout_seconds')
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Failed to read stall timeout from settings: %s", exc)
raw_value = (
settings_timeout
if settings_timeout not in (None, "")
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT')
)
try:
timeout_value = float(raw_value)
except (TypeError, ValueError):
timeout_value = default_timeout
return max(30.0, timeout_value)
def _should_refresh_session(self) -> bool:
"""Check if session should be refreshed"""
if self._session is None:
@@ -181,7 +273,7 @@ class Downloader:
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True,
pause_event: Optional[asyncio.Event] = None,
pause_event: Optional[DownloadStreamControl] = None,
) -> Tuple[bool, str]:
"""
Download a file with resumable downloads and retry mechanism
@@ -193,7 +285,7 @@ class Downloader:
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
custom_headers: Additional headers to include in request
allow_resume: Whether to support resumable downloads
pause_event: Optional event that, when cleared, will pause streaming until set again
pause_event: Optional stream control used to pause/resume and request reconnects
Returns:
Tuple[bool, str]: (success, save_path or error message)
@@ -307,51 +399,88 @@ class Downloader:
last_progress_report_time = datetime.now()
progress_samples: deque[tuple[datetime, int]] = deque()
progress_samples.append((last_progress_report_time, current_size))
# Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Stream download to file with progress updates
loop = asyncio.get_running_loop()
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
control = pause_event
if control is not None:
control.update_stall_timeout(self.stall_timeout)
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if pause_event is not None and not pause_event.is_set():
await pause_event.wait()
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
while True:
active_stall_timeout = control.stall_timeout if control else self.stall_timeout
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if control is not None:
if control.is_paused():
await control.wait()
resume_time = datetime.now()
last_progress_report_time = resume_time
if control.consume_reconnect_request():
raise DownloadRestartRequested(
"Reconnect requested after resume"
)
elif control.consume_reconnect_request():
raise DownloadRestartRequested("Reconnect requested")
if progress_callback and time_diff >= 1.0:
progress_samples.append((now, current_size))
cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff:
progress_samples.popleft()
try:
chunk = await asyncio.wait_for(
response.content.read(self.chunk_size),
timeout=active_stall_timeout,
)
except asyncio.TimeoutError as exc:
logger.warning(
"Download stalled for %.1f seconds without progress from %s",
active_stall_timeout,
url,
)
raise DownloadStalledError(
f"No data received for {active_stall_timeout:.1f} seconds"
) from exc
percent = (current_size / total_size) * 100 if total_size else 0.0
bytes_per_second = 0.0
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
if not chunk:
break
progress_snapshot = DownloadProgress(
percent_complete=percent,
bytes_downloaded=current_size,
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
last_progress_report_time = now
now = datetime.now()
if control is not None:
control.mark_progress(timestamp=now.timestamp())
# Limit progress update frequency to reduce overhead
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and time_diff >= 1.0:
progress_samples.append((now, current_size))
cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff:
progress_samples.popleft()
percent = (current_size / total_size) * 100 if total_size else 0.0
bytes_per_second = 0.0
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
progress_snapshot = DownloadProgress(
percent_complete=percent,
bytes_downloaded=current_size,
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
last_progress_report_time = now
# Download completed successfully
# Verify file size integrity before finalizing
@@ -447,11 +576,17 @@ class Downloader:
return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
except (
aiohttp.ClientError,
aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError,
asyncio.TimeoutError,
DownloadStalledError,
DownloadRestartRequested,
) as e:
retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
if retry_count <= self.max_retries:
# Calculate delay with exponential backoff
delay = self.base_delay * (2 ** (retry_count - 1))