Skip to content

Commit

Permalink
Return subscribe result (#157)
Browse files Browse the repository at this point in the history
  • Loading branch information
at-wat authored Feb 16, 2021
1 parent 4db1be6 commit 70ce0b0
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 33 deletions.
19 changes: 15 additions & 4 deletions client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"context"
"crypto/tls"
"fmt"
"reflect"
"testing"
"time"
)
Expand Down Expand Up @@ -55,7 +56,7 @@ func ExampleClient() {
if _, err := cli.Connect(ctx, "TestClient", WithCleanSession(true)); err != nil {
panic(err)
}
if err := cli.Subscribe(ctx, Subscription{Topic: "test/topic", QoS: QoS1}); err != nil {
if _, err := cli.Subscribe(ctx, Subscription{Topic: "test/topic", QoS: QoS1}); err != nil {
panic(err)
}

Expand Down Expand Up @@ -165,9 +166,14 @@ func TestIntegration_PublishSubscribe(t *testing.T) {
}))

topic := "test_pubsub_" + name
if err := cli.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil {
subs, err := cli.Subscribe(ctx, Subscription{Topic: topic, QoS: qos})
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
expectedSubs := []Subscription{{Topic: topic, QoS: qos}}
if !reflect.DeepEqual(expectedSubs, subs) {
t.Fatalf("Expected subscriptions: %v, actual: %v", expectedSubs, subs)
}

if err := cli.Publish(ctx, &Message{
Topic: topic,
Expand Down Expand Up @@ -216,9 +222,14 @@ func TestIntegration_SubscribeUnsubscribe(t *testing.T) {
t.Fatalf("Unexpected error: '%v'", err)
}

if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil {
subs, err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2})
if err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
expectedSubs := []Subscription{{Topic: "test", QoS: QoS2}}
if !reflect.DeepEqual(expectedSubs, subs) {
t.Fatalf("Expected subscriptions: %v, actual: %v", expectedSubs, subs)
}

if err := cli.Unsubscribe(ctx, "test"); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
Expand Down Expand Up @@ -284,7 +295,7 @@ func BenchmarkPublishSubscribe(b *testing.B) {
chReceived <- msg
}))

if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil {
if _, err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil {
b.Fatalf("Unexpected error: '%v'", err)
}

Expand Down
2 changes: 1 addition & 1 deletion examples/mqtts-client-cert/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func main() {
)

// Subscribe two topics.
if err := cli.Subscribe(ctx,
if _, err := cli.Subscribe(ctx,
mqtt.Subscription{
Topic: "test",
QoS: mqtt.QoS1,
Expand Down
2 changes: 1 addition & 1 deletion examples/wss-presign-url/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func main() {
}),
)

if err := cli.Subscribe(ctx,
if _, err := cli.Subscribe(ctx,
mqtt.Subscription{
Topic: "stop",
QoS: mqtt.QoS1,
Expand Down
6 changes: 3 additions & 3 deletions mock/mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Client struct {
ConnectFn func(ctx context.Context, clientID string, opts ...mqtt.ConnectOption) (sessionPresent bool, err error)
DisconnectFn func(ctx context.Context) error
PublishFn func(ctx context.Context, message *mqtt.Message) error
SubscribeFn func(ctx context.Context, subs ...mqtt.Subscription) error
SubscribeFn func(ctx context.Context, subs ...mqtt.Subscription) ([]mqtt.Subscription, error)
UnsubscribeFn func(ctx context.Context, subs ...string) error
PingFn func(ctx context.Context) error
handler mqtt.Handler
Expand Down Expand Up @@ -58,9 +58,9 @@ func (c *Client) Publish(ctx context.Context, message *mqtt.Message) error {
}

// Subscribe implements mqtt.Client.
func (c *Client) Subscribe(ctx context.Context, subs ...mqtt.Subscription) error {
func (c *Client) Subscribe(ctx context.Context, subs ...mqtt.Subscription) ([]mqtt.Subscription, error) {
if c.SubscribeFn == nil {
return nil
return nil, nil
}
return c.SubscribeFn(ctx, subs...)
}
Expand Down
2 changes: 1 addition & 1 deletion mqtt.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type Client interface {
Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error)
Disconnect(ctx context.Context) error
Publish(ctx context.Context, message *Message) error
Subscribe(ctx context.Context, subs ...Subscription) error
Subscribe(ctx context.Context, subs ...Subscription) ([]Subscription, error)
Unsubscribe(ctx context.Context, subs ...string) error
Ping(ctx context.Context) error
Handle(Handler)
Expand Down
7 changes: 4 additions & 3 deletions packet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func TestPacketSendCancel(t *testing.T) {
cancel()

t.Run("Subscribe", func(t *testing.T) {
if err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, context.Canceled) {
if _, err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, context.Canceled) {
t.Errorf("Expected error: '%v', got: '%v'", context.Canceled, err)
}
})
Expand Down Expand Up @@ -200,7 +200,7 @@ func TestPacketSendError(t *testing.T) {
ca.Close()

t.Run("Subscribe", func(t *testing.T) {
if err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, io.ErrClosedPipe) {
if _, err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, io.ErrClosedPipe) {
t.Errorf("Expected error: '%v', got: '%v'", io.ErrClosedPipe, err)
}
})
Expand Down Expand Up @@ -240,7 +240,8 @@ func TestConnectionError(t *testing.T) {
return err
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Subscribe(ctx, Subscription{Topic: "test"})
_, err := cli.Subscribe(ctx, Subscription{Topic: "test"})
return err
},
func(ctx context.Context, cli *BaseClient) error {
return cli.Unsubscribe(ctx, "test")
Expand Down
4 changes: 2 additions & 2 deletions paho/paho.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func (c *pahoWrapper) Subscribe(topic string, qos byte, callback paho.MessageHan
return
}

token.err = cli.Subscribe(
_, token.err = cli.Subscribe(
context.Background(),
mqtt.Subscription{
Topic: topic,
Expand Down Expand Up @@ -325,7 +325,7 @@ func (c *pahoWrapper) SubscribeMultiple(filters map[string]byte, callback paho.M
return
}

token.err = cli.Subscribe(context.Background(), subs...)
_, token.err = cli.Subscribe(context.Background(), subs...)
token.release()
}()
return token
Expand Down
2 changes: 1 addition & 1 deletion paho/paho_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestIntegration_Will(t *testing.T) {
t.Fatalf("Unexpected error: '%v'", err)
}
defer cli0.Disconnect(context.Background())
if err := cli0.Subscribe(ctx, mqtt.Subscription{Topic: "will", QoS: mqtt.QoS1}); err != nil {
if _, err := cli0.Subscribe(ctx, mqtt.Subscription{Topic: "will", QoS: mqtt.QoS1}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
cli0.Handle(mqtt.HandlerFunc(func(msg *mqtt.Message) {
Expand Down
8 changes: 4 additions & 4 deletions reconnclient_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestIntegration_ReconnectClient(t *testing.T) {
time.Sleep(time.Millisecond)
cli.(*reconnectClient).cli.Close()

if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS1}); err != nil {
if _, err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS1}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
if err := cli.Publish(ctx, &Message{
Expand Down Expand Up @@ -202,7 +202,7 @@ func TestIntegration_ReconnectClient_Resubscribe(t *testing.T) {
}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}
if err := cli.Subscribe(ctx, Subscription{
if _, err := cli.Subscribe(ctx, Subscription{
Topic: "test/" + name + pktName,
QoS: QoS1,
}); err != nil {
Expand Down Expand Up @@ -267,7 +267,7 @@ func TestIntegration_ReconnectClient_RetryPublish(t *testing.T) {

topic := fmt.Sprintf("test/Retry_%s_%d", name, qos)

if err := cliRecv.Subscribe(ctx, Subscription{
if _, err := cliRecv.Subscribe(ctx, Subscription{
Topic: topic,
QoS: qos,
}); err != nil {
Expand Down Expand Up @@ -739,7 +739,7 @@ func TestIntegration_ReconnectClient_RepeatedDisconnect(t *testing.T) {
defer mu.Unlock()
received[msg.Payload[0]]++
}))
if err := cliRaw.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil {
if _, err := cliRaw.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

Expand Down
7 changes: 4 additions & 3 deletions retryclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ func (c *RetryClient) Publish(ctx context.Context, message *Message) error {

// Subscribe tries to subscribe the topic and immediately return nil.
// If it is not acknowledged to be subscribed, the request will be queued.
func (c *RetryClient) Subscribe(ctx context.Context, subs ...Subscription) error {
return wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) {
// First return value ([]Subscription) is always nil.
func (c *RetryClient) Subscribe(ctx context.Context, subs ...Subscription) ([]Subscription, error) {
return nil, wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) {
c.subscribe(ctx, false, cli, subs...)
}), "retryclient: subscribing")
}
Expand Down Expand Up @@ -110,7 +111,7 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes
func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient, subs ...Subscription) {
subscribe := func(ctx context.Context, cli *BaseClient) error {
subscriptions(subs).applyTo(&c.subEstablished)
if err := cli.Subscribe(ctx, subs...); err != nil {
if _, err := cli.Subscribe(ctx, subs...); err != nil {
select {
case <-ctx.Done():
if !retry {
Expand Down
4 changes: 2 additions & 2 deletions retryclient_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func TestIntegration_RetryClient_Cancel(t *testing.T) {
cliRecv.Handle(HandlerFunc(func(msg *Message) {
chRecv <- msg
}))
if err := cliRecv.Subscribe(ctx, Subscription{Topic: "testCancel", QoS: QoS2}); err != nil {
if _, err := cliRecv.Subscribe(ctx, Subscription{Topic: "testCancel", QoS: QoS2}); err != nil {
t.Fatalf("Unexpected error: '%v'", err)
}

Expand Down Expand Up @@ -151,7 +151,7 @@ func TestIntegration_RetryClient_TaskQueue(t *testing.T) {
done()
}
}))
if err := cli.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil {
if _, err := cli.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil {
t.Fatal(err)
}
time.Sleep(10 * time.Millisecond)
Expand Down
17 changes: 9 additions & 8 deletions subscribe.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ func (p *pktSubscribe) Pack() []byte {
}

// Subscribe topics.
func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error {
func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) ([]Subscription, error) {
return subscribeImpl(ctx, c, subs...)
}

func subscribeImpl(ctx context.Context, c *BaseClient, subs ...Subscription) error {
func subscribeImpl(ctx context.Context, c *BaseClient, subs ...Subscription) ([]Subscription, error) {
c.muConnecting.RLock()
defer c.muConnecting.RUnlock()

Expand All @@ -80,26 +80,27 @@ func subscribeImpl(ctx context.Context, c *BaseClient, subs ...Subscription) err
c.sig.mu.Unlock()

retrySubscribe := func(ctx context.Context, cli *BaseClient) error {
return subscribeImpl(ctx, cli, subs...)
_, err := subscribeImpl(ctx, cli, subs...)
return err
}

pkt := (&pktSubscribe{ID: id, Subscriptions: subs}).Pack()
if err := c.write(pkt); err != nil {
return wrapErrorWithRetry(err, retrySubscribe, "sending SUBSCRIBE")
return nil, wrapErrorWithRetry(err, retrySubscribe, "sending SUBSCRIBE")
}
select {
case <-c.connClosed:
return wrapErrorWithRetry(ErrClosedTransport, retrySubscribe, "waiting SUBACK")
return nil, wrapErrorWithRetry(ErrClosedTransport, retrySubscribe, "waiting SUBACK")
case <-ctx.Done():
return wrapErrorWithRetry(ctx.Err(), retrySubscribe, "waiting SUBACK")
return nil, wrapErrorWithRetry(ctx.Err(), retrySubscribe, "waiting SUBACK")
case subAck := <-chSubAck:
if len(subAck.Codes) != len(subs) {
c.Transport.Close()
return wrapErrorf(ErrInvalidSubAck, "subscribing %d topics: %d topics in SUBACK", len(subs), len(subAck.Codes))
return nil, wrapErrorf(ErrInvalidSubAck, "subscribing %d topics: %d topics in SUBACK", len(subs), len(subAck.Codes))
}
for i := 0; i < len(subAck.Codes); i++ {
subs[i].QoS = QoS(subAck.Codes[i])
}
}
return nil
return subs, nil
}

0 comments on commit 70ce0b0

Please sign in to comment.