refactor(example-images): inject websocket manager

This commit is contained in:
pixelpaws
2025-09-23 14:40:43 +08:00
parent 49b7126278
commit 43fcce6361
5 changed files with 191 additions and 125 deletions

View File

@@ -5,11 +5,12 @@ import os
import asyncio
import json
import time
from typing import Any, Dict
from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater
from ..services.websocket_manager import ws_manager # Add this import at the top
from ..services.downloader import get_downloader
from ..services.settings_manager import settings
@@ -76,82 +77,90 @@ class _DownloadProgress(dict):
class DownloadManager:
"""Manages downloading example images for models."""
def __init__(self) -> None:
def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None:
self._download_task: asyncio.Task | None = None
self._is_downloading = False
self._progress = _DownloadProgress()
self._ws_manager = ws_manager
self._state_lock = state_lock or asyncio.Lock()
async def start_download(self, options: dict):
"""Start downloading example images for models."""
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try:
data = options or {}
auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2))
try:
data = options or {}
auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2))
output_dir = settings.get('example_images_path')
output_dir = settings.get('example_images_path')
if not output_dir:
error_msg = 'Example images path not configured in settings'
if auto_mode:
logger.debug(error_msg)
return {
'success': True,
'message': 'Example images path not configured, skipping auto download'
}
raise DownloadConfigurationError(error_msg)
if not output_dir:
error_msg = 'Example images path not configured in settings'
if auto_mode:
logger.debug(error_msg)
return {
'success': True,
'message': 'Example images path not configured, skipping auto download'
}
raise DownloadConfigurationError(error_msg)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
self._progress.reset()
self._progress['status'] = 'running'
self._progress['start_time'] = time.time()
self._progress['end_time'] = None
self._progress.reset()
self._progress['status'] = 'running'
self._progress['start_time'] = time.time()
self._progress['end_time'] = None
progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
saved_progress = json.load(f)
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress['processed_models']),
len(self._progress['failed_models']),
)
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
saved_progress = json.load(f)
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress['processed_models']),
len(self._progress['failed_models']),
)
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
else:
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
else:
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
self._is_downloading = True
self._download_task = asyncio.create_task(
self._download_all_example_images(
output_dir,
optimize,
model_types,
delay
self._is_downloading = True
self._download_task = asyncio.create_task(
self._download_all_example_images(
output_dir,
optimize,
model_types,
delay
)
)
)
return {
'success': True,
'message': 'Download started',
'status': self._progress.snapshot()
}
snapshot = self._progress.snapshot()
except Exception as e:
self._is_downloading = False
self._download_task = None
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
await self._broadcast_progress(status='running')
return {
'success': True,
'message': 'Download started',
'status': snapshot
}
async def get_status(self, request):
"""Get the current status of example images download."""
@@ -165,10 +174,13 @@ class DownloadManager:
async def pause_download(self, request):
"""Pause the example images download."""
if not self._is_downloading:
raise DownloadNotRunningError()
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
self._progress['status'] = 'paused'
self._progress['status'] = 'paused'
await self._broadcast_progress(status='paused')
return {
'success': True,
@@ -178,20 +190,23 @@ class DownloadManager:
async def resume_download(self, request):
"""Resume the example images download."""
if not self._is_downloading:
raise DownloadNotRunningError()
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
if self._progress['status'] == 'paused':
self._progress['status'] = 'running'
if self._progress['status'] == 'paused':
self._progress['status'] = 'running'
else:
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
return {
'success': True,
'message': 'Download resumed'
}
await self._broadcast_progress(status='running')
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
return {
'success': True,
'message': 'Download resumed'
}
async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
"""Download example images for all models."""
@@ -225,6 +240,7 @@ class DownloadManager:
# Update total count
self._progress['total'] = len(all_models)
logger.debug(f"Found {self._progress['total']} models to process")
await self._broadcast_progress(status='running')
# Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models):
@@ -236,6 +252,7 @@ class DownloadManager:
# Update progress
self._progress['completed'] += 1
await self._broadcast_progress(status='running')
# Only add delay after remote download of models, and not after processing the last model
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
@@ -244,8 +261,13 @@ class DownloadManager:
# Mark as completed
self._progress['status'] = 'completed'
self._progress['end_time'] = time.time()
logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
logger.debug(
"Example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
await self._broadcast_progress(status='completed')
except Exception as e:
error_msg = f"Error during example images download: {str(e)}"
logger.error(error_msg, exc_info=True)
@@ -253,7 +275,8 @@ class DownloadManager:
self._progress['last_error'] = error_msg
self._progress['status'] = 'error'
self._progress['end_time'] = time.time()
await self._broadcast_progress(status='error', extra={'error': error_msg})
finally:
# Save final progress to file
try:
@@ -262,8 +285,9 @@ class DownloadManager:
logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading
self._is_downloading = False
self._download_task = None
async with self._state_lock:
self._is_downloading = False
self._download_task = None
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
"""Process a single model download."""
@@ -285,6 +309,7 @@ class DownloadManager:
try:
# Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Skip if already in failed models
if model_hash in self._progress['failed_models']:
@@ -414,10 +439,10 @@ class DownloadManager:
async def start_force_download(self, options: dict):
"""Force download example images for specific models."""
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try:
data = options or {}
model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True)
@@ -442,6 +467,9 @@ class DownloadManager:
self._is_downloading = True
await self._broadcast_progress(status='running')
try:
result = await self._download_specific_models_example_images_sync(
model_hashes,
output_dir,
@@ -450,7 +478,8 @@ class DownloadManager:
delay
)
self._is_downloading = False
async with self._state_lock:
self._is_downloading = False
return {
'success': True,
@@ -459,8 +488,10 @@ class DownloadManager:
}
except Exception as e:
self._is_downloading = False
async with self._state_lock:
self._is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
await self._broadcast_progress(status='error', extra={'error': str(e)})
raise ExampleImagesDownloadError(str(e)) from e
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
@@ -495,15 +526,9 @@ class DownloadManager:
# Update total count based on found models
self._progress['total'] = len(models_to_process)
logger.debug(f"Found {self._progress['total']} models to process")
# Send initial progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': 0,
'total': self._progress['total'],
'status': 'running',
'current_model': ''
})
await self._broadcast_progress(status='running')
# Process each model
success_count = 0
@@ -519,15 +544,9 @@ class DownloadManager:
# Update progress
self._progress['completed'] += 1
# Send progress update via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'running',
'current_model': self._progress['current_model']
})
await self._broadcast_progress(status='running')
# Only add delay after remote download, and not after processing the last model
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
@@ -536,16 +555,14 @@ class DownloadManager:
# Mark as completed
self._progress['status'] = 'completed'
self._progress['end_time'] = time.time()
logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
logger.debug(
"Forced example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
# Send final progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'completed',
'current_model': ''
})
await self._broadcast_progress(status='completed')
return {
'total': self._progress['total'],
@@ -561,16 +578,9 @@ class DownloadManager:
self._progress['last_error'] = error_msg
self._progress['status'] = 'error'
self._progress['end_time'] = time.time()
# Send error status via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'error',
'error': error_msg,
'current_model': ''
})
await self._broadcast_progress(status='error', extra={'error': error_msg})
raise
@@ -598,6 +608,7 @@ class DownloadManager:
try:
# Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
@@ -714,11 +725,53 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
async def _broadcast_progress(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> None:
payload = self._build_progress_payload(status=status, extra=extra)
try:
await self._ws_manager.broadcast(payload)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to broadcast example image progress: %s", exc)
default_download_manager = DownloadManager()
def _build_progress_payload(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': status or self._progress['status'],
'current_model': self._progress['current_model'],
}
if self._progress['errors']:
payload['errors'] = list(self._progress['errors'])
if self._progress['last_error']:
payload['last_error'] = self._progress['last_error']
if extra:
payload.update(extra)
return payload
def get_default_download_manager() -> DownloadManager:
_default_download_manager: DownloadManager | None = None
def get_default_download_manager(ws_manager) -> DownloadManager:
"""Return the singleton download manager used by default routes."""
return default_download_manager
global _default_download_manager
if (
_default_download_manager is None
or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager
):
_default_download_manager = DownloadManager(ws_manager=ws_manager)
return _default_download_manager