mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
Merge pull request #463 from willmiao/codex/refactor-downloadmanager-to-instance-based
refactor: convert example image download manager to service instance
This commit is contained in:
@@ -16,7 +16,10 @@ from ..services.use_cases.example_images import (
|
|||||||
DownloadExampleImagesUseCase,
|
DownloadExampleImagesUseCase,
|
||||||
ImportExampleImagesUseCase,
|
ImportExampleImagesUseCase,
|
||||||
)
|
)
|
||||||
from ..utils.example_images_download_manager import DownloadManager
|
from ..utils.example_images_download_manager import (
|
||||||
|
DownloadManager,
|
||||||
|
get_default_download_manager,
|
||||||
|
)
|
||||||
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
from ..utils.example_images_file_manager import ExampleImagesFileManager
|
||||||
from ..utils.example_images_processor import ExampleImagesProcessor
|
from ..utils.example_images_processor import ExampleImagesProcessor
|
||||||
|
|
||||||
@@ -29,11 +32,11 @@ class ExampleImagesRoutes:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
download_manager=DownloadManager,
|
download_manager: DownloadManager | None = None,
|
||||||
processor=ExampleImagesProcessor,
|
processor=ExampleImagesProcessor,
|
||||||
file_manager=ExampleImagesFileManager,
|
file_manager=ExampleImagesFileManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._download_manager = download_manager
|
self._download_manager = download_manager or get_default_download_manager()
|
||||||
self._processor = processor
|
self._processor = processor
|
||||||
self._file_manager = file_manager
|
self._file_manager = file_manager
|
||||||
self._handler_set: ExampleImagesHandlerSet | None = None
|
self._handler_set: ExampleImagesHandlerSet | None = None
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -37,68 +39,66 @@ class DownloadConfigurationError(ExampleImagesDownloadError):
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Download status tracking
|
|
||||||
download_task = None
|
|
||||||
is_downloading = False
|
|
||||||
download_progress = {
|
|
||||||
'total': 0,
|
|
||||||
'completed': 0,
|
|
||||||
'current_model': '',
|
|
||||||
'status': 'idle', # idle, running, paused, completed, error
|
|
||||||
'errors': [],
|
|
||||||
'last_error': None,
|
|
||||||
'start_time': None,
|
|
||||||
'end_time': None,
|
|
||||||
'processed_models': set(), # Track models that have been processed
|
|
||||||
'refreshed_models': set(), # Track models that had metadata refreshed
|
|
||||||
'failed_models': set() # Track models that failed to download after metadata refresh
|
|
||||||
}
|
|
||||||
|
|
||||||
|
class _DownloadProgress(dict):
|
||||||
|
"""Mutable mapping maintaining download progress with set-aware serialisation."""
|
||||||
|
|
||||||
def _serialize_progress() -> dict:
|
def __init__(self) -> None:
|
||||||
"""Return a JSON-serialisable snapshot of the current progress."""
|
super().__init__()
|
||||||
|
self.reset()
|
||||||
|
|
||||||
snapshot = download_progress.copy()
|
def reset(self) -> None:
|
||||||
snapshot['processed_models'] = list(download_progress['processed_models'])
|
"""Reset the progress dictionary to its initial state."""
|
||||||
snapshot['refreshed_models'] = list(download_progress['refreshed_models'])
|
|
||||||
snapshot['failed_models'] = list(download_progress['failed_models'])
|
self.update(
|
||||||
return snapshot
|
total=0,
|
||||||
|
completed=0,
|
||||||
|
current_model='',
|
||||||
|
status='idle',
|
||||||
|
errors=[],
|
||||||
|
last_error=None,
|
||||||
|
start_time=None,
|
||||||
|
end_time=None,
|
||||||
|
processed_models=set(),
|
||||||
|
refreshed_models=set(),
|
||||||
|
failed_models=set(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def snapshot(self) -> dict:
|
||||||
|
"""Return a JSON-serialisable snapshot of the current progress."""
|
||||||
|
|
||||||
|
snapshot = dict(self)
|
||||||
|
snapshot['processed_models'] = list(self['processed_models'])
|
||||||
|
snapshot['refreshed_models'] = list(self['refreshed_models'])
|
||||||
|
snapshot['failed_models'] = list(self['failed_models'])
|
||||||
|
return snapshot
|
||||||
|
|
||||||
class DownloadManager:
|
class DownloadManager:
|
||||||
"""Manages downloading example images for models"""
|
"""Manages downloading example images for models."""
|
||||||
|
|
||||||
@staticmethod
|
def __init__(self) -> None:
|
||||||
async def start_download(options: dict):
|
self._download_task: asyncio.Task | None = None
|
||||||
"""
|
self._is_downloading = False
|
||||||
Start downloading example images for models
|
self._progress = _DownloadProgress()
|
||||||
|
|
||||||
Expects a JSON body with:
|
async def start_download(self, options: dict):
|
||||||
{
|
"""Start downloading example images for models."""
|
||||||
"optimize": true, # Whether to optimize images (default: true)
|
|
||||||
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
|
|
||||||
"delay": 1.0, # Delay between downloads to avoid rate limiting (default: 1.0)
|
|
||||||
"auto_mode": false # Flag to indicate automatic download (default: false)
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
global download_task, is_downloading, download_progress
|
|
||||||
|
|
||||||
if is_downloading:
|
if self._is_downloading:
|
||||||
raise DownloadInProgressError(_serialize_progress())
|
raise DownloadInProgressError(self._progress.snapshot())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = options or {}
|
data = options or {}
|
||||||
auto_mode = data.get('auto_mode', False)
|
auto_mode = data.get('auto_mode', False)
|
||||||
optimize = data.get('optimize', True)
|
optimize = data.get('optimize', True)
|
||||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
delay = float(data.get('delay', 0.2))
|
||||||
|
|
||||||
# Get output directory from settings
|
|
||||||
output_dir = settings.get('example_images_path')
|
output_dir = settings.get('example_images_path')
|
||||||
|
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
error_msg = 'Example images path not configured in settings'
|
error_msg = 'Example images path not configured in settings'
|
||||||
if auto_mode:
|
if auto_mode:
|
||||||
# For auto mode, just log and return success to avoid showing error toasts
|
|
||||||
logger.debug(error_msg)
|
logger.debug(error_msg)
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -106,40 +106,36 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
raise DownloadConfigurationError(error_msg)
|
raise DownloadConfigurationError(error_msg)
|
||||||
|
|
||||||
# Create the output directory
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# Initialize progress tracking
|
self._progress.reset()
|
||||||
download_progress['total'] = 0
|
self._progress['status'] = 'running'
|
||||||
download_progress['completed'] = 0
|
self._progress['start_time'] = time.time()
|
||||||
download_progress['current_model'] = ''
|
self._progress['end_time'] = None
|
||||||
download_progress['status'] = 'running'
|
|
||||||
download_progress['errors'] = []
|
|
||||||
download_progress['last_error'] = None
|
|
||||||
download_progress['start_time'] = time.time()
|
|
||||||
download_progress['end_time'] = None
|
|
||||||
|
|
||||||
# Get the processed models list from a file if it exists
|
|
||||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||||
if os.path.exists(progress_file):
|
if os.path.exists(progress_file):
|
||||||
try:
|
try:
|
||||||
with open(progress_file, 'r', encoding='utf-8') as f:
|
with open(progress_file, 'r', encoding='utf-8') as f:
|
||||||
saved_progress = json.load(f)
|
saved_progress = json.load(f)
|
||||||
download_progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
||||||
download_progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
||||||
logger.debug(f"Loaded previous progress, {len(download_progress['processed_models'])} models already processed, {len(download_progress['failed_models'])} models marked as failed")
|
logger.debug(
|
||||||
|
"Loaded previous progress, %s models already processed, %s models marked as failed",
|
||||||
|
len(self._progress['processed_models']),
|
||||||
|
len(self._progress['failed_models']),
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load progress file: {e}")
|
logger.error(f"Failed to load progress file: {e}")
|
||||||
download_progress['processed_models'] = set()
|
self._progress['processed_models'] = set()
|
||||||
download_progress['failed_models'] = set()
|
self._progress['failed_models'] = set()
|
||||||
else:
|
else:
|
||||||
download_progress['processed_models'] = set()
|
self._progress['processed_models'] = set()
|
||||||
download_progress['failed_models'] = set()
|
self._progress['failed_models'] = set()
|
||||||
|
|
||||||
# Start the download task
|
self._is_downloading = True
|
||||||
is_downloading = True
|
self._download_task = asyncio.create_task(
|
||||||
download_task = asyncio.create_task(
|
self._download_all_example_images(
|
||||||
DownloadManager._download_all_example_images(
|
|
||||||
output_dir,
|
output_dir,
|
||||||
optimize,
|
optimize,
|
||||||
model_types,
|
model_types,
|
||||||
@@ -150,52 +146,43 @@ class DownloadManager:
|
|||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
'message': 'Download started',
|
'message': 'Download started',
|
||||||
'status': _serialize_progress()
|
'status': self._progress.snapshot()
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to start example images download: {e}", exc_info=True)
|
logger.error(f"Failed to start example images download: {e}", exc_info=True)
|
||||||
raise ExampleImagesDownloadError(str(e)) from e
|
raise ExampleImagesDownloadError(str(e)) from e
|
||||||
|
|
||||||
@staticmethod
|
async def get_status(self, request):
|
||||||
async def get_status(request):
|
"""Get the current status of example images download."""
|
||||||
"""Get the current status of example images download"""
|
|
||||||
global download_progress
|
|
||||||
|
|
||||||
# Create a copy of the progress dict with the set converted to a list for JSON serialization
|
|
||||||
response_progress = _serialize_progress()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
'is_downloading': is_downloading,
|
'is_downloading': self._is_downloading,
|
||||||
'status': response_progress
|
'status': self._progress.snapshot(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
async def pause_download(self, request):
|
||||||
async def pause_download(request):
|
"""Pause the example images download."""
|
||||||
"""Pause the example images download"""
|
|
||||||
global download_progress
|
|
||||||
|
|
||||||
if not is_downloading:
|
if not self._is_downloading:
|
||||||
raise DownloadNotRunningError()
|
raise DownloadNotRunningError()
|
||||||
|
|
||||||
download_progress['status'] = 'paused'
|
self._progress['status'] = 'paused'
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
'message': 'Download paused'
|
'message': 'Download paused'
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
async def resume_download(self, request):
|
||||||
async def resume_download(request):
|
"""Resume the example images download."""
|
||||||
"""Resume the example images download"""
|
|
||||||
global download_progress
|
|
||||||
|
|
||||||
if not is_downloading:
|
if not self._is_downloading:
|
||||||
raise DownloadNotRunningError()
|
raise DownloadNotRunningError()
|
||||||
|
|
||||||
if download_progress['status'] == 'paused':
|
if self._progress['status'] == 'paused':
|
||||||
download_progress['status'] = 'running'
|
self._progress['status'] = 'running'
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -203,15 +190,12 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
raise DownloadNotRunningError(
|
raise DownloadNotRunningError(
|
||||||
f"Download is in '{download_progress['status']}' state, cannot resume"
|
f"Download is in '{self._progress['status']}' state, cannot resume"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
|
||||||
async def _download_all_example_images(output_dir, optimize, model_types, delay):
|
"""Download example images for all models."""
|
||||||
"""Download example images for all models"""
|
|
||||||
global is_downloading, download_progress
|
|
||||||
|
|
||||||
# Get unified downloader
|
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -239,59 +223,58 @@ class DownloadManager:
|
|||||||
all_models.append((scanner_type, model, scanner))
|
all_models.append((scanner_type, model, scanner))
|
||||||
|
|
||||||
# Update total count
|
# Update total count
|
||||||
download_progress['total'] = len(all_models)
|
self._progress['total'] = len(all_models)
|
||||||
logger.debug(f"Found {download_progress['total']} models to process")
|
logger.debug(f"Found {self._progress['total']} models to process")
|
||||||
|
|
||||||
# Process each model
|
# Process each model
|
||||||
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
||||||
# Main logic for processing model is here, but actual operations are delegated to other classes
|
# Main logic for processing model is here, but actual operations are delegated to other classes
|
||||||
was_remote_download = await DownloadManager._process_model(
|
was_remote_download = await self._process_model(
|
||||||
scanner_type, model, scanner,
|
scanner_type, model, scanner,
|
||||||
output_dir, optimize, downloader
|
output_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
download_progress['completed'] += 1
|
self._progress['completed'] += 1
|
||||||
|
|
||||||
# Only add delay after remote download of models, and not after processing the last model
|
# Only add delay after remote download of models, and not after processing the last model
|
||||||
if was_remote_download and i < len(all_models) - 1 and download_progress['status'] == 'running':
|
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
# Mark as completed
|
# Mark as completed
|
||||||
download_progress['status'] = 'completed'
|
self._progress['status'] = 'completed'
|
||||||
download_progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
logger.debug(f"Example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
|
logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error during example images download: {str(e)}"
|
error_msg = f"Error during example images download: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
download_progress['errors'].append(error_msg)
|
self._progress['errors'].append(error_msg)
|
||||||
download_progress['last_error'] = error_msg
|
self._progress['last_error'] = error_msg
|
||||||
download_progress['status'] = 'error'
|
self._progress['status'] = 'error'
|
||||||
download_progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Save final progress to file
|
# Save final progress to file
|
||||||
try:
|
try:
|
||||||
DownloadManager._save_progress(output_dir)
|
self._save_progress(output_dir)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save progress file: {e}")
|
logger.error(f"Failed to save progress file: {e}")
|
||||||
|
|
||||||
# Set download status to not downloading
|
# Set download status to not downloading
|
||||||
is_downloading = False
|
self._is_downloading = False
|
||||||
|
self._download_task = None
|
||||||
|
|
||||||
@staticmethod
|
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
|
||||||
async def _process_model(scanner_type, model, scanner, output_dir, optimize, downloader):
|
"""Process a single model download."""
|
||||||
"""Process a single model download"""
|
|
||||||
global download_progress
|
|
||||||
|
|
||||||
# Check if download is paused
|
# Check if download is paused
|
||||||
while download_progress['status'] == 'paused':
|
while self._progress['status'] == 'paused':
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
# Check if download should continue
|
# Check if download should continue
|
||||||
if download_progress['status'] != 'running':
|
if self._progress['status'] != 'running':
|
||||||
logger.info(f"Download stopped: {download_progress['status']}")
|
logger.info(f"Download stopped: {self._progress['status']}")
|
||||||
return False # Return False to indicate no remote download happened
|
return False # Return False to indicate no remote download happened
|
||||||
|
|
||||||
model_hash = model.get('sha256', '').lower()
|
model_hash = model.get('sha256', '').lower()
|
||||||
@@ -301,15 +284,15 @@ class DownloadManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Update current model info
|
# Update current model info
|
||||||
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||||
|
|
||||||
# Skip if already in failed models
|
# Skip if already in failed models
|
||||||
if model_hash in download_progress['failed_models']:
|
if model_hash in self._progress['failed_models']:
|
||||||
logger.debug(f"Skipping known failed model: {model_name}")
|
logger.debug(f"Skipping known failed model: {model_name}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Skip if already processed AND directory exists with files
|
# Skip if already processed AND directory exists with files
|
||||||
if model_hash in download_progress['processed_models']:
|
if model_hash in self._progress['processed_models']:
|
||||||
model_dir = os.path.join(output_dir, model_hash)
|
model_dir = os.path.join(output_dir, model_hash)
|
||||||
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
|
has_files = os.path.exists(model_dir) and any(os.listdir(model_dir))
|
||||||
if has_files:
|
if has_files:
|
||||||
@@ -318,7 +301,7 @@ class DownloadManager:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
|
logger.info(f"Model {model_name} marked as processed but folder empty or missing, reprocessing")
|
||||||
# Remove from processed models since we need to reprocess
|
# Remove from processed models since we need to reprocess
|
||||||
download_progress['processed_models'].discard(model_hash)
|
self._progress['processed_models'].discard(model_hash)
|
||||||
|
|
||||||
# Create model directory
|
# Create model directory
|
||||||
model_dir = os.path.join(output_dir, model_hash)
|
model_dir = os.path.join(output_dir, model_hash)
|
||||||
@@ -334,7 +317,7 @@ class DownloadManager:
|
|||||||
await MetadataUpdater.update_metadata_from_local_examples(
|
await MetadataUpdater.update_metadata_from_local_examples(
|
||||||
model_hash, model, scanner_type, scanner, model_dir
|
model_hash, model, scanner_type, scanner, model_dir
|
||||||
)
|
)
|
||||||
download_progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
return False # Return False to indicate no remote download happened
|
return False # Return False to indicate no remote download happened
|
||||||
|
|
||||||
# If no local images, try to download from remote
|
# If no local images, try to download from remote
|
||||||
@@ -346,9 +329,9 @@ class DownloadManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If metadata is stale, try to refresh it
|
# If metadata is stale, try to refresh it
|
||||||
if is_stale and model_hash not in download_progress['refreshed_models']:
|
if is_stale and model_hash not in self._progress['refreshed_models']:
|
||||||
await MetadataUpdater.refresh_model_metadata(
|
await MetadataUpdater.refresh_model_metadata(
|
||||||
model_hash, model_name, scanner_type, scanner
|
model_hash, model_name, scanner_type, scanner, self._progress
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the updated model data
|
# Get the updated model data
|
||||||
@@ -363,40 +346,38 @@ class DownloadManager:
|
|||||||
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
model_hash, model_name, updated_images, model_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
|
|
||||||
download_progress['refreshed_models'].add(model_hash)
|
self._progress['refreshed_models'].add(model_hash)
|
||||||
|
|
||||||
# Mark as processed if successful, or as failed if unsuccessful after refresh
|
# Mark as processed if successful, or as failed if unsuccessful after refresh
|
||||||
if success:
|
if success:
|
||||||
download_progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
else:
|
else:
|
||||||
# If we refreshed metadata and still failed, mark as permanently failed
|
# If we refreshed metadata and still failed, mark as permanently failed
|
||||||
if model_hash in download_progress['refreshed_models']:
|
if model_hash in self._progress['refreshed_models']:
|
||||||
download_progress['failed_models'].add(model_hash)
|
self._progress['failed_models'].add(model_hash)
|
||||||
logger.info(f"Marking model {model_name} as failed after metadata refresh")
|
logger.info(f"Marking model {model_name} as failed after metadata refresh")
|
||||||
|
|
||||||
return True # Return True to indicate a remote download happened
|
return True # Return True to indicate a remote download happened
|
||||||
else:
|
else:
|
||||||
# No civitai data or images available, mark as failed to avoid future attempts
|
# No civitai data or images available, mark as failed to avoid future attempts
|
||||||
download_progress['failed_models'].add(model_hash)
|
self._progress['failed_models'].add(model_hash)
|
||||||
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
|
logger.debug(f"No civitai images available for model {model_name}, marking as failed")
|
||||||
|
|
||||||
# Save progress periodically
|
# Save progress periodically
|
||||||
if download_progress['completed'] % 10 == 0 or download_progress['completed'] == download_progress['total'] - 1:
|
if self._progress['completed'] % 10 == 0 or self._progress['completed'] == self._progress['total'] - 1:
|
||||||
DownloadManager._save_progress(output_dir)
|
self._save_progress(output_dir)
|
||||||
|
|
||||||
return False # Default return if no conditions met
|
return False # Default return if no conditions met
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
|
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
download_progress['errors'].append(error_msg)
|
self._progress['errors'].append(error_msg)
|
||||||
download_progress['last_error'] = error_msg
|
self._progress['last_error'] = error_msg
|
||||||
return False # Return False on exception
|
return False # Return False on exception
|
||||||
|
|
||||||
@staticmethod
|
def _save_progress(self, output_dir):
|
||||||
def _save_progress(output_dir):
|
"""Save download progress to file."""
|
||||||
"""Save download progress to file"""
|
|
||||||
global download_progress
|
|
||||||
try:
|
try:
|
||||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||||
|
|
||||||
@@ -411,11 +392,11 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Create new progress data
|
# Create new progress data
|
||||||
progress_data = {
|
progress_data = {
|
||||||
'processed_models': list(download_progress['processed_models']),
|
'processed_models': list(self._progress['processed_models']),
|
||||||
'refreshed_models': list(download_progress['refreshed_models']),
|
'refreshed_models': list(self._progress['refreshed_models']),
|
||||||
'failed_models': list(download_progress['failed_models']),
|
'failed_models': list(self._progress['failed_models']),
|
||||||
'completed': download_progress['completed'],
|
'completed': self._progress['completed'],
|
||||||
'total': download_progress['total'],
|
'total': self._progress['total'],
|
||||||
'last_update': time.time()
|
'last_update': time.time()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -430,61 +411,38 @@ class DownloadManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save progress file: {e}")
|
logger.error(f"Failed to save progress file: {e}")
|
||||||
|
|
||||||
@staticmethod
|
async def start_force_download(self, options: dict):
|
||||||
async def start_force_download(options: dict):
|
"""Force download example images for specific models."""
|
||||||
"""
|
|
||||||
Force download example images for specific models
|
|
||||||
|
|
||||||
Expects a JSON body with:
|
if self._is_downloading:
|
||||||
{
|
raise DownloadInProgressError(self._progress.snapshot())
|
||||||
"model_hashes": ["hash1", "hash2", ...], # List of model hashes to download
|
|
||||||
"optimize": true, # Whether to optimize images (default: true)
|
|
||||||
"model_types": ["lora", "checkpoint"], # Model types to process (default: both)
|
|
||||||
"delay": 1.0 # Delay between downloads (default: 1.0)
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
global download_task, is_downloading, download_progress
|
|
||||||
|
|
||||||
if is_downloading:
|
|
||||||
raise DownloadInProgressError(_serialize_progress())
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = options or {}
|
data = options or {}
|
||||||
model_hashes = data.get('model_hashes', [])
|
model_hashes = data.get('model_hashes', [])
|
||||||
optimize = data.get('optimize', True)
|
optimize = data.get('optimize', True)
|
||||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
delay = float(data.get('delay', 0.2))
|
||||||
|
|
||||||
if not model_hashes:
|
if not model_hashes:
|
||||||
raise DownloadConfigurationError('Missing model_hashes parameter')
|
raise DownloadConfigurationError('Missing model_hashes parameter')
|
||||||
|
|
||||||
# Get output directory from settings
|
|
||||||
output_dir = settings.get('example_images_path')
|
output_dir = settings.get('example_images_path')
|
||||||
|
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||||
|
|
||||||
# Create the output directory
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# Initialize progress tracking
|
self._progress.reset()
|
||||||
download_progress['total'] = len(model_hashes)
|
self._progress['total'] = len(model_hashes)
|
||||||
download_progress['completed'] = 0
|
self._progress['status'] = 'running'
|
||||||
download_progress['current_model'] = ''
|
self._progress['start_time'] = time.time()
|
||||||
download_progress['status'] = 'running'
|
self._progress['end_time'] = None
|
||||||
download_progress['errors'] = []
|
|
||||||
download_progress['last_error'] = None
|
|
||||||
download_progress['start_time'] = time.time()
|
|
||||||
download_progress['end_time'] = None
|
|
||||||
download_progress['processed_models'] = set()
|
|
||||||
download_progress['refreshed_models'] = set()
|
|
||||||
download_progress['failed_models'] = set()
|
|
||||||
|
|
||||||
# Set download status to downloading
|
self._is_downloading = True
|
||||||
is_downloading = True
|
|
||||||
|
|
||||||
# Execute the download function directly instead of creating a background task
|
result = await self._download_specific_models_example_images_sync(
|
||||||
result = await DownloadManager._download_specific_models_example_images_sync(
|
|
||||||
model_hashes,
|
model_hashes,
|
||||||
output_dir,
|
output_dir,
|
||||||
optimize,
|
optimize,
|
||||||
@@ -492,8 +450,7 @@ class DownloadManager:
|
|||||||
delay
|
delay
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set download status to not downloading
|
self._is_downloading = False
|
||||||
is_downloading = False
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -502,17 +459,13 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Set download status to not downloading
|
self._is_downloading = False
|
||||||
is_downloading = False
|
|
||||||
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
||||||
raise ExampleImagesDownloadError(str(e)) from e
|
raise ExampleImagesDownloadError(str(e)) from e
|
||||||
|
|
||||||
@staticmethod
|
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
|
||||||
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):
|
"""Download example images for specific models only - synchronous version."""
|
||||||
"""Download example images for specific models only - synchronous version"""
|
|
||||||
global download_progress
|
|
||||||
|
|
||||||
# Get unified downloader
|
|
||||||
downloader = await get_downloader()
|
downloader = await get_downloader()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -540,14 +493,14 @@ class DownloadManager:
|
|||||||
models_to_process.append((scanner_type, model, scanner))
|
models_to_process.append((scanner_type, model, scanner))
|
||||||
|
|
||||||
# Update total count based on found models
|
# Update total count based on found models
|
||||||
download_progress['total'] = len(models_to_process)
|
self._progress['total'] = len(models_to_process)
|
||||||
logger.debug(f"Found {download_progress['total']} models to process")
|
logger.debug(f"Found {self._progress['total']} models to process")
|
||||||
|
|
||||||
# Send initial progress via WebSocket
|
# Send initial progress via WebSocket
|
||||||
await ws_manager.broadcast({
|
await ws_manager.broadcast({
|
||||||
'type': 'example_images_progress',
|
'type': 'example_images_progress',
|
||||||
'processed': 0,
|
'processed': 0,
|
||||||
'total': download_progress['total'],
|
'total': self._progress['total'],
|
||||||
'status': 'running',
|
'status': 'running',
|
||||||
'current_model': ''
|
'current_model': ''
|
||||||
})
|
})
|
||||||
@@ -556,7 +509,7 @@ class DownloadManager:
|
|||||||
success_count = 0
|
success_count = 0
|
||||||
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
|
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
|
||||||
# Force process this model regardless of previous status
|
# Force process this model regardless of previous status
|
||||||
was_successful = await DownloadManager._process_specific_model(
|
was_successful = await self._process_specific_model(
|
||||||
scanner_type, model, scanner,
|
scanner_type, model, scanner,
|
||||||
output_dir, optimize, downloader
|
output_dir, optimize, downloader
|
||||||
)
|
)
|
||||||
@@ -565,55 +518,55 @@ class DownloadManager:
|
|||||||
success_count += 1
|
success_count += 1
|
||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
download_progress['completed'] += 1
|
self._progress['completed'] += 1
|
||||||
|
|
||||||
# Send progress update via WebSocket
|
# Send progress update via WebSocket
|
||||||
await ws_manager.broadcast({
|
await ws_manager.broadcast({
|
||||||
'type': 'example_images_progress',
|
'type': 'example_images_progress',
|
||||||
'processed': download_progress['completed'],
|
'processed': self._progress['completed'],
|
||||||
'total': download_progress['total'],
|
'total': self._progress['total'],
|
||||||
'status': 'running',
|
'status': 'running',
|
||||||
'current_model': download_progress['current_model']
|
'current_model': self._progress['current_model']
|
||||||
})
|
})
|
||||||
|
|
||||||
# Only add delay after remote download, and not after processing the last model
|
# Only add delay after remote download, and not after processing the last model
|
||||||
if was_successful and i < len(models_to_process) - 1 and download_progress['status'] == 'running':
|
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
# Mark as completed
|
# Mark as completed
|
||||||
download_progress['status'] = 'completed'
|
self._progress['status'] = 'completed'
|
||||||
download_progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
logger.debug(f"Forced example images download completed: {download_progress['completed']}/{download_progress['total']} models processed")
|
logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
|
||||||
|
|
||||||
# Send final progress via WebSocket
|
# Send final progress via WebSocket
|
||||||
await ws_manager.broadcast({
|
await ws_manager.broadcast({
|
||||||
'type': 'example_images_progress',
|
'type': 'example_images_progress',
|
||||||
'processed': download_progress['completed'],
|
'processed': self._progress['completed'],
|
||||||
'total': download_progress['total'],
|
'total': self._progress['total'],
|
||||||
'status': 'completed',
|
'status': 'completed',
|
||||||
'current_model': ''
|
'current_model': ''
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'total': download_progress['total'],
|
'total': self._progress['total'],
|
||||||
'processed': download_progress['completed'],
|
'processed': self._progress['completed'],
|
||||||
'successful': success_count,
|
'successful': success_count,
|
||||||
'errors': download_progress['errors']
|
'errors': self._progress['errors']
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error during forced example images download: {str(e)}"
|
error_msg = f"Error during forced example images download: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
download_progress['errors'].append(error_msg)
|
self._progress['errors'].append(error_msg)
|
||||||
download_progress['last_error'] = error_msg
|
self._progress['last_error'] = error_msg
|
||||||
download_progress['status'] = 'error'
|
self._progress['status'] = 'error'
|
||||||
download_progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
|
|
||||||
# Send error status via WebSocket
|
# Send error status via WebSocket
|
||||||
await ws_manager.broadcast({
|
await ws_manager.broadcast({
|
||||||
'type': 'example_images_progress',
|
'type': 'example_images_progress',
|
||||||
'processed': download_progress['completed'],
|
'processed': self._progress['completed'],
|
||||||
'total': download_progress['total'],
|
'total': self._progress['total'],
|
||||||
'status': 'error',
|
'status': 'error',
|
||||||
'error': error_msg,
|
'error': error_msg,
|
||||||
'current_model': ''
|
'current_model': ''
|
||||||
@@ -625,18 +578,16 @@ class DownloadManager:
|
|||||||
# No need to close any sessions since we use the global downloader
|
# No need to close any sessions since we use the global downloader
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@staticmethod
|
async def _process_specific_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
|
||||||
async def _process_specific_model(scanner_type, model, scanner, output_dir, optimize, downloader):
|
"""Process a specific model for forced download, ignoring previous download status."""
|
||||||
"""Process a specific model for forced download, ignoring previous download status"""
|
|
||||||
global download_progress
|
|
||||||
|
|
||||||
# Check if download is paused
|
# Check if download is paused
|
||||||
while download_progress['status'] == 'paused':
|
while self._progress['status'] == 'paused':
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
# Check if download should continue
|
# Check if download should continue
|
||||||
if download_progress['status'] != 'running':
|
if self._progress['status'] != 'running':
|
||||||
logger.info(f"Download stopped: {download_progress['status']}")
|
logger.info(f"Download stopped: {self._progress['status']}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
model_hash = model.get('sha256', '').lower()
|
model_hash = model.get('sha256', '').lower()
|
||||||
@@ -646,7 +597,7 @@ class DownloadManager:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Update current model info
|
# Update current model info
|
||||||
download_progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||||
|
|
||||||
# Create model directory
|
# Create model directory
|
||||||
model_dir = os.path.join(output_dir, model_hash)
|
model_dir = os.path.join(output_dir, model_hash)
|
||||||
@@ -662,7 +613,7 @@ class DownloadManager:
|
|||||||
await MetadataUpdater.update_metadata_from_local_examples(
|
await MetadataUpdater.update_metadata_from_local_examples(
|
||||||
model_hash, model, scanner_type, scanner, model_dir
|
model_hash, model, scanner_type, scanner, model_dir
|
||||||
)
|
)
|
||||||
download_progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
return False # Return False to indicate no remote download happened
|
return False # Return False to indicate no remote download happened
|
||||||
|
|
||||||
# If no local images, try to download from remote
|
# If no local images, try to download from remote
|
||||||
@@ -674,9 +625,9 @@ class DownloadManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If metadata is stale, try to refresh it
|
# If metadata is stale, try to refresh it
|
||||||
if is_stale and model_hash not in download_progress['refreshed_models']:
|
if is_stale and model_hash not in self._progress['refreshed_models']:
|
||||||
await MetadataUpdater.refresh_model_metadata(
|
await MetadataUpdater.refresh_model_metadata(
|
||||||
model_hash, model_name, scanner_type, scanner
|
model_hash, model_name, scanner_type, scanner, self._progress
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the updated model data
|
# Get the updated model data
|
||||||
@@ -694,18 +645,18 @@ class DownloadManager:
|
|||||||
# Combine failed images from both attempts
|
# Combine failed images from both attempts
|
||||||
failed_images.extend(additional_failed_images)
|
failed_images.extend(additional_failed_images)
|
||||||
|
|
||||||
download_progress['refreshed_models'].add(model_hash)
|
self._progress['refreshed_models'].add(model_hash)
|
||||||
|
|
||||||
# For forced downloads, remove failed images from metadata
|
# For forced downloads, remove failed images from metadata
|
||||||
if failed_images:
|
if failed_images:
|
||||||
# Create a copy of images excluding failed ones
|
# Create a copy of images excluding failed ones
|
||||||
await DownloadManager._remove_failed_images_from_metadata(
|
await self._remove_failed_images_from_metadata(
|
||||||
model_hash, model_name, failed_images, scanner
|
model_hash, model_name, failed_images, scanner
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mark as processed
|
# Mark as processed
|
||||||
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
|
if success or failed_images: # Mark as processed if we successfully downloaded some images or removed failed ones
|
||||||
download_progress['processed_models'].add(model_hash)
|
self._progress['processed_models'].add(model_hash)
|
||||||
|
|
||||||
return True # Return True to indicate a remote download happened
|
return True # Return True to indicate a remote download happened
|
||||||
else:
|
else:
|
||||||
@@ -715,12 +666,11 @@ class DownloadManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
|
error_msg = f"Error processing model {model.get('model_name')}: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
download_progress['errors'].append(error_msg)
|
self._progress['errors'].append(error_msg)
|
||||||
download_progress['last_error'] = error_msg
|
self._progress['last_error'] = error_msg
|
||||||
return False # Return False on exception
|
return False # Return False on exception
|
||||||
|
|
||||||
@staticmethod
|
async def _remove_failed_images_from_metadata(self, model_hash, model_name, failed_images, scanner):
|
||||||
async def _remove_failed_images_from_metadata(model_hash, model_name, failed_images, scanner):
|
|
||||||
"""Remove failed images from model metadata"""
|
"""Remove failed images from model metadata"""
|
||||||
try:
|
try:
|
||||||
# Get current model data
|
# Get current model data
|
||||||
@@ -763,3 +713,12 @@ class DownloadManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
|
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
default_download_manager = DownloadManager()
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_download_manager() -> DownloadManager:
|
||||||
|
"""Return the singleton download manager used by default routes."""
|
||||||
|
|
||||||
|
return default_download_manager
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ class MetadataUpdater:
|
|||||||
"""Handles updating model metadata related to example images"""
|
"""Handles updating model metadata related to example images"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner):
|
async def refresh_model_metadata(model_hash, model_name, scanner_type, scanner, progress: dict | None = None):
|
||||||
"""Refresh model metadata from CivitAI
|
"""Refresh model metadata from CivitAI
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -45,8 +45,6 @@ class MetadataUpdater:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if metadata was successfully refreshed, False otherwise
|
bool: True if metadata was successfully refreshed, False otherwise
|
||||||
"""
|
"""
|
||||||
from ..utils.example_images_download_manager import download_progress
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Find the model in the scanner cache
|
# Find the model in the scanner cache
|
||||||
cache = await scanner.get_cached_data()
|
cache = await scanner.get_cached_data()
|
||||||
@@ -67,7 +65,8 @@ class MetadataUpdater:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# Track that we're refreshing this model
|
# Track that we're refreshing this model
|
||||||
download_progress['refreshed_models'].add(model_hash)
|
if progress is not None:
|
||||||
|
progress['refreshed_models'].add(model_hash)
|
||||||
|
|
||||||
async def update_cache_func(old_path, new_path, metadata):
|
async def update_cache_func(old_path, new_path, metadata):
|
||||||
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
return await scanner.update_single_model_cache(old_path, new_path, metadata)
|
||||||
@@ -89,8 +88,9 @@ class MetadataUpdater:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
|
error_msg = f"Error refreshing metadata for {model_name}: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
download_progress['errors'].append(error_msg)
|
if progress is not None:
|
||||||
download_progress['last_error'] = error_msg
|
progress['errors'].append(error_msg)
|
||||||
|
progress['last_error'] = error_msg
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user