fix: create independent session for downloading example images to prevent interference

This commit is contained in:
Will Miao
2025-04-30 13:35:12 +08:00
parent 26d9a9caa6
commit f36febf10a

View File

@@ -296,9 +296,22 @@ class MiscRoutes:
"""
global is_downloading, download_progress
# Get CivitAI client with proxy support
civitai_client = await ServiceRegistry.get_civitai_client()
session = await civitai_client.session
# Create an independent session for downloading example images
# This avoids interference with the CivitAI client's session
connector = aiohttp.TCPConnector(
ssl=True,
limit=3,
force_close=False,
enable_cleanup_closed=True
)
timeout = aiohttp.ClientTimeout(total=None, connect=60, sock_read=60)
# Create a dedicated session just for this download task
independent_session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=timeout
)
try:
# Get the scanners
@@ -336,6 +349,8 @@ class MiscRoutes:
logger.info(f"Download stopped: {download_progress['status']}")
break
model_success = True # Track if all images for this model download successfully
try:
# Update current model info
model_hash = model.get('sha256', '').lower()
@@ -391,8 +406,8 @@ class MiscRoutes:
try:
logger.debug(f"Downloading {save_filename} for {model_name}")
# Direct download using session from CivitAI client
async with session.get(image_url, timeout=60) as response:
# Direct download using the independent session
async with independent_session.get(image_url, timeout=60) as response:
if response.status == 200:
if is_image and optimize:
# For images, optimize if requested
@@ -420,7 +435,11 @@ class MiscRoutes:
if chunk:
f.write(chunk)
else:
logger.warning(f"Failed to download file: {image_url}, status code: {response.status}")
error_msg = f"Failed to download file: {image_url}, status code: {response.status}"
logger.warning(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
model_success = False # Mark model as failed
# Add a delay between downloads
await asyncio.sleep(delay)
@@ -429,10 +448,13 @@ class MiscRoutes:
logger.error(error_msg)
download_progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg
# Continue with next file
model_success = False # Mark model as failed
# Mark model as processed
download_progress['processed_models'].add(model_hash)
# Only mark model as processed if all images downloaded successfully
if model_success:
download_progress['processed_models'].add(model_hash)
else:
logger.warning(f"Model {model_name} had download errors, will not mark as completed")
# Save progress to file periodically
if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1:
@@ -468,6 +490,12 @@ class MiscRoutes:
download_progress['end_time'] = time.time()
finally:
# Close the independent session
try:
await independent_session.close()
except Exception as e:
logger.error(f"Error closing download session: {e}")
# Save final progress to file
try:
progress_file = os.path.join(output_dir, '.download_progress.json')