Skip to content

Add WebSocket support for command execution and HVNC streaming #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
ApiResponseException,
)
import requests.exceptions as re
from .websocket import WebSocketConnection
from typing import Optional, Dict, Callable


class ApiConnection:
Expand All @@ -18,6 +20,7 @@ def __init__(self, base_url: str):
base_url = base_url[:-1]
self.base_url = base_url
self.session = Session()
self._ws_connections: Dict[str, WebSocketConnection] = {}

def request(self, method: str, endpoint: str, raw=False, **kwargs):
response: Response
Expand Down Expand Up @@ -53,7 +56,59 @@ def add_header(self, header, value):
self.headers[header] = value

def close_session(self):
"""Close HTTP session and WebSocket connections"""
self.session.close()
for ws in self._ws_connections.values():
ws.disconnect()
self._ws_connections.clear()

def _get_headers(self):
return self.headers

def connect_websocket(self, endpoint: str) -> WebSocketConnection:
"""Create and connect to a WebSocket endpoint

Args:
endpoint: WebSocket endpoint path (e.g., '/handler')

Returns:
WebSocket connection object
"""
# Convert HTTP URL to WebSocket URL
ws_url = self.base_url.replace('http://', 'ws://').replace('https://', 'wss://')
ws_url = f"{ws_url}{endpoint}"

# Create WebSocket connection
ws = WebSocketConnection(ws_url)

# Add headers
for header, value in self.headers.items():
ws.set_header(header, value)

# Store connection
self._ws_connections[endpoint] = ws

# Connect
ws.connect()
return ws

def get_websocket(self, endpoint: str) -> Optional[WebSocketConnection]:
"""Get an existing WebSocket connection

Args:
endpoint: WebSocket endpoint path

Returns:
WebSocket connection or None if not found
"""
return self._ws_connections.get(endpoint)

def disconnect_websocket(self, endpoint: str):
"""Disconnect a specific WebSocket connection

Args:
endpoint: WebSocket endpoint path
"""
if endpoint in self._ws_connections:
self._ws_connections[endpoint].disconnect()
del self._ws_connections[endpoint]
81 changes: 79 additions & 2 deletions endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,85 @@ class OperatorEndpoint(ApiEndpointTemplate):


class ListenerEndpoint(ApiEndpointTemplate):
def __init__(self, client: ApiConnection, endpoint: str):
super().__init__(client, endpoint)
self._ws = None

def connect_ws(self):
"""Connect to listener WebSocket endpoint"""
self._ws = self.client.connect_websocket("/listener")
return self._ws

def disconnect_ws(self):
"""Disconnect from listener WebSocket endpoint"""
if self._ws:
self.client.disconnect_websocket("/listener")
self._ws = None

def transmit(self, magick: str, data: str):
return self._post({"magick": magick, "data": data}, path="/transmit", raw=True)
"""Transmit data using WebSocket if connected, fallback to HTTP"""
if self._ws:
self._ws.send_message({
"type": "listener",
"action": "response",
"data": {
"magick": magick,
"payload": data
}
})
else:
return self._post({"magick": magick, "data": data}, path="/transmit", raw=True)

def on_message(self, callback):
"""Register callback for WebSocket messages"""
if self._ws:
self._ws.add_callback("listener", callback)

class HandlerEndpoint(ApiEndpointTemplate):
pass
def __init__(self, client: ApiConnection, endpoint: str):
super().__init__(client, endpoint)
self._ws = None

def connect_ws(self):
"""Connect to handler WebSocket endpoint"""
self._ws = self.client.connect_websocket("/handler")
return self._ws

def disconnect_ws(self):
"""Disconnect from handler WebSocket endpoint"""
if self._ws:
self.client.disconnect_websocket("/handler")
self._ws = None

def execute_command(self, command: str, args: list = None):
"""Execute a command using WebSocket"""
if not self._ws:
raise ApiResponseException("WebSocket not connected")

self._ws.send_message({
"type": "command",
"action": "execute",
"data": {
"command": command,
"args": args or []
}
})

def stream_image(self, image_data: str, metadata: dict):
"""Stream image data using WebSocket"""
if not self._ws:
raise ApiResponseException("WebSocket not connected")

self._ws.send_message({
"type": "image",
"action": "stream",
"data": {
"image_data": image_data,
"metadata": metadata
}
})

def on_message(self, callback):
"""Register callback for WebSocket messages"""
if self._ws:
self._ws.add_callback("handler", callback)
174 changes: 174 additions & 0 deletions websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import websockets
import asyncio
import json
from typing import Optional, Dict, Any, Callable
import threading
import queue
import logging

class WebSocketConnection:
"""Manages WebSocket connections with automatic reconnection and message queuing"""

def __init__(self, url: str):
self._url = url
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._connected = False
self._should_reconnect = True
self._reconnect_delay = 1 # Initial delay in seconds
self._max_reconnect_delay = 30
self._message_queue = queue.Queue()
self._send_thread = None
self._receive_thread = None
self._headers: Dict[str, str] = {}
self._callbacks: Dict[str, Callable] = {}
self._lock = threading.Lock()
self._event_loop = None

def set_header(self, name: str, value: str):
"""Set a header for the WebSocket connection"""
self._headers[name] = value

def connect(self):
"""Initialize connection to WebSocket server"""
self._should_reconnect = True
self._start_event_loop()
self._ensure_send_thread()
self._ensure_receive_thread()

def disconnect(self):
"""Disconnect from WebSocket server"""
self._should_reconnect = False
if self._ws:
asyncio.run_coroutine_threadsafe(self._ws.close(), self._event_loop)

def send_message(self, message: Any):
"""Queue a message to be sent
Args:
message: Can be string or dict (will be converted to JSON)
"""
try:
if isinstance(message, dict):
message = json.dumps(message)
self._message_queue.put(message)
except Exception as e:
logging.error(f"Error queueing message: {e}")

def add_callback(self, event_type: str, callback: Callable):
"""Add a callback for a specific event type"""
self._callbacks[event_type] = callback

def _start_event_loop(self):
"""Start the asyncio event loop in a separate thread"""
def run_event_loop():
self._event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._event_loop)
self._event_loop.run_forever()

thread = threading.Thread(target=run_event_loop, daemon=True)
thread.start()

def _ensure_send_thread(self):
"""Ensure the send thread is running"""
if not self._send_thread or not self._send_thread.is_alive():
self._send_thread = threading.Thread(
target=self._message_sender,
daemon=True
)
self._send_thread.start()

def _ensure_receive_thread(self):
"""Ensure the receive thread is running"""
if not self._receive_thread or not self._receive_thread.is_alive():
self._receive_thread = threading.Thread(
target=self._message_receiver,
daemon=True
)
self._receive_thread.start()

async def _connect(self):
"""Establish WebSocket connection"""
try:
extra_headers = [(k, v) for k, v in self._headers.items()]
self._ws = await websockets.connect(
self._url,
extra_headers=extra_headers,
ping_interval=20,
ping_timeout=20
)
self._connected = True
logging.info("WebSocket connected")
return True
except Exception as e:
logging.error(f"WebSocket connection error: {e}")
return False

def _message_sender(self):
"""Background thread for sending messages"""
while self._should_reconnect or not self._message_queue.empty():
try:
message = self._message_queue.get(timeout=1.0)
if self._ws and self._connected:
future = asyncio.run_coroutine_threadsafe(
self._ws.send(message),
self._event_loop
)
future.result() # Wait for the send to complete
self._message_queue.task_done()
except queue.Empty:
continue
except Exception as e:
logging.error(f"Error in message sender: {e}")
# Put the message back in the queue if send failed
try:
self._message_queue.put(message)
except:
pass

def _message_receiver(self):
"""Background thread for receiving messages"""
while self._should_reconnect:
try:
if not self._connected:
future = asyncio.run_coroutine_threadsafe(
self._connect(),
self._event_loop
)
if not future.result():
# Connection failed, wait before retry
delay = min(self._reconnect_delay * 2, self._max_reconnect_delay)
threading.Event().wait(delay)
continue

# Start receiving messages
while self._connected and self._ws:
try:
future = asyncio.run_coroutine_threadsafe(
self._ws.recv(),
self._event_loop
)
message = future.result()
self._handle_message(message)
except Exception as e:
logging.error(f"Error receiving message: {e}")
self._connected = False
break

except Exception as e:
logging.error(f"Error in message receiver: {e}")
self._connected = False

def _handle_message(self, message: str):
"""Handle incoming WebSocket messages"""
try:
data = json.loads(message)
event_type = data.get("type")
if event_type and event_type in self._callbacks:
self._callbacks[event_type](data)
except Exception as e:
logging.error(f"Error handling message: {e}")

def __del__(self):
"""Cleanup when object is destroyed"""
self.disconnect()
if self._event_loop:
self._event_loop.stop()