diff --git a/client.go b/client.go index 0d90a25..712cabc 100644 --- a/client.go +++ b/client.go @@ -10,11 +10,10 @@ import ( type BaseClient struct { // Transport is an underlying connection. Typically net.Conn. Transport io.ReadWriteCloser - // Handler of incoming messages. - Handler Handler // ConnState is called if the connection state is changed. ConnState func(ConnState, error) + handler Handler sig *signaller mu sync.RWMutex connState ConnState @@ -24,6 +23,13 @@ type BaseClient struct { idLast uint32 } +// Handle registers the message handler. +func (c *BaseClient) Handle(handler Handler) { + c.mu.Lock() + defer c.mu.Unlock() + c.handler = handler +} + // WithUserNamePassword sets plain text auth information used in Connect. func WithUserNamePassword(userName, password string) ConnectOption { return func(o *ConnectOptions) error { diff --git a/client_integration_test.go b/client_integration_test.go index 8a5fa62..28a35e6 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -27,16 +27,17 @@ func ExampleClient() { if err != nil { panic(err) } - baseCli.Handler = HandlerFunc(func(msg *Message) { - fmt.Printf("%s[%d]: %s", msg.Topic, int(msg.QoS), []byte(msg.Payload)) - close(done) - }) // store as Client to make it easy to enable high level wrapper later var cli Client = baseCli ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + cli.Handle(HandlerFunc(func(msg *Message) { + fmt.Printf("%s[%d]: %s", msg.Topic, int(msg.QoS), []byte(msg.Payload)) + close(done) + })) + if _, err := cli.Connect(ctx, "TestClient", WithCleanSession(true)); err != nil { panic(err) } @@ -123,9 +124,6 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) { } chReceived := make(chan *Message, 100) - cli.Handler = HandlerFunc(func(msg *Message) { - chReceived <- msg - }) cli.ConnState = func(s ConnState, err error) { switch s { case StateActive: @@ -142,6 +140,10 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) { t.Fatalf("Unexpected error: '%v'", err) } + cli.Handle(HandlerFunc(func(msg *Message) { + chReceived <- msg + })) + if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil { t.Fatalf("Unexpected error: '%v'", err) } @@ -155,6 +157,8 @@ func TestIntegration_PublishQoS2_SubscribeQoS2(t *testing.T) { } select { + case <-ctx.Done(): + t.Fatalf("Unexpected error: '%v'", ctx.Err()) case msg, ok := <-chReceived: if !ok { t.Errorf("Connection closed unexpectedly") @@ -238,9 +242,6 @@ func BenchmarkPublishSubscribe(b *testing.B) { } chReceived := make(chan *Message, 100) - cli.Handler = HandlerFunc(func(msg *Message) { - chReceived <- msg - }) cli.ConnState = func(s ConnState, err error) { switch s { case StateActive: @@ -256,6 +257,10 @@ func BenchmarkPublishSubscribe(b *testing.B) { b.Fatalf("Unexpected error: '%v'", err) } + cli.Handle(HandlerFunc(func(msg *Message) { + chReceived <- msg + })) + if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil { b.Fatalf("Unexpected error: '%v'", err) } diff --git a/conn.go b/conn.go index 07c0fcd..00a6e12 100644 --- a/conn.go +++ b/conn.go @@ -16,6 +16,22 @@ var ErrUnsupportedProtocol = errors.New("unsupported protocol") // ErrClosedTransport means that the underlying connection is closed. var ErrClosedTransport = errors.New("read/write on closed transport") +// URLDialer is a Dialer using URL string. +type URLDialer struct { + URL string + Options []DialOption +} + +// Dialer is an interface to create connection. +type Dialer interface { + Dial() (ClientCloser, error) +} + +// Dial creates connection using its values. +func (d *URLDialer) Dial() (ClientCloser, error) { + return Dial(d.URL, d.Options...) +} + // Dial creates MQTT client using URL string. func Dial(urlStr string, opts ...DialOption) (*BaseClient, error) { o := &DialOptions{ @@ -113,3 +129,15 @@ func (c *BaseClient) connStateUpdate(newState ConnState) { func (c *BaseClient) Close() error { return c.Transport.Close() } + +// Done is a channel to signal connection close. +func (c *BaseClient) Done() <-chan struct{} { + return c.connClosed +} + +// Err returns connection error. +func (c *BaseClient) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} diff --git a/mqtt.go b/mqtt.go index 63d7e1c..80f2b09 100644 --- a/mqtt.go +++ b/mqtt.go @@ -51,11 +51,14 @@ type Client interface { Subscribe(ctx context.Context, subs ...Subscription) error Unsubscribe(ctx context.Context, subs ...string) error Ping(ctx context.Context) error + Handle(Handler) } // Closer is the interface of connection closer. type Closer interface { Close() error + Done() <-chan struct{} + Err() error } // ClientCloser groups Client and Closer interface diff --git a/reconnclient.go b/reconnclient.go new file mode 100644 index 0000000..5527809 --- /dev/null +++ b/reconnclient.go @@ -0,0 +1,72 @@ +package mqtt + +import ( + "context" + "sync" + "time" +) + +type reconnectClient struct { + Client +} + +// NewReconnectClient creates a MQTT client with re-connect/re-publish/re-subscribe features. +func NewReconnectClient(ctx context.Context, dialer Dialer, clientID string, opts ...ConnectOption) Client { + rc := &RetryClient{} + cli := &reconnectClient{ + Client: rc, + } + done := make(chan struct{}) + var doneOnce sync.Once + go func() { + clean := true + reconnWaitBase := 50 * time.Millisecond + reconnWaitMax := 10 * time.Second + reconnWait := reconnWaitBase + for { + if c, err := dialer.Dial(); err == nil { + optsCurr := append([]ConnectOption{}, opts...) + optsCurr = append(optsCurr, WithCleanSession(clean)) + clean = false // Clean only first time. + reconnWait = reconnWaitBase // Reset reconnect wait. + rc.SetClient(ctx, c) + + if present, err := rc.Connect(ctx, clientID, optsCurr...); err == nil { + doneOnce.Do(func() { close(done) }) + if present { + rc.Resubscribe(ctx) + } + // Start keep alive. + go func() { + _ = KeepAlive(ctx, c, time.Second, time.Second) + }() + select { + case <-c.Done(): + if err := c.Err(); err == nil { + // Disconnected as expected; don't restart. + return + } + case <-ctx.Done(): + // User cancelled; don't restart. + return + } + } + } + select { + case <-time.After(reconnWait): + case <-ctx.Done(): + // User cancelled; don't restart. + return + } + reconnWait *= 2 + if reconnWait > reconnWaitMax { + reconnWait = reconnWaitMax + } + } + }() + select { + case <-done: + case <-ctx.Done(): + } + return cli +} diff --git a/reconnclient_integration_test.go b/reconnclient_integration_test.go new file mode 100644 index 0000000..b8ea763 --- /dev/null +++ b/reconnclient_integration_test.go @@ -0,0 +1,58 @@ +// +build integration + +package mqtt + +import ( + "context" + "crypto/tls" + "testing" + "time" +) + +func TestIntegration_ReconnectClient(t *testing.T) { + for name, url := range urls { + t.Run(name, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + chReceived := make(chan *Message, 100) + cli := NewReconnectClient( + ctx, + &URLDialer{ + URL: url, + Options: []DialOption{ + WithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + }, + }, + "ReconnectClient", + ) + cli.Handle(HandlerFunc(func(msg *Message) { + chReceived <- msg + })) + + // Close underlying client. + cli.(*reconnectClient).Client.(*RetryClient).Client.(ClientCloser).Close() + + if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS1}); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + if err := cli.Publish(ctx, &Message{ + Topic: "test", + QoS: QoS1, + Retain: true, + Payload: []byte("message"), + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + time.Sleep(time.Second) + + select { + case <-ctx.Done(): + t.Fatalf("Unexpected error: '%v'", ctx.Err()) + case <-chReceived: + } + }) + } + +} diff --git a/retryclient.go b/retryclient.go index f05b0a2..9ff7464 100644 --- a/retryclient.go +++ b/retryclient.go @@ -8,24 +8,49 @@ import ( // RetryClient queues unacknowledged messages and retry on reconnect. type RetryClient struct { Client - queue []*Message // unacknoledged messages - mu sync.Mutex - muMsg sync.Mutex + + pubQueue []*Message // unacknoledged messages + subQueue []Subscription // unacknoledged subscriptions + subEstablished []Subscription // acknoledged subscriptions + mu sync.Mutex + muQueue sync.Mutex + handler Handler +} + +// Handle registers the message handler. +func (c *RetryClient) Handle(handler Handler) { + c.mu.Lock() + defer c.mu.Unlock() + c.handler = handler + c.Client.Handle(handler) } // Publish tries to publish the message and immediately return nil. // If it is not acknowledged to be published, the message will be queued and // retried on the next connection. func (c *RetryClient) Publish(ctx context.Context, message *Message) error { - c.mu.Lock() - cli := c.Client - c.mu.Unlock() go func() { + c.mu.Lock() + cli := c.Client + c.mu.Unlock() c.publish(ctx, cli, message) }() return nil } +// Subscribe tries to subscribe the topic and immediately return nil. +// If it is not acknowledged to be subscribed, the request will be queued and +// retried on the next connection. +func (c *RetryClient) Subscribe(ctx context.Context, subs ...Subscription) error { + go func() { + c.mu.Lock() + cli := c.Client + c.mu.Unlock() + c.subscribe(ctx, cli, subs...) + }() + return nil +} + func (c *RetryClient) publish(ctx context.Context, cli Client, message *Message) { if err := cli.Publish(ctx, message); err != nil { select { @@ -35,34 +60,76 @@ func (c *RetryClient) publish(ctx context.Context, cli Client, message *Message) if message.QoS > QoS0 { copyMsg := *message - c.muMsg.Lock() + c.muQueue.Lock() copyMsg.Dup = true - c.queue = append(c.queue, ©Msg) - c.muMsg.Unlock() + c.pubQueue = append(c.pubQueue, ©Msg) + c.muQueue.Unlock() } } } } +func (c *RetryClient) subscribe(ctx context.Context, cli Client, subs ...Subscription) { + if err := cli.Subscribe(ctx, subs...); err != nil { + select { + case <-ctx.Done(): + // User cancelled; don't queue. + default: + c.muQueue.Lock() + c.subQueue = append(c.subQueue, subs...) + c.muQueue.Unlock() + } + } else { + c.muQueue.Lock() + c.subEstablished = append(c.subEstablished, subs...) + c.muQueue.Unlock() + } +} + // SetClient sets the new Client. // If there are any queued messages, retry to publish them. -func (c *RetryClient) SetClient(ctx context.Context, cli Client) error { +func (c *RetryClient) SetClient(ctx context.Context, cli Client) { c.mu.Lock() defer c.mu.Unlock() c.Client = cli +} + +// Connect to the broker. +func (c *RetryClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) { + c.muQueue.Lock() + oldPubQueue := append([]*Message{}, c.pubQueue...) + oldSubQueue := append([]Subscription{}, c.subQueue...) + c.pubQueue = nil + c.subQueue = nil + c.muQueue.Unlock() - c.muMsg.Lock() - var oldQueue []*Message - copy(oldQueue, c.queue) - c.queue = nil - c.muMsg.Unlock() + c.mu.Lock() + cli := c.Client + cli.Handle(c.handler) + c.mu.Unlock() + + present, err := cli.Connect(ctx, clientID, opts...) // Retry publish. go func() { - for _, msg := range oldQueue { + if len(oldSubQueue) > 0 { + c.subscribe(ctx, cli, oldSubQueue...) + } + for _, msg := range oldPubQueue { c.publish(ctx, cli, msg) } }() + return present, err +} + +// Resubscribe subscribes all established subscriptions. +func (c *RetryClient) Resubscribe(ctx context.Context) error { + if len(c.subEstablished) > 0 { + c.mu.Lock() + cli := c.Client + c.mu.Unlock() + c.subscribe(ctx, cli, c.subEstablished...) + } return nil } diff --git a/serve.go b/serve.go index 58c64d1..9a75fcc 100644 --- a/serve.go +++ b/serve.go @@ -48,11 +48,20 @@ func (c *BaseClient) serve() error { publish := (&pktPublish{}).parse(pktFlag, contents) switch publish.Message.QoS { case QoS0: - if c.Handler != nil { - c.Handler.Serve(&publish.Message) + c.mu.RLock() + handler := c.handler + c.mu.RUnlock() + if handler != nil { + handler.Serve(&publish.Message) } case QoS1: // Ownership of the message is now transferred to the receiver. + c.mu.RLock() + handler := c.handler + c.mu.RUnlock() + if handler != nil { + handler.Serve(&publish.Message) + } pktPubAck := pack( packetPubAck.b()|packetFromClient.b(), packUint16(publish.Message.ID), @@ -60,9 +69,6 @@ func (c *BaseClient) serve() error { if err := c.write(pktPubAck); err != nil { return err } - if c.Handler != nil { - c.Handler.Serve(&publish.Message) - } case QoS2: pktPubRec := pack( packetPubRec.b()|packetFromClient.b(), @@ -93,8 +99,11 @@ func (c *BaseClient) serve() error { pubRel := (&pktPubRel{}).parse(pktFlag, contents) if msg, ok := subBuffer[pubRel.ID]; ok { // Ownership of the message is now transferred to the receiver. - if c.Handler != nil { - c.Handler.Serve(msg) + c.mu.RLock() + handler := c.handler + c.mu.RUnlock() + if handler != nil { + handler.Serve(msg) } delete(subBuffer, pubRel.ID) } diff --git a/uniqid.go b/uniqid.go index 830ba58..d011042 100644 --- a/uniqid.go +++ b/uniqid.go @@ -11,7 +11,7 @@ func init() { } func (c *BaseClient) initID() { - c.idLast = uint32(rand.Int31n(0xFFFE)) + 1 + atomic.StoreUint32(&c.idLast, uint32(rand.Int31n(0xFFFE))+1) } func (c *BaseClient) newID() uint16 {