Skip to content

Commit

Permalink
fixup! Upgrade to github.com/pion/dtls/v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielius1922 committed Sep 12, 2024
1 parent 6867fc7 commit c46fcdd
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 22 deletions.
2 changes: 1 addition & 1 deletion dtls/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func Client(conn *dtls.Conn, opts ...udp.Option) *udpClient.Conn {
}

monitor := cfg.CreateInactivityMonitor()
l := coapNet.NewDTLSConn(conn)
l := coapNet.NewConn(conn)
session := server.NewSession(cfg.Ctx,
l,
cfg.MaxMessageSize,
Expand Down
1 change: 1 addition & 0 deletions dtls/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,5 @@ type Config struct {
TransmissionAcknowledgeTimeout time.Duration
TransmissionMaxRetransmit uint32
MTU uint16
ForceHandshakeOnConnect bool
}
9 changes: 8 additions & 1 deletion dtls/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,14 @@ func (s *Server) Serve(l Listener) error {
wg.Add(1)
inactivityMonitor := s.cfg.CreateInactivityMonitor()
requestMonitor := s.cfg.RequestMonitor
cc := s.createConn(coapNet.NewDTLSConn(rw), inactivityMonitor, requestMonitor)
dtlsConn := coapNet.NewConn(rw)
cc := s.createConn(dtlsConn, inactivityMonitor, requestMonitor)
if s.cfg.ForceHandshakeOnConnect {
err = dtlsConn.HandshakeContext(s.ctx)
if err != nil {
return err
}
}
if s.cfg.OnNewConn != nil {
s.cfg.OnNewConn(cc)
}
Expand Down
28 changes: 13 additions & 15 deletions dtls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,29 +151,27 @@ func TestServerSetContextValueWithPKI(t *testing.T) {
require.NoError(t, errC)
}()

onNewConn := func(cc *client.Conn) {
dtlsConn, ok := cc.NetConn().(*piondtls.Conn)
assert.True(t, ok)
// set connection context certificate
state, ok := dtlsConn.ConnectionState()
assert.True(t, ok)
clientCert, errP := x509.ParseCertificate(state.PeerCertificates[0])
assert.NoError(t, errP) //nolint:testifylint
cc.SetContextValue("client-cert", clientCert)
}
handle := func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) {
// get certificate from connection context
var clientCert *x509.Certificate
value := r.Context().Value("client-cert")
if value == nil {
dtlsConn, ok := w.Conn().NetConn().(*piondtls.Conn)
assert.True(t, ok)
state, ok := dtlsConn.ConnectionState()
assert.True(t, ok)
var errP error
clientCert, errP = x509.ParseCertificate(state.PeerCertificates[0])
assert.NoError(t, errP) //nolint:testifylint
w.Conn().SetContextValue("client-cert", clientCert)
} else {
clientCert = r.Context().Value("client-cert").(*x509.Certificate)
}
clientCert := r.Context().Value("client-cert").(*x509.Certificate)
assert.Equal(t, clientCert.SerialNumber, clientSerial)
assert.NotNil(t, clientCert)
errH := w.SetResponse(codes.Content, message.TextPlain, bytes.NewReader([]byte("done")))
assert.NoError(t, errH)
}

sd := dtls.NewServer(options.WithHandlerFunc(handle))
sd := dtls.NewServer(options.WithHandlerFunc(handle), options.WithOnNewConn(onNewConn),
options.WithForceHandshakeOnConnect())
var wg sync.WaitGroup
defer func() {
sd.Stop()
Expand Down
12 changes: 8 additions & 4 deletions net/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package net

import (
"context"
"errors"
"fmt"
"net"
"sync"
Expand Down Expand Up @@ -34,10 +35,6 @@ func NewConn(c net.Conn) *Conn {
return &connection
}

func NewDTLSConn(c net.Conn) *Conn {
return NewConn(c)
}

// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
func (c *Conn) LocalAddr() net.Addr {
return c.connection.LocalAddr()
Expand All @@ -61,6 +58,13 @@ func (c *Conn) Close() error {
return c.connection.Close()
}

func (c *Conn) HandshakeContext(ctx context.Context) error {
if c.handshakeContext == nil {
return errors.New("handshake not supported")
}
return c.handshakeContext(ctx)
}

func (c *Conn) handshake(ctx context.Context) error {
if c.handshakeContext != nil {
err := c.handshakeContext(ctx)
Expand Down
13 changes: 13 additions & 0 deletions options/dtlsOptions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package options

import dtlsServer "github.com/plgd-dev/go-coap/v3/dtls/server"

type ForceHandshakeOnConnectOpt struct{}

func (o ForceHandshakeOnConnectOpt) DTLSServerApply(cfg *dtlsServer.Config) {
cfg.ForceHandshakeOnConnect = true
}

func WithForceHandshakeOnConnect() ForceHandshakeOnConnectOpt {
return ForceHandshakeOnConnectOpt{}
}
2 changes: 1 addition & 1 deletion udp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn {
messagePool: cfg.MessagePool,
numOutstandingInteraction: semaphore.NewWeighted(math.MaxInt64),
}
cc.msgID.Store(uint32(cfg.GetMID() - 0xffff/2))
cc.msgID.Store(pkgMath.CastTo[uint32](cfg.GetMID() - 0xffff/2))
cc.blockWise = cfgOpts.createBlockWise(&cc)
limitParallelRequests := limitparallelrequests.New(cfg.LimitClientParallelRequests, cfg.LimitClientEndpointParallelRequests, cc.do, cc.doObserve)
cc.observationHandler = observation.NewHandler(&cc, cfg.Handler, limitParallelRequests.Do)
Expand Down

0 comments on commit c46fcdd

Please sign in to comment.