mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
Merge pull request #589 from willmiao/codex/fix-download-stalling-issues
Fix stalled downloads by adding stall detection and reconnect logic
This commit is contained in:
@@ -15,7 +15,7 @@ from ..utils.metadata_manager import MetadataManager
|
|||||||
from .service_registry import ServiceRegistry
|
from .service_registry import ServiceRegistry
|
||||||
from .settings_manager import get_settings_manager
|
from .settings_manager import get_settings_manager
|
||||||
from .metadata_service import get_default_metadata_provider
|
from .metadata_service import get_default_metadata_provider
|
||||||
from .downloader import get_downloader, DownloadProgress
|
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
|
||||||
|
|
||||||
# Download to temporary file first
|
# Download to temporary file first
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -44,7 +44,7 @@ class DownloadManager:
|
|||||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||||
self._download_tasks = {} # download_id -> asyncio.Task
|
self._download_tasks = {} # download_id -> asyncio.Task
|
||||||
self._pause_events: Dict[str, asyncio.Event] = {}
|
self._pause_events: Dict[str, DownloadStreamControl] = {}
|
||||||
|
|
||||||
async def _get_lora_scanner(self):
|
async def _get_lora_scanner(self):
|
||||||
"""Get the lora scanner from registry"""
|
"""Get the lora scanner from registry"""
|
||||||
@@ -89,11 +89,11 @@ class DownloadManager:
|
|||||||
'bytes_downloaded': 0,
|
'bytes_downloaded': 0,
|
||||||
'total_bytes': None,
|
'total_bytes': None,
|
||||||
'bytes_per_second': 0.0,
|
'bytes_per_second': 0.0,
|
||||||
|
'last_progress_timestamp': None,
|
||||||
}
|
}
|
||||||
|
|
||||||
pause_event = asyncio.Event()
|
pause_control = DownloadStreamControl()
|
||||||
pause_event.set()
|
self._pause_events[task_id] = pause_control
|
||||||
self._pause_events[task_id] = pause_event
|
|
||||||
|
|
||||||
# Create tracking task
|
# Create tracking task
|
||||||
download_task = asyncio.create_task(
|
download_task = asyncio.create_task(
|
||||||
@@ -140,6 +140,10 @@ class DownloadManager:
|
|||||||
info['bytes_downloaded'] = snapshot.bytes_downloaded
|
info['bytes_downloaded'] = snapshot.bytes_downloaded
|
||||||
info['total_bytes'] = snapshot.total_bytes
|
info['total_bytes'] = snapshot.total_bytes
|
||||||
info['bytes_per_second'] = snapshot.bytes_per_second
|
info['bytes_per_second'] = snapshot.bytes_per_second
|
||||||
|
pause_control = self._pause_events.get(task_id)
|
||||||
|
if isinstance(pause_control, DownloadStreamControl):
|
||||||
|
pause_control.mark_progress(snapshot.timestamp)
|
||||||
|
info['last_progress_timestamp'] = pause_control.last_progress_timestamp
|
||||||
|
|
||||||
if original_callback:
|
if original_callback:
|
||||||
await self._dispatch_progress(original_callback, snapshot, progress_value)
|
await self._dispatch_progress(original_callback, snapshot, progress_value)
|
||||||
@@ -147,12 +151,12 @@ class DownloadManager:
|
|||||||
# Acquire semaphore to limit concurrent downloads
|
# Acquire semaphore to limit concurrent downloads
|
||||||
try:
|
try:
|
||||||
async with self._download_semaphore:
|
async with self._download_semaphore:
|
||||||
pause_event = self._pause_events.get(task_id)
|
pause_control = self._pause_events.get(task_id)
|
||||||
if pause_event is not None and not pause_event.is_set():
|
if pause_control is not None and pause_control.is_paused():
|
||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
self._active_downloads[task_id]['status'] = 'paused'
|
self._active_downloads[task_id]['status'] = 'paused'
|
||||||
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||||
await pause_event.wait()
|
await pause_control.wait()
|
||||||
|
|
||||||
# Update status to downloading
|
# Update status to downloading
|
||||||
if task_id in self._active_downloads:
|
if task_id in self._active_downloads:
|
||||||
@@ -478,7 +482,7 @@ class DownloadManager:
|
|||||||
part_path = save_path + '.part'
|
part_path = save_path + '.part'
|
||||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||||
|
|
||||||
pause_event = self._pause_events.get(download_id) if download_id else None
|
pause_control = self._pause_events.get(download_id) if download_id else None
|
||||||
|
|
||||||
# Store file paths in active_downloads for potential cleanup
|
# Store file paths in active_downloads for potential cleanup
|
||||||
if download_id and download_id in self._active_downloads:
|
if download_id and download_id in self._active_downloads:
|
||||||
@@ -590,6 +594,8 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Download model file with progress tracking using downloader
|
# Download model file with progress tracking using downloader
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
|
if pause_control is not None:
|
||||||
|
pause_control.update_stall_timeout(downloader.stall_timeout)
|
||||||
last_error = None
|
last_error = None
|
||||||
for download_url in download_urls:
|
for download_url in download_urls:
|
||||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||||
@@ -602,8 +608,8 @@ class DownloadManager:
|
|||||||
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
||||||
}
|
}
|
||||||
|
|
||||||
if pause_event is not None:
|
if pause_control is not None:
|
||||||
download_kwargs["pause_event"] = pause_event
|
download_kwargs["pause_event"] = pause_control
|
||||||
|
|
||||||
success, result = await downloader.download_file(
|
success, result = await downloader.download_file(
|
||||||
download_url,
|
download_url,
|
||||||
@@ -756,9 +762,9 @@ class DownloadManager:
|
|||||||
task = self._download_tasks[download_id]
|
task = self._download_tasks[download_id]
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
pause_event = self._pause_events.get(download_id)
|
pause_control = self._pause_events.get(download_id)
|
||||||
if pause_event is not None:
|
if pause_control is not None:
|
||||||
pause_event.set()
|
pause_control.resume()
|
||||||
|
|
||||||
# Update status in active downloads
|
# Update status in active downloads
|
||||||
if download_id in self._active_downloads:
|
if download_id in self._active_downloads:
|
||||||
@@ -835,16 +841,14 @@ class DownloadManager:
|
|||||||
if download_id not in self._download_tasks:
|
if download_id not in self._download_tasks:
|
||||||
return {'success': False, 'error': 'Download task not found'}
|
return {'success': False, 'error': 'Download task not found'}
|
||||||
|
|
||||||
pause_event = self._pause_events.get(download_id)
|
pause_control = self._pause_events.get(download_id)
|
||||||
if pause_event is None:
|
if pause_control is None:
|
||||||
pause_event = asyncio.Event()
|
return {'success': False, 'error': 'Download task not found'}
|
||||||
pause_event.set()
|
|
||||||
self._pause_events[download_id] = pause_event
|
|
||||||
|
|
||||||
if not pause_event.is_set():
|
if pause_control.is_paused():
|
||||||
return {'success': False, 'error': 'Download is already paused'}
|
return {'success': False, 'error': 'Download is already paused'}
|
||||||
|
|
||||||
pause_event.clear()
|
pause_control.pause()
|
||||||
|
|
||||||
download_info = self._active_downloads.get(download_id)
|
download_info = self._active_downloads.get(download_id)
|
||||||
if download_info is not None:
|
if download_info is not None:
|
||||||
@@ -856,16 +860,28 @@ class DownloadManager:
|
|||||||
async def resume_download(self, download_id: str) -> Dict:
|
async def resume_download(self, download_id: str) -> Dict:
|
||||||
"""Resume a previously paused download."""
|
"""Resume a previously paused download."""
|
||||||
|
|
||||||
pause_event = self._pause_events.get(download_id)
|
pause_control = self._pause_events.get(download_id)
|
||||||
if pause_event is None:
|
if pause_control is None:
|
||||||
return {'success': False, 'error': 'Download task not found'}
|
return {'success': False, 'error': 'Download task not found'}
|
||||||
|
|
||||||
if pause_event.is_set():
|
if pause_control.is_set():
|
||||||
return {'success': False, 'error': 'Download is not paused'}
|
return {'success': False, 'error': 'Download is not paused'}
|
||||||
|
|
||||||
pause_event.set()
|
|
||||||
|
|
||||||
download_info = self._active_downloads.get(download_id)
|
download_info = self._active_downloads.get(download_id)
|
||||||
|
force_reconnect = False
|
||||||
|
if pause_control is not None:
|
||||||
|
elapsed = pause_control.time_since_last_progress()
|
||||||
|
threshold = max(30.0, pause_control.stall_timeout / 2.0)
|
||||||
|
if elapsed is not None and elapsed >= threshold:
|
||||||
|
force_reconnect = True
|
||||||
|
logger.info(
|
||||||
|
"Forcing reconnect for download %s after %.1f seconds without progress",
|
||||||
|
download_id,
|
||||||
|
elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
pause_control.resume(force_reconnect=force_reconnect)
|
||||||
|
|
||||||
if download_info is not None:
|
if download_info is not None:
|
||||||
if download_info.get('status') == 'paused':
|
if download_info.get('status') == 'paused':
|
||||||
download_info['status'] = 'downloading'
|
download_info['status'] = 'downloading'
|
||||||
|
|||||||
@@ -36,6 +36,73 @@ class DownloadProgress:
|
|||||||
timestamp: float
|
timestamp: float
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadStreamControl:
|
||||||
|
"""Synchronize pause/resume requests and reconnect hints for a download."""
|
||||||
|
|
||||||
|
def __init__(self, *, stall_timeout: Optional[float] = None) -> None:
|
||||||
|
self._event = asyncio.Event()
|
||||||
|
self._event.set()
|
||||||
|
self._reconnect_requested = False
|
||||||
|
self.last_progress_timestamp: Optional[float] = None
|
||||||
|
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0
|
||||||
|
|
||||||
|
def is_set(self) -> bool:
|
||||||
|
return self._event.is_set()
|
||||||
|
|
||||||
|
def is_paused(self) -> bool:
|
||||||
|
return not self._event.is_set()
|
||||||
|
|
||||||
|
def set(self) -> None:
|
||||||
|
self._event.set()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._event.clear()
|
||||||
|
|
||||||
|
async def wait(self) -> None:
|
||||||
|
await self._event.wait()
|
||||||
|
|
||||||
|
def pause(self) -> None:
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
def resume(self, *, force_reconnect: bool = False) -> None:
|
||||||
|
if force_reconnect:
|
||||||
|
self._reconnect_requested = True
|
||||||
|
self.set()
|
||||||
|
|
||||||
|
def request_reconnect(self) -> None:
|
||||||
|
self._reconnect_requested = True
|
||||||
|
self.set()
|
||||||
|
|
||||||
|
def has_reconnect_request(self) -> bool:
|
||||||
|
return self._reconnect_requested
|
||||||
|
|
||||||
|
def consume_reconnect_request(self) -> bool:
|
||||||
|
reconnect = self._reconnect_requested
|
||||||
|
self._reconnect_requested = False
|
||||||
|
return reconnect
|
||||||
|
|
||||||
|
def mark_progress(self, timestamp: Optional[float] = None) -> None:
|
||||||
|
self.last_progress_timestamp = timestamp or datetime.now().timestamp()
|
||||||
|
self._reconnect_requested = False
|
||||||
|
|
||||||
|
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]:
|
||||||
|
if self.last_progress_timestamp is None:
|
||||||
|
return None
|
||||||
|
reference = now if now is not None else datetime.now().timestamp()
|
||||||
|
return max(0.0, reference - self.last_progress_timestamp)
|
||||||
|
|
||||||
|
def update_stall_timeout(self, stall_timeout: float) -> None:
|
||||||
|
self.stall_timeout = float(stall_timeout)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadRestartRequested(Exception):
|
||||||
|
"""Raised when a caller explicitly requests a fresh HTTP stream."""
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadStalledError(Exception):
|
||||||
|
"""Raised when download progress stalls beyond the configured timeout."""
|
||||||
|
|
||||||
|
|
||||||
class Downloader:
|
class Downloader:
|
||||||
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||||
|
|
||||||
@@ -67,6 +134,7 @@ class Downloader:
|
|||||||
self.max_retries = 5
|
self.max_retries = 5
|
||||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||||
self.session_timeout = 300 # 5 minutes
|
self.session_timeout = 300 # 5 minutes
|
||||||
|
self.stall_timeout = self._resolve_stall_timeout()
|
||||||
|
|
||||||
# Default headers
|
# Default headers
|
||||||
self.default_headers = {
|
self.default_headers = {
|
||||||
@@ -90,6 +158,30 @@ class Downloader:
|
|||||||
self._proxy_url = None
|
self._proxy_url = None
|
||||||
return self._proxy_url
|
return self._proxy_url
|
||||||
|
|
||||||
|
def _resolve_stall_timeout(self) -> float:
|
||||||
|
"""Determine the stall timeout from settings or environment."""
|
||||||
|
default_timeout = 120.0
|
||||||
|
settings_timeout = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
settings_timeout = settings_manager.get('download_stall_timeout_seconds')
|
||||||
|
except Exception as exc: # pragma: no cover - defensive guard
|
||||||
|
logger.debug("Failed to read stall timeout from settings: %s", exc)
|
||||||
|
|
||||||
|
raw_value = (
|
||||||
|
settings_timeout
|
||||||
|
if settings_timeout not in (None, "")
|
||||||
|
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT')
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
timeout_value = float(raw_value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
timeout_value = default_timeout
|
||||||
|
|
||||||
|
return max(30.0, timeout_value)
|
||||||
|
|
||||||
def _should_refresh_session(self) -> bool:
|
def _should_refresh_session(self) -> bool:
|
||||||
"""Check if session should be refreshed"""
|
"""Check if session should be refreshed"""
|
||||||
if self._session is None:
|
if self._session is None:
|
||||||
@@ -181,7 +273,7 @@ class Downloader:
|
|||||||
use_auth: bool = False,
|
use_auth: bool = False,
|
||||||
custom_headers: Optional[Dict[str, str]] = None,
|
custom_headers: Optional[Dict[str, str]] = None,
|
||||||
allow_resume: bool = True,
|
allow_resume: bool = True,
|
||||||
pause_event: Optional[asyncio.Event] = None,
|
pause_event: Optional[DownloadStreamControl] = None,
|
||||||
) -> Tuple[bool, str]:
|
) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Download a file with resumable downloads and retry mechanism
|
Download a file with resumable downloads and retry mechanism
|
||||||
@@ -193,7 +285,7 @@ class Downloader:
|
|||||||
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
||||||
custom_headers: Additional headers to include in request
|
custom_headers: Additional headers to include in request
|
||||||
allow_resume: Whether to support resumable downloads
|
allow_resume: Whether to support resumable downloads
|
||||||
pause_event: Optional event that, when cleared, will pause streaming until set again
|
pause_event: Optional stream control used to pause/resume and request reconnects
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: (success, save_path or error message)
|
Tuple[bool, str]: (success, save_path or error message)
|
||||||
@@ -314,44 +406,81 @@ class Downloader:
|
|||||||
# Stream download to file with progress updates
|
# Stream download to file with progress updates
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||||
|
control = pause_event
|
||||||
|
|
||||||
|
if control is not None:
|
||||||
|
control.update_stall_timeout(self.stall_timeout)
|
||||||
|
|
||||||
with open(part_path, mode) as f:
|
with open(part_path, mode) as f:
|
||||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
while True:
|
||||||
if pause_event is not None and not pause_event.is_set():
|
active_stall_timeout = control.stall_timeout if control else self.stall_timeout
|
||||||
await pause_event.wait()
|
|
||||||
if chunk:
|
|
||||||
# Run blocking file write in executor
|
|
||||||
await loop.run_in_executor(None, f.write, chunk)
|
|
||||||
current_size += len(chunk)
|
|
||||||
|
|
||||||
# Limit progress update frequency to reduce overhead
|
if control is not None:
|
||||||
now = datetime.now()
|
if control.is_paused():
|
||||||
time_diff = (now - last_progress_report_time).total_seconds()
|
await control.wait()
|
||||||
|
resume_time = datetime.now()
|
||||||
|
last_progress_report_time = resume_time
|
||||||
|
if control.consume_reconnect_request():
|
||||||
|
raise DownloadRestartRequested(
|
||||||
|
"Reconnect requested after resume"
|
||||||
|
)
|
||||||
|
elif control.consume_reconnect_request():
|
||||||
|
raise DownloadRestartRequested("Reconnect requested")
|
||||||
|
|
||||||
if progress_callback and time_diff >= 1.0:
|
try:
|
||||||
progress_samples.append((now, current_size))
|
chunk = await asyncio.wait_for(
|
||||||
cutoff = now - timedelta(seconds=5)
|
response.content.read(self.chunk_size),
|
||||||
while progress_samples and progress_samples[0][0] < cutoff:
|
timeout=active_stall_timeout,
|
||||||
progress_samples.popleft()
|
)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Download stalled for %.1f seconds without progress from %s",
|
||||||
|
active_stall_timeout,
|
||||||
|
url,
|
||||||
|
)
|
||||||
|
raise DownloadStalledError(
|
||||||
|
f"No data received for {active_stall_timeout:.1f} seconds"
|
||||||
|
) from exc
|
||||||
|
|
||||||
percent = (current_size / total_size) * 100 if total_size else 0.0
|
if not chunk:
|
||||||
bytes_per_second = 0.0
|
break
|
||||||
if len(progress_samples) >= 2:
|
|
||||||
first_time, first_bytes = progress_samples[0]
|
|
||||||
last_time, last_bytes = progress_samples[-1]
|
|
||||||
elapsed = (last_time - first_time).total_seconds()
|
|
||||||
if elapsed > 0:
|
|
||||||
bytes_per_second = (last_bytes - first_bytes) / elapsed
|
|
||||||
|
|
||||||
progress_snapshot = DownloadProgress(
|
# Run blocking file write in executor
|
||||||
percent_complete=percent,
|
await loop.run_in_executor(None, f.write, chunk)
|
||||||
bytes_downloaded=current_size,
|
current_size += len(chunk)
|
||||||
total_bytes=total_size or None,
|
|
||||||
bytes_per_second=bytes_per_second,
|
|
||||||
timestamp=now.timestamp(),
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
|
now = datetime.now()
|
||||||
last_progress_report_time = now
|
if control is not None:
|
||||||
|
control.mark_progress(timestamp=now.timestamp())
|
||||||
|
|
||||||
|
# Limit progress update frequency to reduce overhead
|
||||||
|
time_diff = (now - last_progress_report_time).total_seconds()
|
||||||
|
|
||||||
|
if progress_callback and time_diff >= 1.0:
|
||||||
|
progress_samples.append((now, current_size))
|
||||||
|
cutoff = now - timedelta(seconds=5)
|
||||||
|
while progress_samples and progress_samples[0][0] < cutoff:
|
||||||
|
progress_samples.popleft()
|
||||||
|
|
||||||
|
percent = (current_size / total_size) * 100 if total_size else 0.0
|
||||||
|
bytes_per_second = 0.0
|
||||||
|
if len(progress_samples) >= 2:
|
||||||
|
first_time, first_bytes = progress_samples[0]
|
||||||
|
last_time, last_bytes = progress_samples[-1]
|
||||||
|
elapsed = (last_time - first_time).total_seconds()
|
||||||
|
if elapsed > 0:
|
||||||
|
bytes_per_second = (last_bytes - first_bytes) / elapsed
|
||||||
|
|
||||||
|
progress_snapshot = DownloadProgress(
|
||||||
|
percent_complete=percent,
|
||||||
|
bytes_downloaded=current_size,
|
||||||
|
total_bytes=total_size or None,
|
||||||
|
bytes_per_second=bytes_per_second,
|
||||||
|
timestamp=now.timestamp(),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
|
||||||
|
last_progress_report_time = now
|
||||||
|
|
||||||
# Download completed successfully
|
# Download completed successfully
|
||||||
# Verify file size integrity before finalizing
|
# Verify file size integrity before finalizing
|
||||||
@@ -447,8 +576,14 @@ class Downloader:
|
|||||||
|
|
||||||
return True, save_path
|
return True, save_path
|
||||||
|
|
||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
|
except (
|
||||||
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
|
aiohttp.ClientError,
|
||||||
|
aiohttp.ClientPayloadError,
|
||||||
|
aiohttp.ServerDisconnectedError,
|
||||||
|
asyncio.TimeoutError,
|
||||||
|
DownloadStalledError,
|
||||||
|
DownloadRestartRequested,
|
||||||
|
) as e:
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
@@ -8,6 +9,7 @@ from unittest.mock import AsyncMock
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from py.services.download_manager import DownloadManager
|
from py.services.download_manager import DownloadManager
|
||||||
|
from py.services.downloader import DownloadStreamControl
|
||||||
from py.services import download_manager
|
from py.services import download_manager
|
||||||
from py.services.service_registry import ServiceRegistry
|
from py.services.service_registry import ServiceRegistry
|
||||||
from py.services.settings_manager import SettingsManager, get_settings_manager
|
from py.services.settings_manager import SettingsManager, get_settings_manager
|
||||||
@@ -528,9 +530,8 @@ async def test_pause_download_updates_state():
|
|||||||
|
|
||||||
download_id = "dl"
|
download_id = "dl"
|
||||||
manager._download_tasks[download_id] = object()
|
manager._download_tasks[download_id] = object()
|
||||||
pause_event = asyncio.Event()
|
pause_control = DownloadStreamControl()
|
||||||
pause_event.set()
|
manager._pause_events[download_id] = pause_control
|
||||||
manager._pause_events[download_id] = pause_event
|
|
||||||
manager._active_downloads[download_id] = {
|
manager._active_downloads[download_id] = {
|
||||||
"status": "downloading",
|
"status": "downloading",
|
||||||
"bytes_per_second": 42.0,
|
"bytes_per_second": 42.0,
|
||||||
@@ -557,8 +558,10 @@ async def test_resume_download_sets_event_and_status():
|
|||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
|
|
||||||
download_id = "dl"
|
download_id = "dl"
|
||||||
pause_event = asyncio.Event()
|
pause_control = DownloadStreamControl()
|
||||||
manager._pause_events[download_id] = pause_event
|
pause_control.pause()
|
||||||
|
pause_control.mark_progress()
|
||||||
|
manager._pause_events[download_id] = pause_control
|
||||||
manager._active_downloads[download_id] = {
|
manager._active_downloads[download_id] = {
|
||||||
"status": "paused",
|
"status": "paused",
|
||||||
"bytes_per_second": 0.0,
|
"bytes_per_second": 0.0,
|
||||||
@@ -571,13 +574,32 @@ async def test_resume_download_sets_event_and_status():
|
|||||||
assert manager._active_downloads[download_id]["status"] == "downloading"
|
assert manager._active_downloads[download_id]["status"] == "downloading"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_resume_download_requests_reconnect_for_stalled_stream():
|
||||||
|
manager = DownloadManager()
|
||||||
|
|
||||||
|
download_id = "dl"
|
||||||
|
pause_control = DownloadStreamControl(stall_timeout=40)
|
||||||
|
pause_control.pause()
|
||||||
|
pause_control.last_progress_timestamp = (datetime.now().timestamp() - 120)
|
||||||
|
manager._pause_events[download_id] = pause_control
|
||||||
|
manager._active_downloads[download_id] = {
|
||||||
|
"status": "paused",
|
||||||
|
"bytes_per_second": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await manager.resume_download(download_id)
|
||||||
|
|
||||||
|
assert result == {"success": True, "message": "Download resumed successfully"}
|
||||||
|
assert pause_control.is_set() is True
|
||||||
|
assert pause_control.has_reconnect_request() is True
|
||||||
|
|
||||||
|
|
||||||
async def test_resume_download_rejects_when_not_paused():
|
async def test_resume_download_rejects_when_not_paused():
|
||||||
manager = DownloadManager()
|
manager = DownloadManager()
|
||||||
|
|
||||||
download_id = "dl"
|
download_id = "dl"
|
||||||
pause_event = asyncio.Event()
|
pause_control = DownloadStreamControl()
|
||||||
pause_event.set()
|
manager._pause_events[download_id] = pause_control
|
||||||
manager._pause_events[download_id] = pause_event
|
|
||||||
|
|
||||||
result = await manager.resume_download(download_id)
|
result = await manager.resume_download(download_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -8,13 +9,24 @@ from py.services.downloader import Downloader
|
|||||||
|
|
||||||
|
|
||||||
class FakeStream:
|
class FakeStream:
|
||||||
def __init__(self, chunks):
|
def __init__(self, chunks: Sequence[Sequence] | Sequence[bytes]):
|
||||||
self._chunks = list(chunks)
|
self._chunks = list(chunks)
|
||||||
|
|
||||||
async def iter_chunked(self, _chunk_size):
|
async def read(self, _chunk_size: int) -> bytes:
|
||||||
for chunk in self._chunks:
|
if not self._chunks:
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
yield chunk
|
return b""
|
||||||
|
|
||||||
|
item = self._chunks.pop(0)
|
||||||
|
delay = 0.0
|
||||||
|
payload = item
|
||||||
|
|
||||||
|
if isinstance(item, tuple):
|
||||||
|
payload = item[0]
|
||||||
|
delay = item[1]
|
||||||
|
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
class FakeResponse:
|
class FakeResponse:
|
||||||
@@ -53,6 +65,12 @@ def _build_downloader(responses, *, max_retries=0):
|
|||||||
downloader._session = fake_session
|
downloader._session = fake_session
|
||||||
downloader._session_created_at = datetime.now()
|
downloader._session_created_at = datetime.now()
|
||||||
downloader._proxy_url = None
|
downloader._proxy_url = None
|
||||||
|
async def _noop_create_session():
|
||||||
|
downloader._session = fake_session
|
||||||
|
downloader._session_created_at = datetime.now()
|
||||||
|
downloader._proxy_url = None
|
||||||
|
|
||||||
|
downloader._create_session = _noop_create_session # type: ignore[assignment]
|
||||||
return downloader
|
return downloader
|
||||||
|
|
||||||
|
|
||||||
@@ -123,3 +141,34 @@ async def test_download_file_succeeds_when_sizes_match(tmp_path):
|
|||||||
assert success is True
|
assert success is True
|
||||||
assert Path(result_path).read_bytes() == payload
|
assert Path(result_path).read_bytes() == payload
|
||||||
assert not Path(str(target_path) + ".part").exists()
|
assert not Path(str(target_path) + ".part").exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_file_recovers_from_stall(tmp_path):
|
||||||
|
target_path = tmp_path / "model" / "file.bin"
|
||||||
|
target_path.parent.mkdir()
|
||||||
|
|
||||||
|
payload = b"abcdef"
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
lambda: FakeResponse(
|
||||||
|
status=200,
|
||||||
|
headers={"content-length": str(len(payload))},
|
||||||
|
chunks=[(b"abc", 0.0), (b"def", 0.1)],
|
||||||
|
),
|
||||||
|
lambda: FakeResponse(
|
||||||
|
status=206,
|
||||||
|
headers={"content-length": "3", "Content-Range": "bytes 3-5/6"},
|
||||||
|
chunks=[(b"def", 0.0)],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
downloader = _build_downloader(responses, max_retries=1)
|
||||||
|
downloader.stall_timeout = 0.05
|
||||||
|
|
||||||
|
success, result_path = await downloader.download_file("https://example.com/file", str(target_path))
|
||||||
|
|
||||||
|
assert success is True
|
||||||
|
assert Path(result_path).read_bytes() == payload
|
||||||
|
assert downloader._session._get_calls == 2
|
||||||
|
assert not Path(str(target_path) + ".part").exists()
|
||||||
|
|||||||
Reference in New Issue
Block a user