Skip to content

Commit

Permalink
Close on protocol violation (#12)
Browse files Browse the repository at this point in the history
* Close connection on protocol violation
* Fix connection state transition and close sequence
* Initialize serve routine in Connect
* Close on invalid SUBACK
  • Loading branch information
at-wat authored Dec 22, 2019
1 parent 01a267e commit 1e7b1ab
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 11 deletions.
48 changes: 48 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,51 @@ func TestConnect(t *testing.T) {
}
cli.Close()
}

func TestProtocolViolation(t *testing.T) {
ca, cb := net.Pipe()
cli := &BaseClient{Transport: cb}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
go func() {
if err := cli.Connect(ctx, "cli"); err != nil {
t.Fatalf("Unexpected error: %v", err)
}
}()

errCh := make(chan error, 10)
cli.ConnState = func(s ConnState, err error) {
if s == StateClosed {
errCh <- err
}
}

b := make([]byte, 100)
if _, err := ca.Read(b); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Send CONNACK.
if _, err := ca.Write([]byte{
0x20, 0x02, 0x00, 0x00,
}); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

// Send SUBSCRIBE from broker.
if _, err := ca.Write([]byte{
0x80, 0x01, 0x00,
}); err != nil {
t.Fatalf("Unexpected error: %v", err)
}

select {
case err := <-errCh:
if err != ErrInvalidPacket {
t.Errorf("Expected error against invalid packet: %v, got: %v", ErrInvalidPacket, err)
}
case <-ctx.Done():
t.Error("Timeout")
}
}
9 changes: 0 additions & 9 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,6 @@ func (d *DialOptions) dial(urlStr string) (*BaseClient, error) {
default:
return nil, ErrUnsupportedProtocol
}
c.connStateUpdate(StateIdle)
c.connClosed = make(chan struct{})

go func() {
err := c.serve()
c.mu.Lock()
c.err = err
c.mu.Unlock()
}()
return c, nil
}

Expand Down
14 changes: 14 additions & 0 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package mqtt
import (
"context"
"errors"
"io"
)

type protocolLevel byte
Expand Down Expand Up @@ -35,7 +36,20 @@ func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...Conne
}
}
c.sig = &signaller{}
c.connClosed = make(chan struct{})

go func() {
err := c.serve()
if err != io.EOF && err != io.ErrUnexpectedEOF {
if errConn := c.Transport.Close(); errConn != nil && err == nil {
err = errConn
}
}
c.mu.Lock()
c.err = err
c.mu.Unlock()
c.connStateUpdate(StateClosed)
}()
payload := packString(clientID)

var flag byte
Expand Down
1 change: 0 additions & 1 deletion mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ type ConnState int
// ConnState values.
const (
StateNew ConnState = iota // initial state
StateIdle // transport layer is opened
StateActive // connected to the broker
StateClosed // connection is unexpectedly closed
StateDisconnected // connection is expectedly closed
Expand Down
8 changes: 7 additions & 1 deletion serve.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package mqtt

import (
"errors"
"io"
)

// ErrInvalidPacket means that an invalid message is arrived from the broker.
var ErrInvalidPacket = errors.New("invalid packet")

func (c *BaseClient) serve() error {
defer func() {
c.connStateUpdate(StateClosed)
close(c.connClosed)
}()
r := c.Transport
Expand Down Expand Up @@ -118,6 +121,9 @@ func (c *BaseClient) serve() error {
case c.sig.PingResp() <- pingResp:
default:
}
default:
// must close connection if the client encounted protocol violation.
return ErrInvalidPacket
}
}
}
1 change: 1 addition & 0 deletions subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error
return ctx.Err()
case subAck := <-chSubAck:
if len(subAck.Codes) != len(subs) {
c.Transport.Close()
return ErrInvalidSubAck
}
for i := 0; i < len(subAck.Codes); i++ {
Expand Down

0 comments on commit 1e7b1ab

Please sign in to comment.