mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 13:42:12 -03:00
686 lines
32 KiB
Python
686 lines
32 KiB
Python
import logging
|
|
import os
|
|
import asyncio
|
|
from collections import OrderedDict
|
|
import uuid
|
|
from typing import Dict
|
|
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES, CIVITAI_MODEL_TAGS
|
|
from ..utils.exif_utils import ExifUtils
|
|
from ..utils.metadata_manager import MetadataManager
|
|
from .service_registry import ServiceRegistry
|
|
from .settings_manager import settings
|
|
from .metadata_service import get_default_metadata_provider
|
|
from .downloader import get_downloader
|
|
|
|
# Download to temporary file first
|
|
import tempfile
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DownloadManager:
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
async def get_instance(cls):
|
|
"""Get singleton instance of DownloadManager"""
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
# Check if already initialized for singleton pattern
|
|
if hasattr(self, '_initialized'):
|
|
return
|
|
self._initialized = True
|
|
|
|
# Add download management
|
|
self._active_downloads = OrderedDict() # download_id -> download_info
|
|
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
|
self._download_tasks = {} # download_id -> asyncio.Task
|
|
|
|
async def _get_lora_scanner(self):
|
|
"""Get the lora scanner from registry"""
|
|
return await ServiceRegistry.get_lora_scanner()
|
|
|
|
async def _get_checkpoint_scanner(self):
|
|
"""Get the checkpoint scanner from registry"""
|
|
return await ServiceRegistry.get_checkpoint_scanner()
|
|
|
|
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
|
save_dir: str = None, relative_path: str = '',
|
|
progress_callback=None, use_default_paths: bool = False,
|
|
download_id: str = None, source: str = None) -> Dict:
|
|
"""Download model from Civitai with task tracking and concurrency control
|
|
|
|
Args:
|
|
model_id: Civitai model ID (optional if model_version_id is provided)
|
|
model_version_id: Civitai model version ID (optional if model_id is provided)
|
|
save_dir: Directory to save the model
|
|
relative_path: Relative path within save_dir
|
|
progress_callback: Callback function for progress updates
|
|
use_default_paths: Flag to use default paths
|
|
download_id: Unique identifier for this download task
|
|
source: Optional source parameter to specify metadata provider
|
|
|
|
Returns:
|
|
Dict with download result
|
|
"""
|
|
# Validate that at least one identifier is provided
|
|
if not model_id and not model_version_id:
|
|
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
|
|
|
|
# Use provided download_id or generate new one
|
|
task_id = download_id or str(uuid.uuid4())
|
|
|
|
# Register download task in tracking dict
|
|
self._active_downloads[task_id] = {
|
|
'model_id': model_id,
|
|
'model_version_id': model_version_id,
|
|
'progress': 0,
|
|
'status': 'queued'
|
|
}
|
|
|
|
# Create tracking task
|
|
download_task = asyncio.create_task(
|
|
self._download_with_semaphore(
|
|
task_id, model_id, model_version_id, save_dir,
|
|
relative_path, progress_callback, use_default_paths, source
|
|
)
|
|
)
|
|
|
|
# Store task for tracking and cancellation
|
|
self._download_tasks[task_id] = download_task
|
|
|
|
try:
|
|
# Wait for download to complete
|
|
result = await download_task
|
|
result['download_id'] = task_id # Include download_id in result
|
|
return result
|
|
except asyncio.CancelledError:
|
|
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
|
|
finally:
|
|
# Clean up task reference
|
|
if task_id in self._download_tasks:
|
|
del self._download_tasks[task_id]
|
|
|
|
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
|
save_dir: str, relative_path: str,
|
|
progress_callback=None, use_default_paths: bool = False,
|
|
source: str = None):
|
|
"""Execute download with semaphore to limit concurrency"""
|
|
# Update status to waiting
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'waiting'
|
|
|
|
# Wrap progress callback to track progress in active_downloads
|
|
original_callback = progress_callback
|
|
async def tracking_callback(progress):
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['progress'] = progress
|
|
if original_callback:
|
|
await original_callback(progress)
|
|
|
|
# Acquire semaphore to limit concurrent downloads
|
|
try:
|
|
async with self._download_semaphore:
|
|
# Update status to downloading
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'downloading'
|
|
|
|
# Use original download implementation
|
|
try:
|
|
# Check for cancellation before starting
|
|
if asyncio.current_task().cancelled():
|
|
raise asyncio.CancelledError()
|
|
|
|
result = await self._execute_original_download(
|
|
model_id, model_version_id, save_dir,
|
|
relative_path, tracking_callback, use_default_paths,
|
|
task_id, source
|
|
)
|
|
|
|
# Update status based on result
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
|
if not result['success']:
|
|
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
|
|
|
return result
|
|
except asyncio.CancelledError:
|
|
# Handle cancellation
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'cancelled'
|
|
logger.info(f"Download cancelled for task {task_id}")
|
|
raise
|
|
except Exception as e:
|
|
# Handle other errors
|
|
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'failed'
|
|
self._active_downloads[task_id]['error'] = str(e)
|
|
return {'success': False, 'error': str(e)}
|
|
finally:
|
|
# Schedule cleanup of download record after delay
|
|
asyncio.create_task(self._cleanup_download_record(task_id))
|
|
|
|
async def _cleanup_download_record(self, task_id: str):
|
|
"""Keep completed downloads in history for a short time"""
|
|
await asyncio.sleep(600) # Keep for 10 minutes
|
|
if task_id in self._active_downloads:
|
|
del self._active_downloads[task_id]
|
|
|
|
async def _execute_original_download(self, model_id, model_version_id, save_dir,
|
|
relative_path, progress_callback, use_default_paths,
|
|
download_id=None, source=None):
|
|
"""Wrapper for original download_from_civitai implementation"""
|
|
try:
|
|
# Check if model version already exists in library
|
|
if model_version_id is not None:
|
|
# Check both scanners
|
|
lora_scanner = await self._get_lora_scanner()
|
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
|
|
|
# Check lora scanner first
|
|
if await lora_scanner.check_model_version_exists(model_version_id):
|
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
|
|
|
# Check checkpoint scanner
|
|
if await checkpoint_scanner.check_model_version_exists(model_version_id):
|
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
|
|
|
# Check embedding scanner
|
|
if await embedding_scanner.check_model_version_exists(model_version_id):
|
|
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
|
|
|
# Get metadata provider based on source parameter
|
|
if source == 'civarchive':
|
|
from .metadata_service import get_metadata_provider
|
|
metadata_provider = await get_metadata_provider('civarchive')
|
|
else:
|
|
metadata_provider = await get_default_metadata_provider()
|
|
|
|
# Get version info based on the provided identifier
|
|
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
|
|
|
|
if not version_info:
|
|
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
|
|
|
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
|
if model_type_from_info == 'checkpoint':
|
|
model_type = 'checkpoint'
|
|
elif model_type_from_info in VALID_LORA_TYPES:
|
|
model_type = 'lora'
|
|
elif model_type_from_info == 'textualinversion':
|
|
model_type = 'embedding'
|
|
else:
|
|
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
|
|
|
# Case 2: model_version_id was None, check after getting version_info
|
|
if model_version_id is None:
|
|
version_id = version_info.get('id')
|
|
|
|
if model_type == 'lora':
|
|
# Check lora scanner
|
|
lora_scanner = await self._get_lora_scanner()
|
|
if await lora_scanner.check_model_version_exists(version_id):
|
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
|
elif model_type == 'checkpoint':
|
|
# Check checkpoint scanner
|
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
|
if await checkpoint_scanner.check_model_version_exists(version_id):
|
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
|
elif model_type == 'embedding':
|
|
# Embeddings are not checked in scanners, but we can still check if it exists
|
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
|
if await embedding_scanner.check_model_version_exists(version_id):
|
|
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
|
|
|
# Handle use_default_paths
|
|
if use_default_paths:
|
|
# Set save_dir based on model type
|
|
if model_type == 'checkpoint':
|
|
default_path = settings.get('default_checkpoint_root')
|
|
if not default_path:
|
|
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
|
save_dir = default_path
|
|
elif model_type == 'lora':
|
|
default_path = settings.get('default_lora_root')
|
|
if not default_path:
|
|
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
|
save_dir = default_path
|
|
elif model_type == 'embedding':
|
|
default_path = settings.get('default_embedding_root')
|
|
if not default_path:
|
|
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
|
save_dir = default_path
|
|
|
|
# Calculate relative path using template
|
|
relative_path = self._calculate_relative_path(version_info, model_type)
|
|
|
|
# Update save directory with relative path if provided
|
|
if relative_path:
|
|
save_dir = os.path.join(save_dir, relative_path)
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# Check if this is an early access model
|
|
if version_info.get('earlyAccessEndsAt'):
|
|
early_access_date = version_info.get('earlyAccessEndsAt', '')
|
|
# Convert to a readable date if possible
|
|
try:
|
|
from datetime import datetime
|
|
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
|
|
formatted_date = date_obj.strftime('%Y-%m-%d')
|
|
early_access_msg = f"This model requires payment (until {formatted_date}). "
|
|
except:
|
|
early_access_msg = "This model requires payment. "
|
|
|
|
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
|
|
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
|
|
|
|
# We'll still try to download, but log a warning and prepare for potential failure
|
|
if progress_callback:
|
|
await progress_callback(1) # Show minimal progress to indicate we're trying
|
|
|
|
# Report initial progress
|
|
if progress_callback:
|
|
await progress_callback(0)
|
|
|
|
# 2. Get file information
|
|
file_info = next((f for f in version_info.get('files', []) if f.get('primary')), None)
|
|
if not file_info:
|
|
return {'success': False, 'error': 'No primary file found in metadata'}
|
|
if not file_info.get('downloadUrl'):
|
|
return {'success': False, 'error': 'No download URL found for primary file'}
|
|
|
|
# 3. Prepare download
|
|
file_name = file_info['name']
|
|
save_path = os.path.join(save_dir, file_name)
|
|
|
|
# 5. Prepare metadata based on model type
|
|
if model_type == "checkpoint":
|
|
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
|
logger.info(f"Creating CheckpointMetadata for {file_name}")
|
|
elif model_type == "lora":
|
|
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
|
logger.info(f"Creating LoraMetadata for {file_name}")
|
|
elif model_type == "embedding":
|
|
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
|
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
|
|
|
# 6. Start download process
|
|
result = await self._execute_download(
|
|
download_url=file_info.get('downloadUrl', ''),
|
|
save_dir=save_dir,
|
|
metadata=metadata,
|
|
version_info=version_info,
|
|
relative_path=relative_path,
|
|
progress_callback=progress_callback,
|
|
model_type=model_type,
|
|
download_id=download_id
|
|
)
|
|
|
|
# If early_access_msg exists and download failed, replace error message
|
|
if 'early_access_msg' in locals() and not result.get('success', False):
|
|
result['error'] = early_access_msg
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
|
|
# Check if this might be an early access error
|
|
error_str = str(e).lower()
|
|
if "403" in error_str or "401" in error_str or "unauthorized" in error_str or "early access" in error_str:
|
|
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str:
|
|
"""Calculate relative path using template from settings
|
|
|
|
Args:
|
|
version_info: Version info from Civitai API
|
|
model_type: Type of model ('lora', 'checkpoint', 'embedding')
|
|
|
|
Returns:
|
|
Relative path string
|
|
"""
|
|
# Get path template from settings for specific model type
|
|
path_template = settings.get_download_path_template(model_type)
|
|
|
|
# If template is empty, return empty path (flat structure)
|
|
if not path_template:
|
|
return ''
|
|
|
|
# Get base model name
|
|
base_model = version_info.get('baseModel', '')
|
|
|
|
# Get author from creator data
|
|
creator_info = version_info.get('creator')
|
|
if creator_info and isinstance(creator_info, dict):
|
|
author = creator_info.get('username') or 'Anonymous'
|
|
else:
|
|
author = 'Anonymous'
|
|
|
|
# Apply mapping if available
|
|
base_model_mappings = settings.get('base_model_path_mappings', {})
|
|
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
|
|
|
# Get model tags
|
|
model_tags = version_info.get('model', {}).get('tags', [])
|
|
|
|
# Find the first Civitai model tag that exists in model_tags
|
|
first_tag = ''
|
|
for civitai_tag in CIVITAI_MODEL_TAGS:
|
|
if civitai_tag in model_tags:
|
|
first_tag = civitai_tag
|
|
break
|
|
|
|
# If no Civitai model tag found, fallback to first tag
|
|
if not first_tag and model_tags:
|
|
first_tag = model_tags[0]
|
|
|
|
# Format the template with available data
|
|
formatted_path = path_template
|
|
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
|
|
formatted_path = formatted_path.replace('{first_tag}', first_tag)
|
|
formatted_path = formatted_path.replace('{author}', author)
|
|
|
|
return formatted_path
|
|
|
|
async def _execute_download(self, download_url: str, save_dir: str,
|
|
metadata, version_info: Dict,
|
|
relative_path: str, progress_callback=None,
|
|
model_type: str = "lora", download_id: str = None) -> Dict:
|
|
"""Execute the actual download process including preview images and model files"""
|
|
try:
|
|
# Extract original filename details
|
|
original_filename = os.path.basename(metadata.file_path)
|
|
base_name, extension = os.path.splitext(original_filename)
|
|
|
|
# Check for filename conflicts and generate unique filename if needed
|
|
# Use the hash from metadata for conflict resolution
|
|
def hash_provider():
|
|
return metadata.sha256
|
|
|
|
unique_filename = metadata.generate_unique_filename(
|
|
save_dir,
|
|
base_name,
|
|
extension,
|
|
hash_provider=hash_provider
|
|
)
|
|
|
|
# Update paths if filename changed
|
|
if unique_filename != original_filename:
|
|
logger.info(f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'")
|
|
save_path = os.path.join(save_dir, unique_filename)
|
|
# Update metadata with new file path and name
|
|
metadata.file_path = save_path.replace(os.sep, '/')
|
|
metadata.file_name = os.path.splitext(unique_filename)[0]
|
|
else:
|
|
save_path = metadata.file_path
|
|
|
|
part_path = save_path + '.part'
|
|
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
|
|
|
# Store file paths in active_downloads for potential cleanup
|
|
if download_id and download_id in self._active_downloads:
|
|
self._active_downloads[download_id]['file_path'] = save_path
|
|
self._active_downloads[download_id]['part_path'] = part_path
|
|
|
|
# Download preview image if available
|
|
images = version_info.get('images', [])
|
|
if images:
|
|
# Report preview download progress
|
|
if progress_callback:
|
|
await progress_callback(1) # 1% progress for starting preview download
|
|
|
|
# Check if it's a video or an image
|
|
is_video = images[0].get('type') == 'video'
|
|
|
|
if (is_video):
|
|
# For videos, use .mp4 extension
|
|
preview_ext = '.mp4'
|
|
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
|
|
|
# Download video directly using downloader
|
|
downloader = await get_downloader()
|
|
success, result = await downloader.download_file(
|
|
images[0]['url'],
|
|
preview_path,
|
|
use_auth=False # Preview images typically don't need auth
|
|
)
|
|
if success:
|
|
metadata.preview_url = preview_path.replace(os.sep, '/')
|
|
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
|
else:
|
|
# For images, use WebP format for better performance
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
|
temp_path = temp_file.name
|
|
|
|
# Download the original image to temp path using downloader
|
|
downloader = await get_downloader()
|
|
success, content, headers = await downloader.download_to_memory(
|
|
images[0]['url'],
|
|
use_auth=False
|
|
)
|
|
if success:
|
|
# Save to temp file
|
|
with open(temp_path, 'wb') as f:
|
|
f.write(content)
|
|
# Optimize and convert to WebP
|
|
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
|
|
|
# Use ExifUtils to optimize and convert the image
|
|
optimized_data, _ = ExifUtils.optimize_image(
|
|
image_data=temp_path,
|
|
target_width=CARD_PREVIEW_WIDTH,
|
|
format='webp',
|
|
quality=85,
|
|
preserve_metadata=False
|
|
)
|
|
|
|
# Save the optimized image
|
|
with open(preview_path, 'wb') as f:
|
|
f.write(optimized_data)
|
|
|
|
# Update metadata
|
|
metadata.preview_url = preview_path.replace(os.sep, '/')
|
|
metadata.preview_nsfw_level = images[0].get('nsfwLevel', 0)
|
|
|
|
# Remove temporary file
|
|
try:
|
|
os.unlink(temp_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete temp file: {e}")
|
|
|
|
# Report preview download completion
|
|
if progress_callback:
|
|
await progress_callback(3) # 3% progress after preview download
|
|
|
|
# Download model file with progress tracking using downloader
|
|
downloader = await get_downloader()
|
|
# Determine if the download URL is from Civitai
|
|
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
|
success, result = await downloader.download_file(
|
|
download_url,
|
|
save_path, # Use full path instead of separate dir and filename
|
|
progress_callback=lambda p: self._handle_download_progress(p, progress_callback),
|
|
use_auth=use_auth # Only use authentication for Civitai downloads
|
|
)
|
|
|
|
if not success:
|
|
# Clean up files on failure, but preserve .part file for resume
|
|
cleanup_files = [metadata_path]
|
|
if metadata.preview_url and os.path.exists(metadata.preview_url):
|
|
cleanup_files.append(metadata.preview_url)
|
|
|
|
for path in cleanup_files:
|
|
if path and os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup file {path}: {e}")
|
|
|
|
# Log but don't remove .part file to allow resume
|
|
if os.path.exists(part_path):
|
|
logger.info(f"Preserving partial download for resume: {part_path}")
|
|
|
|
return {'success': False, 'error': result}
|
|
|
|
# 4. Update file information (size and modified time)
|
|
metadata.update_file_info(save_path)
|
|
|
|
# 5. Final metadata update
|
|
await MetadataManager.save_metadata(save_path, metadata)
|
|
|
|
# 6. Update cache based on model type
|
|
if model_type == "checkpoint":
|
|
scanner = await self._get_checkpoint_scanner()
|
|
logger.info(f"Updating checkpoint cache for {save_path}")
|
|
elif model_type == "lora":
|
|
scanner = await self._get_lora_scanner()
|
|
logger.info(f"Updating lora cache for {save_path}")
|
|
elif model_type == "embedding":
|
|
scanner = await ServiceRegistry.get_embedding_scanner()
|
|
logger.info(f"Updating embedding cache for {save_path}")
|
|
|
|
# Convert metadata to dictionary
|
|
metadata_dict = metadata.to_dict()
|
|
|
|
# Add model to cache and save to disk in a single operation
|
|
await scanner.add_model_to_cache(metadata_dict, relative_path)
|
|
|
|
# Report 100% completion
|
|
if progress_callback:
|
|
await progress_callback(100)
|
|
|
|
return {
|
|
'success': True
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
|
# Clean up partial downloads except .part file
|
|
cleanup_files = [metadata_path]
|
|
if hasattr(metadata, 'preview_url') and metadata.preview_url and os.path.exists(metadata.preview_url):
|
|
cleanup_files.append(metadata.preview_url)
|
|
|
|
for path in cleanup_files:
|
|
if path and os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup file {path}: {e}")
|
|
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
async def _handle_download_progress(self, file_progress: float, progress_callback):
|
|
"""Convert file download progress to overall progress
|
|
|
|
Args:
|
|
file_progress: Progress of file download (0-100)
|
|
progress_callback: Callback function for progress updates
|
|
"""
|
|
if progress_callback:
|
|
# Scale file progress to 3-100 range (after preview download)
|
|
overall_progress = 3 + (file_progress * 0.97) # 97% of progress for file download
|
|
await progress_callback(round(overall_progress))
|
|
|
|
async def cancel_download(self, download_id: str) -> Dict:
|
|
"""Cancel an active download by download_id
|
|
|
|
Args:
|
|
download_id: The unique identifier of the download task
|
|
|
|
Returns:
|
|
Dict: Status of the cancellation operation
|
|
"""
|
|
if download_id not in self._download_tasks:
|
|
return {'success': False, 'error': 'Download task not found'}
|
|
|
|
try:
|
|
# Get the task and cancel it
|
|
task = self._download_tasks[download_id]
|
|
task.cancel()
|
|
|
|
# Update status in active downloads
|
|
if download_id in self._active_downloads:
|
|
self._active_downloads[download_id]['status'] = 'cancelling'
|
|
|
|
# Wait briefly for the task to acknowledge cancellation
|
|
try:
|
|
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
|
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
pass
|
|
|
|
# Clean up ALL files including .part when user cancels
|
|
download_info = self._active_downloads.get(download_id)
|
|
if download_info:
|
|
# Delete the main file
|
|
if 'file_path' in download_info:
|
|
file_path = download_info['file_path']
|
|
if os.path.exists(file_path):
|
|
try:
|
|
os.unlink(file_path)
|
|
logger.debug(f"Deleted cancelled download: {file_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting file: {e}")
|
|
|
|
# Delete the .part file (only on user cancellation)
|
|
if 'part_path' in download_info:
|
|
part_path = download_info['part_path']
|
|
if os.path.exists(part_path):
|
|
try:
|
|
os.unlink(part_path)
|
|
logger.debug(f"Deleted partial download: {part_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting part file: {e}")
|
|
|
|
# Delete metadata file if exists
|
|
if 'file_path' in download_info:
|
|
file_path = download_info['file_path']
|
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
os.unlink(metadata_path)
|
|
except Exception as e:
|
|
logger.error(f"Error deleting metadata file: {e}")
|
|
|
|
# Delete preview file if exists (.webp or .mp4)
|
|
for preview_ext in ['.webp', '.mp4']:
|
|
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
|
if os.path.exists(preview_path):
|
|
try:
|
|
os.unlink(preview_path)
|
|
logger.debug(f"Deleted preview file: {preview_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting preview file: {e}")
|
|
|
|
return {'success': True, 'message': 'Download cancelled successfully'}
|
|
except Exception as e:
|
|
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
async def get_active_downloads(self) -> Dict:
|
|
"""Get information about all active downloads
|
|
|
|
Returns:
|
|
Dict: List of active downloads and their status
|
|
"""
|
|
return {
|
|
'downloads': [
|
|
{
|
|
'download_id': task_id,
|
|
'model_id': info.get('model_id'),
|
|
'model_version_id': info.get('model_version_id'),
|
|
'progress': info.get('progress', 0),
|
|
'status': info.get('status', 'unknown'),
|
|
'error': info.get('error', None)
|
|
}
|
|
for task_id, info in self._active_downloads.items()
|
|
]
|
|
} |