diff --git a/awscrt/mqtt5.py b/awscrt/mqtt5.py index 4f8bff98d..42bc68632 100644 --- a/awscrt/mqtt5.py +++ b/awscrt/mqtt5.py @@ -5,7 +5,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0. -from typing import Any, Callable +from typing import Any, Callable, Union import _awscrt from concurrent.futures import Future from enum import IntEnum @@ -1104,8 +1104,9 @@ class PublishPacket: message_expiry_interval_sec (int): Sent publishes - indicates the maximum amount of time allowed to elapse for message delivery before the server should instead delete the message (relative to a recipient). Received publishes - indicates the remaining amount of time (from the server's perspective) before the message would have been deleted relative to the subscribing client. If left None, indicates no expiration timeout. topic_alias (int): An integer value that is used to identify the Topic instead of using the Topic Name. On outbound publishes, this will only be used if the outbound topic aliasing behavior has been set to Manual. response_topic (str): Opaque topic string intended to assist with request/response implementations. Not internally meaningful to MQTT5 or this client. - correlation_data (Any): Opaque binary data used to correlate between publish messages, as a potential method for request-response implementation. Not internally meaningful to MQTT5. - subscription_identifiers (Sequence[int]): The subscription identifiers of all the subscriptions this message matched. + correlation_data (Optional[Union[bytes, str]]): Deprecated, use `correlation_data_bytes` instead. Opaque binary data used to correlate between publish messages, as a potential method for request-response implementation. Not internally meaningful to MQTT5. For incoming publishes, this will be a utf8 string (if correlation data exists and it's convertible to utf-8) or None (either it didn't exist, or did but wasn't convertible) + correlation_data_bytes (Optional[Union[bytes, str]]): Opaque binary data used to correlate between publish messages, as a potential method for request-response implementation. Not internally meaningful to MQTT5. For outbound publishes, this field takes priority over `correlation_data`. For incoming publishes, this will be binary data if correlation data is set, otherwise it will be None. + subscription_identifiers (Sequence[int]): The subscription identifiers of all the subscriptions this message matched. This field is ignored on outbound publishes (setting it is a protocol error). content_type (str): Property specifying the content type of the payload. Not internally meaningful to MQTT5. user_properties (Sequence[UserProperty]): List of MQTT5 user properties included with the packet. """ @@ -1117,7 +1118,9 @@ class PublishPacket: message_expiry_interval_sec: int = None topic_alias: int = None response_topic: str = None - correlation_data: Any = None # Unicode objects are converted to C strings using 'utf-8' encoding + correlation_data_bytes: 'Optional[Union[bytes, str]]' = None # binary data if correlation data exists on the packet + # Deprecated. Incoming publishes: a string if correlation data exists on the packet and is convertible to utf-8 + correlation_data: 'Optional[Union[bytes, str]]' = None subscription_identifiers: 'Sequence[int]' = None # ignore attempts to set but provide in received packets content_type: str = None user_properties: 'Sequence[UserProperty]' = None @@ -1447,8 +1450,18 @@ def _on_publish( if topic_alias_exists: publish_packet.topic_alias = topic_alias publish_packet.response_topic = response_topic - publish_packet.correlation_data = correlation_data - if publish_packet.subscription_identifiers is not None: + + # hacky workaround to maintain behavioral backwards compatibility with deprecated parameter + if correlation_data is not None: + # `correlation_data_bytes` always has the correlation data, as binary data + publish_packet.correlation_data_bytes = correlation_data + try: + # `correlation_data` contains the correlation data as a utf-8 string, if it can be converted + publish_packet.correlation_data = correlation_data.decode("utf-8") + except Exception: + pass + + if subscription_identifiers_tuples is not None: publish_packet.subscription_identifiers = [subscription_identifier for (subscription_identifier) in subscription_identifiers_tuples] publish_packet.content_type = content_type @@ -1768,7 +1781,7 @@ def __init__(self, client_options: ClientOptions): will.message_expiry_interval_sec, will.topic_alias, will.response_topic, - will.correlation_data, + will.correlation_data_bytes or will.correlation_data, will.content_type, will.user_properties, client_options.session_behavior, @@ -1865,7 +1878,7 @@ def puback(error_code, qos, reason_code, reason_string, user_properties_tuples): publish_packet.message_expiry_interval_sec, publish_packet.topic_alias, publish_packet.response_topic, - publish_packet.correlation_data, + publish_packet.correlation_data_bytes or publish_packet.correlation_data, publish_packet.content_type, publish_packet.user_properties, puback) diff --git a/source/mqtt5_client.c b/source/mqtt5_client.c index eddc5ab7f..f5f9b8eea 100644 --- a/source/mqtt5_client.c +++ b/source/mqtt5_client.c @@ -260,7 +260,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu result = PyObject_CallMethod( client->client_core, "_on_publish", - "(y#iOs#OiOIOHs#z#Os#O)", + "(y#iOs#OiOIOHs#y#Os#O)", /* y */ publish_packet->payload.ptr, /* # */ publish_packet->payload.len, /* i */ (int)publish_packet->qos, @@ -276,7 +276,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu /* H */ (unsigned short)(publish_packet->topic_alias ? *publish_packet->topic_alias : 0), /* s */ publish_packet->response_topic ? publish_packet->response_topic->ptr : NULL, /* # */ publish_packet->response_topic ? publish_packet->response_topic->len : 0, - /* z */ publish_packet->correlation_data ? publish_packet->correlation_data->ptr : NULL, + /* y */ publish_packet->correlation_data ? publish_packet->correlation_data->ptr : NULL, /* # */ publish_packet->correlation_data ? publish_packet->correlation_data->len : 0, /* O */ subscription_identifier_count > 0 ? subscription_identifier_list : Py_None, /* s */ publish_packet->content_type ? publish_packet->content_type->ptr : NULL, diff --git a/test/test_mqtt.py b/test/test_mqtt.py index 41d722e48..c76bad7fa 100644 --- a/test/test_mqtt.py +++ b/test/test_mqtt.py @@ -168,8 +168,8 @@ def test_will(self): host_name=test_input_endpoint, port=8883, will=Will(self.TEST_TOPIC, QoS.AT_LEAST_ONCE, self.TEST_MSG, False), - ping_timeout_ms=500, - keep_alive_secs=1 + ping_timeout_ms=10000, + keep_alive_secs=30 ) connection.connect().result(TIMEOUT) @@ -177,7 +177,9 @@ def test_will(self): client=client, client_id=create_client_id(), host_name=test_input_endpoint, - port=8883 + port=8883, + ping_timeout_ms=10000, + keep_alive_secs=30 ) subscriber.connect().result(TIMEOUT) @@ -193,16 +195,18 @@ def on_message(**kwargs): self.assertEqual(self.TEST_TOPIC, suback['topic']) self.assertIs(QoS.AT_LEAST_ONCE, suback['qos']) - # Disconnect the will client to send the will - time.sleep(1.1) # wait 1.1 seconds to ensure we can make another client ID connect + # wait a few seconds to ensure we can make another client ID connect (don't trigger IoT Core limit) + time.sleep(2) + + # Disconnect the will client to send the will by making another connection with the same client id disconnecter = Connection( client=client, client_id=will_client_id, host_name=test_input_endpoint, port=8883, will=Will(self.TEST_TOPIC, QoS.AT_LEAST_ONCE, self.TEST_MSG, False), - ping_timeout_ms=500, - keep_alive_secs=1 + ping_timeout_ms=10000, + keep_alive_secs=30 ) disconnecter.connect().result(TIMEOUT) diff --git a/test/test_mqtt5.py b/test/test_mqtt5.py index 6446285e6..143fd04cf 100644 --- a/test/test_mqtt5.py +++ b/test/test_mqtt5.py @@ -1230,6 +1230,98 @@ def test_operation_will(self): client2.stop() callbacks2.future_stopped.result(TIMEOUT) + def do_will_correlation_data_test(self, outbound_correlation_data_bytes, outbound_correlation_data, + expected_correlation_data_bytes, expected_correlation_data): + input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST") + input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT") + input_key = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_KEY") + + client_id_publisher = create_client_id() + topic_filter = "test/MQTT5_Binding_Python_" + client_id_publisher + + payload = "TEST WILL" + payload_bytes = payload.encode("utf-8") + + will_packet = mqtt5.PublishPacket( + payload="TEST WILL", + qos=mqtt5.QoS.AT_LEAST_ONCE, + topic=topic_filter, + correlation_data_bytes=outbound_correlation_data_bytes, + correlation_data=outbound_correlation_data + ) + + tls_ctx_options = io.TlsContextOptions.create_client_with_mtls_from_path( + input_cert, + input_key + ) + + client_options1 = mqtt5.ClientOptions( + host_name=input_host_name, + port=8883 + ) + client_options1.connect_options = mqtt5.ConnectPacket(client_id=client_id_publisher, + will_delay_interval_sec=0, + will=will_packet) + client_options1.tls_ctx = io.ClientTlsContext(tls_ctx_options) + callbacks1 = Mqtt5TestCallbacks() + client1 = self._create_client(client_options=client_options1, callbacks=callbacks1) + client1.start() + callbacks1.future_connection_success.result(TIMEOUT) + + client_options2 = mqtt5.ClientOptions( + host_name=input_host_name, + port=8883 + ) + client_options2.connect_options = mqtt5.ConnectPacket(client_id=create_client_id()) + client_options2.tls_ctx = io.ClientTlsContext(tls_ctx_options) + callbacks2 = Mqtt5TestCallbacks() + client2 = self._create_client(client_options=client_options2, callbacks=callbacks2) + client2.start() + callbacks2.future_connection_success.result(TIMEOUT) + + subscriptions = [] + subscriptions.append(mqtt5.Subscription(topic_filter=topic_filter, qos=mqtt5.QoS.AT_LEAST_ONCE)) + subscribe_packet = mqtt5.SubscribePacket( + subscriptions=subscriptions) + subscribe_future = client2.subscribe(subscribe_packet=subscribe_packet) + suback_packet = subscribe_future.result(TIMEOUT) + self.assertIsInstance(suback_packet, mqtt5.SubackPacket) + + disconnect_packet = mqtt5.DisconnectPacket(reason_code=mqtt5.DisconnectReasonCode.DISCONNECT_WITH_WILL_MESSAGE) + client1.stop(disconnect_packet=disconnect_packet) + callbacks1.future_stopped.result(TIMEOUT) + + received_will = callbacks2.future_publish_received.result(TIMEOUT) + self.assertIsInstance(received_will, mqtt5.PublishPacket) + self.assertEqual(received_will.payload, payload_bytes) + self.assertEqual(received_will.correlation_data_bytes, expected_correlation_data_bytes) + self.assertEqual(received_will.correlation_data, expected_correlation_data) + + client2.stop() + callbacks2.future_stopped.result(TIMEOUT) + + def test_will_correlation_data_bytes_binary(self): + correlation_data = bytearray(os.urandom(64)) + self.do_will_correlation_data_test(correlation_data, None, correlation_data, None) + + def test_will_correlation_data_bytes_string(self): + correlation_data = "CorrelationData" + correlation_data_as_bytes = correlation_data.encode('utf-8') + self.do_correlation_data_test(correlation_data, None, correlation_data_as_bytes, correlation_data) + + def test_will_correlation_data_binary(self): + correlation_data = bytearray(os.urandom(64)) + self.do_correlation_data_test(None, correlation_data, correlation_data, None) + + def test_will_correlation_data_string(self): + correlation_data = "CorrelationData" + correlation_data_as_bytes = correlation_data.encode('utf-8') + self.do_correlation_data_test(None, correlation_data, correlation_data_as_bytes, correlation_data) + + def test_will_correlation_data_bytes_binary_precedence(self): + correlation_data = bytearray(os.urandom(64)) + self.do_correlation_data_test(correlation_data, "Ignored", correlation_data, None) + def test_operation_binary_publish(self): input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST") input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT") @@ -1292,6 +1384,81 @@ def test_operation_binary_publish(self): client.stop() callbacks.future_stopped.result(TIMEOUT) + def do_correlation_data_test(self, outbound_correlation_data_bytes, outbound_correlation_data, + expected_correlation_data_bytes, expected_correlation_data): + input_host_name = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_HOST") + input_cert = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_CERT") + input_key = _get_env_variable("AWS_TEST_MQTT5_IOT_CORE_RSA_KEY") + + client_id = create_client_id() + topic_filter = "test/MQTT5_Binding_Python_" + client_id + payload = bytearray(os.urandom(256)) + + client_options = mqtt5.ClientOptions( + host_name=input_host_name, + port=8883 + ) + tls_ctx_options = io.TlsContextOptions.create_client_with_mtls_from_path( + input_cert, + input_key + ) + client_options.tls_ctx = io.ClientTlsContext(tls_ctx_options) + callbacks = Mqtt5TestCallbacks() + client = self._create_client(client_options=client_options, callbacks=callbacks) + client.start() + callbacks.future_connection_success.result(TIMEOUT) + + subscriptions = [] + subscriptions.append(mqtt5.Subscription(topic_filter=topic_filter, qos=mqtt5.QoS.AT_LEAST_ONCE)) + subscribe_packet = mqtt5.SubscribePacket( + subscriptions=subscriptions) + subscribe_future = client.subscribe(subscribe_packet=subscribe_packet) + suback_packet = subscribe_future.result(TIMEOUT) + self.assertIsInstance(suback_packet, mqtt5.SubackPacket) + + publish_packet = mqtt5.PublishPacket( + payload=payload, + topic=topic_filter, + qos=mqtt5.QoS.AT_LEAST_ONCE, + correlation_data_bytes=outbound_correlation_data_bytes, + correlation_data=outbound_correlation_data) + + publish_future = client.publish(publish_packet=publish_packet) + publish_completion_data = publish_future.result(TIMEOUT) + puback_packet = publish_completion_data.puback + self.assertIsInstance(puback_packet, mqtt5.PubackPacket) + + received_publish = callbacks.future_publish_received.result(TIMEOUT) + self.assertIsInstance(received_publish, mqtt5.PublishPacket) + self.assertEqual(received_publish.payload, payload) + self.assertEqual(received_publish.correlation_data_bytes, expected_correlation_data_bytes) + self.assertEqual(received_publish.correlation_data, expected_correlation_data) + + client.stop() + callbacks.future_stopped.result(TIMEOUT) + + def test_operation_publish_correlation_data_bytes_binary(self): + correlation_data = bytearray(os.urandom(64)) + self.do_correlation_data_test(correlation_data, None, correlation_data, None) + + def test_operation_publish_correlation_data_bytes_string(self): + correlation_data = "CorrelationData" + correlation_data_as_bytes = correlation_data.encode('utf-8') + self.do_correlation_data_test(correlation_data, None, correlation_data_as_bytes, correlation_data) + + def test_operation_publish_correlation_data_binary(self): + correlation_data = bytearray(os.urandom(64)) + self.do_correlation_data_test(None, correlation_data, correlation_data, None) + + def test_operation_publish_correlation_data_string(self): + correlation_data = "CorrelationData" + correlation_data_as_bytes = correlation_data.encode('utf-8') + self.do_correlation_data_test(None, correlation_data, correlation_data_as_bytes, correlation_data) + + def test_operation_publish_correlation_data_bytes_binary_precedence(self): + correlation_data = bytearray(os.urandom(64)) + self.do_correlation_data_test(correlation_data, "Ignored", correlation_data, None) + # ============================================================== # OPERATION ERROR TEST CASES # ==============================================================