diff --git a/message/codes/codes.go b/message/codes/codes.go index fd3568bc..5878fe87 100644 --- a/message/codes/codes.go +++ b/message/codes/codes.go @@ -4,6 +4,8 @@ import ( "errors" "fmt" "strconv" + + "github.com/plgd-dev/go-coap/v3/pkg/math" ) // A Code is an unsigned 16-bit coap code as defined in the coap spec. @@ -94,14 +96,15 @@ var strToCode = map[string]Code{ } func getMaxCodeLen() int { - // max uint32 as string binary representation: "0b" + 32 digits - max := 34 + // maxLen uint32 as string binary representation: "0b" + 32 digits + maxLen := 34 for k := range strToCode { - if len(k) > max { - max = len(k) + kLen := len(k) + if kLen > maxLen { + maxLen = kLen } } - return max + return maxLen } func init() { @@ -128,8 +131,7 @@ func (c *Code) UnmarshalJSON(b []byte) error { if ci >= _maxCode { return fmt.Errorf("invalid code: %q", ci) } - - *c = Code(ci) + *c = math.CastTo[Code](ci) return nil } diff --git a/message/encodeDecodeUint32.go b/message/encodeDecodeUint32.go index 51355c39..8f1f2416 100644 --- a/message/encodeDecodeUint32.go +++ b/message/encodeDecodeUint32.go @@ -2,6 +2,8 @@ package message import ( "encoding/binary" + + "github.com/plgd-dev/go-coap/v3/pkg/math" ) func EncodeUint32(buf []byte, value uint32) (int, error) { @@ -18,7 +20,7 @@ func EncodeUint32(buf []byte, value uint32) (int, error) { if len(buf) < 2 { return 2, ErrTooSmall } - binary.BigEndian.PutUint16(buf, uint16(value)) + binary.BigEndian.PutUint16(buf, math.CastTo[uint16](value)) return 2, nil case value <= max3ByteNumber: if len(buf) < 3 { diff --git a/message/getmid.go b/message/getmid.go index e049381e..60fa1f6f 100644 --- a/message/getmid.go +++ b/message/getmid.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "time" + pkgMath "github.com/plgd-dev/go-coap/v3/pkg/math" pkgRand "github.com/plgd-dev/go-coap/v3/pkg/rand" ) @@ -16,7 +17,7 @@ var msgID = uint32(RandMID()) // GetMID generates a message id for UDP. (0 <= mid <= 65535) func GetMID() int32 { - return int32(uint16(atomic.AddUint32(&msgID, 1))) + return int32(pkgMath.CastTo[uint16](atomic.AddUint32(&msgID, 1))) } func RandMID() int32 { @@ -24,9 +25,9 @@ func RandMID() int32 { _, err := rand.Read(b) if err != nil { // fallback to cryptographically insecure pseudo-random generator - return int32(uint16(weakRng.Uint32() >> 16)) + return int32(pkgMath.CastTo[uint16](weakRng.Uint32() >> 16)) } - return int32(uint16(binary.BigEndian.Uint32(b))) + return int32(pkgMath.CastTo[uint16](binary.BigEndian.Uint32(b))) } // ValidateMID validates a message id for UDP. (0 <= mid <= 65535) diff --git a/message/option.go b/message/option.go index d8b8dba8..f62eb454 100644 --- a/message/option.go +++ b/message/option.go @@ -4,6 +4,9 @@ import ( "encoding/binary" "errors" "strconv" + + "github.com/plgd-dev/go-coap/v3/pkg/math" + "golang.org/x/exp/constraints" ) const ( @@ -236,6 +239,18 @@ func ToMediaType(v string) (MediaType, error) { return 0, errors.New("not found") } +func MediaTypeFromNumber[T constraints.Integer](v T) (MediaType, error) { + mt, err := math.SafeCastTo[MediaType](v) + if err != nil { + return MediaType(0), err + } + _, ok := mediaTypeToString[mt] + if !ok { + return MediaType(0), errors.New("invalid value") + } + return mt, nil +} + func extendOpt(opt int) (int, int) { ext := 0 if opt >= ExtendOptionByteAddend { @@ -269,7 +284,7 @@ func marshalOptionHeaderExt(buf []byte, opt, ext int) (int, error) { return 1, ErrTooSmall case ExtendOptionWordCode: if len(buf) > 1 { - binary.BigEndian.PutUint16(buf, uint16(ext)) + binary.BigEndian.PutUint16(buf, math.CastTo[uint16](ext)) return 2, nil } return 2, ErrTooSmall @@ -435,7 +450,8 @@ func (o *Option) Unmarshal(data []byte, optionDefs map[OptionID]OptionDef, optio // Skip unrecognized options (RFC7252 section 5.4.1) return len(data), nil } - if uint32(len(data)) < def.MinLen || uint32(len(data)) > def.MaxLen { + dataLen := math.CastTo[uint32](len(data)) + if dataLen < def.MinLen || dataLen > def.MaxLen { // Skip options with illegal value length (RFC7252 section 5.4.3) return len(data), nil } diff --git a/message/options.go b/message/options.go index 945c291c..02a182a9 100644 --- a/message/options.go +++ b/message/options.go @@ -2,7 +2,10 @@ package message import ( "errors" + "fmt" "strings" + + "github.com/plgd-dev/go-coap/v3/pkg/math" ) // Options Container of COAP Options, It must be always sort'ed after modification. @@ -207,7 +210,10 @@ func (options Options) GetUint32(id OptionID) (uint32, error) { // ContentFormat gets the content format of body. func (options Options) ContentFormat() (MediaType, error) { v, err := options.GetUint32(ContentFormat) - return MediaType(v), err + if err != nil { + return MediaType(0), err + } + return MediaTypeFromNumber(v) } // GetString gets the string value of the first option with the given ID. @@ -353,7 +359,10 @@ func (options Options) SetAccept(buf []byte, contentFormat MediaType) (Options, // Accept gets accept option. func (options Options) Accept() (MediaType, error) { v, err := options.GetUint32(Accept) - return MediaType(v), err + if err != nil { + return MediaType(0), err + } + return MediaTypeFromNumber(v) } // Find returns range of type options. First number is index and second number is index of next option type. @@ -576,7 +585,10 @@ func (options *Options) Unmarshal(data []byte, optionDefs map[OptionID]OptionDef } option := Option{} - oid := OptionID(prev + delta) + oid, err := math.SafeCastTo[OptionID](prev + delta) + if err != nil { + return -1, fmt.Errorf("%w: %w", ErrOptionNotFound, err) + } proc, err = option.Unmarshal(data[:length], optionDefs, oid) if err != nil { return -1, err diff --git a/message/pool/message.go b/message/pool/message.go index 59fbe9fe..5191e76e 100644 --- a/message/pool/message.go +++ b/message/pool/message.go @@ -373,7 +373,10 @@ func (r *Message) AddOptionUint32(opt message.OptionID, value uint32) { func (r *Message) ContentFormat() (message.MediaType, error) { v, err := r.GetOptionUint32(message.ContentFormat) - return message.MediaType(v), err + if err != nil { + return message.MediaType(0), err + } + return message.MediaTypeFromNumber(v) } func (r *Message) HasOption(id message.OptionID) bool { @@ -400,7 +403,10 @@ func (r *Message) SetAccept(contentFormat message.MediaType) { // Accept get's accept option. func (r *Message) Accept() (message.MediaType, error) { v, err := r.GetOptionUint32(message.Accept) - return message.MediaType(v), err + if err != nil { + return message.MediaType(0), err + } + return message.MediaTypeFromNumber(v) } func (r *Message) BodySize() (int64, error) { diff --git a/net/blockwise/blockwise.go b/net/blockwise/blockwise.go index e199d6cc..72f7958a 100644 --- a/net/blockwise/blockwise.go +++ b/net/blockwise/blockwise.go @@ -14,6 +14,7 @@ import ( "github.com/plgd-dev/go-coap/v3/message/pool" "github.com/plgd-dev/go-coap/v3/net/responsewriter" "github.com/plgd-dev/go-coap/v3/pkg/cache" + "github.com/plgd-dev/go-coap/v3/pkg/math" "golang.org/x/sync/semaphore" ) @@ -98,13 +99,13 @@ func EncodeBlockOption(szx SZX, blockNumber int64, moreBlocksFollowing bool) (ui if blockNumber > maxBlockNumber { return 0, ErrBlockNumberExceedLimit } - blockVal := uint32(blockNumber << 4) + blockVal := math.CastTo[uint32](blockNumber << 4) m := uint32(0) if moreBlocksFollowing { m = 1 } blockVal += m << 3 - blockVal += uint32(szx) + blockVal += math.CastTo[uint32](szx) return blockVal, nil } @@ -238,7 +239,11 @@ func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do } req := b.cloneMessage(r) defer b.cc.ReleaseMessage(req) - req.SetOptionUint32(message.Size1, uint32(payloadSize)) + payloadSizeUint32, err := math.SafeCastTo[uint32](payloadSize) + if err != nil { + return nil, fmt.Errorf("cannot set payload size: %w", err) + } + req.SetOptionUint32(message.Size1, payloadSizeUint32) block, err := EncodeBlockOption(maxSzx, 0, true) if err != nil { return nil, fmt.Errorf("cannot encode block option(%v, %v, %v) to bw request: %w", maxSzx, 0, true, err) @@ -436,9 +441,7 @@ func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX b.cc.ReleaseMessage(sendMessage) return nil, false, payloadSizeError(err) } - if szx > maxSZX { - szx = maxSZX - } + szx = getSzx(szx, maxSZX) newBufLen := bufferSize(szx, maxMessageSize) off := num * szx.Size() if blockType == message.Block1 { @@ -471,13 +474,19 @@ func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX return nil, false, fmt.Errorf("cannot read response: %w", err) } + payloadSizeUint32, err := math.SafeCastTo[uint32](payloadSize) + if err != nil { + b.cc.ReleaseMessage(sendMessage) + return nil, false, fmt.Errorf("cannot set payload size: %w", err) + } + sendMessage.SetOptionUint32(sizeType, payloadSizeUint32) + buf = buf[:readed] sendMessage.SetBody(bytes.NewReader(buf)) more = true if offSeek+int64(readed) == payloadSize { more = false } - sendMessage.SetOptionUint32(sizeType, uint32(payloadSize)) num = (offSeek) / szx.Size() block, err = EncodeBlockOption(szx, num, more) if err != nil { diff --git a/pkg/math/cast.go b/pkg/math/cast.go new file mode 100644 index 00000000..a30a844a --- /dev/null +++ b/pkg/math/cast.go @@ -0,0 +1,55 @@ +package math + +import ( + "fmt" + "log" + "unsafe" + + "golang.org/x/exp/constraints" +) + +func Max[T constraints.Integer]() T { + size := unsafe.Sizeof(T(0)) + switch any(T(0)).(type) { + case int, int8, int16, int32, int64: + return T(1<<(size*8-1) - 1) // 2^(n-1) - 1 for signed integers + case uint, uint8, uint16, uint32, uint64: + return T(1<<(size*8) - 1) // 2^n - 1 for unsigned integers + default: + panic("unsupported type") + } +} + +func Min[T constraints.Integer]() T { + size := unsafe.Sizeof(T(0)) + switch any(T(0)).(type) { + case int, int8, int16, int32, int64: + return T(int64(-1) << (size*8 - 1)) // -2^(n-1) + case uint, uint8, uint16, uint32, uint64: + return T(0) + default: + panic("unsupported type") + } +} + +func SafeCastTo[T, F constraints.Integer](from F) (T, error) { + if from > 0 && uint64(Max[T]()) < uint64(from) { + return T(0), fmt.Errorf("value(%v) exceeds the maximum value for type(%v)", from, Max[T]()) + } + if from < 0 && int64(Min[T]()) > int64(from) { + return T(0), fmt.Errorf("value(%v) exceeds the minimum value for type(%v)", from, Min[T]()) + } + return T(from), nil +} + +func CastTo[T, F constraints.Integer](from F) T { + return T(from) +} + +func MustSafeCastTo[T, F constraints.Integer](from F) T { + to, err := SafeCastTo[T](from) + if err != nil { + log.Panicf("value (%v) out of bounds for type %T", from, T(0)) + } + return to +} diff --git a/pkg/math/cast_test.go b/pkg/math/cast_test.go new file mode 100644 index 00000000..b64687d2 --- /dev/null +++ b/pkg/math/cast_test.go @@ -0,0 +1,69 @@ +package math_test + +import ( + "math" + "testing" + + pkgMath "github.com/plgd-dev/go-coap/v3/pkg/math" + "github.com/stretchr/testify/require" +) + +func TestCastToUint8(t *testing.T) { + // uint8 + _, err := pkgMath.SafeCastTo[uint8](uint8(0)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[uint8](math.MaxUint8) + require.NoError(t, err) + // int8 + _, err = pkgMath.SafeCastTo[uint8](math.MinInt8) + require.Error(t, err) + _, err = pkgMath.SafeCastTo[uint8](int8(0)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[uint8](math.MaxInt8) + require.NoError(t, err) + // uint64 + _, err = pkgMath.SafeCastTo[uint8](uint64(0)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[uint8](uint64(math.MaxUint8)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[uint8](uint64(math.MaxUint64)) + require.Error(t, err) + // int64 + _, err = pkgMath.SafeCastTo[uint8](math.MaxInt64) + require.Error(t, err) + _, err = pkgMath.SafeCastTo[uint8](int64(0)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[uint8](math.MaxInt64) + require.Error(t, err) + _, err = pkgMath.SafeCastTo[uint8](int64(math.MaxUint8)) + require.NoError(t, err) +} + +func TestCastToInt8(t *testing.T) { + // uint8 + _, err := pkgMath.SafeCastTo[int8](uint8(0)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[int8](math.MaxUint8) + require.Error(t, err) + // int8 + _, err = pkgMath.SafeCastTo[int8](math.MinInt8) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[int8](math.MaxInt8) + require.NoError(t, err) + // uint64 + _, err = pkgMath.SafeCastTo[int8](uint64(0)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[int8](uint64(math.MaxInt8)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[int8](uint64(math.MaxUint64)) + require.Error(t, err) + // int64 + _, err = pkgMath.SafeCastTo[int8](math.MaxInt64) + require.Error(t, err) + _, err = pkgMath.SafeCastTo[int8](int64(0)) + require.NoError(t, err) + _, err = pkgMath.SafeCastTo[int8](math.MaxInt64) + require.Error(t, err) + _, err = pkgMath.SafeCastTo[int8](int64(math.MaxInt8)) + require.NoError(t, err) +} diff --git a/tcp/client/session.go b/tcp/client/session.go index 66c3d2c9..f1180e4d 100644 --- a/tcp/client/session.go +++ b/tcp/client/session.go @@ -14,6 +14,7 @@ import ( "github.com/plgd-dev/go-coap/v3/message/pool" coapNet "github.com/plgd-dev/go-coap/v3/net" "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/pkg/math" "github.com/plgd-dev/go-coap/v3/tcp/coder" "go.uber.org/atomic" ) @@ -151,11 +152,11 @@ func seekBufferToNextMessage(buffer *bytes.Buffer, msgSize int) *bytes.Buffer { trimmed := 0 for trimmed != msgSize { b := make([]byte, 4096) - max := 4096 - if msgSize-trimmed < max { - max = msgSize - trimmed + toRead := 4096 + if msgSize-trimmed < toRead { + toRead = msgSize - trimmed } - v, _ := buffer.Read(b[:max]) + v, _ := buffer.Read(b[:toRead]) trimmed += v } return buffer @@ -171,7 +172,7 @@ func (s *Session) processBuffer(buffer *bytes.Buffer, cc *Conn) error { if header.MessageLength > s.maxMessageSize { return fmt.Errorf("max message size(%v) was exceeded %v", s.maxMessageSize, header.MessageLength) } - if uint32(buffer.Len()) < header.MessageLength { + if math.CastTo[uint32](buffer.Len()) < header.MessageLength { return nil } req := s.messagePool.AcquireMessage(s.Context()) diff --git a/tcp/coder/coder.go b/tcp/coder/coder.go index 9979370c..cf87dd91 100644 --- a/tcp/coder/coder.go +++ b/tcp/coder/coder.go @@ -6,6 +6,7 @@ import ( "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/pkg/math" ) var DefaultCoder = new(Coder) @@ -46,13 +47,13 @@ func getHeader(messageLength int) (uint8, []byte) { if messageLength < MessageLength15Base { extLen := messageLength - MessageLength14Base extLenBytes := make([]byte, 2) - binary.BigEndian.PutUint16(extLenBytes, uint16(extLen)) + binary.BigEndian.PutUint16(extLenBytes, math.CastTo[uint16](extLen)) return 14, extLenBytes } if messageLength < messageMaxLen { extLen := messageLength - MessageLength15Base extLenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(extLenBytes, uint32(extLen)) + binary.BigEndian.PutUint32(extLenBytes, math.CastTo[uint32](extLen)) return 15, extLenBytes } return 0, nil @@ -192,7 +193,7 @@ func (c *Coder) DecodeHeader(data []byte, h *MessageHeader) (int, error) { opLen = MessageLength15Base + int(extLen) } - h.MessageLength = hdrOff + 1 + uint32(tkl) + uint32(opLen) + h.MessageLength = hdrOff + 1 + uint32(tkl) + math.CastTo[uint32](opLen) if len(data) < 1 { return -1, message.ErrShortRead } @@ -229,12 +230,12 @@ func (c *Coder) DecodeWithHeader(data []byte, header MessageHeader, m *message.M return -1, err } data = data[proc:] - processed += uint32(proc) + processed += math.CastTo[uint32](proc) if len(data) > 0 { m.Payload = data } - processed += uint32(len(data)) + processed += math.CastTo[uint32](len(data)) m.Code = header.Code m.Token = header.Token @@ -247,7 +248,7 @@ func (c *Coder) Decode(data []byte, m *message.Message) (int, error) { if err != nil { return -1, err } - if uint32(len(data)) < header.MessageLength { + if math.CastTo[uint32](len(data)) < header.MessageLength { return -1, message.ErrShortRead } return c.DecodeWithHeader(data[header.Length:], header, m) diff --git a/udp/client/conn.go b/udp/client/conn.go index 5fc0699f..492de01f 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -23,6 +23,7 @@ import ( "github.com/plgd-dev/go-coap/v3/pkg/cache" coapErrors "github.com/plgd-dev/go-coap/v3/pkg/errors" "github.com/plgd-dev/go-coap/v3/pkg/fn" + pkgMath "github.com/plgd-dev/go-coap/v3/pkg/math" coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" "github.com/plgd-dev/go-coap/v3/udp/coder" "go.uber.org/atomic" @@ -76,7 +77,7 @@ type midElement struct { handler HandlerFunc start time.Time deadline time.Time - retransmit atomic.Int32 + retransmit atomic.Uint32 private struct { sync.Mutex @@ -93,7 +94,7 @@ func (m *midElement) ReleaseMessage(cc *Conn) { } } -func (m *midElement) IsExpired(now time.Time, maxRetransmit int32) bool { +func (m *midElement) IsExpired(now time.Time, maxRetransmit uint32) bool { if !m.deadline.IsZero() && now.After(m.deadline) { // remove element if deadline is exceeded return true @@ -167,7 +168,7 @@ type Conn struct { type Transmission struct { nStart *atomic.Uint32 acknowledgeTimeout *atomic.Duration - maxRetransmit *atomic.Int32 + maxRetransmit *atomic.Uint32 } // SetTransmissionNStart changing the nStart value will only effect requests queued after the change. The requests waiting here already before the change will get unblocked when enough weight has been released. @@ -179,7 +180,7 @@ func (t *Transmission) SetTransmissionAcknowledgeTimeout(d time.Duration) { t.acknowledgeTimeout.Store(d) } -func (t *Transmission) SetTransmissionMaxRetransmit(d int32) { +func (t *Transmission) SetTransmissionMaxRetransmit(d uint32) { t.maxRetransmit.Store(d) } @@ -252,7 +253,7 @@ func NewConnWithOpts(session Session, cfg *Config, opts ...Option) *Conn { transmission: &Transmission{ atomic.NewUint32(cfg.TransmissionNStart), atomic.NewDuration(cfg.TransmissionAcknowledgeTimeout), - atomic.NewInt32(int32(cfg.TransmissionMaxRetransmit)), + atomic.NewUint32(cfg.TransmissionMaxRetransmit), }, blockwiseSZX: cfg.BlockwiseSZX, @@ -308,7 +309,7 @@ func (cc *Conn) GetMessageID() int32 { // previous one, the receiver may mistakenly treat the incoming message as a duplicate and discard it. // Hence, by incrementing the global counter, we can ensure unique message IDs and avoid such issues. message.GetMID() - return int32(uint16(cc.msgID.Inc())) + return int32(pkgMath.CastTo[uint16](cc.msgID.Inc())) } // Close closes connection without waiting for the end of the Run function. @@ -638,13 +639,13 @@ func (cc *Conn) getResponseFromCache(mid int32, resp *pool.Message) (bool, error return false, nil } -// checkMyMessageID compare client msgID against peer messageID and if it is near < 0xffff/4 then incrase msgID. +// checkMyMessageID compare client msgID against peer messageID and if it is near < 0xffff/4 then increase msgID. // When msgIDs met it can cause issue because cache can send message to which doesn't bellows to request. func (cc *Conn) checkMyMessageID(req *pool.Message) { if req.Type() == message.Confirmable { for { oldID := cc.msgID.Load() - if uint16(req.MessageID())-uint16(cc.msgID.Load()) >= 0xffff/4 { + if pkgMath.CastTo[uint16](req.MessageID())-pkgMath.CastTo[uint16](cc.msgID.Load()) >= 0xffff/4 { return } newID := oldID + 0xffff/2 @@ -835,7 +836,7 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { } func (cc *Conn) Process(cm *coapNet.ControlMessage, datagram []byte) error { - if uint32(len(datagram)) > cc.session.MaxMessageSize() { + if pkgMath.CastTo[uint32](len(datagram)) > cc.session.MaxMessageSize() { return fmt.Errorf("max message size(%v) was exceeded %v", cc.session.MaxMessageSize(), len(datagram)) } req := cc.AcquireMessage(cc.Context()) @@ -877,7 +878,7 @@ func (cc *Conn) Done() <-chan struct{} { return cc.session.Done() } -func (cc *Conn) checkMidHandlerContainer(now time.Time, maxRetransmit int32, acknowledgeTimeout time.Duration, key int32, value *midElement) { +func (cc *Conn) checkMidHandlerContainer(now time.Time, maxRetransmit uint32, acknowledgeTimeout time.Duration, key int32, value *midElement) { if value.IsExpired(now, maxRetransmit) { cc.midHandlerContainer.Delete(key) value.ReleaseMessage(cc) @@ -914,7 +915,7 @@ func (cc *Conn) CheckExpirations(now time.Time) { acknowledgeTimeout := cc.transmission.acknowledgeTimeout.Load() x := struct { now time.Time - maxRetransmit int32 + maxRetransmit uint32 acknowledgeTimeout time.Duration cc *Conn }{ diff --git a/udp/client_test.go b/udp/client_test.go index 654677d4..54c1373d 100644 --- a/udp/client_test.go +++ b/udp/client_test.go @@ -169,11 +169,11 @@ func TestConnGet(t *testing.T) { } require.NoError(t, err) require.Equal(t, tt.wantCode, got.Code()) - assert.Greater(t, got.Sequence(), uint64(0)) + require.Positive(t, got.Sequence()) if tt.wantContentFormat != nil { ct, errC := got.ContentFormat() require.NoError(t, errC) - assert.Equal(t, *tt.wantContentFormat, ct) + require.Equal(t, *tt.wantContentFormat, ct) } if tt.wantPayload != nil { buf := bytes.NewBuffer(nil) diff --git a/udp/coder/coder.go b/udp/coder/coder.go index 4b1f0d35..9c23909a 100644 --- a/udp/coder/coder.go +++ b/udp/coder/coder.go @@ -7,6 +7,7 @@ import ( "github.com/plgd-dev/go-coap/v3/message" "github.com/plgd-dev/go-coap/v3/message/codes" + "github.com/plgd-dev/go-coap/v3/pkg/math" ) var DefaultCoder = new(Coder) @@ -60,7 +61,8 @@ func (c *Coder) Encode(m message.Message, buf []byte) (int, error) { } tmpbuf := []byte{0, 0} - binary.BigEndian.PutUint16(tmpbuf, uint16(m.MessageID)) + // safe: checked by message.ValidateMID above + binary.BigEndian.PutUint16(tmpbuf, math.CastTo[uint16](m.MessageID)) buf[0] = (1 << 6) | byte(m.Type)<<4 | byte(0xf&len(m.Token)) buf[1] = byte(m.Code)