diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 3a04815a..f5457935 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -59,6 +59,7 @@ class ApiRoutes: app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) app.router.add_post('/api/download-model', routes.download_model) app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint + app.router.add_get('/api/cancel-download-get', routes.cancel_download_get) app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress app.router.add_post('/api/move_model', routes.move_model) app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route @@ -500,6 +501,29 @@ class ApiRoutes: error_message = str(e) logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) return web.Response(status=500, text=error_message) + + async def cancel_download_get(self, request: web.Request) -> web.Response: + """Handle GET request for cancelling a download by download_id""" + try: + download_id = request.query.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + if self.download_manager is None: + self.download_manager = await ServiceRegistry.get_download_manager() + # Create a mock request with match_info for compatibility + mock_request = type('MockRequest', (), { + 'match_info': {'download_id': download_id} + })() + return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager) + except Exception as e: + logger.error(f"Error cancelling download via GET: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) async def get_download_progress(self, request: web.Request) -> web.Response: """Handle request for download progress by download_id""" diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 9abc5e2e..59bb7f39 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -45,14 +45,14 @@ class CivitaiClient: # Optimize TCP connection parameters connector = aiohttp.TCPConnector( ssl=True, - limit=3, # Further reduced from 5 to 3 - ttl_dns_cache=0, # Disabled DNS caching completely + limit=8, # Increase from 3 to 8 for better parallelism + ttl_dns_cache=300, # Enable DNS caching with reasonable timeout force_close=False, # Keep connections for reuse enable_cleanup_closed=True ) trust_env = True # Allow using system environment proxy settings - # Configure timeout parameters - timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) + # Configure timeout parameters - increase read timeout for large files + timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=120) self._session = aiohttp.ClientSession( connector=connector, trust_env=trust_env, @@ -165,7 +165,7 @@ class CivitaiClient: now = datetime.now() time_diff = (now - last_progress_report_time).total_seconds() - if progress_callback and total_size and time_diff >= 0.5: + if progress_callback and total_size and time_diff >= 1.0: progress = (current_size / total_size) * 100 await progress_callback(progress) last_progress_report_time = now diff --git a/py/services/download_manager.py b/py/services/download_manager.py index d0ca7094..e98e5498 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -1,6 +1,8 @@ import logging import os import asyncio +from collections import OrderedDict +import uuid from typing import Dict from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS @@ -33,6 +35,10 @@ class DownloadManager: self._initialized = True self._civitai_client = None # Will be lazily initialized + # Add download management + self._active_downloads = OrderedDict() # download_id -> download_info + self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads + self._download_tasks = {} # download_id -> asyncio.Task async def _get_civitai_client(self): """Lazily initialize CivitaiClient from registry""" @@ -47,27 +53,132 @@ class DownloadManager: async def _get_checkpoint_scanner(self): """Get the checkpoint scanner from registry""" return await ServiceRegistry.get_checkpoint_scanner() - - async def download_from_civitai(self, model_id: int, - model_version_id: int, save_dir: str = None, - relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict: - """Download model from Civitai + + async def download_from_civitai(self, model_id: int, model_version_id: int, + save_dir: str = None, relative_path: str = '', + progress_callback=None, use_default_paths: bool = False, + download_id: str = None) -> Dict: + """Download model from Civitai with task tracking and concurrency control Args: model_id: Civitai model ID - model_version_id: Civitai model version ID (optional, if not provided, will download the latest version) - save_dir: Directory to save the model to + model_version_id: Civitai model version ID + save_dir: Directory to save the model relative_path: Relative path within save_dir progress_callback: Callback function for progress updates - use_default_paths: Flag to indicate whether to use default paths + use_default_paths: Flag to use default paths + download_id: Unique identifier for this download task Returns: Dict with download result """ + # Use provided download_id or generate new one + task_id = download_id or str(uuid.uuid4()) + + # Register download task in tracking dict + self._active_downloads[task_id] = { + 'model_id': model_id, + 'model_version_id': model_version_id, + 'progress': 0, + 'status': 'queued' + } + + # Create tracking task + download_task = asyncio.create_task( + self._download_with_semaphore( + task_id, model_id, model_version_id, save_dir, + relative_path, progress_callback, use_default_paths + ) + ) + + # Store task for tracking and cancellation + self._download_tasks[task_id] = download_task + + try: + # Wait for download to complete + result = await download_task + result['download_id'] = task_id # Include download_id in result + return result + except asyncio.CancelledError: + return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id} + finally: + # Clean up task reference + if task_id in self._download_tasks: + del self._download_tasks[task_id] + + async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int, + save_dir: str, relative_path: str, + progress_callback=None, use_default_paths: bool = False): + """Execute download with semaphore to limit concurrency""" + # Update status to waiting + if task_id in self._active_downloads: + self._active_downloads[task_id]['status'] = 'waiting' + + # Wrap progress callback to track progress in active_downloads + original_callback = progress_callback + async def tracking_callback(progress): + if task_id in self._active_downloads: + self._active_downloads[task_id]['progress'] = progress + if original_callback: + await original_callback(progress) + + # Acquire semaphore to limit concurrent downloads + try: + async with self._download_semaphore: + # Update status to downloading + if task_id in self._active_downloads: + self._active_downloads[task_id]['status'] = 'downloading' + + # Use original download implementation + try: + # Check for cancellation before starting + if asyncio.current_task().cancelled(): + raise asyncio.CancelledError() + + result = await self._execute_original_download( + model_id, model_version_id, save_dir, + relative_path, tracking_callback, use_default_paths, + task_id + ) + + # Update status based on result + if task_id in self._active_downloads: + self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed' + if not result['success']: + self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error') + + return result + except asyncio.CancelledError: + # Handle cancellation + if task_id in self._active_downloads: + self._active_downloads[task_id]['status'] = 'cancelled' + logger.info(f"Download cancelled for task {task_id}") + raise + except Exception as e: + # Handle other errors + logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True) + if task_id in self._active_downloads: + self._active_downloads[task_id]['status'] = 'failed' + self._active_downloads[task_id]['error'] = str(e) + return {'success': False, 'error': str(e)} + finally: + # Schedule cleanup of download record after delay + asyncio.create_task(self._cleanup_download_record(task_id)) + + async def _cleanup_download_record(self, task_id: str): + """Keep completed downloads in history for a short time""" + await asyncio.sleep(600) # Keep for 10 minutes + if task_id in self._active_downloads: + del self._active_downloads[task_id] + + async def _execute_original_download(self, model_id, model_version_id, save_dir, + relative_path, progress_callback, use_default_paths, + download_id=None): + """Wrapper for original download_from_civitai implementation""" try: # Check if model version already exists in library if model_version_id is not None: - # Case 1: model_version_id is provided, check both scanners + # Check both scanners lora_scanner = await self._get_lora_scanner() checkpoint_scanner = await self._get_checkpoint_scanner() @@ -183,7 +294,8 @@ class DownloadManager: version_info=version_info, relative_path=relative_path, progress_callback=progress_callback, - model_type=model_type + model_type=model_type, + download_id=download_id ) return result @@ -243,12 +355,16 @@ class DownloadManager: async def _execute_download(self, download_url: str, save_dir: str, metadata, version_info: Dict, relative_path: str, progress_callback=None, - model_type: str = "lora") -> Dict: + model_type: str = "lora", download_id: str = None) -> Dict: """Execute the actual download process including preview images and model files""" try: civitai_client = await self._get_civitai_client() save_path = metadata.file_path metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' + + # Store file path in active_downloads for potential cleanup + if download_id and download_id in self._active_downloads: + self._active_downloads[download_id]['file_path'] = save_path # Download preview image if available images = version_info.get('images', []) @@ -367,4 +483,86 @@ class DownloadManager: if progress_callback: # Scale file progress to 3-100 range (after preview download) overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download - await progress_callback(round(overall_progress)) \ No newline at end of file + await progress_callback(round(overall_progress)) + + async def cancel_download(self, download_id: str) -> Dict: + """Cancel an active download by download_id + + Args: + download_id: The unique identifier of the download task + + Returns: + Dict: Status of the cancellation operation + """ + if download_id not in self._download_tasks: + return {'success': False, 'error': 'Download task not found'} + + try: + # Get the task and cancel it + task = self._download_tasks[download_id] + task.cancel() + + # Update status in active downloads + if download_id in self._active_downloads: + self._active_downloads[download_id]['status'] = 'cancelling' + + # Wait briefly for the task to acknowledge cancellation + try: + await asyncio.wait_for(asyncio.shield(task), timeout=2.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + + # Clean up partial downloads + download_info = self._active_downloads.get(download_id) + if download_info and 'file_path' in download_info: + # Delete the partial file + file_path = download_info['file_path'] + if os.path.exists(file_path): + try: + os.unlink(file_path) + logger.debug(f"Deleted partial download: {file_path}") + except Exception as e: + logger.error(f"Error deleting partial file: {e}") + + # Delete metadata file if exists + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + if os.path.exists(metadata_path): + try: + os.unlink(metadata_path) + except Exception as e: + logger.error(f"Error deleting metadata file: {e}") + + # Delete preview file if exists (.webp or .mp4) + for preview_ext in ['.webp', '.mp4']: + preview_path = os.path.splitext(file_path)[0] + preview_ext + if os.path.exists(preview_path): + try: + os.unlink(preview_path) + logger.debug(f"Deleted preview file: {preview_path}") + except Exception as e: + logger.error(f"Error deleting preview file: {e}") + + return {'success': True, 'message': 'Download cancelled successfully'} + except Exception as e: + logger.error(f"Error cancelling download: {e}", exc_info=True) + return {'success': False, 'error': str(e)} + + async def get_active_downloads(self) -> Dict: + """Get information about all active downloads + + Returns: + Dict: List of active downloads and their status + """ + return { + 'downloads': [ + { + 'download_id': task_id, + 'model_id': info.get('model_id'), + 'model_version_id': info.get('model_version_id'), + 'progress': info.get('progress', 0), + 'status': info.get('status', 'unknown'), + 'error': info.get('error', None) + } + for task_id, info in self._active_downloads.items() + ] + } \ No newline at end of file diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index d5f8d9eb..e151d216 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -611,13 +611,15 @@ class ModelRouteUtils: use_default_paths = data.get('use_default_paths', False) + # Pass the download_id to download_from_civitai result = await download_manager.download_from_civitai( model_id=model_id, model_version_id=model_version_id, save_dir=data.get('model_root'), relative_path=data.get('relative_path', ''), use_default_paths=use_default_paths, - progress_callback=progress_callback + progress_callback=progress_callback, + download_id=download_id # Pass download_id explicitly ) # Include download_id in the response @@ -631,12 +633,14 @@ class ModelRouteUtils: logger.warning(f"Early access download failed: {error_message}") return web.json_response({ 'success': False, - 'error': f"Early Access Restriction: {error_message}" + 'error': f"Early Access Restriction: {error_message}", + 'download_id': download_id }, status=401) return web.json_response({ 'success': False, - 'error': error_message + 'error': error_message, + 'download_id': download_id }, status=500) return web.json_response(result) @@ -658,6 +662,65 @@ class ModelRouteUtils: 'error': error_message }, status=500) + @staticmethod + async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response: + """Handle cancellation of a download task + + Args: + request: The aiohttp request + download_manager: The download manager instance + + Returns: + web.Response: The HTTP response + """ + try: + download_id = request.match_info.get('download_id') + if not download_id: + return web.json_response({ + 'success': False, + 'error': 'Download ID is required' + }, status=400) + + result = await download_manager.cancel_download(download_id) + + # Notify clients about cancellation via WebSocket + await ws_manager.broadcast_download_progress(download_id, { + 'status': 'cancelled', + 'progress': 0, + 'download_id': download_id, + 'message': 'Download cancelled by user' + }) + + return web.json_response(result) + + except Exception as e: + logger.error(f"Error cancelling download: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + + @staticmethod + async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response: + """Get list of active downloads + + Args: + request: The aiohttp request + download_manager: The download manager instance + + Returns: + web.Response: The HTTP response with list of downloads + """ + try: + result = await download_manager.get_active_downloads() + return web.json_response(result) + except Exception as e: + logger.error(f"Error listing downloads: {e}", exc_info=True) + return web.json_response({ + 'success': False, + 'error': str(e) + }, status=500) + @staticmethod async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response: """Handle bulk deletion of models