Merge pull request #464 from willmiao/codex/refactor-websocket-integration-for-downloading

refactor: align example image downloads with websocket manager
This commit is contained in:
pixelpaws
2025-09-23 14:43:37 +08:00
committed by GitHub
5 changed files with 191 additions and 125 deletions

View File

@@ -166,7 +166,7 @@ class LoraManager:
RecipeRoutes.setup_routes(app) RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app) UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app) MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app) ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
# Setup WebSocket routes that are shared across all model types # Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)

View File

@@ -32,21 +32,24 @@ class ExampleImagesRoutes:
def __init__( def __init__(
self, self,
*, *,
ws_manager,
download_manager: DownloadManager | None = None, download_manager: DownloadManager | None = None,
processor=ExampleImagesProcessor, processor=ExampleImagesProcessor,
file_manager=ExampleImagesFileManager, file_manager=ExampleImagesFileManager,
) -> None: ) -> None:
self._download_manager = download_manager or get_default_download_manager() if ws_manager is None:
raise ValueError("ws_manager is required")
self._download_manager = download_manager or get_default_download_manager(ws_manager)
self._processor = processor self._processor = processor
self._file_manager = file_manager self._file_manager = file_manager
self._handler_set: ExampleImagesHandlerSet | None = None self._handler_set: ExampleImagesHandlerSet | None = None
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
@classmethod @classmethod
def setup_routes(cls, app: web.Application) -> None: def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
"""Register routes on the given aiohttp application using default wiring.""" """Register routes on the given aiohttp application using default wiring."""
controller = cls() controller = cls(ws_manager=ws_manager)
controller.register(app) controller.register(app)
def register(self, app: web.Application) -> None: def register(self, app: web.Application) -> None:

View File

@@ -5,11 +5,12 @@ import os
import asyncio import asyncio
import json import json
import time import time
from typing import Any, Dict
from ..services.service_registry import ServiceRegistry from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater from .example_images_metadata import MetadataUpdater
from ..services.websocket_manager import ws_manager # Add this import at the top
from ..services.downloader import get_downloader from ..services.downloader import get_downloader
from ..services.settings_manager import settings from ..services.settings_manager import settings
@@ -76,82 +77,90 @@ class _DownloadProgress(dict):
class DownloadManager: class DownloadManager:
"""Manages downloading example images for models.""" """Manages downloading example images for models."""
def __init__(self) -> None: def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None:
self._download_task: asyncio.Task | None = None self._download_task: asyncio.Task | None = None
self._is_downloading = False self._is_downloading = False
self._progress = _DownloadProgress() self._progress = _DownloadProgress()
self._ws_manager = ws_manager
self._state_lock = state_lock or asyncio.Lock()
async def start_download(self, options: dict): async def start_download(self, options: dict):
"""Start downloading example images for models.""" """Start downloading example images for models."""
if self._is_downloading: async with self._state_lock:
raise DownloadInProgressError(self._progress.snapshot()) if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try: try:
data = options or {} data = options or {}
auto_mode = data.get('auto_mode', False) auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True) optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint']) model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2)) delay = float(data.get('delay', 0.2))
output_dir = settings.get('example_images_path') output_dir = settings.get('example_images_path')
if not output_dir: if not output_dir:
error_msg = 'Example images path not configured in settings' error_msg = 'Example images path not configured in settings'
if auto_mode: if auto_mode:
logger.debug(error_msg) logger.debug(error_msg)
return { return {
'success': True, 'success': True,
'message': 'Example images path not configured, skipping auto download' 'message': 'Example images path not configured, skipping auto download'
} }
raise DownloadConfigurationError(error_msg) raise DownloadConfigurationError(error_msg)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
self._progress.reset() self._progress.reset()
self._progress['status'] = 'running' self._progress['status'] = 'running'
self._progress['start_time'] = time.time() self._progress['start_time'] = time.time()
self._progress['end_time'] = None self._progress['end_time'] = None
progress_file = os.path.join(output_dir, '.download_progress.json') progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file): if os.path.exists(progress_file):
try: try:
with open(progress_file, 'r', encoding='utf-8') as f: with open(progress_file, 'r', encoding='utf-8') as f:
saved_progress = json.load(f) saved_progress = json.load(f)
self._progress['processed_models'] = set(saved_progress.get('processed_models', [])) self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
self._progress['failed_models'] = set(saved_progress.get('failed_models', [])) self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
logger.debug( logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed", "Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress['processed_models']), len(self._progress['processed_models']),
len(self._progress['failed_models']), len(self._progress['failed_models']),
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to load progress file: {e}") logger.error(f"Failed to load progress file: {e}")
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
else:
self._progress['processed_models'] = set() self._progress['processed_models'] = set()
self._progress['failed_models'] = set() self._progress['failed_models'] = set()
else:
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
self._is_downloading = True self._is_downloading = True
self._download_task = asyncio.create_task( self._download_task = asyncio.create_task(
self._download_all_example_images( self._download_all_example_images(
output_dir, output_dir,
optimize, optimize,
model_types, model_types,
delay delay
)
) )
)
return { snapshot = self._progress.snapshot()
'success': True, except Exception as e:
'message': 'Download started', self._is_downloading = False
'status': self._progress.snapshot() self._download_task = None
} logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
except Exception as e: await self._broadcast_progress(status='running')
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e return {
'success': True,
'message': 'Download started',
'status': snapshot
}
async def get_status(self, request): async def get_status(self, request):
"""Get the current status of example images download.""" """Get the current status of example images download."""
@@ -165,10 +174,13 @@ class DownloadManager:
async def pause_download(self, request): async def pause_download(self, request):
"""Pause the example images download.""" """Pause the example images download."""
if not self._is_downloading: async with self._state_lock:
raise DownloadNotRunningError() if not self._is_downloading:
raise DownloadNotRunningError()
self._progress['status'] = 'paused' self._progress['status'] = 'paused'
await self._broadcast_progress(status='paused')
return { return {
'success': True, 'success': True,
@@ -178,20 +190,23 @@ class DownloadManager:
async def resume_download(self, request): async def resume_download(self, request):
"""Resume the example images download.""" """Resume the example images download."""
if not self._is_downloading: async with self._state_lock:
raise DownloadNotRunningError() if not self._is_downloading:
raise DownloadNotRunningError()
if self._progress['status'] == 'paused': if self._progress['status'] == 'paused':
self._progress['status'] = 'running' self._progress['status'] = 'running'
else:
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
return { await self._broadcast_progress(status='running')
'success': True,
'message': 'Download resumed'
}
raise DownloadNotRunningError( return {
f"Download is in '{self._progress['status']}' state, cannot resume" 'success': True,
) 'message': 'Download resumed'
}
async def _download_all_example_images(self, output_dir, optimize, model_types, delay): async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
"""Download example images for all models.""" """Download example images for all models."""
@@ -225,6 +240,7 @@ class DownloadManager:
# Update total count # Update total count
self._progress['total'] = len(all_models) self._progress['total'] = len(all_models)
logger.debug(f"Found {self._progress['total']} models to process") logger.debug(f"Found {self._progress['total']} models to process")
await self._broadcast_progress(status='running')
# Process each model # Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models): for i, (scanner_type, model, scanner) in enumerate(all_models):
@@ -236,6 +252,7 @@ class DownloadManager:
# Update progress # Update progress
self._progress['completed'] += 1 self._progress['completed'] += 1
await self._broadcast_progress(status='running')
# Only add delay after remote download of models, and not after processing the last model # Only add delay after remote download of models, and not after processing the last model
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running': if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
@@ -244,8 +261,13 @@ class DownloadManager:
# Mark as completed # Mark as completed
self._progress['status'] = 'completed' self._progress['status'] = 'completed'
self._progress['end_time'] = time.time() self._progress['end_time'] = time.time()
logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") logger.debug(
"Example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
await self._broadcast_progress(status='completed')
except Exception as e: except Exception as e:
error_msg = f"Error during example images download: {str(e)}" error_msg = f"Error during example images download: {str(e)}"
logger.error(error_msg, exc_info=True) logger.error(error_msg, exc_info=True)
@@ -253,7 +275,8 @@ class DownloadManager:
self._progress['last_error'] = error_msg self._progress['last_error'] = error_msg
self._progress['status'] = 'error' self._progress['status'] = 'error'
self._progress['end_time'] = time.time() self._progress['end_time'] = time.time()
await self._broadcast_progress(status='error', extra={'error': error_msg})
finally: finally:
# Save final progress to file # Save final progress to file
try: try:
@@ -262,8 +285,9 @@ class DownloadManager:
logger.error(f"Failed to save progress file: {e}") logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading # Set download status to not downloading
self._is_downloading = False async with self._state_lock:
self._download_task = None self._is_downloading = False
self._download_task = None
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader): async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
"""Process a single model download.""" """Process a single model download."""
@@ -285,6 +309,7 @@ class DownloadManager:
try: try:
# Update current model info # Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Skip if already in failed models # Skip if already in failed models
if model_hash in self._progress['failed_models']: if model_hash in self._progress['failed_models']:
@@ -414,10 +439,10 @@ class DownloadManager:
async def start_force_download(self, options: dict): async def start_force_download(self, options: dict):
"""Force download example images for specific models.""" """Force download example images for specific models."""
if self._is_downloading: async with self._state_lock:
raise DownloadInProgressError(self._progress.snapshot()) if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try:
data = options or {} data = options or {}
model_hashes = data.get('model_hashes', []) model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True) optimize = data.get('optimize', True)
@@ -442,6 +467,9 @@ class DownloadManager:
self._is_downloading = True self._is_downloading = True
await self._broadcast_progress(status='running')
try:
result = await self._download_specific_models_example_images_sync( result = await self._download_specific_models_example_images_sync(
model_hashes, model_hashes,
output_dir, output_dir,
@@ -450,7 +478,8 @@ class DownloadManager:
delay delay
) )
self._is_downloading = False async with self._state_lock:
self._is_downloading = False
return { return {
'success': True, 'success': True,
@@ -459,8 +488,10 @@ class DownloadManager:
} }
except Exception as e: except Exception as e:
self._is_downloading = False async with self._state_lock:
self._is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True) logger.error(f"Failed during forced example images download: {e}", exc_info=True)
await self._broadcast_progress(status='error', extra={'error': str(e)})
raise ExampleImagesDownloadError(str(e)) from e raise ExampleImagesDownloadError(str(e)) from e
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay): async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
@@ -495,15 +526,9 @@ class DownloadManager:
# Update total count based on found models # Update total count based on found models
self._progress['total'] = len(models_to_process) self._progress['total'] = len(models_to_process)
logger.debug(f"Found {self._progress['total']} models to process") logger.debug(f"Found {self._progress['total']} models to process")
# Send initial progress via WebSocket # Send initial progress via WebSocket
await ws_manager.broadcast({ await self._broadcast_progress(status='running')
'type': 'example_images_progress',
'processed': 0,
'total': self._progress['total'],
'status': 'running',
'current_model': ''
})
# Process each model # Process each model
success_count = 0 success_count = 0
@@ -519,15 +544,9 @@ class DownloadManager:
# Update progress # Update progress
self._progress['completed'] += 1 self._progress['completed'] += 1
# Send progress update via WebSocket # Send progress update via WebSocket
await ws_manager.broadcast({ await self._broadcast_progress(status='running')
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'running',
'current_model': self._progress['current_model']
})
# Only add delay after remote download, and not after processing the last model # Only add delay after remote download, and not after processing the last model
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running': if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
@@ -536,16 +555,14 @@ class DownloadManager:
# Mark as completed # Mark as completed
self._progress['status'] = 'completed' self._progress['status'] = 'completed'
self._progress['end_time'] = time.time() self._progress['end_time'] = time.time()
logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed") logger.debug(
"Forced example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
# Send final progress via WebSocket # Send final progress via WebSocket
await ws_manager.broadcast({ await self._broadcast_progress(status='completed')
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'completed',
'current_model': ''
})
return { return {
'total': self._progress['total'], 'total': self._progress['total'],
@@ -561,16 +578,9 @@ class DownloadManager:
self._progress['last_error'] = error_msg self._progress['last_error'] = error_msg
self._progress['status'] = 'error' self._progress['status'] = 'error'
self._progress['end_time'] = time.time() self._progress['end_time'] = time.time()
# Send error status via WebSocket # Send error status via WebSocket
await ws_manager.broadcast({ await self._broadcast_progress(status='error', extra={'error': error_msg})
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'error',
'error': error_msg,
'current_model': ''
})
raise raise
@@ -598,6 +608,7 @@ class DownloadManager:
try: try:
# Update current model info # Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})" self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Create model directory # Create model directory
model_dir = os.path.join(output_dir, model_hash) model_dir = os.path.join(output_dir, model_hash)
@@ -714,11 +725,53 @@ class DownloadManager:
except Exception as e: except Exception as e:
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True) logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
async def _broadcast_progress(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> None:
payload = self._build_progress_payload(status=status, extra=extra)
try:
await self._ws_manager.broadcast(payload)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to broadcast example image progress: %s", exc)
default_download_manager = DownloadManager() def _build_progress_payload(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': status or self._progress['status'],
'current_model': self._progress['current_model'],
}
if self._progress['errors']:
payload['errors'] = list(self._progress['errors'])
if self._progress['last_error']:
payload['last_error'] = self._progress['last_error']
if extra:
payload.update(extra)
return payload
def get_default_download_manager() -> DownloadManager: _default_download_manager: DownloadManager | None = None
def get_default_download_manager(ws_manager) -> DownloadManager:
"""Return the singleton download manager used by default routes.""" """Return the singleton download manager used by default routes."""
return default_download_manager global _default_download_manager
if (
_default_download_manager is None
or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager
):
_default_download_manager = DownloadManager(ws_manager=ws_manager)
return _default_download_manager

View File

@@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager):
RecipeRoutes.setup_routes(app) RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app) UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app) MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app) ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
# Setup WebSocket routes that are shared across all model types # Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection) app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json import json
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List, Tuple from typing import Any, Dict, List, Tuple
from aiohttp import web from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer from aiohttp.test_utils import TestClient, TestServer
@@ -88,6 +88,14 @@ class StubExampleImagesFileManager:
return web.json_response({"operation": "has_images", "query": dict(request.query)}) return web.json_response({"operation": "has_images", "query": dict(request.query)})
class StubWebSocketManager:
def __init__(self) -> None:
self.broadcast_calls: List[Dict[str, Any]] = []
async def broadcast(self, payload: Dict[str, Any]) -> None:
self.broadcast_calls.append(payload)
@asynccontextmanager @asynccontextmanager
async def example_images_app() -> ExampleImagesHarness: async def example_images_app() -> ExampleImagesHarness:
"""Yield an ExampleImagesRoutes app wired with stubbed collaborators.""" """Yield an ExampleImagesRoutes app wired with stubbed collaborators."""
@@ -95,8 +103,10 @@ async def example_images_app() -> ExampleImagesHarness:
download_manager = StubDownloadManager() download_manager = StubDownloadManager()
processor = StubExampleImagesProcessor() processor = StubExampleImagesProcessor()
file_manager = StubExampleImagesFileManager() file_manager = StubExampleImagesFileManager()
ws_manager = StubWebSocketManager()
controller = ExampleImagesRoutes( controller = ExampleImagesRoutes(
ws_manager=ws_manager,
download_manager=download_manager, download_manager=download_manager,
processor=processor, processor=processor,
file_manager=file_manager, file_manager=file_manager,