Skip to content

Commit

Permalink
Validate packet according to the spec (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Dec 23, 2019
1 parent a0b1cf9 commit 4929a7a
Show file tree
Hide file tree
Showing 12 changed files with 142 additions and 34 deletions.
7 changes: 5 additions & 2 deletions connack.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ type pktConnAck struct {
Code ConnectionReturnCode
}

func (p *pktConnAck) parse(flag byte, contents []byte) *pktConnAck {
func (p *pktConnAck) parse(flag byte, contents []byte) (*pktConnAck, error) {
if flag != 0 {
return nil, ErrInvalidPacket
}
return &pktConnAck{
SessionPresent: (contents[0]&0x01 != 0),
Code: ConnectionReturnCode(contents[1]),
}
}, nil
}
28 changes: 26 additions & 2 deletions packet.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
package mqtt

import (
"errors"
"fmt"
)

// ErrInvalidRune means that the string has a rune not allowed in MQTT.
var ErrInvalidRune = errors.New("invalid rune in UTF-8 string")

// ErrInvalidPacket means that an invalid message is arrived from the broker.
var ErrInvalidPacket = errors.New("invalid packet")

// ErrInvalidPacketLength means that an invalid length of the message is arrived.
var ErrInvalidPacketLength = errors.New("invalid packet length")

type packetType byte

const (
Expand Down Expand Up @@ -126,7 +136,21 @@ func unpackUint16(b []byte) (int, uint16) {
return 2, uint16(b[0])<<8 | uint16(b[1])
}

func unpackString(b []byte) (int, string) {
func unpackString(b []byte) (int, string, error) {
if len(b) < 2 {
return 0, "", ErrInvalidPacketLength
}
nHeader, n := unpackUint16(b)
return int(n) + nHeader, string(b[nHeader : int(n)+nHeader])
if int(n)+nHeader > len(b) {
return 0, "", ErrInvalidPacketLength
}

// Validate UTF-8 runes according to MQTT-1.5.3-1 and MQTT-1.5.3-2.
rs := []rune(string(b[nHeader : int(n)+nHeader]))
for _, r := range rs {
if r == 0x0000 || (0xD800 <= r && r <= 0xDFFF) {
return 0, "", ErrInvalidRune
}
}
return int(n) + nHeader, string(rs), nil
}
7 changes: 5 additions & 2 deletions pingresp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package mqtt
type pktPingResp struct {
}

func (p *pktPingResp) parse(flag byte, contents []byte) *pktPingResp {
return p
func (p *pktPingResp) parse(flag byte, contents []byte) (*pktPingResp, error) {
if flag != 0 {
return nil, ErrInvalidPacket
}
return p, nil
}
7 changes: 5 additions & 2 deletions puback.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ type pktPubAck struct {
ID uint16
}

func (p *pktPubAck) parse(flag byte, contents []byte) *pktPubAck {
func (p *pktPubAck) parse(flag byte, contents []byte) (*pktPubAck, error) {
if flag != 0 {
return nil, ErrInvalidPacket
}
_, p.ID = unpackUint16(contents)
return p
return p, nil
}
7 changes: 5 additions & 2 deletions pubcomp.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ type pktPubComp struct {
ID uint16
}

func (p *pktPubComp) parse(flag byte, contents []byte) *pktPubComp {
func (p *pktPubComp) parse(flag byte, contents []byte) (*pktPubComp, error) {
if flag != 0 {
return nil, ErrInvalidPacket
}
_, p.ID = unpackUint16(contents)
return p
return p, nil
}
12 changes: 9 additions & 3 deletions publish.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ type pktPublish struct {
Message
}

func (p *pktPublish) parse(flag byte, contents []byte) *pktPublish {
func (p *pktPublish) parse(flag byte, contents []byte) (*pktPublish, error) {
p.Message.Dup = (publishFlag(flag) & publishFlagDup) != 0
p.Message.Retain = (publishFlag(flag) & publishFlagRetain) != 0
switch publishFlag(flag) & publishFlagQoSMask {
Expand All @@ -122,12 +122,18 @@ func (p *pktPublish) parse(flag byte, contents []byte) *pktPublish {
p.Message.QoS = QoS1
case publishFlagQoS2:
p.Message.QoS = QoS2
default:
return nil, ErrInvalidPacket
}

var n, nID int
n, p.Message.Topic = unpackString(contents)
var err error
n, p.Message.Topic, err = unpackString(contents)
if err != nil {
return nil, err
}
nID, p.Message.ID = unpackUint16(contents[n:])
p.Message.Payload = contents[n+nID:]

return p
return p, nil
}
29 changes: 29 additions & 0 deletions publish_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package mqtt

import (
"testing"
)

func TestPublish_ParseError(t *testing.T) {
cases := []struct {
flag byte
contents []byte
err error
}{
{0x00, []byte{0x00, 0x01, 0x61, 0x00, 0x00}, nil},
{0x00, []byte{0x00, 0x01, 0x00, 0x00, 0x00}, ErrInvalidRune},
{0x06, []byte{0x00, 0x01, 0x61, 0x00, 0x00}, ErrInvalidPacket},
{0x00, []byte{0x00, 0x01}, ErrInvalidPacketLength},
{0x00, []byte{0x00}, ErrInvalidPacketLength},
}

for _, c := range cases {
_, err := (&pktPublish{}).parse(c.flag, c.contents)
if err != c.err {
t.Errorf("Parsing packet with flag=%x, contents=%v expected error: %v, got: %v",
c.flag, c.contents,
c.err, err,
)
}
}
}
7 changes: 5 additions & 2 deletions pubrec.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ type pktPubRec struct {
ID uint16
}

func (p *pktPubRec) parse(flag byte, contents []byte) *pktPubRec {
func (p *pktPubRec) parse(flag byte, contents []byte) (*pktPubRec, error) {
if flag != 0 {
return nil, ErrInvalidPacket
}
_, p.ID = unpackUint16(contents)
return p
return p, nil
}
7 changes: 5 additions & 2 deletions pubrel.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ type pktPubRel struct {
ID uint16
}

func (p *pktPubRel) parse(flag byte, contents []byte) *pktPubRel {
func (p *pktPubRel) parse(flag byte, contents []byte) (*pktPubRel, error) {
if flag != 0x02 {
return nil, ErrInvalidPacket
}
_, p.ID = unpackUint16(contents)
return p
return p, nil
}
51 changes: 38 additions & 13 deletions serve.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
package mqtt

import (
"errors"
"io"
)

// ErrInvalidPacket means that an invalid message is arrived from the broker.
var ErrInvalidPacket = errors.New("invalid packet")

func (c *BaseClient) serve() error {
defer func() {
close(c.connClosed)
Expand Down Expand Up @@ -40,12 +36,20 @@ func (c *BaseClient) serve() error {

switch pktType {
case packetConnAck:
connAck, err := (&pktConnAck{}).parse(pktFlag, contents)
if err != nil {
// Client must close connection if packet is invalid.
return err
}
select {
case c.sig.ConnAck() <- (&pktConnAck{}).parse(pktFlag, contents):
case c.sig.ConnAck() <- connAck:
default:
}
case packetPublish:
publish := (&pktPublish{}).parse(pktFlag, contents)
publish, err := (&pktPublish{}).parse(pktFlag, contents)
if err != nil {
return err
}
switch publish.Message.QoS {
case QoS0:
c.mu.RLock()
Expand Down Expand Up @@ -80,23 +84,32 @@ func (c *BaseClient) serve() error {
subBuffer[publish.Message.ID] = &publish.Message
}
case packetPubAck:
pubAck := (&pktPubAck{}).parse(pktFlag, contents)
pubAck, err := (&pktPubAck{}).parse(pktFlag, contents)
if err != nil {
return err
}
if ch, ok := c.sig.PubAck(pubAck.ID); ok {
select {
case ch <- pubAck:
default:
}
}
case packetPubRec:
pubRec := (&pktPubRec{}).parse(pktFlag, contents)
pubRec, err := (&pktPubRec{}).parse(pktFlag, contents)
if err != nil {
return err
}
if ch, ok := c.sig.PubRec(pubRec.ID); ok {
select {
case ch <- pubRec:
default:
}
}
case packetPubRel:
pubRel := (&pktPubRel{}).parse(pktFlag, contents)
pubRel, err := (&pktPubRel{}).parse(pktFlag, contents)
if err != nil {
return err
}
if msg, ok := subBuffer[pubRel.ID]; ok {
// Ownership of the message is now transferred to the receiver.
c.mu.RLock()
Expand All @@ -116,31 +129,43 @@ func (c *BaseClient) serve() error {
return err
}
case packetPubComp:
pubComp := (&pktPubComp{}).parse(pktFlag, contents)
pubComp, err := (&pktPubComp{}).parse(pktFlag, contents)
if err != nil {
return err
}
if ch, ok := c.sig.PubComp(pubComp.ID); ok {
select {
case ch <- pubComp:
default:
}
}
case packetSubAck:
subAck := (&pktSubAck{}).parse(pktFlag, contents)
subAck, err := (&pktSubAck{}).parse(pktFlag, contents)
if err != nil {
return err
}
if ch, ok := c.sig.SubAck(subAck.ID); ok {
select {
case ch <- subAck:
default:
}
}
case packetUnsubAck:
unsubAck := (&pktUnsubAck{}).parse(pktFlag, contents)
unsubAck, err := (&pktUnsubAck{}).parse(pktFlag, contents)
if err != nil {
return err
}
if ch, ok := c.sig.UnsubAck(unsubAck.ID); ok {
select {
case ch <- unsubAck:
default:
}
}
case packetPingResp:
pingResp := (&pktPingResp{}).parse(pktFlag, contents)
pingResp, err := (&pktPingResp{}).parse(pktFlag, contents)
if err != nil {
return err
}
select {
case c.sig.PingResp() <- pingResp:
default:
Expand Down
7 changes: 5 additions & 2 deletions suback.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ type pktSubAck struct {
Codes []subscribeFlag
}

func (p *pktSubAck) parse(flag byte, contents []byte) *pktSubAck {
func (p *pktSubAck) parse(flag byte, contents []byte) (*pktSubAck, error) {
if flag != 0 {
return nil, ErrInvalidPacket
}
p.ID = uint16(contents[0])<<8 | uint16(contents[1])
for _, c := range contents[2:] {
p.Codes = append(p.Codes, subscribeFlag(c))
}
return p
return p, nil
}
7 changes: 5 additions & 2 deletions unsuback.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ type pktUnsubAck struct {
ID uint16
}

func (p *pktUnsubAck) parse(flag byte, contents []byte) *pktUnsubAck {
func (p *pktUnsubAck) parse(flag byte, contents []byte) (*pktUnsubAck, error) {
if flag != 0 {
return nil, ErrInvalidPacket
}
p.ID = uint16(contents[0])<<8 | uint16(contents[1])
return p
return p, nil
}

0 comments on commit 4929a7a

Please sign in to comment.