feat: Add download cancellation and tracking features in DownloadManager and API routes

This commit is contained in:
Will Miao
2025-07-21 15:38:20 +08:00
parent fa444dfb8a
commit de06c6b2f6
4 changed files with 305 additions and 20 deletions

View File

@@ -59,6 +59,7 @@ class ApiRoutes:
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash) app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
app.router.add_post('/api/download-model', routes.download_model) app.router.add_post('/api/download-model', routes.download_model)
app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint
app.router.add_get('/api/cancel-download-get', routes.cancel_download_get)
app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress
app.router.add_post('/api/move_model', routes.move_model) app.router.add_post('/api/move_model', routes.move_model)
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
@@ -501,6 +502,29 @@ class ApiRoutes:
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True) logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
return web.Response(status=500, text=error_message) return web.Response(status=500, text=error_message)
async def cancel_download_get(self, request: web.Request) -> web.Response:
"""Handle GET request for cancelling a download by download_id"""
try:
download_id = request.query.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
# Create a mock request with match_info for compatibility
mock_request = type('MockRequest', (), {
'match_info': {'download_id': download_id}
})()
return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager)
except Exception as e:
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_download_progress(self, request: web.Request) -> web.Response: async def get_download_progress(self, request: web.Request) -> web.Response:
"""Handle request for download progress by download_id""" """Handle request for download progress by download_id"""
try: try:

View File

@@ -45,14 +45,14 @@ class CivitaiClient:
# Optimize TCP connection parameters # Optimize TCP connection parameters
connector = aiohttp.TCPConnector( connector = aiohttp.TCPConnector(
ssl=True, ssl=True,
limit=3, # Further reduced from 5 to 3 limit=8, # Increase from 3 to 8 for better parallelism
ttl_dns_cache=0, # Disabled DNS caching completely ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
force_close=False, # Keep connections for reuse force_close=False, # Keep connections for reuse
enable_cleanup_closed=True enable_cleanup_closed=True
) )
trust_env = True # Allow using system environment proxy settings trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters # Configure timeout parameters - increase read timeout for large files
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=120)
self._session = aiohttp.ClientSession( self._session = aiohttp.ClientSession(
connector=connector, connector=connector,
trust_env=trust_env, trust_env=trust_env,
@@ -165,7 +165,7 @@ class CivitaiClient:
now = datetime.now() now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds() time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 0.5: if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100 progress = (current_size / total_size) * 100
await progress_callback(progress) await progress_callback(progress)
last_progress_report_time = now last_progress_report_time = now

View File

@@ -1,6 +1,8 @@
import logging import logging
import os import os
import asyncio import asyncio
from collections import OrderedDict
import uuid
from typing import Dict from typing import Dict
from ..utils.models import LoraMetadata, CheckpointMetadata from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
@@ -33,6 +35,10 @@ class DownloadManager:
self._initialized = True self._initialized = True
self._civitai_client = None # Will be lazily initialized self._civitai_client = None # Will be lazily initialized
# Add download management
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
async def _get_civitai_client(self): async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry""" """Lazily initialize CivitaiClient from registry"""
@@ -48,26 +54,131 @@ class DownloadManager:
"""Get the checkpoint scanner from registry""" """Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner() return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, model_id: int, async def download_from_civitai(self, model_id: int, model_version_id: int,
model_version_id: int, save_dir: str = None, save_dir: str = None, relative_path: str = '',
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict: progress_callback=None, use_default_paths: bool = False,
"""Download model from Civitai download_id: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control
Args: Args:
model_id: Civitai model ID model_id: Civitai model ID
model_version_id: Civitai model version ID (optional, if not provided, will download the latest version) model_version_id: Civitai model version ID
save_dir: Directory to save the model to save_dir: Directory to save the model
relative_path: Relative path within save_dir relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates progress_callback: Callback function for progress updates
use_default_paths: Flag to indicate whether to use default paths use_default_paths: Flag to use default paths
download_id: Unique identifier for this download task
Returns: Returns:
Dict with download result Dict with download result
""" """
# Use provided download_id or generate new one
task_id = download_id or str(uuid.uuid4())
# Register download task in tracking dict
self._active_downloads[task_id] = {
'model_id': model_id,
'model_version_id': model_version_id,
'progress': 0,
'status': 'queued'
}
# Create tracking task
download_task = asyncio.create_task(
self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths
)
)
# Store task for tracking and cancellation
self._download_tasks[task_id] = download_task
try:
# Wait for download to complete
result = await download_task
result['download_id'] = task_id # Include download_id in result
return result
except asyncio.CancelledError:
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
finally:
# Clean up task reference
if task_id in self._download_tasks:
del self._download_tasks[task_id]
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False):
"""Execute download with semaphore to limit concurrency"""
# Update status to waiting
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'waiting'
# Wrap progress callback to track progress in active_downloads
original_callback = progress_callback
async def tracking_callback(progress):
if task_id in self._active_downloads:
self._active_downloads[task_id]['progress'] = progress
if original_callback:
await original_callback(progress)
# Acquire semaphore to limit concurrent downloads
try:
async with self._download_semaphore:
# Update status to downloading
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'downloading'
# Use original download implementation
try:
# Check for cancellation before starting
if asyncio.current_task().cancelled():
raise asyncio.CancelledError()
result = await self._execute_original_download(
model_id, model_version_id, save_dir,
relative_path, tracking_callback, use_default_paths,
task_id
)
# Update status based on result
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
if not result['success']:
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
return result
except asyncio.CancelledError:
# Handle cancellation
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'cancelled'
logger.info(f"Download cancelled for task {task_id}")
raise
except Exception as e:
# Handle other errors
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'failed'
self._active_downloads[task_id]['error'] = str(e)
return {'success': False, 'error': str(e)}
finally:
# Schedule cleanup of download record after delay
asyncio.create_task(self._cleanup_download_record(task_id))
async def _cleanup_download_record(self, task_id: str):
"""Keep completed downloads in history for a short time"""
await asyncio.sleep(600) # Keep for 10 minutes
if task_id in self._active_downloads:
del self._active_downloads[task_id]
async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths,
download_id=None):
"""Wrapper for original download_from_civitai implementation"""
try: try:
# Check if model version already exists in library # Check if model version already exists in library
if model_version_id is not None: if model_version_id is not None:
# Case 1: model_version_id is provided, check both scanners # Check both scanners
lora_scanner = await self._get_lora_scanner() lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner() checkpoint_scanner = await self._get_checkpoint_scanner()
@@ -183,7 +294,8 @@ class DownloadManager:
version_info=version_info, version_info=version_info,
relative_path=relative_path, relative_path=relative_path,
progress_callback=progress_callback, progress_callback=progress_callback,
model_type=model_type model_type=model_type,
download_id=download_id
) )
return result return result
@@ -243,13 +355,17 @@ class DownloadManager:
async def _execute_download(self, download_url: str, save_dir: str, async def _execute_download(self, download_url: str, save_dir: str,
metadata, version_info: Dict, metadata, version_info: Dict,
relative_path: str, progress_callback=None, relative_path: str, progress_callback=None,
model_type: str = "lora") -> Dict: model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files""" """Execute the actual download process including preview images and model files"""
try: try:
civitai_client = await self._get_civitai_client() civitai_client = await self._get_civitai_client()
save_path = metadata.file_path save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# Store file path in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = save_path
# Download preview image if available # Download preview image if available
images = version_info.get('images', []) images = version_info.get('images', [])
if images: if images:
@@ -368,3 +484,85 @@ class DownloadManager:
# Scale file progress to 3-100 range (after preview download) # Scale file progress to 3-100 range (after preview download)
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
await progress_callback(round(overall_progress)) await progress_callback(round(overall_progress))
async def cancel_download(self, download_id: str) -> Dict:
"""Cancel an active download by download_id
Args:
download_id: The unique identifier of the download task
Returns:
Dict: Status of the cancellation operation
"""
if download_id not in self._download_tasks:
return {'success': False, 'error': 'Download task not found'}
try:
# Get the task and cancel it
task = self._download_tasks[download_id]
task.cancel()
# Update status in active downloads
if download_id in self._active_downloads:
self._active_downloads[download_id]['status'] = 'cancelling'
# Wait briefly for the task to acknowledge cancellation
try:
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
# Clean up partial downloads
download_info = self._active_downloads.get(download_id)
if download_info and 'file_path' in download_info:
# Delete the partial file
file_path = download_info['file_path']
if os.path.exists(file_path):
try:
os.unlink(file_path)
logger.debug(f"Deleted partial download: {file_path}")
except Exception as e:
logger.error(f"Error deleting partial file: {e}")
# Delete metadata file if exists
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
os.unlink(metadata_path)
except Exception as e:
logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4)
for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path):
try:
os.unlink(preview_path)
logger.debug(f"Deleted preview file: {preview_path}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
return {'success': True, 'message': 'Download cancelled successfully'}
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
async def get_active_downloads(self) -> Dict:
"""Get information about all active downloads
Returns:
Dict: List of active downloads and their status
"""
return {
'downloads': [
{
'download_id': task_id,
'model_id': info.get('model_id'),
'model_version_id': info.get('model_version_id'),
'progress': info.get('progress', 0),
'status': info.get('status', 'unknown'),
'error': info.get('error', None)
}
for task_id, info in self._active_downloads.items()
]
}

View File

@@ -611,13 +611,15 @@ class ModelRouteUtils:
use_default_paths = data.get('use_default_paths', False) use_default_paths = data.get('use_default_paths', False)
# Pass the download_id to download_from_civitai
result = await download_manager.download_from_civitai( result = await download_manager.download_from_civitai(
model_id=model_id, model_id=model_id,
model_version_id=model_version_id, model_version_id=model_version_id,
save_dir=data.get('model_root'), save_dir=data.get('model_root'),
relative_path=data.get('relative_path', ''), relative_path=data.get('relative_path', ''),
use_default_paths=use_default_paths, use_default_paths=use_default_paths,
progress_callback=progress_callback progress_callback=progress_callback,
download_id=download_id # Pass download_id explicitly
) )
# Include download_id in the response # Include download_id in the response
@@ -631,12 +633,14 @@ class ModelRouteUtils:
logger.warning(f"Early access download failed: {error_message}") logger.warning(f"Early access download failed: {error_message}")
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': f"Early Access Restriction: {error_message}" 'error': f"Early Access Restriction: {error_message}",
'download_id': download_id
}, status=401) }, status=401)
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': error_message 'error': error_message,
'download_id': download_id
}, status=500) }, status=500)
return web.json_response(result) return web.json_response(result)
@@ -658,6 +662,65 @@ class ModelRouteUtils:
'error': error_message 'error': error_message
}, status=500) }, status=500)
@staticmethod
async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response:
"""Handle cancellation of a download task
Args:
request: The aiohttp request
download_manager: The download manager instance
Returns:
web.Response: The HTTP response
"""
try:
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
result = await download_manager.cancel_download(download_id)
# Notify clients about cancellation via WebSocket
await ws_manager.broadcast_download_progress(download_id, {
'status': 'cancelled',
'progress': 0,
'download_id': download_id,
'message': 'Download cancelled by user'
})
return web.json_response(result)
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response:
"""Get list of active downloads
Args:
request: The aiohttp request
download_manager: The download manager instance
Returns:
web.Response: The HTTP response with list of downloads
"""
try:
result = await download_manager.get_active_downloads()
return web.json_response(result)
except Exception as e:
logger.error(f"Error listing downloads: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod @staticmethod
async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response: async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response:
"""Handle bulk deletion of models """Handle bulk deletion of models