diff --git a/py/services/download_manager.py b/py/services/download_manager.py index 1715937b..a9018662 100644 --- a/py/services/download_manager.py +++ b/py/services/download_manager.py @@ -19,7 +19,7 @@ from ..utils.file_utils import calculate_sha256 from ..utils.metadata_manager import MetadataManager from .service_registry import ServiceRegistry from .settings_manager import get_settings_manager -from .metadata_service import get_default_metadata_provider +from .metadata_service import get_default_metadata_provider, get_metadata_provider from .downloader import get_downloader, DownloadProgress, DownloadStreamControl # Download to temporary file first @@ -27,10 +27,11 @@ import tempfile logger = logging.getLogger(__name__) + class DownloadManager: _instance = None _lock = asyncio.Lock() - + @classmethod async def get_instance(cls): """Get singleton instance of DownloadManager""" @@ -41,30 +42,37 @@ class DownloadManager: def __init__(self): # Check if already initialized for singleton pattern - if hasattr(self, '_initialized'): + if hasattr(self, "_initialized"): return self._initialized = True - + # Add download management self._active_downloads = OrderedDict() # download_id -> download_info self._download_semaphore = asyncio.Semaphore(5) # Limit concurrent downloads self._download_tasks = {} # download_id -> asyncio.Task self._pause_events: Dict[str, DownloadStreamControl] = {} - + async def _get_lora_scanner(self): """Get the lora scanner from registry""" return await ServiceRegistry.get_lora_scanner() - + async def _get_checkpoint_scanner(self): """Get the checkpoint scanner from registry""" return await ServiceRegistry.get_checkpoint_scanner() - - async def download_from_civitai(self, model_id: int = None, model_version_id: int = None, - save_dir: str = None, relative_path: str = '', - progress_callback=None, use_default_paths: bool = False, - download_id: str = None, source: str = None) -> Dict: + + async def download_from_civitai( + self, + model_id: int = None, + model_version_id: int = None, + save_dir: str = None, + relative_path: str = "", + progress_callback=None, + use_default_paths: bool = False, + download_id: str = None, + source: str = None, + ) -> Dict: """Download model from Civitai with task tracking and concurrency control - + Args: model_id: Civitai model ID (optional if model_version_id is provided) model_version_id: Civitai model version ID (optional if model_id is provided) @@ -74,84 +82,109 @@ class DownloadManager: use_default_paths: Flag to use default paths download_id: Unique identifier for this download task source: Optional source parameter to specify metadata provider - + Returns: Dict with download result """ # Validate that at least one identifier is provided if not model_id and not model_version_id: - return {'success': False, 'error': 'Either model_id or model_version_id must be provided'} - + return { + "success": False, + "error": "Either model_id or model_version_id must be provided", + } + # Use provided download_id or generate new one task_id = download_id or str(uuid.uuid4()) - + # Register download task in tracking dict self._active_downloads[task_id] = { - 'model_id': model_id, - 'model_version_id': model_version_id, - 'progress': 0, - 'status': 'queued', - 'bytes_downloaded': 0, - 'total_bytes': None, - 'bytes_per_second': 0.0, - 'last_progress_timestamp': None, + "model_id": model_id, + "model_version_id": model_version_id, + "progress": 0, + "status": "queued", + "bytes_downloaded": 0, + "total_bytes": None, + "bytes_per_second": 0.0, + "last_progress_timestamp": None, } pause_control = DownloadStreamControl() self._pause_events[task_id] = pause_control - + # Create tracking task download_task = asyncio.create_task( self._download_with_semaphore( - task_id, model_id, model_version_id, save_dir, - relative_path, progress_callback, use_default_paths, source + task_id, + model_id, + model_version_id, + save_dir, + relative_path, + progress_callback, + use_default_paths, + source, ) ) - + # Store task for tracking and cancellation self._download_tasks[task_id] = download_task - + try: # Wait for download to complete result = await download_task - result['download_id'] = task_id # Include download_id in result + result["download_id"] = task_id # Include download_id in result return result except asyncio.CancelledError: - return {'success': False, 'error': 'Download was cancelled', 'download_id': task_id} + return { + "success": False, + "error": "Download was cancelled", + "download_id": task_id, + } finally: # Clean up task reference if task_id in self._download_tasks: del self._download_tasks[task_id] self._pause_events.pop(task_id, None) - async def _download_with_semaphore(self, task_id: str, model_id: int, model_version_id: int, - save_dir: str, relative_path: str, - progress_callback=None, use_default_paths: bool = False, - source: str = None): + async def _download_with_semaphore( + self, + task_id: str, + model_id: int, + model_version_id: int, + save_dir: str, + relative_path: str, + progress_callback=None, + use_default_paths: bool = False, + source: str = None, + ): """Execute download with semaphore to limit concurrency""" # Update status to waiting if task_id in self._active_downloads: - self._active_downloads[task_id]['status'] = 'waiting' - + self._active_downloads[task_id]["status"] = "waiting" + # Wrap progress callback to track progress in active_downloads original_callback = progress_callback + async def tracking_callback(progress, metrics=None): progress_value, snapshot = self._normalize_progress(progress, metrics) if task_id in self._active_downloads: info = self._active_downloads[task_id] - info['progress'] = round(progress_value) + info["progress"] = round(progress_value) if snapshot is not None: - info['bytes_downloaded'] = snapshot.bytes_downloaded - info['total_bytes'] = snapshot.total_bytes - info['bytes_per_second'] = snapshot.bytes_per_second + info["bytes_downloaded"] = snapshot.bytes_downloaded + info["total_bytes"] = snapshot.total_bytes + info["bytes_per_second"] = snapshot.bytes_per_second pause_control = self._pause_events.get(task_id) if isinstance(pause_control, DownloadStreamControl): pause_control.mark_progress(snapshot.timestamp) - info['last_progress_timestamp'] = pause_control.last_progress_timestamp + info["last_progress_timestamp"] = ( + pause_control.last_progress_timestamp + ) if original_callback: - await self._dispatch_progress(original_callback, snapshot, progress_value) + await self._dispatch_progress( + original_callback, snapshot, progress_value + ) # Acquire semaphore to limit concurrent downloads try: @@ -159,49 +192,60 @@ class DownloadManager: pause_control = self._pause_events.get(task_id) if pause_control is not None and pause_control.is_paused(): if task_id in self._active_downloads: - self._active_downloads[task_id]['status'] = 'paused' - self._active_downloads[task_id]['bytes_per_second'] = 0.0 + self._active_downloads[task_id]["status"] = "paused" + self._active_downloads[task_id]["bytes_per_second"] = 0.0 await pause_control.wait() # Update status to downloading if task_id in self._active_downloads: - self._active_downloads[task_id]['status'] = 'downloading' - + self._active_downloads[task_id]["status"] = "downloading" + # Use original download implementation try: # Check for cancellation before starting if asyncio.current_task().cancelled(): raise asyncio.CancelledError() - + result = await self._execute_original_download( - model_id, model_version_id, save_dir, - relative_path, tracking_callback, use_default_paths, - task_id, source + model_id, + model_version_id, + save_dir, + relative_path, + tracking_callback, + use_default_paths, + task_id, + source, ) - + # Update status based on result if task_id in self._active_downloads: - self._active_downloads[task_id]['status'] = 'completed' if result['success'] else 'failed' - if not result['success']: - self._active_downloads[task_id]['error'] = result.get('error', 'Unknown error') - self._active_downloads[task_id]['bytes_per_second'] = 0.0 - + self._active_downloads[task_id]["status"] = ( + "completed" if result["success"] else "failed" + ) + if not result["success"]: + self._active_downloads[task_id]["error"] = result.get( + "error", "Unknown error" + ) + self._active_downloads[task_id]["bytes_per_second"] = 0.0 + return result except asyncio.CancelledError: # Handle cancellation if task_id in self._active_downloads: - self._active_downloads[task_id]['status'] = 'cancelled' - self._active_downloads[task_id]['bytes_per_second'] = 0.0 + self._active_downloads[task_id]["status"] = "cancelled" + self._active_downloads[task_id]["bytes_per_second"] = 0.0 logger.info(f"Download cancelled for task {task_id}") raise except Exception as e: # Handle other errors - logger.error(f"Download error for task {task_id}: {str(e)}", exc_info=True) + logger.error( + f"Download error for task {task_id}: {str(e)}", exc_info=True + ) if task_id in self._active_downloads: - self._active_downloads[task_id]['status'] = 'failed' - self._active_downloads[task_id]['error'] = str(e) - self._active_downloads[task_id]['bytes_per_second'] = 0.0 - return {'success': False, 'error': str(e)} + self._active_downloads[task_id]["status"] = "failed" + self._active_downloads[task_id]["error"] = str(e) + self._active_downloads[task_id]["bytes_per_second"] = 0.0 + return {"success": False, "error": str(e)} finally: # Schedule cleanup of download record after delay asyncio.create_task(self._cleanup_download_record(task_id)) @@ -231,75 +275,130 @@ class DownloadManager: lora_scanner = await self._get_lora_scanner() checkpoint_scanner = await self._get_checkpoint_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner() - + # Check lora scanner first if await lora_scanner.check_model_version_exists(model_version_id): - return {'success': False, 'error': 'Model version already exists in lora library'} - + return { + "success": False, + "error": "Model version already exists in lora library", + } + # Check checkpoint scanner - if await checkpoint_scanner.check_model_version_exists(model_version_id): - return {'success': False, 'error': 'Model version already exists in checkpoint library'} - + if await checkpoint_scanner.check_model_version_exists( + model_version_id + ): + return { + "success": False, + "error": "Model version already exists in checkpoint library", + } + # Check embedding scanner if await embedding_scanner.check_model_version_exists(model_version_id): - return {'success': False, 'error': 'Model version already exists in embedding library'} - - metadata_provider = await get_default_metadata_provider() + return { + "success": False, + "error": "Model version already exists in embedding library", + } + + # Use CivArchive provider directly when source is 'civarchive' + # This prioritizes CivArchive metadata (with mirror availability info) over Civitai + if source == "civarchive": + metadata_provider = await get_metadata_provider("civarchive_api") + if not metadata_provider: + logger.warning( + "CivArchive provider not available, falling back to default provider" + ) + metadata_provider = await get_default_metadata_provider() + else: + metadata_provider = await get_default_metadata_provider() # Get version info based on the provided identifier - version_info = await metadata_provider.get_model_version(model_id, model_version_id) - - if not version_info: - return {'success': False, 'error': 'Failed to fetch model metadata'} + version_info = await metadata_provider.get_model_version( + model_id, model_version_id + ) - model_type_from_info = version_info.get('model', {}).get('type', '').lower() - if model_type_from_info == 'checkpoint': - model_type = 'checkpoint' + if not version_info: + # If CivArchive provider failed and source was 'civarchive', try default provider as fallback + if source == "civarchive": + logger.info( + "CivArchive metadata fetch failed, trying default provider" + ) + metadata_provider = await get_default_metadata_provider() + version_info = await metadata_provider.get_model_version( + model_id, model_version_id + ) + + if not version_info: + return {"success": False, "error": "Failed to fetch model metadata"} + + model_type_from_info = version_info.get("model", {}).get("type", "").lower() + if model_type_from_info == "checkpoint": + model_type = "checkpoint" elif model_type_from_info in VALID_LORA_TYPES: - model_type = 'lora' - elif model_type_from_info == 'textualinversion': - model_type = 'embedding' + model_type = "lora" + elif model_type_from_info == "textualinversion": + model_type = "embedding" else: - return {'success': False, 'error': f'Model type "{model_type_from_info}" is not supported for download'} - + return { + "success": False, + "error": f'Model type "{model_type_from_info}" is not supported for download', + } + # Case 2: model_version_id was None, check after getting version_info if model_version_id is None: - version_id = version_info.get('id') - - if model_type == 'lora': + version_id = version_info.get("id") + + if model_type == "lora": # Check lora scanner lora_scanner = await self._get_lora_scanner() if await lora_scanner.check_model_version_exists(version_id): - return {'success': False, 'error': 'Model version already exists in lora library'} - elif model_type == 'checkpoint': + return { + "success": False, + "error": "Model version already exists in lora library", + } + elif model_type == "checkpoint": # Check checkpoint scanner checkpoint_scanner = await self._get_checkpoint_scanner() if await checkpoint_scanner.check_model_version_exists(version_id): - return {'success': False, 'error': 'Model version already exists in checkpoint library'} - elif model_type == 'embedding': + return { + "success": False, + "error": "Model version already exists in checkpoint library", + } + elif model_type == "embedding": # Embeddings are not checked in scanners, but we can still check if it exists embedding_scanner = await ServiceRegistry.get_embedding_scanner() if await embedding_scanner.check_model_version_exists(version_id): - return {'success': False, 'error': 'Model version already exists in embedding library'} - + return { + "success": False, + "error": "Model version already exists in embedding library", + } + # Handle use_default_paths if use_default_paths: settings_manager = get_settings_manager() # Set save_dir based on model type - if model_type == 'checkpoint': - default_path = settings_manager.get('default_checkpoint_root') + if model_type == "checkpoint": + default_path = settings_manager.get("default_checkpoint_root") if not default_path: - return {'success': False, 'error': 'Default checkpoint root path not set in settings'} + return { + "success": False, + "error": "Default checkpoint root path not set in settings", + } save_dir = default_path - elif model_type == 'lora': - default_path = settings_manager.get('default_lora_root') + elif model_type == "lora": + default_path = settings_manager.get("default_lora_root") if not default_path: - return {'success': False, 'error': 'Default lora root path not set in settings'} + return { + "success": False, + "error": "Default lora root path not set in settings", + } save_dir = default_path - elif model_type == 'embedding': - default_path = settings_manager.get('default_embedding_root') + elif model_type == "embedding": + default_path = settings_manager.get("default_embedding_root") if not default_path: - return {'success': False, 'error': 'Default embedding root path not set in settings'} + return { + "success": False, + "error": "Default embedding root path not set in settings", + } save_dir = default_path # Calculate relative path using template @@ -312,61 +411,98 @@ class DownloadManager: os.makedirs(save_dir, exist_ok=True) # Check if this is an early access model - if version_info.get('earlyAccessEndsAt'): - early_access_date = version_info.get('earlyAccessEndsAt', '') + if version_info.get("earlyAccessEndsAt"): + early_access_date = version_info.get("earlyAccessEndsAt", "") # Convert to a readable date if possible try: from datetime import datetime - date_obj = datetime.fromisoformat(early_access_date.replace('Z', '+00:00')) - formatted_date = date_obj.strftime('%Y-%m-%d') - early_access_msg = f"This model requires payment (until {formatted_date}). " + + date_obj = datetime.fromisoformat( + early_access_date.replace("Z", "+00:00") + ) + formatted_date = date_obj.strftime("%Y-%m-%d") + early_access_msg = ( + f"This model requires payment (until {formatted_date}). " + ) except: early_access_msg = "This model requires payment. " - + early_access_msg += "Please ensure you have purchased early access and are logged in to Civitai." - logger.warning(f"Early access model detected: {version_info.get('name', 'Unknown')}") - + logger.warning( + f"Early access model detected: {version_info.get('name', 'Unknown')}" + ) + # We'll still try to download, but log a warning and prepare for potential failure if progress_callback: - await progress_callback(1) # Show minimal progress to indicate we're trying + await progress_callback( + 1 + ) # Show minimal progress to indicate we're trying # Report initial progress if progress_callback: await progress_callback(0) # 2. Get file information - file_info = next((f for f in version_info.get('files', []) if f.get('primary') and f.get('type') in ('Model', 'Negative')), None) + file_info = next( + ( + f + for f in version_info.get("files", []) + if f.get("primary") and f.get("type") in ("Model", "Negative") + ), + None, + ) if not file_info: - return {'success': False, 'error': 'No primary file found in metadata'} - mirrors = file_info.get('mirrors') or [] + return {"success": False, "error": "No primary file found in metadata"} + mirrors = file_info.get("mirrors") or [] download_urls = [] if mirrors: for mirror in mirrors: - if mirror.get('deletedAt') is None and mirror.get('url'): - download_urls.append(mirror['url']) + if mirror.get("deletedAt") is None and mirror.get("url"): + download_urls.append(mirror["url"]) + + # When source is 'civarchive', prioritize non-Civitai URLs + # This avoids failed downloads from deleted Civitai models + if source == "civarchive" and len(download_urls) > 1: + civitai_urls = [ + u + for u in download_urls + if u.startswith("https://civitai.com/api/download/") + ] + non_civitai_urls = [ + u + for u in download_urls + if not u.startswith("https://civitai.com/api/download/") + ] + download_urls = non_civitai_urls + civitai_urls else: - download_url = file_info.get('downloadUrl') + download_url = file_info.get("downloadUrl") if download_url: download_urls.append(download_url) if not download_urls: - return {'success': False, 'error': 'No mirror URL found'} + return {"success": False, "error": "No mirror URL found"} # 3. Prepare download - file_name = file_info['name'] + file_name = file_info["name"] save_path = os.path.join(save_dir, file_name) # 5. Prepare metadata based on model type if model_type == "checkpoint": - metadata = CheckpointMetadata.from_civitai_info(version_info, file_info, save_path) + metadata = CheckpointMetadata.from_civitai_info( + version_info, file_info, save_path + ) logger.info(f"Creating CheckpointMetadata for {file_name}") elif model_type == "lora": - metadata = LoraMetadata.from_civitai_info(version_info, file_info, save_path) + metadata = LoraMetadata.from_civitai_info( + version_info, file_info, save_path + ) logger.info(f"Creating LoraMetadata for {file_name}") elif model_type == "embedding": - metadata = EmbeddingMetadata.from_civitai_info(version_info, file_info, save_path) + metadata = EmbeddingMetadata.from_civitai_info( + version_info, file_info, save_path + ) logger.info(f"Creating EmbeddingMetadata for {file_name}") - + # 6. Start download process result = await self._execute_download( download_urls=download_urls, @@ -379,11 +515,11 @@ class DownloadManager: download_id=download_id, ) - if result.get('success', False): + if result.get("success", False): resolved_model_id = ( model_id - or version_info.get('modelId') - or (version_info.get('model') or {}).get('id') + or version_info.get("modelId") + or (version_info.get("model") or {}).get("id") ) await self._sync_downloaded_version( model_type, @@ -393,8 +529,8 @@ class DownloadManager: ) # If early_access_msg exists and download failed, replace error message - if 'early_access_msg' in locals() and not result.get('success', False): - result['error'] = early_access_msg + if "early_access_msg" in locals() and not result.get("success", False): + result["error"] = early_access_msg return result @@ -402,9 +538,17 @@ class DownloadManager: logger.error(f"Error in download_from_civitai: {e}", exc_info=True) # Check if this might be an early access error error_str = str(e).lower() - if "403" in error_str or "401" in error_str or "unauthorized" in error_str or "early access" in error_str: - return {'success': False, 'error': f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai."} - return {'success': False, 'error': str(e)} + if ( + "403" in error_str + or "401" in error_str + or "unauthorized" in error_str + or "early access" in error_str + ): + return { + "success": False, + "error": f"Early access restriction: {str(e)}. Please ensure you have purchased early access and are logged in to Civitai.", + } + return {"success": False, "error": str(e)} async def _sync_downloaded_version( self, @@ -418,7 +562,9 @@ class DownloadManager: try: update_service = await ServiceRegistry.get_model_update_service() except Exception as exc: - logger.debug("Skipping update sync; failed to acquire update service: %s", exc) + logger.debug( + "Skipping update sync; failed to acquire update service: %s", exc + ) return if update_service is None: @@ -426,18 +572,20 @@ class DownloadManager: resolved_model_id = model_id_value if resolved_model_id is None: - resolved_model_id = version_info.get('modelId') + resolved_model_id = version_info.get("modelId") if resolved_model_id is None: - model_info = version_info.get('model') + model_info = version_info.get("model") if isinstance(model_info, dict): - resolved_model_id = model_info.get('id') + resolved_model_id = model_info.get("id") try: resolved_model_id = int(resolved_model_id) except (TypeError, ValueError): - logger.debug("Skipping update sync; invalid model id: %s", resolved_model_id) + logger.debug( + "Skipping update sync; invalid model id: %s", resolved_model_id + ) return - version_id = version_info.get('id') + version_id = version_info.get("id") if version_id is None: version_id = fallback_version_id try: @@ -453,18 +601,20 @@ class DownloadManager: version_ids = set() scanner = None try: - if model_type == 'lora': + if model_type == "lora": scanner = await self._get_lora_scanner() - elif model_type == 'checkpoint': + elif model_type == "checkpoint": scanner = await self._get_checkpoint_scanner() - elif model_type == 'embedding': + elif model_type == "embedding": scanner = await ServiceRegistry.get_embedding_scanner() except Exception as exc: logger.debug("Failed to acquire scanner for %s models: %s", model_type, exc) if scanner is not None: try: - local_versions = await scanner.get_model_versions_by_id(resolved_model_id) + local_versions = await scanner.get_model_versions_by_id( + resolved_model_id + ) except Exception as exc: logger.debug( "Failed to collect local versions for %s model %s: %s", @@ -474,7 +624,7 @@ class DownloadManager: ) else: for entry in local_versions or []: - vid = entry.get('versionId') + vid = entry.get("versionId") try: version_ids.add(int(vid)) except (TypeError, ValueError): @@ -497,55 +647,63 @@ class DownloadManager: exc, ) - def _calculate_relative_path(self, version_info: Dict, model_type: str = 'lora') -> str: + def _calculate_relative_path( + self, version_info: Dict, model_type: str = "lora" + ) -> str: """Calculate relative path using template from settings - + Args: version_info: Version info from Civitai API model_type: Type of model ('lora', 'checkpoint', 'embedding') - + Returns: Relative path string """ # Get path template from settings for specific model type settings_manager = get_settings_manager() path_template = settings_manager.get_download_path_template(model_type) - + # If template is empty, return empty path (flat structure) if not path_template: - return '' - + return "" + # Get base model name - base_model = version_info.get('baseModel', '') - + base_model = version_info.get("baseModel", "") + # Get author from creator data - creator_info = version_info.get('creator') + creator_info = version_info.get("creator") if creator_info and isinstance(creator_info, dict): - author = creator_info.get('username') or 'Anonymous' + author = creator_info.get("username") or "Anonymous" else: - author = 'Anonymous' - + author = "Anonymous" + # Apply mapping if available - base_model_mappings = settings_manager.get('base_model_path_mappings', {}) + base_model_mappings = settings_manager.get("base_model_path_mappings", {}) mapped_base_model = base_model_mappings.get(base_model, base_model) - - model_info = version_info.get('model') or {} + + model_info = version_info.get("model") or {} # Get model tags - model_tags = model_info.get('tags', []) + model_tags = model_info.get("tags", []) - first_tag = settings_manager.resolve_priority_tag_for_model(model_tags, model_type) + first_tag = settings_manager.resolve_priority_tag_for_model( + model_tags, model_type + ) # Format the template with available data formatted_path = path_template - formatted_path = formatted_path.replace('{base_model}', mapped_base_model) - formatted_path = formatted_path.replace('{first_tag}', first_tag) - formatted_path = formatted_path.replace('{author}', author) - formatted_path = formatted_path.replace('{model_name}', sanitize_folder_name(model_info.get('name', ''))) - formatted_path = formatted_path.replace('{version_name}', sanitize_folder_name(version_info.get('name', ''))) + formatted_path = formatted_path.replace("{base_model}", mapped_base_model) + formatted_path = formatted_path.replace("{first_tag}", first_tag) + formatted_path = formatted_path.replace("{author}", author) + formatted_path = formatted_path.replace( + "{model_name}", sanitize_folder_name(model_info.get("name", "")) + ) + formatted_path = formatted_path.replace( + "{version_name}", sanitize_folder_name(version_info.get("name", "")) + ) - if model_type == 'embedding': - formatted_path = formatted_path.replace(' ', '_') + if model_type == "embedding": + formatted_path = formatted_path.replace(" ", "_") return formatted_path @@ -572,59 +730,58 @@ class DownloadManager: # Extract original filename details original_filename = os.path.basename(metadata.file_path) base_name, extension = os.path.splitext(original_filename) - + # Check for filename conflicts and generate unique filename if needed # Use the hash from metadata for conflict resolution def hash_provider(): return metadata.sha256 - + unique_filename = metadata.generate_unique_filename( - save_dir, - base_name, - extension, - hash_provider=hash_provider + save_dir, base_name, extension, hash_provider=hash_provider ) - + # Update paths if filename changed if unique_filename != original_filename: - logger.info(f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'") + logger.info( + f"Filename conflict detected. Changing '{original_filename}' to '{unique_filename}'" + ) save_path = os.path.join(save_dir, unique_filename) # Update metadata with new file path and name - metadata.file_path = save_path.replace(os.sep, '/') + metadata.file_path = save_path.replace(os.sep, "/") metadata.file_name = os.path.splitext(unique_filename)[0] else: save_path = metadata.file_path - - part_path = save_path + '.part' - metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' + + part_path = save_path + ".part" + metadata_path = os.path.splitext(save_path)[0] + ".metadata.json" pause_control = self._pause_events.get(download_id) if download_id else None # Store file paths in active_downloads for potential cleanup if download_id and download_id in self._active_downloads: - self._active_downloads[download_id]['file_path'] = save_path - self._active_downloads[download_id]['part_path'] = part_path + self._active_downloads[download_id]["file_path"] = save_path + self._active_downloads[download_id]["part_path"] = part_path # Download preview image if available - images = version_info.get('images', []) + images = version_info.get("images", []) if images: if progress_callback: - await progress_callback(1) # 1% progress for starting preview download + await progress_callback( + 1 + ) # 1% progress for starting preview download settings_manager = get_settings_manager() blur_mature_content = bool( - settings_manager.get('blur_mature_content', True) + settings_manager.get("blur_mature_content", True) ) selected_image, nsfw_level = select_preview_media( images, blur_mature_content=blur_mature_content, ) - preview_url = selected_image.get('url') if selected_image else None + preview_url = selected_image.get("url") if selected_image else None media_type = ( - (selected_image.get('type') or '').lower() - if selected_image - else '' + (selected_image.get("type") or "").lower() if selected_image else "" ) def _extension_from_url(url: str, fallback: str) -> str: @@ -641,10 +798,12 @@ class DownloadManager: if preview_url: downloader = await get_downloader() - if media_type == 'video': - preview_ext = _extension_from_url(preview_url, '.mp4') + if media_type == "video": + preview_ext = _extension_from_url(preview_url, ".mp4") preview_path = os.path.splitext(save_path)[0] + preview_ext - rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='video') + rewritten_url, rewritten = rewrite_preview_url( + preview_url, media_type="video" + ) attempt_urls: List[str] = [] if rewritten: attempt_urls.append(rewritten_url) @@ -656,22 +815,20 @@ class DownloadManager: continue seen_attempts.add(attempt) success, _ = await downloader.download_file( - attempt, - preview_path, - use_auth=False + attempt, preview_path, use_auth=False ) if success: preview_downloaded = True break else: - rewritten_url, rewritten = rewrite_preview_url(preview_url, media_type='image') + rewritten_url, rewritten = rewrite_preview_url( + preview_url, media_type="image" + ) if rewritten: - preview_ext = _extension_from_url(preview_url, '.png') + preview_ext = _extension_from_url(preview_url, ".png") preview_path = os.path.splitext(save_path)[0] + preview_ext success, _ = await downloader.download_file( - rewritten_url, - preview_path, - use_auth=False + rewritten_url, preview_path, use_auth=False ) if success: preview_downloaded = True @@ -679,27 +836,34 @@ class DownloadManager: if not preview_downloaded: temp_path: str | None = None try: - with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file: + with tempfile.NamedTemporaryFile( + suffix=".png", delete=False + ) as temp_file: temp_path = temp_file.name - success, content, _ = await downloader.download_to_memory( - preview_url, - use_auth=False + ( + success, + content, + _, + ) = await downloader.download_to_memory( + preview_url, use_auth=False ) if success: - with open(temp_path, 'wb') as temp_file_handle: + with open(temp_path, "wb") as temp_file_handle: temp_file_handle.write(content) - preview_path = os.path.splitext(save_path)[0] + '.webp' + preview_path = ( + os.path.splitext(save_path)[0] + ".webp" + ) optimized_data, _ = ExifUtils.optimize_image( image_data=temp_path, target_width=CARD_PREVIEW_WIDTH, - format='webp', + format="webp", quality=85, - preserve_metadata=False + preserve_metadata=False, ) - with open(preview_path, 'wb') as preview_file: + with open(preview_path, "wb") as preview_file: preview_file.write(optimized_data) preview_downloaded = True @@ -708,11 +872,13 @@ class DownloadManager: try: os.unlink(temp_path) except Exception as e: - logger.warning(f"Failed to delete temp file: {e}") + logger.warning( + f"Failed to delete temp file: {e}" + ) if preview_downloaded and preview_path: preview_nsfw_level = nsfw_level - metadata.preview_url = preview_path.replace(os.sep, '/') + metadata.preview_url = preview_path.replace(os.sep, "/") metadata.preview_nsfw_level = nsfw_level if progress_callback: @@ -726,7 +892,8 @@ class DownloadManager: for download_url in download_urls: use_auth = download_url.startswith("https://civitai.com/api/download/") download_kwargs = { - "progress_callback": lambda progress, snapshot=None: self._handle_download_progress( + "progress_callback": lambda progress, + snapshot=None: self._handle_download_progress( progress, progress_callback, snapshot, @@ -751,11 +918,13 @@ class DownloadManager: try: os.remove(save_path) except Exception as e: - logger.warning(f"Failed to remove incomplete file {save_path}: {e}") + logger.warning( + f"Failed to remove incomplete file {save_path}: {e}" + ) else: # Clean up files on failure, but preserve .part file for resume cleanup_files = [metadata_path] - preview_path_value = getattr(metadata, 'preview_url', None) + preview_path_value = getattr(metadata, "preview_url", None) if preview_path_value and os.path.exists(preview_path_value): cleanup_files.append(preview_path_value) @@ -770,38 +939,59 @@ class DownloadManager: if os.path.exists(part_path): logger.info(f"Preserving partial download for resume: {part_path}") - return {'success': False, 'error': last_error or 'Failed to download file'} + return { + "success": False, + "error": last_error or "Failed to download file", + } # 4. Handle archive extraction and prepare per-file metadata actual_file_paths = [save_path] if zipfile.is_zipfile(save_path): - supported_extensions = self._get_supported_extensions_for_type(model_type) + supported_extensions = self._get_supported_extensions_for_type( + model_type + ) extracted_paths = await self._extract_model_files_from_archive( save_path, supported_extensions ) if not extracted_paths: supported_text = ", ".join(sorted(supported_extensions)) return { - 'success': False, - 'error': f'Zip archive does not contain any supported model files ({supported_text})', + "success": False, + "error": f"Zip archive does not contain any supported model files ({supported_text})", } actual_file_paths = extracted_paths try: os.remove(save_path) except OSError as exc: - logger.warning(f"Unable to delete temporary archive {save_path}: {exc}") + logger.warning( + f"Unable to delete temporary archive {save_path}: {exc}" + ) if download_id and download_id in self._active_downloads: - self._active_downloads[download_id]['file_path'] = extracted_paths[0] - self._active_downloads[download_id]['extracted_paths'] = extracted_paths + self._active_downloads[download_id]["file_path"] = extracted_paths[ + 0 + ] + self._active_downloads[download_id]["extracted_paths"] = ( + extracted_paths + ) - metadata_entries = await self._build_metadata_entries(metadata, actual_file_paths) + metadata_entries = await self._build_metadata_entries( + metadata, actual_file_paths + ) if preview_path: - preview_targets = self._distribute_preview_to_entries(preview_path, metadata_entries) + preview_targets = self._distribute_preview_to_entries( + preview_path, metadata_entries + ) for entry, target in zip(metadata_entries, preview_targets): entry.preview_url = target.replace(os.sep, "/") entry.preview_nsfw_level = preview_nsfw_level - if download_id and download_id in self._active_downloads and preview_targets: - self._active_downloads[download_id]["preview_path"] = preview_targets[0] + if ( + download_id + and download_id in self._active_downloads + and preview_targets + ): + self._active_downloads[download_id]["preview_path"] = ( + preview_targets[0] + ) scanner = None if model_type == "checkpoint": @@ -815,11 +1005,15 @@ class DownloadManager: logger.info(f"Updating embedding cache for {actual_file_paths[0]}") adjust_cached_entry = ( - getattr(scanner, "adjust_cached_entry", None) if scanner is not None else None + getattr(scanner, "adjust_cached_entry", None) + if scanner is not None + else None ) for index, entry in enumerate(metadata_entries): - file_path_for_adjust = getattr(entry, "file_path", actual_file_paths[index]) + file_path_for_adjust = getattr( + entry, "file_path", actual_file_paths[index] + ) normalized_file_path = ( file_path_for_adjust.replace(os.sep, "/") if isinstance(file_path_for_adjust, str) @@ -837,12 +1031,16 @@ class DownloadManager: adjust_metadata = getattr(scanner, "adjust_metadata", None) if callable(adjust_metadata): - adjusted_entry = adjust_metadata(entry, normalized_file_path, adjust_root) + adjusted_entry = adjust_metadata( + entry, normalized_file_path, adjust_root + ) if adjusted_entry is not None: entry = adjusted_entry metadata_entries[index] = entry - metadata_file_path = os.path.splitext(entry.file_path)[0] + '.metadata.json' + metadata_file_path = ( + os.path.splitext(entry.file_path)[0] + ".metadata.json" + ) metadata_files_for_cleanup.append(metadata_file_path) await MetadataManager.save_metadata(entry.file_path, entry) @@ -858,13 +1056,18 @@ class DownloadManager: if progress_callback: await progress_callback(100) - return {'success': True} + return {"success": True} except Exception as e: logger.error(f"Error in _execute_download: {e}", exc_info=True) cleanup_targets = { path - for path in [save_path, metadata_path, *metadata_files_for_cleanup, *extracted_paths] + for path in [ + save_path, + metadata_path, + *metadata_files_for_cleanup, + *extracted_paths, + ] if path } preview_candidate = ( @@ -883,14 +1086,33 @@ class DownloadManager: except Exception as exc: logger.warning(f"Failed to cleanup file {path}: {exc}") - return {'success': False, 'error': str(e)} + return {"success": False, "error": str(e)} def _get_supported_extensions_for_type(self, model_type: str) -> Set[str]: if model_type == "checkpoint": - return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft', '.gguf'} + return { + ".ckpt", + ".pt", + ".pt2", + ".bin", + ".pth", + ".safetensors", + ".pkl", + ".sft", + ".gguf", + } if model_type == "embedding": - return {'.ckpt', '.pt', '.pt2', '.bin', '.pth', '.safetensors', '.pkl', '.sft'} - return {'.safetensors'} + return { + ".ckpt", + ".pt", + ".pt2", + ".bin", + ".pth", + ".safetensors", + ".pkl", + ".sft", + } + return {".safetensors"} async def _extract_model_files_from_archive( self, @@ -901,7 +1123,9 @@ class DownloadManager: return [] target_dir = os.path.dirname(archive_path) - normalized_extensions = {ext.lower() for ext in allowed_extensions or {'.safetensors'}} + normalized_extensions = { + ext.lower() for ext in allowed_extensions or {".safetensors"} + } def _extract_sync() -> List[str]: extracted_files: List[str] = [] @@ -915,7 +1139,9 @@ class DownloadManager: file_name = os.path.basename(info.filename) if not file_name: continue - dest_path = self._resolve_extracted_destination(target_dir, file_name) + dest_path = self._resolve_extracted_destination( + target_dir, file_name + ) with archive.open(info) as source, open(dest_path, "wb") as target: shutil.copyfileobj(source, target) extracted_files.append(dest_path) @@ -923,7 +1149,9 @@ class DownloadManager: return await asyncio.to_thread(_extract_sync) - async def _build_metadata_entries(self, base_metadata, file_paths: List[str]) -> List: + async def _build_metadata_entries( + self, base_metadata, file_paths: List[str] + ) -> List: if not file_paths: return [] @@ -949,7 +1177,9 @@ class DownloadManager: return destination - def _distribute_preview_to_entries(self, preview_path: str, entries: List) -> List[str]: + def _distribute_preview_to_entries( + self, preview_path: str, entries: List + ) -> List[str]: if not preview_path or not entries: return [] @@ -986,7 +1216,9 @@ class DownloadManager: if not progress_callback: return - file_progress, original_snapshot = self._normalize_progress(progress_update, snapshot) + file_progress, original_snapshot = self._normalize_progress( + progress_update, snapshot + ) overall_progress = 3 + (file_progress * 0.97) overall_progress = max(0.0, min(overall_progress, 100.0)) rounded_progress = round(overall_progress) @@ -1001,20 +1233,22 @@ class DownloadManager: timestamp=original_snapshot.timestamp, ) - await self._dispatch_progress(progress_callback, normalized_snapshot, rounded_progress) - + await self._dispatch_progress( + progress_callback, normalized_snapshot, rounded_progress + ) + async def cancel_download(self, download_id: str) -> Dict: """Cancel an active download by download_id Args: download_id: The unique identifier of the download task - + Returns: Dict: Status of the cancellation operation """ if download_id not in self._download_tasks: - return {'success': False, 'error': 'Download task not found'} - + return {"success": False, "error": "Download task not found"} + try: # Get the task and cancel it task = self._download_tasks[download_id] @@ -1026,24 +1260,24 @@ class DownloadManager: # Update status in active downloads if download_id in self._active_downloads: - self._active_downloads[download_id]['status'] = 'cancelling' - self._active_downloads[download_id]['bytes_per_second'] = 0.0 - + self._active_downloads[download_id]["status"] = "cancelling" + self._active_downloads[download_id]["bytes_per_second"] = 0.0 + # Wait briefly for the task to acknowledge cancellation try: await asyncio.wait_for(asyncio.shield(task), timeout=2.0) except (asyncio.CancelledError, asyncio.TimeoutError): pass - + # Clean up ALL files including .part when user cancels download_info = self._active_downloads.get(download_id) if download_info: target_files = set() - primary_path = download_info.get('file_path') + primary_path = download_info.get("file_path") if primary_path: target_files.add(primary_path) - for extra_path in download_info.get('extracted_paths', []): + for extra_path in download_info.get("extracted_paths", []): if extra_path: target_files.add(extra_path) @@ -1056,8 +1290,8 @@ class DownloadManager: logger.error(f"Error deleting file: {e}") # Delete the .part file (only on user cancellation) - if 'part_path' in download_info: - part_path = download_info['part_path'] + if "part_path" in download_info: + part_path = download_info["part_path"] if os.path.exists(part_path): try: os.unlink(part_path) @@ -1067,35 +1301,39 @@ class DownloadManager: # Delete metadata files for each resolved path for file_path in target_files: - metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata_path = os.path.splitext(file_path)[0] + ".metadata.json" if os.path.exists(metadata_path): try: os.unlink(metadata_path) except Exception as e: logger.error(f"Error deleting metadata file: {e}") - preview_path_value = download_info.get('preview_path') + preview_path_value = download_info.get("preview_path") if preview_path_value and os.path.exists(preview_path_value): try: os.unlink(preview_path_value) logger.debug(f"Deleted preview file: {preview_path_value}") except Exception as e: - logger.error(f"Error deleting preview file: {preview_path_value}") + logger.error( + f"Error deleting preview file: {preview_path_value}" + ) # Delete preview file if exists (.webp or .mp4) for legacy paths for file_path in target_files: - for preview_ext in ['.webp', '.mp4']: + for preview_ext in [".webp", ".mp4"]: preview_path = os.path.splitext(file_path)[0] + preview_ext if os.path.exists(preview_path): try: os.unlink(preview_path) logger.debug(f"Deleted preview file: {preview_path}") except Exception as e: - logger.error(f"Error deleting preview file: {preview_path}") - return {'success': True, 'message': 'Download cancelled successfully'} + logger.error( + f"Error deleting preview file: {preview_path}" + ) + return {"success": True, "message": "Download cancelled successfully"} except Exception as e: logger.error(f"Error cancelling download: {e}", exc_info=True) - return {'success': False, 'error': str(e)} + return {"success": False, "error": str(e)} finally: self._pause_events.pop(download_id, None) @@ -1103,33 +1341,33 @@ class DownloadManager: """Pause an active download without losing progress.""" if download_id not in self._download_tasks: - return {'success': False, 'error': 'Download task not found'} + return {"success": False, "error": "Download task not found"} pause_control = self._pause_events.get(download_id) if pause_control is None: - return {'success': False, 'error': 'Download task not found'} + return {"success": False, "error": "Download task not found"} if pause_control.is_paused(): - return {'success': False, 'error': 'Download is already paused'} + return {"success": False, "error": "Download is already paused"} pause_control.pause() download_info = self._active_downloads.get(download_id) if download_info is not None: - download_info['status'] = 'paused' - download_info['bytes_per_second'] = 0.0 + download_info["status"] = "paused" + download_info["bytes_per_second"] = 0.0 - return {'success': True, 'message': 'Download paused successfully'} + return {"success": True, "message": "Download paused successfully"} async def resume_download(self, download_id: str) -> Dict: """Resume a previously paused download.""" pause_control = self._pause_events.get(download_id) if pause_control is None: - return {'success': False, 'error': 'Download task not found'} + return {"success": False, "error": "Download task not found"} if pause_control.is_set(): - return {'success': False, 'error': 'Download is not paused'} + return {"success": False, "error": "Download is not paused"} download_info = self._active_downloads.get(download_id) force_reconnect = False @@ -1147,11 +1385,11 @@ class DownloadManager: pause_control.resume(force_reconnect=force_reconnect) if download_info is not None: - if download_info.get('status') == 'paused': - download_info['status'] = 'downloading' - download_info.setdefault('bytes_per_second', 0.0) + if download_info.get("status") == "paused": + download_info["status"] = "downloading" + download_info.setdefault("bytes_per_second", 0.0) - return {'success': True, 'message': 'Download resumed successfully'} + return {"success": True, "message": "Download resumed successfully"} @staticmethod def _coerce_progress_value(progress) -> float: @@ -1173,10 +1411,12 @@ class DownloadManager: return snapshot.percent_complete, snapshot if isinstance(progress, dict): - if 'percent_complete' in progress: - return cls._coerce_progress_value(progress['percent_complete']), snapshot - if 'progress' in progress: - return cls._coerce_progress_value(progress['progress']), snapshot + if "percent_complete" in progress: + return cls._coerce_progress_value( + progress["percent_complete"] + ), snapshot + if "progress" in progress: + return cls._coerce_progress_value(progress["progress"]), snapshot return cls._coerce_progress_value(progress), None @@ -1201,22 +1441,22 @@ class DownloadManager: async def get_active_downloads(self) -> Dict: """Get information about all active downloads - + Returns: Dict: List of active downloads and their status """ return { - 'downloads': [ + "downloads": [ { - 'download_id': task_id, - 'model_id': info.get('model_id'), - 'model_version_id': info.get('model_version_id'), - 'progress': info.get('progress', 0), - 'status': info.get('status', 'unknown'), - 'error': info.get('error', None), - 'bytes_downloaded': info.get('bytes_downloaded', 0), - 'total_bytes': info.get('total_bytes'), - 'bytes_per_second': info.get('bytes_per_second', 0.0), + "download_id": task_id, + "model_id": info.get("model_id"), + "model_version_id": info.get("model_version_id"), + "progress": info.get("progress", 0), + "status": info.get("status", "unknown"), + "error": info.get("error", None), + "bytes_downloaded": info.get("bytes_downloaded", 0), + "total_bytes": info.get("total_bytes"), + "bytes_per_second": info.get("bytes_per_second", 0.0), } for task_id, info in self._active_downloads.items() ] diff --git a/tests/services/test_download_manager.py b/tests/services/test_download_manager.py index 669f8bf8..c176ed45 100644 --- a/tests/services/test_download_manager.py +++ b/tests/services/test_download_manager.py @@ -87,9 +87,19 @@ def scanners(monkeypatch): checkpoint_scanner = DummyScanner() embedding_scanner = DummyScanner() - monkeypatch.setattr(ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner)) - monkeypatch.setattr(ServiceRegistry, "get_checkpoint_scanner", AsyncMock(return_value=checkpoint_scanner)) - monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=embedding_scanner)) + monkeypatch.setattr( + ServiceRegistry, "get_lora_scanner", AsyncMock(return_value=lora_scanner) + ) + monkeypatch.setattr( + ServiceRegistry, + "get_checkpoint_scanner", + AsyncMock(return_value=checkpoint_scanner), + ) + monkeypatch.setattr( + ServiceRegistry, + "get_embedding_scanner", + AsyncMock(return_value=embedding_scanner), + ) return SimpleNamespace( lora=lora_scanner, @@ -148,7 +158,9 @@ async def test_download_requires_identifier(): } -async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata_provider, tmp_path): +async def test_successful_download_uses_defaults( + monkeypatch, scanners, metadata_provider, tmp_path +): manager = DownloadManager() captured = {} @@ -178,7 +190,9 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata ) return {"success": True} - monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False) + monkeypatch.setattr( + DownloadManager, "_execute_download", fake_execute_download, raising=False + ) result = await manager.download_from_civitai( model_version_id=99, @@ -194,15 +208,19 @@ async def test_successful_download_uses_defaults(monkeypatch, scanners, metadata assert manager._active_downloads[result["download_id"]]["status"] == "completed" assert captured["relative_path"] == "MappedModel/fantasy" - expected_dir = Path(get_settings_manager().get("default_lora_root")) / "MappedModel" / "fantasy" + expected_dir = ( + Path(get_settings_manager().get("default_lora_root")) + / "MappedModel" + / "fantasy" + ) assert captured["save_dir"] == expected_dir assert captured["model_type"] == "lora" - assert captured["download_urls"] == [ - "https://example.invalid/file.safetensors" - ] + assert captured["download_urls"] == ["https://example.invalid/file.safetensors"] -async def test_download_uses_active_mirrors(monkeypatch, scanners, metadata_provider, tmp_path): +async def test_download_uses_active_mirrors( + monkeypatch, scanners, metadata_provider, tmp_path +): manager = DownloadManager() metadata_with_mirrors = { @@ -216,8 +234,14 @@ async def test_download_uses_active_mirrors(monkeypatch, scanners, metadata_prov "primary": True, "downloadUrl": "https://example.invalid/file.safetensors", "mirrors": [ - {"url": "https://mirror.example/file.safetensors", "deletedAt": None}, - {"url": "https://mirror.example/old.safetensors", "deletedAt": "2024-01-01"}, + { + "url": "https://mirror.example/file.safetensors", + "deletedAt": None, + }, + { + "url": "https://mirror.example/old.safetensors", + "deletedAt": "2024-01-01", + }, ], "name": "file.safetensors", } @@ -243,7 +267,9 @@ async def test_download_uses_active_mirrors(monkeypatch, scanners, metadata_prov captured["download_urls"] = download_urls return {"success": True} - monkeypatch.setattr(DownloadManager, "_execute_download", fake_execute_download, raising=False) + monkeypatch.setattr( + DownloadManager, "_execute_download", fake_execute_download, raising=False + ) result = await manager.download_from_civitai( model_version_id=99, @@ -257,7 +283,9 @@ async def test_download_uses_active_mirrors(monkeypatch, scanners, metadata_prov assert captured["download_urls"] == ["https://mirror.example/file.safetensors"] -async def test_download_aborts_when_version_exists(monkeypatch, scanners, metadata_provider): +async def test_download_aborts_when_version_exists( + monkeypatch, scanners, metadata_provider +): scanners.lora.exists = True manager = DownloadManager() @@ -280,7 +308,9 @@ async def test_download_handles_metadata_errors(monkeypatch, scanners): monkeypatch.setattr( download_manager, "get_default_metadata_provider", - AsyncMock(return_value=SimpleNamespace(get_model_version=AsyncMock(return_value=None))), + AsyncMock( + return_value=SimpleNamespace(get_model_version=AsyncMock(return_value=None)) + ), ) manager = DownloadManager() @@ -331,7 +361,9 @@ def test_embedding_relative_path_replaces_spaces(): def test_relative_path_supports_model_and_version_placeholders(): manager = DownloadManager() settings_manager = get_settings_manager() - settings_manager.settings["download_path_templates"]["lora"] = "{model_name}/{version_name}" + settings_manager.settings["download_path_templates"]["lora"] = ( + "{model_name}/{version_name}" + ) version_info = { "baseModel": "BaseModel", @@ -347,7 +379,9 @@ def test_relative_path_supports_model_and_version_placeholders(): def test_relative_path_sanitizes_model_and_version_placeholders(): manager = DownloadManager() settings_manager = get_settings_manager() - settings_manager.settings["download_path_templates"]["lora"] = "{model_name}/{version_name}" + settings_manager.settings["download_path_templates"]["lora"] = ( + "{model_name}/{version_name}" + ) version_info = { "baseModel": "BaseModel", @@ -403,7 +437,9 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path): return True, "second success" dummy_downloader = DummyDownloader() - monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader)) + monkeypatch.setattr( + download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader) + ) class DummyScanner: def __init__(self): @@ -413,9 +449,17 @@ async def test_execute_download_retries_urls(monkeypatch, tmp_path): self.calls.append((metadata_dict, relative_path)) dummy_scanner = DummyScanner() - monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) - monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner)) - monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) + monkeypatch.setattr( + DownloadManager, + "_get_checkpoint_scanner", + AsyncMock(return_value=dummy_scanner), + ) + monkeypatch.setattr( + ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) + ) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) @@ -470,7 +514,9 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p download_urls = ["https://example.invalid/model.safetensors"] class DummyDownloader: - async def download_file(self, _url, path, progress_callback=None, use_auth=None): + async def download_file( + self, _url, path, progress_callback=None, use_auth=None + ): Path(path).write_text("content") return True, "ok" @@ -488,7 +534,9 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p def _find_root_for_file(self, file_path: str): return self.root if file_path.startswith(self.root) else None - def adjust_metadata(self, metadata_obj, _file_path: str, root_path: Optional[str]): + def adjust_metadata( + self, metadata_obj, _file_path: str, root_path: Optional[str] + ): if root_path: metadata_obj.model_type = "diffusion_model" return metadata_obj @@ -503,7 +551,11 @@ async def test_execute_download_adjusts_checkpoint_model_type(monkeypatch, tmp_p return True dummy_scanner = DummyCheckpointScanner(root_dir) - monkeypatch.setattr(DownloadManager, "_get_checkpoint_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr( + DownloadManager, + "_get_checkpoint_scanner", + AsyncMock(return_value=dummy_scanner), + ) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) result = await manager._execute_download( @@ -560,9 +612,13 @@ async def test_execute_download_extracts_zip_single_model(monkeypatch, tmp_path) archive.writestr("docs/readme.txt", b"ignore") return True, "ok" - monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())) + monkeypatch.setattr( + download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) + ) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) - monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) hash_calculator = AsyncMock(return_value="hash-single") monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator) @@ -624,9 +680,13 @@ async def test_execute_download_extracts_zip_multiple_models(monkeypatch, tmp_pa archive.writestr("readme.md", b"ignore") return True, "ok" - monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())) + monkeypatch.setattr( + download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) + ) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) - monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) hash_calculator = AsyncMock(side_effect=["hash-one", "hash-two"]) monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator) @@ -694,9 +754,13 @@ async def test_execute_download_extracts_zip_pt_embedding(monkeypatch, tmp_path) archive.writestr("docs/readme.txt", b"ignore") return True, "ok" - monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader())) + monkeypatch.setattr( + download_manager, "get_downloader", AsyncMock(return_value=DummyDownloader()) + ) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) - monkeypatch.setattr(ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr( + ServiceRegistry, "get_embedding_scanner", AsyncMock(return_value=dummy_scanner) + ) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) hash_calculator = AsyncMock(return_value="hash-pt") monkeypatch.setattr(download_manager, "calculate_sha256", hash_calculator) @@ -815,7 +879,7 @@ async def test_resume_download_requests_reconnect_for_stalled_stream(): download_id = "dl" pause_control = DownloadStreamControl(stall_timeout=40) pause_control.pause() - pause_control.last_progress_timestamp = (datetime.now().timestamp() - 120) + pause_control.last_progress_timestamp = datetime.now().timestamp() - 120 manager._pause_events[download_id] = pause_control manager._active_downloads[download_id] = { "status": "paused", @@ -899,7 +963,9 @@ async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_ return False, b"", {} dummy_downloader = DummyDownloader() - monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader)) + monkeypatch.setattr( + download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader) + ) optimize_called = {"value": False} @@ -907,11 +973,15 @@ async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_ optimize_called["value"] = True return b"", {} - monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image)) + monkeypatch.setattr( + download_manager.ExifUtils, "optimize_image", staticmethod(fake_optimize_image) + ) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) - monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) result = await manager._execute_download( download_urls=download_urls, @@ -925,7 +995,9 @@ async def test_execute_download_uses_rewritten_civitai_preview(monkeypatch, tmp_ ) assert result == {"success": True} - preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")] + preview_urls = [ + url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg") + ] assert any("width=450,optimized=true" in url for url in preview_urls) assert dummy_downloader.memory_calls == 0 assert optimize_called["value"] is False @@ -1021,12 +1093,20 @@ async def test_execute_download_respects_blur_setting(monkeypatch, tmp_path): lambda: StubSettingsManager(True), ) - monkeypatch.setattr(download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader)) - monkeypatch.setattr(download_manager.ExifUtils, "optimize_image", staticmethod(lambda **_kwargs: (b"", {}))) + monkeypatch.setattr( + download_manager, "get_downloader", AsyncMock(return_value=dummy_downloader) + ) + monkeypatch.setattr( + download_manager.ExifUtils, + "optimize_image", + staticmethod(lambda **_kwargs: (b"", {})), + ) monkeypatch.setattr(MetadataManager, "save_metadata", AsyncMock(return_value=True)) dummy_scanner = SimpleNamespace(add_model_to_cache=AsyncMock(return_value=None)) - monkeypatch.setattr(DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner)) + monkeypatch.setattr( + DownloadManager, "_get_lora_scanner", AsyncMock(return_value=dummy_scanner) + ) result = await manager._execute_download( download_urls=download_urls, @@ -1040,10 +1120,283 @@ async def test_execute_download_respects_blur_setting(monkeypatch, tmp_path): ) assert result == {"success": True} - preview_urls = [url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg")] + preview_urls = [ + url for url, _ in dummy_downloader.file_calls if url.endswith(".jpeg") + ] assert preview_urls assert all("nsfw.jpeg" not in url for url in preview_urls) assert any("safe.jpeg" in url for url in preview_urls) assert metadata.preview_nsfw_level == 1 stored_preview = manager._active_downloads["dl"].get("preview_path") assert stored_preview and stored_preview.endswith(".jpeg") + + +async def test_civarchive_source_uses_civarchive_provider( + monkeypatch, scanners, tmp_path +): + manager = DownloadManager() + + captured_providers = [] + + class CivArchiveProvider: + async def get_model_version(self, model_id, model_version_id): + captured_providers.append("civarchive") + return { + "id": 119514, + "model": {"type": "LoRA", "tags": ["celebrity"]}, + "baseModel": "SD 1.5", + "creator": {"username": "dogu_cat"}, + "source": "civarchive", + "files": [ + { + "type": "Model", + "primary": True, + "mirrors": [ + { + "url": "https://huggingface.co/file.safetensors", + "deletedAt": None, + }, + { + "url": "https://civitai.com/api/download/models/119514", + "deletedAt": "2025-05-23T00:00:00.000Z", + }, + ], + "name": "file.safetensors", + } + ], + } + + class DefaultProvider: + async def get_model_version(self, model_id, model_version_id): + captured_providers.append("default") + return { + "id": 119514, + "model": {"type": "LoRA", "tags": ["celebrity"]}, + "baseModel": "SD 1.5", + "creator": {"username": "dogu_cat"}, + "files": [ + { + "type": "Model", + "primary": True, + "downloadUrl": "https://civitai.com/api/download/models/119514", + "name": "file.safetensors", + } + ], + } + + async def get_metadata_provider(provider_name): + if provider_name == "civarchive_api": + return CivArchiveProvider() + return None + + async def get_default_metadata_provider(): + return DefaultProvider() + + monkeypatch.setattr( + download_manager, "get_metadata_provider", get_metadata_provider + ) + monkeypatch.setattr( + download_manager, "get_default_metadata_provider", get_default_metadata_provider + ) + + captured = {} + + async def fake_execute_download( + self, + *, + download_urls, + save_dir, + metadata, + version_info, + relative_path, + progress_callback, + model_type, + download_id, + ): + captured["download_urls"] = download_urls + captured["version_info"] = version_info + return {"success": True} + + monkeypatch.setattr( + DownloadManager, "_execute_download", fake_execute_download, raising=False + ) + + result = await manager.download_from_civitai( + model_id=110828, + model_version_id=119514, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source="civarchive", + ) + + assert result["success"] is True + assert captured_providers == ["civarchive"] + assert captured["version_info"]["source"] == "civarchive" + + +async def test_civarchive_source_prioritizes_non_civitai_urls( + monkeypatch, scanners, tmp_path +): + manager = DownloadManager() + + class CivArchiveProvider: + async def get_model_version(self, model_id, model_version_id): + return { + "id": 119514, + "model": {"type": "LoRA", "tags": ["celebrity"]}, + "baseModel": "SD 1.5", + "creator": {"username": "dogu_cat"}, + "source": "civarchive", + "files": [ + { + "type": "Model", + "primary": True, + "mirrors": [ + { + "url": "https://huggingface.co/file.safetensors", + "deletedAt": None, + "source": "huggingface", + }, + { + "url": "https://civitai.com/api/download/models/119514", + "deletedAt": None, + "source": "civitai", + }, + { + "url": "https://another-mirror.org/file.safetensors", + "deletedAt": None, + "source": "other", + }, + ], + "name": "file.safetensors", + } + ], + } + + async def get_metadata_provider(provider_name): + if provider_name == "civarchive_api": + return CivArchiveProvider() + return None + + monkeypatch.setattr( + download_manager, "get_metadata_provider", get_metadata_provider + ) + + captured = {} + + async def fake_execute_download( + self, + *, + download_urls, + save_dir, + metadata, + version_info, + relative_path, + progress_callback, + model_type, + download_id, + ): + captured["download_urls"] = download_urls + return {"success": True} + + monkeypatch.setattr( + DownloadManager, "_execute_download", fake_execute_download, raising=False + ) + + result = await manager.download_from_civitai( + model_id=110828, + model_version_id=119514, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source="civarchive", + ) + + assert result["success"] is True + assert captured["download_urls"] == [ + "https://huggingface.co/file.safetensors", + "https://another-mirror.org/file.safetensors", + "https://civitai.com/api/download/models/119514", + ] + assert captured["download_urls"][0] == "https://huggingface.co/file.safetensors" + assert captured["download_urls"][1] == "https://another-mirror.org/file.safetensors" + + +async def test_civarchive_source_fallback_to_default_provider( + monkeypatch, scanners, tmp_path +): + manager = DownloadManager() + + class CivArchiveProvider: + async def get_model_version(self, model_id, model_version_id): + return None + + class DefaultProvider: + async def get_model_version(self, model_id, model_version_id): + return { + "id": 119514, + "model": {"type": "LoRA", "tags": ["celebrity"]}, + "baseModel": "SD 1.5", + "creator": {"username": "dogu_cat"}, + "files": [ + { + "type": "Model", + "primary": True, + "downloadUrl": "https://civitai.com/api/download/models/119514", + "name": "file.safetensors", + } + ], + } + + captured_providers = [] + + async def get_metadata_provider(provider_name): + if provider_name == "civarchive_api": + captured_providers.append("civarchive_api") + return CivArchiveProvider() + return None + + async def get_default_metadata_provider(): + captured_providers.append("default") + return DefaultProvider() + + monkeypatch.setattr( + download_manager, "get_metadata_provider", get_metadata_provider + ) + monkeypatch.setattr( + download_manager, "get_default_metadata_provider", get_default_metadata_provider + ) + + captured = {} + + async def fake_execute_download( + self, + *, + download_urls, + save_dir, + metadata, + version_info, + relative_path, + progress_callback, + model_type, + download_id, + ): + captured["download_urls"] = download_urls + return {"success": True} + + monkeypatch.setattr( + DownloadManager, "_execute_download", fake_execute_download, raising=False + ) + + result = await manager.download_from_civitai( + model_id=110828, + model_version_id=119514, + save_dir=str(tmp_path), + use_default_paths=True, + progress_callback=None, + source="civarchive", + ) + + assert result["success"] is True + assert captured_providers == ["civarchive_api", "default"]