Improved server_logic WebSocket Handling

Refined the server_logic function to better manage WebSocket disruptions:

- Integrated `handle_websocket_failure()` directly into `server_logic` to capture and inform about unexpected disconnection scenarios.
- Added conditionals to reset the `latest_image` when the WebSocket connection is deemed inactive.
- Ensured that the error handling mechanism is more resilient to abrupt connection terminations, especially in the context of receiving messages.
This commit is contained in:
TSC
2023-08-10 19:49:29 -05:00
committed by GitHub
parent 60e0b1bcc0
commit 235e6aea01

View File

@@ -524,7 +524,7 @@ import base64
from io import BytesIO from io import BytesIO
from torchvision import transforms from torchvision import transforms
latest_image = None latest_image = list()
connected_client = None connected_client = None
websocket_status = True websocket_status = True
@@ -538,17 +538,26 @@ def handle_websocket_failure():
f"preview is enabled and vae decoding is set to 'true`.") f"preview is enabled and vae decoding is set to 'true`.")
async def server_logic(websocket, path): async def server_logic(websocket, path):
global latest_image, connected_client global latest_image, connected_client, websocket_status
# If websocket_status is False, set latest_image to an empty list
if not websocket_status:
latest_image = list()
# Assign the connected client # Assign the connected client
connected_client = websocket connected_client = websocket
async for message in websocket: try:
# If not a command, treat it as image data async for message in websocket:
if not message.startswith('{'): # If not a command, treat it as image data
image_data = base64.b64decode(message.split(",")[1]) if not message.startswith('{'):
image = Image.open(BytesIO(image_data)) image_data = base64.b64decode(message.split(",")[1])
latest_image = pil2tensor(image) image = Image.open(BytesIO(image_data))
latest_image = pil2tensor(image)
except (websockets.exceptions.ConnectionClosedError, asyncio.exceptions.CancelledError):
handle_websocket_failure()
except Exception:
handle_websocket_failure()
def run_server(): def run_server():
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()