Skip to content

Commit

Permalink
feat(x): add https and http/2 support to the CONNECT client
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolaikabanenkov committed Jan 23, 2025
1 parent efa8083 commit 0e6749a
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 46 deletions.
116 changes: 76 additions & 40 deletions x/httpconnect/connect_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,31 @@ package httpconnect

import (
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/Jigsaw-Code/outline-sdk/transport"
"golang.org/x/net/http2"
"io"
"net"
"net/http"
)

// connectClient is a [transport.StreamDialer] implementation that dials [proxyAddr] with the given [dialer]
// and sends a CONNECT request to the dialed proxy.
type connectClient struct {
// ConnectClient is a [transport.StreamDialer] implementation that dials proxyAddr with the given dialer and sends a CONNECT request to the dialed proxy.
// By default, the client uses "http", but it can be changed to "https" with the [WithHTTPS] option.
type ConnectClient struct {
dialer transport.StreamDialer
proxyAddr string

headers http.Header
scheme string
tlsConfig *tls.Config
headers http.Header
}

var _ transport.StreamDialer = (*connectClient)(nil)
var _ transport.StreamDialer = (*ConnectClient)(nil)

type ClientOption func(c *connectClient)
type ClientOption func(c *ConnectClient)

func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...ClientOption) (transport.StreamDialer, error) {
func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...ClientOption) (*ConnectClient, error) {
if dialer == nil {
return nil, errors.New("dialer must not be nil")
}
Expand All @@ -46,10 +49,10 @@ func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...C
return nil, fmt.Errorf("failed to parse proxy address %s: %w", proxyAddr, err)
}

cc := &connectClient{
cc := &ConnectClient{
dialer: dialer,
proxyAddr: proxyAddr,
headers: make(http.Header),
scheme: "http",
}

for _, opt := range opts {
Expand All @@ -59,69 +62,102 @@ func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...C
return cc, nil
}

// WithHeaders appends the given [headers] to the CONNECT request
// WithHTTPS sets the scheme to "https" and the given tlsConfig to the transport
func WithHTTPS(tlsConfig *tls.Config) ClientOption {
return func(c *ConnectClient) {
c.scheme = "https"
c.tlsConfig = tlsConfig.Clone()
}
}

// WithHeaders appends the given headers to the CONNECT request
func WithHeaders(headers http.Header) ClientOption {
return func(c *connectClient) {
return func(c *ConnectClient) {
c.headers = headers.Clone()
}
}

// DialStream - connects to the proxy and sends a CONNECT request to it, closes the connection if the request fails
func (cc *connectClient) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) {
func (cc *ConnectClient) DialStream(ctx context.Context, remoteAddr string) (streamConn transport.StreamConn, err error) {
_, _, err = net.SplitHostPort(remoteAddr)
if err != nil {
return nil, fmt.Errorf("failed to parse remote address %s: %w", remoteAddr, err)
}

innerConn, err := cc.dialer.DialStream(ctx, cc.proxyAddr)
if err != nil {
return nil, fmt.Errorf("failed to dial proxy %s: %w", cc.proxyAddr, err)
}
defer func() {
if err != nil {
_ = innerConn.Close()
}
}()

conn, err := cc.doConnect(ctx, remoteAddr, innerConn)
roundTripper, err := cc.buildTransport(innerConn)
if err != nil {
return nil, fmt.Errorf("failed to build roundTripper: %w", err)
}

reader, writer, err := doConnect(ctx, roundTripper, cc.scheme, remoteAddr, cc.headers)
if err != nil {
_ = innerConn.Close()
return nil, fmt.Errorf("doConnect %s: %w", remoteAddr, err)
}

return conn, nil
return &pipeConn{
reader: reader,
writer: writer,
StreamConn: innerConn,
}, nil
}

func (cc *connectClient) doConnect(ctx context.Context, remoteAddr string, conn transport.StreamConn) (transport.StreamConn, error) {
_, _, err := net.SplitHostPort(remoteAddr)
func (cc *ConnectClient) buildTransport(conn transport.StreamConn) (http.RoundTripper, error) {
tr := &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return conn, nil
},
TLSClientConfig: cc.tlsConfig,
}

err := http2.ConfigureTransport(tr)
if err != nil {
return nil, fmt.Errorf("failed to parse remote address %s: %w", remoteAddr, err)
return nil, fmt.Errorf("failed to configure transport for HTTP/2: %w", err)
}

pr, pw := io.Pipe()
return tr, nil
}

req, err := http.NewRequestWithContext(ctx, http.MethodConnect, "http://"+remoteAddr, pr) // TODO: HTTPS support
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
func doConnect(
ctx context.Context,
roundTripper http.RoundTripper,
scheme, remoteAddr string,
headers http.Header,
) (io.ReadCloser, io.WriteCloser, error) {
if scheme != "http" && scheme != "https" {
return nil, nil, fmt.Errorf("unsupported scheme: %s", scheme)
}
req.ContentLength = -1 // -1 means length unknown
mergeHeaders(req.Header, cc.headers)

tr := &http.Transport{
// TODO: HTTP/2 support with [http2.ConfigureTransport]
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return conn, nil
},
pr, pw := io.Pipe()
remoteURL := fmt.Sprintf("%s://%s", scheme, remoteAddr)
req, err := http.NewRequestWithContext(ctx, http.MethodConnect, remoteURL, pr)
if err != nil {
return nil, nil, fmt.Errorf("failed to create request: %w", err)
}
req.ContentLength = -1 // -1 means unknown length
mergeHeaders(req.Header, headers)

hc := http.Client{
Transport: tr,
Transport: roundTripper,
}

resp, err := hc.Do(req)
if err != nil {
return nil, fmt.Errorf("do: %w", err)
return nil, nil, fmt.Errorf("do: %w", err)
}
if resp.StatusCode != http.StatusOK {
_ = resp.Body.Close()
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
return nil, nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}

return &pipeConn{
reader: resp.Body,
writer: pw,
StreamConn: conn,
}, nil
return resp.Body, pw, nil
}

func mergeHeaders(dst http.Header, src http.Header) {
Expand Down
89 changes: 84 additions & 5 deletions x/httpconnect/connect_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,27 @@ package httpconnect
import (
"bufio"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"github.com/Jigsaw-Code/outline-sdk/transport"
"github.com/Jigsaw-Code/outline-sdk/x/httpproxy"
"github.com/stretchr/testify/require"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
)

func TestConnectClientOk(t *testing.T) {
func Test_ConnectClient_HTTP_Ok(t *testing.T) {
t.Parallel()

creds := base64.StdEncoding.EncodeToString([]byte("username:password"))

targetSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method, "Method")
w.WriteHeader(http.StatusOK)
_, err := w.Write([]byte("HTTP/1.1 200 OK\r\n"))
require.NoError(t, err)
}))
defer targetSrv.Close()

Expand Down Expand Up @@ -78,7 +78,7 @@ func TestConnectClientOk(t *testing.T) {
require.Equal(t, http.StatusOK, resp.StatusCode)
}

func TestConnectClientFail(t *testing.T) {
func Test_ConnectClient_HTTP_Fail(t *testing.T) {
t.Parallel()

targetURL := "somehost:1234"
Expand Down Expand Up @@ -107,3 +107,82 @@ func TestConnectClientFail(t *testing.T) {
_, err = connClient.DialStream(context.Background(), targetURL)
require.Error(t, err, "unexpected status code: 400")
}

func Test_ConnectClient_HTTP2_Ok(t *testing.T) {
t.Parallel()

targetSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodGet, r.Method, "Method")
w.Header().Set("Content-Type", "text/plain")
_, err := w.Write([]byte("Hello, world!"))
require.NoError(t, err)
}))
defer targetSrv.Close()

targetURL, err := url.Parse(targetSrv.URL)
require.NoError(t, err)

tcpDialer := &transport.TCPDialer{Dialer: net.Dialer{}}
proxySrv := httptest.NewUnstartedServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
require.Equal(t, "HTTP/2.0", request.Proto, "Proto")
require.Equal(t, http.MethodConnect, request.Method, "Method")
require.Equal(t, targetURL.Host, request.URL.Host, "Host")

conn, err := tcpDialer.DialStream(request.Context(), request.URL.Host)
require.NoError(t, err, "DialStream")

writer.WriteHeader(http.StatusOK)
writer.(http.Flusher).Flush()

go func() {
_, _ = io.Copy(conn, request.Body)
require.NoError(t, err, "io.Copy")
}()

_, _ = io.Copy(writer, conn)
require.NoError(t, err, "io.Copy")
}))
proxySrv.EnableHTTP2 = true
proxySrv.StartTLS()
defer proxySrv.Close()

proxyURL, err := url.Parse(proxySrv.URL)
require.NoError(t, err, "Parse")

certs := x509.NewCertPool()
for _, c := range proxySrv.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
require.NoError(t, err, "x509.ParseCertificates")
for _, root := range roots {
certs.AddCert(root)
}
}

connClient, err := NewConnectClient(
tcpDialer,
proxyURL.Host,
WithHTTPS(&tls.Config{RootCAs: certs}),
)
require.NoError(t, err, "NewConnectClient")

streamConn, err := connClient.DialStream(context.Background(), targetURL.Host)
require.NoError(t, err, "DialStream")
require.NotNil(t, streamConn, "StreamConn")

req, err := http.NewRequest(http.MethodGet, targetSrv.URL, nil)
require.NoError(t, err, "NewRequest")
req.Header.Add("Connection", "close")

err = req.Write(streamConn)
require.NoError(t, err, "Write")

rd := bufio.NewReader(streamConn)
resp, err := http.ReadResponse(rd, req)
require.NoError(t, err, "ReadResponse")

body, err := io.ReadAll(resp.Body)
require.NoError(t, err, "ReadAll")
require.Equal(t, "Hello, world!", string(body))

require.Equal(t, http.StatusOK, resp.StatusCode)
}
3 changes: 2 additions & 1 deletion x/httpconnect/pipe_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import (

var _ transport.StreamConn = (*pipeConn)(nil)

// pipeConn is a [transport.StreamConn] that overrides [Read], [Write] (and corresponding [Close]) functions with the given [reader] and [writer]
// pipeConn is a [transport.StreamConn] that overrides the Read and Write functions with the provided [io.ReadCloser] and [io.WriteCloser], respectively.
// The CloseRead, CloseWrite, and Close functions first close the [io.ReadCloser] and [io.WriteCloser], and then call the corresponding functions on the connection.
type pipeConn struct {
reader io.ReadCloser
writer io.WriteCloser
Expand Down

0 comments on commit 0e6749a

Please sign in to comment.