diff --git a/py/services/websocket_manager.py b/py/services/websocket_manager.py index 0d2389a7..223d8b76 100644 --- a/py/services/websocket_manager.py +++ b/py/services/websocket_manager.py @@ -16,6 +16,8 @@ class WebSocketManager: self._download_websockets: Dict[str, web.WebSocketResponse] = {} # New dict for download-specific clients # Add progress tracking dictionary self._download_progress: Dict[str, Dict] = {} + # Cache last initialization progress payloads + self._last_init_progress: Dict[str, Dict] = {} # Add auto-organize progress tracking self._auto_organize_progress: Optional[Dict] = None self._auto_organize_lock = asyncio.Lock() @@ -39,8 +41,10 @@ class WebSocketManager: ws = web.WebSocketResponse() await ws.prepare(request) self._init_websockets.add(ws) - + try: + await self._send_cached_init_progress(ws) + async for msg in ws: if msg.type == web.WSMsgType.ERROR: logger.error(f'Init WebSocket error: {ws.exception()}') @@ -102,23 +106,53 @@ class WebSocketManager: async def broadcast_init_progress(self, data: Dict): """Broadcast initialization progress to connected clients""" + payload = dict(data) if data else {} + + if 'stage' not in payload: + payload['stage'] = 'processing' + if 'progress' not in payload: + payload['progress'] = 0 + if 'details' not in payload: + payload['details'] = 'Processing...' + + key = self._get_init_progress_key(payload) + self._last_init_progress[key] = dict(payload) + if not self._init_websockets: return - - # Ensure data has all required fields - if 'stage' not in data: - data['stage'] = 'processing' - if 'progress' not in data: - data['progress'] = 0 - if 'details' not in data: - data['details'] = 'Processing...' - - for ws in self._init_websockets: + + stale_clients = [] + for ws in list(self._init_websockets): try: - await ws.send_json(data) + await ws.send_json(payload) except Exception as e: logger.error(f"Error sending initialization progress: {e}") - + stale_clients.append(ws) + + for ws in stale_clients: + self._init_websockets.discard(ws) + + async def _send_cached_init_progress(self, ws: web.WebSocketResponse) -> None: + """Send cached initialization progress payloads to a new client""" + if not self._last_init_progress: + return + + for payload in list(self._last_init_progress.values()): + try: + await ws.send_json(payload) + except Exception as e: + logger.debug(f'Error sending cached initialization progress: {e}') + + def _get_init_progress_key(self, data: Dict) -> str: + """Return a stable key for caching initialization progress payloads""" + page_type = data.get('pageType') + if page_type: + return f'page:{page_type}' + scanner_type = data.get('scanner_type') + if scanner_type: + return f'scanner:{scanner_type}' + return 'global' + async def broadcast_download_progress(self, download_id: str, data: Dict): """Send progress update to specific download client""" # Store simplified progress data in memory (only progress percentage) @@ -202,4 +236,5 @@ class WebSocketManager: return str(uuid4()) # Global instance -ws_manager = WebSocketManager() \ No newline at end of file +ws_manager = WebSocketManager() + diff --git a/static/js/managers/SettingsManager.js b/static/js/managers/SettingsManager.js index 34bad67d..d5f3f5d4 100644 --- a/static/js/managers/SettingsManager.js +++ b/static/js/managers/SettingsManager.js @@ -482,13 +482,17 @@ export class SettingsManager { librarySelect.disabled = true; try { + state.loadingManager.showSimpleLoading('Activating library...'); await this.activateLibrary(selectedLibrary); + // Add a short delay before reloading the page + await new Promise(resolve => setTimeout(resolve, 300)); window.location.reload(); } catch (error) { console.error('Failed to activate library:', error); showToast('toast.settings.libraryActivateFailed', { message: error.message }, 'error'); await this.loadLibraries(); } finally { + state.loadingManager.hide(); if (!document.hidden) { librarySelect.disabled = librarySelect.options.length <= 1; } diff --git a/tests/services/test_websocket_manager.py b/tests/services/test_websocket_manager.py index b85c2197..9fc5250e 100644 --- a/tests/services/test_websocket_manager.py +++ b/tests/services/test_websocket_manager.py @@ -36,6 +36,25 @@ async def test_broadcast_init_progress_adds_defaults(manager): ] +async def test_broadcast_init_progress_caches_payload(manager): + await manager.broadcast_init_progress({'pageType': 'loras', 'progress': 42}) + + cached = manager._last_init_progress.get('page:loras') + assert cached is not None + assert cached['progress'] == 42 + assert cached['stage'] == 'processing' + assert cached['details'] == 'Processing...' + + +async def test_send_cached_progress_to_new_client(manager): + await manager.broadcast_init_progress({'pageType': 'loras', 'progress': 87}) + + ws = DummyWebSocket() + await manager._send_cached_init_progress(ws) + + assert ws.messages[-1]['progress'] == 87 + assert ws.messages[-1]['pageType'] == 'loras' + async def test_broadcast_download_progress_tracks_state(manager): ws = DummyWebSocket() download_id = "abc"