fix: create independent session for downloading example images to prevent interference

This commit is contained in:
Will Miao
2025-04-30 13:35:12 +08:00
parent 26d9a9caa6
commit f36febf10a

View File

@@ -296,9 +296,22 @@ class MiscRoutes:
""" """
global is_downloading, download_progress global is_downloading, download_progress
# Get CivitAI client with proxy support # Create an independent session for downloading example images
civitai_client = await ServiceRegistry.get_civitai_client() # This avoids interference with the CivitAI client's session
session = await civitai_client.session connector = aiohttp.TCPConnector(
ssl=True,
limit=3,
force_close=False,
enable_cleanup_closed=True
)
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
# Create a dedicated session just for this download task
independent_session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=timeout
)
try: try:
# Get the scanners # Get the scanners
@@ -336,6 +349,8 @@ class MiscRoutes:
logger.info(f"Download stopped: {download_progress['status']}") logger.info(f"Download stopped: {download_progress['status']}")
break break
model_success = True # Track if all images for this model download successfully
try: try:
# Update current model info # Update current model info
model_hash = model.get('sha256', '').lower() model_hash = model.get('sha256', '').lower()
@@ -391,8 +406,8 @@ class MiscRoutes:
try: try:
logger.debug(f"Downloading {save_filename} for {model_name}") logger.debug(f"Downloading {save_filename} for {model_name}")
# Direct download using session from CivitAI client # Direct download using the independent session
async with session.get(image_url, timeout=60) as response: async with independent_session.get(image_url, timeout=60) as response:
if response.status == 200: if response.status == 200:
if is_image and optimize: if is_image and optimize:
# For images, optimize if requested # For images, optimize if requested
@@ -420,7 +435,11 @@ class MiscRoutes:
if chunk: if chunk:
f.write(chunk) f.write(chunk)
else: else:
logger.warning(f"Failed to download file: {image_url}, status code: {response.status}") error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
logger.warning(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
model_success = False # Mark model as failed
# Add a delay between downloads # Add a delay between downloads
await asyncio.sleep(delay) await asyncio.sleep(delay)
@@ -429,10 +448,13 @@ class MiscRoutes:
logger.error(error_msg) logger.error(error_msg)
download_progress['errors'].append(error_msg) download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg download_progress['last_error'] = error_msg
# Continue with next file model_success = False # Mark model as failed
# Mark model as processed # Only mark model as processed if all images downloaded successfully
download_progress['processed_models'].add(model_hash) if model_success:
download_progress['processed_models'].add(model_hash)
else:
logger.warning(f"Model {model_name} had download errors, will not mark as completed")
# Save progress to file periodically # Save progress to file periodically
if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1: if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1:
@@ -468,6 +490,12 @@ class MiscRoutes:
download_progress['end_time'] = time.time() download_progress['end_time'] = time.time()
finally: finally:
# Close the independent session
try:
await independent_session.close()
except Exception as e:
logger.error(f"Error closing download session: {e}")
# Save final progress to file # Save final progress to file
try: try:
progress_file = os.path.join(output_dir, '.download_progress.json') progress_file = os.path.join(output_dir, '.download_progress.json')