mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
Merge pull request #464 from willmiao/codex/refactor-websocket-integration-for-downloading
refactor: align example image downloads with websocket manager
This commit is contained in:
@@ -166,7 +166,7 @@ class LoraManager:
|
|||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
MiscRoutes.setup_routes(app)
|
MiscRoutes.setup_routes(app)
|
||||||
ExampleImagesRoutes.setup_routes(app)
|
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||||
|
|
||||||
# Setup WebSocket routes that are shared across all model types
|
# Setup WebSocket routes that are shared across all model types
|
||||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
|||||||
@@ -32,21 +32,24 @@ class ExampleImagesRoutes:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
ws_manager,
|
||||||
download_manager: DownloadManager | None = None,
|
download_manager: DownloadManager | None = None,
|
||||||
processor=ExampleImagesProcessor,
|
processor=ExampleImagesProcessor,
|
||||||
file_manager=ExampleImagesFileManager,
|
file_manager=ExampleImagesFileManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._download_manager = download_manager or get_default_download_manager()
|
if ws_manager is None:
|
||||||
|
raise ValueError("ws_manager is required")
|
||||||
|
self._download_manager = download_manager or get_default_download_manager(ws_manager)
|
||||||
self._processor = processor
|
self._processor = processor
|
||||||
self._file_manager = file_manager
|
self._file_manager = file_manager
|
||||||
self._handler_set: ExampleImagesHandlerSet | None = None
|
self._handler_set: ExampleImagesHandlerSet | None = None
|
||||||
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
self._handler_mapping: Mapping[str, Callable[[web.Request], web.StreamResponse]] | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_routes(cls, app: web.Application) -> None:
|
def setup_routes(cls, app: web.Application, *, ws_manager) -> None:
|
||||||
"""Register routes on the given aiohttp application using default wiring."""
|
"""Register routes on the given aiohttp application using default wiring."""
|
||||||
|
|
||||||
controller = cls()
|
controller = cls(ws_manager=ws_manager)
|
||||||
controller.register(app)
|
controller.register(app)
|
||||||
|
|
||||||
def register(self, app: web.Application) -> None:
|
def register(self, app: web.Application) -> None:
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from ..services.service_registry import ServiceRegistry
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..utils.metadata_manager import MetadataManager
|
from ..utils.metadata_manager import MetadataManager
|
||||||
from .example_images_processor import ExampleImagesProcessor
|
from .example_images_processor import ExampleImagesProcessor
|
||||||
from .example_images_metadata import MetadataUpdater
|
from .example_images_metadata import MetadataUpdater
|
||||||
from ..services.websocket_manager import ws_manager # Add this import at the top
|
|
||||||
from ..services.downloader import get_downloader
|
from ..services.downloader import get_downloader
|
||||||
from ..services.settings_manager import settings
|
from ..services.settings_manager import settings
|
||||||
|
|
||||||
@@ -76,82 +77,90 @@ class _DownloadProgress(dict):
|
|||||||
class DownloadManager:
|
class DownloadManager:
|
||||||
"""Manages downloading example images for models."""
|
"""Manages downloading example images for models."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self, *, ws_manager, state_lock: asyncio.Lock | None = None) -> None:
|
||||||
self._download_task: asyncio.Task | None = None
|
self._download_task: asyncio.Task | None = None
|
||||||
self._is_downloading = False
|
self._is_downloading = False
|
||||||
self._progress = _DownloadProgress()
|
self._progress = _DownloadProgress()
|
||||||
|
self._ws_manager = ws_manager
|
||||||
|
self._state_lock = state_lock or asyncio.Lock()
|
||||||
|
|
||||||
async def start_download(self, options: dict):
|
async def start_download(self, options: dict):
|
||||||
"""Start downloading example images for models."""
|
"""Start downloading example images for models."""
|
||||||
|
|
||||||
if self._is_downloading:
|
async with self._state_lock:
|
||||||
raise DownloadInProgressError(self._progress.snapshot())
|
if self._is_downloading:
|
||||||
|
raise DownloadInProgressError(self._progress.snapshot())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = options or {}
|
data = options or {}
|
||||||
auto_mode = data.get('auto_mode', False)
|
auto_mode = data.get('auto_mode', False)
|
||||||
optimize = data.get('optimize', True)
|
optimize = data.get('optimize', True)
|
||||||
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
model_types = data.get('model_types', ['lora', 'checkpoint'])
|
||||||
delay = float(data.get('delay', 0.2))
|
delay = float(data.get('delay', 0.2))
|
||||||
|
|
||||||
output_dir = settings.get('example_images_path')
|
output_dir = settings.get('example_images_path')
|
||||||
|
|
||||||
if not output_dir:
|
if not output_dir:
|
||||||
error_msg = 'Example images path not configured in settings'
|
error_msg = 'Example images path not configured in settings'
|
||||||
if auto_mode:
|
if auto_mode:
|
||||||
logger.debug(error_msg)
|
logger.debug(error_msg)
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
'message': 'Example images path not configured, skipping auto download'
|
'message': 'Example images path not configured, skipping auto download'
|
||||||
}
|
}
|
||||||
raise DownloadConfigurationError(error_msg)
|
raise DownloadConfigurationError(error_msg)
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
self._progress.reset()
|
self._progress.reset()
|
||||||
self._progress['status'] = 'running'
|
self._progress['status'] = 'running'
|
||||||
self._progress['start_time'] = time.time()
|
self._progress['start_time'] = time.time()
|
||||||
self._progress['end_time'] = None
|
self._progress['end_time'] = None
|
||||||
|
|
||||||
progress_file = os.path.join(output_dir, '.download_progress.json')
|
progress_file = os.path.join(output_dir, '.download_progress.json')
|
||||||
if os.path.exists(progress_file):
|
if os.path.exists(progress_file):
|
||||||
try:
|
try:
|
||||||
with open(progress_file, 'r', encoding='utf-8') as f:
|
with open(progress_file, 'r', encoding='utf-8') as f:
|
||||||
saved_progress = json.load(f)
|
saved_progress = json.load(f)
|
||||||
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
self._progress['processed_models'] = set(saved_progress.get('processed_models', []))
|
||||||
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
self._progress['failed_models'] = set(saved_progress.get('failed_models', []))
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded previous progress, %s models already processed, %s models marked as failed",
|
"Loaded previous progress, %s models already processed, %s models marked as failed",
|
||||||
len(self._progress['processed_models']),
|
len(self._progress['processed_models']),
|
||||||
len(self._progress['failed_models']),
|
len(self._progress['failed_models']),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load progress file: {e}")
|
logger.error(f"Failed to load progress file: {e}")
|
||||||
|
self._progress['processed_models'] = set()
|
||||||
|
self._progress['failed_models'] = set()
|
||||||
|
else:
|
||||||
self._progress['processed_models'] = set()
|
self._progress['processed_models'] = set()
|
||||||
self._progress['failed_models'] = set()
|
self._progress['failed_models'] = set()
|
||||||
else:
|
|
||||||
self._progress['processed_models'] = set()
|
|
||||||
self._progress['failed_models'] = set()
|
|
||||||
|
|
||||||
self._is_downloading = True
|
self._is_downloading = True
|
||||||
self._download_task = asyncio.create_task(
|
self._download_task = asyncio.create_task(
|
||||||
self._download_all_example_images(
|
self._download_all_example_images(
|
||||||
output_dir,
|
output_dir,
|
||||||
optimize,
|
optimize,
|
||||||
model_types,
|
model_types,
|
||||||
delay
|
delay
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
snapshot = self._progress.snapshot()
|
||||||
'success': True,
|
except Exception as e:
|
||||||
'message': 'Download started',
|
self._is_downloading = False
|
||||||
'status': self._progress.snapshot()
|
self._download_task = None
|
||||||
}
|
logger.error(f"Failed to start example images download: {e}", exc_info=True)
|
||||||
|
raise ExampleImagesDownloadError(str(e)) from e
|
||||||
|
|
||||||
except Exception as e:
|
await self._broadcast_progress(status='running')
|
||||||
logger.error(f"Failed to start example images download: {e}", exc_info=True)
|
|
||||||
raise ExampleImagesDownloadError(str(e)) from e
|
return {
|
||||||
|
'success': True,
|
||||||
|
'message': 'Download started',
|
||||||
|
'status': snapshot
|
||||||
|
}
|
||||||
|
|
||||||
async def get_status(self, request):
|
async def get_status(self, request):
|
||||||
"""Get the current status of example images download."""
|
"""Get the current status of example images download."""
|
||||||
@@ -165,10 +174,13 @@ class DownloadManager:
|
|||||||
async def pause_download(self, request):
|
async def pause_download(self, request):
|
||||||
"""Pause the example images download."""
|
"""Pause the example images download."""
|
||||||
|
|
||||||
if not self._is_downloading:
|
async with self._state_lock:
|
||||||
raise DownloadNotRunningError()
|
if not self._is_downloading:
|
||||||
|
raise DownloadNotRunningError()
|
||||||
|
|
||||||
self._progress['status'] = 'paused'
|
self._progress['status'] = 'paused'
|
||||||
|
|
||||||
|
await self._broadcast_progress(status='paused')
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -178,20 +190,23 @@ class DownloadManager:
|
|||||||
async def resume_download(self, request):
|
async def resume_download(self, request):
|
||||||
"""Resume the example images download."""
|
"""Resume the example images download."""
|
||||||
|
|
||||||
if not self._is_downloading:
|
async with self._state_lock:
|
||||||
raise DownloadNotRunningError()
|
if not self._is_downloading:
|
||||||
|
raise DownloadNotRunningError()
|
||||||
|
|
||||||
if self._progress['status'] == 'paused':
|
if self._progress['status'] == 'paused':
|
||||||
self._progress['status'] = 'running'
|
self._progress['status'] = 'running'
|
||||||
|
else:
|
||||||
|
raise DownloadNotRunningError(
|
||||||
|
f"Download is in '{self._progress['status']}' state, cannot resume"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
await self._broadcast_progress(status='running')
|
||||||
'success': True,
|
|
||||||
'message': 'Download resumed'
|
|
||||||
}
|
|
||||||
|
|
||||||
raise DownloadNotRunningError(
|
return {
|
||||||
f"Download is in '{self._progress['status']}' state, cannot resume"
|
'success': True,
|
||||||
)
|
'message': 'Download resumed'
|
||||||
|
}
|
||||||
|
|
||||||
async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
|
async def _download_all_example_images(self, output_dir, optimize, model_types, delay):
|
||||||
"""Download example images for all models."""
|
"""Download example images for all models."""
|
||||||
@@ -225,6 +240,7 @@ class DownloadManager:
|
|||||||
# Update total count
|
# Update total count
|
||||||
self._progress['total'] = len(all_models)
|
self._progress['total'] = len(all_models)
|
||||||
logger.debug(f"Found {self._progress['total']} models to process")
|
logger.debug(f"Found {self._progress['total']} models to process")
|
||||||
|
await self._broadcast_progress(status='running')
|
||||||
|
|
||||||
# Process each model
|
# Process each model
|
||||||
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
for i, (scanner_type, model, scanner) in enumerate(all_models):
|
||||||
@@ -236,6 +252,7 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
self._progress['completed'] += 1
|
self._progress['completed'] += 1
|
||||||
|
await self._broadcast_progress(status='running')
|
||||||
|
|
||||||
# Only add delay after remote download of models, and not after processing the last model
|
# Only add delay after remote download of models, and not after processing the last model
|
||||||
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
|
if was_remote_download and i < len(all_models) - 1 and self._progress['status'] == 'running':
|
||||||
@@ -244,8 +261,13 @@ class DownloadManager:
|
|||||||
# Mark as completed
|
# Mark as completed
|
||||||
self._progress['status'] = 'completed'
|
self._progress['status'] = 'completed'
|
||||||
self._progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
logger.debug(f"Example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
|
logger.debug(
|
||||||
|
"Example images download completed: %s/%s models processed",
|
||||||
|
self._progress['completed'],
|
||||||
|
self._progress['total'],
|
||||||
|
)
|
||||||
|
await self._broadcast_progress(status='completed')
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Error during example images download: {str(e)}"
|
error_msg = f"Error during example images download: {str(e)}"
|
||||||
logger.error(error_msg, exc_info=True)
|
logger.error(error_msg, exc_info=True)
|
||||||
@@ -253,7 +275,8 @@ class DownloadManager:
|
|||||||
self._progress['last_error'] = error_msg
|
self._progress['last_error'] = error_msg
|
||||||
self._progress['status'] = 'error'
|
self._progress['status'] = 'error'
|
||||||
self._progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
|
await self._broadcast_progress(status='error', extra={'error': error_msg})
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Save final progress to file
|
# Save final progress to file
|
||||||
try:
|
try:
|
||||||
@@ -262,8 +285,9 @@ class DownloadManager:
|
|||||||
logger.error(f"Failed to save progress file: {e}")
|
logger.error(f"Failed to save progress file: {e}")
|
||||||
|
|
||||||
# Set download status to not downloading
|
# Set download status to not downloading
|
||||||
self._is_downloading = False
|
async with self._state_lock:
|
||||||
self._download_task = None
|
self._is_downloading = False
|
||||||
|
self._download_task = None
|
||||||
|
|
||||||
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
|
async def _process_model(self, scanner_type, model, scanner, output_dir, optimize, downloader):
|
||||||
"""Process a single model download."""
|
"""Process a single model download."""
|
||||||
@@ -285,6 +309,7 @@ class DownloadManager:
|
|||||||
try:
|
try:
|
||||||
# Update current model info
|
# Update current model info
|
||||||
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||||
|
await self._broadcast_progress(status='running')
|
||||||
|
|
||||||
# Skip if already in failed models
|
# Skip if already in failed models
|
||||||
if model_hash in self._progress['failed_models']:
|
if model_hash in self._progress['failed_models']:
|
||||||
@@ -414,10 +439,10 @@ class DownloadManager:
|
|||||||
async def start_force_download(self, options: dict):
|
async def start_force_download(self, options: dict):
|
||||||
"""Force download example images for specific models."""
|
"""Force download example images for specific models."""
|
||||||
|
|
||||||
if self._is_downloading:
|
async with self._state_lock:
|
||||||
raise DownloadInProgressError(self._progress.snapshot())
|
if self._is_downloading:
|
||||||
|
raise DownloadInProgressError(self._progress.snapshot())
|
||||||
|
|
||||||
try:
|
|
||||||
data = options or {}
|
data = options or {}
|
||||||
model_hashes = data.get('model_hashes', [])
|
model_hashes = data.get('model_hashes', [])
|
||||||
optimize = data.get('optimize', True)
|
optimize = data.get('optimize', True)
|
||||||
@@ -442,6 +467,9 @@ class DownloadManager:
|
|||||||
|
|
||||||
self._is_downloading = True
|
self._is_downloading = True
|
||||||
|
|
||||||
|
await self._broadcast_progress(status='running')
|
||||||
|
|
||||||
|
try:
|
||||||
result = await self._download_specific_models_example_images_sync(
|
result = await self._download_specific_models_example_images_sync(
|
||||||
model_hashes,
|
model_hashes,
|
||||||
output_dir,
|
output_dir,
|
||||||
@@ -450,7 +478,8 @@ class DownloadManager:
|
|||||||
delay
|
delay
|
||||||
)
|
)
|
||||||
|
|
||||||
self._is_downloading = False
|
async with self._state_lock:
|
||||||
|
self._is_downloading = False
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'success': True,
|
'success': True,
|
||||||
@@ -459,8 +488,10 @@ class DownloadManager:
|
|||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._is_downloading = False
|
async with self._state_lock:
|
||||||
|
self._is_downloading = False
|
||||||
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
logger.error(f"Failed during forced example images download: {e}", exc_info=True)
|
||||||
|
await self._broadcast_progress(status='error', extra={'error': str(e)})
|
||||||
raise ExampleImagesDownloadError(str(e)) from e
|
raise ExampleImagesDownloadError(str(e)) from e
|
||||||
|
|
||||||
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
|
async def _download_specific_models_example_images_sync(self, model_hashes, output_dir, optimize, model_types, delay):
|
||||||
@@ -495,15 +526,9 @@ class DownloadManager:
|
|||||||
# Update total count based on found models
|
# Update total count based on found models
|
||||||
self._progress['total'] = len(models_to_process)
|
self._progress['total'] = len(models_to_process)
|
||||||
logger.debug(f"Found {self._progress['total']} models to process")
|
logger.debug(f"Found {self._progress['total']} models to process")
|
||||||
|
|
||||||
# Send initial progress via WebSocket
|
# Send initial progress via WebSocket
|
||||||
await ws_manager.broadcast({
|
await self._broadcast_progress(status='running')
|
||||||
'type': 'example_images_progress',
|
|
||||||
'processed': 0,
|
|
||||||
'total': self._progress['total'],
|
|
||||||
'status': 'running',
|
|
||||||
'current_model': ''
|
|
||||||
})
|
|
||||||
|
|
||||||
# Process each model
|
# Process each model
|
||||||
success_count = 0
|
success_count = 0
|
||||||
@@ -519,15 +544,9 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
self._progress['completed'] += 1
|
self._progress['completed'] += 1
|
||||||
|
|
||||||
# Send progress update via WebSocket
|
# Send progress update via WebSocket
|
||||||
await ws_manager.broadcast({
|
await self._broadcast_progress(status='running')
|
||||||
'type': 'example_images_progress',
|
|
||||||
'processed': self._progress['completed'],
|
|
||||||
'total': self._progress['total'],
|
|
||||||
'status': 'running',
|
|
||||||
'current_model': self._progress['current_model']
|
|
||||||
})
|
|
||||||
|
|
||||||
# Only add delay after remote download, and not after processing the last model
|
# Only add delay after remote download, and not after processing the last model
|
||||||
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
|
if was_successful and i < len(models_to_process) - 1 and self._progress['status'] == 'running':
|
||||||
@@ -536,16 +555,14 @@ class DownloadManager:
|
|||||||
# Mark as completed
|
# Mark as completed
|
||||||
self._progress['status'] = 'completed'
|
self._progress['status'] = 'completed'
|
||||||
self._progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
logger.debug(f"Forced example images download completed: {self._progress['completed']}/{self._progress['total']} models processed")
|
logger.debug(
|
||||||
|
"Forced example images download completed: %s/%s models processed",
|
||||||
|
self._progress['completed'],
|
||||||
|
self._progress['total'],
|
||||||
|
)
|
||||||
|
|
||||||
# Send final progress via WebSocket
|
# Send final progress via WebSocket
|
||||||
await ws_manager.broadcast({
|
await self._broadcast_progress(status='completed')
|
||||||
'type': 'example_images_progress',
|
|
||||||
'processed': self._progress['completed'],
|
|
||||||
'total': self._progress['total'],
|
|
||||||
'status': 'completed',
|
|
||||||
'current_model': ''
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'total': self._progress['total'],
|
'total': self._progress['total'],
|
||||||
@@ -561,16 +578,9 @@ class DownloadManager:
|
|||||||
self._progress['last_error'] = error_msg
|
self._progress['last_error'] = error_msg
|
||||||
self._progress['status'] = 'error'
|
self._progress['status'] = 'error'
|
||||||
self._progress['end_time'] = time.time()
|
self._progress['end_time'] = time.time()
|
||||||
|
|
||||||
# Send error status via WebSocket
|
# Send error status via WebSocket
|
||||||
await ws_manager.broadcast({
|
await self._broadcast_progress(status='error', extra={'error': error_msg})
|
||||||
'type': 'example_images_progress',
|
|
||||||
'processed': self._progress['completed'],
|
|
||||||
'total': self._progress['total'],
|
|
||||||
'status': 'error',
|
|
||||||
'error': error_msg,
|
|
||||||
'current_model': ''
|
|
||||||
})
|
|
||||||
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -598,6 +608,7 @@ class DownloadManager:
|
|||||||
try:
|
try:
|
||||||
# Update current model info
|
# Update current model info
|
||||||
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
self._progress['current_model'] = f"{model_name} ({model_hash[:8]})"
|
||||||
|
await self._broadcast_progress(status='running')
|
||||||
|
|
||||||
# Create model directory
|
# Create model directory
|
||||||
model_dir = os.path.join(output_dir, model_hash)
|
model_dir = os.path.join(output_dir, model_hash)
|
||||||
@@ -714,11 +725,53 @@ class DownloadManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
|
logger.error(f"Error removing failed images from metadata for {model_name}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _broadcast_progress(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
extra: Dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
payload = self._build_progress_payload(status=status, extra=extra)
|
||||||
|
try:
|
||||||
|
await self._ws_manager.broadcast(payload)
|
||||||
|
except Exception as exc: # pragma: no cover - defensive logging
|
||||||
|
logger.warning("Failed to broadcast example image progress: %s", exc)
|
||||||
|
|
||||||
default_download_manager = DownloadManager()
|
def _build_progress_payload(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
extra: Dict[str, Any] | None = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
'type': 'example_images_progress',
|
||||||
|
'processed': self._progress['completed'],
|
||||||
|
'total': self._progress['total'],
|
||||||
|
'status': status or self._progress['status'],
|
||||||
|
'current_model': self._progress['current_model'],
|
||||||
|
}
|
||||||
|
|
||||||
|
if self._progress['errors']:
|
||||||
|
payload['errors'] = list(self._progress['errors'])
|
||||||
|
if self._progress['last_error']:
|
||||||
|
payload['last_error'] = self._progress['last_error']
|
||||||
|
|
||||||
|
if extra:
|
||||||
|
payload.update(extra)
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def get_default_download_manager() -> DownloadManager:
|
_default_download_manager: DownloadManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_download_manager(ws_manager) -> DownloadManager:
|
||||||
"""Return the singleton download manager used by default routes."""
|
"""Return the singleton download manager used by default routes."""
|
||||||
|
|
||||||
return default_download_manager
|
global _default_download_manager
|
||||||
|
if (
|
||||||
|
_default_download_manager is None
|
||||||
|
or getattr(_default_download_manager, "_ws_manager", None) is not ws_manager
|
||||||
|
):
|
||||||
|
_default_download_manager = DownloadManager(ws_manager=ws_manager)
|
||||||
|
return _default_download_manager
|
||||||
|
|||||||
@@ -421,7 +421,7 @@ class StandaloneLoraManager(LoraManager):
|
|||||||
RecipeRoutes.setup_routes(app)
|
RecipeRoutes.setup_routes(app)
|
||||||
UpdateRoutes.setup_routes(app)
|
UpdateRoutes.setup_routes(app)
|
||||||
MiscRoutes.setup_routes(app)
|
MiscRoutes.setup_routes(app)
|
||||||
ExampleImagesRoutes.setup_routes(app)
|
ExampleImagesRoutes.setup_routes(app, ws_manager=ws_manager)
|
||||||
|
|
||||||
# Setup WebSocket routes that are shared across all model types
|
# Setup WebSocket routes that are shared across all model types
|
||||||
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List, Tuple
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from aiohttp.test_utils import TestClient, TestServer
|
from aiohttp.test_utils import TestClient, TestServer
|
||||||
@@ -88,6 +88,14 @@ class StubExampleImagesFileManager:
|
|||||||
return web.json_response({"operation": "has_images", "query": dict(request.query)})
|
return web.json_response({"operation": "has_images", "query": dict(request.query)})
|
||||||
|
|
||||||
|
|
||||||
|
class StubWebSocketManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.broadcast_calls: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
async def broadcast(self, payload: Dict[str, Any]) -> None:
|
||||||
|
self.broadcast_calls.append(payload)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def example_images_app() -> ExampleImagesHarness:
|
async def example_images_app() -> ExampleImagesHarness:
|
||||||
"""Yield an ExampleImagesRoutes app wired with stubbed collaborators."""
|
"""Yield an ExampleImagesRoutes app wired with stubbed collaborators."""
|
||||||
@@ -95,8 +103,10 @@ async def example_images_app() -> ExampleImagesHarness:
|
|||||||
download_manager = StubDownloadManager()
|
download_manager = StubDownloadManager()
|
||||||
processor = StubExampleImagesProcessor()
|
processor = StubExampleImagesProcessor()
|
||||||
file_manager = StubExampleImagesFileManager()
|
file_manager = StubExampleImagesFileManager()
|
||||||
|
ws_manager = StubWebSocketManager()
|
||||||
|
|
||||||
controller = ExampleImagesRoutes(
|
controller = ExampleImagesRoutes(
|
||||||
|
ws_manager=ws_manager,
|
||||||
download_manager=download_manager,
|
download_manager=download_manager,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
file_manager=file_manager,
|
file_manager=file_manager,
|
||||||
|
|||||||
Reference in New Issue
Block a user