mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-05-06 16:36:45 -03:00
Compare commits
3 Commits
e97648c70b
...
de3d0571f8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de3d0571f8 | ||
|
|
6f2a01dc86 | ||
|
|
c5c1b8fd2a |
@@ -490,14 +490,33 @@ class CivitaiClient:
|
||||
"""
|
||||
try:
|
||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||
requested_id = int(image_id)
|
||||
|
||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||
success, result = await self._make_request("GET", url, use_auth=True)
|
||||
|
||||
if success:
|
||||
if result and "items" in result and len(result["items"]) > 0:
|
||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||
return result["items"][0]
|
||||
if result and "items" in result and isinstance(result["items"], list):
|
||||
items = result["items"]
|
||||
|
||||
# First, try to find the item with matching ID
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("id") == requested_id:
|
||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||
return item
|
||||
|
||||
# No matching ID found - log warning with details about returned items
|
||||
returned_ids = [
|
||||
item.get("id") for item in items
|
||||
if isinstance(item, dict) and "id" in item
|
||||
]
|
||||
logger.warning(
|
||||
f"CivitAI API returned no matching image for requested ID {image_id}. "
|
||||
f"Returned {len(items)} item(s) with IDs: {returned_ids}. "
|
||||
f"This may indicate the image was deleted, hidden, or there is a database lag."
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(f"No image found with ID: {image_id}")
|
||||
return None
|
||||
|
||||
@@ -505,6 +524,10 @@ class CivitaiClient:
|
||||
return None
|
||||
except RateLimitError:
|
||||
raise
|
||||
except ValueError as e:
|
||||
error_msg = f"Invalid image ID format: {image_id}"
|
||||
logger.error(error_msg)
|
||||
return None
|
||||
except Exception as e:
|
||||
error_msg = f"Error fetching image info: {e}"
|
||||
logger.error(error_msg)
|
||||
|
||||
@@ -19,7 +19,6 @@ from ..utils.civitai_utils import rewrite_preview_url
|
||||
from ..utils.preview_selection import select_preview_media
|
||||
from ..utils.utils import sanitize_folder_name
|
||||
from ..utils.exif_utils import ExifUtils
|
||||
from ..utils.file_utils import calculate_sha256
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .service_registry import ServiceRegistry
|
||||
from .settings_manager import get_settings_manager
|
||||
@@ -965,11 +964,12 @@ class DownloadManager:
|
||||
for download_url in download_urls:
|
||||
use_auth = download_url.startswith("https://civitai.com/api/download/")
|
||||
download_kwargs = {
|
||||
"progress_callback": lambda progress,
|
||||
snapshot=None: self._handle_download_progress(
|
||||
progress,
|
||||
progress_callback,
|
||||
snapshot,
|
||||
"progress_callback": lambda progress, snapshot=None: (
|
||||
self._handle_download_progress(
|
||||
progress,
|
||||
progress_callback,
|
||||
snapshot,
|
||||
)
|
||||
),
|
||||
"use_auth": use_auth, # Only use authentication for Civitai downloads
|
||||
}
|
||||
@@ -1238,7 +1238,8 @@ class DownloadManager:
|
||||
entry.file_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
# Update size to actual downloaded file size
|
||||
entry.size = os.path.getsize(file_path)
|
||||
entry.sha256 = await calculate_sha256(file_path)
|
||||
# Use SHA256 from API metadata (already set in from_civitai_info)
|
||||
# Do not recalculate to avoid blocking during ComfyUI execution
|
||||
entries.append(entry)
|
||||
|
||||
return entries
|
||||
|
||||
@@ -44,7 +44,9 @@ class DownloadStreamControl:
|
||||
self._event.set()
|
||||
self._reconnect_requested = False
|
||||
self.last_progress_timestamp: Optional[float] = None
|
||||
self.stall_timeout: float = float(stall_timeout) if stall_timeout is not None else 120.0
|
||||
self.stall_timeout: float = (
|
||||
float(stall_timeout) if stall_timeout is not None else 120.0
|
||||
)
|
||||
|
||||
def is_set(self) -> bool:
|
||||
return self._event.is_set()
|
||||
@@ -85,7 +87,9 @@ class DownloadStreamControl:
|
||||
self.last_progress_timestamp = timestamp or datetime.now().timestamp()
|
||||
self._reconnect_requested = False
|
||||
|
||||
def time_since_last_progress(self, *, now: Optional[float] = None) -> Optional[float]:
|
||||
def time_since_last_progress(
|
||||
self, *, now: Optional[float] = None
|
||||
) -> Optional[float]:
|
||||
if self.last_progress_timestamp is None:
|
||||
return None
|
||||
reference = now if now is not None else datetime.now().timestamp()
|
||||
@@ -105,10 +109,10 @@ class DownloadStalledError(Exception):
|
||||
|
||||
class Downloader:
|
||||
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||
|
||||
|
||||
_instance = None
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@classmethod
|
||||
async def get_instance(cls):
|
||||
"""Get singleton instance of Downloader"""
|
||||
@@ -116,35 +120,37 @@ class Downloader:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls()
|
||||
return cls._instance
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the downloader with optimal settings"""
|
||||
# Check if already initialized for singleton pattern
|
||||
if hasattr(self, '_initialized'):
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
self._initialized = True
|
||||
|
||||
|
||||
# Session management
|
||||
self._session = None
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None # Store proxy URL for current session
|
||||
self._session_lock = asyncio.Lock()
|
||||
|
||||
|
||||
# Configuration
|
||||
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput
|
||||
self.chunk_size = (
|
||||
16 * 1024 * 1024
|
||||
) # 16MB chunks to balance I/O reduction and memory usage
|
||||
self.max_retries = 5
|
||||
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||
self.session_timeout = 300 # 5 minutes
|
||||
self.stall_timeout = self._resolve_stall_timeout()
|
||||
|
||||
|
||||
# Default headers
|
||||
self.default_headers = {
|
||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
|
||||
"User-Agent": "ComfyUI-LoRA-Manager/1.0",
|
||||
# Explicitly request uncompressed payloads so aiohttp doesn't need optional
|
||||
# decoders (e.g. zstandard) that may be missing in runtime environments.
|
||||
'Accept-Encoding': 'identity',
|
||||
"Accept-Encoding": "identity",
|
||||
}
|
||||
|
||||
|
||||
@property
|
||||
async def session(self) -> aiohttp.ClientSession:
|
||||
"""Get or create the global aiohttp session with optimized settings"""
|
||||
@@ -158,7 +164,7 @@ class Downloader:
|
||||
@property
|
||||
def proxy_url(self) -> Optional[str]:
|
||||
"""Get the current proxy URL (initialize if needed)"""
|
||||
if not hasattr(self, '_proxy_url'):
|
||||
if not hasattr(self, "_proxy_url"):
|
||||
self._proxy_url = None
|
||||
return self._proxy_url
|
||||
|
||||
@@ -169,14 +175,14 @@ class Downloader:
|
||||
|
||||
try:
|
||||
settings_manager = get_settings_manager()
|
||||
settings_timeout = settings_manager.get('download_stall_timeout_seconds')
|
||||
settings_timeout = settings_manager.get("download_stall_timeout_seconds")
|
||||
except Exception as exc: # pragma: no cover - defensive guard
|
||||
logger.debug("Failed to read stall timeout from settings: %s", exc)
|
||||
|
||||
raw_value = (
|
||||
settings_timeout
|
||||
if settings_timeout not in (None, "")
|
||||
else os.environ.get('COMFYUI_DOWNLOAD_STALL_TIMEOUT')
|
||||
else os.environ.get("COMFYUI_DOWNLOAD_STALL_TIMEOUT")
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -190,93 +196,104 @@ class Downloader:
|
||||
"""Check if session should be refreshed"""
|
||||
if self._session is None:
|
||||
return True
|
||||
|
||||
if not hasattr(self, '_session_created_at') or self._session_created_at is None:
|
||||
|
||||
if not hasattr(self, "_session_created_at") or self._session_created_at is None:
|
||||
return True
|
||||
|
||||
|
||||
# Refresh if session is older than timeout
|
||||
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout:
|
||||
if (
|
||||
datetime.now() - self._session_created_at
|
||||
).total_seconds() > self.session_timeout:
|
||||
return True
|
||||
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _create_session(self):
|
||||
"""Create a new aiohttp session with optimized settings.
|
||||
|
||||
|
||||
Note: This is private and caller MUST hold self._session_lock.
|
||||
"""
|
||||
# Close existing session if any
|
||||
if self._session is not None:
|
||||
try:
|
||||
await self._session.close()
|
||||
except Exception as e: # pragma: no cover
|
||||
except Exception as e: # pragma: no cover
|
||||
logger.warning(f"Error closing previous session: {e}")
|
||||
finally:
|
||||
self._session = None
|
||||
|
||||
|
||||
# Check for app-level proxy settings
|
||||
proxy_url = None
|
||||
settings_manager = get_settings_manager()
|
||||
if settings_manager.get('proxy_enabled', False):
|
||||
proxy_host = settings_manager.get('proxy_host', '').strip()
|
||||
proxy_port = settings_manager.get('proxy_port', '').strip()
|
||||
proxy_type = settings_manager.get('proxy_type', 'http').lower()
|
||||
proxy_username = settings_manager.get('proxy_username', '').strip()
|
||||
proxy_password = settings_manager.get('proxy_password', '').strip()
|
||||
|
||||
if settings_manager.get("proxy_enabled", False):
|
||||
proxy_host = settings_manager.get("proxy_host", "").strip()
|
||||
proxy_port = settings_manager.get("proxy_port", "").strip()
|
||||
proxy_type = settings_manager.get("proxy_type", "http").lower()
|
||||
proxy_username = settings_manager.get("proxy_username", "").strip()
|
||||
proxy_password = settings_manager.get("proxy_password", "").strip()
|
||||
|
||||
if proxy_host and proxy_port:
|
||||
# Build proxy URL
|
||||
if proxy_username and proxy_password:
|
||||
proxy_url = f"{proxy_type}://{proxy_username}:{proxy_password}@{proxy_host}:{proxy_port}"
|
||||
else:
|
||||
proxy_url = f"{proxy_type}://{proxy_host}:{proxy_port}"
|
||||
|
||||
logger.debug(f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}")
|
||||
|
||||
logger.debug(
|
||||
f"Using app-level proxy: {proxy_type}://{proxy_host}:{proxy_port}"
|
||||
)
|
||||
logger.debug("Proxy mode: app-level proxy is active.")
|
||||
else:
|
||||
logger.debug("Proxy mode: system-level proxy (trust_env) will be used if configured in environment.")
|
||||
logger.debug(
|
||||
"Proxy mode: system-level proxy (trust_env) will be used if configured in environment."
|
||||
)
|
||||
# Optimize TCP connection parameters
|
||||
connector = aiohttp.TCPConnector(
|
||||
ssl=True,
|
||||
limit=8, # Concurrent connections
|
||||
ttl_dns_cache=300, # DNS cache timeout
|
||||
force_close=False, # Keep connections for reuse
|
||||
enable_cleanup_closed=True
|
||||
enable_cleanup_closed=True,
|
||||
)
|
||||
|
||||
|
||||
# Configure timeout parameters
|
||||
timeout = aiohttp.ClientTimeout(
|
||||
total=None, # No total timeout for large downloads
|
||||
connect=60, # Connection timeout
|
||||
sock_read=300 # 5 minute socket read timeout
|
||||
sock_read=300, # 5 minute socket read timeout
|
||||
)
|
||||
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=proxy_url is None, # Only use system proxy if no app-level proxy is set
|
||||
timeout=timeout
|
||||
trust_env=proxy_url
|
||||
is None, # Only use system proxy if no app-level proxy is set
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
||||
# Store proxy URL for use in requests
|
||||
self._proxy_url = proxy_url
|
||||
self._session_created_at = datetime.now()
|
||||
|
||||
logger.debug("Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s", bool(proxy_url), proxy_url is None)
|
||||
|
||||
|
||||
logger.debug(
|
||||
"Created new HTTP session with proxy settings. App-level proxy: %s, System-level proxy (trust_env): %s",
|
||||
bool(proxy_url),
|
||||
proxy_url is None,
|
||||
)
|
||||
|
||||
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
|
||||
"""Get headers with optional authentication"""
|
||||
headers = self.default_headers.copy()
|
||||
|
||||
|
||||
if use_auth:
|
||||
# Add CivitAI API key if available
|
||||
settings_manager = get_settings_manager()
|
||||
api_key = settings_manager.get('civitai_api_key')
|
||||
api_key = settings_manager.get("civitai_api_key")
|
||||
if api_key:
|
||||
headers['Authorization'] = f'Bearer {api_key}'
|
||||
headers['Content-Type'] = 'application/json'
|
||||
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
url: str,
|
||||
@@ -289,7 +306,7 @@ class Downloader:
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Download a file with resumable downloads and retry mechanism
|
||||
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
save_path: Full path where the file should be saved
|
||||
@@ -298,75 +315,96 @@ class Downloader:
|
||||
custom_headers: Additional headers to include in request
|
||||
allow_resume: Whether to support resumable downloads
|
||||
pause_event: Optional stream control used to pause/resume and request reconnects
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (success, save_path or error message)
|
||||
"""
|
||||
retry_count = 0
|
||||
part_path = save_path + '.part' if allow_resume else save_path
|
||||
|
||||
part_path = save_path + ".part" if allow_resume else save_path
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
|
||||
# Get existing file size for resume
|
||||
resume_offset = 0
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||
|
||||
|
||||
total_size = 0
|
||||
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_file] Using app-level proxy: {self.proxy_url}")
|
||||
logger.debug(
|
||||
f"[download_file] Using app-level proxy: {self.proxy_url}"
|
||||
)
|
||||
else:
|
||||
logger.debug("[download_file] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
logger.debug(
|
||||
"[download_file] Using system-level proxy (trust_env) if configured."
|
||||
)
|
||||
|
||||
# Add Range header for resume if we have partial data
|
||||
request_headers = headers.copy()
|
||||
if allow_resume and resume_offset > 0:
|
||||
request_headers['Range'] = f'bytes={resume_offset}-'
|
||||
|
||||
request_headers["Range"] = f"bytes={resume_offset}-"
|
||||
|
||||
# Disable compression for better chunked downloads
|
||||
request_headers['Accept-Encoding'] = 'identity'
|
||||
|
||||
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}")
|
||||
request_headers["Accept-Encoding"] = "identity"
|
||||
|
||||
logger.debug(
|
||||
f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}"
|
||||
)
|
||||
if resume_offset > 0:
|
||||
logger.debug(f"Requesting range from byte {resume_offset}")
|
||||
|
||||
async with session.get(url, headers=request_headers, allow_redirects=True, proxy=self.proxy_url) as response:
|
||||
|
||||
async with session.get(
|
||||
url,
|
||||
headers=request_headers,
|
||||
allow_redirects=True,
|
||||
proxy=self.proxy_url,
|
||||
) as response:
|
||||
# Handle different response codes
|
||||
if response.status == 200:
|
||||
# Full content response
|
||||
if resume_offset > 0:
|
||||
# Server doesn't support ranges, restart from beginning
|
||||
logger.warning("Server doesn't support range requests, restarting download")
|
||||
logger.warning(
|
||||
"Server doesn't support range requests, restarting download"
|
||||
)
|
||||
resume_offset = 0
|
||||
if os.path.exists(part_path):
|
||||
os.remove(part_path)
|
||||
elif response.status == 206:
|
||||
# Partial content response (resume successful)
|
||||
content_range = response.headers.get('Content-Range')
|
||||
content_range = response.headers.get("Content-Range")
|
||||
if content_range:
|
||||
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
|
||||
range_parts = content_range.split('/')
|
||||
range_parts = content_range.split("/")
|
||||
if len(range_parts) == 2:
|
||||
total_size = int(range_parts[1])
|
||||
logger.info(f"Successfully resumed download from byte {resume_offset}")
|
||||
logger.info(
|
||||
f"Successfully resumed download from byte {resume_offset}"
|
||||
)
|
||||
elif response.status == 416:
|
||||
# Range not satisfiable - file might be complete or corrupted
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
part_size = os.path.getsize(part_path)
|
||||
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
|
||||
logger.warning(
|
||||
f"Range not satisfiable. Part file size: {part_size}"
|
||||
)
|
||||
# Try to get actual file size
|
||||
head_response = await session.head(url, headers=headers, proxy=self.proxy_url)
|
||||
head_response = await session.head(
|
||||
url, headers=headers, proxy=self.proxy_url
|
||||
)
|
||||
if head_response.status == 200:
|
||||
actual_size = int(head_response.headers.get('content-length', 0))
|
||||
actual_size = int(
|
||||
head_response.headers.get("content-length", 0)
|
||||
)
|
||||
if part_size == actual_size:
|
||||
# File is complete, just rename it
|
||||
if allow_resume:
|
||||
@@ -388,25 +426,40 @@ class Downloader:
|
||||
resume_offset = 0
|
||||
continue
|
||||
elif response.status == 401:
|
||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||
return False, "Invalid or missing API key, or early access restriction."
|
||||
logger.warning(
|
||||
f"Unauthorized access to resource: {url} (Status 401)"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
"Invalid or missing API key, or early access restriction.",
|
||||
)
|
||||
elif response.status == 403:
|
||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||
return False, "Access forbidden: You don't have permission to download this file."
|
||||
logger.warning(
|
||||
f"Forbidden access to resource: {url} (Status 403)"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
"Access forbidden: You don't have permission to download this file.",
|
||||
)
|
||||
elif response.status == 404:
|
||||
logger.warning(f"Resource not found: {url} (Status 404)")
|
||||
return False, "File not found - the download link may be invalid or expired."
|
||||
return (
|
||||
False,
|
||||
"File not found - the download link may be invalid or expired.",
|
||||
)
|
||||
else:
|
||||
logger.error(f"Download failed for {url} with status {response.status}")
|
||||
logger.error(
|
||||
f"Download failed for {url} with status {response.status}"
|
||||
)
|
||||
return False, f"Download failed with status {response.status}"
|
||||
|
||||
|
||||
# Get total file size for progress calculation (if not set from Content-Range)
|
||||
if total_size == 0:
|
||||
total_size = int(response.headers.get('content-length', 0))
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
if response.status == 206:
|
||||
# For partial content, add the offset to get total file size
|
||||
total_size += resume_offset
|
||||
|
||||
|
||||
current_size = resume_offset
|
||||
last_progress_report_time = datetime.now()
|
||||
progress_samples: deque[tuple[datetime, int]] = deque()
|
||||
@@ -417,7 +470,7 @@ class Downloader:
|
||||
|
||||
# Stream download to file with progress updates
|
||||
loop = asyncio.get_running_loop()
|
||||
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||
mode = "ab" if (allow_resume and resume_offset > 0) else "wb"
|
||||
control = pause_event
|
||||
|
||||
if control is not None:
|
||||
@@ -425,7 +478,9 @@ class Downloader:
|
||||
|
||||
with open(part_path, mode) as f:
|
||||
while True:
|
||||
active_stall_timeout = control.stall_timeout if control else self.stall_timeout
|
||||
active_stall_timeout = (
|
||||
control.stall_timeout if control else self.stall_timeout
|
||||
)
|
||||
|
||||
if control is not None:
|
||||
if control.is_paused():
|
||||
@@ -437,7 +492,9 @@ class Downloader:
|
||||
"Reconnect requested after resume"
|
||||
)
|
||||
elif control.consume_reconnect_request():
|
||||
raise DownloadRestartRequested("Reconnect requested")
|
||||
raise DownloadRestartRequested(
|
||||
"Reconnect requested"
|
||||
)
|
||||
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
@@ -466,22 +523,32 @@ class Downloader:
|
||||
control.mark_progress(timestamp=now.timestamp())
|
||||
|
||||
# Limit progress update frequency to reduce overhead
|
||||
time_diff = (now - last_progress_report_time).total_seconds()
|
||||
time_diff = (
|
||||
now - last_progress_report_time
|
||||
).total_seconds()
|
||||
|
||||
if progress_callback and time_diff >= 1.0:
|
||||
progress_samples.append((now, current_size))
|
||||
cutoff = now - timedelta(seconds=5)
|
||||
while progress_samples and progress_samples[0][0] < cutoff:
|
||||
while (
|
||||
progress_samples and progress_samples[0][0] < cutoff
|
||||
):
|
||||
progress_samples.popleft()
|
||||
|
||||
percent = (current_size / total_size) * 100 if total_size else 0.0
|
||||
percent = (
|
||||
(current_size / total_size) * 100
|
||||
if total_size
|
||||
else 0.0
|
||||
)
|
||||
bytes_per_second = 0.0
|
||||
if len(progress_samples) >= 2:
|
||||
first_time, first_bytes = progress_samples[0]
|
||||
last_time, last_bytes = progress_samples[-1]
|
||||
elapsed = (last_time - first_time).total_seconds()
|
||||
if elapsed > 0:
|
||||
bytes_per_second = (last_bytes - first_bytes) / elapsed
|
||||
bytes_per_second = (
|
||||
last_bytes - first_bytes
|
||||
) / elapsed
|
||||
|
||||
progress_snapshot = DownloadProgress(
|
||||
percent_complete=percent,
|
||||
@@ -491,21 +558,23 @@ class Downloader:
|
||||
timestamp=now.timestamp(),
|
||||
)
|
||||
|
||||
await self._dispatch_progress_callback(progress_callback, progress_snapshot)
|
||||
await self._dispatch_progress_callback(
|
||||
progress_callback, progress_snapshot
|
||||
)
|
||||
last_progress_report_time = now
|
||||
|
||||
|
||||
# Download completed successfully
|
||||
# Verify file size integrity before finalizing
|
||||
final_size = os.path.getsize(part_path) if os.path.exists(part_path) else 0
|
||||
final_size = (
|
||||
os.path.getsize(part_path) if os.path.exists(part_path) else 0
|
||||
)
|
||||
expected_size = total_size if total_size > 0 else None
|
||||
|
||||
integrity_error: Optional[str] = None
|
||||
if final_size <= 0:
|
||||
integrity_error = "Downloaded file is empty"
|
||||
elif expected_size is not None and final_size != expected_size:
|
||||
integrity_error = (
|
||||
f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
|
||||
)
|
||||
integrity_error = f"File size mismatch. Expected: {expected_size}, Got: {final_size}"
|
||||
|
||||
if integrity_error is not None:
|
||||
logger.error(
|
||||
@@ -554,8 +623,10 @@ class Downloader:
|
||||
max_rename_attempts = 5
|
||||
rename_attempt = 0
|
||||
rename_success = False
|
||||
|
||||
while rename_attempt < max_rename_attempts and not rename_success:
|
||||
|
||||
while (
|
||||
rename_attempt < max_rename_attempts and not rename_success
|
||||
):
|
||||
try:
|
||||
# If the destination file exists, remove it first (Windows safe)
|
||||
if os.path.exists(save_path):
|
||||
@@ -566,11 +637,18 @@ class Downloader:
|
||||
except PermissionError as e:
|
||||
rename_attempt += 1
|
||||
if rename_attempt < max_rename_attempts:
|
||||
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
|
||||
logger.info(
|
||||
f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})"
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
else:
|
||||
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
||||
return False, f"Failed to finalize download: {str(e)}"
|
||||
logger.error(
|
||||
f"Failed to rename file after {max_rename_attempts} attempts: {e}"
|
||||
)
|
||||
return (
|
||||
False,
|
||||
f"Failed to finalize download: {str(e)}",
|
||||
)
|
||||
|
||||
final_size = os.path.getsize(save_path)
|
||||
|
||||
@@ -583,11 +661,12 @@ class Downloader:
|
||||
bytes_per_second=0.0,
|
||||
timestamp=datetime.now().timestamp(),
|
||||
)
|
||||
await self._dispatch_progress_callback(progress_callback, final_snapshot)
|
||||
await self._dispatch_progress_callback(
|
||||
progress_callback, final_snapshot
|
||||
)
|
||||
|
||||
|
||||
return True, save_path
|
||||
|
||||
|
||||
except (
|
||||
aiohttp.ClientError,
|
||||
aiohttp.ClientPayloadError,
|
||||
@@ -597,30 +676,35 @@ class Downloader:
|
||||
DownloadRestartRequested,
|
||||
) as e:
|
||||
retry_count += 1
|
||||
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
||||
logger.warning(
|
||||
f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}"
|
||||
)
|
||||
|
||||
if retry_count <= self.max_retries:
|
||||
# Calculate delay with exponential backoff
|
||||
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||
logger.info(f"Retrying in {delay} seconds...")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
|
||||
# Update resume offset for next attempt
|
||||
if allow_resume and os.path.exists(part_path):
|
||||
resume_offset = os.path.getsize(part_path)
|
||||
logger.info(f"Will resume from byte {resume_offset}")
|
||||
|
||||
|
||||
# Refresh session to get new connection
|
||||
await self._create_session()
|
||||
continue
|
||||
else:
|
||||
logger.error(f"Max retries exceeded for download: {e}")
|
||||
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}"
|
||||
|
||||
return (
|
||||
False,
|
||||
f"Network error after {self.max_retries + 1} attempts: {str(e)}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected download error: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
return False, f"Download failed after {self.max_retries + 1} attempts"
|
||||
|
||||
async def _dispatch_progress_callback(
|
||||
@@ -645,17 +729,17 @@ class Downloader:
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
return_headers: bool = False
|
||||
return_headers: bool = False,
|
||||
) -> Tuple[bool, Union[bytes, str], Optional[Dict]]:
|
||||
"""
|
||||
Download a file to memory (for small files like preview images)
|
||||
|
||||
|
||||
Args:
|
||||
url: Download URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
return_headers: Whether to return response headers along with content
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[bytes, str], Optional[Dict]]: (success, content or error message, response headers if requested)
|
||||
"""
|
||||
@@ -663,16 +747,22 @@ class Downloader:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[download_to_memory] Using app-level proxy: {self.proxy_url}")
|
||||
logger.debug(
|
||||
f"[download_to_memory] Using app-level proxy: {self.proxy_url}"
|
||||
)
|
||||
else:
|
||||
logger.debug("[download_to_memory] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
logger.debug(
|
||||
"[download_to_memory] Using system-level proxy (trust_env) if configured."
|
||||
)
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.get(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
|
||||
async with session.get(
|
||||
url, headers=headers, proxy=self.proxy_url
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
content = await response.read()
|
||||
if return_headers:
|
||||
@@ -691,25 +781,25 @@ class Downloader:
|
||||
else:
|
||||
error_msg = f"Download failed with status {response.status}"
|
||||
return False, error_msg, None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||
return False, str(e), None
|
||||
|
||||
|
||||
async def get_response_headers(
|
||||
self,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Get response headers without downloading the full content
|
||||
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||
"""
|
||||
@@ -717,43 +807,49 @@ class Downloader:
|
||||
session = await self.session
|
||||
# Debug log for proxy mode at request time
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[get_response_headers] Using app-level proxy: {self.proxy_url}")
|
||||
logger.debug(
|
||||
f"[get_response_headers] Using app-level proxy: {self.proxy_url}"
|
||||
)
|
||||
else:
|
||||
logger.debug("[get_response_headers] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
logger.debug(
|
||||
"[get_response_headers] Using system-level proxy (trust_env) if configured."
|
||||
)
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
async with session.head(url, headers=headers, proxy=self.proxy_url) as response:
|
||||
|
||||
async with session.head(
|
||||
url, headers=headers, proxy=self.proxy_url
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return True, dict(response.headers)
|
||||
else:
|
||||
return False, f"Head request failed with status {response.status}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting headers from {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
async def make_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
use_auth: bool = False,
|
||||
custom_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Tuple[bool, Union[Dict, str]]:
|
||||
"""
|
||||
Make a generic HTTP request and return JSON response
|
||||
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
url: Request URL
|
||||
use_auth: Whether to include authentication headers
|
||||
custom_headers: Additional headers to include in request
|
||||
**kwargs: Additional arguments for aiohttp request
|
||||
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||
"""
|
||||
@@ -763,18 +859,22 @@ class Downloader:
|
||||
if self.proxy_url:
|
||||
logger.debug(f"[make_request] Using app-level proxy: {self.proxy_url}")
|
||||
else:
|
||||
logger.debug("[make_request] Using system-level proxy (trust_env) if configured.")
|
||||
|
||||
logger.debug(
|
||||
"[make_request] Using system-level proxy (trust_env) if configured."
|
||||
)
|
||||
|
||||
# Prepare headers
|
||||
headers = self._get_auth_headers(use_auth)
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
|
||||
# Add proxy to kwargs if not already present
|
||||
if 'proxy' not in kwargs:
|
||||
kwargs['proxy'] = self.proxy_url
|
||||
|
||||
async with session.request(method, url, headers=headers, **kwargs) as response:
|
||||
if "proxy" not in kwargs:
|
||||
kwargs["proxy"] = self.proxy_url
|
||||
|
||||
async with session.request(
|
||||
method, url, headers=headers, **kwargs
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
# Try to parse as JSON, fall back to text
|
||||
try:
|
||||
@@ -804,11 +904,11 @@ class Downloader:
|
||||
)
|
||||
else:
|
||||
return False, f"Request failed with status {response.status}"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error making {method} request to {url}: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP session"""
|
||||
if self._session is not None:
|
||||
@@ -817,7 +917,7 @@ class Downloader:
|
||||
self._session_created_at = None
|
||||
self._proxy_url = None
|
||||
logger.debug("Closed HTTP session")
|
||||
|
||||
|
||||
async def refresh_session(self):
|
||||
"""Force refresh the HTTP session (useful when proxy settings change)"""
|
||||
async with self._session_lock:
|
||||
|
||||
@@ -687,7 +687,7 @@
|
||||
padding: 12px 16px;
|
||||
background: oklch(var(--lora-warning) / 0.1);
|
||||
border: 1px solid var(--lora-warning);
|
||||
border-radius: var(--border-radius-sm) var(--border-radius-sm) 0 0;
|
||||
border-radius: var(--border-radius-sm);
|
||||
color: var(--text-color);
|
||||
}
|
||||
|
||||
|
||||
@@ -484,9 +484,11 @@ async def test_get_model_version_info_success(monkeypatch, downloader):
|
||||
assert result["images"][0]["meta"]["other"] == "keep"
|
||||
|
||||
|
||||
async def test_get_image_info_returns_first_item(monkeypatch, downloader):
|
||||
async def test_get_image_info_returns_matching_item(monkeypatch, downloader):
|
||||
"""When API returns multiple items, return the one matching the requested ID."""
|
||||
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||
return True, {"items": [{"id": 1}, {"id": 2}]}
|
||||
# Requested ID is 42, but it's the second item in the response
|
||||
return True, {"items": [{"id": 41}, {"id": 42, "name": "target"}, {"id": 43}]}
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
@@ -494,7 +496,25 @@ async def test_get_image_info_returns_first_item(monkeypatch, downloader):
|
||||
|
||||
result = await client.get_image_info("42")
|
||||
|
||||
assert result == {"id": 1}
|
||||
assert result == {"id": 42, "name": "target"}
|
||||
|
||||
|
||||
async def test_get_image_info_returns_none_when_id_mismatch(monkeypatch, downloader, caplog):
|
||||
"""When API returns items but none match the requested ID, return None and log warning."""
|
||||
async def fake_make_request(method, url, use_auth=True, **kwargs):
|
||||
# Requested ID is 999, but API returns different IDs (simulating deleted/hidden image)
|
||||
return True, {"items": [{"id": 1}, {"id": 2}, {"id": 3}]}
|
||||
|
||||
downloader.make_request = fake_make_request
|
||||
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_image_info("999")
|
||||
|
||||
assert result is None
|
||||
# Verify warning was logged
|
||||
assert "CivitAI API returned no matching image for requested ID 999" in caplog.text
|
||||
assert "Returned 3 item(s) with IDs: [1, 2, 3]" in caplog.text
|
||||
|
||||
|
||||
async def test_get_image_info_handles_missing(monkeypatch, downloader):
|
||||
@@ -508,3 +528,13 @@ async def test_get_image_info_handles_missing(monkeypatch, downloader):
|
||||
result = await client.get_image_info("42")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_get_image_info_handles_invalid_id(monkeypatch, downloader, caplog):
|
||||
"""When given a non-numeric image ID, return None and log error."""
|
||||
client = await CivitaiClient.get_instance()
|
||||
|
||||
result = await client.get_image_info("not-a-number")
|
||||
|
||||
assert result is None
|
||||
assert "Invalid image ID format" in caplog.text
|
||||
|
||||
Reference in New Issue
Block a user