diff --git a/tsc_utils.py b/tsc_utils.py index 7a9551c..f2c0ade 100644 --- a/tsc_utils.py +++ b/tsc_utils.py @@ -524,7 +524,7 @@ import base64 from io import BytesIO from torchvision import transforms -latest_image = None +latest_image = list() connected_client = None websocket_status = True @@ -538,17 +538,26 @@ def handle_websocket_failure(): f"preview is enabled and vae decoding is set to 'true`.") async def server_logic(websocket, path): - global latest_image, connected_client + global latest_image, connected_client, websocket_status + + # If websocket_status is False, set latest_image to an empty list + if not websocket_status: + latest_image = list() # Assign the connected client connected_client = websocket - async for message in websocket: - # If not a command, treat it as image data - if not message.startswith('{'): - image_data = base64.b64decode(message.split(",")[1]) - image = Image.open(BytesIO(image_data)) - latest_image = pil2tensor(image) + try: + async for message in websocket: + # If not a command, treat it as image data + if not message.startswith('{'): + image_data = base64.b64decode(message.split(",")[1]) + image = Image.open(BytesIO(image_data)) + latest_image = pil2tensor(image) + except (websockets.exceptions.ConnectionClosedError, asyncio.exceptions.CancelledError): + handle_websocket_failure() + except Exception: + handle_websocket_failure() def run_server(): loop = asyncio.new_event_loop()