mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 06:32:12 -03:00
feat(example-images): add stop control for download panel
This commit is contained in:
@@ -22,6 +22,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("GET", "/api/lm/example-images-status", "get_example_images_status"),
|
||||
RouteDefinition("POST", "/api/lm/pause-example-images", "pause_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/resume-example-images", "resume_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/stop-example-images", "stop_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/open-example-images-folder", "open_example_images_folder"),
|
||||
RouteDefinition("GET", "/api/lm/example-image-files", "get_example_image_files"),
|
||||
RouteDefinition("GET", "/api/lm/has-example-images", "has_example_images"),
|
||||
|
||||
@@ -68,6 +68,13 @@ class ExampleImagesDownloadHandler:
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def stop_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
result = await self._download_manager.stop_download(request)
|
||||
return web.json_response(result)
|
||||
except DownloadNotRunningError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=400)
|
||||
|
||||
async def force_download_example_images(self, request: web.Request) -> web.StreamResponse:
|
||||
try:
|
||||
payload = await request.json()
|
||||
@@ -149,6 +156,7 @@ class ExampleImagesHandlerSet:
|
||||
"get_example_images_status": self.download.get_example_images_status,
|
||||
"pause_example_images": self.download.pause_example_images,
|
||||
"resume_example_images": self.download.resume_example_images,
|
||||
"stop_example_images": self.download.stop_example_images,
|
||||
"force_download_example_images": self.download.force_download_example_images,
|
||||
"import_example_images": self.management.import_example_images,
|
||||
"delete_example_image": self.management.delete_example_image,
|
||||
|
||||
@@ -105,6 +105,7 @@ class DownloadManager:
|
||||
self._progress = _DownloadProgress()
|
||||
self._ws_manager = ws_manager
|
||||
self._state_lock = state_lock or asyncio.Lock()
|
||||
self._stop_requested = False
|
||||
|
||||
def _resolve_output_dir(self, library_name: str | None = None) -> str:
|
||||
base_path = get_settings_manager().get('example_images_path')
|
||||
@@ -145,6 +146,7 @@ class DownloadManager:
|
||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||
|
||||
self._progress.reset()
|
||||
self._stop_requested = False
|
||||
self._progress['status'] = 'running'
|
||||
self._progress['start_time'] = time.time()
|
||||
self._progress['end_time'] = None
|
||||
@@ -267,6 +269,27 @@ class DownloadManager:
|
||||
'success': True,
|
||||
'message': 'Download resumed'
|
||||
}
|
||||
|
||||
async def stop_download(self, request):
|
||||
"""Stop the example images download after the current model completes."""
|
||||
|
||||
async with self._state_lock:
|
||||
if not self._is_downloading:
|
||||
raise DownloadNotRunningError()
|
||||
|
||||
if self._progress['status'] in {'completed', 'error', 'stopped'}:
|
||||
raise DownloadNotRunningError()
|
||||
|
||||
if self._progress['status'] != 'stopping':
|
||||
self._stop_requested = True
|
||||
self._progress['status'] = 'stopping'
|
||||
|
||||
await self._broadcast_progress(status='stopping')
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Download stopping'
|
||||
}
|
||||
|
||||
async def _download_all_example_images(
|
||||
self,
|
||||
@@ -311,6 +334,12 @@ class DownloadManager:
|
||||
|
||||
# Process each model
|
||||
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
||||
async with self._state_lock:
|
||||
current_status = self._progress['status']
|
||||
|
||||
if current_status not in {'running', 'paused', 'stopping'}:
|
||||
break
|
||||
|
||||
# Main logic for processing model is here, but actual operations are delegated to other classes
|
||||
was_remote_download = await self._process_model(
|
||||
scanner_type,
|
||||
@@ -321,24 +350,59 @@ class DownloadManager:
|
||||
downloader,
|
||||
library_name,
|
||||
)
|
||||
|
||||
|
||||
# Update progress
|
||||
self._progress['completed'] += 1
|
||||
await self._broadcast_progress(status='running')
|
||||
|
||||
|
||||
async with self._state_lock:
|
||||
current_status = self._progress['status']
|
||||
should_stop = self._stop_requested and current_status == 'stopping'
|
||||
|
||||
broadcast_status = 'running' if current_status == 'running' else current_status
|
||||
await self._broadcast_progress(status=broadcast_status)
|
||||
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
# Only add delay after remote download of models, and not after processing the last model
|
||||
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
|
||||
if (
|
||||
was_remote_download
|
||||
and i < len(all_models) - 1
|
||||
and current_status == 'running'
|
||||
):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Mark as completed
|
||||
self._progress['status'] = 'completed'
|
||||
self._progress['end_time'] = time.time()
|
||||
logger.debug(
|
||||
"Example images download completed: %s/%s models processed",
|
||||
self._progress['completed'],
|
||||
self._progress['total'],
|
||||
)
|
||||
await self._broadcast_progress(status='completed')
|
||||
|
||||
async with self._state_lock:
|
||||
if self._stop_requested and self._progress['status'] == 'stopping':
|
||||
self._progress['status'] = 'stopped'
|
||||
self._progress['end_time'] = time.time()
|
||||
self._stop_requested = False
|
||||
final_status = 'stopped'
|
||||
elif self._progress['status'] not in {'error', 'stopped'}:
|
||||
self._progress['status'] = 'completed'
|
||||
self._progress['end_time'] = time.time()
|
||||
self._stop_requested = False
|
||||
final_status = 'completed'
|
||||
else:
|
||||
final_status = self._progress['status']
|
||||
self._stop_requested = False
|
||||
if self._progress['end_time'] is None:
|
||||
self._progress['end_time'] = time.time()
|
||||
|
||||
if final_status == 'completed':
|
||||
logger.debug(
|
||||
"Example images download completed: %s/%s models processed",
|
||||
self._progress['completed'],
|
||||
self._progress['total'],
|
||||
)
|
||||
elif final_status == 'stopped':
|
||||
logger.debug(
|
||||
"Example images download stopped: %s/%s models processed",
|
||||
self._progress['completed'],
|
||||
self._progress['total'],
|
||||
)
|
||||
|
||||
await self._broadcast_progress(status=final_status)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error during example images download: {str(e)}"
|
||||
@@ -360,6 +424,7 @@ class DownloadManager:
|
||||
async with self._state_lock:
|
||||
self._is_downloading = False
|
||||
self._download_task = None
|
||||
self._stop_requested = False
|
||||
|
||||
async def _process_model(
|
||||
self,
|
||||
@@ -378,7 +443,7 @@ class DownloadManager:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check if download should continue
|
||||
if self._progress['status'] != 'running':
|
||||
if self._progress['status'] not in {'running', 'stopping'}:
|
||||
logger.info(f"Download stopped: {self._progress['status']}")
|
||||
return False # Return False to indicate no remote download happened
|
||||
|
||||
@@ -567,6 +632,7 @@ class DownloadManager:
|
||||
raise DownloadConfigurationError('Example images path not configured in settings')
|
||||
|
||||
self._progress.reset()
|
||||
self._stop_requested = False
|
||||
self._progress['total'] = len(model_hashes)
|
||||
self._progress['status'] = 'running'
|
||||
self._progress['start_time'] = time.time()
|
||||
@@ -588,10 +654,15 @@ class DownloadManager:
|
||||
|
||||
async with self._state_lock:
|
||||
self._is_downloading = False
|
||||
final_status = self._progress['status']
|
||||
|
||||
message = 'Force download completed'
|
||||
if final_status == 'stopped':
|
||||
message = 'Force download stopped'
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'message': 'Force download completed',
|
||||
'message': message,
|
||||
'result': result
|
||||
}
|
||||
|
||||
@@ -649,6 +720,12 @@ class DownloadManager:
|
||||
# Process each model
|
||||
success_count = 0
|
||||
for i, (scanner_type, model, scanner) in enumerate(models_to_process):
|
||||
async with self._state_lock:
|
||||
current_status = self._progress['status']
|
||||
|
||||
if current_status not in {'running', 'paused', 'stopping'}:
|
||||
break
|
||||
|
||||
# Force process this model regardless of previous status
|
||||
was_successful = await self._process_specific_model(
|
||||
scanner_type,
|
||||
@@ -659,32 +736,65 @@ class DownloadManager:
|
||||
downloader,
|
||||
library_name,
|
||||
)
|
||||
|
||||
|
||||
if was_successful:
|
||||
success_count += 1
|
||||
|
||||
|
||||
# Update progress
|
||||
self._progress['completed'] += 1
|
||||
|
||||
async with self._state_lock:
|
||||
current_status = self._progress['status']
|
||||
should_stop = self._stop_requested and current_status == 'stopping'
|
||||
|
||||
broadcast_status = 'running' if current_status == 'running' else current_status
|
||||
# Send progress update via WebSocket
|
||||
await self._broadcast_progress(status='running')
|
||||
|
||||
await self._broadcast_progress(status=broadcast_status)
|
||||
|
||||
if should_stop:
|
||||
break
|
||||
|
||||
# Only add delay after remote download, and not after processing the last model
|
||||
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
|
||||
if (
|
||||
was_successful
|
||||
and i < len(models_to_process) - 1
|
||||
and current_status == 'running'
|
||||
):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Mark as completed
|
||||
self._progress['status'] = 'completed'
|
||||
self._progress['end_time'] = time.time()
|
||||
logger.debug(
|
||||
"Forced example images download completed: %s/%s models processed",
|
||||
self._progress['completed'],
|
||||
self._progress['total'],
|
||||
)
|
||||
|
||||
async with self._state_lock:
|
||||
if self._stop_requested and self._progress['status'] == 'stopping':
|
||||
self._progress['status'] = 'stopped'
|
||||
self._progress['end_time'] = time.time()
|
||||
self._stop_requested = False
|
||||
final_status = 'stopped'
|
||||
elif self._progress['status'] not in {'error', 'stopped'}:
|
||||
self._progress['status'] = 'completed'
|
||||
self._progress['end_time'] = time.time()
|
||||
self._stop_requested = False
|
||||
final_status = 'completed'
|
||||
else:
|
||||
final_status = self._progress['status']
|
||||
self._stop_requested = False
|
||||
if self._progress['end_time'] is None:
|
||||
self._progress['end_time'] = time.time()
|
||||
|
||||
if final_status == 'completed':
|
||||
logger.debug(
|
||||
"Forced example images download completed: %s/%s models processed",
|
||||
self._progress['completed'],
|
||||
self._progress['total'],
|
||||
)
|
||||
elif final_status == 'stopped':
|
||||
logger.debug(
|
||||
"Forced example images download stopped: %s/%s models processed",
|
||||
self._progress['completed'],
|
||||
self._progress['total'],
|
||||
)
|
||||
|
||||
# Send final progress via WebSocket
|
||||
await self._broadcast_progress(status='completed')
|
||||
|
||||
await self._broadcast_progress(status=final_status)
|
||||
|
||||
return {
|
||||
'total': self._progress['total'],
|
||||
'processed': self._progress['completed'],
|
||||
@@ -726,7 +836,7 @@ class DownloadManager:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Check if download should continue
|
||||
if self._progress['status'] != 'running':
|
||||
if self._progress['status'] not in {'running', 'stopping'}:
|
||||
logger.info(f"Download stopped: {self._progress['status']}")
|
||||
return False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user