feat: Add download cancellation and tracking features in DownloadManager and API routes

This commit is contained in:
Will Miao
2025-07-21 15:38:20 +08:00
parent fa444dfb8a
commit de06c6b2f6
4 changed files with 305 additions and 20 deletions

View File

@@ -59,6 +59,7 @@ class ApiRoutes:
app.router.add_get('/api/civitai/model/hash/{hash}', routes.get_civitai_model_by_hash)
app.router.add_post('/api/download-model', routes.download_model)
app.router.add_get('/api/download-model-get', routes.download_model_get) # Add new GET endpoint
app.router.add_get('/api/cancel-download-get', routes.cancel_download_get)
app.router.add_get('/api/download-progress/{download_id}', routes.get_download_progress) # Add new endpoint for download progress
app.router.add_post('/api/move_model', routes.move_model)
app.router.add_get('/api/lora-model-description', routes.get_lora_model_description) # Add new route
@@ -500,6 +501,29 @@ class ApiRoutes:
error_message = str(e)
logger.error(f"Error downloading model via GET: {error_message}", exc_info=True)
return web.Response(status=500, text=error_message)
async def cancel_download_get(self, request: web.Request) -> web.Response:
"""Handle GET request for cancelling a download by download_id"""
try:
download_id = request.query.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
if self.download_manager is None:
self.download_manager = await ServiceRegistry.get_download_manager()
# Create a mock request with match_info for compatibility
mock_request = type('MockRequest', (), {
'match_info': {'download_id': download_id}
})()
return await ModelRouteUtils.handle_cancel_download(mock_request, self.download_manager)
except Exception as e:
logger.error(f"Error cancelling download via GET: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
async def get_download_progress(self, request: web.Request) -> web.Response:
"""Handle request for download progress by download_id"""

View File

@@ -45,14 +45,14 @@ class CivitaiClient:
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=3, # Further reduced from 5 to 3
ttl_dns_cache=0, # Disabled DNS caching completely
limit=8, # Increase from 3 to 8 for better parallelism
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
# Configure timeout parameters - increase read timeout for large files
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=120)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=trust_env,
@@ -165,7 +165,7 @@ class CivitaiClient:
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 0.5:
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now

View File

@@ -1,6 +1,8 @@
import logging
import os
import asyncio
from collections import OrderedDict
import uuid
from typing import Dict
from ..utils.models import LoraMetadata, CheckpointMetadata
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
@@ -33,6 +35,10 @@ class DownloadManager:
self._initialized = True
self._civitai_client = None # Will be lazily initialized
# Add download management
self._active_downloads = OrderedDict() # download_id -> download_info
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
self._download_tasks = {} # download_id -> asyncio.Task
async def _get_civitai_client(self):
"""Lazily initialize CivitaiClient from registry"""
@@ -47,27 +53,132 @@ class DownloadManager:
async def _get_checkpoint_scanner(self):
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, model_id: int,
model_version_id: int, save_dir: str = None,
relative_path: str = '', progress_callback=None, use_default_paths: bool = False) -> Dict:
"""Download model from Civitai
async def download_from_civitai(self, model_id: int, model_version_id: int,
save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False,
download_id: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control
Args:
model_id: Civitai model ID
model_version_id: Civitai model version ID (optional, if not provided, will download the latest version)
save_dir: Directory to save the model to
model_version_id: Civitai model version ID
save_dir: Directory to save the model
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
use_default_paths: Flag to indicate whether to use default paths
use_default_paths: Flag to use default paths
download_id: Unique identifier for this download task
Returns:
Dict with download result
"""
# Use provided download_id or generate new one
task_id = download_id or str(uuid.uuid4())
# Register download task in tracking dict
self._active_downloads[task_id] = {
'model_id': model_id,
'model_version_id': model_version_id,
'progress': 0,
'status': 'queued'
}
# Create tracking task
download_task = asyncio.create_task(
self._download_with_semaphore(
task_id, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths
)
)
# Store task for tracking and cancellation
self._download_tasks[task_id] = download_task
try:
# Wait for download to complete
result = await download_task
result['download_id'] = task_id # Include download_id in result
return result
except asyncio.CancelledError:
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
finally:
# Clean up task reference
if task_id in self._download_tasks:
del self._download_tasks[task_id]
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
save_dir: str, relative_path: str,
progress_callback=None, use_default_paths: bool = False):
"""Execute download with semaphore to limit concurrency"""
# Update status to waiting
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'waiting'
# Wrap progress callback to track progress in active_downloads
original_callback = progress_callback
async def tracking_callback(progress):
if task_id in self._active_downloads:
self._active_downloads[task_id]['progress'] = progress
if original_callback:
await original_callback(progress)
# Acquire semaphore to limit concurrent downloads
try:
async with self._download_semaphore:
# Update status to downloading
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'downloading'
# Use original download implementation
try:
# Check for cancellation before starting
if asyncio.current_task().cancelled():
raise asyncio.CancelledError()
result = await self._execute_original_download(
model_id, model_version_id, save_dir,
relative_path, tracking_callback, use_default_paths,
task_id
)
# Update status based on result
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
if not result['success']:
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
return result
except asyncio.CancelledError:
# Handle cancellation
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'cancelled'
logger.info(f"Download cancelled for task {task_id}")
raise
except Exception as e:
# Handle other errors
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
if task_id in self._active_downloads:
self._active_downloads[task_id]['status'] = 'failed'
self._active_downloads[task_id]['error'] = str(e)
return {'success': False, 'error': str(e)}
finally:
# Schedule cleanup of download record after delay
asyncio.create_task(self._cleanup_download_record(task_id))
async def _cleanup_download_record(self, task_id: str):
"""Keep completed downloads in history for a short time"""
await asyncio.sleep(600) # Keep for 10 minutes
if task_id in self._active_downloads:
del self._active_downloads[task_id]
async def _execute_original_download(self, model_id, model_version_id, save_dir,
relative_path, progress_callback, use_default_paths,
download_id=None):
"""Wrapper for original download_from_civitai implementation"""
try:
# Check if model version already exists in library
if model_version_id is not None:
# Case 1: model_version_id is provided, check both scanners
# Check both scanners
lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner()
@@ -183,7 +294,8 @@ class DownloadManager:
version_info=version_info,
relative_path=relative_path,
progress_callback=progress_callback,
model_type=model_type
model_type=model_type,
download_id=download_id
)
return result
@@ -243,12 +355,16 @@ class DownloadManager:
async def _execute_download(self, download_url: str, save_dir: str,
metadata, version_info: Dict,
relative_path: str, progress_callback=None,
model_type: str = "lora") -> Dict:
model_type: str = "lora", download_id: str = None) -> Dict:
"""Execute the actual download process including preview images and model files"""
try:
civitai_client = await self._get_civitai_client()
save_path = metadata.file_path
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
# Store file path in active_downloads for potential cleanup
if download_id and download_id in self._active_downloads:
self._active_downloads[download_id]['file_path'] = save_path
# Download preview image if available
images = version_info.get('images', [])
@@ -367,4 +483,86 @@ class DownloadManager:
if progress_callback:
# Scale file progress to 3-100 range (after preview download)
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
await progress_callback(round(overall_progress))
await progress_callback(round(overall_progress))
async def cancel_download(self, download_id: str) -> Dict:
"""Cancel an active download by download_id
Args:
download_id: The unique identifier of the download task
Returns:
Dict: Status of the cancellation operation
"""
if download_id not in self._download_tasks:
return {'success': False, 'error': 'Download task not found'}
try:
# Get the task and cancel it
task = self._download_tasks[download_id]
task.cancel()
# Update status in active downloads
if download_id in self._active_downloads:
self._active_downloads[download_id]['status'] = 'cancelling'
# Wait briefly for the task to acknowledge cancellation
try:
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
# Clean up partial downloads
download_info = self._active_downloads.get(download_id)
if download_info and 'file_path' in download_info:
# Delete the partial file
file_path = download_info['file_path']
if os.path.exists(file_path):
try:
os.unlink(file_path)
logger.debug(f"Deleted partial download: {file_path}")
except Exception as e:
logger.error(f"Error deleting partial file: {e}")
# Delete metadata file if exists
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
if os.path.exists(metadata_path):
try:
os.unlink(metadata_path)
except Exception as e:
logger.error(f"Error deleting metadata file: {e}")
# Delete preview file if exists (.webp or .mp4)
for preview_ext in ['.webp', '.mp4']:
preview_path = os.path.splitext(file_path)[0] + preview_ext
if os.path.exists(preview_path):
try:
os.unlink(preview_path)
logger.debug(f"Deleted preview file: {preview_path}")
except Exception as e:
logger.error(f"Error deleting preview file: {e}")
return {'success': True, 'message': 'Download cancelled successfully'}
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return {'success': False, 'error': str(e)}
async def get_active_downloads(self) -> Dict:
"""Get information about all active downloads
Returns:
Dict: List of active downloads and their status
"""
return {
'downloads': [
{
'download_id': task_id,
'model_id': info.get('model_id'),
'model_version_id': info.get('model_version_id'),
'progress': info.get('progress', 0),
'status': info.get('status', 'unknown'),
'error': info.get('error', None)
}
for task_id, info in self._active_downloads.items()
]
}

View File

@@ -611,13 +611,15 @@ class ModelRouteUtils:
use_default_paths = data.get('use_default_paths', False)
# Pass the download_id to download_from_civitai
result = await download_manager.download_from_civitai(
model_id=model_id,
model_version_id=model_version_id,
save_dir=data.get('model_root'),
relative_path=data.get('relative_path', ''),
use_default_paths=use_default_paths,
progress_callback=progress_callback
progress_callback=progress_callback,
download_id=download_id # Pass download_id explicitly
)
# Include download_id in the response
@@ -631,12 +633,14 @@ class ModelRouteUtils:
logger.warning(f"Early access download failed: {error_message}")
return web.json_response({
'success': False,
'error': f"Early Access Restriction: {error_message}"
'error': f"Early Access Restriction: {error_message}",
'download_id': download_id
}, status=401)
return web.json_response({
'success': False,
'error': error_message
'error': error_message,
'download_id': download_id
}, status=500)
return web.json_response(result)
@@ -658,6 +662,65 @@ class ModelRouteUtils:
'error': error_message
}, status=500)
@staticmethod
async def handle_cancel_download(request: web.Request, download_manager: DownloadManager) -> web.Response:
"""Handle cancellation of a download task
Args:
request: The aiohttp request
download_manager: The download manager instance
Returns:
web.Response: The HTTP response
"""
try:
download_id = request.match_info.get('download_id')
if not download_id:
return web.json_response({
'success': False,
'error': 'Download ID is required'
}, status=400)
result = await download_manager.cancel_download(download_id)
# Notify clients about cancellation via WebSocket
await ws_manager.broadcast_download_progress(download_id, {
'status': 'cancelled',
'progress': 0,
'download_id': download_id,
'message': 'Download cancelled by user'
})
return web.json_response(result)
except Exception as e:
logger.error(f"Error cancelling download: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_list_downloads(request: web.Request, download_manager: DownloadManager) -> web.Response:
"""Get list of active downloads
Args:
request: The aiohttp request
download_manager: The download manager instance
Returns:
web.Response: The HTTP response with list of downloads
"""
try:
result = await download_manager.get_active_downloads()
return web.json_response(result)
except Exception as e:
logger.error(f"Error listing downloads: {e}", exc_info=True)
return web.json_response({
'success': False,
'error': str(e)
}, status=500)
@staticmethod
async def handle_bulk_delete_models(request: web.Request, scanner) -> web.Response:
"""Handle bulk deletion of models