Skip to content

Commit

Permalink
Refactor readers to allow multiframe clienthello (#251)
Browse files Browse the repository at this point in the history
* Refactor readers to allow multiframe clienthello

* lint

* limit payload_reader

* Change typeddict to attrs
  • Loading branch information
ludeeus authored Apr 23, 2024
1 parent 00f559d commit bb27a71
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 51 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ convention = "pep257"
[tool.ruff.lint.pylint]
max-args = 15
max-branches = 30
max-returns = 7
max-returns = 8
max-statements = 80

[tool.setuptools]
Expand Down
4 changes: 4 additions & 0 deletions snitun/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class ParseSNIError(SniTunError):
"""Invalid ClientHello data."""


class ParseSNIIncompleteError(ParseSNIError):
"""Incomplete ClientHello data."""


class MultiplexerTransportError(SniTunError):
"""Raise if multiplexer have an problem with peer."""

Expand Down
5 changes: 2 additions & 3 deletions snitun/server/listener_sni.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
ParseSNIError,
)
from ..multiplexer.core import Multiplexer
from ..utils.server import MAX_READ_SIZE
from .peer_manager import PeerManager
from .sni import parse_tls_sni
from .sni import parse_tls_sni, payload_reader

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,7 +63,7 @@ async def handle_connection(
if data is None:
try:
async with async_timeout.timeout(2):
client_hello = await reader.read(MAX_READ_SIZE)
client_hello = await payload_reader(reader)
except asyncio.TimeoutError:
_LOGGER.warning("Abort SNI handshake")
writer.close()
Expand Down
70 changes: 54 additions & 16 deletions snitun/server/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from typing import Awaitable, Iterable

import async_timeout
import attr

from ..exceptions import ParseSNIIncompleteError
from ..utils.server import MAX_READ_SIZE
from .listener_peer import PeerListener
from .listener_sni import SNIProxy
Expand All @@ -28,6 +30,14 @@
WORKER_STALE_MAX = 30


@attr.s(slots=True)
class PartialData:
"""Partial data class."""

data: bytes = attr.ib(default=b"")
count: int = attr.ib(default=1)


class SniTunServer:
"""SniTunServer helper class for Dual port Asyncio."""

Expand Down Expand Up @@ -224,6 +234,19 @@ def run(self) -> None:
connections: dict[int, socket.socket] = {}
worker_lb = cycle(self._workers)
stale: dict[int, int] = {}
partial: dict[int, PartialData] = {}

def _register_parial(fileno: int, data: bytes) -> bool:
"""Register partial data."""
if current := partial.get(fileno):
current.data = data
current.count += 1
else:
partial[fileno] = current = PartialData(data=data)

return len(current.data) < MAX_READ_SIZE and current.count < 4

_LOGGER.warning("Server started, fd: %s", fd_server)

while self._running:
events = self._poller.poll(1)
Expand All @@ -242,14 +265,21 @@ def run(self) -> None:

# Read hello & forward to worker
elif event & select.EPOLLIN:
partialdata = partial[fileno].data if fileno in partial else b""
if data := self._process(con, worker_lb, partialdata):
if _register_parial(fileno, data):
continue
self._close_socket(con, shutdown=False)

self._poller.unregister(fileno)
con = connections.pop(fileno)
self._process(con, worker_lb)
partial.pop(fileno, None)

# Close
else:
self._poller.unregister(fileno)
con = connections.pop(fileno)
partial.pop(fileno, None)
self._close_socket(con, shutdown=False)

# cleanup stale connection
Expand All @@ -259,6 +289,7 @@ def run(self) -> None:
elif stale[fileno] >= WORKER_STALE_MAX:
self._poller.unregister(fileno)
con = connections.pop(fileno)
partial.pop(fileno, None)
self._close_socket(con)
else:
stale[fileno] += 1
Expand All @@ -270,48 +301,55 @@ def run(self) -> None:
_LOGGER.critical("Worker '%s' crashed!", worker.name)
os.kill(os.getpid(), signal.SIGINT)

def _process(self, con: socket.socket, workers_lb: Iterable[ServerWorker]) -> None:
def _process(
self,
con: socket.socket,
workers_lb: Iterable[ServerWorker],
data: bytes,
) -> None | bytes:
"""Process connection & helo."""
data = b""
try:
data = con.recv(MAX_READ_SIZE)
data += con.recv(MAX_READ_SIZE)
except OSError as err:
_LOGGER.warning("Receive fails: %s", err)
self._close_socket(con, shutdown=False)
return
return None

# No data received
if not data:
self._close_socket(con)
return
return None

# Peer connection
if data.startswith(b"gA"):
next(workers_lb).handover_connection(con, data)
_LOGGER.debug("Handover new peer connection: %s", data)
return
return None

# TLS/SSL connection
if data[0] != 0x16:
_LOGGER.warning("No valid ClientHello found: %s", data)
self._close_socket(con)
return
return None

try:
hostname = parse_tls_sni(data)
except ParseSNIIncompleteError:
return data
except ParseSNIError:
_LOGGER.warning("Receive invalid ClientHello on public Interface")
else:
for worker in self._workers:
if not worker.is_responsible_peer(hostname):
continue
worker.handover_connection(con, data, sni=hostname)
return None
for worker in self._workers:
if not worker.is_responsible_peer(hostname):
continue
worker.handover_connection(con, data, sni=hostname)

_LOGGER.info("Handover %s to %s", hostname, worker.name)
return
_LOGGER.debug("No responsible worker for %s", hostname)
_LOGGER.info("Handover %s to %s", hostname, worker.name)
return None
_LOGGER.debug("No responsible worker for %s", hostname)

self._close_socket(con)
return None

@staticmethod
def _close_socket(con: socket.socket, shutdown: bool = True) -> None:
Expand Down
58 changes: 46 additions & 12 deletions snitun/server/sni.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""TLS ClientHello parser."""

from __future__ import annotations

import asyncio
import logging

from ..exceptions import ParseSNIError
from ..exceptions import ParseSNIError, ParseSNIIncompleteError
from ..utils.server import MAX_READ_SIZE

_LOGGER = logging.getLogger(__name__)

Expand All @@ -10,9 +15,38 @@
TLS_HANDSHAKE_TYPE_CLIENT_HELLO = 0x01


async def payload_reader(reader: asyncio.StreamReader) -> bytes | None:
"""Read data from reader."""
try:
header = await reader.read(6)
except ConnectionResetError:
raise ParseSNIError from None

if not header:
raise ParseSNIError
if len(header) < 5:
raise ParseSNIError

if (
header[0] != TLS_HANDSHAKE_CONTENT_TYPE
or header[5] != TLS_HANDSHAKE_TYPE_CLIENT_HELLO
):
return None

tls_size = (header[3] << 8) + header[4] + TLS_HEADER_LEN
data = header
while (data_size := len(data)) < tls_size and data_size < MAX_READ_SIZE:
try:
data += await reader.read(MAX_READ_SIZE)
except ConnectionResetError:
raise ParseSNIError from None

return data


def parse_tls_sni(data: bytes) -> str:
"""Parse TLS SNI extention."""
if len(data) < TLS_HEADER_LEN:
if (data_size := len(data)) < TLS_HEADER_LEN:
_LOGGER.debug("Invalid TLS header")
raise ParseSNIError

Expand All @@ -28,9 +62,9 @@ def parse_tls_sni(data: bytes) -> str:

# Calculate TLS record size
tls_size = (data[3] << 8) + data[4] + TLS_HEADER_LEN
if len(data) < tls_size:
if data_size < tls_size:
_LOGGER.debug("Can't calculate the TLS record size")
raise ParseSNIError
raise ParseSNIIncompleteError

# Check if handshake is a ClientHello
pos = TLS_HEADER_LEN
Expand Down Expand Up @@ -63,15 +97,15 @@ def parse_tls_sni(data: bytes) -> str:
raise ParseSNIError from None

# Check data buffer + extension size
if pos + 2 > len(data):
if pos + 2 > data_size:
_LOGGER.debug("Mismatch Extension TLS header")
raise ParseSNIError

# Process extension
return _parse_extension(data, pos)
return _parse_extension(data, pos, data_size)


def _parse_extension(data: bytes, pos: int) -> str:
def _parse_extension(data: bytes, pos: int, data_size: int) -> str:
"""Parse TLS ClientHello Extension."""
# Seek Extension start
try:
Expand All @@ -81,28 +115,28 @@ def _parse_extension(data: bytes, pos: int) -> str:
raise ParseSNIError from None

# Check data buffer + extension size
if pos + tls_extension_size > len(data):
if pos + tls_extension_size > data_size:
_LOGGER.debug("Mismatch Extension TLS header")
raise ParseSNIError

# Loop over extension until we have our SNI
while pos + 4 <= len(data):
while pos + 4 <= data_size:
# SNI?
if data[pos] == 0x00 and data[pos + 1] == 0x00:
return _parse_host_name(data, pos + 4)
return _parse_host_name(data, pos + 4, data_size)

pos += 4 + (data[pos + 2] << 8) + data[pos + 3]

_LOGGER.debug("Can't find any ServerName Extension")
raise ParseSNIError


def _parse_host_name(data: bytes, pos: int) -> str:
def _parse_host_name(data: bytes, pos: int, data_size: int) -> str:
"""Parse TLS ServerName Extension."""
# Seek list size
pos += 2

while pos + 3 < len(data):
while pos + 3 < data_size:
size = (data[pos + 1] << 8) + data[pos + 2]

# Unknown server name type
Expand Down
32 changes: 28 additions & 4 deletions tests/server/test_listener_sni.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
"""Test for SSL SNI proxy."""

from __future__ import annotations

import asyncio
import ipaddress

import pytest

from snitun.server.listener_sni import SNIProxy

from .const_tls import TLS_1_2
Expand All @@ -17,9 +22,24 @@ async def test_proxy_up_down():
await proxy.stop()


async def test_sni_proxy_flow(multiplexer_client, test_client_ssl):
@pytest.mark.parametrize(
"payloads",
[
[TLS_1_2],
[TLS_1_2[:6], TLS_1_2[6:]],
[TLS_1_2[:6], TLS_1_2[6:20], TLS_1_2[20:]],
[TLS_1_2[:6], TLS_1_2[6:20], TLS_1_2[20:32], TLS_1_2[32:]],
],
)
async def test_sni_proxy_flow(
multiplexer_client,
test_client_ssl,
payloads: list[bytes],
):
"""Test a normal flow of connection and exchange data."""
test_client_ssl.writer.write(TLS_1_2)
for payload in payloads:
test_client_ssl.writer.write(payload)
await asyncio.sleep(0.1)
await test_client_ssl.writer.drain()
await asyncio.sleep(0.1)

Expand All @@ -42,7 +62,9 @@ async def test_sni_proxy_flow(multiplexer_client, test_client_ssl):


async def test_sni_proxy_flow_close_by_client(
multiplexer_client, test_client_ssl, event_loop
multiplexer_client,
test_client_ssl,
event_loop,
):
"""Test a normal flow of connection data and close by client."""
loop = event_loop
Expand Down Expand Up @@ -74,7 +96,9 @@ async def test_sni_proxy_flow_close_by_client(


async def test_sni_proxy_flow_close_by_server(
multiplexer_client, test_client_ssl, event_loop
multiplexer_client,
test_client_ssl,
event_loop,
):
"""Test a normal flow of connection data and close by server."""
loop = event_loop
Expand Down
Loading

0 comments on commit bb27a71

Please sign in to comment.