Skip to content

Commit

Permalink
Add KeepAlive and RetryClient (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Dec 22, 2019
1 parent 8d4924d commit 2d271ca
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 4 deletions.
6 changes: 2 additions & 4 deletions connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ const (

// Connect to the broker.
func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) {
o := &ConnectOptions{
KeepAlive: 60,
}
o := &ConnectOptions{}
for _, opt := range opts {
if err := opt(o); err != nil {
return false, err
Expand All @@ -42,7 +40,7 @@ func (c *BaseClient) Connect(ctx context.Context, clientID string, opts ...Conne
go func() {
err := c.serve()
if err != io.EOF && err != io.ErrUnexpectedEOF {
if errConn := c.Transport.Close(); errConn != nil && err == nil {
if errConn := c.Close(); errConn != nil && err == nil {
err = errConn
}
}
Expand Down
47 changes: 47 additions & 0 deletions keepalive.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package mqtt

import (
"context"
"errors"
"time"
)

// ErrKeepAliveDisabled is returned if Runned on keep alive disabled connection.
var ErrKeepAliveDisabled = errors.New("keep alive disabled")

// ErrPingTimeout is returned on ping response timeout.
var ErrPingTimeout = errors.New("ping timeout")

// KeepAlive runs keep alive loop.
// It must be called after Connect and interval must be smaller than the value
// specified by WithKeepAlive option passed to Connect.
func KeepAlive(ctx context.Context, cli ClientCloser, interval, timeout time.Duration) error {
ticker := time.NewTicker(interval)
defer ticker.Stop()

for {
<-ticker.C

ctxTo, cancel := context.WithTimeout(ctx, timeout)
if err := cli.Ping(ctxTo); err != nil {
defer cancel()
// The client should close the connection if PINGRESP is not returned.
// MQTT 3.1.1 spec. 3.1.2.10
cli.Close()

select {
case <-ctx.Done():
// Parent context cancelled.
return ctx.Err()
default:
}
select {
case <-ctxTo.Done():
return ErrPingTimeout
default:
}
return err
}
cancel()
}
}
44 changes: 44 additions & 0 deletions keepalive_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// +build integration

package mqtt

import (
"context"
"crypto/tls"
"testing"
"time"
)

func TestIntegration_KeepAlive(t *testing.T) {
for name, url := range urls {
t.Run(name, func(t *testing.T) {
cli, err := Dial(url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

if _, err := cli.Connect(ctx, "Client1",
WithKeepAlive(1),
); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

// Without keepalive, broker should disconnect on t=1.5s.
if err := KeepAlive(
ctx, cli,
time.Second,
500*time.Millisecond,
); err != context.DeadlineExceeded {
t.Errorf("Expected error: '%v', got: '%v'", context.DeadlineExceeded, err)

if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
}
})
}

}
11 changes: 11 additions & 0 deletions mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ type Client interface {
Ping(ctx context.Context) error
}

// Closer is the interface of connection closer.
type Closer interface {
Close() error
}

// ClientCloser groups Client and Closer interface
type ClientCloser interface {
Client
Closer
}

// HandlerFunc type is an adapter to use functions as MQTT message handler.
type HandlerFunc func(*Message)

Expand Down
68 changes: 68 additions & 0 deletions retryclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package mqtt

import (
"context"
"sync"
)

// RetryClient queues unacknowledged messages and retry on reconnect.
type RetryClient struct {
Client
queue []*Message // unacknoledged messages
mu sync.Mutex
muMsg sync.Mutex
}

// Publish tries to publish the message and immediately return nil.
// If it is not acknowledged to be published, the message will be queued and
// retried on the next connection.
func (c *RetryClient) Publish(ctx context.Context, message *Message) error {
c.mu.Lock()
cli := c.Client
c.mu.Unlock()
go func() {
c.publish(ctx, cli, message)
}()
return nil
}

func (c *RetryClient) publish(ctx context.Context, cli Client, message *Message) {
if err := cli.Publish(ctx, message); err != nil {
select {
case <-ctx.Done():
// User cancelled; don't queue.
default:
if message.QoS > QoS0 {
copyMsg := *message

c.muMsg.Lock()
copyMsg.Dup = true
c.queue = append(c.queue, &copyMsg)
c.muMsg.Unlock()
}
}
}
}

// SetClient sets the new Client.
// If there are any queued messages, retry to publish them.
func (c *RetryClient) SetClient(ctx context.Context, cli Client) error {
c.mu.Lock()
defer c.mu.Unlock()
c.Client = cli

c.muMsg.Lock()
var oldQueue []*Message
copy(oldQueue, c.queue)
c.queue = nil
c.muMsg.Unlock()

// Retry publish.
go func() {
for _, msg := range oldQueue {
c.publish(ctx, cli, msg)
}
}()

return nil
}
44 changes: 44 additions & 0 deletions retryclient_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// +build integration

package mqtt

import (
"context"
"crypto/tls"
"testing"
"time"
)

func TestIntegration_RetryClient(t *testing.T) {
for name, url := range urls {
t.Run(name, func(t *testing.T) {
cliBase, err := Dial(url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}))
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

var cli RetryClient
cli.SetClient(ctx, cliBase)

if _, err := cli.Connect(ctx, "Client1"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := cli.Publish(ctx, &Message{
Topic: "test",
QoS: QoS1,
Payload: []byte("message"),
}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := cli.Disconnect(ctx); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
})
}

}

0 comments on commit 2d271ca

Please sign in to comment.