Skip to content

Commit

Permalink
Upgrade to github.com/pion/dtls/v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielius1922 committed Aug 7, 2024
1 parent 1285834 commit 58b1952
Show file tree
Hide file tree
Showing 20 changed files with 123 additions and 309 deletions.
3 changes: 0 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ linters-settings:
enable:
- nilness
- shadow
gomoddirectives:
replace-allow-list:
- github.com/pion/dtls/v2

linters:
enable:
Expand Down
4 changes: 2 additions & 2 deletions dtls/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"fmt"
"time"

"github.com/pion/dtls/v2"
dtlsnet "github.com/pion/dtls/v2/pkg/net"
"github.com/pion/dtls/v3"
dtlsnet "github.com/pion/dtls/v3/pkg/net"
"github.com/plgd-dev/go-coap/v3/dtls/server"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
Expand Down
123 changes: 43 additions & 80 deletions dtls/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
"testing"
"time"

piondtls "github.com/pion/dtls/v2"
piondtls "github.com/pion/dtls/v3"
"github.com/plgd-dev/go-coap/v3/dtls"
"github.com/plgd-dev/go-coap/v3/message"
"github.com/plgd-dev/go-coap/v3/message/codes"
Expand All @@ -20,7 +20,6 @@ import (
coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/plgd-dev/go-coap/v3/net/responsewriter"
"github.com/plgd-dev/go-coap/v3/options"
"github.com/plgd-dev/go-coap/v3/options/config"
"github.com/plgd-dev/go-coap/v3/pkg/runner/periodic"
"github.com/plgd-dev/go-coap/v3/udp/client"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -123,7 +122,7 @@ func TestConnGet(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Get(ctx, tt.args.path, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -216,7 +215,7 @@ func TestConnGetSeparateMessage(t *testing.T) {
require.NoError(t, errC)
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()

req, err := cc.NewGetRequest(ctx, "/a")
Expand Down Expand Up @@ -340,7 +339,7 @@ func TestConnPost(t *testing.T) {
require.NoError(t, errC)
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Post(ctx, tt.args.path, tt.args.contentFormat, tt.args.payload, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -475,7 +474,7 @@ func TestConnPut(t *testing.T) {
require.NoError(t, errC)
}()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Put(ctx, tt.args.path, tt.args.contentFormat, tt.args.payload, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -590,7 +589,7 @@ func TestConnDelete(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
got, err := cc.Delete(ctx, tt.args.path, tt.args.opts...)
if tt.wantErr {
Expand Down Expand Up @@ -654,58 +653,12 @@ func TestConnPing(t *testing.T) {
require.NoError(t, err)
}

func TestConnHandeShakeFailure(t *testing.T) {
dtlsCfg := &piondtls.Config{
PSK: func(hint []byte) ([]byte, error) {
fmt.Printf("Hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x23}, nil
},
PSKIdentityHint: []byte("Pion DTLS Server"),
CipherSuites: []piondtls.CipherSuiteID{piondtls.TLS_PSK_WITH_AES_128_CCM_8},
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
}
l, err := coapNet.NewDTLSListener("udp", "", dtlsCfg)
require.NoError(t, err)
defer func() {
errC := l.Close()
require.NoError(t, errC)
}()
var wg sync.WaitGroup
defer wg.Wait()

s := dtls.NewServer()
defer s.Stop()

wg.Add(1)
go func() {
defer wg.Done()
errS := s.Serve(l)
assert.NoError(t, errS)
}()

dtlsCfgClient := &piondtls.Config{
PSK: func(hint []byte) ([]byte, error) {
fmt.Printf("Hint: %s \n", hint)
return []byte{0xAB, 0xC1, 0x24}, nil
},
PSKIdentityHint: []byte("Pion DTLS Client"),
CipherSuites: []piondtls.CipherSuiteID{piondtls.TLS_PSK_WITH_AES_128_CCM_8},
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(context.Background(), 1*time.Second)
},
}
_, err = dtls.Dial(l.Addr().String(), dtlsCfgClient)
require.Error(t, err)
}

func TestClientInactiveMonitor(t *testing.T) {
var inactivityDetected atomic.Bool

ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
serverCgf, clientCgf, _, err := createDTLSConfig(ctx)
serverCgf, clientCgf, _, err := createDTLSConfig()
require.NoError(t, err)

ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
Expand Down Expand Up @@ -745,7 +698,9 @@ func TestClientInactiveMonitor(t *testing.T) {
serverWg.Wait()
}()

cc, err := dtls.Dial(ld.Addr().String(), clientCgf,
cc, err := dtls.Dial(
ld.Addr().String(),
clientCgf,
options.WithInactivityMonitor(100*time.Millisecond, func(cc *client.Conn) {
require.False(t, inactivityDetected.Load())
inactivityDetected.Store(true)
Expand Down Expand Up @@ -774,65 +729,73 @@ func TestClientInactiveMonitor(t *testing.T) {
func TestClientKeepAliveMonitor(t *testing.T) {
var inactivityDetected atomic.Bool

ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
serverCgf, clientCgf, _, err := createDTLSConfig(ctx)
serverCgf, clientCgf, _, err := createDTLSConfig()
require.NoError(t, err)

ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
require.NoError(t, err)
defer func() {
errC := ld.Close()
require.NoError(t, errC)
}()

checkClose := semaphore.NewWeighted(1)
err = checkClose.Acquire(ctx, 1)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()

checkClose := semaphore.NewWeighted(2)
err = checkClose.Acquire(ctx, 2)
require.NoError(t, err)
sd := dtls.NewServer(
options.WithOnNewConn(func(cc *client.Conn) {
t.Log("server - new connection")
cc.AddOnClose(func() {
t.Log("server - client is closed")
checkClose.Release(1)
})
}),
options.WithPeriodicRunner(periodic.New(ctx.Done(), time.Millisecond*10)),
options.WithRequestMonitor(func(_ *client.Conn, req *pool.Message) (bool, error) {
t.Logf("server - received message: %+v\n", req)
// lets drop all messages, this will trigger keep alive because of inactivity
return true, nil
}),
)

var serverWg sync.WaitGroup
serverWg.Add(1)
go func() {
defer serverWg.Done()
for {
c, errA := ld.AcceptWithContext(ctx)
if errA != nil {
if errors.Is(errA, coapNet.ErrListenerIsClosed) {
return
}
}
defer c.Close()
assert.NoError(t, errA)
}
errS := sd.Serve(ld)
assert.NoError(t, errS)
}()
defer func() {
errC := ld.Close()
require.NoError(t, errC)
sd.Stop()
serverWg.Wait()
}()

cc, err := dtls.Dial(
ld.Addr().String(),
clientCgf,
options.WithKeepAlive(3, 100*time.Millisecond, func(cc *client.Conn) {
t.Log("client - close for inactivity")
require.False(t, inactivityDetected.Load())
inactivityDetected.Store(true)
errC := cc.Close()
require.NoError(t, errC)
}),
options.WithPeriodicRunner(periodic.New(ctx.Done(), time.Millisecond*10)),
options.WithReceivedMessageQueueSize(32),
options.WithProcessReceivedMessageFunc(func(req *pool.Message, cc *client.Conn, handler config.HandlerFunc[*client.Conn]) {
cc.ProcessReceivedMessageWithHandler(req, handler)
}),
)
require.NoError(t, err)
cc.AddOnClose(func() {
t.Log("connection is closed")
checkClose.Release(1)
})

// send ping to create server side connection
ctxPing, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
err = cc.Ping(ctxPing)
require.Error(t, err)
_ = cc.Ping(ctxPing)

err = checkClose.Acquire(ctx, 1)
err = checkClose.Acquire(ctx, 2)
require.NoError(t, err)
require.True(t, inactivityDetected.Load())
}
2 changes: 1 addition & 1 deletion dtls/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"log"
"time"

piondtls "github.com/pion/dtls/v2"
piondtls "github.com/pion/dtls/v3"
"github.com/plgd-dev/go-coap/v3/dtls"
"github.com/plgd-dev/go-coap/v3/net"
)
Expand Down
31 changes: 16 additions & 15 deletions dtls/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ func (s *Server) checkAndSetListener(l Listener) error {
return nil
}

func (s *Server) popListener() Listener {
s.listenMutex.Lock()
defer s.listenMutex.Unlock()
l := s.listen
s.listen = nil
return l
}

func (s *Server) checkAcceptError(err error) bool {
if err == nil {
return true
Expand Down Expand Up @@ -134,11 +142,8 @@ func (s *Server) Serve(l Listener) error {
return err
}
defer func() {
s.listenMutex.Lock()
defer s.listenMutex.Unlock()
s.listen = nil
s.Stop()
}()

var wg sync.WaitGroup
defer wg.Wait()

Expand All @@ -158,11 +163,9 @@ func (s *Server) Serve(l Listener) error {
continue
}
wg.Add(1)
var cc *udpClient.Conn
inactivityMonitor := s.cfg.CreateInactivityMonitor()
requestMonitor := s.cfg.RequestMonitor

cc = s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor)
cc := s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor)
if s.cfg.OnNewConn != nil {
s.cfg.OnNewConn(cc)
}
Expand All @@ -176,14 +179,12 @@ func (s *Server) Serve(l Listener) error {
// Stop stops server without wait of ends Serve function.
func (s *Server) Stop() {
s.cancel()
s.listenMutex.Lock()
l := s.listen
s.listen = nil
s.listenMutex.Unlock()
if l != nil {
if err := l.Close(); err != nil {
s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err))
}
l := s.popListener()
if l == nil {
return
}
if err := l.Close(); err != nil {
s.cfg.Errors(fmt.Errorf("cannot close listener: %w", err))
}
}

Expand Down
23 changes: 11 additions & 12 deletions dtls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"testing"
"time"

piondtls "github.com/pion/dtls/v2"
piondtls "github.com/pion/dtls/v3"
"github.com/plgd-dev/go-coap/v3/dtls"
"github.com/plgd-dev/go-coap/v3/examples/dtls/pki"
"github.com/plgd-dev/go-coap/v3/message"
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestServerCleanUpConns(t *testing.T) {
<-cc.Done()
}

func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clientConfig *piondtls.Config, clientSerial *big.Int, err error) {
func createDTLSConfig() (serverConfig *piondtls.Config, clientConfig *piondtls.Config, clientSerial *big.Int, err error) {
// root cert
ca, rootBytes, _, caPriv, err := pki.GenerateCA()
if err != nil {
Expand All @@ -111,9 +111,6 @@ func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clien
ExtendedMasterSecret: piondtls.RequireExtendedMasterSecret,
ClientCAs: certPool,
ClientAuth: piondtls.RequireAndVerifyClientCert,
ConnectContextMaker: func() (context.Context, func()) {
return context.WithTimeout(ctx, 30*time.Second)
},
}

// client cert
Expand Down Expand Up @@ -142,9 +139,9 @@ func createDTLSConfig(ctx context.Context) (serverConfig *piondtls.Config, clien
}

func TestServerSetContextValueWithPKI(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3600)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
serverCgf, clientCgf, clientSerial, err := createDTLSConfig(ctx)
serverCgf, clientCgf, clientSerial, err := createDTLSConfig()
require.NoError(t, err)

ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
Expand All @@ -157,8 +154,10 @@ func TestServerSetContextValueWithPKI(t *testing.T) {
onNewConn := func(cc *client.Conn) {
dtlsConn, ok := cc.NetConn().(*piondtls.Conn)
require.True(t, ok)
state, ok := dtlsConn.ConnectionState()
require.True(t, ok)
// set connection context certificate
clientCert, errP := x509.ParseCertificate(dtlsConn.ConnectionState().PeerCertificates[0])
clientCert, errP := x509.ParseCertificate(state.PeerCertificates[0])
require.NoError(t, errP)
cc.SetContextValue("client-cert", clientCert)
}
Expand Down Expand Up @@ -191,9 +190,9 @@ func TestServerSetContextValueWithPKI(t *testing.T) {
func TestServerInactiveMonitor(t *testing.T) {
var inactivityDetected atomic.Bool

ctx, cancel := context.WithTimeout(context.Background(), time.Second*8)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
serverCgf, clientCgf, _, err := createDTLSConfig(ctx)
serverCgf, clientCgf, _, err := createDTLSConfig()
require.NoError(t, err)

ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
Expand Down Expand Up @@ -266,9 +265,9 @@ func TestServerInactiveMonitor(t *testing.T) {
func TestServerKeepAliveMonitor(t *testing.T) {
var inactivityDetected atomic.Bool

ctx, cancel := context.WithTimeout(context.Background(), time.Second*8)
ctx, cancel := context.WithTimeout(context.Background(), Timeout)
defer cancel()
serverCgf, clientCgf, _, err := createDTLSConfig(ctx)
serverCgf, clientCgf, _, err := createDTLSConfig()
require.NoError(t, err)

ld, err := coapNet.NewDTLSListener("udp4", "", serverCgf)
Expand Down
Loading

0 comments on commit 58b1952

Please sign in to comment.