From 235e6aea015e418b890bc7cb443c5ccd4b3bb3c4 Mon Sep 17 00:00:00 2001 From: TSC <112517630+LucianoCirino@users.noreply.github.com> Date: Thu, 10 Aug 2023 19:49:29 -0500 Subject: [PATCH] Improved server_logic WebSocket Handling Refined the server_logic function to better manage WebSocket disruptions: - Integrated `handle_websocket_failure()` directly into `server_logic` to capture and inform about unexpected disconnection scenarios. - Added conditionals to reset the `latest_image` when the WebSocket connection is deemed inactive. - Ensured that the error handling mechanism is more resilient to abrupt connection terminations, especially in the context of receiving messages. --- tsc_utils.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/tsc_utils.py b/tsc_utils.py index 7a9551c..f2c0ade 100644 --- a/tsc_utils.py +++ b/tsc_utils.py @@ -524,7 +524,7 @@ import base64 from io import BytesIO from torchvision import transforms -latest_image = None +latest_image = list() connected_client = None websocket_status = True @@ -538,17 +538,26 @@ def handle_websocket_failure(): f"preview is enabled and vae decoding is set to 'true`.") async def server_logic(websocket, path): - global latest_image, connected_client + global latest_image, connected_client, websocket_status + + # If websocket_status is False, set latest_image to an empty list + if not websocket_status: + latest_image = list() # Assign the connected client connected_client = websocket - async for message in websocket: - # If not a command, treat it as image data - if not message.startswith('{'): - image_data = base64.b64decode(message.split(",")[1]) - image = Image.open(BytesIO(image_data)) - latest_image = pil2tensor(image) + try: + async for message in websocket: + # If not a command, treat it as image data + if not message.startswith('{'): + image_data = base64.b64decode(message.split(",")[1]) + image = Image.open(BytesIO(image_data)) + latest_image = pil2tensor(image) + except (websockets.exceptions.ConnectionClosedError, asyncio.exceptions.CancelledError): + handle_websocket_failure() + except Exception: + handle_websocket_failure() def run_server(): loop = asyncio.new_event_loop()