Skip to content

Commit

Permalink
allow to drop request
Browse files Browse the repository at this point in the history
  • Loading branch information
jkralik committed Nov 27, 2023
1 parent ce59bda commit 55a61d6
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 36 deletions.
4 changes: 2 additions & 2 deletions dtls/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ var DefaultConfig = func() Config {
}
return inactivity.New(timeout, onInactive)
},
RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) error {
return nil
RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) (bool, error) {
return false, nil
},
OnNewConn: func(cc *udpClient.Conn) {
// do nothing by default
Expand Down
5 changes: 4 additions & 1 deletion options/commonOptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,10 @@ func (o WithRequestMonitorOpt[F]) TCPServerApply(cfg *tcpServer.Config) {
}
}

// WithRequestMonitor ping handler
// WithRequestMonitor enables request monitoring for the connection.
// It is called for each CoAP message received from the peer before it is processed.
// If it returns an error, the connection is closed.
// If it returns true, the message is dropped.
func WithRequestMonitor[F WithRequestMonitorFunc](requestMonitor F) WithRequestMonitorOpt[F] {
return WithRequestMonitorOpt[F]{
f: requestMonitor,
Expand Down
4 changes: 2 additions & 2 deletions tcp/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ var DefaultConfig = func() Config {
CreateInactivityMonitor: func() InactivityMonitor {
return inactivity.NewNilMonitor[*Conn]()
},
RequestMonitor: func(*Conn, *pool.Message) error {
return nil
RequestMonitor: func(*Conn, *pool.Message) (bool, error) {
return false, nil
},
Dialer: &net.Dialer{Timeout: time.Second * 3},
Net: "tcp",
Expand Down
11 changes: 7 additions & 4 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type (
EventFunc = func()
GetMIDFunc = func() int32
CreateInactivityMonitorFunc = func() InactivityMonitor
RequestMonitorFunc = func(cc *Conn, req *pool.Message) error
RequestMonitorFunc = func(cc *Conn, req *pool.Message) (drop bool, err error)
)

type Notifier interface {
Expand Down Expand Up @@ -78,7 +78,10 @@ func WithInactivityMonitor(inactivityMonitor InactivityMonitor) Option {
}
}

// WithRequestMonitor enables request monitor for the connection.
// WithRequestMonitor enables request monitoring for the connection.
// It is called for each CoAP message received from the peer before it is processed.
// If it returns an error, the connection is closed.
// If it returns true, the message is dropped.
func WithRequestMonitor(requestMonitor RequestMonitorFunc) Option {
return func(opts *ConnOptions) {
opts.RequestMonitor = requestMonitor
Expand All @@ -104,8 +107,8 @@ func NewConnWithOpts(connection *coapNet.Conn, cfg *Config, opts ...Option) *Con
return nil
},
InactivityMonitor: inactivity.NewNilMonitor[*Conn](),
RequestMonitor: func(*Conn, *pool.Message) error {
return nil
RequestMonitor: func(*Conn, *pool.Message) (bool, error) {
return false, nil
},
}
for _, o := range opts {
Expand Down
11 changes: 8 additions & 3 deletions tcp/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ func NewSession(
inactivityMonitor = inactivity.NewNilMonitor[*Conn]()
}
if requestMonitor == nil {
requestMonitor = func(*Conn, *pool.Message) error {
return nil
requestMonitor = func(*Conn, *pool.Message) (bool, error) {
return false, nil
}
}

Expand Down Expand Up @@ -183,10 +183,15 @@ func (s *Session) processBuffer(buffer *bytes.Buffer, cc *Conn) error {
buffer = seekBufferToNextMessage(buffer, read)
req.SetSequence(s.Sequence())

if err = s.requestMonitor(cc, req); err != nil {
drop, err := s.requestMonitor(cc, req)
if err != nil {
s.messagePool.ReleaseMessage(req)
return fmt.Errorf("request monitor: %w", err)
}
if drop {
s.messagePool.ReleaseMessage(req)
continue
}
s.inactivityMonitor.Notify()
cc.pushToReceivedMessageQueue(req)
}
Expand Down
86 changes: 82 additions & 4 deletions tcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ func TestClientKeepAliveMonitor(t *testing.T) {
require.True(t, inactivityDetected.Load())
}

func TestConnRequestMonitor(t *testing.T) {
func TestConnRequestMonitorCloseConnection(t *testing.T) {
l, err := coapNet.NewTCPListener("tcp", "")
require.NoError(t, err)
defer func() {
Expand All @@ -695,11 +695,11 @@ func TestConnRequestMonitor(t *testing.T) {
testEOFError := errors.New("test error")
s := NewServer(
options.WithMux(m),
options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) error {
options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) {
if req.Code() == codes.DELETE {
return testEOFError
return false, testEOFError
}
return nil
return false, nil
}),
options.WithErrors(func(err error) {
t.Log(err)
Expand Down Expand Up @@ -754,3 +754,81 @@ func TestConnRequestMonitor(t *testing.T) {
require.Error(t, err)
<-reqMonitorErr
}

func TestConnRequestMonitorDropRequest(t *testing.T) {
l, err := coapNet.NewTCPListener("tcp", "")
require.NoError(t, err)
defer func() {
errC := l.Close()
require.NoError(t, errC)
}()
var wg sync.WaitGroup
defer wg.Wait()

m := mux.NewRouter()

// The response counts up with every get
// so we can check if the handler is only called once per message ID
err = m.Handle("/test", mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) {
errH := w.SetResponse(codes.Content, message.TextPlain, nil)
require.NoError(t, errH)
}))
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
s := NewServer(
options.WithMux(m),
options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) {
if req.Code() == codes.DELETE {
t.Log("drop request")
return true, nil
}
return false, nil
}),
options.WithErrors(func(err error) {
require.NoError(t, err)
}),
options.WithOnNewConn(func(c *client.Conn) {
t.Log("new conn")
c.AddOnClose(func() {
t.Log("close conn")
})
}))
defer s.Stop()
wg.Add(1)
go func() {
defer wg.Done()
errS := s.Serve(l)
require.NoError(t, errS)
}()

cc, err := Dial(l.Addr().String(),
options.WithErrors(func(err error) {
t.Log(err)
}),
)
require.NoError(t, err)
defer func() {
errC := cc.Close()
require.NoError(t, errC)
}()

// Setup done - Run Tests
// Same request, executed twice (needs the same token)
getReq, err := cc.NewGetRequest(ctx, "/test")
getReq.SetMessageID(1)

require.NoError(t, err)
got, err := cc.Do(getReq)
require.NoError(t, err)
require.Equal(t, codes.Content.String(), got.Code().String())

// New request but with DELETE code to trigger EOF error from request monitor
deleteReq, err := cc.NewDeleteRequest(ctx, "/test")
require.NoError(t, err)
deleteReq.SetMessageID(2)
_, err = cc.Do(deleteReq)
require.Error(t, err)
require.True(t, errors.Is(err, context.DeadlineExceeded))
}
4 changes: 2 additions & 2 deletions tcp/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ var DefaultConfig = func() Config {
OnNewConn: func(cc *client.Conn) {
// do nothing by default
},
RequestMonitor: func(cc *client.Conn, req *pool.Message) error {
return nil
RequestMonitor: func(cc *client.Conn, req *pool.Message) (bool, error) {
return false, nil
},
ConnectionCacheSize: 2 * 1024,
}
Expand Down
5 changes: 0 additions & 5 deletions udp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ func Client(conn *net.UDPConn, opts ...Option) *client.Conn {
if cfg.MessagePool == nil {
cfg.MessagePool = pool.New(0, 0)
}
if cfg.RequestMonitor == nil {
cfg.RequestMonitor = func(*client.Conn, *pool.Message) error {
return nil
}
}

errorsFunc := cfg.Errors
cfg.Errors = func(err error) {
Expand Down
4 changes: 2 additions & 2 deletions udp/client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ var DefaultConfig = func() Config {
CreateInactivityMonitor: func() InactivityMonitor {
return inactivity.NewNilMonitor[*Conn]()
},
RequestMonitor: func(*Conn, *pool.Message) error {
return nil
RequestMonitor: func(*Conn, *pool.Message) (bool, error) {
return false, nil
},
Dialer: &net.Dialer{Timeout: time.Second * 3},
Net: "udp",
Expand Down
18 changes: 13 additions & 5 deletions udp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type (
EventFunc = func()
GetMIDFunc = func() int32
CreateInactivityMonitorFunc = func() InactivityMonitor
RequestMonitorFunc = func(cc *Conn, req *pool.Message) error
RequestMonitorFunc = func(cc *Conn, req *pool.Message) (drop bool, err error)
)

type InactivityMonitor interface {
Expand Down Expand Up @@ -209,7 +209,10 @@ func WithInactivityMonitor(inactivityMonitor InactivityMonitor) Option {
}
}

// WithRequestMonitor enables request monitor for the connection.
// WithRequestMonitor enables request monitoring for the connection.
// It is called for each CoAP message received from the peer before it is processed.
// If it returns an error, the connection is closed.
// If it returns true, the message is dropped.
func WithRequestMonitor(requestMonitor RequestMonitorFunc) Option {
return func(opts *ConnOptions) {
opts.requestMonitor = requestMonitor
Expand Down Expand Up @@ -237,8 +240,8 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn {
return nil
},
inactivityMonitor: inactivity.NewNilMonitor[*Conn](),
requestMonitor: func(*Conn, *pool.Message) error {
return nil
requestMonitor: func(*Conn, *pool.Message) (bool, error) {
return false, nil
},
}
for _, o := range opts {
Expand Down Expand Up @@ -844,10 +847,15 @@ func (cc *Conn) Process(cm *coapNet.ControlMessage, datagram []byte) error {
req.SetControlMessage(cm)
req.SetSequence(cc.Sequence())
cc.checkMyMessageID(req)
if err = cc.requestMonitor(cc, req); err != nil {
drop, err := cc.requestMonitor(cc, req)
if err != nil {
cc.ReleaseMessage(req)
return fmt.Errorf("request monitor: %w", err)
}
if drop {
cc.ReleaseMessage(req)
return nil
}
cc.inactivityMonitor.Notify()
if cc.handleSpecialMessages(req) {
return nil
Expand Down
86 changes: 82 additions & 4 deletions udp/client/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ func TestConnPing(t *testing.T) {
require.NoError(t, err)
}

func TestConnRequestMonitor(t *testing.T) {
func TestConnRequestMonitorCloseConnection(t *testing.T) {
l, err := coapNet.NewListenUDP("udp", "")
require.NoError(t, err)
defer func() {
Expand All @@ -860,11 +860,11 @@ func TestConnRequestMonitor(t *testing.T) {
testEOFError := errors.New("test error")
s := udp.NewServer(
options.WithMux(m),
options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) error {
options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) {
if req.Code() == codes.DELETE {
return testEOFError
return false, testEOFError
}
return nil
return false, nil
}),
options.WithErrors(func(err error) {
t.Log(err)
Expand Down Expand Up @@ -921,3 +921,81 @@ func TestConnRequestMonitor(t *testing.T) {
require.Error(t, err)
<-reqMonitorErr
}

func TestConnRequestMonitorDropRequest(t *testing.T) {
l, err := coapNet.NewListenUDP("udp", "")
require.NoError(t, err)
defer func() {
errC := l.Close()
require.NoError(t, errC)
}()
var wg sync.WaitGroup
defer wg.Wait()

m := mux.NewRouter()

// The response counts up with every get
// so we can check if the handler is only called once per message ID
err = m.Handle("/test", mux.HandlerFunc(func(w mux.ResponseWriter, r *mux.Message) {
errH := w.SetResponse(codes.Content, message.TextPlain, nil)
require.NoError(t, errH)
}))
require.NoError(t, err)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
defer cancel()
s := udp.NewServer(
options.WithMux(m),
options.WithRequestMonitor(func(c *client.Conn, req *pool.Message) (bool, error) {
if req.Code() == codes.DELETE {
t.Log("drop request")
return true, nil
}
return false, nil
}),
options.WithErrors(func(err error) {
require.NoError(t, err)
}),
options.WithOnNewConn(func(c *client.Conn) {
t.Log("new conn")
c.AddOnClose(func() {
t.Log("close conn")
})
}))
defer s.Stop()
wg.Add(1)
go func() {
defer wg.Done()
errS := s.Serve(l)
require.NoError(t, errS)
}()

cc, err := udp.Dial(l.LocalAddr().String(),
options.WithErrors(func(err error) {
t.Log(err)
}),
)
require.NoError(t, err)
defer func() {
errC := cc.Close()
require.NoError(t, errC)
}()

// Setup done - Run Tests
// Same request, executed twice (needs the same token)
getReq, err := cc.NewGetRequest(ctx, "/test")
getReq.SetMessageID(1)

require.NoError(t, err)
got, err := cc.Do(getReq)
require.NoError(t, err)
require.Equal(t, codes.Content.String(), got.Code().String())

// New request but with DELETE code to trigger EOF error from request monitor
deleteReq, err := cc.NewDeleteRequest(ctx, "/test")
require.NoError(t, err)
deleteReq.SetMessageID(2)
_, err = cc.Do(deleteReq)
require.Error(t, err)
require.True(t, errors.Is(err, context.DeadlineExceeded))
}
4 changes: 2 additions & 2 deletions udp/server/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ var DefaultConfig = func() Config {
OnNewConn: func(cc *udpClient.Conn) {
// do nothing by default
},
RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) error {
return nil
RequestMonitor: func(cc *udpClient.Conn, req *pool.Message) (bool, error) {
return false, nil
},
TransmissionNStart: 1,
TransmissionAcknowledgeTimeout: time.Second * 2,
Expand Down

0 comments on commit 55a61d6

Please sign in to comment.