mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-22 05:32:12 -03:00
1203 lines
52 KiB
Python
1203 lines
52 KiB
Python
import copy
|
|
import logging
|
|
import os
|
|
import asyncio
|
|
import inspect
|
|
import shutil
|
|
import zipfile
|
|
from collections import OrderedDict
|
|
import uuid
|
|
from typing import Dict, List, Optional, Tuple
|
|
from urllib.parse import urlparse
|
|
from ..utils.models import LoraMetadata, CheckpointMetadata, EmbeddingMetadata
|
|
from ..utils.constants import CARD_PREVIEW_WIDTH, VALID_LORA_TYPES
|
|
from ..utils.civitai_utils import rewrite_preview_url
|
|
from ..utils.preview_selection import select_preview_media
|
|
from ..utils.utils import sanitize_folder_name
|
|
from ..utils.exif_utils import ExifUtils
|
|
from ..utils.file_utils import calculate_sha256
|
|
from ..utils.metadata_manager import MetadataManager
|
|
from .service_registry import ServiceRegistry
|
|
from .settings_manager import get_settings_manager
|
|
from .metadata_service import get_default_metadata_provider
|
|
from .downloader import get_downloader, DownloadProgress, DownloadStreamControl
|
|
|
|
# Download to temporary file first
|
|
import tempfile
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DownloadManager:
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
@classmethod
|
|
async def get_instance(cls):
|
|
"""Get singleton instance of DownloadManager"""
|
|
async with cls._lock:
|
|
if cls._instance is None:
|
|
cls._instance = cls()
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
# Check if already initialized for singleton pattern
|
|
if hasattr(self, '_initialized'):
|
|
return
|
|
self._initialized = True
|
|
|
|
# Add download management
|
|
self._active_downloads = OrderedDict() # download_id -> download_info
|
|
self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads
|
|
self._download_tasks = {} # download_id -> asyncio.Task
|
|
self._pause_events: Dict[str, DownloadStreamControl] = {}
|
|
|
|
async def _get_lora_scanner(self):
|
|
"""Get the lora scanner from registry"""
|
|
return await ServiceRegistry.get_lora_scanner()
|
|
|
|
async def _get_checkpoint_scanner(self):
|
|
"""Get the checkpoint scanner from registry"""
|
|
return await ServiceRegistry.get_checkpoint_scanner()
|
|
|
|
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
|
save_dir: str = None, relative_path: str = '',
|
|
progress_callback=None, use_default_paths: bool = False,
|
|
download_id: str = None, source: str = None) -> Dict:
|
|
"""Download model from Civitai with task tracking and concurrency control
|
|
|
|
Args:
|
|
model_id: Civitai model ID (optional if model_version_id is provided)
|
|
model_version_id: Civitai model version ID (optional if model_id is provided)
|
|
save_dir: Directory to save the model
|
|
relative_path: Relative path within save_dir
|
|
progress_callback: Callback function for progress updates
|
|
use_default_paths: Flag to use default paths
|
|
download_id: Unique identifier for this download task
|
|
source: Optional source parameter to specify metadata provider
|
|
|
|
Returns:
|
|
Dict with download result
|
|
"""
|
|
# Validate that at least one identifier is provided
|
|
if not model_id and not model_version_id:
|
|
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
|
|
|
|
# Use provided download_id or generate new one
|
|
task_id = download_id or str(uuid.uuid4())
|
|
|
|
# Register download task in tracking dict
|
|
self._active_downloads[task_id] = {
|
|
'model_id': model_id,
|
|
'model_version_id': model_version_id,
|
|
'progress': 0,
|
|
'status': 'queued',
|
|
'bytes_downloaded': 0,
|
|
'total_bytes': None,
|
|
'bytes_per_second': 0.0,
|
|
'last_progress_timestamp': None,
|
|
}
|
|
|
|
pause_control = DownloadStreamControl()
|
|
self._pause_events[task_id] = pause_control
|
|
|
|
# Create tracking task
|
|
download_task = asyncio.create_task(
|
|
self._download_with_semaphore(
|
|
task_id, model_id, model_version_id, save_dir,
|
|
relative_path, progress_callback, use_default_paths, source
|
|
)
|
|
)
|
|
|
|
# Store task for tracking and cancellation
|
|
self._download_tasks[task_id] = download_task
|
|
|
|
try:
|
|
# Wait for download to complete
|
|
result = await download_task
|
|
result['download_id'] = task_id # Include download_id in result
|
|
return result
|
|
except asyncio.CancelledError:
|
|
return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id}
|
|
finally:
|
|
# Clean up task reference
|
|
if task_id in self._download_tasks:
|
|
del self._download_tasks[task_id]
|
|
self._pause_events.pop(task_id, None)
|
|
|
|
async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int,
|
|
save_dir: str, relative_path: str,
|
|
progress_callback=None, use_default_paths: bool = False,
|
|
source: str = None):
|
|
"""Execute download with semaphore to limit concurrency"""
|
|
# Update status to waiting
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'waiting'
|
|
|
|
# Wrap progress callback to track progress in active_downloads
|
|
original_callback = progress_callback
|
|
async def tracking_callback(progress, metrics=None):
|
|
progress_value, snapshot = self._normalize_progress(progress, metrics)
|
|
|
|
if task_id in self._active_downloads:
|
|
info = self._active_downloads[task_id]
|
|
info['progress'] = round(progress_value)
|
|
if snapshot is not None:
|
|
info['bytes_downloaded'] = snapshot.bytes_downloaded
|
|
info['total_bytes'] = snapshot.total_bytes
|
|
info['bytes_per_second'] = snapshot.bytes_per_second
|
|
pause_control = self._pause_events.get(task_id)
|
|
if isinstance(pause_control, DownloadStreamControl):
|
|
pause_control.mark_progress(snapshot.timestamp)
|
|
info['last_progress_timestamp'] = pause_control.last_progress_timestamp
|
|
|
|
if original_callback:
|
|
await self._dispatch_progress(original_callback, snapshot, progress_value)
|
|
|
|
# Acquire semaphore to limit concurrent downloads
|
|
try:
|
|
async with self._download_semaphore:
|
|
pause_control = self._pause_events.get(task_id)
|
|
if pause_control is not None and pause_control.is_paused():
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'paused'
|
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
|
await pause_control.wait()
|
|
|
|
# Update status to downloading
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'downloading'
|
|
|
|
# Use original download implementation
|
|
try:
|
|
# Check for cancellation before starting
|
|
if asyncio.current_task().cancelled():
|
|
raise asyncio.CancelledError()
|
|
|
|
result = await self._execute_original_download(
|
|
model_id, model_version_id, save_dir,
|
|
relative_path, tracking_callback, use_default_paths,
|
|
task_id, source
|
|
)
|
|
|
|
# Update status based on result
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed'
|
|
if not result['success']:
|
|
self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error')
|
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
|
|
|
return result
|
|
except asyncio.CancelledError:
|
|
# Handle cancellation
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'cancelled'
|
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
|
logger.info(f"Download cancelled for task {task_id}")
|
|
raise
|
|
except Exception as e:
|
|
# Handle other errors
|
|
logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True)
|
|
if task_id in self._active_downloads:
|
|
self._active_downloads[task_id]['status'] = 'failed'
|
|
self._active_downloads[task_id]['error'] = str(e)
|
|
self._active_downloads[task_id]['bytes_per_second'] = 0.0
|
|
return {'success': False, 'error': str(e)}
|
|
finally:
|
|
# Schedule cleanup of download record after delay
|
|
asyncio.create_task(self._cleanup_download_record(task_id))
|
|
|
|
async def _cleanup_download_record(self, task_id: str):
|
|
"""Keep completed downloads in history for a short time"""
|
|
await asyncio.sleep(600) # Keep for 10 minutes
|
|
if task_id in self._active_downloads:
|
|
del self._active_downloads[task_id]
|
|
|
|
async def _execute_original_download(
|
|
self,
|
|
model_id,
|
|
model_version_id,
|
|
save_dir,
|
|
relative_path,
|
|
progress_callback,
|
|
use_default_paths,
|
|
download_id=None,
|
|
source=None,
|
|
):
|
|
"""Wrapper for original download_from_civitai implementation"""
|
|
try:
|
|
# Check if model version already exists in library
|
|
if model_version_id is not None:
|
|
# Check both scanners
|
|
lora_scanner = await self._get_lora_scanner()
|
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
|
|
|
# Check lora scanner first
|
|
if await lora_scanner.check_model_version_exists(model_version_id):
|
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
|
|
|
# Check checkpoint scanner
|
|
if await checkpoint_scanner.check_model_version_exists(model_version_id):
|
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
|
|
|
# Check embedding scanner
|
|
if await embedding_scanner.check_model_version_exists(model_version_id):
|
|
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
|
|
|
metadata_provider = await get_default_metadata_provider()
|
|
|
|
# Get version info based on the provided identifier
|
|
version_info = await metadata_provider.get_model_version(model_id, model_version_id)
|
|
|
|
if not version_info:
|
|
return {'success': False, 'error': 'Failed to fetch model metadata'}
|
|
|
|
model_type_from_info = version_info.get('model', {}).get('type', '').lower()
|
|
if model_type_from_info == 'checkpoint':
|
|
model_type = 'checkpoint'
|
|
elif model_type_from_info in VALID_LORA_TYPES:
|
|
model_type = 'lora'
|
|
elif model_type_from_info == 'textualinversion':
|
|
model_type = 'embedding'
|
|
else:
|
|
return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'}
|
|
|
|
# Case 2: model_version_id was None, check after getting version_info
|
|
if model_version_id is None:
|
|
version_id = version_info.get('id')
|
|
|
|
if model_type == 'lora':
|
|
# Check lora scanner
|
|
lora_scanner = await self._get_lora_scanner()
|
|
if await lora_scanner.check_model_version_exists(version_id):
|
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
|
elif model_type == 'checkpoint':
|
|
# Check checkpoint scanner
|
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
|
if await checkpoint_scanner.check_model_version_exists(version_id):
|
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
|
elif model_type == 'embedding':
|
|
# Embeddings are not checked in scanners, but we can still check if it exists
|
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
|
if await embedding_scanner.check_model_version_exists(version_id):
|
|
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
|
|
|
# Handle use_default_paths
|
|
if use_default_paths:
|
|
settings_manager = get_settings_manager()
|
|
# Set save_dir based on model type
|
|
if model_type == 'checkpoint':
|
|
default_path = settings_manager.get('default_checkpoint_root')
|
|
if not default_path:
|
|
return {'success': False, 'error': 'Default checkpoint root path not set in settings'}
|
|
save_dir = default_path
|
|
elif model_type == 'lora':
|
|
default_path = settings_manager.get('default_lora_root')
|
|
if not default_path:
|
|
return {'success': False, 'error': 'Default lora root path not set in settings'}
|
|
save_dir = default_path
|
|
elif model_type == 'embedding':
|
|
default_path = settings_manager.get('default_embedding_root')
|
|
if not default_path:
|
|
return {'success': False, 'error': 'Default embedding root path not set in settings'}
|
|
save_dir = default_path
|
|
|
|
# Calculate relative path using template
|
|
relative_path = self._calculate_relative_path(version_info, model_type)
|
|
|
|
# Update save directory with relative path if provided
|
|
if relative_path:
|
|
save_dir = os.path.join(save_dir, relative_path)
|
|
# Create directory if it doesn't exist
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
# Check if this is an early access model
|
|
if version_info.get('earlyAccessEndsAt'):
|
|
early_access_date = version_info.get('earlyAccessEndsAt', '')
|
|
# Convert to a readable date if possible
|
|
try:
|
|
from datetime import datetime
|
|
date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00'))
|
|
formatted_date = date_obj.strftime('%Y-%m-%d')
|
|
early_access_msg = f"This model requires payment (until {formatted_date}). "
|
|
except:
|
|
early_access_msg = "This model requires payment. "
|
|
|
|
early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai."
|
|
logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}")
|
|
|
|
# We'll still try to download, but log a warning and prepare for potential failure
|
|
if progress_callback:
|
|
await progress_callback(1) # Show minimal progress to indicate we're trying
|
|
|
|
# Report initial progress
|
|
if progress_callback:
|
|
await progress_callback(0)
|
|
|
|
# 2. Get file information
|
|
file_info = next((f for f in version_info.get('files', []) if f.get('primary') and f.get('type') in ('Model', 'Negative')), None)
|
|
if not file_info:
|
|
return {'success': False, 'error': 'No primary file found in metadata'}
|
|
mirrors = file_info.get('mirrors') or []
|
|
download_urls = []
|
|
if mirrors:
|
|
for mirror in mirrors:
|
|
if mirror.get('deletedAt') is None and mirror.get('url'):
|
|
download_urls.append(mirror['url'])
|
|
else:
|
|
download_url = file_info.get('downloadUrl')
|
|
if download_url:
|
|
download_urls.append(download_url)
|
|
|
|
if not download_urls:
|
|
return {'success': False, 'error': 'No mirror URL found'}
|
|
|
|
# 3. Prepare download
|
|
file_name = file_info['name']
|
|
save_path = os.path.join(save_dir, file_name)
|
|
|
|
# 5. Prepare metadata based on model type
|
|
if model_type == "checkpoint":
|
|
metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path)
|
|
logger.info(f"Creating CheckpointMetadata for {file_name}")
|
|
elif model_type == "lora":
|
|
metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path)
|
|
logger.info(f"Creating LoraMetadata for {file_name}")
|
|
elif model_type == "embedding":
|
|
metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path)
|
|
logger.info(f"Creating EmbeddingMetadata for {file_name}")
|
|
|
|
# 6. Start download process
|
|
result = await self._execute_download(
|
|
download_urls=download_urls,
|
|
save_dir=save_dir,
|
|
metadata=metadata,
|
|
version_info=version_info,
|
|
relative_path=relative_path,
|
|
progress_callback=progress_callback,
|
|
model_type=model_type,
|
|
download_id=download_id,
|
|
)
|
|
|
|
if result.get('success', False):
|
|
resolved_model_id = (
|
|
model_id
|
|
or version_info.get('modelId')
|
|
or (version_info.get('model') or {}).get('id')
|
|
)
|
|
await self._sync_downloaded_version(
|
|
model_type,
|
|
resolved_model_id,
|
|
version_info,
|
|
model_version_id,
|
|
)
|
|
|
|
# If early_access_msg exists and download failed, replace error message
|
|
if 'early_access_msg' in locals() and not result.get('success', False):
|
|
result['error'] = early_access_msg
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in download_from_civitai: {e}", exc_info=True)
|
|
# Check if this might be an early access error
|
|
error_str = str(e).lower()
|
|
if "403" in error_str or "401" in error_str or "unauthorized" in error_str or "early access" in error_str:
|
|
return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."}
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
async def _sync_downloaded_version(
|
|
self,
|
|
model_type: str,
|
|
model_id_value,
|
|
version_info: Dict,
|
|
fallback_version_id=None,
|
|
) -> None:
|
|
"""Ensure update tracking reflects a newly downloaded version."""
|
|
|
|
try:
|
|
update_service = await ServiceRegistry.get_model_update_service()
|
|
except Exception as exc:
|
|
logger.debug("Skipping update sync; failed to acquire update service: %s", exc)
|
|
return
|
|
|
|
if update_service is None:
|
|
return
|
|
|
|
resolved_model_id = model_id_value
|
|
if resolved_model_id is None:
|
|
resolved_model_id = version_info.get('modelId')
|
|
if resolved_model_id is None:
|
|
model_info = version_info.get('model')
|
|
if isinstance(model_info, dict):
|
|
resolved_model_id = model_info.get('id')
|
|
try:
|
|
resolved_model_id = int(resolved_model_id)
|
|
except (TypeError, ValueError):
|
|
logger.debug("Skipping update sync; invalid model id: %s", resolved_model_id)
|
|
return
|
|
|
|
version_id = version_info.get('id')
|
|
if version_id is None:
|
|
version_id = fallback_version_id
|
|
try:
|
|
version_id = int(version_id)
|
|
except (TypeError, ValueError):
|
|
logger.debug(
|
|
"Skipping update sync; invalid version id for model %s: %s",
|
|
resolved_model_id,
|
|
version_id,
|
|
)
|
|
return
|
|
|
|
version_ids = set()
|
|
scanner = None
|
|
try:
|
|
if model_type == 'lora':
|
|
scanner = await self._get_lora_scanner()
|
|
elif model_type == 'checkpoint':
|
|
scanner = await self._get_checkpoint_scanner()
|
|
elif model_type == 'embedding':
|
|
scanner = await ServiceRegistry.get_embedding_scanner()
|
|
except Exception as exc:
|
|
logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc)
|
|
|
|
if scanner is not None:
|
|
try:
|
|
local_versions = await scanner.get_model_versions_by_id(resolved_model_id)
|
|
except Exception as exc:
|
|
logger.debug(
|
|
"Failed to collect local versions for %s model %s: %s",
|
|
model_type,
|
|
resolved_model_id,
|
|
exc,
|
|
)
|
|
else:
|
|
for entry in local_versions or []:
|
|
vid = entry.get('versionId')
|
|
try:
|
|
version_ids.add(int(vid))
|
|
except (TypeError, ValueError):
|
|
continue
|
|
|
|
version_ids.add(version_id)
|
|
|
|
try:
|
|
await update_service.update_in_library_versions(
|
|
model_type,
|
|
resolved_model_id,
|
|
sorted(version_ids),
|
|
)
|
|
except Exception as exc:
|
|
logger.debug(
|
|
"Failed to update in-library versions for %s model %s: %s",
|
|
model_type,
|
|
resolved_model_id,
|
|
exc,
|
|
)
|
|
|
|
def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str:
|
|
"""Calculate relative path using template from settings
|
|
|
|
Args:
|
|
version_info: Version info from Civitai API
|
|
model_type: Type of model ('lora', 'checkpoint', 'embedding')
|
|
|
|
Returns:
|
|
Relative path string
|
|
"""
|
|
# Get path template from settings for specific model type
|
|
settings_manager = get_settings_manager()
|
|
path_template = settings_manager.get_download_path_template(model_type)
|
|
|
|
# If template is empty, return empty path (flat structure)
|
|
if not path_template:
|
|
return ''
|
|
|
|
# Get base model name
|
|
base_model = version_info.get('baseModel', '')
|
|
|
|
# Get author from creator data
|
|
creator_info = version_info.get('creator')
|
|
if creator_info and isinstance(creator_info, dict):
|
|
author = creator_info.get('username') or 'Anonymous'
|
|
else:
|
|
author = 'Anonymous'
|
|
|
|
# Apply mapping if available
|
|
base_model_mappings = settings_manager.get('base_model_path_mappings', {})
|
|
mapped_base_model = base_model_mappings.get(base_model, base_model)
|
|
|
|
model_info = version_info.get('model') or {}
|
|
|
|
# Get model tags
|
|
model_tags = model_info.get('tags', [])
|
|
|
|
first_tag = settings_manager.resolve_priority_tag_for_model(model_tags, model_type)
|
|
|
|
# Format the template with available data
|
|
formatted_path = path_template
|
|
formatted_path = formatted_path.replace('{base_model}', mapped_base_model)
|
|
formatted_path = formatted_path.replace('{first_tag}', first_tag)
|
|
formatted_path = formatted_path.replace('{author}', author)
|
|
formatted_path = formatted_path.replace('{model_name}', sanitize_folder_name(model_info.get('name', '')))
|
|
formatted_path = formatted_path.replace('{version_name}', sanitize_folder_name(version_info.get('name', '')))
|
|
|
|
if model_type == 'embedding':
|
|
formatted_path = formatted_path.replace(' ', '_')
|
|
|
|
return formatted_path
|
|
|
|
async def _execute_download(
|
|
self,
|
|
download_urls: List[str],
|
|
save_dir: str,
|
|
metadata,
|
|
version_info: Dict,
|
|
relative_path: str,
|
|
progress_callback=None,
|
|
model_type: str = "lora",
|
|
download_id: str = None,
|
|
) -> Dict:
|
|
"""Execute the actual download process including preview images and model files"""
|
|
metadata_entries: List = []
|
|
metadata_files_for_cleanup: List[str] = []
|
|
extracted_paths: List[str] = []
|
|
metadata_path = ""
|
|
preview_targets: List[str] = []
|
|
preview_path: str | None = None
|
|
preview_nsfw_level = 0
|
|
try:
|
|
# Extract original filename details
|
|
original_filename = os.path.basename(metadata.file_path)
|
|
base_name, extension = os.path.splitext(original_filename)
|
|
|
|
# Check for filename conflicts and generate unique filename if needed
|
|
# Use the hash from metadata for conflict resolution
|
|
def hash_provider():
|
|
return metadata.sha256
|
|
|
|
unique_filename = metadata.generate_unique_filename(
|
|
save_dir,
|
|
base_name,
|
|
extension,
|
|
hash_provider=hash_provider
|
|
)
|
|
|
|
# Update paths if filename changed
|
|
if unique_filename != original_filename:
|
|
logger.info(f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'")
|
|
save_path = os.path.join(save_dir, unique_filename)
|
|
# Update metadata with new file path and name
|
|
metadata.file_path = save_path.replace(os.sep, '/')
|
|
metadata.file_name = os.path.splitext(unique_filename)[0]
|
|
else:
|
|
save_path = metadata.file_path
|
|
|
|
part_path = save_path + '.part'
|
|
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
|
|
|
|
pause_control = self._pause_events.get(download_id) if download_id else None
|
|
|
|
# Store file paths in active_downloads for potential cleanup
|
|
if download_id and download_id in self._active_downloads:
|
|
self._active_downloads[download_id]['file_path'] = save_path
|
|
self._active_downloads[download_id]['part_path'] = part_path
|
|
|
|
# Download preview image if available
|
|
images = version_info.get('images', [])
|
|
if images:
|
|
if progress_callback:
|
|
await progress_callback(1) # 1% progress for starting preview download
|
|
|
|
settings_manager = get_settings_manager()
|
|
blur_mature_content = bool(
|
|
settings_manager.get('blur_mature_content', True)
|
|
)
|
|
selected_image, nsfw_level = select_preview_media(
|
|
images,
|
|
blur_mature_content=blur_mature_content,
|
|
)
|
|
|
|
preview_url = selected_image.get('url') if selected_image else None
|
|
media_type = (
|
|
(selected_image.get('type') or '').lower()
|
|
if selected_image
|
|
else ''
|
|
)
|
|
|
|
def _extension_from_url(url: str, fallback: str) -> str:
|
|
try:
|
|
parsed = urlparse(url)
|
|
except ValueError:
|
|
return fallback
|
|
ext = os.path.splitext(parsed.path)[1]
|
|
return ext or fallback
|
|
|
|
preview_downloaded = False
|
|
preview_path = None
|
|
|
|
if preview_url:
|
|
downloader = await get_downloader()
|
|
|
|
if media_type == 'video':
|
|
preview_ext = _extension_from_url(preview_url, '.mp4')
|
|
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video')
|
|
attempt_urls: List[str] = []
|
|
if rewritten:
|
|
attempt_urls.append(rewritten_url)
|
|
attempt_urls.append(preview_url)
|
|
|
|
seen_attempts = set()
|
|
for attempt in attempt_urls:
|
|
if not attempt or attempt in seen_attempts:
|
|
continue
|
|
seen_attempts.add(attempt)
|
|
success, _ = await downloader.download_file(
|
|
attempt,
|
|
preview_path,
|
|
use_auth=False
|
|
)
|
|
if success:
|
|
preview_downloaded = True
|
|
break
|
|
else:
|
|
rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image')
|
|
if rewritten:
|
|
preview_ext = _extension_from_url(preview_url, '.png')
|
|
preview_path = os.path.splitext(save_path)[0] + preview_ext
|
|
success, _ = await downloader.download_file(
|
|
rewritten_url,
|
|
preview_path,
|
|
use_auth=False
|
|
)
|
|
if success:
|
|
preview_downloaded = True
|
|
|
|
if not preview_downloaded:
|
|
temp_path: str | None = None
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
|
|
temp_path = temp_file.name
|
|
|
|
success, content, _ = await downloader.download_to_memory(
|
|
preview_url,
|
|
use_auth=False
|
|
)
|
|
if success:
|
|
with open(temp_path, 'wb') as temp_file_handle:
|
|
temp_file_handle.write(content)
|
|
preview_path = os.path.splitext(save_path)[0] + '.webp'
|
|
|
|
optimized_data, _ = ExifUtils.optimize_image(
|
|
image_data=temp_path,
|
|
target_width=CARD_PREVIEW_WIDTH,
|
|
format='webp',
|
|
quality=85,
|
|
preserve_metadata=False
|
|
)
|
|
|
|
with open(preview_path, 'wb') as preview_file:
|
|
preview_file.write(optimized_data)
|
|
|
|
preview_downloaded = True
|
|
finally:
|
|
if temp_path and os.path.exists(temp_path):
|
|
try:
|
|
os.unlink(temp_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to delete temp file: {e}")
|
|
|
|
if preview_downloaded and preview_path:
|
|
preview_nsfw_level = nsfw_level
|
|
metadata.preview_url = preview_path.replace(os.sep, '/')
|
|
metadata.preview_nsfw_level = nsfw_level
|
|
|
|
if progress_callback:
|
|
await progress_callback(3) # 3% progress after preview download
|
|
|
|
# Download model file with progress tracking using downloader
|
|
downloader = await get_downloader()
|
|
if pause_control is not None:
|
|
pause_control.update_stall_timeout(downloader.stall_timeout)
|
|
last_error = None
|
|
for download_url in download_urls:
|
|
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
|
download_kwargs = {
|
|
"progress_callback": lambda progress, snapshot=None: self._handle_download_progress(
|
|
progress,
|
|
progress_callback,
|
|
snapshot,
|
|
),
|
|
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
|
}
|
|
|
|
if pause_control is not None:
|
|
download_kwargs["pause_event"] = pause_control
|
|
|
|
success, result = await downloader.download_file(
|
|
download_url,
|
|
save_path, # Use full path instead of separate dir and filename
|
|
**download_kwargs,
|
|
)
|
|
|
|
if success:
|
|
break
|
|
|
|
last_error = result
|
|
if os.path.exists(save_path):
|
|
try:
|
|
os.remove(save_path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to remove incomplete file {save_path}: {e}")
|
|
else:
|
|
# Clean up files on failure, but preserve .part file for resume
|
|
cleanup_files = [metadata_path]
|
|
preview_path_value = getattr(metadata, 'preview_url', None)
|
|
if preview_path_value and os.path.exists(preview_path_value):
|
|
cleanup_files.append(preview_path_value)
|
|
|
|
for path in cleanup_files:
|
|
if path and os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to cleanup file {path}: {e}")
|
|
|
|
# Log but don't remove .part file to allow resume
|
|
if os.path.exists(part_path):
|
|
logger.info(f"Preserving partial download for resume: {part_path}")
|
|
|
|
return {'success': False, 'error': last_error or 'Failed to download file'}
|
|
|
|
# 4. Handle archive extraction and prepare per-file metadata
|
|
actual_file_paths = [save_path]
|
|
if zipfile.is_zipfile(save_path):
|
|
extracted_paths = await self._extract_safetensors_from_archive(save_path)
|
|
if not extracted_paths:
|
|
return {'success': False, 'error': 'Zip archive does not contain any safetensors files'}
|
|
actual_file_paths = extracted_paths
|
|
try:
|
|
os.remove(save_path)
|
|
except OSError as exc:
|
|
logger.warning(f"Unable to delete temporary archive {save_path}: {exc}")
|
|
if download_id and download_id in self._active_downloads:
|
|
self._active_downloads[download_id]['file_path'] = extracted_paths[0]
|
|
self._active_downloads[download_id]['extracted_paths'] = extracted_paths
|
|
|
|
metadata_entries = await self._build_metadata_entries(metadata, actual_file_paths)
|
|
if preview_path:
|
|
preview_targets = self._distribute_preview_to_entries(preview_path, metadata_entries)
|
|
for entry, target in zip(metadata_entries, preview_targets):
|
|
entry.preview_url = target.replace(os.sep, "/")
|
|
entry.preview_nsfw_level = preview_nsfw_level
|
|
if download_id and download_id in self._active_downloads and preview_targets:
|
|
self._active_downloads[download_id]["preview_path"] = preview_targets[0]
|
|
|
|
scanner = None
|
|
if model_type == "checkpoint":
|
|
scanner = await self._get_checkpoint_scanner()
|
|
logger.info(f"Updating checkpoint cache for {actual_file_paths[0]}")
|
|
elif model_type == "lora":
|
|
scanner = await self._get_lora_scanner()
|
|
logger.info(f"Updating lora cache for {actual_file_paths[0]}")
|
|
elif model_type == "embedding":
|
|
scanner = await ServiceRegistry.get_embedding_scanner()
|
|
logger.info(f"Updating embedding cache for {actual_file_paths[0]}")
|
|
|
|
adjust_cached_entry = (
|
|
getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None
|
|
)
|
|
|
|
for index, entry in enumerate(metadata_entries):
|
|
file_path_for_adjust = getattr(entry, "file_path", actual_file_paths[index])
|
|
normalized_file_path = (
|
|
file_path_for_adjust.replace(os.sep, "/")
|
|
if isinstance(file_path_for_adjust, str)
|
|
else str(file_path_for_adjust)
|
|
)
|
|
|
|
if scanner is not None:
|
|
find_root = getattr(scanner, "_find_root_for_file", None)
|
|
adjust_root = None
|
|
if callable(find_root):
|
|
try:
|
|
adjust_root = find_root(normalized_file_path)
|
|
except TypeError:
|
|
adjust_root = None
|
|
|
|
adjust_metadata = getattr(scanner, "adjust_metadata", None)
|
|
if callable(adjust_metadata):
|
|
adjusted_entry = adjust_metadata(entry, normalized_file_path, adjust_root)
|
|
if adjusted_entry is not None:
|
|
entry = adjusted_entry
|
|
metadata_entries[index] = entry
|
|
|
|
metadata_file_path = os.path.splitext(entry.file_path)[0] + '.metadata.json'
|
|
metadata_files_for_cleanup.append(metadata_file_path)
|
|
|
|
await MetadataManager.save_metadata(entry.file_path, entry)
|
|
|
|
metadata_dict = entry.to_dict()
|
|
if callable(adjust_cached_entry):
|
|
metadata_dict = adjust_cached_entry(metadata_dict)
|
|
|
|
if scanner is not None:
|
|
await scanner.add_model_to_cache(metadata_dict, relative_path)
|
|
|
|
# Report 100% completion
|
|
if progress_callback:
|
|
await progress_callback(100)
|
|
|
|
return {'success': True}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in _execute_download: {e}", exc_info=True)
|
|
cleanup_targets = {
|
|
path
|
|
for path in [save_path, metadata_path, *metadata_files_for_cleanup, *extracted_paths]
|
|
if path
|
|
}
|
|
preview_candidate = (
|
|
metadata_entries[0].preview_url
|
|
if metadata_entries
|
|
else getattr(metadata, "preview_url", None)
|
|
)
|
|
if preview_candidate:
|
|
cleanup_targets.add(preview_candidate)
|
|
|
|
cleanup_targets.update(preview_targets)
|
|
for path in cleanup_targets:
|
|
if path and os.path.exists(path):
|
|
try:
|
|
os.remove(path)
|
|
except Exception as exc:
|
|
logger.warning(f"Failed to cleanup file {path}: {exc}")
|
|
|
|
return {'success': False, 'error': str(e)}
|
|
|
|
async def _extract_safetensors_from_archive(self, archive_path: str) -> List[str]:
|
|
if not zipfile.is_zipfile(archive_path):
|
|
return []
|
|
|
|
target_dir = os.path.dirname(archive_path)
|
|
|
|
def _extract_sync() -> List[str]:
|
|
extracted_files: List[str] = []
|
|
with zipfile.ZipFile(archive_path, "r") as archive:
|
|
for info in archive.infolist():
|
|
if info.is_dir():
|
|
continue
|
|
if not info.filename.lower().endswith(".safetensors"):
|
|
continue
|
|
file_name = os.path.basename(info.filename)
|
|
if not file_name:
|
|
continue
|
|
dest_path = self._resolve_extracted_destination(target_dir, file_name)
|
|
with archive.open(info) as source, open(dest_path, "wb") as target:
|
|
shutil.copyfileobj(source, target)
|
|
extracted_files.append(dest_path)
|
|
return extracted_files
|
|
|
|
return await asyncio.to_thread(_extract_sync)
|
|
|
|
async def _build_metadata_entries(self, base_metadata, file_paths: List[str]) -> List:
|
|
if not file_paths:
|
|
return []
|
|
|
|
entries: List = []
|
|
for index, file_path in enumerate(file_paths):
|
|
entry = base_metadata if index == 0 else copy.deepcopy(base_metadata)
|
|
entry.update_file_info(file_path)
|
|
entry.sha256 = await calculate_sha256(file_path)
|
|
entries.append(entry)
|
|
|
|
return entries
|
|
|
|
def _resolve_extracted_destination(self, target_dir: str, filename: str) -> str:
|
|
base_name, extension = os.path.splitext(filename)
|
|
candidate = filename
|
|
destination = os.path.join(target_dir, candidate)
|
|
counter = 1
|
|
|
|
while os.path.exists(destination):
|
|
candidate = f"{base_name}-{counter}{extension}"
|
|
destination = os.path.join(target_dir, candidate)
|
|
counter += 1
|
|
|
|
return destination
|
|
|
|
def _distribute_preview_to_entries(self, preview_path: str, entries: List) -> List[str]:
|
|
if not preview_path or not entries:
|
|
return []
|
|
|
|
if not os.path.exists(preview_path):
|
|
return []
|
|
|
|
extension = os.path.splitext(preview_path)[1] or ".webp"
|
|
|
|
targets = [
|
|
os.path.splitext(entry.file_path)[0] + extension for entry in entries
|
|
]
|
|
|
|
if not targets:
|
|
return []
|
|
|
|
first_target = targets[0]
|
|
if preview_path != first_target:
|
|
os.replace(preview_path, first_target)
|
|
source_path = first_target
|
|
|
|
for target in targets[1:]:
|
|
shutil.copyfile(source_path, target)
|
|
|
|
return targets
|
|
|
|
async def _handle_download_progress(
|
|
self,
|
|
progress_update,
|
|
progress_callback,
|
|
snapshot=None,
|
|
):
|
|
"""Convert file download progress to overall progress."""
|
|
|
|
if not progress_callback:
|
|
return
|
|
|
|
file_progress, original_snapshot = self._normalize_progress(progress_update, snapshot)
|
|
overall_progress = 3 + (file_progress * 0.97)
|
|
overall_progress = max(0.0, min(overall_progress, 100.0))
|
|
rounded_progress = round(overall_progress)
|
|
|
|
normalized_snapshot: Optional[DownloadProgress] = None
|
|
if original_snapshot is not None:
|
|
normalized_snapshot = DownloadProgress(
|
|
percent_complete=overall_progress,
|
|
bytes_downloaded=original_snapshot.bytes_downloaded,
|
|
total_bytes=original_snapshot.total_bytes,
|
|
bytes_per_second=original_snapshot.bytes_per_second,
|
|
timestamp=original_snapshot.timestamp,
|
|
)
|
|
|
|
await self._dispatch_progress(progress_callback, normalized_snapshot, rounded_progress)
|
|
|
|
async def cancel_download(self, download_id: str) -> Dict:
|
|
"""Cancel an active download by download_id
|
|
|
|
Args:
|
|
download_id: The unique identifier of the download task
|
|
|
|
Returns:
|
|
Dict: Status of the cancellation operation
|
|
"""
|
|
if download_id not in self._download_tasks:
|
|
return {'success': False, 'error': 'Download task not found'}
|
|
|
|
try:
|
|
# Get the task and cancel it
|
|
task = self._download_tasks[download_id]
|
|
task.cancel()
|
|
|
|
pause_control = self._pause_events.get(download_id)
|
|
if pause_control is not None:
|
|
pause_control.resume()
|
|
|
|
# Update status in active downloads
|
|
if download_id in self._active_downloads:
|
|
self._active_downloads[download_id]['status'] = 'cancelling'
|
|
self._active_downloads[download_id]['bytes_per_second'] = 0.0
|
|
|
|
# Wait briefly for the task to acknowledge cancellation
|
|
try:
|
|
await asyncio.wait_for(asyncio.shield(task), timeout=2.0)
|
|
except (asyncio.CancelledError, asyncio.TimeoutError):
|
|
pass
|
|
|
|
# Clean up ALL files including .part when user cancels
|
|
download_info = self._active_downloads.get(download_id)
|
|
if download_info:
|
|
target_files = set()
|
|
primary_path = download_info.get('file_path')
|
|
if primary_path:
|
|
target_files.add(primary_path)
|
|
|
|
for extra_path in download_info.get('extracted_paths', []):
|
|
if extra_path:
|
|
target_files.add(extra_path)
|
|
|
|
for file_path in target_files:
|
|
if os.path.exists(file_path):
|
|
try:
|
|
os.unlink(file_path)
|
|
logger.debug(f"Deleted cancelled download: {file_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting file: {e}")
|
|
|
|
# Delete the .part file (only on user cancellation)
|
|
if 'part_path' in download_info:
|
|
part_path = download_info['part_path']
|
|
if os.path.exists(part_path):
|
|
try:
|
|
os.unlink(part_path)
|
|
logger.debug(f"Deleted partial download: {part_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting part file: {e}")
|
|
|
|
# Delete metadata files for each resolved path
|
|
for file_path in target_files:
|
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
|
if os.path.exists(metadata_path):
|
|
try:
|
|
os.unlink(metadata_path)
|
|
except Exception as e:
|
|
logger.error(f"Error deleting metadata file: {e}")
|
|
|
|
preview_path_value = download_info.get('preview_path')
|
|
if preview_path_value and os.path.exists(preview_path_value):
|
|
try:
|
|
os.unlink(preview_path_value)
|
|
logger.debug(f"Deleted preview file: {preview_path_value}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting preview file: {preview_path_value}")
|
|
|
|
# Delete preview file if exists (.webp or .mp4) for legacy paths
|
|
for file_path in target_files:
|
|
for preview_ext in ['.webp', '.mp4']:
|
|
preview_path = os.path.splitext(file_path)[0] + preview_ext
|
|
if os.path.exists(preview_path):
|
|
try:
|
|
os.unlink(preview_path)
|
|
logger.debug(f"Deleted preview file: {preview_path}")
|
|
except Exception as e:
|
|
logger.error(f"Error deleting preview file: {preview_path}")
|
|
return {'success': True, 'message': 'Download cancelled successfully'}
|
|
except Exception as e:
|
|
logger.error(f"Error cancelling download: {e}", exc_info=True)
|
|
return {'success': False, 'error': str(e)}
|
|
finally:
|
|
self._pause_events.pop(download_id, None)
|
|
|
|
async def pause_download(self, download_id: str) -> Dict:
|
|
"""Pause an active download without losing progress."""
|
|
|
|
if download_id not in self._download_tasks:
|
|
return {'success': False, 'error': 'Download task not found'}
|
|
|
|
pause_control = self._pause_events.get(download_id)
|
|
if pause_control is None:
|
|
return {'success': False, 'error': 'Download task not found'}
|
|
|
|
if pause_control.is_paused():
|
|
return {'success': False, 'error': 'Download is already paused'}
|
|
|
|
pause_control.pause()
|
|
|
|
download_info = self._active_downloads.get(download_id)
|
|
if download_info is not None:
|
|
download_info['status'] = 'paused'
|
|
download_info['bytes_per_second'] = 0.0
|
|
|
|
return {'success': True, 'message': 'Download paused successfully'}
|
|
|
|
async def resume_download(self, download_id: str) -> Dict:
|
|
"""Resume a previously paused download."""
|
|
|
|
pause_control = self._pause_events.get(download_id)
|
|
if pause_control is None:
|
|
return {'success': False, 'error': 'Download task not found'}
|
|
|
|
if pause_control.is_set():
|
|
return {'success': False, 'error': 'Download is not paused'}
|
|
|
|
download_info = self._active_downloads.get(download_id)
|
|
force_reconnect = False
|
|
if pause_control is not None:
|
|
elapsed = pause_control.time_since_last_progress()
|
|
threshold = max(30.0, pause_control.stall_timeout / 2.0)
|
|
if elapsed is not None and elapsed >= threshold:
|
|
force_reconnect = True
|
|
logger.info(
|
|
"Forcing reconnect for download %s after %.1f seconds without progress",
|
|
download_id,
|
|
elapsed,
|
|
)
|
|
|
|
pause_control.resume(force_reconnect=force_reconnect)
|
|
|
|
if download_info is not None:
|
|
if download_info.get('status') == 'paused':
|
|
download_info['status'] = 'downloading'
|
|
download_info.setdefault('bytes_per_second', 0.0)
|
|
|
|
return {'success': True, 'message': 'Download resumed successfully'}
|
|
|
|
@staticmethod
|
|
def _coerce_progress_value(progress) -> float:
|
|
try:
|
|
return float(progress)
|
|
except (TypeError, ValueError):
|
|
return 0.0
|
|
|
|
@classmethod
|
|
def _normalize_progress(
|
|
cls,
|
|
progress,
|
|
snapshot: Optional[DownloadProgress] = None,
|
|
) -> Tuple[float, Optional[DownloadProgress]]:
|
|
if isinstance(progress, DownloadProgress):
|
|
return progress.percent_complete, progress
|
|
|
|
if isinstance(snapshot, DownloadProgress):
|
|
return snapshot.percent_complete, snapshot
|
|
|
|
if isinstance(progress, dict):
|
|
if 'percent_complete' in progress:
|
|
return cls._coerce_progress_value(progress['percent_complete']), snapshot
|
|
if 'progress' in progress:
|
|
return cls._coerce_progress_value(progress['progress']), snapshot
|
|
|
|
return cls._coerce_progress_value(progress), None
|
|
|
|
async def _dispatch_progress(
|
|
self,
|
|
callback,
|
|
snapshot: Optional[DownloadProgress],
|
|
progress_value: float,
|
|
) -> None:
|
|
try:
|
|
if snapshot is not None:
|
|
result = callback(snapshot, snapshot)
|
|
else:
|
|
result = callback(progress_value)
|
|
except TypeError:
|
|
result = callback(progress_value)
|
|
|
|
if inspect.isawaitable(result):
|
|
await result
|
|
elif asyncio.iscoroutine(result):
|
|
await result
|
|
|
|
async def get_active_downloads(self) -> Dict:
|
|
"""Get information about all active downloads
|
|
|
|
Returns:
|
|
Dict: List of active downloads and their status
|
|
"""
|
|
return {
|
|
'downloads': [
|
|
{
|
|
'download_id': task_id,
|
|
'model_id': info.get('model_id'),
|
|
'model_version_id': info.get('model_version_id'),
|
|
'progress': info.get('progress', 0),
|
|
'status': info.get('status', 'unknown'),
|
|
'error': info.get('error', None),
|
|
'bytes_downloaded': info.get('bytes_downloaded', 0),
|
|
'total_bytes': info.get('total_bytes'),
|
|
'bytes_per_second': info.get('bytes_per_second', 0.0),
|
|
}
|
|
for task_id, info in self._active_downloads.items()
|
|
]
|
|
}
|