fix(download): recover stalled transfers automatically

This commit is contained in:
pixelpaws
2025-10-23 17:25:38 +08:00
parent 2eae8a7729
commit faa26651dd
4 changed files with 303 additions and 81 deletions

View File

@@ -15,7 +15,7 @@ from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader, DownloadProgress from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
# Download to temporary file first # Download to temporary file first
import tempfile import tempfile
@@ -44,7 +44,7 @@ class DownloadManager:
self._active_downloads = OrderedDict() # download_id -> download_info self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task self._download_tasks = {} # download_id -> asyncio.Task
self._pause_events: Dict[str, asyncio.Event] = {} self._pause_events: Dict[str, DownloadStreamControl] = {}
async def _get_lora_scanner(self): async def _get_lora_scanner(self):
"""Get the lora scanner from registry""" """Get the lora scanner from registry"""
@@ -89,11 +89,11 @@ class DownloadManager:
'bytes_downloaded': 0, 'bytes_downloaded': 0,
'total_bytes': None, 'total_bytes': None,
'bytes_per_second': 0.0, 'bytes_per_second': 0.0,
'last_progress_timestamp': None,
} }
pause_event = asyncio.Event() pause_control = DownloadStreamControl()
pause_event.set() self._pause_events[task_id] = pause_control
self._pause_events[task_id] = pause_event
# Create tracking task # Create tracking task
download_task = asyncio.create_task( download_task = asyncio.create_task(
@@ -140,19 +140,23 @@ class DownloadManager:
info['bytes_downloaded'] = snapshot.bytes_downloaded info['bytes_downloaded'] = snapshot.bytes_downloaded
info['total_bytes'] = snapshot.total_bytes info['total_bytes'] = snapshot.total_bytes
info['bytes_per_second'] = snapshot.bytes_per_second info['bytes_per_second'] = snapshot.bytes_per_second
pause_control = self._pause_events.get(task_id)
if isinstance(pause_control, DownloadStreamControl):
pause_control.mark_progress(snapshot.timestamp)
info['last_progress_timestamp'] = pause_control.last_progress_timestamp
if original_callback: if original_callback:
await self._dispatch_progress(original_callback, snapshot, progress_value) await self._dispatch_progress(original_callback, snapshot, progress_value)
# Acquire semaphore to limit concurrent downloads # Acquire semaphore to limit concurrent downloads
try: try:
async with self._download_semaphore: async with self._download_semaphore:
pause_event = self._pause_events.get(task_id) pause_control = self._pause_events.get(task_id)
if pause_event is not None and not pause_event.is_set(): if pause_control is not None and pause_control.is_paused():
if task_id in self._active_downloads: if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'paused' self._active_downloads[task_id]['status'] = 'paused'
self._active_downloads[task_id]['bytes_per_second'] = 0.0 self._active_downloads[task_id]['bytes_per_second'] = 0.0
await pause_event.wait() await pause_control.wait()
# Update status to downloading # Update status to downloading
if task_id in self._active_downloads: if task_id in self._active_downloads:
@@ -478,8 +482,8 @@ class DownloadManager:
part_path = save_path + '.part' part_path = save_path + '.part'
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
pause_event = self._pause_events.get(download_id) if download_id else None pause_control = self._pause_events.get(download_id) if download_id else None
# Store file paths in active_downloads for potential cleanup # Store file paths in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads: if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = save_path self._active_downloads[download_id]['file_path'] = save_path
@@ -590,6 +594,8 @@ class DownloadManager:
# Download model file with progress tracking using downloader # Download model file with progress tracking using downloader
downloader = await get_downloader() downloader = await get_downloader()
if pause_control is not None:
pause_control.update_stall_timeout(downloader.stall_timeout)
last_error = None last_error = None
for download_url in download_urls: for download_url in download_urls:
use_auth = download_url.startswith("https://civitai.com/api/download/") use_auth = download_url.startswith("https://civitai.com/api/download/")
@@ -602,8 +608,8 @@ class DownloadManager:
"use_auth": use_auth, # Only use authentication for Civitai downloads "use_auth": use_auth, # Only use authentication for Civitai downloads
} }
if pause_event is not None: if pause_control is not None:
download_kwargs["pause_event"] = pause_event download_kwargs["pause_event"] = pause_control
success, result = await downloader.download_file( success, result = await downloader.download_file(
download_url, download_url,
@@ -756,9 +762,9 @@ class DownloadManager:
task = self._download_tasks[download_id] task = self._download_tasks[download_id]
task.cancel() task.cancel()
pause_event = self._pause_events.get(download_id) pause_control = self._pause_events.get(download_id)
if pause_event is not None: if pause_control is not None:
pause_event.set() pause_control.resume()
# Update status in active downloads # Update status in active downloads
if download_id in self._active_downloads: if download_id in self._active_downloads:
@@ -835,16 +841,14 @@ class DownloadManager:
if download_id not in self._download_tasks: if download_id not in self._download_tasks:
return {'success': False, 'error': 'Download task not found'} return {'success': False, 'error': 'Download task not found'}
pause_event = self._pause_events.get(download_id) pause_control = self._pause_events.get(download_id)
if pause_event is None: if pause_control is None:
pause_event = asyncio.Event() return {'success': False, 'error': 'Download task not found'}
pause_event.set()
self._pause_events[download_id] = pause_event
if not pause_event.is_set(): if pause_control.is_paused():
return {'success': False, 'error': 'Download is already paused'} return {'success': False, 'error': 'Download is already paused'}
pause_event.clear() pause_control.pause()
download_info = self._active_downloads.get(download_id) download_info = self._active_downloads.get(download_id)
if download_info is not None: if download_info is not None:
@@ -856,16 +860,28 @@ class DownloadManager:
async def resume_download(self, download_id: str) -> Dict: async def resume_download(self, download_id: str) -> Dict:
"""Resume a previously paused download.""" """Resume a previously paused download."""
pause_event = self._pause_events.get(download_id) pause_control = self._pause_events.get(download_id)
if pause_event is None: if pause_control is None:
return {'success': False, 'error': 'Download task not found'} return {'success': False, 'error': 'Download task not found'}
if pause_event.is_set(): if pause_control.is_set():
return {'success': False, 'error': 'Download is not paused'} return {'success': False, 'error': 'Download is not paused'}
pause_event.set()
download_info = self._active_downloads.get(download_id) download_info = self._active_downloads.get(download_id)
force_reconnect = False
if pause_control is not None:
elapsed = pause_control.time_since_last_progress()
threshold = max(30.0, pause_control.stall_timeout / 2.0)
if elapsed is not None and elapsed >= threshold:
force_reconnect = True
logger.info(
"Forcing reconnect for download %s after %.1f seconds without progress",
download_id,
elapsed,
)
pause_control.resume(force_reconnect=force_reconnect)
if download_info is not None: if download_info is not None:
if download_info.get('status') == 'paused': if download_info.get('status') == 'paused':
download_info['status'] = 'downloading' download_info['status'] = 'downloading'

View File

@@ -36,6 +36,73 @@ class DownloadProgress:
timestamp: float timestamp: float
class DownloadStreamControl:
"""Synchronize pause/resume requests and reconnect hints for a download."""
def __init__(self, *, stall_timeout: Optional[float] = None) -> None:
self._event = asyncio.Event()
self._event.set()
self._reconnect_requested = False
self.last_progress_timestamp: Optional[float] = None
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0
def is_set(self) -> bool:
return self._event.is_set()
def is_paused(self) -> bool:
return not self._event.is_set()
def set(self) -> None:
self._event.set()
def clear(self) -> None:
self._event.clear()
async def wait(self) -> None:
await self._event.wait()
def pause(self) -> None:
self.clear()
def resume(self, *, force_reconnect: bool = False) -> None:
if force_reconnect:
self._reconnect_requested = True
self.set()
def request_reconnect(self) -> None:
self._reconnect_requested = True
self.set()
def has_reconnect_request(self) -> bool:
return self._reconnect_requested
def consume_reconnect_request(self) -> bool:
reconnect = self._reconnect_requested
self._reconnect_requested = False
return reconnect
def mark_progress(self, timestamp: Optional[float] = None) -> None:
self.last_progress_timestamp = timestamp or datetime.now().timestamp()
self._reconnect_requested = False
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]:
if self.last_progress_timestamp is None:
return None
reference = now if now is not None else datetime.now().timestamp()
return max(0.0, reference - self.last_progress_timestamp)
def update_stall_timeout(self, stall_timeout: float) -> None:
self.stall_timeout = float(stall_timeout)
class DownloadRestartRequested(Exception):
"""Raised when a caller explicitly requests a fresh HTTP stream."""
class DownloadStalledError(Exception):
"""Raised when download progress stalls beyond the configured timeout."""
class Downloader: class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application.""" """Unified downloader for all HTTP/HTTPS downloads in the application."""
@@ -67,6 +134,7 @@ class Downloader:
self.max_retries = 5 self.max_retries = 5
self.base_delay = 2.0 # Base delay for exponential backoff self.base_delay = 2.0 # Base delay for exponential backoff
self.session_timeout = 300 # 5 minutes self.session_timeout = 300 # 5 minutes
self.stall_timeout = self._resolve_stall_timeout()
# Default headers # Default headers
self.default_headers = { self.default_headers = {
@@ -82,14 +150,38 @@ class Downloader:
if self._session is None or self._should_refresh_session(): if self._session is None or self._should_refresh_session():
await self._create_session() await self._create_session()
return self._session return self._session
@property @property
def proxy_url(self) -> Optional[str]: def proxy_url(self) -> Optional[str]:
"""Get the current proxy URL (initialize if needed)""" """Get the current proxy URL (initialize if needed)"""
if not hasattr(self, '_proxy_url'): if not hasattr(self, '_proxy_url'):
self._proxy_url = None self._proxy_url = None
return self._proxy_url return self._proxy_url
def _resolve_stall_timeout(self) -> float:
"""Determine the stall timeout from settings or environment."""
default_timeout = 120.0
settings_timeout = None
try:
settings_manager = get_settings_manager()
settings_timeout = settings_manager.get('download_stall_timeout_seconds')
except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Failed to read stall timeout from settings: %s", exc)
raw_value = (
settings_timeout
if settings_timeout not in (None, "")
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT')
)
try:
timeout_value = float(raw_value)
except (TypeError, ValueError):
timeout_value = default_timeout
return max(30.0, timeout_value)
def _should_refresh_session(self) -> bool: def _should_refresh_session(self) -> bool:
"""Check if session should be refreshed""" """Check if session should be refreshed"""
if self._session is None: if self._session is None:
@@ -181,7 +273,7 @@ class Downloader:
use_auth: bool = False, use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None, custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True, allow_resume: bool = True,
pause_event: Optional[asyncio.Event] = None, pause_event: Optional[DownloadStreamControl] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """
Download a file with resumable downloads and retry mechanism Download a file with resumable downloads and retry mechanism
@@ -193,7 +285,7 @@ class Downloader:
use_auth: Whether to include authentication headers (e.g., CivitAI API key) use_auth: Whether to include authentication headers (e.g., CivitAI API key)
custom_headers: Additional headers to include in request custom_headers: Additional headers to include in request
allow_resume: Whether to support resumable downloads allow_resume: Whether to support resumable downloads
pause_event: Optional event that, when cleared, will pause streaming until set again pause_event: Optional stream control used to pause/resume and request reconnects
Returns: Returns:
Tuple[bool, str]: (success, save_path or error message) Tuple[bool, str]: (success, save_path or error message)
@@ -307,51 +399,88 @@ class Downloader:
last_progress_report_time = datetime.now() last_progress_report_time = datetime.now()
progress_samples: deque[tuple[datetime, int]] = deque() progress_samples: deque[tuple[datetime, int]] = deque()
progress_samples.append((last_progress_report_time, current_size)) progress_samples.append((last_progress_report_time, current_size))
# Ensure directory exists # Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True) os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Stream download to file with progress updates # Stream download to file with progress updates
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb' mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
control = pause_event
if control is not None:
control.update_stall_timeout(self.stall_timeout)
with open(part_path, mode) as f: with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size): while True:
if pause_event is not None and not pause_event.is_set(): active_stall_timeout = control.stall_timeout if control else self.stall_timeout
await pause_event.wait()
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead if control is not None:
now = datetime.now() if control.is_paused():
time_diff = (now - last_progress_report_time).total_seconds() await control.wait()
resume_time = datetime.now()
last_progress_report_time = resume_time
if control.consume_reconnect_request():
raise DownloadRestartRequested(
"Reconnect requested after resume"
)
elif control.consume_reconnect_request():
raise DownloadRestartRequested("Reconnect requested")
if progress_callback and time_diff >= 1.0: try:
progress_samples.append((now, current_size)) chunk = await asyncio.wait_for(
cutoff = now - timedelta(seconds=5) response.content.read(self.chunk_size),
while progress_samples and progress_samples[0][0] < cutoff: timeout=active_stall_timeout,
progress_samples.popleft() )
except asyncio.TimeoutError as exc:
logger.warning(
"Download stalled for %.1f seconds without progress from %s",
active_stall_timeout,
url,
)
raise DownloadStalledError(
f"No data received for {active_stall_timeout:.1f} seconds"
) from exc
percent = (current_size / total_size) * 100 if total_size else 0.0 if not chunk:
bytes_per_second = 0.0 break
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
progress_snapshot = DownloadProgress( # Run blocking file write in executor
percent_complete=percent, await loop.run_in_executor(None, f.write, chunk)
bytes_downloaded=current_size, current_size += len(chunk)
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
await self._dispatch_progress_callback(progress_callback, progress_snapshot) now = datetime.now()
last_progress_report_time = now if control is not None:
control.mark_progress(timestamp=now.timestamp())
# Limit progress update frequency to reduce overhead
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and time_diff >= 1.0:
progress_samples.append((now, current_size))
cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff:
progress_samples.popleft()
percent = (current_size / total_size) * 100 if total_size else 0.0
bytes_per_second = 0.0
if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds()
if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed
progress_snapshot = DownloadProgress(
percent_complete=percent,
bytes_downloaded=current_size,
total_bytes=total_size or None,
bytes_per_second=bytes_per_second,
timestamp=now.timestamp(),
)
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
last_progress_report_time = now
# Download completed successfully # Download completed successfully
# Verify file size integrity before finalizing # Verify file size integrity before finalizing
@@ -447,11 +576,17 @@ class Downloader:
return True, save_path return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError, except (
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e: aiohttp.ClientError,
aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError,
asyncio.TimeoutError,
DownloadStalledError,
DownloadRestartRequested,
) as e:
retry_count += 1 retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}") logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
if retry_count <= self.max_retries: if retry_count <= self.max_retries:
# Calculate delay with exponential backoff # Calculate delay with exponential backoff
delay = self.base_delay * (2 ** (retry_count - 1)) delay = self.base_delay * (2 ** (retry_count - 1))

View File

@@ -1,5 +1,6 @@
import asyncio import asyncio
import os import os
from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from types import SimpleNamespace from types import SimpleNamespace
@@ -8,6 +9,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from py.services.download_manager import DownloadManager from py.services.download_manager import DownloadManager
from py.services.downloader import DownloadStreamControl
from py.services import download_manager from py.services import download_manager
from py.services.service_registry import ServiceRegistry from py.services.service_registry import ServiceRegistry
from py.services.settings_manager import SettingsManager, get_settings_manager from py.services.settings_manager import SettingsManager, get_settings_manager
@@ -528,9 +530,8 @@ async def test_pause_download_updates_state():
download_id = "dl" download_id = "dl"
manager._download_tasks[download_id] = object() manager._download_tasks[download_id] = object()
pause_event = asyncio.Event() pause_control = DownloadStreamControl()
pause_event.set() manager._pause_events[download_id] = pause_control
manager._pause_events[download_id] = pause_event
manager._active_downloads[download_id] = { manager._active_downloads[download_id] = {
"status": "downloading", "status": "downloading",
"bytes_per_second": 42.0, "bytes_per_second": 42.0,
@@ -557,8 +558,10 @@ async def test_resume_download_sets_event_and_status():
manager = DownloadManager() manager = DownloadManager()
download_id = "dl" download_id = "dl"
pause_event = asyncio.Event() pause_control = DownloadStreamControl()
manager._pause_events[download_id] = pause_event pause_control.pause()
pause_control.mark_progress()
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = { manager._active_downloads[download_id] = {
"status": "paused", "status": "paused",
"bytes_per_second": 0.0, "bytes_per_second": 0.0,
@@ -571,13 +574,32 @@ async def test_resume_download_sets_event_and_status():
assert manager._active_downloads[download_id]["status"] == "downloading" assert manager._active_downloads[download_id]["status"] == "downloading"
async def test_resume_download_requests_reconnect_for_stalled_stream():
manager = DownloadManager()
download_id = "dl"
pause_control = DownloadStreamControl(stall_timeout=40)
pause_control.pause()
pause_control.last_progress_timestamp = (datetime.now().timestamp() - 120)
manager._pause_events[download_id] = pause_control
manager._active_downloads[download_id] = {
"status": "paused",
"bytes_per_second": 0.0,
}
result = await manager.resume_download(download_id)
assert result == {"success": True, "message": "Download resumed successfully"}
assert pause_control.is_set() is True
assert pause_control.has_reconnect_request() is True
async def test_resume_download_rejects_when_not_paused(): async def test_resume_download_rejects_when_not_paused():
manager = DownloadManager() manager = DownloadManager()
download_id = "dl" download_id = "dl"
pause_event = asyncio.Event() pause_control = DownloadStreamControl()
pause_event.set() manager._pause_events[download_id] = pause_control
manager._pause_events[download_id] = pause_event
result = await manager.resume_download(download_id) result = await manager.resume_download(download_id)

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Sequence
import pytest import pytest
@@ -8,13 +9,24 @@ from py.services.downloader import Downloader
class FakeStream: class FakeStream:
def __init__(self, chunks): def __init__(self, chunks: Sequence[Sequence] | Sequence[bytes]):
self._chunks = list(chunks) self._chunks = list(chunks)
async def iter_chunked(self, _chunk_size): async def read(self, _chunk_size: int) -> bytes:
for chunk in self._chunks: if not self._chunks:
await asyncio.sleep(0) await asyncio.sleep(0)
yield chunk return b""
item = self._chunks.pop(0)
delay = 0.0
payload = item
if isinstance(item, tuple):
payload = item[0]
delay = item[1]
await asyncio.sleep(delay)
return payload
class FakeResponse: class FakeResponse:
@@ -53,6 +65,12 @@ def _build_downloader(responses, *, max_retries=0):
downloader._session = fake_session downloader._session = fake_session
downloader._session_created_at = datetime.now() downloader._session_created_at = datetime.now()
downloader._proxy_url = None downloader._proxy_url = None
async def _noop_create_session():
downloader._session = fake_session
downloader._session_created_at = datetime.now()
downloader._proxy_url = None
downloader._create_session = _noop_create_session # type: ignore[assignment]
return downloader return downloader
@@ -123,3 +141,34 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path):
assert success is True assert success is True
assert Path(result_path).read_bytes() == payload assert Path(result_path).read_bytes() == payload
assert not Path(str(target_path) + ".part").exists() assert not Path(str(target_path) + ".part").exists()
@pytest.mark.asyncio
async def test_download_file_recovers_from_stall(tmp_path):
target_path = tmp_path / "model" / "file.bin"
target_path.parent.mkdir()
payload = b"abcdef"
responses = [
lambda: FakeResponse(
status=200,
headers={"content-length": str(len(payload))},
chunks=[(b"abc", 0.0), (b"def", 0.1)],
),
lambda: FakeResponse(
status=206,
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
chunks=[(b"def", 0.0)],
),
]
downloader = _build_downloader(responses, max_retries=1)
downloader.stall_timeout = 0.05
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
assert success is True
assert Path(result_path).read_bytes() == payload
assert downloader._session._get_calls == 2
assert not Path(str(target_path) + ".part").exists()