优化下载性能:移除 SHA256 计算并使用 16MB chunks

- 移除下载后的 SHA256 计算,直接使用 API 返回的 hash 值
- 将 chunk size 从 4MB 调整为 16MB,减少 75% 的 I/O 操作
- 这有助于缓解 ComfyUI 执行期间的卡顿问题
This commit is contained in:
Will Miao
2026-03-25 19:29:48 +08:00
parent c5c1b8fd2a
commit 6f2a01dc86
2 changed files with 252 additions and 151 deletions

View File

@@ -19,7 +19,6 @@ from ..utils.civitai_utils import rewrite_preview_url
from ..utils.preview_selection import select_preview_media from ..utils.preview_selection import select_preview_media
from ..utils.utils import sanitize_folder_name from ..utils.utils import sanitize_folder_name
from ..utils.exif_utils import ExifUtils from ..utils.exif_utils import ExifUtils
from ..utils.file_utils import calculate_sha256
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
@@ -965,11 +964,12 @@ class DownloadManager:
for download_url in download_urls: for download_url in download_urls:
use_auth = download_url.startswith("https://civitai.com/api/download/") use_auth = download_url.startswith("https://civitai.com/api/download/")
download_kwargs = { download_kwargs = {
"progress_callback": lambda progress, "progress_callback": lambda progress, snapshot=None: (
snapshot=None: self._handle_download_progress( self._handle_download_progress(
progress, progress,
progress_callback, progress_callback,
snapshot, snapshot,
)
), ),
"use_auth": use_auth, # Only use authentication for Civitai downloads "use_auth": use_auth, # Only use authentication for Civitai downloads
} }
@@ -1238,7 +1238,8 @@ class DownloadManager:
entry.file_name = os.path.splitext(os.path.basename(file_path))[0] entry.file_name = os.path.splitext(os.path.basename(file_path))[0]
# Update size to actual downloaded file size # Update size to actual downloaded file size
entry.size = os.path.getsize(file_path) entry.size = os.path.getsize(file_path)
entry.sha256 = await calculate_sha256(file_path) # Use SHA256 from API metadata (already set in from_civitai_info)
# Do not recalculate to avoid blocking during ComfyUI execution
entries.append(entry) entries.append(entry)
return entries return entries

View File

@@ -44,7 +44,9 @@ class DownloadStreamControl:
self._event.set() self._event.set()
self._reconnect_requested = False self._reconnect_requested = False
self.last_progress_timestamp: Optional[float] = None self.last_progress_timestamp: Optional[float] = None
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0 self.stall_timeout: float = (
float(stall_timeout) if stall_timeout is not None else 120.0
)
def is_set(self) -> bool: def is_set(self) -> bool:
return self._event.is_set() return self._event.is_set()
@@ -85,7 +87,9 @@ class DownloadStreamControl:
self.last_progress_timestamp = timestamp or datetime.now().timestamp() self.last_progress_timestamp = timestamp or datetime.now().timestamp()
self._reconnect_requested = False self._reconnect_requested = False
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]: def time_since_last_progress(
self, *, now: Optional[float] = None
) -> Optional[float]:
if self.last_progress_timestamp is None: if self.last_progress_timestamp is None:
return None return None
reference = now if now is not None else datetime.now().timestamp() reference = now if now is not None else datetime.now().timestamp()
@@ -105,10 +109,10 @@ class DownloadStalledError(Exception):
class Downloader: class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application.""" """Unified downloader for all HTTP/HTTPS downloads in the application."""
_instance = None _instance = None
_lock = asyncio.Lock() _lock = asyncio.Lock()
@classmethod @classmethod
async def get_instance(cls): async def get_instance(cls):
"""Get singleton instance of Downloader""" """Get singleton instance of Downloader"""
@@ -116,35 +120,37 @@ class Downloader:
if cls._instance is None: if cls._instance is None:
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
def __init__(self): def __init__(self):
"""Initialize the downloader with optimal settings""" """Initialize the downloader with optimal settings"""
# Check if already initialized for singleton pattern # Check if already initialized for singleton pattern
if hasattr(self, '_initialized'): if hasattr(self, "_initialized"):
return return
self._initialized = True self._initialized = True
# Session management # Session management
self._session = None self._session = None
self._session_created_at = None self._session_created_at = None
self._proxy_url = None # Store proxy URL for current session self._proxy_url = None # Store proxy URL for current session
self._session_lock = asyncio.Lock() self._session_lock = asyncio.Lock()
# Configuration # Configuration
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput self.chunk_size = (
16 * 1024 * 1024
) # 16MB chunks to balance I/O reduction and memory usage
self.max_retries = 5 self.max_retries = 5
self.base_delay = 2.0 # Base delay for exponential backoff self.base_delay = 2.0 # Base delay for exponential backoff
self.session_timeout = 300 # 5 minutes self.session_timeout = 300 # 5 minutes
self.stall_timeout = self._resolve_stall_timeout() self.stall_timeout = self._resolve_stall_timeout()
# Default headers # Default headers
self.default_headers = { self.default_headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0', "User-Agent": "ComfyUI-LoRA-Manager/1.0",
# Explicitly request uncompressed payloads so aiohttp doesn't need optional # Explicitly request uncompressed payloads so aiohttp doesn't need optional
# decoders (e.g. zstandard) that may be missing in runtime environments. # decoders (e.g. zstandard) that may be missing in runtime environments.
'Accept-Encoding': 'identity', "Accept-Encoding": "identity",
} }
@property @property
async def session(self) -> aiohttp.ClientSession: async def session(self) -> aiohttp.ClientSession:
"""Get or create the global aiohttp session with optimized settings""" """Get or create the global aiohttp session with optimized settings"""
@@ -158,7 +164,7 @@ class Downloader:
@property @property
def proxy_url(self) -> Optional[str]: def proxy_url(self) -> Optional[str]:
"""Get the current proxy URL (initialize if needed)""" """Get the current proxy URL (initialize if needed)"""
if not hasattr(self, '_proxy_url'): if not hasattr(self, "_proxy_url"):
self._proxy_url = None self._proxy_url = None
return self._proxy_url return self._proxy_url
@@ -169,14 +175,14 @@ class Downloader:
try: try:
settings_manager = get_settings_manager() settings_manager = get_settings_manager()
settings_timeout = settings_manager.get('download_stall_timeout_seconds') settings_timeout = settings_manager.get("download_stall_timeout_seconds")
except Exception as exc: # pragma: no cover - defensive guard except Exception as exc: # pragma: no cover - defensive guard
logger.debug("Failed to read stall timeout from settings: %s", exc) logger.debug("Failed to read stall timeout from settings: %s", exc)
raw_value = ( raw_value = (
settings_timeout settings_timeout
if settings_timeout not in (None, "") if settings_timeout not in (None, "")
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT') else os.environ.get("COMFYUI_DOWNLOAD_STALL_TIMEOUT")
) )
try: try:
@@ -190,93 +196,104 @@ class Downloader:
"""Check if session should be refreshed""" """Check if session should be refreshed"""
if self._session is None: if self._session is None:
return True return True
if not hasattr(self, '_session_created_at') or self._session_created_at is None: if not hasattr(self, "_session_created_at") or self._session_created_at is None:
return True return True
# Refresh if session is older than timeout # Refresh if session is older than timeout
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout: if (
datetime.now() - self._session_created_at
).total_seconds() > self.session_timeout:
return True return True
return False return False
async def _create_session(self): async def _create_session(self):
"""Create a new aiohttp session with optimized settings. """Create a new aiohttp session with optimized settings.
Note: This is private and caller MUST hold self._session_lock. Note: This is private and caller MUST hold self._session_lock.
""" """
# Close existing session if any # Close existing session if any
if self._session is not None: if self._session is not None:
try: try:
await self._session.close() await self._session.close()
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.warning(f"Error closing previous session: {e}") logger.warning(f"Error closing previous session: {e}")
finally: finally:
self._session = None self._session = None
# Check for app-level proxy settings # Check for app-level proxy settings
proxy_url = None proxy_url = None
settings_manager = get_settings_manager() settings_manager = get_settings_manager()
if settings_manager.get('proxy_enabled', False): if settings_manager.get("proxy_enabled", False):
proxy_host = settings_manager.get('proxy_host', '').strip() proxy_host = settings_manager.get("proxy_host", "").strip()
proxy_port = settings_manager.get('proxy_port', '').strip() proxy_port = settings_manager.get("proxy_port", "").strip()
proxy_type = settings_manager.get('proxy_type', 'http').lower() proxy_type = settings_manager.get("proxy_type", "http").lower()
proxy_username = settings_manager.get('proxy_username', '').strip() proxy_username = settings_manager.get("proxy_username", "").strip()
proxy_password = settings_manager.get('proxy_password', '').strip() proxy_password = settings_manager.get("proxy_password", "").strip()
if proxy_host and proxy_port: if proxy_host and proxy_port:
# Build proxy URL # Build proxy URL
if proxy_username and proxy_password: if proxy_username and proxy_password:
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}" proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
else: else:
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}" proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
logger.debug(f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}") logger.debug(
f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}"
)
logger.debug("Proxy mode: app-level proxy is active.") logger.debug("Proxy mode: app-level proxy is active.")
else: else:
logger.debug("Proxy mode: system-level proxy (trust_env) will be used if configured in environment.") logger.debug(
"Proxy mode: system-level proxy (trust_env) will be used if configured in environment."
)
# Optimize TCP connection parameters # Optimize TCP connection parameters
connector = aiohttp.TCPConnector( connector = aiohttp.TCPConnector(
ssl=True, ssl=True,
limit=8, # Concurrent connections limit=8, # Concurrent connections
ttl_dns_cache=300, # DNS cache timeout ttl_dns_cache=300, # DNS cache timeout
force_close=False, # Keep connections for reuse force_close=False, # Keep connections for reuse
enable_cleanup_closed=True enable_cleanup_closed=True,
) )
# Configure timeout parameters # Configure timeout parameters
timeout = aiohttp.ClientTimeout( timeout = aiohttp.ClientTimeout(
total=None, # No total timeout for large downloads total=None, # No total timeout for large downloads
connect=60, # Connection timeout connect=60, # Connection timeout
sock_read=300 # 5 minute socket read timeout sock_read=300, # 5 minute socket read timeout
) )
self._session = aiohttp.ClientSession( self._session = aiohttp.ClientSession(
connector=connector, connector=connector,
trust_env=proxy_url is None, # Only use system proxy if no app-level proxy is set trust_env=proxy_url
timeout=timeout is None, # Only use system proxy if no app-level proxy is set
timeout=timeout,
) )
# Store proxy URL for use in requests # Store proxy URL for use in requests
self._proxy_url = proxy_url self._proxy_url = proxy_url
self._session_created_at = datetime.now() self._session_created_at = datetime.now()
logger.debug("Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s", bool(proxy_url), proxy_url is None) logger.debug(
"Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s",
bool(proxy_url),
proxy_url is None,
)
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]: def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
"""Get headers with optional authentication""" """Get headers with optional authentication"""
headers = self.default_headers.copy() headers = self.default_headers.copy()
if use_auth: if use_auth:
# Add CivitAI API key if available # Add CivitAI API key if available
settings_manager = get_settings_manager() settings_manager = get_settings_manager()
api_key = settings_manager.get('civitai_api_key') api_key = settings_manager.get("civitai_api_key")
if api_key: if api_key:
headers['Authorization'] = f'Bearer {api_key}' headers["Authorization"] = f"Bearer {api_key}"
headers['Content-Type'] = 'application/json' headers["Content-Type"] = "application/json"
return headers return headers
async def download_file( async def download_file(
self, self,
url: str, url: str,
@@ -289,7 +306,7 @@ class Downloader:
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
""" """
Download a file with resumable downloads and retry mechanism Download a file with resumable downloads and retry mechanism
Args: Args:
url: Download URL url: Download URL
save_path: Full path where the file should be saved save_path: Full path where the file should be saved
@@ -298,75 +315,96 @@ class Downloader:
custom_headers: Additional headers to include in request custom_headers: Additional headers to include in request
allow_resume: Whether to support resumable downloads allow_resume: Whether to support resumable downloads
pause_event: Optional stream control used to pause/resume and request reconnects pause_event: Optional stream control used to pause/resume and request reconnects
Returns: Returns:
Tuple[bool, str]: (success, save_path or error message) Tuple[bool, str]: (success, save_path or error message)
""" """
retry_count = 0 retry_count = 0
part_path = save_path + '.part' if allow_resume else save_path part_path = save_path + ".part" if allow_resume else save_path
# Prepare headers # Prepare headers
headers = self._get_auth_headers(use_auth) headers = self._get_auth_headers(use_auth)
if custom_headers: if custom_headers:
headers.update(custom_headers) headers.update(custom_headers)
# Get existing file size for resume # Get existing file size for resume
resume_offset = 0 resume_offset = 0
if allow_resume and os.path.exists(part_path): if allow_resume and os.path.exists(part_path):
resume_offset = os.path.getsize(part_path) resume_offset = os.path.getsize(part_path)
logger.info(f"Resuming download from offset {resume_offset} bytes") logger.info(f"Resuming download from offset {resume_offset} bytes")
total_size = 0 total_size = 0
while retry_count <= self.max_retries: while retry_count <= self.max_retries:
try: try:
session = await self.session session = await self.session
# Debug log for proxy mode at request time # Debug log for proxy mode at request time
if self.proxy_url: if self.proxy_url:
logger.debug(f"[download_file] Using app-level proxy: {self.proxy_url}") logger.debug(
f"[download_file] Using app-level proxy: {self.proxy_url}"
)
else: else:
logger.debug("[download_file] Using system-level proxy (trust_env) if configured.") logger.debug(
"[download_file] Using system-level proxy (trust_env) if configured."
)
# Add Range header for resume if we have partial data # Add Range header for resume if we have partial data
request_headers = headers.copy() request_headers = headers.copy()
if allow_resume and resume_offset > 0: if allow_resume and resume_offset > 0:
request_headers['Range'] = f'bytes={resume_offset}-' request_headers["Range"] = f"bytes={resume_offset}-"
# Disable compression for better chunked downloads # Disable compression for better chunked downloads
request_headers['Accept-Encoding'] = 'identity' request_headers["Accept-Encoding"] = "identity"
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}") logger.debug(
f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}"
)
if resume_offset > 0: if resume_offset > 0:
logger.debug(f"Requesting range from byte {resume_offset}") logger.debug(f"Requesting range from byte {resume_offset}")
async with session.get(url, headers=request_headers, allow_redirects=True, proxy=self.proxy_url) as response: async with session.get(
url,
headers=request_headers,
allow_redirects=True,
proxy=self.proxy_url,
) as response:
# Handle different response codes # Handle different response codes
if response.status == 200: if response.status == 200:
# Full content response # Full content response
if resume_offset > 0: if resume_offset > 0:
# Server doesn't support ranges, restart from beginning # Server doesn't support ranges, restart from beginning
logger.warning("Server doesn't support range requests, restarting download") logger.warning(
"Server doesn't support range requests, restarting download"
)
resume_offset = 0 resume_offset = 0
if os.path.exists(part_path): if os.path.exists(part_path):
os.remove(part_path) os.remove(part_path)
elif response.status == 206: elif response.status == 206:
# Partial content response (resume successful) # Partial content response (resume successful)
content_range = response.headers.get('Content-Range') content_range = response.headers.get("Content-Range")
if content_range: if content_range:
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048") # Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
range_parts = content_range.split('/') range_parts = content_range.split("/")
if len(range_parts) == 2: if len(range_parts) == 2:
total_size = int(range_parts[1]) total_size = int(range_parts[1])
logger.info(f"Successfully resumed download from byte {resume_offset}") logger.info(
f"Successfully resumed download from byte {resume_offset}"
)
elif response.status == 416: elif response.status == 416:
# Range not satisfiable - file might be complete or corrupted # Range not satisfiable - file might be complete or corrupted
if allow_resume and os.path.exists(part_path): if allow_resume and os.path.exists(part_path):
part_size = os.path.getsize(part_path) part_size = os.path.getsize(part_path)
logger.warning(f"Range not satisfiable. Part file size: {part_size}") logger.warning(
f"Range not satisfiable. Part file size: {part_size}"
)
# Try to get actual file size # Try to get actual file size
head_response = await session.head(url, headers=headers, proxy=self.proxy_url) head_response = await session.head(
url, headers=headers, proxy=self.proxy_url
)
if head_response.status == 200: if head_response.status == 200:
actual_size = int(head_response.headers.get('content-length', 0)) actual_size = int(
head_response.headers.get("content-length", 0)
)
if part_size == actual_size: if part_size == actual_size:
# File is complete, just rename it # File is complete, just rename it
if allow_resume: if allow_resume:
@@ -388,25 +426,40 @@ class Downloader:
resume_offset = 0 resume_offset = 0
continue continue
elif response.status == 401: elif response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)") logger.warning(
return False, "Invalid or missing API key, or early access restriction." f"Unauthorized access to resource: {url} (Status 401)"
)
return (
False,
"Invalid or missing API key, or early access restriction.",
)
elif response.status == 403: elif response.status == 403:
logger.warning(f"Forbidden access to resource: {url} (Status 403)") logger.warning(
return False, "Access forbidden: You don't have permission to download this file." f"Forbidden access to resource: {url} (Status 403)"
)
return (
False,
"Access forbidden: You don't have permission to download this file.",
)
elif response.status == 404: elif response.status == 404:
logger.warning(f"Resource not found: {url} (Status 404)") logger.warning(f"Resource not found: {url} (Status 404)")
return False, "File not found - the download link may be invalid or expired." return (
False,
"File not found - the download link may be invalid or expired.",
)
else: else:
logger.error(f"Download failed for {url} with status {response.status}") logger.error(
f"Download failed for {url} with status {response.status}"
)
return False, f"Download failed with status {response.status}" return False, f"Download failed with status {response.status}"
# Get total file size for progress calculation (if not set from Content-Range) # Get total file size for progress calculation (if not set from Content-Range)
if total_size == 0: if total_size == 0:
total_size = int(response.headers.get('content-length', 0)) total_size = int(response.headers.get("content-length", 0))
if response.status == 206: if response.status == 206:
# For partial content, add the offset to get total file size # For partial content, add the offset to get total file size
total_size += resume_offset total_size += resume_offset
current_size = resume_offset current_size = resume_offset
last_progress_report_time = datetime.now() last_progress_report_time = datetime.now()
progress_samples: deque[tuple[datetime, int]] = deque() progress_samples: deque[tuple[datetime, int]] = deque()
@@ -417,7 +470,7 @@ class Downloader:
# Stream download to file with progress updates # Stream download to file with progress updates
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb' mode = "ab" if (allow_resume and resume_offset > 0) else "wb"
control = pause_event control = pause_event
if control is not None: if control is not None:
@@ -425,7 +478,9 @@ class Downloader:
with open(part_path, mode) as f: with open(part_path, mode) as f:
while True: while True:
active_stall_timeout = control.stall_timeout if control else self.stall_timeout active_stall_timeout = (
control.stall_timeout if control else self.stall_timeout
)
if control is not None: if control is not None:
if control.is_paused(): if control.is_paused():
@@ -437,7 +492,9 @@ class Downloader:
"Reconnect requested after resume" "Reconnect requested after resume"
) )
elif control.consume_reconnect_request(): elif control.consume_reconnect_request():
raise DownloadRestartRequested("Reconnect requested") raise DownloadRestartRequested(
"Reconnect requested"
)
try: try:
chunk = await asyncio.wait_for( chunk = await asyncio.wait_for(
@@ -466,22 +523,32 @@ class Downloader:
control.mark_progress(timestamp=now.timestamp()) control.mark_progress(timestamp=now.timestamp())
# Limit progress update frequency to reduce overhead # Limit progress update frequency to reduce overhead
time_diff = (now - last_progress_report_time).total_seconds() time_diff = (
now - last_progress_report_time
).total_seconds()
if progress_callback and time_diff >= 1.0: if progress_callback and time_diff >= 1.0:
progress_samples.append((now, current_size)) progress_samples.append((now, current_size))
cutoff = now - timedelta(seconds=5) cutoff = now - timedelta(seconds=5)
while progress_samples and progress_samples[0][0] < cutoff: while (
progress_samples and progress_samples[0][0] < cutoff
):
progress_samples.popleft() progress_samples.popleft()
percent = (current_size / total_size) * 100 if total_size else 0.0 percent = (
(current_size / total_size) * 100
if total_size
else 0.0
)
bytes_per_second = 0.0 bytes_per_second = 0.0
if len(progress_samples) >= 2: if len(progress_samples) >= 2:
first_time, first_bytes = progress_samples[0] first_time, first_bytes = progress_samples[0]
last_time, last_bytes = progress_samples[-1] last_time, last_bytes = progress_samples[-1]
elapsed = (last_time - first_time).total_seconds() elapsed = (last_time - first_time).total_seconds()
if elapsed > 0: if elapsed > 0:
bytes_per_second = (last_bytes - first_bytes) / elapsed bytes_per_second = (
last_bytes - first_bytes
) / elapsed
progress_snapshot = DownloadProgress( progress_snapshot = DownloadProgress(
percent_complete=percent, percent_complete=percent,
@@ -491,21 +558,23 @@ class Downloader:
timestamp=now.timestamp(), timestamp=now.timestamp(),
) )
await self._dispatch_progress_callback(progress_callback, progress_snapshot) await self._dispatch_progress_callback(
progress_callback, progress_snapshot
)
last_progress_report_time = now last_progress_report_time = now
# Download completed successfully # Download completed successfully
# Verify file size integrity before finalizing # Verify file size integrity before finalizing
final_size = os.path.getsize(part_path) if os.path.exists(part_path) else 0 final_size = (
os.path.getsize(part_path) if os.path.exists(part_path) else 0
)
expected_size = total_size if total_size > 0 else None expected_size = total_size if total_size > 0 else None
integrity_error: Optional[str] = None integrity_error: Optional[str] = None
if final_size <= 0: if final_size <= 0:
integrity_error = "Downloaded file is empty" integrity_error = "Downloaded file is empty"
elif expected_size is not None and final_size != expected_size: elif expected_size is not None and final_size != expected_size:
integrity_error = ( integrity_error = f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
)
if integrity_error is not None: if integrity_error is not None:
logger.error( logger.error(
@@ -554,8 +623,10 @@ class Downloader:
max_rename_attempts = 5 max_rename_attempts = 5
rename_attempt = 0 rename_attempt = 0
rename_success = False rename_success = False
while rename_attempt < max_rename_attempts and not rename_success: while (
rename_attempt < max_rename_attempts and not rename_success
):
try: try:
# If the destination file exists, remove it first (Windows safe) # If the destination file exists, remove it first (Windows safe)
if os.path.exists(save_path): if os.path.exists(save_path):
@@ -566,11 +637,18 @@ class Downloader:
except PermissionError as e: except PermissionError as e:
rename_attempt += 1 rename_attempt += 1
if rename_attempt < max_rename_attempts: if rename_attempt < max_rename_attempts:
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})") logger.info(
f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})"
)
await asyncio.sleep(2) await asyncio.sleep(2)
else: else:
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}") logger.error(
return False, f"Failed to finalize download: {str(e)}" f"Failed to rename file after {max_rename_attempts} attempts: {e}"
)
return (
False,
f"Failed to finalize download: {str(e)}",
)
final_size = os.path.getsize(save_path) final_size = os.path.getsize(save_path)
@@ -583,11 +661,12 @@ class Downloader:
bytes_per_second=0.0, bytes_per_second=0.0,
timestamp=datetime.now().timestamp(), timestamp=datetime.now().timestamp(),
) )
await self._dispatch_progress_callback(progress_callback, final_snapshot) await self._dispatch_progress_callback(
progress_callback, final_snapshot
)
return True, save_path return True, save_path
except ( except (
aiohttp.ClientError, aiohttp.ClientError,
aiohttp.ClientPayloadError, aiohttp.ClientPayloadError,
@@ -597,30 +676,35 @@ class Downloader:
DownloadRestartRequested, DownloadRestartRequested,
) as e: ) as e:
retry_count += 1 retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}") logger.warning(
f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}"
)
if retry_count <= self.max_retries: if retry_count <= self.max_retries:
# Calculate delay with exponential backoff # Calculate delay with exponential backoff
delay = self.base_delay * (2 ** (retry_count - 1)) delay = self.base_delay * (2 ** (retry_count - 1))
logger.info(f"Retrying in {delay} seconds...") logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay) await asyncio.sleep(delay)
# Update resume offset for next attempt # Update resume offset for next attempt
if allow_resume and os.path.exists(part_path): if allow_resume and os.path.exists(part_path):
resume_offset = os.path.getsize(part_path) resume_offset = os.path.getsize(part_path)
logger.info(f"Will resume from byte {resume_offset}") logger.info(f"Will resume from byte {resume_offset}")
# Refresh session to get new connection # Refresh session to get new connection
await self._create_session() await self._create_session()
continue continue
else: else:
logger.error(f"Max retries exceeded for download: {e}") logger.error(f"Max retries exceeded for download: {e}")
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}" return (
False,
f"Network error after {self.max_retries + 1} attempts: {str(e)}",
)
except Exception as e: except Exception as e:
logger.error(f"Unexpected download error: {e}") logger.error(f"Unexpected download error: {e}")
return False, str(e) return False, str(e)
return False, f"Download failed after {self.max_retries + 1} attempts" return False, f"Download failed after {self.max_retries + 1} attempts"
async def _dispatch_progress_callback( async def _dispatch_progress_callback(
@@ -645,17 +729,17 @@ class Downloader:
url: str, url: str,
use_auth: bool = False, use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None, custom_headers: Optional[Dict[str, str]] = None,
return_headers: bool = False return_headers: bool = False,
) -> Tuple[bool, Union[bytes, str], Optional[Dict]]: ) -> Tuple[bool, Union[bytes, str], Optional[Dict]]:
""" """
Download a file to memory (for small files like preview images) Download a file to memory (for small files like preview images)
Args: Args:
url: Download URL url: Download URL
use_auth: Whether to include authentication headers use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request custom_headers: Additional headers to include in request
return_headers: Whether to return response headers along with content return_headers: Whether to return response headers along with content
Returns: Returns:
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested) Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
""" """
@@ -663,16 +747,22 @@ class Downloader:
session = await self.session session = await self.session
# Debug log for proxy mode at request time # Debug log for proxy mode at request time
if self.proxy_url: if self.proxy_url:
logger.debug(f"[download_to_memory] Using app-level proxy: {self.proxy_url}") logger.debug(
f"[download_to_memory] Using app-level proxy: {self.proxy_url}"
)
else: else:
logger.debug("[download_to_memory] Using system-level proxy (trust_env) if configured.") logger.debug(
"[download_to_memory] Using system-level proxy (trust_env) if configured."
)
# Prepare headers # Prepare headers
headers = self._get_auth_headers(use_auth) headers = self._get_auth_headers(use_auth)
if custom_headers: if custom_headers:
headers.update(custom_headers) headers.update(custom_headers)
async with session.get(url, headers=headers, proxy=self.proxy_url) as response: async with session.get(
url, headers=headers, proxy=self.proxy_url
) as response:
if response.status == 200: if response.status == 200:
content = await response.read() content = await response.read()
if return_headers: if return_headers:
@@ -691,25 +781,25 @@ class Downloader:
else: else:
error_msg = f"Download failed with status {response.status}" error_msg = f"Download failed with status {response.status}"
return False, error_msg, None return False, error_msg, None
except Exception as e: except Exception as e:
logger.error(f"Error downloading to memory from {url}: {e}") logger.error(f"Error downloading to memory from {url}: {e}")
return False, str(e), None return False, str(e), None
async def get_response_headers( async def get_response_headers(
self, self,
url: str, url: str,
use_auth: bool = False, use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None custom_headers: Optional[Dict[str, str]] = None,
) -> Tuple[bool, Union[Dict, str]]: ) -> Tuple[bool, Union[Dict, str]]:
""" """
Get response headers without downloading the full content Get response headers without downloading the full content
Args: Args:
url: URL to check url: URL to check
use_auth: Whether to include authentication headers use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request custom_headers: Additional headers to include in request
Returns: Returns:
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message) Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
""" """
@@ -717,43 +807,49 @@ class Downloader:
session = await self.session session = await self.session
# Debug log for proxy mode at request time # Debug log for proxy mode at request time
if self.proxy_url: if self.proxy_url:
logger.debug(f"[get_response_headers] Using app-level proxy: {self.proxy_url}") logger.debug(
f"[get_response_headers] Using app-level proxy: {self.proxy_url}"
)
else: else:
logger.debug("[get_response_headers] Using system-level proxy (trust_env) if configured.") logger.debug(
"[get_response_headers] Using system-level proxy (trust_env) if configured."
)
# Prepare headers # Prepare headers
headers = self._get_auth_headers(use_auth) headers = self._get_auth_headers(use_auth)
if custom_headers: if custom_headers:
headers.update(custom_headers) headers.update(custom_headers)
async with session.head(url, headers=headers, proxy=self.proxy_url) as response: async with session.head(
url, headers=headers, proxy=self.proxy_url
) as response:
if response.status == 200: if response.status == 200:
return True, dict(response.headers) return True, dict(response.headers)
else: else:
return False, f"Head request failed with status {response.status}" return False, f"Head request failed with status {response.status}"
except Exception as e: except Exception as e:
logger.error(f"Error getting headers from {url}: {e}") logger.error(f"Error getting headers from {url}: {e}")
return False, str(e) return False, str(e)
async def make_request( async def make_request(
self, self,
method: str, method: str,
url: str, url: str,
use_auth: bool = False, use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None, custom_headers: Optional[Dict[str, str]] = None,
**kwargs **kwargs,
) -> Tuple[bool, Union[Dict, str]]: ) -> Tuple[bool, Union[Dict, str]]:
""" """
Make a generic HTTP request and return JSON response Make a generic HTTP request and return JSON response
Args: Args:
method: HTTP method (GET, POST, etc.) method: HTTP method (GET, POST, etc.)
url: Request URL url: Request URL
use_auth: Whether to include authentication headers use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request custom_headers: Additional headers to include in request
**kwargs: Additional arguments for aiohttp request **kwargs: Additional arguments for aiohttp request
Returns: Returns:
Tuple[bool, Union[Dict, str]]: (success, response data or error message) Tuple[bool, Union[Dict, str]]: (success, response data or error message)
""" """
@@ -763,18 +859,22 @@ class Downloader:
if self.proxy_url: if self.proxy_url:
logger.debug(f"[make_request] Using app-level proxy: {self.proxy_url}") logger.debug(f"[make_request] Using app-level proxy: {self.proxy_url}")
else: else:
logger.debug("[make_request] Using system-level proxy (trust_env) if configured.") logger.debug(
"[make_request] Using system-level proxy (trust_env) if configured."
)
# Prepare headers # Prepare headers
headers = self._get_auth_headers(use_auth) headers = self._get_auth_headers(use_auth)
if custom_headers: if custom_headers:
headers.update(custom_headers) headers.update(custom_headers)
# Add proxy to kwargs if not already present # Add proxy to kwargs if not already present
if 'proxy' not in kwargs: if "proxy" not in kwargs:
kwargs['proxy'] = self.proxy_url kwargs["proxy"] = self.proxy_url
async with session.request(method, url, headers=headers, **kwargs) as response: async with session.request(
method, url, headers=headers, **kwargs
) as response:
if response.status == 200: if response.status == 200:
# Try to parse as JSON, fall back to text # Try to parse as JSON, fall back to text
try: try:
@@ -804,11 +904,11 @@ class Downloader:
) )
else: else:
return False, f"Request failed with status {response.status}" return False, f"Request failed with status {response.status}"
except Exception as e: except Exception as e:
logger.error(f"Error making {method} request to {url}: {e}") logger.error(f"Error making {method} request to {url}: {e}")
return False, str(e) return False, str(e)
async def close(self): async def close(self):
"""Close the HTTP session""" """Close the HTTP session"""
if self._session is not None: if self._session is not None:
@@ -817,7 +917,7 @@ class Downloader:
self._session_created_at = None self._session_created_at = None
self._proxy_url = None self._proxy_url = None
logger.debug("Closed HTTP session") logger.debug("Closed HTTP session")
async def refresh_session(self): async def refresh_session(self):
"""Force refresh the HTTP session (useful when proxy settings change)""" """Force refresh the HTTP session (useful when proxy settings change)"""
async with self._session_lock: async with self._session_lock: