refactor(example-images): inject websocket manager

This commit is contained in:
pixelpaws
2025-09-23 14:40:43 +08:00
parent 49b7126278
commit 43fcce6361
5 changed files with 191 additions and 125 deletions

View File

@@ -166,7 +166,7 @@ class LoraManager:
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)

View File

@@ -32,21 +32,24 @@ class ExampleImagesRoutes:
def __init__(
self,
*,
ws_manager,
download_manager: DownloadManager | None = None,
processor=ExampleImagesProcessor,
file_manager=ExampleImagesFileManager,
) -> None:
self._download_manager = download_manager or get_default_download_manager()
if ws_manager is None:
raise ValueError("ws_manager is required")
self._download_manager = download_manager or get_default_download_manager(ws_manager)
self._processor = processor
self._file_manager = file_manager
self._handler_set: ExampleImagesHandlerSet | None = None
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
@classmethod
def setup_routes(cls, app: web.Application) -> None:
def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
"""Register routes on the given aiohttp application using default wiring."""
controller = cls()
controller = cls(ws_manager=ws_manager)
controller.register(app)
def register(self, app: web.Application) -> None:

View File

@@ -5,11 +5,12 @@ import os
import asyncio
import json
import time
from typing import Any, Dict
from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater
from ..services.websocket_manager import ws_manager # Add this import at the top
from ..services.downloader import get_downloader
from ..services.settings_manager import settings
@@ -76,14 +77,17 @@ class _DownloadProgress(dict):
class DownloadManager:
"""Manages downloading example images for models."""
def __init__(self) -> None:
def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None:
self._download_task: asyncio.Task | None = None
self._is_downloading = False
self._progress = _DownloadProgress()
self._ws_manager = ws_manager
self._state_lock = state_lock or asyncio.Lock()
async def start_download(self, options: dict):
"""Start downloading example images for models."""
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
@@ -143,16 +147,21 @@ class DownloadManager:
)
)
snapshot = self._progress.snapshot()
except Exception as e:
self._is_downloading = False
self._download_task = None
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
await self._broadcast_progress(status='running')
return {
'success': True,
'message': 'Download started',
'status': self._progress.snapshot()
'status': snapshot
}
except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
async def get_status(self, request):
"""Get the current status of example images download."""
@@ -165,11 +174,14 @@ class DownloadManager:
async def pause_download(self, request):
"""Pause the example images download."""
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
self._progress['status'] = 'paused'
await self._broadcast_progress(status='paused')
return {
'success': True,
'message': 'Download paused'
@@ -178,21 +190,24 @@ class DownloadManager:
async def resume_download(self, request):
"""Resume the example images download."""
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
if self._progress['status'] == 'paused':
self._progress['status'] = 'running'
else:
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
await self._broadcast_progress(status='running')
return {
'success': True,
'message': 'Download resumed'
}
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
"""Download example images for all models."""
@@ -225,6 +240,7 @@ class DownloadManager:
# Update total count
self._progress['total'] = len(all_models)
logger.debug(f"Found {self._progress['total']} models to process")
await self._broadcast_progress(status='running')
# Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models):
@@ -236,6 +252,7 @@ class DownloadManager:
# Update progress
self._progress['completed'] += 1
await self._broadcast_progress(status='running')
# Only add delay after remote download of models, and not after processing the last model
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
@@ -244,7 +261,12 @@ class DownloadManager:
# Mark as completed
self._progress['status'] = 'completed'
self._progress['end_time'] = time.time()
logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
logger.debug(
"Example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
await self._broadcast_progress(status='completed')
except Exception as e:
error_msg = f"Error during example images download: {str(e)}"
@@ -253,6 +275,7 @@ class DownloadManager:
self._progress['last_error'] = error_msg
self._progress['status'] = 'error'
self._progress['end_time'] = time.time()
await self._broadcast_progress(status='error', extra={'error': error_msg})
finally:
# Save final progress to file
@@ -262,6 +285,7 @@ class DownloadManager:
logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading
async with self._state_lock:
self._is_downloading = False
self._download_task = None
@@ -285,6 +309,7 @@ class DownloadManager:
try:
# Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Skip if already in failed models
if model_hash in self._progress['failed_models']:
@@ -414,10 +439,10 @@ class DownloadManager:
async def start_force_download(self, options: dict):
"""Force download example images for specific models."""
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try:
data = options or {}
model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True)
@@ -442,6 +467,9 @@ class DownloadManager:
self._is_downloading = True
await self._broadcast_progress(status='running')
try:
result = await self._download_specific_models_example_images_sync(
model_hashes,
output_dir,
@@ -450,6 +478,7 @@ class DownloadManager:
delay
)
async with self._state_lock:
self._is_downloading = False
return {
@@ -459,8 +488,10 @@ class DownloadManager:
}
except Exception as e:
async with self._state_lock:
self._is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
await self._broadcast_progress(status='error', extra={'error': str(e)})
raise ExampleImagesDownloadError(str(e)) from e
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
@@ -497,13 +528,7 @@ class DownloadManager:
logger.debug(f"Found {self._progress['total']} models to process")
# Send initial progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': 0,
'total': self._progress['total'],
'status': 'running',
'current_model': ''
})
await self._broadcast_progress(status='running')
# Process each model
success_count = 0
@@ -521,13 +546,7 @@ class DownloadManager:
self._progress['completed'] += 1
# Send progress update via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'running',
'current_model': self._progress['current_model']
})
await self._broadcast_progress(status='running')
# Only add delay after remote download, and not after processing the last model
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
@@ -536,16 +555,14 @@ class DownloadManager:
# Mark as completed
self._progress['status'] = 'completed'
self._progress['end_time'] = time.time()
logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
logger.debug(
"Forced example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
# Send final progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'completed',
'current_model': ''
})
await self._broadcast_progress(status='completed')
return {
'total': self._progress['total'],
@@ -563,14 +580,7 @@ class DownloadManager:
self._progress['end_time'] = time.time()
# Send error status via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'error',
'error': error_msg,
'current_model': ''
})
await self._broadcast_progress(status='error', extra={'error': error_msg})
raise
@@ -598,6 +608,7 @@ class DownloadManager:
try:
# Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
@@ -714,11 +725,53 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
async def _broadcast_progress(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> None:
payload = self._build_progress_payload(status=status, extra=extra)
try:
await self._ws_manager.broadcast(payload)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to broadcast example image progress: %s", exc)
default_download_manager = DownloadManager()
def _build_progress_payload(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': status or self._progress['status'],
'current_model': self._progress['current_model'],
}
if self._progress['errors']:
payload['errors'] = list(self._progress['errors'])
if self._progress['last_error']:
payload['last_error'] = self._progress['last_error']
if extra:
payload.update(extra)
return payload
def get_default_download_manager() -> DownloadManager:
_default_download_manager: DownloadManager | None = None
def get_default_download_manager(ws_manager) -> DownloadManager:
"""Return the singleton download manager used by default routes."""
return default_download_manager
global _default_download_manager
if (
_default_download_manager is None
or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager
):
_default_download_manager = DownloadManager(ws_manager=ws_manager)
return _default_download_manager

View File

@@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager):
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
@@ -88,6 +88,14 @@ class StubExampleImagesFileManager:
return web.json_response({"operation": "has_images", "query": dict(request.query)})
class StubWebSocketManager:
def __init__(self) -> None:
self.broadcast_calls: List[Dict[str, Any]] = []
async def broadcast(self, payload: Dict[str, Any]) -> None:
self.broadcast_calls.append(payload)
@asynccontextmanager
async def example_images_app() -> ExampleImagesHarness:
"""Yield an ExampleImagesRoutes app wired with stubbed collaborators."""
@@ -95,8 +103,10 @@ async def example_images_app() -> ExampleImagesHarness:
download_manager = StubDownloadManager()
processor = StubExampleImagesProcessor()
file_manager = StubExampleImagesFileManager()
ws_manager = StubWebSocketManager()
controller = ExampleImagesRoutes(
ws_manager=ws_manager,
download_manager=download_manager,
processor=processor,
file_manager=file_manager,