Merge pull request #463 from willmiao/codex/refactor-downloadmanager-to-instance-based

refactor: convert example image download manager to service instance
This commit is contained in:
pixelpaws
2025-09-23 13:08:08 +08:00
committed by GitHub
3 changed files with 225 additions and 263 deletions

View File

@@ -16,7 +16,10 @@ from ..services.use_cases.example_images import (
DownloadExampleImagesUseCase, DownloadExampleImagesUseCase,
ImportExampleImagesUseCase, ImportExampleImagesUseCase,
) )
from ..utils.example_images_download_manager import DownloadManager from ..utils.example_images_download_manager import (
DownloadManager,
get_default_download_manager,
)
from ..utils.example_images_file_manager import ExampleImagesFileManager from ..utils.example_images_file_manager import ExampleImagesFileManager
from ..utils.example_images_processor import ExampleImagesProcessor from ..utils.example_images_processor import ExampleImagesProcessor
@@ -29,11 +32,11 @@ class ExampleImagesRoutes:
def __init__( def __init__(
self, self,
*, *,
download_manager=DownloadManager, download_manager: DownloadManager | None = None,
processor=ExampleImagesProcessor, processor=ExampleImagesProcessor,
file_manager=ExampleImagesFileManager, file_manager=ExampleImagesFileManager,
) -> None: ) -> None:
self._download_manager = download_manager self._download_manager = download_manager or get_default_download_manager()
self._processor = processor self._processor = processor
self._file_manager = file_manager self._file_manager = file_manager
self._handler_set: ExampleImagesHandlerSet | None = None self._handler_set: ExampleImagesHandlerSet | None = None

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import logging import logging
import os import os
import asyncio import asyncio
@@ -37,165 +39,150 @@ class DownloadConfigurationError(ExampleImagesDownloadError):
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Download status tracking
download_task = None
is_downloading = False
download_progress = {
'total': 0,
'completed': 0,
'current_model': '',
'status': 'idle', # idle, running, paused, completed, error
'errors': [],
'last_error': None,
'start_time': None,
'end_time': None,
'processed_models': set(), # Track models that have been processed
'refreshed_models': set(), # Track models that had metadata refreshed
'failed_models': set() # Track models that failed to download after metadata refresh
}
class _DownloadProgress(dict):
"""Mutable mapping maintaining download progress with set-aware serialisation."""
def _serialize_progress() -> dict: def __init__(self) -> None:
"""Return a JSON-serialisable snapshot of the current progress.""" super().__init__()
self.reset()
snapshot = download_progress.copy() def reset(self) -> None:
snapshot['processed_models'] = list(download_progress['processed_models']) """Reset the progress dictionary to its initial state."""
snapshot['refreshed_models'] = list(download_progress['refreshed_models'])
snapshot['failed_models'] = list(download_progress['failed_models']) self.update(
return snapshot total=0,
completed=0,
current_model='',
status='idle',
errors=[],
last_error=None,
start_time=None,
end_time=None,
processed_models=set(),
refreshed_models=set(),
failed_models=set(),
)
def snapshot(self) -> dict:
"""Return a JSON-serialisable snapshot of the current progress."""
snapshot = dict(self)
snapshot['processed_models'] = list(self['processed_models'])
snapshot['refreshed_models'] = list(self['refreshed_models'])
snapshot['failed_models'] = list(self['failed_models'])
return snapshot
class DownloadManager: class DownloadManager:
"""Manages downloading example images for models""" """Manages downloading example images for models."""
@staticmethod def __init__(self) -> None:
async def start_download(options: dict): self._download_task: asyncio.Task | None = None
""" self._is_downloading = False
Start downloading example images for models self._progress = _DownloadProgress()
Expects a JSON body with: async def start_download(self, options: dict):
{ """Start downloading example images for models."""
"optimize": true, # Whether to optimize images (default: true)
"model_types": ["lora", "checkpoint"], # Model types to process (default: both) if self._is_downloading:
"delay": 1.0, # Delay between downloads to avoid rate limiting (default: 1.0) raise DownloadInProgressError(self._progress.snapshot())
"auto_mode": false # Flag to indicate automatic download (default: false)
}
"""
global download_task, is_downloading, download_progress
if is_downloading:
raise DownloadInProgressError(_serialize_progress())
try: try:
data = options or {} data = options or {}
auto_mode = data.get('auto_mode', False) auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True) optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint']) model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds delay = float(data.get('delay', 0.2))
# Get output directory from settings
output_dir = settings.get('example_images_path') output_dir = settings.get('example_images_path')
if not output_dir: if not output_dir:
error_msg = 'Example images path not configured in settings' error_msg = 'Example images path not configured in settings'
if auto_mode: if auto_mode:
# For auto mode, just log and return success to avoid showing error toasts
logger.debug(error_msg) logger.debug(error_msg)
return { return {
'success': True, 'success': True,
'message': 'Example images path not configured, skipping auto download' 'message': 'Example images path not configured, skipping auto download'
} }
raise DownloadConfigurationError(error_msg) raise DownloadConfigurationError(error_msg)
# Create the output directory
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Initialize progress tracking self._progress.reset()
download_progress['total'] = 0 self._progress['status'] = 'running'
download_progress['completed'] = 0 self._progress['start_time'] = time.time()
download_progress['current_model'] = '' self._progress['end_time'] = None
download_progress['status'] = 'running'
download_progress['errors'] = []
download_progress['last_error'] = None
download_progress['start_time'] = time.time()
download_progress['end_time'] = None
# Get the processed models list from a file if it exists
progress_file = os.path.join(output_dir, '.download_progress.json') progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file): if os.path.exists(progress_file):
try: try:
with open(progress_file, 'r', encoding='utf-8') as f: with open(progress_file, 'r', encoding='utf-8') as f:
saved_progress = json.load(f) saved_progress = json.load(f)
download_progress['processed_models'] = set(saved_progress.get('processed_models', [])) self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
download_progress['failed_models'] = set(saved_progress.get('failed_models', [])) self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
logger.debug(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed, {len(download_progress['failed_models'])} models marked as failed") logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress['processed_models']),
len(self._progress['failed_models']),
)
except Exception as e: except Exception as e:
logger.error(f"Failed to load progress file: {e}") logger.error(f"Failed to load progress file: {e}")
download_progress['processed_models'] = set() self._progress['processed_models'] = set()
download_progress['failed_models'] = set() self._progress['failed_models'] = set()
else: else:
download_progress['processed_models'] = set() self._progress['processed_models'] = set()
download_progress['failed_models'] = set() self._progress['failed_models'] = set()
# Start the download task self._is_downloading = True
is_downloading = True self._download_task = asyncio.create_task(
download_task = asyncio.create_task( self._download_all_example_images(
DownloadManager._download_all_example_images( output_dir,
output_dir, optimize,
optimize,
model_types, model_types,
delay delay
) )
) )
return { return {
'success': True, 'success': True,
'message': 'Download started', 'message': 'Download started',
'status': _serialize_progress() 'status': self._progress.snapshot()
} }
except Exception as e: except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True) logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e raise ExampleImagesDownloadError(str(e)) from e
@staticmethod async def get_status(self, request):
async def get_status(request): """Get the current status of example images download."""
"""Get the current status of example images download"""
global download_progress
# Create a copy of the progress dict with the set converted to a list for JSON serialization
response_progress = _serialize_progress()
return { return {
'success': True, 'success': True,
'is_downloading': is_downloading, 'is_downloading': self._is_downloading,
'status': response_progress 'status': self._progress.snapshot(),
} }
@staticmethod async def pause_download(self, request):
async def pause_download(request): """Pause the example images download."""
"""Pause the example images download"""
global download_progress if not self._is_downloading:
if not is_downloading:
raise DownloadNotRunningError() raise DownloadNotRunningError()
download_progress['status'] = 'paused' self._progress['status'] = 'paused'
return { return {
'success': True, 'success': True,
'message': 'Download paused' 'message': 'Download paused'
} }
@staticmethod async def resume_download(self, request):
async def resume_download(request): """Resume the example images download."""
"""Resume the example images download"""
global download_progress if not self._is_downloading:
if not is_downloading:
raise DownloadNotRunningError() raise DownloadNotRunningError()
if download_progress['status'] == 'paused': if self._progress['status'] == 'paused':
download_progress['status'] = 'running' self._progress['status'] = 'running'
return { return {
'success': True, 'success': True,
@@ -203,15 +190,12 @@ class DownloadManager:
} }
raise DownloadNotRunningError( raise DownloadNotRunningError(
f"Download is in '{download_progress['status']}' state, cannot resume" f"Download is in '{self._progress['status']}' state, cannot resume"
) )
@staticmethod async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
async def _download_all_example_images(output_dir, optimize, model_types, delay): """Download example images for all models."""
"""Download example images for all models"""
global is_downloading, download_progress
# Get unified downloader
downloader = await get_downloader() downloader = await get_downloader()
try: try:
@@ -239,59 +223,58 @@ class DownloadManager:
all_models.append((scanner_type, model, scanner)) all_models.append((scanner_type, model, scanner))
# Update total count # Update total count
download_progress['total'] = len(all_models) self._progress['total'] = len(all_models)
logger.debug(f"Found {download_progress['total']} models to process") logger.debug(f"Found {self._progress['total']} models to process")
# Process each model # Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models): for i, (scanner_type, model, scanner) in enumerate(all_models):
# Main logic for processing model is here, but actual operations are delegated to other classes # Main logic for processing model is here, but actual operations are delegated to other classes
was_remote_download = await DownloadManager._process_model( was_remote_download = await self._process_model(
scanner_type, model, scanner, scanner_type, model, scanner,
output_dir, optimize, downloader output_dir, optimize, downloader
) )
# Update progress # Update progress
download_progress['completed'] += 1 self._progress['completed'] += 1
# Only add delay after remote download of models, and not after processing the last model # Only add delay after remote download of models, and not after processing the last model
if was_remote_download and i < len(all_models) - 1 and download_progress['status'] == 'running': if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
await asyncio.sleep(delay) await asyncio.sleep(delay)
# Mark as completed # Mark as completed
download_progress['status'] = 'completed' self._progress['status'] = 'completed'
download_progress['end_time'] = time.time() self._progress['end_time'] = time.time()
logger.debug(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
except Exception as e: except Exception as e:
error_msg = f"Error during example images download: {str(e)}" error_msg = f"Error during example images download: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg) self._progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg self._progress['last_error'] = error_msg
download_progress['status'] = 'error' self._progress['status'] = 'error'
download_progress['end_time'] = time.time() self._progress['end_time'] = time.time()
finally: finally:
# Save final progress to file # Save final progress to file
try: try:
DownloadManager._save_progress(output_dir) self._save_progress(output_dir)
except Exception as e: except Exception as e:
logger.error(f"Failed to save progress file: {e}") logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading # Set download status to not downloading
is_downloading = False self._is_downloading = False
self._download_task = None
@staticmethod async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader): """Process a single model download."""
"""Process a single model download"""
global download_progress
# Check if download is paused # Check if download is paused
while download_progress['status'] == 'paused': while self._progress['status'] == 'paused':
await asyncio.sleep(1) await asyncio.sleep(1)
# Check if download should continue # Check if download should continue
if download_progress['status'] != 'running': if self._progress['status'] != 'running':
logger.info(f"Download stopped: {download_progress['status']}") logger.info(f"Download stopped: {self._progress['status']}")
return False # Return False to indicate no remote download happened return False # Return False to indicate no remote download happened
model_hash = model.get('sha256', '').lower() model_hash = model.get('sha256', '').lower()
@@ -301,15 +284,15 @@ class DownloadManager:
try: try:
# Update current model info # Update current model info
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
# Skip if already in failed models # Skip if already in failed models
if model_hash in download_progress['failed_models']: if model_hash in self._progress['failed_models']:
logger.debug(f"Skipping known failed model: {model_name}") logger.debug(f"Skipping known failed model: {model_name}")
return False return False
# Skip if already processed AND directory exists with files # Skip if already processed AND directory exists with files
if model_hash in download_progress['processed_models']: if model_hash in self._progress['processed_models']:
model_dir = os.path.join(output_dir, model_hash) model_dir = os.path.join(output_dir, model_hash)
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir)) has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
if has_files: if has_files:
@@ -318,7 +301,7 @@ class DownloadManager:
else: else:
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing") logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
# Remove from processed models since we need to reprocess # Remove from processed models since we need to reprocess
download_progress['processed_models'].discard(model_hash) self._progress['processed_models'].discard(model_hash)
# Create model directory # Create model directory
model_dir = os.path.join(output_dir, model_hash) model_dir = os.path.join(output_dir, model_hash)
@@ -334,7 +317,7 @@ class DownloadManager:
await MetadataUpdater.update_metadata_from_local_examples( await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir model_hash, model, scanner_type, scanner, model_dir
) )
download_progress['processed_models'].add(model_hash) self._progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened return False # Return False to indicate no remote download happened
# If no local images, try to download from remote # If no local images, try to download from remote
@@ -346,57 +329,55 @@ class DownloadManager:
) )
# If metadata is stale, try to refresh it # If metadata is stale, try to refresh it
if is_stale and model_hash not in download_progress['refreshed_models']: if is_stale and model_hash not in self._progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata( await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner model_hash, model_name, scanner_type, scanner, self._progress
) )
# Get the updated model data # Get the updated model data
updated_model = await MetadataUpdater.get_updated_model( updated_model = await MetadataUpdater.get_updated_model(
model_hash, scanner model_hash, scanner
) )
if updated_model and updated_model.get('civitai', {}).get('images'): if updated_model and updated_model.get('civitai', {}).get('images'):
# Retry download with updated metadata # Retry download with updated metadata
updated_images = updated_model.get('civitai', {}).get('images', []) updated_images = updated_model.get('civitai', {}).get('images', [])
success, _ = await ExampleImagesProcessor.download_model_images( success, _ = await ExampleImagesProcessor.download_model_images(
model_hash, model_name, updated_images, model_dir, optimize, downloader model_hash, model_name, updated_images, model_dir, optimize, downloader
) )
download_progress['refreshed_models'].add(model_hash) self._progress['refreshed_models'].add(model_hash)
# Mark as processed if successful, or as failed if unsuccessful after refresh # Mark as processed if successful, or as failed if unsuccessful after refresh
if success: if success:
download_progress['processed_models'].add(model_hash) self._progress['processed_models'].add(model_hash)
else: else:
# If we refreshed metadata and still failed, mark as permanently failed # If we refreshed metadata and still failed, mark as permanently failed
if model_hash in download_progress['refreshed_models']: if model_hash in self._progress['refreshed_models']:
download_progress['failed_models'].add(model_hash) self._progress['failed_models'].add(model_hash)
logger.info(f"Marking model {model_name} as failed after metadata refresh") logger.info(f"Marking model {model_name} as failed after metadata refresh")
return True # Return True to indicate a remote download happened return True # Return True to indicate a remote download happened
else: else:
# No civitai data or images available, mark as failed to avoid future attempts # No civitai data or images available, mark as failed to avoid future attempts
download_progress['failed_models'].add(model_hash) self._progress['failed_models'].add(model_hash)
logger.debug(f"No civitai images available for model {model_name}, marking as failed") logger.debug(f"No civitai images available for model {model_name}, marking as failed")
# Save progress periodically # Save progress periodically
if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1: if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1:
DownloadManager._save_progress(output_dir) self._save_progress(output_dir)
return False # Default return if no conditions met return False # Default return if no conditions met
except Exception as e: except Exception as e:
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg) self._progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg self._progress['last_error'] = error_msg
return False # Return False on exception return False # Return False on exception
@staticmethod def _save_progress(self, output_dir):
def _save_progress(output_dir): """Save download progress to file."""
"""Save download progress to file"""
global download_progress
try: try:
progress_file = os.path.join(output_dir, '.download_progress.json') progress_file = os.path.join(output_dir, '.download_progress.json')
@@ -411,11 +392,11 @@ class DownloadManager:
# Create new progress data # Create new progress data
progress_data = { progress_data = {
'processed_models': list(download_progress['processed_models']), 'processed_models': list(self._progress['processed_models']),
'refreshed_models': list(download_progress['refreshed_models']), 'refreshed_models': list(self._progress['refreshed_models']),
'failed_models': list(download_progress['failed_models']), 'failed_models': list(self._progress['failed_models']),
'completed': download_progress['completed'], 'completed': self._progress['completed'],
'total': download_progress['total'], 'total': self._progress['total'],
'last_update': time.time() 'last_update': time.time()
} }
@@ -430,70 +411,46 @@ class DownloadManager:
except Exception as e: except Exception as e:
logger.error(f"Failed to save progress file: {e}") logger.error(f"Failed to save progress file: {e}")
@staticmethod async def start_force_download(self, options: dict):
async def start_force_download(options: dict): """Force download example images for specific models."""
"""
Force download example images for specific models
Expects a JSON body with:
{
"model_hashes": ["hash1", "hash2", ...], # List of model hashes to download
"optimize": true, # Whether to optimize images (default: true)
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
"delay": 1.0 # Delay between downloads (default: 1.0)
}
"""
global download_task, is_downloading, download_progress
if is_downloading: if self._is_downloading:
raise DownloadInProgressError(_serialize_progress()) raise DownloadInProgressError(self._progress.snapshot())
try: try:
data = options or {} data = options or {}
model_hashes = data.get('model_hashes', []) model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True) optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint']) model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds delay = float(data.get('delay', 0.2))
if not model_hashes: if not model_hashes:
raise DownloadConfigurationError('Missing model_hashes parameter') raise DownloadConfigurationError('Missing model_hashes parameter')
# Get output directory from settings
output_dir = settings.get('example_images_path') output_dir = settings.get('example_images_path')
if not output_dir: if not output_dir:
raise DownloadConfigurationError('Example images path not configured in settings') raise DownloadConfigurationError('Example images path not configured in settings')
# Create the output directory
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# Initialize progress tracking
download_progress['total'] = len(model_hashes)
download_progress['completed'] = 0
download_progress['current_model'] = ''
download_progress['status'] = 'running'
download_progress['errors'] = []
download_progress['last_error'] = None
download_progress['start_time'] = time.time()
download_progress['end_time'] = None
download_progress['processed_models'] = set()
download_progress['refreshed_models'] = set()
download_progress['failed_models'] = set()
# Set download status to downloading self._progress.reset()
is_downloading = True self._progress['total'] = len(model_hashes)
self._progress['status'] = 'running'
self._progress['start_time'] = time.time()
self._progress['end_time'] = None
# Execute the download function directly instead of creating a background task self._is_downloading = True
result = await DownloadManager._download_specific_models_example_images_sync(
result = await self._download_specific_models_example_images_sync(
model_hashes, model_hashes,
output_dir, output_dir,
optimize, optimize,
model_types, model_types,
delay delay
) )
# Set download status to not downloading self._is_downloading = False
is_downloading = False
return { return {
'success': True, 'success': True,
@@ -502,17 +459,13 @@ class DownloadManager:
} }
except Exception as e: except Exception as e:
# Set download status to not downloading self._is_downloading = False
is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True) logger.error(f"Failed during forced example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e raise ExampleImagesDownloadError(str(e)) from e
@staticmethod async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay): """Download example images for specific models only - synchronous version."""
"""Download example images for specific models only - synchronous version"""
global download_progress
# Get unified downloader
downloader = await get_downloader() downloader = await get_downloader()
try: try:
@@ -540,14 +493,14 @@ class DownloadManager:
models_to_process.append((scanner_type, model, scanner)) models_to_process.append((scanner_type, model, scanner))
# Update total count based on found models # Update total count based on found models
download_progress['total'] = len(models_to_process) self._progress['total'] = len(models_to_process)
logger.debug(f"Found {download_progress['total']} models to process") logger.debug(f"Found {self._progress['total']} models to process")
# Send initial progress via WebSocket # Send initial progress via WebSocket
await ws_manager.broadcast({ await ws_manager.broadcast({
'type': 'example_images_progress', 'type': 'example_images_progress',
'processed': 0, 'processed': 0,
'total': download_progress['total'], 'total': self._progress['total'],
'status': 'running', 'status': 'running',
'current_model': '' 'current_model': ''
}) })
@@ -556,8 +509,8 @@ class DownloadManager:
success_count = 0 success_count = 0
for i, (scanner_type, model, scanner) in enumerate(models_to_process): for i, (scanner_type, model, scanner) in enumerate(models_to_process):
# Force process this model regardless of previous status # Force process this model regardless of previous status
was_successful = await DownloadManager._process_specific_model( was_successful = await self._process_specific_model(
scanner_type, model, scanner, scanner_type, model, scanner,
output_dir, optimize, downloader output_dir, optimize, downloader
) )
@@ -565,55 +518,55 @@ class DownloadManager:
success_count += 1 success_count += 1
# Update progress # Update progress
download_progress['completed'] += 1 self._progress['completed'] += 1
# Send progress update via WebSocket # Send progress update via WebSocket
await ws_manager.broadcast({ await ws_manager.broadcast({
'type': 'example_images_progress', 'type': 'example_images_progress',
'processed': download_progress['completed'], 'processed': self._progress['completed'],
'total': download_progress['total'], 'total': self._progress['total'],
'status': 'running', 'status': 'running',
'current_model': download_progress['current_model'] 'current_model': self._progress['current_model']
}) })
# Only add delay after remote download, and not after processing the last model # Only add delay after remote download, and not after processing the last model
if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running': if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
await asyncio.sleep(delay) await asyncio.sleep(delay)
# Mark as completed # Mark as completed
download_progress['status'] = 'completed' self._progress['status'] = 'completed'
download_progress['end_time'] = time.time() self._progress['end_time'] = time.time()
logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed") logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
# Send final progress via WebSocket # Send final progress via WebSocket
await ws_manager.broadcast({ await ws_manager.broadcast({
'type': 'example_images_progress', 'type': 'example_images_progress',
'processed': download_progress['completed'], 'processed': self._progress['completed'],
'total': download_progress['total'], 'total': self._progress['total'],
'status': 'completed', 'status': 'completed',
'current_model': '' 'current_model': ''
}) })
return { return {
'total': download_progress['total'], 'total': self._progress['total'],
'processed': download_progress['completed'], 'processed': self._progress['completed'],
'successful': success_count, 'successful': success_count,
'errors': download_progress['errors'] 'errors': self._progress['errors']
} }
except Exception as e: except Exception as e:
error_msg = f"Error during forced example images download: {str(e)}" error_msg = f"Error during forced example images download: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg) self._progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg self._progress['last_error'] = error_msg
download_progress['status'] = 'error' self._progress['status'] = 'error'
download_progress['end_time'] = time.time() self._progress['end_time'] = time.time()
# Send error status via WebSocket # Send error status via WebSocket
await ws_manager.broadcast({ await ws_manager.broadcast({
'type': 'example_images_progress', 'type': 'example_images_progress',
'processed': download_progress['completed'], 'processed': self._progress['completed'],
'total': download_progress['total'], 'total': self._progress['total'],
'status': 'error', 'status': 'error',
'error': error_msg, 'error': error_msg,
'current_model': '' 'current_model': ''
@@ -625,18 +578,16 @@ class DownloadManager:
# No need to close any sessions since we use the global downloader # No need to close any sessions since we use the global downloader
pass pass
@staticmethod async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader): """Process a specific model for forced download, ignoring previous download status."""
"""Process a specific model for forced download, ignoring previous download status"""
global download_progress
# Check if download is paused # Check if download is paused
while download_progress['status'] == 'paused': while self._progress['status'] == 'paused':
await asyncio.sleep(1) await asyncio.sleep(1)
# Check if download should continue # Check if download should continue
if download_progress['status'] != 'running': if self._progress['status'] != 'running':
logger.info(f"Download stopped: {download_progress['status']}") logger.info(f"Download stopped: {self._progress['status']}")
return False return False
model_hash = model.get('sha256', '').lower() model_hash = model.get('sha256', '').lower()
@@ -646,7 +597,7 @@ class DownloadManager:
try: try:
# Update current model info # Update current model info
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})" self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
# Create model directory # Create model directory
model_dir = os.path.join(output_dir, model_hash) model_dir = os.path.join(output_dir, model_hash)
@@ -662,7 +613,7 @@ class DownloadManager:
await MetadataUpdater.update_metadata_from_local_examples( await MetadataUpdater.update_metadata_from_local_examples(
model_hash, model, scanner_type, scanner, model_dir model_hash, model, scanner_type, scanner, model_dir
) )
download_progress['processed_models'].add(model_hash) self._progress['processed_models'].add(model_hash)
return False # Return False to indicate no remote download happened return False # Return False to indicate no remote download happened
# If no local images, try to download from remote # If no local images, try to download from remote
@@ -674,9 +625,9 @@ class DownloadManager:
) )
# If metadata is stale, try to refresh it # If metadata is stale, try to refresh it
if is_stale and model_hash not in download_progress['refreshed_models']: if is_stale and model_hash not in self._progress['refreshed_models']:
await MetadataUpdater.refresh_model_metadata( await MetadataUpdater.refresh_model_metadata(
model_hash, model_name, scanner_type, scanner model_hash, model_name, scanner_type, scanner, self._progress
) )
# Get the updated model data # Get the updated model data
@@ -694,18 +645,18 @@ class DownloadManager:
# Combine failed images from both attempts # Combine failed images from both attempts
failed_images.extend(additional_failed_images) failed_images.extend(additional_failed_images)
download_progress['refreshed_models'].add(model_hash) self._progress['refreshed_models'].add(model_hash)
# For forced downloads, remove failed images from metadata # For forced downloads, remove failed images from metadata
if failed_images: if failed_images:
# Create a copy of images excluding failed ones # Create a copy of images excluding failed ones
await DownloadManager._remove_failed_images_from_metadata( await self._remove_failed_images_from_metadata(
model_hash, model_name, failed_images, scanner model_hash, model_name, failed_images, scanner
) )
# Mark as processed # Mark as processed
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
download_progress['processed_models'].add(model_hash) self._progress['processed_models'].add(model_hash)
return True # Return True to indicate a remote download happened return True # Return True to indicate a remote download happened
else: else:
@@ -715,12 +666,11 @@ class DownloadManager:
except Exception as e: except Exception as e:
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}" error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg) self._progress['errors'].append(error_msg)
download_progress['last_error'] = error_msg self._progress['last_error'] = error_msg
return False # Return False on exception return False # Return False on exception
@staticmethod async def _remove_failed_images_from_metadata(self, model_hash, model_name, failed_images, scanner):
async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner):
"""Remove failed images from model metadata""" """Remove failed images from model metadata"""
try: try:
# Get current model data # Get current model data
@@ -762,4 +712,13 @@ class DownloadManager:
await scanner.update_single_model_cache(file_path, file_path, model_data) await scanner.update_single_model_cache(file_path, file_path, model_data)
except Exception as e: except Exception as e:
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
default_download_manager = DownloadManager()
def get_default_download_manager() -> DownloadManager:
"""Return the singleton download manager used by default routes."""
return default_download_manager

View File

@@ -33,7 +33,7 @@ class MetadataUpdater:
"""Handles updating model metadata related to example images""" """Handles updating model metadata related to example images"""
@staticmethod @staticmethod
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner): async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None):
"""Refresh model metadata from CivitAI """Refresh model metadata from CivitAI
Args: Args:
@@ -45,8 +45,6 @@ class MetadataUpdater:
Returns: Returns:
bool: True if metadata was successfully refreshed, False otherwise bool: True if metadata was successfully refreshed, False otherwise
""" """
from ..utils.example_images_download_manager import download_progress
try: try:
# Find the model in the scanner cache # Find the model in the scanner cache
cache = await scanner.get_cached_data() cache = await scanner.get_cached_data()
@@ -67,7 +65,8 @@ class MetadataUpdater:
return False return False
# Track that we're refreshing this model # Track that we're refreshing this model
download_progress['refreshed_models'].add(model_hash) if progress is not None:
progress['refreshed_models'].add(model_hash)
async def update_cache_func(old_path, new_path, metadata): async def update_cache_func(old_path, new_path, metadata):
return await scanner.update_single_model_cache(old_path, new_path, metadata) return await scanner.update_single_model_cache(old_path, new_path, metadata)
@@ -85,12 +84,13 @@ class MetadataUpdater:
else: else:
logger.warning(f"Failed to refresh metadata for {model_name}, {error}") logger.warning(f"Failed to refresh metadata for {model_name}, {error}")
return False return False
except Exception as e: except Exception as e:
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}" error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
download_progress['errors'].append(error_msg) if progress is not None:
download_progress['last_error'] = error_msg progress['errors'].append(error_msg)
progress['last_error'] = error_msg
return False return False
@staticmethod @staticmethod