mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat(example-images): add use case orchestration
This commit is contained in:
@@ -3,7 +3,6 @@ import os
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from aiohttp import web
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
from ..utils.metadata_manager import MetadataManager
|
||||
from .example_images_processor import ExampleImagesProcessor
|
||||
@@ -12,6 +11,30 @@ from ..services.websocket_manager import ws_manager # Add this import at the to
|
||||
from ..services.downloader import get_downloader
|
||||
from ..services.settings_manager import settings
|
||||
|
||||
|
||||
class ExampleImagesDownloadError(RuntimeError):
|
||||
"""Base error for example image download operations."""
|
||||
|
||||
|
||||
class DownloadInProgressError(ExampleImagesDownloadError):
|
||||
"""Raised when a download is already running."""
|
||||
|
||||
def __init__(self, progress_snapshot: dict) -> None:
|
||||
super().__init__("Download already in progress")
|
||||
self.progress_snapshot = progress_snapshot
|
||||
|
||||
|
||||
class DownloadNotRunningError(ExampleImagesDownloadError):
|
||||
"""Raised when pause/resume is requested without an active download."""
|
||||
|
||||
def __init__(self, message: str = "No download in progress") -> None:
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class DownloadConfigurationError(ExampleImagesDownloadError):
|
||||
"""Raised when configuration prevents starting a download."""
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Download status tracking
|
||||
@@ -31,11 +54,21 @@ download_progress = {
|
||||
'failed_models': set() # Track models that failed to download after metadata refresh
|
||||
}
|
||||
|
||||
|
||||
def _serialize_progress() -> dict:
|
||||
"""Return a JSON-serialisable snapshot of the current progress."""
|
||||
|
||||
snapshot = download_progress.copy()
|
||||
snapshot['processed_models'] = list(download_progress['processed_models'])
|
||||
snapshot['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
snapshot['failed_models'] = list(download_progress['failed_models'])
|
||||
return snapshot
|
||||
|
||||
class DownloadManager:
|
||||
"""Manages downloading example images for models"""
|
||||
|
||||
@staticmethod
|
||||
async def start_download(request):
|
||||
async def start_download(options: dict):
|
||||
"""
|
||||
Start downloading example images for models
|
||||
|
||||
@@ -50,25 +83,14 @@ class DownloadManager:
|
||||
global download_task, is_downloading, download_progress
|
||||
|
||||
if is_downloading:
|
||||
# Create a copy for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download already in progress',
|
||||
'status': response_progress
|
||||
}, status=400)
|
||||
|
||||
raise DownloadInProgressError(_serialize_progress())
|
||||
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
data = options or {}
|
||||
auto_mode = data.get('auto_mode', False)
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
# Get output directory from settings
|
||||
output_dir = settings.get('example_images_path')
|
||||
@@ -78,15 +100,11 @@ class DownloadManager:
|
||||
if auto_mode:
|
||||
# For auto mode, just log and return success to avoid showing error toasts
|
||||
logger.debug(error_msg)
|
||||
return web.json_response({
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Example images path not configured, skipping auto download'
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': error_msg
|
||||
}, status=400)
|
||||
}
|
||||
raise DownloadConfigurationError(error_msg)
|
||||
|
||||
# Create the output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -129,41 +147,29 @@ class DownloadManager:
|
||||
)
|
||||
)
|
||||
|
||||
# Create a copy for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Download started',
|
||||
'status': response_progress
|
||||
})
|
||||
|
||||
'status': _serialize_progress()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start example images download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
raise ExampleImagesDownloadError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
async def get_status(request):
|
||||
"""Get the current status of example images download"""
|
||||
global download_progress
|
||||
|
||||
|
||||
# Create a copy of the progress dict with the set converted to a list for JSON serialization
|
||||
response_progress = download_progress.copy()
|
||||
response_progress['processed_models'] = list(download_progress['processed_models'])
|
||||
response_progress['refreshed_models'] = list(download_progress['refreshed_models'])
|
||||
response_progress['failed_models'] = list(download_progress['failed_models'])
|
||||
|
||||
return web.json_response({
|
||||
response_progress = _serialize_progress()
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'is_downloading': is_downloading,
|
||||
'status': response_progress
|
||||
})
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def pause_download(request):
|
||||
@@ -171,17 +177,14 @@ class DownloadManager:
|
||||
global download_progress
|
||||
|
||||
if not is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No download in progress'
|
||||
}, status=400)
|
||||
|
||||
raise DownloadNotRunningError()
|
||||
|
||||
download_progress['status'] = 'paused'
|
||||
|
||||
return web.json_response({
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Download paused'
|
||||
})
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def resume_download(request):
|
||||
@@ -189,23 +192,19 @@ class DownloadManager:
|
||||
global download_progress
|
||||
|
||||
if not is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'No download in progress'
|
||||
}, status=400)
|
||||
|
||||
raise DownloadNotRunningError()
|
||||
|
||||
if download_progress['status'] == 'paused':
|
||||
download_progress['status'] = 'running'
|
||||
|
||||
return web.json_response({
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Download resumed'
|
||||
})
|
||||
else:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': f"Download is in '{download_progress['status']}' state, cannot resume"
|
||||
}, status=400)
|
||||
}
|
||||
|
||||
raise DownloadNotRunningError(
|
||||
f"Download is in '{download_progress['status']}' state, cannot resume"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _download_all_example_images(output_dir, optimize, model_types, delay):
|
||||
@@ -432,7 +431,7 @@ class DownloadManager:
|
||||
logger.error(f"Failed to save progress file: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def start_force_download(request):
|
||||
async def start_force_download(options: dict):
|
||||
"""
|
||||
Force download example images for specific models
|
||||
|
||||
@@ -447,33 +446,23 @@ class DownloadManager:
|
||||
global download_task, is_downloading, download_progress
|
||||
|
||||
if is_downloading:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Download already in progress'
|
||||
}, status=400)
|
||||
raise DownloadInProgressError(_serialize_progress())
|
||||
|
||||
try:
|
||||
# Parse the request body
|
||||
data = await request.json()
|
||||
data = options or {}
|
||||
model_hashes = data.get('model_hashes', [])
|
||||
optimize = data.get('optimize', True)
|
||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
delay = float(data.get('delay', 0.2)) # Default to 0.2 seconds
|
||||
|
||||
if not model_hashes:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Missing model_hashes parameter'
|
||||
}, status=400)
|
||||
|
||||
raise DownloadConfigurationError('Missing model_hashes parameter')
|
||||
|
||||
# Get output directory from settings
|
||||
output_dir = settings.get('example_images_path')
|
||||
|
||||
|
||||
if not output_dir:
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': 'Example images path not configured in settings'
|
||||
}, status=400)
|
||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||
|
||||
# Create the output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
@@ -506,20 +495,17 @@ class DownloadManager:
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
|
||||
return web.json_response({
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Force download completed',
|
||||
'result': result
|
||||
})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Set download status to not downloading
|
||||
is_downloading = False
|
||||
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
||||
return web.json_response({
|
||||
'success': False,
|
||||
'error': str(e)
|
||||
}, status=500)
|
||||
raise ExampleImagesDownloadError(str(e)) from e
|
||||
|
||||
@staticmethod
|
||||
async def _download_specific_models_example_images_sync(model_hashes, output_dir, optimize, model_types, delay):
|
||||
|
||||
Reference in New Issue
Block a user