Refactor download logic to use unified downloader service

- Introduced a new `Downloader` class to centralize HTTP/HTTPS download management.
- Replaced direct `aiohttp` session handling with the unified downloader in `MetadataArchiveManager`, `DownloadManager`, and `ExampleImagesProcessor`.
- Added support for resumable downloads, progress tracking, and error handling in the new downloader.
- Updated methods to utilize the downloader's capabilities for downloading files and images, improving code maintainability and readability.
This commit is contained in:
Will Miao
2025-09-09 10:34:14 +08:00
parent 821827a375
commit 14721c265f
6 changed files with 777 additions and 514 deletions

View File

@@ -1,5 +1,4 @@
import os import os
import aiohttp
import logging import logging
import toml import toml
import git import git
@@ -8,7 +7,7 @@ import shutil
import tempfile import tempfile
from aiohttp import web from aiohttp import web
from typing import Dict, List from typing import Dict, List
from ..services.downloader import get_downloader, Downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -162,28 +161,42 @@ class UpdateRoutes:
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
try: try:
async with aiohttp.ClientSession() as session: downloader = await get_downloader()
async with session.get(github_api) as resp:
if resp.status != 200:
logger.error(f"Failed to fetch release info: {resp.status}")
return False, ""
data = await resp.json()
zip_url = data.get("zipball_url")
version = data.get("tag_name", "unknown")
# Download ZIP # Get release info
async with session.get(zip_url) as zip_resp: success, data = await downloader.make_request(
if zip_resp.status != 200: 'GET',
logger.error(f"Failed to download ZIP: {zip_resp.status}") github_api,
return False, "" use_auth=False
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip: )
tmp_zip.write(await zip_resp.read()) if not success:
zip_path = tmp_zip.name logger.error(f"Failed to fetch release info: {data}")
return False, ""
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json']) zip_url = data.get("zipball_url")
version = data.get("tag_name", "unknown")
# Extract ZIP to temp dir # Download ZIP to temporary file
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
tmp_zip_path = tmp_zip.name
success, result = await downloader.download_file(
url=zip_url,
save_path=tmp_zip_path,
use_auth=False,
allow_resume=False
)
if not success:
logger.error(f"Failed to download ZIP: {result}")
return False, ""
zip_path = tmp_zip_path
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json'])
# Extract ZIP to temp dir
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(zip_path, 'r') as zip_ref: with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(tmp_dir) zip_ref.extractall(tmp_dir)
# Find extracted folder (GitHub ZIP contains a root folder) # Find extracted folder (GitHub ZIP contains a root folder)
@@ -213,9 +226,9 @@ class UpdateRoutes:
with open(tracking_info_file, "w", encoding='utf-8') as file: with open(tracking_info_file, "w", encoding='utf-8') as file:
file.write('\n'.join(tracking_files)) file.write('\n'.join(tracking_files))
os.remove(zip_path) os.remove(zip_path)
logger.info(f"Updated plugin via ZIP to {version}") logger.info(f"Updated plugin via ZIP to {version}")
return True, version return True, version
except Exception as e: except Exception as e:
logger.error(f"ZIP update failed: {e}", exc_info=True) logger.error(f"ZIP update failed: {e}", exc_info=True)
@@ -244,23 +257,23 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main" github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
try: try:
async with aiohttp.ClientSession() as session: downloader = await Downloader.get_instance()
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response: success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'})
if response.status != 200:
logger.warning(f"Failed to fetch GitHub commit: {response.status}")
return "main", []
data = await response.json() if not success:
commit_sha = data.get('sha', '')[:7] # Short hash logger.warning(f"Failed to fetch GitHub commit: {data}")
commit_message = data.get('commit', {}).get('message', '') return "main", []
# Format as "main-{short_hash}" commit_sha = data.get('sha', '')[:7] # Short hash
version = f"main-{commit_sha}" commit_message = data.get('commit', {}).get('message', '')
# Use commit message as changelog # Format as "main-{short_hash}"
changelog = [commit_message] if commit_message else [] version = f"main-{commit_sha}"
return version, changelog # Use commit message as changelog
changelog = [commit_message] if commit_message else []
return version, changelog
except Exception as e: except Exception as e:
logger.error(f"Error fetching nightly version: {e}", exc_info=True) logger.error(f"Error fetching nightly version: {e}", exc_info=True)
@@ -410,22 +423,22 @@ class UpdateRoutes:
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest" github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
try: try:
async with aiohttp.ClientSession() as session: downloader = await Downloader.get_instance()
async with session.get(github_url, headers={'Accept': 'application/vnd.github+json'}) as response: success, data = await downloader.make_request('GET', github_url, headers={'Accept': 'application/vnd.github+json'})
if response.status != 200:
logger.warning(f"Failed to fetch GitHub release: {response.status}")
return "v0.0.0", []
data = await response.json() if not success:
version = data.get('tag_name', '') logger.warning(f"Failed to fetch GitHub release: {data}")
if not version.startswith('v'): return "v0.0.0", []
version = f"v{version}"
# Extract changelog from release notes version = data.get('tag_name', '')
body = data.get('body', '') if not version.startswith('v'):
changelog = UpdateRoutes._parse_changelog(body) version = f"v{version}"
return version, changelog # Extract changelog from release notes
body = data.get('body', '')
changelog = UpdateRoutes._parse_changelog(body)
return version, changelog
except Exception as e: except Exception as e:
logger.error(f"Error fetching remote version: {e}", exc_info=True) logger.error(f"Error fetching remote version: {e}", exc_info=True)

View File

@@ -1,10 +1,10 @@
from datetime import datetime from datetime import datetime
import aiohttp
import os import os
import logging import logging
import asyncio import asyncio
from typing import Optional, Dict, Tuple, List from typing import Optional, Dict, Tuple, List
from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager from .model_metadata_provider import CivitaiModelMetadataProvider, ModelMetadataProviderManager
from .downloader import get_downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -32,60 +32,6 @@ class CivitaiClient:
self._initialized = True self._initialized = True
self.base_url = "https://civitai.com/api/v1" self.base_url = "https://civitai.com/api/v1"
self.headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
self._session = None
self._session_created_at = None
# Adjust chunk size based on storage type - consider making this configurable
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better HDD throughput
@property
async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session"""
if self._session is None:
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=8, # Increase from 3 to 8 for better parallelism
ttl_dns_cache=300, # Enable DNS caching with reasonable timeout
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
trust_env = True # Allow using system environment proxy settings
# Configure timeout parameters - increase read timeout for large files and remove sock_read timeout
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=None)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=trust_env,
timeout=timeout
)
self._session_created_at = datetime.now()
return self._session
async def _ensure_fresh_session(self):
"""Refresh session if it's been open too long"""
if self._session is not None:
if not hasattr(self, '_session_created_at') or \
(datetime.now() - self._session_created_at).total_seconds() > 300: # 5 minutes
await self.close()
self._session = None
return await self.session
def _get_request_headers(self) -> dict:
"""Get request headers with optional API key"""
headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
'Content-Type': 'application/json'
}
from .settings_manager import settings
api_key = settings.get('civitai_api_key')
if (api_key):
headers['Authorization'] = f'Bearer {api_key}'
return headers
async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]: async def download_file(self, url: str, save_dir: str, default_filename: str, progress_callback=None) -> Tuple[bool, str]:
"""Download file with resumable downloads and retry mechanism """Download file with resumable downloads and retry mechanism
@@ -99,214 +45,69 @@ class CivitaiClient:
Returns: Returns:
Tuple[bool, str]: (success, save_path or error message) Tuple[bool, str]: (success, save_path or error message)
""" """
max_retries = 5 downloader = await get_downloader()
retry_count = 0
base_delay = 2.0 # Base delay for exponential backoff
# Initial setup
session = await self._ensure_fresh_session()
save_path = os.path.join(save_dir, default_filename) save_path = os.path.join(save_dir, default_filename)
part_path = save_path + '.part'
# Get existing file size for resume # Use unified downloader with CivitAI authentication
resume_offset = 0 success, result = await downloader.download_file(
if os.path.exists(part_path): url=url,
resume_offset = os.path.getsize(part_path) save_path=save_path,
logger.info(f"Resuming download from offset {resume_offset} bytes") progress_callback=progress_callback,
use_auth=True, # Enable CivitAI authentication
allow_resume=True
)
total_size = 0 return success, result
while retry_count <= max_retries:
try:
headers = self._get_request_headers()
# Add Range header for resume if we have partial data
if resume_offset > 0:
headers['Range'] = f'bytes={resume_offset}-'
# Add Range header to allow resumable downloads
headers['Accept-Encoding'] = 'identity' # Disable compression for better chunked downloads
logger.debug(f"Download attempt {retry_count + 1}/{max_retries + 1} from: {url}")
if resume_offset > 0:
logger.debug(f"Requesting range from byte {resume_offset}")
async with session.get(url, headers=headers, allow_redirects=True) as response:
# Handle different response codes
if response.status == 200:
# Full content response
if resume_offset > 0:
# Server doesn't support ranges, restart from beginning
logger.warning("Server doesn't support range requests, restarting download")
resume_offset = 0
if os.path.exists(part_path):
os.remove(part_path)
elif response.status == 206:
# Partial content response (resume successful)
content_range = response.headers.get('Content-Range')
if content_range:
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
range_parts = content_range.split('/')
if len(range_parts) == 2:
total_size = int(range_parts[1])
logger.info(f"Successfully resumed download from byte {resume_offset}")
elif response.status == 416:
# Range not satisfiable - file might be complete or corrupted
if os.path.exists(part_path):
part_size = os.path.getsize(part_path)
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
# Try to get actual file size
head_response = await session.head(url, headers=self._get_request_headers())
if head_response.status == 200:
actual_size = int(head_response.headers.get('content-length', 0))
if part_size == actual_size:
# File is complete, just rename it
os.rename(part_path, save_path)
if progress_callback:
await progress_callback(100)
return True, save_path
# Remove corrupted part file and restart
os.remove(part_path)
resume_offset = 0
continue
elif response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
return False, "Invalid or missing CivitAI API key, or early access restriction."
elif response.status == 403:
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
return False, "Access forbidden: You don't have permission to download this file."
else:
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get total file size for progress calculation (if not set from Content-Range)
if total_size == 0:
total_size = int(response.headers.get('content-length', 0))
if response.status == 206:
# For partial content, add the offset to get total file size
total_size += resume_offset
current_size = resume_offset
last_progress_report_time = datetime.now()
# Stream download to file with progress updates using larger buffer
loop = asyncio.get_running_loop()
mode = 'ab' if resume_offset > 0 else 'wb'
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Download completed successfully
# Verify file size if total_size was provided
final_size = os.path.getsize(part_path)
if total_size > 0 and final_size != total_size:
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
# Don't treat this as fatal error, rename anyway
# Atomically rename .part to final file with retries
max_rename_attempts = 5
rename_attempt = 0
rename_success = False
while rename_attempt < max_rename_attempts and not rename_success:
try:
os.rename(part_path, save_path)
rename_success = True
except PermissionError as e:
rename_attempt += 1
if rename_attempt < max_rename_attempts:
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
await asyncio.sleep(2) # Wait before retrying
else:
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
return False, f"Failed to finalize download: {str(e)}"
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{max_retries + 1}): {e}")
if retry_count <= max_retries:
# Calculate delay with exponential backoff
delay = base_delay * (2 ** (retry_count - 1))
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
# Update resume offset for next attempt
if os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Will resume from byte {resume_offset}")
# Refresh session to get new connection
await self.close()
session = await self._ensure_fresh_session()
continue
else:
logger.error(f"Max retries exceeded for download: {e}")
return False, f"Network error after {max_retries + 1} attempts: {str(e)}"
except Exception as e:
logger.error(f"Unexpected download error: {e}")
return False, str(e)
return False, f"Download failed after {max_retries + 1} attempts"
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try: try:
session = await self._ensure_fresh_session() downloader = await get_downloader()
async with session.get(f"{self.base_url}/model-versions/by-hash/{model_hash}") as response: success, result = await downloader.make_request(
if response.status == 200: 'GET',
return await response.json() f"{self.base_url}/model-versions/by-hash/{model_hash}",
return None use_auth=True
)
if success:
return result
return None
except Exception as e: except Exception as e:
logger.error(f"API Error: {str(e)}") logger.error(f"API Error: {str(e)}")
return None return None
async def download_preview_image(self, image_url: str, save_path: str): async def download_preview_image(self, image_url: str, save_path: str):
try: try:
session = await self._ensure_fresh_session() downloader = await get_downloader()
async with session.get(image_url) as response: success, content = await downloader.download_to_memory(
if response.status == 200: image_url,
content = await response.read() use_auth=False # Preview images don't need auth
with open(save_path, 'wb') as f: )
f.write(content) if success:
return True # Ensure directory exists
return False os.makedirs(os.path.dirname(save_path), exist_ok=True)
with open(save_path, 'wb') as f:
f.write(content)
return True
return False
except Exception as e: except Exception as e:
print(f"Download Error: {str(e)}") logger.error(f"Download Error: {str(e)}")
return False return False
async def get_model_versions(self, model_id: str) -> List[Dict]: async def get_model_versions(self, model_id: str) -> List[Dict]:
"""Get all versions of a model with local availability info""" """Get all versions of a model with local availability info"""
try: try:
session = await self._ensure_fresh_session() # Use fresh session downloader = await get_downloader()
async with session.get(f"{self.base_url}/models/{model_id}") as response: success, result = await downloader.make_request(
if response.status != 200: 'GET',
return None f"{self.base_url}/models/{model_id}",
data = await response.json() use_auth=True
)
if success:
# Also return model type along with versions # Also return model type along with versions
return { return {
'modelVersions': data.get('modelVersions', []), 'modelVersions': result.get('modelVersions', []),
'type': data.get('type', '') 'type': result.get('type', '')
} }
return None
except Exception as e: except Exception as e:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
@@ -322,68 +123,74 @@ class CivitaiClient:
Optional[Dict]: The model version data with additional fields or None if not found Optional[Dict]: The model version data with additional fields or None if not found
""" """
try: try:
session = await self._ensure_fresh_session() downloader = await get_downloader()
headers = self._get_request_headers()
# Case 1: Only version_id is provided # Case 1: Only version_id is provided
if model_id is None and version_id is not None: if model_id is None and version_id is not None:
# First get the version info to extract model_id # First get the version info to extract model_id
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response: success, version = await downloader.make_request(
if response.status != 200: 'GET',
return None f"{self.base_url}/model-versions/{version_id}",
use_auth=True
)
if not success:
return None
version = await response.json() model_id = version.get('modelId')
model_id = version.get('modelId') if not model_id:
logger.error(f"No modelId found in version {version_id}")
if not model_id: return None
logger.error(f"No modelId found in version {version_id}")
return None
# Now get the model data for additional metadata # Now get the model data for additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response: success, model_data = await downloader.make_request(
if response.status != 200: 'GET',
return version # Return version without additional metadata f"{self.base_url}/models/{model_id}",
use_auth=True
model_data = await response.json() )
if success:
# Enrich version with model data # Enrich version with model data
version['model']['description'] = model_data.get("description") version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", []) version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator") version['creator'] = model_data.get("creator")
return version return version
# Case 2: model_id is provided (with or without version_id) # Case 2: model_id is provided (with or without version_id)
elif model_id is not None: elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata # Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response: success, data = await downloader.make_request(
if response.status != 200: 'GET',
return None f"{self.base_url}/models/{model_id}",
use_auth=True
)
if not success:
return None
data = await response.json() model_versions = data.get('modelVersions', [])
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use # Step 2: Determine the version_id to use
target_version_id = version_id target_version_id = version_id
if target_version_id is None: if target_version_id is None:
target_version_id = model_versions[0].get('id') target_version_id = model_versions[0].get('id')
# Step 3: Get detailed version info using the version_id # Step 3: Get detailed version info using the version_id
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response: success, version = await downloader.make_request(
if response.status != 200: 'GET',
return None f"{self.base_url}/model-versions/{target_version_id}",
use_auth=True
)
if not success:
return None
version = await response.json() # Step 4: Enrich version_info with model data
# Add description and tags from model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
# Step 4: Enrich version_info with model data # Add creator from model data
# Add description and tags from model data version['creator'] = data.get("creator")
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
# Add creator from model data return version
version['creator'] = data.get("creator")
return version
# Case 3: Neither model_id nor version_id provided # Case 3: Neither model_id nor version_id provided
else: else:
@@ -406,30 +213,29 @@ class CivitaiClient:
- An error message if there was an error, or None on success - An error message if there was an error, or None on success
""" """
try: try:
session = await self._ensure_fresh_session() downloader = await get_downloader()
url = f"{self.base_url}/model-versions/{version_id}" url = f"{self.base_url}/model-versions/{version_id}"
headers = self._get_request_headers()
logger.debug(f"Resolving DNS for model version info: {url}") logger.debug(f"Resolving DNS for model version info: {url}")
async with session.get(url, headers=headers) as response: success, result = await downloader.make_request(
if response.status == 200: 'GET',
logger.debug(f"Successfully fetched model version info for: {version_id}") url,
return await response.json(), None use_auth=True
)
# Handle specific error cases if success:
if response.status == 404: logger.debug(f"Successfully fetched model version info for: {version_id}")
# Try to parse the error message return result, None
try:
error_data = await response.json()
error_msg = error_data.get('error', f"Model not found (status 404)")
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
except:
return None, "Model not found (status 404)"
# Other error cases # Handle specific error cases
logger.error(f"Failed to fetch model info for {version_id} (status {response.status})") if "404" in str(result):
return None, f"Failed to fetch model info (status {response.status})" error_msg = f"Model not found (status 404)"
logger.warning(f"Model version not found: {version_id} - {error_msg}")
return None, error_msg
# Other error cases
logger.error(f"Failed to fetch model info for {version_id}: {result}")
return None, str(result)
except Exception as e: except Exception as e:
error_msg = f"Error fetching model version info: {e}" error_msg = f"Error fetching model version info: {e}"
logger.error(error_msg) logger.error(error_msg)
@@ -444,48 +250,50 @@ class CivitaiClient:
Returns: Returns:
Tuple[Optional[Dict], int]: A tuple containing: Tuple[Optional[Dict], int]: A tuple containing:
- A dictionary with model metadata or None if not found - A dictionary with model metadata or None if not found
- The HTTP status code from the request - The HTTP status code from the request (0 for exceptions)
""" """
try: try:
session = await self._ensure_fresh_session() downloader = await get_downloader()
headers = self._get_request_headers()
url = f"{self.base_url}/models/{model_id}" url = f"{self.base_url}/models/{model_id}"
async with session.get(url, headers=headers) as response: success, result = await downloader.make_request(
status_code = response.status 'GET',
url,
use_auth=True
)
if status_code != 200: if not success:
logger.warning(f"Failed to fetch model metadata: Status {status_code}") # Try to extract status code from error message
return None, status_code status_code = 0
if "404" in str(result):
status_code = 404
elif "401" in str(result):
status_code = 401
elif "403" in str(result):
status_code = 403
logger.warning(f"Failed to fetch model metadata: {result}")
return None, status_code
data = await response.json() # Extract relevant metadata
metadata = {
# Extract relevant metadata "description": result.get("description") or "No model description available",
metadata = { "tags": result.get("tags", []),
"description": data.get("description") or "No model description available", "creator": {
"tags": data.get("tags", []), "username": result.get("creator", {}).get("username"),
"creator": { "image": result.get("creator", {}).get("image")
"username": data.get("creator", {}).get("username"),
"image": data.get("creator", {}).get("image")
}
} }
}
if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]: if metadata["description"] or metadata["tags"] or metadata["creator"]["username"]:
return metadata, status_code return metadata, 200
else: else:
logger.warning(f"No metadata found for model {model_id}") logger.warning(f"No metadata found for model {model_id}")
return None, status_code return None, 200
except Exception as e: except Exception as e:
logger.error(f"Error fetching model metadata: {e}", exc_info=True) logger.error(f"Error fetching model metadata: {e}", exc_info=True)
return None, 0 return None, 0
async def close(self):
"""Close the session if it exists"""
if self._session is not None:
await self._session.close()
self._session = None
async def get_image_info(self, image_id: str) -> Optional[Dict]: async def get_image_info(self, image_id: str) -> Optional[Dict]:
"""Fetch image information from Civitai API """Fetch image information from Civitai API
@@ -496,22 +304,25 @@ class CivitaiClient:
Optional[Dict]: The image data or None if not found Optional[Dict]: The image data or None if not found
""" """
try: try:
session = await self._ensure_fresh_session() downloader = await get_downloader()
headers = self._get_request_headers()
url = f"{self.base_url}/images?imageId={image_id}&nsfw=X" url = f"{self.base_url}/images?imageId={image_id}&nsfw=X"
logger.debug(f"Fetching image info for ID: {image_id}") logger.debug(f"Fetching image info for ID: {image_id}")
async with session.get(url, headers=headers) as response: success, result = await downloader.make_request(
if response.status == 200: 'GET',
data = await response.json() url,
if data and "items" in data and len(data["items"]) > 0: use_auth=True
logger.debug(f"Successfully fetched image info for ID: {image_id}") )
return data["items"][0]
logger.warning(f"No image found with ID: {image_id}")
return None
logger.error(f"Failed to fetch image info for ID: {image_id} (status {response.status})") if success:
if result and "items" in result and len(result["items"]) > 0:
logger.debug(f"Successfully fetched image info for ID: {image_id}")
return result["items"][0]
logger.warning(f"No image found with ID: {image_id}")
return None return None
logger.error(f"Failed to fetch image info for ID: {image_id}: {result}")
return None
except Exception as e: except Exception as e:
error_msg = f"Error fetching image info: {e}" error_msg = f"Error fetching image info: {e}"
logger.error(error_msg) logger.error(error_msg)

465
py/services/downloader.py Normal file
View File

@@ -0,0 +1,465 @@
"""
Unified download manager for all HTTP/HTTPS downloads in the application.
This module provides a centralized download service with:
- Singleton pattern for global session management
- Support for authenticated downloads (e.g., CivitAI API key)
- Resumable downloads with automatic retry
- Progress tracking and callbacks
- Optimized connection pooling and timeouts
- Unified error handling and logging
"""
import os
import logging
import asyncio
import aiohttp
from datetime import datetime
from typing import Optional, Dict, Tuple, Callable, Union
from ..services.settings_manager import settings
logger = logging.getLogger(__name__)
class Downloader:
"""Unified downloader for all HTTP/HTTPS downloads in the application."""
_instance = None
_lock = asyncio.Lock()
@classmethod
async def get_instance(cls):
"""Get singleton instance of Downloader"""
async with cls._lock:
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""Initialize the downloader with optimal settings"""
# Check if already initialized for singleton pattern
if hasattr(self, '_initialized'):
return
self._initialized = True
# Session management
self._session = None
self._session_created_at = None
# Configuration
self.chunk_size = 4 * 1024 * 1024 # 4MB chunks for better throughput
self.max_retries = 5
self.base_delay = 2.0 # Base delay for exponential backoff
self.session_timeout = 300 # 5 minutes
# Default headers
self.default_headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0'
}
@property
async def session(self) -> aiohttp.ClientSession:
"""Get or create the global aiohttp session with optimized settings"""
if self._session is None or self._should_refresh_session():
await self._create_session()
return self._session
def _should_refresh_session(self) -> bool:
"""Check if session should be refreshed"""
if self._session is None:
return True
if not hasattr(self, '_session_created_at') or self._session_created_at is None:
return True
# Refresh if session is older than timeout
if (datetime.now() - self._session_created_at).total_seconds() > self.session_timeout:
return True
return False
async def _create_session(self):
"""Create a new aiohttp session with optimized settings"""
# Close existing session if any
if self._session is not None:
await self._session.close()
# Optimize TCP connection parameters
connector = aiohttp.TCPConnector(
ssl=True,
limit=8, # Concurrent connections
ttl_dns_cache=300, # DNS cache timeout
force_close=False, # Keep connections for reuse
enable_cleanup_closed=True
)
# Configure timeout parameters
timeout = aiohttp.ClientTimeout(
total=None, # No total timeout for large downloads
connect=60, # Connection timeout
sock_read=None # No socket read timeout
)
self._session = aiohttp.ClientSession(
connector=connector,
trust_env=True, # Use system proxy settings
timeout=timeout
)
self._session_created_at = datetime.now()
logger.debug("Created new HTTP session")
def _get_auth_headers(self, use_auth: bool = False) -> Dict[str, str]:
"""Get headers with optional authentication"""
headers = self.default_headers.copy()
if use_auth:
# Add CivitAI API key if available
api_key = settings.get('civitai_api_key')
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
headers['Content-Type'] = 'application/json'
return headers
async def download_file(
self,
url: str,
save_path: str,
progress_callback: Optional[Callable[[float], None]] = None,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
allow_resume: bool = True
) -> Tuple[bool, str]:
"""
Download a file with resumable downloads and retry mechanism
Args:
url: Download URL
save_path: Full path where the file should be saved
progress_callback: Optional callback for progress updates (0-100)
use_auth: Whether to include authentication headers (e.g., CivitAI API key)
custom_headers: Additional headers to include in request
allow_resume: Whether to support resumable downloads
Returns:
Tuple[bool, str]: (success, save_path or error message)
"""
retry_count = 0
part_path = save_path + '.part' if allow_resume else save_path
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
# Get existing file size for resume
resume_offset = 0
if allow_resume and os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Resuming download from offset {resume_offset} bytes")
total_size = 0
while retry_count <= self.max_retries:
try:
session = await self.session
# Add Range header for resume if we have partial data
request_headers = headers.copy()
if allow_resume and resume_offset > 0:
request_headers['Range'] = f'bytes={resume_offset}-'
# Disable compression for better chunked downloads
request_headers['Accept-Encoding'] = 'identity'
logger.debug(f"Download attempt {retry_count + 1}/{self.max_retries + 1} from: {url}")
if resume_offset > 0:
logger.debug(f"Requesting range from byte {resume_offset}")
async with session.get(url, headers=request_headers, allow_redirects=True) as response:
# Handle different response codes
if response.status == 200:
# Full content response
if resume_offset > 0:
# Server doesn't support ranges, restart from beginning
logger.warning("Server doesn't support range requests, restarting download")
resume_offset = 0
if os.path.exists(part_path):
os.remove(part_path)
elif response.status == 206:
# Partial content response (resume successful)
content_range = response.headers.get('Content-Range')
if content_range:
# Parse total size from Content-Range header (e.g., "bytes 1024-2047/2048")
range_parts = content_range.split('/')
if len(range_parts) == 2:
total_size = int(range_parts[1])
logger.info(f"Successfully resumed download from byte {resume_offset}")
elif response.status == 416:
# Range not satisfiable - file might be complete or corrupted
if allow_resume and os.path.exists(part_path):
part_size = os.path.getsize(part_path)
logger.warning(f"Range not satisfiable. Part file size: {part_size}")
# Try to get actual file size
head_response = await session.head(url, headers=headers)
if head_response.status == 200:
actual_size = int(head_response.headers.get('content-length', 0))
if part_size == actual_size:
# File is complete, just rename it
if allow_resume:
os.rename(part_path, save_path)
if progress_callback:
await progress_callback(100)
return True, save_path
# Remove corrupted part file and restart
os.remove(part_path)
resume_offset = 0
continue
elif response.status == 401:
logger.warning(f"Unauthorized access to resource: {url} (Status 401)")
return False, "Invalid or missing API key, or early access restriction."
elif response.status == 403:
logger.warning(f"Forbidden access to resource: {url} (Status 403)")
return False, "Access forbidden: You don't have permission to download this file."
elif response.status == 404:
logger.warning(f"Resource not found: {url} (Status 404)")
return False, "File not found - the download link may be invalid or expired."
else:
logger.error(f"Download failed for {url} with status {response.status}")
return False, f"Download failed with status {response.status}"
# Get total file size for progress calculation (if not set from Content-Range)
if total_size == 0:
total_size = int(response.headers.get('content-length', 0))
if response.status == 206:
# For partial content, add the offset to get total file size
total_size += resume_offset
current_size = resume_offset
last_progress_report_time = datetime.now()
# Ensure directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Stream download to file with progress updates
loop = asyncio.get_running_loop()
mode = 'ab' if (allow_resume and resume_offset > 0) else 'wb'
with open(part_path, mode) as f:
async for chunk in response.content.iter_chunked(self.chunk_size):
if chunk:
# Run blocking file write in executor
await loop.run_in_executor(None, f.write, chunk)
current_size += len(chunk)
# Limit progress update frequency to reduce overhead
now = datetime.now()
time_diff = (now - last_progress_report_time).total_seconds()
if progress_callback and total_size and time_diff >= 1.0:
progress = (current_size / total_size) * 100
await progress_callback(progress)
last_progress_report_time = now
# Download completed successfully
# Verify file size if total_size was provided
final_size = os.path.getsize(part_path)
if total_size > 0 and final_size != total_size:
logger.warning(f"File size mismatch. Expected: {total_size}, Got: {final_size}")
# Don't treat this as fatal error, continue anyway
# Atomically rename .part to final file (only if using resume)
if allow_resume and part_path != save_path:
max_rename_attempts = 5
rename_attempt = 0
rename_success = False
while rename_attempt < max_rename_attempts and not rename_success:
try:
os.rename(part_path, save_path)
rename_success = True
except PermissionError as e:
rename_attempt += 1
if rename_attempt < max_rename_attempts:
logger.info(f"File still in use, retrying rename in 2 seconds (attempt {rename_attempt}/{max_rename_attempts})")
await asyncio.sleep(2)
else:
logger.error(f"Failed to rename file after {max_rename_attempts} attempts: {e}")
return False, f"Failed to finalize download: {str(e)}"
# Ensure 100% progress is reported
if progress_callback:
await progress_callback(100)
return True, save_path
except (aiohttp.ClientError, aiohttp.ClientPayloadError,
aiohttp.ServerDisconnectedError, asyncio.TimeoutError) as e:
retry_count += 1
logger.warning(f"Network error during download (attempt {retry_count}/{self.max_retries + 1}): {e}")
if retry_count <= self.max_retries:
# Calculate delay with exponential backoff
delay = self.base_delay * (2 ** (retry_count - 1))
logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
# Update resume offset for next attempt
if allow_resume and os.path.exists(part_path):
resume_offset = os.path.getsize(part_path)
logger.info(f"Will resume from byte {resume_offset}")
# Refresh session to get new connection
await self._create_session()
continue
else:
logger.error(f"Max retries exceeded for download: {e}")
return False, f"Network error after {self.max_retries + 1} attempts: {str(e)}"
except Exception as e:
logger.error(f"Unexpected download error: {e}")
return False, str(e)
return False, f"Download failed after {self.max_retries + 1} attempts"
async def download_to_memory(
self,
url: str,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None
) -> Tuple[bool, Union[bytes, str]]:
"""
Download a file to memory (for small files like preview images)
Args:
url: Download URL
use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request
Returns:
Tuple[bool, Union[bytes, str]]: (success, content or error message)
"""
try:
session = await self.session
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
async with session.get(url, headers=headers) as response:
if response.status == 200:
content = await response.read()
return True, content
elif response.status == 401:
return False, "Unauthorized access - invalid or missing API key"
elif response.status == 403:
return False, "Access forbidden"
elif response.status == 404:
return False, "File not found"
else:
return False, f"Download failed with status {response.status}"
except Exception as e:
logger.error(f"Error downloading to memory from {url}: {e}")
return False, str(e)
async def get_response_headers(
self,
url: str,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None
) -> Tuple[bool, Union[Dict, str]]:
"""
Get response headers without downloading the full content
Args:
url: URL to check
use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request
Returns:
Tuple[bool, Union[Dict, str]]: (success, headers dict or error message)
"""
try:
session = await self.session
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
async with session.head(url, headers=headers) as response:
if response.status == 200:
return True, dict(response.headers)
else:
return False, f"Head request failed with status {response.status}"
except Exception as e:
logger.error(f"Error getting headers from {url}: {e}")
return False, str(e)
async def make_request(
self,
method: str,
url: str,
use_auth: bool = False,
custom_headers: Optional[Dict[str, str]] = None,
**kwargs
) -> Tuple[bool, Union[Dict, str]]:
"""
Make a generic HTTP request and return JSON response
Args:
method: HTTP method (GET, POST, etc.)
url: Request URL
use_auth: Whether to include authentication headers
custom_headers: Additional headers to include in request
**kwargs: Additional arguments for aiohttp request
Returns:
Tuple[bool, Union[Dict, str]]: (success, response data or error message)
"""
try:
session = await self.session
# Prepare headers
headers = self._get_auth_headers(use_auth)
if custom_headers:
headers.update(custom_headers)
async with session.request(method, url, headers=headers, **kwargs) as response:
if response.status == 200:
# Try to parse as JSON, fall back to text
try:
data = await response.json()
return True, data
except:
text = await response.text()
return True, text
elif response.status == 401:
return False, "Unauthorized access - invalid or missing API key"
elif response.status == 403:
return False, "Access forbidden"
elif response.status == 404:
return False, "Resource not found"
else:
return False, f"Request failed with status {response.status}"
except Exception as e:
logger.error(f"Error making {method} request to {url}: {e}")
return False, str(e)
async def close(self):
"""Close the HTTP session"""
if self._session is not None:
await self._session.close()
self._session = None
self._session_created_at = None
logger.debug("Closed HTTP session")
# Global instance accessor
async def get_downloader() -> Downloader:
"""Get the global downloader instance"""
return await Downloader.get_instance()

View File

@@ -1,9 +1,9 @@
import zipfile import zipfile
import aiohttp
import logging import logging
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from .downloader import get_downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -67,6 +67,8 @@ class MetadataArchiveManager:
async def _download_archive(self, progress_callback=None) -> bool: async def _download_archive(self, progress_callback=None) -> bool:
"""Download the zip archive from one of the available URLs""" """Download the zip archive from one of the available URLs"""
downloader = await get_downloader()
for url in self.DOWNLOAD_URLS: for url in self.DOWNLOAD_URLS:
try: try:
logger.info(f"Attempting to download from {url}") logger.info(f"Attempting to download from {url}")
@@ -74,26 +76,25 @@ class MetadataArchiveManager:
if progress_callback: if progress_callback:
progress_callback("download", f"Downloading from {url}") progress_callback("download", f"Downloading from {url}")
async with aiohttp.ClientSession() as session: # Custom progress callback to report download progress
async with session.get(url) as response: async def download_progress(progress):
if response.status == 200: if progress_callback:
total_size = int(response.headers.get('content-length', 0)) progress_callback("download", f"Downloaded {progress:.1f}%")
downloaded = 0
with open(self.archive_path, 'wb') as f: success, result = await downloader.download_file(
async for chunk in response.content.iter_chunked(8192): url=url,
f.write(chunk) save_path=str(self.archive_path),
downloaded += len(chunk) progress_callback=download_progress,
use_auth=False, # Public download, no auth needed
allow_resume=True
)
if progress_callback and total_size > 0: if success:
percentage = (downloaded / total_size) * 100 logger.info(f"Successfully downloaded archive from {url}")
progress_callback("download", f"Downloaded {percentage:.1f}%") return True
else:
logger.info(f"Successfully downloaded archive from {url}") logger.warning(f"Failed to download from {url}: {result}")
return True continue
else:
logger.warning(f"Failed to download from {url}: HTTP {response.status}")
continue
except Exception as e: except Exception as e:
logger.warning(f"Error downloading from {url}: {e}") logger.warning(f"Error downloading from {url}: {e}")

View File

@@ -3,13 +3,13 @@ import os
import asyncio import asyncio
import json import json
import time import time
import aiohttp
from aiohttp import web from aiohttp import web
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater from .example_images_metadata import MetadataUpdater
from ..services.websocket_manager import ws_manager # Add this import at the top from ..services.websocket_manager import ws_manager # Add this import at the top
from ..services.downloader import get_downloader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -199,19 +199,8 @@ class DownloadManager:
"""Download example images for all models""" """Download example images for all models"""
global is_downloading, download_progress global is_downloading, download_progress
# Create independent download session # Get unified downloader
connector = aiohttp.TCPConnector( downloader = await get_downloader()
ssl=True,
limit=3,
force_close=False,
enable_cleanup_closed=True
)
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
independent_session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=timeout
)
try: try:
# Get scanners # Get scanners
@@ -246,7 +235,7 @@ class DownloadManager:
# Main logic for processing model is here, but actual operations are delegated to other classes # Main logic for processing model is here, but actual operations are delegated to other classes
was_remote_download = await DownloadManager._process_model( was_remote_download = await DownloadManager._process_model(
scanner_type, model, scanner, scanner_type, model, scanner,
output_dir, optimize, independent_session output_dir, optimize, downloader
) )
# Update progress # Update progress
@@ -270,12 +259,6 @@ class DownloadManager:
download_progress['end_time'] = time.time() download_progress['end_time'] = time.time()
finally: finally:
# Close the independent session
try:
await independent_session.close()
except Exception as e:
logger.error(f"Error closing download session: {e}")
# Save final progress to file # Save final progress to file
try: try:
DownloadManager._save_progress(output_dir) DownloadManager._save_progress(output_dir)
@@ -286,7 +269,7 @@ class DownloadManager:
is_downloading = False is_downloading = False
@staticmethod @staticmethod
async def _process_model(scanner_type, model, scanner, output_dir, optimize, independent_session): async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader):
"""Process a single model download""" """Process a single model download"""
global download_progress global download_progress
@@ -347,7 +330,7 @@ class DownloadManager:
images = model.get('civitai', {}).get('images', []) images = model.get('civitai', {}).get('images', [])
success, is_stale = await ExampleImagesProcessor.download_model_images( success, is_stale = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, images, model_dir, optimize, independent_session model_hash, model_name, images, model_dir, optimize, downloader
) )
# If metadata is stale, try to refresh it # If metadata is stale, try to refresh it
@@ -365,7 +348,7 @@ class DownloadManager:
# Retry download with updated metadata # Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', []) updated_images = updated_model.get('civitai', {}).get('images', [])
success, _ = await ExampleImagesProcessor.download_model_images( success, _ = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, updated_images, model_dir, optimize, independent_session model_hash, model_name, updated_images, model_dir, optimize, downloader
) )
download_progress['refreshed_models'].add(model_hash) download_progress['refreshed_models'].add(model_hash)
@@ -529,19 +512,8 @@ class DownloadManager:
"""Download example images for specific models only - synchronous version""" """Download example images for specific models only - synchronous version"""
global download_progress global download_progress
# Create independent download session # Get unified downloader
connector = aiohttp.TCPConnector( downloader = await get_downloader()
ssl=True,
limit=3,
force_close=False,
enable_cleanup_closed=True
)
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
independent_session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=timeout
)
try: try:
# Get scanners # Get scanners
@@ -586,7 +558,7 @@ class DownloadManager:
# Force process this model regardless of previous status # Force process this model regardless of previous status
was_successful = await DownloadManager._process_specific_model( was_successful = await DownloadManager._process_specific_model(
scanner_type, model, scanner, scanner_type, model, scanner,
output_dir, optimize, independent_session output_dir, optimize, downloader
) )
if was_successful: if was_successful:
@@ -650,14 +622,11 @@ class DownloadManager:
raise raise
finally: finally:
# Close the independent session # No need to close any sessions since we use the global downloader
try: pass
await independent_session.close()
except Exception as e:
logger.error(f"Error closing download session: {e}")
@staticmethod @staticmethod
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, independent_session): async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader):
"""Process a specific model for forced download, ignoring previous download status""" """Process a specific model for forced download, ignoring previous download status"""
global download_progress global download_progress
@@ -701,7 +670,7 @@ class DownloadManager:
images = model.get('civitai', {}).get('images', []) images = model.get('civitai', {}).get('images', [])
success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( success, is_stale, failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, images, model_dir, optimize, independent_session model_hash, model_name, images, model_dir, optimize, downloader
) )
# If metadata is stale, try to refresh it # If metadata is stale, try to refresh it
@@ -719,7 +688,7 @@ class DownloadManager:
# Retry download with updated metadata # Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', []) updated_images = updated_model.get('civitai', {}).get('images', [])
success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking( success, _, additional_failed_images = await ExampleImagesProcessor.download_model_images_with_tracking(
model_hash, model_name, updated_images, model_dir, optimize, independent_session model_hash, model_name, updated_images, model_dir, optimize, downloader
) )
# Combine failed images from both attempts # Combine failed images from both attempts

View File

@@ -35,7 +35,7 @@ class ExampleImagesProcessor:
return image_url return image_url
@staticmethod @staticmethod
async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, independent_session): async def download_model_images(model_hash, model_name, model_images, model_dir, optimize, downloader):
"""Download images for a single model """Download images for a single model
Returns: Returns:
@@ -78,23 +78,25 @@ class ExampleImagesProcessor:
try: try:
logger.debug(f"Downloading {save_filename} for {model_name}") logger.debug(f"Downloading {save_filename} for {model_name}")
# Download directly using the independent session # Download using the unified downloader
async with independent_session.get(image_url, timeout=60) as response: success, content = await downloader.download_to_memory(
if response.status == 200: image_url,
with open(save_path, 'wb') as f: use_auth=False # Example images don't need auth
async for chunk in response.content.iter_chunked(8192): )
if chunk:
f.write(chunk) if success:
elif response.status == 404: with open(save_path, 'wb') as f:
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" f.write(content)
logger.warning(error_msg) elif "404" in str(content):
model_success = False # Mark the model as failed due to 404 error error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
# Return early to trigger metadata refresh attempt logger.warning(error_msg)
return False, True # (success, is_metadata_stale) model_success = False # Mark the model as failed due to 404 error
else: # Return early to trigger metadata refresh attempt
error_msg = f"Failed to download file: {image_url}, status code: {response.status}" return False, True # (success, is_metadata_stale)
logger.warning(error_msg) else:
model_success = False # Mark the model as failed error_msg = f"Failed to download file: {image_url}, error: {content}"
logger.warning(error_msg)
model_success = False # Mark the model as failed
except Exception as e: except Exception as e:
error_msg = f"Error downloading file {image_url}: {str(e)}" error_msg = f"Error downloading file {image_url}: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
@@ -103,7 +105,7 @@ class ExampleImagesProcessor:
return model_success, False # (success, is_metadata_stale) return model_success, False # (success, is_metadata_stale)
@staticmethod @staticmethod
async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, independent_session): async def download_model_images_with_tracking(model_hash, model_name, model_images, model_dir, optimize, downloader):
"""Download images for a single model with tracking of failed image URLs """Download images for a single model with tracking of failed image URLs
Returns: Returns:
@@ -147,25 +149,27 @@ class ExampleImagesProcessor:
try: try:
logger.debug(f"Downloading {save_filename} for {model_name}") logger.debug(f"Downloading {save_filename} for {model_name}")
# Download directly using the independent session # Download using the unified downloader
async with independent_session.get(image_url, timeout=60) as response: success, content = await downloader.download_to_memory(
if response.status == 200: image_url,
with open(save_path, 'wb') as f: use_auth=False # Example images don't need auth
async for chunk in response.content.iter_chunked(8192): )
if chunk:
f.write(chunk) if success:
elif response.status == 404: with open(save_path, 'wb') as f:
error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale" f.write(content)
logger.warning(error_msg) elif "404" in str(content):
model_success = False # Mark the model as failed due to 404 error error_msg = f"Failed to download file: {image_url}, status code: 404 - Model metadata might be stale"
failed_images.append(image_url) # Track failed URL logger.warning(error_msg)
# Return early to trigger metadata refresh attempt model_success = False # Mark the model as failed due to 404 error
return False, True, failed_images # (success, is_metadata_stale, failed_images) failed_images.append(image_url) # Track failed URL
else: # Return early to trigger metadata refresh attempt
error_msg = f"Failed to download file: {image_url}, status code: {response.status}" return False, True, failed_images # (success, is_metadata_stale, failed_images)
logger.warning(error_msg) else:
model_success = False # Mark the model as failed error_msg = f"Failed to download file: {image_url}, error: {content}"
failed_images.append(image_url) # Track failed URL logger.warning(error_msg)
model_success = False # Mark the model as failed
failed_images.append(image_url) # Track failed URL
except Exception as e: except Exception as e:
error_msg = f"Error downloading file {image_url}: {str(e)}" error_msg = f"Error downloading file {image_url}: {str(e)}"
logger.error(error_msg) logger.error(error_msg)