diff --git a/reconnclient.go b/reconnclient.go index fcad1b4..f57c8b7 100644 --- a/reconnclient.go +++ b/reconnclient.go @@ -61,9 +61,7 @@ func NewReconnectClient(dialer Dialer, opts ...ReconnectOption) (ReconnectClient // The function returns after establishing a first connection, which can be canceled by the context. // Once after establishing the connection, the retry loop is not affected by the context. func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ...ConnectOption) (bool, error) { - connOptions := &ConnectOptions{ - CleanSession: true, - } + connOptions := &ConnectOptions{} for _, opt := range opts { if err := opt(connOptions); err != nil { return false, err @@ -78,49 +76,48 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ... var errDial, errConnect firstError - done := make(chan struct{}) + done := make(chan bool, 1) var doneOnce sync.Once - var sessionPresent bool go func(ctx context.Context) { defer func() { close(c.done) }() - clean := connOptions.CleanSession reconnWait := c.options.ReconnectWaitBase + var initialized bool for { if baseCli, err := c.dialer.DialContext(ctx); err == nil { - optsCurr := append([]ConnectOption{}, opts...) - optsCurr = append(optsCurr, WithCleanSession(clean)) - clean = false // Clean only first time. c.RetryClient.SetClient(ctx, baseCli) - var ctxTimeout context.Context - var cancel func() + var ctxConnect context.Context + var cancelConnect func() if c.options.Timeout == 0 { - ctxTimeout, cancel = ctx, func() {} + ctxConnect, cancelConnect = ctx, func() {} } else { - ctxTimeout, cancel = context.WithTimeout(ctx, c.options.Timeout) + ctxConnect, cancelConnect = context.WithTimeout(ctx, c.options.Timeout) } - if sessionPresent, err := c.RetryClient.Connect(ctxTimeout, clientID, optsCurr...); err == nil { - cancel() + if sessionPresent, err := c.RetryClient.Connect(ctxConnect, clientID, opts...); err == nil { + cancelConnect() reconnWait = c.options.ReconnectWaitBase // Reset reconnect wait. doneOnce.Do(func() { ctx = context.Background() + done <- sessionPresent close(done) }) - if !sessionPresent { + if initialized && (!sessionPresent || c.options.AlwaysResubscribe) { c.RetryClient.Resubscribe(ctx) } c.RetryClient.Retry(ctx) + initialized = true + ctxKeepAlive, cancelKeepAlive := context.WithCancel(ctx) if c.options.PingInterval > time.Duration(0) { // Start keep alive. go func() { if err := KeepAlive( - ctx, baseCli, + ctxKeepAlive, baseCli, c.options.PingInterval, c.options.Timeout, ); err != nil { @@ -133,20 +130,23 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ... } select { case <-baseCli.Done(): + cancelKeepAlive() if err := baseCli.Err(); err == nil { // Disconnected as expected; don't restart. return } case <-ctx.Done(): + cancelKeepAlive() // User cancelled; don't restart. return case <-c.disconnected: + cancelKeepAlive() return } - } else if err != ctxTimeout.Err() { + } else if err != ctxConnect.Err() { errConnect.Store(err) // Hold first connect error excepting context cancel. } - cancel() + cancelConnect() } else if err != ctx.Err() { errDial.Store(err) // Hold first dial error excepting context cancel. } @@ -165,7 +165,8 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ... } }(ctx) select { - case <-done: + case sessionPresent := <-done: + return sessionPresent, nil case <-ctx.Done(): var actualErrs []string if err := errDial.Load(); err != nil { @@ -180,7 +181,6 @@ func (c *reconnectClient) Connect(ctx context.Context, clientID string, opts ... } return false, wrapErrorf(ctx.Err(), "establishing first connection%s", errStr) } - return sessionPresent, nil } // Disconnect from the broker. @@ -203,6 +203,7 @@ type ReconnectOptions struct { ReconnectWaitMax time.Duration PingInterval time.Duration RetryClient *RetryClient + AlwaysResubscribe bool } // ReconnectOption sets option for Connect. @@ -243,3 +244,14 @@ func WithRetryClient(cli *RetryClient) ReconnectOption { return nil } } + +// WithAlwaysResubscribe enables or disables re-subscribe on reconnect. +// Default value is false. +// This option can be used to ensure all subscriptions are restored +// even if the server is buggy. +func WithAlwaysResubscribe(always bool) ReconnectOption { + return func(o *ReconnectOptions) error { + o.AlwaysResubscribe = always + return nil + } +} diff --git a/reconnclient_integration_test.go b/reconnclient_integration_test.go index 7014e56..b529137 100644 --- a/reconnclient_integration_test.go +++ b/reconnclient_integration_test.go @@ -22,6 +22,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "sync" "sync/atomic" "testing" @@ -92,7 +93,7 @@ func TestIntegration_ReconnectClient(t *testing.T) { } } -func newCloseFilter(key byte, en bool) func([]byte) bool { +func newFilterBase(cbMsg func([]byte) bool) func([]byte) bool { var readBuf []byte return func(b []byte) (ret bool) { readBuf = append(readBuf, b...) @@ -101,9 +102,6 @@ func newCloseFilter(key byte, en bool) func([]byte) bool { if len(readBuf) == 0 { return } - if readBuf[0]&0xF0 == key { - ret = en - } var length int for i := 1; i < 5; i++ { if i >= len(readBuf) { @@ -115,7 +113,11 @@ func newCloseFilter(key byte, en bool) func([]byte) bool { break } } - if length >= len(readBuf) { + if length > len(readBuf) { + return + } + if cbMsg(readBuf[:length]) { + ret = true return } readBuf = readBuf[length:] @@ -123,6 +125,12 @@ func newCloseFilter(key byte, en bool) func([]byte) bool { } } +func newCloseFilter(key byte, en bool) func([]byte) bool { + return newFilterBase(func(msg []byte) bool { + return en && msg[0]&0xF0 == key + }) +} + func TestIntegration_ReconnectClient_Resubscribe(t *testing.T) { for name, url := range urls { url := url @@ -220,9 +228,8 @@ func TestIntegration_ReconnectClient_Resubscribe(t *testing.T) { } cli.Disconnect(ctx) - cnt := atomic.LoadInt32(&dialCnt) - if cnt < 2 { - t.Errorf("Must be dialled at least twice, dialled %d times", cnt) + if cnt := atomic.LoadInt32(&dialCnt); cnt < 2 { + t.Errorf("Must be dialed at least twice, dialed %d times", cnt) } }) } @@ -232,6 +239,153 @@ func TestIntegration_ReconnectClient_Resubscribe(t *testing.T) { } } +func TestIntegration_ReconnectClient_SessionPersistence(t *testing.T) { + for name, url := range urls { + url := url + t.Run(name, func(t *testing.T) { + for resubName, alwaysResub := range map[string]bool{ + "Always": true, + "IfNotPresent": false, + } { + alwaysResub := alwaysResub + resubName := resubName + t.Run(resubName, func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + var subCnt int32 + var dialCnt int32 + var actualConn atomic.Value + + cli, err := NewReconnectClient( + DialerFunc(func(ctx context.Context) (*BaseClient, error) { + cli, err := DialContext(ctx, url, + WithTLSConfig(&tls.Config{InsecureSkipVerify: true}), + ) + if err != nil { + return nil, err + } + atomic.AddInt32(&dialCnt, 1) + ca, cb := filteredpipe.DetectAndClosePipe( + newFilterBase(func([]byte) bool { return false }), + newFilterBase(func(msg []byte) bool { + if msg[0]&0xF0 == 0x80 { + atomic.AddInt32(&subCnt, 1) + } + return false + }), + ) + filteredpipe.Connect(ca, cli.Transport) + actualConn.Store(cli.Transport) + cli.Transport = cb + return cli, nil + }), + WithPingInterval(250*time.Millisecond), + WithTimeout(250*time.Millisecond), + WithReconnectWait(200*time.Millisecond, time.Second), + WithAlwaysResubscribe(alwaysResub), + ) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + id := time.Now().UnixNano() + + chReceived := make(chan *Message, 100) + cli.Handle(HandlerFunc(func(msg *Message) { + chReceived <- msg + })) + _, err = cli.Connect( + ctx, + fmt.Sprintf("ReconnectClientSession%s-%d", name, id), + ) + if err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + topic := fmt.Sprintf("test_session/%s/%d", name, id) + if _, err := cli.Subscribe(ctx, Subscription{ + Topic: topic, + QoS: QoS2, + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + if err := cli.Publish(ctx, &Message{ + Topic: topic, + QoS: QoS2, + Payload: []byte{1}, + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + for { + select { + case <-time.After(50 * time.Millisecond): + case <-ctx.Done(): + t.Fatal("Timeout") + } + if cnt := atomic.LoadInt32(&dialCnt); cnt >= 1 { + break + } + } + select { + case <-chReceived: + case <-ctx.Done(): + t.Fatal("Timeout") + } + + actualConn.Load().(io.ReadWriteCloser).Close() + + if err := cli.Publish(ctx, &Message{ + Topic: topic, + QoS: QoS2, + Payload: []byte{2}, + }); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } + + for { + select { + case <-time.After(50 * time.Millisecond): + case <-ctx.Done(): + t.Fatal("Timeout") + } + if cnt := atomic.LoadInt32(&dialCnt); cnt >= 2 { + break + } + } + + select { + case <-chReceived: + case <-ctx.Done(): + t.Fatal("Timeout") + } + + cli.Disconnect(ctx) + + select { + case <-cli.Client().Done(): + case <-ctx.Done(): + t.Fatal("Timeout") + } + + if cnt := atomic.LoadInt32(&dialCnt); cnt != 2 { + t.Errorf("Must be dialed twice, dialed %d times", cnt) + } + if alwaysResub { + if cnt := atomic.LoadInt32(&subCnt); cnt != 2 { + t.Errorf("Must be subscribed twice, subscribed %d times", cnt) + } + } else { + if cnt := atomic.LoadInt32(&subCnt); cnt != 1 { + t.Errorf("Must be subscribed once, subscribed %d times", cnt) + } + } + }) + } + }) + } +} + func newOnOffFilter(sw *int32) func([]byte) bool { return func(b []byte) bool { s := atomic.LoadInt32(sw) @@ -678,6 +832,8 @@ func TestIntegration_ReconnectClient_RepeatedDisconnect(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + topic := fmt.Sprintf("test_%d_%s", qos, name) + cliRaw, err := DialContext( ctx, url, WithTLSConfig(&tls.Config{InsecureSkipVerify: true}), @@ -691,6 +847,16 @@ func TestIntegration_ReconnectClient_RepeatedDisconnect(t *testing.T) { ); err != nil { t.Fatalf("Unexpected error: '%v'", err) } + var mu sync.Mutex + received := make(map[byte]int) + cliRaw.Handle(HandlerFunc(func(msg *Message) { + mu.Lock() + defer mu.Unlock() + received[msg.Payload[0]]++ + })) + if _, err := cliRaw.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil { + t.Fatalf("Unexpected error: '%v'", err) + } cli, err := NewReconnectClient( &URLDialer{ @@ -716,19 +882,6 @@ func TestIntegration_ReconnectClient_RepeatedDisconnect(t *testing.T) { t.Fatalf("Unexpected error: '%v'", err) } - topic := fmt.Sprintf("test_%d_%s", qos, name) - - var mu sync.Mutex - received := make(map[byte]int) - cliRaw.Handle(HandlerFunc(func(msg *Message) { - mu.Lock() - defer mu.Unlock() - received[msg.Payload[0]]++ - })) - if _, err := cliRaw.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil { - t.Fatalf("Unexpected error: '%v'", err) - } - go func() { cli := cli.(*reconnectClient) for { @@ -757,7 +910,26 @@ func TestIntegration_ReconnectClient_RepeatedDisconnect(t *testing.T) { time.Sleep(10 * time.Millisecond) } - time.Sleep(500 * time.Millisecond) + tick := time.NewTicker(50 * time.Millisecond) + timeoutCnt := 10 + for { + <-tick.C + timeoutCnt-- + if timeoutCnt <= 0 { + break + } + mu.Lock() + var i int + for i = 0; i < testCount; i++ { + if received[byte(i)] < 1 { + break + } + } + mu.Unlock() + if i == testCount { + break + } + } mu.Lock() defer mu.Unlock()