mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Refactor download logic to use unified downloader service
- Introduced a new `Downloader` class to centralize HTTP/HTTPS download management. - Replaced direct `aiohttp` session handling with the unified downloader in `MetadataArchiveManager`, `DownloadManager`, and `ExampleImagesProcessor`. - Added support for resumable downloads, progress tracking, and error handling in the new downloader. - Updated methods to utilize the downloader's capabilities for downloading files and images, improving code maintainability and readability.
This commit is contained in:
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import aiohttp
|
|
||||||
import logging
|
import logging
|
||||||
import toml
|
import toml
|
||||||
import git
|
import git
|
||||||
@@ -8,7 +7,7 @@ import shutil
|
|||||||
import tempfile
|
import tempfile
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
from ..services.downloader import get_downloader, Downloader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -162,28 +161,42 @@ class UpdateRoutes:
|
|||||||
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
downloader = await get_downloader()
|
||||||
async with session.get(github_api) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
logger.error(f"Failed to fetch release info: {resp.status}")
|
|
||||||
return False, ""
|
|
||||||
data = await resp.json()
|
|
||||||
zip_url = data.get("zipball_url")
|
|
||||||
version = data.get("tag_name", "unknown")
|
|
||||||
|
|
||||||
# Download ZIP
|
# Get release info
|
||||||
async with session.get(zip_url) as zip_resp:
|
success, data = await downloader.make_request(
|
||||||
if zip_resp.status != 200:
|
'GET',
|
||||||
logger.error(f"Failed to download ZIP: {zip_resp.status}")
|
github_api,
|
||||||
return False, ""
|
use_auth=False
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
)
|
||||||
tmp_zip.write(await zip_resp.read())
|
if not success:
|
||||||
zip_path = tmp_zip.name
|
logger.error(f"Failed to fetch release info: {data}")
|
||||||
|
return False, ""
|
||||||
|
|
||||||
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json'])
|
zip_url = data.get("zipball_url")
|
||||||
|
version = data.get("tag_name", "unknown")
|
||||||
|
|
||||||
# Extract ZIP to temp dir
|
# Download ZIP to temporary file
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
||||||
|
tmp_zip_path = tmp_zip.name
|
||||||
|
|
||||||
|
success, result = await downloader.download_file(
|
||||||
|
url=zip_url,
|
||||||
|
save_path=tmp_zip_path,
|
||||||
|
use_auth=False,
|
||||||
|
allow_resume=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
logger.error(f"Failed to download ZIP: {result}")
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
zip_path = tmp_zip_path
|
||||||
|
|
||||||
|
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json'])
|
||||||
|
|
||||||
|
# Extract ZIP to temp dir
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
||||||
zip_ref.extractall(tmp_dir)
|
zip_ref.extractall(tmp_dir)
|
||||||
# Find extracted folder (GitHub ZIP contains a root folder)
|
# Find extracted folder (GitHub ZIP contains a root folder)
|
||||||
@@ -213,9 +226,9 @@ class UpdateRoutes:
|
|||||||
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
||||||
file.write('\n'.join(tracking_files))
|
file.write('\n'.join(tracking_files))
|
||||||
|
|
||||||
os.remove(zip_path)
|
os.remove(zip_path)
|
||||||
logger.info(f"Updated plugin via ZIP to {version}")
|
logger.info(f"Updated plugin via ZIP to {version}")
|
||||||
return True, version
|
return True, version
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
||||||
@@ -244,23 +257,23 @@ class UpdateRoutes:
|
|||||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
downloader = await Downloader.get_instance()
|
||||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'})
|
||||||
if response.status != 200:
|
|
||||||
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
|
|
||||||
return "main", []
|
|
||||||
|
|
||||||
data = await response.json()
|
if not success:
|
||||||
commit_sha = data.get('sha', '')[:7] # Short hash
|
logger.warning(f"Failed to fetch GitHub commit: {data}")
|
||||||
commit_message = data.get('commit', {}).get('message', '')
|
return "main", []
|
||||||
|
|
||||||
# Format as "main-{short_hash}"
|
commit_sha = data.get('sha', '')[:7] # Short hash
|
||||||
version = f"main-{commit_sha}"
|
commit_message = data.get('commit', {}).get('message', '')
|
||||||
|
|
||||||
# Use commit message as changelog
|
# Format as "main-{short_hash}"
|
||||||
changelog = [commit_message] if commit_message else []
|
version = f"main-{commit_sha}"
|
||||||
|
|
||||||
return version, changelog
|
# Use commit message as changelog
|
||||||
|
changelog = [commit_message] if commit_message else []
|
||||||
|
|
||||||
|
return version, changelog
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
||||||
@@ -410,22 +423,22 @@ class UpdateRoutes:
|
|||||||
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
downloader = await Downloader.get_instance()
|
||||||
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response:
|
success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'})
|
||||||
if response.status != 200:
|
|
||||||
logger.warning(f"Failed to fetch GitHub release: {response.status}")
|
|
||||||
return "v0.0.0", []
|
|
||||||
|
|
||||||
data = await response.json()
|
if not success:
|
||||||
version = data.get('tag_name', '')
|
logger.warning(f"Failed to fetch GitHub release: {data}")
|
||||||
if not version.startswith('v'):
|
return "v0.0.0", []
|
||||||
version = f"v{version}"
|
|
||||||
|
|
||||||
# Extract changelog from release notes
|
version = data.get('tag_name', '')
|
||||||
body = data.get('body', '')
|
if not version.startswith('v'):
|
||||||
changelog = UpdateRoutes._parse_changelog(body)
|
version = f"v{version}"
|
||||||
|
|
||||||
return version, changelog
|
# Extract changelog from release notes
|
||||||
|
body = data.get('body', '')
|
||||||
|
changelog = UpdateRoutes._parse_changelog(body)
|
||||||
|
|
||||||
|
return version, changelog
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching remote version: {e}", exc_info=True)
|
logger.error(f"Error fetching remote version: {e}", exc_info=True)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import aiohttp
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Optional, Dict, Tuple, List
|
from typing import Optional, Dict, Tuple, List
|
||||||
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
|
||||||
|
from .downloader import get_downloader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -32,60 +32,6 @@ class CivitaiClient:
|
|||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
self.base_url = "https://civitai.com/api/v1"
|
self.base_url = "https://civitai.com/api/v1"
|
||||||
self.headers = {
|
|
||||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
|
||||||
}
|
|
||||||
self._session = None
|
|
||||||
self._session_created_at = None
|
|
||||||
# Adjust chunk size based on storage type - consider making this configurable
|
|
||||||
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput
|
|
||||||
|
|
||||||
@property
|
|
||||||
async def session(self) -> aiohttp.ClientSession:
|
|
||||||
"""Lazy initialize the session"""
|
|
||||||
if self._session is None:
|
|
||||||
# Optimize TCP connection parameters
|
|
||||||
connector = aiohttp.TCPConnector(
|
|
||||||
ssl=True,
|
|
||||||
limit=8, # Increase from 3 to 8 for better parallelism
|
|
||||||
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
|
|
||||||
force_close=False, # Keep connections for reuse
|
|
||||||
enable_cleanup_closed=True
|
|
||||||
)
|
|
||||||
trust_env = True # Allow using system environment proxy settings
|
|
||||||
# Configure timeout parameters - increase read timeout for large files and remove sock_read timeout
|
|
||||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None)
|
|
||||||
self._session = aiohttp.ClientSession(
|
|
||||||
connector=connector,
|
|
||||||
trust_env=trust_env,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
self._session_created_at = datetime.now()
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
async def _ensure_fresh_session(self):
|
|
||||||
"""Refresh session if it's been open too long"""
|
|
||||||
if self._session is not None:
|
|
||||||
if not hasattr(self, '_session_created_at') or \
|
|
||||||
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
|
|
||||||
await self.close()
|
|
||||||
self._session = None
|
|
||||||
|
|
||||||
return await self.session
|
|
||||||
|
|
||||||
def _get_request_headers(self) -> dict:
|
|
||||||
"""Get request headers with optional API key"""
|
|
||||||
headers = {
|
|
||||||
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
from .settings_manager import settings
|
|
||||||
api_key = settings.get('civitai_api_key')
|
|
||||||
if (api_key):
|
|
||||||
headers['Authorization'] = f'Bearer {api_key}'
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
|
||||||
"""Download file with resumable downloads and retry mechanism
|
"""Download file with resumable downloads and retry mechanism
|
||||||
@@ -99,214 +45,69 @@ class CivitaiClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, str]: (success, save_path or error message)
|
Tuple[bool, str]: (success, save_path or error message)
|
||||||
"""
|
"""
|
||||||
max_retries = 5
|
downloader = await get_downloader()
|
||||||
retry_count = 0
|
|
||||||
base_delay = 2.0 # Base delay for exponential backoff
|
|
||||||
|
|
||||||
# Initial setup
|
|
||||||
session = await self._ensure_fresh_session()
|
|
||||||
save_path = os.path.join(save_dir, default_filename)
|
save_path = os.path.join(save_dir, default_filename)
|
||||||
part_path = save_path + '.part'
|
|
||||||
|
|
||||||
# Get existing file size for resume
|
# Use unified downloader with CivitAI authentication
|
||||||
resume_offset = 0
|
success, result = await downloader.download_file(
|
||||||
if os.path.exists(part_path):
|
url=url,
|
||||||
resume_offset = os.path.getsize(part_path)
|
save_path=save_path,
|
||||||
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
progress_callback=progress_callback,
|
||||||
|
use_auth=True, # Enable CivitAI authentication
|
||||||
|
allow_resume=True
|
||||||
|
)
|
||||||
|
|
||||||
total_size = 0
|
return success, result
|
||||||
|
|
||||||
while retry_count <= max_retries:
|
|
||||||
try:
|
|
||||||
headers = self._get_request_headers()
|
|
||||||
|
|
||||||
# Add Range header for resume if we have partial data
|
|
||||||
if resume_offset > 0:
|
|
||||||
headers['Range'] = f'bytes={resume_offset}-'
|
|
||||||
|
|
||||||
# Add Range header to allow resumable downloads
|
|
||||||
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
|
|
||||||
|
|
||||||
logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}")
|
|
||||||
if resume_offset > 0:
|
|
||||||
logger.debug(f"Requesting range from byte {resume_offset}")
|
|
||||||
|
|
||||||
async with session.get(url, headers=headers, allow_redirects=True) as response:
|
|
||||||
# Handle different response codes
|
|
||||||
if response.status == 200:
|
|
||||||
# Full content response
|
|
||||||
if resume_offset > 0:
|
|
||||||
# Server doesn't support ranges, restart from beginning
|
|
||||||
logger.warning("Server doesn't support range requests, restarting download")
|
|
||||||
resume_offset = 0
|
|
||||||
if os.path.exists(part_path):
|
|
||||||
os.remove(part_path)
|
|
||||||
elif response.status == 206:
|
|
||||||
# Partial content response (resume successful)
|
|
||||||
content_range = response.headers.get('Content-Range')
|
|
||||||
if content_range:
|
|
||||||
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
|
|
||||||
range_parts = content_range.split('/')
|
|
||||||
if len(range_parts) == 2:
|
|
||||||
total_size = int(range_parts[1])
|
|
||||||
logger.info(f"Successfully resumed download from byte {resume_offset}")
|
|
||||||
elif response.status == 416:
|
|
||||||
# Range not satisfiable - file might be complete or corrupted
|
|
||||||
if os.path.exists(part_path):
|
|
||||||
part_size = os.path.getsize(part_path)
|
|
||||||
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
|
|
||||||
# Try to get actual file size
|
|
||||||
head_response = await session.head(url, headers=self._get_request_headers())
|
|
||||||
if head_response.status == 200:
|
|
||||||
actual_size = int(head_response.headers.get('content-length', 0))
|
|
||||||
if part_size == actual_size:
|
|
||||||
# File is complete, just rename it
|
|
||||||
os.rename(part_path, save_path)
|
|
||||||
if progress_callback:
|
|
||||||
await progress_callback(100)
|
|
||||||
return True, save_path
|
|
||||||
# Remove corrupted part file and restart
|
|
||||||
os.remove(part_path)
|
|
||||||
resume_offset = 0
|
|
||||||
continue
|
|
||||||
elif response.status == 401:
|
|
||||||
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
|
||||||
return False, "Invalid or missing CivitAI API key, or early access restriction."
|
|
||||||
elif response.status == 403:
|
|
||||||
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
|
||||||
return False, "Access forbidden: You don't have permission to download this file."
|
|
||||||
else:
|
|
||||||
logger.error(f"Download failed for {url} with status {response.status}")
|
|
||||||
return False, f"Download failed with status {response.status}"
|
|
||||||
|
|
||||||
# Get total file size for progress calculation (if not set from Content-Range)
|
|
||||||
if total_size == 0:
|
|
||||||
total_size = int(response.headers.get('content-length', 0))
|
|
||||||
if response.status == 206:
|
|
||||||
# For partial content, add the offset to get total file size
|
|
||||||
total_size += resume_offset
|
|
||||||
|
|
||||||
current_size = resume_offset
|
|
||||||
last_progress_report_time = datetime.now()
|
|
||||||
|
|
||||||
# Stream download to file with progress updates using larger buffer
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
mode = 'ab' if resume_offset > 0 else 'wb'
|
|
||||||
with open(part_path, mode) as f:
|
|
||||||
async for chunk in response.content.iter_chunked(self.chunk_size):
|
|
||||||
if chunk:
|
|
||||||
# Run blocking file write in executor
|
|
||||||
await loop.run_in_executor(None, f.write, chunk)
|
|
||||||
current_size += len(chunk)
|
|
||||||
|
|
||||||
# Limit progress update frequency to reduce overhead
|
|
||||||
now = datetime.now()
|
|
||||||
time_diff = (now - last_progress_report_time).total_seconds()
|
|
||||||
|
|
||||||
if progress_callback and total_size and time_diff >= 1.0:
|
|
||||||
progress = (current_size / total_size) * 100
|
|
||||||
await progress_callback(progress)
|
|
||||||
last_progress_report_time = now
|
|
||||||
|
|
||||||
# Download completed successfully
|
|
||||||
# Verify file size if total_size was provided
|
|
||||||
final_size = os.path.getsize(part_path)
|
|
||||||
if total_size > 0 and final_size != total_size:
|
|
||||||
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
|
|
||||||
# Don't treat this as fatal error, rename anyway
|
|
||||||
|
|
||||||
# Atomically rename .part to final file with retries
|
|
||||||
max_rename_attempts = 5
|
|
||||||
rename_attempt = 0
|
|
||||||
rename_success = False
|
|
||||||
|
|
||||||
while rename_attempt < max_rename_attempts and not rename_success:
|
|
||||||
try:
|
|
||||||
os.rename(part_path, save_path)
|
|
||||||
rename_success = True
|
|
||||||
except PermissionError as e:
|
|
||||||
rename_attempt += 1
|
|
||||||
if rename_attempt < max_rename_attempts:
|
|
||||||
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
|
|
||||||
await asyncio.sleep(2) # Wait before retrying
|
|
||||||
else:
|
|
||||||
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
|
||||||
return False, f"Failed to finalize download: {str(e)}"
|
|
||||||
|
|
||||||
# Ensure 100% progress is reported
|
|
||||||
if progress_callback:
|
|
||||||
await progress_callback(100)
|
|
||||||
|
|
||||||
return True, save_path
|
|
||||||
|
|
||||||
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
|
|
||||||
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
|
|
||||||
retry_count += 1
|
|
||||||
logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}")
|
|
||||||
|
|
||||||
if retry_count <= max_retries:
|
|
||||||
# Calculate delay with exponential backoff
|
|
||||||
delay = base_delay * (2 ** (retry_count - 1))
|
|
||||||
logger.info(f"Retrying in {delay} seconds...")
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
# Update resume offset for next attempt
|
|
||||||
if os.path.exists(part_path):
|
|
||||||
resume_offset = os.path.getsize(part_path)
|
|
||||||
logger.info(f"Will resume from byte {resume_offset}")
|
|
||||||
|
|
||||||
# Refresh session to get new connection
|
|
||||||
await self.close()
|
|
||||||
session = await self._ensure_fresh_session()
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.error(f"Max retries exceeded for download: {e}")
|
|
||||||
return False, f"Network error after {max_retries + 1} attempts: {str(e)}"
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected download error: {e}")
|
|
||||||
return False, str(e)
|
|
||||||
|
|
||||||
return False, f"Download failed after {max_retries + 1} attempts"
|
|
||||||
|
|
||||||
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
downloader = await get_downloader()
|
||||||
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response:
|
success, result = await downloader.make_request(
|
||||||
if response.status == 200:
|
'GET',
|
||||||
return await response.json()
|
f"{self.base_url}/model-versions/by-hash/{model_hash}",
|
||||||
return None
|
use_auth=True
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
return result
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"API Error: {str(e)}")
|
logger.error(f"API Error: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def download_preview_image(self, image_url: str, save_path: str):
|
async def download_preview_image(self, image_url: str, save_path: str):
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
downloader = await get_downloader()
|
||||||
async with session.get(image_url) as response:
|
success, content = await downloader.download_to_memory(
|
||||||
if response.status == 200:
|
image_url,
|
||||||
content = await response.read()
|
use_auth=False # Preview images don't need auth
|
||||||
with open(save_path, 'wb') as f:
|
)
|
||||||
f.write(content)
|
if success:
|
||||||
return True
|
# Ensure directory exists
|
||||||
return False
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
with open(save_path, 'wb') as f:
|
||||||
|
f.write(content)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Download Error: {str(e)}")
|
logger.error(f"Download Error: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
async def get_model_versions(self, model_id: str) -> List[Dict]:
|
||||||
"""Get all versions of a model with local availability info"""
|
"""Get all versions of a model with local availability info"""
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session() # Use fresh session
|
downloader = await get_downloader()
|
||||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
success, result = await downloader.make_request(
|
||||||
if response.status != 200:
|
'GET',
|
||||||
return None
|
f"{self.base_url}/models/{model_id}",
|
||||||
data = await response.json()
|
use_auth=True
|
||||||
|
)
|
||||||
|
if success:
|
||||||
# Also return model type along with versions
|
# Also return model type along with versions
|
||||||
return {
|
return {
|
||||||
'modelVersions': data.get('modelVersions', []),
|
'modelVersions': result.get('modelVersions', []),
|
||||||
'type': data.get('type', '')
|
'type': result.get('type', '')
|
||||||
}
|
}
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model versions: {e}")
|
logger.error(f"Error fetching model versions: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -322,68 +123,74 @@ class CivitaiClient:
|
|||||||
Optional[Dict]: The model version data with additional fields or None if not found
|
Optional[Dict]: The model version data with additional fields or None if not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
downloader = await get_downloader()
|
||||||
headers = self._get_request_headers()
|
|
||||||
|
|
||||||
# Case 1: Only version_id is provided
|
# Case 1: Only version_id is provided
|
||||||
if model_id is None and version_id is not None:
|
if model_id is None and version_id is not None:
|
||||||
# First get the version info to extract model_id
|
# First get the version info to extract model_id
|
||||||
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response:
|
success, version = await downloader.make_request(
|
||||||
if response.status != 200:
|
'GET',
|
||||||
return None
|
f"{self.base_url}/model-versions/{version_id}",
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
return None
|
||||||
|
|
||||||
version = await response.json()
|
model_id = version.get('modelId')
|
||||||
model_id = version.get('modelId')
|
if not model_id:
|
||||||
|
logger.error(f"No modelId found in version {version_id}")
|
||||||
if not model_id:
|
return None
|
||||||
logger.error(f"No modelId found in version {version_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Now get the model data for additional metadata
|
# Now get the model data for additional metadata
|
||||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
success, model_data = await downloader.make_request(
|
||||||
if response.status != 200:
|
'GET',
|
||||||
return version # Return version without additional metadata
|
f"{self.base_url}/models/{model_id}",
|
||||||
|
use_auth=True
|
||||||
model_data = await response.json()
|
)
|
||||||
|
if success:
|
||||||
# Enrich version with model data
|
# Enrich version with model data
|
||||||
version['model']['description'] = model_data.get("description")
|
version['model']['description'] = model_data.get("description")
|
||||||
version['model']['tags'] = model_data.get("tags", [])
|
version['model']['tags'] = model_data.get("tags", [])
|
||||||
version['creator'] = model_data.get("creator")
|
version['creator'] = model_data.get("creator")
|
||||||
|
|
||||||
return version
|
return version
|
||||||
|
|
||||||
# Case 2: model_id is provided (with or without version_id)
|
# Case 2: model_id is provided (with or without version_id)
|
||||||
elif model_id is not None:
|
elif model_id is not None:
|
||||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
success, data = await downloader.make_request(
|
||||||
if response.status != 200:
|
'GET',
|
||||||
return None
|
f"{self.base_url}/models/{model_id}",
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
return None
|
||||||
|
|
||||||
data = await response.json()
|
model_versions = data.get('modelVersions', [])
|
||||||
model_versions = data.get('modelVersions', [])
|
|
||||||
|
|
||||||
# Step 2: Determine the version_id to use
|
# Step 2: Determine the version_id to use
|
||||||
target_version_id = version_id
|
target_version_id = version_id
|
||||||
if target_version_id is None:
|
if target_version_id is None:
|
||||||
target_version_id = model_versions[0].get('id')
|
target_version_id = model_versions[0].get('id')
|
||||||
|
|
||||||
# Step 3: Get detailed version info using the version_id
|
# Step 3: Get detailed version info using the version_id
|
||||||
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
|
success, version = await downloader.make_request(
|
||||||
if response.status != 200:
|
'GET',
|
||||||
return None
|
f"{self.base_url}/model-versions/{target_version_id}",
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
if not success:
|
||||||
|
return None
|
||||||
|
|
||||||
version = await response.json()
|
# Step 4: Enrich version_info with model data
|
||||||
|
# Add description and tags from model data
|
||||||
|
version['model']['description'] = data.get("description")
|
||||||
|
version['model']['tags'] = data.get("tags", [])
|
||||||
|
|
||||||
# Step 4: Enrich version_info with model data
|
# Add creator from model data
|
||||||
# Add description and tags from model data
|
version['creator'] = data.get("creator")
|
||||||
version['model']['description'] = data.get("description")
|
|
||||||
version['model']['tags'] = data.get("tags", [])
|
|
||||||
|
|
||||||
# Add creator from model data
|
return version
|
||||||
version['creator'] = data.get("creator")
|
|
||||||
|
|
||||||
return version
|
|
||||||
|
|
||||||
# Case 3: Neither model_id nor version_id provided
|
# Case 3: Neither model_id nor version_id provided
|
||||||
else:
|
else:
|
||||||
@@ -406,30 +213,29 @@ class CivitaiClient:
|
|||||||
- An error message if there was an error, or None on success
|
- An error message if there was an error, or None on success
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
downloader = await get_downloader()
|
||||||
url = f"{self.base_url}/model-versions/{version_id}"
|
url = f"{self.base_url}/model-versions/{version_id}"
|
||||||
headers = self._get_request_headers()
|
|
||||||
|
|
||||||
logger.debug(f"Resolving DNS for model version info: {url}")
|
logger.debug(f"Resolving DNS for model version info: {url}")
|
||||||
async with session.get(url, headers=headers) as response:
|
success, result = await downloader.make_request(
|
||||||
if response.status == 200:
|
'GET',
|
||||||
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
url,
|
||||||
return await response.json(), None
|
use_auth=True
|
||||||
|
)
|
||||||
|
|
||||||
# Handle specific error cases
|
if success:
|
||||||
if response.status == 404:
|
logger.debug(f"Successfully fetched model version info for: {version_id}")
|
||||||
# Try to parse the error message
|
return result, None
|
||||||
try:
|
|
||||||
error_data = await response.json()
|
|
||||||
error_msg = error_data.get('error', f"Model not found (status 404)")
|
|
||||||
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
|
||||||
return None, error_msg
|
|
||||||
except:
|
|
||||||
return None, "Model not found (status 404)"
|
|
||||||
|
|
||||||
# Other error cases
|
# Handle specific error cases
|
||||||
logger.error(f"Failed to fetch model info for {version_id} (status {response.status})")
|
if "404" in str(result):
|
||||||
return None, f"Failed to fetch model info (status {response.status})"
|
error_msg = f"Model not found (status 404)"
|
||||||
|
logger.warning(f"Model version not found: {version_id} - {error_msg}")
|
||||||
|
return None, error_msg
|
||||||
|
|
||||||
|
# Other error cases
|
||||||
|
logger.error(f"Failed to fetch model info for {version_id}: {result}")
|
||||||
|
return None, str(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error fetching model version info: {e}"
|
error_msg = f"Error fetching model version info: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -444,48 +250,50 @@ class CivitaiClient:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple[Optional[Dict], int]: A tuple containing:
|
Tuple[Optional[Dict], int]: A tuple containing:
|
||||||
- A dictionary with model metadata or None if not found
|
- A dictionary with model metadata or None if not found
|
||||||
- The HTTP status code from the request
|
- The HTTP status code from the request (0 for exceptions)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
downloader = await get_downloader()
|
||||||
headers = self._get_request_headers()
|
|
||||||
url = f"{self.base_url}/models/{model_id}"
|
url = f"{self.base_url}/models/{model_id}"
|
||||||
|
|
||||||
async with session.get(url, headers=headers) as response:
|
success, result = await downloader.make_request(
|
||||||
status_code = response.status
|
'GET',
|
||||||
|
url,
|
||||||
|
use_auth=True
|
||||||
|
)
|
||||||
|
|
||||||
if status_code != 200:
|
if not success:
|
||||||
logger.warning(f"Failed to fetch model metadata: Status {status_code}")
|
# Try to extract status code from error message
|
||||||
return None, status_code
|
status_code = 0
|
||||||
|
if "404" in str(result):
|
||||||
|
status_code = 404
|
||||||
|
elif "401" in str(result):
|
||||||
|
status_code = 401
|
||||||
|
elif "403" in str(result):
|
||||||
|
status_code = 403
|
||||||
|
logger.warning(f"Failed to fetch model metadata: {result}")
|
||||||
|
return None, status_code
|
||||||
|
|
||||||
data = await response.json()
|
# Extract relevant metadata
|
||||||
|
metadata = {
|
||||||
# Extract relevant metadata
|
"description": result.get("description") or "No model description available",
|
||||||
metadata = {
|
"tags": result.get("tags", []),
|
||||||
"description": data.get("description") or "No model description available",
|
"creator": {
|
||||||
"tags": data.get("tags", []),
|
"username": result.get("creator", {}).get("username"),
|
||||||
"creator": {
|
"image": result.get("creator", {}).get("image")
|
||||||
"username": data.get("creator", {}).get("username"),
|
|
||||||
"image": data.get("creator", {}).get("image")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
|
||||||
return metadata, status_code
|
return metadata, 200
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No metadata found for model {model_id}")
|
logger.warning(f"No metadata found for model {model_id}")
|
||||||
return None, status_code
|
return None, 200
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
logger.error(f"Error fetching model metadata: {e}", exc_info=True)
|
||||||
return None, 0
|
return None, 0
|
||||||
|
|
||||||
async def close(self):
|
|
||||||
"""Close the session if it exists"""
|
|
||||||
if self._session is not None:
|
|
||||||
await self._session.close()
|
|
||||||
self._session = None
|
|
||||||
|
|
||||||
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
async def get_image_info(self, image_id: str) -> Optional[Dict]:
|
||||||
"""Fetch image information from Civitai API
|
"""Fetch image information from Civitai API
|
||||||
|
|
||||||
@@ -496,22 +304,25 @@ class CivitaiClient:
|
|||||||
Optional[Dict]: The image data or None if not found
|
Optional[Dict]: The image data or None if not found
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
downloader = await get_downloader()
|
||||||
headers = self._get_request_headers()
|
|
||||||
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
|
||||||
|
|
||||||
logger.debug(f"Fetching image info for ID: {image_id}")
|
logger.debug(f"Fetching image info for ID: {image_id}")
|
||||||
async with session.get(url, headers=headers) as response:
|
success, result = await downloader.make_request(
|
||||||
if response.status == 200:
|
'GET',
|
||||||
data = await response.json()
|
url,
|
||||||
if data and "items" in data and len(data["items"]) > 0:
|
use_auth=True
|
||||||
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
)
|
||||||
return data["items"][0]
|
|
||||||
logger.warning(f"No image found with ID: {image_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})")
|
if success:
|
||||||
|
if result and "items" in result and len(result["items"]) > 0:
|
||||||
|
logger.debug(f"Successfully fetched image info for ID: {image_id}")
|
||||||
|
return result["items"][0]
|
||||||
|
logger.warning(f"No image found with ID: {image_id}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error fetching image info: {e}"
|
error_msg = f"Error fetching image info: {e}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
|||||||
465
py/services/downloader.py
Normal file
465
py/services/downloader.py
Normal file
@@ -0,0 +1,465 @@
|
|||||||
|
"""
|
||||||
|
Unified download manager for all HTTP/HTTPS downloads in the application.
|
||||||
|
|
||||||
|
This module provides a centralized download service with:
|
||||||
|
- Singleton pattern for global session management
|
||||||
|
- Support for authenticated downloads (e.g., CivitAI API key)
|
||||||
|
- Resumable downloads with automatic retry
|
||||||
|
- Progress tracking and callbacks
|
||||||
|
- Optimized connection pooling and timeouts
|
||||||
|
- Unified error handling and logging
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Tuple, Callable, Union
|
||||||
|
from ..services.settings_manager import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Downloader:
|
||||||
|
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
|
||||||
|
|
||||||
|
_instance = None
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_instance(cls):
|
||||||
|
"""Get singleton instance of Downloader"""
|
||||||
|
async with cls._lock:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = cls()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the downloader with optimal settings"""
|
||||||
|
# Check if already initialized for singleton pattern
|
||||||
|
if hasattr(self, '_initialized'):
|
||||||
|
return
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
# Session management
|
||||||
|
self._session = None
|
||||||
|
self._session_created_at = None
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput
|
||||||
|
self.max_retries = 5
|
||||||
|
self.base_delay = 2.0 # Base delay for exponential backoff
|
||||||
|
self.session_timeout = 300 # 5 minutes
|
||||||
|
|
||||||
|
# Default headers
|
||||||
|
self.default_headers = {
|
||||||
|
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def session(self) -> aiohttp.ClientSession:
|
||||||
|
"""Get or create the global aiohttp session with optimized settings"""
|
||||||
|
if self._session is None or self._should_refresh_session():
|
||||||
|
await self._create_session()
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
def _should_refresh_session(self) -> bool:
|
||||||
|
"""Check if session should be refreshed"""
|
||||||
|
if self._session is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not hasattr(self, '_session_created_at') or self._session_created_at is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Refresh if session is older than timeout
|
||||||
|
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _create_session(self):
|
||||||
|
"""Create a new aiohttp session with optimized settings"""
|
||||||
|
# Close existing session if any
|
||||||
|
if self._session is not None:
|
||||||
|
await self._session.close()
|
||||||
|
|
||||||
|
# Optimize TCP connection parameters
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
ssl=True,
|
||||||
|
limit=8, # Concurrent connections
|
||||||
|
ttl_dns_cache=300, # DNS cache timeout
|
||||||
|
force_close=False, # Keep connections for reuse
|
||||||
|
enable_cleanup_closed=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Configure timeout parameters
|
||||||
|
timeout = aiohttp.ClientTimeout(
|
||||||
|
total=None, # No total timeout for large downloads
|
||||||
|
connect=60, # Connection timeout
|
||||||
|
sock_read=None # No socket read timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
connector=connector,
|
||||||
|
trust_env=True, # Use system proxy settings
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
self._session_created_at = datetime.now()
|
||||||
|
|
||||||
|
logger.debug("Created new HTTP session")
|
||||||
|
|
||||||
|
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
|
||||||
|
"""Get headers with optional authentication"""
|
||||||
|
headers = self.default_headers.copy()
|
||||||
|
|
||||||
|
if use_auth:
|
||||||
|
# Add CivitAI API key if available
|
||||||
|
api_key = settings.get('civitai_api_key')
|
||||||
|
if api_key:
|
||||||
|
headers['Authorization'] = f'Bearer {api_key}'
|
||||||
|
headers['Content-Type'] = 'application/json'
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def download_file(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
save_path: str,
|
||||||
|
progress_callback: Optional[Callable[[float], None]] = None,
|
||||||
|
use_auth: bool = False,
|
||||||
|
custom_headers: Optional[Dict[str, str]] = None,
|
||||||
|
allow_resume: bool = True
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Download a file with resumable downloads and retry mechanism
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Download URL
|
||||||
|
save_path: Full path where the file should be saved
|
||||||
|
progress_callback: Optional callback for progress updates (0-100)
|
||||||
|
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
|
||||||
|
custom_headers: Additional headers to include in request
|
||||||
|
allow_resume: Whether to support resumable downloads
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, str]: (success, save_path or error message)
|
||||||
|
"""
|
||||||
|
retry_count = 0
|
||||||
|
part_path = save_path + '.part' if allow_resume else save_path
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
headers = self._get_auth_headers(use_auth)
|
||||||
|
if custom_headers:
|
||||||
|
headers.update(custom_headers)
|
||||||
|
|
||||||
|
# Get existing file size for resume
|
||||||
|
resume_offset = 0
|
||||||
|
if allow_resume and os.path.exists(part_path):
|
||||||
|
resume_offset = os.path.getsize(part_path)
|
||||||
|
logger.info(f"Resuming download from offset {resume_offset} bytes")
|
||||||
|
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
while retry_count <= self.max_retries:
|
||||||
|
try:
|
||||||
|
session = await self.session
|
||||||
|
|
||||||
|
# Add Range header for resume if we have partial data
|
||||||
|
request_headers = headers.copy()
|
||||||
|
if allow_resume and resume_offset > 0:
|
||||||
|
request_headers['Range'] = f'bytes={resume_offset}-'
|
||||||
|
|
||||||
|
# Disable compression for better chunked downloads
|
||||||
|
request_headers['Accept-Encoding'] = 'identity'
|
||||||
|
|
||||||
|
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}")
|
||||||
|
if resume_offset > 0:
|
||||||
|
logger.debug(f"Requesting range from byte {resume_offset}")
|
||||||
|
|
||||||
|
async with session.get(url, headers=request_headers, allow_redirects=True) as response:
|
||||||
|
# Handle different response codes
|
||||||
|
if response.status == 200:
|
||||||
|
# Full content response
|
||||||
|
if resume_offset > 0:
|
||||||
|
# Server doesn't support ranges, restart from beginning
|
||||||
|
logger.warning("Server doesn't support range requests, restarting download")
|
||||||
|
resume_offset = 0
|
||||||
|
if os.path.exists(part_path):
|
||||||
|
os.remove(part_path)
|
||||||
|
elif response.status == 206:
|
||||||
|
# Partial content response (resume successful)
|
||||||
|
content_range = response.headers.get('Content-Range')
|
||||||
|
if content_range:
|
||||||
|
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
|
||||||
|
range_parts = content_range.split('/')
|
||||||
|
if len(range_parts) == 2:
|
||||||
|
total_size = int(range_parts[1])
|
||||||
|
logger.info(f"Successfully resumed download from byte {resume_offset}")
|
||||||
|
elif response.status == 416:
|
||||||
|
# Range not satisfiable - file might be complete or corrupted
|
||||||
|
if allow_resume and os.path.exists(part_path):
|
||||||
|
part_size = os.path.getsize(part_path)
|
||||||
|
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
|
||||||
|
# Try to get actual file size
|
||||||
|
head_response = await session.head(url, headers=headers)
|
||||||
|
if head_response.status == 200:
|
||||||
|
actual_size = int(head_response.headers.get('content-length', 0))
|
||||||
|
if part_size == actual_size:
|
||||||
|
# File is complete, just rename it
|
||||||
|
if allow_resume:
|
||||||
|
os.rename(part_path, save_path)
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(100)
|
||||||
|
return True, save_path
|
||||||
|
# Remove corrupted part file and restart
|
||||||
|
os.remove(part_path)
|
||||||
|
resume_offset = 0
|
||||||
|
continue
|
||||||
|
elif response.status == 401:
|
||||||
|
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
|
||||||
|
return False, "Invalid or missing API key, or early access restriction."
|
||||||
|
elif response.status == 403:
|
||||||
|
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
|
||||||
|
return False, "Access forbidden: You don't have permission to download this file."
|
||||||
|
elif response.status == 404:
|
||||||
|
logger.warning(f"Resource not found: {url} (Status 404)")
|
||||||
|
return False, "File not found - the download link may be invalid or expired."
|
||||||
|
else:
|
||||||
|
logger.error(f"Download failed for {url} with status {response.status}")
|
||||||
|
return False, f"Download failed with status {response.status}"
|
||||||
|
|
||||||
|
# Get total file size for progress calculation (if not set from Content-Range)
|
||||||
|
if total_size == 0:
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
if response.status == 206:
|
||||||
|
# For partial content, add the offset to get total file size
|
||||||
|
total_size += resume_offset
|
||||||
|
|
||||||
|
current_size = resume_offset
|
||||||
|
last_progress_report_time = datetime.now()
|
||||||
|
|
||||||
|
# Ensure directory exists
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
|
||||||
|
# Stream download to file with progress updates
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
|
||||||
|
with open(part_path, mode) as f:
|
||||||
|
async for chunk in response.content.iter_chunked(self.chunk_size):
|
||||||
|
if chunk:
|
||||||
|
# Run blocking file write in executor
|
||||||
|
await loop.run_in_executor(None, f.write, chunk)
|
||||||
|
current_size += len(chunk)
|
||||||
|
|
||||||
|
# Limit progress update frequency to reduce overhead
|
||||||
|
now = datetime.now()
|
||||||
|
time_diff = (now - last_progress_report_time).total_seconds()
|
||||||
|
|
||||||
|
if progress_callback and total_size and time_diff >= 1.0:
|
||||||
|
progress = (current_size / total_size) * 100
|
||||||
|
await progress_callback(progress)
|
||||||
|
last_progress_report_time = now
|
||||||
|
|
||||||
|
# Download completed successfully
|
||||||
|
# Verify file size if total_size was provided
|
||||||
|
final_size = os.path.getsize(part_path)
|
||||||
|
if total_size > 0 and final_size != total_size:
|
||||||
|
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
|
||||||
|
# Don't treat this as fatal error, continue anyway
|
||||||
|
|
||||||
|
# Atomically rename .part to final file (only if using resume)
|
||||||
|
if allow_resume and part_path != save_path:
|
||||||
|
max_rename_attempts = 5
|
||||||
|
rename_attempt = 0
|
||||||
|
rename_success = False
|
||||||
|
|
||||||
|
while rename_attempt < max_rename_attempts and not rename_success:
|
||||||
|
try:
|
||||||
|
os.rename(part_path, save_path)
|
||||||
|
rename_success = True
|
||||||
|
except PermissionError as e:
|
||||||
|
rename_attempt += 1
|
||||||
|
if rename_attempt < max_rename_attempts:
|
||||||
|
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
|
||||||
|
return False, f"Failed to finalize download: {str(e)}"
|
||||||
|
|
||||||
|
# Ensure 100% progress is reported
|
||||||
|
if progress_callback:
|
||||||
|
await progress_callback(100)
|
||||||
|
|
||||||
|
return True, save_path
|
||||||
|
|
||||||
|
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
|
||||||
|
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
|
||||||
|
retry_count += 1
|
||||||
|
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
|
||||||
|
|
||||||
|
if retry_count <= self.max_retries:
|
||||||
|
# Calculate delay with exponential backoff
|
||||||
|
delay = self.base_delay * (2 ** (retry_count - 1))
|
||||||
|
logger.info(f"Retrying in {delay} seconds...")
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
# Update resume offset for next attempt
|
||||||
|
if allow_resume and os.path.exists(part_path):
|
||||||
|
resume_offset = os.path.getsize(part_path)
|
||||||
|
logger.info(f"Will resume from byte {resume_offset}")
|
||||||
|
|
||||||
|
# Refresh session to get new connection
|
||||||
|
await self._create_session()
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
logger.error(f"Max retries exceeded for download: {e}")
|
||||||
|
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected download error: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
return False, f"Download failed after {self.max_retries + 1} attempts"
|
||||||
|
|
||||||
|
async def download_to_memory(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
use_auth: bool = False,
|
||||||
|
custom_headers: Optional[Dict[str, str]] = None
|
||||||
|
) -> Tuple[bool, Union[bytes, str]]:
|
||||||
|
"""
|
||||||
|
Download a file to memory (for small files like preview images)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Download URL
|
||||||
|
use_auth: Whether to include authentication headers
|
||||||
|
custom_headers: Additional headers to include in request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, Union[bytes, str]]: (success, content or error message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session = await self.session
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
headers = self._get_auth_headers(use_auth)
|
||||||
|
if custom_headers:
|
||||||
|
headers.update(custom_headers)
|
||||||
|
|
||||||
|
async with session.get(url, headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
return True, content
|
||||||
|
elif response.status == 401:
|
||||||
|
return False, "Unauthorized access - invalid or missing API key"
|
||||||
|
elif response.status == 403:
|
||||||
|
return False, "Access forbidden"
|
||||||
|
elif response.status == 404:
|
||||||
|
return False, "File not found"
|
||||||
|
else:
|
||||||
|
return False, f"Download failed with status {response.status}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error downloading to memory from {url}: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
async def get_response_headers(
|
||||||
|
self,
|
||||||
|
url: str,
|
||||||
|
use_auth: bool = False,
|
||||||
|
custom_headers: Optional[Dict[str, str]] = None
|
||||||
|
) -> Tuple[bool, Union[Dict, str]]:
|
||||||
|
"""
|
||||||
|
Get response headers without downloading the full content
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: URL to check
|
||||||
|
use_auth: Whether to include authentication headers
|
||||||
|
custom_headers: Additional headers to include in request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session = await self.session
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
headers = self._get_auth_headers(use_auth)
|
||||||
|
if custom_headers:
|
||||||
|
headers.update(custom_headers)
|
||||||
|
|
||||||
|
async with session.head(url, headers=headers) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
return True, dict(response.headers)
|
||||||
|
else:
|
||||||
|
return False, f"Head request failed with status {response.status}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting headers from {url}: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
async def make_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
use_auth: bool = False,
|
||||||
|
custom_headers: Optional[Dict[str, str]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> Tuple[bool, Union[Dict, str]]:
|
||||||
|
"""
|
||||||
|
Make a generic HTTP request and return JSON response
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method (GET, POST, etc.)
|
||||||
|
url: Request URL
|
||||||
|
use_auth: Whether to include authentication headers
|
||||||
|
custom_headers: Additional headers to include in request
|
||||||
|
**kwargs: Additional arguments for aiohttp request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
session = await self.session
|
||||||
|
|
||||||
|
# Prepare headers
|
||||||
|
headers = self._get_auth_headers(use_auth)
|
||||||
|
if custom_headers:
|
||||||
|
headers.update(custom_headers)
|
||||||
|
|
||||||
|
async with session.request(method, url, headers=headers, **kwargs) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
# Try to parse as JSON, fall back to text
|
||||||
|
try:
|
||||||
|
data = await response.json()
|
||||||
|
return True, data
|
||||||
|
except:
|
||||||
|
text = await response.text()
|
||||||
|
return True, text
|
||||||
|
elif response.status == 401:
|
||||||
|
return False, "Unauthorized access - invalid or missing API key"
|
||||||
|
elif response.status == 403:
|
||||||
|
return False, "Access forbidden"
|
||||||
|
elif response.status == 404:
|
||||||
|
return False, "Resource not found"
|
||||||
|
else:
|
||||||
|
return False, f"Request failed with status {response.status}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error making {method} request to {url}: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the HTTP session"""
|
||||||
|
if self._session is not None:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
self._session_created_at = None
|
||||||
|
logger.debug("Closed HTTP session")
|
||||||
|
|
||||||
|
|
||||||
|
# Global instance accessor
|
||||||
|
async def get_downloader() -> Downloader:
|
||||||
|
"""Get the global downloader instance"""
|
||||||
|
return await Downloader.get_instance()
|
||||||
@@ -1,9 +1,9 @@
|
|||||||
import zipfile
|
import zipfile
|
||||||
import aiohttp
|
|
||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from .downloader import get_downloader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -67,6 +67,8 @@ class MetadataArchiveManager:
|
|||||||
|
|
||||||
async def _download_archive(self, progress_callback=None) -> bool:
|
async def _download_archive(self, progress_callback=None) -> bool:
|
||||||
"""Download the zip archive from one of the available URLs"""
|
"""Download the zip archive from one of the available URLs"""
|
||||||
|
downloader = await get_downloader()
|
||||||
|
|
||||||
for url in self.DOWNLOAD_URLS:
|
for url in self.DOWNLOAD_URLS:
|
||||||
try:
|
try:
|
||||||
logger.info(f"Attempting to download from {url}")
|
logger.info(f"Attempting to download from {url}")
|
||||||
@@ -74,26 +76,25 @@ class MetadataArchiveManager:
|
|||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback("download", f"Downloading from {url}")
|
progress_callback("download", f"Downloading from {url}")
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
# Custom progress callback to report download progress
|
||||||
async with session.get(url) as response:
|
async def download_progress(progress):
|
||||||
if response.status == 200:
|
if progress_callback:
|
||||||
total_size = int(response.headers.get('content-length', 0))
|
progress_callback("download", f"Downloaded {progress:.1f}%")
|
||||||
downloaded = 0
|
|
||||||
|
|
||||||
with open(self.archive_path, 'wb') as f:
|
success, result = await downloader.download_file(
|
||||||
async for chunk in response.content.iter_chunked(8192):
|
url=url,
|
||||||
f.write(chunk)
|
save_path=str(self.archive_path),
|
||||||
downloaded += len(chunk)
|
progress_callback=download_progress,
|
||||||
|
use_auth=False, # Public download, no auth needed
|
||||||
|
allow_resume=True
|
||||||
|
)
|
||||||
|
|
||||||
if progress_callback and total_size > 0:
|
if success:
|
||||||
percentage = (downloaded / total_size) * 100
|
logger.info(f"Successfully downloaded archive from {url}")
|
||||||
progress_callback("download", f"Downloaded {percentage:.1f}%")
|
return True
|
||||||
|
else:
|
||||||
logger.info(f"Successfully downloaded archive from {url}")
|
logger.warning(f"Failed to download from {url}: {result}")
|
||||||
return True
|
continue
|
||||||
else:
|
|
||||||
logger.warning(f"Failed to download from {url}: HTTP {response.status}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error downloading from {url}: {e}")
|
logger.warning(f"Error downloading from {url}: {e}")
|
||||||
|
|||||||
@@ -3,13 +3,13 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import aiohttp
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .example_images_processor import ExampleImagesProcessor
|
from .example_images_processor import ExampleImagesProcessor
|
||||||
from .example_images_metadata import MetadataUpdater
|
from .example_images_metadata import MetadataUpdater
|
||||||
from ..services.websocket_manager import ws_manager # Add this import at the top
|
from ..services.websocket_manager import ws_manager # Add this import at the top
|
||||||
|
from ..services.downloader import get_downloader
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -199,19 +199,8 @@ class DownloadManager:
|
|||||||
"""Download example images for all models"""
|
"""Download example images for all models"""
|
||||||
global is_downloading, download_progress
|
global is_downloading, download_progress
|
||||||
|
|
||||||
# Create independent download session
|
# Get unified downloader
|
||||||
connector = aiohttp.TCPConnector(
|
downloader = await get_downloader()
|
||||||
ssl=True,
|
|
||||||
limit=3,
|
|
||||||
force_close=False,
|
|
||||||
enable_cleanup_closed=True
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
|
||||||
independent_session = aiohttp.ClientSession(
|
|
||||||
connector=connector,
|
|
||||||
trust_env=True,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get scanners
|
# Get scanners
|
||||||
@@ -246,7 +235,7 @@ class DownloadManager:
|
|||||||
# Main logic for processing model is here, but actual operations are delegated to other classes
|
# Main logic for processing model is here, but actual operations are delegated to other classes
|
||||||
was_remote_download = await DownloadManager._process_model(
|
was_remote_download = await DownloadManager._process_model(
|
||||||
scanner_type, model, scanner,
|
scanner_type, model, scanner,
|
||||||
output_dir, optimize, independent_session
|
output_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
@@ -270,12 +259,6 @@ class DownloadManager:
|
|||||||
download_progress['end_time'] = time.time()
|
download_progress['end_time'] = time.time()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Close the independent session
|
|
||||||
try:
|
|
||||||
await independent_session.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing download session: {e}")
|
|
||||||
|
|
||||||
# Save final progress to file
|
# Save final progress to file
|
||||||
try:
|
try:
|
||||||
DownloadManager._save_progress(output_dir)
|
DownloadManager._save_progress(output_dir)
|
||||||
@@ -286,7 +269,7 @@ class DownloadManager:
|
|||||||
is_downloading = False
|
is_downloading = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _process_model(scanner_type, model, scanner, output_dir, optimize, independent_session):
|
async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader):
|
||||||
"""Process a single model download"""
|
"""Process a single model download"""
|
||||||
global download_progress
|
global download_progress
|
||||||
|
|
||||||
@@ -347,7 +330,7 @@ class DownloadManager:
|
|||||||
images = model.get('civitai', {}).get('images', [])
|
images = model.get('civitai', {}).get('images', [])
|
||||||
|
|
||||||
success, is_stale = await ExampleImagesProcessor.download_model_images(
|
success, is_stale = await ExampleImagesProcessor.download_model_images(
|
||||||
model_hash, model_name, images, model_dir, optimize, independent_session
|
model_hash, model_name, images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# If metadata is stale, try to refresh it
|
# If metadata is stale, try to refresh it
|
||||||
@@ -365,7 +348,7 @@ class DownloadManager:
|
|||||||
# Retry download with updated metadata
|
# Retry download with updated metadata
|
||||||
updated_images = updated_model.get('civitai', {}).get('images', [])
|
updated_images = updated_model.get('civitai', {}).get('images', [])
|
||||||
success, _ = await ExampleImagesProcessor.download_model_images(
|
success, _ = await ExampleImagesProcessor.download_model_images(
|
||||||
model_hash, model_name, updated_images, model_dir, optimize, independent_session
|
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
download_progress['refreshed_models'].add(model_hash)
|
download_progress['refreshed_models'].add(model_hash)
|
||||||
@@ -529,19 +512,8 @@ class DownloadManager:
|
|||||||
"""Download example images for specific models only - synchronous version"""
|
"""Download example images for specific models only - synchronous version"""
|
||||||
global download_progress
|
global download_progress
|
||||||
|
|
||||||
# Create independent download session
|
# Get unified downloader
|
||||||
connector = aiohttp.TCPConnector(
|
downloader = await get_downloader()
|
||||||
ssl=True,
|
|
||||||
limit=3,
|
|
||||||
force_close=False,
|
|
||||||
enable_cleanup_closed=True
|
|
||||||
)
|
|
||||||
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
|
|
||||||
independent_session = aiohttp.ClientSession(
|
|
||||||
connector=connector,
|
|
||||||
trust_env=True,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get scanners
|
# Get scanners
|
||||||
@@ -586,7 +558,7 @@ class DownloadManager:
|
|||||||
# Force process this model regardless of previous status
|
# Force process this model regardless of previous status
|
||||||
was_successful = await DownloadManager._process_specific_model(
|
was_successful = await DownloadManager._process_specific_model(
|
||||||
scanner_type, model, scanner,
|
scanner_type, model, scanner,
|
||||||
output_dir, optimize, independent_session
|
output_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
if was_successful:
|
if was_successful:
|
||||||
@@ -650,14 +622,11 @@ class DownloadManager:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Close the independent session
|
# No need to close any sessions since we use the global downloader
|
||||||
try:
|
pass
|
||||||
await independent_session.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error closing download session: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session):
|
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader):
|
||||||
"""Process a specific model for forced download, ignoring previous download status"""
|
"""Process a specific model for forced download, ignoring previous download status"""
|
||||||
global download_progress
|
global download_progress
|
||||||
|
|
||||||
@@ -701,7 +670,7 @@ class DownloadManager:
|
|||||||
images = model.get('civitai', {}).get('images', [])
|
images = model.get('civitai', {}).get('images', [])
|
||||||
|
|
||||||
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||||
model_hash, model_name, images, model_dir, optimize, independent_session
|
model_hash, model_name, images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# If metadata is stale, try to refresh it
|
# If metadata is stale, try to refresh it
|
||||||
@@ -719,7 +688,7 @@ class DownloadManager:
|
|||||||
# Retry download with updated metadata
|
# Retry download with updated metadata
|
||||||
updated_images = updated_model.get('civitai', {}).get('images', [])
|
updated_images = updated_model.get('civitai', {}).get('images', [])
|
||||||
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
|
||||||
model_hash, model_name, updated_images, model_dir, optimize, independent_session
|
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine failed images from both attempts
|
# Combine failed images from both attempts
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ class ExampleImagesProcessor:
|
|||||||
return image_url
|
return image_url
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, independent_session):
|
async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, downloader):
|
||||||
"""Download images for a single model
|
"""Download images for a single model
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -78,23 +78,25 @@ class ExampleImagesProcessor:
|
|||||||
try:
|
try:
|
||||||
logger.debug(f"Downloading {save_filename} for {model_name}")
|
logger.debug(f"Downloading {save_filename} for {model_name}")
|
||||||
|
|
||||||
# Download directly using the independent session
|
# Download using the unified downloader
|
||||||
async with independent_session.get(image_url, timeout=60) as response:
|
success, content = await downloader.download_to_memory(
|
||||||
if response.status == 200:
|
image_url,
|
||||||
with open(save_path, 'wb') as f:
|
use_auth=False # Example images don't need auth
|
||||||
async for chunk in response.content.iter_chunked(8192):
|
)
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
if success:
|
||||||
elif response.status == 404:
|
with open(save_path, 'wb') as f:
|
||||||
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
|
f.write(content)
|
||||||
logger.warning(error_msg)
|
elif "404" in str(content):
|
||||||
model_success = False # Mark the model as failed due to 404 error
|
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
|
||||||
# Return early to trigger metadata refresh attempt
|
logger.warning(error_msg)
|
||||||
return False, True # (success, is_metadata_stale)
|
model_success = False # Mark the model as failed due to 404 error
|
||||||
else:
|
# Return early to trigger metadata refresh attempt
|
||||||
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
|
return False, True # (success, is_metadata_stale)
|
||||||
logger.warning(error_msg)
|
else:
|
||||||
model_success = False # Mark the model as failed
|
error_msg = f"Failed to download file: {image_url}, error: {content}"
|
||||||
|
logger.warning(error_msg)
|
||||||
|
model_success = False # Mark the model as failed
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error downloading file {image_url}: {str(e)}"
|
error_msg = f"Error downloading file {image_url}: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -103,7 +105,7 @@ class ExampleImagesProcessor:
|
|||||||
return model_success, False # (success, is_metadata_stale)
|
return model_success, False # (success, is_metadata_stale)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session):
|
async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, downloader):
|
||||||
"""Download images for a single model with tracking of failed image URLs
|
"""Download images for a single model with tracking of failed image URLs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -147,25 +149,27 @@ class ExampleImagesProcessor:
|
|||||||
try:
|
try:
|
||||||
logger.debug(f"Downloading {save_filename} for {model_name}")
|
logger.debug(f"Downloading {save_filename} for {model_name}")
|
||||||
|
|
||||||
# Download directly using the independent session
|
# Download using the unified downloader
|
||||||
async with independent_session.get(image_url, timeout=60) as response:
|
success, content = await downloader.download_to_memory(
|
||||||
if response.status == 200:
|
image_url,
|
||||||
with open(save_path, 'wb') as f:
|
use_auth=False # Example images don't need auth
|
||||||
async for chunk in response.content.iter_chunked(8192):
|
)
|
||||||
if chunk:
|
|
||||||
f.write(chunk)
|
if success:
|
||||||
elif response.status == 404:
|
with open(save_path, 'wb') as f:
|
||||||
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
|
f.write(content)
|
||||||
logger.warning(error_msg)
|
elif "404" in str(content):
|
||||||
model_success = False # Mark the model as failed due to 404 error
|
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
|
||||||
failed_images.append(image_url) # Track failed URL
|
logger.warning(error_msg)
|
||||||
# Return early to trigger metadata refresh attempt
|
model_success = False # Mark the model as failed due to 404 error
|
||||||
return False, True, failed_images # (success, is_metadata_stale, failed_images)
|
failed_images.append(image_url) # Track failed URL
|
||||||
else:
|
# Return early to trigger metadata refresh attempt
|
||||||
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
|
return False, True, failed_images # (success, is_metadata_stale, failed_images)
|
||||||
logger.warning(error_msg)
|
else:
|
||||||
model_success = False # Mark the model as failed
|
error_msg = f"Failed to download file: {image_url}, error: {content}"
|
||||||
failed_images.append(image_url) # Track failed URL
|
logger.warning(error_msg)
|
||||||
|
model_success = False # Mark the model as failed
|
||||||
|
failed_images.append(image_url) # Track failed URL
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error downloading file {image_url}: {str(e)}"
|
error_msg = f"Error downloading file {image_url}: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
|||||||
Reference in New Issue
Block a user