fix(download): recover stalled transfers automatically

This commit is contained in:
pixelpaws
2025-10-23 17:25:38 +08:00
parent 2eae8a7729
commit faa26651dd
4 changed files with 303 additions and 81 deletions

View File

@@ -15,7 +15,7 @@ from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager
from .metadata_service import get_default_metadata_provider
from .downloader import get_downloader, DownloadProgress
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
# Download to temporary file first
import tempfile
@@ -44,7 +44,7 @@ class DownloadManager:
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
self._pause_events: Dict[str, asyncio.Event] = {}
self._pause_events: Dict[str, DownloadStreamControl] = {}
async def _get_lora_scanner(self):
"""Get the lora scanner from registry"""
@@ -89,11 +89,11 @@ class DownloadManager:
'bytes_downloaded': 0,
'total_bytes': None,
'bytes_per_second': 0.0,
'last_progress_timestamp': None,
}
pause_event = asyncio.Event()
pause_event.set()
self._pause_events[task_id] = pause_event
pause_control = DownloadStreamControl()
self._pause_events[task_id] = pause_control
# Create tracking task
download_task = asyncio.create_task(
@@ -140,19 +140,23 @@ class DownloadManager:
info['bytes_downloaded'] = snapshot.bytes_downloaded
info['total_bytes'] = snapshot.total_bytes
info['bytes_per_second'] = snapshot.bytes_per_second
pause_control = self._pause_events.get(task_id)
if isinstance(pause_control, DownloadStreamControl):
pause_control.mark_progress(snapshot.timestamp)
info['last_progress_timestamp'] = pause_control.last_progress_timestamp
if original_callback:
await self._dispatch_progress(original_callback, snapshot, progress_value)
# Acquire semaphore to limit concurrent downloads
try:
async with self._download_semaphore:
pause_event = self._pause_events.get(task_id)
if pause_event is not None and not pause_event.is_set():
pause_control = self._pause_events.get(task_id)
if pause_control is not None and pause_control.is_paused():
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'paused'
self._active_downloads[task_id]['bytes_per_second'] = 0.0
await pause_event.wait()
await pause_control.wait()
# Update status to downloading
if task_id in self._active_downloads:
@@ -478,8 +482,8 @@ class DownloadManager:
part_path = save_path + '.part'
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
pause_event = self._pause_events.get(download_id) if download_id else None
pause_control = self._pause_events.get(download_id) if download_id else None
# Store file paths in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = save_path
@@ -590,6 +594,8 @@ class DownloadManager:
# Download model file with progress tracking using downloader
downloader = await get_downloader()
if pause_control is not None:
pause_control.update_stall_timeout(downloader.stall_timeout)
last_error = None
for download_url in download_urls:
use_auth = download_url.startswith("https://civitai.com/api/download/")
@@ -602,8 +608,8 @@ class DownloadManager:
"use_auth": use_auth, # Only use authentication for Civitai downloads
}
if pause_event is not None:
download_kwargs["pause_event"] = pause_event
if pause_control is not None:
download_kwargs["pause_event"] = pause_control
success, result = await downloader.download_file(
download_url,
@@ -756,9 +762,9 @@ class DownloadManager:
task = self._download_tasks[download_id]
task.cancel()
pause_event = self._pause_events.get(download_id)
if pause_event is not None:
pause_event.set()
pause_control = self._pause_events.get(download_id)
if pause_control is not None:
pause_control.resume()
# Update status in active downloads
if download_id in self._active_downloads:
@@ -835,16 +841,14 @@ class DownloadManager:
if download_id not in self._download_tasks:
return {'success': False, 'error': 'Download task not found'}
pause_event = self._pause_events.get(download_id)
if pause_event is None:
pause_event = asyncio.Event()
pause_event.set()
self._pause_events[download_id] = pause_event
pause_control = self._pause_events.get(download_id)
if pause_control is None:
return {'success': False, 'error': 'Download task not found'}
if not pause_event.is_set():
if pause_control.is_paused():
return {'success': False, 'error': 'Download is already paused'}
pause_event.clear()
pause_control.pause()
download_info = self._active_downloads.get(download_id)
if download_info is not None:
@@ -856,16 +860,28 @@ class DownloadManager:
async def resume_download(self, download_id: str) -> Dict:
"""Resume a previously paused download."""
pause_event = self._pause_events.get(download_id)
if pause_event is None:
pause_control = self._pause_events.get(download_id)
if pause_control is None:
return {'success': False, 'error': 'Download task not found'}
if pause_event.is_set():
if pause_control.is_set():
return {'success': False, 'error': 'Download is not paused'}
pause_event.set()
download_info = self._active_downloads.get(download_id)
force_reconnect = False
if pause_control is not None:
elapsed = pause_control.time_since_last_progress()
threshold = max(30.0, pause_control.stall_timeout / 2.0)
if elapsed is not None and elapsed >= threshold:
force_reconnect = True
logger.info(
"Forcing reconnect for download %s after %.1f seconds without progress",
download_id,
elapsed,
)
pause_control.resume(force_reconnect=force_reconnect)
if download_info is not None:
if download_info.get('status') == 'paused':
download_info['status'] = 'downloading'