Skip to content

Commit

Permalink
Add ReconnectClient (#18)
Browse files Browse the repository at this point in the history
* Add ReconnectClient
* Temporary implement incremental wait
* Implement auto resubscribe
* Add Client.Handle
  • Loading branch information
at-wat authored Dec 23, 2019
1 parent 58be7b1 commit 11b1b9d
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 36 deletions.
10 changes: 8 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@ import (
type BaseClient struct {
// Transport is an underlying connection. Typically net.Conn.
Transport io.ReadWriteCloser
// Handler of incoming messages.
Handler Handler
// ConnState is called if the connection state is changed.
ConnState func(ConnState, error)

handler Handler
sig *signaller
mu sync.RWMutex
connState ConnState
Expand All @@ -24,6 +23,13 @@ type BaseClient struct {
idLast uint32
}

// Handle registers the message handler.
func (c *BaseClient) Handle(handler Handler) {
c.mu.Lock()
defer c.mu.Unlock()
c.handler = handler
}

// WithUserNamePassword sets plain text auth information used in Connect.
func WithUserNamePassword(userName, password string) ConnectOption {
return func(o *ConnectOptions) error {
Expand Down
25 changes: 15 additions & 10 deletions client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,17 @@ func ExampleClient() {
if err != nil {
panic(err)
}
baseCli.Handler = HandlerFunc(func(msg *Message) {
fmt.Printf("%s[%d]: %s", msg.Topic, int(msg.QoS), []byte(msg.Payload))
close(done)
})

// store as Client to make it easy to enable high level wrapper later
var cli Client = baseCli
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

cli.Handle(HandlerFunc(func(msg *Message) {
fmt.Printf("%s[%d]: %s", msg.Topic, int(msg.QoS), []byte(msg.Payload))
close(done)
}))

if _, err := cli.Connect(ctx, "TestClient", WithCleanSession(true)); err != nil {
panic(err)
}
Expand Down Expand Up @@ -123,9 +124,6 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) {
}

chReceived := make(chan *Message, 100)
cli.Handler = HandlerFunc(func(msg *Message) {
chReceived <- msg
})
cli.ConnState = func(s ConnState, err error) {
switch s {
case StateActive:
Expand All @@ -142,6 +140,10 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) {
t.Fatalf("Unexpected error: '%v'", err)
}

cli.Handle(HandlerFunc(func(msg *Message) {
chReceived <- msg
}))

if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
Expand All @@ -155,6 +157,8 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) {
}

select {
case <-ctx.Done():
t.Fatalf("Unexpected error: '%v'", ctx.Err())
case msg, ok := <-chReceived:
if !ok {
t.Errorf("Connection closed unexpectedly")
Expand Down Expand Up @@ -238,9 +242,6 @@ func BenchmarkPublishSubscribe(b *testing.B) {
}

chReceived := make(chan *Message, 100)
cli.Handler = HandlerFunc(func(msg *Message) {
chReceived <- msg
})
cli.ConnState = func(s ConnState, err error) {
switch s {
case StateActive:
Expand All @@ -256,6 +257,10 @@ func BenchmarkPublishSubscribe(b *testing.B) {
b.Fatalf("Unexpected error: '%v'", err)
}

cli.Handle(HandlerFunc(func(msg *Message) {
chReceived <- msg
}))

if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil {
b.Fatalf("Unexpected error: '%v'", err)
}
Expand Down
28 changes: 28 additions & 0 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ var ErrUnsupportedProtocol = errors.New("unsupported protocol")
// ErrClosedTransport means that the underlying connection is closed.
var ErrClosedTransport = errors.New("read/write on closed transport")

// URLDialer is a Dialer using URL string.
type URLDialer struct {
URL string
Options []DialOption
}

// Dialer is an interface to create connection.
type Dialer interface {
Dial() (ClientCloser, error)
}

// Dial creates connection using its values.
func (d *URLDialer) Dial() (ClientCloser, error) {
return Dial(d.URL, d.Options...)
}

// Dial creates MQTT client using URL string.
func Dial(urlStr string, opts ...DialOption) (*BaseClient, error) {
o := &DialOptions{
Expand Down Expand Up @@ -113,3 +129,15 @@ func (c *BaseClient) connStateUpdate(newState ConnState) {
func (c *BaseClient) Close() error {
return c.Transport.Close()
}

// Done is a channel to signal connection close.
func (c *BaseClient) Done() <-chan struct{} {
return c.connClosed
}

// Err returns connection error.
func (c *BaseClient) Err() error {
c.mu.Lock()
defer c.mu.Unlock()
return c.err
}
3 changes: 3 additions & 0 deletions mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,14 @@ type Client interface {
Subscribe(ctx context.Context, subs ...Subscription) error
Unsubscribe(ctx context.Context, subs ...string) error
Ping(ctx context.Context) error
Handle(Handler)
}

// Closer is the interface of connection closer.
type Closer interface {
Close() error
Done() <-chan struct{}
Err() error
}

// ClientCloser groups Client and Closer interface
Expand Down
72 changes: 72 additions & 0 deletions reconnclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package mqtt

import (
"context"
"sync"
"time"
)

type reconnectClient struct {
Client
}

// NewReconnectClient creates a MQTT client with re-connect/re-publish/re-subscribe features.
func NewReconnectClient(ctx context.Context, dialer Dialer, clientID string, opts ...ConnectOption) Client {
rc := &RetryClient{}
cli := &reconnectClient{
Client: rc,
}
done := make(chan struct{})
var doneOnce sync.Once
go func() {
clean := true
reconnWaitBase := 50 * time.Millisecond
reconnWaitMax := 10 * time.Second
reconnWait := reconnWaitBase
for {
if c, err := dialer.Dial(); err == nil {
optsCurr := append([]ConnectOption{}, opts...)
optsCurr = append(optsCurr, WithCleanSession(clean))
clean = false // Clean only first time.
reconnWait = reconnWaitBase // Reset reconnect wait.
rc.SetClient(ctx, c)

if present, err := rc.Connect(ctx, clientID, optsCurr...); err == nil {
doneOnce.Do(func() { close(done) })
if present {
rc.Resubscribe(ctx)
}
// Start keep alive.
go func() {
_ = KeepAlive(ctx, c, time.Second, time.Second)
}()
select {
case <-c.Done():
if err := c.Err(); err == nil {
// Disconnected as expected; don't restart.
return
}
case <-ctx.Done():
// User cancelled; don't restart.
return
}
}
}
select {
case <-time.After(reconnWait):
case <-ctx.Done():
// User cancelled; don't restart.
return
}
reconnWait *= 2
if reconnWait > reconnWaitMax {
reconnWait = reconnWaitMax
}
}
}()
select {
case <-done:
case <-ctx.Done():
}
return cli
}
58 changes: 58 additions & 0 deletions reconnclient_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// +build integration

package mqtt

import (
"context"
"crypto/tls"
"testing"
"time"
)

func TestIntegration_ReconnectClient(t *testing.T) {
for name, url := range urls {
t.Run(name, func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

chReceived := make(chan *Message, 100)
cli := NewReconnectClient(
ctx,
&URLDialer{
URL: url,
Options: []DialOption{
WithTLSConfig(&tls.Config{InsecureSkipVerify: true}),
},
},
"ReconnectClient",
)
cli.Handle(HandlerFunc(func(msg *Message) {
chReceived <- msg
}))

// Close underlying client.
cli.(*reconnectClient).Client.(*RetryClient).Client.(ClientCloser).Close()

if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS1}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
if err := cli.Publish(ctx, &Message{
Topic: "test",
QoS: QoS1,
Retain: true,
Payload: []byte("message"),
}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

time.Sleep(time.Second)

select {
case <-ctx.Done():
t.Fatalf("Unexpected error: '%v'", ctx.Err())
case <-chReceived:
}
})
}

}
Loading

0 comments on commit 11b1b9d

Please sign in to comment.