Skip to content

Commit

Permalink
fixup! Upgrade to github.com/pion/dtls/v3
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielius1922 committed Aug 7, 2024
1 parent 5117fe4 commit af365e2
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 26 deletions.
2 changes: 1 addition & 1 deletion dtls/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func Client(conn *dtls.Conn, opts ...udp.Option) *udpClient.Conn {
}

monitor := cfg.CreateInactivityMonitor()
l := coapNet.NewConn(conn)
l := coapNet.NewDTLSConn(conn)
session := server.NewSession(cfg.Ctx,
l,
cfg.MaxMessageSize,
Expand Down
2 changes: 1 addition & 1 deletion dtls/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (s *Server) Serve(l Listener) error {
wg.Add(1)
inactivityMonitor := s.cfg.CreateInactivityMonitor()
requestMonitor := s.cfg.RequestMonitor
cc := s.createConn(coapNet.NewConn(rw), inactivityMonitor, requestMonitor)
cc := s.createConn(coapNet.NewDTLSConn(rw), inactivityMonitor, requestMonitor)
if s.cfg.OnNewConn != nil {
s.cfg.OnNewConn(cc)
}
Expand Down
34 changes: 15 additions & 19 deletions dtls/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,41 +151,37 @@ func TestServerSetContextValueWithPKI(t *testing.T) {
require.NoError(t, errC)
}()

onNewConn := func(cc *client.Conn) {
dtlsConn, ok := cc.NetConn().(*piondtls.Conn)
require.True(t, ok)
state, k := dtlsConn.ConnectionState()
require.True(t, k)
// set connection context certificate
clientCert, errP := x509.ParseCertificate(state.PeerCertificates[0])
require.NoError(t, errP)
cc.SetContextValue("client-cert", clientCert)
}
handle := func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) {
// get certificate from connection context
var clientCert *x509.Certificate
value := r.Context().Value("client-cert")
if value == nil {
dtlsConn, ok := w.Conn().NetConn().(*piondtls.Conn)
require.True(t, ok)
state, k := dtlsConn.ConnectionState()
require.True(t, k)
assert.True(t, ok)
state, ok := dtlsConn.ConnectionState()
assert.True(t, ok)
var errP error
clientCert, errP = x509.ParseCertificate(state.PeerCertificates[0])
require.NoError(t, errP)
assert.NoError(t, errP) //nolint:testifylint
w.Conn().SetContextValue("client-cert", clientCert)
} else {
clientCert = r.Context().Value("client-cert").(*x509.Certificate)
}
require.Equal(t, clientCert.SerialNumber, clientSerial)
require.NotNil(t, clientCert)
assert.Equal(t, clientCert.SerialNumber, clientSerial)
assert.NotNil(t, clientCert)
errH := w.SetResponse(codes.Content, message.TextPlain, bytes.NewReader([]byte("done")))
require.NoError(t, errH)
assert.NoError(t, errH)
}

sd := dtls.NewServer(options.WithHandlerFunc(handle), options.WithOnNewConn(onNewConn))
defer sd.Stop()
sd := dtls.NewServer(options.WithHandlerFunc(handle))
var wg sync.WaitGroup
defer func() {
sd.Stop()
wg.Wait()
}()
wg.Add(1)
go func() {
defer wg.Done()
errS := sd.Serve(ld)
assert.NoError(t, errS)
}()
Expand Down
42 changes: 37 additions & 5 deletions net/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,45 @@ import (
"go.uber.org/atomic"
)

type ConnOption interface {
Apply(*ConnConfig)
}

type ConnConfig struct {
disabledManualCloseAfterHandshake bool
}

func (h ConnConfig) Apply(o *ConnConfig) {
o.disabledManualCloseAfterHandshake = h.disabledManualCloseAfterHandshake
}

func WithDisabledManualCloseAfterHandshake() ConnConfig {
return ConnConfig{
disabledManualCloseAfterHandshake: true,
}
}

// Conn is a generic stream-oriented network connection that provides Read/Write with context.
//
// Multiple goroutines may invoke methods on a Conn simultaneously.
type Conn struct {
connection net.Conn
closed atomic.Bool
cfg ConnConfig
handshakeContext func(ctx context.Context) error
lock sync.Mutex
}

// NewConn creates connection over net.Conn.
func NewConn(c net.Conn) *Conn {
func NewConn(c net.Conn, opts ...ConnOption) *Conn {
cfg := ConnConfig{}
for _, o := range opts {
o.Apply(&cfg)
}

connection := Conn{
connection: c,
cfg: cfg,
}

if v, ok := c.(interface {
Expand All @@ -34,6 +59,10 @@ func NewConn(c net.Conn) *Conn {
return &connection
}

func NewDTLSConn(c net.Conn) *Conn {
return NewConn(c, WithDisabledManualCloseAfterHandshake())
}

// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
func (c *Conn) LocalAddr() net.Addr {
return c.connection.LocalAddr()
Expand Down Expand Up @@ -63,11 +92,14 @@ func (c *Conn) handshake(ctx context.Context) error {
if err == nil {
return nil
}
errC := c.Close()
if errC == nil {
return err
if !c.cfg.disabledManualCloseAfterHandshake {
errC := c.Close()
if errC == nil {
return err
}
return fmt.Errorf("%v", []error{err, errC})
}
return fmt.Errorf("%v", []error{err, errC})
return err
}
return nil
}
Expand Down

0 comments on commit af365e2

Please sign in to comment.