diff --git a/dtls/client.go b/dtls/client.go index 99439df2..70c2ba38 100644 --- a/dtls/client.go +++ b/dtls/client.go @@ -98,7 +98,7 @@ func Client(conn *dtls.Conn, opts ...udp.Option) *udpClient.Conn { } monitor := cfg.CreateInactivityMonitor() - l := coapNet.NewConn(conn) + l := coapNet.NewDTLSConn(conn) session := server.NewSession(cfg.Ctx, l, cfg.MaxMessageSize, diff --git a/dtls/server/server.go b/dtls/server/server.go index 88211b08..be89aeb0 100644 --- a/dtls/server/server.go +++ b/dtls/server/server.go @@ -165,7 +165,7 @@ func (s *Server) Serve(l Listener) error { wg.Add(1) inactivityMonitor := s.cfg.CreateInactivityMonitor() requestMonitor := s.cfg.RequestMonitor - cc := s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor) + cc := s.createConn(coapNet.NewDTLSConn(rw), inactivityMonitor, requestMonitor) if s.cfg.OnNewConn != nil { s.cfg.OnNewConn(cc) } diff --git a/dtls/server_test.go b/dtls/server_test.go index 42878692..f080f3b3 100644 --- a/dtls/server_test.go +++ b/dtls/server_test.go @@ -151,18 +151,6 @@ func TestServerSetContextValueWithPKI(t *testing.T) { require.NoError(t, errC) }() - // TODO: fix in piondtls? - // dtlsConn.ConnectionState() -> func (s *State) serialize() -> uint16(s.cipherSuite.ID()) - // s.cipherSuite is nil at the call which causes a segfault - // } - // onNewConn := func(cc *client.Conn) { - // dtlsConn, ok := cc.NetConn().(*piondtls.Conn) - // require.True(t, ok) - // // set connection context certificate - // clientCert, errP := x509.ParseCertificate(dtlsConn.ConnectionState().PeerCertificates[0]) - // require.NoError(t, errP) - // cc.SetContextValue("client-cert", clientCert) - // } handle := func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { // get certificate from connection context var clientCert *x509.Certificate diff --git a/net/conn.go b/net/conn.go index 8d6faabe..a80bd93a 100644 --- a/net/conn.go +++ b/net/conn.go @@ -9,20 +9,45 @@ import ( "go.uber.org/atomic" ) +type ConnOption interface { + Apply(*ConnConfig) +} + +type ConnConfig struct { + disabledManualCloseAfterHandshake bool +} + +func (h ConnConfig) Apply(o *ConnConfig) { + o.disabledManualCloseAfterHandshake = h.disabledManualCloseAfterHandshake +} + +func WithDisabledManualCloseAfterHandshake() ConnConfig { + return ConnConfig{ + disabledManualCloseAfterHandshake: true, + } +} + // Conn is a generic stream-oriented network connection that provides Read/Write with context. // // Multiple goroutines may invoke methods on a Conn simultaneously. type Conn struct { connection net.Conn closed atomic.Bool + cfg ConnConfig handshakeContext func(ctx context.Context) error lock sync.Mutex } // NewConn creates connection over net.Conn. -func NewConn(c net.Conn) *Conn { +func NewConn(c net.Conn, opts ...ConnOption) *Conn { + cfg := ConnConfig{} + for _, o := range opts { + o.Apply(&cfg) + } + connection := Conn{ connection: c, + cfg: cfg, } if v, ok := c.(interface { @@ -34,6 +59,10 @@ func NewConn(c net.Conn) *Conn { return &connection } +func NewDTLSConn(c net.Conn) *Conn { + return NewConn(c, WithDisabledManualCloseAfterHandshake()) +} + // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. func (c *Conn) LocalAddr() net.Addr { return c.connection.LocalAddr() @@ -63,11 +92,14 @@ func (c *Conn) handshake(ctx context.Context) error { if err == nil { return nil } - errC := c.Close() - if errC == nil { - return err + if !c.cfg.disabledManualCloseAfterHandshake { + errC := c.Close() + if errC == nil { + return err + } + return fmt.Errorf("%v", []error{err, errC}) } - return fmt.Errorf("%v", []error{err, errC}) + return err } return nil }