diff --git a/client_test.go b/client_test.go index 77af8ef..5ec0837 100644 --- a/client_test.go +++ b/client_test.go @@ -46,3 +46,51 @@ func TestConnect(t *testing.T) { } cli.Close() } + +func TestProtocolViolation(t *testing.T) { + ca, cb := net.Pipe() + cli := &BaseClient{Transport: cb} + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + go func() { + if err := cli.Connect(ctx, "cli"); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + }() + + errCh := make(chan error, 10) + cli.ConnState = func(s ConnState, err error) { + if s == StateClosed { + errCh <- err + } + } + + b := make([]byte, 100) + if _, err := ca.Read(b); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Send CONNACK. + if _, err := ca.Write([]byte{ + 0x20, 0x02, 0x00, 0x00, + }); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Send SUBSCRIBE from broker. + if _, err := ca.Write([]byte{ + 0x80, 0x01, 0x00, + }); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + select { + case err := <-errCh: + if err != ErrInvalidPacket { + t.Errorf("Expected error against invalid packet: %v, got: %v", ErrInvalidPacket, err) + } + case <-ctx.Done(): + t.Error("Timeout") + } +} diff --git a/conn.go b/conn.go index d424852..07c0fcd 100644 --- a/conn.go +++ b/conn.go @@ -91,15 +91,6 @@ func (d *DialOptions) dial(urlStr string) (*BaseClient, error) { default: return nil, ErrUnsupportedProtocol } - c.connStateUpdate(StateIdle) - c.connClosed = make(chan struct{}) - - go func() { - err := c.serve() - c.mu.Lock() - c.err = err - c.mu.Unlock() - }() return c, nil } diff --git a/connect.go b/connect.go index a9cbcb1..9bffc2e 100644 --- a/connect.go +++ b/connect.go @@ -3,6 +3,7 @@ package mqtt import ( "context" "errors" + "io" ) type protocolLevel byte @@ -35,7 +36,20 @@ func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...Conne } } c.sig = &signaller{} + c.connClosed = make(chan struct{}) + go func() { + err := c.serve() + if err != io.EOF && err != io.ErrUnexpectedEOF { + if errConn := c.Transport.Close(); errConn != nil && err == nil { + err = errConn + } + } + c.mu.Lock() + c.err = err + c.mu.Unlock() + c.connStateUpdate(StateClosed) + }() payload := packString(clientID) var flag byte diff --git a/mqtt.go b/mqtt.go index 91b3571..1a92ecd 100644 --- a/mqtt.go +++ b/mqtt.go @@ -72,7 +72,6 @@ type ConnState int // ConnState values. const ( StateNew ConnState = iota // initial state - StateIdle // transport layer is opened StateActive // connected to the broker StateClosed // connection is unexpectedly closed StateDisconnected // connection is expectedly closed diff --git a/serve.go b/serve.go index 833d66a..2c93a72 100644 --- a/serve.go +++ b/serve.go @@ -1,12 +1,15 @@ package mqtt import ( + "errors" "io" ) +// ErrInvalidPacket means that an invalid message is arrived from the broker. +var ErrInvalidPacket = errors.New("invalid packet") + func (c *BaseClient) serve() error { defer func() { - c.connStateUpdate(StateClosed) close(c.connClosed) }() r := c.Transport @@ -118,6 +121,9 @@ func (c *BaseClient) serve() error { case c.sig.PingResp() <- pingResp: default: } + default: + // must close connection if the client encounted protocol violation. + return ErrInvalidPacket } } } diff --git a/subscribe.go b/subscribe.go index cd72884..257ec14 100644 --- a/subscribe.go +++ b/subscribe.go @@ -60,6 +60,7 @@ func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error return ctx.Err() case subAck := <-chSubAck: if len(subAck.Codes) != len(subs) { + c.Transport.Close() return ErrInvalidSubAck } for i := 0; i < len(subAck.Codes); i++ {