mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
fix(download): recover stalled transfers automatically
This commit is contained in:
@@ -15,7 +15,7 @@ from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import get_settings_manager
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
from .downloader import get_downloader, DownloadProgress
|
||||
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
|
||||
|
||||
# Download to temporary file first
|
||||
import tempfile
|
||||
@@ -44,7 +44,7 @@ class DownloadManager:
|
||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
self._pause_events: Dict[str, asyncio.Event] = {}
|
||||
self._pause_events: Dict[str, DownloadStreamControl] = {}
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
@@ -89,11 +89,11 @@ class DownloadManager:
|
||||
'bytes_downloaded': 0,
|
||||
'total_bytes': None,
|
||||
'bytes_per_second': 0.0,
|
||||
'last_progress_timestamp': None,
|
||||
}
|
||||
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
self._pause_events[task_id] = pause_event
|
||||
pause_control = DownloadStreamControl()
|
||||
self._pause_events[task_id] = pause_control
|
||||
|
||||
# Create tracking task
|
||||
download_task = asyncio.create_task(
|
||||
@@ -140,19 +140,23 @@ class DownloadManager:
|
||||
info['bytes_downloaded'] = snapshot.bytes_downloaded
|
||||
info['total_bytes'] = snapshot.total_bytes
|
||||
info['bytes_per_second'] = snapshot.bytes_per_second
|
||||
pause_control = self._pause_events.get(task_id)
|
||||
if isinstance(pause_control, DownloadStreamControl):
|
||||
pause_control.mark_progress(snapshot.timestamp)
|
||||
info['last_progress_timestamp'] = pause_control.last_progress_timestamp
|
||||
|
||||
if original_callback:
|
||||
await self._dispatch_progress(original_callback, snapshot, progress_value)
|
||||
|
||||
|
||||
# Acquire semaphore to limit concurrent downloads
|
||||
try:
|
||||
async with self._download_semaphore:
|
||||
pause_event = self._pause_events.get(task_id)
|
||||
if pause_event is not None and not pause_event.is_set():
|
||||
pause_control = self._pause_events.get(task_id)
|
||||
if pause_control is not None and pause_control.is_paused():
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'paused'
|
||||
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||
await pause_event.wait()
|
||||
await pause_control.wait()
|
||||
|
||||
# Update status to downloading
|
||||
if task_id in self._active_downloads:
|
||||
@@ -478,8 +482,8 @@ class DownloadManager:
|
||||
part_path = save_path + '.part'
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
pause_event = self._pause_events.get(download_id) if download_id else None
|
||||
|
||||
pause_control = self._pause_events.get(download_id) if download_id else None
|
||||
|
||||
# Store file paths in active_downloads for potential cleanup
|
||||
if download_id and download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['file_path'] = save_path
|
||||
@@ -590,6 +594,8 @@ class DownloadManager:
|
||||
|
||||
# Download model file with progress tracking using downloader
|
||||
downloader = await get_downloader()
|
||||
if pause_control is not None:
|
||||
pause_control.update_stall_timeout(downloader.stall_timeout)
|
||||
last_error = None
|
||||
for download_url in download_urls:
|
||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||
@@ -602,8 +608,8 @@ class DownloadManager:
|
||||
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
||||
}
|
||||
|
||||
if pause_event is not None:
|
||||
download_kwargs["pause_event"] = pause_event
|
||||
if pause_control is not None:
|
||||
download_kwargs["pause_event"] = pause_control
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
@@ -756,9 +762,9 @@ class DownloadManager:
|
||||
task = self._download_tasks[download_id]
|
||||
task.cancel()
|
||||
|
||||
pause_event = self._pause_events.get(download_id)
|
||||
if pause_event is not None:
|
||||
pause_event.set()
|
||||
pause_control = self._pause_events.get(download_id)
|
||||
if pause_control is not None:
|
||||
pause_control.resume()
|
||||
|
||||
# Update status in active downloads
|
||||
if download_id in self._active_downloads:
|
||||
@@ -835,16 +841,14 @@ class DownloadManager:
|
||||
if download_id not in self._download_tasks:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
pause_event = self._pause_events.get(download_id)
|
||||
if pause_event is None:
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
self._pause_events[download_id] = pause_event
|
||||
pause_control = self._pause_events.get(download_id)
|
||||
if pause_control is None:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
if not pause_event.is_set():
|
||||
if pause_control.is_paused():
|
||||
return {'success': False, 'error': 'Download is already paused'}
|
||||
|
||||
pause_event.clear()
|
||||
pause_control.pause()
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info is not None:
|
||||
@@ -856,16 +860,28 @@ class DownloadManager:
|
||||
async def resume_download(self, download_id: str) -> Dict:
|
||||
"""Resume a previously paused download."""
|
||||
|
||||
pause_event = self._pause_events.get(download_id)
|
||||
if pause_event is None:
|
||||
pause_control = self._pause_events.get(download_id)
|
||||
if pause_control is None:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
if pause_event.is_set():
|
||||
if pause_control.is_set():
|
||||
return {'success': False, 'error': 'Download is not paused'}
|
||||
|
||||
pause_event.set()
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
force_reconnect = False
|
||||
if pause_control is not None:
|
||||
elapsed = pause_control.time_since_last_progress()
|
||||
threshold = max(30.0, pause_control.stall_timeout / 2.0)
|
||||
if elapsed is not None and elapsed >= threshold:
|
||||
force_reconnect = True
|
||||
logger.info(
|
||||
"Forcing reconnect for download %s after %.1f seconds without progress",
|
||||
download_id,
|
||||
elapsed,
|
||||
)
|
||||
|
||||
pause_control.resume(force_reconnect=force_reconnect)
|
||||
|
||||
if download_info is not None:
|
||||
if download_info.get('status') == 'paused':
|
||||
download_info['status'] = 'downloading'
|
||||
|
||||
@@ -36,6 +36,73 @@ class DownloadProgress:
|
||||
timestamp: float
|
||||
|
||||
|
||||
class DownloadStreamControl:
|
||||
"""Synchronize pause/resume requests and reconnect hints for a download."""
|
||||
|
||||
def __init__(self, *, stall_timeout: Optional[float] = None) -> None:
|
||||
self._event = asyncio.Event()
|
||||
self._event.set()
|
||||
self._reconnect_requested = False
|
||||
self.last_progress_timestamp: Optional[float] = None
|
||||
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
|
||||
def is_paused(self) -> bool:
|
||||
return not self._event.is_set()
|
||||
|
||||
def set(self) -> None:
|
||||
self._event.set()
|
||||
|
||||
def clear(self) -> None:
|
||||
self._event.clear()
|
||||
|
||||
async def wait(self) -> None:
|
||||
await self._event.wait()
|
||||
|
||||
def pause(self) -> None:
|
||||
self.clear()
|
||||
|
||||
def resume(self, *, force_reconnect: bool = False) -> None:
|
||||
if force_reconnect:
|
||||
self._reconnect_requested = True
|
||||
self.set()
|
||||
|
||||
def request_reconnect(self) -> None:
|
||||
self._reconnect_requested = True
|
||||
self.set()
|
||||
|
||||
def has_reconnect_request(self) -> bool:
|
||||
return self._reconnect_requested
|
||||
|
||||
def consume_reconnect_request(self) -> bool:
|
||||
reconnect = self._reconnect_requested
|
||||
self._reconnect_requested = False
|
||||
return reconnect
|
||||
|
||||
def mark_progress(self, timestamp: Optional[float] = None) -> None:
|
||||
self.last_progress_timestamp = timestamp or datetime.now().timestamp()
|
||||
self._reconnect_requested = False
|
||||
|
||||
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]:
|
||||
if self.last_progress_timestamp is None:
|
||||
return None
|
||||
reference = now if now is not None else datetime.now().timestamp()
|
||||
return max(0.0, reference - self.last_progress_timestamp)
|
||||
|
||||
def update_stall_timeout(self, stall_timeout: float) -> None:
|
||||
self.stall_timeout = float(stall_timeout)
|
||||
|
||||
|
||||
class DownloadRestartRequested(Exception):
|
||||
"""Raised when a caller explicitly requests a fresh HTTP stream."""
|
||||
|
||||
|
||||
class DownloadStalledError(Exception):
|
||||
"""Raised when download progress stalls beyond the configured timeout."""
|
||||
|
||||
|
||||
class Downloader:
|
||||
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||
|
||||
@@ -67,6 +134,7 @@ class Downloader:
|
||||
self.max_retries = 5
|
||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||
self.session_timeout = 300 # 5 minutes
|
||||
self.stall_timeout = self._resolve_stall_timeout()
|
||||
|
||||
# Default headers
|
||||
self.default_headers = {
|
||||
@@ -82,14 +150,38 @@ class Downloader:
|
||||
if self._session is None or self._should_refresh_session():
|
||||
await self._create_session()
|
||||
return self._session
|
||||
|
||||
|
||||
@property
|
||||
def proxy_url(self) -> Optional[str]:
|
||||
"""Get the current proxy URL (initialize if needed)"""
|
||||
if not hasattr(self, '_proxy_url'):
|
||||
self._proxy_url = None
|
||||
return self._proxy_url
|
||||
|
||||
|
||||
def _resolve_stall_timeout(self) -> float:
|
||||
"""Determine the stall timeout from settings or environment."""
|
||||
default_timeout = 120.0
|
||||
settings_timeout = None
|
||||
|
||||
try:
|
||||
settings_manager = get_settings_manager()
|
||||
settings_timeout = settings_manager.get('download_stall_timeout_seconds')
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to read stall timeout from settings: %s", exc)
|
||||
|
||||
raw_value = (
|
||||
settings_timeout
|
||||
if settings_timeout not in (None, "")
|
||||
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT')
|
||||
)
|
||||
|
||||
try:
|
||||
timeout_value = float(raw_value)
|
||||
except (TypeError, ValueError):
|
||||
timeout_value = default_timeout
|
||||
|
||||
return max(30.0, timeout_value)
|
||||
|
||||
def _should_refresh_session(self) -> bool:
|
||||
"""Check if session should be refreshed"""
|
||||
if self._session is None:
|
||||
@@ -181,7 +273,7 @@ class Downloader:
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
allow_resume: bool = True,
|
||||
pause_event: Optional[asyncio.Event] = None,
|
||||
pause_event: Optional[DownloadStreamControl] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Download a file with resumable downloads and retry mechanism
|
||||
@@ -193,7 +285,7 @@ class Downloader:
|
||||
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
||||
custom_headers: Additional headers to include in request
|
||||
allow_resume: Whether to support resumable downloads
|
||||
pause_event: Optional event that, when cleared, will pause streaming until set again
|
||||
pause_event: Optional stream control used to pause/resume and request reconnects
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
@@ -307,51 +399,88 @@ class Downloader:
|
||||
last_progress_report_time = datetime.now()
|
||||
progress_samples: deque[tuple[datetime, int]] = deque()
|
||||
progress_samples.append((last_progress_report_time, current_size))
|
||||
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
|
||||
|
||||
# Stream download to file with progress updates
|
||||
loop = asyncio.get_running_loop()
|
||||
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||
control = pause_event
|
||||
|
||||
if control is not None:
|
||||
control.update_stall_timeout(self.stall_timeout)
|
||||
|
||||
with open(part_path, mode) as f:
|
||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||
if pause_event is not None and not pause_event.is_set():
|
||||
await pause_event.wait()
|
||||
if chunk:
|
||||
# Run blocking file write in executor
|
||||
await loop.run_in_executor(None, f.write, chunk)
|
||||
current_size += len(chunk)
|
||||
while True:
|
||||
active_stall_timeout = control.stall_timeout if control else self.stall_timeout
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
now = datetime.now()
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
if control is not None:
|
||||
if control.is_paused():
|
||||
await control.wait()
|
||||
resume_time = datetime.now()
|
||||
last_progress_report_time = resume_time
|
||||
if control.consume_reconnect_request():
|
||||
raise DownloadRestartRequested(
|
||||
"Reconnect requested after resume"
|
||||
)
|
||||
elif control.consume_reconnect_request():
|
||||
raise DownloadRestartRequested("Reconnect requested")
|
||||
|
||||
if progress_callback and time_diff >= 1.0:
|
||||
progress_samples.append((now, current_size))
|
||||
cutoff = now - timedelta(seconds=5)
|
||||
while progress_samples and progress_samples[0][0] < cutoff:
|
||||
progress_samples.popleft()
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
response.content.read(self.chunk_size),
|
||||
timeout=active_stall_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.warning(
|
||||
"Download stalled for %.1f seconds without progress from %s",
|
||||
active_stall_timeout,
|
||||
url,
|
||||
)
|
||||
raise DownloadStalledError(
|
||||
f"No data received for {active_stall_timeout:.1f} seconds"
|
||||
) from exc
|
||||
|
||||
percent = (current_size / total_size) * 100 if total_size else 0.0
|
||||
bytes_per_second = 0.0
|
||||
if len(progress_samples) >= 2:
|
||||
first_time, first_bytes = progress_samples[0]
|
||||
last_time, last_bytes = progress_samples[-1]
|
||||
elapsed = (last_time - first_time).total_seconds()
|
||||
if elapsed > 0:
|
||||
bytes_per_second = (last_bytes - first_bytes) / elapsed
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
progress_snapshot = DownloadProgress(
|
||||
percent_complete=percent,
|
||||
bytes_downloaded=current_size,
|
||||
total_bytes=total_size or None,
|
||||
bytes_per_second=bytes_per_second,
|
||||
timestamp=now.timestamp(),
|
||||
)
|
||||
# Run blocking file write in executor
|
||||
await loop.run_in_executor(None, f.write, chunk)
|
||||
current_size += len(chunk)
|
||||
|
||||
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
|
||||
last_progress_report_time = now
|
||||
now = datetime.now()
|
||||
if control is not None:
|
||||
control.mark_progress(timestamp=now.timestamp())
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
|
||||
if progress_callback and time_diff >= 1.0:
|
||||
progress_samples.append((now, current_size))
|
||||
cutoff = now - timedelta(seconds=5)
|
||||
while progress_samples and progress_samples[0][0] < cutoff:
|
||||
progress_samples.popleft()
|
||||
|
||||
percent = (current_size / total_size) * 100 if total_size else 0.0
|
||||
bytes_per_second = 0.0
|
||||
if len(progress_samples) >= 2:
|
||||
first_time, first_bytes = progress_samples[0]
|
||||
last_time, last_bytes = progress_samples[-1]
|
||||
elapsed = (last_time - first_time).total_seconds()
|
||||
if elapsed > 0:
|
||||
bytes_per_second = (last_bytes - first_bytes) / elapsed
|
||||
|
||||
progress_snapshot = DownloadProgress(
|
||||
percent_complete=percent,
|
||||
bytes_downloaded=current_size,
|
||||
total_bytes=total_size or None,
|
||||
bytes_per_second=bytes_per_second,
|
||||
timestamp=now.timestamp(),
|
||||
)
|
||||
|
||||
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
|
||||
last_progress_report_time = now
|
||||
|
||||
# Download completed successfully
|
||||
# Verify file size integrity before finalizing
|
||||
@@ -447,11 +576,17 @@ class Downloader:
|
||||
|
||||
return True, save_path
|
||||
|
||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
|
||||
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientPayloadError,
|
||||
aiohttp.ServerDisconnectedError,
|
||||
asyncio.TimeoutError,
|
||||
DownloadStalledError,
|
||||
DownloadRestartRequested,
|
||||
) as e:
|
||||
retry_count += 1
|
||||
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
||||
|
||||
|
||||
if retry_count <= self.max_retries:
|
||||
# Calculate delay with exponential backoff
|
||||
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from types import SimpleNamespace
|
||||
@@ -8,6 +9,7 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
|
||||
from py.services.download_manager import DownloadManager
|
||||
from py.services.downloader import DownloadStreamControl
|
||||
from py.services import download_manager
|
||||
from py.services.service_registry import ServiceRegistry
|
||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||
@@ -528,9 +530,8 @@ async def test_pause_download_updates_state():
|
||||
|
||||
download_id = "dl"
|
||||
manager._download_tasks[download_id] = object()
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
manager._pause_events[download_id] = pause_event
|
||||
pause_control = DownloadStreamControl()
|
||||
manager._pause_events[download_id] = pause_control
|
||||
manager._active_downloads[download_id] = {
|
||||
"status": "downloading",
|
||||
"bytes_per_second": 42.0,
|
||||
@@ -557,8 +558,10 @@ async def test_resume_download_sets_event_and_status():
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
pause_event = asyncio.Event()
|
||||
manager._pause_events[download_id] = pause_event
|
||||
pause_control = DownloadStreamControl()
|
||||
pause_control.pause()
|
||||
pause_control.mark_progress()
|
||||
manager._pause_events[download_id] = pause_control
|
||||
manager._active_downloads[download_id] = {
|
||||
"status": "paused",
|
||||
"bytes_per_second": 0.0,
|
||||
@@ -571,13 +574,32 @@ async def test_resume_download_sets_event_and_status():
|
||||
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||
|
||||
|
||||
async def test_resume_download_requests_reconnect_for_stalled_stream():
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
pause_control = DownloadStreamControl(stall_timeout=40)
|
||||
pause_control.pause()
|
||||
pause_control.last_progress_timestamp = (datetime.now().timestamp() - 120)
|
||||
manager._pause_events[download_id] = pause_control
|
||||
manager._active_downloads[download_id] = {
|
||||
"status": "paused",
|
||||
"bytes_per_second": 0.0,
|
||||
}
|
||||
|
||||
result = await manager.resume_download(download_id)
|
||||
|
||||
assert result == {"success": True, "message": "Download resumed successfully"}
|
||||
assert pause_control.is_set() is True
|
||||
assert pause_control.has_reconnect_request() is True
|
||||
|
||||
|
||||
async def test_resume_download_rejects_when_not_paused():
|
||||
manager = DownloadManager()
|
||||
|
||||
download_id = "dl"
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
manager._pause_events[download_id] = pause_event
|
||||
pause_control = DownloadStreamControl()
|
||||
manager._pause_events[download_id] = pause_control
|
||||
|
||||
result = await manager.resume_download(download_id)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -8,13 +9,24 @@ from py.services.downloader import Downloader
|
||||
|
||||
|
||||
class FakeStream:
|
||||
def __init__(self, chunks):
|
||||
def __init__(self, chunks: Sequence[Sequence] | Sequence[bytes]):
|
||||
self._chunks = list(chunks)
|
||||
|
||||
async def iter_chunked(self, _chunk_size):
|
||||
for chunk in self._chunks:
|
||||
async def read(self, _chunk_size: int) -> bytes:
|
||||
if not self._chunks:
|
||||
await asyncio.sleep(0)
|
||||
yield chunk
|
||||
return b""
|
||||
|
||||
item = self._chunks.pop(0)
|
||||
delay = 0.0
|
||||
payload = item
|
||||
|
||||
if isinstance(item, tuple):
|
||||
payload = item[0]
|
||||
delay = item[1]
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
return payload
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
@@ -53,6 +65,12 @@ def _build_downloader(responses, *, max_retries=0):
|
||||
downloader._session = fake_session
|
||||
downloader._session_created_at = datetime.now()
|
||||
downloader._proxy_url = None
|
||||
async def _noop_create_session():
|
||||
downloader._session = fake_session
|
||||
downloader._session_created_at = datetime.now()
|
||||
downloader._proxy_url = None
|
||||
|
||||
downloader._create_session = _noop_create_session # type: ignore[assignment]
|
||||
return downloader
|
||||
|
||||
|
||||
@@ -123,3 +141,34 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path):
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == payload
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_file_recovers_from_stall(tmp_path):
|
||||
target_path = tmp_path / "model" / "file.bin"
|
||||
target_path.parent.mkdir()
|
||||
|
||||
payload = b"abcdef"
|
||||
|
||||
responses = [
|
||||
lambda: FakeResponse(
|
||||
status=200,
|
||||
headers={"content-length": str(len(payload))},
|
||||
chunks=[(b"abc", 0.0), (b"def", 0.1)],
|
||||
),
|
||||
lambda: FakeResponse(
|
||||
status=206,
|
||||
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
|
||||
chunks=[(b"def", 0.0)],
|
||||
),
|
||||
]
|
||||
|
||||
downloader = _build_downloader(responses, max_retries=1)
|
||||
downloader.stall_timeout = 0.05
|
||||
|
||||
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||
|
||||
assert success is True
|
||||
assert Path(result_path).read_bytes() == payload
|
||||
assert downloader._session._get_calls == 2
|
||||
assert not Path(str(target_path) + ".part").exists()
|
||||
|
||||
Reference in New Issue
Block a user