diff --git a/client_integration_test.go b/client_integration_test.go index 9e9d923..80c1e1e 100644 --- a/client_integration_test.go +++ b/client_integration_test.go @@ -21,6 +21,7 @@ import ( "context" "crypto/tls" "fmt" + "reflect" "testing" "time" ) @@ -55,7 +56,7 @@ func ExampleClient() { if _, err := cli.Connect(ctx, "TestClient", WithCleanSession(true)); err != nil { panic(err) } - if err := cli.Subscribe(ctx, Subscription{Topic: "test/topic", QoS: QoS1}); err != nil { + if _, err := cli.Subscribe(ctx, Subscription{Topic: "test/topic", QoS: QoS1}); err != nil { panic(err) } @@ -165,9 +166,14 @@ func TestIntegration_PublishSubscribe(t *testing.T) { })) topic := "test_pubsub_" + name - if err := cli.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil { + subs, err := cli.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}) + if err != nil { t.Fatalf("Unexpected error: '%v'", err) } + expectedSubs := []Subscription{{Topic: topic, QoS: qos}} + if !reflect.DeepEqual(expectedSubs, subs) { + t.Fatalf("Expected subscriptions: %v, actual: %v", expectedSubs, subs) + } if err := cli.Publish(ctx, &Message{ Topic: topic, @@ -216,9 +222,14 @@ func TestIntegration_SubscribeUnsubscribe(t *testing.T) { t.Fatalf("Unexpected error: '%v'", err) } - if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil { + subs, err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}) + if err != nil { t.Fatalf("Unexpected error: '%v'", err) } + expectedSubs := []Subscription{{Topic: "test", QoS: QoS2}} + if !reflect.DeepEqual(expectedSubs, subs) { + t.Fatalf("Expected subscriptions: %v, actual: %v", expectedSubs, subs) + } if err := cli.Unsubscribe(ctx, "test"); err != nil { t.Fatalf("Unexpected error: '%v'", err) @@ -284,7 +295,7 @@ func BenchmarkPublishSubscribe(b *testing.B) { chReceived <- msg })) - if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil { + if _, err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS2}); err != nil { b.Fatalf("Unexpected error: '%v'", err) } diff --git a/examples/mqtts-client-cert/main.go b/examples/mqtts-client-cert/main.go index 72a5db9..2e46614 100644 --- a/examples/mqtts-client-cert/main.go +++ b/examples/mqtts-client-cert/main.go @@ -102,7 +102,7 @@ func main() { ) // Subscribe two topics. - if err := cli.Subscribe(ctx, + if _, err := cli.Subscribe(ctx, mqtt.Subscription{ Topic: "test", QoS: mqtt.QoS1, diff --git a/examples/wss-presign-url/main.go b/examples/wss-presign-url/main.go index ed14659..8ece97a 100644 --- a/examples/wss-presign-url/main.go +++ b/examples/wss-presign-url/main.go @@ -80,7 +80,7 @@ func main() { }), ) - if err := cli.Subscribe(ctx, + if _, err := cli.Subscribe(ctx, mqtt.Subscription{ Topic: "stop", QoS: mqtt.QoS1, diff --git a/mock/mqtt.go b/mock/mqtt.go index 443b5c2..67affba 100644 --- a/mock/mqtt.go +++ b/mock/mqtt.go @@ -27,7 +27,7 @@ type Client struct { ConnectFn func(ctx context.Context, clientID string, opts ...mqtt.ConnectOption) (sessionPresent bool, err error) DisconnectFn func(ctx context.Context) error PublishFn func(ctx context.Context, message *mqtt.Message) error - SubscribeFn func(ctx context.Context, subs ...mqtt.Subscription) error + SubscribeFn func(ctx context.Context, subs ...mqtt.Subscription) ([]mqtt.Subscription, error) UnsubscribeFn func(ctx context.Context, subs ...string) error PingFn func(ctx context.Context) error handler mqtt.Handler @@ -58,9 +58,9 @@ func (c *Client) Publish(ctx context.Context, message *mqtt.Message) error { } // Subscribe implements mqtt.Client. -func (c *Client) Subscribe(ctx context.Context, subs ...mqtt.Subscription) error { +func (c *Client) Subscribe(ctx context.Context, subs ...mqtt.Subscription) ([]mqtt.Subscription, error) { if c.SubscribeFn == nil { - return nil + return nil, nil } return c.SubscribeFn(ctx, subs...) } diff --git a/mqtt.go b/mqtt.go index 15a0ebd..f6da0e0 100644 --- a/mqtt.go +++ b/mqtt.go @@ -41,7 +41,7 @@ type Client interface { Connect(ctx context.Context, clientID string, opts ...ConnectOption) (sessionPresent bool, err error) Disconnect(ctx context.Context) error Publish(ctx context.Context, message *Message) error - Subscribe(ctx context.Context, subs ...Subscription) error + Subscribe(ctx context.Context, subs ...Subscription) ([]Subscription, error) Unsubscribe(ctx context.Context, subs ...string) error Ping(ctx context.Context) error Handle(Handler) diff --git a/packet_test.go b/packet_test.go index c2f4131..d8f01c7 100644 --- a/packet_test.go +++ b/packet_test.go @@ -165,7 +165,7 @@ func TestPacketSendCancel(t *testing.T) { cancel() t.Run("Subscribe", func(t *testing.T) { - if err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, context.Canceled) { + if _, err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, context.Canceled) { t.Errorf("Expected error: '%v', got: '%v'", context.Canceled, err) } }) @@ -200,7 +200,7 @@ func TestPacketSendError(t *testing.T) { ca.Close() t.Run("Subscribe", func(t *testing.T) { - if err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, io.ErrClosedPipe) { + if _, err := cli.Subscribe(ctx, Subscription{Topic: "test"}); !errors.Is(err, io.ErrClosedPipe) { t.Errorf("Expected error: '%v', got: '%v'", io.ErrClosedPipe, err) } }) @@ -240,7 +240,8 @@ func TestConnectionError(t *testing.T) { return err }, func(ctx context.Context, cli *BaseClient) error { - return cli.Subscribe(ctx, Subscription{Topic: "test"}) + _, err := cli.Subscribe(ctx, Subscription{Topic: "test"}) + return err }, func(ctx context.Context, cli *BaseClient) error { return cli.Unsubscribe(ctx, "test") diff --git a/paho/paho.go b/paho/paho.go index 78f311f..300afb4 100644 --- a/paho/paho.go +++ b/paho/paho.go @@ -291,7 +291,7 @@ func (c *pahoWrapper) Subscribe(topic string, qos byte, callback paho.MessageHan return } - token.err = cli.Subscribe( + _, token.err = cli.Subscribe( context.Background(), mqtt.Subscription{ Topic: topic, @@ -325,7 +325,7 @@ func (c *pahoWrapper) SubscribeMultiple(filters map[string]byte, callback paho.M return } - token.err = cli.Subscribe(context.Background(), subs...) + _, token.err = cli.Subscribe(context.Background(), subs...) token.release() }() return token diff --git a/paho/paho_integration_test.go b/paho/paho_integration_test.go index 41e27e6..b8a4da6 100644 --- a/paho/paho_integration_test.go +++ b/paho/paho_integration_test.go @@ -156,7 +156,7 @@ func TestIntegration_Will(t *testing.T) { t.Fatalf("Unexpected error: '%v'", err) } defer cli0.Disconnect(context.Background()) - if err := cli0.Subscribe(ctx, mqtt.Subscription{Topic: "will", QoS: mqtt.QoS1}); err != nil { + if _, err := cli0.Subscribe(ctx, mqtt.Subscription{Topic: "will", QoS: mqtt.QoS1}); err != nil { t.Fatalf("Unexpected error: '%v'", err) } cli0.Handle(mqtt.HandlerFunc(func(msg *mqtt.Message) { diff --git a/reconnclient_integration_test.go b/reconnclient_integration_test.go index f8fa17a..991fa0e 100644 --- a/reconnclient_integration_test.go +++ b/reconnclient_integration_test.go @@ -67,7 +67,7 @@ func TestIntegration_ReconnectClient(t *testing.T) { time.Sleep(time.Millisecond) cli.(*reconnectClient).cli.Close() - if err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS1}); err != nil { + if _, err := cli.Subscribe(ctx, Subscription{Topic: "test", QoS: QoS1}); err != nil { t.Fatalf("Unexpected error: '%v'", err) } if err := cli.Publish(ctx, &Message{ @@ -202,7 +202,7 @@ func TestIntegration_ReconnectClient_Resubscribe(t *testing.T) { }); err != nil { t.Fatalf("Unexpected error: '%v'", err) } - if err := cli.Subscribe(ctx, Subscription{ + if _, err := cli.Subscribe(ctx, Subscription{ Topic: "test/" + name + pktName, QoS: QoS1, }); err != nil { @@ -267,7 +267,7 @@ func TestIntegration_ReconnectClient_RetryPublish(t *testing.T) { topic := fmt.Sprintf("test/Retry_%s_%d", name, qos) - if err := cliRecv.Subscribe(ctx, Subscription{ + if _, err := cliRecv.Subscribe(ctx, Subscription{ Topic: topic, QoS: qos, }); err != nil { @@ -739,7 +739,7 @@ func TestIntegration_ReconnectClient_RepeatedDisconnect(t *testing.T) { defer mu.Unlock() received[msg.Payload[0]]++ })) - if err := cliRaw.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil { + if _, err := cliRaw.Subscribe(ctx, Subscription{Topic: topic, QoS: qos}); err != nil { t.Fatalf("Unexpected error: '%v'", err) } diff --git a/retryclient.go b/retryclient.go index b568d3e..de7e811 100644 --- a/retryclient.go +++ b/retryclient.go @@ -62,8 +62,9 @@ func (c *RetryClient) Publish(ctx context.Context, message *Message) error { // Subscribe tries to subscribe the topic and immediately return nil. // If it is not acknowledged to be subscribed, the request will be queued. -func (c *RetryClient) Subscribe(ctx context.Context, subs ...Subscription) error { - return wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) { +// First return value ([]Subscription) is always nil. +func (c *RetryClient) Subscribe(ctx context.Context, subs ...Subscription) ([]Subscription, error) { + return nil, wrapError(c.pushTask(ctx, func(ctx context.Context, cli *BaseClient) { c.subscribe(ctx, false, cli, subs...) }), "retryclient: subscribing") } @@ -110,7 +111,7 @@ func (c *RetryClient) publish(ctx context.Context, cli *BaseClient, message *Mes func (c *RetryClient) subscribe(ctx context.Context, retry bool, cli *BaseClient, subs ...Subscription) { subscribe := func(ctx context.Context, cli *BaseClient) error { subscriptions(subs).applyTo(&c.subEstablished) - if err := cli.Subscribe(ctx, subs...); err != nil { + if _, err := cli.Subscribe(ctx, subs...); err != nil { select { case <-ctx.Done(): if !retry { diff --git a/retryclient_integration_test.go b/retryclient_integration_test.go index 987b208..f4e0098 100644 --- a/retryclient_integration_test.go +++ b/retryclient_integration_test.go @@ -81,7 +81,7 @@ func TestIntegration_RetryClient_Cancel(t *testing.T) { cliRecv.Handle(HandlerFunc(func(msg *Message) { chRecv <- msg })) - if err := cliRecv.Subscribe(ctx, Subscription{Topic: "testCancel", QoS: QoS2}); err != nil { + if _, err := cliRecv.Subscribe(ctx, Subscription{Topic: "testCancel", QoS: QoS2}); err != nil { t.Fatalf("Unexpected error: '%v'", err) } @@ -151,7 +151,7 @@ func TestIntegration_RetryClient_TaskQueue(t *testing.T) { done() } })) - if err := cli.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil { + if _, err := cli.Subscribe(ctx, Subscription{Topic: "test/queue", QoS: QoS1}); err != nil { t.Fatal(err) } time.Sleep(10 * time.Millisecond) diff --git a/subscribe.go b/subscribe.go index cb19ebc..e382df9 100644 --- a/subscribe.go +++ b/subscribe.go @@ -61,11 +61,11 @@ func (p *pktSubscribe) Pack() []byte { } // Subscribe topics. -func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) error { +func (c *BaseClient) Subscribe(ctx context.Context, subs ...Subscription) ([]Subscription, error) { return subscribeImpl(ctx, c, subs...) } -func subscribeImpl(ctx context.Context, c *BaseClient, subs ...Subscription) error { +func subscribeImpl(ctx context.Context, c *BaseClient, subs ...Subscription) ([]Subscription, error) { c.muConnecting.RLock() defer c.muConnecting.RUnlock() @@ -80,26 +80,27 @@ func subscribeImpl(ctx context.Context, c *BaseClient, subs ...Subscription) err c.sig.mu.Unlock() retrySubscribe := func(ctx context.Context, cli *BaseClient) error { - return subscribeImpl(ctx, cli, subs...) + _, err := subscribeImpl(ctx, cli, subs...) + return err } pkt := (&pktSubscribe{ID: id, Subscriptions: subs}).Pack() if err := c.write(pkt); err != nil { - return wrapErrorWithRetry(err, retrySubscribe, "sending SUBSCRIBE") + return nil, wrapErrorWithRetry(err, retrySubscribe, "sending SUBSCRIBE") } select { case <-c.connClosed: - return wrapErrorWithRetry(ErrClosedTransport, retrySubscribe, "waiting SUBACK") + return nil, wrapErrorWithRetry(ErrClosedTransport, retrySubscribe, "waiting SUBACK") case <-ctx.Done(): - return wrapErrorWithRetry(ctx.Err(), retrySubscribe, "waiting SUBACK") + return nil, wrapErrorWithRetry(ctx.Err(), retrySubscribe, "waiting SUBACK") case subAck := <-chSubAck: if len(subAck.Codes) != len(subs) { c.Transport.Close() - return wrapErrorf(ErrInvalidSubAck, "subscribing %d topics: %d topics in SUBACK", len(subs), len(subAck.Codes)) + return nil, wrapErrorf(ErrInvalidSubAck, "subscribing %d topics: %d topics in SUBACK", len(subs), len(subAck.Codes)) } for i := 0; i < len(subAck.Codes); i++ { subs[i].QoS = QoS(subAck.Codes[i]) } } - return nil + return subs, nil }