From 6867fc7aa6fe073127af18b4665a397cdc21e8f0 Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Thu, 1 Aug 2024 09:38:17 +0200 Subject: [PATCH] Upgrade to github.com/pion/dtls/v3 --- .golangci.yml | 3 - dtls/client.go | 6 +- dtls/client_test.go | 121 ++++++++---------------- dtls/example_test.go | 2 +- dtls/server/server.go | 34 +++---- dtls/server_test.go | 66 +++++++------ examples/dtls/cid/client/main.go | 7 +- examples/dtls/cid/server/main.go | 2 +- examples/dtls/pki/client/main.go | 2 +- examples/dtls/pki/server/main.go | 17 ++-- examples/dtls/psk/client/main.go | 2 +- examples/dtls/psk/server/main.go | 2 +- examples/options/server/main.go | 10 +- go.mod | 8 +- go.sum | 4 +- message/codes/codes.go | 2 +- net/conn.go | 4 + net/dtlslistener.go | 156 +++---------------------------- net/options.go | 21 ----- net/tcplistener.go | 7 +- net/tlslistener.go | 8 +- net/tlslistener_test.go | 9 +- server.go | 2 +- tcp/client_test.go | 8 +- tcp/server/server.go | 29 +++--- 25 files changed, 177 insertions(+), 355 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index 6c6b4dc1..f7b6a08c 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -5,9 +5,6 @@ linters-settings: enable: - nilness - shadow - gomoddirectives: - replace-allow-list: - - github.com/pion/dtls/v2 linters: enable: diff --git a/dtls/client.go b/dtls/client.go index 96ac267d..70c2ba38 100644 --- a/dtls/client.go +++ b/dtls/client.go @@ -4,8 +4,8 @@ import ( "fmt" "time" - "github.com/pion/dtls/v2" - dtlsnet "github.com/pion/dtls/v2/pkg/net" + "github.com/pion/dtls/v3" + dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/plgd-dev/go-coap/v3/dtls/server" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" @@ -98,7 +98,7 @@ func Client(conn *dtls.Conn, opts ...udp.Option) *udpClient.Conn { } monitor := cfg.CreateInactivityMonitor() - l := coapNet.NewConn(conn) + l := coapNet.NewDTLSConn(conn) session := server.NewSession(cfg.Ctx, l, cfg.MaxMessageSize, diff --git a/dtls/client_test.go b/dtls/client_test.go index 60526191..d52f0b6a 100644 --- a/dtls/client_test.go +++ b/dtls/client_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" @@ -20,7 +20,6 @@ import ( coapNet "github.com/plgd-dev/go-coap/v3/net" "github.com/plgd-dev/go-coap/v3/net/responsewriter" "github.com/plgd-dev/go-coap/v3/options" - "github.com/plgd-dev/go-coap/v3/options/config" "github.com/plgd-dev/go-coap/v3/pkg/runner/periodic" "github.com/plgd-dev/go-coap/v3/udp/client" "github.com/stretchr/testify/assert" @@ -123,7 +122,7 @@ func TestConnGet(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() got, err := cc.Get(ctx, tt.args.path, tt.args.opts...) if tt.wantErr { @@ -216,7 +215,7 @@ func TestConnGetSeparateMessage(t *testing.T) { require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() req, err := cc.NewGetRequest(ctx, "/a") @@ -340,7 +339,7 @@ func TestConnPost(t *testing.T) { require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() got, err := cc.Post(ctx, tt.args.path, tt.args.contentFormat, tt.args.payload, tt.args.opts...) if tt.wantErr { @@ -475,7 +474,7 @@ func TestConnPut(t *testing.T) { require.NoError(t, errC) }() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() got, err := cc.Put(ctx, tt.args.path, tt.args.contentFormat, tt.args.payload, tt.args.opts...) if tt.wantErr { @@ -590,7 +589,7 @@ func TestConnDelete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() got, err := cc.Delete(ctx, tt.args.path, tt.args.opts...) if tt.wantErr { @@ -654,58 +653,12 @@ func TestConnPing(t *testing.T) { require.NoError(t, err) } -func TestConnHandeShakeFailure(t *testing.T) { - dtlsCfg := &piondtls.Config{ - PSK: func(hint []byte) ([]byte, error) { - fmt.Printf("Hint: %s \n", hint) - return []byte{0xAB, 0xC1, 0x23}, nil - }, - PSKIdentityHint: []byte("Pion DTLS Server"), - CipherSuites: []piondtls.CipherSuiteID{piondtls.TLS_PSK_WITH_AES_128_CCM_8}, - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(context.Background(), 1*time.Second) - }, - } - l, err := coapNet.NewDTLSListener("udp", "", dtlsCfg) - require.NoError(t, err) - defer func() { - errC := l.Close() - require.NoError(t, errC) - }() - var wg sync.WaitGroup - defer wg.Wait() - - s := dtls.NewServer() - defer s.Stop() - - wg.Add(1) - go func() { - defer wg.Done() - errS := s.Serve(l) - assert.NoError(t, errS) - }() - - dtlsCfgClient := &piondtls.Config{ - PSK: func(hint []byte) ([]byte, error) { - fmt.Printf("Hint: %s \n", hint) - return []byte{0xAB, 0xC1, 0x24}, nil - }, - PSKIdentityHint: []byte("Pion DTLS Client"), - CipherSuites: []piondtls.CipherSuiteID{piondtls.TLS_PSK_WITH_AES_128_CCM_8}, - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(context.Background(), 1*time.Second) - }, - } - _, err = dtls.Dial(l.Addr().String(), dtlsCfgClient) - require.Error(t, err) -} - func TestClientInactiveMonitor(t *testing.T) { var inactivityDetected atomic.Bool ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() - serverCgf, clientCgf, _, err := createDTLSConfig(ctx) + serverCgf, clientCgf, _, err := createDTLSConfig() require.NoError(t, err) ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf) @@ -745,7 +698,9 @@ func TestClientInactiveMonitor(t *testing.T) { serverWg.Wait() }() - cc, err := dtls.Dial(ld.Addr().String(), clientCgf, + cc, err := dtls.Dial( + ld.Addr().String(), + clientCgf, options.WithInactivityMonitor(100*time.Millisecond, func(cc *client.Conn) { require.False(t, inactivityDetected.Load()) inactivityDetected.Store(true) @@ -774,65 +729,71 @@ func TestClientInactiveMonitor(t *testing.T) { func TestClientKeepAliveMonitor(t *testing.T) { var inactivityDetected atomic.Bool - ctx, cancel := context.WithTimeout(context.Background(), Timeout) - defer cancel() - serverCgf, clientCgf, _, err := createDTLSConfig(ctx) + serverCgf, clientCgf, _, err := createDTLSConfig() require.NoError(t, err) - ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf) require.NoError(t, err) + defer func() { + errC := ld.Close() + require.NoError(t, errC) + }() + + ctx, cancel := context.WithTimeout(context.Background(), Timeout) + defer cancel() - checkClose := semaphore.NewWeighted(1) - err = checkClose.Acquire(ctx, 1) + checkClose := semaphore.NewWeighted(2) + err = checkClose.Acquire(ctx, 2) require.NoError(t, err) + sd := dtls.NewServer( + options.WithOnNewConn(func(cc *client.Conn) { + cc.AddOnClose(func() { + checkClose.Release(1) + }) + }), + options.WithPeriodicRunner(periodic.New(ctx.Done(), time.Millisecond*10)), + options.WithRequestMonitor(func(_ *client.Conn, _ *pool.Message) (bool, error) { + // lets drop all messages, this will trigger keep alive because of inactivity + return true, nil + }), + ) + var serverWg sync.WaitGroup serverWg.Add(1) go func() { defer serverWg.Done() - for { - c, errA := ld.AcceptWithContext(ctx) - if errA != nil { - if errors.Is(errA, coapNet.ErrListenerIsClosed) { - return - } - } - defer c.Close() - assert.NoError(t, errA) - } + errS := sd.Serve(ld) + assert.NoError(t, errS) }() defer func() { - errC := ld.Close() - require.NoError(t, errC) + sd.Stop() + serverWg.Wait() }() cc, err := dtls.Dial( ld.Addr().String(), clientCgf, options.WithKeepAlive(3, 100*time.Millisecond, func(cc *client.Conn) { + t.Log("client - close for inactivity") require.False(t, inactivityDetected.Load()) inactivityDetected.Store(true) errC := cc.Close() require.NoError(t, errC) }), options.WithPeriodicRunner(periodic.New(ctx.Done(), time.Millisecond*10)), - options.WithReceivedMessageQueueSize(32), - options.WithProcessReceivedMessageFunc(func(req *pool.Message, cc *client.Conn, handler config.HandlerFunc[*client.Conn]) { - cc.ProcessReceivedMessageWithHandler(req, handler) - }), ) require.NoError(t, err) cc.AddOnClose(func() { + t.Log("connection is closed") checkClose.Release(1) }) // send ping to create server side connection ctxPing, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - err = cc.Ping(ctxPing) - require.Error(t, err) + _ = cc.Ping(ctxPing) - err = checkClose.Acquire(ctx, 1) + err = checkClose.Acquire(ctx, 2) require.NoError(t, err) require.True(t, inactivityDetected.Load()) } diff --git a/dtls/example_test.go b/dtls/example_test.go index c41a9601..92d10d2d 100644 --- a/dtls/example_test.go +++ b/dtls/example_test.go @@ -7,7 +7,7 @@ import ( "log" "time" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" "github.com/plgd-dev/go-coap/v3/net" ) diff --git a/dtls/server/server.go b/dtls/server/server.go index 1f879a67..92b95358 100644 --- a/dtls/server/server.go +++ b/dtls/server/server.go @@ -89,12 +89,20 @@ func (s *Server) checkAndSetListener(l Listener) error { s.listenMutex.Lock() defer s.listenMutex.Unlock() if s.listen != nil { - return errors.New("server already serve listener") + return errors.New("server already serves listener") } s.listen = l return nil } +func (s *Server) popListener() Listener { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + l := s.listen + s.listen = nil + return l +} + func (s *Server) checkAcceptError(err error) bool { if err == nil { return true @@ -129,16 +137,14 @@ func (s *Server) Serve(l Listener) error { if s.cfg.BlockwiseSZX > blockwise.SZX1024 { return errors.New("invalid blockwiseSZX") } + err := s.checkAndSetListener(l) if err != nil { return err } defer func() { - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - s.listen = nil + s.Stop() }() - var wg sync.WaitGroup defer wg.Wait() @@ -158,11 +164,9 @@ func (s *Server) Serve(l Listener) error { continue } wg.Add(1) - var cc *udpClient.Conn inactivityMonitor := s.cfg.CreateInactivityMonitor() requestMonitor := s.cfg.RequestMonitor - - cc = s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor) + cc := s.createConn(coapNet.NewDTLSConn(rw), inactivityMonitor, requestMonitor) if s.cfg.OnNewConn != nil { s.cfg.OnNewConn(cc) } @@ -176,14 +180,12 @@ func (s *Server) Serve(l Listener) error { // Stop stops server without wait of ends Serve function. func (s *Server) Stop() { s.cancel() - s.listenMutex.Lock() - l := s.listen - s.listen = nil - s.listenMutex.Unlock() - if l != nil { - if err := l.Close(); err != nil { - s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err)) - } + l := s.popListener() + if l == nil { + return + } + if err := l.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err)) } } diff --git a/dtls/server_test.go b/dtls/server_test.go index cde36991..21fe835a 100644 --- a/dtls/server_test.go +++ b/dtls/server_test.go @@ -12,7 +12,7 @@ import ( "testing" "time" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" "github.com/plgd-dev/go-coap/v3/examples/dtls/pki" "github.com/plgd-dev/go-coap/v3/message" @@ -85,7 +85,7 @@ func TestServerCleanUpConns(t *testing.T) { <-cc.Done() } -func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clientConfig *piondtls.Config, clientSerial *big.Int, err error) { +func createDTLSConfig() (serverConfig *piondtls.Config, clientConfig *piondtls.Config, clientSerial *big.Int, err error) { // root cert ca, rootBytes, _, caPriv, err := pki.GenerateCA() if err != nil { @@ -111,9 +111,6 @@ func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clien ExtendedMasterSecret: piondtls.RequireExtendedMasterSecret, ClientCAs: certPool, ClientAuth: piondtls.RequireAndVerifyClientCert, - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, } // client cert @@ -142,9 +139,9 @@ func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clien } func TestServerSetContextValueWithPKI(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() - serverCgf, clientCgf, clientSerial, err := createDTLSConfig(ctx) + serverCgf, clientCgf, clientSerial, err := createDTLSConfig() require.NoError(t, err) ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf) @@ -154,46 +151,59 @@ func TestServerSetContextValueWithPKI(t *testing.T) { require.NoError(t, errC) }() - onNewConn := func(cc *client.Conn) { - dtlsConn, ok := cc.NetConn().(*piondtls.Conn) - require.True(t, ok) - // set connection context certificate - clientCert, errP := x509.ParseCertificate(dtlsConn.ConnectionState().PeerCertificates[0]) - require.NoError(t, errP) - cc.SetContextValue("client-cert", clientCert) - } handle := func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { // get certificate from connection context - clientCert := r.Context().Value("client-cert").(*x509.Certificate) - require.Equal(t, clientCert.SerialNumber, clientSerial) - require.NotNil(t, clientCert) + var clientCert *x509.Certificate + value := r.Context().Value("client-cert") + if value == nil { + dtlsConn, ok := w.Conn().NetConn().(*piondtls.Conn) + assert.True(t, ok) + state, ok := dtlsConn.ConnectionState() + assert.True(t, ok) + var errP error + clientCert, errP = x509.ParseCertificate(state.PeerCertificates[0]) + assert.NoError(t, errP) //nolint:testifylint + w.Conn().SetContextValue("client-cert", clientCert) + } else { + clientCert = r.Context().Value("client-cert").(*x509.Certificate) + } + assert.Equal(t, clientCert.SerialNumber, clientSerial) + assert.NotNil(t, clientCert) errH := w.SetResponse(codes.Content, message.TextPlain, bytes.NewReader([]byte("done"))) - require.NoError(t, errH) + assert.NoError(t, errH) } - sd := dtls.NewServer(options.WithHandlerFunc(handle), options.WithOnNewConn(onNewConn)) - defer sd.Stop() + sd := dtls.NewServer(options.WithHandlerFunc(handle)) + var wg sync.WaitGroup + defer func() { + sd.Stop() + wg.Wait() + }() + wg.Add(1) go func() { + defer wg.Done() errS := sd.Serve(ld) assert.NoError(t, errS) }() cc, err := dtls.Dial(ld.Addr().String(), clientCgf) require.NoError(t, err) + defer func() { + errC := cc.Close() + require.NoError(t, errC) + <-cc.Done() + }() _, err = cc.Get(ctx, "/") require.NoError(t, err) - err = cc.Close() - require.NoError(t, err) - <-cc.Done() } func TestServerInactiveMonitor(t *testing.T) { var inactivityDetected atomic.Bool - ctx, cancel := context.WithTimeout(context.Background(), time.Second*8) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() - serverCgf, clientCgf, _, err := createDTLSConfig(ctx) + serverCgf, clientCgf, _, err := createDTLSConfig() require.NoError(t, err) ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf) @@ -266,9 +276,9 @@ func TestServerInactiveMonitor(t *testing.T) { func TestServerKeepAliveMonitor(t *testing.T) { var inactivityDetected atomic.Bool - ctx, cancel := context.WithTimeout(context.Background(), time.Second*8) + ctx, cancel := context.WithTimeout(context.Background(), Timeout) defer cancel() - serverCgf, clientCgf, _, err := createDTLSConfig(ctx) + serverCgf, clientCgf, _, err := createDTLSConfig() require.NoError(t, err) ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf) diff --git a/examples/dtls/cid/client/main.go b/examples/dtls/cid/client/main.go index 8e71c47a..aac8a9d5 100644 --- a/examples/dtls/cid/client/main.go +++ b/examples/dtls/cid/client/main.go @@ -6,7 +6,7 @@ import ( "log" "net" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" ) @@ -49,7 +49,10 @@ func main() { log.Printf("Response payload: %+v", resp) // Export state to resume connection from another address. - state := client.ConnectionState() + state, ok := client.ConnectionState() + if !ok { + log.Fatalf("Error exporting DTLS state") + } // Setup second UDP listener on a different address. udpconn, err = net.ListenUDP("udp", nil) diff --git a/examples/dtls/cid/server/main.go b/examples/dtls/cid/server/main.go index a5564bf1..dac59e3c 100644 --- a/examples/dtls/cid/server/main.go +++ b/examples/dtls/cid/server/main.go @@ -8,7 +8,7 @@ import ( "net" "time" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls/server" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" diff --git a/examples/dtls/pki/client/main.go b/examples/dtls/pki/client/main.go index c69fd0e6..5f2d8d94 100644 --- a/examples/dtls/pki/client/main.go +++ b/examples/dtls/pki/client/main.go @@ -7,7 +7,7 @@ import ( "os" "time" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" "github.com/plgd-dev/go-coap/v3/examples/dtls/pki" ) diff --git a/examples/dtls/pki/server/main.go b/examples/dtls/pki/server/main.go index b45427f2..d6eecd19 100644 --- a/examples/dtls/pki/server/main.go +++ b/examples/dtls/pki/server/main.go @@ -2,15 +2,13 @@ package main import ( "bytes" - "context" "crypto/tls" "crypto/x509" "fmt" "log" "math/big" - "time" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" "github.com/plgd-dev/go-coap/v3/examples/dtls/pki" "github.com/plgd-dev/go-coap/v3/message" @@ -26,7 +24,11 @@ func onNewConn(cc *client.Conn) { if !ok { log.Fatalf("invalid type %T", cc.NetConn()) } - clientCert, err := x509.ParseCertificate(dtlsConn.ConnectionState().PeerCertificates[0]) + state, ok := dtlsConn.ConnectionState() + if !ok { + log.Fatalf("cannot get connection state") + } + clientCert, err := x509.ParseCertificate(state.PeerCertificates[0]) if err != nil { log.Fatal(err) } @@ -57,7 +59,7 @@ func main() { m := mux.NewRouter() m.Handle("/a", mux.HandlerFunc(handleA)) - config, err := createServerConfig(context.Background()) + config, err := createServerConfig() if err != nil { log.Fatalln(err) return @@ -76,7 +78,7 @@ func listenAndServeDTLS(network string, addr string, config *piondtls.Config, ha return s.Serve(l) } -func createServerConfig(ctx context.Context) (*piondtls.Config, error) { +func createServerConfig() (*piondtls.Config, error) { // root cert ca, rootBytes, _, caPriv, err := pki.GenerateCA() if err != nil { @@ -102,8 +104,5 @@ func createServerConfig(ctx context.Context) (*piondtls.Config, error) { ExtendedMasterSecret: piondtls.RequireExtendedMasterSecret, ClientCAs: certPool, ClientAuth: piondtls.RequireAndVerifyClientCert, - ConnectContextMaker: func() (context.Context, func()) { - return context.WithTimeout(ctx, 30*time.Second) - }, }, nil } diff --git a/examples/dtls/psk/client/main.go b/examples/dtls/psk/client/main.go index 51cde6d9..a692e76e 100644 --- a/examples/dtls/psk/client/main.go +++ b/examples/dtls/psk/client/main.go @@ -7,7 +7,7 @@ import ( "os" "time" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" ) diff --git a/examples/dtls/psk/server/main.go b/examples/dtls/psk/server/main.go index aafd2b18..4c91f8a9 100644 --- a/examples/dtls/psk/server/main.go +++ b/examples/dtls/psk/server/main.go @@ -5,7 +5,7 @@ import ( "fmt" "log" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" coap "github.com/plgd-dev/go-coap/v3" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" diff --git a/examples/options/server/main.go b/examples/options/server/main.go index 451e5104..8543d434 100644 --- a/examples/options/server/main.go +++ b/examples/options/server/main.go @@ -7,7 +7,7 @@ import ( "fmt" "log" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" coap "github.com/plgd-dev/go-coap/v3" "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" @@ -46,10 +46,14 @@ func handleOnNewConn(cc *udpClient.Conn) { if !ok { log.Fatalf("invalid type %T", cc.NetConn()) } - clientId := dtlsConn.ConnectionState().IdentityHint + state, ok := dtlsConn.ConnectionState() + if !ok { + log.Fatalf("cannot get connection state") + } + clientId := state.IdentityHint cc.SetContextValue("clientId", clientId) cc.AddOnClose(func() { - clientId := dtlsConn.ConnectionState().IdentityHint + clientId := state.IdentityHint log.Printf("closed connection clientId: %s", clientId) }) } diff --git a/go.mod b/go.mod index 6b1201f5..48f1a76e 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,7 @@ go 1.20 require ( github.com/dsnet/golib/memfile v1.0.0 - github.com/pion/dtls/v2 v2.2.8-0.20240701035148-45e16a098c47 - github.com/pion/transport/v3 v3.0.7 + github.com/pion/dtls/v3 v3.0.2 github.com/stretchr/testify v1.9.0 go.uber.org/atomic v1.11.0 golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 @@ -16,12 +15,9 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pion/logging v0.2.2 // indirect + github.com/pion/transport/v3 v3.0.7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/sys v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -// note: github.com/pion/dtls/v2/pkg/net package is not yet available in release branches, -// so we force to the use of the pinned master branch -replace github.com/pion/dtls/v2 => github.com/pion/dtls/v2 v2.2.8-0.20240701035148-45e16a098c47 diff --git a/go.sum b/go.sum index db014f2c..5c8dab3c 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dsnet/golib/memfile v1.0.0 h1:J9pUspY2bDCbF9o+YGwcf3uG6MdyITfh/Fk3/CaEiFs= github.com/dsnet/golib/memfile v1.0.0/go.mod h1:tXGNW9q3RwvWt1VV2qrRKlSSz0npnh12yftCSCy2T64= -github.com/pion/dtls/v2 v2.2.8-0.20240701035148-45e16a098c47 h1:WCUn5hJZLLMoOvedDEDA/OFzaYbZy7G71mQ9h5GiQ/o= -github.com/pion/dtls/v2 v2.2.8-0.20240701035148-45e16a098c47/go.mod h1:8eXNLDNOiXaHvo/wOFnFcr/yinEimCDUQ512tlOSvPo= +github.com/pion/dtls/v3 v3.0.2 h1:425DEeJ/jfuTTghhUDW0GtYZYIwwMtnKKJNMcWccTX0= +github.com/pion/dtls/v3 v3.0.2/go.mod h1:dfIXcFkKoujDQ+jtd8M6RgqKK3DuaUilm3YatAbGp5k= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/transport/v3 v3.0.7 h1:iRbMH05BzSNwhILHoBoAPxoB9xQgOaJk+591KC9P1o0= diff --git a/message/codes/codes.go b/message/codes/codes.go index 5878fe87..648a1d10 100644 --- a/message/codes/codes.go +++ b/message/codes/codes.go @@ -96,7 +96,7 @@ var strToCode = map[string]Code{ } func getMaxCodeLen() int { - // maxLen uint32 as string binary representation: "0b" + 32 digits + // max uint32 as string binary representation: "0b" + 32 digits maxLen := 34 for k := range strToCode { kLen := len(k) diff --git a/net/conn.go b/net/conn.go index 8d6faabe..4f8c6b0c 100644 --- a/net/conn.go +++ b/net/conn.go @@ -34,6 +34,10 @@ func NewConn(c net.Conn) *Conn { return &connection } +func NewDTLSConn(c net.Conn) *Conn { + return NewConn(c) +} + // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. func (c *Conn) LocalAddr() net.Addr { return c.connection.LocalAddr() diff --git a/net/dtlslistener.go b/net/dtlslistener.go index 4a7379b6..13c5dc58 100644 --- a/net/dtlslistener.go +++ b/net/dtlslistener.go @@ -2,145 +2,39 @@ package net import ( "context" - "errors" "fmt" "net" - "sync" - "time" - dtls "github.com/pion/dtls/v2" - dtlsnet "github.com/pion/dtls/v2/pkg/net" - "github.com/pion/dtls/v2/pkg/protocol" - "github.com/pion/dtls/v2/pkg/protocol/recordlayer" - "github.com/pion/transport/v3/udp" + dtls "github.com/pion/dtls/v3" "go.uber.org/atomic" ) -type GoPoolFunc = func(f func()) error - -var DefaultDTLSListenerConfig = DTLSListenerConfig{ - GoPool: func(f func()) error { - go f() - return nil - }, -} - -type DTLSListenerConfig struct { - GoPool GoPoolFunc -} - -type acceptedConn struct { - conn net.Conn - err error -} - // DTLSListener is a DTLS listener that provides accept with context. type DTLSListener struct { - listener net.Listener - config *dtls.Config - closed atomic.Bool - goPool GoPoolFunc - acceptedConnChan chan acceptedConn - wg sync.WaitGroup - done chan struct{} -} - -func tlsPacketFilter(packet []byte) bool { - pkts, err := recordlayer.UnpackDatagram(packet) - if err != nil || len(pkts) < 1 { - return false - } - h := &recordlayer.Header{} - if err := h.Unmarshal(pkts[0]); err != nil { - return false - } - return h.ContentType == protocol.ContentTypeHandshake + listener net.Listener + closed atomic.Bool } -// NewDTLSListener creates dtls listener. -// Known networks are "udp", "udp4" (IPv4-only), "udp6" (IPv6-only). -func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config, opts ...DTLSListenerOption) (*DTLSListener, error) { +func newNetDTLSListener(network string, addr string, dtlsCfg *dtls.Config) (net.Listener, error) { a, err := net.ResolveUDPAddr(network, addr) if err != nil { return nil, fmt.Errorf("cannot resolve address: %w", err) } - cfg := DefaultDTLSListenerConfig - for _, o := range opts { - o.ApplyDTLS(&cfg) - } - - if cfg.GoPool == nil { - return nil, errors.New("empty go pool") - } - - l := DTLSListener{ - goPool: cfg.GoPool, - config: dtlsCfg, - acceptedConnChan: make(chan acceptedConn, 256), - done: make(chan struct{}), - } - connectContextMaker := dtlsCfg.ConnectContextMaker - if connectContextMaker == nil { - connectContextMaker = func() (context.Context, func()) { - return context.WithTimeout(context.Background(), 30*time.Second) - } - } - dtlsCfg.ConnectContextMaker = func() (context.Context, func()) { - ctx, cancel := connectContextMaker() - if l.closed.Load() { - cancel() - } - return ctx, cancel - } - - lc := udp.ListenConfig{ - AcceptFilter: tlsPacketFilter, - } - l.listener, err = lc.Listen(network, a) + dtls, err := dtls.Listen(network, a, dtlsCfg) if err != nil { - return nil, err - } - l.wg.Add(1) - go l.run() - return &l, nil -} - -func (l *DTLSListener) send(conn net.Conn, err error) { - select { - case <-l.done: - case l.acceptedConnChan <- acceptedConn{ - conn: conn, - err: err, - }: + return nil, fmt.Errorf("cannot create new net dtls listener: %w", err) } + return dtls, nil } -func (l *DTLSListener) accept() error { - c, err := l.listener.Accept() - if err != nil { - l.send(nil, err) - return err - } - err = l.goPool(func() { - l.send(dtls.Server(dtlsnet.PacketConnFromConn(c), c.RemoteAddr(), l.config)) - }) +// NewDTLSListener creates dtls listener. +// Known networks are "udp", "udp4" (IPv4-only), "udp6" (IPv6-only). +func NewDTLSListener(network string, addr string, dtlsCfg *dtls.Config) (*DTLSListener, error) { + dtls, err := newNetDTLSListener(network, addr, dtlsCfg) if err != nil { - _ = c.Close() - } - return err -} - -func (l *DTLSListener) run() { - defer l.wg.Done() - for { - if l.closed.Load() { - return - } - err := l.accept() - if errors.Is(err, udp.ErrClosedListener) { - return - } + return nil, fmt.Errorf("cannot create new dtls listener: %w", err) } + return &DTLSListener{listener: dtls}, nil } // AcceptWithContext waits with context for a generic Conn. @@ -153,27 +47,7 @@ func (l *DTLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) if l.closed.Load() { return nil, ErrListenerIsClosed } - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-l.done: - return nil, ErrListenerIsClosed - case d := <-l.acceptedConnChan: - err := d.err - if errors.Is(err, context.DeadlineExceeded) { - // we don't want to report error handshake deadline exceeded - continue - } - if errors.Is(err, udp.ErrClosedListener) { - return nil, ErrListenerIsClosed - } - if err != nil { - return nil, err - } - return d.conn, nil - } - } + return l.listener.Accept() } // Accept waits for a generic Conn. @@ -186,8 +60,6 @@ func (l *DTLSListener) Close() error { if !l.closed.CompareAndSwap(false, true) { return nil } - close(l.done) - defer l.wg.Wait() return l.listener.Close() } diff --git a/net/options.go b/net/options.go index c037ef9c..a18578ad 100644 --- a/net/options.go +++ b/net/options.go @@ -120,24 +120,3 @@ func (m MulticastInterfaceErrorOpt) applyMC(o *MulticastOptions) { func WithMulticastInterfaceError(interfaceError InterfaceError) MulticastOption { return &MulticastInterfaceErrorOpt{interfaceError: interfaceError} } - -// A DTLSListenerOption sets options such as gopool. -type DTLSListenerOption interface { - ApplyDTLS(*DTLSListenerConfig) -} - -// GoPoolOpt gopool option. -type GoPoolOpt struct { - goPool GoPoolFunc -} - -func (o GoPoolOpt) ApplyDTLS(cfg *DTLSListenerConfig) { - cfg.GoPool = o.goPool -} - -// WithGoPool sets function for managing spawning go routines -// for handling incoming request's. -// Eg: https://github.com/panjf2000/ants. -func WithGoPool(goPool GoPoolFunc) GoPoolOpt { - return GoPoolOpt{goPool: goPool} -} diff --git a/net/tcplistener.go b/net/tcplistener.go index 43cc955e..33f5404a 100644 --- a/net/tcplistener.go +++ b/net/tcplistener.go @@ -14,12 +14,11 @@ type TCPListener struct { closed atomic.Bool } -func newNetTCPListen(network string, addr string) (*net.TCPListener, error) { +func newNetTCPListener(network string, addr string) (*net.TCPListener, error) { a, err := net.ResolveTCPAddr(network, addr) if err != nil { - return nil, fmt.Errorf("cannot create new net tcp listener: %w", err) + return nil, fmt.Errorf("cannot resolve address: %w", err) } - tcp, err := net.ListenTCP(network, a) if err != nil { return nil, fmt.Errorf("cannot create new net tcp listener: %w", err) @@ -30,7 +29,7 @@ func newNetTCPListen(network string, addr string) (*net.TCPListener, error) { // NewTCPListener creates tcp listener. // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only). func NewTCPListener(network string, addr string) (*TCPListener, error) { - tcp, err := newNetTCPListen(network, addr) + tcp, err := newNetTCPListener(network, addr) if err != nil { return nil, fmt.Errorf("cannot create new tcp listener: %w", err) } diff --git a/net/tlslistener.go b/net/tlslistener.go index df12b1a6..222e1bff 100644 --- a/net/tlslistener.go +++ b/net/tlslistener.go @@ -19,7 +19,7 @@ type TLSListener struct { // NewTLSListener creates tcp listener. // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only). func NewTLSListener(network string, addr string, tlsCfg *tls.Config) (*TLSListener, error) { - tcp, err := newNetTCPListen(network, addr) + tcp, err := newNetTCPListener(network, addr) if err != nil { return nil, fmt.Errorf("cannot create new tls listener: %w", err) } @@ -40,11 +40,7 @@ func (l *TLSListener) AcceptWithContext(ctx context.Context) (net.Conn, error) { if l.closed.Load() { return nil, ErrListenerIsClosed } - rw, err := l.listener.Accept() - if err != nil { - return nil, err - } - return rw, nil + return l.listener.Accept() } // Accept waits for a generic Conn. diff --git a/net/tlslistener_test.go b/net/tlslistener_test.go index 5c9576c3..c36ea990 100644 --- a/net/tlslistener_test.go +++ b/net/tlslistener_test.go @@ -1,4 +1,4 @@ -package net +package net_test import ( "context" @@ -10,6 +10,7 @@ import ( "testing" "time" + coapNet "github.com/plgd-dev/go-coap/v3/net" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -82,7 +83,7 @@ func TestTLSListenerAcceptWithContext(t *testing.T) { }() config := SetTLSConfig(t) - listener, err := NewTLSListener("tcp", "127.0.0.1:", config) + listener, err := coapNet.NewTLSListener("tcp", "127.0.0.1:", config) require.NoError(t, err) defer func() { err := listener.Close() @@ -168,7 +169,7 @@ func TestTLSListenerCheckForInfinitLoop(t *testing.T) { }() config := SetTLSConfig(t) - listener, err := NewTLSListener("tcp", "127.0.0.1:", config) + listener, err := coapNet.NewTLSListener("tcp", "127.0.0.1:", config) require.NoError(t, err) defer func() { err := listener.Close() @@ -215,7 +216,7 @@ func TestTLSListenerCheckForInfinitLoop(t *testing.T) { } require.NoError(t, err) b := make([]byte, 1024) - c := NewConn(con) + c := coapNet.NewConn(con) _, err = c.ReadWithContext(context.Background(), b) require.Error(t, err) assert.Contains(t, err.Error(), "EOF") diff --git a/server.go b/server.go index bec94646..10d54a1f 100644 --- a/server.go +++ b/server.go @@ -6,7 +6,7 @@ import ( "errors" "fmt" - piondtls "github.com/pion/dtls/v2" + piondtls "github.com/pion/dtls/v3" "github.com/plgd-dev/go-coap/v3/dtls" dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server" "github.com/plgd-dev/go-coap/v3/mux" diff --git a/tcp/client_test.go b/tcp/client_test.go index 2a39c19c..5c4a1198 100644 --- a/tcp/client_test.go +++ b/tcp/client_test.go @@ -632,16 +632,16 @@ func TestClientKeepAliveMonitor(t *testing.T) { ) var serverWg sync.WaitGroup - defer func() { - sd.Stop() - serverWg.Wait() - }() serverWg.Add(1) go func() { defer serverWg.Done() errS := sd.Serve(ld) assert.NoError(t, errS) }() + defer func() { + sd.Stop() + serverWg.Wait() + }() cc, err := Dial( ld.Addr().String(), diff --git a/tcp/server/server.go b/tcp/server/server.go index 08f6e079..0a916170 100644 --- a/tcp/server/server.go +++ b/tcp/server/server.go @@ -17,11 +17,6 @@ import ( "github.com/plgd-dev/go-coap/v3/tcp/client" ) -// A Option sets options such as credentials, codec and keepalive parameters, etc. -type Option interface { - TCPServerApply(cfg *Config) -} - // Listener defined used by coap type Listener interface { Close() error @@ -36,6 +31,11 @@ type Server struct { cfg *Config } +// A Option sets options such as credentials, codec and keepalive parameters, etc. +type Option interface { + TCPServerApply(cfg *Config) +} + func New(opt ...Option) *Server { cfg := DefaultConfig for _, o := range opt { @@ -117,14 +117,7 @@ func (s *Server) checkAcceptError(err error) bool { } } -func (s *Server) serveConnection(connections *connections.Connections, rw net.Conn) { - var cc *client.Conn - inactivityMonitor := s.cfg.CreateInactivityMonitor() - requestMonitor := s.cfg.RequestMonitor - cc = s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor) - if s.cfg.OnNewConn != nil { - s.cfg.OnNewConn(cc) - } +func (s *Server) serveConnection(connections *connections.Connections, cc *client.Conn) { connections.Store(cc) defer connections.Delete(cc) @@ -160,13 +153,19 @@ func (s *Server) Serve(l Listener) error { if ok := s.checkAcceptError(err); !ok { return nil } - if rw == nil { + if err != nil || rw == nil { continue } wg.Add(1) + inactivityMonitor := s.cfg.CreateInactivityMonitor() + requestMonitor := s.cfg.RequestMonitor + cc := s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor) + if s.cfg.OnNewConn != nil { + s.cfg.OnNewConn(cc) + } go func() { defer wg.Done() - s.serveConnection(connections, rw) + s.serveConnection(connections, cc) }() } }