mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
test(downloads): cover pause and resume flows
This commit is contained in:
@@ -43,6 +43,7 @@ class DownloadManager:
|
||||
self._active_downloads = OrderedDict() # download_id -> download_info
|
||||
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
||||
self._download_tasks = {} # download_id -> asyncio.Task
|
||||
self._pause_events: Dict[str, asyncio.Event] = {}
|
||||
|
||||
async def _get_lora_scanner(self):
|
||||
"""Get the lora scanner from registry"""
|
||||
@@ -88,11 +89,15 @@ class DownloadManager:
|
||||
'total_bytes': None,
|
||||
'bytes_per_second': 0.0,
|
||||
}
|
||||
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
self._pause_events[task_id] = pause_event
|
||||
|
||||
# Create tracking task
|
||||
download_task = asyncio.create_task(
|
||||
self._download_with_semaphore(
|
||||
task_id, model_id, model_version_id, save_dir,
|
||||
task_id, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths, source
|
||||
)
|
||||
)
|
||||
@@ -111,9 +116,10 @@ class DownloadManager:
|
||||
# Clean up task reference
|
||||
if task_id in self._download_tasks:
|
||||
del self._download_tasks[task_id]
|
||||
self._pause_events.pop(task_id, None)
|
||||
|
||||
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
||||
save_dir: str, relative_path: str,
|
||||
save_dir: str, relative_path: str,
|
||||
progress_callback=None, use_default_paths: bool = False,
|
||||
source: str = None):
|
||||
"""Execute download with semaphore to limit concurrency"""
|
||||
@@ -140,6 +146,13 @@ class DownloadManager:
|
||||
# Acquire semaphore to limit concurrent downloads
|
||||
try:
|
||||
async with self._download_semaphore:
|
||||
pause_event = self._pause_events.get(task_id)
|
||||
if pause_event is not None and not pause_event.is_set():
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'paused'
|
||||
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
||||
await pause_event.wait()
|
||||
|
||||
# Update status to downloading
|
||||
if task_id in self._active_downloads:
|
||||
self._active_downloads[task_id]['status'] = 'downloading'
|
||||
@@ -189,9 +202,17 @@ class DownloadManager:
|
||||
if task_id in self._active_downloads:
|
||||
del self._active_downloads[task_id]
|
||||
|
||||
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
||||
relative_path, progress_callback, use_default_paths,
|
||||
download_id=None, source=None):
|
||||
async def _execute_original_download(
|
||||
self,
|
||||
model_id,
|
||||
model_version_id,
|
||||
save_dir,
|
||||
relative_path,
|
||||
progress_callback,
|
||||
use_default_paths,
|
||||
download_id=None,
|
||||
source=None,
|
||||
):
|
||||
"""Wrapper for original download_from_civitai implementation"""
|
||||
try:
|
||||
# Check if model version already exists in library
|
||||
@@ -345,7 +366,7 @@ class DownloadManager:
|
||||
relative_path=relative_path,
|
||||
progress_callback=progress_callback,
|
||||
model_type=model_type,
|
||||
download_id=download_id
|
||||
download_id=download_id,
|
||||
)
|
||||
|
||||
# If early_access_msg exists and download failed, replace error message
|
||||
@@ -410,10 +431,17 @@ class DownloadManager:
|
||||
|
||||
return formatted_path
|
||||
|
||||
async def _execute_download(self, download_urls: List[str], save_dir: str,
|
||||
metadata, version_info: Dict,
|
||||
relative_path: str, progress_callback=None,
|
||||
model_type: str = "lora", download_id: str = None) -> Dict:
|
||||
async def _execute_download(
|
||||
self,
|
||||
download_urls: List[str],
|
||||
save_dir: str,
|
||||
metadata,
|
||||
version_info: Dict,
|
||||
relative_path: str,
|
||||
progress_callback=None,
|
||||
model_type: str = "lora",
|
||||
download_id: str = None,
|
||||
) -> Dict:
|
||||
"""Execute the actual download process including preview images and model files"""
|
||||
try:
|
||||
# Extract original filename details
|
||||
@@ -444,6 +472,8 @@ class DownloadManager:
|
||||
|
||||
part_path = save_path + '.part'
|
||||
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
||||
|
||||
pause_event = self._pause_events.get(download_id) if download_id else None
|
||||
|
||||
# Store file paths in active_downloads for potential cleanup
|
||||
if download_id and download_id in self._active_downloads:
|
||||
@@ -558,15 +588,22 @@ class DownloadManager:
|
||||
last_error = None
|
||||
for download_url in download_urls:
|
||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
progress_callback=lambda progress, snapshot=None: self._handle_download_progress(
|
||||
download_kwargs = {
|
||||
"progress_callback": lambda progress, snapshot=None: self._handle_download_progress(
|
||||
progress,
|
||||
progress_callback,
|
||||
snapshot,
|
||||
),
|
||||
use_auth=use_auth # Only use authentication for Civitai downloads
|
||||
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
||||
}
|
||||
|
||||
if pause_event is not None:
|
||||
download_kwargs["pause_event"] = pause_event
|
||||
|
||||
success, result = await downloader.download_file(
|
||||
download_url,
|
||||
save_path, # Use full path instead of separate dir and filename
|
||||
**download_kwargs,
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -675,7 +712,7 @@ class DownloadManager:
|
||||
|
||||
async def cancel_download(self, download_id: str) -> Dict:
|
||||
"""Cancel an active download by download_id
|
||||
|
||||
|
||||
Args:
|
||||
download_id: The unique identifier of the download task
|
||||
|
||||
@@ -689,7 +726,11 @@ class DownloadManager:
|
||||
# Get the task and cancel it
|
||||
task = self._download_tasks[download_id]
|
||||
task.cancel()
|
||||
|
||||
|
||||
pause_event = self._pause_events.get(download_id)
|
||||
if pause_event is not None:
|
||||
pause_event.set()
|
||||
|
||||
# Update status in active downloads
|
||||
if download_id in self._active_downloads:
|
||||
self._active_downloads[download_id]['status'] = 'cancelling'
|
||||
@@ -756,6 +797,52 @@ class DownloadManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
||||
return {'success': False, 'error': str(e)}
|
||||
finally:
|
||||
self._pause_events.pop(download_id, None)
|
||||
|
||||
async def pause_download(self, download_id: str) -> Dict:
|
||||
"""Pause an active download without losing progress."""
|
||||
|
||||
if download_id not in self._download_tasks:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
pause_event = self._pause_events.get(download_id)
|
||||
if pause_event is None:
|
||||
pause_event = asyncio.Event()
|
||||
pause_event.set()
|
||||
self._pause_events[download_id] = pause_event
|
||||
|
||||
if not pause_event.is_set():
|
||||
return {'success': False, 'error': 'Download is already paused'}
|
||||
|
||||
pause_event.clear()
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info is not None:
|
||||
download_info['status'] = 'paused'
|
||||
download_info['bytes_per_second'] = 0.0
|
||||
|
||||
return {'success': True, 'message': 'Download paused successfully'}
|
||||
|
||||
async def resume_download(self, download_id: str) -> Dict:
|
||||
"""Resume a previously paused download."""
|
||||
|
||||
pause_event = self._pause_events.get(download_id)
|
||||
if pause_event is None:
|
||||
return {'success': False, 'error': 'Download task not found'}
|
||||
|
||||
if pause_event.is_set():
|
||||
return {'success': False, 'error': 'Download is not paused'}
|
||||
|
||||
pause_event.set()
|
||||
|
||||
download_info = self._active_downloads.get(download_id)
|
||||
if download_info is not None:
|
||||
if download_info.get('status') == 'paused':
|
||||
download_info['status'] = 'downloading'
|
||||
download_info.setdefault('bytes_per_second', 0.0)
|
||||
|
||||
return {'success': True, 'message': 'Download resumed successfully'}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_progress_value(progress) -> float:
|
||||
|
||||
Reference in New Issue
Block a user