Merge pull request #589 from willmiao/codex/fix-download-stalling-issues

Fix stalled downloads by adding stall detection and reconnect logic
This commit is contained in:
pixelpaws
2025-10-23 17:40:31 +08:00
committed by GitHub
4 changed files with 303 additions and 81 deletions

View File

@@ -15,7 +15,7 @@ from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader, DownloadProgress
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
# Download to temporary file first
import tempfile
@@ -44,7 +44,7 @@ class DownloadManager:
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
self._pause_events: Dict[str, asyncio.Event] = {}
self._pause_events: Dict[str, DownloadStreamControl] = {}
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
@@ -89,11 +89,11 @@ class DownloadManager:
'bytes_downloaded': 0,
'total_bytes': None,
'bytes_per_second': 0.0,
'last_progress_timestamp': None,
}
pause_event = asyncio.Event()
pause_event.set()
self._pause_events[task_id] = pause_event
pause_control = DownloadStreamControl()
self._pause_events[task_id] = pause_control
# Create tracking task
download_task = asyncio.create_task(
@@ -140,19 +140,23 @@ class DownloadManager:
info['bytes_downloaded'] = snapshot.bytes_downloaded
info['total_bytes'] = snapshot.total_bytes
info['bytes_per_second'] = snapshot.bytes_per_second
pause_control = self._pause_events.get(task_id)
if isinstance(pause_control, DownloadStreamControl):
pause_control.mark_progress(snapshot.timestamp)
info['last_progress_timestamp'] = pause_control.last_progress_timestamp
if original_callback:
await self._dispatch_progress(original_callback, snapshot, progress_value)
# Acquire semaphore to limit concurrent downloads
try:
async with self._download_semaphore:
pause_event = self._pause_events.get(task_id)
if pause_event is not None and not pause_event.is_set():
pause_control = self._pause_events.get(task_id)
if pause_control is not None and pause_control.is_paused():
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'paused'
self._active_downloads[task_id]['bytes_per_second'] = 0.0
await pause_event.wait()
await pause_control.wait()
# Update status to downloading
if task_id in self._active_downloads:
@@ -478,8 +482,8 @@ class DownloadManager:
part_path = save_path + '.part'
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
pause_event = self._pause_events.get(download_id) if download_id else None
pause_control = self._pause_events.get(download_id) if download_id else None
# Store file paths in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = save_path
@@ -590,6 +594,8 @@ class DownloadManager:
# Download model file with progress tracking using downloader
downloader = await get_downloader()
if pause_control is not None:
pause_control.update_stall_timeout(downloader.stall_timeout)
last_error = None
for download_url in download_urls:
use_auth = download_url.startswith("https://civitai.com/api/download/")
@@ -602,8 +608,8 @@ class DownloadManager:
"use_auth": use_auth, # Only use authentication for Civitai downloads
}
if pause_event is not None:
download_kwargs["pause_event"] = pause_event
if pause_control is not None:
download_kwargs["pause_event"] = pause_control
success, result = await downloader.download_file(
download_url,
@@ -756,9 +762,9 @@ class DownloadManager:
task = self._download_tasks[download_id]
task.cancel()
pause_event = self._pause_events.get(download_id)
if pause_event is not None:
pause_event.set()
pause_control = self._pause_events.get(download_id)
if pause_control is not None:
pause_control.resume()
# Update status in active downloads
if download_id in self._active_downloads:
@@ -835,16 +841,14 @@ class DownloadManager:
if download_id not in self._download_tasks:
return {'success': False, 'error': 'Download task not found'}
pause_event = self._pause_events.get(download_id)
if pause_event is None:
pause_event = asyncio.Event()
pause_event.set()
self._pause_events[download_id] = pause_event
pause_control = self._pause_events.get(download_id)
if pause_control is None:
return {'success': False, 'error': 'Download task not found'}
if not pause_event.is_set():
if pause_control.is_paused():
return {'success': False, 'error': 'Download is already paused'}
pause_event.clear()
pause_control.pause()
download_info = self._active_downloads.get(download_id)
if download_info is not None:
@@ -856,16 +860,28 @@ class DownloadManager:
async def resume_download(self, download_id: str) -> Dict:
"""Resume a previously paused download."""
pause_event = self._pause_events.get(download_id)
if pause_event is None:
pause_control = self._pause_events.get(download_id)
if pause_control is None:
return {'success': False, 'error': 'Download task not found'}
if pause_event.is_set():
if pause_control.is_set():
return {'success': False, 'error': 'Download is not paused'}
pause_event.set()
download_info = self._active_downloads.get(download_id)
force_reconnect = False
if pause_control is not None:
elapsed = pause_control.time_since_last_progress()
threshold = max(30.0, pause_control.stall_timeout / 2.0)
if elapsed is not None and elapsed >= threshold:
force_reconnect = True
logger.info(
"Forcing reconnect for download %s after %.1f seconds without progress",
download_id,
elapsed,
)
pause_control.resume(force_reconnect=force_reconnect)
if download_info is not None:
if download_info.get('status') == 'paused':
download_info['status'] = 'downloading'

View File

@@ -36,6 +36,73 @@ class DownloadProgress:
timestamp: float
class DownloadStreamControl:
"""Synchronize pause/resume requests and reconnect hints for a download."""
def __init__(self, *, stall_timeout: Optional[float] = None) -> None:
self._event = asyncio.Event()
self._event.set()
self._reconnect_requested = False
self.last_progress_timestamp: Optional[float] = None
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0
def is_set(self) -> bool:
return self._event.is_set()
def is_paused(self) -> bool:
return not self._event.is_set()
def set(self) -> None:
self._event.set()
def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
await self._event.wait()
def pause(self) -> None:
self.clear()
def resume(self, *, force_reconnect: bool = False) -> None:
if force_reconnect:
self._reconnect_requested = True
self.set()
def request_reconnect(self) -> None:
self._reconnect_requested = True
self.set()
def has_reconnect_request(self) -> bool:
return self._reconnect_requested
def consume_reconnect_request(self) -> bool:
reconnect = self._reconnect_requested
self._reconnect_requested = False
return reconnect
def mark_progress(self, timestamp: Optional[float] = None) -> None:
self.last_progress_timestamp = timestamp or datetime.now().timestamp()
self._reconnect_requested = False
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]:
if self.last_progress_timestamp is None:
return None
reference = now if now is not None else datetime.now().timestamp()
return max(0.0, reference - self.last_progress_timestamp)
def update_stall_timeout(self, stall_timeout: float) -> None:
self.stall_timeout = float(stall_timeout)
class DownloadRestartRequested(Exception):
"""Raised when a caller explicitly requests a fresh HTTP stream."""
class DownloadStalledError(Exception):
"""Raised when download progress stalls beyond the configured timeout."""
class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
@@ -67,6 +134,7 @@ class Downloader:
self.max_retries = 5
self.base_delay = 2.0 # Base delay for exponential backoff
self.session_timeout = 300 # 5 minutes
self.stall_timeout = self._resolve_stall_timeout()
# Default headers
self.default_headers = {
@@ -82,14 +150,38 @@ class Downloader:
if self._session is None or self._should_refresh_session():
await self._create_session()
return self._session
@property
def proxy_url(self) -> Optional[str]:
"""Get the current proxy URL (initialize if needed)"""
if not hasattr(self, '_proxy_url'):
self._proxy_url = None
return self._proxy_url
def _resolve_stall_timeout(self) -> float:
"""Determine the stall timeout from settings or environment."""
default_timeout = 120.0
settings_timeout = None
try:
settings_manager = get_settings_manager()
settings_timeout = settings_manager.get('download_stall_timeout_seconds')
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Failed to read stall timeout from settings: %s", exc)
raw_value = (
settings_timeout
if settings_timeout not in (None, "")
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT')
)
try:
timeout_value = float(raw_value)
except (TypeError, ValueError):
timeout_value = default_timeout
return max(30.0, timeout_value)
def _should_refresh_session(self) -> bool:
"""Check if session should be refreshed"""
if self._session is None:
@@ -181,7 +273,7 @@ class Downloader:
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True,
pause_event: Optional[asyncio.Event] = None,
pause_event: Optional[DownloadStreamControl] = None,
) -> Tuple[bool, str]:
"""
Download a file with resumable downloads and retry mechanism
@@ -193,7 +285,7 @@ class Downloader:
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
custom_headers: Additional headers to include in request
allow_resume: Whether to support resumable downloads
pause_event: Optional event that, when cleared, will pause streaming until set again
pause_event: Optional stream control used to pause/resume and request reconnects
Returns:
Tuple[bool, str]: (success, save_path or error message)
@@ -307,51 +399,88 @@ class Downloader:
last_progress_report_time = datetime.now()
progress_samples: deque[tuple[datetime, int]] = deque()
progress_samples.append((last_progress_report_time, current_size))
# Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Stream download to file with progress updates
loop = asyncio.get_running_loop()
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
control = pause_event
if control is not None:
control.update_stall_timeout(self.stall_timeout)
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if pause_event is not None and not pause_event.is_set():
await pause_event.wait()
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
while True:
active_stall_timeout = control.stall_timeout if control else self.stall_timeout
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if control is not None:
if control.is_paused():
await control.wait()
resume_time = datetime.now()
last_progress_report_time = resume_time
if control.consume_reconnect_request():
raise DownloadRestartRequested(
"Reconnect requested after resume"
)
elif control.consume_reconnect_request():
raise DownloadRestartRequested("Reconnect requested")
if progress_callback and time_diff >= 1.0:
progress_samples.append((now, current_size))
cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff:
progress_samples.popleft()
try:
chunk = await asyncio.wait_for(
response.content.read(self.chunk_size),
timeout=active_stall_timeout,
)
except asyncio.TimeoutError as exc:
logger.warning(
"Download stalled for %.1f seconds without progress from %s",
active_stall_timeout,
url,
)
raise DownloadStalledError(
f"No data received for {active_stall_timeout:.1f} seconds"
) from exc
percent = (current_size / total_size) * 100 if total_size else 0.0
bytes_per_second = 0.0
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
if not chunk:
break
progress_snapshot = DownloadProgress(
percent_complete=percent,
bytes_downloaded=current_size,
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
last_progress_report_time = now
now = datetime.now()
if control is not None:
control.mark_progress(timestamp=now.timestamp())
# Limit progress update frequency to reduce overhead
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and time_diff >= 1.0:
progress_samples.append((now, current_size))
cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff:
progress_samples.popleft()
percent = (current_size / total_size) * 100 if total_size else 0.0
bytes_per_second = 0.0
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
progress_snapshot = DownloadProgress(
percent_complete=percent,
bytes_downloaded=current_size,
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
last_progress_report_time = now
# Download completed successfully
# Verify file size integrity before finalizing
@@ -447,11 +576,17 @@ class Downloader:
return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
except (
aiohttp.ClientError,
aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError,
asyncio.TimeoutError,
DownloadStalledError,
DownloadRestartRequested,
) as e:
retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
if retry_count <= self.max_retries:
# Calculate delay with exponential backoff
delay = self.base_delay * (2 ** (retry_count - 1))

View File

@@ -1,5 +1,6 @@
import asyncio
import os
from datetime import datetime
from pathlib import Path
from typing import Optional
from types import SimpleNamespace
@@ -8,6 +9,7 @@ from unittest.mock import AsyncMock
import pytest
from py.services.download_manager import DownloadManager
from py.services.downloader import DownloadStreamControl
from py.services import download_manager
from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager
@@ -528,9 +530,8 @@ async def test_pause_download_updates_state():
download_id = "dl"
manager._download_tasks[download_id] = object()
pause_event = asyncio.Event()
pause_event.set()
manager._pause_events[download_id] = pause_event
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"status": "downloading",
"bytes_per_second": 42.0,
@@ -557,8 +558,10 @@ async def test_resume_download_sets_event_and_status():
manager = DownloadManager()
download_id = "dl"
pause_event = asyncio.Event()
manager._pause_events[download_id] = pause_event
pause_control = DownloadStreamControl()
pause_control.pause()
pause_control.mark_progress()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"status": "paused",
"bytes_per_second": 0.0,
@@ -571,13 +574,32 @@ async def test_resume_download_sets_event_and_status():
assert manager._active_downloads[download_id]["status"] == "downloading"
async def test_resume_download_requests_reconnect_for_stalled_stream():
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl(stall_timeout=40)
pause_control.pause()
pause_control.last_progress_timestamp = (datetime.now().timestamp() - 120)
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"status": "paused",
"bytes_per_second": 0.0,
}
result = await manager.resume_download(download_id)
assert result == {"success": True, "message": "Download resumed successfully"}
assert pause_control.is_set() is True
assert pause_control.has_reconnect_request() is True
async def test_resume_download_rejects_when_not_paused():
manager = DownloadManager()
download_id = "dl"
pause_event = asyncio.Event()
pause_event.set()
manager._pause_events[download_id] = pause_event
pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_control
result = await manager.resume_download(download_id)

View File

@@ -1,6 +1,7 @@
import asyncio
from datetime import datetime
from pathlib import Path
from typing import Sequence
import pytest
@@ -8,13 +9,24 @@ from py.services.downloader import Downloader
class FakeStream:
def __init__(self, chunks):
def __init__(self, chunks: Sequence[Sequence] | Sequence[bytes]):
self._chunks = list(chunks)
async def iter_chunked(self, _chunk_size):
for chunk in self._chunks:
async def read(self, _chunk_size: int) -> bytes:
if not self._chunks:
await asyncio.sleep(0)
yield chunk
return b""
item = self._chunks.pop(0)
delay = 0.0
payload = item
if isinstance(item, tuple):
payload = item[0]
delay = item[1]
await asyncio.sleep(delay)
return payload
class FakeResponse:
@@ -53,6 +65,12 @@ def _build_downloader(responses, *, max_retries=0):
downloader._session = fake_session
downloader._session_created_at = datetime.now()
downloader._proxy_url = None
async def _noop_create_session():
downloader._session = fake_session
downloader._session_created_at = datetime.now()
downloader._proxy_url = None
downloader._create_session = _noop_create_session # type: ignore[assignment]
return downloader
@@ -123,3 +141,34 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path):
assert success is True
assert Path(result_path).read_bytes() == payload
assert not Path(str(target_path) + ".part").exists()
@pytest.mark.asyncio
async def test_download_file_recovers_from_stall(tmp_path):
target_path = tmp_path / "model" / "file.bin"
target_path.parent.mkdir()
payload = b"abcdef"
responses = [
lambda: FakeResponse(
status=200,
headers={"content-length": str(len(payload))},
chunks=[(b"abc", 0.0), (b"def", 0.1)],
),
lambda: FakeResponse(
status=206,
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
chunks=[(b"def", 0.0)],
),
]
downloader = _build_downloader(responses, max_retries=1)
downloader.stall_timeout = 0.05
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
assert success is True
assert Path(result_path).read_bytes() == payload
assert downloader._session._get_calls == 2
assert not Path(str(target_path) + ".part").exists()