diff --git a/dtls/server/config.go b/dtls/server/config.go index f4859699..245b4433 100644 --- a/dtls/server/config.go +++ b/dtls/server/config.go @@ -34,8 +34,8 @@ var DefaultConfig = func() Config { } return inactivity.New(timeout, onInactive) }, - RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) error { - return nil + RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) (bool, error) { + return false, nil }, OnNewConn: func(cc *udpClient.Conn) { // do nothing by default diff --git a/options/commonOptions.go b/options/commonOptions.go index 736e6c38..1c09b89a 100644 --- a/options/commonOptions.go +++ b/options/commonOptions.go @@ -637,7 +637,10 @@ func (o WithRequestMonitorOpt[F]) TCPServerApply(cfg *tcpServer.Config) { } } -// WithRequestMonitor ping handler +// WithRequestMonitor enables request monitoring for the connection. +// It is called for each CoAP message received from the peer before it is processed. +// If it returns an error, the connection is closed. +// If it returns true, the message is dropped. func WithRequestMonitor[F WithRequestMonitorFunc](requestMonitor F) WithRequestMonitorOpt[F] { return WithRequestMonitorOpt[F]{ f: requestMonitor, diff --git a/tcp/client/config.go b/tcp/client/config.go index 2551a64e..b3c71857 100644 --- a/tcp/client/config.go +++ b/tcp/client/config.go @@ -20,8 +20,8 @@ var DefaultConfig = func() Config { CreateInactivityMonitor: func() InactivityMonitor { return inactivity.NewNilMonitor[*Conn]() }, - RequestMonitor: func(*Conn, *pool.Message) error { - return nil + RequestMonitor: func(*Conn, *pool.Message) (bool, error) { + return false, nil }, Dialer: &net.Dialer{Timeout: time.Second * 3}, Net: "tcp", diff --git a/tcp/client/conn.go b/tcp/client/conn.go index 21b6c330..ecc8e879 100644 --- a/tcp/client/conn.go +++ b/tcp/client/conn.go @@ -33,7 +33,7 @@ type ( EventFunc = func() GetMIDFunc = func() int32 CreateInactivityMonitorFunc = func() InactivityMonitor - RequestMonitorFunc = func(cc *Conn, req *pool.Message) error + RequestMonitorFunc = func(cc *Conn, req *pool.Message) (drop bool, err error) ) type Notifier interface { @@ -78,7 +78,10 @@ func WithInactivityMonitor(inactivityMonitor InactivityMonitor) Option { } } -// WithRequestMonitor enables request monitor for the connection. +// WithRequestMonitor enables request monitoring for the connection. +// It is called for each CoAP message received from the peer before it is processed. +// If it returns an error, the connection is closed. +// If it returns true, the message is dropped. func WithRequestMonitor(requestMonitor RequestMonitorFunc) Option { return func(opts *ConnOptions) { opts.RequestMonitor = requestMonitor @@ -104,8 +107,8 @@ func NewConnWithOpts(connection *coapNet.Conn, cfg *Config, opts ...Option) *Con return nil }, InactivityMonitor: inactivity.NewNilMonitor[*Conn](), - RequestMonitor: func(*Conn, *pool.Message) error { - return nil + RequestMonitor: func(*Conn, *pool.Message) (bool, error) { + return false, nil }, } for _, o := range opts { diff --git a/tcp/client/session.go b/tcp/client/session.go index f7ddb827..ecf58415 100644 --- a/tcp/client/session.go +++ b/tcp/client/session.go @@ -63,8 +63,8 @@ func NewSession( inactivityMonitor = inactivity.NewNilMonitor[*Conn]() } if requestMonitor == nil { - requestMonitor = func(*Conn, *pool.Message) error { - return nil + requestMonitor = func(*Conn, *pool.Message) (bool, error) { + return false, nil } } @@ -183,10 +183,15 @@ func (s *Session) processBuffer(buffer *bytes.Buffer, cc *Conn) error { buffer = seekBufferToNextMessage(buffer, read) req.SetSequence(s.Sequence()) - if err = s.requestMonitor(cc, req); err != nil { + drop, err := s.requestMonitor(cc, req) + if err != nil { s.messagePool.ReleaseMessage(req) return fmt.Errorf("request monitor: %w", err) } + if drop { + s.messagePool.ReleaseMessage(req) + continue + } s.inactivityMonitor.Notify() cc.pushToReceivedMessageQueue(req) } diff --git a/tcp/client_test.go b/tcp/client_test.go index 0294791c..405fa187 100644 --- a/tcp/client_test.go +++ b/tcp/client_test.go @@ -669,7 +669,7 @@ func TestClientKeepAliveMonitor(t *testing.T) { require.True(t, inactivityDetected.Load()) } -func TestConnRequestMonitor(t *testing.T) { +func TestConnRequestMonitorCloseConnection(t *testing.T) { l, err := coapNet.NewTCPListener("tcp", "") require.NoError(t, err) defer func() { @@ -695,11 +695,11 @@ func TestConnRequestMonitor(t *testing.T) { testEOFError := errors.New("test error") s := NewServer( options.WithMux(m), - options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) error { + options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) { if req.Code() == codes.DELETE { - return testEOFError + return false, testEOFError } - return nil + return false, nil }), options.WithErrors(func(err error) { t.Log(err) @@ -754,3 +754,81 @@ func TestConnRequestMonitor(t *testing.T) { require.Error(t, err) <-reqMonitorErr } + +func TestConnRequestMonitorDropRequest(t *testing.T) { + l, err := coapNet.NewTCPListener("tcp", "") + require.NoError(t, err) + defer func() { + errC := l.Close() + require.NoError(t, errC) + }() + var wg sync.WaitGroup + defer wg.Wait() + + m := mux.NewRouter() + + // The response counts up with every get + // so we can check if the handler is only called once per message ID + err = m.Handle("/test", mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + errH := w.SetResponse(codes.Content, message.TextPlain, nil) + require.NoError(t, errH) + })) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + s := NewServer( + options.WithMux(m), + options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) { + if req.Code() == codes.DELETE { + t.Log("drop request") + return true, nil + } + return false, nil + }), + options.WithErrors(func(err error) { + require.NoError(t, err) + }), + options.WithOnNewConn(func(c *client.Conn) { + t.Log("new conn") + c.AddOnClose(func() { + t.Log("close conn") + }) + })) + defer s.Stop() + wg.Add(1) + go func() { + defer wg.Done() + errS := s.Serve(l) + require.NoError(t, errS) + }() + + cc, err := Dial(l.Addr().String(), + options.WithErrors(func(err error) { + t.Log(err) + }), + ) + require.NoError(t, err) + defer func() { + errC := cc.Close() + require.NoError(t, errC) + }() + + // Setup done - Run Tests + // Same request, executed twice (needs the same token) + getReq, err := cc.NewGetRequest(ctx, "/test") + getReq.SetMessageID(1) + + require.NoError(t, err) + got, err := cc.Do(getReq) + require.NoError(t, err) + require.Equal(t, codes.Content.String(), got.Code().String()) + + // New request but with DELETE code to trigger EOF error from request monitor + deleteReq, err := cc.NewDeleteRequest(ctx, "/test") + require.NoError(t, err) + deleteReq.SetMessageID(2) + _, err = cc.Do(deleteReq) + require.Error(t, err) + require.True(t, errors.Is(err, context.DeadlineExceeded)) +} diff --git a/tcp/server/config.go b/tcp/server/config.go index 48f5fb2d..25a59b07 100644 --- a/tcp/server/config.go +++ b/tcp/server/config.go @@ -41,8 +41,8 @@ var DefaultConfig = func() Config { OnNewConn: func(cc *client.Conn) { // do nothing by default }, - RequestMonitor: func(cc *client.Conn, req *pool.Message) error { - return nil + RequestMonitor: func(cc *client.Conn, req *pool.Message) (bool, error) { + return false, nil }, ConnectionCacheSize: 2 * 1024, } diff --git a/udp/client.go b/udp/client.go index 26a1bcb0..c2461181 100644 --- a/udp/client.go +++ b/udp/client.go @@ -58,11 +58,6 @@ func Client(conn *net.UDPConn, opts ...Option) *client.Conn { if cfg.MessagePool == nil { cfg.MessagePool = pool.New(0, 0) } - if cfg.RequestMonitor == nil { - cfg.RequestMonitor = func(*client.Conn, *pool.Message) error { - return nil - } - } errorsFunc := cfg.Errors cfg.Errors = func(err error) { diff --git a/udp/client/config.go b/udp/client/config.go index f42d3fbe..1d1d4a7d 100644 --- a/udp/client/config.go +++ b/udp/client/config.go @@ -21,8 +21,8 @@ var DefaultConfig = func() Config { CreateInactivityMonitor: func() InactivityMonitor { return inactivity.NewNilMonitor[*Conn]() }, - RequestMonitor: func(*Conn, *pool.Message) error { - return nil + RequestMonitor: func(*Conn, *pool.Message) (bool, error) { + return false, nil }, Dialer: &net.Dialer{Timeout: time.Second * 3}, Net: "udp", diff --git a/udp/client/conn.go b/udp/client/conn.go index 9c28e2c3..14586c24 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -38,7 +38,7 @@ type ( EventFunc = func() GetMIDFunc = func() int32 CreateInactivityMonitorFunc = func() InactivityMonitor - RequestMonitorFunc = func(cc *Conn, req *pool.Message) error + RequestMonitorFunc = func(cc *Conn, req *pool.Message) (drop bool, err error) ) type InactivityMonitor interface { @@ -209,7 +209,10 @@ func WithInactivityMonitor(inactivityMonitor InactivityMonitor) Option { } } -// WithRequestMonitor enables request monitor for the connection. +// WithRequestMonitor enables request monitoring for the connection. +// It is called for each CoAP message received from the peer before it is processed. +// If it returns an error, the connection is closed. +// If it returns true, the message is dropped. func WithRequestMonitor(requestMonitor RequestMonitorFunc) Option { return func(opts *ConnOptions) { opts.requestMonitor = requestMonitor @@ -237,8 +240,8 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { return nil }, inactivityMonitor: inactivity.NewNilMonitor[*Conn](), - requestMonitor: func(*Conn, *pool.Message) error { - return nil + requestMonitor: func(*Conn, *pool.Message) (bool, error) { + return false, nil }, } for _, o := range opts { @@ -844,10 +847,15 @@ func (cc *Conn) Process(cm *coapNet.ControlMessage, datagram []byte) error { req.SetControlMessage(cm) req.SetSequence(cc.Sequence()) cc.checkMyMessageID(req) - if err = cc.requestMonitor(cc, req); err != nil { + drop, err := cc.requestMonitor(cc, req) + if err != nil { cc.ReleaseMessage(req) return fmt.Errorf("request monitor: %w", err) } + if drop { + cc.ReleaseMessage(req) + return nil + } cc.inactivityMonitor.Notify() if cc.handleSpecialMessages(req) { return nil diff --git a/udp/client/conn_test.go b/udp/client/conn_test.go index 71b1ba81..cbc0dbbe 100644 --- a/udp/client/conn_test.go +++ b/udp/client/conn_test.go @@ -834,7 +834,7 @@ func TestConnPing(t *testing.T) { require.NoError(t, err) } -func TestConnRequestMonitor(t *testing.T) { +func TestConnRequestMonitorCloseConnection(t *testing.T) { l, err := coapNet.NewListenUDP("udp", "") require.NoError(t, err) defer func() { @@ -860,11 +860,11 @@ func TestConnRequestMonitor(t *testing.T) { testEOFError := errors.New("test error") s := udp.NewServer( options.WithMux(m), - options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) error { + options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) { if req.Code() == codes.DELETE { - return testEOFError + return false, testEOFError } - return nil + return false, nil }), options.WithErrors(func(err error) { t.Log(err) @@ -921,3 +921,81 @@ func TestConnRequestMonitor(t *testing.T) { require.Error(t, err) <-reqMonitorErr } + +func TestConnRequestMonitorDropRequest(t *testing.T) { + l, err := coapNet.NewListenUDP("udp", "") + require.NoError(t, err) + defer func() { + errC := l.Close() + require.NoError(t, errC) + }() + var wg sync.WaitGroup + defer wg.Wait() + + m := mux.NewRouter() + + // The response counts up with every get + // so we can check if the handler is only called once per message ID + err = m.Handle("/test", mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) { + errH := w.SetResponse(codes.Content, message.TextPlain, nil) + require.NoError(t, errH) + })) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + s := udp.NewServer( + options.WithMux(m), + options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) { + if req.Code() == codes.DELETE { + t.Log("drop request") + return true, nil + } + return false, nil + }), + options.WithErrors(func(err error) { + require.NoError(t, err) + }), + options.WithOnNewConn(func(c *client.Conn) { + t.Log("new conn") + c.AddOnClose(func() { + t.Log("close conn") + }) + })) + defer s.Stop() + wg.Add(1) + go func() { + defer wg.Done() + errS := s.Serve(l) + require.NoError(t, errS) + }() + + cc, err := udp.Dial(l.LocalAddr().String(), + options.WithErrors(func(err error) { + t.Log(err) + }), + ) + require.NoError(t, err) + defer func() { + errC := cc.Close() + require.NoError(t, errC) + }() + + // Setup done - Run Tests + // Same request, executed twice (needs the same token) + getReq, err := cc.NewGetRequest(ctx, "/test") + getReq.SetMessageID(1) + + require.NoError(t, err) + got, err := cc.Do(getReq) + require.NoError(t, err) + require.Equal(t, codes.Content.String(), got.Code().String()) + + // New request but with DELETE code to trigger EOF error from request monitor + deleteReq, err := cc.NewDeleteRequest(ctx, "/test") + require.NoError(t, err) + deleteReq.SetMessageID(2) + _, err = cc.Do(deleteReq) + require.Error(t, err) + require.True(t, errors.Is(err, context.DeadlineExceeded)) +} diff --git a/udp/server/config.go b/udp/server/config.go index 627f25b7..9a556867 100644 --- a/udp/server/config.go +++ b/udp/server/config.go @@ -37,8 +37,8 @@ var DefaultConfig = func() Config { OnNewConn: func(cc *udpClient.Conn) { // do nothing by default }, - RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) error { - return nil + RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) (bool, error) { + return false, nil }, TransmissionNStart: 1, TransmissionAcknowledgeTimeout: time.Second * 2,