diff --git a/udp/client/conn.go b/udp/client/conn.go index 54d215f3..ac6a7038 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -6,6 +6,7 @@ import ( "fmt" "math" "net" + "strconv" "sync" "time" @@ -127,6 +128,62 @@ func (m *midElement) GetMessage(cc *Conn) (*pool.Message, bool, error) { return msg, true, nil } +// MessageCache is a cache of CoAP messages. +type MessageCache interface { + Load(key string, msg *pool.Message) (bool, error) + Store(key string, msg *pool.Message) error + CheckExpirations(time.Time) +} + +// messageCache is a CoAP message cache backed by an in-memory cache. +type messageCache struct { + c *cache.Cache[string, []byte] +} + +// newMessageCache constructs a new CoAP message cache. +func newMessageCache() *messageCache { + return &messageCache{ + c: cache.NewCache[string, []byte](), + } +} + +// Load loads a message from the cache if one exists with key. +func (m *messageCache) Load(key string, msg *pool.Message) (bool, error) { + cachedResp := m.c.Load(key) + if cachedResp == nil { + return false, nil + } + if rawMsg := cachedResp.Data(); len(rawMsg) > 0 { + _, err := msg.UnmarshalWithDecoder(coder.DefaultCoder, rawMsg) + if err != nil { + return false, err + } + return true, nil + } + return false, nil +} + +// Store stores a message in the cache. +func (m *messageCache) Store(key string, msg *pool.Message) error { + if _, err := msg.GetOptionUint32(message.Block2); err == nil { + // Skip caching blockwise response. + return nil + } + marshaledResp, err := msg.MarshalWithEncoder(coder.DefaultCoder) + if err != nil { + return err + } + cacheMsg := make([]byte, len(marshaledResp)) + copy(cacheMsg, marshaledResp) + m.c.LoadOrStore(key, cache.NewElement(cacheMsg, time.Now().Add(ExchangeLifetime), nil)) + return nil +} + +// CheckExpirations checks the cache for any expirations. +func (m *messageCache) CheckExpirations(now time.Time) { + m.c.CheckExpirations(now) +} + // Conn represents a virtual connection to a conceptual endpoint, to perform COAPs commands. type Conn struct { // This field needs to be the first in the struct to ensure proper word alignment on 32-bit platforms. @@ -145,7 +202,7 @@ type Conn struct { processReceivedMessage config.ProcessReceivedMessageFunc[*Conn] errors ErrorFunc - responseMsgCache *cache.Cache[string, []byte] + responseMsgCache MessageCache msgIDMutex *MutexMap tokenHandlerContainer *coapSync.Map[uint64, HandlerFunc] @@ -192,6 +249,7 @@ type ConnOptions struct { createBlockWise func(cc *Conn) *blockwise.BlockWise[*Conn] inactivityMonitor InactivityMonitor requestMonitor RequestMonitorFunc + responseMsgCache MessageCache } type Option = func(opts *ConnOptions) @@ -220,6 +278,23 @@ func WithRequestMonitor(requestMonitor RequestMonitorFunc) Option { } } +// WithResponseMessageCache sets the cache used for response messages. All +// response messages are submitted to the cache, but it is up to the cache +// implementation to determine which messages are stored and for how long. +// Caching responses enables sending the same Acknowledgment for retransmitted +// confirmable messages within an EXCHANGE_LIFETIME. It may be desirable to +// relax this behavior in some scenarios. +// https://datatracker.ietf.org/doc/html/rfc7252#section-4.5 +// The default response message cache stores all responses with an expiration of +// 247 seconds, which is EXCHANGE_LIFETIME when using default CoAP transmission +// parameters. +// https://datatracker.ietf.org/doc/html/rfc7252#section-4.8.2 +func WithResponseMessageCache(cache MessageCache) Option { + return func(opts *ConnOptions) { + opts.responseMsgCache = cache + } +} + func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { if cfg.Errors == nil { cfg.Errors = func(error) { @@ -248,6 +323,10 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { for _, o := range opts { o(&cfgOpts) } + // Only construct cache if one was not set via options. + if cfgOpts.responseMsgCache == nil { + cfgOpts.responseMsgCache = newMessageCache() + } cc := Conn{ session: session, transmission: &Transmission{ @@ -262,7 +341,7 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { processReceivedMessage: cfg.ProcessReceivedMessage, errors: cfg.Errors, msgIDMutex: NewMutexMap(), - responseMsgCache: cache.NewCache[string, []byte](), + responseMsgCache: cfgOpts.responseMsgCache, inactivityMonitor: cfgOpts.inactivityMonitor, requestMonitor: cfgOpts.requestMonitor, messagePool: cfg.MessagePool, @@ -609,34 +688,14 @@ func (cc *Conn) Sequence() uint64 { return cc.sequence.Add(1) } -func (cc *Conn) responseMsgCacheID(msgID int32) string { - return fmt.Sprintf("resp-%v-%d", cc.RemoteAddr(), msgID) +// getResponseFromCache gets a message from the response message cache. +func (cc *Conn) getResponseFromCache(mid int32, resp *pool.Message) (bool, error) { + return cc.responseMsgCache.Load(strconv.Itoa(int(mid)), resp) } +// addResponseToCache adds a message to the response message cache. func (cc *Conn) addResponseToCache(resp *pool.Message) error { - marshaledResp, err := resp.MarshalWithEncoder(coder.DefaultCoder) - if err != nil { - return err - } - cacheMsg := make([]byte, len(marshaledResp)) - copy(cacheMsg, marshaledResp) - cc.responseMsgCache.LoadOrStore(cc.responseMsgCacheID(resp.MessageID()), cache.NewElement(cacheMsg, time.Now().Add(ExchangeLifetime), nil)) - return nil -} - -func (cc *Conn) getResponseFromCache(mid int32, resp *pool.Message) (bool, error) { - cachedResp := cc.responseMsgCache.Load(cc.responseMsgCacheID(mid)) - if cachedResp == nil { - return false, nil - } - if rawMsg := cachedResp.Data(); len(rawMsg) > 0 { - _, err := resp.UnmarshalWithDecoder(coder.DefaultCoder, rawMsg) - if err != nil { - return false, err - } - return true, nil - } - return false, nil + return cc.responseMsgCache.Store(strconv.Itoa(int(resp.MessageID())), resp) } // checkMyMessageID compare client msgID against peer messageID and if it is near < 0xffff/4 then increase msgID.