From 14721c265f5d9bfc910474e08908fc66f96f1648 Mon Sep 17 00:00:00 2001 From: Will Miao <13051207myq@gmail.com> Date: Tue, 9 Sep 2025 10:34:14 +0800 Subject: [PATCH] Refactor download logic to use unified downloader service - Introduced a new `Downloader` class to centralize HTTP/HTTPS download management. - Replaced direct `aiohttp` session handling with the unified downloader in `MetadataArchiveManager`, `DownloadManager`, and `ExampleImagesProcessor`. - Added support for resumable downloads, progress tracking, and error handling in the new downloader. - Updated methods to utilize the downloader's capabilities for downloading files and images, improving code maintainability and readability. --- py/routes/update_routes.py | 127 ++--- py/services/civitai_client.py | 515 +++++++------------- py/services/downloader.py | 465 ++++++++++++++++++ py/services/metadata_archive_manager.py | 43 +- py/utils/example_images_download_manager.py | 61 +-- py/utils/example_images_processor.py | 80 +-- 6 files changed, 777 insertions(+), 514 deletions(-) create mode 100644 py/services/downloader.py diff --git a/py/routes/update_routes.py b/py/routes/update_routes.py index 9d60036e..66ef603a 100644 --- a/py/routes/update_routes.py +++ b/py/routes/update_routes.py @@ -1,5 +1,4 @@ import os -import aiohttp import logging import toml import git @@ -8,7 +7,7 @@ import shutil import tempfile from aiohttp import web from typing import Dict, List - +from ..services.downloader import get_downloader, Downloader logger = logging.getLogger(__name__) @@ -162,28 +161,42 @@ class UpdateRoutes: github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" try: - async with aiohttp.ClientSession() as session: - async with session.get(github_api) as resp: - if resp.status != 200: - logger.error(f"Failed to fetch release info: {resp.status}") - return False, "" - data = await resp.json() - zip_url = data.get("zipball_url") - version = data.get("tag_name", "unknown") + downloader = await get_downloader() + + # Get release info + success, data = await downloader.make_request( + 'GET', + github_api, + use_auth=False + ) + if not success: + logger.error(f"Failed to fetch release info: {data}") + return False, "" + + zip_url = data.get("zipball_url") + version = data.get("tag_name", "unknown") - # Download ZIP - async with session.get(zip_url) as zip_resp: - if zip_resp.status != 200: - logger.error(f"Failed to download ZIP: {zip_resp.status}") - return False, "" - with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: - tmp_zip.write(await zip_resp.read()) - zip_path = tmp_zip.name + # Download ZIP to temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: + tmp_zip_path = tmp_zip.name + + success, result = await downloader.download_file( + url=zip_url, + save_path=tmp_zip_path, + use_auth=False, + allow_resume=False + ) + + if not success: + logger.error(f"Failed to download ZIP: {result}") + return False, "" - UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json']) + zip_path = tmp_zip_path - # Extract ZIP to temp dir - with tempfile.TemporaryDirectory() as tmp_dir: + UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json']) + + # Extract ZIP to temp dir + with tempfile.TemporaryDirectory() as tmp_dir: with zipfile.ZipFile(zip_path, 'r') as zip_ref: zip_ref.extractall(tmp_dir) # Find extracted folder (GitHub ZIP contains a root folder) @@ -213,9 +226,9 @@ class UpdateRoutes: with open(tracking_info_file, "w", encoding='utf-8') as file: file.write('\n'.join(tracking_files)) - os.remove(zip_path) - logger.info(f"Updated plugin via ZIP to {version}") - return True, version + os.remove(zip_path) + logger.info(f"Updated plugin via ZIP to {version}") + return True, version except Exception as e: logger.error(f"ZIP update failed: {e}", exc_info=True) @@ -244,23 +257,23 @@ class UpdateRoutes: github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main" try: - async with aiohttp.ClientSession() as session: - async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response: - if response.status != 200: - logger.warning(f"Failed to fetch GitHub commit: {response.status}") - return "main", [] - - data = await response.json() - commit_sha = data.get('sha', '')[:7] # Short hash - commit_message = data.get('commit', {}).get('message', '') - - # Format as "main-{short_hash}" - version = f"main-{commit_sha}" - - # Use commit message as changelog - changelog = [commit_message] if commit_message else [] - - return version, changelog + downloader = await Downloader.get_instance() + success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'}) + + if not success: + logger.warning(f"Failed to fetch GitHub commit: {data}") + return "main", [] + + commit_sha = data.get('sha', '')[:7] # Short hash + commit_message = data.get('commit', {}).get('message', '') + + # Format as "main-{short_hash}" + version = f"main-{commit_sha}" + + # Use commit message as changelog + changelog = [commit_message] if commit_message else [] + + return version, changelog except Exception as e: logger.error(f"Error fetching nightly version: {e}", exc_info=True) @@ -410,22 +423,22 @@ class UpdateRoutes: github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" try: - async with aiohttp.ClientSession() as session: - async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response: - if response.status != 200: - logger.warning(f"Failed to fetch GitHub release: {response.status}") - return "v0.0.0", [] - - data = await response.json() - version = data.get('tag_name', '') - if not version.startswith('v'): - version = f"v{version}" - - # Extract changelog from release notes - body = data.get('body', '') - changelog = UpdateRoutes._parse_changelog(body) - - return version, changelog + downloader = await Downloader.get_instance() + success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'}) + + if not success: + logger.warning(f"Failed to fetch GitHub release: {data}") + return "v0.0.0", [] + + version = data.get('tag_name', '') + if not version.startswith('v'): + version = f"v{version}" + + # Extract changelog from release notes + body = data.get('body', '') + changelog = UpdateRoutes._parse_changelog(body) + + return version, changelog except Exception as e: logger.error(f"Error fetching remote version: {e}", exc_info=True) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 64a6f2f5..6f87fbe4 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -1,10 +1,10 @@ from datetime import datetime -import aiohttp import os import logging import asyncio from typing import Optional, Dict, Tuple, List from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager +from .downloader import get_downloader logger = logging.getLogger(__name__) @@ -32,61 +32,7 @@ class CivitaiClient: self._initialized = True self.base_url = "https://civitai.com/api/v1" - self.headers = { - 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' - } - self._session = None - self._session_created_at = None - # Adjust chunk size based on storage type - consider making this configurable - self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput - @property - async def session(self) -> aiohttp.ClientSession: - """Lazy initialize the session""" - if self._session is None: - # Optimize TCP connection parameters - connector = aiohttp.TCPConnector( - ssl=True, - limit=8, # Increase from 3 to 8 for better parallelism - ttl_dns_cache=300, # Enable DNS caching with reasonable timeout - force_close=False, # Keep connections for reuse - enable_cleanup_closed=True - ) - trust_env = True # Allow using system environment proxy settings - # Configure timeout parameters - increase read timeout for large files and remove sock_read timeout - timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None) - self._session = aiohttp.ClientSession( - connector=connector, - trust_env=trust_env, - timeout=timeout - ) - self._session_created_at = datetime.now() - return self._session - - async def _ensure_fresh_session(self): - """Refresh session if it's been open too long""" - if self._session is not None: - if not hasattr(self, '_session_created_at') or \ - (datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes - await self.close() - self._session = None - - return await self.session - - def _get_request_headers(self) -> dict: - """Get request headers with optional API key""" - headers = { - 'User-Agent': 'ComfyUI-LoRA-Manager/1.0', - 'Content-Type': 'application/json' - } - - from .settings_manager import settings - api_key = settings.get('civitai_api_key') - if (api_key): - headers['Authorization'] = f'Bearer {api_key}' - - return headers - async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: """Download file with resumable downloads and retry mechanism @@ -99,214 +45,69 @@ class CivitaiClient: Returns: Tuple[bool, str]: (success, save_path or error message) """ - max_retries = 5 - retry_count = 0 - base_delay = 2.0 # Base delay for exponential backoff - - # Initial setup - session = await self._ensure_fresh_session() + downloader = await get_downloader() save_path = os.path.join(save_dir, default_filename) - part_path = save_path + '.part' - # Get existing file size for resume - resume_offset = 0 - if os.path.exists(part_path): - resume_offset = os.path.getsize(part_path) - logger.info(f"Resuming download from offset {resume_offset} bytes") + # Use unified downloader with CivitAI authentication + success, result = await downloader.download_file( + url=url, + save_path=save_path, + progress_callback=progress_callback, + use_auth=True, # Enable CivitAI authentication + allow_resume=True + ) - total_size = 0 - - while retry_count <= max_retries: - try: - headers = self._get_request_headers() - - # Add Range header for resume if we have partial data - if resume_offset > 0: - headers['Range'] = f'bytes={resume_offset}-' - - # Add Range header to allow resumable downloads - headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads - - logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}") - if resume_offset > 0: - logger.debug(f"Requesting range from byte {resume_offset}") - - async with session.get(url, headers=headers, allow_redirects=True) as response: - # Handle different response codes - if response.status == 200: - # Full content response - if resume_offset > 0: - # Server doesn't support ranges, restart from beginning - logger.warning("Server doesn't support range requests, restarting download") - resume_offset = 0 - if os.path.exists(part_path): - os.remove(part_path) - elif response.status == 206: - # Partial content response (resume successful) - content_range = response.headers.get('Content-Range') - if content_range: - # Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048") - range_parts = content_range.split('/') - if len(range_parts) == 2: - total_size = int(range_parts[1]) - logger.info(f"Successfully resumed download from byte {resume_offset}") - elif response.status == 416: - # Range not satisfiable - file might be complete or corrupted - if os.path.exists(part_path): - part_size = os.path.getsize(part_path) - logger.warning(f"Range not satisfiable. Part file size: {part_size}") - # Try to get actual file size - head_response = await session.head(url, headers=self._get_request_headers()) - if head_response.status == 200: - actual_size = int(head_response.headers.get('content-length', 0)) - if part_size == actual_size: - # File is complete, just rename it - os.rename(part_path, save_path) - if progress_callback: - await progress_callback(100) - return True, save_path - # Remove corrupted part file and restart - os.remove(part_path) - resume_offset = 0 - continue - elif response.status == 401: - logger.warning(f"Unauthorized access to resource: {url} (Status 401)") - return False, "Invalid or missing CivitAI API key, or early access restriction." - elif response.status == 403: - logger.warning(f"Forbidden access to resource: {url} (Status 403)") - return False, "Access forbidden: You don't have permission to download this file." - else: - logger.error(f"Download failed for {url} with status {response.status}") - return False, f"Download failed with status {response.status}" - - # Get total file size for progress calculation (if not set from Content-Range) - if total_size == 0: - total_size = int(response.headers.get('content-length', 0)) - if response.status == 206: - # For partial content, add the offset to get total file size - total_size += resume_offset - - current_size = resume_offset - last_progress_report_time = datetime.now() - - # Stream download to file with progress updates using larger buffer - loop = asyncio.get_running_loop() - mode = 'ab' if resume_offset > 0 else 'wb' - with open(part_path, mode) as f: - async for chunk in response.content.iter_chunked(self.chunk_size): - if chunk: - # Run blocking file write in executor - await loop.run_in_executor(None, f.write, chunk) - current_size += len(chunk) - - # Limit progress update frequency to reduce overhead - now = datetime.now() - time_diff = (now - last_progress_report_time).total_seconds() - - if progress_callback and total_size and time_diff >= 1.0: - progress = (current_size / total_size) * 100 - await progress_callback(progress) - last_progress_report_time = now - - # Download completed successfully - # Verify file size if total_size was provided - final_size = os.path.getsize(part_path) - if total_size > 0 and final_size != total_size: - logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}") - # Don't treat this as fatal error, rename anyway - - # Atomically rename .part to final file with retries - max_rename_attempts = 5 - rename_attempt = 0 - rename_success = False - - while rename_attempt < max_rename_attempts and not rename_success: - try: - os.rename(part_path, save_path) - rename_success = True - except PermissionError as e: - rename_attempt += 1 - if rename_attempt < max_rename_attempts: - logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})") - await asyncio.sleep(2) # Wait before retrying - else: - logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}") - return False, f"Failed to finalize download: {str(e)}" - - # Ensure 100% progress is reported - if progress_callback: - await progress_callback(100) - - return True, save_path - - except (aiohttp.ClientError, aiohttp.ClientPayloadError, - aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e: - retry_count += 1 - logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}") - - if retry_count <= max_retries: - # Calculate delay with exponential backoff - delay = base_delay * (2 ** (retry_count - 1)) - logger.info(f"Retrying in {delay} seconds...") - await asyncio.sleep(delay) - - # Update resume offset for next attempt - if os.path.exists(part_path): - resume_offset = os.path.getsize(part_path) - logger.info(f"Will resume from byte {resume_offset}") - - # Refresh session to get new connection - await self.close() - session = await self._ensure_fresh_session() - continue - else: - logger.error(f"Max retries exceeded for download: {e}") - return False, f"Network error after {max_retries + 1} attempts: {str(e)}" - - except Exception as e: - logger.error(f"Unexpected download error: {e}") - return False, str(e) - - return False, f"Download failed after {max_retries + 1} attempts" + return success, result async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: try: - session = await self._ensure_fresh_session() - async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response: - if response.status == 200: - return await response.json() - return None + downloader = await get_downloader() + success, result = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/by-hash/{model_hash}", + use_auth=True + ) + if success: + return result + return None except Exception as e: logger.error(f"API Error: {str(e)}") return None async def download_preview_image(self, image_url: str, save_path: str): try: - session = await self._ensure_fresh_session() - async with session.get(image_url) as response: - if response.status == 200: - content = await response.read() - with open(save_path, 'wb') as f: - f.write(content) - return True - return False + downloader = await get_downloader() + success, content = await downloader.download_to_memory( + image_url, + use_auth=False # Preview images don't need auth + ) + if success: + # Ensure directory exists + os.makedirs(os.path.dirname(save_path), exist_ok=True) + with open(save_path, 'wb') as f: + f.write(content) + return True + return False except Exception as e: - print(f"Download Error: {str(e)}") + logger.error(f"Download Error: {str(e)}") return False async def get_model_versions(self, model_id: str) -> List[Dict]: """Get all versions of a model with local availability info""" try: - session = await self._ensure_fresh_session() # Use fresh session - async with session.get(f"{self.base_url}/models/{model_id}") as response: - if response.status != 200: - return None - data = await response.json() + downloader = await get_downloader() + success, result = await downloader.make_request( + 'GET', + f"{self.base_url}/models/{model_id}", + use_auth=True + ) + if success: # Also return model type along with versions return { - 'modelVersions': data.get('modelVersions', []), - 'type': data.get('type', '') + 'modelVersions': result.get('modelVersions', []), + 'type': result.get('type', '') } + return None except Exception as e: logger.error(f"Error fetching model versions: {e}") return None @@ -322,68 +123,74 @@ class CivitaiClient: Optional[Dict]: The model version data with additional fields or None if not found """ try: - session = await self._ensure_fresh_session() - headers = self._get_request_headers() + downloader = await get_downloader() # Case 1: Only version_id is provided if model_id is None and version_id is not None: # First get the version info to extract model_id - async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response: - if response.status != 200: - return None - - version = await response.json() - model_id = version.get('modelId') - - if not model_id: - logger.error(f"No modelId found in version {version_id}") - return None + success, version = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/{version_id}", + use_auth=True + ) + if not success: + return None + model_id = version.get('modelId') + if not model_id: + logger.error(f"No modelId found in version {version_id}") + return None + # Now get the model data for additional metadata - async with session.get(f"{self.base_url}/models/{model_id}") as response: - if response.status != 200: - return version # Return version without additional metadata - - model_data = await response.json() - + success, model_data = await downloader.make_request( + 'GET', + f"{self.base_url}/models/{model_id}", + use_auth=True + ) + if success: # Enrich version with model data version['model']['description'] = model_data.get("description") version['model']['tags'] = model_data.get("tags", []) version['creator'] = model_data.get("creator") - - return version + + return version # Case 2: model_id is provided (with or without version_id) elif model_id is not None: # Step 1: Get model data to find version_id if not provided and get additional metadata - async with session.get(f"{self.base_url}/models/{model_id}") as response: - if response.status != 200: - return None - - data = await response.json() - model_versions = data.get('modelVersions', []) + success, data = await downloader.make_request( + 'GET', + f"{self.base_url}/models/{model_id}", + use_auth=True + ) + if not success: + return None - # Step 2: Determine the version_id to use - target_version_id = version_id - if target_version_id is None: - target_version_id = model_versions[0].get('id') + model_versions = data.get('modelVersions', []) + # Step 2: Determine the version_id to use + target_version_id = version_id + if target_version_id is None: + target_version_id = model_versions[0].get('id') + # Step 3: Get detailed version info using the version_id - async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response: - if response.status != 200: - return None - - version = await response.json() - - # Step 4: Enrich version_info with model data - # Add description and tags from model data - version['model']['description'] = data.get("description") - version['model']['tags'] = data.get("tags", []) - - # Add creator from model data - version['creator'] = data.get("creator") - - return version + success, version = await downloader.make_request( + 'GET', + f"{self.base_url}/model-versions/{target_version_id}", + use_auth=True + ) + if not success: + return None + + # Step 4: Enrich version_info with model data + # Add description and tags from model data + version['model']['description'] = data.get("description") + version['model']['tags'] = data.get("tags", []) + + # Add creator from model data + version['creator'] = data.get("creator") + + return version # Case 3: Neither model_id nor version_id provided else: @@ -406,30 +213,29 @@ class CivitaiClient: - An error message if there was an error, or None on success """ try: - session = await self._ensure_fresh_session() + downloader = await get_downloader() url = f"{self.base_url}/model-versions/{version_id}" - headers = self._get_request_headers() logger.debug(f"Resolving DNS for model version info: {url}") - async with session.get(url, headers=headers) as response: - if response.status == 200: - logger.debug(f"Successfully fetched model version info for: {version_id}") - return await response.json(), None - - # Handle specific error cases - if response.status == 404: - # Try to parse the error message - try: - error_data = await response.json() - error_msg = error_data.get('error', f"Model not found (status 404)") - logger.warning(f"Model version not found: {version_id} - {error_msg}") - return None, error_msg - except: - return None, "Model not found (status 404)" - - # Other error cases - logger.error(f"Failed to fetch model info for {version_id} (status {response.status})") - return None, f"Failed to fetch model info (status {response.status})" + success, result = await downloader.make_request( + 'GET', + url, + use_auth=True + ) + + if success: + logger.debug(f"Successfully fetched model version info for: {version_id}") + return result, None + + # Handle specific error cases + if "404" in str(result): + error_msg = f"Model not found (status 404)" + logger.warning(f"Model version not found: {version_id} - {error_msg}") + return None, error_msg + + # Other error cases + logger.error(f"Failed to fetch model info for {version_id}: {result}") + return None, str(result) except Exception as e: error_msg = f"Error fetching model version info: {e}" logger.error(error_msg) @@ -444,48 +250,50 @@ class CivitaiClient: Returns: Tuple[Optional[Dict], int]: A tuple containing: - A dictionary with model metadata or None if not found - - The HTTP status code from the request + - The HTTP status code from the request (0 for exceptions) """ try: - session = await self._ensure_fresh_session() - headers = self._get_request_headers() + downloader = await get_downloader() url = f"{self.base_url}/models/{model_id}" - async with session.get(url, headers=headers) as response: - status_code = response.status - - if status_code != 200: - logger.warning(f"Failed to fetch model metadata: Status {status_code}") - return None, status_code - - data = await response.json() - - # Extract relevant metadata - metadata = { - "description": data.get("description") or "No model description available", - "tags": data.get("tags", []), - "creator": { - "username": data.get("creator", {}).get("username"), - "image": data.get("creator", {}).get("image") - } + success, result = await downloader.make_request( + 'GET', + url, + use_auth=True + ) + + if not success: + # Try to extract status code from error message + status_code = 0 + if "404" in str(result): + status_code = 404 + elif "401" in str(result): + status_code = 401 + elif "403" in str(result): + status_code = 403 + logger.warning(f"Failed to fetch model metadata: {result}") + return None, status_code + + # Extract relevant metadata + metadata = { + "description": result.get("description") or "No model description available", + "tags": result.get("tags", []), + "creator": { + "username": result.get("creator", {}).get("username"), + "image": result.get("creator", {}).get("image") } - - if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]: - return metadata, status_code - else: - logger.warning(f"No metadata found for model {model_id}") - return None, status_code + } + + if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]: + return metadata, 200 + else: + logger.warning(f"No metadata found for model {model_id}") + return None, 200 except Exception as e: logger.error(f"Error fetching model metadata: {e}", exc_info=True) return None, 0 - async def close(self): - """Close the session if it exists""" - if self._session is not None: - await self._session.close() - self._session = None - async def get_image_info(self, image_id: str) -> Optional[Dict]: """Fetch image information from Civitai API @@ -496,22 +304,25 @@ class CivitaiClient: Optional[Dict]: The image data or None if not found """ try: - session = await self._ensure_fresh_session() - headers = self._get_request_headers() + downloader = await get_downloader() url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" logger.debug(f"Fetching image info for ID: {image_id}") - async with session.get(url, headers=headers) as response: - if response.status == 200: - data = await response.json() - if data and "items" in data and len(data["items"]) > 0: - logger.debug(f"Successfully fetched image info for ID: {image_id}") - return data["items"][0] - logger.warning(f"No image found with ID: {image_id}") - return None - - logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})") + success, result = await downloader.make_request( + 'GET', + url, + use_auth=True + ) + + if success: + if result and "items" in result and len(result["items"]) > 0: + logger.debug(f"Successfully fetched image info for ID: {image_id}") + return result["items"][0] + logger.warning(f"No image found with ID: {image_id}") return None + + logger.error(f"Failed to fetch image info for ID: {image_id}: {result}") + return None except Exception as e: error_msg = f"Error fetching image info: {e}" logger.error(error_msg) diff --git a/py/services/downloader.py b/py/services/downloader.py new file mode 100644 index 00000000..cb7f5ef1 --- /dev/null +++ b/py/services/downloader.py @@ -0,0 +1,465 @@ +""" +Unified download manager for all HTTP/HTTPS downloads in the application. + +This module provides a centralized download service with: +- Singleton pattern for global session management +- Support for authenticated downloads (e.g., CivitAI API key) +- Resumable downloads with automatic retry +- Progress tracking and callbacks +- Optimized connection pooling and timeouts +- Unified error handling and logging +""" + +import os +import logging +import asyncio +import aiohttp +from datetime import datetime +from typing import Optional, Dict, Tuple, Callable, Union +from ..services.settings_manager import settings + +logger = logging.getLogger(__name__) + + +class Downloader: + """Unified downloader for all HTTP/HTTPS downloads in the application.""" + + _instance = None + _lock = asyncio.Lock() + + @classmethod + async def get_instance(cls): + """Get singleton instance of Downloader""" + async with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def __init__(self): + """Initialize the downloader with optimal settings""" + # Check if already initialized for singleton pattern + if hasattr(self, '_initialized'): + return + self._initialized = True + + # Session management + self._session = None + self._session_created_at = None + + # Configuration + self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput + self.max_retries = 5 + self.base_delay = 2.0 # Base delay for exponential backoff + self.session_timeout = 300 # 5 minutes + + # Default headers + self.default_headers = { + 'User-Agent': 'ComfyUI-LoRA-Manager/1.0' + } + + @property + async def session(self) -> aiohttp.ClientSession: + """Get or create the global aiohttp session with optimized settings""" + if self._session is None or self._should_refresh_session(): + await self._create_session() + return self._session + + def _should_refresh_session(self) -> bool: + """Check if session should be refreshed""" + if self._session is None: + return True + + if not hasattr(self, '_session_created_at') or self._session_created_at is None: + return True + + # Refresh if session is older than timeout + if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout: + return True + + return False + + async def _create_session(self): + """Create a new aiohttp session with optimized settings""" + # Close existing session if any + if self._session is not None: + await self._session.close() + + # Optimize TCP connection parameters + connector = aiohttp.TCPConnector( + ssl=True, + limit=8, # Concurrent connections + ttl_dns_cache=300, # DNS cache timeout + force_close=False, # Keep connections for reuse + enable_cleanup_closed=True + ) + + # Configure timeout parameters + timeout = aiohttp.ClientTimeout( + total=None, # No total timeout for large downloads + connect=60, # Connection timeout + sock_read=None # No socket read timeout + ) + + self._session = aiohttp.ClientSession( + connector=connector, + trust_env=True, # Use system proxy settings + timeout=timeout + ) + self._session_created_at = datetime.now() + + logger.debug("Created new HTTP session") + + def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]: + """Get headers with optional authentication""" + headers = self.default_headers.copy() + + if use_auth: + # Add CivitAI API key if available + api_key = settings.get('civitai_api_key') + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + headers['Content-Type'] = 'application/json' + + return headers + + async def download_file( + self, + url: str, + save_path: str, + progress_callback: Optional[Callable[[float], None]] = None, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None, + allow_resume: bool = True + ) -> Tuple[bool, str]: + """ + Download a file with resumable downloads and retry mechanism + + Args: + url: Download URL + save_path: Full path where the file should be saved + progress_callback: Optional callback for progress updates (0-100) + use_auth: Whether to include authentication headers (e.g., CivitAI API key) + custom_headers: Additional headers to include in request + allow_resume: Whether to support resumable downloads + + Returns: + Tuple[bool, str]: (success, save_path or error message) + """ + retry_count = 0 + part_path = save_path + '.part' if allow_resume else save_path + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + # Get existing file size for resume + resume_offset = 0 + if allow_resume and os.path.exists(part_path): + resume_offset = os.path.getsize(part_path) + logger.info(f"Resuming download from offset {resume_offset} bytes") + + total_size = 0 + + while retry_count <= self.max_retries: + try: + session = await self.session + + # Add Range header for resume if we have partial data + request_headers = headers.copy() + if allow_resume and resume_offset > 0: + request_headers['Range'] = f'bytes={resume_offset}-' + + # Disable compression for better chunked downloads + request_headers['Accept-Encoding'] = 'identity' + + logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}") + if resume_offset > 0: + logger.debug(f"Requesting range from byte {resume_offset}") + + async with session.get(url, headers=request_headers, allow_redirects=True) as response: + # Handle different response codes + if response.status == 200: + # Full content response + if resume_offset > 0: + # Server doesn't support ranges, restart from beginning + logger.warning("Server doesn't support range requests, restarting download") + resume_offset = 0 + if os.path.exists(part_path): + os.remove(part_path) + elif response.status == 206: + # Partial content response (resume successful) + content_range = response.headers.get('Content-Range') + if content_range: + # Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048") + range_parts = content_range.split('/') + if len(range_parts) == 2: + total_size = int(range_parts[1]) + logger.info(f"Successfully resumed download from byte {resume_offset}") + elif response.status == 416: + # Range not satisfiable - file might be complete or corrupted + if allow_resume and os.path.exists(part_path): + part_size = os.path.getsize(part_path) + logger.warning(f"Range not satisfiable. Part file size: {part_size}") + # Try to get actual file size + head_response = await session.head(url, headers=headers) + if head_response.status == 200: + actual_size = int(head_response.headers.get('content-length', 0)) + if part_size == actual_size: + # File is complete, just rename it + if allow_resume: + os.rename(part_path, save_path) + if progress_callback: + await progress_callback(100) + return True, save_path + # Remove corrupted part file and restart + os.remove(part_path) + resume_offset = 0 + continue + elif response.status == 401: + logger.warning(f"Unauthorized access to resource: {url} (Status 401)") + return False, "Invalid or missing API key, or early access restriction." + elif response.status == 403: + logger.warning(f"Forbidden access to resource: {url} (Status 403)") + return False, "Access forbidden: You don't have permission to download this file." + elif response.status == 404: + logger.warning(f"Resource not found: {url} (Status 404)") + return False, "File not found - the download link may be invalid or expired." + else: + logger.error(f"Download failed for {url} with status {response.status}") + return False, f"Download failed with status {response.status}" + + # Get total file size for progress calculation (if not set from Content-Range) + if total_size == 0: + total_size = int(response.headers.get('content-length', 0)) + if response.status == 206: + # For partial content, add the offset to get total file size + total_size += resume_offset + + current_size = resume_offset + last_progress_report_time = datetime.now() + + # Ensure directory exists + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + # Stream download to file with progress updates + loop = asyncio.get_running_loop() + mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb' + with open(part_path, mode) as f: + async for chunk in response.content.iter_chunked(self.chunk_size): + if chunk: + # Run blocking file write in executor + await loop.run_in_executor(None, f.write, chunk) + current_size += len(chunk) + + # Limit progress update frequency to reduce overhead + now = datetime.now() + time_diff = (now - last_progress_report_time).total_seconds() + + if progress_callback and total_size and time_diff >= 1.0: + progress = (current_size / total_size) * 100 + await progress_callback(progress) + last_progress_report_time = now + + # Download completed successfully + # Verify file size if total_size was provided + final_size = os.path.getsize(part_path) + if total_size > 0 and final_size != total_size: + logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}") + # Don't treat this as fatal error, continue anyway + + # Atomically rename .part to final file (only if using resume) + if allow_resume and part_path != save_path: + max_rename_attempts = 5 + rename_attempt = 0 + rename_success = False + + while rename_attempt < max_rename_attempts and not rename_success: + try: + os.rename(part_path, save_path) + rename_success = True + except PermissionError as e: + rename_attempt += 1 + if rename_attempt < max_rename_attempts: + logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})") + await asyncio.sleep(2) + else: + logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}") + return False, f"Failed to finalize download: {str(e)}" + + # Ensure 100% progress is reported + if progress_callback: + await progress_callback(100) + + return True, save_path + + except (aiohttp.ClientError, aiohttp.ClientPayloadError, + aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e: + retry_count += 1 + logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}") + + if retry_count <= self.max_retries: + # Calculate delay with exponential backoff + delay = self.base_delay * (2 ** (retry_count - 1)) + logger.info(f"Retrying in {delay} seconds...") + await asyncio.sleep(delay) + + # Update resume offset for next attempt + if allow_resume and os.path.exists(part_path): + resume_offset = os.path.getsize(part_path) + logger.info(f"Will resume from byte {resume_offset}") + + # Refresh session to get new connection + await self._create_session() + continue + else: + logger.error(f"Max retries exceeded for download: {e}") + return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}" + + except Exception as e: + logger.error(f"Unexpected download error: {e}") + return False, str(e) + + return False, f"Download failed after {self.max_retries + 1} attempts" + + async def download_to_memory( + self, + url: str, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None + ) -> Tuple[bool, Union[bytes, str]]: + """ + Download a file to memory (for small files like preview images) + + Args: + url: Download URL + use_auth: Whether to include authentication headers + custom_headers: Additional headers to include in request + + Returns: + Tuple[bool, Union[bytes, str]]: (success, content or error message) + """ + try: + session = await self.session + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + async with session.get(url, headers=headers) as response: + if response.status == 200: + content = await response.read() + return True, content + elif response.status == 401: + return False, "Unauthorized access - invalid or missing API key" + elif response.status == 403: + return False, "Access forbidden" + elif response.status == 404: + return False, "File not found" + else: + return False, f"Download failed with status {response.status}" + + except Exception as e: + logger.error(f"Error downloading to memory from {url}: {e}") + return False, str(e) + + async def get_response_headers( + self, + url: str, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None + ) -> Tuple[bool, Union[Dict, str]]: + """ + Get response headers without downloading the full content + + Args: + url: URL to check + use_auth: Whether to include authentication headers + custom_headers: Additional headers to include in request + + Returns: + Tuple[bool, Union[Dict, str]]: (success, headers dict or error message) + """ + try: + session = await self.session + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + async with session.head(url, headers=headers) as response: + if response.status == 200: + return True, dict(response.headers) + else: + return False, f"Head request failed with status {response.status}" + + except Exception as e: + logger.error(f"Error getting headers from {url}: {e}") + return False, str(e) + + async def make_request( + self, + method: str, + url: str, + use_auth: bool = False, + custom_headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> Tuple[bool, Union[Dict, str]]: + """ + Make a generic HTTP request and return JSON response + + Args: + method: HTTP method (GET, POST, etc.) + url: Request URL + use_auth: Whether to include authentication headers + custom_headers: Additional headers to include in request + **kwargs: Additional arguments for aiohttp request + + Returns: + Tuple[bool, Union[Dict, str]]: (success, response data or error message) + """ + try: + session = await self.session + + # Prepare headers + headers = self._get_auth_headers(use_auth) + if custom_headers: + headers.update(custom_headers) + + async with session.request(method, url, headers=headers, **kwargs) as response: + if response.status == 200: + # Try to parse as JSON, fall back to text + try: + data = await response.json() + return True, data + except: + text = await response.text() + return True, text + elif response.status == 401: + return False, "Unauthorized access - invalid or missing API key" + elif response.status == 403: + return False, "Access forbidden" + elif response.status == 404: + return False, "Resource not found" + else: + return False, f"Request failed with status {response.status}" + + except Exception as e: + logger.error(f"Error making {method} request to {url}: {e}") + return False, str(e) + + async def close(self): + """Close the HTTP session""" + if self._session is not None: + await self._session.close() + self._session = None + self._session_created_at = None + logger.debug("Closed HTTP session") + + +# Global instance accessor +async def get_downloader() -> Downloader: + """Get the global downloader instance""" + return await Downloader.get_instance() diff --git a/py/services/metadata_archive_manager.py b/py/services/metadata_archive_manager.py index 3daf761b..a1ba9b74 100644 --- a/py/services/metadata_archive_manager.py +++ b/py/services/metadata_archive_manager.py @@ -1,9 +1,9 @@ import zipfile -import aiohttp import logging import asyncio from pathlib import Path from typing import Optional +from .downloader import get_downloader logger = logging.getLogger(__name__) @@ -67,6 +67,8 @@ class MetadataArchiveManager: async def _download_archive(self, progress_callback=None) -> bool: """Download the zip archive from one of the available URLs""" + downloader = await get_downloader() + for url in self.DOWNLOAD_URLS: try: logger.info(f"Attempting to download from {url}") @@ -74,26 +76,25 @@ class MetadataArchiveManager: if progress_callback: progress_callback("download", f"Downloading from {url}") - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - total_size = int(response.headers.get('content-length', 0)) - downloaded = 0 - - with open(self.archive_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - f.write(chunk) - downloaded += len(chunk) - - if progress_callback and total_size > 0: - percentage = (downloaded / total_size) * 100 - progress_callback("download", f"Downloaded {percentage:.1f}%") - - logger.info(f"Successfully downloaded archive from {url}") - return True - else: - logger.warning(f"Failed to download from {url}: HTTP {response.status}") - continue + # Custom progress callback to report download progress + async def download_progress(progress): + if progress_callback: + progress_callback("download", f"Downloaded {progress:.1f}%") + + success, result = await downloader.download_file( + url=url, + save_path=str(self.archive_path), + progress_callback=download_progress, + use_auth=False, # Public download, no auth needed + allow_resume=True + ) + + if success: + logger.info(f"Successfully downloaded archive from {url}") + return True + else: + logger.warning(f"Failed to download from {url}: {result}") + continue except Exception as e: logger.warning(f"Error downloading from {url}: {e}") diff --git a/py/utils/example_images_download_manager.py b/py/utils/example_images_download_manager.py index db1b93f0..e3f46244 100644 --- a/py/utils/example_images_download_manager.py +++ b/py/utils/example_images_download_manager.py @@ -3,13 +3,13 @@ import os import asyncio import json import time -import aiohttp from aiohttp import web from ..services.service_registry import ServiceRegistry from ..utils.metadata_manager import MetadataManager from .example_images_processor import ExampleImagesProcessor from .example_images_metadata import MetadataUpdater from ..services.websocket_manager import ws_manager # Add this import at the top +from ..services.downloader import get_downloader logger = logging.getLogger(__name__) @@ -199,19 +199,8 @@ class DownloadManager: """Download example images for all models""" global is_downloading, download_progress - # Create independent download session - connector = aiohttp.TCPConnector( - ssl=True, - limit=3, - force_close=False, - enable_cleanup_closed=True - ) - timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) - independent_session = aiohttp.ClientSession( - connector=connector, - trust_env=True, - timeout=timeout - ) + # Get unified downloader + downloader = await get_downloader() try: # Get scanners @@ -246,7 +235,7 @@ class DownloadManager: # Main logic for processing model is here, but actual operations are delegated to other classes was_remote_download = await DownloadManager._process_model( scanner_type, model, scanner, - output_dir, optimize, independent_session + output_dir, optimize, downloader ) # Update progress @@ -270,12 +259,6 @@ class DownloadManager: download_progress['end_time'] = time.time() finally: - # Close the independent session - try: - await independent_session.close() - except Exception as e: - logger.error(f"Error closing download session: {e}") - # Save final progress to file try: DownloadManager._save_progress(output_dir) @@ -286,7 +269,7 @@ class DownloadManager: is_downloading = False @staticmethod - async def _process_model(scanner_type, model, scanner, output_dir, optimize, independent_session): + async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader): """Process a single model download""" global download_progress @@ -347,7 +330,7 @@ class DownloadManager: images = model.get('civitai', {}).get('images', []) success, is_stale = await ExampleImagesProcessor.download_model_images( - model_hash, model_name, images, model_dir, optimize, independent_session + model_hash, model_name, images, model_dir, optimize, downloader ) # If metadata is stale, try to refresh it @@ -365,7 +348,7 @@ class DownloadManager: # Retry download with updated metadata updated_images = updated_model.get('civitai', {}).get('images', []) success, _ = await ExampleImagesProcessor.download_model_images( - model_hash, model_name, updated_images, model_dir, optimize, independent_session + model_hash, model_name, updated_images, model_dir, optimize, downloader ) download_progress['refreshed_models'].add(model_hash) @@ -529,19 +512,8 @@ class DownloadManager: """Download example images for specific models only - synchronous version""" global download_progress - # Create independent download session - connector = aiohttp.TCPConnector( - ssl=True, - limit=3, - force_close=False, - enable_cleanup_closed=True - ) - timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60) - independent_session = aiohttp.ClientSession( - connector=connector, - trust_env=True, - timeout=timeout - ) + # Get unified downloader + downloader = await get_downloader() try: # Get scanners @@ -586,7 +558,7 @@ class DownloadManager: # Force process this model regardless of previous status was_successful = await DownloadManager._process_specific_model( scanner_type, model, scanner, - output_dir, optimize, independent_session + output_dir, optimize, downloader ) if was_successful: @@ -650,14 +622,11 @@ class DownloadManager: raise finally: - # Close the independent session - try: - await independent_session.close() - except Exception as e: - logger.error(f"Error closing download session: {e}") + # No need to close any sessions since we use the global downloader + pass @staticmethod - async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session): + async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader): """Process a specific model for forced download, ignoring previous download status""" global download_progress @@ -701,7 +670,7 @@ class DownloadManager: images = model.get('civitai', {}).get('images', []) success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, model_name, images, model_dir, optimize, independent_session + model_hash, model_name, images, model_dir, optimize, downloader ) # If metadata is stale, try to refresh it @@ -719,7 +688,7 @@ class DownloadManager: # Retry download with updated metadata updated_images = updated_model.get('civitai', {}).get('images', []) success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( - model_hash, model_name, updated_images, model_dir, optimize, independent_session + model_hash, model_name, updated_images, model_dir, optimize, downloader ) # Combine failed images from both attempts diff --git a/py/utils/example_images_processor.py b/py/utils/example_images_processor.py index 6d14e621..9dba4e2c 100644 --- a/py/utils/example_images_processor.py +++ b/py/utils/example_images_processor.py @@ -35,7 +35,7 @@ class ExampleImagesProcessor: return image_url @staticmethod - async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, independent_session): + async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, downloader): """Download images for a single model Returns: @@ -78,23 +78,25 @@ class ExampleImagesProcessor: try: logger.debug(f"Downloading {save_filename} for {model_name}") - # Download directly using the independent session - async with independent_session.get(image_url, timeout=60) as response: - if response.status == 200: - with open(save_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - if chunk: - f.write(chunk) - elif response.status == 404: - error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" - logger.warning(error_msg) - model_success = False # Mark the model as failed due to 404 error - # Return early to trigger metadata refresh attempt - return False, True # (success, is_metadata_stale) - else: - error_msg = f"Failed to download file: {image_url}, status code: {response.status}" - logger.warning(error_msg) - model_success = False # Mark the model as failed + # Download using the unified downloader + success, content = await downloader.download_to_memory( + image_url, + use_auth=False # Example images don't need auth + ) + + if success: + with open(save_path, 'wb') as f: + f.write(content) + elif "404" in str(content): + error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" + logger.warning(error_msg) + model_success = False # Mark the model as failed due to 404 error + # Return early to trigger metadata refresh attempt + return False, True # (success, is_metadata_stale) + else: + error_msg = f"Failed to download file: {image_url}, error: {content}" + logger.warning(error_msg) + model_success = False # Mark the model as failed except Exception as e: error_msg = f"Error downloading file {image_url}: {str(e)}" logger.error(error_msg) @@ -103,7 +105,7 @@ class ExampleImagesProcessor: return model_success, False # (success, is_metadata_stale) @staticmethod - async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session): + async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, downloader): """Download images for a single model with tracking of failed image URLs Returns: @@ -147,25 +149,27 @@ class ExampleImagesProcessor: try: logger.debug(f"Downloading {save_filename} for {model_name}") - # Download directly using the independent session - async with independent_session.get(image_url, timeout=60) as response: - if response.status == 200: - with open(save_path, 'wb') as f: - async for chunk in response.content.iter_chunked(8192): - if chunk: - f.write(chunk) - elif response.status == 404: - error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" - logger.warning(error_msg) - model_success = False # Mark the model as failed due to 404 error - failed_images.append(image_url) # Track failed URL - # Return early to trigger metadata refresh attempt - return False, True, failed_images # (success, is_metadata_stale, failed_images) - else: - error_msg = f"Failed to download file: {image_url}, status code: {response.status}" - logger.warning(error_msg) - model_success = False # Mark the model as failed - failed_images.append(image_url) # Track failed URL + # Download using the unified downloader + success, content = await downloader.download_to_memory( + image_url, + use_auth=False # Example images don't need auth + ) + + if success: + with open(save_path, 'wb') as f: + f.write(content) + elif "404" in str(content): + error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" + logger.warning(error_msg) + model_success = False # Mark the model as failed due to 404 error + failed_images.append(image_url) # Track failed URL + # Return early to trigger metadata refresh attempt + return False, True, failed_images # (success, is_metadata_stale, failed_images) + else: + error_msg = f"Failed to download file: {image_url}, error: {content}" + logger.warning(error_msg) + model_success = False # Mark the model as failed + failed_images.append(image_url) # Track failed URL except Exception as e: error_msg = f"Error downloading file {image_url}: {str(e)}" logger.error(error_msg)