diff --git a/pkg/core/README.md b/pkg/core/README.md new file mode 100644 index 0000000000..c406fb3a02 --- /dev/null +++ b/pkg/core/README.md @@ -0,0 +1,19 @@ +# A pluggable transport implementation based on Hysteria + +## Hysteria +[Hysteria](https://github.com/HyNetwork/hysteria) uses a custom version of QUIC protocol ([RFC 9000 - QUIC: A UDP-Based Multiplexed and Secure Transport](https://www.rfc-editor.org/rfc/rfc9000.html)): + +* a custom congestion control ([RFC 9002 - QUIC Loss Detection and Congestion Control](https://www.rfc-editor.org/rfc/rfc9002.html)) +* tweaked QUIC parameters +* an obfuscation layer +* non-standard transports (e.g. [faketcp](https://github.com/wangyu-/udp2raw)) + +## Usage + +* Follow [Custom CA](https://hysteria.network/docs/custom-ca/) doc to generate certificates +* See [server side implementation example](https://github.com/apernet/hysteria/pull/340/files#diff-8a9b6ccee2487fc2b424d9f4b3cad2ebde2cc27b1cf1aa078e0de084872edbaaR62-R155) in the `transport_test.go` file +* See [client side implementation example](https://github.com/apernet/hysteria/pull/340/files#diff-8a9b6ccee2487fc2b424d9f4b3cad2ebde2cc27b1cf1aa078e0de084872edbaaR157-R229) in the `transport_test.go` file + +## Implementation + +The implementation uses [Pluggable Transport Specification v3.0 - Go Transport API](https://github.com/Pluggable-Transports/Pluggable-Transports-spec/blob/main/releases/PTSpecV3.0/Pluggable%20Transport%20Specification%20v3.0%20-%20Go%20Transport%20API%20v3.0.md) \ No newline at end of file diff --git a/pkg/core/client.go b/pkg/core/client.go index f4c2ab677f..b8a79f8b22 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -6,6 +6,12 @@ import ( "crypto/tls" "errors" "fmt" + "math/rand" + "net" + "strconv" + "sync" + "time" + "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/congestion" "github.com/lunixbochs/struc" @@ -13,11 +19,6 @@ import ( "github.com/tobyxdd/hysteria/pkg/pmtud_fix" "github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/utils" - "math/rand" - "net" - "strconv" - "sync" - "time" ) var ( @@ -183,6 +184,20 @@ func (c *Client) openStreamWithReconnect() (quic.Connection, quic.Stream, error) return c.quicSession, &wrappedQUICStream{stream}, err } +// Implement Pluggable Transport Client interface +func (c *Client) Dial() (net.Conn, error) { + session, stream, err := c.openStreamWithReconnect() + if err != nil { + return nil, err + } + + return &quicConn{ + Orig: stream, + PseudoLocalAddr: session.LocalAddr(), + PseudoRemoteAddr: session.RemoteAddr(), + }, nil +} + func (c *Client) DialTCP(addr string) (net.Conn, error) { host, port, err := utils.SplitHostPort(addr) if err != nil { diff --git a/pkg/core/server.go b/pkg/core/server.go index 04d887194a..6f07916632 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -5,6 +5,8 @@ import ( "crypto/tls" "errors" "fmt" + "net" + "github.com/lucas-clemente/quic-go" "github.com/lunixbochs/struc" "github.com/prometheus/client_golang/prometheus" @@ -12,7 +14,6 @@ import ( "github.com/tobyxdd/hysteria/pkg/obfs" "github.com/tobyxdd/hysteria/pkg/pmtud_fix" "github.com/tobyxdd/hysteria/pkg/transport" - "net" ) type ConnectFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) @@ -42,6 +43,40 @@ type Server struct { listener quic.Listener } +type HysteriaTransport struct { + addr string + protocol string + tlsConfig *tls.Config + quicConfig *quic.Config + transport *transport.ServerTransport + sendBPS uint64 + recvBPS uint64 + congestionFactory CongestionFactory + disableUDP bool + obfuscator obfs.Obfuscator + connectFunc ConnectFunc + disconnectFunc DisconnectFunc +} + +type TransportServer struct { + transport *transport.ServerTransport + sendBPS, recvBPS uint64 + congestionFactory CongestionFactory + disableUDP bool + aclEngine *acl.Engine + + connectFunc ConnectFunc + disconnectFunc DisconnectFunc + tcpRequestFunc TCPRequestFunc + tcpErrorFunc TCPErrorFunc + udpRequestFunc UDPRequestFunc + udpErrorFunc UDPErrorFunc + + listener quic.Listener + allStreams chan *quicConn + isListening bool +} + func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig *quic.Config, transport *transport.ServerTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, obfuscator obfs.Obfuscator, connectFunc ConnectFunc, disconnectFunc DisconnectFunc, @@ -92,6 +127,8 @@ func (s *Server) Serve() error { } } +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. func (s *Server) Close() error { return s.listener.Close() } @@ -173,3 +210,151 @@ func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([] } return ch.Auth, ok, vb[0] == protocolVersionV2, nil } + +// Implement Pluggable Transport Server interface +func (t *HysteriaTransport) Listen() (net.Listener, error) { + listener, err := t.transport.QUICListen(t.protocol, t.addr, t.tlsConfig, t.quicConfig, t.obfuscator) + if err != nil { + return nil, err + } + s := &TransportServer{ + listener: listener, + transport: t.transport, + sendBPS: t.sendBPS, + recvBPS: t.recvBPS, + congestionFactory: t.congestionFactory, + disableUDP: t.disableUDP, + connectFunc: t.connectFunc, + disconnectFunc: t.disconnectFunc, + allStreams: make(chan *quicConn), + isListening: false, + } + + return s, nil +} + +// Addr returns the listener's network address. +func (s *TransportServer) Addr() net.Addr { + return s.listener.Addr() +} + +func (s *TransportServer) Close() error { + s.isListening = false + return s.listener.Close() +} + +func (s *TransportServer) Accept() (net.Conn, error) { + if !s.isListening { + s.isListening = true + go acceptConn(s) + } + // Return the next stream + select { + case stream := <-s.allStreams: + return stream, nil + } +} + +// An internal goroutine for accepting connections. Then for each accepted +// connection, start a goroutine for handling the control stream & accepting +// streams. Put those streams into a channel +func acceptConn(s *TransportServer) { + for { + cs, err := s.listener.Accept(context.Background()) + if err != nil { + _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") + return + } + go acceptStream(cs, s) + } +} + +func acceptStream(cs quic.Connection, s *TransportServer) { + // Expect the client to create a control stream to send its own information + ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) + stream, err := cs.AcceptStream(ctx) + ctxCancel() + if err != nil { + _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") + return + } + // Handle the control stream + _, ok, _, err := s.handleControlStream(cs, stream) + if err != nil { + _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") + return + } + if !ok { + _ = cs.CloseWithError(closeErrorCodeAuth, "auth error") + return + } + // Close the control stream + stream.Close() + + for { + // Accept the next stream + stream, err = cs.AcceptStream(context.Background()) + if err != nil { + _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") + return + } + + conn := &quicConn{ + Orig: stream, + PseudoLocalAddr: cs.LocalAddr(), + PseudoRemoteAddr: cs.RemoteAddr(), + } + s.allStreams <- conn + } +} + +// Auth & negotiate speed +// Copy from (s *Server) handleControlStream, TODO: refactor +func (s *TransportServer) handleControlStream(cs quic.Connection, stream quic.Stream) ([]byte, bool, bool, error) { + // Check version + vb := make([]byte, 1) + _, err := stream.Read(vb) + if err != nil { + return nil, false, false, err + } + if vb[0] != protocolVersion && vb[0] != protocolVersionV2 { + return nil, false, false, fmt.Errorf("unsupported protocol version %d, expecting %d/%d", + vb[0], protocolVersionV2, protocolVersion) + } + // Parse client hello + var ch clientHello + err = struc.Unpack(stream, &ch) + if err != nil { + return nil, false, false, err + } + // Speed + if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 { + return nil, false, false, errors.New("invalid rate from client") + } + serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS + if s.sendBPS > 0 && serverSendBPS > s.sendBPS { + serverSendBPS = s.sendBPS + } + if s.recvBPS > 0 && serverRecvBPS > s.recvBPS { + serverRecvBPS = s.recvBPS + } + // Auth + ok, msg := s.connectFunc(cs.RemoteAddr(), ch.Auth, serverSendBPS, serverRecvBPS) + // Response + err = struc.Pack(stream, &serverHello{ + OK: ok, + Rate: transmissionRate{ + SendBPS: serverSendBPS, + RecvBPS: serverRecvBPS, + }, + Message: msg, + }) + if err != nil { + return nil, false, false, err + } + // Set the congestion accordingly + if ok && s.congestionFactory != nil { + cs.SetCongestionControl(s.congestionFactory(serverSendBPS)) + } + return ch.Auth, ok, vb[0] == protocolVersionV2, nil +} diff --git a/pkg/core/transport_test.go b/pkg/core/transport_test.go new file mode 100644 index 0000000000..094bf5d1eb --- /dev/null +++ b/pkg/core/transport_test.go @@ -0,0 +1,256 @@ +package core + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "io/ioutil" + "net" + "testing" + + "github.com/lucas-clemente/quic-go" + "github.com/lucas-clemente/quic-go/congestion" + "github.com/sirupsen/logrus" + hyCongestion "github.com/tobyxdd/hysteria/pkg/congestion" + "github.com/tobyxdd/hysteria/pkg/obfs" + "github.com/tobyxdd/hysteria/pkg/transport" +) + +// Configs for testing +const ( + server_addr = "localhost:2345" + protocol = "" + certFile = "../../hysteria.server.crt" + keyFile = "../../hysteria.server.key" + obfs_str = "c561508f56ed" + auth_str = "ga5438aaa690a5748eb59de8f7bedcb0" + client_up_mbps = 20 + client_down_mbps = 1000 + server_name = "www.0e6e852f62bbeb99.com" + test_request = "Here we go!" + test_response = "You got it." + customCA = "../../hysteria.ca.crt" +) + +// Default config copied from cmd/config.go +const ( + mbpsToBps = 125000 + minSpeedBPS = 16384 + + DefaultStreamReceiveWindow = 15728640 // 15 MB/s + DefaultConnectionReceiveWindow = 67108864 // 64 MB/s + DefaultMaxIncomingStreams = 1024 + + DefaultALPN = "hysteria" +) + +func TestE2E(t *testing.T) { + // Server and Client share the same obfuscator + obfuscator := obfs.NewXPlusObfuscator([]byte(obfs_str)) + signal := make(chan struct{}) + + go runServer(obfuscator, signal) + + err := runClient(obfuscator, signal) + if err != nil { + t.Fail() + } +} + +// Simulate a server +func runServer(obfuscator *obfs.XPlusObfuscator, signal chan struct{}) error { + // Load TLS server config + cer, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + fmt.Println("Cannot read server cert or key files") + return err + } + + var serverTlsConfig = &tls.Config{ + Certificates: []tls.Certificate{cer}, + MinVersion: tls.VersionTLS13, + NextProtos: []string{DefaultALPN}, + } + + // QUIC config + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: DefaultStreamReceiveWindow, + MaxStreamReceiveWindow: DefaultStreamReceiveWindow, + InitialConnectionReceiveWindow: DefaultConnectionReceiveWindow, + MaxConnectionReceiveWindow: DefaultConnectionReceiveWindow, + MaxIncomingStreams: DefaultMaxIncomingStreams, // Client doesn't need this + KeepAlive: true, + DisablePathMTUDiscovery: false, + EnableDatagrams: true, + } + + // Auth + var authFunc ConnectFunc + authFunc, err = passwordAuthFunc(auth_str) + connectFunc := func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { + ok, msg := authFunc(addr, auth, sSend, sRecv) + if !ok { + logrus.WithFields(logrus.Fields{ + "src": addr, + "msg": msg, + }).Info("Authentication failed, client rejected") + } else { + logrus.WithFields(logrus.Fields{ + "src": addr, + }).Info("Client connected") + } + return ok, msg + } + + server := &HysteriaTransport{ + addr: server_addr, + protocol: protocol, + tlsConfig: serverTlsConfig, + quicConfig: quicConfig, + transport: transport.DefaultServerTransport, + sendBPS: 0, + recvBPS: 0, + congestionFactory: congestionFactory, + disableUDP: false, + obfuscator: obfuscator, + connectFunc: connectFunc, + disconnectFunc: disconnectFunc, + } + + l, err := server.Listen() + + if err != nil { + fmt.Println("Failed to initialize server") + } + + serverBuffer := make([]byte, len(test_request)) + fmt.Println("Server up and running") + signal <- struct{}{} + + serverConn, err := l.Accept() + defer serverConn.Close() + + if err != nil { + return err + } + + fmt.Println("Server starts reading from connection") + _, err = serverConn.Read(serverBuffer) + + if err != nil { + return err + } + + s := string(serverBuffer) + if s == test_request { + fmt.Println("Server received the expected data from the client") + signal <- struct{}{} + _, err = serverConn.Write([]byte(test_response)) + fmt.Println("Server sent the response to the client") + return err + } + + return errors.New("Something is wrong") +} + +// Simulate a client +func runClient(obfuscator *obfs.XPlusObfuscator, signal chan struct{}) error { + // Load TLS client config + var clientTlsConfig = &tls.Config{ + InsecureSkipVerify: false, + MinVersion: tls.VersionTLS13, + NextProtos: []string{DefaultALPN}, + ServerName: server_name, + } + bs, err := ioutil.ReadFile(customCA) + if err != nil { + logrus.WithFields(logrus.Fields{ + "error": err, + "file": customCA, + }).Fatal("Failed to load CA") + } + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(bs) { + logrus.WithFields(logrus.Fields{ + "file": customCA, + }).Fatal("Failed to parse CA") + } + clientTlsConfig.RootCAs = cp + + // QUIC config + quicConfig := &quic.Config{ + InitialStreamReceiveWindow: DefaultStreamReceiveWindow, + MaxStreamReceiveWindow: DefaultStreamReceiveWindow, + InitialConnectionReceiveWindow: DefaultConnectionReceiveWindow, + MaxConnectionReceiveWindow: DefaultConnectionReceiveWindow, + KeepAlive: true, + DisablePathMTUDiscovery: false, + EnableDatagrams: true, + } + + <-signal + client, err := NewClient(server_addr, protocol, []byte(auth_str), clientTlsConfig, quicConfig, + transport.DefaultClientTransport, client_up_mbps, client_down_mbps, + congestionFactory, obfuscator) + + if err != nil { + fmt.Println("Failed to initialize client") + return err + } + + clientConn, err := client.Dial() + fmt.Println("Client up and running") + defer clientConn.Close() + + if err != nil { + fmt.Println("Failed to connect to the server") + return err + } + + // write data from clientConn for server to read + _, err = clientConn.Write([]byte(test_request)) + if err != nil { + return err + } + fmt.Println("Client sent the data to the server") + + <-signal + clientBuffer := make([]byte, len(test_response)) + fmt.Println("Client starts reading from connection") + _, err = clientConn.Read(clientBuffer) + s := string(clientBuffer) + if s == test_response { + fmt.Println("Client received the expected response from the server") + return nil + } + + return err +} + +// Below are default functions copied from cmd/server.go or cmd/client.go + +// Use Hysteria custom congestion +func congestionFactory(refBPS uint64) congestion.CongestionControl { + return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) +} + +func passwordAuthFunc(pwd string) (ConnectFunc, error) { + var pwds []string + pwds = []string{pwd} + return func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) { + for _, pwd := range pwds { + if string(auth) == pwd { + return true, "Welcome" + } + } + return false, "Wrong password" + }, nil +} + +func disconnectFunc(addr net.Addr, auth []byte, err error) { + logrus.WithFields(logrus.Fields{ + "src": addr, + "error": err, + }).Info("Client disconnected") +}