Merge pull request #464 from willmiao/codex/refactor-websocket-integration-for-downloading

refactor: align example image downloads with websocket manager
This commit is contained in:
pixelpaws
2025-09-23 14:43:37 +08:00
committed by GitHub
5 changed files with 191 additions and 125 deletions

View File

@@ -166,7 +166,7 @@ class LoraManager:
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)

View File

@@ -32,21 +32,24 @@ class ExampleImagesRoutes:
def __init__(
self,
*,
ws_manager,
download_manager: DownloadManager | None = None,
processor=ExampleImagesProcessor,
file_manager=ExampleImagesFileManager,
) -> None:
self._download_manager = download_manager or get_default_download_manager()
if ws_manager is None:
raise ValueError("ws_manager is required")
self._download_manager = download_manager or get_default_download_manager(ws_manager)
self._processor = processor
self._file_manager = file_manager
self._handler_set: ExampleImagesHandlerSet | None = None
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
@classmethod
def setup_routes(cls, app: web.Application) -> None:
def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
"""Register routes on the given aiohttp application using default wiring."""
controller = cls()
controller = cls(ws_manager=ws_manager)
controller.register(app)
def register(self, app: web.Application) -> None:

View File

@@ -5,11 +5,12 @@ import os
import asyncio
import json
import time
from typing import Any, Dict
from ..services.service_registry import ServiceRegistry
from ..utils.metadata_manager import MetadataManager
from .example_images_processor import ExampleImagesProcessor
from .example_images_metadata import MetadataUpdater
from ..services.websocket_manager import ws_manager # Add this import at the top
from ..services.downloader import get_downloader
from ..services.settings_manager import settings
@@ -76,82 +77,90 @@ class _DownloadProgress(dict):
class DownloadManager:
"""Manages downloading example images for models."""
def __init__(self) -> None:
def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None:
self._download_task: asyncio.Task | None = None
self._is_downloading = False
self._progress = _DownloadProgress()
self._ws_manager = ws_manager
self._state_lock = state_lock or asyncio.Lock()
async def start_download(self, options: dict):
"""Start downloading example images for models."""
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try:
data = options or {}
auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2))
try:
data = options or {}
auto_mode = data.get('auto_mode', False)
optimize = data.get('optimize', True)
model_types = data.get('model_types', ['lora', 'checkpoint'])
delay = float(data.get('delay', 0.2))
output_dir = settings.get('example_images_path')
output_dir = settings.get('example_images_path')
if not output_dir:
error_msg = 'Example images path not configured in settings'
if auto_mode:
logger.debug(error_msg)
return {
'success': True,
'message': 'Example images path not configured, skipping auto download'
}
raise DownloadConfigurationError(error_msg)
if not output_dir:
error_msg = 'Example images path not configured in settings'
if auto_mode:
logger.debug(error_msg)
return {
'success': True,
'message': 'Example images path not configured, skipping auto download'
}
raise DownloadConfigurationError(error_msg)
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
self._progress.reset()
self._progress['status'] = 'running'
self._progress['start_time'] = time.time()
self._progress['end_time'] = None
self._progress.reset()
self._progress['status'] = 'running'
self._progress['start_time'] = time.time()
self._progress['end_time'] = None
progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
saved_progress = json.load(f)
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress['processed_models']),
len(self._progress['failed_models']),
)
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
progress_file = os.path.join(output_dir, '.download_progress.json')
if os.path.exists(progress_file):
try:
with open(progress_file, 'r', encoding='utf-8') as f:
saved_progress = json.load(f)
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
logger.debug(
"Loaded previous progress, %s models already processed, %s models marked as failed",
len(self._progress['processed_models']),
len(self._progress['failed_models']),
)
except Exception as e:
logger.error(f"Failed to load progress file: {e}")
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
else:
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
else:
self._progress['processed_models'] = set()
self._progress['failed_models'] = set()
self._is_downloading = True
self._download_task = asyncio.create_task(
self._download_all_example_images(
output_dir,
optimize,
model_types,
delay
self._is_downloading = True
self._download_task = asyncio.create_task(
self._download_all_example_images(
output_dir,
optimize,
model_types,
delay
)
)
)
return {
'success': True,
'message': 'Download started',
'status': self._progress.snapshot()
}
snapshot = self._progress.snapshot()
except Exception as e:
self._is_downloading = False
self._download_task = None
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
except Exception as e:
logger.error(f"Failed to start example images download: {e}", exc_info=True)
raise ExampleImagesDownloadError(str(e)) from e
await self._broadcast_progress(status='running')
return {
'success': True,
'message': 'Download started',
'status': snapshot
}
async def get_status(self, request):
"""Get the current status of example images download."""
@@ -165,10 +174,13 @@ class DownloadManager:
async def pause_download(self, request):
"""Pause the example images download."""
if not self._is_downloading:
raise DownloadNotRunningError()
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
self._progress['status'] = 'paused'
self._progress['status'] = 'paused'
await self._broadcast_progress(status='paused')
return {
'success': True,
@@ -178,20 +190,23 @@ class DownloadManager:
async def resume_download(self, request):
"""Resume the example images download."""
if not self._is_downloading:
raise DownloadNotRunningError()
async with self._state_lock:
if not self._is_downloading:
raise DownloadNotRunningError()
if self._progress['status'] == 'paused':
self._progress['status'] = 'running'
if self._progress['status'] == 'paused':
self._progress['status'] = 'running'
else:
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
return {
'success': True,
'message': 'Download resumed'
}
await self._broadcast_progress(status='running')
raise DownloadNotRunningError(
f"Download is in '{self._progress['status']}' state, cannot resume"
)
return {
'success': True,
'message': 'Download resumed'
}
async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
"""Download example images for all models."""
@@ -225,6 +240,7 @@ class DownloadManager:
# Update total count
self._progress['total'] = len(all_models)
logger.debug(f"Found {self._progress['total']} models to process")
await self._broadcast_progress(status='running')
# Process each model
for i, (scanner_type, model, scanner) in enumerate(all_models):
@@ -236,6 +252,7 @@ class DownloadManager:
# Update progress
self._progress['completed'] += 1
await self._broadcast_progress(status='running')
# Only add delay after remote download of models, and not after processing the last model
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
@@ -244,8 +261,13 @@ class DownloadManager:
# Mark as completed
self._progress['status'] = 'completed'
self._progress['end_time'] = time.time()
logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
logger.debug(
"Example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
await self._broadcast_progress(status='completed')
except Exception as e:
error_msg = f"Error during example images download: {str(e)}"
logger.error(error_msg, exc_info=True)
@@ -253,7 +275,8 @@ class DownloadManager:
self._progress['last_error'] = error_msg
self._progress['status'] = 'error'
self._progress['end_time'] = time.time()
await self._broadcast_progress(status='error', extra={'error': error_msg})
finally:
# Save final progress to file
try:
@@ -262,8 +285,9 @@ class DownloadManager:
logger.error(f"Failed to save progress file: {e}")
# Set download status to not downloading
self._is_downloading = False
self._download_task = None
async with self._state_lock:
self._is_downloading = False
self._download_task = None
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
"""Process a single model download."""
@@ -285,6 +309,7 @@ class DownloadManager:
try:
# Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Skip if already in failed models
if model_hash in self._progress['failed_models']:
@@ -414,10 +439,10 @@ class DownloadManager:
async def start_force_download(self, options: dict):
"""Force download example images for specific models."""
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
async with self._state_lock:
if self._is_downloading:
raise DownloadInProgressError(self._progress.snapshot())
try:
data = options or {}
model_hashes = data.get('model_hashes', [])
optimize = data.get('optimize', True)
@@ -442,6 +467,9 @@ class DownloadManager:
self._is_downloading = True
await self._broadcast_progress(status='running')
try:
result = await self._download_specific_models_example_images_sync(
model_hashes,
output_dir,
@@ -450,7 +478,8 @@ class DownloadManager:
delay
)
self._is_downloading = False
async with self._state_lock:
self._is_downloading = False
return {
'success': True,
@@ -459,8 +488,10 @@ class DownloadManager:
}
except Exception as e:
self._is_downloading = False
async with self._state_lock:
self._is_downloading = False
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
await self._broadcast_progress(status='error', extra={'error': str(e)})
raise ExampleImagesDownloadError(str(e)) from e
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
@@ -495,15 +526,9 @@ class DownloadManager:
# Update total count based on found models
self._progress['total'] = len(models_to_process)
logger.debug(f"Found {self._progress['total']} models to process")
# Send initial progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': 0,
'total': self._progress['total'],
'status': 'running',
'current_model': ''
})
await self._broadcast_progress(status='running')
# Process each model
success_count = 0
@@ -519,15 +544,9 @@ class DownloadManager:
# Update progress
self._progress['completed'] += 1
# Send progress update via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'running',
'current_model': self._progress['current_model']
})
await self._broadcast_progress(status='running')
# Only add delay after remote download, and not after processing the last model
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
@@ -536,16 +555,14 @@ class DownloadManager:
# Mark as completed
self._progress['status'] = 'completed'
self._progress['end_time'] = time.time()
logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
logger.debug(
"Forced example images download completed: %s/%s models processed",
self._progress['completed'],
self._progress['total'],
)
# Send final progress via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'completed',
'current_model': ''
})
await self._broadcast_progress(status='completed')
return {
'total': self._progress['total'],
@@ -561,16 +578,9 @@ class DownloadManager:
self._progress['last_error'] = error_msg
self._progress['status'] = 'error'
self._progress['end_time'] = time.time()
# Send error status via WebSocket
await ws_manager.broadcast({
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': 'error',
'error': error_msg,
'current_model': ''
})
await self._broadcast_progress(status='error', extra={'error': error_msg})
raise
@@ -598,6 +608,7 @@ class DownloadManager:
try:
# Update current model info
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
await self._broadcast_progress(status='running')
# Create model directory
model_dir = os.path.join(output_dir, model_hash)
@@ -714,11 +725,53 @@ class DownloadManager:
except Exception as e:
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
async def _broadcast_progress(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> None:
payload = self._build_progress_payload(status=status, extra=extra)
try:
await self._ws_manager.broadcast(payload)
except Exception as exc: # pragma: no cover - defensive logging
logger.warning("Failed to broadcast example image progress: %s", exc)
default_download_manager = DownloadManager()
def _build_progress_payload(
self,
*,
status: str | None = None,
extra: Dict[str, Any] | None = None,
) -> Dict[str, Any]:
payload: Dict[str, Any] = {
'type': 'example_images_progress',
'processed': self._progress['completed'],
'total': self._progress['total'],
'status': status or self._progress['status'],
'current_model': self._progress['current_model'],
}
if self._progress['errors']:
payload['errors'] = list(self._progress['errors'])
if self._progress['last_error']:
payload['last_error'] = self._progress['last_error']
if extra:
payload.update(extra)
return payload
def get_default_download_manager() -> DownloadManager:
_default_download_manager: DownloadManager | None = None
def get_default_download_manager(ws_manager) -> DownloadManager:
"""Return the singleton download manager used by default routes."""
return default_download_manager
global _default_download_manager
if (
_default_download_manager is None
or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager
):
_default_download_manager = DownloadManager(ws_manager=ws_manager)
return _default_download_manager

View File

@@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager):
RecipeRoutes.setup_routes(app)
UpdateRoutes.setup_routes(app)
MiscRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app)
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
# Setup WebSocket routes that are shared across all model types
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List, Tuple
from typing import Any, Dict, List, Tuple
from aiohttp import web
from aiohttp.test_utils import TestClient, TestServer
@@ -88,6 +88,14 @@ class StubExampleImagesFileManager:
return web.json_response({"operation": "has_images", "query": dict(request.query)})
class StubWebSocketManager:
def __init__(self) -> None:
self.broadcast_calls: List[Dict[str, Any]] = []
async def broadcast(self, payload: Dict[str, Any]) -> None:
self.broadcast_calls.append(payload)
@asynccontextmanager
async def example_images_app() -> ExampleImagesHarness:
"""Yield an ExampleImagesRoutes app wired with stubbed collaborators."""
@@ -95,8 +103,10 @@ async def example_images_app() -> ExampleImagesHarness:
download_manager = StubDownloadManager()
processor = StubExampleImagesProcessor()
file_manager = StubExampleImagesFileManager()
ws_manager = StubWebSocketManager()
controller = ExampleImagesRoutes(
ws_manager=ws_manager,
download_manager=download_manager,
processor=processor,
file_manager=file_manager,