From 839accf88a2c18eef47c7891106817c6f5d8f552 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 9 Jan 2025 01:28:57 +0000 Subject: [PATCH] Add WebSocket support for command execution and HVNC streaming --- base.py | 55 ++++++++++++++++ endpoint.py | 81 +++++++++++++++++++++++- websocket.py | 174 +++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 308 insertions(+), 2 deletions(-) create mode 100644 websocket.py diff --git a/base.py b/base.py index 059528c..9f677c4 100644 --- a/base.py +++ b/base.py @@ -7,6 +7,8 @@ ApiResponseException, ) import requests.exceptions as re +from .websocket import WebSocketConnection +from typing import Optional, Dict, Callable class ApiConnection: @@ -18,6 +20,7 @@ def __init__(self, base_url: str): base_url = base_url[:-1] self.base_url = base_url self.session = Session() + self._ws_connections: Dict[str, WebSocketConnection] = {} def request(self, method: str, endpoint: str, raw=False, **kwargs): response: Response @@ -53,7 +56,59 @@ def add_header(self, header, value): self.headers[header] = value def close_session(self): + """Close HTTP session and WebSocket connections""" self.session.close() + for ws in self._ws_connections.values(): + ws.disconnect() + self._ws_connections.clear() def _get_headers(self): return self.headers + + def connect_websocket(self, endpoint: str) -> WebSocketConnection: + """Create and connect to a WebSocket endpoint + + Args: + endpoint: WebSocket endpoint path (e.g., '/handler') + + Returns: + WebSocket connection object + """ + # Convert HTTP URL to WebSocket URL + ws_url = self.base_url.replace('http://', 'ws://').replace('https://', 'wss://') + ws_url = f"{ws_url}{endpoint}" + + # Create WebSocket connection + ws = WebSocketConnection(ws_url) + + # Add headers + for header, value in self.headers.items(): + ws.set_header(header, value) + + # Store connection + self._ws_connections[endpoint] = ws + + # Connect + ws.connect() + return ws + + def get_websocket(self, endpoint: str) -> Optional[WebSocketConnection]: + """Get an existing WebSocket connection + + Args: + endpoint: WebSocket endpoint path + + Returns: + WebSocket connection or None if not found + """ + return self._ws_connections.get(endpoint) + + def disconnect_websocket(self, endpoint: str): + """Disconnect a specific WebSocket connection + + Args: + endpoint: WebSocket endpoint path + """ + if endpoint in self._ws_connections: + self._ws_connections[endpoint].disconnect() + del self._ws_connections[endpoint] diff --git a/endpoint.py b/endpoint.py index a8e6a9f..5bc2740 100644 --- a/endpoint.py +++ b/endpoint.py @@ -57,8 +57,85 @@ class OperatorEndpoint(ApiEndpointTemplate): class ListenerEndpoint(ApiEndpointTemplate): + def __init__(self, client: ApiConnection, endpoint: str): + super().__init__(client, endpoint) + self._ws = None + + def connect_ws(self): + """Connect to listener WebSocket endpoint""" + self._ws = self.client.connect_websocket("/listener") + return self._ws + + def disconnect_ws(self): + """Disconnect from listener WebSocket endpoint""" + if self._ws: + self.client.disconnect_websocket("/listener") + self._ws = None + def transmit(self, magick: str, data: str): - return self._post({"magick": magick, "data": data}, path="/transmit", raw=True) + """Transmit data using WebSocket if connected, fallback to HTTP""" + if self._ws: + self._ws.send_message({ + "type": "listener", + "action": "response", + "data": { + "magick": magick, + "payload": data + } + }) + else: + return self._post({"magick": magick, "data": data}, path="/transmit", raw=True) + + def on_message(self, callback): + """Register callback for WebSocket messages""" + if self._ws: + self._ws.add_callback("listener", callback) class HandlerEndpoint(ApiEndpointTemplate): - pass + def __init__(self, client: ApiConnection, endpoint: str): + super().__init__(client, endpoint) + self._ws = None + + def connect_ws(self): + """Connect to handler WebSocket endpoint""" + self._ws = self.client.connect_websocket("/handler") + return self._ws + + def disconnect_ws(self): + """Disconnect from handler WebSocket endpoint""" + if self._ws: + self.client.disconnect_websocket("/handler") + self._ws = None + + def execute_command(self, command: str, args: list = None): + """Execute a command using WebSocket""" + if not self._ws: + raise ApiResponseException("WebSocket not connected") + + self._ws.send_message({ + "type": "command", + "action": "execute", + "data": { + "command": command, + "args": args or [] + } + }) + + def stream_image(self, image_data: str, metadata: dict): + """Stream image data using WebSocket""" + if not self._ws: + raise ApiResponseException("WebSocket not connected") + + self._ws.send_message({ + "type": "image", + "action": "stream", + "data": { + "image_data": image_data, + "metadata": metadata + } + }) + + def on_message(self, callback): + """Register callback for WebSocket messages""" + if self._ws: + self._ws.add_callback("handler", callback) diff --git a/websocket.py b/websocket.py new file mode 100644 index 0000000..6a119d8 --- /dev/null +++ b/websocket.py @@ -0,0 +1,174 @@ +import websockets +import asyncio +import json +from typing import Optional, Dict, Any, Callable +import threading +import queue +import logging + +class WebSocketConnection: + """Manages WebSocket connections with automatic reconnection and message queuing""" + + def __init__(self, url: str): + self._url = url + self._ws: Optional[websockets.WebSocketClientProtocol] = None + self._connected = False + self._should_reconnect = True + self._reconnect_delay = 1 # Initial delay in seconds + self._max_reconnect_delay = 30 + self._message_queue = queue.Queue() + self._send_thread = None + self._receive_thread = None + self._headers: Dict[str, str] = {} + self._callbacks: Dict[str, Callable] = {} + self._lock = threading.Lock() + self._event_loop = None + + def set_header(self, name: str, value: str): + """Set a header for the WebSocket connection""" + self._headers[name] = value + + def connect(self): + """Initialize connection to WebSocket server""" + self._should_reconnect = True + self._start_event_loop() + self._ensure_send_thread() + self._ensure_receive_thread() + + def disconnect(self): + """Disconnect from WebSocket server""" + self._should_reconnect = False + if self._ws: + asyncio.run_coroutine_threadsafe(self._ws.close(), self._event_loop) + + def send_message(self, message: Any): + """Queue a message to be sent + Args: + message: Can be string or dict (will be converted to JSON) + """ + try: + if isinstance(message, dict): + message = json.dumps(message) + self._message_queue.put(message) + except Exception as e: + logging.error(f"Error queueing message: {e}") + + def add_callback(self, event_type: str, callback: Callable): + """Add a callback for a specific event type""" + self._callbacks[event_type] = callback + + def _start_event_loop(self): + """Start the asyncio event loop in a separate thread""" + def run_event_loop(): + self._event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._event_loop) + self._event_loop.run_forever() + + thread = threading.Thread(target=run_event_loop, daemon=True) + thread.start() + + def _ensure_send_thread(self): + """Ensure the send thread is running""" + if not self._send_thread or not self._send_thread.is_alive(): + self._send_thread = threading.Thread( + target=self._message_sender, + daemon=True + ) + self._send_thread.start() + + def _ensure_receive_thread(self): + """Ensure the receive thread is running""" + if not self._receive_thread or not self._receive_thread.is_alive(): + self._receive_thread = threading.Thread( + target=self._message_receiver, + daemon=True + ) + self._receive_thread.start() + + async def _connect(self): + """Establish WebSocket connection""" + try: + extra_headers = [(k, v) for k, v in self._headers.items()] + self._ws = await websockets.connect( + self._url, + extra_headers=extra_headers, + ping_interval=20, + ping_timeout=20 + ) + self._connected = True + logging.info("WebSocket connected") + return True + except Exception as e: + logging.error(f"WebSocket connection error: {e}") + return False + + def _message_sender(self): + """Background thread for sending messages""" + while self._should_reconnect or not self._message_queue.empty(): + try: + message = self._message_queue.get(timeout=1.0) + if self._ws and self._connected: + future = asyncio.run_coroutine_threadsafe( + self._ws.send(message), + self._event_loop + ) + future.result() # Wait for the send to complete + self._message_queue.task_done() + except queue.Empty: + continue + except Exception as e: + logging.error(f"Error in message sender: {e}") + # Put the message back in the queue if send failed + try: + self._message_queue.put(message) + except: + pass + + def _message_receiver(self): + """Background thread for receiving messages""" + while self._should_reconnect: + try: + if not self._connected: + future = asyncio.run_coroutine_threadsafe( + self._connect(), + self._event_loop + ) + if not future.result(): + # Connection failed, wait before retry + delay = min(self._reconnect_delay * 2, self._max_reconnect_delay) + threading.Event().wait(delay) + continue + + # Start receiving messages + while self._connected and self._ws: + try: + future = asyncio.run_coroutine_threadsafe( + self._ws.recv(), + self._event_loop + ) + message = future.result() + self._handle_message(message) + except Exception as e: + logging.error(f"Error receiving message: {e}") + self._connected = False + break + + except Exception as e: + logging.error(f"Error in message receiver: {e}") + self._connected = False + + def _handle_message(self, message: str): + """Handle incoming WebSocket messages""" + try: + data = json.loads(message) + event_type = data.get("type") + if event_type and event_type in self._callbacks: + self._callbacks[event_type](data) + except Exception as e: + logging.error(f"Error handling message: {e}") + + def __del__(self): + """Cleanup when object is destroyed""" + self.disconnect() + if self._event_loop: + self._event_loop.stop() \ No newline at end of file